Compare commits

..

2 Commits

Author SHA1 Message Date
5d7e37ab7c Working config & permissions 2025-03-27 23:28:12 +08:00
5896d5fcd2 Use per-source address configs 2025-03-27 22:25:26 +08:00
7 changed files with 223 additions and 54 deletions

4
config.json Normal file
View File

@ -0,0 +1,4 @@
{
"ListenAddress": "127.0.0.1:3128",
"ListenType": "tcp4"
}

7
internal/conf/config.go Normal file
View File

@ -0,0 +1,7 @@
package conf
type Config struct {
ListenAddress string // Address to listen on, passed to net.Listen
ListenType string // Type of network to listen on, passed to net.Listen. One of tcp, tcp4 and tcp6
}

View File

@ -10,6 +10,8 @@ import (
"net/http"
"strings"
"time"
"edgaru089.ink/go/regolith/internal/perm"
)
const (
@ -22,6 +24,24 @@ var (
)
type Server struct {
Perm *perm.Perm
}
// checkPerm invokes perm.Match.
// If s.perm is nil, then every request is Accept-ed.
//
// It also logs if the action is Deny.
func (s *Server) checkPerm(src, dest string) (act perm.Action) {
if s.Perm == nil {
return perm.ActionAccept
}
src_host, _, _ := net.SplitHostPort(src)
act = s.Perm.Match(src_host, dest)
if act == perm.ActionDeny {
log.Printf("denied: [%s] -> [%s]", src, dest)
}
return
}
func (s *Server) Serve(listener net.Listener) (err error) {
@ -120,6 +140,12 @@ func (s *Server) handle_connect(conn net.Conn, req *http.Request) {
}
req_addr := net.JoinHostPort(req.URL.Hostname(), port)
// check permission
if s.checkPerm(conn.RemoteAddr().String(), req_addr) != perm.ActionAccept {
_ = simple_respond(conn, req, http.StatusBadGateway)
return
}
// prep for error log on close
var close_err error
defer func() {
@ -164,11 +190,22 @@ func (s *Server) handle_request(conn net.Conn, req *http.Request) (should_keepal
remove_hop_headers(req.Header)
remove_extra_host_port(req)
if req.URL.Scheme == "" || req.URL.Host == "" {
if req.URL.Scheme != "http" || req.URL.Host == "" {
_ = simple_respond(conn, req, http.StatusBadRequest)
return false
}
// Check permission
req_hostname := req.URL.Hostname()
req_port := req.URL.Port()
if req_port == "" {
req_port = "80"
}
if s.checkPerm(conn.RemoteAddr().String(), net.JoinHostPort(req_hostname, req_port)) != perm.ActionAccept {
_ = simple_respond(conn, req, http.StatusBadGateway)
return false
}
// Request & error log
var close_err error
defer func() {

View File

@ -27,6 +27,20 @@ func MostSevere(a, b Action) Action {
return min(a, b)
}
func (a Action) String() (name string) {
switch a {
case ActionDeny:
name = "deny"
case ActionIgnore:
name = "ignore"
case ActionAccept:
name = "accept"
default:
name = "<" + strconv.Itoa(int(a)) + ">"
}
return
}
// Marshal/Unmarshal for Action
func (a *Action) UnmarshalText(text []byte) error {
switch strings.ToLower(string(text)) {
@ -42,25 +56,14 @@ func (a *Action) UnmarshalText(text []byte) error {
return nil
}
func (a Action) MarshalText() ([]byte, error) {
var name string
switch a {
case ActionDeny:
name = "deny"
case ActionIgnore:
name = "ignore"
case ActionAccept:
name = "accept"
default:
name = "<" + strconv.Itoa(int(a)) + ">"
}
return []byte(name), nil
return []byte(a.String()), nil
}
// Config is a list of address and actions.
// Config is a list of address and actions, for each source address.
// It can just be Marshal/Unmarshaled into/from json.
type Config struct {
DefaultAction Action // What we should do when no action is matched.
DefaultPort []uint // Port numbers to add to address without port numbers already in them. Don't put too many entries in here.
DefaultPort []int // Port numbers to add to address without port numbers already in them. Don't put too many entries in here.
// Object which holds addresses and optionally ports, mapping to actions.
//

View File

@ -1,74 +1,102 @@
package perm
import (
"log"
"net"
"strconv"
"strings"
"sync"
)
type int_perm struct {
match map[string]Action
def Action
}
// Perm matches address:port strings to Actions
// loaded from a Config.
//
// global permissions are stored under the "$global" address.
//
// It's thread safe.
type Perm struct {
lock sync.RWMutex
perm map[string]Action
def Action
global int_perm
source map[string]int_perm
}
// New creates a new Perm struct from a Config.
//
// You can also &perm.Perm{} and then call Load on it.
func New(c *Config) (p *Perm) {
func New(c map[string]Config) (p *Perm) {
p = &Perm{}
p.Load(c)
return
}
// Load loads/reloads the Perm struct.
func (p *Perm) Load(c *Config) {
func (p *Perm) Load(cs map[string]Config) {
p.lock.Lock()
defer p.lock.Unlock()
if p.perm == nil {
p.perm = make(map[string]Action)
if p.source == nil {
p.source = make(map[string]int_perm)
} else {
clear(p.perm)
clear(p.source)
}
p.def = c.DefaultAction
// helper to load every source addr
load_per_source := func(c Config) (p_int int_perm) {
p_int.match = make(map[string]Action)
p_int.def = c.DefaultAction
log.Printf("default action %s", p_int.def)
// insert helper to use the most severe action existing
insert := func(addrport string, action Action) {
existing_action, ok := p.perm[addrport]
if ok {
p.perm[addrport] = MostSevere(existing_action, action)
} else {
p.perm[addrport] = action
}
}
// loop around the Match map
for addrport, act := range c.Match {
addr, port := splitHostPort(addrport)
if port != "" {
insert(net.JoinHostPort(addr, port), act)
} else {
// so this is why def_port shouldn't be that big
// TODO change this to sth faster
for def_port := range c.DefaultPort {
insert(net.JoinHostPort(addr, strconv.Itoa(def_port)), act)
// insert helper to use the most severe action existing
insert := func(addrport string, action Action) {
log.Printf("loading target %s, action %s", addrport, action)
existing_action, ok := p_int.match[addrport]
if ok {
p_int.match[addrport] = MostSevere(existing_action, action)
} else {
p_int.match[addrport] = action
}
}
// loop around the Match map
for addrport, act := range c.Match {
addr, port := splitHostPort(addrport)
if port != "" {
insert(net.JoinHostPort(addr, port), act)
} else {
// so this is why def_port shouldn't be that big
// TODO change this to sth faster
for _, def_port := range c.DefaultPort {
insert(net.JoinHostPort(addr, strconv.Itoa(def_port)), act)
}
}
}
return
}
for src, c := range cs {
log.Printf("loading source %s", src)
if strings.EqualFold(src, "$global") {
p.global = load_per_source(c)
} else {
p.source[src] = load_per_source(c)
}
}
return
}
// Match matches an address to an action.
// addr must be in net.JoinHostPort format.
func (p *Perm) Match(addr string) Action {
//
// src must be a host (either ipv4 or v6), while
// dest must be in net.JoinHostPort format.
func (p *Perm) Match(src, dest string) Action {
// sanity check
if p == nil {
return ActionDeny
@ -77,14 +105,28 @@ func (p *Perm) Match(addr string) Action {
p.lock.RLock()
defer p.lock.RUnlock()
// sanity check no.2
if p.perm == nil {
return p.def
// find its source struct
p_int, ok_int := p.source[src]
// only check if dest is directly listed
if ok_int && p_int.match != nil {
if action, ok := p_int.match[dest]; ok {
return action
}
}
action, ok := p.perm[addr]
if !ok {
return p.def
// then check the global struct, also only directly listed
if p.global.match != nil {
if action, ok := p.global.match[dest]; ok {
return action
}
}
// directly listed in neither.
if ok_int {
// if source struct exists, use source struct default.
return p_int.def
} else {
// if not exist, use global default.
return p.global.def
}
return action
}

59
main.go
View File

@ -1,29 +1,84 @@
package main
import (
"encoding/json"
"fmt"
"log"
"net"
"os"
"os/signal"
"syscall"
"edgaru089.ink/go/regolith/internal/conf"
"edgaru089.ink/go/regolith/internal/http"
"edgaru089.ink/go/regolith/internal/perm"
)
func main() {
var s *http.Server
listener, err := net.Listen("tcp", ":3128")
{
perm_buf, err := os.ReadFile("perm.json")
if err != nil {
panic(err)
}
perm_json := make(map[string]perm.Config)
err = json.Unmarshal(perm_buf, &perm_json)
if err != nil {
panic(err)
}
s = &http.Server{
Perm: perm.New(perm_json),
}
}
var conf conf.Config
{
conf_buf, err := os.ReadFile("config.json")
if err != nil {
panic(err)
}
err = json.Unmarshal(conf_buf, &conf)
if err != nil {
panic(err)
}
}
listener, err := net.Listen(conf.ListenType, conf.ListenAddress)
if err != nil {
panic(err)
}
log.Printf("listeneing on [%s], type %s", conf.ListenAddress, conf.ListenType)
sigint_chan := make(chan os.Signal, 1)
signal.Notify(sigint_chan, os.Interrupt)
go func() {
<-sigint_chan
log.Printf("SIGINT received, quitting")
listener.Close()
}()
s := &http.Server{}
sighup_chan := make(chan os.Signal, 1)
signal.Notify(sighup_chan, syscall.SIGHUP)
go func() {
for {
<-sighup_chan
log.Printf("SIGHUP received, reloading permissions")
perm_buf, err := os.ReadFile("perm.json")
if err != nil {
log.Printf("skipping reload: error opening perm.json: %e", err)
continue
}
perm_json := make(map[string]perm.Config)
err = json.Unmarshal(perm_buf, &perm_json)
if err != nil {
log.Printf("skipping reload: error unmarshaling perm.json: %e", err)
continue
}
s.Perm.Load(perm_json)
}
}()
err = s.Serve(listener)
if err != nil {
fmt.Println(err)

21
perm.json Normal file
View File

@ -0,0 +1,21 @@
{
"$global": {
"DefaultAction": "deny",
"DefaultPort": [443],
"Match": {
"mirrors.tuna.tsinghua.edu.cn": "accept",
"mirrors6.tuna.tsinghua.edu.cn": "accept",
"incoming.telemetry.mozilla.org": "ignore"
}
},
"127.0.0.1": {
"DefaultAction": "deny",
"DefaultPort": [443],
"Match": {
"pkg.go.dev": "accept",
"go.dev": "accept"
}
}
}