diff --git a/cmd/gocached/gocached.go b/cmd/gocached/gocached.go index e444a07..7cd34e7 100644 --- a/cmd/gocached/gocached.go +++ b/cmd/gocached/gocached.go @@ -9,7 +9,6 @@ import ( "flag" "fmt" "log" - "maps" "net" "net/http" "os" @@ -74,15 +73,27 @@ func main() { log.Fatal("must specify --jwt-claim at least once when --jwt-issuer is set") } - globalClaims := map[string]string{} - maps.Copy(globalClaims, jwtClaims) - maps.Copy(globalClaims, globalJWTClaims) - - opts = append(opts, gocached.WithJWTAuth(gocached.JWTIssuerConfig{ - Issuer: *jwtIssuer, - RequiredClaims: jwtClaims, - GlobalWriteClaims: globalClaims, - })) + opts = append(opts, + gocached.WithJWTAuth(*jwtIssuer), + gocached.WithNamespaceMapping(func(claims map[string]any) (gocached.Namespace, error) { + var ns gocached.Namespace + for k, want := range jwtClaims { + if got := claims[k]; got != want { + return "", fmt.Errorf("claim %q = %v, want %v", k, got, want) + } + if ns != "" { + ns += "," + } + ns += gocached.Namespace(fmt.Sprintf("%s=%s", k, want)) + } + for k, want := range globalJWTClaims { + if got := claims[k]; got != want { + return ns, nil + } + } + return gocached.GlobalNamespace, nil + }), + ) } srv, err := gocached.NewServer(opts...) diff --git a/gocached/gocached.go b/gocached/gocached.go index 9f97931..961ca6a 100644 --- a/gocached/gocached.go +++ b/gocached/gocached.go @@ -250,6 +250,24 @@ func (srv *Server) start() error { } srv.db = db + // Fill the namespace ID cache. + rows, err := srv.db.Query("SELECT NamespaceID, Namespace FROM Namespaces") + if err != nil { + return err + } + defer rows.Close() + for rows.Next() { + var id int64 + var ns string + if err := rows.Scan(&id, &ns); err != nil { + return err + } + srv.namespaces[Namespace(ns)] = id + } + if err := rows.Err(); err != nil { + return err + } + reg := prometheus.NewRegistry() reg.MustRegister( collectors.NewGoCollector(), @@ -276,17 +294,13 @@ func (srv *Server) start() error { } if len(srv.jwtIssuers) > 0 { - issuerURLs := make([]string, 0, len(srv.jwtIssuers)) - for iss := range srv.jwtIssuers { - issuerURLs = append(issuerURLs, iss) - } - srv.jwtValidator = ijwt.NewJWTValidator(srv.logf, gocachedAudience, issuerURLs) + srv.jwtValidator = ijwt.NewJWTValidator(srv.logf, gocachedAudience, srv.jwtIssuers) if err := srv.jwtValidator.RunUpdateJWKSLoop(srv.shutdownCtx); err != nil { return fmt.Errorf("failed to fetch JWKS for JWT validator: %w", err) } - for iss, entry := range srv.jwtIssuers { - srv.logf("gocached: using JWT issuer %q with required claims %v, global write claims %v", iss, entry.requiredClaims, entry.globalWriteClaims) + for _, iss := range srv.jwtIssuers { + srv.logf("gocached: using JWT issuer %q", iss) } go srv.runCleanSessionsLoop() @@ -383,40 +397,42 @@ func WithMaxAge(maxAge time.Duration) ServerOption { } } -// JWTIssuerConfig configures a single OIDC issuer for JWT-based authentication. -type JWTIssuerConfig struct { - // Issuer is the OIDC issuer URL. It must be a reachable HTTP(S) server - // that serves its JWKS via a URL discoverable at - // /.well-known/openid-configuration. - Issuer string - - // RequiredClaims are claims that any JWT from this issuer must have to - // start a session. All key-value pairs must match exactly. - RequiredClaims map[string]string - - // GlobalWriteClaims are claims that a JWT from this issuer must have to - // write to the cache's global namespace. It should be a superset of - // RequiredClaims. - GlobalWriteClaims map[string]string +// Namespace identifies a logical partition of the cache where each peer is +// equally trusted. Every session is associated with exactly one Namespace to +// which it can read and write; sessions for non-global namespaces also read +// from [GlobalNamespace]. See [WithNamespaceMapping]. Namespace is case- +// insensitive, but is otherwise treated as an opaque string. +type Namespace string + +// GlobalNamespace is a trusted namespace that all sessions can read from. +const GlobalNamespace Namespace = "" + +// WithJWTAuth enables JWT-based authentication for the server. Each issuer +// must be a reachable HTTP(S) server that serves its JWKS via a URL +// discoverable at /.well-known/openid-configuration. JWTs presented for token +// exchange must pass the standard signature/issuer/audience/expiry checks +// against one of these issuers. If [WithNamespaceMapping] is provided, then +// it may still be rejected if the mapping function returns an error for its +// claims. No requests other than token exchange are allowed without +// authentication. May be called multiple times; issuers accumulate. +func WithJWTAuth(issuers ...string) ServerOption { + return func(srv *Server) { + srv.jwtIssuers = append(srv.jwtIssuers, issuers...) + } } -// WithJWTAuth enables JWT-based authentication for the server. Each issuer must -// be a reachable HTTP(S) server that serves its JWKS via a URL discoverable at -// /.well-known/openid-configuration, and any JWT presented to the server must -// exactly match the issuer's required claims to start a session. No requests are -// allowed without authentication if JWT auth is enabled. It can be called multiple -// times; configs accumulate. -func WithJWTAuth(issuers ...JWTIssuerConfig) ServerOption { +// WithNamespaceMapping sets the function that makes policy decisions based on +// a JWT's claims. It is called once per token exchange after the JWT's +// signature and standard claims have been validated. It should return an error +// if the claims are not authorized, and otherwise return which [Namespace] it +// is allowed to read and write in. All sessions are allowed to read from the +// [GlobalNamespace] regardless of the namespace returned. Check claims["iss"] +// to switch on per-issuer rules. If JWT auth is enabled, but no mapping +// function is provided, all sessions will read and write in the +// [GlobalNamespace]. +func WithNamespaceMapping(fn func(claims map[string]any) (Namespace, error)) ServerOption { return func(srv *Server) { - if srv.jwtIssuers == nil { - srv.jwtIssuers = make(map[string]*jwtIssuerConfig) - } - for _, ic := range issuers { - srv.jwtIssuers[ic.Issuer] = &jwtIssuerConfig{ - requiredClaims: ic.RequiredClaims, - globalWriteClaims: ic.GlobalWriteClaims, - } - } + srv.namespaceMapping = fn } } @@ -428,12 +444,21 @@ func NewServer(opts ...ServerOption) (*Server, error) { shutdownCtx: context.Background(), logf: log.Printf, sessions: make(map[string]*sessionData), + namespaces: make(map[Namespace]int64), clock: time.Now, } for _, opt := range opts { opt(srv) } + if len(srv.jwtIssuers) > 0 && srv.namespaceMapping == nil { + // If JWT auth is enabled, but not namespace mapping, every session is in + // the global namespace. + srv.namespaceMapping = func(claims map[string]any) (Namespace, error) { + return GlobalNamespace, nil + } + } + err := srv.start() if err != nil { return nil, err @@ -474,13 +499,15 @@ type Server struct { shutdownCtx context.Context shutdownCancel context.CancelFunc - jwtValidator *ijwt.Validator // nil unless jwtIssuers is non-empty - jwtIssuers map[string]*jwtIssuerConfig // keyed by issuer URL + jwtValidator *ijwt.Validator // nil unless jwtIssuers is non-empty + jwtIssuers []string // accepted issuer URLs + namespaceMapping func(claims map[string]any) (Namespace, error) // required when jwtIssuers is non-empty mu sync.RWMutex // guards following fields in this block sessions map[string]*sessionData // maps access token -> session data. accessDirty map[actionKey]int64 // action -> accessTime accessFlushTimer *time.Timer // nil if no flush is scheduled + namespaces map[Namespace]int64 // cached namespace string -> NamespaceID // sqliteWriteMu serializes access to SQLite. In theory the SQLite driver // should serialize access with our 5000ms busy timeout, but empirically we @@ -520,18 +547,13 @@ type Server struct { } } -// jwtIssuerConfig holds per-issuer claim requirements for JWT auth. -type jwtIssuerConfig struct { - requiredClaims map[string]string - globalWriteClaims map[string]string -} - // sessionData corresponds to a specific access token, and is only used if JWT // auth is enabled. type sessionData struct { - expiry time.Time // Session valid until. - globalNSWrite bool // Whether this session can write to the cache's global namespace. - claims map[string]any // Claims from the JWT used to create this session, stored for debug. + expiry time.Time // Session valid until. + namespaceID int64 // The namespace this session writes to. 0 means GlobalNamespace; non-zero sessions also read from 0. + namespace Namespace // Canonical (lower-cased) Namespace this session writes to. Used as the namespaces cache key and shown on the debug page. + claims map[string]any // Claims from the JWT used to create this session, stored for debug. mu sync.Mutex // Guards stats. stats stats @@ -641,12 +663,11 @@ func (srv *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { } if r.Method == "PUT" { - if sessionData != nil && !sessionData.globalNSWrite { - // TODO(tomhjp): support per-namespace writes. - http.Error(w, "forbidden", http.StatusForbidden) - return + var writeNS int64 + if sessionData != nil { + writeNS = sessionData.namespaceID } - srv.handlePut(w, r, reqStats) + srv.handlePut(w, r, reqStats, writeNS) return } if r.Method != "GET" && r.Method != "HEAD" { @@ -654,7 +675,7 @@ func (srv *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } if strings.HasPrefix(r.URL.Path, "/action/") { - srv.handleGetAction(w, r, reqStats) + srv.handleGetAction(w, r, reqStats, sessionData) return } if sessionData != nil && r.URL.Path == "/session/stats" { @@ -723,7 +744,7 @@ func getHexSuffix(r *http.Request, prefix string) (hexSuffix string, ok bool) { // actionKey is the comparable value type for the (NamespaceID, ActionID) // primary key tuple used in the SQLite Actions table. type actionKey struct { - NamespaceID int // 0 for global + NamespaceID int64 // 0 for global ActionID string } @@ -745,7 +766,27 @@ func validHex(x string) bool { // we do a DB write to update it. const relAtimeSeconds = 60 * 60 * 24 // 1 day -func (srv *Server) handleGetAction(w http.ResponseWriter, r *http.Request, stats *stats) { +const getFromGlobalNamespace = ` +SELECT b.SHA256, b.StoredSize, b.UncompressedSize, b.SmallData, a.AltOutputID, a.AccessTime, a.NamespaceID +FROM Actions a, Blobs b +WHERE a.NamespaceID = 0 + AND a.ActionID = ? + AND a.BlobID = b.BlobID +` + +// Get hits from the global namespace first so shared cache mtime is bumped +// with higher priority than namespaced cache. +const getFromSessionNamespace = ` +SELECT b.SHA256, b.StoredSize, b.UncompressedSize, b.SmallData, a.AltOutputID, a.AccessTime, a.NamespaceID +FROM Actions a, Blobs b +WHERE a.NamespaceID IN (0, ?) + AND a.ActionID = ? + AND a.BlobID = b.BlobID +ORDER BY CASE a.NamespaceID WHEN 0 THEN 0 ELSE 1 END +LIMIT 1 +` + +func (srv *Server) handleGetAction(w http.ResponseWriter, r *http.Request, stats *stats, sessionData *sessionData) { srv.m.ActiveGets.Add(1) defer srv.m.ActiveGets.Add(-1) @@ -771,19 +812,24 @@ func (srv *Server) handleGetAction(w http.ResponseWriter, r *http.Request, stats return } - var sha256hex string - var storedSize, uncompressedSize int64 - var smallData sql.NullString - var altObjectID string - var accessTime int64 - var actionKey = actionKey{ - NamespaceID: 0, // global for now; TODO(bradfitz): support namespac - ActionID: actionID, - } - err := srv.db.QueryRow( - "SELECT b.SHA256, b.StoredSize, b.UncompressedSize, b.SmallData, a.AltOutputID, a.AccessTime FROM Actions a, Blobs b WHERE a.NameSpaceID = ? AND a.ActionID = ? AND a.BlobID = b.BlobID", - actionKey.NamespaceID, actionKey.ActionID).Scan( - &sha256hex, &storedSize, &uncompressedSize, &smallData, &altObjectID, &accessTime) + var ( + sha256hex string + storedSize, uncompressedSize int64 + smallData sql.NullString + altObjectID string + accessTime int64 + err error + actionKey = actionKey{ + ActionID: actionID, + } + ) + if sessionData != nil && sessionData.namespaceID != 0 { + err = srv.db.QueryRow(getFromSessionNamespace, sessionData.namespaceID, actionID).Scan( + &sha256hex, &storedSize, &uncompressedSize, &smallData, &altObjectID, &accessTime, &actionKey.NamespaceID) + } else { + err = srv.db.QueryRow(getFromGlobalNamespace, actionID).Scan( + &sha256hex, &storedSize, &uncompressedSize, &smallData, &altObjectID, &accessTime, &actionKey.NamespaceID) + } if err != nil { if errors.Is(err, sql.ErrNoRows) { http.Error(w, "not found", http.StatusNotFound) @@ -997,7 +1043,7 @@ func (srv *Server) getObjectFromDiskOrPeer(_ context.Context, sha256hex string, return f, nil } -func (s *Server) handlePut(w http.ResponseWriter, r *http.Request, stats *stats) { +func (s *Server) handlePut(w http.ResponseWriter, r *http.Request, stats *stats, namespaceID int64) { s.m.ActivePuts.Add(1) defer s.m.ActivePuts.Add(-1) @@ -1074,13 +1120,12 @@ func (s *Server) handlePut(w http.ResponseWriter, r *http.Request, stats *stats) // Insert or update the action in the database. nowUnix := s.now().Unix() altObjectID := "" - namespace := 0 // global for now; TODO(bradfitz): support namespaces if sha256hex != outputID { altObjectID = outputID } res, err := s.db.Exec(`INSERT OR IGNORE INTO Actions (NamespaceID, ActionID, BlobID, AltOutputID, CreateTime, AccessTime) VALUES (?, ?, ?, ?, ?, ?)`, - namespace, + namespaceID, actionID, blobID, altObjectID, @@ -1139,23 +1184,36 @@ func (srv *Server) handleTokenExchange(w http.ResponseWriter, r *http.Request) { return } - globalNSWrite, err := srv.evaluateClaims(jwtClaims) + ns, err := srv.namespaceMapping(jwtClaims) if err != nil { srv.m.AuthErrs.Add(1) if srv.verbose { - srv.logf("token exchange: %v", err) + srv.logf("token exchange: namespace func error: %v", err) } http.Error(w, "unauthorized", http.StatusUnauthorized) return } + ns = Namespace(strings.ToLower(string(ns))) + var namespaceID int64 + if ns != GlobalNamespace { + namespaceID, err = srv.resolveNamespaceID(ns) + if err != nil { + srv.m.AuthErrs.Add(1) + srv.logf("token exchange: %v", err) + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + } + const ttl = time.Hour // 52 base32 characters, 256 bits of entropy. accessToken := tokenPrefix + strings.ToLower(rand.Text()+rand.Text()) srv.addSessionData(accessToken, &sessionData{ - expiry: srv.now().UTC().Add(ttl), - globalNSWrite: globalNSWrite, - claims: jwtClaims, + expiry: srv.now().UTC().Add(ttl), + namespaceID: namespaceID, + namespace: ns, + claims: jwtClaims, }) resp := map[string]any{ @@ -1173,38 +1231,38 @@ func (srv *Server) handleTokenExchange(w http.ResponseWriter, r *http.Request) { srv.m.Auths.Add(1) } -func (srv *Server) evaluateClaims(claims map[string]any) (globalNSWrite bool, _ error) { - iss, _ := claims["iss"].(string) - cfg, ok := srv.jwtIssuers[iss] - if !ok { - return false, fmt.Errorf("got claims %v; unknown issuer %q", claims, iss) - } - - if missing := findMissingClaims(cfg.requiredClaims, claims); len(missing) > 0 { - return false, fmt.Errorf("got claims %v; missing required claims: %v", claims, missing) +// resolveNamespaceID returns the integer ID for the given Namespace, +// inserting a row in the Namespaces table if one doesn't already exist. The +// name must already be lower-cased or it will violate the CHECK constraint. +func (srv *Server) resolveNamespaceID(ns Namespace) (int64, error) { + // If it's not a new namespace, we only need to consult our cache of IDs. + srv.mu.Lock() + id, ok := srv.namespaces[ns] + srv.mu.Unlock() + if ok { + return id, nil } - if missing := findMissingClaims(cfg.globalWriteClaims, claims); len(missing) == 0 { - return true, nil - } else if srv.verbose { - srv.logf("token exchange: missing global namespace write claims: %v", missing) - } + srv.sqliteWriteMu.Lock() + defer srv.sqliteWriteMu.Unlock() - return false, nil -} + srv.mu.Lock() + defer srv.mu.Unlock() -func findMissingClaims(wantClaims map[string]string, gotClaims map[string]any) map[string]any { - if wantClaims == nil { - return nil + // Check if we lost a race now that we have both locks. + if id, ok = srv.namespaces[ns]; ok { + return id, nil } - missing := make(map[string]any) - for k, want := range wantClaims { - if got, ok := gotClaims[k]; !ok || got != want { - missing[k] = want - } + err := srv.db.QueryRow(`INSERT INTO Namespaces (Namespace) VALUES (?) + RETURNING NamespaceID;`, ns).Scan(&id) + if err != nil { + return 0, fmt.Errorf("resolving namespace %q: %w", ns, err) } - return missing + + srv.namespaces[ns] = id + + return id, nil } func (srv *Server) handleSessionStats(w http.ResponseWriter, sessionData *sessionData) { @@ -1722,10 +1780,11 @@ func (srv *Server) serveSessions(w http.ResponseWriter, r *http.Request) { for _, v := range srv.sessions { v.mu.Lock() sessions = append(sessions, &sessionData{ - expiry: v.expiry, - globalNSWrite: v.globalNSWrite, - claims: v.claims, - stats: v.stats, + expiry: v.expiry, + namespaceID: v.namespaceID, + namespace: v.namespace, + claims: v.claims, + stats: v.stats, }) v.mu.Unlock() } @@ -1733,15 +1792,13 @@ func (srv *Server) serveSessions(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/html; charset=utf-8") fmt.Fprintf(w, "

gocached sessions

\n") - for iss, cfg := range srv.jwtIssuers { + for _, iss := range srv.jwtIssuers { fmt.Fprintf(w, "

JWT issuer: %s

\n", iss) - fmt.Fprintf(w, "

JWT claims required: %v

\n", cfg.requiredClaims) - fmt.Fprintf(w, "

JWT global write claims required: %v

\n", cfg.globalWriteClaims) } fmt.Fprintf(w, "

Number of sessions: %d

\n", len(sessions)) fmt.Fprintf(w, "\n") - fmt.Fprintf(w, "\n") + fmt.Fprintf(w, "\n") slices.SortFunc(sessions, func(a, b *sessionData) int { return a.stats.LastUsed.Compare(b.stats.LastUsed) }) @@ -1750,10 +1807,14 @@ func (srv *Server) serveSessions(w http.ResponseWriter, r *http.Request) { if !d.stats.LastUsed.IsZero() { lastUsed = durFmt(time.Since(d.stats.LastUsed)) + " ago" } + nsLabel := "(global)" + if d.namespaceID != 0 { + nsLabel = fmt.Sprintf("%q (id=%d)", d.namespace, d.namespaceID) + } statsJSON, _ := json.MarshalIndent(d.stats, "", " ") claimsJSON, _ := json.MarshalIndent(d.claims, "", " ") - fmt.Fprintf(w, "\n", - lastUsed, d.expiry.Format(time.RFC3339), d.globalNSWrite, statsJSON, claimsJSON) + fmt.Fprintf(w, "\n", + lastUsed, d.expiry.Format(time.RFC3339), nsLabel, statsJSON, claimsJSON) } fmt.Fprintf(w, "
Last usedExpiry timeGlobal writeStatsClaims
Last usedExpiry timeNamespaceStatsClaims
%s%s%v
%s
%s
%s%s%s
%s
%s
\n") } diff --git a/gocached/gocached_test.go b/gocached/gocached_test.go index a46d92e..a795f1a 100644 --- a/gocached/gocached_test.go +++ b/gocached/gocached_test.go @@ -38,6 +38,23 @@ import ( // value in SQLite to store bytes, as it's common. const sha256OfEmpty = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" +// Package-level ECDSA P-256 keys reused across tests that need OIDC signing +// keys. Generating these is the single most expensive step in the JWT tests, +// so we share them rather than regenerating per test. +var ( + testKey1 = mustGenerateTestKey() + testKey2 = mustGenerateTestKey() + testKey3 = mustGenerateTestKey() +) + +func mustGenerateTestKey() *ecdsa.PrivateKey { + k, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + panic(fmt.Sprintf("generating test ECDSA key: %v", err)) + } + return k +} + type tester struct { t testing.TB srv *Server @@ -641,41 +658,17 @@ func TestClientConnReuse(t *testing.T) { } func TestExchangeToken(t *testing.T) { - // Generate private keys outside of the loop for speed. - privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - t.Fatalf("error generating OIDC server private key: %v", err) - } - otherPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - t.Fatalf("error generating OIDC server private key: %v", err) - } - wantClaims := map[string]string{ - "sub": "user123", - } - wantGlobalClaims := map[string]string{ - "sub": "user123", - "ref": "refs/heads/main", - } + privateKey := testKey1 + otherPrivateKey := testKey2 for name, tc := range map[string]struct { mutateClaims func(jwt.MapClaims) signingKey *ecdsa.PrivateKey wantStatusCode int - wantWrite bool }{ // Base case: no mutation. - "valid_read": { - wantStatusCode: http.StatusOK, - wantWrite: false, - }, - // Additional claim needed for write scope. - "valid_write": { - mutateClaims: func(cl jwt.MapClaims) { - cl["ref"] = "refs/heads/main" - }, + "valid": { wantStatusCode: http.StatusOK, - wantWrite: true, }, // Every other test makes one mutation from the base case that should cause failure. "missing_sub": { @@ -722,10 +715,12 @@ func TestExchangeToken(t *testing.T) { t.Run(name, func(t *testing.T) { issuer, createJWT := startOIDCServer(t, privateKey.Public()) st := newServerTester(t, - WithJWTAuth(JWTIssuerConfig{ - Issuer: issuer, - RequiredClaims: wantClaims, - GlobalWriteClaims: wantGlobalClaims, + WithJWTAuth(issuer), + WithNamespaceMapping(func(claims map[string]any) (Namespace, error) { + if claims["sub"] != "user123" { + return "", fmt.Errorf("sub = %v, want user123", claims["sub"]) + } + return GlobalNamespace, nil }), ) @@ -798,14 +793,8 @@ func TestExchangeToken(t *testing.T) { cl.AccessToken = d.AccessToken st.wantGetMiss(cl, "abc123") - if tc.wantWrite { - st.wantPut(cl, "abc123", "def456", "data789") - st.wantGet(cl, "abc123", "def456", "data789") - } else { - if _, err := cl.Put(t.Context(), "abc123", "def456", 0, nil); err == nil { - t.Fatalf("Put without write scope succeeded unexpectedly") - } - } + st.wantPut(cl, "abc123", "def456", "data789") + st.wantGet(cl, "abc123", "def456", "data789") // Check session stats. reqStats, err := http.NewRequest("GET", st.hs.URL+"/session/stats", nil) @@ -833,7 +822,7 @@ func TestExchangeToken(t *testing.T) { if stats.Gets == 0 { t.Errorf("expected non-zero gets in session stats") } - if stats.Puts == 0 && tc.wantWrite { + if stats.Puts == 0 { t.Errorf("expected non-zero puts in session stats") } }) @@ -841,94 +830,41 @@ func TestExchangeToken(t *testing.T) { } func TestMultiIssuerAuth(t *testing.T) { - // Generate separate keys for each issuer. - keyA, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - t.Fatalf("error generating key A: %v", err) - } - keyB, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - t.Fatalf("error generating key B: %v", err) - } - keyC, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - t.Fatalf("error generating key C: %v", err) - } + keyA, keyB, keyC := testKey1, testKey2, testKey3 issuerA, createJWTA := startOIDCServer(t, keyA.Public()) issuerB, createJWTB := startOIDCServer(t, keyB.Public()) issuerC, createJWTC := startOIDCServer(t, keyC.Public()) - st := newServerTester(t, - WithJWTAuth( - JWTIssuerConfig{ - Issuer: issuerA, - RequiredClaims: map[string]string{"sub": "userA"}, - GlobalWriteClaims: map[string]string{ - "sub": "userA", - "ref": "refs/heads/main", - }, - }, - JWTIssuerConfig{ - Issuer: issuerB, - RequiredClaims: map[string]string{"sub": "userB"}, - GlobalWriteClaims: map[string]string{ - "sub": "userB", - "ref": "refs/heads/main", - }, - }, - ), - ) - - makeJWTBody := func(jwtString string) []byte { - body, err := json.Marshal(map[string]any{"jwt": jwtString}) - if err != nil { - t.Fatalf("error marshaling request body: %v", err) - } - return body - } - - exchangeToken := func(jwtBody []byte) (*http.Response, string) { - t.Helper() - req, err := http.NewRequest("POST", st.hs.URL+"/auth/exchange-token", bytes.NewReader(jwtBody)) - if err != nil { - t.Fatalf("error creating request: %v", err) + namespaceFunc := func(claims map[string]any) (Namespace, error) { + iss, _ := claims["iss"].(string) + var requiredSub string + switch iss { + case issuerA: + requiredSub = "userA" + case issuerB: + requiredSub = "userB" + default: + return "", fmt.Errorf("unknown issuer %q", iss) } - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("error making request: %v", err) - } - body, err := io.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - t.Fatalf("error reading response body: %v", err) + if claims["sub"] != requiredSub { + return "", fmt.Errorf("issuer %q: sub = %v, want %v", iss, claims["sub"], requiredSub) } - if resp.StatusCode == http.StatusOK { - var d struct { - AccessToken string `json:"access_token"` - } - if err := json.Unmarshal(body, &d); err != nil { - t.Fatalf("error decoding response body: %v", err) - } - return resp, d.AccessToken + if claims["ref"] == "refs/heads/main" { + return GlobalNamespace, nil } - return resp, "" + return Namespace(requiredSub), nil } - baseClaims := func(iss string) jwt.MapClaims { - return jwt.MapClaims{ - "iss": iss, - "aud": gocachedAudience, - "nbf": jwt.NewNumericDate(time.Now().Add(-time.Minute)), - "exp": jwt.NewNumericDate(time.Now().Add(time.Hour)), - } - } + st := newServerTester(t, + WithJWTAuth(issuerA, issuerB), + WithNamespaceMapping(namespaceFunc), + ) // Issuer A: valid read-only token. t.Run("issuerA_read", func(t *testing.T) { - claims := baseClaims(issuerA) - claims["sub"] = "userA" - resp, accessToken := exchangeToken(makeJWTBody(createJWTA(claims, keyA))) + claims := baseClaims(issuerA, "userA") + resp, accessToken := exchangeToken(t, st.hs.URL, createJWTA(claims, keyA)) if resp.StatusCode != http.StatusOK { t.Fatalf("unexpected status code: want %d, got %d", http.StatusOK, resp.StatusCode) } @@ -943,10 +879,9 @@ func TestMultiIssuerAuth(t *testing.T) { // Issuer A: valid write token. t.Run("issuerA_write", func(t *testing.T) { - claims := baseClaims(issuerA) - claims["sub"] = "userA" + claims := baseClaims(issuerA, "userA") claims["ref"] = "refs/heads/main" - resp, accessToken := exchangeToken(makeJWTBody(createJWTA(claims, keyA))) + resp, accessToken := exchangeToken(t, st.hs.URL, createJWTA(claims, keyA)) if resp.StatusCode != http.StatusOK { t.Fatalf("unexpected status code: want %d, got %d", http.StatusOK, resp.StatusCode) } @@ -958,9 +893,8 @@ func TestMultiIssuerAuth(t *testing.T) { // Issuer B: valid read-only token. t.Run("issuerB_read", func(t *testing.T) { - claims := baseClaims(issuerB) - claims["sub"] = "userB" - resp, accessToken := exchangeToken(makeJWTBody(createJWTB(claims, keyB))) + claims := baseClaims(issuerB, "userB") + resp, accessToken := exchangeToken(t, st.hs.URL, createJWTB(claims, keyB)) if resp.StatusCode != http.StatusOK { t.Fatalf("unexpected status code: want %d, got %d", http.StatusOK, resp.StatusCode) } @@ -975,10 +909,9 @@ func TestMultiIssuerAuth(t *testing.T) { // Issuer B: valid write token. t.Run("issuerB_write", func(t *testing.T) { - claims := baseClaims(issuerB) - claims["sub"] = "userB" + claims := baseClaims(issuerB, "userB") claims["ref"] = "refs/heads/main" - resp, accessToken := exchangeToken(makeJWTBody(createJWTB(claims, keyB))) + resp, accessToken := exchangeToken(t, st.hs.URL, createJWTB(claims, keyB)) if resp.StatusCode != http.StatusOK { t.Fatalf("unexpected status code: want %d, got %d", http.StatusOK, resp.StatusCode) } @@ -990,9 +923,8 @@ func TestMultiIssuerAuth(t *testing.T) { // Issuer B: wrong required claims (sub doesn't match). t.Run("issuerB_wrong_sub", func(t *testing.T) { - claims := baseClaims(issuerB) - claims["sub"] = "userA" // issuer B requires sub=userB - resp, _ := exchangeToken(makeJWTBody(createJWTB(claims, keyB))) + claims := baseClaims(issuerB, "userA") // issuer B requires sub=userB + resp, _ := exchangeToken(t, st.hs.URL, createJWTB(claims, keyB)) if resp.StatusCode != http.StatusUnauthorized { t.Fatalf("unexpected status code: want %d, got %d", http.StatusUnauthorized, resp.StatusCode) } @@ -1000,15 +932,79 @@ func TestMultiIssuerAuth(t *testing.T) { // Issuer C: not configured, should be rejected. t.Run("issuerC_rejected", func(t *testing.T) { - claims := baseClaims(issuerC) - claims["sub"] = "userC" - resp, _ := exchangeToken(makeJWTBody(createJWTC(claims, keyC))) + claims := baseClaims(issuerC, "userC") + resp, _ := exchangeToken(t, st.hs.URL, createJWTC(claims, keyC)) if resp.StatusCode != http.StatusUnauthorized { t.Fatalf("unexpected status code: want %d, got %d", http.StatusUnauthorized, resp.StatusCode) } }) } +func TestNamespaces(t *testing.T) { + privateKey := testKey1 + issuer, createJWT := startOIDCServer(t, privateKey.Public()) + + // Namespace policy: ref=refs/heads/main grants global write; otherwise + // the session writes to a namespace named after its sub claim. + st := newServerTester(t, + WithJWTAuth(issuer), + WithNamespaceMapping(func(claims map[string]any) (Namespace, error) { + if claims["ref"] == "refs/heads/main" { + return GlobalNamespace, nil + } + sub, _ := claims["sub"].(string) + if sub == "" { + return "", fmt.Errorf("missing sub claim") + } + return Namespace(sub), nil + }), + ) + + // Exchange tokens once, but mint a fresh client (with empty disk cache) + // per operation so client-side disk hits don't short-circuit the + // server-side namespace-routing assertions we want to make. + mainClaims := baseClaims(issuer, "main-builder") + mainClaims["ref"] = "refs/heads/main" + _, aliceToken := exchangeToken(t, st.hs.URL, createJWT(baseClaims(issuer, "alice"), privateKey)) + _, bobToken := exchangeToken(t, st.hs.URL, createJWT(baseClaims(issuer, "bob"), privateKey)) + _, mainToken := exchangeToken(t, st.hs.URL, createJWT(mainClaims, privateKey)) + freshClient := func(token string) *cachers.HTTPClient { + c := st.mkClient() + c.AccessToken = token + return c + } + alice := func() *cachers.HTTPClient { return freshClient(aliceToken) } + bob := func() *cachers.HTTPClient { return freshClient(bobToken) } + main := func() *cachers.HTTPClient { return freshClient(mainToken) } + + const ( + actionA = "0001" + outA = "9901" + outG = "9902" + valA = "alice bytes" + valG = "global bytes" + ) + + // alice writes to her own namespace. + st.wantPut(alice(), actionA, outA, valA) + + // bob shares no namespace with alice and there's nothing in global yet. + st.wantGetMiss(bob(), actionA) + + // alice can read what she just wrote from her own namespace. + st.wantGet(alice(), actionA, outA, valA) + + // main-branch session writes to the global namespace. + st.wantPut(main(), actionA, outG, valG) + + // bob now hits via global. + st.wantGet(bob(), actionA, outG, valG) + + // alice prefers global over her own namespace, so the access-time bump + // lands on the shared row. + st.wantGet(alice(), actionA, outG, valG) +} + func BenchmarkFlushAccessTimes(b *testing.B) { st := newServerTester(b, WithVerbose(false)) s := st.srv @@ -1034,3 +1030,36 @@ func BenchmarkFlushAccessTimes(b *testing.B) { } } } + +func exchangeToken(t testing.TB, url string, jwt string) (resp *http.Response, accessToken string) { + t.Helper() + body, err := json.Marshal(map[string]any{"jwt": jwt}) + if err != nil { + t.Fatalf("marshal: %v", err) + } + resp, err = http.Post(url+"/auth/exchange-token", "application/json", bytes.NewReader(body)) + if err != nil { + t.Fatalf("token exchange: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return resp, "" + } + var d struct { + AccessToken string `json:"access_token"` + } + if err := json.NewDecoder(resp.Body).Decode(&d); err != nil { + t.Fatalf("decode: %v", err) + } + return resp, d.AccessToken +} + +func baseClaims(iss, sub string) jwt.MapClaims { + return jwt.MapClaims{ + "sub": sub, + "iss": iss, + "aud": gocachedAudience, + "nbf": jwt.NewNumericDate(time.Now().Add(-time.Minute)), + "exp": jwt.NewNumericDate(time.Now().Add(time.Hour)), + } +}