diff --git a/backend/src/servicers/wiki.py b/backend/src/servicers/wiki.py index bde3173..7df45ab 100644 --- a/backend/src/servicers/wiki.py +++ b/backend/src/servicers/wiki.py @@ -30,10 +30,10 @@ def _caller_is_owner( *, context: ReaderContext, - state: Union[Wiki.State, Transcript.State, Page.State], + state: Union[Wiki.State, Page.State, Transcript.State], **kwargs, ): - """Allow when the caller's `user_id` matches the stored owner ID.""" + """Allow when the caller's `user_id` matches `state.owner_id`.""" if context.auth is None or context.auth.user_id is None: return Unauthenticated() if state is not None and context.auth.user_id == state.owner_id: @@ -41,6 +41,18 @@ def _caller_is_owner( return PermissionDenied() +def _caller_is_authenticated( + *, + context: ReaderContext, + **kwargs, +): + """Allow any authenticated caller. Used for factory `create` methods + where no state exists yet to check ownership against.""" + if context.auth is None or context.auth.user_id is None: + return Unauthenticated() + return Ok() + + def _truncate(value: object, limit: int = 500) -> str: """Render `value` as a string, shortened for log lines so a big tool payload doesn't flood the output.""" @@ -309,7 +321,13 @@ class WikiServicer(Wiki.Servicer): `ingest` workflow that folds transcripts into it.""" def authorizer(self): - return allow_if(any=[_caller_is_owner, is_app_internal]) + return Wiki.Authorizer( + create=allow_if(any=[_caller_is_authenticated, is_app_internal]), + get=allow_if(any=[_caller_is_owner, is_app_internal]), + update=allow_if(any=[_caller_is_owner, is_app_internal]), + add_transcript=allow_if(any=[_caller_is_owner, is_app_internal]), + ingest=allow_if(any=[_caller_is_owner, is_app_internal]), + ) async def create( self, @@ -445,7 +463,11 @@ class PageServicer(Page.Servicer): a title.""" def authorizer(self): - return allow_if(any=[_caller_is_owner, is_app_internal]) + return Page.Authorizer( + create=allow_if(any=[_caller_is_authenticated, is_app_internal]), + get=allow_if(any=[_caller_is_owner, is_app_internal]), + update=allow_if(any=[_caller_is_owner, is_app_internal]), + ) async def create( self, @@ -479,7 +501,11 @@ class TranscriptServicer(Transcript.Servicer): transcript).""" def authorizer(self): - return allow_if(any=[_caller_is_owner, is_app_internal]) + return Transcript.Authorizer( + create=allow_if(any=[_caller_is_authenticated, is_app_internal]), + get=allow_if(any=[_caller_is_owner, is_app_internal]), + update=allow_if(any=[_caller_is_owner, is_app_internal]), + ) async def create( self, diff --git a/backend/tests/wiki_test.py b/backend/tests/wiki_test.py index f43490d..ed77bec 100644 --- a/backend/tests/wiki_test.py +++ b/backend/tests/wiki_test.py @@ -19,7 +19,6 @@ ) from pydantic_ai.models.function import AgentInfo, FunctionModel from reboot.aio.applications import Application -from reboot.aio.auth.authorizers import allow from reboot.aio.tests import Reboot from servicers import wiki as wiki_module from servicers.wiki import ( @@ -29,44 +28,11 @@ WikiServicer, ) -# Production servicers intentionally don't define an -# `authorizer()`: in development Reboot defaults to allow-all, -# but in production an absent authorizer denies by default, -# which we rely on so that no permissive code accidentally -# ships. The tests run against the production-mode harness, so -# we extend each servicer here and grant `allow()` for the -# duration of the suite. - - -class PermissiveUserServicer(UserServicer): - - def authorizer(self): - return allow() - - -class PermissiveWikiServicer(WikiServicer): - - def authorizer(self): - return allow() - - -class PermissivePageServicer(PageServicer): - - def authorizer(self): - return allow() - - -class PermissiveTranscriptServicer(TranscriptServicer): - - def authorizer(self): - return allow() - - APPLICATION_SERVICERS = [ - PermissiveUserServicer, - PermissiveWikiServicer, - PermissivePageServicer, - PermissiveTranscriptServicer, + UserServicer, + WikiServicer, + PageServicer, + TranscriptServicer, ] @@ -115,7 +81,7 @@ async def asyncSetUp(self) -> None: # production the MCP session's "new session" hook # calls `_auto_construct` for the authenticated user. # Tests don't go through that hook, so we do it here. - await PermissiveUserServicer._auto_construct( + await UserServicer._auto_construct( self.context, state_id=self.user_id, ) @@ -170,6 +136,7 @@ async def test_page_crud(self) -> None: self.context, title="My Page", content="Initial body.", + owner_id=self.user_id, ) got = await page.get(self.context) self.assertEqual(got.title, "My Page") @@ -194,6 +161,7 @@ async def test_transcript_crud(self) -> None: transcript, _ = await Transcript.create( self.context, messages=messages, + owner_id=self.user_id, ) got = await transcript.get(self.context) self.assertEqual(len(got.messages), 2) @@ -346,7 +314,7 @@ async def asyncSetUp(self) -> None: # production the MCP session's "new session" hook # calls `_auto_construct` for the authenticated user. # Tests don't go through that hook, so we do it here. - await PermissiveUserServicer._auto_construct( + await UserServicer._auto_construct( self.context, state_id=self.user_id, )