diff --git a/agent_assembly/adapters/crewai/__init__.py b/agent_assembly/adapters/crewai/__init__.py new file mode 100644 index 0000000..e050dae --- /dev/null +++ b/agent_assembly/adapters/crewai/__init__.py @@ -0,0 +1,5 @@ +"""CrewAI adapter package.""" + +from agent_assembly.adapters.crewai.patch import CrewAIPatch + +__all__ = ["CrewAIPatch"] diff --git a/agent_assembly/adapters/crewai/patch.py b/agent_assembly/adapters/crewai/patch.py new file mode 100644 index 0000000..261a7ec --- /dev/null +++ b/agent_assembly/adapters/crewai/patch.py @@ -0,0 +1,315 @@ +"""CrewAI patch module.""" + +from __future__ import annotations + +from dataclasses import dataclass +from functools import wraps +import importlib +from threading import local +from typing import Any, Literal, Mapping + +_TOOLS_PATCHED_FLAG = "_agent_assembly_crewai_tools_patched" +_TASK_PATCHED_FLAG = "_agent_assembly_crewai_task_patched" +_ORIGINAL_TOOL_RUN = "_agent_assembly_original_crewai_tool_run" +_ORIGINAL_TASK_EXECUTE_SYNC = "_agent_assembly_original_crewai_task_execute_sync" +_AGENT_CONTEXT = local() +_DEFAULT_PENDING_APPROVAL_TIMEOUT_SECONDS = 300 + + +@dataclass(slots=True) +class CrewAIPatch: + """Applies CrewAI runtime monkey-patching hooks.""" + + callback_handler: Any + + def apply(self) -> bool: + """Apply patch wiring and return whether CrewAI is available.""" + base_tool_cls = _load_crewai_basetool_class() + if base_tool_cls is None: + return False + + _apply_basetool_run_patch(base_tool_cls, self.callback_handler) + task_cls = _load_crewai_task_class() + if task_cls is not None: + _apply_task_execute_sync_patch(task_cls, self.callback_handler) + return True + + +def _load_crewai_basetool_class() -> type[Any] | None: + try: + module = importlib.import_module("crewai.tools") + except ImportError: + return None + + base_tool_cls = getattr(module, "BaseTool", None) + if isinstance(base_tool_cls, type): + return base_tool_cls + return None + + +def _load_crewai_task_class() -> type[Any] | None: + try: + module = importlib.import_module("crewai") + except ImportError: + return None + + task_cls = getattr(module, "Task", None) + if isinstance(task_cls, type): + return task_cls + return None + + +def _set_thread_local_agent_id(agent_id: str | None) -> None: + _AGENT_CONTEXT.agent_id = agent_id + + +def _get_thread_local_agent_id() -> str | None: + agent_id = getattr(_AGENT_CONTEXT, "agent_id", None) + if isinstance(agent_id, str) and agent_id: + return agent_id + return None + + +def _extract_agent_id_from_inputs(args: tuple[Any, ...], kwargs: dict[str, Any]) -> str | None: + direct_agent_id = kwargs.get("agent_id") + if isinstance(direct_agent_id, str) and direct_agent_id: + return direct_agent_id + + config = kwargs.get("config") + if isinstance(config, dict): + configurable = config.get("configurable") + if isinstance(configurable, dict): + configurable_agent_id = configurable.get("agent_id") + if isinstance(configurable_agent_id, str) and configurable_agent_id: + return configurable_agent_id + + metadata = config.get("metadata") + if isinstance(metadata, dict): + metadata_agent_id = metadata.get("agent_id") + if isinstance(metadata_agent_id, str) and metadata_agent_id: + return metadata_agent_id + + if args and isinstance(args[0], dict): + state_agent_id = args[0].get("agent_id") + if isinstance(state_agent_id, str) and state_agent_id: + return state_agent_id + + return None + + +def _format_blocked_message(reason: str | None) -> str: + reason_text = reason or "No reason provided." + return ( + f"[BLOCKED by governance policy] {reason_text}. " + "Please choose a different approach to accomplish this task." + ) + + +def _format_approval_rejected_message(reason: str | None) -> str: + reason_text = reason or "No reason provided." + return f"[APPROVAL REJECTED] Action was reviewed and denied: {reason_text}" + + +def _normalize_decision( + decision: object, +) -> tuple[Literal["allow", "deny", "pending"], str | None]: + if isinstance(decision, str): + normalized = decision.strip().lower() + if normalized == "deny": + return "deny", None + if normalized == "pending": + return "pending", None + return "allow", None + + if isinstance(decision, Mapping): + raw_status = str(decision.get("status", "allow")).strip().lower() + if raw_status == "deny": + status: Literal["allow", "deny", "pending"] = "deny" + elif raw_status == "pending": + status = "pending" + else: + status = "allow" + + reason_value = decision.get("reason") + reason = str(reason_value) if reason_value is not None else None + return status, reason + + return "allow", None + + +def _invoke_sync_tool_check( + callback_handler: Any, + *, + tool_name: str, + tool_args: dict[str, Any], + agent_id: str | None, +) -> object: + method = getattr(callback_handler, "check_tool_start", None) + if callable(method): + return method( + serialized={"name": tool_name}, + input_str=str(tool_args), + tool_name=tool_name, + args=tool_args, + agent_id=agent_id, + ) + + return {"status": "allow"} + + +def _wait_for_sync_tool_approval( + callback_handler: Any, + *, + tool_name: str, + timeout_seconds: int, + tool_args: dict[str, Any], + agent_id: str | None, +) -> object: + method = getattr(callback_handler, "wait_for_tool_approval", None) + if callable(method): + return method( + tool_name=tool_name, + timeout_seconds=timeout_seconds, + args=tool_args, + agent_id=agent_id, + ) + + return {"status": "deny", "reason": "Approval handler is unavailable."} + + +def _get_pending_tool_approval_timeout_seconds(callback_handler: Any) -> int: + provider = getattr(callback_handler, "get_pending_tool_approval_timeout_seconds", None) + if callable(provider): + configured = provider() + else: + configured = getattr(callback_handler, "pending_tool_approval_timeout_seconds", None) + + if isinstance(configured, str): + stripped = configured.strip() + if stripped.isdigit(): + parsed = int(stripped) + if parsed > 0: + return parsed + return _DEFAULT_PENDING_APPROVAL_TIMEOUT_SECONDS + + if isinstance(configured, bool): + return _DEFAULT_PENDING_APPROVAL_TIMEOUT_SECONDS + + if isinstance(configured, int) and configured > 0: + return configured + + return _DEFAULT_PENDING_APPROVAL_TIMEOUT_SECONDS + + +def _record_sync_tool_result( + callback_handler: Any, + *, + tool_name: str, + result: object, +) -> None: + record_method = getattr(callback_handler, "record_result", None) + if callable(record_method): + record_method(tool_name=tool_name, result=result) + return None + + tool_end_method = getattr(callback_handler, "on_tool_end", None) + if callable(tool_end_method): + tool_end_method(output=result, tool_name=tool_name) + return None + + +def _apply_basetool_run_patch(base_tool_cls: type[Any], callback_handler: Any) -> None: + if getattr(base_tool_cls, _TOOLS_PATCHED_FLAG, False): + return None + + original_run = base_tool_cls.run + + @wraps(original_run) + def patched_run(self: Any, *args: Any, **kwargs: Any) -> Any: + tool_name = getattr(self, "name", self.__class__.__name__) + tool_args = dict(kwargs) + agent_id = _get_thread_local_agent_id() + decision = _invoke_sync_tool_check( + callback_handler, + tool_name=str(tool_name), + tool_args=tool_args, + agent_id=agent_id, + ) + status, reason = _normalize_decision(decision) + is_pending_flow = False + if status == "pending": + is_pending_flow = True + timeout_seconds = _get_pending_tool_approval_timeout_seconds(callback_handler) + final_decision = _wait_for_sync_tool_approval( + callback_handler, + tool_name=str(tool_name), + timeout_seconds=timeout_seconds, + tool_args=tool_args, + agent_id=agent_id, + ) + status, reason = _normalize_decision(final_decision) + + if status == "deny": + if is_pending_flow: + return _format_approval_rejected_message(reason) + return _format_blocked_message(reason) + + result = original_run(self, *args, **kwargs) + _record_sync_tool_result(callback_handler, tool_name=str(tool_name), result=result) + return result + + setattr(base_tool_cls, _ORIGINAL_TOOL_RUN, original_run) + setattr(base_tool_cls, "run", patched_run) + setattr(base_tool_cls, _TOOLS_PATCHED_FLAG, True) + + +def _record_task_start(callback_handler: Any, task: Any) -> None: + method = getattr(callback_handler, "record", None) + if callable(method): + method( + action="task_start", + task_description=str(getattr(task, "description", ""))[:200], + expected_output=getattr(task, "expected_output", None), + ) + return None + + fallback = getattr(callback_handler, "on_task_start", None) + if callable(fallback): + fallback(task=task) + return None + + +def _record_task_complete(callback_handler: Any, result: object) -> None: + method = getattr(callback_handler, "record", None) + if callable(method): + method(action="task_complete", output_preview=str(result)[:500]) + return None + + fallback = getattr(callback_handler, "on_task_complete", None) + if callable(fallback): + fallback(result=result) + return None + + +def _apply_task_execute_sync_patch(task_cls: type[Any], callback_handler: Any) -> None: + if getattr(task_cls, _TASK_PATCHED_FLAG, False): + return None + + original_execute_sync = task_cls.execute_sync + + @wraps(original_execute_sync) + def patched_execute_sync(self: Any, *args: Any, **kwargs: Any) -> Any: + previous_agent_id = _get_thread_local_agent_id() + _set_thread_local_agent_id(_extract_agent_id_from_inputs(args, kwargs)) + _record_task_start(callback_handler, self) + try: + result = original_execute_sync(self, *args, **kwargs) + finally: + _set_thread_local_agent_id(previous_agent_id) + + _record_task_complete(callback_handler, result) + return result + + setattr(task_cls, _ORIGINAL_TASK_EXECUTE_SYNC, original_execute_sync) + setattr(task_cls, "execute_sync", patched_execute_sync) + setattr(task_cls, _TASK_PATCHED_FLAG, True) diff --git a/agent_assembly/adapters/langchain/runtime.py b/agent_assembly/adapters/langchain/runtime.py index 50bb20f..aca6841 100644 --- a/agent_assembly/adapters/langchain/runtime.py +++ b/agent_assembly/adapters/langchain/runtime.py @@ -5,6 +5,7 @@ from threading import Lock from typing import Any +from agent_assembly.adapters.crewai.patch import CrewAIPatch from agent_assembly.adapters.langchain.callback_handler import AssemblyCallbackHandler from agent_assembly.adapters.langgraph import LangGraphPatch @@ -19,11 +20,13 @@ def auto_inject_callback_handler(interceptor: Any) -> AssemblyCallbackHandler: with _RUNTIME_LOCK: if _ACTIVE_CALLBACK_HANDLER is not None: LangGraphPatch(_ACTIVE_CALLBACK_HANDLER).apply() + CrewAIPatch(interceptor).apply() return _ACTIVE_CALLBACK_HANDLER handler = AssemblyCallbackHandler(interceptor) _ACTIVE_CALLBACK_HANDLER = handler LangGraphPatch(handler).apply() + CrewAIPatch(interceptor).apply() return handler diff --git a/docs/contents/document/api-references/index.mdx b/docs/contents/document/api-references/index.mdx index 8ed37d7..e757ff1 100644 --- a/docs/contents/document/api-references/index.mdx +++ b/docs/contents/document/api-references/index.mdx @@ -72,3 +72,12 @@ from agent_assembly.exceptions import ( - `deny` (or unresolved `pending`) raises `ToolExecutionBlockedError`. - LLM start interception is scan-only and does not mutate prompt content. - LangGraph `StateGraph.compile()` is patched to add pre/post invocation governance hooks. + +## CrewAI runtime interception + +`init_assembly(...)` also applies CrewAI runtime patches when CrewAI is installed. + +- `BaseTool.run()` is patched with synchronous governance checks. +- Blocked tool calls return policy message strings (instead of raising exceptions). +- `pending` approval decisions block synchronously and return denial strings if not approved. +- `Task.execute_sync()` emits task start/complete audit events. diff --git a/test/integration/test_crewai_interception_integration.py b/test/integration/test_crewai_interception_integration.py new file mode 100644 index 0000000..b3a297a --- /dev/null +++ b/test/integration/test_crewai_interception_integration.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +import pytest + +from agent_assembly.adapters.crewai import patch as crewai_patch + + +@pytest.mark.integration +def test_crewai_two_task_flow_continues_after_blocked_tool( + monkeypatch: pytest.MonkeyPatch, +) -> None: + class FakeBaseTool: + def __init__(self, name: str) -> None: + self.name = name + + def run(self, *args: Any, **kwargs: Any) -> str: + del args, kwargs + return f"ok:{self.name}" + + class FakeTask: + def __init__(self, tool: FakeBaseTool, description: str, expected_output: str) -> None: + self.tool = tool + self.description = description + self.expected_output = expected_output + + def execute_sync(self, *args: Any, **kwargs: Any) -> str: + del args, kwargs + return self.tool.run() + + fake_crewai_tools = SimpleNamespace(BaseTool=FakeBaseTool) + fake_crewai_module = SimpleNamespace(Task=FakeTask) + + def fake_import_module(module_name: str) -> object: + if module_name == "crewai.tools": + return fake_crewai_tools + if module_name == "crewai": + return fake_crewai_module + raise ImportError(module_name) + + monkeypatch.setattr(crewai_patch.importlib, "import_module", fake_import_module) + + class Interceptor: + def check_tool_start(self, **kwargs: object) -> dict[str, str]: + tool_name = kwargs.get("tool_name") + if tool_name == "blocked_tool": + return {"status": "deny", "reason": "blocked by policy"} + return {"status": "allow"} + + patcher = crewai_patch.CrewAIPatch(Interceptor()) + assert patcher.apply() is True + + blocked_task = FakeTask( + tool=FakeBaseTool("blocked_tool"), + description="task1", + expected_output="blocked string", + ) + safe_task = FakeTask( + tool=FakeBaseTool("safe_tool"), + description="task2", + expected_output="normal result", + ) + + results = [ + blocked_task.execute_sync(agent_id="agent-1"), + safe_task.execute_sync(agent_id="agent-2"), + ] + + assert isinstance(results[0], str) + assert "[BLOCKED by governance policy]" in results[0] + assert "blocked by policy" in results[0] + assert results[1] == "ok:safe_tool" + + +@pytest.mark.integration +def test_crewai_real_task_and_tool_classes_flow_when_available( + monkeypatch: pytest.MonkeyPatch, +) -> None: + crewai = pytest.importorskip("crewai") + crewai_tools = pytest.importorskip("crewai.tools") + + BaseTool = crewai_tools.BaseTool + Task = crewai.Task + Agent = crewai.Agent + Crew = crewai.Crew + + class BlockedTool(BaseTool): + name: str = "blocked_tool" + description: str = "Tool that should be blocked by governance." + + def _run(self, **kwargs: object) -> str: + del kwargs + return "should-not-run" + + class SafeTool(BaseTool): + name: str = "safe_tool" + description: str = "Tool that should remain allowed." + + def _run(self, **kwargs: object) -> str: + del kwargs + return "ok:safe_tool" + + def fake_execute_sync(self: object, *args: object, **kwargs: object) -> str: + del args, kwargs + task_tools = getattr(self, "tools", None) or [] + if task_tools: + return str(task_tools[0].run()) + return "no-tool" + + monkeypatch.setattr(Task, "execute_sync", fake_execute_sync, raising=True) + + class Interceptor: + def check_tool_start(self, **kwargs: object) -> dict[str, str]: + if kwargs.get("tool_name") == "blocked_tool": + return {"status": "deny", "reason": "blocked by policy"} + return {"status": "allow"} + + patcher = crewai_patch.CrewAIPatch(Interceptor()) + assert patcher.apply() is True + + blocked_agent = Agent(role="blocked", goal="run blocked task", backstory="blocked") + safe_agent = Agent(role="safe", goal="run safe task", backstory="safe") + + blocked_task = Task( + description="blocked task", + expected_output="blocked string", + agent=blocked_agent, + tools=[BlockedTool()], + ) + safe_task = Task( + description="safe task", + expected_output="safe string", + agent=safe_agent, + tools=[SafeTool()], + ) + + # Build a real CrewAI Crew object to validate object wiring and two-agent task setup. + Crew(agents=[blocked_agent, safe_agent], tasks=[blocked_task, safe_task], verbose=False) + + results = [ + blocked_task.execute_sync(agent_id="agent-1"), + safe_task.execute_sync(agent_id="agent-2"), + ] + + assert isinstance(results[0], str) + assert "[BLOCKED by governance policy]" in results[0] + assert "blocked by policy" in results[0] + assert results[1] == "ok:safe_tool" diff --git a/test/unit/adapters/crewai/test_patch.py b/test/unit/adapters/crewai/test_patch.py new file mode 100644 index 0000000..92e048e --- /dev/null +++ b/test/unit/adapters/crewai/test_patch.py @@ -0,0 +1,400 @@ +from __future__ import annotations + +from concurrent.futures import ThreadPoolExecutor +from types import SimpleNamespace +from typing import Any + +import pytest + +from agent_assembly.adapters.crewai import patch as crewai_patch + + +class _RecordingInterceptor: + def check_tool_start(self, **kwargs: object) -> dict[str, str]: + del kwargs + return {"status": "allow"} + + +def _install_fake_crewai_modules( + monkeypatch: pytest.MonkeyPatch, +) -> tuple[type[Any], type[Any]]: + class FakeBaseTool: + name = "fake_tool" + + def run(self, *args: Any, **kwargs: Any) -> dict[str, object]: + return {"args": args, "kwargs": kwargs} + + class FakeTask: + description = "fake task" + expected_output = "fake output" + + def execute_sync(self, *args: Any, **kwargs: Any) -> dict[str, object]: + return {"args": args, "kwargs": kwargs} + + fake_crewai_tools = SimpleNamespace(BaseTool=FakeBaseTool) + fake_crewai_module = SimpleNamespace(Task=FakeTask) + + def fake_import_module(module_name: str) -> object: + if module_name == "crewai.tools": + return fake_crewai_tools + if module_name == "crewai": + return fake_crewai_module + raise ImportError(module_name) + + monkeypatch.setattr(crewai_patch.importlib, "import_module", fake_import_module) + return FakeBaseTool, FakeTask + + +def test_apply_patches_crewai_run_and_is_idempotent(monkeypatch: pytest.MonkeyPatch) -> None: + FakeBaseTool, FakeTask = _install_fake_crewai_modules(monkeypatch) + + patcher = crewai_patch.CrewAIPatch(_RecordingInterceptor()) + assert patcher.apply() is True + first_run_ref = FakeBaseTool.run + first_task_ref = FakeTask.execute_sync + + assert getattr(FakeBaseTool, crewai_patch._TOOLS_PATCHED_FLAG, False) is True + assert getattr(FakeTask, crewai_patch._TASK_PATCHED_FLAG, False) is True + + assert patcher.apply() is True + assert FakeBaseTool.run is first_run_ref + assert FakeTask.execute_sync is first_task_ref + + +def test_loader_edge_cases_and_apply_false_without_basetool( + monkeypatch: pytest.MonkeyPatch, +) -> None: + def raise_import_error(module_name: str) -> object: + raise ImportError(module_name) + + monkeypatch.setattr(crewai_patch.importlib, "import_module", raise_import_error) + assert crewai_patch._load_crewai_basetool_class() is None + assert crewai_patch._load_crewai_task_class() is None + assert crewai_patch.CrewAIPatch(_RecordingInterceptor()).apply() is False + + fake_crewai_tools = SimpleNamespace(BaseTool=object()) + fake_crewai_module = SimpleNamespace(Task=object()) + + def return_non_type(module_name: str) -> object: + if module_name == "crewai.tools": + return fake_crewai_tools + if module_name == "crewai": + return fake_crewai_module + raise ImportError(module_name) + + monkeypatch.setattr(crewai_patch.importlib, "import_module", return_non_type) + assert crewai_patch._load_crewai_basetool_class() is None + assert crewai_patch._load_crewai_task_class() is None + + +def test_helper_branch_coverage_for_decision_and_agent_extraction() -> None: + assert crewai_patch._normalize_decision("deny") == ("deny", None) + assert crewai_patch._normalize_decision("pending") == ("pending", None) + assert crewai_patch._normalize_decision("allow") == ("allow", None) + assert crewai_patch._normalize_decision(12345) == ("allow", None) + + assert ( + crewai_patch._extract_agent_id_from_inputs((), {"agent_id": "agent-direct"}) == "agent-direct" + ) + assert ( + crewai_patch._extract_agent_id_from_inputs( + (), + {"config": {"configurable": {"agent_id": "agent-configurable"}}}, + ) + == "agent-configurable" + ) + assert ( + crewai_patch._extract_agent_id_from_inputs( + (), + {"config": {"metadata": {"agent_id": "agent-metadata"}}}, + ) + == "agent-metadata" + ) + assert crewai_patch._extract_agent_id_from_inputs(({"agent_id": "agent-state"},), {}) == "agent-state" + assert crewai_patch._extract_agent_id_from_inputs((), {}) is None + + class NoHandlers: + pass + + fallback_check = crewai_patch._invoke_sync_tool_check( + NoHandlers(), + tool_name="x", + tool_args={}, + agent_id=None, + ) + assert fallback_check == {"status": "allow"} + + fallback_wait = crewai_patch._wait_for_sync_tool_approval( + NoHandlers(), + tool_name="x", + timeout_seconds=1, + tool_args={}, + agent_id=None, + ) + assert fallback_wait == {"status": "deny", "reason": "Approval handler is unavailable."} + + class TimeoutProvider: + def get_pending_tool_approval_timeout_seconds(self) -> str: + return "42" + + assert crewai_patch._get_pending_tool_approval_timeout_seconds(TimeoutProvider()) == 42 + assert crewai_patch._get_pending_tool_approval_timeout_seconds( + SimpleNamespace(pending_tool_approval_timeout_seconds=0) + ) == 300 + assert crewai_patch._get_pending_tool_approval_timeout_seconds( + SimpleNamespace(pending_tool_approval_timeout_seconds=True) + ) == 300 + + +def test_record_result_and_task_fallback_handlers_are_used( + monkeypatch: pytest.MonkeyPatch, +) -> None: + class FakeBaseTool: + name = "record_result_tool" + + def run(self, *args: Any, **kwargs: Any) -> dict[str, object]: + return {"args": args, "kwargs": kwargs} + + class FakeTask: + description = "fallback task" + expected_output = "fallback output" + + def execute_sync(self, *args: Any, **kwargs: Any) -> str: + del args, kwargs + return "task-result" + + fake_crewai_tools = SimpleNamespace(BaseTool=FakeBaseTool) + fake_crewai_module = SimpleNamespace(Task=FakeTask) + + def fake_import_module(module_name: str) -> object: + if module_name == "crewai.tools": + return fake_crewai_tools + if module_name == "crewai": + return fake_crewai_module + raise ImportError(module_name) + + monkeypatch.setattr(crewai_patch.importlib, "import_module", fake_import_module) + + seen_results: list[object] = [] + lifecycle_events: list[str] = [] + + class FallbackInterceptor: + def check_tool_start(self, **kwargs: object) -> dict[str, str]: + del kwargs + return {"status": "allow"} + + def record_result(self, **kwargs: object) -> None: + seen_results.append(kwargs["result"]) + + def on_task_start(self, **kwargs: object) -> None: + del kwargs + lifecycle_events.append("start") + + def on_task_complete(self, **kwargs: object) -> None: + del kwargs + lifecycle_events.append("complete") + + patcher = crewai_patch.CrewAIPatch(FallbackInterceptor()) + assert patcher.apply() is True + + tool_result = FakeBaseTool().run(alpha=1) + task_result = FakeTask().execute_sync() + + assert tool_result == {"args": (), "kwargs": {"alpha": 1}} + assert seen_results == [tool_result] + assert task_result == "task-result" + assert lifecycle_events == ["start", "complete"] + + +def test_blocked_tool_returns_policy_string(monkeypatch: pytest.MonkeyPatch) -> None: + FakeBaseTool, _ = _install_fake_crewai_modules(monkeypatch) + + class BlockInterceptor: + def check_tool_start(self, **kwargs: object) -> dict[str, str]: + del kwargs + return {"status": "deny", "reason": "blocked for safety"} + + patcher = crewai_patch.CrewAIPatch(BlockInterceptor()) + assert patcher.apply() is True + + tool = FakeBaseTool() + result = tool.run(param="value") + + assert isinstance(result, str) + assert "[BLOCKED by governance policy]" in result + assert "blocked for safety" in result + + +def test_allowed_tool_runs_and_records_result(monkeypatch: pytest.MonkeyPatch) -> None: + FakeBaseTool, _ = _install_fake_crewai_modules(monkeypatch) + observed: list[object] = [] + + class AllowInterceptor: + def check_tool_start(self, **kwargs: object) -> dict[str, str]: + del kwargs + return {"status": "allow"} + + def on_tool_end(self, *, output: object, **kwargs: object) -> None: + del kwargs + observed.append(output) + + patcher = crewai_patch.CrewAIPatch(AllowInterceptor()) + assert patcher.apply() is True + + tool = FakeBaseTool() + result = tool.run(param="value") + + assert result == {"args": (), "kwargs": {"param": "value"}} + assert observed == [result] + + +def test_pending_tool_waits_and_allows_when_approved( + monkeypatch: pytest.MonkeyPatch, +) -> None: + FakeBaseTool, _ = _install_fake_crewai_modules(monkeypatch) + wait_calls: list[dict[str, object]] = [] + + class PendingThenApproveInterceptor: + def check_tool_start(self, **kwargs: object) -> dict[str, str]: + del kwargs + return {"status": "pending", "reason": "needs approval"} + + def wait_for_tool_approval(self, **kwargs: object) -> dict[str, str]: + wait_calls.append(dict(kwargs)) + return {"status": "allow"} + + patcher = crewai_patch.CrewAIPatch(PendingThenApproveInterceptor()) + assert patcher.apply() is True + + tool = FakeBaseTool() + result = tool.run(param="value") + + assert result == {"args": (), "kwargs": {"param": "value"}} + assert len(wait_calls) == 1 + assert wait_calls[0]["timeout_seconds"] == 300 + + +def test_pending_tool_uses_configurable_timeout(monkeypatch: pytest.MonkeyPatch) -> None: + FakeBaseTool, _ = _install_fake_crewai_modules(monkeypatch) + wait_calls: list[dict[str, object]] = [] + + class PendingWithConfigurableTimeoutInterceptor: + pending_tool_approval_timeout_seconds = 37 + + def check_tool_start(self, **kwargs: object) -> dict[str, str]: + del kwargs + return {"status": "pending", "reason": "needs approval"} + + def wait_for_tool_approval(self, **kwargs: object) -> dict[str, str]: + wait_calls.append(dict(kwargs)) + return {"status": "allow"} + + patcher = crewai_patch.CrewAIPatch(PendingWithConfigurableTimeoutInterceptor()) + assert patcher.apply() is True + + tool = FakeBaseTool() + result = tool.run(param="value") + + assert result == {"args": (), "kwargs": {"param": "value"}} + assert len(wait_calls) == 1 + assert wait_calls[0]["timeout_seconds"] == 37 + + +def test_pending_timeout_returns_denied_string(monkeypatch: pytest.MonkeyPatch) -> None: + FakeBaseTool, _ = _install_fake_crewai_modules(monkeypatch) + + class PendingTimeoutInterceptor: + def check_tool_start(self, **kwargs: object) -> dict[str, str]: + del kwargs + return {"status": "pending", "reason": "requires manual approval"} + + def wait_for_tool_approval(self, **kwargs: object) -> dict[str, str]: + del kwargs + return {"status": "deny", "reason": "approval timeout"} + + patcher = crewai_patch.CrewAIPatch(PendingTimeoutInterceptor()) + assert patcher.apply() is True + + tool = FakeBaseTool() + result = tool.run(param="value") + + assert isinstance(result, str) + assert result.startswith("[APPROVAL REJECTED]") + assert "approval timeout" in result + + +def test_task_start_and_complete_events_are_recorded( + monkeypatch: pytest.MonkeyPatch, +) -> None: + _, FakeTask = _install_fake_crewai_modules(monkeypatch) + recorded: list[dict[str, object]] = [] + + class TaskRecordInterceptor: + def check_tool_start(self, **kwargs: object) -> dict[str, str]: + del kwargs + return {"status": "allow"} + + def record(self, **kwargs: object) -> None: + recorded.append(dict(kwargs)) + + patcher = crewai_patch.CrewAIPatch(TaskRecordInterceptor()) + assert patcher.apply() is True + + task = FakeTask() + result = task.execute_sync(input_text="hello") + + assert result == {"args": (), "kwargs": {"input_text": "hello"}} + assert len(recorded) == 2 + assert recorded[0]["action"] == "task_start" + assert recorded[1]["action"] == "task_complete" + + +def test_thread_local_agent_id_isolated_across_concurrent_tasks( + monkeypatch: pytest.MonkeyPatch, +) -> None: + class FakeBaseTool: + name = "thread_tool" + + def run(self, *args: Any, **kwargs: Any) -> dict[str, object]: + return {"args": args, "kwargs": kwargs} + + class FakeTask: + description = "thread task" + expected_output = "thread output" + + def execute_sync(self, *args: Any, **kwargs: Any) -> object: + tool = kwargs["tool"] + return tool.run() + + fake_crewai_tools = SimpleNamespace(BaseTool=FakeBaseTool) + fake_crewai_module = SimpleNamespace(Task=FakeTask) + + def fake_import_module(module_name: str) -> object: + if module_name == "crewai.tools": + return fake_crewai_tools + if module_name == "crewai": + return fake_crewai_module + raise ImportError(module_name) + + monkeypatch.setattr(crewai_patch.importlib, "import_module", fake_import_module) + + observed_agent_ids: list[str | None] = [] + + class ConcurrencyInterceptor: + def check_tool_start(self, **kwargs: object) -> dict[str, str]: + observed_agent_ids.append( + str(kwargs.get("agent_id")) if kwargs.get("agent_id") is not None else None + ) + return {"status": "allow"} + + patcher = crewai_patch.CrewAIPatch(ConcurrencyInterceptor()) + assert patcher.apply() is True + + task = FakeTask() + with ThreadPoolExecutor(max_workers=2) as pool: + future_a = pool.submit(task.execute_sync, tool=FakeBaseTool(), agent_id="agent-A") + future_b = pool.submit(task.execute_sync, tool=FakeBaseTool(), agent_id="agent-B") + assert future_a.result() == {"args": (), "kwargs": {}} + assert future_b.result() == {"args": (), "kwargs": {}} + + assert sorted(observed_agent_ids) == ["agent-A", "agent-B"]