From b21c9c46fc276de4ebc7d843829e9b22d0707ad5 Mon Sep 17 00:00:00 2001 From: Clement Tee Date: Thu, 28 May 2026 22:05:29 +0800 Subject: [PATCH 01/17] feat: Support HTTP web app Co-Authored-By: Claude Opus 4.6 --- internal/config/config.go | 9 +- internal/connect/conn.go | 17 +- internal/connect/conn_test.go | 35 ++++ internal/connect/connect.go | 2 + internal/connect/connect_test.go | 2 + internal/connect/listener.go | 3 +- internal/connect/listener_test.go | 8 + internal/httpproxy/proxy.go | 22 ++- internal/httpproxy/proxy_test.go | 68 +++++++- internal/kuberneteshandler/kubernetes.go | 2 +- internal/metrics/http_middleware.go | 3 +- internal/metrics/round_tripper.go | 25 +-- internal/metrics/round_tripper_test.go | 27 ++-- internal/proxy/proxy.go | 24 ++- internal/proxy/proxy_test.go | 3 +- internal/token/gat_claims.go | 16 +- internal/token/gat_claims_test.go | 9 ++ internal/utils/parser/parser.go | 82 ++++++++++ internal/utils/parser/parser_test.go | 151 ++++++++++++++++++ internal/webapphandler/web_app.go | 76 +++++++++ internal/webapphandler/web_app_config.go | 47 ++++++ internal/webapphandler/web_app_config_test.go | 67 ++++++++ internal/webapphandler/web_app_test.go | 126 +++++++++++++++ 23 files changed, 785 insertions(+), 39 deletions(-) create mode 100644 internal/utils/parser/parser.go create mode 100644 internal/utils/parser/parser_test.go create mode 100644 internal/webapphandler/web_app.go create mode 100644 internal/webapphandler/web_app_config.go create mode 100644 internal/webapphandler/web_app_config_test.go create mode 100644 internal/webapphandler/web_app_test.go diff --git a/internal/config/config.go b/internal/config/config.go index debd462..0618f25 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -41,6 +41,11 @@ type Config struct { TLS TLSConfig `yaml:"tls"` Kubernetes *KubernetesConfig `yaml:"kubernetes,omitempty"` SSH *SSHConfig `yaml:"ssh,omitempty"` + WebApp *WebAppConfig `yaml:"webApp,omitempty"` +} + +type WebAppConfig struct { + Headers map[string]string `yaml:"headers,omitempty"` } type TwingateConfig struct { @@ -269,8 +274,8 @@ func (c *Config) Validate() error { } // Check that at least one protocol is configured - if c.Kubernetes == nil && c.SSH == nil { - return fmt.Errorf("%w: at least one protocol (Kubernetes or SSH) must be configured", ErrRequired) + if c.Kubernetes == nil && c.SSH == nil && c.WebApp == nil { + return fmt.Errorf("%w: at least one protocol (Kubernetes, SSH, or WebApp) must be configured", ErrRequired) } return nil diff --git a/internal/connect/conn.go b/internal/connect/conn.go index 2534765..00cd631 100644 --- a/internal/connect/conn.go +++ b/internal/connect/conn.go @@ -47,8 +47,10 @@ type Conn interface { GATClaims() *token.GATClaims GetID() string GetAddress() string + GetToken() string Authenticate() error TransportProtocol() TransportProtocol + ShouldUpgradeTLS() bool UpgradeToTLS() error Close() error @@ -64,6 +66,7 @@ type ProxyConn struct { ID string Address string Claims *token.GATClaims + Token string Timer *time.Timer Mu sync.Mutex @@ -105,6 +108,10 @@ func (p *ProxyConn) TransportProtocol() TransportProtocol { return TransportTLS } +func (p *ProxyConn) ShouldUpgradeTLS() bool { + return p.TransportProtocol() == TransportTLS && p.GATClaims().Resource.Type != token.ResourceTypeWebApp +} + func (p *ProxyConn) GATClaims() *token.GATClaims { return p.Claims } @@ -117,6 +124,10 @@ func (p *ProxyConn) GetAddress() string { return p.Address } +func (p *ProxyConn) GetToken() string { + return p.Token +} + // Authenticate sets up TLS and processes the CONNECT message for authentication. func (p *ProxyConn) Authenticate() error { _ = p.SetDeadline(time.Now().Add(defaultTimeout)) @@ -225,7 +236,10 @@ func (p *ProxyConn) Authenticate() error { p.tracker.RecordConnectMetrics(httpCode) - p.Logger.Info("Authenticated connection", zap.String("resource_address", connectInfo.Claims.Resource.Address)) + p.Logger.Info("Authenticated connection", + zap.String("resource_type", connectInfo.Claims.Resource.Type), + zap.String("resource_address", connectInfo.Claims.Resource.Address), + ) p.setConnectInfo(connectInfo) return nil @@ -249,6 +263,7 @@ func (p *ProxyConn) setConnectInfo(connectInfo Info) { p.ID = connectInfo.ConnID p.Address = connectInfo.Address p.Claims = connectInfo.Claims + p.Token = connectInfo.Token p.Timer = time.AfterFunc(time.Until(connectInfo.Claims.ExpiresAt.Time), func() { _ = p.Close() }) diff --git a/internal/connect/conn_test.go b/internal/connect/conn_test.go index 84e2787..286bf0d 100644 --- a/internal/connect/conn_test.go +++ b/internal/connect/conn_test.go @@ -110,6 +110,41 @@ func TestProxyConn_TransportProtocol(t *testing.T) { } } +func TestProxyConn_ShouldUpgradeTLS(t *testing.T) { + tests := []struct { + name string + resourceType token.ResourceType + expected bool + }{ + { + name: "Kubernetes should upgrade TLS", + resourceType: token.ResourceTypeKubernetes, + expected: true, + }, + { + name: "SSH should not upgrade TLS", + resourceType: token.ResourceTypeSSH, + expected: false, + }, + { + name: "Web app should not upgrade TLS", + resourceType: token.ResourceTypeWebApp, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + claims := &token.GATClaims{ + Resource: token.Resource{Type: tt.resourceType}, + } + proxyConn := &ProxyConn{Claims: claims} + + assert.Equal(t, tt.expected, proxyConn.ShouldUpgradeTLS()) + }) + } +} + func TestProxyConn_Close(t *testing.T) { conn := &mockConn{} timer := time.NewTimer(0 * time.Millisecond) diff --git a/internal/connect/connect.go b/internal/connect/connect.go index 8ad610e..acd90d7 100644 --- a/internal/connect/connect.go +++ b/internal/connect/connect.go @@ -28,6 +28,7 @@ type Info struct { Address string Claims *token.GATClaims ConnID string + Token string } type HTTPError struct { @@ -156,5 +157,6 @@ func (v *MessageValidator) ParseConnect(req *http.Request, ekm []byte) (connectI Address: address, Claims: gatClaims, ConnID: connID, + Token: bearerToken, }, nil } diff --git a/internal/connect/connect_test.go b/internal/connect/connect_test.go index 5d8e868..0c38faa 100644 --- a/internal/connect/connect_test.go +++ b/internal/connect/connect_test.go @@ -110,6 +110,7 @@ func TestConnectValidator_ParseConnect(t *testing.T) { require.NoError(t, err) assert.Equal(t, *connectInfo.Claims, gatClaims) assert.Equal(t, "conn-id", connectInfo.ConnID) + assert.Equal(t, signedToken, connectInfo.Token) }) t.Run("Non-CONNECT method", func(t *testing.T) { @@ -127,6 +128,7 @@ func TestConnectValidator_ParseConnect(t *testing.T) { assert.Contains(t, httpErr.Error(), "expected CONNECT request") assert.Nil(t, connectInfo.Claims) assert.Empty(t, connectInfo.ConnID) + assert.Empty(t, connectInfo.Token) }) t.Run("Missing auth header", func(t *testing.T) { diff --git a/internal/connect/listener.go b/internal/connect/listener.go index c66bf62..b88741a 100644 --- a/internal/connect/listener.go +++ b/internal/connect/listener.go @@ -183,8 +183,7 @@ func (l *Listener) Serve(ctx context.Context, listener net.Listener) error { return } - // For non-SSH protocols, upgrade to TLS - if tp != TransportSSH { + if proxyConn.ShouldUpgradeTLS() { if err := proxyConn.UpgradeToTLS(); err != nil { l.logger.Error("Failed to upgrade to TLS", zap.Error(err)) diff --git a/internal/connect/listener_test.go b/internal/connect/listener_test.go index f674e35..2782f71 100644 --- a/internal/connect/listener_test.go +++ b/internal/connect/listener_test.go @@ -60,6 +60,10 @@ func (m *mockProxyConn) GetAddress() string { return "mock" } +func (m *mockProxyConn) GetToken() string { + return "" +} + func (m *mockProxyConn) Authenticate() error { if m.isHealthz { // write health check response @@ -74,6 +78,10 @@ func (m *mockProxyConn) Authenticate() error { return nil } +func (m *mockProxyConn) ShouldUpgradeTLS() bool { + return m.TransportProtocol() == TransportTLS && m.GATClaims().Resource.Type != token.ResourceTypeWebApp +} + func (m *mockProxyConn) UpgradeToTLS() error { return nil } diff --git a/internal/httpproxy/proxy.go b/internal/httpproxy/proxy.go index 51d9cf0..5c89bc1 100644 --- a/internal/httpproxy/proxy.go +++ b/internal/httpproxy/proxy.go @@ -28,7 +28,7 @@ func ProxyConnFromContext(ctx context.Context) *connect.ProxyConn { } type Config struct { - Handler http.Handler + Handlers map[string]http.Handler Registry *prometheus.Registry Logger *zap.Logger } @@ -37,11 +37,29 @@ type Proxy struct { httpServer *http.Server } +func newResourceRouter(handlers map[string]http.Handler, logger *zap.Logger) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn := ProxyConnFromContext(r.Context()) + resourceType := conn.GATClaims().Resource.Type + + handler, exists := handlers[resourceType] + if !exists { + logger.Error("No handler for resource type", zap.String("type", resourceType)) + http.Error(w, "unsupported resource type", http.StatusNotFound) + + return + } + + handler.ServeHTTP(w, r) + }) +} + func NewProxy(cfg Config) *Proxy { + router := newResourceRouter(cfg.Handlers, cfg.Logger) handler := metrics.HTTPMiddleware(metrics.HTTPMiddlewareConfig{ Registry: cfg.Registry, Next: auditMiddleware(auditMiddlewareConfig{ - next: cfg.Handler, + next: router, logger: cfg.Logger, }), }) diff --git a/internal/httpproxy/proxy_test.go b/internal/httpproxy/proxy_test.go index 22fce14..48bd8e4 100644 --- a/internal/httpproxy/proxy_test.go +++ b/internal/httpproxy/proxy_test.go @@ -8,6 +8,7 @@ import ( "io" "net" "net/http" + "net/http/httptest" "testing" "github.com/prometheus/client_golang/prometheus" @@ -33,7 +34,10 @@ func (l *mockConnListener) Accept() (net.Conn, error) { proxyConn := connect.NewProxyConn(conn, nil, nil, zap.NewNop(), connMetrics) proxyConn.ID = "test-conn" proxyConn.Address = "localhost" - proxyConn.Claims = &token.GATClaims{User: token.User{Username: "test@acme.com"}} + proxyConn.Claims = &token.GATClaims{ + User: token.User{Username: "test@acme.com"}, + Resource: token.Resource{Type: token.ResourceTypeKubernetes}, + } return proxyConn, nil } @@ -56,6 +60,64 @@ func TestProxyConnFromContext(t *testing.T) { }) } +func TestResourceRouter_DispatchesToCorrectHandler(t *testing.T) { + var handledBy = "" + + handlers := map[string]http.Handler{ + token.ResourceTypeKubernetes: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + handledBy = "kubernetes" + + w.WriteHeader(http.StatusOK) + }), + token.ResourceTypeSSH: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + handledBy = "ssh" + + w.WriteHeader(http.StatusOK) + }), + } + + connMetrics := connect.CreateProxyConnMetrics(prometheus.NewRegistry()) + proxyConn := connect.NewProxyConn(nil, nil, nil, zap.NewNop(), connMetrics) + proxyConn.Claims = &token.GATClaims{ + Resource: token.Resource{Type: token.ResourceTypeKubernetes}, + } + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/test", nil) + ctx := context.WithValue(req.Context(), ConnContextKey{}, proxyConn) + req = req.WithContext(ctx) + + router := newResourceRouter(handlers, zap.NewNop()) + router.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusOK, recorder.Code) + assert.Equal(t, "kubernetes", handledBy) +} + +func TestResourceRouter_UnknownResource(t *testing.T) { + handlers := map[string]http.Handler{ + token.ResourceTypeKubernetes: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }), + } + + connMetrics := connect.CreateProxyConnMetrics(prometheus.NewRegistry()) + proxyConn := connect.NewProxyConn(nil, nil, nil, zap.NewNop(), connMetrics) + proxyConn.Claims = &token.GATClaims{ + Resource: token.Resource{Type: "unknown"}, + } + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/test", nil) + ctx := context.WithValue(req.Context(), ConnContextKey{}, proxyConn) + req = req.WithContext(ctx) + + router := newResourceRouter(handlers, zap.NewNop()) + router.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusNotFound, recorder.Code) +} + func TestProxy_ForwardRequest(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) @@ -66,7 +128,7 @@ func TestProxy_ForwardRequest(t *testing.T) { proxy := NewProxy(Config{ Registry: prometheus.NewRegistry(), - Handler: handler, + Handlers: map[string]http.Handler{token.ResourceTypeKubernetes: handler}, Logger: zap.NewNop(), }) @@ -99,7 +161,7 @@ func TestProxy_Shutdown(t *testing.T) { require.NoError(t, err) proxy := NewProxy(Config{ - Handler: handler, + Handlers: map[string]http.Handler{token.ResourceTypeKubernetes: handler}, Registry: prometheus.NewRegistry(), Logger: zap.NewNop(), }) diff --git a/internal/kuberneteshandler/kubernetes.go b/internal/kuberneteshandler/kubernetes.go index 433dcd4..2e71ed3 100644 --- a/internal/kuberneteshandler/kubernetes.go +++ b/internal/kuberneteshandler/kubernetes.go @@ -44,7 +44,7 @@ func NewHandler(cfg Config) (*Handler, error) { conn := httpproxy.ProxyConnFromContext(r.In.Context()) rewrite(r, conn) }, - Transport: metrics.InstrumentRoundTripper(cfg.roundTripperMetrics, transport), + Transport: metrics.InstrumentRoundTripper(cfg.roundTripperMetrics, "kubernetes", transport), } handler := &Handler{ diff --git a/internal/metrics/http_middleware.go b/internal/metrics/http_middleware.go index ee82590..bdfbf65 100644 --- a/internal/metrics/http_middleware.go +++ b/internal/metrics/http_middleware.go @@ -17,7 +17,8 @@ import ( // Metric label names. const ( - labelRequestType = "type" + labelRequestType = "type" + labelResourceType = "resourceType" ) // Request type values. diff --git a/internal/metrics/round_tripper.go b/internal/metrics/round_tripper.go index 066f6cb..6997365 100644 --- a/internal/metrics/round_tripper.go +++ b/internal/metrics/round_tripper.go @@ -4,6 +4,7 @@ package metrics import ( + "context" "net/http" "github.com/prometheus/client_golang/prometheus" @@ -22,13 +23,13 @@ func RegisterRoundTripperMetrics(registry *prometheus.Registry) *RoundTripperMet Namespace: Namespace, Name: "api_server_requests_total", Help: "Total number of requests from Gateway to API Server processed", - }, []string{"type", "method", "code"}), + }, []string{"resourceType", "type", "method", "code"}), activeRequests: prometheus.NewGaugeVec(prometheus.GaugeOpts{ Namespace: Namespace, Name: "api_server_active_requests", Help: "Number of currently active requests from Gateway to API Server", - }, []string{"type"}), + }, []string{"resourceType", "type"}), requestDuration: prometheus.NewHistogramVec( prometheus.HistogramOpts{ @@ -36,7 +37,7 @@ func RegisterRoundTripperMetrics(registry *prometheus.Registry) *RoundTripperMet Name: "api_server_request_duration_seconds", Help: "Measures the initial HTTP request-response latency between Gateway and API Server in seconds. For HTTP streaming, WebSocket, and SPDY connections, this metric captures only the setup time and not the duration of the data transfer.", Buckets: prometheus.DefBuckets, - }, []string{"type", "method", "code"}), + }, []string{"resourceType", "type", "method", "code"}), } registry.MustRegister(c.requestsTotal, c.activeRequests, c.requestDuration) @@ -44,20 +45,24 @@ func RegisterRoundTripperMetrics(registry *prometheus.Registry) *RoundTripperMet return c } -func InstrumentRoundTripper(metrics *RoundTripperMetrics, next http.RoundTripper) promhttp.RoundTripperFunc { - opts := promhttp.WithLabelFromCtx(labelRequestType, getRequestTypeFromContext) +func InstrumentRoundTripper(metrics *RoundTripperMetrics, resourceType string, next http.RoundTripper) promhttp.RoundTripperFunc { + resourceTypeOpt := promhttp.WithLabelFromCtx(labelResourceType, func(_ context.Context) string { return resourceType }) + requestTypeOpt := promhttp.WithLabelFromCtx(labelRequestType, getRequestTypeFromContext) base := promhttp.InstrumentRoundTripperCounter( metrics.requestsTotal, instrumentRoundTripperInFlight( metrics.activeRequests, + resourceType, promhttp.InstrumentRoundTripperDuration( metrics.requestDuration, next, - opts, + resourceTypeOpt, + requestTypeOpt, ), ), - opts, + resourceTypeOpt, + requestTypeOpt, ) return func(r *http.Request) (*http.Response, error) { @@ -65,12 +70,12 @@ func InstrumentRoundTripper(metrics *RoundTripperMetrics, next http.RoundTripper } } -func instrumentRoundTripperInFlight(activeRequests *prometheus.GaugeVec, next http.RoundTripper) promhttp.RoundTripperFunc { +func instrumentRoundTripperInFlight(activeRequests *prometheus.GaugeVec, resourceType string, next http.RoundTripper) promhttp.RoundTripperFunc { return func(r *http.Request) (*http.Response, error) { requestType := getRequestTypeFromContext(r.Context()) - activeRequests.WithLabelValues(requestType).Inc() - defer activeRequests.WithLabelValues(requestType).Dec() + activeRequests.WithLabelValues(resourceType, requestType).Inc() + defer activeRequests.WithLabelValues(resourceType, requestType).Dec() return next.RoundTrip(r) } diff --git a/internal/metrics/round_tripper_test.go b/internal/metrics/round_tripper_test.go index 62c60bb..dd99473 100644 --- a/internal/metrics/round_tripper_test.go +++ b/internal/metrics/round_tripper_test.go @@ -23,7 +23,7 @@ func TestInstrumentRoundTripper(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) - transport := InstrumentRoundTripper(collectors, promhttp.RoundTripperFunc(func(r *http.Request) (*http.Response, error) { + transport := InstrumentRoundTripper(collectors, "kubernetes", promhttp.RoundTripperFunc(func(r *http.Request) (*http.Response, error) { return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody, Request: r}, nil })) @@ -38,17 +38,20 @@ func TestInstrumentRoundTripper(t *testing.T) { labelsByMetric := testutil.ExtractLabelsFromMetrics(metricFamilies) expectedLabels := map[string]map[string]string{ "twingate_gateway_api_server_requests_total": { - "type": "http", - "method": "get", - "code": "200", + "resourceType": "kubernetes", + "type": "http", + "method": "get", + "code": "200", }, "twingate_gateway_api_server_active_requests": { - "type": "http", + "resourceType": "kubernetes", + "type": "http", }, "twingate_gateway_api_server_request_duration_seconds": { - "type": "http", - "method": "get", - "code": "200", + "resourceType": "kubernetes", + "type": "http", + "method": "get", + "code": "200", }, } assert.Equal(t, expectedLabels, labelsByMetric) @@ -64,17 +67,17 @@ func TestInstrumentRoundTripper_MultipleTransports(t *testing.T) { }) // Instrumenting multiple transports should not panic - transport1 := InstrumentRoundTripper(collectors, mockTransport) - transport2 := InstrumentRoundTripper(collectors, mockTransport) + k8sTransport := InstrumentRoundTripper(collectors, "kubernetes", mockTransport) + webAppTransport := InstrumentRoundTripper(collectors, "webapp", mockTransport) req := httptest.NewRequest(http.MethodGet, "/", nil) - resp1, err := transport1.RoundTrip(req) + resp1, err := k8sTransport.RoundTrip(req) require.NoError(t, err) defer resp1.Body.Close() - resp2, err := transport2.RoundTrip(req) + resp2, err := webAppTransport.RoundTrip(req) require.NoError(t, err) defer resp2.Body.Close() diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 5aaee99..fc2092f 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -25,6 +25,8 @@ import ( "gateway/internal/metrics" "gateway/internal/sessionrecorder" "gateway/internal/sshhandler" + "gateway/internal/token" + "gateway/internal/webapphandler" ) const shutdownTimeout = 30 * time.Second @@ -45,9 +47,11 @@ type Proxy struct { func NewProxy(config *gatewayconfig.Config, registry *prometheus.Registry, logger *zap.Logger) (*Proxy, error) { var httpProxy *httpproxy.Proxy - if config.Kubernetes != nil { - roundTripperMetrics := metrics.RegisterRoundTripperMetrics(registry) + handlers := make(map[string]http.Handler) + + roundTripperMetrics := metrics.RegisterRoundTripperMetrics(registry) + if config.Kubernetes != nil { k8sConfig, err := kuberneteshandler.NewConfig(&config.AuditLog, config.Kubernetes, roundTripperMetrics, logger) if err != nil { return nil, fmt.Errorf("failed to create Kubernetes config: %w", err) @@ -58,8 +62,22 @@ func NewProxy(config *gatewayconfig.Config, registry *prometheus.Registry, logge return nil, fmt.Errorf("failed to create Kubernetes handler: %w", err) } + handlers[token.ResourceTypeKubernetes] = k8sHandler + } + + if config.WebApp != nil { + webAppCfg, err := webapphandler.NewConfig(config.WebApp.Headers, roundTripperMetrics, logger) + if err != nil { + return nil, fmt.Errorf("failed to create web app config: %w", err) + } + + webAppHandler := webapphandler.NewHandler(*webAppCfg) + handlers[token.ResourceTypeWebApp] = webAppHandler + } + + if len(handlers) > 0 { httpProxy = httpproxy.NewProxy(httpproxy.Config{ - Handler: k8sHandler, + Handlers: handlers, Registry: registry, Logger: logger, }) diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 2aa929d..40ce41c 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -21,6 +21,7 @@ import ( "gateway/internal/kuberneteshandler" "gateway/internal/metrics" "gateway/internal/sshhandler" + "gateway/internal/token" ) var fullConfig = gatewayconfig.Config{ @@ -139,7 +140,7 @@ func TestShutdown_ClosesAllComponents(t *testing.T) { require.NoError(t, err) httpProxy := httpproxy.NewProxy(httpproxy.Config{ - Handler: k8sHandler, + Handlers: map[string]http.Handler{token.ResourceTypeKubernetes: k8sHandler}, Registry: registry, Logger: zap.NewNop(), }) diff --git a/internal/token/gat_claims.go b/internal/token/gat_claims.go index 293d9a4..56a0cad 100644 --- a/internal/token/gat_claims.go +++ b/internal/token/gat_claims.go @@ -85,8 +85,21 @@ func (u User) MarshalLogObject(enc zapcore.ObjectEncoder) error { return err } +type GeoIP struct { + Lat float64 `json:"lat"` + Lon float64 `json:"lon"` + Country string `json:"country,omitempty"` + Region string `json:"region,omitempty"` + City string `json:"city,omitempty"` +} + +type DeviceLocation struct { + GeoIP GeoIP `json:"geoip"` +} + type Device struct { - ID string `json:"id"` + ID string `json:"id"` + Location DeviceLocation `json:"location,omitzero"` } type ResourceType = string @@ -94,6 +107,7 @@ type ResourceType = string const ( ResourceTypeKubernetes = "KUBERNETES" ResourceTypeSSH = "SSH" + ResourceTypeWebApp = "WEB_APP" ) type Resource struct { diff --git a/internal/token/gat_claims_test.go b/internal/token/gat_claims_test.go index 521f5a5..3d48cb7 100644 --- a/internal/token/gat_claims_test.go +++ b/internal/token/gat_claims_test.go @@ -29,6 +29,15 @@ func TestGATTokenClaims_Validate(t *testing.T) { }, Device: Device{ ID: "device-1", + Location: DeviceLocation{ + GeoIP: GeoIP{ + Lat: 37.7749, + Lon: -122.4194, + Country: "US", + Region: "California", + City: "San Francisco", + }, + }, }, Resource: Resource{ ID: "resource-1", diff --git a/internal/utils/parser/parser.go b/internal/utils/parser/parser.go new file mode 100644 index 0000000..8235366 --- /dev/null +++ b/internal/utils/parser/parser.go @@ -0,0 +1,82 @@ +// Copyright (c) Twingate Inc. +// SPDX-License-Identifier: MPL-2.0 + +package parser + +import ( + "errors" + "fmt" + "regexp" + "strings" + "unicode" +) + +const ( + allowedNamespace = "twingate" +) + +var ( + ErrInvalidTemplate = errors.New("invalid template") + ErrUnknownVariable = errors.New("unknown variable") +) + +var templateRe = regexp.MustCompile( + `^(.*?)` + // prefix + `{{\s*` + // opening braces + `([a-zA-Z0-9_-]+)` + // namespace + `\.` + + `([a-zA-Z0-9_-]+)` + // key + `\s*}}` + // closing braces + `(.*)$`, // suffix +) + +type Template struct { + prefix string + variable string + suffix string +} + +func New(s string) (*Template, error) { + match := templateRe.FindStringSubmatch(s) + + if match == nil { + if strings.Contains(s, "{{") || strings.Contains(s, "}}") { + return nil, fmt.Errorf("%w: invalid variable syntax", ErrInvalidTemplate) + } + + return &Template{prefix: s}, nil + } + + prefix, namespace, variable, suffix := match[1], match[2], match[3], match[4] + + if namespace != allowedNamespace { + return nil, fmt.Errorf("%w: unsupported namespace %q", ErrInvalidTemplate, namespace) + } + + if templateRe.MatchString(suffix) { + return nil, fmt.Errorf("%w: multiple variable are not supported", ErrInvalidTemplate) + } + + return &Template{ + prefix: strings.TrimLeftFunc(prefix, unicode.IsSpace), + variable: variable, + suffix: strings.TrimRightFunc(suffix, unicode.IsSpace), + }, nil +} + +func (t *Template) Variable() string { + return t.variable +} + +func (t *Template) Evaluate(variables map[string]string) (string, error) { + if t.variable == "" { + return t.prefix, nil + } + + result, ok := variables[t.variable] + if !ok { + return "", fmt.Errorf("%w: %q", ErrUnknownVariable, t.variable) + } + + return t.prefix + result + t.suffix, nil +} diff --git a/internal/utils/parser/parser_test.go b/internal/utils/parser/parser_test.go new file mode 100644 index 0000000..83c059a --- /dev/null +++ b/internal/utils/parser/parser_test.go @@ -0,0 +1,151 @@ +// Copyright (c) Twingate Inc. +// SPDX-License-Identifier: MPL-2.0 + +package parser + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParser_New(t *testing.T) { + tests := []struct { + name string + variable string + wantPrefix string + wantVariable string + wantSuffix string + wantErr error + errSubstr string + }{ + { + name: "plain text", + variable: "static-value", + wantPrefix: "static-value", + }, + { + name: "empty string", + variable: "", + }, + { + name: "single placeholder", + variable: "{{twingate.jwt}}", + wantVariable: "jwt", + }, + { + name: "Expression with leading and trailing space", + variable: "{{ twingate.jwt }}", + wantVariable: "jwt", + }, + { + name: "Expression with prefix", + variable: " Bearer {{twingate.jwt}}", + wantPrefix: "Bearer ", + wantVariable: "jwt", + }, + { + name: "suffix after placeholder", + variable: "{{twingate.username}}/profile ", + wantVariable: "username", + wantSuffix: "/profile", + }, + { + name: "Invalid variable format", + variable: "{{invalid}}", + wantErr: ErrInvalidTemplate, + errSubstr: "invalid variable syntax", + }, + { + name: "Missing opening braces", + variable: "twingate.jwt }}", + wantErr: ErrInvalidTemplate, + errSubstr: "invalid variable syntax", + }, + { + name: "Missing closing braces", + variable: "{{ twingate.jwt", + wantErr: ErrInvalidTemplate, + errSubstr: "invalid variable syntax", + }, + { + name: "multiple variables rejected", + variable: "{{twingate.username}} {{twingate.groups}}", + wantErr: ErrInvalidTemplate, + errSubstr: "multiple variable are not supported", + }, + { + name: "non-twingate namespace", + variable: "{{other.key}}", + wantErr: ErrInvalidTemplate, + errSubstr: "unsupported namespace", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + template, err := New(tt.variable) + + if tt.wantErr != nil { + require.ErrorIs(t, err, tt.wantErr) + assert.Contains(t, err.Error(), tt.errSubstr) + + return + } + + require.NoError(t, err) + + assert.Equal(t, tt.wantPrefix, template.prefix) + assert.Equal(t, tt.wantVariable, template.variable) + assert.Equal(t, tt.wantSuffix, template.suffix) + }) + } +} + +func TestParser_Evaluate(t *testing.T) { + tests := []struct { + name string + template Template + values map[string]string + want string + wantErr error + errSubstr string + }{ + { + name: "Success evaluate", + template: Template{prefix: "Prefix ", variable: "foo", suffix: " suffix"}, + values: map[string]string{"foo": "bar", "extra": "foo"}, + want: "Prefix bar suffix", + }, + { + name: "Success evaluate", + template: Template{prefix: "Bearer ", variable: "jwt"}, + values: map[string]string{"jwt": "test-token", "extra": "foo"}, + want: "Bearer test-token", + }, + { + name: "Missing variable", + template: Template{prefix: "Bearer ", variable: "jwt", suffix: ""}, + values: map[string]string{}, + wantErr: ErrUnknownVariable, + errSubstr: "jwt", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := tt.template.Evaluate(tt.values) + + if tt.wantErr != nil { + require.ErrorIs(t, err, tt.wantErr) + assert.Contains(t, err.Error(), tt.errSubstr) + + return + } + + require.NoError(t, err) + assert.Equal(t, tt.want, result) + }) + } +} diff --git a/internal/webapphandler/web_app.go b/internal/webapphandler/web_app.go new file mode 100644 index 0000000..3f6b72c --- /dev/null +++ b/internal/webapphandler/web_app.go @@ -0,0 +1,76 @@ +// Copyright (c) Twingate Inc. +// SPDX-License-Identifier: MPL-2.0 + +package webapphandler + +import ( + "fmt" + "net/http" + "net/http/httputil" + "net/url" + "strings" + + "go.uber.org/zap" + + "gateway/internal/connect" + "gateway/internal/httpproxy" + "gateway/internal/metrics" + "gateway/internal/token" + "gateway/internal/utils/parser" +) + +type Handler struct { + proxy http.Handler +} + +func NewHandler(cfg Config) *Handler { + proxy := &httputil.ReverseProxy{ + Rewrite: func(r *httputil.ProxyRequest) { + conn := httpproxy.ProxyConnFromContext(r.In.Context()) + + if err := rewrite(r, conn, cfg.headers); err != nil { + cfg.logger.Error("failed to rewrite headers", zap.Error(err)) + } + }, + Transport: metrics.InstrumentRoundTripper(cfg.roundTripperMetrics, "webapp", http.DefaultTransport), + } + + return &Handler{proxy: proxy} +} + +func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + h.proxy.ServeHTTP(w, r) +} + +func rewrite(r *httputil.ProxyRequest, conn *connect.ProxyConn, headers map[string]*parser.Template) error { + targetURL := &url.URL{ + Scheme: "http", // plain HTTP — no upstream TLS + Host: conn.GetAddress(), + } + r.SetURL(targetURL) + + claims := conn.GATClaims() + + geoLoc := "" + if claims.Device.Location != (token.DeviceLocation{}) { + geoLoc = fmt.Sprintf("%v,%v", claims.Device.Location.GeoIP.Lat, claims.Device.Location.GeoIP.Lon) + } + + variables := map[string]string{ + "jwt": conn.GetToken(), + "username": claims.User.Username, + "groups": strings.Join(claims.User.Groups, ","), + "clientGeoLoc": geoLoc, + } + + for headerName, template := range headers { + headerValue, err := template.Evaluate(variables) + if err != nil { + return fmt.Errorf("header %q: %w", headerName, err) + } + + r.Out.Header.Set(headerName, headerValue) + } + + return nil +} diff --git a/internal/webapphandler/web_app_config.go b/internal/webapphandler/web_app_config.go new file mode 100644 index 0000000..dd5603c --- /dev/null +++ b/internal/webapphandler/web_app_config.go @@ -0,0 +1,47 @@ +// Copyright (c) Twingate Inc. +// SPDX-License-Identifier: MPL-2.0 + +package webapphandler + +import ( + "errors" + "fmt" + "slices" + + "go.uber.org/zap" + + "gateway/internal/metrics" + "gateway/internal/utils/parser" +) + +var ErrUnsupportedVariable = errors.New("unsupported variable") + +var allowedVariables = []string{ + "jwt", "username", "groups", "clientGeoLoc", +} + +type Config struct { + headers map[string]*parser.Template + roundTripperMetrics *metrics.RoundTripperMetrics + logger *zap.Logger +} + +func NewConfig(rawHeaders map[string]string, roundTripperMetrics *metrics.RoundTripperMetrics, logger *zap.Logger) (*Config, error) { + headers := make(map[string]*parser.Template, len(rawHeaders)) + + for name, value := range rawHeaders { + tmpl, err := parser.New(value) + if err != nil { + return nil, fmt.Errorf("header %q: %w", name, err) + } + + variable := tmpl.Variable() + if variable != "" && !slices.Contains(allowedVariables, variable) { + return nil, fmt.Errorf("header %q: %w %q", name, ErrUnsupportedVariable, variable) + } + + headers[name] = tmpl + } + + return &Config{headers: headers, roundTripperMetrics: roundTripperMetrics, logger: logger}, nil +} diff --git a/internal/webapphandler/web_app_config_test.go b/internal/webapphandler/web_app_config_test.go new file mode 100644 index 0000000..0e71ba2 --- /dev/null +++ b/internal/webapphandler/web_app_config_test.go @@ -0,0 +1,67 @@ +// Copyright (c) Twingate Inc. +// SPDX-License-Identifier: MPL-2.0 + +package webapphandler + +import ( + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/zap" + + "gateway/internal/utils/parser" +) + +func TestNewConfig(t *testing.T) { + tests := []struct { + name string + headers map[string]string + wantErr error + }{ + { + name: "valid header templates", + headers: map[string]string{ + "Authorization": "Bearer {{twingate.jwt}}", + "X-Username": "{{twingate.username}}", + "X-Twingate": "test", + }, + }, + { + name: "unsupported variable", + headers: map[string]string{ + "X-Bad": "{{twingate.invalid}}", + }, + wantErr: ErrUnsupportedVariable, + }, + { + name: "invalid template syntax", + headers: map[string]string{ + "X-Bad": "{{invalid}}", + }, + wantErr: parser.ErrInvalidTemplate, + }, + { + name: "empty headers", + headers: map[string]string{}, + }, + { + name: "nil headers", + headers: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg, err := NewConfig(tt.headers, nil, zap.NewNop()) + + if tt.wantErr != nil { + require.ErrorIs(t, err, tt.wantErr) + + return + } + + require.NoError(t, err) + require.NotNil(t, cfg) + }) + } +} diff --git a/internal/webapphandler/web_app_test.go b/internal/webapphandler/web_app_test.go new file mode 100644 index 0000000..2db97a8 --- /dev/null +++ b/internal/webapphandler/web_app_test.go @@ -0,0 +1,126 @@ +// Copyright (c) Twingate Inc. +// SPDX-License-Identifier: MPL-2.0 + +package webapphandler + +import ( + "net/http" + "net/http/httptest" + "net/http/httputil" + "testing" + + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + + "gateway/internal/connect" + "gateway/internal/token" + "gateway/internal/utils/parser" +) + +func mustParse(t *testing.T, templates map[string]string) map[string]*parser.Template { + t.Helper() + + result := make(map[string]*parser.Template, len(templates)) + + for name, tmpl := range templates { + parsed, err := parser.New(tmpl) + require.NoError(t, err, "failed to parse template for header %q", name) + + result[name] = parsed + } + + return result +} + +func TestRewrite(t *testing.T) { + baseClaims := &token.GATClaims{ + User: token.User{ + ID: "user-1", + Username: "alice@acme.com", + Groups: []string{"Everyone", "Engineering"}, + }, + Device: token.Device{ + ID: "device-1", + Location: token.DeviceLocation{ + GeoIP: token.GeoIP{Lat: 37.5, Lon: -122.4}, + }, + }, + } + + tests := []struct { + name string + address string + jwtToken string + claims *token.GATClaims + headers map[string]string + wantHeaders map[string]string + }{ + { + name: "resolves all variables", + jwtToken: "test-token", + claims: baseClaims, + headers: map[string]string{ + "Authorization": "Bearer {{twingate.jwt}}", + "X-Username": "{{twingate.username}}", + "X-Groups": "{{twingate.groups}}", + "X-Geo": "{{twingate.clientGeoLoc}}", + "Existing": "new-value", + }, + wantHeaders: map[string]string{ + "Authorization": "Bearer test-token", + "X-Username": "alice@acme.com", + "X-Groups": "Everyone,Engineering", + "X-Geo": "37.5,-122.4", + "Existing": "new-value", + }, + }, + { + name: "empty geo when no device location", + jwtToken: "test-token", + claims: &token.GATClaims{ + User: baseClaims.User, + Device: token.Device{ID: "device-1"}, + Resource: baseClaims.Resource, + }, + headers: map[string]string{ + "X-Geo": "{{twingate.clientGeoLoc}}", + }, + wantHeaders: map[string]string{"X-Geo": ""}, + }, + { + name: "empty headers", + jwtToken: "test-token", + claims: baseClaims, + headers: map[string]string{}, + wantHeaders: map[string]string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + connMetrics := connect.CreateProxyConnMetrics(prometheus.NewRegistry()) + conn := connect.NewProxyConn(nil, nil, nil, zap.NewNop(), connMetrics) + conn.Address = tt.address + conn.Token = tt.jwtToken + conn.Claims = tt.claims + + outReq := httptest.NewRequest(http.MethodGet, "http://test/api/resource", nil) + outReq.Header.Set("Existing", "old-value") + + proxyReq := &httputil.ProxyRequest{ + In: httptest.NewRequest(http.MethodGet, "http://test/api/resource", nil), + Out: outReq, + } + parsedHeaders := mustParse(t, tt.headers) + + err := rewrite(proxyReq, conn, parsedHeaders) + require.NoError(t, err) + + for name, wantValue := range tt.wantHeaders { + assert.Equal(t, wantValue, proxyReq.Out.Header.Get(name), "header %q", name) + } + }) + } +} From b17491a700c452b88a4e23bb900c7e18811e5a7d Mon Sep 17 00:00:00 2001 From: Clement Tee Date: Thu, 28 May 2026 22:32:21 +0800 Subject: [PATCH 02/17] Better code --- internal/config/config.go | 30 ++++ internal/config/config_test.go | 64 ++++++++ internal/httpproxy/utils/parser/parser.go | 85 ++++++++++ .../httpproxy/utils/parser/parser_test.go | 145 +++++++++++++++++ internal/kuberneteshandler/kubernetes.go | 2 +- internal/metrics/http_middleware.go | 3 +- internal/metrics/round_tripper.go | 17 +- internal/metrics/round_tripper_test.go | 26 +-- internal/proxy/proxy.go | 5 +- internal/utils/parser/parser.go | 82 ---------- internal/utils/parser/parser_test.go | 151 ------------------ internal/webapphandler/web_app.go | 8 +- internal/webapphandler/web_app_config.go | 23 +-- internal/webapphandler/web_app_config_test.go | 13 +- internal/webapphandler/web_app_test.go | 18 ++- 15 files changed, 378 insertions(+), 294 deletions(-) create mode 100644 internal/httpproxy/utils/parser/parser.go create mode 100644 internal/httpproxy/utils/parser/parser_test.go delete mode 100644 internal/utils/parser/parser.go delete mode 100644 internal/utils/parser/parser_test.go diff --git a/internal/config/config.go b/internal/config/config.go index 0618f25..0ee4079 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -8,6 +8,7 @@ import ( "fmt" "net/http" "os" + "slices" "strings" "time" @@ -15,6 +16,8 @@ import ( "go.uber.org/zap" "go.yaml.in/yaml/v4" "golang.org/x/crypto/ssh" + + "gateway/internal/httpproxy/utils/parser" ) var ( @@ -23,6 +26,7 @@ var ( ErrDuplicateUpstream = errors.New("duplicate upstream name") ErrInvalidSSHKeyType = errors.New("invalid SSH key type") ErrNegativeTTL = errors.New("TTL must be non-negative") + ErrUnsupportedKey = errors.New("unsupported key") ) const ( @@ -273,6 +277,12 @@ func (c *Config) Validate() error { } } + if c.WebApp != nil { + if err := c.WebApp.Validate(); err != nil { + return fmt.Errorf("webApp config: %w", err) + } + } + // Check that at least one protocol is configured if c.Kubernetes == nil && c.SSH == nil && c.WebApp == nil { return fmt.Errorf("%w: at least one protocol (Kubernetes, SSH, or WebApp) must be configured", ErrRequired) @@ -441,6 +451,26 @@ func (v *SSHCAVaultConfig) Validate() error { return nil } +func (w *WebAppConfig) Validate() error { + allowedWebAppKeys := []string{ + "jwt", "username", "groups", "clientGeoLoc", + } + + for name, value := range w.Headers { + tmpl, err := parser.NewTemplate(value) + if err != nil { + return fmt.Errorf("header %q: %w", name, err) + } + + key := tmpl.Key() + if key != "" && !slices.Contains(allowedWebAppKeys, key) { + return fmt.Errorf("header %q: %w %q", name, ErrUnsupportedKey, key) + } + } + + return nil +} + var ( ErrConflictingAuthConfig = errors.New("only one of 'token', 'appRole', 'gcp', or 'aws' can be specified for Vault auth") ErrConflictingSecretIDConfig = errors.New("only one of 'secretID' or 'secretIDFile' can be specified") diff --git a/internal/config/config_test.go b/internal/config/config_test.go index b6898e4..4c04b23 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -304,6 +304,9 @@ func TestConfig_Validate(t *testing.T) { PrivateKeyFile: "tls.key", }, Kubernetes: &KubernetesConfig{}, + WebApp: &WebAppConfig{ + Headers: map[string]string{"Authorization": "Bearer {{twingate.jwt}}"}, + }, }, wantErr: false, }, @@ -365,6 +368,23 @@ func TestConfig_Validate(t *testing.T) { wantErr: true, errContains: "at least one protocol", }, + { + name: "invalid WebApp header template", + config: &Config{ + Twingate: TwingateConfig{Network: "test"}, + Port: 8443, + MetricsPort: 9090, + TLS: TLSConfig{ + CertificateFile: "tls.crt", + PrivateKeyFile: "tls.key", + }, + WebApp: &WebAppConfig{ + Headers: map[string]string{"Authorization": "Bearer {{twingate.jwt"}, + }, + }, + wantErr: true, + errContains: "webApp config", + }, } for _, tt := range tests { @@ -1104,3 +1124,47 @@ func TestSSHCAVaultAWSConfig_Validate(t *testing.T) { }) } } + +func TestWebAppConfig_Validate(t *testing.T) { + tests := []struct { + name string + config WebAppConfig + wantErr bool + errContains string + }{ + { + name: "empty headers", + config: WebAppConfig{Headers: map[string]string{}}, + wantErr: false, + }, + { + name: "valid template variable", + config: WebAppConfig{Headers: map[string]string{"Authorization": "Bearer {{twingate.jwt}}"}}, + wantErr: false, + }, + { + name: "invalid template", + config: WebAppConfig{Headers: map[string]string{"X-Bad": "{{invalid"}}, + wantErr: true, + errContains: "invalid brackets syntax", + }, + { + name: "unsupported key", + config: WebAppConfig{Headers: map[string]string{"X-Bad": "{{twingate.unknown}}"}}, + wantErr: true, + errContains: "unsupported key", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.config.Validate() + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errContains) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/internal/httpproxy/utils/parser/parser.go b/internal/httpproxy/utils/parser/parser.go new file mode 100644 index 0000000..6700856 --- /dev/null +++ b/internal/httpproxy/utils/parser/parser.go @@ -0,0 +1,85 @@ +// Copyright (c) Twingate Inc. +// SPDX-License-Identifier: MPL-2.0 + +package parser + +import ( + "errors" + "fmt" + "regexp" + "strings" + "unicode" +) + +const ( + allowedNamespace = "twingate" +) + +var ( + ErrInvalidTemplate = errors.New("invalid template") + ErrUnknownKey = errors.New("unknown key") +) + +var templateRe = regexp.MustCompile( + `^(.*?)` + // prefix + `{{\s*` + // opening braces + `([a-zA-Z0-9_-]+)` + // namespace + `\.` + + `([a-zA-Z0-9_-]+)` + // key + `\s*}}` + // closing braces + `(.*)$`, // suffix +) + +type Template struct { + prefix string + key string + suffix string +} + +// NewTemplate parses a string like " {{.}} " into a Template. +func NewTemplate(s string) (*Template, error) { + match := templateRe.FindStringSubmatch(s) + + if match == nil { + if strings.Contains(s, "{{") || strings.Contains(s, "}}") { + return nil, fmt.Errorf("%w: invalid brackets syntax", ErrInvalidTemplate) + } + + return &Template{prefix: s}, nil + } + + prefix, namespace, key, suffix := match[1], match[2], match[3], match[4] + + if namespace != allowedNamespace { + return nil, fmt.Errorf("%w: unsupported namespace %q", ErrInvalidTemplate, namespace) + } + + if templateRe.MatchString(suffix) { + return nil, fmt.Errorf("%w: multiple templates are not supported", ErrInvalidTemplate) + } + + return &Template{ + prefix: strings.TrimLeftFunc(prefix, unicode.IsSpace), + key: key, + suffix: strings.TrimRightFunc(suffix, unicode.IsSpace), + }, nil +} + +func (t *Template) Key() string { + return t.key +} + +// Evaluate replaces the key in the template with the corresponding value from the map +// and returns the resulting string along with the prefix and suffix. +func (t *Template) Evaluate(values map[string]string) (string, error) { + if t.key == "" { + return t.prefix, nil + } + + result, ok := values[t.key] + if !ok { + return "", fmt.Errorf("%w: %q", ErrUnknownKey, t.key) + } + + return t.prefix + result + t.suffix, nil +} diff --git a/internal/httpproxy/utils/parser/parser_test.go b/internal/httpproxy/utils/parser/parser_test.go new file mode 100644 index 0000000..fd9ba71 --- /dev/null +++ b/internal/httpproxy/utils/parser/parser_test.go @@ -0,0 +1,145 @@ +// Copyright (c) Twingate Inc. +// SPDX-License-Identifier: MPL-2.0 + +package parser + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParser_NewTemplate(t *testing.T) { + tests := []struct { + name string + input string + wantPrefix string + wantKey string + wantSuffix string + wantErr error + errSubstr string + }{ + { + name: "plain text", + input: "static-value", + wantPrefix: "static-value", + }, + { + name: "empty string", + input: "", + }, + { + name: "single placeholder", + input: "{{twingate.jwt}}", + wantKey: "jwt", + }, + { + name: "template with leading and trailing space", + input: "{{ twingate.jwt }}", + wantKey: "jwt", + }, + { + name: "template with prefix", + input: " Bearer {{twingate.jwt}}", + wantPrefix: "Bearer ", + wantKey: "jwt", + }, + { + name: "suffix after placeholder", + input: "{{twingate.username}}/profile ", + wantKey: "username", + wantSuffix: "/profile", + }, + { + name: "Invalid template format", + input: "{{invalid}}", + wantErr: ErrInvalidTemplate, + errSubstr: "invalid brackets syntax", + }, + { + name: "Missing opening braces", + input: "twingate.jwt }}", + wantErr: ErrInvalidTemplate, + errSubstr: "invalid brackets syntax", + }, + { + name: "Missing closing braces", + input: "{{ twingate.jwt", + wantErr: ErrInvalidTemplate, + errSubstr: "invalid brackets syntax", + }, + { + name: "multiple templates rejected", + input: "{{twingate.username}} {{twingate.groups}}", + wantErr: ErrInvalidTemplate, + errSubstr: "multiple templates are not supported", + }, + { + name: "non-twingate namespace", + input: "{{other.key}}", + wantErr: ErrInvalidTemplate, + errSubstr: "unsupported namespace", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + template, err := NewTemplate(tt.input) + + if tt.wantErr != nil { + require.ErrorIs(t, err, tt.wantErr) + assert.Contains(t, err.Error(), tt.errSubstr) + + return + } + + require.NoError(t, err) + + assert.Equal(t, tt.wantPrefix, template.prefix) + assert.Equal(t, tt.wantKey, template.key) + assert.Equal(t, tt.wantSuffix, template.suffix) + }) + } +} + +func TestParser_Evaluate(t *testing.T) { + tests := []struct { + name string + template Template + values map[string]string + want string + wantErr error + errSubstr string + }{ + { + name: "Success", + template: Template{prefix: "Prefix ", key: "foo", suffix: " suffix"}, + values: map[string]string{"foo": "bar", "extra": "foo"}, + want: "Prefix bar suffix", + }, + { + name: "Missing key", + template: Template{prefix: "Bearer ", key: "jwt", suffix: ""}, + values: map[string]string{}, + wantErr: ErrUnknownKey, + errSubstr: "jwt", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := tt.template.Evaluate(tt.values) + + if tt.wantErr != nil { + require.ErrorIs(t, err, tt.wantErr) + assert.Contains(t, err.Error(), tt.errSubstr) + + return + } + + require.NoError(t, err) + assert.Equal(t, tt.want, result) + }) + } +} diff --git a/internal/kuberneteshandler/kubernetes.go b/internal/kuberneteshandler/kubernetes.go index 2e71ed3..ffd6e94 100644 --- a/internal/kuberneteshandler/kubernetes.go +++ b/internal/kuberneteshandler/kubernetes.go @@ -44,7 +44,7 @@ func NewHandler(cfg Config) (*Handler, error) { conn := httpproxy.ProxyConnFromContext(r.In.Context()) rewrite(r, conn) }, - Transport: metrics.InstrumentRoundTripper(cfg.roundTripperMetrics, "kubernetes", transport), + Transport: metrics.InstrumentRoundTripper(cfg.roundTripperMetrics, metrics.ResourceTypeKubernetes, transport), } handler := &Handler{ diff --git a/internal/metrics/http_middleware.go b/internal/metrics/http_middleware.go index bdfbf65..ee82590 100644 --- a/internal/metrics/http_middleware.go +++ b/internal/metrics/http_middleware.go @@ -17,8 +17,7 @@ import ( // Metric label names. const ( - labelRequestType = "type" - labelResourceType = "resourceType" + labelRequestType = "type" ) // Request type values. diff --git a/internal/metrics/round_tripper.go b/internal/metrics/round_tripper.go index 6997365..a3e0b78 100644 --- a/internal/metrics/round_tripper.go +++ b/internal/metrics/round_tripper.go @@ -11,6 +11,17 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" ) +// Metric label names. + +const labelResourceType = "resource_type" + +// Resource type values. + +const ( + ResourceTypeKubernetes = "kubernetes" + ResourceTypeWebApp = "web_app" +) + type RoundTripperMetrics struct { requestsTotal *prometheus.CounterVec activeRequests *prometheus.GaugeVec @@ -23,13 +34,13 @@ func RegisterRoundTripperMetrics(registry *prometheus.Registry) *RoundTripperMet Namespace: Namespace, Name: "api_server_requests_total", Help: "Total number of requests from Gateway to API Server processed", - }, []string{"resourceType", "type", "method", "code"}), + }, []string{labelResourceType, "type", "method", "code"}), activeRequests: prometheus.NewGaugeVec(prometheus.GaugeOpts{ Namespace: Namespace, Name: "api_server_active_requests", Help: "Number of currently active requests from Gateway to API Server", - }, []string{"resourceType", "type"}), + }, []string{labelResourceType, "type"}), requestDuration: prometheus.NewHistogramVec( prometheus.HistogramOpts{ @@ -37,7 +48,7 @@ func RegisterRoundTripperMetrics(registry *prometheus.Registry) *RoundTripperMet Name: "api_server_request_duration_seconds", Help: "Measures the initial HTTP request-response latency between Gateway and API Server in seconds. For HTTP streaming, WebSocket, and SPDY connections, this metric captures only the setup time and not the duration of the data transfer.", Buckets: prometheus.DefBuckets, - }, []string{"resourceType", "type", "method", "code"}), + }, []string{labelResourceType, "type", "method", "code"}), } registry.MustRegister(c.requestsTotal, c.activeRequests, c.requestDuration) diff --git a/internal/metrics/round_tripper_test.go b/internal/metrics/round_tripper_test.go index dd99473..6b05052 100644 --- a/internal/metrics/round_tripper_test.go +++ b/internal/metrics/round_tripper_test.go @@ -23,7 +23,7 @@ func TestInstrumentRoundTripper(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) - transport := InstrumentRoundTripper(collectors, "kubernetes", promhttp.RoundTripperFunc(func(r *http.Request) (*http.Response, error) { + transport := InstrumentRoundTripper(collectors, ResourceTypeKubernetes, promhttp.RoundTripperFunc(func(r *http.Request) (*http.Response, error) { return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody, Request: r}, nil })) @@ -38,20 +38,20 @@ func TestInstrumentRoundTripper(t *testing.T) { labelsByMetric := testutil.ExtractLabelsFromMetrics(metricFamilies) expectedLabels := map[string]map[string]string{ "twingate_gateway_api_server_requests_total": { - "resourceType": "kubernetes", - "type": "http", - "method": "get", - "code": "200", + "resource_type": "kubernetes", + "type": "http", + "method": "get", + "code": "200", }, "twingate_gateway_api_server_active_requests": { - "resourceType": "kubernetes", - "type": "http", + "resource_type": "kubernetes", + "type": "http", }, "twingate_gateway_api_server_request_duration_seconds": { - "resourceType": "kubernetes", - "type": "http", - "method": "get", - "code": "200", + "resource_type": "kubernetes", + "type": "http", + "method": "get", + "code": "200", }, } assert.Equal(t, expectedLabels, labelsByMetric) @@ -67,8 +67,8 @@ func TestInstrumentRoundTripper_MultipleTransports(t *testing.T) { }) // Instrumenting multiple transports should not panic - k8sTransport := InstrumentRoundTripper(collectors, "kubernetes", mockTransport) - webAppTransport := InstrumentRoundTripper(collectors, "webapp", mockTransport) + k8sTransport := InstrumentRoundTripper(collectors, ResourceTypeKubernetes, mockTransport) + webAppTransport := InstrumentRoundTripper(collectors, ResourceTypeWebApp, mockTransport) req := httptest.NewRequest(http.MethodGet, "/", nil) diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index fc2092f..8cb9674 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -49,7 +49,10 @@ func NewProxy(config *gatewayconfig.Config, registry *prometheus.Registry, logge handlers := make(map[string]http.Handler) - roundTripperMetrics := metrics.RegisterRoundTripperMetrics(registry) + var roundTripperMetrics *metrics.RoundTripperMetrics + if config.Kubernetes != nil || config.WebApp != nil { + roundTripperMetrics = metrics.RegisterRoundTripperMetrics(registry) + } if config.Kubernetes != nil { k8sConfig, err := kuberneteshandler.NewConfig(&config.AuditLog, config.Kubernetes, roundTripperMetrics, logger) diff --git a/internal/utils/parser/parser.go b/internal/utils/parser/parser.go deleted file mode 100644 index 8235366..0000000 --- a/internal/utils/parser/parser.go +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright (c) Twingate Inc. -// SPDX-License-Identifier: MPL-2.0 - -package parser - -import ( - "errors" - "fmt" - "regexp" - "strings" - "unicode" -) - -const ( - allowedNamespace = "twingate" -) - -var ( - ErrInvalidTemplate = errors.New("invalid template") - ErrUnknownVariable = errors.New("unknown variable") -) - -var templateRe = regexp.MustCompile( - `^(.*?)` + // prefix - `{{\s*` + // opening braces - `([a-zA-Z0-9_-]+)` + // namespace - `\.` + - `([a-zA-Z0-9_-]+)` + // key - `\s*}}` + // closing braces - `(.*)$`, // suffix -) - -type Template struct { - prefix string - variable string - suffix string -} - -func New(s string) (*Template, error) { - match := templateRe.FindStringSubmatch(s) - - if match == nil { - if strings.Contains(s, "{{") || strings.Contains(s, "}}") { - return nil, fmt.Errorf("%w: invalid variable syntax", ErrInvalidTemplate) - } - - return &Template{prefix: s}, nil - } - - prefix, namespace, variable, suffix := match[1], match[2], match[3], match[4] - - if namespace != allowedNamespace { - return nil, fmt.Errorf("%w: unsupported namespace %q", ErrInvalidTemplate, namespace) - } - - if templateRe.MatchString(suffix) { - return nil, fmt.Errorf("%w: multiple variable are not supported", ErrInvalidTemplate) - } - - return &Template{ - prefix: strings.TrimLeftFunc(prefix, unicode.IsSpace), - variable: variable, - suffix: strings.TrimRightFunc(suffix, unicode.IsSpace), - }, nil -} - -func (t *Template) Variable() string { - return t.variable -} - -func (t *Template) Evaluate(variables map[string]string) (string, error) { - if t.variable == "" { - return t.prefix, nil - } - - result, ok := variables[t.variable] - if !ok { - return "", fmt.Errorf("%w: %q", ErrUnknownVariable, t.variable) - } - - return t.prefix + result + t.suffix, nil -} diff --git a/internal/utils/parser/parser_test.go b/internal/utils/parser/parser_test.go deleted file mode 100644 index 83c059a..0000000 --- a/internal/utils/parser/parser_test.go +++ /dev/null @@ -1,151 +0,0 @@ -// Copyright (c) Twingate Inc. -// SPDX-License-Identifier: MPL-2.0 - -package parser - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestParser_New(t *testing.T) { - tests := []struct { - name string - variable string - wantPrefix string - wantVariable string - wantSuffix string - wantErr error - errSubstr string - }{ - { - name: "plain text", - variable: "static-value", - wantPrefix: "static-value", - }, - { - name: "empty string", - variable: "", - }, - { - name: "single placeholder", - variable: "{{twingate.jwt}}", - wantVariable: "jwt", - }, - { - name: "Expression with leading and trailing space", - variable: "{{ twingate.jwt }}", - wantVariable: "jwt", - }, - { - name: "Expression with prefix", - variable: " Bearer {{twingate.jwt}}", - wantPrefix: "Bearer ", - wantVariable: "jwt", - }, - { - name: "suffix after placeholder", - variable: "{{twingate.username}}/profile ", - wantVariable: "username", - wantSuffix: "/profile", - }, - { - name: "Invalid variable format", - variable: "{{invalid}}", - wantErr: ErrInvalidTemplate, - errSubstr: "invalid variable syntax", - }, - { - name: "Missing opening braces", - variable: "twingate.jwt }}", - wantErr: ErrInvalidTemplate, - errSubstr: "invalid variable syntax", - }, - { - name: "Missing closing braces", - variable: "{{ twingate.jwt", - wantErr: ErrInvalidTemplate, - errSubstr: "invalid variable syntax", - }, - { - name: "multiple variables rejected", - variable: "{{twingate.username}} {{twingate.groups}}", - wantErr: ErrInvalidTemplate, - errSubstr: "multiple variable are not supported", - }, - { - name: "non-twingate namespace", - variable: "{{other.key}}", - wantErr: ErrInvalidTemplate, - errSubstr: "unsupported namespace", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - template, err := New(tt.variable) - - if tt.wantErr != nil { - require.ErrorIs(t, err, tt.wantErr) - assert.Contains(t, err.Error(), tt.errSubstr) - - return - } - - require.NoError(t, err) - - assert.Equal(t, tt.wantPrefix, template.prefix) - assert.Equal(t, tt.wantVariable, template.variable) - assert.Equal(t, tt.wantSuffix, template.suffix) - }) - } -} - -func TestParser_Evaluate(t *testing.T) { - tests := []struct { - name string - template Template - values map[string]string - want string - wantErr error - errSubstr string - }{ - { - name: "Success evaluate", - template: Template{prefix: "Prefix ", variable: "foo", suffix: " suffix"}, - values: map[string]string{"foo": "bar", "extra": "foo"}, - want: "Prefix bar suffix", - }, - { - name: "Success evaluate", - template: Template{prefix: "Bearer ", variable: "jwt"}, - values: map[string]string{"jwt": "test-token", "extra": "foo"}, - want: "Bearer test-token", - }, - { - name: "Missing variable", - template: Template{prefix: "Bearer ", variable: "jwt", suffix: ""}, - values: map[string]string{}, - wantErr: ErrUnknownVariable, - errSubstr: "jwt", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := tt.template.Evaluate(tt.values) - - if tt.wantErr != nil { - require.ErrorIs(t, err, tt.wantErr) - assert.Contains(t, err.Error(), tt.errSubstr) - - return - } - - require.NoError(t, err) - assert.Equal(t, tt.want, result) - }) - } -} diff --git a/internal/webapphandler/web_app.go b/internal/webapphandler/web_app.go index 3f6b72c..6cf6bd1 100644 --- a/internal/webapphandler/web_app.go +++ b/internal/webapphandler/web_app.go @@ -14,9 +14,9 @@ import ( "gateway/internal/connect" "gateway/internal/httpproxy" + "gateway/internal/httpproxy/utils/parser" "gateway/internal/metrics" "gateway/internal/token" - "gateway/internal/utils/parser" ) type Handler struct { @@ -32,7 +32,7 @@ func NewHandler(cfg Config) *Handler { cfg.logger.Error("failed to rewrite headers", zap.Error(err)) } }, - Transport: metrics.InstrumentRoundTripper(cfg.roundTripperMetrics, "webapp", http.DefaultTransport), + Transport: metrics.InstrumentRoundTripper(cfg.roundTripperMetrics, metrics.ResourceTypeWebApp, http.DefaultTransport), } return &Handler{proxy: proxy} @@ -63,8 +63,8 @@ func rewrite(r *httputil.ProxyRequest, conn *connect.ProxyConn, headers map[stri "clientGeoLoc": geoLoc, } - for headerName, template := range headers { - headerValue, err := template.Evaluate(variables) + for headerName, tmpl := range headers { + headerValue, err := tmpl.Evaluate(variables) if err != nil { return fmt.Errorf("header %q: %w", headerName, err) } diff --git a/internal/webapphandler/web_app_config.go b/internal/webapphandler/web_app_config.go index dd5603c..43610c9 100644 --- a/internal/webapphandler/web_app_config.go +++ b/internal/webapphandler/web_app_config.go @@ -4,42 +4,29 @@ package webapphandler import ( - "errors" "fmt" - "slices" "go.uber.org/zap" + "gateway/internal/httpproxy/utils/parser" "gateway/internal/metrics" - "gateway/internal/utils/parser" ) -var ErrUnsupportedVariable = errors.New("unsupported variable") - -var allowedVariables = []string{ - "jwt", "username", "groups", "clientGeoLoc", -} - type Config struct { headers map[string]*parser.Template roundTripperMetrics *metrics.RoundTripperMetrics logger *zap.Logger } -func NewConfig(rawHeaders map[string]string, roundTripperMetrics *metrics.RoundTripperMetrics, logger *zap.Logger) (*Config, error) { - headers := make(map[string]*parser.Template, len(rawHeaders)) +func NewConfig(configHeaders map[string]string, roundTripperMetrics *metrics.RoundTripperMetrics, logger *zap.Logger) (*Config, error) { + headers := make(map[string]*parser.Template, len(configHeaders)) - for name, value := range rawHeaders { - tmpl, err := parser.New(value) + for name, value := range configHeaders { + tmpl, err := parser.NewTemplate(value) if err != nil { return nil, fmt.Errorf("header %q: %w", name, err) } - variable := tmpl.Variable() - if variable != "" && !slices.Contains(allowedVariables, variable) { - return nil, fmt.Errorf("header %q: %w %q", name, ErrUnsupportedVariable, variable) - } - headers[name] = tmpl } diff --git a/internal/webapphandler/web_app_config_test.go b/internal/webapphandler/web_app_config_test.go index 0e71ba2..0993d36 100644 --- a/internal/webapphandler/web_app_config_test.go +++ b/internal/webapphandler/web_app_config_test.go @@ -9,7 +9,7 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/zap" - "gateway/internal/utils/parser" + "gateway/internal/httpproxy/utils/parser" ) func TestNewConfig(t *testing.T) { @@ -22,21 +22,12 @@ func TestNewConfig(t *testing.T) { name: "valid header templates", headers: map[string]string{ "Authorization": "Bearer {{twingate.jwt}}", - "X-Username": "{{twingate.username}}", - "X-Twingate": "test", }, }, - { - name: "unsupported variable", - headers: map[string]string{ - "X-Bad": "{{twingate.invalid}}", - }, - wantErr: ErrUnsupportedVariable, - }, { name: "invalid template syntax", headers: map[string]string{ - "X-Bad": "{{invalid}}", + "X-Invalid": "{{invalid}}", }, wantErr: parser.ErrInvalidTemplate, }, diff --git a/internal/webapphandler/web_app_test.go b/internal/webapphandler/web_app_test.go index 2db97a8..5020128 100644 --- a/internal/webapphandler/web_app_test.go +++ b/internal/webapphandler/web_app_test.go @@ -15,8 +15,8 @@ import ( "go.uber.org/zap" "gateway/internal/connect" + "gateway/internal/httpproxy/utils/parser" "gateway/internal/token" - "gateway/internal/utils/parser" ) func mustParse(t *testing.T, templates map[string]string) map[string]*parser.Template { @@ -25,7 +25,7 @@ func mustParse(t *testing.T, templates map[string]string) map[string]*parser.Tem result := make(map[string]*parser.Template, len(templates)) for name, tmpl := range templates { - parsed, err := parser.New(tmpl) + parsed, err := parser.NewTemplate(tmpl) require.NoError(t, err, "failed to parse template for header %q", name) result[name] = parsed @@ -90,11 +90,13 @@ func TestRewrite(t *testing.T) { wantHeaders: map[string]string{"X-Geo": ""}, }, { - name: "empty headers", - jwtToken: "test-token", - claims: baseClaims, - headers: map[string]string{}, - wantHeaders: map[string]string{}, + name: "empty headers", + jwtToken: "test-token", + claims: baseClaims, + headers: map[string]string{}, + wantHeaders: map[string]string{ + "Existing": "old-value", + }, }, } @@ -119,7 +121,7 @@ func TestRewrite(t *testing.T) { require.NoError(t, err) for name, wantValue := range tt.wantHeaders { - assert.Equal(t, wantValue, proxyReq.Out.Header.Get(name), "header %q", name) + assert.Equal(t, wantValue, proxyReq.Out.Header.Get(name)) } }) } From 2e42d6a24c35fb468cbd95e7f243fe44a8093cc6 Mon Sep 17 00:00:00 2001 From: Clement Tee Date: Fri, 29 May 2026 10:40:41 +0800 Subject: [PATCH 03/17] Better regex --- internal/config/config_test.go | 2 +- internal/httpproxy/utils/parser/parser.go | 10 +++------- .../httpproxy/utils/parser/parser_test.go | 20 +++++++++---------- internal/webapphandler/web_app_test.go | 2 +- 4 files changed, 15 insertions(+), 19 deletions(-) diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 4c04b23..cdbc9ae 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -1146,7 +1146,7 @@ func TestWebAppConfig_Validate(t *testing.T) { name: "invalid template", config: WebAppConfig{Headers: map[string]string{"X-Bad": "{{invalid"}}, wantErr: true, - errContains: "invalid brackets syntax", + errContains: "invalid template format", }, { name: "unsupported key", diff --git a/internal/httpproxy/utils/parser/parser.go b/internal/httpproxy/utils/parser/parser.go index 6700856..3eadd09 100644 --- a/internal/httpproxy/utils/parser/parser.go +++ b/internal/httpproxy/utils/parser/parser.go @@ -21,13 +21,13 @@ var ( ) var templateRe = regexp.MustCompile( - `^(.*?)` + // prefix + `^([^{}]*)` + // prefix (no brackets allowed) `{{\s*` + // opening braces `([a-zA-Z0-9_-]+)` + // namespace `\.` + `([a-zA-Z0-9_-]+)` + // key `\s*}}` + // closing braces - `(.*)$`, // suffix + `([^{}]*)$`, // suffix (no brackets allowed) ) type Template struct { @@ -42,7 +42,7 @@ func NewTemplate(s string) (*Template, error) { if match == nil { if strings.Contains(s, "{{") || strings.Contains(s, "}}") { - return nil, fmt.Errorf("%w: invalid brackets syntax", ErrInvalidTemplate) + return nil, fmt.Errorf("%w: invalid template format. Format must be {{twingate.key}} ", ErrInvalidTemplate) } return &Template{prefix: s}, nil @@ -54,10 +54,6 @@ func NewTemplate(s string) (*Template, error) { return nil, fmt.Errorf("%w: unsupported namespace %q", ErrInvalidTemplate, namespace) } - if templateRe.MatchString(suffix) { - return nil, fmt.Errorf("%w: multiple templates are not supported", ErrInvalidTemplate) - } - return &Template{ prefix: strings.TrimLeftFunc(prefix, unicode.IsSpace), key: key, diff --git a/internal/httpproxy/utils/parser/parser_test.go b/internal/httpproxy/utils/parser/parser_test.go index fd9ba71..1c6bc5f 100644 --- a/internal/httpproxy/utils/parser/parser_test.go +++ b/internal/httpproxy/utils/parser/parser_test.go @@ -30,7 +30,7 @@ func TestParser_NewTemplate(t *testing.T) { input: "", }, { - name: "single placeholder", + name: "template only", input: "{{twingate.jwt}}", wantKey: "jwt", }, @@ -46,34 +46,34 @@ func TestParser_NewTemplate(t *testing.T) { wantKey: "jwt", }, { - name: "suffix after placeholder", + name: "template with suffix", input: "{{twingate.username}}/profile ", wantKey: "username", wantSuffix: "/profile", }, { - name: "Invalid template format", + name: "invalid template", input: "{{invalid}}", wantErr: ErrInvalidTemplate, - errSubstr: "invalid brackets syntax", + errSubstr: "invalid template format", }, { - name: "Missing opening braces", + name: "missing opening braces", input: "twingate.jwt }}", wantErr: ErrInvalidTemplate, - errSubstr: "invalid brackets syntax", + errSubstr: "invalid template format", }, { - name: "Missing closing braces", + name: "missing closing braces", input: "{{ twingate.jwt", wantErr: ErrInvalidTemplate, - errSubstr: "invalid brackets syntax", + errSubstr: "invalid template format", }, { - name: "multiple templates rejected", + name: "multiple templates", input: "{{twingate.username}} {{twingate.groups}}", wantErr: ErrInvalidTemplate, - errSubstr: "multiple templates are not supported", + errSubstr: "invalid template format", }, { name: "non-twingate namespace", diff --git a/internal/webapphandler/web_app_test.go b/internal/webapphandler/web_app_test.go index 5020128..7864e8d 100644 --- a/internal/webapphandler/web_app_test.go +++ b/internal/webapphandler/web_app_test.go @@ -58,7 +58,7 @@ func TestRewrite(t *testing.T) { wantHeaders map[string]string }{ { - name: "resolves all variables", + name: "resolves all header templates", jwtToken: "test-token", claims: baseClaims, headers: map[string]string{ From a5b9ebdf2b5f7f7162a3db5e1a69ea60f3146c12 Mon Sep 17 00:00:00 2001 From: Clement Tee Date: Fri, 29 May 2026 11:00:10 +0800 Subject: [PATCH 04/17] Better wording --- internal/config/config_test.go | 4 ++-- internal/httpproxy/utils/parser/parser.go | 2 +- internal/httpproxy/utils/parser/parser_test.go | 8 ++++---- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/internal/config/config_test.go b/internal/config/config_test.go index cdbc9ae..4d0b446 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -1143,10 +1143,10 @@ func TestWebAppConfig_Validate(t *testing.T) { wantErr: false, }, { - name: "invalid template", + name: "invalid template syntax", config: WebAppConfig{Headers: map[string]string{"X-Bad": "{{invalid"}}, wantErr: true, - errContains: "invalid template format", + errContains: "unsupported syntax", }, { name: "unsupported key", diff --git a/internal/httpproxy/utils/parser/parser.go b/internal/httpproxy/utils/parser/parser.go index 3eadd09..ec198a9 100644 --- a/internal/httpproxy/utils/parser/parser.go +++ b/internal/httpproxy/utils/parser/parser.go @@ -42,7 +42,7 @@ func NewTemplate(s string) (*Template, error) { if match == nil { if strings.Contains(s, "{{") || strings.Contains(s, "}}") { - return nil, fmt.Errorf("%w: invalid template format. Format must be {{twingate.key}} ", ErrInvalidTemplate) + return nil, fmt.Errorf("%w: unsupported syntax. Syntax must be {{twingate.key}} ", ErrInvalidTemplate) } return &Template{prefix: s}, nil diff --git a/internal/httpproxy/utils/parser/parser_test.go b/internal/httpproxy/utils/parser/parser_test.go index 1c6bc5f..7975212 100644 --- a/internal/httpproxy/utils/parser/parser_test.go +++ b/internal/httpproxy/utils/parser/parser_test.go @@ -55,25 +55,25 @@ func TestParser_NewTemplate(t *testing.T) { name: "invalid template", input: "{{invalid}}", wantErr: ErrInvalidTemplate, - errSubstr: "invalid template format", + errSubstr: "unsupported syntax", }, { name: "missing opening braces", input: "twingate.jwt }}", wantErr: ErrInvalidTemplate, - errSubstr: "invalid template format", + errSubstr: "unsupported syntax", }, { name: "missing closing braces", input: "{{ twingate.jwt", wantErr: ErrInvalidTemplate, - errSubstr: "invalid template format", + errSubstr: "unsupported syntax", }, { name: "multiple templates", input: "{{twingate.username}} {{twingate.groups}}", wantErr: ErrInvalidTemplate, - errSubstr: "invalid template format", + errSubstr: "unsupported syntax", }, { name: "non-twingate namespace", From c21e89802212f1dc0b47d72e0c01df4bac19550b Mon Sep 17 00:00:00 2001 From: Clement Tee Date: Fri, 29 May 2026 11:00:10 +0800 Subject: [PATCH 05/17] Update code --- internal/config/config_test.go | 4 ++-- internal/httpproxy/utils/parser/parser.go | 2 +- internal/httpproxy/utils/parser/parser_test.go | 8 ++++---- internal/token/gat_claims.go | 10 +++------- internal/token/gat_claims_test.go | 14 ++++++-------- internal/webapphandler/web_app.go | 4 ++-- internal/webapphandler/web_app_test.go | 6 ++---- 7 files changed, 20 insertions(+), 28 deletions(-) diff --git a/internal/config/config_test.go b/internal/config/config_test.go index cdbc9ae..4d0b446 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -1143,10 +1143,10 @@ func TestWebAppConfig_Validate(t *testing.T) { wantErr: false, }, { - name: "invalid template", + name: "invalid template syntax", config: WebAppConfig{Headers: map[string]string{"X-Bad": "{{invalid"}}, wantErr: true, - errContains: "invalid template format", + errContains: "unsupported syntax", }, { name: "unsupported key", diff --git a/internal/httpproxy/utils/parser/parser.go b/internal/httpproxy/utils/parser/parser.go index 3eadd09..ec198a9 100644 --- a/internal/httpproxy/utils/parser/parser.go +++ b/internal/httpproxy/utils/parser/parser.go @@ -42,7 +42,7 @@ func NewTemplate(s string) (*Template, error) { if match == nil { if strings.Contains(s, "{{") || strings.Contains(s, "}}") { - return nil, fmt.Errorf("%w: invalid template format. Format must be {{twingate.key}} ", ErrInvalidTemplate) + return nil, fmt.Errorf("%w: unsupported syntax. Syntax must be {{twingate.key}} ", ErrInvalidTemplate) } return &Template{prefix: s}, nil diff --git a/internal/httpproxy/utils/parser/parser_test.go b/internal/httpproxy/utils/parser/parser_test.go index 1c6bc5f..7975212 100644 --- a/internal/httpproxy/utils/parser/parser_test.go +++ b/internal/httpproxy/utils/parser/parser_test.go @@ -55,25 +55,25 @@ func TestParser_NewTemplate(t *testing.T) { name: "invalid template", input: "{{invalid}}", wantErr: ErrInvalidTemplate, - errSubstr: "invalid template format", + errSubstr: "unsupported syntax", }, { name: "missing opening braces", input: "twingate.jwt }}", wantErr: ErrInvalidTemplate, - errSubstr: "invalid template format", + errSubstr: "unsupported syntax", }, { name: "missing closing braces", input: "{{ twingate.jwt", wantErr: ErrInvalidTemplate, - errSubstr: "invalid template format", + errSubstr: "unsupported syntax", }, { name: "multiple templates", input: "{{twingate.username}} {{twingate.groups}}", wantErr: ErrInvalidTemplate, - errSubstr: "invalid template format", + errSubstr: "unsupported syntax", }, { name: "non-twingate namespace", diff --git a/internal/token/gat_claims.go b/internal/token/gat_claims.go index 56a0cad..f95db08 100644 --- a/internal/token/gat_claims.go +++ b/internal/token/gat_claims.go @@ -85,7 +85,7 @@ func (u User) MarshalLogObject(enc zapcore.ObjectEncoder) error { return err } -type GeoIP struct { +type GeoIPLocation struct { Lat float64 `json:"lat"` Lon float64 `json:"lon"` Country string `json:"country,omitempty"` @@ -93,13 +93,9 @@ type GeoIP struct { City string `json:"city,omitempty"` } -type DeviceLocation struct { - GeoIP GeoIP `json:"geoip"` -} - type Device struct { - ID string `json:"id"` - Location DeviceLocation `json:"location,omitzero"` + ID string `json:"id"` + Location GeoIPLocation `json:"location,omitzero"` } type ResourceType = string diff --git a/internal/token/gat_claims_test.go b/internal/token/gat_claims_test.go index 3d48cb7..bba08c7 100644 --- a/internal/token/gat_claims_test.go +++ b/internal/token/gat_claims_test.go @@ -29,14 +29,12 @@ func TestGATTokenClaims_Validate(t *testing.T) { }, Device: Device{ ID: "device-1", - Location: DeviceLocation{ - GeoIP: GeoIP{ - Lat: 37.7749, - Lon: -122.4194, - Country: "US", - Region: "California", - City: "San Francisco", - }, + Location: GeoIPLocation{ + Lat: 37.7749, + Lon: -122.4194, + Country: "US", + Region: "California", + City: "San Francisco", }, }, Resource: Resource{ diff --git a/internal/webapphandler/web_app.go b/internal/webapphandler/web_app.go index 6cf6bd1..55280c8 100644 --- a/internal/webapphandler/web_app.go +++ b/internal/webapphandler/web_app.go @@ -52,8 +52,8 @@ func rewrite(r *httputil.ProxyRequest, conn *connect.ProxyConn, headers map[stri claims := conn.GATClaims() geoLoc := "" - if claims.Device.Location != (token.DeviceLocation{}) { - geoLoc = fmt.Sprintf("%v,%v", claims.Device.Location.GeoIP.Lat, claims.Device.Location.GeoIP.Lon) + if claims.Device.Location != (token.GeoIPLocation{}) { + geoLoc = fmt.Sprintf("%v,%v", claims.Device.Location.Lat, claims.Device.Location.Lon) } variables := map[string]string{ diff --git a/internal/webapphandler/web_app_test.go b/internal/webapphandler/web_app_test.go index 7864e8d..d2fe796 100644 --- a/internal/webapphandler/web_app_test.go +++ b/internal/webapphandler/web_app_test.go @@ -42,10 +42,8 @@ func TestRewrite(t *testing.T) { Groups: []string{"Everyone", "Engineering"}, }, Device: token.Device{ - ID: "device-1", - Location: token.DeviceLocation{ - GeoIP: token.GeoIP{Lat: 37.5, Lon: -122.4}, - }, + ID: "device-1", + Location: token.GeoIPLocation{Lat: 37.5, Lon: -122.4}, }, } From 0c37a75f5ab1b7a67a4b985fa0582a16cd2b04dd Mon Sep 17 00:00:00 2001 From: Clement Tee Date: Tue, 2 Jun 2026 12:08:07 +0800 Subject: [PATCH 06/17] Fix `parser.go` review comments --- internal/config/config.go | 2 +- internal/httpproxy/{utils => }/parser/parser.go | 7 ++++--- internal/httpproxy/{utils => }/parser/parser_test.go | 8 ++++---- internal/webapphandler/web_app.go | 2 +- internal/webapphandler/web_app_config.go | 2 +- internal/webapphandler/web_app_config_test.go | 2 +- internal/webapphandler/web_app_test.go | 2 +- 7 files changed, 13 insertions(+), 12 deletions(-) rename internal/httpproxy/{utils => }/parser/parser.go (86%) rename internal/httpproxy/{utils => }/parser/parser_test.go (96%) diff --git a/internal/config/config.go b/internal/config/config.go index 0ee4079..66bfa47 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -17,7 +17,7 @@ import ( "go.yaml.in/yaml/v4" "golang.org/x/crypto/ssh" - "gateway/internal/httpproxy/utils/parser" + "gateway/internal/httpproxy/parser" ) var ( diff --git a/internal/httpproxy/utils/parser/parser.go b/internal/httpproxy/parser/parser.go similarity index 86% rename from internal/httpproxy/utils/parser/parser.go rename to internal/httpproxy/parser/parser.go index ec198a9..97e8e43 100644 --- a/internal/httpproxy/utils/parser/parser.go +++ b/internal/httpproxy/parser/parser.go @@ -21,13 +21,13 @@ var ( ) var templateRe = regexp.MustCompile( - `^([^{}]*)` + // prefix (no brackets allowed) + `^([^{}]*)` + // prefix (no braces allowed) `{{\s*` + // opening braces `([a-zA-Z0-9_-]+)` + // namespace `\.` + `([a-zA-Z0-9_-]+)` + // key `\s*}}` + // closing braces - `([^{}]*)$`, // suffix (no brackets allowed) + `([^{}]*)$`, // suffix (no braces allowed) ) type Template struct { @@ -37,6 +37,7 @@ type Template struct { } // NewTemplate parses a string like " {{.}} " into a Template. +// If there is no template variable (just a static string), the key and suffix are empty and the prefix is the static string. func NewTemplate(s string) (*Template, error) { match := templateRe.FindStringSubmatch(s) @@ -45,7 +46,7 @@ func NewTemplate(s string) (*Template, error) { return nil, fmt.Errorf("%w: unsupported syntax. Syntax must be {{twingate.key}} ", ErrInvalidTemplate) } - return &Template{prefix: s}, nil + return &Template{prefix: strings.TrimSpace(s)}, nil } prefix, namespace, key, suffix := match[1], match[2], match[3], match[4] diff --git a/internal/httpproxy/utils/parser/parser_test.go b/internal/httpproxy/parser/parser_test.go similarity index 96% rename from internal/httpproxy/utils/parser/parser_test.go rename to internal/httpproxy/parser/parser_test.go index 7975212..b2d0cc5 100644 --- a/internal/httpproxy/utils/parser/parser_test.go +++ b/internal/httpproxy/parser/parser_test.go @@ -22,7 +22,7 @@ func TestParser_NewTemplate(t *testing.T) { }{ { name: "plain text", - input: "static-value", + input: " static-value ", wantPrefix: "static-value", }, { @@ -105,9 +105,9 @@ func TestParser_NewTemplate(t *testing.T) { func TestParser_Evaluate(t *testing.T) { tests := []struct { - name string - template Template - values map[string]string + name string + template Template + values map[string]string want string wantErr error errSubstr string diff --git a/internal/webapphandler/web_app.go b/internal/webapphandler/web_app.go index 55280c8..9b15a74 100644 --- a/internal/webapphandler/web_app.go +++ b/internal/webapphandler/web_app.go @@ -14,7 +14,7 @@ import ( "gateway/internal/connect" "gateway/internal/httpproxy" - "gateway/internal/httpproxy/utils/parser" + "gateway/internal/httpproxy/parser" "gateway/internal/metrics" "gateway/internal/token" ) diff --git a/internal/webapphandler/web_app_config.go b/internal/webapphandler/web_app_config.go index 43610c9..8dd1e72 100644 --- a/internal/webapphandler/web_app_config.go +++ b/internal/webapphandler/web_app_config.go @@ -8,7 +8,7 @@ import ( "go.uber.org/zap" - "gateway/internal/httpproxy/utils/parser" + "gateway/internal/httpproxy/parser" "gateway/internal/metrics" ) diff --git a/internal/webapphandler/web_app_config_test.go b/internal/webapphandler/web_app_config_test.go index 0993d36..fa82605 100644 --- a/internal/webapphandler/web_app_config_test.go +++ b/internal/webapphandler/web_app_config_test.go @@ -9,7 +9,7 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/zap" - "gateway/internal/httpproxy/utils/parser" + "gateway/internal/httpproxy/parser" ) func TestNewConfig(t *testing.T) { diff --git a/internal/webapphandler/web_app_test.go b/internal/webapphandler/web_app_test.go index d2fe796..c499bab 100644 --- a/internal/webapphandler/web_app_test.go +++ b/internal/webapphandler/web_app_test.go @@ -15,7 +15,7 @@ import ( "go.uber.org/zap" "gateway/internal/connect" - "gateway/internal/httpproxy/utils/parser" + "gateway/internal/httpproxy/parser" "gateway/internal/token" ) From 4016e4aff19e021f31c601fd2e6edef69f46bddd Mon Sep 17 00:00:00 2001 From: Clement Tee Date: Tue, 2 Jun 2026 12:36:56 +0800 Subject: [PATCH 07/17] Refactor claims key to be constant --- internal/config/config.go | 4 ---- internal/config/constants.go | 25 +++++++++++++++++++++++++ internal/metrics/round_tripper.go | 6 +++--- internal/webapphandler/web_app.go | 19 ++++++++++++------- internal/webapphandler/web_app_test.go | 23 +++++++++++++++++++---- 5 files changed, 59 insertions(+), 18 deletions(-) create mode 100644 internal/config/constants.go diff --git a/internal/config/config.go b/internal/config/config.go index 66bfa47..0871365 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -452,10 +452,6 @@ func (v *SSHCAVaultConfig) Validate() error { } func (w *WebAppConfig) Validate() error { - allowedWebAppKeys := []string{ - "jwt", "username", "groups", "clientGeoLoc", - } - for name, value := range w.Headers { tmpl, err := parser.NewTemplate(value) if err != nil { diff --git a/internal/config/constants.go b/internal/config/constants.go new file mode 100644 index 0000000..094aca3 --- /dev/null +++ b/internal/config/constants.go @@ -0,0 +1,25 @@ +// Copyright (c) Twingate Inc. +// SPDX-License-Identifier: MPL-2.0 + +package config + +// Keys used in HTTP header templates. +const ( + JWT = "jwt" + Username = "username" + Groups = "groups" + ClientGeoLatLong = "clientGeoLatLong" + ClientCity = "clientCity" + ClientRegion = "clientRegion" + ClientCountry = "clientCountry" +) + +var allowedWebAppKeys = []string{ + JWT, + Username, + Groups, + ClientGeoLatLong, + ClientCity, + ClientRegion, + ClientCountry, +} diff --git a/internal/metrics/round_tripper.go b/internal/metrics/round_tripper.go index a3e0b78..8d0ee0f 100644 --- a/internal/metrics/round_tripper.go +++ b/internal/metrics/round_tripper.go @@ -33,20 +33,20 @@ func RegisterRoundTripperMetrics(registry *prometheus.Registry) *RoundTripperMet requestsTotal: prometheus.NewCounterVec(prometheus.CounterOpts{ Namespace: Namespace, Name: "api_server_requests_total", - Help: "Total number of requests from Gateway to API Server processed", + Help: "Total number of requests from Gateway to HTTP server processed", }, []string{labelResourceType, "type", "method", "code"}), activeRequests: prometheus.NewGaugeVec(prometheus.GaugeOpts{ Namespace: Namespace, Name: "api_server_active_requests", - Help: "Number of currently active requests from Gateway to API Server", + Help: "Number of currently active requests from Gateway to HTTP server", }, []string{labelResourceType, "type"}), requestDuration: prometheus.NewHistogramVec( prometheus.HistogramOpts{ Namespace: Namespace, Name: "api_server_request_duration_seconds", - Help: "Measures the initial HTTP request-response latency between Gateway and API Server in seconds. For HTTP streaming, WebSocket, and SPDY connections, this metric captures only the setup time and not the duration of the data transfer.", + Help: "Measures the initial HTTP request-response latency between Gateway and HTTP server in seconds. For HTTP streaming, WebSocket, and SPDY connections, this metric captures only the setup time and not the duration of the data transfer.", Buckets: prometheus.DefBuckets, }, []string{labelResourceType, "type", "method", "code"}), } diff --git a/internal/webapphandler/web_app.go b/internal/webapphandler/web_app.go index 9b15a74..77acc31 100644 --- a/internal/webapphandler/web_app.go +++ b/internal/webapphandler/web_app.go @@ -12,6 +12,7 @@ import ( "go.uber.org/zap" + gatewayconfig "gateway/internal/config" "gateway/internal/connect" "gateway/internal/httpproxy" "gateway/internal/httpproxy/parser" @@ -51,16 +52,20 @@ func rewrite(r *httputil.ProxyRequest, conn *connect.ProxyConn, headers map[stri claims := conn.GATClaims() - geoLoc := "" - if claims.Device.Location != (token.GeoIPLocation{}) { - geoLoc = fmt.Sprintf("%v,%v", claims.Device.Location.Lat, claims.Device.Location.Lon) + clientLocation := claims.Device.Location + clientGeoLatLong := "" + if clientLocation != (token.GeoIPLocation{}) { + clientGeoLatLong = fmt.Sprintf("%v,%v", clientLocation.Lat, clientLocation.Lon) } variables := map[string]string{ - "jwt": conn.GetToken(), - "username": claims.User.Username, - "groups": strings.Join(claims.User.Groups, ","), - "clientGeoLoc": geoLoc, + gatewayconfig.JWT: conn.GetToken(), + gatewayconfig.Username: claims.User.Username, + gatewayconfig.Groups: strings.Join(claims.User.Groups, ","), + gatewayconfig.ClientGeoLatLong: clientGeoLatLong, + gatewayconfig.ClientCity: clientLocation.City, + gatewayconfig.ClientRegion: clientLocation.Region, + gatewayconfig.ClientCountry: clientLocation.Country, } for headerName, tmpl := range headers { diff --git a/internal/webapphandler/web_app_test.go b/internal/webapphandler/web_app_test.go index c499bab..9f7aca3 100644 --- a/internal/webapphandler/web_app_test.go +++ b/internal/webapphandler/web_app_test.go @@ -43,7 +43,7 @@ func TestRewrite(t *testing.T) { }, Device: token.Device{ ID: "device-1", - Location: token.GeoIPLocation{Lat: 37.5, Lon: -122.4}, + Location: token.GeoIPLocation{Lat: 37.5, Lon: -122.4, Country: "US", Region: "CA", City: "San Mateo"}, }, } @@ -63,7 +63,10 @@ func TestRewrite(t *testing.T) { "Authorization": "Bearer {{twingate.jwt}}", "X-Username": "{{twingate.username}}", "X-Groups": "{{twingate.groups}}", - "X-Geo": "{{twingate.clientGeoLoc}}", + "X-Geo": "{{twingate.clientGeoLatLong}}", + "X-City": "{{twingate.clientCity}}", + "X-Region": "{{twingate.clientRegion}}", + "X-Country": "{{twingate.clientCountry}}", "Existing": "new-value", }, wantHeaders: map[string]string{ @@ -71,6 +74,9 @@ func TestRewrite(t *testing.T) { "X-Username": "alice@acme.com", "X-Groups": "Everyone,Engineering", "X-Geo": "37.5,-122.4", + "X-City": "San Mateo", + "X-Region": "CA", + "X-Country": "US", "Existing": "new-value", }, }, @@ -83,9 +89,18 @@ func TestRewrite(t *testing.T) { Resource: baseClaims.Resource, }, headers: map[string]string{ - "X-Geo": "{{twingate.clientGeoLoc}}", + "X-Geo": "{{twingate.clientGeoLatLong}}", + "X-City": "{{twingate.clientCity}}", + "X-Region": "{{twingate.clientRegion}}", + "X-Country": "{{twingate.clientCountry}}", + }, + wantHeaders: map[string]string{ + "X-Geo": "", + "X-City": "", + "X-Region": "", + "X-Country": "", + "Existing": "old-value", }, - wantHeaders: map[string]string{"X-Geo": ""}, }, { name: "empty headers", From 0699bd5c513675f0233a91863d01fe45bf0a8960 Mon Sep 17 00:00:00 2001 From: Clement Tee Date: Tue, 2 Jun 2026 12:38:02 +0800 Subject: [PATCH 08/17] Remove unnecessary `web_app` prefix in filename --- internal/webapphandler/{web_app_config.go => config.go} | 0 internal/webapphandler/{web_app_config_test.go => config_test.go} | 0 internal/webapphandler/{web_app.go => handler.go} | 0 internal/webapphandler/{web_app_test.go => handler_test.go} | 0 4 files changed, 0 insertions(+), 0 deletions(-) rename internal/webapphandler/{web_app_config.go => config.go} (100%) rename internal/webapphandler/{web_app_config_test.go => config_test.go} (100%) rename internal/webapphandler/{web_app.go => handler.go} (100%) rename internal/webapphandler/{web_app_test.go => handler_test.go} (100%) diff --git a/internal/webapphandler/web_app_config.go b/internal/webapphandler/config.go similarity index 100% rename from internal/webapphandler/web_app_config.go rename to internal/webapphandler/config.go diff --git a/internal/webapphandler/web_app_config_test.go b/internal/webapphandler/config_test.go similarity index 100% rename from internal/webapphandler/web_app_config_test.go rename to internal/webapphandler/config_test.go diff --git a/internal/webapphandler/web_app.go b/internal/webapphandler/handler.go similarity index 100% rename from internal/webapphandler/web_app.go rename to internal/webapphandler/handler.go diff --git a/internal/webapphandler/web_app_test.go b/internal/webapphandler/handler_test.go similarity index 100% rename from internal/webapphandler/web_app_test.go rename to internal/webapphandler/handler_test.go From 4ea5aab46195f1bc8dbbf0779bde3c0a2c1893f8 Mon Sep 17 00:00:00 2001 From: Clement Tee Date: Tue, 2 Jun 2026 12:39:30 +0800 Subject: [PATCH 09/17] Fix lint --- internal/httpproxy/parser/parser_test.go | 6 +++--- internal/webapphandler/handler.go | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/internal/httpproxy/parser/parser_test.go b/internal/httpproxy/parser/parser_test.go index b2d0cc5..0c33c19 100644 --- a/internal/httpproxy/parser/parser_test.go +++ b/internal/httpproxy/parser/parser_test.go @@ -105,9 +105,9 @@ func TestParser_NewTemplate(t *testing.T) { func TestParser_Evaluate(t *testing.T) { tests := []struct { - name string - template Template - values map[string]string + name string + template Template + values map[string]string want string wantErr error errSubstr string diff --git a/internal/webapphandler/handler.go b/internal/webapphandler/handler.go index 77acc31..8202223 100644 --- a/internal/webapphandler/handler.go +++ b/internal/webapphandler/handler.go @@ -53,6 +53,7 @@ func rewrite(r *httputil.ProxyRequest, conn *connect.ProxyConn, headers map[stri claims := conn.GATClaims() clientLocation := claims.Device.Location + clientGeoLatLong := "" if clientLocation != (token.GeoIPLocation{}) { clientGeoLatLong = fmt.Sprintf("%v,%v", clientLocation.Lat, clientLocation.Lon) From 42deb34431e6f483fae265def90b35a2b807a1ec Mon Sep 17 00:00:00 2001 From: Clement Tee Date: Tue, 2 Jun 2026 12:57:58 +0800 Subject: [PATCH 10/17] Rename parser as template --- internal/config/config.go | 4 ++-- .../httpproxy/{parser/parser.go => template/template.go} | 6 +++--- .../{parser/parser_test.go => template/template_test.go} | 4 ++-- internal/webapphandler/config.go | 8 ++++---- internal/webapphandler/config_test.go | 4 ++-- internal/webapphandler/handler.go | 4 ++-- internal/webapphandler/handler_test.go | 8 ++++---- 7 files changed, 19 insertions(+), 19 deletions(-) rename internal/httpproxy/{parser/parser.go => template/template.go} (92%) rename internal/httpproxy/{parser/parser_test.go => template/template_test.go} (98%) diff --git a/internal/config/config.go b/internal/config/config.go index 0871365..a7f48a5 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -17,7 +17,7 @@ import ( "go.yaml.in/yaml/v4" "golang.org/x/crypto/ssh" - "gateway/internal/httpproxy/parser" + "gateway/internal/httpproxy/template" ) var ( @@ -453,7 +453,7 @@ func (v *SSHCAVaultConfig) Validate() error { func (w *WebAppConfig) Validate() error { for name, value := range w.Headers { - tmpl, err := parser.NewTemplate(value) + tmpl, err := template.New(value) if err != nil { return fmt.Errorf("header %q: %w", name, err) } diff --git a/internal/httpproxy/parser/parser.go b/internal/httpproxy/template/template.go similarity index 92% rename from internal/httpproxy/parser/parser.go rename to internal/httpproxy/template/template.go index 97e8e43..9ef5982 100644 --- a/internal/httpproxy/parser/parser.go +++ b/internal/httpproxy/template/template.go @@ -1,7 +1,7 @@ // Copyright (c) Twingate Inc. // SPDX-License-Identifier: MPL-2.0 -package parser +package template import ( "errors" @@ -36,9 +36,9 @@ type Template struct { suffix string } -// NewTemplate parses a string like " {{.}} " into a Template. +// New parses a string like " {{.}} " into a Template. // If there is no template variable (just a static string), the key and suffix are empty and the prefix is the static string. -func NewTemplate(s string) (*Template, error) { +func New(s string) (*Template, error) { match := templateRe.FindStringSubmatch(s) if match == nil { diff --git a/internal/httpproxy/parser/parser_test.go b/internal/httpproxy/template/template_test.go similarity index 98% rename from internal/httpproxy/parser/parser_test.go rename to internal/httpproxy/template/template_test.go index 0c33c19..0cb5b46 100644 --- a/internal/httpproxy/parser/parser_test.go +++ b/internal/httpproxy/template/template_test.go @@ -1,7 +1,7 @@ // Copyright (c) Twingate Inc. // SPDX-License-Identifier: MPL-2.0 -package parser +package template import ( "testing" @@ -85,7 +85,7 @@ func TestParser_NewTemplate(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - template, err := NewTemplate(tt.input) + template, err := New(tt.input) if tt.wantErr != nil { require.ErrorIs(t, err, tt.wantErr) diff --git a/internal/webapphandler/config.go b/internal/webapphandler/config.go index 8dd1e72..e026003 100644 --- a/internal/webapphandler/config.go +++ b/internal/webapphandler/config.go @@ -8,21 +8,21 @@ import ( "go.uber.org/zap" - "gateway/internal/httpproxy/parser" + "gateway/internal/httpproxy/template" "gateway/internal/metrics" ) type Config struct { - headers map[string]*parser.Template + headers map[string]*template.Template roundTripperMetrics *metrics.RoundTripperMetrics logger *zap.Logger } func NewConfig(configHeaders map[string]string, roundTripperMetrics *metrics.RoundTripperMetrics, logger *zap.Logger) (*Config, error) { - headers := make(map[string]*parser.Template, len(configHeaders)) + headers := make(map[string]*template.Template, len(configHeaders)) for name, value := range configHeaders { - tmpl, err := parser.NewTemplate(value) + tmpl, err := template.New(value) if err != nil { return nil, fmt.Errorf("header %q: %w", name, err) } diff --git a/internal/webapphandler/config_test.go b/internal/webapphandler/config_test.go index fa82605..e7d16cc 100644 --- a/internal/webapphandler/config_test.go +++ b/internal/webapphandler/config_test.go @@ -9,7 +9,7 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/zap" - "gateway/internal/httpproxy/parser" + "gateway/internal/httpproxy/template" ) func TestNewConfig(t *testing.T) { @@ -29,7 +29,7 @@ func TestNewConfig(t *testing.T) { headers: map[string]string{ "X-Invalid": "{{invalid}}", }, - wantErr: parser.ErrInvalidTemplate, + wantErr: template.ErrInvalidTemplate, }, { name: "empty headers", diff --git a/internal/webapphandler/handler.go b/internal/webapphandler/handler.go index 8202223..8c73f39 100644 --- a/internal/webapphandler/handler.go +++ b/internal/webapphandler/handler.go @@ -15,7 +15,7 @@ import ( gatewayconfig "gateway/internal/config" "gateway/internal/connect" "gateway/internal/httpproxy" - "gateway/internal/httpproxy/parser" + "gateway/internal/httpproxy/template" "gateway/internal/metrics" "gateway/internal/token" ) @@ -43,7 +43,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.proxy.ServeHTTP(w, r) } -func rewrite(r *httputil.ProxyRequest, conn *connect.ProxyConn, headers map[string]*parser.Template) error { +func rewrite(r *httputil.ProxyRequest, conn *connect.ProxyConn, headers map[string]*template.Template) error { targetURL := &url.URL{ Scheme: "http", // plain HTTP — no upstream TLS Host: conn.GetAddress(), diff --git a/internal/webapphandler/handler_test.go b/internal/webapphandler/handler_test.go index 9f7aca3..acfaae5 100644 --- a/internal/webapphandler/handler_test.go +++ b/internal/webapphandler/handler_test.go @@ -15,17 +15,17 @@ import ( "go.uber.org/zap" "gateway/internal/connect" - "gateway/internal/httpproxy/parser" + "gateway/internal/httpproxy/template" "gateway/internal/token" ) -func mustParse(t *testing.T, templates map[string]string) map[string]*parser.Template { +func mustParse(t *testing.T, templates map[string]string) map[string]*template.Template { t.Helper() - result := make(map[string]*parser.Template, len(templates)) + result := make(map[string]*template.Template, len(templates)) for name, tmpl := range templates { - parsed, err := parser.NewTemplate(tmpl) + parsed, err := template.New(tmpl) require.NoError(t, err, "failed to parse template for header %q", name) result[name] = parsed From d8a5f04d6504760c53eda6f8455c8fe560e0b31b Mon Sep 17 00:00:00 2001 From: Clement Tee Date: Tue, 2 Jun 2026 14:50:08 +0800 Subject: [PATCH 11/17] Move template to webapphandler --- internal/config/config.go | 4 +- internal/config/constants.go | 25 --- internal/webapphandler/config.go | 2 +- internal/webapphandler/config_test.go | 2 +- internal/webapphandler/handler.go | 18 +-- internal/webapphandler/handler_test.go | 22 +-- internal/webapphandler/template/template.go | 102 ++++++++++++ .../webapphandler/template/template_test.go | 145 ++++++++++++++++++ 8 files changed, 271 insertions(+), 49 deletions(-) delete mode 100644 internal/config/constants.go create mode 100644 internal/webapphandler/template/template.go create mode 100644 internal/webapphandler/template/template_test.go diff --git a/internal/config/config.go b/internal/config/config.go index a7f48a5..26ef353 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -17,7 +17,7 @@ import ( "go.yaml.in/yaml/v4" "golang.org/x/crypto/ssh" - "gateway/internal/httpproxy/template" + "gateway/internal/webapphandler/template" ) var ( @@ -459,7 +459,7 @@ func (w *WebAppConfig) Validate() error { } key := tmpl.Key() - if key != "" && !slices.Contains(allowedWebAppKeys, key) { + if key != "" && !slices.Contains(template.AllowedWebAppKeys, key) { return fmt.Errorf("header %q: %w %q", name, ErrUnsupportedKey, key) } } diff --git a/internal/config/constants.go b/internal/config/constants.go deleted file mode 100644 index 094aca3..0000000 --- a/internal/config/constants.go +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) Twingate Inc. -// SPDX-License-Identifier: MPL-2.0 - -package config - -// Keys used in HTTP header templates. -const ( - JWT = "jwt" - Username = "username" - Groups = "groups" - ClientGeoLatLong = "clientGeoLatLong" - ClientCity = "clientCity" - ClientRegion = "clientRegion" - ClientCountry = "clientCountry" -) - -var allowedWebAppKeys = []string{ - JWT, - Username, - Groups, - ClientGeoLatLong, - ClientCity, - ClientRegion, - ClientCountry, -} diff --git a/internal/webapphandler/config.go b/internal/webapphandler/config.go index e026003..323f5c0 100644 --- a/internal/webapphandler/config.go +++ b/internal/webapphandler/config.go @@ -8,8 +8,8 @@ import ( "go.uber.org/zap" - "gateway/internal/httpproxy/template" "gateway/internal/metrics" + "gateway/internal/webapphandler/template" ) type Config struct { diff --git a/internal/webapphandler/config_test.go b/internal/webapphandler/config_test.go index e7d16cc..d678917 100644 --- a/internal/webapphandler/config_test.go +++ b/internal/webapphandler/config_test.go @@ -9,7 +9,7 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/zap" - "gateway/internal/httpproxy/template" + "gateway/internal/webapphandler/template" ) func TestNewConfig(t *testing.T) { diff --git a/internal/webapphandler/handler.go b/internal/webapphandler/handler.go index 8c73f39..b4fac94 100644 --- a/internal/webapphandler/handler.go +++ b/internal/webapphandler/handler.go @@ -12,12 +12,11 @@ import ( "go.uber.org/zap" - gatewayconfig "gateway/internal/config" "gateway/internal/connect" "gateway/internal/httpproxy" - "gateway/internal/httpproxy/template" "gateway/internal/metrics" "gateway/internal/token" + "gateway/internal/webapphandler/template" ) type Handler struct { @@ -31,6 +30,7 @@ func NewHandler(cfg Config) *Handler { if err := rewrite(r, conn, cfg.headers); err != nil { cfg.logger.Error("failed to rewrite headers", zap.Error(err)) + panic(err) } }, Transport: metrics.InstrumentRoundTripper(cfg.roundTripperMetrics, metrics.ResourceTypeWebApp, http.DefaultTransport), @@ -60,13 +60,13 @@ func rewrite(r *httputil.ProxyRequest, conn *connect.ProxyConn, headers map[stri } variables := map[string]string{ - gatewayconfig.JWT: conn.GetToken(), - gatewayconfig.Username: claims.User.Username, - gatewayconfig.Groups: strings.Join(claims.User.Groups, ","), - gatewayconfig.ClientGeoLatLong: clientGeoLatLong, - gatewayconfig.ClientCity: clientLocation.City, - gatewayconfig.ClientRegion: clientLocation.Region, - gatewayconfig.ClientCountry: clientLocation.Country, + template.JWT: conn.GetToken(), + template.Username: claims.User.Username, + template.Groups: strings.Join(claims.User.Groups, ","), + template.ClientGeoLatLong: clientGeoLatLong, + template.ClientGeoCity: clientLocation.City, + template.ClientGeoRegion: clientLocation.Region, + template.ClientGeoCountry: clientLocation.Country, } for headerName, tmpl := range headers { diff --git a/internal/webapphandler/handler_test.go b/internal/webapphandler/handler_test.go index acfaae5..788a220 100644 --- a/internal/webapphandler/handler_test.go +++ b/internal/webapphandler/handler_test.go @@ -15,8 +15,8 @@ import ( "go.uber.org/zap" "gateway/internal/connect" - "gateway/internal/httpproxy/template" "gateway/internal/token" + "gateway/internal/webapphandler/template" ) func mustParse(t *testing.T, templates map[string]string) map[string]*template.Template { @@ -63,17 +63,17 @@ func TestRewrite(t *testing.T) { "Authorization": "Bearer {{twingate.jwt}}", "X-Username": "{{twingate.username}}", "X-Groups": "{{twingate.groups}}", - "X-Geo": "{{twingate.clientGeoLatLong}}", - "X-City": "{{twingate.clientCity}}", - "X-Region": "{{twingate.clientRegion}}", - "X-Country": "{{twingate.clientCountry}}", + "X-LatLong": "{{twingate.clientGeoLatLong}}", + "X-City": "{{twingate.clientGeoCity}}", + "X-Region": "{{twingate.clientGeoRegion}}", + "X-Country": "{{twingate.clientGeoCountry}}", "Existing": "new-value", }, wantHeaders: map[string]string{ "Authorization": "Bearer test-token", "X-Username": "alice@acme.com", "X-Groups": "Everyone,Engineering", - "X-Geo": "37.5,-122.4", + "X-LatLong": "37.5,-122.4", "X-City": "San Mateo", "X-Region": "CA", "X-Country": "US", @@ -89,13 +89,13 @@ func TestRewrite(t *testing.T) { Resource: baseClaims.Resource, }, headers: map[string]string{ - "X-Geo": "{{twingate.clientGeoLatLong}}", - "X-City": "{{twingate.clientCity}}", - "X-Region": "{{twingate.clientRegion}}", - "X-Country": "{{twingate.clientCountry}}", + "X-LatLong": "{{twingate.clientGeoLatLong}}", + "X-City": "{{twingate.clientGeoCity}}", + "X-Region": "{{twingate.clientGeoRegion}}", + "X-Country": "{{twingate.clientGeoCountry}}", }, wantHeaders: map[string]string{ - "X-Geo": "", + "X-LatLong": "", "X-City": "", "X-Region": "", "X-Country": "", diff --git a/internal/webapphandler/template/template.go b/internal/webapphandler/template/template.go new file mode 100644 index 0000000..de45834 --- /dev/null +++ b/internal/webapphandler/template/template.go @@ -0,0 +1,102 @@ +// Copyright (c) Twingate Inc. +// SPDX-License-Identifier: MPL-2.0 + +package template + +import ( + "errors" + "fmt" + "regexp" + "strings" + "unicode" +) + +const ( + allowedNamespace = "twingate" +) + +const ( + JWT = "jwt" + Username = "username" + Groups = "groups" + ClientGeoLatLong = "clientGeoLatLong" + ClientGeoCity = "clientGeoCity" + ClientGeoRegion = "clientGeoRegion" + ClientGeoCountry = "clientGeoCountry" +) + +var AllowedWebAppKeys = []string{ + JWT, + Username, + Groups, + ClientGeoLatLong, + ClientGeoCity, + ClientGeoRegion, + ClientGeoCountry, +} + +var ( + ErrInvalidTemplate = errors.New("invalid template") + ErrUnknownKey = errors.New("unknown key") +) + +var templateRe = regexp.MustCompile( + `^([^{}]*)` + // prefix (no braces allowed) + `{{\s*` + // opening braces + `([a-zA-Z0-9_-]+)` + // namespace + `\.` + + `([a-zA-Z0-9_-]+)` + // key + `\s*}}` + // closing braces + `([^{}]*)$`, // suffix (no braces allowed) +) + +type Template struct { + prefix string + key string + suffix string +} + +// New parses a string like " {{.}} " into a Template. +// If there is no template variable (just a static string), the key and suffix are empty and the prefix is the static string. +func New(s string) (*Template, error) { + match := templateRe.FindStringSubmatch(s) + + if match == nil { + if strings.Contains(s, "{{") || strings.Contains(s, "}}") { + return nil, fmt.Errorf("%w: unsupported syntax. Syntax must be {{twingate.key}} ", ErrInvalidTemplate) + } + + return &Template{prefix: strings.TrimSpace(s)}, nil + } + + prefix, namespace, key, suffix := match[1], match[2], match[3], match[4] + + if namespace != allowedNamespace { + return nil, fmt.Errorf("%w: unsupported namespace %q", ErrInvalidTemplate, namespace) + } + + return &Template{ + prefix: strings.TrimLeftFunc(prefix, unicode.IsSpace), + key: key, + suffix: strings.TrimRightFunc(suffix, unicode.IsSpace), + }, nil +} + +func (t *Template) Key() string { + return t.key +} + +// Evaluate replaces the key in the template with the corresponding value from the map +// and returns the resulting string along with the prefix and suffix. +func (t *Template) Evaluate(values map[string]string) (string, error) { + if t.key == "" { + return t.prefix, nil + } + + result, ok := values[t.key] + if !ok { + return "", fmt.Errorf("%w: %q", ErrUnknownKey, t.key) + } + + return t.prefix + result + t.suffix, nil +} diff --git a/internal/webapphandler/template/template_test.go b/internal/webapphandler/template/template_test.go new file mode 100644 index 0000000..0cb5b46 --- /dev/null +++ b/internal/webapphandler/template/template_test.go @@ -0,0 +1,145 @@ +// Copyright (c) Twingate Inc. +// SPDX-License-Identifier: MPL-2.0 + +package template + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParser_NewTemplate(t *testing.T) { + tests := []struct { + name string + input string + wantPrefix string + wantKey string + wantSuffix string + wantErr error + errSubstr string + }{ + { + name: "plain text", + input: " static-value ", + wantPrefix: "static-value", + }, + { + name: "empty string", + input: "", + }, + { + name: "template only", + input: "{{twingate.jwt}}", + wantKey: "jwt", + }, + { + name: "template with leading and trailing space", + input: "{{ twingate.jwt }}", + wantKey: "jwt", + }, + { + name: "template with prefix", + input: " Bearer {{twingate.jwt}}", + wantPrefix: "Bearer ", + wantKey: "jwt", + }, + { + name: "template with suffix", + input: "{{twingate.username}}/profile ", + wantKey: "username", + wantSuffix: "/profile", + }, + { + name: "invalid template", + input: "{{invalid}}", + wantErr: ErrInvalidTemplate, + errSubstr: "unsupported syntax", + }, + { + name: "missing opening braces", + input: "twingate.jwt }}", + wantErr: ErrInvalidTemplate, + errSubstr: "unsupported syntax", + }, + { + name: "missing closing braces", + input: "{{ twingate.jwt", + wantErr: ErrInvalidTemplate, + errSubstr: "unsupported syntax", + }, + { + name: "multiple templates", + input: "{{twingate.username}} {{twingate.groups}}", + wantErr: ErrInvalidTemplate, + errSubstr: "unsupported syntax", + }, + { + name: "non-twingate namespace", + input: "{{other.key}}", + wantErr: ErrInvalidTemplate, + errSubstr: "unsupported namespace", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + template, err := New(tt.input) + + if tt.wantErr != nil { + require.ErrorIs(t, err, tt.wantErr) + assert.Contains(t, err.Error(), tt.errSubstr) + + return + } + + require.NoError(t, err) + + assert.Equal(t, tt.wantPrefix, template.prefix) + assert.Equal(t, tt.wantKey, template.key) + assert.Equal(t, tt.wantSuffix, template.suffix) + }) + } +} + +func TestParser_Evaluate(t *testing.T) { + tests := []struct { + name string + template Template + values map[string]string + want string + wantErr error + errSubstr string + }{ + { + name: "Success", + template: Template{prefix: "Prefix ", key: "foo", suffix: " suffix"}, + values: map[string]string{"foo": "bar", "extra": "foo"}, + want: "Prefix bar suffix", + }, + { + name: "Missing key", + template: Template{prefix: "Bearer ", key: "jwt", suffix: ""}, + values: map[string]string{}, + wantErr: ErrUnknownKey, + errSubstr: "jwt", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := tt.template.Evaluate(tt.values) + + if tt.wantErr != nil { + require.ErrorIs(t, err, tt.wantErr) + assert.Contains(t, err.Error(), tt.errSubstr) + + return + } + + require.NoError(t, err) + assert.Equal(t, tt.want, result) + }) + } +} From e7ceee93702e366c0a95e43049d7f529a61d5c65 Mon Sep 17 00:00:00 2001 From: Clement Tee Date: Tue, 2 Jun 2026 16:03:28 +0800 Subject: [PATCH 12/17] Create http proxy per resource type --- internal/connect/conn.go | 2 +- internal/connect/conn_test.go | 2 +- internal/connect/listener.go | 10 +- internal/connect/listener_test.go | 69 +++++++------ internal/httpproxy/proxy.go | 37 ++----- internal/httpproxy/proxy_test.go | 74 ++------------ internal/metrics/http_middleware.go | 119 +++++++++++++---------- internal/metrics/http_middleware_test.go | 43 ++++---- internal/proxy/proxy.go | 64 +++++++----- internal/proxy/proxy_test.go | 29 +++--- internal/token/gat_claims.go | 8 +- internal/webapphandler/handler_test.go | 28 ++++++ 12 files changed, 241 insertions(+), 244 deletions(-) diff --git a/internal/connect/conn.go b/internal/connect/conn.go index 00cd631..09e7477 100644 --- a/internal/connect/conn.go +++ b/internal/connect/conn.go @@ -237,7 +237,7 @@ func (p *ProxyConn) Authenticate() error { p.tracker.RecordConnectMetrics(httpCode) p.Logger.Info("Authenticated connection", - zap.String("resource_type", connectInfo.Claims.Resource.Type), + zap.String("resource_type", string(connectInfo.Claims.Resource.Type)), zap.String("resource_address", connectInfo.Claims.Resource.Address), ) p.setConnectInfo(connectInfo) diff --git a/internal/connect/conn_test.go b/internal/connect/conn_test.go index 286bf0d..ef5349c 100644 --- a/internal/connect/conn_test.go +++ b/internal/connect/conn_test.go @@ -99,7 +99,7 @@ func TestProxyConn_TransportProtocol(t *testing.T) { } for _, tt := range tests { - t.Run(tt.resourceType, func(t *testing.T) { + t.Run(string(tt.resourceType), func(t *testing.T) { claims := &token.GATClaims{ Resource: token.Resource{Type: tt.resourceType}, } diff --git a/internal/connect/listener.go b/internal/connect/listener.go index b88741a..4268e51 100644 --- a/internal/connect/listener.go +++ b/internal/connect/listener.go @@ -70,7 +70,7 @@ func (l *ProtocolListener) Addr() net.Addr { type ConnFactory func(net.Conn, *tls.Config, Validator, *zap.Logger) Conn type Listener struct { - channels map[TransportProtocol]chan<- Conn + channels map[token.ResourceType]chan<- Conn tokenParser *token.Parser certReloader *CertReloader @@ -88,7 +88,7 @@ type Listener struct { func NewListener( twingateConfig config.TwingateConfig, tlsCfg config.TLSConfig, - channels map[TransportProtocol]chan<- Conn, + channels map[token.ResourceType]chan<- Conn, registry *prometheus.Registry, logger *zap.Logger, ) (*Listener, error) { @@ -172,11 +172,11 @@ func (l *Listener) Serve(ctx context.Context, listener net.Listener) error { return } - tp := proxyConn.TransportProtocol() - channel, exists := l.channels[tp] + resourceType := proxyConn.GATClaims().Resource.Type + channel, exists := l.channels[resourceType] if !exists { - l.logger.Error("Unsupported transport protocol", zap.Int("transport", int(tp))) + l.logger.Error("Unsupported resource type", zap.String("resource_type", string(resourceType))) _ = proxyConn.Close() diff --git a/internal/connect/listener_test.go b/internal/connect/listener_test.go index 2782f71..5d19e53 100644 --- a/internal/connect/listener_test.go +++ b/internal/connect/listener_test.go @@ -97,16 +97,20 @@ func createMockListener(t *testing.T) (net.Listener, string) { return listener, addr } -var listenerClaims = &token.GATClaims{ - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), - }, - User: token.User{ - ID: "user-1", - Username: "user@acme.com", - Groups: []string{"Everyone", "Engineering"}, - }, - Resource: token.Resource{ID: "resource-1", Type: token.ResourceTypeKubernetes, Address: "https://api.acme.com"}, +func createClaims(t *testing.T, resourceType token.ResourceType) *token.GATClaims { + t.Helper() + + return &token.GATClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), + }, + User: token.User{ + ID: "user-1", + Username: "user@acme.com", + Groups: []string{"Everyone", "Engineering"}, + }, + Resource: token.Resource{ID: "resource-1", Type: resourceType, Address: "https://api.acme.com"}, + } } type testListenerFixtures struct { @@ -125,9 +129,9 @@ func createTestListenerWithChannels(t *testing.T) *testListenerFixtures { // Create channels for testing httpChannel := make(chan Conn, 1) sshChannel := make(chan Conn, 1) - channels := map[TransportProtocol]chan<- Conn{ - TransportTLS: httpChannel, - TransportSSH: sshChannel, + channels := map[token.ResourceType]chan<- Conn{ + token.ResourceTypeKubernetes: httpChannel, + token.ResourceTypeSSH: sshChannel, } registry := prometheus.NewRegistry() @@ -154,11 +158,13 @@ func createTestListenerWithChannels(t *testing.T) *testListenerFixtures { func TestListener_Serve_HTTPS(t *testing.T) { fixtures := createTestListenerWithChannels(t) + kubernetesClaims := createClaims(t, token.ResourceTypeKubernetes) + fixtures.listener.proxyConnFactory = func(conn net.Conn, _ *tls.Config, _ Validator, _ *zap.Logger) Conn { return &mockProxyConn{ Conn: conn, transportProtocol: TransportTLS, - Claims: listenerClaims, + Claims: kubernetesClaims, } } @@ -191,7 +197,7 @@ func TestListener_Serve_HTTPS(t *testing.T) { case conn := <-fixtures.httpChannel: require.False(t, conn.(*mockProxyConn).IsClosed()) require.Equal(t, TransportTLS, conn.TransportProtocol()) - require.Equal(t, listenerClaims, conn.GATClaims()) + require.Equal(t, kubernetesClaims, conn.GATClaims()) case <-time.After(1 * time.Second): t.Fatal("timeout waiting for HTTP connection") } @@ -206,11 +212,13 @@ func TestListener_Serve_HTTPS(t *testing.T) { func TestListener_Serve_SSH(t *testing.T) { fixtures := createTestListenerWithChannels(t) + sshClaims := createClaims(t, token.ResourceTypeSSH) + fixtures.listener.proxyConnFactory = func(conn net.Conn, _ *tls.Config, _ Validator, _ *zap.Logger) Conn { return &mockProxyConn{ Conn: conn, transportProtocol: TransportSSH, - Claims: listenerClaims, + Claims: sshClaims, } } @@ -242,8 +250,7 @@ func TestListener_Serve_SSH(t *testing.T) { select { case conn := <-fixtures.sshChannel: require.False(t, conn.(*mockProxyConn).IsClosed()) - require.Equal(t, TransportSSH, conn.TransportProtocol()) - require.Equal(t, listenerClaims, conn.GATClaims()) + require.Equal(t, sshClaims, conn.GATClaims()) case <-time.After(1 * time.Second): t.Fatal("timeout waiting for SSH connection") } @@ -329,13 +336,13 @@ func TestListener_Serve_Healthz(t *testing.T) { waitGroup.Wait() } -func TestListener_UnsupportedTransport(t *testing.T) { +func TestListener_UnsupportedResourceType(t *testing.T) { tcpListener, addr := createMockListener(t) // Create channels but omit one transport type httpChannel := make(chan Conn, 1) - channels := map[TransportProtocol]chan<- Conn{ - TransportTLS: httpChannel, + channels := map[token.ResourceType]chan<- Conn{ + token.ResourceTypeKubernetes: httpChannel, // SSH not included - unsupported } @@ -350,13 +357,14 @@ func TestListener_UnsupportedTransport(t *testing.T) { certReloader: certReloader, } + sshClaims := createClaims(t, token.ResourceTypeSSH) + // Use a channel to safely pass the connection from the factory connCreated := make(chan *mockProxyConn, 1) listener.proxyConnFactory = func(conn net.Conn, _ *tls.Config, _ Validator, _ *zap.Logger) Conn { mockConn := &mockProxyConn{ - Conn: conn, - transportProtocol: TransportSSH, // This transport is not supported - Claims: listenerClaims, + Conn: conn, + Claims: sshClaims, } connCreated <- mockConn @@ -391,12 +399,14 @@ func TestListener_Serve_GracefulShutdown(t *testing.T) { // Use unbuffered channel so the goroutine inside Serve blocks on send httpChannel := make(chan Conn) - channels := map[TransportProtocol]chan<- Conn{ - TransportTLS: httpChannel, + channels := map[token.ResourceType]chan<- Conn{ + token.ResourceTypeKubernetes: httpChannel, } logger := zap.NewNop() + kubernetesClaims := createClaims(t, token.ResourceTypeKubernetes) + listener := &Listener{ channels: channels, logger: logger, @@ -408,7 +418,7 @@ func TestListener_Serve_GracefulShutdown(t *testing.T) { return &mockProxyConn{ Conn: conn, transportProtocol: TransportTLS, - Claims: listenerClaims, + Claims: kubernetesClaims, } } @@ -465,9 +475,10 @@ func TestProtocolListener(t *testing.T) { assert.Equal(t, addr, protocolListener.Addr().String()) // Send a mock connection + kubernetesClaims := createClaims(t, token.ResourceTypeKubernetes) + mockConn := &mockProxyConn{ - transportProtocol: TransportTLS, - Claims: listenerClaims, + Claims: kubernetesClaims, } ch <- mockConn diff --git a/internal/httpproxy/proxy.go b/internal/httpproxy/proxy.go index 5c89bc1..a2bef30 100644 --- a/internal/httpproxy/proxy.go +++ b/internal/httpproxy/proxy.go @@ -9,7 +9,6 @@ import ( "net/http" "time" - "github.com/prometheus/client_golang/prometheus" "go.uber.org/zap" "gateway/internal/connect" @@ -28,41 +27,21 @@ func ProxyConnFromContext(ctx context.Context) *connect.ProxyConn { } type Config struct { - Handlers map[string]http.Handler - Registry *prometheus.Registry - Logger *zap.Logger + Handler http.Handler + Metrics *metrics.HTTPMetrics + Logger *zap.Logger + ResourceTypeLabel string } type Proxy struct { httpServer *http.Server } -func newResourceRouter(handlers map[string]http.Handler, logger *zap.Logger) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn := ProxyConnFromContext(r.Context()) - resourceType := conn.GATClaims().Resource.Type - - handler, exists := handlers[resourceType] - if !exists { - logger.Error("No handler for resource type", zap.String("type", resourceType)) - http.Error(w, "unsupported resource type", http.StatusNotFound) - - return - } - - handler.ServeHTTP(w, r) - }) -} - func NewProxy(cfg Config) *Proxy { - router := newResourceRouter(cfg.Handlers, cfg.Logger) - handler := metrics.HTTPMiddleware(metrics.HTTPMiddlewareConfig{ - Registry: cfg.Registry, - Next: auditMiddleware(auditMiddlewareConfig{ - next: router, - logger: cfg.Logger, - }), - }) + handler := metrics.HTTPMiddleware(cfg.Metrics, cfg.ResourceTypeLabel, auditMiddleware(auditMiddlewareConfig{ + next: cfg.Handler, + logger: cfg.Logger, + })) mux := http.NewServeMux() mux.Handle("/", handler) diff --git a/internal/httpproxy/proxy_test.go b/internal/httpproxy/proxy_test.go index 48bd8e4..a999def 100644 --- a/internal/httpproxy/proxy_test.go +++ b/internal/httpproxy/proxy_test.go @@ -8,7 +8,6 @@ import ( "io" "net" "net/http" - "net/http/httptest" "testing" "github.com/prometheus/client_golang/prometheus" @@ -17,6 +16,7 @@ import ( "go.uber.org/zap" "gateway/internal/connect" + "gateway/internal/metrics" "gateway/internal/token" ) @@ -60,64 +60,6 @@ func TestProxyConnFromContext(t *testing.T) { }) } -func TestResourceRouter_DispatchesToCorrectHandler(t *testing.T) { - var handledBy = "" - - handlers := map[string]http.Handler{ - token.ResourceTypeKubernetes: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - handledBy = "kubernetes" - - w.WriteHeader(http.StatusOK) - }), - token.ResourceTypeSSH: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - handledBy = "ssh" - - w.WriteHeader(http.StatusOK) - }), - } - - connMetrics := connect.CreateProxyConnMetrics(prometheus.NewRegistry()) - proxyConn := connect.NewProxyConn(nil, nil, nil, zap.NewNop(), connMetrics) - proxyConn.Claims = &token.GATClaims{ - Resource: token.Resource{Type: token.ResourceTypeKubernetes}, - } - - recorder := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/test", nil) - ctx := context.WithValue(req.Context(), ConnContextKey{}, proxyConn) - req = req.WithContext(ctx) - - router := newResourceRouter(handlers, zap.NewNop()) - router.ServeHTTP(recorder, req) - - assert.Equal(t, http.StatusOK, recorder.Code) - assert.Equal(t, "kubernetes", handledBy) -} - -func TestResourceRouter_UnknownResource(t *testing.T) { - handlers := map[string]http.Handler{ - token.ResourceTypeKubernetes: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusOK) - }), - } - - connMetrics := connect.CreateProxyConnMetrics(prometheus.NewRegistry()) - proxyConn := connect.NewProxyConn(nil, nil, nil, zap.NewNop(), connMetrics) - proxyConn.Claims = &token.GATClaims{ - Resource: token.Resource{Type: "unknown"}, - } - - recorder := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/test", nil) - ctx := context.WithValue(req.Context(), ConnContextKey{}, proxyConn) - req = req.WithContext(ctx) - - router := newResourceRouter(handlers, zap.NewNop()) - router.ServeHTTP(recorder, req) - - assert.Equal(t, http.StatusNotFound, recorder.Code) -} - func TestProxy_ForwardRequest(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) @@ -127,9 +69,10 @@ func TestProxy_ForwardRequest(t *testing.T) { require.NoError(t, err) proxy := NewProxy(Config{ - Registry: prometheus.NewRegistry(), - Handlers: map[string]http.Handler{token.ResourceTypeKubernetes: handler}, - Logger: zap.NewNop(), + Metrics: metrics.RegisterHTTPMetrics(prometheus.NewRegistry()), + ResourceTypeLabel: metrics.ResourceTypeKubernetes, + Handler: handler, + Logger: zap.NewNop(), }) go func() { @@ -161,9 +104,10 @@ func TestProxy_Shutdown(t *testing.T) { require.NoError(t, err) proxy := NewProxy(Config{ - Handlers: map[string]http.Handler{token.ResourceTypeKubernetes: handler}, - Registry: prometheus.NewRegistry(), - Logger: zap.NewNop(), + Metrics: metrics.RegisterHTTPMetrics(prometheus.NewRegistry()), + ResourceTypeLabel: metrics.ResourceTypeKubernetes, + Handler: handler, + Logger: zap.NewNop(), }) done := make(chan error, 1) diff --git a/internal/metrics/http_middleware.go b/internal/metrics/http_middleware.go index ee82590..47a371c 100644 --- a/internal/metrics/http_middleware.go +++ b/internal/metrics/http_middleware.go @@ -28,71 +28,84 @@ const ( requestTypeUnknown = "unknown" ) -type HTTPMiddlewareConfig struct { - Registry *prometheus.Registry - Next http.Handler +type contextKey struct{} + +type HTTPMetrics struct { + requestsTotal *prometheus.CounterVec + activeRequests *prometheus.GaugeVec + requestDuration *prometheus.HistogramVec + requestSizeBytes *prometheus.HistogramVec + responseSizeBytes *prometheus.HistogramVec } -type contextKey struct{} +func RegisterHTTPMetrics(registry *prometheus.Registry) *HTTPMetrics { + m := &HTTPMetrics{ + requestsTotal: prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: Namespace, + Name: "http_requests_total", + Help: "Total number of HTTP requests processed", + }, []string{labelResourceType, "type", "method", "code"}), -func HTTPMiddleware(config HTTPMiddlewareConfig) http.HandlerFunc { - requestsTotal := prometheus.NewCounterVec(prometheus.CounterOpts{ - Namespace: Namespace, - Name: "http_requests_total", - Help: "Total number of HTTP requests processed", - }, []string{"type", "method", "code"}) - - activeRequests := prometheus.NewGaugeVec(prometheus.GaugeOpts{ - Namespace: Namespace, - Name: "http_active_requests", - Help: "Number of currently active HTTP requests", - }, []string{"type"}) - - requestDuration := prometheus.NewHistogramVec( - prometheus.HistogramOpts{ + activeRequests: prometheus.NewGaugeVec(prometheus.GaugeOpts{ Namespace: Namespace, - Name: "http_request_duration_seconds", - Help: "Latencies of HTTP requests in seconds", - Buckets: []float64{0.1, 0.25, 0.5, 1, 2, 5, 10, 30, 60, 120, 300, 600, 1800, 3600}, - }, []string{"type", "method", "code"}) - - requestSizeBytes := prometheus.NewHistogramVec(prometheus.HistogramOpts{ - Namespace: Namespace, - Name: "http_request_size_bytes", - Help: "Size of incoming HTTP request in bytes", - Buckets: prometheus.ExponentialBuckets(100, 10, 6), - }, []string{"type", "method", "code"}) - - responseSizeBytes := prometheus.NewHistogramVec(prometheus.HistogramOpts{ - Namespace: Namespace, - Name: "http_response_size_bytes", - Help: "Size of outgoing HTTP response in bytes", - Buckets: prometheus.ExponentialBuckets(100, 10, 6), - }, []string{"type", "method", "code"}, - ) + Name: "http_active_requests", + Help: "Number of currently active HTTP requests", + }, []string{labelResourceType, "type"}), + + requestDuration: prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: Namespace, + Name: "http_request_duration_seconds", + Help: "Latencies of HTTP requests in seconds", + Buckets: []float64{0.1, 0.25, 0.5, 1, 2, 5, 10, 30, 60, 120, 300, 600, 1800, 3600}, + }, []string{labelResourceType, "type", "method", "code"}), + + requestSizeBytes: prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: Namespace, + Name: "http_request_size_bytes", + Help: "Size of incoming HTTP request in bytes", + Buckets: prometheus.ExponentialBuckets(100, 10, 6), + }, []string{labelResourceType, "type", "method", "code"}), - config.Registry.MustRegister(requestsTotal, activeRequests, requestDuration, requestSizeBytes, responseSizeBytes) + responseSizeBytes: prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: Namespace, + Name: "http_response_size_bytes", + Help: "Size of outgoing HTTP response in bytes", + Buckets: prometheus.ExponentialBuckets(100, 10, 6), + }, []string{labelResourceType, "type", "method", "code"}), + } + + registry.MustRegister(m.requestsTotal, m.activeRequests, m.requestDuration, m.requestSizeBytes, m.responseSizeBytes) + + return m +} - opts := promhttp.WithLabelFromCtx(labelRequestType, getRequestTypeFromContext) +func HTTPMiddleware(metrics *HTTPMetrics, resourceType string, next http.Handler) http.HandlerFunc { + resourceTypeOpt := promhttp.WithLabelFromCtx(labelResourceType, func(_ context.Context) string { return resourceType }) + requestTypeOpt := promhttp.WithLabelFromCtx(labelRequestType, getRequestTypeFromContext) base := promhttp.InstrumentHandlerCounter( - requestsTotal, - instrumentHandlerInFlight(activeRequests, + metrics.requestsTotal, + instrumentHandlerInFlight(metrics.activeRequests, resourceType, promhttp.InstrumentHandlerDuration( - requestDuration, + metrics.requestDuration, promhttp.InstrumentHandlerRequestSize( - requestSizeBytes, + metrics.requestSizeBytes, promhttp.InstrumentHandlerResponseSize( - responseSizeBytes, - config.Next, - opts, + metrics.responseSizeBytes, + next, + resourceTypeOpt, + requestTypeOpt, ), - opts, + resourceTypeOpt, + requestTypeOpt, ), - opts, + resourceTypeOpt, + requestTypeOpt, ), ), - opts, + resourceTypeOpt, + requestTypeOpt, ) return func(w http.ResponseWriter, r *http.Request) { @@ -100,12 +113,12 @@ func HTTPMiddleware(config HTTPMiddlewareConfig) http.HandlerFunc { } } -func instrumentHandlerInFlight(activeRequests *prometheus.GaugeVec, next http.Handler) http.Handler { +func instrumentHandlerInFlight(activeRequests *prometheus.GaugeVec, resourceType string, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { requestType := getRequestTypeFromContext(r.Context()) - activeRequests.WithLabelValues(requestType).Inc() - defer activeRequests.WithLabelValues(requestType).Dec() + activeRequests.WithLabelValues(resourceType, requestType).Inc() + defer activeRequests.WithLabelValues(resourceType, requestType).Dec() next.ServeHTTP(w, r) }) diff --git a/internal/metrics/http_middleware_test.go b/internal/metrics/http_middleware_test.go index e441206..9084d49 100644 --- a/internal/metrics/http_middleware_test.go +++ b/internal/metrics/http_middleware_test.go @@ -108,14 +108,14 @@ func TestWithRequestType(t *testing.T) { func TestHTTPMiddleware(t *testing.T) { testRegistry := prometheus.NewRegistry() + httpMetrics := RegisterHTTPMetrics(testRegistry) server := httptest.NewServer(HTTPMiddleware( - HTTPMiddlewareConfig{ - Registry: testRegistry, - Next: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusOK) - }), - }, + httpMetrics, + ResourceTypeKubernetes, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }), )) defer server.Close() @@ -133,27 +133,32 @@ func TestHTTPMiddleware(t *testing.T) { labelsByMetric := testutil.ExtractLabelsFromMetrics(metricFamilies) expectedLabels := map[string]map[string]string{ "twingate_gateway_http_requests_total": { - "type": "http", - "method": "get", - "code": "200", + "resource_type": "kubernetes", + "type": "http", + "method": "get", + "code": "200", }, "twingate_gateway_http_active_requests": { - "type": "http", + "resource_type": "kubernetes", + "type": "http", }, "twingate_gateway_http_request_duration_seconds": { - "type": "http", - "method": "get", - "code": "200", + "resource_type": "kubernetes", + "type": "http", + "method": "get", + "code": "200", }, "twingate_gateway_http_request_size_bytes": { - "type": "http", - "method": "get", - "code": "200", + "resource_type": "kubernetes", + "type": "http", + "method": "get", + "code": "200", }, "twingate_gateway_http_response_size_bytes": { - "type": "http", - "method": "get", - "code": "200", + "resource_type": "kubernetes", + "type": "http", + "method": "get", + "code": "200", }, } assert.Equal(t, expectedLabels, labelsByMetric) diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 8cb9674..6ce8498 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -36,7 +36,7 @@ type Proxy struct { registry *prometheus.Registry logger *zap.Logger - httpProxy *httpproxy.Proxy + httpProxies map[token.ResourceType]*httpproxy.Proxy sshProxy *sshhandler.SSHProxy metricsServer *metrics.Server @@ -45,13 +45,16 @@ type Proxy struct { } func NewProxy(config *gatewayconfig.Config, registry *prometheus.Registry, logger *zap.Logger) (*Proxy, error) { - var httpProxy *httpproxy.Proxy + httpProxies := make(map[token.ResourceType]*httpproxy.Proxy) - handlers := make(map[string]http.Handler) + var ( + roundTripperMetrics *metrics.RoundTripperMetrics + httpMetrics *metrics.HTTPMetrics + ) - var roundTripperMetrics *metrics.RoundTripperMetrics if config.Kubernetes != nil || config.WebApp != nil { roundTripperMetrics = metrics.RegisterRoundTripperMetrics(registry) + httpMetrics = metrics.RegisterHTTPMetrics(registry) } if config.Kubernetes != nil { @@ -65,7 +68,12 @@ func NewProxy(config *gatewayconfig.Config, registry *prometheus.Registry, logge return nil, fmt.Errorf("failed to create Kubernetes handler: %w", err) } - handlers[token.ResourceTypeKubernetes] = k8sHandler + httpProxies[token.ResourceTypeKubernetes] = httpproxy.NewProxy(httpproxy.Config{ + Handler: k8sHandler, + Metrics: httpMetrics, + Logger: logger, + ResourceTypeLabel: metrics.ResourceTypeKubernetes, + }) } if config.WebApp != nil { @@ -75,14 +83,12 @@ func NewProxy(config *gatewayconfig.Config, registry *prometheus.Registry, logge } webAppHandler := webapphandler.NewHandler(*webAppCfg) - handlers[token.ResourceTypeWebApp] = webAppHandler - } - if len(handlers) > 0 { - httpProxy = httpproxy.NewProxy(httpproxy.Config{ - Handlers: handlers, - Registry: registry, - Logger: logger, + httpProxies[token.ResourceTypeWebApp] = httpproxy.NewProxy(httpproxy.Config{ + Handler: webAppHandler, + Metrics: httpMetrics, + Logger: logger, + ResourceTypeLabel: metrics.ResourceTypeWebApp, }) } @@ -108,7 +114,7 @@ func NewProxy(config *gatewayconfig.Config, registry *prometheus.Registry, logge registry: registry, logger: logger, - httpProxy: httpProxy, + httpProxies: httpProxies, sshProxy: sshProxy, metricsServer: metricsServer, }, nil @@ -125,22 +131,22 @@ func (p *Proxy) Start() error { p.listener = listener - channels := make(map[connect.TransportProtocol]chan<- connect.Conn) + channels := make(map[token.ResourceType]chan<- connect.Conn) var sshListener *connect.ProtocolListener if p.sshProxy != nil { sshChannel := make(chan connect.Conn) - channels[connect.TransportSSH] = sshChannel + channels[token.ResourceTypeSSH] = sshChannel sshListener = connect.NewProtocolListener(sshChannel, listener.Addr()) } - var httpListener *connect.ProtocolListener + httpListeners := make(map[token.ResourceType]*connect.ProtocolListener) - if p.httpProxy != nil { - httpChannel := make(chan connect.Conn) - channels[connect.TransportTLS] = httpChannel - httpListener = connect.NewProtocolListener(httpChannel, listener.Addr()) + for resourceType := range p.httpProxies { + ch := make(chan connect.Conn) + channels[resourceType] = ch + httpListeners[resourceType] = connect.NewProtocolListener(ch, listener.Addr()) } connectListener, err := connect.NewListener( @@ -180,17 +186,19 @@ func (p *Proxy) Start() error { }) } - if p.httpProxy != nil { + for resourceType, proxy := range p.httpProxies { g.Go(func() error { - p.logger.Info("Starting HTTP proxy") + p.logger.Info("Starting HTTP proxy", zap.String("resource_type", string(resourceType))) - err := p.httpProxy.Start(httpListener) + err := proxy.Start(httpListeners[resourceType]) if errors.Is(err, http.ErrServerClosed) { return nil } if err != nil { - p.logger.Error("HTTP proxy stopped with error", zap.Error(err)) + p.logger.Error("HTTP proxy stopped with error", + zap.String("resource_type", string(resourceType)), + zap.Error(err)) } return err @@ -245,9 +253,11 @@ func (p *Proxy) shutdown() { } } - if p.httpProxy != nil { - if err := p.httpProxy.Shutdown(ctx); err != nil { - p.logger.Error("Failed to shut down HTTP proxy", zap.Error(err)) + for resourceType, proxy := range p.httpProxies { + if err := proxy.Shutdown(ctx); err != nil { + p.logger.Error("Failed to shut down HTTP proxy", + zap.String("resource_type", string(resourceType)), + zap.Error(err)) } } diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 40ce41c..c11df12 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -68,14 +68,16 @@ func TestNewProxy_Success(t *testing.T) { assert.Equal(t, registry, p.registry) assert.Equal(t, logger, p.logger) - assert.NotNil(t, p.httpProxy) + assert.Len(t, p.httpProxies, 1) + assert.Contains(t, p.httpProxies, token.ResourceTypeKubernetes) assert.NotNil(t, p.sshProxy) assert.NotNil(t, p.metricsServer) } -func TestNewProxy_KubernetesOnly(t *testing.T) { +func TestNewProxy_HTTPOnly(t *testing.T) { config := fullConfig config.SSH = nil + config.WebApp = &gatewayconfig.WebAppConfig{Headers: map[string]string{}} registry := prometheus.NewRegistry() logger, err := NewLogger(DefaultLoggerName, false) @@ -85,7 +87,9 @@ func TestNewProxy_KubernetesOnly(t *testing.T) { require.NoError(t, err) assert.NotNil(t, p) - assert.NotNil(t, p.httpProxy) + assert.Len(t, p.httpProxies, 2) + assert.Contains(t, p.httpProxies, token.ResourceTypeKubernetes) + assert.Contains(t, p.httpProxies, token.ResourceTypeWebApp) assert.Nil(t, p.sshProxy) } @@ -102,7 +106,7 @@ func TestNewProxy_SSHOnly(t *testing.T) { require.NoError(t, err) assert.NotNil(t, p) assert.NotNil(t, p.sshProxy) - assert.Nil(t, p.httpProxy) + assert.Empty(t, p.httpProxies) } func createTestProxy(t *testing.T) (*Proxy, net.Listener) { @@ -140,12 +144,15 @@ func TestShutdown_ClosesAllComponents(t *testing.T) { require.NoError(t, err) httpProxy := httpproxy.NewProxy(httpproxy.Config{ - Handlers: map[string]http.Handler{token.ResourceTypeKubernetes: k8sHandler}, - Registry: registry, - Logger: zap.NewNop(), + Handler: k8sHandler, + Metrics: metrics.RegisterHTTPMetrics(registry), + ResourceTypeLabel: metrics.ResourceTypeKubernetes, + Logger: zap.NewNop(), }) - p.httpProxy = httpProxy + p.httpProxies = map[token.ResourceType]*httpproxy.Proxy{ + token.ResourceTypeKubernetes: httpProxy, + } // Start HTTP proxy on a protocol listener httpChannel := make(chan connect.Conn) @@ -154,7 +161,7 @@ func TestShutdown_ClosesAllComponents(t *testing.T) { httpDone := make(chan error, 1) go func() { - httpDone <- p.httpProxy.Start(httpListener) + httpDone <- httpProxy.Start(httpListener) }() // Create and attach a real SSH proxy @@ -230,11 +237,11 @@ func TestShutdown_NilComponents(t *testing.T) { p := &Proxy{ logger: zap.NewNop(), listener: nil, - httpProxy: nil, + httpProxies: nil, sshProxy: nil, metricsServer: metricsServer, } - // Should not panic with nil listener, httpProxy, and sshProxy + // Should not panic with nil listener, httpProxies, and sshProxy p.shutdown() } diff --git a/internal/token/gat_claims.go b/internal/token/gat_claims.go index f95db08..dc67b3c 100644 --- a/internal/token/gat_claims.go +++ b/internal/token/gat_claims.go @@ -98,12 +98,12 @@ type Device struct { Location GeoIPLocation `json:"location,omitzero"` } -type ResourceType = string +type ResourceType string const ( - ResourceTypeKubernetes = "KUBERNETES" - ResourceTypeSSH = "SSH" - ResourceTypeWebApp = "WEB_APP" + ResourceTypeKubernetes ResourceType = "KUBERNETES" + ResourceTypeSSH ResourceType = "SSH" + ResourceTypeWebApp ResourceType = "WEB_APP" ) type Resource struct { diff --git a/internal/webapphandler/handler_test.go b/internal/webapphandler/handler_test.go index 788a220..87defe6 100644 --- a/internal/webapphandler/handler_test.go +++ b/internal/webapphandler/handler_test.go @@ -4,6 +4,7 @@ package webapphandler import ( + "context" "net/http" "net/http/httptest" "net/http/httputil" @@ -15,6 +16,8 @@ import ( "go.uber.org/zap" "gateway/internal/connect" + "gateway/internal/httpproxy" + "gateway/internal/metrics" "gateway/internal/token" "gateway/internal/webapphandler/template" ) @@ -34,6 +37,31 @@ func mustParse(t *testing.T, templates map[string]string) map[string]*template.T return result } +func TestNewHandler_PanicsOnRewriteError(t *testing.T) { + connMetrics := connect.CreateProxyConnMetrics(prometheus.NewRegistry()) + conn := connect.NewProxyConn(nil, nil, nil, zap.NewNop(), connMetrics) + conn.Claims = &token.GATClaims{ + User: token.User{Username: "alice@acme.com"}, + } + + unknownKeyTemplate, err := template.New("{{twingate.nonexistent}}") + require.NoError(t, err) + + handler := NewHandler(Config{ + headers: map[string]*template.Template{"X-Bad": unknownKeyTemplate}, + roundTripperMetrics: metrics.RegisterRoundTripperMetrics(prometheus.NewRegistry()), + logger: zap.NewNop(), + }) + + req := httptest.NewRequest(http.MethodGet, "http://test/api", nil) + ctx := context.WithValue(req.Context(), httpproxy.ConnContextKey{}, conn) + req = req.WithContext(ctx) + + assert.Panics(t, func() { + handler.ServeHTTP(httptest.NewRecorder(), req) + }) +} + func TestRewrite(t *testing.T) { baseClaims := &token.GATClaims{ User: token.User{ From 558f37e9afeb34458ddd12b9d3198e47e9deadf5 Mon Sep 17 00:00:00 2001 From: Clement Tee Date: Tue, 2 Jun 2026 16:55:34 +0800 Subject: [PATCH 13/17] Remove `TransportProtocol` type --- internal/connect/conn.go | 21 ----------- internal/connect/conn_test.go | 62 ------------------------------- internal/connect/listener.go | 2 +- internal/connect/listener_test.go | 27 ++++---------- internal/token/gat_claims.go | 4 ++ internal/token/gat_claims_test.go | 34 +++++++++++++++++ 6 files changed, 46 insertions(+), 104 deletions(-) diff --git a/internal/connect/conn.go b/internal/connect/conn.go index 09e7477..88d915b 100644 --- a/internal/connect/conn.go +++ b/internal/connect/conn.go @@ -32,13 +32,6 @@ func httpResponseString(httpCode int) string { return fmt.Sprintf("HTTP/1.1 %d %s\r\n\r\n", httpCode, http.StatusText(httpCode)) } -type TransportProtocol int - -const ( - TransportTLS TransportProtocol = iota - TransportSSH -) - // Conn is a custom connection that wraps the underlying TCP net.Conn, handling downstream // proxy (Twingate Client)'s authentication via the initial CONNECT message. It handles 2 TLS // upgrades: with downstream proxy and then optionally with downstream client e.g. `kubectl`. @@ -49,8 +42,6 @@ type Conn interface { GetAddress() string GetToken() string Authenticate() error - TransportProtocol() TransportProtocol - ShouldUpgradeTLS() bool UpgradeToTLS() error Close() error @@ -100,18 +91,6 @@ func (p *ProxyConn) Close() error { return p.Conn.Close() } -func (p *ProxyConn) TransportProtocol() TransportProtocol { - if p.GATClaims().Resource.Type == token.ResourceTypeSSH { - return TransportSSH - } - - return TransportTLS -} - -func (p *ProxyConn) ShouldUpgradeTLS() bool { - return p.TransportProtocol() == TransportTLS && p.GATClaims().Resource.Type != token.ResourceTypeWebApp -} - func (p *ProxyConn) GATClaims() *token.GATClaims { return p.Claims } diff --git a/internal/connect/conn_test.go b/internal/connect/conn_test.go index ef5349c..0bc06c3 100644 --- a/internal/connect/conn_test.go +++ b/internal/connect/conn_test.go @@ -83,68 +83,6 @@ func TestProxyConn_setConnectInfo(t *testing.T) { }) } -func TestProxyConn_TransportProtocol(t *testing.T) { - tests := []struct { - resourceType token.ResourceType - expected TransportProtocol - }{ - { - resourceType: token.ResourceTypeKubernetes, - expected: TransportTLS, - }, - { - resourceType: token.ResourceTypeSSH, - expected: TransportSSH, - }, - } - - for _, tt := range tests { - t.Run(string(tt.resourceType), func(t *testing.T) { - claims := &token.GATClaims{ - Resource: token.Resource{Type: tt.resourceType}, - } - proxyConn := &ProxyConn{Claims: claims} - - assert.Equal(t, tt.expected, proxyConn.TransportProtocol()) - }) - } -} - -func TestProxyConn_ShouldUpgradeTLS(t *testing.T) { - tests := []struct { - name string - resourceType token.ResourceType - expected bool - }{ - { - name: "Kubernetes should upgrade TLS", - resourceType: token.ResourceTypeKubernetes, - expected: true, - }, - { - name: "SSH should not upgrade TLS", - resourceType: token.ResourceTypeSSH, - expected: false, - }, - { - name: "Web app should not upgrade TLS", - resourceType: token.ResourceTypeWebApp, - expected: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - claims := &token.GATClaims{ - Resource: token.Resource{Type: tt.resourceType}, - } - proxyConn := &ProxyConn{Claims: claims} - - assert.Equal(t, tt.expected, proxyConn.ShouldUpgradeTLS()) - }) - } -} - func TestProxyConn_Close(t *testing.T) { conn := &mockConn{} timer := time.NewTimer(0 * time.Millisecond) diff --git a/internal/connect/listener.go b/internal/connect/listener.go index 4268e51..f7a8cd0 100644 --- a/internal/connect/listener.go +++ b/internal/connect/listener.go @@ -183,7 +183,7 @@ func (l *Listener) Serve(ctx context.Context, listener net.Listener) error { return } - if proxyConn.ShouldUpgradeTLS() { + if proxyConn.GATClaims().ShouldUpgradeTLS() { if err := proxyConn.UpgradeToTLS(); err != nil { l.logger.Error("Failed to upgrade to TLS", zap.Error(err)) diff --git a/internal/connect/listener_test.go b/internal/connect/listener_test.go index 5d19e53..0044ba8 100644 --- a/internal/connect/listener_test.go +++ b/internal/connect/listener_test.go @@ -25,8 +25,7 @@ import ( type mockProxyConn struct { net.Conn - transportProtocol TransportProtocol - Claims *token.GATClaims + Claims *token.GATClaims isClosed atomic.Bool @@ -44,10 +43,6 @@ func (m *mockProxyConn) IsClosed() bool { return m.isClosed.Load() } -func (m *mockProxyConn) TransportProtocol() TransportProtocol { - return m.transportProtocol -} - func (m *mockProxyConn) GATClaims() *token.GATClaims { return m.Claims } @@ -78,10 +73,6 @@ func (m *mockProxyConn) Authenticate() error { return nil } -func (m *mockProxyConn) ShouldUpgradeTLS() bool { - return m.TransportProtocol() == TransportTLS && m.GATClaims().Resource.Type != token.ResourceTypeWebApp -} - func (m *mockProxyConn) UpgradeToTLS() error { return nil } @@ -162,9 +153,8 @@ func TestListener_Serve_HTTPS(t *testing.T) { fixtures.listener.proxyConnFactory = func(conn net.Conn, _ *tls.Config, _ Validator, _ *zap.Logger) Conn { return &mockProxyConn{ - Conn: conn, - transportProtocol: TransportTLS, - Claims: kubernetesClaims, + Conn: conn, + Claims: kubernetesClaims, } } @@ -196,7 +186,6 @@ func TestListener_Serve_HTTPS(t *testing.T) { select { case conn := <-fixtures.httpChannel: require.False(t, conn.(*mockProxyConn).IsClosed()) - require.Equal(t, TransportTLS, conn.TransportProtocol()) require.Equal(t, kubernetesClaims, conn.GATClaims()) case <-time.After(1 * time.Second): t.Fatal("timeout waiting for HTTP connection") @@ -216,9 +205,8 @@ func TestListener_Serve_SSH(t *testing.T) { fixtures.listener.proxyConnFactory = func(conn net.Conn, _ *tls.Config, _ Validator, _ *zap.Logger) Conn { return &mockProxyConn{ - Conn: conn, - transportProtocol: TransportSSH, - Claims: sshClaims, + Conn: conn, + Claims: sshClaims, } } @@ -416,9 +404,8 @@ func TestListener_Serve_GracefulShutdown(t *testing.T) { listener.proxyConnFactory = func(conn net.Conn, _ *tls.Config, _ Validator, _ *zap.Logger) Conn { return &mockProxyConn{ - Conn: conn, - transportProtocol: TransportTLS, - Claims: kubernetesClaims, + Conn: conn, + Claims: kubernetesClaims, } } diff --git a/internal/token/gat_claims.go b/internal/token/gat_claims.go index dc67b3c..0ea5c4c 100644 --- a/internal/token/gat_claims.go +++ b/internal/token/gat_claims.go @@ -61,6 +61,10 @@ func (p GATClaims) Validate() error { return nil } +func (p GATClaims) ShouldUpgradeTLS() bool { + return p.Resource.Type == ResourceTypeKubernetes +} + func (p GATClaims) getHeaderType() string { return "GAT" } diff --git a/internal/token/gat_claims_test.go b/internal/token/gat_claims_test.go index bba08c7..f3a6b13 100644 --- a/internal/token/gat_claims_test.go +++ b/internal/token/gat_claims_test.go @@ -13,9 +13,43 @@ import ( "time" "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func TestGATClaims_ShouldUpgradeTLS(t *testing.T) { + tests := []struct { + name string + resourceType ResourceType + expected bool + }{ + { + name: "Kubernetes should upgrade TLS", + resourceType: ResourceTypeKubernetes, + expected: true, + }, + { + name: "SSH should not upgrade TLS", + resourceType: ResourceTypeSSH, + expected: false, + }, + { + name: "Web app should not upgrade TLS", + resourceType: ResourceTypeWebApp, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + claims := &GATClaims{ + Resource: Resource{Type: tt.resourceType}, + } + assert.Equal(t, tt.expected, claims.ShouldUpgradeTLS()) + }) + } +} + func TestGATTokenClaims_Validate(t *testing.T) { privateKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) From 289c6cfb31c0681da296bf668328f0a1daefad87 Mon Sep 17 00:00:00 2001 From: Clement Tee Date: Wed, 3 Jun 2026 00:30:42 +0800 Subject: [PATCH 14/17] Update code --- internal/httpproxy/template/template.go | 82 ----------- internal/httpproxy/template/template_test.go | 145 ------------------- 2 files changed, 227 deletions(-) delete mode 100644 internal/httpproxy/template/template.go delete mode 100644 internal/httpproxy/template/template_test.go diff --git a/internal/httpproxy/template/template.go b/internal/httpproxy/template/template.go deleted file mode 100644 index 9ef5982..0000000 --- a/internal/httpproxy/template/template.go +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright (c) Twingate Inc. -// SPDX-License-Identifier: MPL-2.0 - -package template - -import ( - "errors" - "fmt" - "regexp" - "strings" - "unicode" -) - -const ( - allowedNamespace = "twingate" -) - -var ( - ErrInvalidTemplate = errors.New("invalid template") - ErrUnknownKey = errors.New("unknown key") -) - -var templateRe = regexp.MustCompile( - `^([^{}]*)` + // prefix (no braces allowed) - `{{\s*` + // opening braces - `([a-zA-Z0-9_-]+)` + // namespace - `\.` + - `([a-zA-Z0-9_-]+)` + // key - `\s*}}` + // closing braces - `([^{}]*)$`, // suffix (no braces allowed) -) - -type Template struct { - prefix string - key string - suffix string -} - -// New parses a string like " {{.}} " into a Template. -// If there is no template variable (just a static string), the key and suffix are empty and the prefix is the static string. -func New(s string) (*Template, error) { - match := templateRe.FindStringSubmatch(s) - - if match == nil { - if strings.Contains(s, "{{") || strings.Contains(s, "}}") { - return nil, fmt.Errorf("%w: unsupported syntax. Syntax must be {{twingate.key}} ", ErrInvalidTemplate) - } - - return &Template{prefix: strings.TrimSpace(s)}, nil - } - - prefix, namespace, key, suffix := match[1], match[2], match[3], match[4] - - if namespace != allowedNamespace { - return nil, fmt.Errorf("%w: unsupported namespace %q", ErrInvalidTemplate, namespace) - } - - return &Template{ - prefix: strings.TrimLeftFunc(prefix, unicode.IsSpace), - key: key, - suffix: strings.TrimRightFunc(suffix, unicode.IsSpace), - }, nil -} - -func (t *Template) Key() string { - return t.key -} - -// Evaluate replaces the key in the template with the corresponding value from the map -// and returns the resulting string along with the prefix and suffix. -func (t *Template) Evaluate(values map[string]string) (string, error) { - if t.key == "" { - return t.prefix, nil - } - - result, ok := values[t.key] - if !ok { - return "", fmt.Errorf("%w: %q", ErrUnknownKey, t.key) - } - - return t.prefix + result + t.suffix, nil -} diff --git a/internal/httpproxy/template/template_test.go b/internal/httpproxy/template/template_test.go deleted file mode 100644 index 0cb5b46..0000000 --- a/internal/httpproxy/template/template_test.go +++ /dev/null @@ -1,145 +0,0 @@ -// Copyright (c) Twingate Inc. -// SPDX-License-Identifier: MPL-2.0 - -package template - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestParser_NewTemplate(t *testing.T) { - tests := []struct { - name string - input string - wantPrefix string - wantKey string - wantSuffix string - wantErr error - errSubstr string - }{ - { - name: "plain text", - input: " static-value ", - wantPrefix: "static-value", - }, - { - name: "empty string", - input: "", - }, - { - name: "template only", - input: "{{twingate.jwt}}", - wantKey: "jwt", - }, - { - name: "template with leading and trailing space", - input: "{{ twingate.jwt }}", - wantKey: "jwt", - }, - { - name: "template with prefix", - input: " Bearer {{twingate.jwt}}", - wantPrefix: "Bearer ", - wantKey: "jwt", - }, - { - name: "template with suffix", - input: "{{twingate.username}}/profile ", - wantKey: "username", - wantSuffix: "/profile", - }, - { - name: "invalid template", - input: "{{invalid}}", - wantErr: ErrInvalidTemplate, - errSubstr: "unsupported syntax", - }, - { - name: "missing opening braces", - input: "twingate.jwt }}", - wantErr: ErrInvalidTemplate, - errSubstr: "unsupported syntax", - }, - { - name: "missing closing braces", - input: "{{ twingate.jwt", - wantErr: ErrInvalidTemplate, - errSubstr: "unsupported syntax", - }, - { - name: "multiple templates", - input: "{{twingate.username}} {{twingate.groups}}", - wantErr: ErrInvalidTemplate, - errSubstr: "unsupported syntax", - }, - { - name: "non-twingate namespace", - input: "{{other.key}}", - wantErr: ErrInvalidTemplate, - errSubstr: "unsupported namespace", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - template, err := New(tt.input) - - if tt.wantErr != nil { - require.ErrorIs(t, err, tt.wantErr) - assert.Contains(t, err.Error(), tt.errSubstr) - - return - } - - require.NoError(t, err) - - assert.Equal(t, tt.wantPrefix, template.prefix) - assert.Equal(t, tt.wantKey, template.key) - assert.Equal(t, tt.wantSuffix, template.suffix) - }) - } -} - -func TestParser_Evaluate(t *testing.T) { - tests := []struct { - name string - template Template - values map[string]string - want string - wantErr error - errSubstr string - }{ - { - name: "Success", - template: Template{prefix: "Prefix ", key: "foo", suffix: " suffix"}, - values: map[string]string{"foo": "bar", "extra": "foo"}, - want: "Prefix bar suffix", - }, - { - name: "Missing key", - template: Template{prefix: "Bearer ", key: "jwt", suffix: ""}, - values: map[string]string{}, - wantErr: ErrUnknownKey, - errSubstr: "jwt", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := tt.template.Evaluate(tt.values) - - if tt.wantErr != nil { - require.ErrorIs(t, err, tt.wantErr) - assert.Contains(t, err.Error(), tt.errSubstr) - - return - } - - require.NoError(t, err) - assert.Equal(t, tt.want, result) - }) - } -} From c921c379038c85f30eb4ad3cc8a67ac3f2c64f6b Mon Sep 17 00:00:00 2001 From: Clement Tee Date: Wed, 3 Jun 2026 17:34:48 +0800 Subject: [PATCH 15/17] Remove duplicate validation in `internal/config` --- internal/config/config.go | 26 --------- internal/config/config_test.go | 61 --------------------- internal/webapphandler/config.go | 5 ++ internal/webapphandler/config_test.go | 14 ++++- internal/webapphandler/template/template.go | 1 + 5 files changed, 17 insertions(+), 90 deletions(-) diff --git a/internal/config/config.go b/internal/config/config.go index 26ef353..0618f25 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -8,7 +8,6 @@ import ( "fmt" "net/http" "os" - "slices" "strings" "time" @@ -16,8 +15,6 @@ import ( "go.uber.org/zap" "go.yaml.in/yaml/v4" "golang.org/x/crypto/ssh" - - "gateway/internal/webapphandler/template" ) var ( @@ -26,7 +23,6 @@ var ( ErrDuplicateUpstream = errors.New("duplicate upstream name") ErrInvalidSSHKeyType = errors.New("invalid SSH key type") ErrNegativeTTL = errors.New("TTL must be non-negative") - ErrUnsupportedKey = errors.New("unsupported key") ) const ( @@ -277,12 +273,6 @@ func (c *Config) Validate() error { } } - if c.WebApp != nil { - if err := c.WebApp.Validate(); err != nil { - return fmt.Errorf("webApp config: %w", err) - } - } - // Check that at least one protocol is configured if c.Kubernetes == nil && c.SSH == nil && c.WebApp == nil { return fmt.Errorf("%w: at least one protocol (Kubernetes, SSH, or WebApp) must be configured", ErrRequired) @@ -451,22 +441,6 @@ func (v *SSHCAVaultConfig) Validate() error { return nil } -func (w *WebAppConfig) Validate() error { - for name, value := range w.Headers { - tmpl, err := template.New(value) - if err != nil { - return fmt.Errorf("header %q: %w", name, err) - } - - key := tmpl.Key() - if key != "" && !slices.Contains(template.AllowedWebAppKeys, key) { - return fmt.Errorf("header %q: %w %q", name, ErrUnsupportedKey, key) - } - } - - return nil -} - var ( ErrConflictingAuthConfig = errors.New("only one of 'token', 'appRole', 'gcp', or 'aws' can be specified for Vault auth") ErrConflictingSecretIDConfig = errors.New("only one of 'secretID' or 'secretIDFile' can be specified") diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 4d0b446..9d5295f 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -368,23 +368,6 @@ func TestConfig_Validate(t *testing.T) { wantErr: true, errContains: "at least one protocol", }, - { - name: "invalid WebApp header template", - config: &Config{ - Twingate: TwingateConfig{Network: "test"}, - Port: 8443, - MetricsPort: 9090, - TLS: TLSConfig{ - CertificateFile: "tls.crt", - PrivateKeyFile: "tls.key", - }, - WebApp: &WebAppConfig{ - Headers: map[string]string{"Authorization": "Bearer {{twingate.jwt"}, - }, - }, - wantErr: true, - errContains: "webApp config", - }, } for _, tt := range tests { @@ -1124,47 +1107,3 @@ func TestSSHCAVaultAWSConfig_Validate(t *testing.T) { }) } } - -func TestWebAppConfig_Validate(t *testing.T) { - tests := []struct { - name string - config WebAppConfig - wantErr bool - errContains string - }{ - { - name: "empty headers", - config: WebAppConfig{Headers: map[string]string{}}, - wantErr: false, - }, - { - name: "valid template variable", - config: WebAppConfig{Headers: map[string]string{"Authorization": "Bearer {{twingate.jwt}}"}}, - wantErr: false, - }, - { - name: "invalid template syntax", - config: WebAppConfig{Headers: map[string]string{"X-Bad": "{{invalid"}}, - wantErr: true, - errContains: "unsupported syntax", - }, - { - name: "unsupported key", - config: WebAppConfig{Headers: map[string]string{"X-Bad": "{{twingate.unknown}}"}}, - wantErr: true, - errContains: "unsupported key", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := tt.config.Validate() - if tt.wantErr { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.errContains) - } else { - require.NoError(t, err) - } - }) - } -} diff --git a/internal/webapphandler/config.go b/internal/webapphandler/config.go index 323f5c0..b9bbada 100644 --- a/internal/webapphandler/config.go +++ b/internal/webapphandler/config.go @@ -5,6 +5,7 @@ package webapphandler import ( "fmt" + "slices" "go.uber.org/zap" @@ -27,6 +28,10 @@ func NewConfig(configHeaders map[string]string, roundTripperMetrics *metrics.Rou return nil, fmt.Errorf("header %q: %w", name, err) } + if key := tmpl.Key(); key != "" && !slices.Contains(template.AllowedWebAppKeys, key) { + return nil, fmt.Errorf("header %q: %w %q", name, template.ErrUnsupportedKey, key) + } + headers[name] = tmpl } diff --git a/internal/webapphandler/config_test.go b/internal/webapphandler/config_test.go index d678917..ee67a8a 100644 --- a/internal/webapphandler/config_test.go +++ b/internal/webapphandler/config_test.go @@ -14,9 +14,10 @@ import ( func TestNewConfig(t *testing.T) { tests := []struct { - name string - headers map[string]string - wantErr error + name string + headers map[string]string + wantErr error + errContains string }{ { name: "valid header templates", @@ -31,6 +32,13 @@ func TestNewConfig(t *testing.T) { }, wantErr: template.ErrInvalidTemplate, }, + { + name: "unsupported key", + headers: map[string]string{ + "X-Bad": "{{twingate.unknown}}", + }, + wantErr: template.ErrUnsupportedKey, + }, { name: "empty headers", headers: map[string]string{}, diff --git a/internal/webapphandler/template/template.go b/internal/webapphandler/template/template.go index de45834..218f94f 100644 --- a/internal/webapphandler/template/template.go +++ b/internal/webapphandler/template/template.go @@ -38,6 +38,7 @@ var AllowedWebAppKeys = []string{ var ( ErrInvalidTemplate = errors.New("invalid template") ErrUnknownKey = errors.New("unknown key") + ErrUnsupportedKey = errors.New("unsupported key") ) var templateRe = regexp.MustCompile( From e8bff292f0743bb37f22a38d1f510fc8fc022bb3 Mon Sep 17 00:00:00 2001 From: Clement Tee Date: Wed, 3 Jun 2026 17:43:25 +0800 Subject: [PATCH 16/17] Add `ResourceType` to roundtripper metrics --- internal/httpproxy/proxy.go | 10 +++++----- internal/httpproxy/proxy_test.go | 16 ++++++++-------- internal/metrics/http_middleware.go | 6 +++--- internal/metrics/round_tripper.go | 17 +++++++++-------- internal/proxy/proxy.go | 16 ++++++++-------- internal/proxy/proxy_test.go | 8 ++++---- 6 files changed, 37 insertions(+), 36 deletions(-) diff --git a/internal/httpproxy/proxy.go b/internal/httpproxy/proxy.go index a2bef30..c3609e5 100644 --- a/internal/httpproxy/proxy.go +++ b/internal/httpproxy/proxy.go @@ -27,10 +27,10 @@ func ProxyConnFromContext(ctx context.Context) *connect.ProxyConn { } type Config struct { - Handler http.Handler - Metrics *metrics.HTTPMetrics - Logger *zap.Logger - ResourceTypeLabel string + Handler http.Handler + Metrics *metrics.HTTPMetrics + Logger *zap.Logger + ResourceType metrics.ResourceType } type Proxy struct { @@ -38,7 +38,7 @@ type Proxy struct { } func NewProxy(cfg Config) *Proxy { - handler := metrics.HTTPMiddleware(cfg.Metrics, cfg.ResourceTypeLabel, auditMiddleware(auditMiddlewareConfig{ + handler := metrics.HTTPMiddleware(cfg.Metrics, cfg.ResourceType, auditMiddleware(auditMiddlewareConfig{ next: cfg.Handler, logger: cfg.Logger, })) diff --git a/internal/httpproxy/proxy_test.go b/internal/httpproxy/proxy_test.go index a999def..5ede9e0 100644 --- a/internal/httpproxy/proxy_test.go +++ b/internal/httpproxy/proxy_test.go @@ -69,10 +69,10 @@ func TestProxy_ForwardRequest(t *testing.T) { require.NoError(t, err) proxy := NewProxy(Config{ - Metrics: metrics.RegisterHTTPMetrics(prometheus.NewRegistry()), - ResourceTypeLabel: metrics.ResourceTypeKubernetes, - Handler: handler, - Logger: zap.NewNop(), + Metrics: metrics.RegisterHTTPMetrics(prometheus.NewRegistry()), + ResourceType: metrics.ResourceTypeKubernetes, + Handler: handler, + Logger: zap.NewNop(), }) go func() { @@ -104,10 +104,10 @@ func TestProxy_Shutdown(t *testing.T) { require.NoError(t, err) proxy := NewProxy(Config{ - Metrics: metrics.RegisterHTTPMetrics(prometheus.NewRegistry()), - ResourceTypeLabel: metrics.ResourceTypeKubernetes, - Handler: handler, - Logger: zap.NewNop(), + Metrics: metrics.RegisterHTTPMetrics(prometheus.NewRegistry()), + ResourceType: metrics.ResourceTypeKubernetes, + Handler: handler, + Logger: zap.NewNop(), }) done := make(chan error, 1) diff --git a/internal/metrics/http_middleware.go b/internal/metrics/http_middleware.go index 47a371c..d71626f 100644 --- a/internal/metrics/http_middleware.go +++ b/internal/metrics/http_middleware.go @@ -80,13 +80,13 @@ func RegisterHTTPMetrics(registry *prometheus.Registry) *HTTPMetrics { return m } -func HTTPMiddleware(metrics *HTTPMetrics, resourceType string, next http.Handler) http.HandlerFunc { - resourceTypeOpt := promhttp.WithLabelFromCtx(labelResourceType, func(_ context.Context) string { return resourceType }) +func HTTPMiddleware(metrics *HTTPMetrics, resourceType ResourceType, next http.Handler) http.HandlerFunc { + resourceTypeOpt := promhttp.WithLabelFromCtx(labelResourceType, func(_ context.Context) string { return string(resourceType) }) requestTypeOpt := promhttp.WithLabelFromCtx(labelRequestType, getRequestTypeFromContext) base := promhttp.InstrumentHandlerCounter( metrics.requestsTotal, - instrumentHandlerInFlight(metrics.activeRequests, resourceType, + instrumentHandlerInFlight(metrics.activeRequests, string(resourceType), promhttp.InstrumentHandlerDuration( metrics.requestDuration, promhttp.InstrumentHandlerRequestSize( diff --git a/internal/metrics/round_tripper.go b/internal/metrics/round_tripper.go index 8d0ee0f..af1955b 100644 --- a/internal/metrics/round_tripper.go +++ b/internal/metrics/round_tripper.go @@ -15,11 +15,12 @@ import ( const labelResourceType = "resource_type" -// Resource type values. +// ResourceType identifies the upstream resource type in metric labels. +type ResourceType string const ( - ResourceTypeKubernetes = "kubernetes" - ResourceTypeWebApp = "web_app" + ResourceTypeKubernetes ResourceType = "kubernetes" + ResourceTypeWebApp ResourceType = "web_app" ) type RoundTripperMetrics struct { @@ -56,8 +57,8 @@ func RegisterRoundTripperMetrics(registry *prometheus.Registry) *RoundTripperMet return c } -func InstrumentRoundTripper(metrics *RoundTripperMetrics, resourceType string, next http.RoundTripper) promhttp.RoundTripperFunc { - resourceTypeOpt := promhttp.WithLabelFromCtx(labelResourceType, func(_ context.Context) string { return resourceType }) +func InstrumentRoundTripper(metrics *RoundTripperMetrics, resourceType ResourceType, next http.RoundTripper) promhttp.RoundTripperFunc { + resourceTypeOpt := promhttp.WithLabelFromCtx(labelResourceType, func(_ context.Context) string { return string(resourceType) }) requestTypeOpt := promhttp.WithLabelFromCtx(labelRequestType, getRequestTypeFromContext) base := promhttp.InstrumentRoundTripperCounter( @@ -81,12 +82,12 @@ func InstrumentRoundTripper(metrics *RoundTripperMetrics, resourceType string, n } } -func instrumentRoundTripperInFlight(activeRequests *prometheus.GaugeVec, resourceType string, next http.RoundTripper) promhttp.RoundTripperFunc { +func instrumentRoundTripperInFlight(activeRequests *prometheus.GaugeVec, resourceType ResourceType, next http.RoundTripper) promhttp.RoundTripperFunc { return func(r *http.Request) (*http.Response, error) { requestType := getRequestTypeFromContext(r.Context()) - activeRequests.WithLabelValues(resourceType, requestType).Inc() - defer activeRequests.WithLabelValues(resourceType, requestType).Dec() + activeRequests.WithLabelValues(string(resourceType), requestType).Inc() + defer activeRequests.WithLabelValues(string(resourceType), requestType).Dec() return next.RoundTrip(r) } diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 6ce8498..3e068b8 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -69,10 +69,10 @@ func NewProxy(config *gatewayconfig.Config, registry *prometheus.Registry, logge } httpProxies[token.ResourceTypeKubernetes] = httpproxy.NewProxy(httpproxy.Config{ - Handler: k8sHandler, - Metrics: httpMetrics, - Logger: logger, - ResourceTypeLabel: metrics.ResourceTypeKubernetes, + Handler: k8sHandler, + Metrics: httpMetrics, + Logger: logger, + ResourceType: metrics.ResourceTypeKubernetes, }) } @@ -85,10 +85,10 @@ func NewProxy(config *gatewayconfig.Config, registry *prometheus.Registry, logge webAppHandler := webapphandler.NewHandler(*webAppCfg) httpProxies[token.ResourceTypeWebApp] = httpproxy.NewProxy(httpproxy.Config{ - Handler: webAppHandler, - Metrics: httpMetrics, - Logger: logger, - ResourceTypeLabel: metrics.ResourceTypeWebApp, + Handler: webAppHandler, + Metrics: httpMetrics, + Logger: logger, + ResourceType: metrics.ResourceTypeWebApp, }) } diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index c11df12..a5718bf 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -144,10 +144,10 @@ func TestShutdown_ClosesAllComponents(t *testing.T) { require.NoError(t, err) httpProxy := httpproxy.NewProxy(httpproxy.Config{ - Handler: k8sHandler, - Metrics: metrics.RegisterHTTPMetrics(registry), - ResourceTypeLabel: metrics.ResourceTypeKubernetes, - Logger: zap.NewNop(), + Handler: k8sHandler, + Metrics: metrics.RegisterHTTPMetrics(registry), + ResourceType: metrics.ResourceTypeKubernetes, + Logger: zap.NewNop(), }) p.httpProxies = map[token.ResourceType]*httpproxy.Proxy{ From a6dfef22c80826c553f06e87d260766865241340 Mon Sep 17 00:00:00 2001 From: Clement Tee Date: Wed, 3 Jun 2026 17:43:50 +0800 Subject: [PATCH 17/17] Better code for web app rewrite --- internal/webapphandler/handler.go | 30 ++++++++++++++------------ internal/webapphandler/handler_test.go | 26 ++++++++++++++++------ 2 files changed, 35 insertions(+), 21 deletions(-) diff --git a/internal/webapphandler/handler.go b/internal/webapphandler/handler.go index b4fac94..7cf144a 100644 --- a/internal/webapphandler/handler.go +++ b/internal/webapphandler/handler.go @@ -15,7 +15,6 @@ import ( "gateway/internal/connect" "gateway/internal/httpproxy" "gateway/internal/metrics" - "gateway/internal/token" "gateway/internal/webapphandler/template" ) @@ -43,31 +42,34 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.proxy.ServeHTTP(w, r) } -func rewrite(r *httputil.ProxyRequest, conn *connect.ProxyConn, headers map[string]*template.Template) error { - targetURL := &url.URL{ - Scheme: "http", // plain HTTP — no upstream TLS - Host: conn.GetAddress(), - } - r.SetURL(targetURL) - +func buildVariables(conn *connect.ProxyConn) map[string]string { claims := conn.GATClaims() - clientLocation := claims.Device.Location - clientGeoLatLong := "" - if clientLocation != (token.GeoIPLocation{}) { - clientGeoLatLong = fmt.Sprintf("%v,%v", clientLocation.Lat, clientLocation.Lon) + latLong := "" + if clientLocation.Lat != 0 || clientLocation.Lon != 0 { + latLong = fmt.Sprintf("%v,%v", clientLocation.Lat, clientLocation.Lon) } - variables := map[string]string{ + return map[string]string{ template.JWT: conn.GetToken(), template.Username: claims.User.Username, template.Groups: strings.Join(claims.User.Groups, ","), - template.ClientGeoLatLong: clientGeoLatLong, + template.ClientGeoLatLong: latLong, template.ClientGeoCity: clientLocation.City, template.ClientGeoRegion: clientLocation.Region, template.ClientGeoCountry: clientLocation.Country, } +} + +func rewrite(r *httputil.ProxyRequest, conn *connect.ProxyConn, headers map[string]*template.Template) error { + targetURL := &url.URL{ + Scheme: "http", // plain HTTP — no upstream TLS + Host: conn.GetAddress(), + } + r.SetURL(targetURL) + + variables := buildVariables(conn) for headerName, tmpl := range headers { headerValue, err := tmpl.Evaluate(variables) diff --git a/internal/webapphandler/handler_test.go b/internal/webapphandler/handler_test.go index 87defe6..1f100e5 100644 --- a/internal/webapphandler/handler_test.go +++ b/internal/webapphandler/handler_test.go @@ -5,9 +5,11 @@ package webapphandler import ( "context" + "maps" "net/http" "net/http/httptest" "net/http/httputil" + "slices" "testing" "github.com/prometheus/client_golang/prometheus" @@ -109,12 +111,11 @@ func TestRewrite(t *testing.T) { }, }, { - name: "empty geo when no device location", + name: "empty lat/lon with non-empty geo fields", jwtToken: "test-token", claims: &token.GATClaims{ - User: baseClaims.User, - Device: token.Device{ID: "device-1"}, - Resource: baseClaims.Resource, + User: baseClaims.User, + Device: token.Device{ID: "device-1", Location: token.GeoIPLocation{Country: "US", Region: "CA", City: "San Mateo"}}, }, headers: map[string]string{ "X-LatLong": "{{twingate.clientGeoLatLong}}", @@ -124,9 +125,9 @@ func TestRewrite(t *testing.T) { }, wantHeaders: map[string]string{ "X-LatLong": "", - "X-City": "", - "X-Region": "", - "X-Country": "", + "X-City": "San Mateo", + "X-Region": "CA", + "X-Country": "US", "Existing": "old-value", }, }, @@ -167,3 +168,14 @@ func TestRewrite(t *testing.T) { }) } } + +func TestBuildVariables_CoversAllowedKeys(t *testing.T) { + connMetrics := connect.CreateProxyConnMetrics(prometheus.NewRegistry()) + conn := connect.NewProxyConn(nil, nil, nil, zap.NewNop(), connMetrics) + conn.Claims = &token.GATClaims{} + + got := slices.Sorted(maps.Keys(buildVariables(conn))) + want := slices.Sorted(slices.Values(template.AllowedWebAppKeys)) + + assert.Equal(t, want, got) +}