target based upstream options
This commit is contained in:
@@ -3,6 +3,7 @@ package http
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
@@ -12,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"edgaru089.ink/go/regolith/internal/perm"
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -21,6 +23,8 @@ const (
|
||||
var (
|
||||
dialer = net.Dialer{Timeout: outgoing_client_timeout}
|
||||
http_client = http.Client{Timeout: outgoing_client_timeout}
|
||||
|
||||
dialer_func func(ctx context.Context, network, addr string) (net.Conn, error) = dialer.DialContext
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
@@ -58,6 +62,22 @@ func (s *Server) Serve(listener net.Listener) (err error) {
|
||||
}
|
||||
}
|
||||
|
||||
// SetDialer sets the global dialer every Server uses to dial outgoing HTTP/CONNECT connections.
|
||||
//
|
||||
// When set to nil, net.Dial is used instead.
|
||||
func (s *Server) SetDialer(dialer proxy.ContextDialer) {
|
||||
if http_client.Transport == nil {
|
||||
http_client.Transport = http.DefaultTransport
|
||||
}
|
||||
|
||||
http_client.Transport.(*http.Transport).DialContext = dialer.DialContext
|
||||
if dialer != nil {
|
||||
dialer_func = dialer.DialContext
|
||||
} else {
|
||||
dialer_func = dialer.DialContext
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) dispatch(conn net.Conn) {
|
||||
buf := bufio.NewReader(conn)
|
||||
for {
|
||||
@@ -156,7 +176,9 @@ func (s *Server) handle_connect(conn net.Conn, req *http.Request) {
|
||||
}()
|
||||
|
||||
// dial
|
||||
remote_conn, err := dialer.Dial("tcp", req_addr)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), outgoing_client_timeout)
|
||||
remote_conn, err := dialer_func(ctx, "tcp", req_addr)
|
||||
cancel()
|
||||
if err != nil {
|
||||
simple_respond(conn, req, http.StatusBadGateway)
|
||||
close_err = err
|
||||
|
143
internal/out/out.go
Normal file
143
internal/out/out.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package out
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net"
|
||||
"net/url"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
|
||||
"edgaru089.ink/go/regolith/internal/util"
|
||||
)
|
||||
|
||||
// Configures upstream TCP connections.
|
||||
//
|
||||
// Does not support HTTP proxies as upstream. They're
|
||||
// just a pain to use besides CONNECT commands.
|
||||
//
|
||||
// Upstreams are specified as 'direct' for direct dialing
|
||||
// by the Go net package, or 'socks5://user:pass@host:port'
|
||||
// for SOCKS5 upstreams.
|
||||
//
|
||||
// Numeric addresses are always dialed directly.
|
||||
type Config struct {
|
||||
Default string // Default URI
|
||||
Proxied map[string][]string // Proxy URI to hostname array mapping
|
||||
}
|
||||
|
||||
// Dialer implements x/net/proxy.Dialer from the Config given to it.
|
||||
type Dialer struct {
|
||||
lock sync.RWMutex
|
||||
|
||||
hostnames map[string]string
|
||||
uris map[string]interface {
|
||||
proxy.Dialer
|
||||
proxy.ContextDialer
|
||||
}
|
||||
|
||||
def interface {
|
||||
proxy.Dialer
|
||||
proxy.ContextDialer
|
||||
}
|
||||
}
|
||||
|
||||
var _ proxy.ContextDialer = &Dialer{} // *Dialer implements x/net/proxy.ContextDialer
|
||||
var _ proxy.Dialer = &Dialer{} // *Dialer implements x/net/proxy.Dialer
|
||||
|
||||
// New creates a new Dialer from a Config.
|
||||
//
|
||||
// You can also &out.Dialer{} and then call Load on it.
|
||||
func New(c Config) (d *Dialer) {
|
||||
d = &Dialer{}
|
||||
d.Load(c)
|
||||
return
|
||||
}
|
||||
|
||||
// Load loads/reloads the Dialer struct.
|
||||
func (d *Dialer) Load(c Config) {
|
||||
d.lock.Lock()
|
||||
defer d.lock.Unlock()
|
||||
|
||||
if len(c.Default) == 0 {
|
||||
c.Default = "direct"
|
||||
}
|
||||
log.Print("default dialer ", c.Default)
|
||||
|
||||
if d.hostnames == nil {
|
||||
d.hostnames = make(map[string]string)
|
||||
d.uris = make(map[string]interface {
|
||||
proxy.Dialer
|
||||
proxy.ContextDialer
|
||||
})
|
||||
} else {
|
||||
clear(d.hostnames)
|
||||
clear(d.uris)
|
||||
}
|
||||
|
||||
for uri, hosts := range c.Proxied {
|
||||
d.uris[uri] = nil
|
||||
for _, host := range hosts {
|
||||
d.hostnames[host] = uri
|
||||
log.Printf("dialer( %s ) => %s", host, uri)
|
||||
}
|
||||
}
|
||||
d.uris[c.Default] = nil
|
||||
|
||||
// construct the Dialers
|
||||
for uri := range d.uris {
|
||||
if uri == "direct" {
|
||||
d.uris[uri] = proxy.Direct
|
||||
} else {
|
||||
url, err := url.Parse(uri)
|
||||
if err != nil {
|
||||
log.Print("URI parse error: ", err)
|
||||
d.uris[uri] = proxy.Direct
|
||||
continue
|
||||
}
|
||||
dx, err := proxy.FromURL(url, proxy.Direct)
|
||||
if err != nil {
|
||||
log.Print("Proxy.FromURL error: ", err)
|
||||
d.uris[uri] = proxy.Direct
|
||||
continue
|
||||
}
|
||||
|
||||
var ok bool
|
||||
if d.uris[uri], ok = dx.(interface {
|
||||
proxy.ContextDialer
|
||||
proxy.Dialer
|
||||
}); !ok {
|
||||
log.Print("Proxy.FromURL unable to cast to proxy.Dialer+ContextDialer")
|
||||
d.uris[uri] = proxy.Direct
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
d.def = d.uris[c.Default]
|
||||
}
|
||||
|
||||
func (d *Dialer) findDialer(host string) interface {
|
||||
proxy.Dialer
|
||||
proxy.ContextDialer
|
||||
} {
|
||||
d.lock.RLock()
|
||||
defer d.lock.RUnlock()
|
||||
|
||||
if dx, ok := d.uris[d.hostnames[host]]; ok {
|
||||
return dx
|
||||
} else {
|
||||
return d.def
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Dialer) Dial(network, address string) (net.Conn, error) {
|
||||
host, _ := util.SplitHostPort(address)
|
||||
return d.findDialer(host).Dial(network, address)
|
||||
}
|
||||
|
||||
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
host, _ := util.SplitHostPort(address)
|
||||
return d.findDialer(host).DialContext(ctx, network, address)
|
||||
}
|
@@ -69,7 +69,7 @@ func (p *Perm) Load(cs map[string]Config) {
|
||||
|
||||
// loop around the Match map
|
||||
for addrport, act := range c.Match {
|
||||
addr, port := splitHostPort(addrport)
|
||||
addr, port := util.SplitHostPort(addrport)
|
||||
if port != "" {
|
||||
insert(net.JoinHostPort(addr, port), act)
|
||||
} else {
|
||||
@@ -81,7 +81,7 @@ func (p *Perm) Load(cs map[string]Config) {
|
||||
}
|
||||
}
|
||||
for _, glob := range c.MatchWildcard {
|
||||
addr, port := splitHostPort(glob.Glob)
|
||||
addr, port := util.SplitHostPort(glob.Glob)
|
||||
if port != "" {
|
||||
log.Printf("loading glob target %s, action %s", glob.Glob, glob.Act)
|
||||
p_int.match_glob = append(
|
||||
|
38
internal/util/hostport.go
Normal file
38
internal/util/hostport.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package util
|
||||
|
||||
import "strings"
|
||||
|
||||
// ValidOptionalPort reports whether port is either an empty string
|
||||
// or matches /^:\d*$/
|
||||
func ValidOptionalPort(port string) bool {
|
||||
if port == "" {
|
||||
return true
|
||||
}
|
||||
if port[0] != ':' {
|
||||
return false
|
||||
}
|
||||
for _, b := range port[1:] {
|
||||
if b < '0' || b > '9' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// SplitHostPort separates host and port. If the port is not valid, it returns
|
||||
// the entire input as host, and it doesn't check the validity of the host.
|
||||
// Unlike net.SplitHostPort, but per RFC 3986, it requires ports to be numeric.
|
||||
func SplitHostPort(hostPort string) (host, port string) {
|
||||
host = hostPort
|
||||
|
||||
colon := strings.LastIndexByte(host, ':')
|
||||
if colon != -1 && ValidOptionalPort(host[colon:]) {
|
||||
host, port = host[:colon], host[colon+1:]
|
||||
}
|
||||
|
||||
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
|
||||
host = host[1 : len(host)-1]
|
||||
}
|
||||
|
||||
return
|
||||
}
|
Reference in New Issue
Block a user