Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@ import (
"fmt"
"net/http"
"os"
"slices"
"strings"
"time"

"github.com/hashicorp/go-retryablehttp"
"go.uber.org/zap"
"go.yaml.in/yaml/v4"
"golang.org/x/crypto/ssh"

"gateway/internal/webapphandler/template"
)

var (
Expand All @@ -23,6 +26,7 @@ var (
ErrDuplicateUpstream = errors.New("duplicate upstream name")
ErrInvalidSSHKeyType = errors.New("invalid SSH key type")
ErrNegativeTTL = errors.New("TTL must be non-negative")
ErrUnsupportedKey = errors.New("unsupported key")
)

const (
Expand All @@ -41,6 +45,11 @@ type Config struct {
TLS TLSConfig `yaml:"tls"`
Kubernetes *KubernetesConfig `yaml:"kubernetes,omitempty"`
SSH *SSHConfig `yaml:"ssh,omitempty"`
WebApp *WebAppConfig `yaml:"webApp,omitempty"`
}

type WebAppConfig struct {
Headers map[string]string `yaml:"headers,omitempty"`
}
Comment thread
clement0010 marked this conversation as resolved.
Comment thread
clement0010 marked this conversation as resolved.

type TwingateConfig struct {
Expand Down Expand Up @@ -268,9 +277,15 @@ func (c *Config) Validate() error {
}
}

if c.WebApp != nil {
if err := c.WebApp.Validate(); err != nil {
return fmt.Errorf("webApp config: %w", err)
}
}

// Check that at least one protocol is configured
if c.Kubernetes == nil && c.SSH == nil {
return fmt.Errorf("%w: at least one protocol (Kubernetes or SSH) must be configured", ErrRequired)
if c.Kubernetes == nil && c.SSH == nil && c.WebApp == nil {
return fmt.Errorf("%w: at least one protocol (Kubernetes, SSH, or WebApp) must be configured", ErrRequired)
}

return nil
Expand Down Expand Up @@ -436,6 +451,22 @@ func (v *SSHCAVaultConfig) Validate() error {
return nil
}

func (w *WebAppConfig) Validate() error {
for name, value := range w.Headers {
tmpl, err := template.New(value)
if err != nil {
return fmt.Errorf("header %q: %w", name, err)
}

key := tmpl.Key()
if key != "" && !slices.Contains(template.AllowedWebAppKeys, key) {
return fmt.Errorf("header %q: %w %q", name, ErrUnsupportedKey, key)
}
}

return nil
}

var (
ErrConflictingAuthConfig = errors.New("only one of 'token', 'appRole', 'gcp', or 'aws' can be specified for Vault auth")
ErrConflictingSecretIDConfig = errors.New("only one of 'secretID' or 'secretIDFile' can be specified")
Expand Down
64 changes: 64 additions & 0 deletions internal/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,9 @@ func TestConfig_Validate(t *testing.T) {
PrivateKeyFile: "tls.key",
},
Kubernetes: &KubernetesConfig{},
WebApp: &WebAppConfig{
Headers: map[string]string{"Authorization": "Bearer {{twingate.jwt}}"},
},
},
wantErr: false,
},
Expand Down Expand Up @@ -365,6 +368,23 @@ func TestConfig_Validate(t *testing.T) {
wantErr: true,
errContains: "at least one protocol",
},
{
name: "invalid WebApp header template",
config: &Config{
Twingate: TwingateConfig{Network: "test"},
Port: 8443,
MetricsPort: 9090,
TLS: TLSConfig{
CertificateFile: "tls.crt",
PrivateKeyFile: "tls.key",
},
WebApp: &WebAppConfig{
Headers: map[string]string{"Authorization": "Bearer {{twingate.jwt"},
},
},
wantErr: true,
errContains: "webApp config",
},
}

for _, tt := range tests {
Expand Down Expand Up @@ -1104,3 +1124,47 @@ func TestSSHCAVaultAWSConfig_Validate(t *testing.T) {
})
}
}

func TestWebAppConfig_Validate(t *testing.T) {
tests := []struct {
name string
config WebAppConfig
wantErr bool
errContains string
}{
{
name: "empty headers",
config: WebAppConfig{Headers: map[string]string{}},
wantErr: false,
},
{
name: "valid template variable",
config: WebAppConfig{Headers: map[string]string{"Authorization": "Bearer {{twingate.jwt}}"}},
wantErr: false,
},
{
name: "invalid template syntax",
config: WebAppConfig{Headers: map[string]string{"X-Bad": "{{invalid"}},
wantErr: true,
errContains: "unsupported syntax",
},
{
name: "unsupported key",
config: WebAppConfig{Headers: map[string]string{"X-Bad": "{{twingate.unknown}}"}},
wantErr: true,
errContains: "unsupported key",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.config.Validate()
if tt.wantErr {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.errContains)
} else {
require.NoError(t, err)
}
})
}
}
28 changes: 11 additions & 17 deletions internal/connect/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,6 @@ func httpResponseString(httpCode int) string {
return fmt.Sprintf("HTTP/1.1 %d %s\r\n\r\n", httpCode, http.StatusText(httpCode))
}

type TransportProtocol int

const (
TransportTLS TransportProtocol = iota
TransportSSH
)

// Conn is a custom connection that wraps the underlying TCP net.Conn, handling downstream
// proxy (Twingate Client)'s authentication via the initial CONNECT message. It handles 2 TLS
// upgrades: with downstream proxy and then optionally with downstream client e.g. `kubectl`.
Expand All @@ -47,8 +40,8 @@ type Conn interface {
GATClaims() *token.GATClaims
GetID() string
GetAddress() string
GetToken() string
Authenticate() error
TransportProtocol() TransportProtocol
UpgradeToTLS() error

Close() error
Expand All @@ -64,6 +57,7 @@ type ProxyConn struct {
ID string
Address string
Claims *token.GATClaims
Token string

Timer *time.Timer
Mu sync.Mutex
Expand Down Expand Up @@ -97,14 +91,6 @@ func (p *ProxyConn) Close() error {
return p.Conn.Close()
}

func (p *ProxyConn) TransportProtocol() TransportProtocol {
if p.GATClaims().Resource.Type == token.ResourceTypeSSH {
return TransportSSH
}

return TransportTLS
}

func (p *ProxyConn) GATClaims() *token.GATClaims {
return p.Claims
}
Expand All @@ -117,6 +103,10 @@ func (p *ProxyConn) GetAddress() string {
return p.Address
}

func (p *ProxyConn) GetToken() string {
return p.Token
}

// Authenticate sets up TLS and processes the CONNECT message for authentication.
func (p *ProxyConn) Authenticate() error {
_ = p.SetDeadline(time.Now().Add(defaultTimeout))
Expand Down Expand Up @@ -225,7 +215,10 @@ func (p *ProxyConn) Authenticate() error {

p.tracker.RecordConnectMetrics(httpCode)

p.Logger.Info("Authenticated connection", zap.String("resource_address", connectInfo.Claims.Resource.Address))
p.Logger.Info("Authenticated connection",
zap.String("resource_type", string(connectInfo.Claims.Resource.Type)),
zap.String("resource_address", connectInfo.Claims.Resource.Address),
)
p.setConnectInfo(connectInfo)

return nil
Expand All @@ -249,6 +242,7 @@ func (p *ProxyConn) setConnectInfo(connectInfo Info) {
p.ID = connectInfo.ConnID
p.Address = connectInfo.Address
p.Claims = connectInfo.Claims
p.Token = connectInfo.Token
p.Timer = time.AfterFunc(time.Until(connectInfo.Claims.ExpiresAt.Time), func() {
_ = p.Close()
})
Expand Down
27 changes: 0 additions & 27 deletions internal/connect/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,33 +83,6 @@ func TestProxyConn_setConnectInfo(t *testing.T) {
})
}

func TestProxyConn_TransportProtocol(t *testing.T) {
tests := []struct {
resourceType token.ResourceType
expected TransportProtocol
}{
{
resourceType: token.ResourceTypeKubernetes,
expected: TransportTLS,
},
{
resourceType: token.ResourceTypeSSH,
expected: TransportSSH,
},
}

for _, tt := range tests {
t.Run(tt.resourceType, func(t *testing.T) {
claims := &token.GATClaims{
Resource: token.Resource{Type: tt.resourceType},
}
proxyConn := &ProxyConn{Claims: claims}

assert.Equal(t, tt.expected, proxyConn.TransportProtocol())
})
}
}

func TestProxyConn_Close(t *testing.T) {
conn := &mockConn{}
timer := time.NewTimer(0 * time.Millisecond)
Expand Down
2 changes: 2 additions & 0 deletions internal/connect/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type Info struct {
Address string
Claims *token.GATClaims
ConnID string
Token string
}

type HTTPError struct {
Expand Down Expand Up @@ -156,5 +157,6 @@ func (v *MessageValidator) ParseConnect(req *http.Request, ekm []byte) (connectI
Address: address,
Claims: gatClaims,
ConnID: connID,
Token: bearerToken,
}, nil
}
2 changes: 2 additions & 0 deletions internal/connect/connect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ func TestConnectValidator_ParseConnect(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, *connectInfo.Claims, gatClaims)
assert.Equal(t, "conn-id", connectInfo.ConnID)
assert.Equal(t, signedToken, connectInfo.Token)
})

t.Run("Non-CONNECT method", func(t *testing.T) {
Expand All @@ -127,6 +128,7 @@ func TestConnectValidator_ParseConnect(t *testing.T) {
assert.Contains(t, httpErr.Error(), "expected CONNECT request")
assert.Nil(t, connectInfo.Claims)
assert.Empty(t, connectInfo.ConnID)
assert.Empty(t, connectInfo.Token)
})

t.Run("Missing auth header", func(t *testing.T) {
Expand Down
13 changes: 6 additions & 7 deletions internal/connect/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func (l *ProtocolListener) Addr() net.Addr {
type ConnFactory func(net.Conn, *tls.Config, Validator, *zap.Logger) Conn

type Listener struct {
channels map[TransportProtocol]chan<- Conn
channels map[token.ResourceType]chan<- Conn

tokenParser *token.Parser
certReloader *CertReloader
Expand All @@ -88,7 +88,7 @@ type Listener struct {
func NewListener(
twingateConfig config.TwingateConfig,
tlsCfg config.TLSConfig,
channels map[TransportProtocol]chan<- Conn,
channels map[token.ResourceType]chan<- Conn,
registry *prometheus.Registry,
logger *zap.Logger,
) (*Listener, error) {
Expand Down Expand Up @@ -172,19 +172,18 @@ func (l *Listener) Serve(ctx context.Context, listener net.Listener) error {
return
}

tp := proxyConn.TransportProtocol()
channel, exists := l.channels[tp]
resourceType := proxyConn.GATClaims().Resource.Type
channel, exists := l.channels[resourceType]

if !exists {
l.logger.Error("Unsupported transport protocol", zap.Int("transport", int(tp)))
l.logger.Error("Unsupported resource type", zap.String("resource_type", string(resourceType)))

_ = proxyConn.Close()

return
}

// For non-SSH protocols, upgrade to TLS
if tp != TransportSSH {
if proxyConn.GATClaims().ShouldUpgradeTLS() {
if err := proxyConn.UpgradeToTLS(); err != nil {
l.logger.Error("Failed to upgrade to TLS", zap.Error(err))

Expand Down
Loading
Loading