diff --git a/cmd/gocached/gocached.go b/cmd/gocached/gocached.go index e444a07..7cd34e7 100644 --- a/cmd/gocached/gocached.go +++ b/cmd/gocached/gocached.go @@ -9,7 +9,6 @@ import ( "flag" "fmt" "log" - "maps" "net" "net/http" "os" @@ -74,15 +73,27 @@ func main() { log.Fatal("must specify --jwt-claim at least once when --jwt-issuer is set") } - globalClaims := map[string]string{} - maps.Copy(globalClaims, jwtClaims) - maps.Copy(globalClaims, globalJWTClaims) - - opts = append(opts, gocached.WithJWTAuth(gocached.JWTIssuerConfig{ - Issuer: *jwtIssuer, - RequiredClaims: jwtClaims, - GlobalWriteClaims: globalClaims, - })) + opts = append(opts, + gocached.WithJWTAuth(*jwtIssuer), + gocached.WithNamespaceMapping(func(claims map[string]any) (gocached.Namespace, error) { + var ns gocached.Namespace + for k, want := range jwtClaims { + if got := claims[k]; got != want { + return "", fmt.Errorf("claim %q = %v, want %v", k, got, want) + } + if ns != "" { + ns += "," + } + ns += gocached.Namespace(fmt.Sprintf("%s=%s", k, want)) + } + for k, want := range globalJWTClaims { + if got := claims[k]; got != want { + return ns, nil + } + } + return gocached.GlobalNamespace, nil + }), + ) } srv, err := gocached.NewServer(opts...) diff --git a/gocached/gocached.go b/gocached/gocached.go index 9f97931..961ca6a 100644 --- a/gocached/gocached.go +++ b/gocached/gocached.go @@ -250,6 +250,24 @@ func (srv *Server) start() error { } srv.db = db + // Fill the namespace ID cache. + rows, err := srv.db.Query("SELECT NamespaceID, Namespace FROM Namespaces") + if err != nil { + return err + } + defer rows.Close() + for rows.Next() { + var id int64 + var ns string + if err := rows.Scan(&id, &ns); err != nil { + return err + } + srv.namespaces[Namespace(ns)] = id + } + if err := rows.Err(); err != nil { + return err + } + reg := prometheus.NewRegistry() reg.MustRegister( collectors.NewGoCollector(), @@ -276,17 +294,13 @@ func (srv *Server) start() error { } if len(srv.jwtIssuers) > 0 { - issuerURLs := make([]string, 0, len(srv.jwtIssuers)) - for iss := range srv.jwtIssuers { - issuerURLs = append(issuerURLs, iss) - } - srv.jwtValidator = ijwt.NewJWTValidator(srv.logf, gocachedAudience, issuerURLs) + srv.jwtValidator = ijwt.NewJWTValidator(srv.logf, gocachedAudience, srv.jwtIssuers) if err := srv.jwtValidator.RunUpdateJWKSLoop(srv.shutdownCtx); err != nil { return fmt.Errorf("failed to fetch JWKS for JWT validator: %w", err) } - for iss, entry := range srv.jwtIssuers { - srv.logf("gocached: using JWT issuer %q with required claims %v, global write claims %v", iss, entry.requiredClaims, entry.globalWriteClaims) + for _, iss := range srv.jwtIssuers { + srv.logf("gocached: using JWT issuer %q", iss) } go srv.runCleanSessionsLoop() @@ -383,40 +397,42 @@ func WithMaxAge(maxAge time.Duration) ServerOption { } } -// JWTIssuerConfig configures a single OIDC issuer for JWT-based authentication. -type JWTIssuerConfig struct { - // Issuer is the OIDC issuer URL. It must be a reachable HTTP(S) server - // that serves its JWKS via a URL discoverable at - // /.well-known/openid-configuration. - Issuer string - - // RequiredClaims are claims that any JWT from this issuer must have to - // start a session. All key-value pairs must match exactly. - RequiredClaims map[string]string - - // GlobalWriteClaims are claims that a JWT from this issuer must have to - // write to the cache's global namespace. It should be a superset of - // RequiredClaims. - GlobalWriteClaims map[string]string +// Namespace identifies a logical partition of the cache where each peer is +// equally trusted. Every session is associated with exactly one Namespace to +// which it can read and write; sessions for non-global namespaces also read +// from [GlobalNamespace]. See [WithNamespaceMapping]. Namespace is case- +// insensitive, but is otherwise treated as an opaque string. +type Namespace string + +// GlobalNamespace is a trusted namespace that all sessions can read from. +const GlobalNamespace Namespace = "" + +// WithJWTAuth enables JWT-based authentication for the server. Each issuer +// must be a reachable HTTP(S) server that serves its JWKS via a URL +// discoverable at /.well-known/openid-configuration. JWTs presented for token +// exchange must pass the standard signature/issuer/audience/expiry checks +// against one of these issuers. If [WithNamespaceMapping] is provided, then +// it may still be rejected if the mapping function returns an error for its +// claims. No requests other than token exchange are allowed without +// authentication. May be called multiple times; issuers accumulate. +func WithJWTAuth(issuers ...string) ServerOption { + return func(srv *Server) { + srv.jwtIssuers = append(srv.jwtIssuers, issuers...) + } } -// WithJWTAuth enables JWT-based authentication for the server. Each issuer must -// be a reachable HTTP(S) server that serves its JWKS via a URL discoverable at -// /.well-known/openid-configuration, and any JWT presented to the server must -// exactly match the issuer's required claims to start a session. No requests are -// allowed without authentication if JWT auth is enabled. It can be called multiple -// times; configs accumulate. -func WithJWTAuth(issuers ...JWTIssuerConfig) ServerOption { +// WithNamespaceMapping sets the function that makes policy decisions based on +// a JWT's claims. It is called once per token exchange after the JWT's +// signature and standard claims have been validated. It should return an error +// if the claims are not authorized, and otherwise return which [Namespace] it +// is allowed to read and write in. All sessions are allowed to read from the +// [GlobalNamespace] regardless of the namespace returned. Check claims["iss"] +// to switch on per-issuer rules. If JWT auth is enabled, but no mapping +// function is provided, all sessions will read and write in the +// [GlobalNamespace]. +func WithNamespaceMapping(fn func(claims map[string]any) (Namespace, error)) ServerOption { return func(srv *Server) { - if srv.jwtIssuers == nil { - srv.jwtIssuers = make(map[string]*jwtIssuerConfig) - } - for _, ic := range issuers { - srv.jwtIssuers[ic.Issuer] = &jwtIssuerConfig{ - requiredClaims: ic.RequiredClaims, - globalWriteClaims: ic.GlobalWriteClaims, - } - } + srv.namespaceMapping = fn } } @@ -428,12 +444,21 @@ func NewServer(opts ...ServerOption) (*Server, error) { shutdownCtx: context.Background(), logf: log.Printf, sessions: make(map[string]*sessionData), + namespaces: make(map[Namespace]int64), clock: time.Now, } for _, opt := range opts { opt(srv) } + if len(srv.jwtIssuers) > 0 && srv.namespaceMapping == nil { + // If JWT auth is enabled, but not namespace mapping, every session is in + // the global namespace. + srv.namespaceMapping = func(claims map[string]any) (Namespace, error) { + return GlobalNamespace, nil + } + } + err := srv.start() if err != nil { return nil, err @@ -474,13 +499,15 @@ type Server struct { shutdownCtx context.Context shutdownCancel context.CancelFunc - jwtValidator *ijwt.Validator // nil unless jwtIssuers is non-empty - jwtIssuers map[string]*jwtIssuerConfig // keyed by issuer URL + jwtValidator *ijwt.Validator // nil unless jwtIssuers is non-empty + jwtIssuers []string // accepted issuer URLs + namespaceMapping func(claims map[string]any) (Namespace, error) // required when jwtIssuers is non-empty mu sync.RWMutex // guards following fields in this block sessions map[string]*sessionData // maps access token -> session data. accessDirty map[actionKey]int64 // action -> accessTime accessFlushTimer *time.Timer // nil if no flush is scheduled + namespaces map[Namespace]int64 // cached namespace string -> NamespaceID // sqliteWriteMu serializes access to SQLite. In theory the SQLite driver // should serialize access with our 5000ms busy timeout, but empirically we @@ -520,18 +547,13 @@ type Server struct { } } -// jwtIssuerConfig holds per-issuer claim requirements for JWT auth. -type jwtIssuerConfig struct { - requiredClaims map[string]string - globalWriteClaims map[string]string -} - // sessionData corresponds to a specific access token, and is only used if JWT // auth is enabled. type sessionData struct { - expiry time.Time // Session valid until. - globalNSWrite bool // Whether this session can write to the cache's global namespace. - claims map[string]any // Claims from the JWT used to create this session, stored for debug. + expiry time.Time // Session valid until. + namespaceID int64 // The namespace this session writes to. 0 means GlobalNamespace; non-zero sessions also read from 0. + namespace Namespace // Canonical (lower-cased) Namespace this session writes to. Used as the namespaces cache key and shown on the debug page. + claims map[string]any // Claims from the JWT used to create this session, stored for debug. mu sync.Mutex // Guards stats. stats stats @@ -641,12 +663,11 @@ func (srv *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { } if r.Method == "PUT" { - if sessionData != nil && !sessionData.globalNSWrite { - // TODO(tomhjp): support per-namespace writes. - http.Error(w, "forbidden", http.StatusForbidden) - return + var writeNS int64 + if sessionData != nil { + writeNS = sessionData.namespaceID } - srv.handlePut(w, r, reqStats) + srv.handlePut(w, r, reqStats, writeNS) return } if r.Method != "GET" && r.Method != "HEAD" { @@ -654,7 +675,7 @@ func (srv *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } if strings.HasPrefix(r.URL.Path, "/action/") { - srv.handleGetAction(w, r, reqStats) + srv.handleGetAction(w, r, reqStats, sessionData) return } if sessionData != nil && r.URL.Path == "/session/stats" { @@ -723,7 +744,7 @@ func getHexSuffix(r *http.Request, prefix string) (hexSuffix string, ok bool) { // actionKey is the comparable value type for the (NamespaceID, ActionID) // primary key tuple used in the SQLite Actions table. type actionKey struct { - NamespaceID int // 0 for global + NamespaceID int64 // 0 for global ActionID string } @@ -745,7 +766,27 @@ func validHex(x string) bool { // we do a DB write to update it. const relAtimeSeconds = 60 * 60 * 24 // 1 day -func (srv *Server) handleGetAction(w http.ResponseWriter, r *http.Request, stats *stats) { +const getFromGlobalNamespace = ` +SELECT b.SHA256, b.StoredSize, b.UncompressedSize, b.SmallData, a.AltOutputID, a.AccessTime, a.NamespaceID +FROM Actions a, Blobs b +WHERE a.NamespaceID = 0 + AND a.ActionID = ? + AND a.BlobID = b.BlobID +` + +// Get hits from the global namespace first so shared cache mtime is bumped +// with higher priority than namespaced cache. +const getFromSessionNamespace = ` +SELECT b.SHA256, b.StoredSize, b.UncompressedSize, b.SmallData, a.AltOutputID, a.AccessTime, a.NamespaceID +FROM Actions a, Blobs b +WHERE a.NamespaceID IN (0, ?) + AND a.ActionID = ? + AND a.BlobID = b.BlobID +ORDER BY CASE a.NamespaceID WHEN 0 THEN 0 ELSE 1 END +LIMIT 1 +` + +func (srv *Server) handleGetAction(w http.ResponseWriter, r *http.Request, stats *stats, sessionData *sessionData) { srv.m.ActiveGets.Add(1) defer srv.m.ActiveGets.Add(-1) @@ -771,19 +812,24 @@ func (srv *Server) handleGetAction(w http.ResponseWriter, r *http.Request, stats return } - var sha256hex string - var storedSize, uncompressedSize int64 - var smallData sql.NullString - var altObjectID string - var accessTime int64 - var actionKey = actionKey{ - NamespaceID: 0, // global for now; TODO(bradfitz): support namespac - ActionID: actionID, - } - err := srv.db.QueryRow( - "SELECT b.SHA256, b.StoredSize, b.UncompressedSize, b.SmallData, a.AltOutputID, a.AccessTime FROM Actions a, Blobs b WHERE a.NameSpaceID = ? AND a.ActionID = ? AND a.BlobID = b.BlobID", - actionKey.NamespaceID, actionKey.ActionID).Scan( - &sha256hex, &storedSize, &uncompressedSize, &smallData, &altObjectID, &accessTime) + var ( + sha256hex string + storedSize, uncompressedSize int64 + smallData sql.NullString + altObjectID string + accessTime int64 + err error + actionKey = actionKey{ + ActionID: actionID, + } + ) + if sessionData != nil && sessionData.namespaceID != 0 { + err = srv.db.QueryRow(getFromSessionNamespace, sessionData.namespaceID, actionID).Scan( + &sha256hex, &storedSize, &uncompressedSize, &smallData, &altObjectID, &accessTime, &actionKey.NamespaceID) + } else { + err = srv.db.QueryRow(getFromGlobalNamespace, actionID).Scan( + &sha256hex, &storedSize, &uncompressedSize, &smallData, &altObjectID, &accessTime, &actionKey.NamespaceID) + } if err != nil { if errors.Is(err, sql.ErrNoRows) { http.Error(w, "not found", http.StatusNotFound) @@ -997,7 +1043,7 @@ func (srv *Server) getObjectFromDiskOrPeer(_ context.Context, sha256hex string, return f, nil } -func (s *Server) handlePut(w http.ResponseWriter, r *http.Request, stats *stats) { +func (s *Server) handlePut(w http.ResponseWriter, r *http.Request, stats *stats, namespaceID int64) { s.m.ActivePuts.Add(1) defer s.m.ActivePuts.Add(-1) @@ -1074,13 +1120,12 @@ func (s *Server) handlePut(w http.ResponseWriter, r *http.Request, stats *stats) // Insert or update the action in the database. nowUnix := s.now().Unix() altObjectID := "" - namespace := 0 // global for now; TODO(bradfitz): support namespaces if sha256hex != outputID { altObjectID = outputID } res, err := s.db.Exec(`INSERT OR IGNORE INTO Actions (NamespaceID, ActionID, BlobID, AltOutputID, CreateTime, AccessTime) VALUES (?, ?, ?, ?, ?, ?)`, - namespace, + namespaceID, actionID, blobID, altObjectID, @@ -1139,23 +1184,36 @@ func (srv *Server) handleTokenExchange(w http.ResponseWriter, r *http.Request) { return } - globalNSWrite, err := srv.evaluateClaims(jwtClaims) + ns, err := srv.namespaceMapping(jwtClaims) if err != nil { srv.m.AuthErrs.Add(1) if srv.verbose { - srv.logf("token exchange: %v", err) + srv.logf("token exchange: namespace func error: %v", err) } http.Error(w, "unauthorized", http.StatusUnauthorized) return } + ns = Namespace(strings.ToLower(string(ns))) + var namespaceID int64 + if ns != GlobalNamespace { + namespaceID, err = srv.resolveNamespaceID(ns) + if err != nil { + srv.m.AuthErrs.Add(1) + srv.logf("token exchange: %v", err) + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + } + const ttl = time.Hour // 52 base32 characters, 256 bits of entropy. accessToken := tokenPrefix + strings.ToLower(rand.Text()+rand.Text()) srv.addSessionData(accessToken, &sessionData{ - expiry: srv.now().UTC().Add(ttl), - globalNSWrite: globalNSWrite, - claims: jwtClaims, + expiry: srv.now().UTC().Add(ttl), + namespaceID: namespaceID, + namespace: ns, + claims: jwtClaims, }) resp := map[string]any{ @@ -1173,38 +1231,38 @@ func (srv *Server) handleTokenExchange(w http.ResponseWriter, r *http.Request) { srv.m.Auths.Add(1) } -func (srv *Server) evaluateClaims(claims map[string]any) (globalNSWrite bool, _ error) { - iss, _ := claims["iss"].(string) - cfg, ok := srv.jwtIssuers[iss] - if !ok { - return false, fmt.Errorf("got claims %v; unknown issuer %q", claims, iss) - } - - if missing := findMissingClaims(cfg.requiredClaims, claims); len(missing) > 0 { - return false, fmt.Errorf("got claims %v; missing required claims: %v", claims, missing) +// resolveNamespaceID returns the integer ID for the given Namespace, +// inserting a row in the Namespaces table if one doesn't already exist. The +// name must already be lower-cased or it will violate the CHECK constraint. +func (srv *Server) resolveNamespaceID(ns Namespace) (int64, error) { + // If it's not a new namespace, we only need to consult our cache of IDs. + srv.mu.Lock() + id, ok := srv.namespaces[ns] + srv.mu.Unlock() + if ok { + return id, nil } - if missing := findMissingClaims(cfg.globalWriteClaims, claims); len(missing) == 0 { - return true, nil - } else if srv.verbose { - srv.logf("token exchange: missing global namespace write claims: %v", missing) - } + srv.sqliteWriteMu.Lock() + defer srv.sqliteWriteMu.Unlock() - return false, nil -} + srv.mu.Lock() + defer srv.mu.Unlock() -func findMissingClaims(wantClaims map[string]string, gotClaims map[string]any) map[string]any { - if wantClaims == nil { - return nil + // Check if we lost a race now that we have both locks. + if id, ok = srv.namespaces[ns]; ok { + return id, nil } - missing := make(map[string]any) - for k, want := range wantClaims { - if got, ok := gotClaims[k]; !ok || got != want { - missing[k] = want - } + err := srv.db.QueryRow(`INSERT INTO Namespaces (Namespace) VALUES (?) + RETURNING NamespaceID;`, ns).Scan(&id) + if err != nil { + return 0, fmt.Errorf("resolving namespace %q: %w", ns, err) } - return missing + + srv.namespaces[ns] = id + + return id, nil } func (srv *Server) handleSessionStats(w http.ResponseWriter, sessionData *sessionData) { @@ -1722,10 +1780,11 @@ func (srv *Server) serveSessions(w http.ResponseWriter, r *http.Request) { for _, v := range srv.sessions { v.mu.Lock() sessions = append(sessions, &sessionData{ - expiry: v.expiry, - globalNSWrite: v.globalNSWrite, - claims: v.claims, - stats: v.stats, + expiry: v.expiry, + namespaceID: v.namespaceID, + namespace: v.namespace, + claims: v.claims, + stats: v.stats, }) v.mu.Unlock() } @@ -1733,15 +1792,13 @@ func (srv *Server) serveSessions(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/html; charset=utf-8") fmt.Fprintf(w, "
JWT issuer: %s
\n", iss) - fmt.Fprintf(w, "JWT claims required: %v
\n", cfg.requiredClaims) - fmt.Fprintf(w, "JWT global write claims required: %v
\n", cfg.globalWriteClaims) } fmt.Fprintf(w, "Number of sessions: %d
\n", len(sessions)) fmt.Fprintf(w, "| Last used | Expiry time | Global write | Stats | Claims |
|---|---|---|---|---|
| Last used | Expiry time | Namespace | Stats | Claims |
| %s | %s | %v | %s | %s |
| %s | %s | %s | %s | %s |