diff --git a/internal/config/config.go b/internal/config/config.go index debd462..0618f25 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -41,6 +41,11 @@ type Config struct { TLS TLSConfig `yaml:"tls"` Kubernetes *KubernetesConfig `yaml:"kubernetes,omitempty"` SSH *SSHConfig `yaml:"ssh,omitempty"` + WebApp *WebAppConfig `yaml:"webApp,omitempty"` +} + +type WebAppConfig struct { + Headers map[string]string `yaml:"headers,omitempty"` } type TwingateConfig struct { @@ -269,8 +274,8 @@ func (c *Config) Validate() error { } // Check that at least one protocol is configured - if c.Kubernetes == nil && c.SSH == nil { - return fmt.Errorf("%w: at least one protocol (Kubernetes or SSH) must be configured", ErrRequired) + if c.Kubernetes == nil && c.SSH == nil && c.WebApp == nil { + return fmt.Errorf("%w: at least one protocol (Kubernetes, SSH, or WebApp) must be configured", ErrRequired) } return nil diff --git a/internal/config/config_test.go b/internal/config/config_test.go index b6898e4..9d5295f 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -304,6 +304,9 @@ func TestConfig_Validate(t *testing.T) { PrivateKeyFile: "tls.key", }, Kubernetes: &KubernetesConfig{}, + WebApp: &WebAppConfig{ + Headers: map[string]string{"Authorization": "Bearer {{twingate.jwt}}"}, + }, }, wantErr: false, }, diff --git a/internal/connect/conn.go b/internal/connect/conn.go index 2534765..88d915b 100644 --- a/internal/connect/conn.go +++ b/internal/connect/conn.go @@ -32,13 +32,6 @@ func httpResponseString(httpCode int) string { return fmt.Sprintf("HTTP/1.1 %d %s\r\n\r\n", httpCode, http.StatusText(httpCode)) } -type TransportProtocol int - -const ( - TransportTLS TransportProtocol = iota - TransportSSH -) - // Conn is a custom connection that wraps the underlying TCP net.Conn, handling downstream // proxy (Twingate Client)'s authentication via the initial CONNECT message. It handles 2 TLS // upgrades: with downstream proxy and then optionally with downstream client e.g. `kubectl`. @@ -47,8 +40,8 @@ type Conn interface { GATClaims() *token.GATClaims GetID() string GetAddress() string + GetToken() string Authenticate() error - TransportProtocol() TransportProtocol UpgradeToTLS() error Close() error @@ -64,6 +57,7 @@ type ProxyConn struct { ID string Address string Claims *token.GATClaims + Token string Timer *time.Timer Mu sync.Mutex @@ -97,14 +91,6 @@ func (p *ProxyConn) Close() error { return p.Conn.Close() } -func (p *ProxyConn) TransportProtocol() TransportProtocol { - if p.GATClaims().Resource.Type == token.ResourceTypeSSH { - return TransportSSH - } - - return TransportTLS -} - func (p *ProxyConn) GATClaims() *token.GATClaims { return p.Claims } @@ -117,6 +103,10 @@ func (p *ProxyConn) GetAddress() string { return p.Address } +func (p *ProxyConn) GetToken() string { + return p.Token +} + // Authenticate sets up TLS and processes the CONNECT message for authentication. func (p *ProxyConn) Authenticate() error { _ = p.SetDeadline(time.Now().Add(defaultTimeout)) @@ -225,7 +215,10 @@ func (p *ProxyConn) Authenticate() error { p.tracker.RecordConnectMetrics(httpCode) - p.Logger.Info("Authenticated connection", zap.String("resource_address", connectInfo.Claims.Resource.Address)) + p.Logger.Info("Authenticated connection", + zap.String("resource_type", string(connectInfo.Claims.Resource.Type)), + zap.String("resource_address", connectInfo.Claims.Resource.Address), + ) p.setConnectInfo(connectInfo) return nil @@ -249,6 +242,7 @@ func (p *ProxyConn) setConnectInfo(connectInfo Info) { p.ID = connectInfo.ConnID p.Address = connectInfo.Address p.Claims = connectInfo.Claims + p.Token = connectInfo.Token p.Timer = time.AfterFunc(time.Until(connectInfo.Claims.ExpiresAt.Time), func() { _ = p.Close() }) diff --git a/internal/connect/conn_test.go b/internal/connect/conn_test.go index 84e2787..0bc06c3 100644 --- a/internal/connect/conn_test.go +++ b/internal/connect/conn_test.go @@ -83,33 +83,6 @@ func TestProxyConn_setConnectInfo(t *testing.T) { }) } -func TestProxyConn_TransportProtocol(t *testing.T) { - tests := []struct { - resourceType token.ResourceType - expected TransportProtocol - }{ - { - resourceType: token.ResourceTypeKubernetes, - expected: TransportTLS, - }, - { - resourceType: token.ResourceTypeSSH, - expected: TransportSSH, - }, - } - - for _, tt := range tests { - t.Run(tt.resourceType, func(t *testing.T) { - claims := &token.GATClaims{ - Resource: token.Resource{Type: tt.resourceType}, - } - proxyConn := &ProxyConn{Claims: claims} - - assert.Equal(t, tt.expected, proxyConn.TransportProtocol()) - }) - } -} - func TestProxyConn_Close(t *testing.T) { conn := &mockConn{} timer := time.NewTimer(0 * time.Millisecond) diff --git a/internal/connect/connect.go b/internal/connect/connect.go index 8ad610e..acd90d7 100644 --- a/internal/connect/connect.go +++ b/internal/connect/connect.go @@ -28,6 +28,7 @@ type Info struct { Address string Claims *token.GATClaims ConnID string + Token string } type HTTPError struct { @@ -156,5 +157,6 @@ func (v *MessageValidator) ParseConnect(req *http.Request, ekm []byte) (connectI Address: address, Claims: gatClaims, ConnID: connID, + Token: bearerToken, }, nil } diff --git a/internal/connect/connect_test.go b/internal/connect/connect_test.go index 5d8e868..0c38faa 100644 --- a/internal/connect/connect_test.go +++ b/internal/connect/connect_test.go @@ -110,6 +110,7 @@ func TestConnectValidator_ParseConnect(t *testing.T) { require.NoError(t, err) assert.Equal(t, *connectInfo.Claims, gatClaims) assert.Equal(t, "conn-id", connectInfo.ConnID) + assert.Equal(t, signedToken, connectInfo.Token) }) t.Run("Non-CONNECT method", func(t *testing.T) { @@ -127,6 +128,7 @@ func TestConnectValidator_ParseConnect(t *testing.T) { assert.Contains(t, httpErr.Error(), "expected CONNECT request") assert.Nil(t, connectInfo.Claims) assert.Empty(t, connectInfo.ConnID) + assert.Empty(t, connectInfo.Token) }) t.Run("Missing auth header", func(t *testing.T) { diff --git a/internal/connect/listener.go b/internal/connect/listener.go index c66bf62..f7a8cd0 100644 --- a/internal/connect/listener.go +++ b/internal/connect/listener.go @@ -70,7 +70,7 @@ func (l *ProtocolListener) Addr() net.Addr { type ConnFactory func(net.Conn, *tls.Config, Validator, *zap.Logger) Conn type Listener struct { - channels map[TransportProtocol]chan<- Conn + channels map[token.ResourceType]chan<- Conn tokenParser *token.Parser certReloader *CertReloader @@ -88,7 +88,7 @@ type Listener struct { func NewListener( twingateConfig config.TwingateConfig, tlsCfg config.TLSConfig, - channels map[TransportProtocol]chan<- Conn, + channels map[token.ResourceType]chan<- Conn, registry *prometheus.Registry, logger *zap.Logger, ) (*Listener, error) { @@ -172,19 +172,18 @@ func (l *Listener) Serve(ctx context.Context, listener net.Listener) error { return } - tp := proxyConn.TransportProtocol() - channel, exists := l.channels[tp] + resourceType := proxyConn.GATClaims().Resource.Type + channel, exists := l.channels[resourceType] if !exists { - l.logger.Error("Unsupported transport protocol", zap.Int("transport", int(tp))) + l.logger.Error("Unsupported resource type", zap.String("resource_type", string(resourceType))) _ = proxyConn.Close() return } - // For non-SSH protocols, upgrade to TLS - if tp != TransportSSH { + if proxyConn.GATClaims().ShouldUpgradeTLS() { if err := proxyConn.UpgradeToTLS(); err != nil { l.logger.Error("Failed to upgrade to TLS", zap.Error(err)) diff --git a/internal/connect/listener_test.go b/internal/connect/listener_test.go index f674e35..0044ba8 100644 --- a/internal/connect/listener_test.go +++ b/internal/connect/listener_test.go @@ -25,8 +25,7 @@ import ( type mockProxyConn struct { net.Conn - transportProtocol TransportProtocol - Claims *token.GATClaims + Claims *token.GATClaims isClosed atomic.Bool @@ -44,10 +43,6 @@ func (m *mockProxyConn) IsClosed() bool { return m.isClosed.Load() } -func (m *mockProxyConn) TransportProtocol() TransportProtocol { - return m.transportProtocol -} - func (m *mockProxyConn) GATClaims() *token.GATClaims { return m.Claims } @@ -60,6 +55,10 @@ func (m *mockProxyConn) GetAddress() string { return "mock" } +func (m *mockProxyConn) GetToken() string { + return "" +} + func (m *mockProxyConn) Authenticate() error { if m.isHealthz { // write health check response @@ -89,16 +88,20 @@ func createMockListener(t *testing.T) (net.Listener, string) { return listener, addr } -var listenerClaims = &token.GATClaims{ - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), - }, - User: token.User{ - ID: "user-1", - Username: "user@acme.com", - Groups: []string{"Everyone", "Engineering"}, - }, - Resource: token.Resource{ID: "resource-1", Type: token.ResourceTypeKubernetes, Address: "https://api.acme.com"}, +func createClaims(t *testing.T, resourceType token.ResourceType) *token.GATClaims { + t.Helper() + + return &token.GATClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), + }, + User: token.User{ + ID: "user-1", + Username: "user@acme.com", + Groups: []string{"Everyone", "Engineering"}, + }, + Resource: token.Resource{ID: "resource-1", Type: resourceType, Address: "https://api.acme.com"}, + } } type testListenerFixtures struct { @@ -117,9 +120,9 @@ func createTestListenerWithChannels(t *testing.T) *testListenerFixtures { // Create channels for testing httpChannel := make(chan Conn, 1) sshChannel := make(chan Conn, 1) - channels := map[TransportProtocol]chan<- Conn{ - TransportTLS: httpChannel, - TransportSSH: sshChannel, + channels := map[token.ResourceType]chan<- Conn{ + token.ResourceTypeKubernetes: httpChannel, + token.ResourceTypeSSH: sshChannel, } registry := prometheus.NewRegistry() @@ -146,11 +149,12 @@ func createTestListenerWithChannels(t *testing.T) *testListenerFixtures { func TestListener_Serve_HTTPS(t *testing.T) { fixtures := createTestListenerWithChannels(t) + kubernetesClaims := createClaims(t, token.ResourceTypeKubernetes) + fixtures.listener.proxyConnFactory = func(conn net.Conn, _ *tls.Config, _ Validator, _ *zap.Logger) Conn { return &mockProxyConn{ - Conn: conn, - transportProtocol: TransportTLS, - Claims: listenerClaims, + Conn: conn, + Claims: kubernetesClaims, } } @@ -182,8 +186,7 @@ func TestListener_Serve_HTTPS(t *testing.T) { select { case conn := <-fixtures.httpChannel: require.False(t, conn.(*mockProxyConn).IsClosed()) - require.Equal(t, TransportTLS, conn.TransportProtocol()) - require.Equal(t, listenerClaims, conn.GATClaims()) + require.Equal(t, kubernetesClaims, conn.GATClaims()) case <-time.After(1 * time.Second): t.Fatal("timeout waiting for HTTP connection") } @@ -198,11 +201,12 @@ func TestListener_Serve_HTTPS(t *testing.T) { func TestListener_Serve_SSH(t *testing.T) { fixtures := createTestListenerWithChannels(t) + sshClaims := createClaims(t, token.ResourceTypeSSH) + fixtures.listener.proxyConnFactory = func(conn net.Conn, _ *tls.Config, _ Validator, _ *zap.Logger) Conn { return &mockProxyConn{ - Conn: conn, - transportProtocol: TransportSSH, - Claims: listenerClaims, + Conn: conn, + Claims: sshClaims, } } @@ -234,8 +238,7 @@ func TestListener_Serve_SSH(t *testing.T) { select { case conn := <-fixtures.sshChannel: require.False(t, conn.(*mockProxyConn).IsClosed()) - require.Equal(t, TransportSSH, conn.TransportProtocol()) - require.Equal(t, listenerClaims, conn.GATClaims()) + require.Equal(t, sshClaims, conn.GATClaims()) case <-time.After(1 * time.Second): t.Fatal("timeout waiting for SSH connection") } @@ -321,13 +324,13 @@ func TestListener_Serve_Healthz(t *testing.T) { waitGroup.Wait() } -func TestListener_UnsupportedTransport(t *testing.T) { +func TestListener_UnsupportedResourceType(t *testing.T) { tcpListener, addr := createMockListener(t) // Create channels but omit one transport type httpChannel := make(chan Conn, 1) - channels := map[TransportProtocol]chan<- Conn{ - TransportTLS: httpChannel, + channels := map[token.ResourceType]chan<- Conn{ + token.ResourceTypeKubernetes: httpChannel, // SSH not included - unsupported } @@ -342,13 +345,14 @@ func TestListener_UnsupportedTransport(t *testing.T) { certReloader: certReloader, } + sshClaims := createClaims(t, token.ResourceTypeSSH) + // Use a channel to safely pass the connection from the factory connCreated := make(chan *mockProxyConn, 1) listener.proxyConnFactory = func(conn net.Conn, _ *tls.Config, _ Validator, _ *zap.Logger) Conn { mockConn := &mockProxyConn{ - Conn: conn, - transportProtocol: TransportSSH, // This transport is not supported - Claims: listenerClaims, + Conn: conn, + Claims: sshClaims, } connCreated <- mockConn @@ -383,12 +387,14 @@ func TestListener_Serve_GracefulShutdown(t *testing.T) { // Use unbuffered channel so the goroutine inside Serve blocks on send httpChannel := make(chan Conn) - channels := map[TransportProtocol]chan<- Conn{ - TransportTLS: httpChannel, + channels := map[token.ResourceType]chan<- Conn{ + token.ResourceTypeKubernetes: httpChannel, } logger := zap.NewNop() + kubernetesClaims := createClaims(t, token.ResourceTypeKubernetes) + listener := &Listener{ channels: channels, logger: logger, @@ -398,9 +404,8 @@ func TestListener_Serve_GracefulShutdown(t *testing.T) { listener.proxyConnFactory = func(conn net.Conn, _ *tls.Config, _ Validator, _ *zap.Logger) Conn { return &mockProxyConn{ - Conn: conn, - transportProtocol: TransportTLS, - Claims: listenerClaims, + Conn: conn, + Claims: kubernetesClaims, } } @@ -457,9 +462,10 @@ func TestProtocolListener(t *testing.T) { assert.Equal(t, addr, protocolListener.Addr().String()) // Send a mock connection + kubernetesClaims := createClaims(t, token.ResourceTypeKubernetes) + mockConn := &mockProxyConn{ - transportProtocol: TransportTLS, - Claims: listenerClaims, + Claims: kubernetesClaims, } ch <- mockConn diff --git a/internal/httpproxy/proxy.go b/internal/httpproxy/proxy.go index 51d9cf0..c3609e5 100644 --- a/internal/httpproxy/proxy.go +++ b/internal/httpproxy/proxy.go @@ -9,7 +9,6 @@ import ( "net/http" "time" - "github.com/prometheus/client_golang/prometheus" "go.uber.org/zap" "gateway/internal/connect" @@ -28,9 +27,10 @@ func ProxyConnFromContext(ctx context.Context) *connect.ProxyConn { } type Config struct { - Handler http.Handler - Registry *prometheus.Registry - Logger *zap.Logger + Handler http.Handler + Metrics *metrics.HTTPMetrics + Logger *zap.Logger + ResourceType metrics.ResourceType } type Proxy struct { @@ -38,13 +38,10 @@ type Proxy struct { } func NewProxy(cfg Config) *Proxy { - handler := metrics.HTTPMiddleware(metrics.HTTPMiddlewareConfig{ - Registry: cfg.Registry, - Next: auditMiddleware(auditMiddlewareConfig{ - next: cfg.Handler, - logger: cfg.Logger, - }), - }) + handler := metrics.HTTPMiddleware(cfg.Metrics, cfg.ResourceType, auditMiddleware(auditMiddlewareConfig{ + next: cfg.Handler, + logger: cfg.Logger, + })) mux := http.NewServeMux() mux.Handle("/", handler) diff --git a/internal/httpproxy/proxy_test.go b/internal/httpproxy/proxy_test.go index 22fce14..5ede9e0 100644 --- a/internal/httpproxy/proxy_test.go +++ b/internal/httpproxy/proxy_test.go @@ -16,6 +16,7 @@ import ( "go.uber.org/zap" "gateway/internal/connect" + "gateway/internal/metrics" "gateway/internal/token" ) @@ -33,7 +34,10 @@ func (l *mockConnListener) Accept() (net.Conn, error) { proxyConn := connect.NewProxyConn(conn, nil, nil, zap.NewNop(), connMetrics) proxyConn.ID = "test-conn" proxyConn.Address = "localhost" - proxyConn.Claims = &token.GATClaims{User: token.User{Username: "test@acme.com"}} + proxyConn.Claims = &token.GATClaims{ + User: token.User{Username: "test@acme.com"}, + Resource: token.Resource{Type: token.ResourceTypeKubernetes}, + } return proxyConn, nil } @@ -65,9 +69,10 @@ func TestProxy_ForwardRequest(t *testing.T) { require.NoError(t, err) proxy := NewProxy(Config{ - Registry: prometheus.NewRegistry(), - Handler: handler, - Logger: zap.NewNop(), + Metrics: metrics.RegisterHTTPMetrics(prometheus.NewRegistry()), + ResourceType: metrics.ResourceTypeKubernetes, + Handler: handler, + Logger: zap.NewNop(), }) go func() { @@ -99,9 +104,10 @@ func TestProxy_Shutdown(t *testing.T) { require.NoError(t, err) proxy := NewProxy(Config{ - Handler: handler, - Registry: prometheus.NewRegistry(), - Logger: zap.NewNop(), + Metrics: metrics.RegisterHTTPMetrics(prometheus.NewRegistry()), + ResourceType: metrics.ResourceTypeKubernetes, + Handler: handler, + Logger: zap.NewNop(), }) done := make(chan error, 1) diff --git a/internal/kuberneteshandler/kubernetes.go b/internal/kuberneteshandler/kubernetes.go index 433dcd4..ffd6e94 100644 --- a/internal/kuberneteshandler/kubernetes.go +++ b/internal/kuberneteshandler/kubernetes.go @@ -44,7 +44,7 @@ func NewHandler(cfg Config) (*Handler, error) { conn := httpproxy.ProxyConnFromContext(r.In.Context()) rewrite(r, conn) }, - Transport: metrics.InstrumentRoundTripper(cfg.roundTripperMetrics, transport), + Transport: metrics.InstrumentRoundTripper(cfg.roundTripperMetrics, metrics.ResourceTypeKubernetes, transport), } handler := &Handler{ diff --git a/internal/metrics/http_middleware.go b/internal/metrics/http_middleware.go index ee82590..d71626f 100644 --- a/internal/metrics/http_middleware.go +++ b/internal/metrics/http_middleware.go @@ -28,71 +28,84 @@ const ( requestTypeUnknown = "unknown" ) -type HTTPMiddlewareConfig struct { - Registry *prometheus.Registry - Next http.Handler +type contextKey struct{} + +type HTTPMetrics struct { + requestsTotal *prometheus.CounterVec + activeRequests *prometheus.GaugeVec + requestDuration *prometheus.HistogramVec + requestSizeBytes *prometheus.HistogramVec + responseSizeBytes *prometheus.HistogramVec } -type contextKey struct{} +func RegisterHTTPMetrics(registry *prometheus.Registry) *HTTPMetrics { + m := &HTTPMetrics{ + requestsTotal: prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: Namespace, + Name: "http_requests_total", + Help: "Total number of HTTP requests processed", + }, []string{labelResourceType, "type", "method", "code"}), -func HTTPMiddleware(config HTTPMiddlewareConfig) http.HandlerFunc { - requestsTotal := prometheus.NewCounterVec(prometheus.CounterOpts{ - Namespace: Namespace, - Name: "http_requests_total", - Help: "Total number of HTTP requests processed", - }, []string{"type", "method", "code"}) - - activeRequests := prometheus.NewGaugeVec(prometheus.GaugeOpts{ - Namespace: Namespace, - Name: "http_active_requests", - Help: "Number of currently active HTTP requests", - }, []string{"type"}) - - requestDuration := prometheus.NewHistogramVec( - prometheus.HistogramOpts{ + activeRequests: prometheus.NewGaugeVec(prometheus.GaugeOpts{ Namespace: Namespace, - Name: "http_request_duration_seconds", - Help: "Latencies of HTTP requests in seconds", - Buckets: []float64{0.1, 0.25, 0.5, 1, 2, 5, 10, 30, 60, 120, 300, 600, 1800, 3600}, - }, []string{"type", "method", "code"}) - - requestSizeBytes := prometheus.NewHistogramVec(prometheus.HistogramOpts{ - Namespace: Namespace, - Name: "http_request_size_bytes", - Help: "Size of incoming HTTP request in bytes", - Buckets: prometheus.ExponentialBuckets(100, 10, 6), - }, []string{"type", "method", "code"}) - - responseSizeBytes := prometheus.NewHistogramVec(prometheus.HistogramOpts{ - Namespace: Namespace, - Name: "http_response_size_bytes", - Help: "Size of outgoing HTTP response in bytes", - Buckets: prometheus.ExponentialBuckets(100, 10, 6), - }, []string{"type", "method", "code"}, - ) + Name: "http_active_requests", + Help: "Number of currently active HTTP requests", + }, []string{labelResourceType, "type"}), + + requestDuration: prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: Namespace, + Name: "http_request_duration_seconds", + Help: "Latencies of HTTP requests in seconds", + Buckets: []float64{0.1, 0.25, 0.5, 1, 2, 5, 10, 30, 60, 120, 300, 600, 1800, 3600}, + }, []string{labelResourceType, "type", "method", "code"}), + + requestSizeBytes: prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: Namespace, + Name: "http_request_size_bytes", + Help: "Size of incoming HTTP request in bytes", + Buckets: prometheus.ExponentialBuckets(100, 10, 6), + }, []string{labelResourceType, "type", "method", "code"}), - config.Registry.MustRegister(requestsTotal, activeRequests, requestDuration, requestSizeBytes, responseSizeBytes) + responseSizeBytes: prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: Namespace, + Name: "http_response_size_bytes", + Help: "Size of outgoing HTTP response in bytes", + Buckets: prometheus.ExponentialBuckets(100, 10, 6), + }, []string{labelResourceType, "type", "method", "code"}), + } + + registry.MustRegister(m.requestsTotal, m.activeRequests, m.requestDuration, m.requestSizeBytes, m.responseSizeBytes) + + return m +} - opts := promhttp.WithLabelFromCtx(labelRequestType, getRequestTypeFromContext) +func HTTPMiddleware(metrics *HTTPMetrics, resourceType ResourceType, next http.Handler) http.HandlerFunc { + resourceTypeOpt := promhttp.WithLabelFromCtx(labelResourceType, func(_ context.Context) string { return string(resourceType) }) + requestTypeOpt := promhttp.WithLabelFromCtx(labelRequestType, getRequestTypeFromContext) base := promhttp.InstrumentHandlerCounter( - requestsTotal, - instrumentHandlerInFlight(activeRequests, + metrics.requestsTotal, + instrumentHandlerInFlight(metrics.activeRequests, string(resourceType), promhttp.InstrumentHandlerDuration( - requestDuration, + metrics.requestDuration, promhttp.InstrumentHandlerRequestSize( - requestSizeBytes, + metrics.requestSizeBytes, promhttp.InstrumentHandlerResponseSize( - responseSizeBytes, - config.Next, - opts, + metrics.responseSizeBytes, + next, + resourceTypeOpt, + requestTypeOpt, ), - opts, + resourceTypeOpt, + requestTypeOpt, ), - opts, + resourceTypeOpt, + requestTypeOpt, ), ), - opts, + resourceTypeOpt, + requestTypeOpt, ) return func(w http.ResponseWriter, r *http.Request) { @@ -100,12 +113,12 @@ func HTTPMiddleware(config HTTPMiddlewareConfig) http.HandlerFunc { } } -func instrumentHandlerInFlight(activeRequests *prometheus.GaugeVec, next http.Handler) http.Handler { +func instrumentHandlerInFlight(activeRequests *prometheus.GaugeVec, resourceType string, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { requestType := getRequestTypeFromContext(r.Context()) - activeRequests.WithLabelValues(requestType).Inc() - defer activeRequests.WithLabelValues(requestType).Dec() + activeRequests.WithLabelValues(resourceType, requestType).Inc() + defer activeRequests.WithLabelValues(resourceType, requestType).Dec() next.ServeHTTP(w, r) }) diff --git a/internal/metrics/http_middleware_test.go b/internal/metrics/http_middleware_test.go index e441206..9084d49 100644 --- a/internal/metrics/http_middleware_test.go +++ b/internal/metrics/http_middleware_test.go @@ -108,14 +108,14 @@ func TestWithRequestType(t *testing.T) { func TestHTTPMiddleware(t *testing.T) { testRegistry := prometheus.NewRegistry() + httpMetrics := RegisterHTTPMetrics(testRegistry) server := httptest.NewServer(HTTPMiddleware( - HTTPMiddlewareConfig{ - Registry: testRegistry, - Next: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusOK) - }), - }, + httpMetrics, + ResourceTypeKubernetes, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }), )) defer server.Close() @@ -133,27 +133,32 @@ func TestHTTPMiddleware(t *testing.T) { labelsByMetric := testutil.ExtractLabelsFromMetrics(metricFamilies) expectedLabels := map[string]map[string]string{ "twingate_gateway_http_requests_total": { - "type": "http", - "method": "get", - "code": "200", + "resource_type": "kubernetes", + "type": "http", + "method": "get", + "code": "200", }, "twingate_gateway_http_active_requests": { - "type": "http", + "resource_type": "kubernetes", + "type": "http", }, "twingate_gateway_http_request_duration_seconds": { - "type": "http", - "method": "get", - "code": "200", + "resource_type": "kubernetes", + "type": "http", + "method": "get", + "code": "200", }, "twingate_gateway_http_request_size_bytes": { - "type": "http", - "method": "get", - "code": "200", + "resource_type": "kubernetes", + "type": "http", + "method": "get", + "code": "200", }, "twingate_gateway_http_response_size_bytes": { - "type": "http", - "method": "get", - "code": "200", + "resource_type": "kubernetes", + "type": "http", + "method": "get", + "code": "200", }, } assert.Equal(t, expectedLabels, labelsByMetric) diff --git a/internal/metrics/round_tripper.go b/internal/metrics/round_tripper.go index 066f6cb..af1955b 100644 --- a/internal/metrics/round_tripper.go +++ b/internal/metrics/round_tripper.go @@ -4,12 +4,25 @@ package metrics import ( + "context" "net/http" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" ) +// Metric label names. + +const labelResourceType = "resource_type" + +// ResourceType identifies the upstream resource type in metric labels. +type ResourceType string + +const ( + ResourceTypeKubernetes ResourceType = "kubernetes" + ResourceTypeWebApp ResourceType = "web_app" +) + type RoundTripperMetrics struct { requestsTotal *prometheus.CounterVec activeRequests *prometheus.GaugeVec @@ -21,22 +34,22 @@ func RegisterRoundTripperMetrics(registry *prometheus.Registry) *RoundTripperMet requestsTotal: prometheus.NewCounterVec(prometheus.CounterOpts{ Namespace: Namespace, Name: "api_server_requests_total", - Help: "Total number of requests from Gateway to API Server processed", - }, []string{"type", "method", "code"}), + Help: "Total number of requests from Gateway to HTTP server processed", + }, []string{labelResourceType, "type", "method", "code"}), activeRequests: prometheus.NewGaugeVec(prometheus.GaugeOpts{ Namespace: Namespace, Name: "api_server_active_requests", - Help: "Number of currently active requests from Gateway to API Server", - }, []string{"type"}), + Help: "Number of currently active requests from Gateway to HTTP server", + }, []string{labelResourceType, "type"}), requestDuration: prometheus.NewHistogramVec( prometheus.HistogramOpts{ Namespace: Namespace, Name: "api_server_request_duration_seconds", - Help: "Measures the initial HTTP request-response latency between Gateway and API Server in seconds. For HTTP streaming, WebSocket, and SPDY connections, this metric captures only the setup time and not the duration of the data transfer.", + Help: "Measures the initial HTTP request-response latency between Gateway and HTTP server in seconds. For HTTP streaming, WebSocket, and SPDY connections, this metric captures only the setup time and not the duration of the data transfer.", Buckets: prometheus.DefBuckets, - }, []string{"type", "method", "code"}), + }, []string{labelResourceType, "type", "method", "code"}), } registry.MustRegister(c.requestsTotal, c.activeRequests, c.requestDuration) @@ -44,20 +57,24 @@ func RegisterRoundTripperMetrics(registry *prometheus.Registry) *RoundTripperMet return c } -func InstrumentRoundTripper(metrics *RoundTripperMetrics, next http.RoundTripper) promhttp.RoundTripperFunc { - opts := promhttp.WithLabelFromCtx(labelRequestType, getRequestTypeFromContext) +func InstrumentRoundTripper(metrics *RoundTripperMetrics, resourceType ResourceType, next http.RoundTripper) promhttp.RoundTripperFunc { + resourceTypeOpt := promhttp.WithLabelFromCtx(labelResourceType, func(_ context.Context) string { return string(resourceType) }) + requestTypeOpt := promhttp.WithLabelFromCtx(labelRequestType, getRequestTypeFromContext) base := promhttp.InstrumentRoundTripperCounter( metrics.requestsTotal, instrumentRoundTripperInFlight( metrics.activeRequests, + resourceType, promhttp.InstrumentRoundTripperDuration( metrics.requestDuration, next, - opts, + resourceTypeOpt, + requestTypeOpt, ), ), - opts, + resourceTypeOpt, + requestTypeOpt, ) return func(r *http.Request) (*http.Response, error) { @@ -65,12 +82,12 @@ func InstrumentRoundTripper(metrics *RoundTripperMetrics, next http.RoundTripper } } -func instrumentRoundTripperInFlight(activeRequests *prometheus.GaugeVec, next http.RoundTripper) promhttp.RoundTripperFunc { +func instrumentRoundTripperInFlight(activeRequests *prometheus.GaugeVec, resourceType ResourceType, next http.RoundTripper) promhttp.RoundTripperFunc { return func(r *http.Request) (*http.Response, error) { requestType := getRequestTypeFromContext(r.Context()) - activeRequests.WithLabelValues(requestType).Inc() - defer activeRequests.WithLabelValues(requestType).Dec() + activeRequests.WithLabelValues(string(resourceType), requestType).Inc() + defer activeRequests.WithLabelValues(string(resourceType), requestType).Dec() return next.RoundTrip(r) } diff --git a/internal/metrics/round_tripper_test.go b/internal/metrics/round_tripper_test.go index 62c60bb..6b05052 100644 --- a/internal/metrics/round_tripper_test.go +++ b/internal/metrics/round_tripper_test.go @@ -23,7 +23,7 @@ func TestInstrumentRoundTripper(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) - transport := InstrumentRoundTripper(collectors, promhttp.RoundTripperFunc(func(r *http.Request) (*http.Response, error) { + transport := InstrumentRoundTripper(collectors, ResourceTypeKubernetes, promhttp.RoundTripperFunc(func(r *http.Request) (*http.Response, error) { return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody, Request: r}, nil })) @@ -38,17 +38,20 @@ func TestInstrumentRoundTripper(t *testing.T) { labelsByMetric := testutil.ExtractLabelsFromMetrics(metricFamilies) expectedLabels := map[string]map[string]string{ "twingate_gateway_api_server_requests_total": { - "type": "http", - "method": "get", - "code": "200", + "resource_type": "kubernetes", + "type": "http", + "method": "get", + "code": "200", }, "twingate_gateway_api_server_active_requests": { - "type": "http", + "resource_type": "kubernetes", + "type": "http", }, "twingate_gateway_api_server_request_duration_seconds": { - "type": "http", - "method": "get", - "code": "200", + "resource_type": "kubernetes", + "type": "http", + "method": "get", + "code": "200", }, } assert.Equal(t, expectedLabels, labelsByMetric) @@ -64,17 +67,17 @@ func TestInstrumentRoundTripper_MultipleTransports(t *testing.T) { }) // Instrumenting multiple transports should not panic - transport1 := InstrumentRoundTripper(collectors, mockTransport) - transport2 := InstrumentRoundTripper(collectors, mockTransport) + k8sTransport := InstrumentRoundTripper(collectors, ResourceTypeKubernetes, mockTransport) + webAppTransport := InstrumentRoundTripper(collectors, ResourceTypeWebApp, mockTransport) req := httptest.NewRequest(http.MethodGet, "/", nil) - resp1, err := transport1.RoundTrip(req) + resp1, err := k8sTransport.RoundTrip(req) require.NoError(t, err) defer resp1.Body.Close() - resp2, err := transport2.RoundTrip(req) + resp2, err := webAppTransport.RoundTrip(req) require.NoError(t, err) defer resp2.Body.Close() diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 5aaee99..3e068b8 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -25,6 +25,8 @@ import ( "gateway/internal/metrics" "gateway/internal/sessionrecorder" "gateway/internal/sshhandler" + "gateway/internal/token" + "gateway/internal/webapphandler" ) const shutdownTimeout = 30 * time.Second @@ -34,7 +36,7 @@ type Proxy struct { registry *prometheus.Registry logger *zap.Logger - httpProxy *httpproxy.Proxy + httpProxies map[token.ResourceType]*httpproxy.Proxy sshProxy *sshhandler.SSHProxy metricsServer *metrics.Server @@ -43,11 +45,19 @@ type Proxy struct { } func NewProxy(config *gatewayconfig.Config, registry *prometheus.Registry, logger *zap.Logger) (*Proxy, error) { - var httpProxy *httpproxy.Proxy + httpProxies := make(map[token.ResourceType]*httpproxy.Proxy) - if config.Kubernetes != nil { - roundTripperMetrics := metrics.RegisterRoundTripperMetrics(registry) + var ( + roundTripperMetrics *metrics.RoundTripperMetrics + httpMetrics *metrics.HTTPMetrics + ) + + if config.Kubernetes != nil || config.WebApp != nil { + roundTripperMetrics = metrics.RegisterRoundTripperMetrics(registry) + httpMetrics = metrics.RegisterHTTPMetrics(registry) + } + if config.Kubernetes != nil { k8sConfig, err := kuberneteshandler.NewConfig(&config.AuditLog, config.Kubernetes, roundTripperMetrics, logger) if err != nil { return nil, fmt.Errorf("failed to create Kubernetes config: %w", err) @@ -58,10 +68,27 @@ func NewProxy(config *gatewayconfig.Config, registry *prometheus.Registry, logge return nil, fmt.Errorf("failed to create Kubernetes handler: %w", err) } - httpProxy = httpproxy.NewProxy(httpproxy.Config{ - Handler: k8sHandler, - Registry: registry, - Logger: logger, + httpProxies[token.ResourceTypeKubernetes] = httpproxy.NewProxy(httpproxy.Config{ + Handler: k8sHandler, + Metrics: httpMetrics, + Logger: logger, + ResourceType: metrics.ResourceTypeKubernetes, + }) + } + + if config.WebApp != nil { + webAppCfg, err := webapphandler.NewConfig(config.WebApp.Headers, roundTripperMetrics, logger) + if err != nil { + return nil, fmt.Errorf("failed to create web app config: %w", err) + } + + webAppHandler := webapphandler.NewHandler(*webAppCfg) + + httpProxies[token.ResourceTypeWebApp] = httpproxy.NewProxy(httpproxy.Config{ + Handler: webAppHandler, + Metrics: httpMetrics, + Logger: logger, + ResourceType: metrics.ResourceTypeWebApp, }) } @@ -87,7 +114,7 @@ func NewProxy(config *gatewayconfig.Config, registry *prometheus.Registry, logge registry: registry, logger: logger, - httpProxy: httpProxy, + httpProxies: httpProxies, sshProxy: sshProxy, metricsServer: metricsServer, }, nil @@ -104,22 +131,22 @@ func (p *Proxy) Start() error { p.listener = listener - channels := make(map[connect.TransportProtocol]chan<- connect.Conn) + channels := make(map[token.ResourceType]chan<- connect.Conn) var sshListener *connect.ProtocolListener if p.sshProxy != nil { sshChannel := make(chan connect.Conn) - channels[connect.TransportSSH] = sshChannel + channels[token.ResourceTypeSSH] = sshChannel sshListener = connect.NewProtocolListener(sshChannel, listener.Addr()) } - var httpListener *connect.ProtocolListener + httpListeners := make(map[token.ResourceType]*connect.ProtocolListener) - if p.httpProxy != nil { - httpChannel := make(chan connect.Conn) - channels[connect.TransportTLS] = httpChannel - httpListener = connect.NewProtocolListener(httpChannel, listener.Addr()) + for resourceType := range p.httpProxies { + ch := make(chan connect.Conn) + channels[resourceType] = ch + httpListeners[resourceType] = connect.NewProtocolListener(ch, listener.Addr()) } connectListener, err := connect.NewListener( @@ -159,17 +186,19 @@ func (p *Proxy) Start() error { }) } - if p.httpProxy != nil { + for resourceType, proxy := range p.httpProxies { g.Go(func() error { - p.logger.Info("Starting HTTP proxy") + p.logger.Info("Starting HTTP proxy", zap.String("resource_type", string(resourceType))) - err := p.httpProxy.Start(httpListener) + err := proxy.Start(httpListeners[resourceType]) if errors.Is(err, http.ErrServerClosed) { return nil } if err != nil { - p.logger.Error("HTTP proxy stopped with error", zap.Error(err)) + p.logger.Error("HTTP proxy stopped with error", + zap.String("resource_type", string(resourceType)), + zap.Error(err)) } return err @@ -224,9 +253,11 @@ func (p *Proxy) shutdown() { } } - if p.httpProxy != nil { - if err := p.httpProxy.Shutdown(ctx); err != nil { - p.logger.Error("Failed to shut down HTTP proxy", zap.Error(err)) + for resourceType, proxy := range p.httpProxies { + if err := proxy.Shutdown(ctx); err != nil { + p.logger.Error("Failed to shut down HTTP proxy", + zap.String("resource_type", string(resourceType)), + zap.Error(err)) } } diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 2aa929d..a5718bf 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -21,6 +21,7 @@ import ( "gateway/internal/kuberneteshandler" "gateway/internal/metrics" "gateway/internal/sshhandler" + "gateway/internal/token" ) var fullConfig = gatewayconfig.Config{ @@ -67,14 +68,16 @@ func TestNewProxy_Success(t *testing.T) { assert.Equal(t, registry, p.registry) assert.Equal(t, logger, p.logger) - assert.NotNil(t, p.httpProxy) + assert.Len(t, p.httpProxies, 1) + assert.Contains(t, p.httpProxies, token.ResourceTypeKubernetes) assert.NotNil(t, p.sshProxy) assert.NotNil(t, p.metricsServer) } -func TestNewProxy_KubernetesOnly(t *testing.T) { +func TestNewProxy_HTTPOnly(t *testing.T) { config := fullConfig config.SSH = nil + config.WebApp = &gatewayconfig.WebAppConfig{Headers: map[string]string{}} registry := prometheus.NewRegistry() logger, err := NewLogger(DefaultLoggerName, false) @@ -84,7 +87,9 @@ func TestNewProxy_KubernetesOnly(t *testing.T) { require.NoError(t, err) assert.NotNil(t, p) - assert.NotNil(t, p.httpProxy) + assert.Len(t, p.httpProxies, 2) + assert.Contains(t, p.httpProxies, token.ResourceTypeKubernetes) + assert.Contains(t, p.httpProxies, token.ResourceTypeWebApp) assert.Nil(t, p.sshProxy) } @@ -101,7 +106,7 @@ func TestNewProxy_SSHOnly(t *testing.T) { require.NoError(t, err) assert.NotNil(t, p) assert.NotNil(t, p.sshProxy) - assert.Nil(t, p.httpProxy) + assert.Empty(t, p.httpProxies) } func createTestProxy(t *testing.T) (*Proxy, net.Listener) { @@ -139,12 +144,15 @@ func TestShutdown_ClosesAllComponents(t *testing.T) { require.NoError(t, err) httpProxy := httpproxy.NewProxy(httpproxy.Config{ - Handler: k8sHandler, - Registry: registry, - Logger: zap.NewNop(), + Handler: k8sHandler, + Metrics: metrics.RegisterHTTPMetrics(registry), + ResourceType: metrics.ResourceTypeKubernetes, + Logger: zap.NewNop(), }) - p.httpProxy = httpProxy + p.httpProxies = map[token.ResourceType]*httpproxy.Proxy{ + token.ResourceTypeKubernetes: httpProxy, + } // Start HTTP proxy on a protocol listener httpChannel := make(chan connect.Conn) @@ -153,7 +161,7 @@ func TestShutdown_ClosesAllComponents(t *testing.T) { httpDone := make(chan error, 1) go func() { - httpDone <- p.httpProxy.Start(httpListener) + httpDone <- httpProxy.Start(httpListener) }() // Create and attach a real SSH proxy @@ -229,11 +237,11 @@ func TestShutdown_NilComponents(t *testing.T) { p := &Proxy{ logger: zap.NewNop(), listener: nil, - httpProxy: nil, + httpProxies: nil, sshProxy: nil, metricsServer: metricsServer, } - // Should not panic with nil listener, httpProxy, and sshProxy + // Should not panic with nil listener, httpProxies, and sshProxy p.shutdown() } diff --git a/internal/token/gat_claims.go b/internal/token/gat_claims.go index 293d9a4..0ea5c4c 100644 --- a/internal/token/gat_claims.go +++ b/internal/token/gat_claims.go @@ -61,6 +61,10 @@ func (p GATClaims) Validate() error { return nil } +func (p GATClaims) ShouldUpgradeTLS() bool { + return p.Resource.Type == ResourceTypeKubernetes +} + func (p GATClaims) getHeaderType() string { return "GAT" } @@ -85,15 +89,25 @@ func (u User) MarshalLogObject(enc zapcore.ObjectEncoder) error { return err } +type GeoIPLocation struct { + Lat float64 `json:"lat"` + Lon float64 `json:"lon"` + Country string `json:"country,omitempty"` + Region string `json:"region,omitempty"` + City string `json:"city,omitempty"` +} + type Device struct { - ID string `json:"id"` + ID string `json:"id"` + Location GeoIPLocation `json:"location,omitzero"` } -type ResourceType = string +type ResourceType string const ( - ResourceTypeKubernetes = "KUBERNETES" - ResourceTypeSSH = "SSH" + ResourceTypeKubernetes ResourceType = "KUBERNETES" + ResourceTypeSSH ResourceType = "SSH" + ResourceTypeWebApp ResourceType = "WEB_APP" ) type Resource struct { diff --git a/internal/token/gat_claims_test.go b/internal/token/gat_claims_test.go index 521f5a5..f3a6b13 100644 --- a/internal/token/gat_claims_test.go +++ b/internal/token/gat_claims_test.go @@ -13,9 +13,43 @@ import ( "time" "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func TestGATClaims_ShouldUpgradeTLS(t *testing.T) { + tests := []struct { + name string + resourceType ResourceType + expected bool + }{ + { + name: "Kubernetes should upgrade TLS", + resourceType: ResourceTypeKubernetes, + expected: true, + }, + { + name: "SSH should not upgrade TLS", + resourceType: ResourceTypeSSH, + expected: false, + }, + { + name: "Web app should not upgrade TLS", + resourceType: ResourceTypeWebApp, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + claims := &GATClaims{ + Resource: Resource{Type: tt.resourceType}, + } + assert.Equal(t, tt.expected, claims.ShouldUpgradeTLS()) + }) + } +} + func TestGATTokenClaims_Validate(t *testing.T) { privateKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) @@ -29,6 +63,13 @@ func TestGATTokenClaims_Validate(t *testing.T) { }, Device: Device{ ID: "device-1", + Location: GeoIPLocation{ + Lat: 37.7749, + Lon: -122.4194, + Country: "US", + Region: "California", + City: "San Francisco", + }, }, Resource: Resource{ ID: "resource-1", diff --git a/internal/webapphandler/config.go b/internal/webapphandler/config.go new file mode 100644 index 0000000..b9bbada --- /dev/null +++ b/internal/webapphandler/config.go @@ -0,0 +1,39 @@ +// Copyright (c) Twingate Inc. +// SPDX-License-Identifier: MPL-2.0 + +package webapphandler + +import ( + "fmt" + "slices" + + "go.uber.org/zap" + + "gateway/internal/metrics" + "gateway/internal/webapphandler/template" +) + +type Config struct { + headers map[string]*template.Template + roundTripperMetrics *metrics.RoundTripperMetrics + logger *zap.Logger +} + +func NewConfig(configHeaders map[string]string, roundTripperMetrics *metrics.RoundTripperMetrics, logger *zap.Logger) (*Config, error) { + headers := make(map[string]*template.Template, len(configHeaders)) + + for name, value := range configHeaders { + tmpl, err := template.New(value) + if err != nil { + return nil, fmt.Errorf("header %q: %w", name, err) + } + + if key := tmpl.Key(); key != "" && !slices.Contains(template.AllowedWebAppKeys, key) { + return nil, fmt.Errorf("header %q: %w %q", name, template.ErrUnsupportedKey, key) + } + + headers[name] = tmpl + } + + return &Config{headers: headers, roundTripperMetrics: roundTripperMetrics, logger: logger}, nil +} diff --git a/internal/webapphandler/config_test.go b/internal/webapphandler/config_test.go new file mode 100644 index 0000000..ee67a8a --- /dev/null +++ b/internal/webapphandler/config_test.go @@ -0,0 +1,66 @@ +// Copyright (c) Twingate Inc. +// SPDX-License-Identifier: MPL-2.0 + +package webapphandler + +import ( + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/zap" + + "gateway/internal/webapphandler/template" +) + +func TestNewConfig(t *testing.T) { + tests := []struct { + name string + headers map[string]string + wantErr error + errContains string + }{ + { + name: "valid header templates", + headers: map[string]string{ + "Authorization": "Bearer {{twingate.jwt}}", + }, + }, + { + name: "invalid template syntax", + headers: map[string]string{ + "X-Invalid": "{{invalid}}", + }, + wantErr: template.ErrInvalidTemplate, + }, + { + name: "unsupported key", + headers: map[string]string{ + "X-Bad": "{{twingate.unknown}}", + }, + wantErr: template.ErrUnsupportedKey, + }, + { + name: "empty headers", + headers: map[string]string{}, + }, + { + name: "nil headers", + headers: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg, err := NewConfig(tt.headers, nil, zap.NewNop()) + + if tt.wantErr != nil { + require.ErrorIs(t, err, tt.wantErr) + + return + } + + require.NoError(t, err) + require.NotNil(t, cfg) + }) + } +} diff --git a/internal/webapphandler/handler.go b/internal/webapphandler/handler.go new file mode 100644 index 0000000..7cf144a --- /dev/null +++ b/internal/webapphandler/handler.go @@ -0,0 +1,84 @@ +// Copyright (c) Twingate Inc. +// SPDX-License-Identifier: MPL-2.0 + +package webapphandler + +import ( + "fmt" + "net/http" + "net/http/httputil" + "net/url" + "strings" + + "go.uber.org/zap" + + "gateway/internal/connect" + "gateway/internal/httpproxy" + "gateway/internal/metrics" + "gateway/internal/webapphandler/template" +) + +type Handler struct { + proxy http.Handler +} + +func NewHandler(cfg Config) *Handler { + proxy := &httputil.ReverseProxy{ + Rewrite: func(r *httputil.ProxyRequest) { + conn := httpproxy.ProxyConnFromContext(r.In.Context()) + + if err := rewrite(r, conn, cfg.headers); err != nil { + cfg.logger.Error("failed to rewrite headers", zap.Error(err)) + panic(err) + } + }, + Transport: metrics.InstrumentRoundTripper(cfg.roundTripperMetrics, metrics.ResourceTypeWebApp, http.DefaultTransport), + } + + return &Handler{proxy: proxy} +} + +func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + h.proxy.ServeHTTP(w, r) +} + +func buildVariables(conn *connect.ProxyConn) map[string]string { + claims := conn.GATClaims() + clientLocation := claims.Device.Location + + latLong := "" + if clientLocation.Lat != 0 || clientLocation.Lon != 0 { + latLong = fmt.Sprintf("%v,%v", clientLocation.Lat, clientLocation.Lon) + } + + return map[string]string{ + template.JWT: conn.GetToken(), + template.Username: claims.User.Username, + template.Groups: strings.Join(claims.User.Groups, ","), + template.ClientGeoLatLong: latLong, + template.ClientGeoCity: clientLocation.City, + template.ClientGeoRegion: clientLocation.Region, + template.ClientGeoCountry: clientLocation.Country, + } +} + +func rewrite(r *httputil.ProxyRequest, conn *connect.ProxyConn, headers map[string]*template.Template) error { + targetURL := &url.URL{ + Scheme: "http", // plain HTTP — no upstream TLS + Host: conn.GetAddress(), + } + r.SetURL(targetURL) + + variables := buildVariables(conn) + + for headerName, tmpl := range headers { + headerValue, err := tmpl.Evaluate(variables) + if err != nil { + return fmt.Errorf("header %q: %w", headerName, err) + } + + r.Out.Header.Set(headerName, headerValue) + } + + return nil +} diff --git a/internal/webapphandler/handler_test.go b/internal/webapphandler/handler_test.go new file mode 100644 index 0000000..1f100e5 --- /dev/null +++ b/internal/webapphandler/handler_test.go @@ -0,0 +1,181 @@ +// Copyright (c) Twingate Inc. +// SPDX-License-Identifier: MPL-2.0 + +package webapphandler + +import ( + "context" + "maps" + "net/http" + "net/http/httptest" + "net/http/httputil" + "slices" + "testing" + + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + + "gateway/internal/connect" + "gateway/internal/httpproxy" + "gateway/internal/metrics" + "gateway/internal/token" + "gateway/internal/webapphandler/template" +) + +func mustParse(t *testing.T, templates map[string]string) map[string]*template.Template { + t.Helper() + + result := make(map[string]*template.Template, len(templates)) + + for name, tmpl := range templates { + parsed, err := template.New(tmpl) + require.NoError(t, err, "failed to parse template for header %q", name) + + result[name] = parsed + } + + return result +} + +func TestNewHandler_PanicsOnRewriteError(t *testing.T) { + connMetrics := connect.CreateProxyConnMetrics(prometheus.NewRegistry()) + conn := connect.NewProxyConn(nil, nil, nil, zap.NewNop(), connMetrics) + conn.Claims = &token.GATClaims{ + User: token.User{Username: "alice@acme.com"}, + } + + unknownKeyTemplate, err := template.New("{{twingate.nonexistent}}") + require.NoError(t, err) + + handler := NewHandler(Config{ + headers: map[string]*template.Template{"X-Bad": unknownKeyTemplate}, + roundTripperMetrics: metrics.RegisterRoundTripperMetrics(prometheus.NewRegistry()), + logger: zap.NewNop(), + }) + + req := httptest.NewRequest(http.MethodGet, "http://test/api", nil) + ctx := context.WithValue(req.Context(), httpproxy.ConnContextKey{}, conn) + req = req.WithContext(ctx) + + assert.Panics(t, func() { + handler.ServeHTTP(httptest.NewRecorder(), req) + }) +} + +func TestRewrite(t *testing.T) { + baseClaims := &token.GATClaims{ + User: token.User{ + ID: "user-1", + Username: "alice@acme.com", + Groups: []string{"Everyone", "Engineering"}, + }, + Device: token.Device{ + ID: "device-1", + Location: token.GeoIPLocation{Lat: 37.5, Lon: -122.4, Country: "US", Region: "CA", City: "San Mateo"}, + }, + } + + tests := []struct { + name string + address string + jwtToken string + claims *token.GATClaims + headers map[string]string + wantHeaders map[string]string + }{ + { + name: "resolves all header templates", + jwtToken: "test-token", + claims: baseClaims, + headers: map[string]string{ + "Authorization": "Bearer {{twingate.jwt}}", + "X-Username": "{{twingate.username}}", + "X-Groups": "{{twingate.groups}}", + "X-LatLong": "{{twingate.clientGeoLatLong}}", + "X-City": "{{twingate.clientGeoCity}}", + "X-Region": "{{twingate.clientGeoRegion}}", + "X-Country": "{{twingate.clientGeoCountry}}", + "Existing": "new-value", + }, + wantHeaders: map[string]string{ + "Authorization": "Bearer test-token", + "X-Username": "alice@acme.com", + "X-Groups": "Everyone,Engineering", + "X-LatLong": "37.5,-122.4", + "X-City": "San Mateo", + "X-Region": "CA", + "X-Country": "US", + "Existing": "new-value", + }, + }, + { + name: "empty lat/lon with non-empty geo fields", + jwtToken: "test-token", + claims: &token.GATClaims{ + User: baseClaims.User, + Device: token.Device{ID: "device-1", Location: token.GeoIPLocation{Country: "US", Region: "CA", City: "San Mateo"}}, + }, + headers: map[string]string{ + "X-LatLong": "{{twingate.clientGeoLatLong}}", + "X-City": "{{twingate.clientGeoCity}}", + "X-Region": "{{twingate.clientGeoRegion}}", + "X-Country": "{{twingate.clientGeoCountry}}", + }, + wantHeaders: map[string]string{ + "X-LatLong": "", + "X-City": "San Mateo", + "X-Region": "CA", + "X-Country": "US", + "Existing": "old-value", + }, + }, + { + name: "empty headers", + jwtToken: "test-token", + claims: baseClaims, + headers: map[string]string{}, + wantHeaders: map[string]string{ + "Existing": "old-value", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + connMetrics := connect.CreateProxyConnMetrics(prometheus.NewRegistry()) + conn := connect.NewProxyConn(nil, nil, nil, zap.NewNop(), connMetrics) + conn.Address = tt.address + conn.Token = tt.jwtToken + conn.Claims = tt.claims + + outReq := httptest.NewRequest(http.MethodGet, "http://test/api/resource", nil) + outReq.Header.Set("Existing", "old-value") + + proxyReq := &httputil.ProxyRequest{ + In: httptest.NewRequest(http.MethodGet, "http://test/api/resource", nil), + Out: outReq, + } + parsedHeaders := mustParse(t, tt.headers) + + err := rewrite(proxyReq, conn, parsedHeaders) + require.NoError(t, err) + + for name, wantValue := range tt.wantHeaders { + assert.Equal(t, wantValue, proxyReq.Out.Header.Get(name)) + } + }) + } +} + +func TestBuildVariables_CoversAllowedKeys(t *testing.T) { + connMetrics := connect.CreateProxyConnMetrics(prometheus.NewRegistry()) + conn := connect.NewProxyConn(nil, nil, nil, zap.NewNop(), connMetrics) + conn.Claims = &token.GATClaims{} + + got := slices.Sorted(maps.Keys(buildVariables(conn))) + want := slices.Sorted(slices.Values(template.AllowedWebAppKeys)) + + assert.Equal(t, want, got) +} diff --git a/internal/webapphandler/template/template.go b/internal/webapphandler/template/template.go new file mode 100644 index 0000000..218f94f --- /dev/null +++ b/internal/webapphandler/template/template.go @@ -0,0 +1,103 @@ +// Copyright (c) Twingate Inc. +// SPDX-License-Identifier: MPL-2.0 + +package template + +import ( + "errors" + "fmt" + "regexp" + "strings" + "unicode" +) + +const ( + allowedNamespace = "twingate" +) + +const ( + JWT = "jwt" + Username = "username" + Groups = "groups" + ClientGeoLatLong = "clientGeoLatLong" + ClientGeoCity = "clientGeoCity" + ClientGeoRegion = "clientGeoRegion" + ClientGeoCountry = "clientGeoCountry" +) + +var AllowedWebAppKeys = []string{ + JWT, + Username, + Groups, + ClientGeoLatLong, + ClientGeoCity, + ClientGeoRegion, + ClientGeoCountry, +} + +var ( + ErrInvalidTemplate = errors.New("invalid template") + ErrUnknownKey = errors.New("unknown key") + ErrUnsupportedKey = errors.New("unsupported key") +) + +var templateRe = regexp.MustCompile( + `^([^{}]*)` + // prefix (no braces allowed) + `{{\s*` + // opening braces + `([a-zA-Z0-9_-]+)` + // namespace + `\.` + + `([a-zA-Z0-9_-]+)` + // key + `\s*}}` + // closing braces + `([^{}]*)$`, // suffix (no braces allowed) +) + +type Template struct { + prefix string + key string + suffix string +} + +// New parses a string like " {{.}} " into a Template. +// If there is no template variable (just a static string), the key and suffix are empty and the prefix is the static string. +func New(s string) (*Template, error) { + match := templateRe.FindStringSubmatch(s) + + if match == nil { + if strings.Contains(s, "{{") || strings.Contains(s, "}}") { + return nil, fmt.Errorf("%w: unsupported syntax. Syntax must be {{twingate.key}} ", ErrInvalidTemplate) + } + + return &Template{prefix: strings.TrimSpace(s)}, nil + } + + prefix, namespace, key, suffix := match[1], match[2], match[3], match[4] + + if namespace != allowedNamespace { + return nil, fmt.Errorf("%w: unsupported namespace %q", ErrInvalidTemplate, namespace) + } + + return &Template{ + prefix: strings.TrimLeftFunc(prefix, unicode.IsSpace), + key: key, + suffix: strings.TrimRightFunc(suffix, unicode.IsSpace), + }, nil +} + +func (t *Template) Key() string { + return t.key +} + +// Evaluate replaces the key in the template with the corresponding value from the map +// and returns the resulting string along with the prefix and suffix. +func (t *Template) Evaluate(values map[string]string) (string, error) { + if t.key == "" { + return t.prefix, nil + } + + result, ok := values[t.key] + if !ok { + return "", fmt.Errorf("%w: %q", ErrUnknownKey, t.key) + } + + return t.prefix + result + t.suffix, nil +} diff --git a/internal/webapphandler/template/template_test.go b/internal/webapphandler/template/template_test.go new file mode 100644 index 0000000..0cb5b46 --- /dev/null +++ b/internal/webapphandler/template/template_test.go @@ -0,0 +1,145 @@ +// Copyright (c) Twingate Inc. +// SPDX-License-Identifier: MPL-2.0 + +package template + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParser_NewTemplate(t *testing.T) { + tests := []struct { + name string + input string + wantPrefix string + wantKey string + wantSuffix string + wantErr error + errSubstr string + }{ + { + name: "plain text", + input: " static-value ", + wantPrefix: "static-value", + }, + { + name: "empty string", + input: "", + }, + { + name: "template only", + input: "{{twingate.jwt}}", + wantKey: "jwt", + }, + { + name: "template with leading and trailing space", + input: "{{ twingate.jwt }}", + wantKey: "jwt", + }, + { + name: "template with prefix", + input: " Bearer {{twingate.jwt}}", + wantPrefix: "Bearer ", + wantKey: "jwt", + }, + { + name: "template with suffix", + input: "{{twingate.username}}/profile ", + wantKey: "username", + wantSuffix: "/profile", + }, + { + name: "invalid template", + input: "{{invalid}}", + wantErr: ErrInvalidTemplate, + errSubstr: "unsupported syntax", + }, + { + name: "missing opening braces", + input: "twingate.jwt }}", + wantErr: ErrInvalidTemplate, + errSubstr: "unsupported syntax", + }, + { + name: "missing closing braces", + input: "{{ twingate.jwt", + wantErr: ErrInvalidTemplate, + errSubstr: "unsupported syntax", + }, + { + name: "multiple templates", + input: "{{twingate.username}} {{twingate.groups}}", + wantErr: ErrInvalidTemplate, + errSubstr: "unsupported syntax", + }, + { + name: "non-twingate namespace", + input: "{{other.key}}", + wantErr: ErrInvalidTemplate, + errSubstr: "unsupported namespace", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + template, err := New(tt.input) + + if tt.wantErr != nil { + require.ErrorIs(t, err, tt.wantErr) + assert.Contains(t, err.Error(), tt.errSubstr) + + return + } + + require.NoError(t, err) + + assert.Equal(t, tt.wantPrefix, template.prefix) + assert.Equal(t, tt.wantKey, template.key) + assert.Equal(t, tt.wantSuffix, template.suffix) + }) + } +} + +func TestParser_Evaluate(t *testing.T) { + tests := []struct { + name string + template Template + values map[string]string + want string + wantErr error + errSubstr string + }{ + { + name: "Success", + template: Template{prefix: "Prefix ", key: "foo", suffix: " suffix"}, + values: map[string]string{"foo": "bar", "extra": "foo"}, + want: "Prefix bar suffix", + }, + { + name: "Missing key", + template: Template{prefix: "Bearer ", key: "jwt", suffix: ""}, + values: map[string]string{}, + wantErr: ErrUnknownKey, + errSubstr: "jwt", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := tt.template.Evaluate(tt.values) + + if tt.wantErr != nil { + require.ErrorIs(t, err, tt.wantErr) + assert.Contains(t, err.Error(), tt.errSubstr) + + return + } + + require.NoError(t, err) + assert.Equal(t, tt.want, result) + }) + } +}