diff --git a/agent_assembly/adapters/langchain/runtime.py b/agent_assembly/adapters/langchain/runtime.py index aca6841..70ef52d 100644 --- a/agent_assembly/adapters/langchain/runtime.py +++ b/agent_assembly/adapters/langchain/runtime.py @@ -8,25 +8,35 @@ from agent_assembly.adapters.crewai.patch import CrewAIPatch from agent_assembly.adapters.langchain.callback_handler import AssemblyCallbackHandler from agent_assembly.adapters.langgraph import LangGraphPatch +from agent_assembly.adapters.pydantic_ai.patch import PydanticAIPatch, set_process_agent_id _ACTIVE_CALLBACK_HANDLER: AssemblyCallbackHandler | None = None _RUNTIME_LOCK = Lock() -def auto_inject_callback_handler(interceptor: Any) -> AssemblyCallbackHandler: +def auto_inject_callback_handler( + interceptor: Any, + *, + process_agent_id: str | None = None, +) -> AssemblyCallbackHandler: """Create and register the active callback handler instance.""" global _ACTIVE_CALLBACK_HANDLER with _RUNTIME_LOCK: + if process_agent_id is not None: + set_process_agent_id(process_agent_id) + if _ACTIVE_CALLBACK_HANDLER is not None: LangGraphPatch(_ACTIVE_CALLBACK_HANDLER).apply() CrewAIPatch(interceptor).apply() + PydanticAIPatch(interceptor).apply() return _ACTIVE_CALLBACK_HANDLER handler = AssemblyCallbackHandler(interceptor) _ACTIVE_CALLBACK_HANDLER = handler LangGraphPatch(handler).apply() CrewAIPatch(interceptor).apply() + PydanticAIPatch(interceptor).apply() return handler diff --git a/agent_assembly/adapters/pydantic_ai/__init__.py b/agent_assembly/adapters/pydantic_ai/__init__.py new file mode 100644 index 0000000..328e92a --- /dev/null +++ b/agent_assembly/adapters/pydantic_ai/__init__.py @@ -0,0 +1,5 @@ +"""Pydantic AI adapter package.""" + +from agent_assembly.adapters.pydantic_ai.patch import PydanticAIPatch + +__all__ = ["PydanticAIPatch"] diff --git a/agent_assembly/adapters/pydantic_ai/patch.py b/agent_assembly/adapters/pydantic_ai/patch.py new file mode 100644 index 0000000..79139d4 --- /dev/null +++ b/agent_assembly/adapters/pydantic_ai/patch.py @@ -0,0 +1,283 @@ +"""Pydantic AI patch module.""" + +from __future__ import annotations + +from dataclasses import dataclass +from functools import wraps +import importlib +import inspect +from typing import Any, Literal, Mapping + +from agent_assembly.adapters.crewai.patch import ( + _get_pending_tool_approval_timeout_seconds as _resolve_pending_timeout_seconds, +) +from agent_assembly.adapters.crewai.patch import _normalize_decision as _normalize_governance_decision + +_ORIGINAL_TOOL_RUN = "_agent_assembly_original_pydantic_ai_tool_run" +_TOOLS_PATCHED_FLAG = "_agent_assembly_pydantic_ai_tools_patched" +_PROCESS_AGENT_ID: str | None = None +_MAX_AUDIT_RESULT_CHARS = 2000 + + +@dataclass(slots=True) +class PydanticAIPatch: + """Applies Pydantic AI runtime monkey-patching hooks.""" + + callback_handler: Any + + def apply(self) -> bool: + """Apply patch wiring and return whether Pydantic AI is available.""" + tool_cls = _load_pydantic_ai_tool_class() + if tool_cls is None: + return False + _apply_tool_run_patch(tool_cls, self.callback_handler) + return True + + +class AssemblyModelWrapper: + """Optional model wrapper for LLM input scan-forward interception.""" + + def __init__(self, model: Any, callback_handler: Any) -> None: + self._model = model + self._callback_handler = callback_handler + + async def request(self, *args: Any, **kwargs: Any) -> Any: + scan_method = getattr(self._callback_handler, "on_llm_start_scan", None) + if callable(scan_method): + scan_result = scan_method( + serialized={"name": self._model.__class__.__name__}, + prompts=[str(args[0])] if args else [], + run_id=kwargs.get("run_id"), + ) + if inspect.isawaitable(scan_result): + await scan_result + + result = self._model.request(*args, **kwargs) + if inspect.isawaitable(result): + return await result + return result + + def __getattr__(self, name: str) -> Any: + return getattr(self._model, name) + + +def _load_pydantic_ai_tool_class() -> type[Any] | None: + try: + module = importlib.import_module("pydantic_ai.tools") + except ImportError: + return None + + tool_cls = getattr(module, "Tool", None) + if isinstance(tool_cls, type): + return tool_cls + return None + + +def _apply_tool_run_patch(tool_cls: type[Any], callback_handler: Any) -> None: + if getattr(tool_cls, _TOOLS_PATCHED_FLAG, False): + return None + + original_run = tool_cls._run + + @wraps(original_run) + async def patched_run(self: Any, ctx: Any, args: Any, **kwargs: Any) -> Any: + tool_name = str(getattr(self, "name", self.__class__.__name__)) + tool_args = _serialize_tool_args(args) + agent_id = _resolve_agent_id(ctx) + run_id = _resolve_run_id(ctx) + + decision = await _invoke_async_tool_check( + callback_handler, + tool_name=tool_name, + tool_args=tool_args, + agent_id=agent_id, + run_id=run_id, + ) + status, reason = _normalize_decision(decision) + is_pending_flow = False + if status == "pending": + is_pending_flow = True + timeout_seconds = _get_pending_tool_approval_timeout_seconds(callback_handler) + final_decision = await _wait_for_async_tool_approval( + callback_handler, + tool_name=tool_name, + timeout_seconds=timeout_seconds, + tool_args=tool_args, + agent_id=agent_id, + run_id=run_id, + ) + status, reason = _normalize_decision(final_decision) + + if status == "deny": + if is_pending_flow: + raise _build_pending_rejected_error(tool_name, reason) + raise _build_denied_error(tool_name, reason) + + result = original_run(self, ctx, args, **kwargs) + if inspect.isawaitable(result): + result = await result + + await _record_async_tool_result( + callback_handler, + tool_name=tool_name, + result=result, + agent_id=agent_id, + run_id=run_id, + ) + return result + + setattr(tool_cls, _ORIGINAL_TOOL_RUN, original_run) + setattr(tool_cls, "_run", patched_run) + setattr(tool_cls, _TOOLS_PATCHED_FLAG, True) + + +def set_process_agent_id(agent_id: str | None) -> None: + global _PROCESS_AGENT_ID + _PROCESS_AGENT_ID = agent_id + + +def _get_process_agent_id() -> str | None: + if isinstance(_PROCESS_AGENT_ID, str) and _PROCESS_AGENT_ID: + return _PROCESS_AGENT_ID + return None + + +def _resolve_agent_id(ctx: Any) -> str | None: + deps = getattr(ctx, "deps", None) + candidate = getattr(deps, "assembly_agent_id", None) + if isinstance(candidate, str) and candidate: + return candidate + return _get_process_agent_id() + + +def _resolve_run_id(ctx: Any) -> str | None: + run_id = getattr(ctx, "run_id", None) + if run_id is None: + return None + return str(run_id) + + +def _serialize_tool_args(args: Any) -> dict[str, Any]: + if hasattr(args, "model_dump"): + model_dump = getattr(args, "model_dump") + if callable(model_dump): + dumped = model_dump() + if isinstance(dumped, dict): + return dict(dumped) + + if isinstance(args, Mapping): + return dict(args) + + return {"value": str(args)} + + +def _normalize_decision( + decision: object, +) -> tuple[Literal["allow", "deny", "pending"], str | None]: + return _normalize_governance_decision(decision) + + +async def _invoke_async_tool_check( + callback_handler: Any, + *, + tool_name: str, + tool_args: dict[str, Any], + agent_id: str | None, + run_id: str | None, +) -> object: + method = getattr(callback_handler, "check_tool_start", None) + if not callable(method): + return {"status": "allow"} + + result = method( + serialized={"name": tool_name}, + input_str=str(tool_args), + tool_name=tool_name, + args=tool_args, + agent_id=agent_id, + run_id=run_id, + ) + if inspect.isawaitable(result): + return await result + return result + + +async def _wait_for_async_tool_approval( + callback_handler: Any, + *, + tool_name: str, + timeout_seconds: int, + tool_args: dict[str, Any], + agent_id: str | None, + run_id: str | None, +) -> object: + method = getattr(callback_handler, "wait_for_tool_approval", None) + if not callable(method): + return {"status": "deny", "reason": "Approval handler is unavailable."} + + result = method( + serialized={"name": tool_name}, + input_str=str(tool_args), + tool_name=tool_name, + timeout_seconds=timeout_seconds, + args=tool_args, + agent_id=agent_id, + run_id=run_id, + ) + if inspect.isawaitable(result): + return await result + return result + + +def _get_pending_tool_approval_timeout_seconds(callback_handler: Any) -> int: + return _resolve_pending_timeout_seconds(callback_handler) + + +def _truncate_result_for_audit(result: object) -> str: + return str(result)[:_MAX_AUDIT_RESULT_CHARS] + + +async def _record_async_tool_result( + callback_handler: Any, + *, + tool_name: str, + result: object, + agent_id: str | None, + run_id: str | None, +) -> None: + record_method = getattr(callback_handler, "record_result", None) + if callable(record_method): + recorded = record_method( + tool_name=tool_name, + result=_truncate_result_for_audit(result), + agent_id=agent_id, + run_id=run_id, + ) + if inspect.isawaitable(recorded): + await recorded + return None + + tool_end_method = getattr(callback_handler, "on_tool_end", None) + if callable(tool_end_method): + recorded = tool_end_method( + output=_truncate_result_for_audit(result), + tool_name=tool_name, + agent_id=agent_id, + run_id=run_id, + ) + if inspect.isawaitable(recorded): + await recorded + + +def _build_denied_error(tool_name: str, reason: str | None) -> Exception: + from agent_assembly.exceptions import PolicyViolationError + + reason_text = reason or "No reason provided." + return PolicyViolationError(f"Tool '{tool_name}' blocked by governance policy: {reason_text}") + + +def _build_pending_rejected_error(tool_name: str, reason: str | None) -> Exception: + from agent_assembly.exceptions import PolicyViolationError + + reason_text = reason or "No reason provided." + return PolicyViolationError(f"Tool '{tool_name}' rejected during approval: {reason_text}") diff --git a/agent_assembly/core/assembly.py b/agent_assembly/core/assembly.py index be49337..c92e5f6 100644 --- a/agent_assembly/core/assembly.py +++ b/agent_assembly/core/assembly.py @@ -38,5 +38,5 @@ def init_assembly( agent_id=agent_id, api_key=api_key, ) - auto_inject_callback_handler(interceptor=object()) + auto_inject_callback_handler(interceptor=object(), process_agent_id=agent_id) return client diff --git a/agent_assembly/exceptions/__init__.py b/agent_assembly/exceptions/__init__.py index a3ab7f4..91b91d1 100644 --- a/agent_assembly/exceptions/__init__.py +++ b/agent_assembly/exceptions/__init__.py @@ -10,6 +10,7 @@ "ConfigurationError", "AdapterValidationError", "ToolExecutionBlockedError", + "PolicyViolationError", ] @@ -46,3 +47,8 @@ class AdapterValidationError(AssemblyError): class ToolExecutionBlockedError(AssemblyError): """Exception raised when a tool run is blocked by governance.""" pass + + +class PolicyViolationError(ToolExecutionBlockedError): + """Exception raised when policy blocks tool execution.""" + pass diff --git a/docs/contents/document/api-references/index.mdx b/docs/contents/document/api-references/index.mdx index e757ff1..a996844 100644 --- a/docs/contents/document/api-references/index.mdx +++ b/docs/contents/document/api-references/index.mdx @@ -81,3 +81,13 @@ from agent_assembly.exceptions import ( - Blocked tool calls return policy message strings (instead of raising exceptions). - `pending` approval decisions block synchronously and return denial strings if not approved. - `Task.execute_sync()` emits task start/complete audit events. + +## Pydantic AI runtime interception + +`init_assembly(...)` also applies Pydantic AI runtime patches when Pydantic AI is installed. + +- `Tool._run(...)` is patched with async governance checks. +- `deny` decisions raise `PolicyViolationError` with tool name and reason. +- `pending` decisions await approval and raise on rejection. +- Successful results are recorded with audit payload truncation at 2000 chars. +- Agent identity resolves from `ctx.deps.assembly_agent_id` with process-level fallback. diff --git a/test/integration/test_pydantic_ai_interception_integration.py b/test/integration/test_pydantic_ai_interception_integration.py new file mode 100644 index 0000000..674a7ab --- /dev/null +++ b/test/integration/test_pydantic_ai_interception_integration.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +import pytest + +from agent_assembly.adapters.pydantic_ai import patch as pydantic_ai_patch +from agent_assembly.exceptions import PolicyViolationError + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_pydantic_ai_two_tool_flow_continues_after_blocked_tool( + monkeypatch: pytest.MonkeyPatch, +) -> None: + class FakeTool: + def __init__(self, name: str, result: str) -> None: + self.name = name + self._result = result + + async def _run(self, ctx: Any, args: Any, **kwargs: Any) -> str: + del ctx, args, kwargs + return self._result + + fake_pydantic_ai_tools = SimpleNamespace(Tool=FakeTool) + + def fake_import_module(module_name: str) -> object: + if module_name == "pydantic_ai.tools": + return fake_pydantic_ai_tools + raise ImportError(module_name) + + monkeypatch.setattr(pydantic_ai_patch.importlib, "import_module", fake_import_module) + + class Interceptor: + async def check_tool_start(self, **kwargs: object) -> dict[str, str]: + if kwargs.get("tool_name") == "blocked_tool": + return {"status": "deny", "reason": "blocked by policy"} + return {"status": "allow"} + + patcher = pydantic_ai_patch.PydanticAIPatch(Interceptor()) + assert patcher.apply() is True + + ctx = SimpleNamespace(deps=SimpleNamespace(assembly_agent_id="agent-1"), run_id="run-main") + blocked_tool = FakeTool("blocked_tool", "should-not-run") + safe_tool = FakeTool("safe_tool", "ok:safe_tool") + + with pytest.raises(PolicyViolationError, match="blocked by governance policy: blocked by policy"): + await blocked_tool._run(ctx, {"step": 1}) + + safe_result = await safe_tool._run(ctx, {"step": 2}) + assert safe_result == "ok:safe_tool" + + +@pytest.mark.integration +def test_pydantic_ai_real_tool_class_patch_path_when_available() -> None: + pydantic_ai_tools = pytest.importorskip("pydantic_ai.tools") + tool_cls = pydantic_ai_tools.Tool + + class Interceptor: + async def check_tool_start(self, **kwargs: object) -> dict[str, str]: + del kwargs + return {"status": "allow"} + + patcher = pydantic_ai_patch.PydanticAIPatch(Interceptor()) + assert patcher.apply() is True + assert getattr(tool_cls, pydantic_ai_patch._TOOLS_PATCHED_FLAG, False) is True diff --git a/test/unit/adapters/pydantic_ai/test_pydantic_ai_patch.py b/test/unit/adapters/pydantic_ai/test_pydantic_ai_patch.py new file mode 100644 index 0000000..58a4edd --- /dev/null +++ b/test/unit/adapters/pydantic_ai/test_pydantic_ai_patch.py @@ -0,0 +1,349 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +import pytest + +from agent_assembly.adapters.pydantic_ai import patch as pydantic_ai_patch +from agent_assembly.exceptions import PolicyViolationError + + +class _RecordingInterceptor: + async def check_tool_start(self, **kwargs: object) -> dict[str, str]: + del kwargs + return {"status": "allow"} + + +class _ArgsModel: + def __init__(self, payload: dict[str, Any]) -> None: + self._payload = payload + + def model_dump(self) -> dict[str, Any]: + return dict(self._payload) + + +def _install_fake_pydantic_ai_modules( + monkeypatch: pytest.MonkeyPatch, +) -> type[Any]: + class FakeTool: + name = "fake_tool" + + async def _run(self, ctx: Any, args: Any, **kwargs: Any) -> dict[str, object]: + return { + "ctx": ctx, + "args": args, + "kwargs": kwargs, + } + + fake_pydantic_ai_tools = SimpleNamespace(Tool=FakeTool) + + def fake_import_module(module_name: str) -> object: + if module_name == "pydantic_ai.tools": + return fake_pydantic_ai_tools + raise ImportError(module_name) + + monkeypatch.setattr(pydantic_ai_patch.importlib, "import_module", fake_import_module) + return FakeTool + + +@pytest.mark.asyncio +async def test_apply_patches_tool_run_and_is_idempotent(monkeypatch: pytest.MonkeyPatch) -> None: + FakeTool = _install_fake_pydantic_ai_modules(monkeypatch) + + patcher = pydantic_ai_patch.PydanticAIPatch(_RecordingInterceptor()) + assert patcher.apply() is True + first_run_ref = FakeTool._run + + assert getattr(FakeTool, pydantic_ai_patch._TOOLS_PATCHED_FLAG, False) is True + + assert patcher.apply() is True + assert FakeTool._run is first_run_ref + + +def test_loader_edge_cases_and_apply_false_without_tool(monkeypatch: pytest.MonkeyPatch) -> None: + def raise_import_error(module_name: str) -> object: + raise ImportError(module_name) + + monkeypatch.setattr(pydantic_ai_patch.importlib, "import_module", raise_import_error) + assert pydantic_ai_patch._load_pydantic_ai_tool_class() is None + assert pydantic_ai_patch.PydanticAIPatch(_RecordingInterceptor()).apply() is False + + fake_pydantic_ai_tools = SimpleNamespace(Tool=object()) + + def return_non_type(module_name: str) -> object: + if module_name == "pydantic_ai.tools": + return fake_pydantic_ai_tools + raise ImportError(module_name) + + monkeypatch.setattr(pydantic_ai_patch.importlib, "import_module", return_non_type) + assert pydantic_ai_patch._load_pydantic_ai_tool_class() is None + + +def test_helper_branches_for_agent_id_timeout_and_serialization() -> None: + class TimeoutProvider: + def get_pending_tool_approval_timeout_seconds(self) -> str: + return "42" + + assert pydantic_ai_patch._get_pending_tool_approval_timeout_seconds(TimeoutProvider()) == 42 + assert pydantic_ai_patch._get_pending_tool_approval_timeout_seconds( + SimpleNamespace(pending_tool_approval_timeout_seconds=0) + ) == 300 + assert pydantic_ai_patch._get_pending_tool_approval_timeout_seconds( + SimpleNamespace(pending_tool_approval_timeout_seconds=True) + ) == 300 + + assert pydantic_ai_patch._normalize_decision("deny") == ("deny", None) + assert pydantic_ai_patch._normalize_decision("pending") == ("pending", None) + assert pydantic_ai_patch._normalize_decision("allow") == ("allow", None) + assert pydantic_ai_patch._normalize_decision(12345) == ("allow", None) + + pydantic_ai_patch.set_process_agent_id("process-agent") + ctx_with_deps = SimpleNamespace(deps=SimpleNamespace(assembly_agent_id="deps-agent"), run_id=123) + ctx_without_deps = SimpleNamespace(deps=SimpleNamespace(), run_id=None) + + assert pydantic_ai_patch._resolve_agent_id(ctx_with_deps) == "deps-agent" + assert pydantic_ai_patch._resolve_agent_id(ctx_without_deps) == "process-agent" + assert pydantic_ai_patch._resolve_run_id(ctx_with_deps) == "123" + assert pydantic_ai_patch._resolve_run_id(ctx_without_deps) is None + + model_args = _ArgsModel({"x": 1, "y": 2}) + assert pydantic_ai_patch._serialize_tool_args(model_args) == {"x": 1, "y": 2} + assert pydantic_ai_patch._serialize_tool_args({"a": 1}) == {"a": 1} + assert pydantic_ai_patch._serialize_tool_args(99) == {"value": "99"} + + pydantic_ai_patch.set_process_agent_id(None) + + +@pytest.mark.asyncio +async def test_denied_tool_raises_policy_violation_error(monkeypatch: pytest.MonkeyPatch) -> None: + FakeTool = _install_fake_pydantic_ai_modules(monkeypatch) + + class BlockInterceptor: + async def check_tool_start(self, **kwargs: object) -> dict[str, str]: + del kwargs + return {"status": "deny", "reason": "blocked for safety"} + + patcher = pydantic_ai_patch.PydanticAIPatch(BlockInterceptor()) + assert patcher.apply() is True + + tool = FakeTool() + ctx = SimpleNamespace(deps=SimpleNamespace(assembly_agent_id="agent-a"), run_id="run-1") + + with pytest.raises(PolicyViolationError, match="blocked by governance policy: blocked for safety"): + await tool._run(ctx, _ArgsModel({"topic": "finance"})) + + +@pytest.mark.asyncio +async def test_pending_then_approved_runs_and_records_result( + monkeypatch: pytest.MonkeyPatch, +) -> None: + FakeTool = _install_fake_pydantic_ai_modules(monkeypatch) + wait_calls: list[dict[str, object]] = [] + recorded_results: list[dict[str, object]] = [] + + class PendingThenApproveInterceptor: + pending_tool_approval_timeout_seconds = 25 + + async def check_tool_start(self, **kwargs: object) -> dict[str, str]: + del kwargs + return {"status": "pending", "reason": "needs approval"} + + async def wait_for_tool_approval(self, **kwargs: object) -> dict[str, str]: + wait_calls.append(dict(kwargs)) + return {"status": "allow"} + + async def record_result(self, **kwargs: object) -> None: + recorded_results.append(dict(kwargs)) + + patcher = pydantic_ai_patch.PydanticAIPatch(PendingThenApproveInterceptor()) + assert patcher.apply() is True + + tool = FakeTool() + ctx = SimpleNamespace(deps=SimpleNamespace(assembly_agent_id="agent-b"), run_id="run-2") + result = await tool._run(ctx, _ArgsModel({"q": "hello"}), trace="yes") + + assert result["kwargs"] == {"trace": "yes"} + assert len(wait_calls) == 1 + assert wait_calls[0]["timeout_seconds"] == 25 + assert len(recorded_results) == 1 + assert recorded_results[0]["tool_name"] == "fake_tool" + assert recorded_results[0]["agent_id"] == "agent-b" + + +@pytest.mark.asyncio +async def test_pending_then_rejected_raises_policy_violation_error( + monkeypatch: pytest.MonkeyPatch, +) -> None: + FakeTool = _install_fake_pydantic_ai_modules(monkeypatch) + + class PendingThenRejectInterceptor: + async def check_tool_start(self, **kwargs: object) -> dict[str, str]: + del kwargs + return {"status": "pending", "reason": "requires approval"} + + async def wait_for_tool_approval(self, **kwargs: object) -> dict[str, str]: + del kwargs + return {"status": "deny", "reason": "approval rejected"} + + patcher = pydantic_ai_patch.PydanticAIPatch(PendingThenRejectInterceptor()) + assert patcher.apply() is True + + tool = FakeTool() + ctx = SimpleNamespace(deps=SimpleNamespace(assembly_agent_id="agent-c"), run_id="run-3") + + with pytest.raises(PolicyViolationError, match="rejected during approval: approval rejected"): + await tool._run(ctx, _ArgsModel({"q": "secret"})) + + +@pytest.mark.asyncio +async def test_result_recording_truncates_to_2000_chars(monkeypatch: pytest.MonkeyPatch) -> None: + class FakeTool: + name = "truncate_tool" + + async def _run(self, ctx: Any, args: Any, **kwargs: Any) -> str: + del ctx, args, kwargs + return "x" * 2500 + + fake_pydantic_ai_tools = SimpleNamespace(Tool=FakeTool) + + def fake_import_module(module_name: str) -> object: + if module_name == "pydantic_ai.tools": + return fake_pydantic_ai_tools + raise ImportError(module_name) + + monkeypatch.setattr(pydantic_ai_patch.importlib, "import_module", fake_import_module) + + observed: list[str] = [] + + class Interceptor: + async def check_tool_start(self, **kwargs: object) -> dict[str, str]: + del kwargs + return {"status": "allow"} + + async def record_result(self, **kwargs: object) -> None: + observed.append(str(kwargs["result"])) + + patcher = pydantic_ai_patch.PydanticAIPatch(Interceptor()) + assert patcher.apply() is True + + tool = FakeTool() + ctx = SimpleNamespace(deps=SimpleNamespace(assembly_agent_id="agent-d"), run_id="run-4") + result = await tool._run(ctx, _ArgsModel({"q": "long"})) + + assert isinstance(result, str) + assert len(result) == 2500 + assert len(observed) == 1 + assert len(observed[0]) == 2000 + + +@pytest.mark.asyncio +async def test_assembly_model_wrapper_scans_and_forwards_request() -> None: + class FakeModel: + async def request(self, *args: object, **kwargs: object) -> str: + del kwargs + return f"forwarded:{args[0]}" + + scanned_prompts: list[list[str]] = [] + + class Interceptor: + async def on_llm_start_scan(self, **kwargs: object) -> None: + prompts = kwargs.get("prompts", []) + if isinstance(prompts, list): + scanned_prompts.append([str(item) for item in prompts]) + + wrapper = pydantic_ai_patch.AssemblyModelWrapper(FakeModel(), Interceptor()) + result = await wrapper.request("hello") + + assert result == "forwarded:hello" + assert scanned_prompts == [["hello"]] + + +@pytest.mark.asyncio +async def test_assembly_model_wrapper_passthrough_attrs() -> None: + class FakeModel: + model_name = "demo-model" + + def request(self, *args: object, **kwargs: object) -> str: + del args, kwargs + return "ok" + + wrapper = pydantic_ai_patch.AssemblyModelWrapper(FakeModel(), object()) + assert wrapper.model_name == "demo-model" + result = await wrapper.request("ignored") + assert result == "ok" + + +@pytest.mark.asyncio +async def test_fallback_and_non_awaitable_branches_for_async_helpers() -> None: + class NoHandlers: + pass + + pydantic_ai_patch.set_process_agent_id(None) + assert pydantic_ai_patch._get_process_agent_id() is None + + fallback_check = await pydantic_ai_patch._invoke_async_tool_check( + NoHandlers(), + tool_name="x", + tool_args={}, + agent_id=None, + run_id=None, + ) + assert fallback_check == {"status": "allow"} + + fallback_wait = await pydantic_ai_patch._wait_for_async_tool_approval( + NoHandlers(), + tool_name="x", + timeout_seconds=1, + tool_args={}, + agent_id=None, + run_id=None, + ) + assert fallback_wait == {"status": "deny", "reason": "Approval handler is unavailable."} + + assert pydantic_ai_patch._get_pending_tool_approval_timeout_seconds( + SimpleNamespace(pending_tool_approval_timeout_seconds="NaN") + ) == 300 + + class SyncInterceptor: + def check_tool_start(self, **kwargs: object) -> dict[str, str]: + del kwargs + return {"status": "allow"} + + def wait_for_tool_approval(self, **kwargs: object) -> dict[str, str]: + del kwargs + return {"status": "allow"} + + sync_check = await pydantic_ai_patch._invoke_async_tool_check( + SyncInterceptor(), + tool_name="x", + tool_args={}, + agent_id=None, + run_id=None, + ) + assert sync_check == {"status": "allow"} + + sync_wait = await pydantic_ai_patch._wait_for_async_tool_approval( + SyncInterceptor(), + tool_name="x", + timeout_seconds=2, + tool_args={}, + agent_id=None, + run_id=None, + ) + assert sync_wait == {"status": "allow"} + + observed_outputs: list[str] = [] + + class ToolEndOnlyInterceptor: + async def on_tool_end(self, **kwargs: object) -> None: + observed_outputs.append(str(kwargs["output"])) + + await pydantic_ai_patch._record_async_tool_result( + ToolEndOnlyInterceptor(), + tool_name="fallback", + result="result-value", + agent_id="agent-z", + run_id="run-z", + ) + assert observed_outputs == ["result-value"]