From d822db0b4cc8eb48648419fae4321998a67cddf2 Mon Sep 17 00:00:00 2001 From: Sam Calder-Mason Date: Wed, 17 Jun 2026 10:12:59 +1000 Subject: [PATCH 1/2] Extract proxy auth into a token.Source abstraction MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Token handling was split across two packages and implemented twice: the interactive grant lived in pkg/auth/{client,store} while the proxy client carried its own client_credentials token cache, mode branching, and a mode-specific 401 retry. The proxy — whose job is to send authenticated requests — knew which OAuth grant was in play and ran one of two token caches. Introduce pkg/auth/token with a single Source interface (Token / Invalidate) and a NewSource factory that owns the grant decision and the construction of the OAuth client and credential store. The proxy now resolves only its issuer and holds an opaque Source; the ccTokens cache, usesClientCredentials branching, and clientCredentialsToken all leave pkg/proxy. Refresh the access token proactively at 50% of its lifetime (shared client.ShouldRefresh policy) instead of waiting for the 5-minute expiry buffer, leaving a wide margin before expiry. Add a uniform invalidate-and-retry-once on 401/403 across every server-to-proxy request path (ClickHouseQuery, Discover, the server-side query funnel, and the embedder), replacing the client_credentials-only retry, so a token revoked before the local buffer self-heals regardless of grant. --- modules/clickhouse/schema_routing_test.go | 1 + pkg/app/refresh_test.go | 1 + pkg/auth/client/client.go | 37 +++ pkg/auth/client/refreshpolicy_test.go | 82 +++++++ pkg/auth/store/store.go | 54 ++++- pkg/auth/store/store_test.go | 107 +++++++++ pkg/auth/token/token.go | 184 +++++++++++++++ pkg/auth/token/token_test.go | 229 ++++++++++++++++++ pkg/embedding/remote.go | 94 ++++---- pkg/embedding/remote_test.go | 14 +- pkg/proxy/client.go | 268 +++++++++------------- pkg/proxy/client_credentials_test.go | 57 +++-- pkg/proxy/proxy.go | 4 + pkg/proxy/router.go | 10 + pkg/proxy/router_test.go | 1 + pkg/proxy/server.go | 4 + pkg/searchruntime/runtime.go | 1 + pkg/server/api.go | 75 ++++-- pkg/server/operations_ethnode_test.go | 1 + pkg/server/proxy_routing_test.go | 1 + 20 files changed, 968 insertions(+), 257 deletions(-) create mode 100644 pkg/auth/client/refreshpolicy_test.go create mode 100644 pkg/auth/token/token.go create mode 100644 pkg/auth/token/token_test.go diff --git a/modules/clickhouse/schema_routing_test.go b/modules/clickhouse/schema_routing_test.go index 04962ad4..9c28549f 100644 --- a/modules/clickhouse/schema_routing_test.go +++ b/modules/clickhouse/schema_routing_test.go @@ -100,6 +100,7 @@ func (c *schemaProxyClient) Start(_ context.Context) error { return nil } func (c *schemaProxyClient) Stop(_ context.Context) error { return nil } func (c *schemaProxyClient) URL() string { return c.url } func (c *schemaProxyClient) RegisterToken() string { return c.token } +func (c *schemaProxyClient) Invalidate() {} func (c *schemaProxyClient) RevokeToken() {} func (c *schemaProxyClient) ClickHouseDatasources() []string { return schemaDatasourceNames(c.ClickHouseDatasourceInfo()) diff --git a/pkg/app/refresh_test.go b/pkg/app/refresh_test.go index 9fff3361..631d2fa3 100644 --- a/pkg/app/refresh_test.go +++ b/pkg/app/refresh_test.go @@ -325,6 +325,7 @@ func (f *fakeProxyClient) Start(_ context.Context) error { return func (f *fakeProxyClient) Stop(_ context.Context) error { return nil } func (f *fakeProxyClient) URL() string { return "" } func (f *fakeProxyClient) RegisterToken() string { return "" } +func (f *fakeProxyClient) Invalidate() {} func (f *fakeProxyClient) RevokeToken() {} func (f *fakeProxyClient) Discover(_ context.Context) error { return nil } func (f *fakeProxyClient) EnsureAuthenticated(_ context.Context) error { return nil } diff --git a/pkg/auth/client/client.go b/pkg/auth/client/client.go index f11d9e81..92b16720 100644 --- a/pkg/auth/client/client.go +++ b/pkg/auth/client/client.go @@ -46,6 +46,43 @@ type Tokens struct { RefreshTokenIssuedAt time.Time `json:"refresh_token_issued_at,omitempty"` } +// DefaultRefreshFraction is the elapsed-lifetime fraction at which an access +// token is proactively refreshed: at 0.5 it refreshes once half its lifetime +// has passed (e.g. ~30 min into a 1h token), leaving a wide margin so a request +// never races a just-expired token. +const DefaultRefreshFraction = 0.5 + +// ShouldRefresh reports whether a token expiring at expiresAt (minted with an +// original lifetime of expiresIn seconds) should be proactively refreshed at +// the moment now. +// +// It returns true once the token is within buffer of expiry, or — when the +// original lifetime is known — once it has passed refreshFraction of that +// lifetime. refreshFraction is the elapsed-life fraction at which to refresh +// (e.g. 0.5 refreshes at the halfway point, well before expiry, so a request +// never races a just-expired token). A zero expiresAt (unknown expiry) only +// triggers on the buffer check. +func ShouldRefresh(now, expiresAt time.Time, expiresIn int, buffer time.Duration, refreshFraction float64) bool { + if expiresAt.IsZero() { + return false + } + + if now.Add(buffer).After(expiresAt) { + return true + } + + if expiresIn > 0 && refreshFraction > 0 && refreshFraction < 1 { + lifetime := time.Duration(expiresIn) * time.Second + refreshAt := expiresAt.Add(-time.Duration(float64(lifetime) * (1 - refreshFraction))) + + if now.After(refreshAt) { + return true + } + } + + return false +} + // Config configures the OAuth client. type Config struct { // IssuerURL is the OIDC issuer URL (e.g., https://dex.example.com). diff --git a/pkg/auth/client/refreshpolicy_test.go b/pkg/auth/client/refreshpolicy_test.go new file mode 100644 index 00000000..3035acdc --- /dev/null +++ b/pkg/auth/client/refreshpolicy_test.go @@ -0,0 +1,82 @@ +package client + +import ( + "testing" + "time" +) + +func TestShouldRefresh(t *testing.T) { + t.Parallel() + + now := time.Now() + const buffer = 5 * time.Minute + + tests := []struct { + name string + expiresAt time.Time + expiresIn int + fraction float64 + want bool + }{ + { + name: "fresh token below the fraction is not refreshed", + expiresAt: now.Add(40 * time.Minute), // 1h token, 33% elapsed + expiresIn: 3600, + fraction: 0.5, + want: false, + }, + { + name: "token past 50% of its lifetime is refreshed", + expiresAt: now.Add(20 * time.Minute), // 1h token, 66% elapsed + expiresIn: 3600, + fraction: 0.5, + want: true, + }, + { + name: "token exactly at the fraction is not yet refreshed", + expiresAt: now.Add(30 * time.Minute), // 1h token, exactly 50% + expiresIn: 3600, + fraction: 0.5, + want: false, + }, + { + name: "within the expiry buffer is always refreshed", + expiresAt: now.Add(2 * time.Minute), + expiresIn: 3600, + fraction: 0.5, + want: true, + }, + { + name: "a higher fraction refreshes later", + expiresAt: now.Add(20 * time.Minute), // 66% elapsed + expiresIn: 3600, + fraction: 0.75, // refresh only past 75% + want: false, + }, + { + name: "unknown lifetime falls back to the buffer only", + expiresAt: now.Add(20 * time.Minute), + expiresIn: 0, + fraction: 0.5, + want: false, + }, + { + name: "zero expiry never triggers", + expiresAt: time.Time{}, + expiresIn: 3600, + fraction: 0.5, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got := ShouldRefresh(now, tt.expiresAt, tt.expiresIn, buffer, tt.fraction) + if got != tt.want { + t.Fatalf("ShouldRefresh = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/auth/store/store.go b/pkg/auth/store/store.go index c6223baa..baa2e721 100644 --- a/pkg/auth/store/store.go +++ b/pkg/auth/store/store.go @@ -35,6 +35,11 @@ type Store interface { // GetAccessToken returns a valid access token, refreshing if needed. GetAccessToken() (string, error) + // Invalidate forces the next GetAccessToken to refresh the access token, + // even if it has not yet hit the local expiry buffer. Used when the proxy + // rejects a token that should still be valid locally (e.g. revocation). + Invalidate() + // IsAuthenticated returns true if valid tokens are stored. IsAuthenticated() bool } @@ -68,11 +73,12 @@ type Config struct { // store implements the Store interface. type store struct { - log logrus.FieldLogger - cfg Config - mu sync.RWMutex - tokens *client.Tokens - refreshMu sync.Mutex + log logrus.FieldLogger + cfg Config + mu sync.RWMutex + tokens *client.Tokens + forceRefresh bool + refreshMu sync.Mutex } // New creates a new credential store. @@ -197,6 +203,30 @@ func (s *store) GetAccessToken() (string, error) { return tokens.AccessToken, nil } +// Invalidate forces the next GetAccessToken to refresh. +func (s *store) Invalidate() { + s.mu.Lock() + defer s.mu.Unlock() + + s.forceRefresh = true +} + +// forceRefreshRequested reports whether Invalidate has armed a forced refresh. +func (s *store) forceRefreshRequested() bool { + s.mu.RLock() + defer s.mu.RUnlock() + + return s.forceRefresh +} + +// clearForceRefresh clears the forced-refresh flag after a successful refresh. +func (s *store) clearForceRefresh() { + s.mu.Lock() + defer s.mu.Unlock() + + s.forceRefresh = false +} + // IsAuthenticated returns true if valid tokens are stored. func (s *store) IsAuthenticated() bool { tokens, err := s.getTokens() @@ -234,8 +264,16 @@ func (s *store) needsRefresh(tokens *client.Tokens) bool { return true } - // Refresh if the access token is within the buffer of expiry. - if time.Now().Add(s.cfg.RefreshBuffer).After(tokens.ExpiresAt) { + if s.forceRefreshRequested() { + return true + } + + // Refresh proactively once the access token passes the configured fraction + // of its lifetime, or once it is within the buffer of expiry. + if client.ShouldRefresh( + time.Now(), tokens.ExpiresAt, tokens.ExpiresIn, + s.cfg.RefreshBuffer, client.DefaultRefreshFraction, + ) { return true } @@ -291,6 +329,8 @@ func (s *store) refresh(prior *client.Tokens) (*client.Tokens, error) { return nil, err } + s.clearForceRefresh() + return newTokens, nil } diff --git a/pkg/auth/store/store_test.go b/pkg/auth/store/store_test.go index 1acf6121..f6b66fcb 100644 --- a/pkg/auth/store/store_test.go +++ b/pkg/auth/store/store_test.go @@ -3,6 +3,7 @@ package store import ( "context" "errors" + "path/filepath" "sync" "sync/atomic" "testing" @@ -165,6 +166,112 @@ func TestGetAccessTokenSerializesConcurrentRefreshes(t *testing.T) { } } +func TestGetAccessTokenRefreshesAtAccessTokenHalfLife(t *testing.T) { + t.Parallel() + + client := &stubAuthClient{} + store := New(logrus.New(), Config{ + Path: filepath.Join(t.TempDir(), "creds.json"), + AuthClient: client, + RefreshBuffer: 5 * time.Minute, + }).(*store) + // 1h token with 20m left = 66% elapsed: past the 50% refresh point, but + // well outside the 5m expiry buffer. + store.tokens = &authclient.Tokens{ + AccessToken: "aging", + RefreshToken: "refresh-token", + ExpiresIn: 3600, + ExpiresAt: time.Now().Add(20 * time.Minute), + } + + token, err := store.GetAccessToken() + if err != nil { + t.Fatalf("GetAccessToken returned error: %v", err) + } + + if token != "refreshed-token" { + t.Fatalf("expected proactive refresh past 50%% of lifetime, got %q", token) + } + + if client.refreshCalls != 1 { + t.Fatalf("expected 1 refresh call, got %d", client.refreshCalls) + } +} + +func TestGetAccessTokenDoesNotRefreshBeforeAccessTokenHalfLife(t *testing.T) { + t.Parallel() + + client := &stubAuthClient{} + store := New(logrus.New(), Config{ + Path: filepath.Join(t.TempDir(), "creds.json"), + AuthClient: client, + RefreshBuffer: 5 * time.Minute, + }).(*store) + // 1h token with 40m left = 33% elapsed: before the 50% refresh point. + store.tokens = &authclient.Tokens{ + AccessToken: "fresh", + RefreshToken: "refresh-token", + ExpiresIn: 3600, + ExpiresAt: time.Now().Add(40 * time.Minute), + } + + token, err := store.GetAccessToken() + if err != nil { + t.Fatalf("GetAccessToken returned error: %v", err) + } + + if token != "fresh" { + t.Fatalf("expected original token before 50%% of lifetime, got %q", token) + } + + if client.refreshCalls != 0 { + t.Fatalf("expected no refresh calls, got %d", client.refreshCalls) + } +} + +func TestInvalidateForcesRefreshThenClears(t *testing.T) { + t.Parallel() + + client := &stubAuthClient{} + store := New(logrus.New(), Config{ + Path: filepath.Join(t.TempDir(), "creds.json"), + AuthClient: client, + RefreshBuffer: 5 * time.Minute, + }).(*store) + // A fresh token that would not otherwise refresh (0% elapsed). + store.tokens = &authclient.Tokens{ + AccessToken: "fresh", + RefreshToken: "refresh-token", + ExpiresIn: 3600, + ExpiresAt: time.Now().Add(time.Hour), + } + + store.Invalidate() + + token, err := store.GetAccessToken() + if err != nil { + t.Fatalf("GetAccessToken returned error: %v", err) + } + + if token != "refreshed-token" { + t.Fatalf("expected forced refresh after Invalidate, got %q", token) + } + + if client.refreshCalls != 1 { + t.Fatalf("expected 1 refresh call after Invalidate, got %d", client.refreshCalls) + } + + // The forced-refresh flag must clear: the now-fresh token should not + // refresh again on the next call. + if _, err := store.GetAccessToken(); err != nil { + t.Fatalf("second GetAccessToken returned error: %v", err) + } + + if client.refreshCalls != 1 { + t.Fatalf("expected no extra refresh after the flag cleared, got %d", client.refreshCalls) + } +} + type stubAuthClient struct { refreshCalls int refreshErr error diff --git a/pkg/auth/token/token.go b/pkg/auth/token/token.go new file mode 100644 index 00000000..fbe2eafc --- /dev/null +++ b/pkg/auth/token/token.go @@ -0,0 +1,184 @@ +// Package token provides grant-agnostic access-token sources for proxy auth. +// +// A Source hides which OAuth grant is in play (interactive refresh-token vs +// client_credentials) behind a single Token/Invalidate seam, so request code +// can attach a token and retry on rejection without branching on auth mode. +package token + +import ( + "context" + "fmt" + "strings" + "sync" + "time" + + "github.com/sirupsen/logrus" + + "github.com/ethpandaops/panda/pkg/auth/client" + "github.com/ethpandaops/panda/pkg/auth/store" +) + +// ModeClientCredentials selects the non-interactive service-account grant. Any +// other mode ("", "oauth", "oidc") uses the interactive refresh-token grant. +const ModeClientCredentials = "client_credentials" + +// clientCredentialsBuffer is how long before expiry a cached client_credentials +// access token is re-minted, in addition to the proactive fractional refresh. +const clientCredentialsBuffer = 5 * time.Minute + +// Config describes how to build a Source. A blank IssuerURL or ClientID means +// no auth is configured and NewSource returns nil. +type Config struct { + // IssuerURL is the resolved OIDC issuer (callers apply any proxy-URL + // fallback before passing it in). + IssuerURL string + // ClientID is the OAuth client ID. + ClientID string + // Resource is the optional RFC 8707 resource indicator. + Resource string + // Username and Password are the service-account credentials for + // ModeClientCredentials. + Username string + Password string + // Mode selects the grant: ModeClientCredentials or interactive (default). + Mode string + // RefreshTokenTTL is the expected refresh-token lifetime, used by the + // interactive store to keep the refresh token alive via rotation. + RefreshTokenTTL time.Duration + // MintTimeout bounds a single client_credentials mint. + MintTimeout time.Duration +} + +// Source yields valid access tokens for proxy requests, hiding the OAuth grant. +type Source interface { + // Token returns a currently valid access token, refreshing or minting one + // as needed. + Token(ctx context.Context) (string, error) + + // Invalidate drops any cached access token so the next Token call obtains a + // fresh one. Used when the proxy rejects a token that has not yet hit the + // local expiry buffer (e.g. server-side revocation). + Invalidate() +} + +// refreshSource is the interactive (authorization-code / device) grant. It +// delegates to the on-disk credential store, which refreshes via the refresh +// token and persists rotation. +type refreshSource struct { + store store.Store +} + +// clientCredentialsSource is the non-interactive service-account grant. It +// mints access tokens on demand and caches them in memory only — no refresh +// token, nothing written to disk. +type clientCredentialsSource struct { + log logrus.FieldLogger + client client.Client + mintTimeout time.Duration + + mu sync.Mutex + tokens *client.Tokens +} + +// NewSource builds the access-token Source for cfg, owning the grant decision +// and the construction of the OAuth client and (for interactive grants) the +// on-disk credential store. It returns nil when no auth is configured, so the +// proxy can treat "no token source" as "no auth required". +func NewSource(log logrus.FieldLogger, cfg Config) Source { + issuer := strings.TrimRight(cfg.IssuerURL, "/") + clientID := strings.TrimSpace(cfg.ClientID) + + if issuer == "" || clientID == "" { + return nil + } + + resource := strings.TrimRight(cfg.Resource, "/") + + authClient := client.New(log, client.Config{ + IssuerURL: issuer, + ClientID: clientID, + Resource: resource, + Username: cfg.Username, + Password: cfg.Password, + }) + + if cfg.Mode == ModeClientCredentials { + return NewClientCredentialsSource(log, authClient, cfg.MintTimeout) + } + + return NewRefreshSource(store.New(log, store.Config{ + AuthClient: authClient, + IssuerURL: issuer, + ClientID: clientID, + Resource: resource, + RefreshTokenTTL: cfg.RefreshTokenTTL, + })) +} + +// NewRefreshSource builds a Source backed by the on-disk credential store. +func NewRefreshSource(s store.Store) Source { + return &refreshSource{store: s} +} + +// NewClientCredentialsSource builds a Source that mints tokens via the +// client_credentials grant, caching them in memory for mintTimeout-bounded +// re-mints. +func NewClientCredentialsSource(log logrus.FieldLogger, c client.Client, mintTimeout time.Duration) Source { + return &clientCredentialsSource{ + log: log.WithField("component", "client-credentials-token"), + client: c, + mintTimeout: mintTimeout, + } +} + +// Token returns a valid access token from the credential store. +func (s *refreshSource) Token(_ context.Context) (string, error) { + return s.store.GetAccessToken() +} + +// Invalidate forces the store to refresh on the next Token call. +func (s *refreshSource) Invalidate() { + s.store.Invalidate() +} + +// Token returns the cached client_credentials access token, re-minting via the +// issuer's token endpoint when missing or past the refresh threshold. +func (s *clientCredentialsSource) Token(_ context.Context) (string, error) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.tokens != nil && !client.ShouldRefresh( + time.Now(), s.tokens.ExpiresAt, s.tokens.ExpiresIn, + clientCredentialsBuffer, client.DefaultRefreshFraction, + ) { + return s.tokens.AccessToken, nil + } + + ctx, cancel := context.WithTimeout(context.Background(), s.mintTimeout) + defer cancel() + + tokens, err := s.client.ClientCredentials(ctx) + if err != nil { + // Keep serving a still-valid cached token across transient mint failures. + if s.tokens != nil && time.Now().Before(s.tokens.ExpiresAt) { + s.log.WithError(err).Warn("Re-minting client_credentials token failed; using cached token") + + return s.tokens.AccessToken, nil + } + + return "", fmt.Errorf("minting client_credentials token: %w", err) + } + + s.tokens = tokens + + return tokens.AccessToken, nil +} + +// Invalidate drops the cached client_credentials token so the next Token call +// mints a fresh one. +func (s *clientCredentialsSource) Invalidate() { + s.mu.Lock() + defer s.mu.Unlock() + + s.tokens = nil +} diff --git a/pkg/auth/token/token_test.go b/pkg/auth/token/token_test.go new file mode 100644 index 00000000..1ad16542 --- /dev/null +++ b/pkg/auth/token/token_test.go @@ -0,0 +1,229 @@ +package token + +import ( + "context" + "errors" + "fmt" + "sync" + "testing" + "time" + + "github.com/sirupsen/logrus" + + "github.com/ethpandaops/panda/pkg/auth/client" +) + +func discardLog() logrus.FieldLogger { + log := logrus.New() + log.SetOutput(discardWriter{}) + + return log +} + +type discardWriter struct{} + +func (discardWriter) Write(p []byte) (int, error) { return len(p), nil } + +// fakeAuthClient is a client.Client whose ClientCredentials mints +// "svc-token-" with the configured lifetime and can be made to fail. +type fakeAuthClient struct { + mu sync.Mutex + mints int + expiresIn int + fail bool +} + +func (f *fakeAuthClient) Login(_ context.Context) (*client.Tokens, error) { + return nil, errors.New("not implemented") +} + +func (f *fakeAuthClient) Refresh(_ context.Context, _ string) (*client.Tokens, error) { + return nil, errors.New("not implemented") +} + +func (f *fakeAuthClient) ClientCredentials(_ context.Context) (*client.Tokens, error) { + f.mu.Lock() + defer f.mu.Unlock() + + f.mints++ + if f.fail { + return nil, errors.New("mint failed") + } + + return &client.Tokens{ + AccessToken: fmt.Sprintf("svc-token-%d", f.mints), + TokenType: "Bearer", + ExpiresIn: f.expiresIn, + ExpiresAt: time.Now().Add(time.Duration(f.expiresIn) * time.Second), + }, nil +} + +func (f *fakeAuthClient) mintCount() int { + f.mu.Lock() + defer f.mu.Unlock() + + return f.mints +} + +func TestClientCredentialsSourceCachesUntilRefreshFraction(t *testing.T) { + t.Parallel() + + auth := &fakeAuthClient{expiresIn: 3600} + src := NewClientCredentialsSource(discardLog(), auth, time.Second) + + first, err := src.Token(context.Background()) + if err != nil { + t.Fatalf("first Token error: %v", err) + } + + second, err := src.Token(context.Background()) + if err != nil { + t.Fatalf("second Token error: %v", err) + } + + if first != "svc-token-1" || second != "svc-token-1" { + t.Fatalf("expected cached svc-token-1, got %q then %q", first, second) + } + + if got := auth.mintCount(); got != 1 { + t.Fatalf("expected 1 mint for a cached token, got %d", got) + } +} + +func TestClientCredentialsSourceReMintsWithinBuffer(t *testing.T) { + t.Parallel() + + // A 1s lifetime is inside the refresh buffer, so each call re-mints. + auth := &fakeAuthClient{expiresIn: 1} + src := NewClientCredentialsSource(discardLog(), auth, time.Second) + + first, _ := src.Token(context.Background()) + second, _ := src.Token(context.Background()) + + if first == second { + t.Fatalf("expected a re-mint, got %q twice", first) + } + + if got := auth.mintCount(); got != 2 { + t.Fatalf("expected 2 mints, got %d", got) + } +} + +func TestClientCredentialsSourceInvalidateForcesReMint(t *testing.T) { + t.Parallel() + + auth := &fakeAuthClient{expiresIn: 3600} + src := NewClientCredentialsSource(discardLog(), auth, time.Second) + + first, _ := src.Token(context.Background()) + + src.Invalidate() + + second, err := src.Token(context.Background()) + if err != nil { + t.Fatalf("Token after Invalidate error: %v", err) + } + + if first == second { + t.Fatalf("expected a fresh token after Invalidate, got %q twice", first) + } + + if got := auth.mintCount(); got != 2 { + t.Fatalf("expected 2 mints after Invalidate, got %d", got) + } +} + +func TestClientCredentialsSourceServesCachedTokenAcrossMintOutage(t *testing.T) { + t.Parallel() + + // 60s lifetime sits inside the 5m buffer, so the next call attempts a + // re-mint — but the cached token is still valid. + auth := &fakeAuthClient{expiresIn: 60} + src := NewClientCredentialsSource(discardLog(), auth, time.Second) + + first, err := src.Token(context.Background()) + if err != nil { + t.Fatalf("first Token error: %v", err) + } + + auth.mu.Lock() + auth.fail = true + auth.mu.Unlock() + + got, err := src.Token(context.Background()) + if err != nil { + t.Fatalf("expected cached token across mint outage, got error: %v", err) + } + + if got != first { + t.Fatalf("expected cached token %q across outage, got %q", first, got) + } +} + +// fakeStore is a minimal store.Store for exercising the refresh source. +type fakeStore struct { + token string + invalidated bool +} + +func (f *fakeStore) Path() string { return "" } +func (f *fakeStore) Save(*client.Tokens) error { return nil } +func (f *fakeStore) Load() (*client.Tokens, error) { return nil, nil } +func (f *fakeStore) Clear() error { return nil } +func (f *fakeStore) GetAccessToken() (string, error) { return f.token, nil } +func (f *fakeStore) Invalidate() { f.invalidated = true } +func (f *fakeStore) IsAuthenticated() bool { return f.token != "" } + +func TestRefreshSourceDelegatesToStore(t *testing.T) { + t.Parallel() + + fs := &fakeStore{token: "access-123"} + src := NewRefreshSource(fs) + + tok, err := src.Token(context.Background()) + if err != nil { + t.Fatalf("Token error: %v", err) + } + + if tok != "access-123" { + t.Fatalf("Token = %q, want access-123", tok) + } + + src.Invalidate() + if !fs.invalidated { + t.Fatal("Invalidate did not propagate to the store") + } +} + +func TestNewSourceNilWhenUnconfigured(t *testing.T) { + t.Parallel() + + if src := NewSource(discardLog(), Config{}); src != nil { + t.Fatal("NewSource with no issuer/clientID should return nil") + } + + if src := NewSource(discardLog(), Config{IssuerURL: "https://issuer.example"}); src != nil { + t.Fatal("NewSource with no clientID should return nil") + } +} + +func TestNewSourceBuildsConfiguredGrants(t *testing.T) { + t.Parallel() + + cc := NewSource(discardLog(), Config{ + IssuerURL: "https://issuer.example", + ClientID: "panda-proxy", + Mode: ModeClientCredentials, + }) + if cc == nil { + t.Fatal("NewSource for client_credentials returned nil") + } + + interactive := NewSource(discardLog(), Config{ + IssuerURL: "https://issuer.example", + ClientID: "panda", + }) + if interactive == nil { + t.Fatal("NewSource for the interactive grant returned nil") + } +} diff --git a/pkg/embedding/remote.go b/pkg/embedding/remote.go index 8195c543..b24c1beb 100644 --- a/pkg/embedding/remote.go +++ b/pkg/embedding/remote.go @@ -59,35 +59,40 @@ type embedResult struct { // RemoteEmbedder implements Embedder by calling the proxy's /embed endpoint. // An optional local cache avoids round-trips to the proxy on warm restarts. type RemoteEmbedder struct { - log logrus.FieldLogger - proxyURL string - httpClient *http.Client - tokenFn func() string - localCache cache.Cache - model string + log logrus.FieldLogger + proxyURL string + httpClient *http.Client + tokenFn func() string + invalidateFn func() + localCache cache.Cache + model string } // Compile-time interface check. var _ Embedder = (*RemoteEmbedder)(nil) // NewRemote creates a new RemoteEmbedder that calls the proxy's /embed endpoint. -// tokenFn is called on each request to get the current auth token. +// tokenFn is called on each request to get the current auth token, and +// invalidateFn drops the cached token so a 401/403 can be retried with a fresh +// one (it may be nil to disable the retry). // localCache and model are optional — when both are set, embedding vectors are // cached locally using {model}:{textHash} keys to avoid proxy round-trips. func NewRemote( log logrus.FieldLogger, proxyURL string, tokenFn func() string, + invalidateFn func(), localCache cache.Cache, model string, ) *RemoteEmbedder { return &RemoteEmbedder{ - log: log.WithField("component", "remote-embedder"), - proxyURL: proxyURL, - httpClient: &http.Client{Timeout: remoteEmbedTimeout}, - tokenFn: tokenFn, - localCache: localCache, - model: model, + log: log.WithField("component", "remote-embedder"), + proxyURL: proxyURL, + httpClient: &http.Client{Timeout: remoteEmbedTimeout}, + tokenFn: tokenFn, + invalidateFn: invalidateFn, + localCache: localCache, + model: model, } } @@ -355,27 +360,48 @@ func (e *RemoteEmbedder) embedDirect( return vectors, nil } -func (e *RemoteEmbedder) checkCached(hashes []string) ([]embedResult, error) { - reqBody, err := json.Marshal(embedCheckRequest{Hashes: hashes}) - if err != nil { - return nil, fmt.Errorf("marshaling check request: %w", err) +// doWithAuthRetry POSTs jsonBody to the proxy path with the current auth token, +// retrying once after invalidating the token on a 401/403. +func (e *RemoteEmbedder) doWithAuthRetry(path string, jsonBody []byte) (*http.Response, error) { + send := func() (*http.Response, error) { + req, err := http.NewRequestWithContext( + context.Background(), http.MethodPost, e.proxyURL+path, bytes.NewReader(jsonBody), + ) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + + if token := e.tokenFn(); token != "" { + req.Header.Set("Authorization", "Bearer "+token) + } + + return e.httpClient.Do(req) } - req, err := http.NewRequestWithContext( - context.Background(), http.MethodPost, - e.proxyURL+"/embed/check", bytes.NewReader(reqBody), - ) + resp, err := send() if err != nil { - return nil, fmt.Errorf("creating check request: %w", err) + return nil, err + } + + if (resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden) && e.invalidateFn != nil { + _ = resp.Body.Close() + e.invalidateFn() + + return send() } - req.Header.Set("Content-Type", "application/json") + return resp, nil +} - if token := e.tokenFn(); token != "" { - req.Header.Set("Authorization", "Bearer "+token) +func (e *RemoteEmbedder) checkCached(hashes []string) ([]embedResult, error) { + reqBody, err := json.Marshal(embedCheckRequest{Hashes: hashes}) + if err != nil { + return nil, fmt.Errorf("marshaling check request: %w", err) } - resp, err := e.httpClient.Do(req) + resp, err := e.doWithAuthRetry("/embed/check", reqBody) if err != nil { return nil, fmt.Errorf("calling embed check: %w", err) } @@ -401,21 +427,7 @@ func (e *RemoteEmbedder) callEmbed(items []embedItem) (*embedResponse, error) { return nil, fmt.Errorf("marshaling embed request: %w", err) } - req, err := http.NewRequestWithContext( - context.Background(), http.MethodPost, - e.proxyURL+"/embed", bytes.NewReader(reqBody), - ) - if err != nil { - return nil, fmt.Errorf("creating embed request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - - if token := e.tokenFn(); token != "" { - req.Header.Set("Authorization", "Bearer "+token) - } - - resp, err := e.httpClient.Do(req) + resp, err := e.doWithAuthRetry("/embed", reqBody) if err != nil { return nil, fmt.Errorf("calling proxy embed: %w", err) } diff --git a/pkg/embedding/remote_test.go b/pkg/embedding/remote_test.go index b3870961..877b632b 100644 --- a/pkg/embedding/remote_test.go +++ b/pkg/embedding/remote_test.go @@ -65,7 +65,7 @@ func TestRemoteEmbedder_Embed(t *testing.T) { require.NoError(t, json.NewEncoder(w).Encode(resp)) }, nil) - embedder := NewRemote(logrus.New(), srv.URL, func() string { return "" }, nil, "") + embedder := NewRemote(logrus.New(), srv.URL, func() string { return "" }, nil, nil, "") vec, err := embedder.Embed("hello world") require.NoError(t, err) @@ -111,7 +111,7 @@ func TestRemoteEmbedder_EmbedBatch_AllMisses(t *testing.T) { }, ) - embedder := NewRemote(logrus.New(), srv.URL, func() string { return "" }, nil, "") + embedder := NewRemote(logrus.New(), srv.URL, func() string { return "" }, nil, nil, "") vectors, err := embedder.EmbedBatch(texts) require.NoError(t, err) @@ -155,7 +155,7 @@ func TestRemoteEmbedder_EmbedBatch_AllCached(t *testing.T) { }, ) - embedder := NewRemote(logrus.New(), srv.URL, func() string { return "" }, nil, "") + embedder := NewRemote(logrus.New(), srv.URL, func() string { return "" }, nil, nil, "") vectors, err := embedder.EmbedBatch(texts) require.NoError(t, err) @@ -197,7 +197,7 @@ func TestRemoteEmbedder_EmbedBatch_PartialCache(t *testing.T) { }, ) - embedder := NewRemote(logrus.New(), srv.URL, func() string { return "" }, nil, "") + embedder := NewRemote(logrus.New(), srv.URL, func() string { return "" }, nil, nil, "") vectors, err := embedder.EmbedBatch(texts) require.NoError(t, err) @@ -228,7 +228,7 @@ func TestRemoteEmbedder_EmbedBatch_DuplicateTexts(t *testing.T) { nil, ) - embedder := NewRemote(logrus.New(), srv.URL, func() string { return "" }, nil, "") + embedder := NewRemote(logrus.New(), srv.URL, func() string { return "" }, nil, nil, "") vectors, err := embedder.EmbedBatch(texts) require.NoError(t, err) @@ -245,7 +245,7 @@ func TestRemoteEmbedder_ServerError(t *testing.T) { http.Error(w, "internal server error", http.StatusInternalServerError) }, nil) - embedder := NewRemote(logrus.New(), srv.URL, func() string { return "" }, nil, "") + embedder := NewRemote(logrus.New(), srv.URL, func() string { return "" }, nil, nil, "") _, err := embedder.Embed("test") require.Error(t, err) @@ -273,7 +273,7 @@ func TestRemoteEmbedder_AuthHeader(t *testing.T) { tokenCalled.Store(true) return expectedToken - }, nil, "") + }, nil, nil, "") _, err := embedder.Embed("test") require.NoError(t, err) diff --git a/pkg/proxy/client.go b/pkg/proxy/client.go index eb218ddf..636f7498 100644 --- a/pkg/proxy/client.go +++ b/pkg/proxy/client.go @@ -16,8 +16,7 @@ import ( "github.com/ethpandaops/panda/internal/version" "github.com/ethpandaops/panda/pkg/attribution" - "github.com/ethpandaops/panda/pkg/auth/client" - "github.com/ethpandaops/panda/pkg/auth/store" + "github.com/ethpandaops/panda/pkg/auth/token" "github.com/ethpandaops/panda/pkg/proxy/handlers" "github.com/ethpandaops/panda/pkg/types" ) @@ -105,28 +104,22 @@ type proxyClient struct { cfg ClientConfig httpClient *http.Client queryHTTPClient *http.Client - authClient client.Client - credStore store.Store + + // tokenSource yields access tokens for proxy requests regardless of grant + // type (interactive refresh-token or client_credentials). It is nil when + // the proxy needs no auth. + tokenSource token.Source mu sync.RWMutex datasources *DatasourcesResponse stopCh chan struct{} stopped bool - - // ccMu guards ccTokens, the in-memory client_credentials token cache. - // Tokens minted via client_credentials are never written to disk. - ccMu sync.Mutex - ccTokens *client.Tokens } // AuthModeClientCredentials is the ClientConfig.AuthMode value for the // OAuth2 client_credentials grant (Authentik service-account form). const AuthModeClientCredentials = "client_credentials" -// clientCredentialsRefreshBuffer is how long before expiry a cached -// client_credentials access token is re-minted. -const clientCredentialsRefreshBuffer = 5 * time.Minute - var ErrAuthenticationRequired = errors.New("proxy authentication required") // Compile-time interface checks. @@ -152,45 +145,29 @@ func NewClient(log logrus.FieldLogger, cfg ClientConfig) Client { stopCh: make(chan struct{}), } - // Set up auth client and credential store if OIDC is configured. + // The proxy owns only "what is my issuer" (its own URL is the issuer in + // embedded mode); the auth package owns everything else — grant choice, + // OAuth client, and credential store. The proxy treats the result as an + // opaque token provider, nil when no auth is configured. issuerURL := strings.TrimRight(cfg.IssuerURL, "/") if issuerURL == "" { issuerURL = strings.TrimRight(cfg.URL, "/") } - resource := strings.TrimRight(cfg.Resource, "/") - - if issuerURL != "" && cfg.ClientID != "" { - c.authClient = client.New(log, client.Config{ - IssuerURL: issuerURL, - ClientID: cfg.ClientID, - Resource: resource, - Username: cfg.Username, - Password: cfg.Password, - }) - - // client_credentials mints tokens on demand and keeps them in - // memory only — no on-disk credential store. - if cfg.AuthMode != AuthModeClientCredentials { - c.credStore = store.New(log, store.Config{ - AuthClient: c.authClient, - IssuerURL: issuerURL, - ClientID: cfg.ClientID, - Resource: resource, - RefreshTokenTTL: cfg.RefreshTokenTTL, - }) - } - } + c.tokenSource = token.NewSource(log, token.Config{ + IssuerURL: issuerURL, + ClientID: cfg.ClientID, + Resource: cfg.Resource, + Username: cfg.Username, + Password: cfg.Password, + Mode: cfg.AuthMode, + RefreshTokenTTL: cfg.RefreshTokenTTL, + MintTimeout: cfg.HTTPTimeout, + }) return c } -// usesClientCredentials reports whether this client mints tokens via the -// client_credentials grant. -func (c *proxyClient) usesClientCredentials() bool { - return c.cfg.AuthMode == AuthModeClientCredentials && c.authClient != nil -} - // Start starts the client and performs initial discovery. func (c *proxyClient) Start(ctx context.Context) error { c.log.WithField("url", c.cfg.URL).Info("Starting proxy client") @@ -235,22 +212,42 @@ func (c *proxyClient) URL() string { } func (c *proxyClient) RegisterToken() string { - if c.credStore == nil && !c.usesClientCredentials() { + if c.tokenSource == nil { return NoAuthToken } - token, err := c.loadAccessToken() + tok, err := c.tokenSource.Token(context.Background()) if err != nil { - if errors.Is(err, ErrAuthenticationRequired) { - c.log.WithError(err).Debug("Proxy access token is not available") - } else { - c.log.WithError(err).Error("Failed to get proxy access token from credential store") - } + c.log.WithError(err).Debug("Proxy access token is not available") return "" } - return token + return tok +} + +// Invalidate drops the cached access token so the next request obtains a fresh +// one. Callers use it on a 401/403 before retrying. +func (c *proxyClient) Invalidate() { + if c.tokenSource != nil { + c.tokenSource.Invalidate() + } +} + +// accessToken returns the bearer token for a proxy request: an empty string +// when the proxy needs no auth, the token when available, or an error wrapping +// ErrAuthenticationRequired when credentials are missing or unusable. +func (c *proxyClient) accessToken(ctx context.Context) (string, error) { + if c.tokenSource == nil { + return "", nil + } + + tok, err := c.tokenSource.Token(ctx) + if err != nil { + return "", fmt.Errorf("%w: %v", ErrAuthenticationRequired, err) + } + + return tok, nil } func (c *proxyClient) RevokeToken() { @@ -353,40 +350,66 @@ func (c *proxyClient) ClickHouseQuery(ctx context.Context, datasource, sql strin requestURL += "?" + encoded } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, strings.NewReader(sql)) - if err != nil { - return nil, fmt.Errorf("creating request: %w", err) - } + send := func() (int, []byte, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, strings.NewReader(sql)) + if err != nil { + return 0, nil, fmt.Errorf("creating request: %w", err) + } - req.Header.Set(handlers.DatasourceHeader, datasource) - req.Header.Set("Content-Type", "text/plain") + req.Header.Set(handlers.DatasourceHeader, datasource) + req.Header.Set("Content-Type", "text/plain") - if v := attribution.FromContext(ctx); v != "" { - req.Header.Set(attribution.Header, v) - } + if v := attribution.FromContext(ctx); v != "" { + req.Header.Set(attribution.Header, v) + } + + if token := c.RegisterToken(); token != "" && token != NoAuthToken { + req.Header.Set("Authorization", "Bearer "+token) + } + + resp, err := c.queryHTTPClient.Do(req) + if err != nil { + return 0, nil, fmt.Errorf("executing query: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return resp.StatusCode, nil, fmt.Errorf("reading response: %w", err) + } - if token := c.RegisterToken(); token != "" && token != NoAuthToken { - req.Header.Set("Authorization", "Bearer "+token) + return resp.StatusCode, body, nil } - resp, err := c.queryHTTPClient.Do(req) + status, body, err := send() if err != nil { - return nil, fmt.Errorf("executing query: %w", err) + return nil, err } - defer func() { _ = resp.Body.Close() }() - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("reading response: %w", err) + // One invalidate-and-retry on auth rejection covers a token revoked + // server-side before the local refresh buffer kicks in. + if isAuthRejection(status) && c.tokenSource != nil { + c.log.Debug("Proxy rejected query token; invalidating and retrying") + c.tokenSource.Invalidate() + + if status, body, err = send(); err != nil { + return nil, err + } } - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return nil, fmt.Errorf("query failed (%d): %s", resp.StatusCode, strings.TrimSpace(string(body))) + if status < 200 || status >= 300 { + return nil, fmt.Errorf("query failed (%d): %s", status, strings.TrimSpace(string(body))) } return body, nil } +// isAuthRejection reports whether status is a proxy auth rejection that a token +// refresh-and-retry can recover from. +func isAuthRejection(status int) bool { + return status == http.StatusUnauthorized || status == http.StatusForbidden +} + // PrometheusDatasources returns the discovered Prometheus datasource names. func (c *proxyClient) PrometheusDatasources() []string { c.mu.RLock() @@ -460,14 +483,14 @@ func (c *proxyClient) EmbeddingModel() string { } // Discover fetches datasource information from the proxy's /datasources endpoint. -// In client_credentials mode a 401/403 invalidates the cached token and the -// request is retried once with a freshly minted one (covers proxy-side -// revocation before the local expiry buffer kicks in). +// A 401/403 invalidates the cached token and the request is retried once with a +// fresh one, covering proxy-side revocation before the local expiry buffer +// kicks in. This is independent of the auth grant in use. func (c *proxyClient) Discover(ctx context.Context) error { err := c.discoverOnce(ctx) - if err != nil && errors.Is(err, ErrAuthenticationRequired) && c.usesClientCredentials() { - c.log.WithError(err).Debug("Proxy rejected client_credentials token; re-minting and retrying") - c.invalidateClientCredentialsToken() + if err != nil && errors.Is(err, ErrAuthenticationRequired) && c.tokenSource != nil { + c.log.WithError(err).Debug("Proxy rejected token; invalidating and retrying") + c.tokenSource.Invalidate() return c.discoverOnce(ctx) } @@ -484,17 +507,13 @@ func (c *proxyClient) discoverOnce(ctx context.Context) error { return fmt.Errorf("creating request: %w", err) } - token, err := c.loadAccessToken() + tok, err := c.accessToken(ctx) if err != nil { - if errors.Is(err, ErrAuthenticationRequired) { - return err - } - - return fmt.Errorf("loading access token: %w", err) + return err } - if token != "" { - req.Header.Set("Authorization", "Bearer "+token) + if tok != "" { + req.Header.Set("Authorization", "Bearer "+tok) } resp, err := c.httpClient.Do(req) @@ -536,22 +555,13 @@ func (c *proxyClient) discoverOnce(ctx context.Context) error { } // EnsureAuthenticated checks if the user has valid credentials. -func (c *proxyClient) EnsureAuthenticated(_ context.Context) error { - if c.usesClientCredentials() { - if _, err := c.loadAccessToken(); err != nil { - return fmt.Errorf("authenticating to proxy via client_credentials: %w", err) - } - - return nil - } - - if c.credStore == nil { +func (c *proxyClient) EnsureAuthenticated(ctx context.Context) error { + if c.tokenSource == nil { // No auth required (e.g., local dev mode). return nil } - _, err := c.loadAccessToken() - if err != nil { + if _, err := c.tokenSource.Token(ctx); err != nil { return fmt.Errorf( "not authenticated to proxy. Run `panda auth login` first: %w", err, @@ -561,72 +571,6 @@ func (c *proxyClient) EnsureAuthenticated(_ context.Context) error { return nil } -func (c *proxyClient) loadAccessToken() (string, error) { - if c.usesClientCredentials() { - return c.clientCredentialsToken() - } - - if c.credStore == nil { - return "", nil - } - - tokens, err := c.credStore.Load() - if err != nil { - return "", fmt.Errorf("loading stored credentials: %w", err) - } - - if tokens == nil { - return "", ErrAuthenticationRequired - } - - token, err := c.credStore.GetAccessToken() - if err != nil { - return "", fmt.Errorf("%w: %v", ErrAuthenticationRequired, err) - } - - return token, nil -} - -// clientCredentialsToken returns the cached client_credentials access token, -// re-minting via the issuer's token endpoint when missing or close to expiry. -// Tokens live in memory only; nothing is written to disk. -func (c *proxyClient) clientCredentialsToken() (string, error) { - c.ccMu.Lock() - defer c.ccMu.Unlock() - - if c.ccTokens != nil && time.Now().Add(clientCredentialsRefreshBuffer).Before(c.ccTokens.ExpiresAt) { - return c.ccTokens.AccessToken, nil - } - - ctx, cancel := context.WithTimeout(context.Background(), c.cfg.HTTPTimeout) - defer cancel() - - tokens, err := c.authClient.ClientCredentials(ctx) - if err != nil { - // Keep serving a still-valid cached token across transient mint failures. - if c.ccTokens != nil && time.Now().Before(c.ccTokens.ExpiresAt) { - c.log.WithError(err).Warn("Re-minting client_credentials token failed; using cached token") - return c.ccTokens.AccessToken, nil - } - - return "", fmt.Errorf("minting client_credentials token: %w", err) - } - - c.ccTokens = tokens - - return tokens.AccessToken, nil -} - -// invalidateClientCredentialsToken drops the cached client_credentials token -// so the next loadAccessToken mints a fresh one. Used when the proxy rejects -// a token that has not yet hit the local expiry buffer (e.g. revocation). -func (c *proxyClient) invalidateClientCredentialsToken() { - c.ccMu.Lock() - defer c.ccMu.Unlock() - - c.ccTokens = nil -} - // backgroundRefresh periodically refreshes datasource information. func (c *proxyClient) backgroundRefresh() { ticker := time.NewTicker(c.cfg.DiscoveryInterval) diff --git a/pkg/proxy/client_credentials_test.go b/pkg/proxy/client_credentials_test.go index 1e189cbe..2386350d 100644 --- a/pkg/proxy/client_credentials_test.go +++ b/pkg/proxy/client_credentials_test.go @@ -9,7 +9,6 @@ import ( "strings" "sync/atomic" "testing" - "time" "github.com/sirupsen/logrus" ) @@ -105,10 +104,6 @@ func TestClientCredentialsTokenIsCachedInMemory(t *testing.T) { issuer := newFakeIssuer(t, 3600) client := newClientCredentialsClient(issuer.server.URL, "http://unused.test") - if client.credStore != nil { - t.Fatal("client_credentials mode must not create an on-disk credential store") - } - first := client.RegisterToken() if first != "svc-token-1" { t.Fatalf("RegisterToken() = %q, want svc-token-1", first) @@ -198,26 +193,50 @@ func TestClientCredentialsDiscoverFailsWhenRetryRejected(t *testing.T) { } } -func TestClientCredentialsServesCachedTokenAcrossMintOutage(t *testing.T) { +// Serving a cached client_credentials token across a mint outage while the +// token is inside the refresh buffer is covered by the token package's +// clientCredentialsSource tests, where the in-memory cache is directly +// controllable. + +func TestClickHouseQueryRetriesOn401(t *testing.T) { t.Parallel() issuer := newFakeIssuer(t, 3600) - client := newClientCredentialsClient(issuer.server.URL, "http://unused.test") - first := client.RegisterToken() - if first == "" { - t.Fatal("expected an initial token") + var calls atomic.Int64 + + proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/clickhouse/" { + http.NotFound(w, r) + return + } + + // Reject the first token (revoked server-side), accept the re-mint. + if calls.Add(1) == 1 { + http.Error(w, "revoked", http.StatusUnauthorized) + return + } + + _, _ = w.Write([]byte("query-ok")) + })) + t.Cleanup(proxy.Close) + + client := newClientCredentialsClient(issuer.server.URL, proxy.URL) + + body, err := client.ClickHouseQuery(context.Background(), "xatu", "SELECT 1", nil) + if err != nil { + t.Fatalf("ClickHouseQuery error after retry: %v", err) } - // Simulate the issuer going away while the cached token is still valid - // but inside the refresh buffer. - issuer.server.Close() - client.ccMu.Lock() - client.ccTokens.ExpiresAt = time.Now().Add(time.Minute) - client.ccMu.Unlock() + if string(body) != "query-ok" { + t.Fatalf("body = %q, want query-ok", body) + } - got := client.RegisterToken() - if got != first { - t.Fatalf("expected cached token %q across mint outage, got %q", first, got) + if got := calls.Load(); got != 2 { + t.Fatalf("expected 2 clickhouse attempts (initial + retry), got %d", got) + } + + if got := issuer.mints.Load(); got != 2 { + t.Fatalf("expected 2 mints (initial + re-mint on 401), got %d", got) } } diff --git a/pkg/proxy/proxy.go b/pkg/proxy/proxy.go index 954a9ee6..4877a870 100644 --- a/pkg/proxy/proxy.go +++ b/pkg/proxy/proxy.go @@ -30,6 +30,10 @@ type Service interface { // requests, or NoAuthToken when no bearer token is required. RegisterToken() string + // Invalidate drops the cached access token so the next RegisterToken + // obtains a fresh one. Callers use it on a 401/403 before retrying. + Invalidate() + // RevokeToken is a no-op for client-managed bearer tokens. RevokeToken() diff --git a/pkg/proxy/router.go b/pkg/proxy/router.go index 362eb6ae..d845da95 100644 --- a/pkg/proxy/router.go +++ b/pkg/proxy/router.go @@ -171,6 +171,16 @@ func (r *routerClient) RegisterToken() string { return primary.RegisterToken() } +// Invalidate drops the primary proxy's cached token for primary-only requests. +func (r *routerClient) Invalidate() { + primary := r.Primary() + if primary == nil { + return + } + + primary.Invalidate() +} + // RevokeToken revokes a primary-proxy token for primary-only proxy requests. func (r *routerClient) RevokeToken() { primary := r.Primary() diff --git a/pkg/proxy/router_test.go b/pkg/proxy/router_test.go index 0cacfa03..66d468ca 100644 --- a/pkg/proxy/router_test.go +++ b/pkg/proxy/router_test.go @@ -296,6 +296,7 @@ func (f *fakeRouterClient) URL() string { return f.url } func (f *fakeRouterClient) RegisterToken() string { return f.token } +func (f *fakeRouterClient) Invalidate() {} func (f *fakeRouterClient) RevokeToken() {} func (f *fakeRouterClient) ClickHouseDatasources() []string { diff --git a/pkg/proxy/server.go b/pkg/proxy/server.go index 1b43c83c..a821c9a3 100644 --- a/pkg/proxy/server.go +++ b/pkg/proxy/server.go @@ -679,6 +679,10 @@ func (s *server) RegisterToken() string { return NoAuthToken } +// Invalidate is a no-op: the embedded proxy issues no bearer tokens. +func (s *server) Invalidate() { +} + func (s *server) RevokeToken() { } diff --git a/pkg/searchruntime/runtime.go b/pkg/searchruntime/runtime.go index 1bcdfbbe..6a491f84 100644 --- a/pkg/searchruntime/runtime.go +++ b/pkg/searchruntime/runtime.go @@ -97,6 +97,7 @@ func Build( log, proxyService.URL(), func() string { return proxyService.RegisterToken() }, + proxyService.Invalidate, localCache, model, ) diff --git a/pkg/server/api.go b/pkg/server/api.go index 5980160b..8d412c77 100644 --- a/pkg/server/api.go +++ b/pkg/server/api.go @@ -767,41 +767,74 @@ func (s *service) proxyRequestWithService( } targetURL := baseURL + requestPath - req, err := http.NewRequestWithContext(ctx, method, targetURL, body) - if err != nil { - return nil, http.StatusInternalServerError, nil, fmt.Errorf("creating proxy request: %w", err) - } - for key, values := range headers { - for _, value := range values { - req.Header.Add(key, value) + // Buffer the body so the request can be replayed on an auth retry. + var bodyBytes []byte + if body != nil { + buffered, err := io.ReadAll(body) + if err != nil { + return nil, http.StatusInternalServerError, nil, fmt.Errorf("reading proxy request body: %w", err) } - } - req.Header.Del("Authorization") - if v := attribution.FromContext(ctx); v != "" { - req.Header.Set(attribution.Header, v) + bodyBytes = buffered } - token := proxySvc.RegisterToken() defer proxySvc.RevokeToken() - if token != "" && token != proxy.NoAuthToken { - req.Header.Set("Authorization", "Bearer "+token) + send := func() ([]byte, int, http.Header, error) { + var reqBody io.Reader + if bodyBytes != nil { + reqBody = bytes.NewReader(bodyBytes) + } + + req, err := http.NewRequestWithContext(ctx, method, targetURL, reqBody) + if err != nil { + return nil, http.StatusInternalServerError, nil, fmt.Errorf("creating proxy request: %w", err) + } + + for key, values := range headers { + for _, value := range values { + req.Header.Add(key, value) + } + } + req.Header.Del("Authorization") + + if v := attribution.FromContext(ctx); v != "" { + req.Header.Set(attribution.Header, v) + } + + if token := proxySvc.RegisterToken(); token != "" && token != proxy.NoAuthToken { + req.Header.Set("Authorization", "Bearer "+token) + } + + resp, err := s.httpClient.Do(req) + if err != nil { + return nil, http.StatusBadGateway, nil, err + } + defer func() { _ = resp.Body.Close() }() + + data, err := io.ReadAll(resp.Body) + if err != nil { + return nil, resp.StatusCode, resp.Header.Clone(), fmt.Errorf("reading proxy response: %w", err) + } + + return data, resp.StatusCode, resp.Header.Clone(), nil } - resp, err := s.httpClient.Do(req) + data, status, header, err := send() if err != nil { - return nil, http.StatusBadGateway, nil, err + return data, status, header, err } - defer func() { _ = resp.Body.Close() }() - data, err := io.ReadAll(resp.Body) - if err != nil { - return nil, resp.StatusCode, resp.Header.Clone(), fmt.Errorf("reading proxy response: %w", err) + // One invalidate-and-retry on auth rejection covers a token revoked + // server-side before the local refresh buffer kicks in. + if status == http.StatusUnauthorized || status == http.StatusForbidden { + proxySvc.Invalidate() + + return send() } - return data, resp.StatusCode, resp.Header.Clone(), nil + return data, status, header, nil } func (s *service) primaryProxyService() (proxy.Service, error) { diff --git a/pkg/server/operations_ethnode_test.go b/pkg/server/operations_ethnode_test.go index 36d8ef14..348d32dd 100644 --- a/pkg/server/operations_ethnode_test.go +++ b/pkg/server/operations_ethnode_test.go @@ -112,6 +112,7 @@ func (p *ethNodeOperationProxy) Start(_ context.Context) error { return nil } func (p *ethNodeOperationProxy) Stop(_ context.Context) error { return nil } func (p *ethNodeOperationProxy) URL() string { return p.url } func (p *ethNodeOperationProxy) RegisterToken() string { return proxy.NoAuthToken } +func (p *ethNodeOperationProxy) Invalidate() {} func (p *ethNodeOperationProxy) RevokeToken() {} func (p *ethNodeOperationProxy) ClickHouseDatasources() []string { return nil diff --git a/pkg/server/proxy_routing_test.go b/pkg/server/proxy_routing_test.go index 8e019298..f8b79642 100644 --- a/pkg/server/proxy_routing_test.go +++ b/pkg/server/proxy_routing_test.go @@ -250,6 +250,7 @@ func (c *routingProxyClient) Start(_ context.Context) error { return nil } func (c *routingProxyClient) Stop(_ context.Context) error { return nil } func (c *routingProxyClient) URL() string { return c.url } func (c *routingProxyClient) RegisterToken() string { return c.token } +func (c *routingProxyClient) Invalidate() {} func (c *routingProxyClient) RevokeToken() {} func (c *routingProxyClient) ClickHouseDatasources() []string { return datasourceNames(c.ClickHouseDatasourceInfo()) From da8474785391432e1896d2dff4791b51d80aba6c Mon Sep 17 00:00:00 2001 From: Sam Calder-Mason Date: Wed, 17 Jun 2026 11:00:08 +1000 Subject: [PATCH 2/2] Reload credentials from disk on each interactive token fetch Restores pre-refactor behaviour: the interactive refresh Source now reloads the credential file on every Token call, so a login or logout performed by the host CLI on the bind-mounted credentials file is observed by a running server without a restart. Returns ErrNotAuthenticated when the file is absent. --- pkg/auth/token/token.go | 19 ++++++++++++++++++- pkg/auth/token/token_test.go | 33 ++++++++++++++++++++++++++++++--- 2 files changed, 48 insertions(+), 4 deletions(-) diff --git a/pkg/auth/token/token.go b/pkg/auth/token/token.go index fbe2eafc..46e95719 100644 --- a/pkg/auth/token/token.go +++ b/pkg/auth/token/token.go @@ -7,6 +7,7 @@ package token import ( "context" + "errors" "fmt" "strings" "sync" @@ -22,6 +23,10 @@ import ( // other mode ("", "oauth", "oidc") uses the interactive refresh-token grant. const ModeClientCredentials = "client_credentials" +// ErrNotAuthenticated is returned by an interactive Source when no credentials +// are stored (e.g. the user has not run `panda auth login`, or logged out). +var ErrNotAuthenticated = errors.New("not authenticated") + // clientCredentialsBuffer is how long before expiry a cached client_credentials // access token is re-minted, in addition to the proactive fractional refresh. const clientCredentialsBuffer = 5 * time.Minute @@ -131,8 +136,20 @@ func NewClientCredentialsSource(log logrus.FieldLogger, c client.Client, mintTim } } -// Token returns a valid access token from the credential store. +// Token returns a valid access token from the credential store. It reloads the +// credentials from disk on each call so a login or logout performed outside the +// running process (the host CLI writes the bind-mounted credentials file) is +// observed without a server restart. func (s *refreshSource) Token(_ context.Context) (string, error) { + tokens, err := s.store.Load() + if err != nil { + return "", err + } + + if tokens == nil { + return "", ErrNotAuthenticated + } + return s.store.GetAccessToken() } diff --git a/pkg/auth/token/token_test.go b/pkg/auth/token/token_test.go index 1ad16542..02f05b41 100644 --- a/pkg/auth/token/token_test.go +++ b/pkg/auth/token/token_test.go @@ -166,9 +166,17 @@ type fakeStore struct { invalidated bool } -func (f *fakeStore) Path() string { return "" } -func (f *fakeStore) Save(*client.Tokens) error { return nil } -func (f *fakeStore) Load() (*client.Tokens, error) { return nil, nil } +func (f *fakeStore) Path() string { return "" } +func (f *fakeStore) Save(*client.Tokens) error { return nil } + +func (f *fakeStore) Load() (*client.Tokens, error) { + if f.token == "" { + return nil, nil + } + + return &client.Tokens{AccessToken: f.token}, nil +} + func (f *fakeStore) Clear() error { return nil } func (f *fakeStore) GetAccessToken() (string, error) { return f.token, nil } func (f *fakeStore) Invalidate() { f.invalidated = true } @@ -195,6 +203,25 @@ func TestRefreshSourceDelegatesToStore(t *testing.T) { } } +func TestRefreshSourceDetectsLogoutOnReload(t *testing.T) { + t.Parallel() + + fs := &fakeStore{token: "access-123"} + src := NewRefreshSource(fs) + + if _, err := src.Token(context.Background()); err != nil { + t.Fatalf("Token error while authenticated: %v", err) + } + + // Simulate a logout (credentials file cleared) under a running server. + fs.token = "" + + _, err := src.Token(context.Background()) + if !errors.Is(err, ErrNotAuthenticated) { + t.Fatalf("expected ErrNotAuthenticated after logout, got %v", err) + } +} + func TestNewSourceNilWhenUnconfigured(t *testing.T) { t.Parallel()