diff --git a/go.mod b/go.mod index 444e1f5..9ce5494 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module edgaru089.ink/go/regolith go 1.24.0 + +require golang.org/x/net v0.43.0 // indirect diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..8028634 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= +golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= diff --git a/internal/http/server.go b/internal/http/server.go index 6e8849c..e629626 100644 --- a/internal/http/server.go +++ b/internal/http/server.go @@ -3,6 +3,7 @@ package http import ( "bufio" "bytes" + "context" "errors" "io" "log" @@ -12,6 +13,7 @@ import ( "time" "edgaru089.ink/go/regolith/internal/perm" + "golang.org/x/net/proxy" ) const ( @@ -21,6 +23,8 @@ const ( var ( dialer = net.Dialer{Timeout: outgoing_client_timeout} http_client = http.Client{Timeout: outgoing_client_timeout} + + dialer_func func(ctx context.Context, network, addr string) (net.Conn, error) = dialer.DialContext ) type Server struct { @@ -58,6 +62,22 @@ func (s *Server) Serve(listener net.Listener) (err error) { } } +// SetDialer sets the global dialer every Server uses to dial outgoing HTTP/CONNECT connections. +// +// When set to nil, net.Dial is used instead. +func (s *Server) SetDialer(dialer proxy.ContextDialer) { + if http_client.Transport == nil { + http_client.Transport = http.DefaultTransport + } + + http_client.Transport.(*http.Transport).DialContext = dialer.DialContext + if dialer != nil { + dialer_func = dialer.DialContext + } else { + dialer_func = dialer.DialContext + } +} + func (s *Server) dispatch(conn net.Conn) { buf := bufio.NewReader(conn) for { @@ -156,7 +176,9 @@ func (s *Server) handle_connect(conn net.Conn, req *http.Request) { }() // dial - remote_conn, err := dialer.Dial("tcp", req_addr) + ctx, cancel := context.WithTimeout(context.Background(), outgoing_client_timeout) + remote_conn, err := dialer_func(ctx, "tcp", req_addr) + cancel() if err != nil { simple_respond(conn, req, http.StatusBadGateway) close_err = err diff --git a/internal/out/out.go b/internal/out/out.go new file mode 100644 index 0000000..507d382 --- /dev/null +++ b/internal/out/out.go @@ -0,0 +1,143 @@ +package out + +import ( + "context" + "log" + "net" + "net/url" + "sync" + + "golang.org/x/net/proxy" + + "edgaru089.ink/go/regolith/internal/util" +) + +// Configures upstream TCP connections. +// +// Does not support HTTP proxies as upstream. They're +// just a pain to use besides CONNECT commands. +// +// Upstreams are specified as 'direct' for direct dialing +// by the Go net package, or 'socks5://user:pass@host:port' +// for SOCKS5 upstreams. +// +// Numeric addresses are always dialed directly. +type Config struct { + Default string // Default URI + Proxied map[string][]string // Proxy URI to hostname array mapping +} + +// Dialer implements x/net/proxy.Dialer from the Config given to it. +type Dialer struct { + lock sync.RWMutex + + hostnames map[string]string + uris map[string]interface { + proxy.Dialer + proxy.ContextDialer + } + + def interface { + proxy.Dialer + proxy.ContextDialer + } +} + +var _ proxy.ContextDialer = &Dialer{} // *Dialer implements x/net/proxy.ContextDialer +var _ proxy.Dialer = &Dialer{} // *Dialer implements x/net/proxy.Dialer + +// New creates a new Dialer from a Config. +// +// You can also &out.Dialer{} and then call Load on it. +func New(c Config) (d *Dialer) { + d = &Dialer{} + d.Load(c) + return +} + +// Load loads/reloads the Dialer struct. +func (d *Dialer) Load(c Config) { + d.lock.Lock() + defer d.lock.Unlock() + + if len(c.Default) == 0 { + c.Default = "direct" + } + log.Print("default dialer ", c.Default) + + if d.hostnames == nil { + d.hostnames = make(map[string]string) + d.uris = make(map[string]interface { + proxy.Dialer + proxy.ContextDialer + }) + } else { + clear(d.hostnames) + clear(d.uris) + } + + for uri, hosts := range c.Proxied { + d.uris[uri] = nil + for _, host := range hosts { + d.hostnames[host] = uri + log.Printf("dialer( %s ) => %s", host, uri) + } + } + d.uris[c.Default] = nil + + // construct the Dialers + for uri := range d.uris { + if uri == "direct" { + d.uris[uri] = proxy.Direct + } else { + url, err := url.Parse(uri) + if err != nil { + log.Print("URI parse error: ", err) + d.uris[uri] = proxy.Direct + continue + } + dx, err := proxy.FromURL(url, proxy.Direct) + if err != nil { + log.Print("Proxy.FromURL error: ", err) + d.uris[uri] = proxy.Direct + continue + } + + var ok bool + if d.uris[uri], ok = dx.(interface { + proxy.ContextDialer + proxy.Dialer + }); !ok { + log.Print("Proxy.FromURL unable to cast to proxy.Dialer+ContextDialer") + d.uris[uri] = proxy.Direct + continue + } + } + } + + d.def = d.uris[c.Default] +} + +func (d *Dialer) findDialer(host string) interface { + proxy.Dialer + proxy.ContextDialer +} { + d.lock.RLock() + defer d.lock.RUnlock() + + if dx, ok := d.uris[d.hostnames[host]]; ok { + return dx + } else { + return d.def + } +} + +func (d *Dialer) Dial(network, address string) (net.Conn, error) { + host, _ := util.SplitHostPort(address) + return d.findDialer(host).Dial(network, address) +} + +func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + host, _ := util.SplitHostPort(address) + return d.findDialer(host).DialContext(ctx, network, address) +} diff --git a/internal/perm/perm.go b/internal/perm/perm.go index 9926ade..45a0e0a 100644 --- a/internal/perm/perm.go +++ b/internal/perm/perm.go @@ -69,7 +69,7 @@ func (p *Perm) Load(cs map[string]Config) { // loop around the Match map for addrport, act := range c.Match { - addr, port := splitHostPort(addrport) + addr, port := util.SplitHostPort(addrport) if port != "" { insert(net.JoinHostPort(addr, port), act) } else { @@ -81,7 +81,7 @@ func (p *Perm) Load(cs map[string]Config) { } } for _, glob := range c.MatchWildcard { - addr, port := splitHostPort(glob.Glob) + addr, port := util.SplitHostPort(glob.Glob) if port != "" { log.Printf("loading glob target %s, action %s", glob.Glob, glob.Act) p_int.match_glob = append( diff --git a/internal/util/hostport.go b/internal/util/hostport.go new file mode 100644 index 0000000..d617f2f --- /dev/null +++ b/internal/util/hostport.go @@ -0,0 +1,38 @@ +package util + +import "strings" + +// ValidOptionalPort reports whether port is either an empty string +// or matches /^:\d*$/ +func ValidOptionalPort(port string) bool { + if port == "" { + return true + } + if port[0] != ':' { + return false + } + for _, b := range port[1:] { + if b < '0' || b > '9' { + return false + } + } + return true +} + +// SplitHostPort separates host and port. If the port is not valid, it returns +// the entire input as host, and it doesn't check the validity of the host. +// Unlike net.SplitHostPort, but per RFC 3986, it requires ports to be numeric. +func SplitHostPort(hostPort string) (host, port string) { + host = hostPort + + colon := strings.LastIndexByte(host, ':') + if colon != -1 && ValidOptionalPort(host[colon:]) { + host, port = host[:colon], host[colon+1:] + } + + if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { + host = host[1 : len(host)-1] + } + + return +} diff --git a/main.go b/main.go index 1e09096..af75a5f 100644 --- a/main.go +++ b/main.go @@ -13,9 +13,25 @@ import ( "edgaru089.ink/go/regolith/internal/conf" "edgaru089.ink/go/regolith/internal/http" + "edgaru089.ink/go/regolith/internal/out" "edgaru089.ink/go/regolith/internal/perm" ) +func readDialer() *out.Dialer { + out_buf, err := os.ReadFile("out.json") + if err != nil { + log.Print("error reading out.json: ", err) + return nil + } + var outcfg out.Config + err = json.Unmarshal(out_buf, &outcfg) + if err != nil { + log.Print("error unmarshaling out.json: ", err) + return nil + } + return out.New(outcfg) +} + func main() { var s *http.Server @@ -34,6 +50,9 @@ func main() { } } + dialer := readDialer() + s.SetDialer(dialer) + var conf conf.Config { conf_buf, err := os.ReadFile("config.json") @@ -75,7 +94,7 @@ func main() { go func() { for { <-sighup_chan - log.Printf("SIGHUP received, reloading permissions") + log.Printf("SIGHUP received, reloading") perm_buf, err := os.ReadFile("perm.json") if err != nil { log.Printf("skipping reload: error opening perm.json: %e", err) @@ -88,6 +107,9 @@ func main() { continue } s.Perm.Load(perm_json) + + dialer := readDialer() + s.SetDialer(dialer) } }() diff --git a/out.json b/out.json new file mode 100644 index 0000000..bbc3830 --- /dev/null +++ b/out.json @@ -0,0 +1,10 @@ +{ + "Default": "direct", + "Proxied": { + "socks5://192.168.1.3:1080": [ + "github.com", + "go.dev", + "pkg.go.dev" + ] + } +} diff --git a/perm.json b/perm.json index 09667f0..bb76d76 100644 --- a/perm.json +++ b/perm.json @@ -13,9 +13,10 @@ ] }, "127.0.0.1": { - "DefaultAction": "deny", + "DefaultAction": "accept", "DefaultPort": [443], "Match": { + "github.com": "accept", "pkg.go.dev": "accept", "go.dev": "accept" },