Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 31 additions & 5 deletions reboot/examples/agent-wiki/backend/src/servicers/wiki.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,29 @@
def _caller_is_owner(
*,
context: ReaderContext,
state: Union[Wiki.State, Transcript.State, Page.State],
state: Union[Wiki.State, Page.State, Transcript.State],
**kwargs,
):
"""Allow when the caller's `user_id` matches the stored owner ID."""
"""Allow when the caller's `user_id` matches `state.owner_id`."""
if context.auth is None or context.auth.user_id is None:
return Unauthenticated()
if state is not None and context.auth.user_id == state.owner_id:
return Ok()
return PermissionDenied()


def _caller_is_authenticated(
*,
context: ReaderContext,
**kwargs,
):
"""Allow any authenticated caller. Used for factory `create` methods
where no state exists yet to check ownership against."""
if context.auth is None or context.auth.user_id is None:
return Unauthenticated()
return Ok()


def _truncate(value: object, limit: int = 500) -> str:
"""Render `value` as a string, shortened for log lines so a
big tool payload doesn't flood the output."""
Expand Down Expand Up @@ -309,7 +321,13 @@ class WikiServicer(Wiki.Servicer):
`ingest` workflow that folds transcripts into it."""

def authorizer(self):
return allow_if(any=[_caller_is_owner, is_app_internal])
return Wiki.Authorizer(
create=allow_if(any=[_caller_is_authenticated, is_app_internal]),
get=allow_if(any=[_caller_is_owner, is_app_internal]),
update=allow_if(any=[_caller_is_owner, is_app_internal]),
add_transcript=allow_if(any=[_caller_is_owner, is_app_internal]),
ingest=allow_if(any=[_caller_is_owner, is_app_internal]),
)

async def create(
self,
Expand Down Expand Up @@ -445,7 +463,11 @@ class PageServicer(Page.Servicer):
a title."""

def authorizer(self):
return allow_if(any=[_caller_is_owner, is_app_internal])
return Page.Authorizer(
create=allow_if(any=[_caller_is_authenticated, is_app_internal]),
get=allow_if(any=[_caller_is_owner, is_app_internal]),
update=allow_if(any=[_caller_is_owner, is_app_internal]),
)

async def create(
self,
Expand Down Expand Up @@ -479,7 +501,11 @@ class TranscriptServicer(Transcript.Servicer):
transcript)."""

def authorizer(self):
return allow_if(any=[_caller_is_owner, is_app_internal])
return Transcript.Authorizer(
create=allow_if(any=[_caller_is_authenticated, is_app_internal]),
get=allow_if(any=[_caller_is_owner, is_app_internal]),
update=allow_if(any=[_caller_is_owner, is_app_internal]),
)

async def create(
self,
Expand Down
48 changes: 8 additions & 40 deletions reboot/examples/agent-wiki/backend/tests/wiki_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
)
from pydantic_ai.models.function import AgentInfo, FunctionModel
from reboot.aio.applications import Application
from reboot.aio.auth.authorizers import allow
from reboot.aio.tests import Reboot
from servicers import wiki as wiki_module
from servicers.wiki import (
Expand All @@ -29,44 +28,11 @@
WikiServicer,
)

# Production servicers intentionally don't define an
# `authorizer()`: in development Reboot defaults to allow-all,
# but in production an absent authorizer denies by default,
# which we rely on so that no permissive code accidentally
# ships. The tests run against the production-mode harness, so
# we extend each servicer here and grant `allow()` for the
# duration of the suite.


class PermissiveUserServicer(UserServicer):

def authorizer(self):
return allow()


class PermissiveWikiServicer(WikiServicer):

def authorizer(self):
return allow()


class PermissivePageServicer(PageServicer):

def authorizer(self):
return allow()


class PermissiveTranscriptServicer(TranscriptServicer):

def authorizer(self):
return allow()


APPLICATION_SERVICERS = [
PermissiveUserServicer,
PermissiveWikiServicer,
PermissivePageServicer,
PermissiveTranscriptServicer,
UserServicer,
WikiServicer,
PageServicer,
TranscriptServicer,
]


Expand Down Expand Up @@ -115,7 +81,7 @@ async def asyncSetUp(self) -> None:
# production the MCP session's "new session" hook
# calls `_auto_construct` for the authenticated user.
# Tests don't go through that hook, so we do it here.
await PermissiveUserServicer._auto_construct(
await UserServicer._auto_construct(
self.context,
state_id=self.user_id,
)
Expand Down Expand Up @@ -170,6 +136,7 @@ async def test_page_crud(self) -> None:
self.context,
title="My Page",
content="Initial body.",
owner_id=self.user_id,
)
got = await page.get(self.context)
self.assertEqual(got.title, "My Page")
Expand All @@ -194,6 +161,7 @@ async def test_transcript_crud(self) -> None:
transcript, _ = await Transcript.create(
self.context,
messages=messages,
owner_id=self.user_id,
)
got = await transcript.get(self.context)
self.assertEqual(len(got.messages), 2)
Expand Down Expand Up @@ -346,7 +314,7 @@ async def asyncSetUp(self) -> None:
# production the MCP session's "new session" hook
# calls `_auto_construct` for the authenticated user.
# Tests don't go through that hook, so we do it here.
await PermissiveUserServicer._auto_construct(
await UserServicer._auto_construct(
self.context,
state_id=self.user_id,
)
Expand Down
Loading