diff --git a/agent_assembly/adapters/__init__.py b/agent_assembly/adapters/__init__.py index c16f448..1da5617 100644 --- a/agent_assembly/adapters/__init__.py +++ b/agent_assembly/adapters/__init__.py @@ -1,3 +1,4 @@ from agent_assembly.adapters.base import FrameworkAdapter, GovernanceInterceptor +from agent_assembly.adapters.registry import AdapterInfo, AdapterRegistry -__all__ = ["GovernanceInterceptor", "FrameworkAdapter"] +__all__ = ["GovernanceInterceptor", "FrameworkAdapter", "AdapterInfo", "AdapterRegistry"] diff --git a/agent_assembly/adapters/registry.py b/agent_assembly/adapters/registry.py new file mode 100644 index 0000000..3c620e6 --- /dev/null +++ b/agent_assembly/adapters/registry.py @@ -0,0 +1,188 @@ +from __future__ import annotations + +from dataclasses import dataclass +from importlib import metadata +from threading import Lock +from typing import Callable, Literal + +from agent_assembly.adapters.base import FrameworkAdapter + + +@dataclass(frozen=True, slots=True) +class AdapterInfo: + name: str + version: str + status: Literal["active", "error"] + hooks_registered: int + + +class _BuiltinPlaceholderAdapter(FrameworkAdapter): + def __init__(self, framework_name: str, import_name: str | None = None) -> None: + self._framework_name = framework_name + self._import_name = import_name or framework_name + + def get_framework_name(self) -> str: + return self._import_name + + def get_supported_versions(self) -> list[str]: + return [">=0.0.0"] + + def register_hooks(self, interceptor: object) -> None: + return None + + def unregister_hooks(self) -> None: + return None + + +def _noop_interceptor_method(*args: object, **kwargs: object) -> None: + del args, kwargs + return None + + +class _NoopGovernanceInterceptor: + def __getattr__(self, name: str) -> Callable[..., None]: + del name + return _noop_interceptor_method + + +_NOOP_GOVERNANCE_INTERCEPTOR = _NoopGovernanceInterceptor() + + +class AdapterRegistry: + def __init__(self) -> None: + self._lock = Lock() + self._registered: dict[str, FrameworkAdapter] = {} + self._active: dict[str, FrameworkAdapter] = {} + self._errors: dict[str, str] = {} + self._discovered_entry_points: set[str] = set() + builtin_frameworks = [ + ("langchain", "langchain"), + ("langgraph", "langgraph"), + ("crewai", "crewai"), + ("pydantic-ai", "pydantic_ai"), + ] + for registry_name, import_name in builtin_frameworks: + self._registered[registry_name] = _BuiltinPlaceholderAdapter( + framework_name=registry_name, + import_name=import_name, + ) + + def register(self, adapter: FrameworkAdapter) -> None: + adapter_name = adapter.get_framework_name() + with self._lock: + self._registered[adapter_name] = adapter + self._errors.pop(adapter_name, None) + if adapter_name in self._active and self._active[adapter_name] is not adapter: + self._active.pop(adapter_name, None) + + def unregister(self, name: str) -> None: + with self._lock: + active_adapter = self._active.pop(name, None) + self._registered.pop(name, None) + self._errors.pop(name, None) + + if active_adapter is not None: + active_adapter.unregister_hooks() + + def list_active(self) -> list[AdapterInfo]: + with self._lock: + active_items = list(self._active.items()) + error_names = set(self._errors.keys()) + + result: list[AdapterInfo] = [] + for name, adapter in active_items: + hooks_registered = getattr(adapter, "_hooks_registered_count", 0) + if not isinstance(hooks_registered, int): + hooks_registered = 0 + + result.append( + AdapterInfo( + name=name, + version=adapter.get_active_version() or "", + status="active", + hooks_registered=hooks_registered, + ) + ) + + active_names = {name for name, _ in active_items} + for name in sorted(error_names - active_names): + result.append( + AdapterInfo( + name=name, + version="", + status="error", + hooks_registered=0, + ) + ) + + return sorted(result, key=lambda info: info.name) + + def _discover_entry_point_adapters(self) -> list[str]: + discovered: list[str] = [] + entry_points = metadata.entry_points() + adapter_entry_points = entry_points.select(group="agent_assembly.adapters") + + for entry_point in adapter_entry_points: + with self._lock: + if entry_point.name in self._discovered_entry_points: + continue + + try: + loaded = entry_point.load() + except Exception as error: # pragma: no cover - guarded by tests via monkeypatch + with self._lock: + self._errors[entry_point.name] = str(error) + continue + + if not isinstance(loaded, type): + with self._lock: + self._errors[entry_point.name] = "Entry point did not load a class." + continue + + if not issubclass(loaded, FrameworkAdapter): + with self._lock: + self._errors[entry_point.name] = "Entry point class is not a FrameworkAdapter." + continue + + try: + adapter = loaded() + except Exception as error: # pragma: no cover - guarded by tests via monkeypatch + with self._lock: + self._errors[entry_point.name] = str(error) + continue + + self.register(adapter) + with self._lock: + self._discovered_entry_points.add(entry_point.name) + discovered.append(adapter.get_framework_name()) + + return discovered + + def auto_detect(self) -> list[str]: + self._discover_entry_point_adapters() + + with self._lock: + registered_items = list(self._registered.items()) + + activated: list[str] = [] + for name, adapter in registered_items: + if not adapter.is_available(): + continue + + with self._lock: + if self._active.get(name) is adapter: + continue + + try: + adapter.register(_NOOP_GOVERNANCE_INTERCEPTOR) + except Exception as error: + with self._lock: + self._errors[name] = str(error) + continue + + with self._lock: + self._active[name] = adapter + self._errors.pop(name, None) + activated.append(name) + + return activated diff --git a/mypy.ini b/mypy.ini index 70c6cf5..a2f1140 100644 --- a/mypy.ini +++ b/mypy.ini @@ -23,3 +23,6 @@ strict_concatenate = True [mypy-agent_assembly.adapters.base] strict = True + +[mypy-agent_assembly.adapters.registry] +strict = True diff --git a/test/unit/adapters/test_registry.py b/test/unit/adapters/test_registry.py new file mode 100644 index 0000000..4721b4b --- /dev/null +++ b/test/unit/adapters/test_registry.py @@ -0,0 +1,273 @@ +from __future__ import annotations + +from concurrent.futures import ThreadPoolExecutor +from types import SimpleNamespace + +import pytest + +from agent_assembly.adapters import AdapterRegistry, FrameworkAdapter, GovernanceInterceptor + + +class DummyAdapter(FrameworkAdapter): + def __init__(self, framework_name: str) -> None: + self._framework_name = framework_name + self.register_calls = 0 + + def get_framework_name(self) -> str: + return self._framework_name + + def get_supported_versions(self) -> list[str]: + return [">=0.1.0"] + + def register_hooks(self, interceptor: GovernanceInterceptor) -> None: + self.register_calls += 1 + + def unregister_hooks(self) -> None: + return None + + +class EmptyEntryPoints(list[object]): + def select(self, *, group: str) -> list[object]: + del group + return [] + + +class FakeEntryPoint: + def __init__(self, name: str, loaded: type[FrameworkAdapter]) -> None: + self.name = name + self._loaded = loaded + + def load(self) -> type[FrameworkAdapter]: + return self._loaded + + +class ThirdPartyAdapter(FrameworkAdapter): + def get_framework_name(self) -> str: + return "third_party_framework" + + def get_supported_versions(self) -> list[str]: + return [">=1.0.0"] + + def register_hooks(self, interceptor: GovernanceInterceptor) -> None: + return None + + def unregister_hooks(self) -> None: + return None + + +class CountingEntryPointAdapter(FrameworkAdapter): + register_calls = 0 + + def get_framework_name(self) -> str: + return "entrypoint_counting_framework" + + def get_supported_versions(self) -> list[str]: + return [">=1.0.0"] + + def register_hooks(self, interceptor: GovernanceInterceptor) -> None: + CountingEntryPointAdapter.register_calls += 1 + + def unregister_hooks(self) -> None: + return None + + +class InterceptorCallingAdapter(FrameworkAdapter): + def __init__(self) -> None: + self.hook_registered = False + + def get_framework_name(self) -> str: + return "interceptor_calling_framework" + + def get_supported_versions(self) -> list[str]: + return [">=1.0.0"] + + def register_hooks(self, interceptor: GovernanceInterceptor) -> None: + interceptor.record_event("adapter-registered") + self.hook_registered = True + + def unregister_hooks(self) -> None: + return None + + +def test_auto_detect_activates_only_importable_frameworks( + monkeypatch: pytest.MonkeyPatch, +) -> None: + registry = AdapterRegistry() + importable = DummyAdapter("available_framework") + missing = DummyAdapter("missing_framework") + registry._registered = { + importable.get_framework_name(): importable, + missing.get_framework_name(): missing, + } + + monkeypatch.setattr( + "agent_assembly.adapters.registry.metadata.entry_points", + lambda: EmptyEntryPoints(), + ) + + def fake_import_module(module_name: str) -> object: + if module_name == "available_framework": + return SimpleNamespace(__version__="1.2.3") + raise ImportError + + monkeypatch.setattr("agent_assembly.adapters.base.importlib.import_module", fake_import_module) + + activated = registry.auto_detect() + + assert activated == ["available_framework"] + assert importable.register_calls == 1 + assert missing.register_calls == 0 + + +def test_entry_point_discovery_loads_third_party_adapters( + monkeypatch: pytest.MonkeyPatch, +) -> None: + registry = AdapterRegistry() + + class FakeEntryPoints(list[FakeEntryPoint]): + def select(self, *, group: str) -> list[FakeEntryPoint]: + assert group == "agent_assembly.adapters" + return list(self) + + monkeypatch.setattr( + "agent_assembly.adapters.registry.metadata.entry_points", + lambda: FakeEntryPoints([FakeEntryPoint("third-party", ThirdPartyAdapter)]), + ) + + discovered = registry._discover_entry_point_adapters() + + assert discovered == ["third_party_framework"] + assert "third_party_framework" in registry._registered + + +def test_auto_detect_is_idempotent(monkeypatch: pytest.MonkeyPatch) -> None: + registry = AdapterRegistry() + adapter = DummyAdapter("idempotent_framework") + registry._registered = {adapter.get_framework_name(): adapter} + + monkeypatch.setattr( + "agent_assembly.adapters.registry.metadata.entry_points", + lambda: EmptyEntryPoints(), + ) + monkeypatch.setattr( + "agent_assembly.adapters.base.importlib.import_module", + lambda module_name: SimpleNamespace(__version__="9.9.9"), + ) + + first_activation = registry.auto_detect() + second_activation = registry.auto_detect() + + assert first_activation == ["idempotent_framework"] + assert second_activation == [] + assert adapter.register_calls == 1 + + +def test_list_active_reflects_real_time_state( + monkeypatch: pytest.MonkeyPatch, +) -> None: + registry = AdapterRegistry() + adapter = DummyAdapter("stateful_framework") + registry._registered = {adapter.get_framework_name(): adapter} + + monkeypatch.setattr( + "agent_assembly.adapters.registry.metadata.entry_points", + lambda: EmptyEntryPoints(), + ) + monkeypatch.setattr( + "agent_assembly.adapters.base.importlib.import_module", + lambda module_name: SimpleNamespace(__version__="2.0.0"), + ) + + registry.auto_detect() + active_after_detect = registry.list_active() + + assert len(active_after_detect) == 1 + assert active_after_detect[0].name == "stateful_framework" + assert active_after_detect[0].status == "active" + + registry.unregister("stateful_framework") + + assert registry.list_active() == [] + + +def test_register_unregister_is_thread_safe() -> None: + registry = AdapterRegistry() + + def mutate_registry(thread_id: int) -> None: + for round_id in range(40): + adapter = DummyAdapter(f"concurrent_{thread_id}_{round_id}") + registry.register(adapter) + registry.unregister(adapter.get_framework_name()) + + with ThreadPoolExecutor(max_workers=8) as executor: + futures = [executor.submit(mutate_registry, thread_id) for thread_id in range(8)] + for future in futures: + future.result() + + with registry._lock: + assert isinstance(registry._registered, dict) + assert isinstance(registry._active, dict) + + +def test_auto_detect_is_idempotent_for_entry_point_adapters( + monkeypatch: pytest.MonkeyPatch, +) -> None: + registry = AdapterRegistry() + CountingEntryPointAdapter.register_calls = 0 + + class FakeEntryPoints(list[FakeEntryPoint]): + def select(self, *, group: str) -> list[FakeEntryPoint]: + assert group == "agent_assembly.adapters" + return list(self) + + monkeypatch.setattr( + "agent_assembly.adapters.registry.metadata.entry_points", + lambda: FakeEntryPoints([FakeEntryPoint("counting-entrypoint", CountingEntryPointAdapter)]), + ) + + def fake_import_module(module_name: str) -> object: + if module_name == "entrypoint_counting_framework": + return SimpleNamespace(__version__="3.0.0") + raise ImportError + + monkeypatch.setattr("agent_assembly.adapters.base.importlib.import_module", fake_import_module) + + first = registry.auto_detect() + second = registry.auto_detect() + + assert first == ["entrypoint_counting_framework"] + assert second == [] + assert CountingEntryPointAdapter.register_calls == 1 + + +def test_auto_detect_uses_resilient_noop_interceptor( + monkeypatch: pytest.MonkeyPatch, +) -> None: + registry = AdapterRegistry() + adapter = InterceptorCallingAdapter() + registry._registered = {adapter.get_framework_name(): adapter} + + monkeypatch.setattr( + "agent_assembly.adapters.registry.metadata.entry_points", + lambda: EmptyEntryPoints(), + ) + + def fake_import_module(module_name: str) -> object: + if module_name == "interceptor_calling_framework": + return SimpleNamespace(__version__="1.0.0") + raise ImportError + + monkeypatch.setattr("agent_assembly.adapters.base.importlib.import_module", fake_import_module) + + activated = registry.auto_detect() + + assert activated == ["interceptor_calling_framework"] + assert adapter.hook_registered is True + + +def test_builtin_registry_name_uses_pydantic_ai_label_with_python_import_name() -> None: + registry = AdapterRegistry() + + assert "pydantic-ai" in registry._registered + assert "pydantic_ai" not in registry._registered + assert registry._registered["pydantic-ai"].get_framework_name() == "pydantic_ai"