278 lines
		
	
	
		
			6.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			278 lines
		
	
	
		
			6.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package http
 | |
| 
 | |
| import (
 | |
| 	"bufio"
 | |
| 	"bytes"
 | |
| 	"errors"
 | |
| 	"io"
 | |
| 	"log"
 | |
| 	"net"
 | |
| 	"net/http"
 | |
| 	"strings"
 | |
| 	"time"
 | |
| 
 | |
| 	"edgaru089.ink/go/regolith/internal/perm"
 | |
| )
 | |
| 
 | |
| const (
 | |
| 	outgoing_client_timeout = 10 * time.Second
 | |
| )
 | |
| 
 | |
| var (
 | |
| 	dialer      = net.Dialer{Timeout: outgoing_client_timeout}
 | |
| 	http_client = http.Client{Timeout: outgoing_client_timeout}
 | |
| )
 | |
| 
 | |
| type Server struct {
 | |
| 	Perm *perm.Perm
 | |
| }
 | |
| 
 | |
| // checkPerm invokes perm.Match.
 | |
| // If s.perm is nil, then every request is Accept-ed.
 | |
| //
 | |
| // It also logs if the action is Deny.
 | |
| func (s *Server) checkPerm(src, dest string) (act perm.Action) {
 | |
| 	if s.Perm == nil {
 | |
| 		return perm.ActionAccept
 | |
| 	}
 | |
| 
 | |
| 	src_host, _, _ := net.SplitHostPort(src)
 | |
| 	act = s.Perm.Match(src_host, dest)
 | |
| 	if act == perm.ActionDeny {
 | |
| 		log.Printf("denied: [%s] -> [%s]", src, dest)
 | |
| 	}
 | |
| 	return
 | |
| }
 | |
| 
 | |
| func (s *Server) Serve(listener net.Listener) (err error) {
 | |
| 	for {
 | |
| 		conn, err := listener.Accept()
 | |
| 		if err != nil {
 | |
| 			if errors.Is(err, net.ErrClosed) {
 | |
| 				return nil
 | |
| 			}
 | |
| 			return err
 | |
| 		}
 | |
| 
 | |
| 		go s.dispatch(conn)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (s *Server) dispatch(conn net.Conn) {
 | |
| 	buf := bufio.NewReader(conn)
 | |
| 	for {
 | |
| 		req, err := http.ReadRequest(buf)
 | |
| 		if err != nil {
 | |
| 			// Invalid request
 | |
| 			_ = conn.Close()
 | |
| 			break
 | |
| 		}
 | |
| 
 | |
| 		if req.Method == http.MethodConnect {
 | |
| 			if buf.Buffered() > 0 {
 | |
| 				// There is still data in the buffered reader.
 | |
| 				// We need to get it out and put it into a cachedConn,
 | |
| 				// so that handleConnect can read it.
 | |
| 				data := make([]byte, buf.Buffered())
 | |
| 				_, err := io.ReadFull(buf, data)
 | |
| 				if err != nil {
 | |
| 					// Read from buffer failed, is this possible?
 | |
| 					_ = conn.Close()
 | |
| 					return
 | |
| 				}
 | |
| 				cachedConn := &cached_conn{
 | |
| 					Conn:   conn,
 | |
| 					Buffer: *bytes.NewBuffer(data),
 | |
| 				}
 | |
| 				s.handle_connect(cachedConn, req)
 | |
| 			} else {
 | |
| 				// No data in the buffered reader, we can just pass the original connection.
 | |
| 				s.handle_connect(conn, req)
 | |
| 			}
 | |
| 			// handle_connect will take over the connection,
 | |
| 			// i.e. it will not return until the connection is closed.
 | |
| 			// When it returns, there will be no more requests from this connection,
 | |
| 			// so we simply exit the loop.
 | |
| 			break
 | |
| 		} else {
 | |
| 			// handle_request on the other hand handles one request at a time,
 | |
| 			// and returns when the request is done. It returns a bool indicating
 | |
| 			// whether the connection should be kept alive, but itself never closes
 | |
| 			// the connection.
 | |
| 			kept_alive := s.handle_request(conn, req)
 | |
| 			if !kept_alive {
 | |
| 				_ = conn.Close()
 | |
| 				return
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // cached_conn is a net.Conn wrapper that first Read()s from a buffer,
 | |
| // and then from the underlying net.Conn when the buffer is drained.
 | |
| type cached_conn struct {
 | |
| 	net.Conn
 | |
| 	Buffer bytes.Buffer
 | |
| }
 | |
| 
 | |
| func (c *cached_conn) Read(b []byte) (int, error) {
 | |
| 	if c.Buffer.Len() > 0 {
 | |
| 		n, err := c.Buffer.Read(b)
 | |
| 		if err == io.EOF {
 | |
| 			// Buffer is drained, hide it from the caller
 | |
| 			err = nil
 | |
| 		}
 | |
| 		return n, err
 | |
| 	}
 | |
| 	return c.Conn.Read(b)
 | |
| }
 | |
| 
 | |
| // handle_connect returns until the connection is closed by
 | |
| // the client, or errors. You don't need to close it again.
 | |
| func (s *Server) handle_connect(conn net.Conn, req *http.Request) {
 | |
| 	conn.RemoteAddr()
 | |
| 
 | |
| 	defer conn.Close()
 | |
| 
 | |
| 	port := req.URL.Port()
 | |
| 	if port == "" {
 | |
| 		port = "80"
 | |
| 	}
 | |
| 	req_addr := net.JoinHostPort(req.URL.Hostname(), port)
 | |
| 
 | |
| 	// check permission
 | |
| 	if s.checkPerm(conn.RemoteAddr().String(), req_addr) != perm.ActionAccept {
 | |
| 		_ = simple_respond(conn, req, http.StatusBadGateway)
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	// prep for error log on close
 | |
| 	var close_err error
 | |
| 	defer func() {
 | |
| 		if close_err != nil && !errors.Is(close_err, net.ErrClosed) {
 | |
| 			// log non-closed errors
 | |
| 			log.Printf("[%s] -> [%s] error dialing remote: %v", conn.RemoteAddr(), req_addr, close_err)
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	// dial
 | |
| 	remote_conn, err := dialer.Dial("tcp", req_addr)
 | |
| 	if err != nil {
 | |
| 		simple_respond(conn, req, http.StatusBadGateway)
 | |
| 		close_err = err
 | |
| 		return
 | |
| 	}
 | |
| 	defer remote_conn.Close()
 | |
| 
 | |
| 	log.Printf("[%s] -> [%s] connected", conn.RemoteAddr(), req_addr)
 | |
| 	// send a 200 OK and start copying
 | |
| 	_ = simple_respond(conn, req, http.StatusOK)
 | |
| 	err_chan := make(chan error, 2)
 | |
| 	go func() {
 | |
| 		_, err := io.Copy(remote_conn, conn)
 | |
| 		err_chan <- err
 | |
| 	}()
 | |
| 	go func() {
 | |
| 		_, err := io.Copy(conn, remote_conn)
 | |
| 		err_chan <- err
 | |
| 	}()
 | |
| 	close_err = <-err_chan
 | |
| }
 | |
| 
 | |
| func (s *Server) handle_request(conn net.Conn, req *http.Request) (should_keepalive bool) {
 | |
| 	// Some clients use Connection, some use Proxy-Connection
 | |
| 	// https://www.oreilly.com/library/view/http-the-definitive/1565925092/re40.html
 | |
| 	keep_alive := req.ProtoAtLeast(1, 1) &&
 | |
| 		(strings.EqualFold(req.Header.Get("Proxy-Connection"), "keep-alive") ||
 | |
| 			strings.EqualFold(req.Header.Get("Connection"), "keep-alive"))
 | |
| 	req.RequestURI = "" // Outgoing request should not have RequestURI
 | |
| 
 | |
| 	remove_hop_headers(req.Header)
 | |
| 	remove_extra_host_port(req)
 | |
| 
 | |
| 	if req.URL.Scheme != "http" || req.URL.Host == "" {
 | |
| 		_ = simple_respond(conn, req, http.StatusBadRequest)
 | |
| 		return false
 | |
| 	}
 | |
| 
 | |
| 	// Check permission
 | |
| 	req_hostname := req.URL.Hostname()
 | |
| 	req_port := req.URL.Port()
 | |
| 	if req_port == "" {
 | |
| 		req_port = "80"
 | |
| 	}
 | |
| 	if s.checkPerm(conn.RemoteAddr().String(), net.JoinHostPort(req_hostname, req_port)) != perm.ActionAccept {
 | |
| 		_ = simple_respond(conn, req, http.StatusBadGateway)
 | |
| 		return false
 | |
| 	}
 | |
| 
 | |
| 	// Request & error log
 | |
| 	var close_err error
 | |
| 	defer func() {
 | |
| 		if close_err != nil && !errors.Is(close_err, net.ErrClosed) {
 | |
| 			// log non-closed errors
 | |
| 			log.Printf("[%s] -> [%s] error: %v", conn.RemoteAddr(), req.URL, close_err)
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	// Do the request and send the response back
 | |
| 	resp, err := http_client.Do(req)
 | |
| 	if err != nil {
 | |
| 		close_err = err
 | |
| 		_ = simple_respond(conn, req, http.StatusBadGateway)
 | |
| 		return false
 | |
| 	}
 | |
| 
 | |
| 	remove_hop_headers(resp.Header)
 | |
| 	if keep_alive {
 | |
| 		resp.Header.Set("Connection", "keep-alive")
 | |
| 		resp.Header.Set("Proxy-Connection", "keep-alive")
 | |
| 		resp.Header.Set("Keep-Alive", "timeout=60")
 | |
| 	}
 | |
| 
 | |
| 	close_err = resp.Write(conn)
 | |
| 	return close_err == nil && keep_alive
 | |
| }
 | |
| 
 | |
| func remove_hop_headers(header http.Header) {
 | |
| 	header.Del("Proxy-Connection") // Not in RFC but common
 | |
| 	// https://www.ietf.org/rfc/rfc2616.txt
 | |
| 	header.Del("Connection")
 | |
| 	header.Del("Keep-Alive")
 | |
| 	header.Del("Proxy-Authenticate")
 | |
| 	header.Del("Proxy-Authorization")
 | |
| 	header.Del("TE")
 | |
| 	header.Del("Trailers")
 | |
| 	header.Del("Transfer-Encoding")
 | |
| 	header.Del("Upgrade")
 | |
| }
 | |
| 
 | |
| func remove_extra_host_port(req *http.Request) {
 | |
| 	host := req.Host
 | |
| 	if host == "" {
 | |
| 		host = req.URL.Host
 | |
| 	}
 | |
| 	if pHost, port, err := net.SplitHostPort(host); err == nil && port == "80" {
 | |
| 		host = pHost
 | |
| 	}
 | |
| 	req.Host = host
 | |
| 	req.URL.Host = host
 | |
| }
 | |
| 
 | |
| func simple_respond(conn net.Conn, req *http.Request, statusCode int) error {
 | |
| 	resp := &http.Response{
 | |
| 		StatusCode: statusCode,
 | |
| 		Status:     http.StatusText(statusCode),
 | |
| 		Proto:      req.Proto,
 | |
| 		ProtoMajor: req.ProtoMajor,
 | |
| 		ProtoMinor: req.ProtoMinor,
 | |
| 		Header:     http.Header{},
 | |
| 	}
 | |
| 	// Remove the "Content-Length: 0" header, some clients (e.g. ffmpeg) may not like it.
 | |
| 	resp.ContentLength = -1
 | |
| 	// Also, prevent the "Connection: close" header.
 | |
| 	resp.Close = false
 | |
| 	resp.Uncompressed = true
 | |
| 	return resp.Write(conn)
 | |
| }
 |