From 79ec925c8e475173e53386ac493d67dcb258cd69 Mon Sep 17 00:00:00 2001 From: Tom Proctor Date: Wed, 6 May 2026 16:20:06 +0000 Subject: [PATCH] gocached: support namespaces Remove the special-case concept of "global writes" and instead allow callers to provide a namespace mapping function that makes policy decisions about which clients are globally trusted and which clients should be isolated from each other. As a result, all clients are now able to write, but not necessarily into the shared global namespace. All clients can still read from the global namespace, as well as their own. The Namespaces table already existed in the schema, so no schema updates required. However, there are some breaking changes in the package API, WithJWTAuth now takes issuer URLs only and policy moves into the new WithNamespaceMapping option. cmd/gocached implements the spirit of the old API in terms of a namespace mapping function, with the main difference that it now allows writes if you don't have the global claims, but just into your own isolated namespace. Updates tailscale/corp#38092 Signed-off-by: Tom Proctor --- cmd/gocached/gocached.go | 31 ++-- gocached/gocached.go | 287 ++++++++++++++++++++++--------------- gocached/gocached_test.go | 289 +++++++++++++++++++++----------------- 3 files changed, 354 insertions(+), 253 deletions(-) diff --git a/cmd/gocached/gocached.go b/cmd/gocached/gocached.go index e444a07..7cd34e7 100644 --- a/cmd/gocached/gocached.go +++ b/cmd/gocached/gocached.go @@ -9,7 +9,6 @@ import ( "flag" "fmt" "log" - "maps" "net" "net/http" "os" @@ -74,15 +73,27 @@ func main() { log.Fatal("must specify --jwt-claim at least once when --jwt-issuer is set") } - globalClaims := map[string]string{} - maps.Copy(globalClaims, jwtClaims) - maps.Copy(globalClaims, globalJWTClaims) - - opts = append(opts, gocached.WithJWTAuth(gocached.JWTIssuerConfig{ - Issuer: *jwtIssuer, - RequiredClaims: jwtClaims, - GlobalWriteClaims: globalClaims, - })) + opts = append(opts, + gocached.WithJWTAuth(*jwtIssuer), + gocached.WithNamespaceMapping(func(claims map[string]any) (gocached.Namespace, error) { + var ns gocached.Namespace + for k, want := range jwtClaims { + if got := claims[k]; got != want { + return "", fmt.Errorf("claim %q = %v, want %v", k, got, want) + } + if ns != "" { + ns += "," + } + ns += gocached.Namespace(fmt.Sprintf("%s=%s", k, want)) + } + for k, want := range globalJWTClaims { + if got := claims[k]; got != want { + return ns, nil + } + } + return gocached.GlobalNamespace, nil + }), + ) } srv, err := gocached.NewServer(opts...) diff --git a/gocached/gocached.go b/gocached/gocached.go index 9f97931..961ca6a 100644 --- a/gocached/gocached.go +++ b/gocached/gocached.go @@ -250,6 +250,24 @@ func (srv *Server) start() error { } srv.db = db + // Fill the namespace ID cache. + rows, err := srv.db.Query("SELECT NamespaceID, Namespace FROM Namespaces") + if err != nil { + return err + } + defer rows.Close() + for rows.Next() { + var id int64 + var ns string + if err := rows.Scan(&id, &ns); err != nil { + return err + } + srv.namespaces[Namespace(ns)] = id + } + if err := rows.Err(); err != nil { + return err + } + reg := prometheus.NewRegistry() reg.MustRegister( collectors.NewGoCollector(), @@ -276,17 +294,13 @@ func (srv *Server) start() error { } if len(srv.jwtIssuers) > 0 { - issuerURLs := make([]string, 0, len(srv.jwtIssuers)) - for iss := range srv.jwtIssuers { - issuerURLs = append(issuerURLs, iss) - } - srv.jwtValidator = ijwt.NewJWTValidator(srv.logf, gocachedAudience, issuerURLs) + srv.jwtValidator = ijwt.NewJWTValidator(srv.logf, gocachedAudience, srv.jwtIssuers) if err := srv.jwtValidator.RunUpdateJWKSLoop(srv.shutdownCtx); err != nil { return fmt.Errorf("failed to fetch JWKS for JWT validator: %w", err) } - for iss, entry := range srv.jwtIssuers { - srv.logf("gocached: using JWT issuer %q with required claims %v, global write claims %v", iss, entry.requiredClaims, entry.globalWriteClaims) + for _, iss := range srv.jwtIssuers { + srv.logf("gocached: using JWT issuer %q", iss) } go srv.runCleanSessionsLoop() @@ -383,40 +397,42 @@ func WithMaxAge(maxAge time.Duration) ServerOption { } } -// JWTIssuerConfig configures a single OIDC issuer for JWT-based authentication. -type JWTIssuerConfig struct { - // Issuer is the OIDC issuer URL. It must be a reachable HTTP(S) server - // that serves its JWKS via a URL discoverable at - // /.well-known/openid-configuration. - Issuer string - - // RequiredClaims are claims that any JWT from this issuer must have to - // start a session. All key-value pairs must match exactly. - RequiredClaims map[string]string - - // GlobalWriteClaims are claims that a JWT from this issuer must have to - // write to the cache's global namespace. It should be a superset of - // RequiredClaims. - GlobalWriteClaims map[string]string +// Namespace identifies a logical partition of the cache where each peer is +// equally trusted. Every session is associated with exactly one Namespace to +// which it can read and write; sessions for non-global namespaces also read +// from [GlobalNamespace]. See [WithNamespaceMapping]. Namespace is case- +// insensitive, but is otherwise treated as an opaque string. +type Namespace string + +// GlobalNamespace is a trusted namespace that all sessions can read from. +const GlobalNamespace Namespace = "" + +// WithJWTAuth enables JWT-based authentication for the server. Each issuer +// must be a reachable HTTP(S) server that serves its JWKS via a URL +// discoverable at /.well-known/openid-configuration. JWTs presented for token +// exchange must pass the standard signature/issuer/audience/expiry checks +// against one of these issuers. If [WithNamespaceMapping] is provided, then +// it may still be rejected if the mapping function returns an error for its +// claims. No requests other than token exchange are allowed without +// authentication. May be called multiple times; issuers accumulate. +func WithJWTAuth(issuers ...string) ServerOption { + return func(srv *Server) { + srv.jwtIssuers = append(srv.jwtIssuers, issuers...) + } } -// WithJWTAuth enables JWT-based authentication for the server. Each issuer must -// be a reachable HTTP(S) server that serves its JWKS via a URL discoverable at -// /.well-known/openid-configuration, and any JWT presented to the server must -// exactly match the issuer's required claims to start a session. No requests are -// allowed without authentication if JWT auth is enabled. It can be called multiple -// times; configs accumulate. -func WithJWTAuth(issuers ...JWTIssuerConfig) ServerOption { +// WithNamespaceMapping sets the function that makes policy decisions based on +// a JWT's claims. It is called once per token exchange after the JWT's +// signature and standard claims have been validated. It should return an error +// if the claims are not authorized, and otherwise return which [Namespace] it +// is allowed to read and write in. All sessions are allowed to read from the +// [GlobalNamespace] regardless of the namespace returned. Check claims["iss"] +// to switch on per-issuer rules. If JWT auth is enabled, but no mapping +// function is provided, all sessions will read and write in the +// [GlobalNamespace]. +func WithNamespaceMapping(fn func(claims map[string]any) (Namespace, error)) ServerOption { return func(srv *Server) { - if srv.jwtIssuers == nil { - srv.jwtIssuers = make(map[string]*jwtIssuerConfig) - } - for _, ic := range issuers { - srv.jwtIssuers[ic.Issuer] = &jwtIssuerConfig{ - requiredClaims: ic.RequiredClaims, - globalWriteClaims: ic.GlobalWriteClaims, - } - } + srv.namespaceMapping = fn } } @@ -428,12 +444,21 @@ func NewServer(opts ...ServerOption) (*Server, error) { shutdownCtx: context.Background(), logf: log.Printf, sessions: make(map[string]*sessionData), + namespaces: make(map[Namespace]int64), clock: time.Now, } for _, opt := range opts { opt(srv) } + if len(srv.jwtIssuers) > 0 && srv.namespaceMapping == nil { + // If JWT auth is enabled, but not namespace mapping, every session is in + // the global namespace. + srv.namespaceMapping = func(claims map[string]any) (Namespace, error) { + return GlobalNamespace, nil + } + } + err := srv.start() if err != nil { return nil, err @@ -474,13 +499,15 @@ type Server struct { shutdownCtx context.Context shutdownCancel context.CancelFunc - jwtValidator *ijwt.Validator // nil unless jwtIssuers is non-empty - jwtIssuers map[string]*jwtIssuerConfig // keyed by issuer URL + jwtValidator *ijwt.Validator // nil unless jwtIssuers is non-empty + jwtIssuers []string // accepted issuer URLs + namespaceMapping func(claims map[string]any) (Namespace, error) // required when jwtIssuers is non-empty mu sync.RWMutex // guards following fields in this block sessions map[string]*sessionData // maps access token -> session data. accessDirty map[actionKey]int64 // action -> accessTime accessFlushTimer *time.Timer // nil if no flush is scheduled + namespaces map[Namespace]int64 // cached namespace string -> NamespaceID // sqliteWriteMu serializes access to SQLite. In theory the SQLite driver // should serialize access with our 5000ms busy timeout, but empirically we @@ -520,18 +547,13 @@ type Server struct { } } -// jwtIssuerConfig holds per-issuer claim requirements for JWT auth. -type jwtIssuerConfig struct { - requiredClaims map[string]string - globalWriteClaims map[string]string -} - // sessionData corresponds to a specific access token, and is only used if JWT // auth is enabled. type sessionData struct { - expiry time.Time // Session valid until. - globalNSWrite bool // Whether this session can write to the cache's global namespace. - claims map[string]any // Claims from the JWT used to create this session, stored for debug. + expiry time.Time // Session valid until. + namespaceID int64 // The namespace this session writes to. 0 means GlobalNamespace; non-zero sessions also read from 0. + namespace Namespace // Canonical (lower-cased) Namespace this session writes to. Used as the namespaces cache key and shown on the debug page. + claims map[string]any // Claims from the JWT used to create this session, stored for debug. mu sync.Mutex // Guards stats. stats stats @@ -641,12 +663,11 @@ func (srv *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { } if r.Method == "PUT" { - if sessionData != nil && !sessionData.globalNSWrite { - // TODO(tomhjp): support per-namespace writes. - http.Error(w, "forbidden", http.StatusForbidden) - return + var writeNS int64 + if sessionData != nil { + writeNS = sessionData.namespaceID } - srv.handlePut(w, r, reqStats) + srv.handlePut(w, r, reqStats, writeNS) return } if r.Method != "GET" && r.Method != "HEAD" { @@ -654,7 +675,7 @@ func (srv *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } if strings.HasPrefix(r.URL.Path, "/action/") { - srv.handleGetAction(w, r, reqStats) + srv.handleGetAction(w, r, reqStats, sessionData) return } if sessionData != nil && r.URL.Path == "/session/stats" { @@ -723,7 +744,7 @@ func getHexSuffix(r *http.Request, prefix string) (hexSuffix string, ok bool) { // actionKey is the comparable value type for the (NamespaceID, ActionID) // primary key tuple used in the SQLite Actions table. type actionKey struct { - NamespaceID int // 0 for global + NamespaceID int64 // 0 for global ActionID string } @@ -745,7 +766,27 @@ func validHex(x string) bool { // we do a DB write to update it. const relAtimeSeconds = 60 * 60 * 24 // 1 day -func (srv *Server) handleGetAction(w http.ResponseWriter, r *http.Request, stats *stats) { +const getFromGlobalNamespace = ` +SELECT b.SHA256, b.StoredSize, b.UncompressedSize, b.SmallData, a.AltOutputID, a.AccessTime, a.NamespaceID +FROM Actions a, Blobs b +WHERE a.NamespaceID = 0 + AND a.ActionID = ? + AND a.BlobID = b.BlobID +` + +// Get hits from the global namespace first so shared cache mtime is bumped +// with higher priority than namespaced cache. +const getFromSessionNamespace = ` +SELECT b.SHA256, b.StoredSize, b.UncompressedSize, b.SmallData, a.AltOutputID, a.AccessTime, a.NamespaceID +FROM Actions a, Blobs b +WHERE a.NamespaceID IN (0, ?) + AND a.ActionID = ? + AND a.BlobID = b.BlobID +ORDER BY CASE a.NamespaceID WHEN 0 THEN 0 ELSE 1 END +LIMIT 1 +` + +func (srv *Server) handleGetAction(w http.ResponseWriter, r *http.Request, stats *stats, sessionData *sessionData) { srv.m.ActiveGets.Add(1) defer srv.m.ActiveGets.Add(-1) @@ -771,19 +812,24 @@ func (srv *Server) handleGetAction(w http.ResponseWriter, r *http.Request, stats return } - var sha256hex string - var storedSize, uncompressedSize int64 - var smallData sql.NullString - var altObjectID string - var accessTime int64 - var actionKey = actionKey{ - NamespaceID: 0, // global for now; TODO(bradfitz): support namespac - ActionID: actionID, - } - err := srv.db.QueryRow( - "SELECT b.SHA256, b.StoredSize, b.UncompressedSize, b.SmallData, a.AltOutputID, a.AccessTime FROM Actions a, Blobs b WHERE a.NameSpaceID = ? AND a.ActionID = ? AND a.BlobID = b.BlobID", - actionKey.NamespaceID, actionKey.ActionID).Scan( - &sha256hex, &storedSize, &uncompressedSize, &smallData, &altObjectID, &accessTime) + var ( + sha256hex string + storedSize, uncompressedSize int64 + smallData sql.NullString + altObjectID string + accessTime int64 + err error + actionKey = actionKey{ + ActionID: actionID, + } + ) + if sessionData != nil && sessionData.namespaceID != 0 { + err = srv.db.QueryRow(getFromSessionNamespace, sessionData.namespaceID, actionID).Scan( + &sha256hex, &storedSize, &uncompressedSize, &smallData, &altObjectID, &accessTime, &actionKey.NamespaceID) + } else { + err = srv.db.QueryRow(getFromGlobalNamespace, actionID).Scan( + &sha256hex, &storedSize, &uncompressedSize, &smallData, &altObjectID, &accessTime, &actionKey.NamespaceID) + } if err != nil { if errors.Is(err, sql.ErrNoRows) { http.Error(w, "not found", http.StatusNotFound) @@ -997,7 +1043,7 @@ func (srv *Server) getObjectFromDiskOrPeer(_ context.Context, sha256hex string, return f, nil } -func (s *Server) handlePut(w http.ResponseWriter, r *http.Request, stats *stats) { +func (s *Server) handlePut(w http.ResponseWriter, r *http.Request, stats *stats, namespaceID int64) { s.m.ActivePuts.Add(1) defer s.m.ActivePuts.Add(-1) @@ -1074,13 +1120,12 @@ func (s *Server) handlePut(w http.ResponseWriter, r *http.Request, stats *stats) // Insert or update the action in the database. nowUnix := s.now().Unix() altObjectID := "" - namespace := 0 // global for now; TODO(bradfitz): support namespaces if sha256hex != outputID { altObjectID = outputID } res, err := s.db.Exec(`INSERT OR IGNORE INTO Actions (NamespaceID, ActionID, BlobID, AltOutputID, CreateTime, AccessTime) VALUES (?, ?, ?, ?, ?, ?)`, - namespace, + namespaceID, actionID, blobID, altObjectID, @@ -1139,23 +1184,36 @@ func (srv *Server) handleTokenExchange(w http.ResponseWriter, r *http.Request) { return } - globalNSWrite, err := srv.evaluateClaims(jwtClaims) + ns, err := srv.namespaceMapping(jwtClaims) if err != nil { srv.m.AuthErrs.Add(1) if srv.verbose { - srv.logf("token exchange: %v", err) + srv.logf("token exchange: namespace func error: %v", err) } http.Error(w, "unauthorized", http.StatusUnauthorized) return } + ns = Namespace(strings.ToLower(string(ns))) + var namespaceID int64 + if ns != GlobalNamespace { + namespaceID, err = srv.resolveNamespaceID(ns) + if err != nil { + srv.m.AuthErrs.Add(1) + srv.logf("token exchange: %v", err) + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + } + const ttl = time.Hour // 52 base32 characters, 256 bits of entropy. accessToken := tokenPrefix + strings.ToLower(rand.Text()+rand.Text()) srv.addSessionData(accessToken, &sessionData{ - expiry: srv.now().UTC().Add(ttl), - globalNSWrite: globalNSWrite, - claims: jwtClaims, + expiry: srv.now().UTC().Add(ttl), + namespaceID: namespaceID, + namespace: ns, + claims: jwtClaims, }) resp := map[string]any{ @@ -1173,38 +1231,38 @@ func (srv *Server) handleTokenExchange(w http.ResponseWriter, r *http.Request) { srv.m.Auths.Add(1) } -func (srv *Server) evaluateClaims(claims map[string]any) (globalNSWrite bool, _ error) { - iss, _ := claims["iss"].(string) - cfg, ok := srv.jwtIssuers[iss] - if !ok { - return false, fmt.Errorf("got claims %v; unknown issuer %q", claims, iss) - } - - if missing := findMissingClaims(cfg.requiredClaims, claims); len(missing) > 0 { - return false, fmt.Errorf("got claims %v; missing required claims: %v", claims, missing) +// resolveNamespaceID returns the integer ID for the given Namespace, +// inserting a row in the Namespaces table if one doesn't already exist. The +// name must already be lower-cased or it will violate the CHECK constraint. +func (srv *Server) resolveNamespaceID(ns Namespace) (int64, error) { + // If it's not a new namespace, we only need to consult our cache of IDs. + srv.mu.Lock() + id, ok := srv.namespaces[ns] + srv.mu.Unlock() + if ok { + return id, nil } - if missing := findMissingClaims(cfg.globalWriteClaims, claims); len(missing) == 0 { - return true, nil - } else if srv.verbose { - srv.logf("token exchange: missing global namespace write claims: %v", missing) - } + srv.sqliteWriteMu.Lock() + defer srv.sqliteWriteMu.Unlock() - return false, nil -} + srv.mu.Lock() + defer srv.mu.Unlock() -func findMissingClaims(wantClaims map[string]string, gotClaims map[string]any) map[string]any { - if wantClaims == nil { - return nil + // Check if we lost a race now that we have both locks. + if id, ok = srv.namespaces[ns]; ok { + return id, nil } - missing := make(map[string]any) - for k, want := range wantClaims { - if got, ok := gotClaims[k]; !ok || got != want { - missing[k] = want - } + err := srv.db.QueryRow(`INSERT INTO Namespaces (Namespace) VALUES (?) + RETURNING NamespaceID;`, ns).Scan(&id) + if err != nil { + return 0, fmt.Errorf("resolving namespace %q: %w", ns, err) } - return missing + + srv.namespaces[ns] = id + + return id, nil } func (srv *Server) handleSessionStats(w http.ResponseWriter, sessionData *sessionData) { @@ -1722,10 +1780,11 @@ func (srv *Server) serveSessions(w http.ResponseWriter, r *http.Request) { for _, v := range srv.sessions { v.mu.Lock() sessions = append(sessions, &sessionData{ - expiry: v.expiry, - globalNSWrite: v.globalNSWrite, - claims: v.claims, - stats: v.stats, + expiry: v.expiry, + namespaceID: v.namespaceID, + namespace: v.namespace, + claims: v.claims, + stats: v.stats, }) v.mu.Unlock() } @@ -1733,15 +1792,13 @@ func (srv *Server) serveSessions(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/html; charset=utf-8") fmt.Fprintf(w, "

gocached sessions

\n") - for iss, cfg := range srv.jwtIssuers { + for _, iss := range srv.jwtIssuers { fmt.Fprintf(w, "

JWT issuer: %s

\n", iss) - fmt.Fprintf(w, "

JWT claims required: %v

\n", cfg.requiredClaims) - fmt.Fprintf(w, "

JWT global write claims required: %v

\n", cfg.globalWriteClaims) } fmt.Fprintf(w, "

Number of sessions: %d

\n", len(sessions)) fmt.Fprintf(w, "\n") - fmt.Fprintf(w, "\n") + fmt.Fprintf(w, "\n") slices.SortFunc(sessions, func(a, b *sessionData) int { return a.stats.LastUsed.Compare(b.stats.LastUsed) }) @@ -1750,10 +1807,14 @@ func (srv *Server) serveSessions(w http.ResponseWriter, r *http.Request) { if !d.stats.LastUsed.IsZero() { lastUsed = durFmt(time.Since(d.stats.LastUsed)) + " ago" } + nsLabel := "(global)" + if d.namespaceID != 0 { + nsLabel = fmt.Sprintf("%q (id=%d)", d.namespace, d.namespaceID) + } statsJSON, _ := json.MarshalIndent(d.stats, "", " ") claimsJSON, _ := json.MarshalIndent(d.claims, "", " ") - fmt.Fprintf(w, "\n", - lastUsed, d.expiry.Format(time.RFC3339), d.globalNSWrite, statsJSON, claimsJSON) + fmt.Fprintf(w, "\n", + lastUsed, d.expiry.Format(time.RFC3339), nsLabel, statsJSON, claimsJSON) } fmt.Fprintf(w, "
Last usedExpiry timeGlobal writeStatsClaims
Last usedExpiry timeNamespaceStatsClaims
%s%s%v
%s
%s
%s%s%s
%s
%s
\n") } diff --git a/gocached/gocached_test.go b/gocached/gocached_test.go index a46d92e..a795f1a 100644 --- a/gocached/gocached_test.go +++ b/gocached/gocached_test.go @@ -38,6 +38,23 @@ import ( // value in SQLite to store bytes, as it's common. const sha256OfEmpty = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" +// Package-level ECDSA P-256 keys reused across tests that need OIDC signing +// keys. Generating these is the single most expensive step in the JWT tests, +// so we share them rather than regenerating per test. +var ( + testKey1 = mustGenerateTestKey() + testKey2 = mustGenerateTestKey() + testKey3 = mustGenerateTestKey() +) + +func mustGenerateTestKey() *ecdsa.PrivateKey { + k, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + panic(fmt.Sprintf("generating test ECDSA key: %v", err)) + } + return k +} + type tester struct { t testing.TB srv *Server @@ -641,41 +658,17 @@ func TestClientConnReuse(t *testing.T) { } func TestExchangeToken(t *testing.T) { - // Generate private keys outside of the loop for speed. - privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - t.Fatalf("error generating OIDC server private key: %v", err) - } - otherPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - t.Fatalf("error generating OIDC server private key: %v", err) - } - wantClaims := map[string]string{ - "sub": "user123", - } - wantGlobalClaims := map[string]string{ - "sub": "user123", - "ref": "refs/heads/main", - } + privateKey := testKey1 + otherPrivateKey := testKey2 for name, tc := range map[string]struct { mutateClaims func(jwt.MapClaims) signingKey *ecdsa.PrivateKey wantStatusCode int - wantWrite bool }{ // Base case: no mutation. - "valid_read": { - wantStatusCode: http.StatusOK, - wantWrite: false, - }, - // Additional claim needed for write scope. - "valid_write": { - mutateClaims: func(cl jwt.MapClaims) { - cl["ref"] = "refs/heads/main" - }, + "valid": { wantStatusCode: http.StatusOK, - wantWrite: true, }, // Every other test makes one mutation from the base case that should cause failure. "missing_sub": { @@ -722,10 +715,12 @@ func TestExchangeToken(t *testing.T) { t.Run(name, func(t *testing.T) { issuer, createJWT := startOIDCServer(t, privateKey.Public()) st := newServerTester(t, - WithJWTAuth(JWTIssuerConfig{ - Issuer: issuer, - RequiredClaims: wantClaims, - GlobalWriteClaims: wantGlobalClaims, + WithJWTAuth(issuer), + WithNamespaceMapping(func(claims map[string]any) (Namespace, error) { + if claims["sub"] != "user123" { + return "", fmt.Errorf("sub = %v, want user123", claims["sub"]) + } + return GlobalNamespace, nil }), ) @@ -798,14 +793,8 @@ func TestExchangeToken(t *testing.T) { cl.AccessToken = d.AccessToken st.wantGetMiss(cl, "abc123") - if tc.wantWrite { - st.wantPut(cl, "abc123", "def456", "data789") - st.wantGet(cl, "abc123", "def456", "data789") - } else { - if _, err := cl.Put(t.Context(), "abc123", "def456", 0, nil); err == nil { - t.Fatalf("Put without write scope succeeded unexpectedly") - } - } + st.wantPut(cl, "abc123", "def456", "data789") + st.wantGet(cl, "abc123", "def456", "data789") // Check session stats. reqStats, err := http.NewRequest("GET", st.hs.URL+"/session/stats", nil) @@ -833,7 +822,7 @@ func TestExchangeToken(t *testing.T) { if stats.Gets == 0 { t.Errorf("expected non-zero gets in session stats") } - if stats.Puts == 0 && tc.wantWrite { + if stats.Puts == 0 { t.Errorf("expected non-zero puts in session stats") } }) @@ -841,94 +830,41 @@ func TestExchangeToken(t *testing.T) { } func TestMultiIssuerAuth(t *testing.T) { - // Generate separate keys for each issuer. - keyA, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - t.Fatalf("error generating key A: %v", err) - } - keyB, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - t.Fatalf("error generating key B: %v", err) - } - keyC, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - t.Fatalf("error generating key C: %v", err) - } + keyA, keyB, keyC := testKey1, testKey2, testKey3 issuerA, createJWTA := startOIDCServer(t, keyA.Public()) issuerB, createJWTB := startOIDCServer(t, keyB.Public()) issuerC, createJWTC := startOIDCServer(t, keyC.Public()) - st := newServerTester(t, - WithJWTAuth( - JWTIssuerConfig{ - Issuer: issuerA, - RequiredClaims: map[string]string{"sub": "userA"}, - GlobalWriteClaims: map[string]string{ - "sub": "userA", - "ref": "refs/heads/main", - }, - }, - JWTIssuerConfig{ - Issuer: issuerB, - RequiredClaims: map[string]string{"sub": "userB"}, - GlobalWriteClaims: map[string]string{ - "sub": "userB", - "ref": "refs/heads/main", - }, - }, - ), - ) - - makeJWTBody := func(jwtString string) []byte { - body, err := json.Marshal(map[string]any{"jwt": jwtString}) - if err != nil { - t.Fatalf("error marshaling request body: %v", err) - } - return body - } - - exchangeToken := func(jwtBody []byte) (*http.Response, string) { - t.Helper() - req, err := http.NewRequest("POST", st.hs.URL+"/auth/exchange-token", bytes.NewReader(jwtBody)) - if err != nil { - t.Fatalf("error creating request: %v", err) + namespaceFunc := func(claims map[string]any) (Namespace, error) { + iss, _ := claims["iss"].(string) + var requiredSub string + switch iss { + case issuerA: + requiredSub = "userA" + case issuerB: + requiredSub = "userB" + default: + return "", fmt.Errorf("unknown issuer %q", iss) } - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("error making request: %v", err) - } - body, err := io.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - t.Fatalf("error reading response body: %v", err) + if claims["sub"] != requiredSub { + return "", fmt.Errorf("issuer %q: sub = %v, want %v", iss, claims["sub"], requiredSub) } - if resp.StatusCode == http.StatusOK { - var d struct { - AccessToken string `json:"access_token"` - } - if err := json.Unmarshal(body, &d); err != nil { - t.Fatalf("error decoding response body: %v", err) - } - return resp, d.AccessToken + if claims["ref"] == "refs/heads/main" { + return GlobalNamespace, nil } - return resp, "" + return Namespace(requiredSub), nil } - baseClaims := func(iss string) jwt.MapClaims { - return jwt.MapClaims{ - "iss": iss, - "aud": gocachedAudience, - "nbf": jwt.NewNumericDate(time.Now().Add(-time.Minute)), - "exp": jwt.NewNumericDate(time.Now().Add(time.Hour)), - } - } + st := newServerTester(t, + WithJWTAuth(issuerA, issuerB), + WithNamespaceMapping(namespaceFunc), + ) // Issuer A: valid read-only token. t.Run("issuerA_read", func(t *testing.T) { - claims := baseClaims(issuerA) - claims["sub"] = "userA" - resp, accessToken := exchangeToken(makeJWTBody(createJWTA(claims, keyA))) + claims := baseClaims(issuerA, "userA") + resp, accessToken := exchangeToken(t, st.hs.URL, createJWTA(claims, keyA)) if resp.StatusCode != http.StatusOK { t.Fatalf("unexpected status code: want %d, got %d", http.StatusOK, resp.StatusCode) } @@ -943,10 +879,9 @@ func TestMultiIssuerAuth(t *testing.T) { // Issuer A: valid write token. t.Run("issuerA_write", func(t *testing.T) { - claims := baseClaims(issuerA) - claims["sub"] = "userA" + claims := baseClaims(issuerA, "userA") claims["ref"] = "refs/heads/main" - resp, accessToken := exchangeToken(makeJWTBody(createJWTA(claims, keyA))) + resp, accessToken := exchangeToken(t, st.hs.URL, createJWTA(claims, keyA)) if resp.StatusCode != http.StatusOK { t.Fatalf("unexpected status code: want %d, got %d", http.StatusOK, resp.StatusCode) } @@ -958,9 +893,8 @@ func TestMultiIssuerAuth(t *testing.T) { // Issuer B: valid read-only token. t.Run("issuerB_read", func(t *testing.T) { - claims := baseClaims(issuerB) - claims["sub"] = "userB" - resp, accessToken := exchangeToken(makeJWTBody(createJWTB(claims, keyB))) + claims := baseClaims(issuerB, "userB") + resp, accessToken := exchangeToken(t, st.hs.URL, createJWTB(claims, keyB)) if resp.StatusCode != http.StatusOK { t.Fatalf("unexpected status code: want %d, got %d", http.StatusOK, resp.StatusCode) } @@ -975,10 +909,9 @@ func TestMultiIssuerAuth(t *testing.T) { // Issuer B: valid write token. t.Run("issuerB_write", func(t *testing.T) { - claims := baseClaims(issuerB) - claims["sub"] = "userB" + claims := baseClaims(issuerB, "userB") claims["ref"] = "refs/heads/main" - resp, accessToken := exchangeToken(makeJWTBody(createJWTB(claims, keyB))) + resp, accessToken := exchangeToken(t, st.hs.URL, createJWTB(claims, keyB)) if resp.StatusCode != http.StatusOK { t.Fatalf("unexpected status code: want %d, got %d", http.StatusOK, resp.StatusCode) } @@ -990,9 +923,8 @@ func TestMultiIssuerAuth(t *testing.T) { // Issuer B: wrong required claims (sub doesn't match). t.Run("issuerB_wrong_sub", func(t *testing.T) { - claims := baseClaims(issuerB) - claims["sub"] = "userA" // issuer B requires sub=userB - resp, _ := exchangeToken(makeJWTBody(createJWTB(claims, keyB))) + claims := baseClaims(issuerB, "userA") // issuer B requires sub=userB + resp, _ := exchangeToken(t, st.hs.URL, createJWTB(claims, keyB)) if resp.StatusCode != http.StatusUnauthorized { t.Fatalf("unexpected status code: want %d, got %d", http.StatusUnauthorized, resp.StatusCode) } @@ -1000,15 +932,79 @@ func TestMultiIssuerAuth(t *testing.T) { // Issuer C: not configured, should be rejected. t.Run("issuerC_rejected", func(t *testing.T) { - claims := baseClaims(issuerC) - claims["sub"] = "userC" - resp, _ := exchangeToken(makeJWTBody(createJWTC(claims, keyC))) + claims := baseClaims(issuerC, "userC") + resp, _ := exchangeToken(t, st.hs.URL, createJWTC(claims, keyC)) if resp.StatusCode != http.StatusUnauthorized { t.Fatalf("unexpected status code: want %d, got %d", http.StatusUnauthorized, resp.StatusCode) } }) } +func TestNamespaces(t *testing.T) { + privateKey := testKey1 + issuer, createJWT := startOIDCServer(t, privateKey.Public()) + + // Namespace policy: ref=refs/heads/main grants global write; otherwise + // the session writes to a namespace named after its sub claim. + st := newServerTester(t, + WithJWTAuth(issuer), + WithNamespaceMapping(func(claims map[string]any) (Namespace, error) { + if claims["ref"] == "refs/heads/main" { + return GlobalNamespace, nil + } + sub, _ := claims["sub"].(string) + if sub == "" { + return "", fmt.Errorf("missing sub claim") + } + return Namespace(sub), nil + }), + ) + + // Exchange tokens once, but mint a fresh client (with empty disk cache) + // per operation so client-side disk hits don't short-circuit the + // server-side namespace-routing assertions we want to make. + mainClaims := baseClaims(issuer, "main-builder") + mainClaims["ref"] = "refs/heads/main" + _, aliceToken := exchangeToken(t, st.hs.URL, createJWT(baseClaims(issuer, "alice"), privateKey)) + _, bobToken := exchangeToken(t, st.hs.URL, createJWT(baseClaims(issuer, "bob"), privateKey)) + _, mainToken := exchangeToken(t, st.hs.URL, createJWT(mainClaims, privateKey)) + freshClient := func(token string) *cachers.HTTPClient { + c := st.mkClient() + c.AccessToken = token + return c + } + alice := func() *cachers.HTTPClient { return freshClient(aliceToken) } + bob := func() *cachers.HTTPClient { return freshClient(bobToken) } + main := func() *cachers.HTTPClient { return freshClient(mainToken) } + + const ( + actionA = "0001" + outA = "9901" + outG = "9902" + valA = "alice bytes" + valG = "global bytes" + ) + + // alice writes to her own namespace. + st.wantPut(alice(), actionA, outA, valA) + + // bob shares no namespace with alice and there's nothing in global yet. + st.wantGetMiss(bob(), actionA) + + // alice can read what she just wrote from her own namespace. + st.wantGet(alice(), actionA, outA, valA) + + // main-branch session writes to the global namespace. + st.wantPut(main(), actionA, outG, valG) + + // bob now hits via global. + st.wantGet(bob(), actionA, outG, valG) + + // alice prefers global over her own namespace, so the access-time bump + // lands on the shared row. + st.wantGet(alice(), actionA, outG, valG) +} + func BenchmarkFlushAccessTimes(b *testing.B) { st := newServerTester(b, WithVerbose(false)) s := st.srv @@ -1034,3 +1030,36 @@ func BenchmarkFlushAccessTimes(b *testing.B) { } } } + +func exchangeToken(t testing.TB, url string, jwt string) (resp *http.Response, accessToken string) { + t.Helper() + body, err := json.Marshal(map[string]any{"jwt": jwt}) + if err != nil { + t.Fatalf("marshal: %v", err) + } + resp, err = http.Post(url+"/auth/exchange-token", "application/json", bytes.NewReader(body)) + if err != nil { + t.Fatalf("token exchange: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return resp, "" + } + var d struct { + AccessToken string `json:"access_token"` + } + if err := json.NewDecoder(resp.Body).Decode(&d); err != nil { + t.Fatalf("decode: %v", err) + } + return resp, d.AccessToken +} + +func baseClaims(iss, sub string) jwt.MapClaims { + return jwt.MapClaims{ + "sub": sub, + "iss": iss, + "aud": gocachedAudience, + "nbf": jwt.NewNumericDate(time.Now().Add(-time.Minute)), + "exp": jwt.NewNumericDate(time.Now().Add(time.Hour)), + } +}