regolith/internal/http/server.go

278 lines
6.9 KiB
Go

package http
import (
"bufio"
"bytes"
"errors"
"io"
"log"
"net"
"net/http"
"strings"
"time"
"edgaru089.ink/go/regolith/internal/perm"
)
const (
outgoing_client_timeout = 10 * time.Second
)
var (
dialer = net.Dialer{Timeout: outgoing_client_timeout}
http_client = http.Client{Timeout: outgoing_client_timeout}
)
type Server struct {
Perm *perm.Perm
}
// checkPerm invokes perm.Match.
// If s.perm is nil, then every request is Accept-ed.
//
// It also logs if the action is Deny.
func (s *Server) checkPerm(src, dest string) (act perm.Action) {
if s.Perm == nil {
return perm.ActionAccept
}
src_host, _, _ := net.SplitHostPort(src)
act = s.Perm.Match(src_host, dest)
if act == perm.ActionDeny {
log.Printf("denied: [%s] -> [%s]", src, dest)
}
return
}
func (s *Server) Serve(listener net.Listener) (err error) {
for {
conn, err := listener.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return nil
}
return err
}
go s.dispatch(conn)
}
}
func (s *Server) dispatch(conn net.Conn) {
buf := bufio.NewReader(conn)
for {
req, err := http.ReadRequest(buf)
if err != nil {
// Invalid request
_ = conn.Close()
break
}
if req.Method == http.MethodConnect {
if buf.Buffered() > 0 {
// There is still data in the buffered reader.
// We need to get it out and put it into a cachedConn,
// so that handleConnect can read it.
data := make([]byte, buf.Buffered())
_, err := io.ReadFull(buf, data)
if err != nil {
// Read from buffer failed, is this possible?
_ = conn.Close()
return
}
cachedConn := &cached_conn{
Conn: conn,
Buffer: *bytes.NewBuffer(data),
}
s.handle_connect(cachedConn, req)
} else {
// No data in the buffered reader, we can just pass the original connection.
s.handle_connect(conn, req)
}
// handle_connect will take over the connection,
// i.e. it will not return until the connection is closed.
// When it returns, there will be no more requests from this connection,
// so we simply exit the loop.
break
} else {
// handle_request on the other hand handles one request at a time,
// and returns when the request is done. It returns a bool indicating
// whether the connection should be kept alive, but itself never closes
// the connection.
kept_alive := s.handle_request(conn, req)
if !kept_alive {
_ = conn.Close()
return
}
}
}
}
// cached_conn is a net.Conn wrapper that first Read()s from a buffer,
// and then from the underlying net.Conn when the buffer is drained.
type cached_conn struct {
net.Conn
Buffer bytes.Buffer
}
func (c *cached_conn) Read(b []byte) (int, error) {
if c.Buffer.Len() > 0 {
n, err := c.Buffer.Read(b)
if err == io.EOF {
// Buffer is drained, hide it from the caller
err = nil
}
return n, err
}
return c.Conn.Read(b)
}
// handle_connect returns until the connection is closed by
// the client, or errors. You don't need to close it again.
func (s *Server) handle_connect(conn net.Conn, req *http.Request) {
conn.RemoteAddr()
defer conn.Close()
port := req.URL.Port()
if port == "" {
port = "80"
}
req_addr := net.JoinHostPort(req.URL.Hostname(), port)
// check permission
if s.checkPerm(conn.RemoteAddr().String(), req_addr) != perm.ActionAccept {
_ = simple_respond(conn, req, http.StatusBadGateway)
return
}
// prep for error log on close
var close_err error
defer func() {
if close_err != nil && !errors.Is(close_err, net.ErrClosed) {
// log non-closed errors
log.Printf("[%s] -> [%s] error dialing remote: %v", conn.RemoteAddr(), req_addr, close_err)
}
}()
// dial
remote_conn, err := dialer.Dial("tcp", req_addr)
if err != nil {
simple_respond(conn, req, http.StatusBadGateway)
close_err = err
return
}
defer remote_conn.Close()
log.Printf("[%s] -> [%s] connected", conn.RemoteAddr(), req_addr)
// send a 200 OK and start copying
_ = simple_respond(conn, req, http.StatusOK)
err_chan := make(chan error, 2)
go func() {
_, err := io.Copy(remote_conn, conn)
err_chan <- err
}()
go func() {
_, err := io.Copy(conn, remote_conn)
err_chan <- err
}()
close_err = <-err_chan
}
func (s *Server) handle_request(conn net.Conn, req *http.Request) (should_keepalive bool) {
// Some clients use Connection, some use Proxy-Connection
// https://www.oreilly.com/library/view/http-the-definitive/1565925092/re40.html
keep_alive := req.ProtoAtLeast(1, 1) &&
(strings.EqualFold(req.Header.Get("Proxy-Connection"), "keep-alive") ||
strings.EqualFold(req.Header.Get("Connection"), "keep-alive"))
req.RequestURI = "" // Outgoing request should not have RequestURI
remove_hop_headers(req.Header)
remove_extra_host_port(req)
if req.URL.Scheme != "http" || req.URL.Host == "" {
_ = simple_respond(conn, req, http.StatusBadRequest)
return false
}
// Check permission
req_hostname := req.URL.Hostname()
req_port := req.URL.Port()
if req_port == "" {
req_port = "80"
}
if s.checkPerm(conn.RemoteAddr().String(), net.JoinHostPort(req_hostname, req_port)) != perm.ActionAccept {
_ = simple_respond(conn, req, http.StatusBadGateway)
return false
}
// Request & error log
var close_err error
defer func() {
if close_err != nil && !errors.Is(close_err, net.ErrClosed) {
// log non-closed errors
log.Printf("[%s] -> [%s] error: %v", conn.RemoteAddr(), req.URL, close_err)
}
}()
// Do the request and send the response back
resp, err := http_client.Do(req)
if err != nil {
close_err = err
_ = simple_respond(conn, req, http.StatusBadGateway)
return false
}
remove_hop_headers(resp.Header)
if keep_alive {
resp.Header.Set("Connection", "keep-alive")
resp.Header.Set("Proxy-Connection", "keep-alive")
resp.Header.Set("Keep-Alive", "timeout=60")
}
close_err = resp.Write(conn)
return close_err == nil && keep_alive
}
func remove_hop_headers(header http.Header) {
header.Del("Proxy-Connection") // Not in RFC but common
// https://www.ietf.org/rfc/rfc2616.txt
header.Del("Connection")
header.Del("Keep-Alive")
header.Del("Proxy-Authenticate")
header.Del("Proxy-Authorization")
header.Del("TE")
header.Del("Trailers")
header.Del("Transfer-Encoding")
header.Del("Upgrade")
}
func remove_extra_host_port(req *http.Request) {
host := req.Host
if host == "" {
host = req.URL.Host
}
if pHost, port, err := net.SplitHostPort(host); err == nil && port == "80" {
host = pHost
}
req.Host = host
req.URL.Host = host
}
func simple_respond(conn net.Conn, req *http.Request, statusCode int) error {
resp := &http.Response{
StatusCode: statusCode,
Status: http.StatusText(statusCode),
Proto: req.Proto,
ProtoMajor: req.ProtoMajor,
ProtoMinor: req.ProtoMinor,
Header: http.Header{},
}
// Remove the "Content-Length: 0" header, some clients (e.g. ffmpeg) may not like it.
resp.ContentLength = -1
// Also, prevent the "Connection: close" header.
resp.Close = false
resp.Uncompressed = true
return resp.Write(conn)
}