Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions modules/clickhouse/schema_routing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ func (c *schemaProxyClient) Start(_ context.Context) error { return nil }
func (c *schemaProxyClient) Stop(_ context.Context) error { return nil }
func (c *schemaProxyClient) URL() string { return c.url }
func (c *schemaProxyClient) RegisterToken() string { return c.token }
func (c *schemaProxyClient) Invalidate() {}
func (c *schemaProxyClient) RevokeToken() {}
func (c *schemaProxyClient) ClickHouseDatasources() []string {
return schemaDatasourceNames(c.ClickHouseDatasourceInfo())
Expand Down
1 change: 1 addition & 0 deletions pkg/app/refresh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ func (f *fakeProxyClient) Start(_ context.Context) error { return
func (f *fakeProxyClient) Stop(_ context.Context) error { return nil }
func (f *fakeProxyClient) URL() string { return "" }
func (f *fakeProxyClient) RegisterToken() string { return "" }
func (f *fakeProxyClient) Invalidate() {}
func (f *fakeProxyClient) RevokeToken() {}
func (f *fakeProxyClient) Discover(_ context.Context) error { return nil }
func (f *fakeProxyClient) EnsureAuthenticated(_ context.Context) error { return nil }
Expand Down
37 changes: 37 additions & 0 deletions pkg/auth/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,43 @@ type Tokens struct {
RefreshTokenIssuedAt time.Time `json:"refresh_token_issued_at,omitempty"`
}

// DefaultRefreshFraction is the elapsed-lifetime fraction at which an access
// token is proactively refreshed: at 0.5 it refreshes once half its lifetime
// has passed (e.g. ~30 min into a 1h token), leaving a wide margin so a request
// never races a just-expired token.
const DefaultRefreshFraction = 0.5

// ShouldRefresh reports whether a token expiring at expiresAt (minted with an
// original lifetime of expiresIn seconds) should be proactively refreshed at
// the moment now.
//
// It returns true once the token is within buffer of expiry, or — when the
// original lifetime is known — once it has passed refreshFraction of that
// lifetime. refreshFraction is the elapsed-life fraction at which to refresh
// (e.g. 0.5 refreshes at the halfway point, well before expiry, so a request
// never races a just-expired token). A zero expiresAt (unknown expiry) only
// triggers on the buffer check.
func ShouldRefresh(now, expiresAt time.Time, expiresIn int, buffer time.Duration, refreshFraction float64) bool {
if expiresAt.IsZero() {
return false
}

if now.Add(buffer).After(expiresAt) {
return true
}

if expiresIn > 0 && refreshFraction > 0 && refreshFraction < 1 {
lifetime := time.Duration(expiresIn) * time.Second
refreshAt := expiresAt.Add(-time.Duration(float64(lifetime) * (1 - refreshFraction)))

if now.After(refreshAt) {
return true
}
}

return false
}

// Config configures the OAuth client.
type Config struct {
// IssuerURL is the OIDC issuer URL (e.g., https://dex.example.com).
Expand Down
82 changes: 82 additions & 0 deletions pkg/auth/client/refreshpolicy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package client

import (
"testing"
"time"
)

func TestShouldRefresh(t *testing.T) {
t.Parallel()

now := time.Now()
const buffer = 5 * time.Minute

tests := []struct {
name string
expiresAt time.Time
expiresIn int
fraction float64
want bool
}{
{
name: "fresh token below the fraction is not refreshed",
expiresAt: now.Add(40 * time.Minute), // 1h token, 33% elapsed
expiresIn: 3600,
fraction: 0.5,
want: false,
},
{
name: "token past 50% of its lifetime is refreshed",
expiresAt: now.Add(20 * time.Minute), // 1h token, 66% elapsed
expiresIn: 3600,
fraction: 0.5,
want: true,
},
{
name: "token exactly at the fraction is not yet refreshed",
expiresAt: now.Add(30 * time.Minute), // 1h token, exactly 50%
expiresIn: 3600,
fraction: 0.5,
want: false,
},
{
name: "within the expiry buffer is always refreshed",
expiresAt: now.Add(2 * time.Minute),
expiresIn: 3600,
fraction: 0.5,
want: true,
},
{
name: "a higher fraction refreshes later",
expiresAt: now.Add(20 * time.Minute), // 66% elapsed
expiresIn: 3600,
fraction: 0.75, // refresh only past 75%
want: false,
},
{
name: "unknown lifetime falls back to the buffer only",
expiresAt: now.Add(20 * time.Minute),
expiresIn: 0,
fraction: 0.5,
want: false,
},
{
name: "zero expiry never triggers",
expiresAt: time.Time{},
expiresIn: 3600,
fraction: 0.5,
want: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

got := ShouldRefresh(now, tt.expiresAt, tt.expiresIn, buffer, tt.fraction)
if got != tt.want {
t.Fatalf("ShouldRefresh = %v, want %v", got, tt.want)
}
})
}
}
54 changes: 47 additions & 7 deletions pkg/auth/store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ type Store interface {
// GetAccessToken returns a valid access token, refreshing if needed.
GetAccessToken() (string, error)

// Invalidate forces the next GetAccessToken to refresh the access token,
// even if it has not yet hit the local expiry buffer. Used when the proxy
// rejects a token that should still be valid locally (e.g. revocation).
Invalidate()

// IsAuthenticated returns true if valid tokens are stored.
IsAuthenticated() bool
}
Expand Down Expand Up @@ -68,11 +73,12 @@ type Config struct {

// store implements the Store interface.
type store struct {
log logrus.FieldLogger
cfg Config
mu sync.RWMutex
tokens *client.Tokens
refreshMu sync.Mutex
log logrus.FieldLogger
cfg Config
mu sync.RWMutex
tokens *client.Tokens
forceRefresh bool
refreshMu sync.Mutex
}

// New creates a new credential store.
Expand Down Expand Up @@ -197,6 +203,30 @@ func (s *store) GetAccessToken() (string, error) {
return tokens.AccessToken, nil
}

// Invalidate forces the next GetAccessToken to refresh.
func (s *store) Invalidate() {
s.mu.Lock()
defer s.mu.Unlock()

s.forceRefresh = true
}

// forceRefreshRequested reports whether Invalidate has armed a forced refresh.
func (s *store) forceRefreshRequested() bool {
s.mu.RLock()
defer s.mu.RUnlock()

return s.forceRefresh
}

// clearForceRefresh clears the forced-refresh flag after a successful refresh.
func (s *store) clearForceRefresh() {
s.mu.Lock()
defer s.mu.Unlock()

s.forceRefresh = false
}

// IsAuthenticated returns true if valid tokens are stored.
func (s *store) IsAuthenticated() bool {
tokens, err := s.getTokens()
Expand Down Expand Up @@ -234,8 +264,16 @@ func (s *store) needsRefresh(tokens *client.Tokens) bool {
return true
}

// Refresh if the access token is within the buffer of expiry.
if time.Now().Add(s.cfg.RefreshBuffer).After(tokens.ExpiresAt) {
if s.forceRefreshRequested() {
return true
}

// Refresh proactively once the access token passes the configured fraction
// of its lifetime, or once it is within the buffer of expiry.
if client.ShouldRefresh(
time.Now(), tokens.ExpiresAt, tokens.ExpiresIn,
s.cfg.RefreshBuffer, client.DefaultRefreshFraction,
) {
return true
}

Expand Down Expand Up @@ -291,6 +329,8 @@ func (s *store) refresh(prior *client.Tokens) (*client.Tokens, error) {
return nil, err
}

s.clearForceRefresh()

return newTokens, nil
}

Expand Down
107 changes: 107 additions & 0 deletions pkg/auth/store/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package store
import (
"context"
"errors"
"path/filepath"
"sync"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -165,6 +166,112 @@ func TestGetAccessTokenSerializesConcurrentRefreshes(t *testing.T) {
}
}

func TestGetAccessTokenRefreshesAtAccessTokenHalfLife(t *testing.T) {
t.Parallel()

client := &stubAuthClient{}
store := New(logrus.New(), Config{
Path: filepath.Join(t.TempDir(), "creds.json"),
AuthClient: client,
RefreshBuffer: 5 * time.Minute,
}).(*store)
// 1h token with 20m left = 66% elapsed: past the 50% refresh point, but
// well outside the 5m expiry buffer.
store.tokens = &authclient.Tokens{
AccessToken: "aging",
RefreshToken: "refresh-token",
ExpiresIn: 3600,
ExpiresAt: time.Now().Add(20 * time.Minute),
}

token, err := store.GetAccessToken()
if err != nil {
t.Fatalf("GetAccessToken returned error: %v", err)
}

if token != "refreshed-token" {
t.Fatalf("expected proactive refresh past 50%% of lifetime, got %q", token)
}

if client.refreshCalls != 1 {
t.Fatalf("expected 1 refresh call, got %d", client.refreshCalls)
}
}

func TestGetAccessTokenDoesNotRefreshBeforeAccessTokenHalfLife(t *testing.T) {
t.Parallel()

client := &stubAuthClient{}
store := New(logrus.New(), Config{
Path: filepath.Join(t.TempDir(), "creds.json"),
AuthClient: client,
RefreshBuffer: 5 * time.Minute,
}).(*store)
// 1h token with 40m left = 33% elapsed: before the 50% refresh point.
store.tokens = &authclient.Tokens{
AccessToken: "fresh",
RefreshToken: "refresh-token",
ExpiresIn: 3600,
ExpiresAt: time.Now().Add(40 * time.Minute),
}

token, err := store.GetAccessToken()
if err != nil {
t.Fatalf("GetAccessToken returned error: %v", err)
}

if token != "fresh" {
t.Fatalf("expected original token before 50%% of lifetime, got %q", token)
}

if client.refreshCalls != 0 {
t.Fatalf("expected no refresh calls, got %d", client.refreshCalls)
}
}

func TestInvalidateForcesRefreshThenClears(t *testing.T) {
t.Parallel()

client := &stubAuthClient{}
store := New(logrus.New(), Config{
Path: filepath.Join(t.TempDir(), "creds.json"),
AuthClient: client,
RefreshBuffer: 5 * time.Minute,
}).(*store)
// A fresh token that would not otherwise refresh (0% elapsed).
store.tokens = &authclient.Tokens{
AccessToken: "fresh",
RefreshToken: "refresh-token",
ExpiresIn: 3600,
ExpiresAt: time.Now().Add(time.Hour),
}

store.Invalidate()

token, err := store.GetAccessToken()
if err != nil {
t.Fatalf("GetAccessToken returned error: %v", err)
}

if token != "refreshed-token" {
t.Fatalf("expected forced refresh after Invalidate, got %q", token)
}

if client.refreshCalls != 1 {
t.Fatalf("expected 1 refresh call after Invalidate, got %d", client.refreshCalls)
}

// The forced-refresh flag must clear: the now-fresh token should not
// refresh again on the next call.
if _, err := store.GetAccessToken(); err != nil {
t.Fatalf("second GetAccessToken returned error: %v", err)
}

if client.refreshCalls != 1 {
t.Fatalf("expected no extra refresh after the flag cleared, got %d", client.refreshCalls)
}
}

type stubAuthClient struct {
refreshCalls int
refreshErr error
Expand Down
Loading
Loading