Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,27 +29,28 @@ from agent_assembly import init_assembly


async def main() -> None:
client = init_assembly(
context = init_assembly(
gateway_url="http://localhost:8080",
api_key="required-api-key",
agent_id="my-agent-001",
api_key="optional-api-key",
mode="auto",
)

try:
registration = await client.register_agent()
decision = await client.check_policy_compliance("tool.call")
registration = await context.client.register_agent()
decision = await context.client.check_policy_compliance("tool.call")
print(registration)
print(decision)
finally:
client.close()
context.shutdown()


asyncio.run(main())
```

## Public API

- `init_assembly(gateway_url, agent_id, api_key=None) -> GatewayClient`
- `init_assembly(gateway_url, api_key, agent_id=None, mode="auto") -> AssemblyContext`
- `GatewayClient.register_agent() -> dict`
- `GatewayClient.check_policy_compliance(action: str) -> dict`
- Exceptions: `AssemblyError`, `AgentError`, `PolicyError`, `GatewayError`, `ConfigurationError`
Expand All @@ -62,7 +63,7 @@ from agent_assembly import init_assembly
from agent_assembly.exceptions import ConfigurationError

try:
client = init_assembly(gateway_url="", agent_id="my-agent-001")
context = init_assembly(gateway_url="", api_key="my-api-key", agent_id="my-agent-001")
except ConfigurationError as exc:
print(f"Invalid configuration: {exc}")
```
Expand Down
3 changes: 2 additions & 1 deletion agent_assembly/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Agent Assembly Python SDK."""

from agent_assembly.adapters import FrameworkAdapter, GovernanceInterceptor
from agent_assembly.core import init_assembly
from agent_assembly.core import AssemblyContext, init_assembly
from agent_assembly.exceptions import (
AdapterValidationError,
AgentError,
Expand All @@ -17,6 +17,7 @@
__all__ = [
"__version__",
"init_assembly",
"AssemblyContext",
"GovernanceInterceptor",
"FrameworkAdapter",
"AssemblyError",
Expand Down
41 changes: 41 additions & 0 deletions agent_assembly/adapters/crewai/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@ def apply(self) -> bool:
_apply_task_execute_sync_patch(task_cls, self.callback_handler)
return True

def revert(self) -> None:
"""Revert CrewAI runtime monkey patches when available."""
base_tool_cls = _load_crewai_basetool_class()
if base_tool_cls is not None:
_revert_basetool_run_patch(base_tool_cls)

task_cls = _load_crewai_task_class()
if task_cls is not None:
_revert_task_execute_sync_patch(task_cls)
return None


def _load_crewai_basetool_class() -> type[Any] | None:
try:
Expand Down Expand Up @@ -263,6 +274,21 @@ def patched_run(self: Any, *args: Any, **kwargs: Any) -> Any:
setattr(base_tool_cls, _TOOLS_PATCHED_FLAG, True)


def _revert_basetool_run_patch(base_tool_cls: type[Any]) -> None:
if not getattr(base_tool_cls, _TOOLS_PATCHED_FLAG, False):
return None

original_run = getattr(base_tool_cls, _ORIGINAL_TOOL_RUN, None)
if callable(original_run):
setattr(base_tool_cls, "run", original_run)

if hasattr(base_tool_cls, _ORIGINAL_TOOL_RUN):
delattr(base_tool_cls, _ORIGINAL_TOOL_RUN)
if hasattr(base_tool_cls, _TOOLS_PATCHED_FLAG):
delattr(base_tool_cls, _TOOLS_PATCHED_FLAG)
return None


def _record_task_start(callback_handler: Any, task: Any) -> None:
method = getattr(callback_handler, "record", None)
if callable(method):
Expand Down Expand Up @@ -313,3 +339,18 @@ def patched_execute_sync(self: Any, *args: Any, **kwargs: Any) -> Any:
setattr(task_cls, _ORIGINAL_TASK_EXECUTE_SYNC, original_execute_sync)
setattr(task_cls, "execute_sync", patched_execute_sync)
setattr(task_cls, _TASK_PATCHED_FLAG, True)


def _revert_task_execute_sync_patch(task_cls: type[Any]) -> None:
if not getattr(task_cls, _TASK_PATCHED_FLAG, False):
return None

original_execute_sync = getattr(task_cls, _ORIGINAL_TASK_EXECUTE_SYNC, None)
if callable(original_execute_sync):
setattr(task_cls, "execute_sync", original_execute_sync)

if hasattr(task_cls, _ORIGINAL_TASK_EXECUTE_SYNC):
delattr(task_cls, _ORIGINAL_TASK_EXECUTE_SYNC)
if hasattr(task_cls, _TASK_PATCHED_FLAG):
delattr(task_cls, _TASK_PATCHED_FLAG)
return None
2 changes: 2 additions & 0 deletions agent_assembly/adapters/langchain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
from agent_assembly.adapters.langchain.callback_handler import AssemblyCallbackHandler
from agent_assembly.adapters.langgraph import LangGraphPatch
from agent_assembly.adapters.langchain.langgraph_patch import patch_stategraph_compile
from agent_assembly.adapters.langchain.patch import LangChainPatch
from agent_assembly.adapters.langchain.runtime import (
auto_inject_callback_handler,
get_active_callback_handler,
)

__all__ = [
"AssemblyCallbackHandler",
"LangChainPatch",
"LangGraphPatch",
"patch_stategraph_compile",
"auto_inject_callback_handler",
Expand Down
36 changes: 36 additions & 0 deletions agent_assembly/adapters/langchain/patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""LangChain patch orchestrator."""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any

from agent_assembly.adapters.langchain.callback_handler import AssemblyCallbackHandler
from agent_assembly.adapters.langchain.runtime import (
_reset_runtime_state_for_tests,
auto_inject_callback_handler,
)
from agent_assembly.adapters.pydantic_ai.patch import set_process_agent_id


@dataclass(slots=True)
class LangChainPatch:
"""Apply/remove LangChain callback handler injection."""

interceptor: Any
process_agent_id: str | None = None

def apply(self) -> bool:
auto_inject_callback_handler(
self.interceptor,
process_agent_id=self.process_agent_id,
)
return True

def revert(self) -> None:
set_process_agent_id(None)
_reset_runtime_state_for_tests()
return None


__all__ = ["LangChainPatch", "AssemblyCallbackHandler"]
10 changes: 1 addition & 9 deletions agent_assembly/adapters/langchain/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
from threading import Lock
from typing import Any

from agent_assembly.adapters.crewai.patch import CrewAIPatch
from agent_assembly.adapters.langchain.callback_handler import AssemblyCallbackHandler
from agent_assembly.adapters.langgraph import LangGraphPatch
from agent_assembly.adapters.pydantic_ai.patch import PydanticAIPatch, set_process_agent_id
from agent_assembly.adapters.pydantic_ai.patch import set_process_agent_id

_ACTIVE_CALLBACK_HANDLER: AssemblyCallbackHandler | None = None
_RUNTIME_LOCK = Lock()
Expand All @@ -27,16 +25,10 @@ def auto_inject_callback_handler(
set_process_agent_id(process_agent_id)

if _ACTIVE_CALLBACK_HANDLER is not None:
LangGraphPatch(_ACTIVE_CALLBACK_HANDLER).apply()
CrewAIPatch(interceptor).apply()
PydanticAIPatch(interceptor).apply()
return _ACTIVE_CALLBACK_HANDLER

handler = AssemblyCallbackHandler(interceptor)
_ACTIVE_CALLBACK_HANDLER = handler
LangGraphPatch(handler).apply()
CrewAIPatch(interceptor).apply()
PydanticAIPatch(interceptor).apply()
return handler


Expand Down
23 changes: 23 additions & 0 deletions agent_assembly/adapters/langgraph/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ def apply(self) -> bool:
_apply_stategraph_compile_patch(state_graph_cls, self.callback_handler)
return True

def revert(self) -> None:
"""Revert state graph compile patch when it is currently active."""
state_graph_cls = _load_stategraph_class()
if state_graph_cls is None:
return None
_revert_stategraph_compile_patch(state_graph_cls)
return None


def patch_stategraph_compile(callback_handler: Any) -> bool:
"""Backward-compatible helper used by existing runtime wiring."""
Expand Down Expand Up @@ -239,6 +247,21 @@ def patched_compile(self: Any, *args: Any, **kwargs: Any) -> Any:
setattr(state_graph_cls, _PATCHED_FLAG, True)


def _revert_stategraph_compile_patch(state_graph_cls: type[Any]) -> None:
if not getattr(state_graph_cls, _PATCHED_FLAG, False):
return None

original_compile = getattr(state_graph_cls, _ORIGINAL_COMPILE, None)
if callable(original_compile):
setattr(state_graph_cls, "compile", original_compile)

if hasattr(state_graph_cls, _ORIGINAL_COMPILE):
delattr(state_graph_cls, _ORIGINAL_COMPILE)
if hasattr(state_graph_cls, _PATCHED_FLAG):
delattr(state_graph_cls, _PATCHED_FLAG)
return None


def _wrap_graph_invoke_fallback(compiled_graph: Any, callback_handler: Any) -> None:
invoke = getattr(compiled_graph, "invoke", None)
if not callable(invoke) or getattr(invoke, _INVOKE_WRAPPED_FLAG, False):
Expand Down
5 changes: 5 additions & 0 deletions agent_assembly/adapters/mcp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""MCP adapter package."""

from agent_assembly.adapters.mcp.patch import MCPClientPatch

__all__ = ["MCPClientPatch"]
25 changes: 25 additions & 0 deletions agent_assembly/adapters/mcp/patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""MCP client patch module."""

from __future__ import annotations

from dataclasses import dataclass
import importlib.util
from typing import Any


@dataclass(slots=True)
class MCPClientPatch:
"""Patch placeholder for MCP client interception."""

callback_handler: Any

def apply(self) -> bool:
_ = self.callback_handler
return _is_mcp_available()

def revert(self) -> None:
return None


def _is_mcp_available() -> bool:
return importlib.util.find_spec("mcp") is not None
5 changes: 5 additions & 0 deletions agent_assembly/adapters/openai_agents/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""OpenAI Agents adapter package."""

from agent_assembly.adapters.openai_agents.patch import OpenAIAgentsPatch

__all__ = ["OpenAIAgentsPatch"]
25 changes: 25 additions & 0 deletions agent_assembly/adapters/openai_agents/patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""OpenAI Agents patch module."""

from __future__ import annotations

from dataclasses import dataclass
import importlib.util
from typing import Any


@dataclass(slots=True)
class OpenAIAgentsPatch:
"""Patch placeholder for OpenAI Agents SDK interception."""

callback_handler: Any

def apply(self) -> bool:
_ = self.callback_handler
return _is_openai_agents_available()

def revert(self) -> None:
return None


def _is_openai_agents_available() -> bool:
return importlib.util.find_spec("openai.agents") is not None
23 changes: 23 additions & 0 deletions agent_assembly/adapters/pydantic_ai/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ def apply(self) -> bool:
_apply_tool_run_patch(tool_cls, self.callback_handler)
return True

def revert(self) -> None:
"""Revert Pydantic AI tool patch when available."""
tool_cls = _load_pydantic_ai_tool_class()
if tool_cls is not None:
_revert_tool_run_patch(tool_cls)
set_process_agent_id(None)
return None


class AssemblyModelWrapper:
"""Optional model wrapper for LLM input scan-forward interception."""
Expand Down Expand Up @@ -131,6 +139,21 @@ async def patched_run(self: Any, ctx: Any, args: Any, **kwargs: Any) -> Any:
setattr(tool_cls, _TOOLS_PATCHED_FLAG, True)


def _revert_tool_run_patch(tool_cls: type[Any]) -> None:
if not getattr(tool_cls, _TOOLS_PATCHED_FLAG, False):
return None

original_run = getattr(tool_cls, _ORIGINAL_TOOL_RUN, None)
if callable(original_run):
setattr(tool_cls, "_run", original_run)

if hasattr(tool_cls, _ORIGINAL_TOOL_RUN):
delattr(tool_cls, _ORIGINAL_TOOL_RUN)
if hasattr(tool_cls, _TOOLS_PATCHED_FLAG):
delattr(tool_cls, _TOOLS_PATCHED_FLAG)
return None


def set_process_agent_id(agent_id: str | None) -> None:
global _PROCESS_AGENT_ID
_PROCESS_AGENT_ID = agent_id
Expand Down
4 changes: 2 additions & 2 deletions agent_assembly/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Core module for Agent Assembly initialization."""

from agent_assembly.core.assembly import init_assembly
from agent_assembly.core.assembly import AssemblyContext, init_assembly

__all__ = ["init_assembly"]
__all__ = ["init_assembly", "AssemblyContext"]
Loading