diff --git a/docs/refactor.md b/docs/refactor.md index 212cf8aa..9da1a087 100644 --- a/docs/refactor.md +++ b/docs/refactor.md @@ -1,6 +1,6 @@ # Storage, Secrets, Identity, And Composition Refactor Plan -_Updated 2026-05-27. Reflects actual codebase state vs. original plan._ +_Updated 2026-05-28. Reflects actual codebase state vs. original plan._ --- @@ -11,9 +11,9 @@ _Updated 2026-05-27. Reflects actual codebase state vs. original plan._ | 1 — Storage foundation | **Replaced** | `StoreDatabase` (SQLite/Postgres) for registries; `DiskStore + AesGcmEncryptionWrapper` for credentials. Two-store design is better than the single-KV-substrate plan. | | 2 — Secret source chain | **Partial** | `ServerSecretResolver` (env → file → keyring → generate) covers server-owned keys (master key, UI session key). Client identity private keys still resolved in `identity/local.py`. | | 3 — IdentityRepository | **Partial** | Server-side: `IdentityRegistry` in `server/store/repositories.py` (relational, done). Client-side: still raw `identity/local.py`, no structural server boundary. | -| 4 — CredentialRepository | **Not done** | `AuthService` still calls `build_store_key()` and `self._vault.get/put/delete` directly. This is the most significant remaining gap. | -| 5 — ProviderRepository | **Partial** | `ProviderDefinitionRepository` handles custom providers (done). Bundled provider loading still lives in `AuthService._load_bundled_providers()`. | -| 6 — Slim AuthService | **Partial** | Receives `ProviderDefinitionRepository` (done). Still owns raw vault key construction and all vault I/O. | +| 4 — CredentialRepository | **Done** | `server/credential_repository.py` owns vault key construction and credential persistence. | +| 5 — ProviderRepository | **Done** | `server/provider_repository.py` owns bundled + custom provider resolution. | +| 6 — Slim AuthService | **Partial** | Receives `CredentialRepository` and `ProviderRepository`; policy cleanup and audit logger injection remain separate follow-ups. | | 7 — Server composition root | **Functional** | `ServerStore` + `app.state` + `dependencies.py` cover the spirit. No `ServerState` dataclass. `identity="server"` placeholder is gone. | | 8 — Proxy server authority | **Done** | `proxy_catalog.py` + `/proxy/routes` endpoint. | | 9 — Docs | **Partial** | CONTEXT.md and UBIQUITOUS_LANGUAGE.md updated. | @@ -31,7 +31,7 @@ _Updated 2026-05-27. Reflects actual codebase state vs. original plan._ ## Remaining work -Three gaps remain. They are listed in delivery order — each one unblocks the next. +The original CredentialRepository and ProviderRepository gaps are now closed. Remaining phases below are follow-up work and historical context for the broader refactor. --- diff --git a/src/authsome/auth/utils.py b/src/authsome/auth/utils.py index 43c50d4d..8941c0a7 100644 --- a/src/authsome/auth/utils.py +++ b/src/authsome/auth/utils.py @@ -9,6 +9,13 @@ from typing import TYPE_CHECKING from urllib.parse import urlsplit, urlunsplit +from authsome.auth.input_provider import InputField +from authsome.auth.models.connection import ProviderClientRecord +from authsome.auth.models.enums import AuthType, FlowType +from authsome.auth.models.provider import ProviderDefinition +from authsome.errors import InvalidProviderSchemaError +from authsome.utils import is_filesystem_safe + if TYPE_CHECKING: from authsome.auth.sessions import AuthSession @@ -53,3 +60,105 @@ def normalize_base_url(base_url: str | None) -> str | None: def export_name_part(value: str) -> str: """Convert a string into a component suitable for an environment variable name.""" return re.sub(r"[^A-Z0-9]+", "_", value.upper()).strip("_") + + +VALID_FLOWS: dict[AuthType, set[FlowType]] = { + AuthType.OAUTH2: {FlowType.PKCE, FlowType.DEVICE_CODE, FlowType.DCR_PKCE}, + AuthType.API_KEY: {FlowType.API_KEY}, + AuthType.BROWSER: {FlowType.BROWSER}, +} + + +def validate_provider_definition(definition: ProviderDefinition) -> None: + if not is_filesystem_safe(definition.name): + raise InvalidProviderSchemaError( + f"Provider name '{definition.name}' is not filesystem-safe", + provider=definition.name, + ) + valid_flows = VALID_FLOWS.get(definition.auth_type) + if valid_flows is None: + raise InvalidProviderSchemaError( + f"Unrecognized auth_type: {definition.auth_type}", + provider=definition.name, + ) + if definition.flow not in valid_flows: + raise InvalidProviderSchemaError( + f"Flow '{definition.flow}' is not valid for auth_type '{definition.auth_type}'. " + f"Valid flows: {[flow.value for flow in valid_flows]}", + provider=definition.name, + ) + if definition.auth_type == AuthType.OAUTH2 and definition.oauth is None: + raise InvalidProviderSchemaError( + "auth_type 'oauth2' requires an 'oauth' configuration section", + provider=definition.name, + ) + if definition.auth_type == AuthType.API_KEY and definition.api_key is None: + raise InvalidProviderSchemaError( + "auth_type 'api_key' requires an 'api_key' configuration section", + provider=definition.name, + ) + if definition.auth_type == AuthType.BROWSER and definition.browser is None: + raise InvalidProviderSchemaError( + "auth_type 'browser' requires a 'browser' configuration section", + provider=definition.name, + ) + + +def required_inputs( + *, + provider: ProviderDefinition, + flow_type: FlowType, + client_record: ProviderClientRecord | None, + scopes: list[str] | None = None, + base_url: str | None = None, + provider_config_only: bool = False, +) -> list[InputField]: + flow_base_url = base_url or (client_record.base_url if client_record else None) + flow_client_id = client_record.client_id if client_record else None + persisted_scopes = client_record.scopes if client_record else None + fields: list[InputField] = [] + + if provider.oauth and provider.oauth.base_url and (provider_config_only or not flow_base_url): + fields.append( + InputField( + name="base_url", + label="Base URL", + secret=False, + default=flow_base_url or provider.oauth.base_url, + ) + ) + fields.append( + InputField( + name="api_url", + label="API Host URL", + secret=False, + default=( + client_record.api_url + if client_record and client_record.api_url + else provider.primary_api_url() or "" + ), + ) + ) + + if flow_type == FlowType.PKCE and (provider_config_only or not flow_client_id): + fields.append(InputField(name="client_id", label="Client ID", secret=False, default=flow_client_id or "")) + fields.append(InputField(name="client_secret", label="Client Secret", secret=True, default="")) + elif flow_type == FlowType.DEVICE_CODE and (provider_config_only or not flow_client_id): + fields.append(InputField(name="client_id", label="Client ID", secret=False, default=flow_client_id or "")) + fields.append(InputField(name="client_secret", label="Client Secret (Optional)", secret=True, default="")) + + if flow_type in (FlowType.PKCE, FlowType.DEVICE_CODE, FlowType.DCR_PKCE): + if scopes is None and persisted_scopes is None: + default_scopes = ",".join(provider.oauth.scopes) if provider.oauth and provider.oauth.scopes else "" + fields.append( + InputField(name="scopes", label="Scopes (comma-separated)", secret=False, default=default_scopes) + ) + + if flow_type == FlowType.API_KEY: + api_key_field = InputField(name="api_key", label="API Key", secret=True) + if provider.api_key and provider.api_key.key_pattern: + api_key_field.pattern = provider.api_key.key_pattern + api_key_field.pattern_hint = provider.api_key.key_pattern_hint + fields.append(api_key_field) + + return fields diff --git a/src/authsome/identity/__init__.py b/src/authsome/identity/__init__.py index 1d851dbe..380b7e93 100644 --- a/src/authsome/identity/__init__.py +++ b/src/authsome/identity/__init__.py @@ -30,8 +30,6 @@ ClaimStatus, IdentityClaimRecord, PrincipalRecord, - PrincipalVaultBindingRecord, - VaultRecord, ) from authsome.identity.proof import ( POP_AUTH_SCHEME, @@ -51,13 +49,11 @@ "IdentityStatus", "IdentityRegistration", "PrincipalRecord", - "PrincipalVaultBindingRecord", "POP_AUTH_SCHEME", "ProofClaims", "ProofValidationError", "ReplayCache", "RuntimeIdentity", - "VaultRecord", "current_from_home", "create_identity", "create_proof_jwt", diff --git a/src/authsome/identity/principal.py b/src/authsome/identity/principal.py index 1a9c01bb..b704bda9 100644 --- a/src/authsome/identity/principal.py +++ b/src/authsome/identity/principal.py @@ -1,4 +1,4 @@ -"""Domain models for principals, vaults, and ownership bindings. +"""Domain models for principals and identity claims. These are pure data models shared across server, cli, and identity modules. Filesystem-backed registry implementations live in server/registries.py. @@ -40,16 +40,6 @@ class PrincipalRecord(BaseModel): updated_at: datetime = Field(default_factory=utc_now) -# TODO: This should be a server property. The principal module should not care about which vault is owned by who -class VaultRecord(BaseModel): - """Vault record owned as a first-class resource.""" - - vault_id: str - handle: str = "default" - created_at: datetime = Field(default_factory=utc_now) - updated_at: datetime = Field(default_factory=utc_now) - - class IdentityClaimRecord(BaseModel): """Binding from identity to principal with lifecycle state.""" @@ -58,14 +48,3 @@ class IdentityClaimRecord(BaseModel): claim_status: ClaimStatus = ClaimStatus.PENDING created_at: datetime = Field(default_factory=utc_now) updated_at: datetime = Field(default_factory=utc_now) - - -# TODO: This should be a server property. The principal module should not care about which vault is owned by who -class PrincipalVaultBindingRecord(BaseModel): - """Binding from principal to a vault.""" - - principal_id: str - vault_id: str - is_default: bool = False - created_at: datetime = Field(default_factory=utc_now) - updated_at: datetime = Field(default_factory=utc_now) diff --git a/src/authsome/server/app.py b/src/authsome/server/app.py index 06ef7c2a..96dcd9d8 100644 --- a/src/authsome/server/app.py +++ b/src/authsome/server/app.py @@ -25,6 +25,7 @@ load_server_config, load_ui_session_signing_secret, ) +from authsome.server.provider_repository import ProviderRepository from authsome.server.routes.audit import router as audit_router from authsome.server.routes.auth import router as auth_router from authsome.server.routes.connections import router as connections_router @@ -52,7 +53,7 @@ async def lifespan(app: FastAPI): app.state.vault_registry = app.state.store.vaults app.state.identity_claim_registry = app.state.store.identity_claims app.state.principal_vault_binding_registry = app.state.store.principal_vault_bindings - app.state.provider_definition_repository = app.state.store.provider_definitions + app.state.provider_repository = ProviderRepository(app.state.store.provider_definitions) app.state.hosted_account_service = create_hosted_account_service(app.state.store) app.state.server_base_url = get_server_base_url() init_posthog() diff --git a/src/authsome/server/credential_repository.py b/src/authsome/server/credential_repository.py new file mode 100644 index 00000000..5aa2c482 --- /dev/null +++ b/src/authsome/server/credential_repository.py @@ -0,0 +1,248 @@ +"""Server-owned repository for credential records stored in Vault.""" + +from __future__ import annotations + +import json +from typing import NamedTuple + +from loguru import logger + +from authsome.auth.models.connection import ( + ConnectionRecord, + ProviderClientRecord, + ProviderMetadataRecord, + ProviderStateRecord, +) +from authsome.vault import Vault + + +class StoreKeyParts(NamedTuple): + """Parsed components of a credential store key.""" + + vault: str | None = None + identity: str | None = None + provider: str | None = None + record_type: str | None = None + connection: str | None = None + + +def build_store_key( + *, + vault: str | None = None, + identity: str | None = None, + provider: str | None = None, + record_type: str | None = None, + connection: str | None = None, +) -> str: + """Build a namespaced key for server-owned credential storage.""" + if record_type == "definition" and provider: + return f"provider:{provider}:definition" + if record_type == "server" and provider: + return f"server:provider:{provider}:client" + + if vault and provider: + if record_type == "metadata": + return f"vault:{vault}:{provider}:metadata" + if record_type == "state": + return f"vault:{vault}:{provider}:state" + if record_type == "connection" and connection: + return f"vault:{vault}:{provider}:connection:{connection}" + if record_type == "client": + return f"vault:{vault}:{provider}:client" + + if identity and provider: + if record_type == "metadata": + return f"identity:{identity}:{provider}:metadata" + if record_type == "state": + return f"identity:{identity}:{provider}:state" + if record_type == "connection" and connection: + return f"identity:{identity}:{provider}:connection:{connection}" + if record_type == "client": + return f"identity:{identity}:{provider}:client" + + raise ValueError( + f"Cannot build store key with vault={vault}, identity={identity}, provider={provider}, " + f"record_type={record_type}, connection={connection}" + ) + + +def parse_store_key(key: str) -> StoreKeyParts: + """Parse a credential store key into its components.""" + if key.startswith("provider:") and key.endswith(":definition"): + provider = key[len("provider:") : -len(":definition")] + return StoreKeyParts(provider=provider, record_type="definition") + + if key.startswith("server:provider:") and key.endswith(":client"): + provider = key[len("server:provider:") : -len(":client")] + return StoreKeyParts(provider=provider, record_type="server") + + if key.startswith("vault:"): + parts = key.split(":", 2) + if len(parts) < 3: + return StoreKeyParts() + vault = parts[1] + remainder = parts[2] + + if remainder.endswith(":metadata"): + return StoreKeyParts(vault=vault, provider=remainder[:-9], record_type="metadata") + if remainder.endswith(":state"): + return StoreKeyParts(vault=vault, provider=remainder[:-6], record_type="state") + if remainder.endswith(":client"): + return StoreKeyParts(vault=vault, provider=remainder[:-7], record_type="client") + + if ":connection:" in remainder: + provider, _, connection = remainder.partition(":connection:") + return StoreKeyParts( + vault=vault, + provider=provider, + record_type="connection", + connection=connection, + ) + + if key.startswith("identity:"): + parts = key.split(":", 2) + if len(parts) < 3: + return StoreKeyParts() + identity = parts[1] + remainder = parts[2] + + if remainder.endswith(":metadata"): + return StoreKeyParts(identity=identity, provider=remainder[:-9], record_type="metadata") + if remainder.endswith(":state"): + return StoreKeyParts(identity=identity, provider=remainder[:-6], record_type="state") + if remainder.endswith(":client"): + return StoreKeyParts(identity=identity, provider=remainder[:-7], record_type="client") + + if ":connection:" in remainder: + provider, _, connection = remainder.partition(":connection:") + return StoreKeyParts( + identity=identity, + provider=provider, + record_type="connection", + connection=connection, + ) + + return StoreKeyParts() + + +class CredentialRepository: + """Persist Authsome credential records in vault-backed namespaces.""" + + def __init__( + self, + vault: Vault, + *, + identity: str | None, + principal_id: str | None, + vault_id: str, + ) -> None: + self._vault = vault + self._identity = identity + self._principal_id = principal_id + self._vault_id = vault_id + + @property + def vault(self) -> Vault: + return self._vault + + @property + def vault_id(self) -> str: + return self._vault_id + + @property + def collection(self) -> str: + return f"vault:{self._vault_id}" + + @property + def server_collection(self) -> str: + return "server" + + async def list_connection_keys(self) -> list[str]: + prefix = f"vault:{self._vault_id}:" + keys = await self._vault.list(prefix, collection=self.collection) + return [key for key in keys if parse_store_key(key).record_type == "connection"] + + async def get_connection(self, provider: str, connection: str) -> ConnectionRecord | None: + key = build_store_key( + vault=self._vault_id, + provider=provider, + record_type="connection", + connection=connection, + ) + raw = await self._vault.get(key, collection=self.collection) + if raw is None: + return None + return self.load_connection_record(raw, key) + + async def save_connection(self, record: ConnectionRecord) -> None: + record.identity = self._identity + record.principal_id = self._principal_id + record.vault_id = self._vault_id + key = build_store_key( + vault=self._vault_id, + provider=record.provider, + record_type="connection", + connection=record.connection_name, + ) + await self._vault.put(key, record.model_dump_json(), collection=self.collection) + + async def delete_connection(self, provider: str, connection: str) -> None: + key = build_store_key( + vault=self._vault_id, + provider=provider, + record_type="connection", + connection=connection, + ) + await self._vault.delete(key, collection=self.collection) + + async def get_provider_metadata(self, provider: str) -> ProviderMetadataRecord | None: + key = build_store_key(vault=self._vault_id, provider=provider, record_type="metadata") + raw = await self._vault.get(key, collection=self.collection) + return ProviderMetadataRecord.model_validate_json(raw) if raw else None + + async def save_provider_metadata(self, record: ProviderMetadataRecord) -> None: + record.identity = self._identity + record.principal_id = self._principal_id + record.vault_id = self._vault_id + key = build_store_key(vault=self._vault_id, provider=record.provider, record_type="metadata") + await self._vault.put(key, record.model_dump_json(), collection=self.collection) + + async def delete_provider_metadata(self, provider: str) -> None: + key = build_store_key(vault=self._vault_id, provider=provider, record_type="metadata") + await self._vault.delete(key, collection=self.collection) + + async def get_provider_state(self, provider: str) -> ProviderStateRecord | None: + key = build_store_key(vault=self._vault_id, provider=provider, record_type="state") + raw = await self._vault.get(key, collection=self.collection) + return ProviderStateRecord.model_validate_json(raw) if raw else None + + async def save_provider_state(self, record: ProviderStateRecord) -> None: + record.identity = self._identity + record.principal_id = self._principal_id + record.vault_id = self._vault_id + key = build_store_key(vault=self._vault_id, provider=record.provider, record_type="state") + await self._vault.put(key, record.model_dump_json(), collection=self.collection) + + async def get_provider_client(self, provider: str) -> ProviderClientRecord | None: + key = build_store_key(provider=provider, record_type="server") + raw = await self._vault.get(key, collection=self.server_collection) + return ProviderClientRecord.model_validate_json(raw) if raw else None + + async def save_provider_client(self, record: ProviderClientRecord) -> None: + key = build_store_key(provider=record.provider, record_type="server") + await self._vault.put(key, record.model_dump_json(), collection=self.server_collection) + + async def delete_provider_client(self, provider: str) -> None: + key = build_store_key(provider=provider, record_type="server") + await self._vault.delete(key, collection=self.server_collection) + + @staticmethod + def load_connection_record(raw: str, key: str) -> ConnectionRecord | None: + try: + data = json.loads(raw) + except json.JSONDecodeError: + logger.warning("Corrupt record at key {}", key) + return None + if data.get("schema_version", 1) < 2: + return None + return ConnectionRecord.model_validate(data) diff --git a/src/authsome/server/credential_service.py b/src/authsome/server/credential_service.py index c297cd72..7980b195 100644 --- a/src/authsome/server/credential_service.py +++ b/src/authsome/server/credential_service.py @@ -6,7 +6,6 @@ from __future__ import annotations -import importlib.resources import json from datetime import timedelta from typing import Any @@ -16,7 +15,6 @@ from authsome import audit from authsome.auth.flows.api_key import ApiKeyFlow -from authsome.auth.flows.base import AuthFlow from authsome.auth.flows.browser import BrowserFlow from authsome.auth.flows.dcr_pkce import DcrPkceFlow from authsome.auth.flows.device_code import DeviceCodeFlow @@ -31,33 +29,32 @@ from authsome.auth.models.enums import AuthType, ConnectionStatus, ExportFormat, FlowType from authsome.auth.models.provider import ProviderDefinition from authsome.auth.sessions import AuthSession -from authsome.auth.utils import export_name_part, normalize_base_url, normalize_scopes +from authsome.auth.utils import ( + export_name_part, + normalize_base_url, + normalize_scopes, + required_inputs, + validate_provider_definition, +) from authsome.errors import ( - AuthsomeError, ConnectionNotFoundError, CredentialMissingError, IdentityNotFoundError, InvalidProviderSchemaError, OperationNotAllowedError, - ProviderNotFoundError, RefreshFailedError, TokenExpiredError, UnsupportedFlowError, ) from authsome.identity.principal import PrincipalRole -from authsome.server.store.repositories import ProviderDefinitionRepository -from authsome.utils import build_store_key, format_duration, is_filesystem_safe, parse_store_key, utc_now +from authsome.server.credential_repository import CredentialRepository, parse_store_key +from authsome.server.provider_repository import ProviderRepository +from authsome.utils import format_duration, utc_now from authsome.vault import Vault -_VALID_FLOWS: dict[AuthType, set[FlowType]] = { - AuthType.OAUTH2: {FlowType.PKCE, FlowType.DEVICE_CODE, FlowType.DCR_PKCE}, - AuthType.API_KEY: {FlowType.API_KEY}, - AuthType.BROWSER: {FlowType.BROWSER}, -} - _NEAR_EXPIRY_SECONDS = 300 -_FLOW_HANDLERS: dict[FlowType, type[AuthFlow]] = { +_FLOW_HANDLERS = { FlowType.PKCE: PkceFlow, FlowType.DEVICE_CODE: DeviceCodeFlow, FlowType.DCR_PKCE: DcrPkceFlow, @@ -70,42 +67,29 @@ class AuthService: """ Authentication and credential lifecycle service. - All reads and writes go through self._vault. - Key construction (vault:::...) lives here. + Coordinates provider lookup, auth flows, credential persistence, and policy checks. """ def __init__( self, - vault: Vault, - provider_definitions: ProviderDefinitionRepository, + *, + credentials: CredentialRepository, + providers: ProviderRepository, identity: str | None = None, principal_id: str | None = None, principal_role: PrincipalRole = PrincipalRole.USER, vault_id: str | None = None, ) -> None: - self._vault = vault + self._credentials = credentials self._identity = identity self._principal_id = principal_id + self._vault_id = vault_id or credentials.vault_id self._principal_role = principal_role - self._vault_id = vault_id - self._provider_definitions = provider_definitions - self._bundled: dict[str, ProviderDefinition] = self._load_bundled_providers() - - @property - def _coll(self) -> str: - """Vault collection for the resolved credential scope.""" - if self._vault_id is None: - raise ValueError("AuthService.vault_id is required for vault-scoped operations but was not set") - return f"vault:{self._vault_id}" - - @property - def _server_coll(self) -> str: - """Vault collection for server-scoped provider client records.""" - return "server" + self._providers = providers @property def vault(self) -> Vault: - return self._vault + return self._credentials.vault @property def identity(self) -> str | None: @@ -131,46 +115,18 @@ def vault_id(self) -> str | None: # ── Provider operations ─────────────────────────────────────────────── - @staticmethod - def _load_bundled_providers() -> dict[str, ProviderDefinition]: - bundled: dict[str, ProviderDefinition] = {} - try: - files = importlib.resources.files("authsome.auth.bundled_providers") - for file in files.iterdir(): - if file.name.endswith(".json"): - with file.open("r", encoding="utf-8") as f: - data = json.load(f) - defn = ProviderDefinition.model_validate(data) - bundled[defn.name] = defn - except Exception as e: - logger.warning("Error loading bundled providers: {}", e) - return bundled - - async def _load_custom_providers(self) -> dict[str, ProviderDefinition]: - providers = await self._provider_definitions.list() - return {provider.name: provider for provider in providers} - async def list_providers(self) -> list[ProviderDefinition]: - providers = {**self._bundled, **(await self._load_custom_providers())} - return sorted(providers.values(), key=lambda p: p.name) + return await self._providers.list() async def list_providers_by_source(self) -> dict[str, list[ProviderDefinition]]: - bundled_list = sorted(self._bundled.values(), key=lambda p: p.name) - custom_providers = await self._load_custom_providers() - custom_list = sorted(custom_providers.values(), key=lambda p: p.name) - return {"bundled": bundled_list, "custom": custom_list} + return await self._providers.list_by_source() async def get_provider(self, provider: str) -> ProviderDefinition: - custom = await self._provider_definitions.get(provider) - if custom is not None: - return custom - if provider in self._bundled: - return self._bundled[provider] - raise ProviderNotFoundError(provider) + return await self._providers.get(provider) async def is_local_provider(self, provider: str) -> bool: """Check if a provider is a custom/local provider.""" - return await self._provider_definitions.get(provider) is not None + return await self._providers.is_custom(provider) async def resolve_credentials(self, **kwargs: Any) -> dict[str, Any]: """Resolve credentials for a provider/connection pair.""" @@ -189,12 +145,12 @@ async def resolve_credentials(self, **kwargs: Any) -> dict[str, Any]: async def register_provider(self, definition: ProviderDefinition, *, force: bool = False) -> None: self._ensure_admin_operation_allowed("register", definition.name) self._validate_provider(definition) - await self._provider_definitions.save(definition, force=force) + await self._providers.save_custom(definition, force=force) logger.info("Registered provider: {}", definition.name) async def remove_provider(self, name: str) -> bool: """Remove a custom provider. Returns True if removed.""" - return await self._provider_definitions.delete(name) + return await self._providers.delete_custom(name) def _ensure_admin_operation_allowed(self, operation: str, provider: str) -> None: if self._principal_role == PrincipalRole.ADMIN: @@ -215,33 +171,7 @@ def _ensure_provider_client_mutation_allowed(self, provider: str) -> None: ) def _validate_provider(self, definition: ProviderDefinition) -> None: - if not is_filesystem_safe(definition.name): - raise InvalidProviderSchemaError( - f"Provider name '{definition.name}' is not filesystem-safe", provider=definition.name - ) - valid_flows = _VALID_FLOWS.get(definition.auth_type) - if valid_flows is None: - raise InvalidProviderSchemaError( - f"Unrecognized auth_type: {definition.auth_type}", provider=definition.name - ) - if definition.flow not in valid_flows: - raise InvalidProviderSchemaError( - f"Flow '{definition.flow}' is not valid for auth_type '{definition.auth_type}'. " - f"Valid flows: {[f.value for f in valid_flows]}", - provider=definition.name, - ) - if definition.auth_type == AuthType.OAUTH2 and definition.oauth is None: - raise InvalidProviderSchemaError( - "auth_type 'oauth2' requires an 'oauth' configuration section", provider=definition.name - ) - if definition.auth_type == AuthType.API_KEY and definition.api_key is None: - raise InvalidProviderSchemaError( - "auth_type 'api_key' requires an 'api_key' configuration section", provider=definition.name - ) - if definition.auth_type == AuthType.BROWSER and definition.browser is None: - raise InvalidProviderSchemaError( - "auth_type 'browser' requires a 'browser' configuration section", provider=definition.name - ) + validate_provider_definition(definition) if definition.oauth: for field_name in ("authorization_url", "token_url"): url = getattr(definition.oauth, field_name, None) @@ -259,8 +189,7 @@ def _validate_url(url: str, field_name: str, provider_name: str) -> None: # ── Connection operations ───────────────────────────────────────────── async def list_connections(self) -> list[dict[str, Any]]: - prefix = f"vault:{self._vault_id}:" - keys = await self._vault.list(prefix, collection=self._coll) + keys = await self._credentials.list_connection_keys() providers: dict[str, list[dict[str, Any]]] = {} defaults: dict[str, str] = {} @@ -270,33 +199,25 @@ async def list_connections(self) -> list[dict[str, Any]]: provider_name = parts.provider connection_name = parts.connection if provider_name not in defaults: - meta_key = build_store_key(vault=self._vault_id, provider=provider_name, record_type="metadata") - meta_json = await self._vault.get(meta_key, collection=self._coll) - if meta_json: - defaults[provider_name] = ProviderMetadataRecord.model_validate_json( - meta_json - ).default_connection - else: - defaults[provider_name] = "default" - record_json = await self._vault.get(key, collection=self._coll) - if record_json: - record = self._load_connection_record(record_json, key) - if record is None: - continue - if provider_name not in providers: - providers[provider_name] = [] - providers[provider_name].append( - { - "connection_name": connection_name, - "is_default": connection_name == defaults.get(provider_name, "default"), - "auth_type": record.auth_type.value, - "status": record.status.value, - "scopes": record.scopes, - "base_url": record.base_url, - "api_url": record.api_url, - "expires_at": record.expires_at.isoformat() if record.expires_at else None, - } - ) + metadata = await self._credentials.get_provider_metadata(provider_name) + defaults[provider_name] = metadata.default_connection if metadata else "default" + record = await self._credentials.get_connection(provider_name, connection_name) + if record is None: + continue + if provider_name not in providers: + providers[provider_name] = [] + providers[provider_name].append( + { + "connection_name": connection_name, + "is_default": connection_name == defaults.get(provider_name, "default"), + "auth_type": record.auth_type.value, + "status": record.status.value, + "scopes": record.scopes, + "base_url": record.base_url, + "api_url": record.api_url, + "expires_at": record.expires_at.isoformat() if record.expires_at else None, + } + ) return [ {"name": pname, "default_connection": defaults.get(pname, "default"), "connections": conns} @@ -309,32 +230,21 @@ async def get_connection( connection: str = "default", ) -> ConnectionRecord: connection = await self.resolve_connection_name(provider, connection) - key = build_store_key(vault=self._vault_id, provider=provider, record_type="connection", connection=connection) - record_json = await self._vault.get(key, collection=self._coll) - if not record_json: + record = await self._credentials.get_connection(provider, connection) + if record is None: raise ConnectionNotFoundError( provider=provider, connection=connection, identity=self._identity or self._principal_id or "hosted-ui", ) - record = self._load_connection_record(record_json, key) - if record is None: - raise AuthsomeError( - f"Stored credentials for '{provider}' use the old v1 format. " - "Please run: authsome provider revoke {provider} && authsome login {provider}" - ) return record async def resolve_connection_name(self, provider: str, connection: str | None = None) -> str: """Resolve an optional connection name to the provider default.""" if connection: return connection - meta_key = build_store_key(vault=self._vault_id, provider=provider, record_type="metadata") - existing_json = await self._vault.get(meta_key, collection=self._coll) - if existing_json: - metadata = ProviderMetadataRecord.model_validate_json(existing_json) - return metadata.default_connection - return "default" + metadata = await self._credentials.get_provider_metadata(provider) + return metadata.default_connection if metadata else "default" async def get_provider_client(self, provider: str) -> ProviderClientRecord | None: """Return stored client credentials for a provider, or None if absent. @@ -411,11 +321,8 @@ async def update_provider_configuration( async def set_default_connection(self, provider: str, connection: str) -> None: """Set the default connection for a provider.""" await self.get_connection(provider, connection) - meta_key = build_store_key(vault=self._vault_id, provider=provider, record_type="metadata") - existing_json = await self._vault.get(meta_key, collection=self._coll) - if existing_json: - metadata = ProviderMetadataRecord.model_validate_json(existing_json) - else: + metadata = await self._credentials.get_provider_metadata(provider) + if metadata is None: metadata = ProviderMetadataRecord( identity=self._identity, principal_id=self._principal_id, @@ -426,7 +333,7 @@ async def set_default_connection(self, provider: str, connection: str) -> None: metadata.connection_names.append(connection) metadata.default_connection = connection metadata.last_used_connection = connection - await self._vault.put(meta_key, metadata.model_dump_json(), collection=self._coll) + await self._credentials.save_provider_metadata(metadata) # ── Authentication ──────────────────────────────────────────────────── @@ -437,80 +344,17 @@ async def get_required_inputs( base_url: str | None = None, ) -> list[InputField]: """Determine what inputs are missing for a given session.""" - from authsome.auth.input_provider import InputField - provider = session.provider definition = await self.get_provider(provider) - flow_type = FlowType(session.flow_type) client_record = await self._get_provider_client_credentials(provider) - provider_config_only = bool(session.payload.get("provider_config_only")) - - flow_base_url = base_url or (client_record.base_url if client_record else None) - flow_client_id = client_record.client_id if client_record else None - persisted_scopes = client_record.scopes if client_record else None - - fields: list[InputField] = [] - - if definition.oauth and definition.oauth.base_url and (provider_config_only or not flow_base_url): - fields.append( - InputField( - name="base_url", - label="Base URL", - secret=False, - default=flow_base_url or definition.oauth.base_url, - ) - ) - fields.append( - InputField( - name="api_url", - label="API Host URL", - secret=False, - default=( - client_record.api_url - if client_record and client_record.api_url - else definition.primary_api_url() or "" - ), - ) - ) - - if flow_type == FlowType.PKCE and (provider_config_only or not flow_client_id): - fields.append( - InputField( - name="client_id", - label="Client ID", - secret=False, - default=flow_client_id or "", - ) - ) - fields.append(InputField(name="client_secret", label="Client Secret", secret=True, default="")) - elif flow_type == FlowType.DEVICE_CODE and (provider_config_only or not flow_client_id): - fields.append( - InputField( - name="client_id", - label="Client ID", - secret=False, - default=flow_client_id or "", - ) - ) - fields.append(InputField(name="client_secret", label="Client Secret (Optional)", secret=True, default="")) - - if flow_type in (FlowType.PKCE, FlowType.DEVICE_CODE, FlowType.DCR_PKCE): - if scopes is None and persisted_scopes is None: - default_scopes = ( - ",".join(definition.oauth.scopes) if definition.oauth and definition.oauth.scopes else "" - ) - fields.append( - InputField(name="scopes", label="Scopes (comma-separated)", secret=False, default=default_scopes) - ) - - if flow_type == FlowType.API_KEY: - api_key_field = InputField(name="api_key", label="API Key", secret=True) - if definition.api_key and definition.api_key.key_pattern: - api_key_field.pattern = definition.api_key.key_pattern - api_key_field.pattern_hint = definition.api_key.key_pattern_hint - fields.append(api_key_field) - - return fields + return required_inputs( + provider=definition, + flow_type=FlowType(session.flow_type), + client_record=client_record, + scopes=scopes, + base_url=base_url, + provider_config_only=bool(session.payload.get("provider_config_only")), + ) async def save_inputs(self, session: AuthSession, inputs: dict[str, str]) -> None: """Save collected inputs to the Vault or session payload.""" @@ -693,17 +537,9 @@ def _build_docs_hints(definition: ProviderDefinition, flow_type: FlowType) -> li """Convert provider docs URL into a bridge instruction block.""" if not definition.docs_url: return [] - if flow_type not in (FlowType.PKCE, FlowType.DEVICE_CODE, FlowType.DCR_PKCE, FlowType.API_KEY): return [] - - return [ - { - "type": "instructions", - "label": "Instructions", - "url": definition.docs_url, - } - ] + return [{"type": "instructions", "label": "Instructions", "url": definition.docs_url}] # ── Token operations ────────────────────────────────────────────────── @@ -741,8 +577,7 @@ async def logout(self, provider: str, connection: str = "default") -> None: client_secret=client_secret, ) - key = build_store_key(vault=self._vault_id, provider=provider, record_type="connection", connection=connection) - await self._vault.delete(key, collection=self._coll) + await self._credentials.delete_connection(provider, connection) await self._remove_from_provider_metadata(provider, connection) async def revoke(self, provider: str, vault_ids: list[str] | None = None) -> None: @@ -755,33 +590,36 @@ async def revoke(self, provider: str, vault_ids: list[str] | None = None) -> Non await self.get_provider(provider) ids_to_revoke = vault_ids if vault_ids is not None else ([self._vault_id] if self._vault_id else []) for vault_id in ids_to_revoke: + credentials = CredentialRepository( + self.vault, + identity=self._identity, + principal_id=self._principal_id, + vault_id=vault_id, + ) vault_service = AuthService( - vault=self._vault, + credentials=credentials, + providers=self._providers, identity=self._identity, principal_id=self._principal_id, principal_role=self._principal_role, vault_id=vault_id, - provider_definitions=self._provider_definitions, ) - meta_key = build_store_key(vault=vault_id, provider=provider, record_type="metadata") - existing_json = await self._vault.get(meta_key, collection=vault_service._coll) - if not existing_json: + metadata = await credentials.get_provider_metadata(provider) + if metadata is None: continue - metadata = ProviderMetadataRecord.model_validate_json(existing_json) for conn_name in list(metadata.connection_names): await vault_service.logout(provider, connection=conn_name) - await self._vault.delete(meta_key, collection=vault_service._coll) + await credentials.delete_provider_metadata(provider) - client_key = build_store_key(provider=provider, record_type="server") - await self._vault.delete(client_key, collection=self._server_coll) + await self._credentials.delete_provider_client(provider) async def remove(self, provider: str) -> None: """Revoke all tokens and remove the provider definition if it is local.""" self._ensure_admin_operation_allowed("remove", provider) await self.revoke(provider) if await self.is_local_provider(provider): - await self._provider_definitions.delete(provider) + await self._providers.delete_custom(provider) logger.info("Removed local provider definition: {}", provider) else: logger.info("Revoked bundled provider: {} (definition kept)", provider) @@ -888,48 +726,18 @@ async def get_identity(self, name: str) -> str: # ── Internal helpers ────────────────────────────────────────────────── - def _load_connection_record(self, record_json: str, key: str) -> ConnectionRecord | None: - """Load and validate a connection record, detecting v1 format.""" - try: - data = json.loads(record_json) - except json.JSONDecodeError: - logger.warning("Corrupt record at key {}", key) - return None - - if data.get("schema_version", 1) < 2: - return None # v1 data — caller handles detection - - return ConnectionRecord.model_validate(data) - async def _save_connection(self, record: ConnectionRecord) -> None: - record.identity = self._identity - record.principal_id = self._principal_id - record.vault_id = self._vault_id - key = build_store_key( - vault=self._vault_id, - provider=record.provider, - record_type="connection", - connection=record.connection_name, - ) - await self._vault.put(key, record.model_dump_json(), collection=self._coll) + await self._credentials.save_connection(record) async def _get_provider_client_credentials(self, provider: str) -> ProviderClientRecord | None: - key = build_store_key(provider=provider, record_type="server") - record_json = await self._vault.get(key, collection=self._server_coll) - if record_json: - return ProviderClientRecord.model_validate_json(record_json) - return None + return await self._credentials.get_provider_client(provider) async def _save_provider_client_credentials(self, record: ProviderClientRecord) -> None: - key = build_store_key(provider=record.provider, record_type="server") - await self._vault.put(key, record.model_dump_json(), collection=self._server_coll) + await self._credentials.save_provider_client(record) async def _update_provider_metadata(self, provider: str, connection_name: str) -> None: - meta_key = build_store_key(vault=self._vault_id, provider=provider, record_type="metadata") - existing_json = await self._vault.get(meta_key, collection=self._coll) - if existing_json: - metadata = ProviderMetadataRecord.model_validate_json(existing_json) - else: + metadata = await self._credentials.get_provider_metadata(provider) + if metadata is None: metadata = ProviderMetadataRecord( identity=self._identity, principal_id=self._principal_id, @@ -939,18 +747,17 @@ async def _update_provider_metadata(self, provider: str, connection_name: str) - if connection_name not in metadata.connection_names: metadata.connection_names.append(connection_name) metadata.last_used_connection = connection_name - await self._vault.put(meta_key, metadata.model_dump_json(), collection=self._coll) + await self._credentials.save_provider_metadata(metadata) async def _remove_from_provider_metadata(self, provider: str, connection_name: str) -> None: - meta_key = build_store_key(vault=self._vault_id, provider=provider, record_type="metadata") - existing_json = await self._vault.get(meta_key, collection=self._coll) - if existing_json: - metadata = ProviderMetadataRecord.model_validate_json(existing_json) - if connection_name in metadata.connection_names: - metadata.connection_names.remove(connection_name) - if metadata.last_used_connection == connection_name: - metadata.last_used_connection = metadata.connection_names[0] if metadata.connection_names else None - await self._vault.put(meta_key, metadata.model_dump_json(), collection=self._coll) + metadata = await self._credentials.get_provider_metadata(provider) + if metadata is None: + return + if connection_name in metadata.connection_names: + metadata.connection_names.remove(connection_name) + if metadata.last_used_connection == connection_name: + metadata.last_used_connection = metadata.connection_names[0] if metadata.connection_names else None + await self._credentials.save_provider_metadata(metadata) def _get_api_key(self, record: ConnectionRecord) -> str: if record.api_key is None: @@ -1059,10 +866,9 @@ async def _refresh_token(self, record: ConnectionRecord, provider_name: str) -> return record async def _get_or_create_provider_state(self, provider: str) -> ProviderStateRecord: - key = build_store_key(vault=self._vault_id, provider=provider, record_type="state") - existing = await self._vault.get(key, collection=self._coll) + existing = await self._credentials.get_provider_state(provider) if existing: - return ProviderStateRecord.model_validate_json(existing) + return existing return ProviderStateRecord( provider=provider, identity=self._identity, @@ -1071,11 +877,7 @@ async def _get_or_create_provider_state(self, provider: str) -> ProviderStateRec ) async def _save_provider_state(self, state: ProviderStateRecord) -> None: - state.identity = self._identity - state.principal_id = self._principal_id - state.vault_id = self._vault_id - key = build_store_key(vault=self._vault_id, provider=state.provider, record_type="state") - await self._vault.put(key, state.model_dump_json(), collection=self._coll) + await self._credentials.save_provider_state(state) async def _get_access_token_from_record(self, record: ConnectionRecord) -> str: if record.auth_type == AuthType.API_KEY: diff --git a/src/authsome/server/dependencies.py b/src/authsome/server/dependencies.py index e0ab7d5e..9340bf7e 100644 --- a/src/authsome/server/dependencies.py +++ b/src/authsome/server/dependencies.py @@ -16,9 +16,11 @@ from authsome.paths import get_server_audit_db_path as _get_server_audit_db_path from authsome.paths import get_server_home as _get_server_home from authsome.paths import get_server_log_path as _get_server_log_path +from authsome.server.credential_repository import CredentialRepository from authsome.server.hosted_auth import HostedAccountService from authsome.server.identity_bootstrap import IdentityBootstrapService from authsome.server.ownership import OwnershipResolver +from authsome.server.provider_repository import ProviderRepository from authsome.server.secrets import load_master_secret, load_ui_session_signing_secret from authsome.server.store import ServerStore from authsome.server.store import create_server_store as _create_server_store @@ -93,6 +95,21 @@ async def create_vault(home: Path) -> Vault: return Vault(encrypted_kv) +def create_credential_repository( + *, + vault: Vault, + identity: str | None, + principal_id: str | None, + vault_id: str, +) -> CredentialRepository: + return CredentialRepository( + vault, + identity=identity, + principal_id=principal_id, + vault_id=vault_id, + ) + + async def create_auth_service( home: Path | None = None, identity: str | None = None, vault_id: str | None = None ) -> AuthService: @@ -106,10 +123,15 @@ async def create_auth_service( store = await create_store(home) vault = await create_vault(store.home) return AuthService( - vault=vault, + credentials=create_credential_repository( + vault=vault, + identity=identity, + principal_id=None, + vault_id=vault_id, + ), + providers=ProviderRepository(store.provider_definitions), identity=identity, vault_id=vault_id, - provider_definitions=store.provider_definitions, ) diff --git a/src/authsome/server/provider_repository.py b/src/authsome/server/provider_repository.py new file mode 100644 index 00000000..415a8ed3 --- /dev/null +++ b/src/authsome/server/provider_repository.py @@ -0,0 +1,67 @@ +"""Server-owned provider definition repository.""" + +from __future__ import annotations + +import builtins +import importlib.resources +import json + +from loguru import logger + +from authsome.auth.models.provider import ProviderDefinition +from authsome.errors import ProviderNotFoundError +from authsome.server.store.repositories import ProviderDefinitionRepository + + +class ProviderRepository: + """Resolve bundled and custom provider definitions.""" + + def __init__(self, custom: ProviderDefinitionRepository) -> None: + self._custom = custom + self._bundled: dict[str, ProviderDefinition] | None = None + + def _load_bundled(self) -> dict[str, ProviderDefinition]: + if self._bundled is not None: + return self._bundled + + bundled: dict[str, ProviderDefinition] = {} + try: + files = importlib.resources.files("authsome.auth.bundled_providers") + for file in files.iterdir(): + if file.name.endswith(".json"): + with file.open("r", encoding="utf-8") as handle: + definition = ProviderDefinition.model_validate(json.load(handle)) + bundled[definition.name] = definition + except Exception as exc: + logger.warning("Error loading bundled providers: {}", exc) + + self._bundled = bundled + return bundled + + async def get(self, name: str) -> ProviderDefinition: + custom = await self._custom.get(name) + if custom is not None: + return custom + bundled = self._load_bundled() + if name in bundled: + return bundled[name] + raise ProviderNotFoundError(name) + + async def list(self) -> builtins.list[ProviderDefinition]: + providers = {**self._load_bundled()} + providers.update({provider.name: provider for provider in await self._custom.list()}) + return sorted(providers.values(), key=lambda provider: provider.name) + + async def list_by_source(self) -> dict[str, builtins.list[ProviderDefinition]]: + bundled = sorted(self._load_bundled().values(), key=lambda provider: provider.name) + custom = sorted(await self._custom.list(), key=lambda provider: provider.name) + return {"bundled": bundled, "custom": custom} + + async def save_custom(self, definition: ProviderDefinition, *, force: bool = False) -> None: + await self._custom.save(definition, force=force) + + async def delete_custom(self, name: str) -> bool: + return await self._custom.delete(name) + + async def is_custom(self, name: str) -> bool: + return await self._custom.get(name) is not None diff --git a/src/authsome/server/routes/_deps.py b/src/authsome/server/routes/_deps.py index 11b64d0f..f08a7f3d 100644 --- a/src/authsome/server/routes/_deps.py +++ b/src/authsome/server/routes/_deps.py @@ -7,6 +7,7 @@ from authsome.auth.sessions import AuthSessionStore from authsome.identity.principal import PrincipalRole from authsome.identity.proof import POP_AUTH_SCHEME, ProofValidationError, validate_proof_jwt +from authsome.server.credential_repository import CredentialRepository from authsome.server.credential_service import AuthService from authsome.server.store.repositories import VaultRegistry from authsome.server.ui_sessions import UiSessionStore @@ -14,6 +15,30 @@ UI_SESSION_COOKIE_NAME = "authsome_ui_session" +def build_auth_service( + request: Request, + *, + identity: str | None, + principal_id: str | None, + principal_role: PrincipalRole = PrincipalRole.USER, + vault_id: str, +) -> AuthService: + credentials = CredentialRepository( + request.app.state.vault, + identity=identity, + principal_id=principal_id, + vault_id=vault_id, + ) + return AuthService( + credentials=credentials, + providers=request.app.state.provider_repository, + identity=identity, + principal_id=principal_id, + principal_role=principal_role, + vault_id=vault_id, + ) + + async def get_auth_service( request: Request, *, @@ -25,13 +50,12 @@ async def get_auth_service( if resolved is None: resolved = await request.app.state.ownership_resolver.resolve(identity=identity) request.app.state.ownership_cache[identity] = resolved - return AuthService( - vault=request.app.state.vault, + return build_auth_service( + request, identity=identity, principal_id=resolved.principal_id, principal_role=resolved.role, vault_id=resolved.vault_id, - provider_definitions=request.app.state.provider_definition_repository, ) if principal_id is None: @@ -43,13 +67,12 @@ async def get_auth_service( principal = await request.app.state.store.principals.get(principal_id) if principal is None: return None - return AuthService( - vault=request.app.state.vault, + return build_auth_service( + request, identity=None, principal_id=principal_id, principal_role=principal.role, vault_id=binding.vault_id, - provider_definitions=request.app.state.provider_definition_repository, ) diff --git a/src/authsome/server/routes/ui.py b/src/authsome/server/routes/ui.py index c3592893..c3eeb516 100644 --- a/src/authsome/server/routes/ui.py +++ b/src/authsome/server/routes/ui.py @@ -26,6 +26,7 @@ from authsome.server.credential_service import AuthService from authsome.server.routes._deps import ( UI_SESSION_COOKIE_NAME, + build_auth_service, get_auth_service, get_auth_sessions, get_protected_auth_service, @@ -328,12 +329,11 @@ async def _provider_connection_groups( groups: list[dict[str, Any]] = [] for binding in await bindings.list_for_principal(principal_id): - scoped_auth = AuthService( - vault=request.app.state.vault, + scoped_auth = build_auth_service( + request, identity=identity, principal_id=principal_id, vault_id=binding.vault_id, - provider_definitions=request.app.state.provider_definition_repository, ) provider_connections = next( (group["connections"] for group in await scoped_auth.list_connections() if group["name"] == provider_name), diff --git a/src/authsome/server/schemas.py b/src/authsome/server/schemas.py index c4900019..0b3af3f1 100644 --- a/src/authsome/server/schemas.py +++ b/src/authsome/server/schemas.py @@ -11,6 +11,8 @@ from pydantic import BaseModel, Field +from authsome.utils import utc_now + class HealthResponse(BaseModel): status: Literal["ok"] @@ -105,3 +107,22 @@ class ProviderRoute(BaseModel): class ProxyRoutesResponse(BaseModel): routes: list[ProviderRoute] + + +class VaultRecord(BaseModel): + """Vault record owned as a first-class server resource.""" + + vault_id: str + handle: str = "default" + created_at: datetime = Field(default_factory=utc_now) + updated_at: datetime = Field(default_factory=utc_now) + + +class PrincipalVaultBindingRecord(BaseModel): + """Server-owned binding from principal to vault.""" + + principal_id: str + vault_id: str + is_default: bool = False + created_at: datetime = Field(default_factory=utc_now) + updated_at: datetime = Field(default_factory=utc_now) diff --git a/src/authsome/server/store/repositories.py b/src/authsome/server/store/repositories.py index e5984e4f..f5cab606 100644 --- a/src/authsome/server/store/repositories.py +++ b/src/authsome/server/store/repositories.py @@ -18,10 +18,9 @@ IdentityClaimRecord, PrincipalRecord, PrincipalRole, - PrincipalVaultBindingRecord, - VaultRecord, ) from authsome.identity.registry import IdentityRegistration +from authsome.server.schemas import PrincipalVaultBindingRecord, VaultRecord from authsome.server.store.database import StoreBackend, StoreDatabase from authsome.utils import utc_now diff --git a/src/authsome/utils.py b/src/authsome/utils.py index ba61efec..8be90c57 100644 --- a/src/authsome/utils.py +++ b/src/authsome/utils.py @@ -4,21 +4,11 @@ import re from datetime import UTC, datetime -from typing import Any, NamedTuple +from typing import Any from authsome.errors import AuthsomeError -class StoreKeyParts(NamedTuple): - """Parsed components of a credential store key.""" - - vault: str | None = None - identity: str | None = None - provider: str | None = None - record_type: str | None = None - connection: str | None = None - - def utc_now() -> datetime: """Return the current UTC datetime.""" return datetime.now(UTC) @@ -70,124 +60,6 @@ def is_filesystem_safe(name: str) -> bool: return True -def build_store_key( - *, - vault: str | None = None, - identity: str | None = None, - provider: str | None = None, - record_type: str | None = None, - connection: str | None = None, -) -> str: - """ - Build a namespaced key for the credential store. - - Spec §10.1 key namespace: - provider::definition - server:provider::client - vault:::metadata - vault:::state - vault:::connection: - vault:::client - identity:::metadata - identity:::state - identity:::connection: - identity:::client - """ - if record_type == "definition" and provider: - return f"provider:{provider}:definition" - if record_type == "server" and provider: - return f"server:provider:{provider}:client" - - if vault and provider: - if record_type == "metadata": - return f"vault:{vault}:{provider}:metadata" - elif record_type == "state": - return f"vault:{vault}:{provider}:state" - elif record_type == "connection" and connection: - return f"vault:{vault}:{provider}:connection:{connection}" - elif record_type == "client": - return f"vault:{vault}:{provider}:client" - - if identity and provider: - if record_type == "metadata": - return f"identity:{identity}:{provider}:metadata" - elif record_type == "state": - return f"identity:{identity}:{provider}:state" - elif record_type == "connection" and connection: - return f"identity:{identity}:{provider}:connection:{connection}" - elif record_type == "client": - return f"identity:{identity}:{provider}:client" - - raise ValueError( - f"Cannot build store key with vault={vault}, identity={identity}, provider={provider}, " - f"record_type={record_type}, connection={connection}" - ) - - -def parse_store_key(key: str) -> StoreKeyParts: - """ - Parse a credential store key into its components. - - Safely handles provider and connection names that may contain colons. - """ - if key.startswith("provider:") and key.endswith(":definition"): - provider = key[len("provider:") : -len(":definition")] - return StoreKeyParts(provider=provider, record_type="definition") - - if key.startswith("server:provider:") and key.endswith(":client"): - provider = key[len("server:provider:") : -len(":client")] - return StoreKeyParts(provider=provider, record_type="server") - - if key.startswith("vault:"): - parts = key.split(":", 2) - if len(parts) < 3: - return StoreKeyParts() - vault = parts[1] - remainder = parts[2] - - if remainder.endswith(":metadata"): - return StoreKeyParts(vault=vault, provider=remainder[:-9], record_type="metadata") - if remainder.endswith(":state"): - return StoreKeyParts(vault=vault, provider=remainder[:-6], record_type="state") - if remainder.endswith(":client"): - return StoreKeyParts(vault=vault, provider=remainder[:-7], record_type="client") - - if ":connection:" in remainder: - provider, _, connection = remainder.partition(":connection:") - return StoreKeyParts( - vault=vault, - provider=provider, - record_type="connection", - connection=connection, - ) - - if key.startswith("identity:"): - # Format: identity:: - parts = key.split(":", 2) - if len(parts) < 3: - return StoreKeyParts() - identity = parts[1] - remainder = parts[2] - - if remainder.endswith(":metadata"): - return StoreKeyParts(identity=identity, provider=remainder[:-9], record_type="metadata") - if remainder.endswith(":state"): - return StoreKeyParts(identity=identity, provider=remainder[:-6], record_type="state") - if remainder.endswith(":client"): - return StoreKeyParts(identity=identity, provider=remainder[:-7], record_type="client") - - if ":connection:" in remainder: - provider, _, connection = remainder.partition(":connection:") - return StoreKeyParts( - identity=identity, - provider=provider, - record_type="connection", - connection=connection, - ) - - return StoreKeyParts() - - def redact(record: Any, redacted_value: str = "***REDACTED***") -> dict[str, Any]: """ Return a dict of a Pydantic model with Sensitive-annotated fields replaced. diff --git a/tests/auth/test_browser_service.py b/tests/auth/test_browser_service.py index 0f8727d8..28ab5772 100644 --- a/tests/auth/test_browser_service.py +++ b/tests/auth/test_browser_service.py @@ -11,14 +11,39 @@ from authsome.auth.models.enums import AuthType, ConnectionStatus, FlowType from authsome.auth.models.provider import BrowserConfig, ExtractRule, ProviderDefinition from authsome.errors import CredentialMissingError, TokenExpiredError +from authsome.server.credential_repository import CredentialRepository from authsome.server.credential_service import AuthService from authsome.utils import utc_now +class StaticProviders: + async def get(self, name: str): # noqa: ANN001, ANN201 + return _provider() + + async def list(self): # noqa: ANN201 + return [_provider()] + + async def list_by_source(self): # noqa: ANN201 + return {"bundled": [_provider()], "custom": []} + + async def save_custom(self, definition, *, force: bool = False) -> None: # noqa: ANN001 + raise AssertionError("unexpected provider save") + + async def delete_custom(self, name: str) -> bool: + return False + + async def is_custom(self, name: str) -> bool: + return False + + def _svc() -> AuthService: vault = MagicMock() return AuthService( - vault=vault, identity="agent", principal_id="p1", vault_id="v1", provider_definitions=[_provider()] + credentials=CredentialRepository(vault, identity="agent", principal_id="p1", vault_id="v1"), + providers=StaticProviders(), + identity="agent", + principal_id="p1", + vault_id="v1", ) diff --git a/tests/auth/test_service.py b/tests/auth/test_service.py index eb45cad8..7c16bc31 100644 --- a/tests/auth/test_service.py +++ b/tests/auth/test_service.py @@ -9,21 +9,36 @@ from authsome.auth.models.enums import AuthType, ConnectionStatus from authsome.errors import RefreshFailedError from authsome.server.audit import ServerAuditLog, configure_server_audit_log +from authsome.server.credential_repository import CredentialRepository from authsome.server.credential_service import AuthService from authsome.utils import utc_now -class EmptyProviderDefinitions: +class EmptyProviders: async def get(self, name: str): # noqa: ANN001, ANN201 - return None + from authsome.errors import ProviderNotFoundError + + raise ProviderNotFoundError(name) async def list(self): # noqa: ANN201 return [] - async def save(self, definition, *, force: bool = False) -> None: # noqa: ANN001 - raise AssertionError("unexpected provider definition save") + async def list_by_source(self): # noqa: ANN201 + return {"bundled": [], "custom": []} + + async def save_custom(self, definition, *, force: bool = False) -> None: # noqa: ANN001 + raise AssertionError("unexpected provider save") + + async def delete_custom(self, name: str) -> bool: + return False + - async def delete(self, name: str) -> bool: +def _credentials( + vault, *, identity: str | None = "agent-a", principal_id: str | None = None, vault_id: str = "vault_default" +): # noqa: ANN001 + return CredentialRepository(vault, identity=identity, principal_id=principal_id, vault_id=vault_id) + + async def is_custom(self, name: str) -> bool: return False @@ -41,10 +56,10 @@ def audit_log(self, tmp_path) -> ServerAuditLog: # noqa: ANN001 def service(self) -> AuthService: mock_vault = mock.AsyncMock() return AuthService( - mock_vault, + credentials=_credentials(mock_vault, identity="test-profile", vault_id="test-vault"), + providers=EmptyProviders(), identity="test-profile", vault_id="test-vault", - provider_definitions=EmptyProviderDefinitions(), ) async def test_refresh_failure_fallback_available(self, audit_log: ServerAuditLog, service: AuthService): @@ -131,11 +146,11 @@ async def test_refresh_failure_expired(self, audit_log: ServerAuditLog, service: def test_auth_service_allows_missing_identity() -> None: mock_vault = mock.AsyncMock() service = AuthService( - mock_vault, + credentials=_credentials(mock_vault, identity=None, principal_id="principal_1"), + providers=EmptyProviders(), identity=None, principal_id="principal_1", vault_id="vault_default", - provider_definitions=EmptyProviderDefinitions(), ) assert service.identity is None @@ -143,10 +158,21 @@ def test_auth_service_allows_missing_identity() -> None: def test_auth_service_scopes_collection_by_vault_id() -> None: mock_vault = mock.AsyncMock() service = AuthService( - mock_vault, + credentials=_credentials(mock_vault, identity="agent-a", principal_id="principal_1"), + providers=EmptyProviders(), identity="agent-a", principal_id="principal_1", vault_id="vault_default", - provider_definitions=EmptyProviderDefinitions(), ) - assert service._coll == "vault:vault_default" + assert service._credentials.collection == "vault:vault_default" + + +def test_auth_service_requires_providers() -> None: + mock_vault = mock.AsyncMock() + + with pytest.raises(TypeError): + AuthService( + credentials=_credentials(mock_vault, identity="agent-a"), + identity="agent-a", + vault_id="vault_default", + ) # type: ignore[call-arg] diff --git a/tests/auth/test_service_provider_clients.py b/tests/auth/test_service_provider_clients.py index e7861c4c..c9d9ab18 100644 --- a/tests/auth/test_service_provider_clients.py +++ b/tests/auth/test_service_provider_clients.py @@ -14,30 +14,43 @@ from authsome.errors import OperationNotAllowedError from authsome.identity import create_identity from authsome.identity.principal import PrincipalRole +from authsome.server.credential_repository import CredentialRepository, build_store_key from authsome.server.credential_service import AuthService from authsome.server.dependencies import ( create_store, create_vault, ) -from authsome.utils import build_store_key +from authsome.server.provider_repository import ProviderRepository -class EmptyProviderDefinitions: +class EmptyProviders: async def get(self, name: str): # noqa: ANN001, ANN201 - return None + from authsome.errors import ProviderNotFoundError + + raise ProviderNotFoundError(name) async def list(self): # noqa: ANN201 return [] - async def save(self, definition, *, force: bool = False) -> None: # noqa: ANN001 - raise AssertionError("unexpected provider definition save") + async def list_by_source(self): # noqa: ANN201 + return {"bundled": [], "custom": []} + + async def save_custom(self, definition, *, force: bool = False) -> None: # noqa: ANN001 + raise AssertionError("unexpected provider save") + + async def delete_custom(self, name: str) -> bool: + return False - async def delete(self, name: str) -> bool: + async def is_custom(self, name: str) -> bool: return False def _service(vault, **kwargs) -> AuthService: # noqa: ANN001, ANN003 - return AuthService(vault, provider_definitions=EmptyProviderDefinitions(), **kwargs) + identity = kwargs.get("identity") + principal_id = kwargs.get("principal_id") + vault_id = kwargs.get("vault_id", "vault_default") + credentials = CredentialRepository(vault, identity=identity, principal_id=principal_id, vault_id=vault_id) + return AuthService(credentials=credentials, providers=EmptyProviders(), **kwargs) def _make_provider(*, flow: FlowType = FlowType.PKCE) -> ProviderDefinition: @@ -371,12 +384,17 @@ async def test_revoke_local_deletes_shared_client_and_all_identity_connections(t vault = await create_vault(store.home) try: service = AuthService( - vault, + credentials=CredentialRepository( + vault, + identity="steady-wisely-boldly-0042", + principal_id="principal_1", + vault_id=primary_vault.vault_id, + ), + providers=ProviderRepository(store.provider_definitions), identity="steady-wisely-boldly-0042", principal_id="principal_1", principal_role=PrincipalRole.ADMIN, vault_id=primary_vault.vault_id, - provider_definitions=store.provider_definitions, ) primary_connection = ConnectionRecord( diff --git a/tests/auth/test_service_provider_definitions.py b/tests/auth/test_service_provider_definitions.py index 08883bdf..9073a343 100644 --- a/tests/auth/test_service_provider_definitions.py +++ b/tests/auth/test_service_provider_definitions.py @@ -9,7 +9,9 @@ from authsome.auth.models.enums import AuthType, FlowType from authsome.auth.models.provider import ApiKeyConfig, ProviderDefinition from authsome.identity.principal import PrincipalRole +from authsome.server.credential_repository import CredentialRepository from authsome.server.credential_service import AuthService +from authsome.server.provider_repository import ProviderRepository from authsome.server.store import create_server_store from authsome.vault import Vault @@ -30,11 +32,16 @@ async def test_custom_provider_definition_is_stored_in_store_not_vault(tmp_path: try: vault = AsyncMock(spec=Vault) service = AuthService( - vault=vault, + credentials=CredentialRepository( + vault, + identity="steady-wisely-boldly-0042", + principal_id=None, + vault_id="vault_test", + ), + providers=ProviderRepository(store.provider_definitions), identity="steady-wisely-boldly-0042", principal_role=PrincipalRole.ADMIN, vault_id="vault_test", - provider_definitions=store.provider_definitions, ) await service.register_provider(_provider()) @@ -51,10 +58,15 @@ async def test_provider_client_credentials_still_use_vault(tmp_path: Path) -> No try: vault = AsyncMock(spec=Vault) service = AuthService( - vault=vault, + credentials=CredentialRepository( + vault, + identity="steady-wisely-boldly-0042", + principal_id=None, + vault_id="vault_test", + ), + providers=ProviderRepository(store.provider_definitions), identity="steady-wisely-boldly-0042", vault_id="vault_test", - provider_definitions=store.provider_definitions, ) await service._save_provider_client_credentials(ProviderClientRecord(provider="github", client_id="cid")) diff --git a/tests/common/test_utils.py b/tests/common/test_utils.py index c51233ef..5fee1beb 100644 --- a/tests/common/test_utils.py +++ b/tests/common/test_utils.py @@ -2,14 +2,9 @@ from datetime import UTC, datetime -import pytest - from authsome.utils import ( - StoreKeyParts, - build_store_key, is_filesystem_safe, parse_rfc3339, - parse_store_key, to_rfc3339, utc_now, ) @@ -49,53 +44,3 @@ def test_is_filesystem_safe(): assert is_filesystem_safe("bad..name") is False assert is_filesystem_safe("bad\\name") is False assert is_filesystem_safe(".hidden") is False - - -def test_build_store_key(): - # Test definition key - assert build_store_key(record_type="definition", provider="github") == "provider:github:definition" - assert build_store_key(vault="vault_default", provider="github", record_type="metadata") == ( - "vault:vault_default:github:metadata" - ) - # Test metadata key - assert ( - build_store_key(identity="default", provider="github", record_type="metadata") - == "identity:default:github:metadata" - ) - # Test state key - assert ( - build_store_key(identity="default", provider="github", record_type="state") == "identity:default:github:state" - ) - # Test connection key - assert ( - build_store_key( - identity="default", - provider="github", - record_type="connection", - connection="personal", - ) - == "identity:default:github:connection:personal" - ) - # Test client key - assert ( - build_store_key(identity="default", provider="github", record_type="client") == "identity:default:github:client" - ) - # Test server-scoped client key - assert build_store_key(provider="github", record_type="server") == "server:provider:github:client" - # Test value error - with pytest.raises(ValueError): - build_store_key(identity="default", provider="github", record_type="unknown") - - # Test missing provider with identity - with pytest.raises(ValueError): - build_store_key(identity="default", record_type="metadata") - - -def test_parse_store_key_server() -> None: - assert parse_store_key("server:provider:github:client") == StoreKeyParts( - vault=None, - identity=None, - provider="github", - record_type="server", - connection=None, - ) diff --git a/tests/server/test_credential_repository.py b/tests/server/test_credential_repository.py new file mode 100644 index 00000000..5b8c5500 --- /dev/null +++ b/tests/server/test_credential_repository.py @@ -0,0 +1,144 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +from authsome.auth.models.connection import ( + ConnectionRecord, + ProviderClientRecord, + ProviderMetadataRecord, + ProviderStateRecord, +) +from authsome.auth.models.enums import AuthType, ConnectionStatus +from authsome.server.credential_repository import CredentialRepository, StoreKeyParts, build_store_key, parse_store_key +from authsome.server.dependencies import create_vault + + +def _connection() -> ConnectionRecord: + return ConnectionRecord( + provider="github", + connection_name="default", + auth_type=AuthType.OAUTH2, + status=ConnectionStatus.CONNECTED, + access_token="access-token", + ) + + +@pytest.mark.asyncio +async def test_connection_roundtrip_preserves_vault_key_format(tmp_path: Path) -> None: + vault = await create_vault(tmp_path) + repo = CredentialRepository( + vault, + identity="steady-wisely-boldly-0042", + principal_id="principal_1", + vault_id="vault_1", + ) + + await repo.save_connection(_connection()) + + key = build_store_key(vault="vault_1", provider="github", record_type="connection", connection="default") + raw = await vault.get(key, collection="vault:vault_1") + loaded = await repo.get_connection("github", "default") + + assert raw is not None + assert loaded is not None + assert loaded.identity == "steady-wisely-boldly-0042" + assert loaded.principal_id == "principal_1" + assert loaded.vault_id == "vault_1" + assert loaded.access_token == "access-token" + + +@pytest.mark.asyncio +async def test_connection_delete(tmp_path: Path) -> None: + vault = await create_vault(tmp_path) + repo = CredentialRepository(vault, identity="agent", principal_id="principal_1", vault_id="vault_1") + + await repo.save_connection(_connection()) + await repo.delete_connection("github", "default") + + assert await repo.get_connection("github", "default") is None + + +@pytest.mark.asyncio +async def test_metadata_state_and_client_roundtrip(tmp_path: Path) -> None: + vault = await create_vault(tmp_path) + repo = CredentialRepository(vault, identity="agent", principal_id="principal_1", vault_id="vault_1") + + metadata = ProviderMetadataRecord(provider="github", connection_names=["default"]) + state = ProviderStateRecord(provider="github") + client = ProviderClientRecord(provider="github", client_id="cid", client_secret="secret") + + await repo.save_provider_metadata(metadata) + await repo.save_provider_state(state) + await repo.save_provider_client(client) + + loaded_metadata = await repo.get_provider_metadata("github") + loaded_state = await repo.get_provider_state("github") + loaded_client = await repo.get_provider_client("github") + + assert loaded_metadata is not None + assert loaded_metadata.vault_id == "vault_1" + assert loaded_metadata.connection_names == ["default"] + assert loaded_state is not None + assert loaded_state.vault_id == "vault_1" + assert loaded_client is not None + assert loaded_client.client_id == "cid" + + server_key = build_store_key(provider="github", record_type="server") + assert await vault.get(server_key, collection="server") is not None + + +@pytest.mark.asyncio +async def test_list_connection_keys_returns_existing_connection_keys(tmp_path: Path) -> None: + vault = await create_vault(tmp_path) + repo = CredentialRepository(vault, identity="agent", principal_id="principal_1", vault_id="vault_1") + + await repo.save_connection(_connection()) + + keys = await repo.list_connection_keys() + + assert keys == [build_store_key(vault="vault_1", provider="github", record_type="connection", connection="default")] + + +def test_build_store_key() -> None: + assert build_store_key(record_type="definition", provider="github") == "provider:github:definition" + assert build_store_key(vault="vault_default", provider="github", record_type="metadata") == ( + "vault:vault_default:github:metadata" + ) + assert ( + build_store_key(identity="default", provider="github", record_type="metadata") + == "identity:default:github:metadata" + ) + assert ( + build_store_key(identity="default", provider="github", record_type="state") == "identity:default:github:state" + ) + assert ( + build_store_key( + identity="default", + provider="github", + record_type="connection", + connection="personal", + ) + == "identity:default:github:connection:personal" + ) + assert ( + build_store_key(identity="default", provider="github", record_type="client") == "identity:default:github:client" + ) + assert build_store_key(provider="github", record_type="server") == "server:provider:github:client" + + with pytest.raises(ValueError): + build_store_key(identity="default", provider="github", record_type="unknown") + + with pytest.raises(ValueError): + build_store_key(identity="default", record_type="metadata") + + +def test_parse_store_key_server() -> None: + assert parse_store_key("server:provider:github:client") == StoreKeyParts( + vault=None, + identity=None, + provider="github", + record_type="server", + connection=None, + ) diff --git a/tests/server/test_pop_auth.py b/tests/server/test_pop_auth.py index 43fb97b7..4a41d53a 100644 --- a/tests/server/test_pop_auth.py +++ b/tests/server/test_pop_auth.py @@ -11,7 +11,7 @@ from authsome.identity import create_identity, load_private_key from authsome.identity.proof import create_proof_jwt from authsome.server.app import create_app -from authsome.utils import build_store_key +from authsome.server.credential_repository import build_store_key def _auth_header( diff --git a/tests/server/test_provider_repository.py b/tests/server/test_provider_repository.py new file mode 100644 index 00000000..d044988b --- /dev/null +++ b/tests/server/test_provider_repository.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +from authsome.auth.models.enums import AuthType, FlowType +from authsome.auth.models.provider import ApiKeyConfig, ProviderDefinition +from authsome.server.provider_repository import ProviderRepository +from authsome.server.store import create_server_store + + +def _custom_provider(name: str = "github") -> ProviderDefinition: + return ProviderDefinition( + name=name, + display_name="Custom GitHub", + auth_type=AuthType.API_KEY, + flow=FlowType.API_KEY, + api_key=ApiKeyConfig(header_name="Authorization"), + ) + + +@pytest.mark.asyncio +async def test_bundled_provider_loads(tmp_path: Path) -> None: + store = await create_server_store(home=tmp_path) + try: + providers = ProviderRepository(store.provider_definitions) + + github = await providers.get("github") + + assert github.name == "github" + assert await providers.is_custom("github") is False + finally: + await store.close() + + +@pytest.mark.asyncio +async def test_custom_provider_overrides_bundled(tmp_path: Path) -> None: + store = await create_server_store(home=tmp_path) + try: + providers = ProviderRepository(store.provider_definitions) + await providers.save_custom(_custom_provider("github"), force=True) + + github = await providers.get("github") + + assert github.display_name == "Custom GitHub" + assert await providers.is_custom("github") is True + finally: + await store.close() + + +@pytest.mark.asyncio +async def test_delete_custom_exposes_bundled_fallback(tmp_path: Path) -> None: + store = await create_server_store(home=tmp_path) + try: + providers = ProviderRepository(store.provider_definitions) + await providers.save_custom(_custom_provider("github"), force=True) + + removed = await providers.delete_custom("github") + github = await providers.get("github") + + assert removed is True + assert github.name == "github" + assert github.display_name != "Custom GitHub" + assert await providers.is_custom("github") is False + finally: + await store.close() + + +@pytest.mark.asyncio +async def test_list_by_source_separates_bundled_and_custom(tmp_path: Path) -> None: + store = await create_server_store(home=tmp_path) + try: + providers = ProviderRepository(store.provider_definitions) + await providers.save_custom(_custom_provider("custom-only"), force=True) + + by_source = await providers.list_by_source() + + assert any(provider.name == "github" for provider in by_source["bundled"]) + assert [provider.name for provider in by_source["custom"]] == ["custom-only"] + finally: + await store.close() diff --git a/tests/server/test_ui_dashboard.py b/tests/server/test_ui_dashboard.py index e5e75b95..0cd01364 100644 --- a/tests/server/test_ui_dashboard.py +++ b/tests/server/test_ui_dashboard.py @@ -11,7 +11,8 @@ from authsome.identity import create_identity, load_private_key from authsome.identity.proof import create_proof_jwt from authsome.server.app import create_app -from authsome.utils import build_store_key, utc_now +from authsome.server.credential_repository import build_store_key +from authsome.utils import utc_now def _auth_header(tmp_path: Path, method: str, path: str, *, handle: str) -> dict[str, str]: diff --git a/tests/server/test_ui_sessions.py b/tests/server/test_ui_sessions.py index 6a40aff4..b2980c9b 100644 --- a/tests/server/test_ui_sessions.py +++ b/tests/server/test_ui_sessions.py @@ -8,8 +8,8 @@ from authsome.identity import create_identity, load_private_key from authsome.identity.proof import create_proof_jwt from authsome.server.app import create_app +from authsome.server.credential_repository import build_store_key from authsome.server.ui_sessions import UiSessionStore -from authsome.utils import build_store_key def _auth_header(tmp_path: Path, method: str, path: str, *, handle: str) -> dict[str, str]: diff --git a/tests/test_import_boundaries.py b/tests/test_import_boundaries.py new file mode 100644 index 00000000..b0713f62 --- /dev/null +++ b/tests/test_import_boundaries.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +import ast +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] +SRC = ROOT / "src" / "authsome" + + +def _imports_under(package: str) -> set[str]: + package_path = SRC / package + imports: set[str] = set() + for path in package_path.rglob("*.py"): + tree = ast.parse(path.read_text(), filename=str(path)) + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + if alias.name.startswith("authsome."): + imports.add(alias.name) + elif isinstance(node, ast.ImportFrom) and node.module: + if node.module.startswith("authsome."): + imports.add(node.module) + return imports + + +def _assert_no_imports(package: str, forbidden: set[str]) -> None: + imports = _imports_under(package) + own_prefix = f"authsome.{package}" + violations = sorted( + module + for module in imports + if module != own_prefix + and not module.startswith(f"{own_prefix}.") + and any(module == item or module.startswith(f"{item}.") for item in forbidden) + ) + assert violations == [] + + +def test_auth_is_orthogonal_library_module() -> None: + _assert_no_imports("auth", {"authsome.server", "authsome.vault", "authsome.identity"}) + + +def test_vault_is_orthogonal_library_module() -> None: + _assert_no_imports("vault", {"authsome.server", "authsome.auth", "authsome.identity"}) + + +def test_identity_is_orthogonal_library_module() -> None: + _assert_no_imports("identity", {"authsome.server", "authsome.auth", "authsome.vault"})