diff --git a/listener/listener.go b/listener/listener.go index 27bbd57..6c00a9f 100644 --- a/listener/listener.go +++ b/listener/listener.go @@ -2,6 +2,7 @@ package listener import ( "context" + "crypto/tls" "errors" "log" "net" @@ -19,6 +20,7 @@ type Listener struct { activeConns map[net.Conn]struct{} mu sync.Mutex wg sync.WaitGroup + tlsConfig *tls.Config } func NewListener(px ProxyIO) *Listener { @@ -31,6 +33,20 @@ func NewListener(px ProxyIO) *Listener { } } +func NewTLSListener(px ProxyIO, cfg *tls.Config) *Listener { + if px == nil { + panic("proxy cannot be nil") + } + if cfg == nil { + panic("tls config cannot be nil") + } + return &Listener{ + proxy: px, + activeConns: make(map[net.Conn]struct{}), + tlsConfig: cfg, + } +} + func (l *Listener) Listen(ln net.Listener) { l.ln = ln defer ln.Close() @@ -62,14 +78,24 @@ func (l *Listener) Listen(ln net.Listener) { delete(l.activeConns, conn) l.mu.Unlock() }() - err := l.proxy.Handle(conn) - if err != nil { + if err := l.handleConn(conn); err != nil { log.Printf("connection %s: %v", conn.RemoteAddr(), err) } }() } } +func (l *Listener) handleConn(conn net.Conn) error { + if l.tlsConfig != nil { + server := tls.Server(conn, l.tlsConfig) + if err := server.Handshake(); err != nil { + return err + } + conn = server + } + return l.proxy.Handle(conn) +} + func (l *Listener) GracefulShutdown(ctx context.Context) { l.ln.Close() diff --git a/listener/listener_test.go b/listener/listener_test.go index 88166fd..8a8d21f 100644 --- a/listener/listener_test.go +++ b/listener/listener_test.go @@ -2,6 +2,12 @@ package listener import ( "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "math/big" "net" "sync" "testing" @@ -125,6 +131,68 @@ func TestGracefulShutdownWaitsForConnections(t *testing.T) { } } +func generateTLSConfigs(t *testing.T) (serverCfg *tls.Config, clientCfg *tls.Config) { + t.Helper() + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + } + certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) + if err != nil { + t.Fatal(err) + } + cert := tls.Certificate{Certificate: [][]byte{certDER}, PrivateKey: key} + return &tls.Config{Certificates: []tls.Certificate{cert}}, + &tls.Config{InsecureSkipVerify: true} +} + +func TestTLSListenerHandshake(t *testing.T) { + serverCfg, clientCfg := generateTLSConfigs(t) + proxy := newStubProxy() + close(proxy.block) + l := NewTLSListener(proxy, serverCfg) + addr := startListener(t, l) + defer l.GracefulShutdown(context.Background()) + + conn, err := tls.Dial("tcp", addr, clientCfg) + if err != nil { + t.Fatalf("tls dial failed: %v", err) + } + defer conn.Close() + + select { + case <-proxy.called: + case <-time.After(time.Second): + t.Error("proxy Handle was not called after TLS handshake") + } +} + +func TestTLSListenerRejectsPlainConn(t *testing.T) { + serverCfg, _ := generateTLSConfigs(t) + proxy := newStubProxy() + l := NewTLSListener(proxy, serverCfg) + addr := startListener(t, l) + defer l.GracefulShutdown(context.Background()) + + conn, err := net.Dial("tcp", addr) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + select { + case <-proxy.called: + t.Error("expected Handle not to be called on failed handshake") + case <-time.After(200 * time.Millisecond): + } +} + func TestGracefulShutdownForcesCloseOnDeadline(t *testing.T) { proxy := newStubProxy() l := NewListener(proxy)