diff --git a/agent_assembly/cli/__init__.py b/agent_assembly/cli/__init__.py new file mode 100644 index 0000000..429ec93 --- /dev/null +++ b/agent_assembly/cli/__init__.py @@ -0,0 +1 @@ +"""CLI tools for Agent Assembly SDK.""" diff --git a/agent_assembly/cli/adapter_validator.py b/agent_assembly/cli/adapter_validator.py new file mode 100644 index 0000000..3ca1174 --- /dev/null +++ b/agent_assembly/cli/adapter_validator.py @@ -0,0 +1,295 @@ +"""Adapter contract validation logic for community adapters.""" + +from __future__ import annotations + +import importlib +import importlib.util +import inspect +import tomllib +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + pass + +from agent_assembly.adapters.base import FrameworkAdapter, GovernanceInterceptor + + +@dataclass(frozen=True, slots=True) +class AdapterValidationResult: + """Result of a single adapter contract check.""" + + check_name: str + passed: bool + message: str + + +def _check_inherits_framework_adapter(cls: type) -> AdapterValidationResult: + """Check that the class inherits from FrameworkAdapter.""" + if issubclass(cls, FrameworkAdapter): + return AdapterValidationResult( + check_name="inherits_framework_adapter", + passed=True, + message="Class inherits from FrameworkAdapter.", + ) + return AdapterValidationResult( + check_name="inherits_framework_adapter", + passed=False, + message=f"Class {cls.__name__} does not inherit from FrameworkAdapter.", + ) + + +_REQUIRED_ABSTRACT_METHODS = frozenset( + { + "get_framework_name", + "get_supported_versions", + "register_hooks", + "unregister_hooks", + } +) + + +def _check_abstract_methods_implemented(cls: type) -> AdapterValidationResult: + """Check that all 4 required abstract methods are concretely implemented.""" + remaining: frozenset[str] = getattr(cls, "__abstractmethods__", frozenset()) + missing = _REQUIRED_ABSTRACT_METHODS & remaining + if not missing: + return AdapterValidationResult( + check_name="abstract_methods_implemented", + passed=True, + message="All required abstract methods are implemented.", + ) + return AdapterValidationResult( + check_name="abstract_methods_implemented", + passed=False, + message=f"Missing implementations: {', '.join(sorted(missing))}.", + ) + + +def _check_framework_name(instance: FrameworkAdapter) -> AdapterValidationResult: + """Check that get_framework_name() returns a non-empty string.""" + try: + name = instance.get_framework_name() + except Exception as exc: + return AdapterValidationResult( + check_name="framework_name", + passed=False, + message=f"get_framework_name() raised {type(exc).__name__}: {exc}", + ) + if isinstance(name, str) and name.strip(): + return AdapterValidationResult( + check_name="framework_name", + passed=True, + message=f"Framework name: '{name}'.", + ) + return AdapterValidationResult( + check_name="framework_name", + passed=False, + message="get_framework_name() must return a non-empty string.", + ) + + +def _check_supported_versions(instance: FrameworkAdapter) -> AdapterValidationResult: + """Check that get_supported_versions() returns a non-empty list of strings.""" + try: + versions = instance.get_supported_versions() + except Exception as exc: + return AdapterValidationResult( + check_name="supported_versions", + passed=False, + message=f"get_supported_versions() raised {type(exc).__name__}: {exc}", + ) + if not isinstance(versions, list) or not versions: + return AdapterValidationResult( + check_name="supported_versions", + passed=False, + message="get_supported_versions() must return a non-empty list.", + ) + for i, v in enumerate(versions): + if not isinstance(v, str) or not v.strip(): + return AdapterValidationResult( + check_name="supported_versions", + passed=False, + message=f"Version at index {i} must be a non-empty string.", + ) + return AdapterValidationResult( + check_name="supported_versions", + passed=True, + message=f"Supported versions: {versions}.", + ) + + +def _check_register_hooks_signature(cls: type) -> AdapterValidationResult: + """Check that register_hooks accepts a GovernanceInterceptor argument.""" + register_hooks = getattr(cls, "register_hooks") + sig = inspect.signature(register_hooks) + params = [p for name, p in sig.parameters.items() if name != "self"] + if not params: + return AdapterValidationResult( + check_name="register_hooks_signature", + passed=False, + message="register_hooks() must accept an interceptor argument.", + ) + first_param = params[0] + annotation = first_param.annotation + acceptable = ( + annotation is inspect.Parameter.empty + or annotation is GovernanceInterceptor + or (isinstance(annotation, str) and "GovernanceInterceptor" in annotation) + ) + if acceptable: + return AdapterValidationResult( + check_name="register_hooks_signature", + passed=True, + message="register_hooks() accepts an interceptor argument.", + ) + return AdapterValidationResult( + check_name="register_hooks_signature", + passed=False, + message=(f"register_hooks() first parameter annotated as {annotation}, " f"expected GovernanceInterceptor."), + ) + + +def _check_unregister_hooks_idempotent( + instance: FrameworkAdapter, +) -> AdapterValidationResult: + """Check that calling unregister_hooks() twice does not raise.""" + try: + instance.unregister_hooks() + instance.unregister_hooks() + except Exception as exc: + return AdapterValidationResult( + check_name="unregister_hooks_idempotent", + passed=False, + message=(f"unregister_hooks() is not idempotent: " f"second call raised {type(exc).__name__}: {exc}"), + ) + return AdapterValidationResult( + check_name="unregister_hooks_idempotent", + passed=True, + message="unregister_hooks() is idempotent (two calls without error).", + ) + + +def _check_entry_point_metadata(cls: type, path_or_module: str) -> AdapterValidationResult: + """Check entry point metadata in pyproject.toml if present at the given path.""" + search_path = Path(path_or_module) + if search_path.is_file(): + search_path = search_path.parent + + pyproject_path = search_path / "pyproject.toml" + if not pyproject_path.is_file(): + return AdapterValidationResult( + check_name="entry_point_metadata", + passed=True, + message="No pyproject.toml found; skipping entry point check.", + ) + + try: + with open(pyproject_path, "rb") as f: + data = tomllib.load(f) + except Exception as exc: + return AdapterValidationResult( + check_name="entry_point_metadata", + passed=False, + message=f"Failed to parse pyproject.toml: {exc}", + ) + + entry_points = data.get("project", {}).get("entry-points", {}).get("agent_assembly.adapters", {}) + if not entry_points: + return AdapterValidationResult( + check_name="entry_point_metadata", + passed=False, + message=('pyproject.toml missing [project.entry-points."agent_assembly.adapters"] ' "section."), + ) + + class_qualname = f"{cls.__module__}:{cls.__qualname__}" + for ep_name, ep_value in entry_points.items(): + if ep_value == class_qualname: + return AdapterValidationResult( + check_name="entry_point_metadata", + passed=True, + message=f"Entry point '{ep_name}' correctly references {class_qualname}.", + ) + + return AdapterValidationResult( + check_name="entry_point_metadata", + passed=False, + message=(f"No entry point references {class_qualname}. " f"Found: {entry_points}."), + ) + + +def validate_adapter(cls: type, path_or_module: str) -> list[AdapterValidationResult]: + """Run all contract checks against an adapter class and return results.""" + results: list[AdapterValidationResult] = [] + + results.append(_check_inherits_framework_adapter(cls)) + results.append(_check_abstract_methods_implemented(cls)) + + # Instance-level checks require a concrete class that can be instantiated + if any(not r.passed for r in results): + return results + + instance = cls() + results.append(_check_framework_name(instance)) + results.append(_check_supported_versions(instance)) + results.append(_check_register_hooks_signature(cls)) + results.append(_check_unregister_hooks_idempotent(instance)) + results.append(_check_entry_point_metadata(cls, path_or_module)) + + return results + + +def _find_adapter_class_in_module(module: object) -> type | None: + """Scan a module for the first FrameworkAdapter subclass.""" + for _name, obj in inspect.getmembers(module, inspect.isclass): + if obj is not FrameworkAdapter and issubclass(obj, FrameworkAdapter): + return obj + return None + + +def load_adapter_class_from_module(module_name: str) -> type: + """Load an adapter class from a dotted module name. + + Raises: + ImportError: If the module cannot be imported. + ValueError: If no FrameworkAdapter subclass is found in the module. + """ + module = importlib.import_module(module_name) + cls = _find_adapter_class_in_module(module) + if cls is None: + raise ValueError(f"No FrameworkAdapter subclass found in module '{module_name}'.") + return cls + + +def load_adapter_class_from_path(file_path: str) -> type: + """Load an adapter class from a file system path. + + Raises: + FileNotFoundError: If the path does not exist. + ValueError: If no FrameworkAdapter subclass is found in the file. + """ + path = Path(file_path).resolve() + if not path.is_file(): + raise FileNotFoundError(f"File not found: {path}") + + module_name = f"_aasm_validate_{path.stem}" + spec = importlib.util.spec_from_file_location(module_name, path) + if spec is None or spec.loader is None: + raise ValueError(f"Cannot create module spec from path: {path}") + + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + cls = _find_adapter_class_in_module(module) + if cls is None: + raise ValueError(f"No FrameworkAdapter subclass found in '{path}'.") + return cls + + +def load_adapter_class(path_or_module: str) -> type: + """Load an adapter class from either a file path or a dotted module name.""" + candidate = Path(path_or_module) + if candidate.exists(): + return load_adapter_class_from_path(path_or_module) + return load_adapter_class_from_module(path_or_module) diff --git a/agent_assembly/cli/main.py b/agent_assembly/cli/main.py new file mode 100644 index 0000000..2e2ac08 --- /dev/null +++ b/agent_assembly/cli/main.py @@ -0,0 +1,61 @@ +"""CLI entry point for Agent Assembly SDK tools.""" + +from __future__ import annotations + +import argparse +import sys + + +def _build_parser() -> argparse.ArgumentParser: + """Build the top-level argument parser with the adapter subcommand.""" + parser = argparse.ArgumentParser( + prog="aasm", + description="Agent Assembly SDK command-line tools.", + ) + subparsers = parser.add_subparsers(dest="command", help="Available commands") + + adapter_parser = subparsers.add_parser("adapter", help="Adapter management commands") + adapter_subparsers = adapter_parser.add_subparsers(dest="adapter_command", help="Adapter subcommands") + + validate_parser = adapter_subparsers.add_parser( + "validate", help="Validate a community adapter against the FrameworkAdapter contract" + ) + validate_parser.add_argument( + "path_or_module", + help="File path or dotted module name of the adapter to validate", + ) + + return parser + + +def _handle_adapter_validate(args: argparse.Namespace) -> int: + """Run adapter validation and return an exit code.""" + from agent_assembly.cli.adapter_validator import ( + load_adapter_class, + validate_adapter, + ) + from agent_assembly.cli.output import format_results + + try: + cls = load_adapter_class(args.path_or_module) + except (ImportError, FileNotFoundError, ValueError) as exc: + print(f"Error: {exc}", file=sys.stderr) + return 1 + + results = validate_adapter(cls, args.path_or_module) + print(format_results(results)) + + all_passed = all(r.passed for r in results) + return 0 if all_passed else 1 + + +def main() -> None: + """Parse CLI arguments and dispatch to the appropriate handler.""" + parser = _build_parser() + args = parser.parse_args() + + if args.command == "adapter" and getattr(args, "adapter_command", None) == "validate": + sys.exit(_handle_adapter_validate(args)) + else: + parser.print_help() + sys.exit(1) diff --git a/agent_assembly/cli/output.py b/agent_assembly/cli/output.py new file mode 100644 index 0000000..bdad3d5 --- /dev/null +++ b/agent_assembly/cli/output.py @@ -0,0 +1,24 @@ +"""Pass/fail output formatting for adapter validation results.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from agent_assembly.cli.adapter_validator import AdapterValidationResult + + +def format_results(results: list[AdapterValidationResult]) -> str: + """Format validation results as human-readable PASS/FAIL lines.""" + lines: list[str] = [] + for result in results: + prefix = "PASS" if result.passed else "FAIL" + lines.append(f" [{prefix}] {result.check_name}: {result.message}") + + passed = sum(1 for r in results if r.passed) + failed = len(results) - passed + + lines.append("") + lines.append(f"Results: {passed} passed, {failed} failed, {len(results)} total") + + return "\n".join(lines) diff --git a/pyproject.toml b/pyproject.toml index ba11b12..cfec966 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,9 @@ dependencies = [ "typing-extensions>=4.0.0", ] +[project.scripts] +aasm = "agent_assembly.cli.main:main" + [project.urls] Homepage = "https://github.com/agent-assembly/python-sdk" Repository = "https://github.com/agent-assembly/python-sdk" diff --git a/test/unit/cli/__init__.py b/test/unit/cli/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/unit/cli/conftest.py b/test/unit/cli/conftest.py new file mode 100644 index 0000000..7660b78 --- /dev/null +++ b/test/unit/cli/conftest.py @@ -0,0 +1,105 @@ +"""Shared fixtures for CLI adapter validator tests.""" + +from __future__ import annotations + +import pytest + +from agent_assembly.adapters.base import FrameworkAdapter, GovernanceInterceptor + + +class ValidAdapter(FrameworkAdapter): + """A fully valid adapter for testing.""" + + def get_framework_name(self) -> str: + return "test_framework" + + def get_supported_versions(self) -> list[str]: + return [">=1.0.0"] + + def register_hooks(self, interceptor: GovernanceInterceptor) -> None: + pass + + def unregister_hooks(self) -> None: + pass + + +class EmptyNameAdapter(FrameworkAdapter): + """Adapter that returns an empty framework name.""" + + def get_framework_name(self) -> str: + return "" + + def get_supported_versions(self) -> list[str]: + return [">=1.0.0"] + + def register_hooks(self, interceptor: GovernanceInterceptor) -> None: + pass + + def unregister_hooks(self) -> None: + pass + + +class EmptyVersionsAdapter(FrameworkAdapter): + """Adapter that returns an empty versions list.""" + + def get_framework_name(self) -> str: + return "test_framework" + + def get_supported_versions(self) -> list[str]: + return [] + + def register_hooks(self, interceptor: GovernanceInterceptor) -> None: + pass + + def unregister_hooks(self) -> None: + pass + + +class NonIdempotentAdapter(FrameworkAdapter): + """Adapter whose unregister_hooks raises on the second call.""" + + def __init__(self) -> None: + self._call_count = 0 + + def get_framework_name(self) -> str: + return "test_framework" + + def get_supported_versions(self) -> list[str]: + return [">=1.0.0"] + + def register_hooks(self, interceptor: GovernanceInterceptor) -> None: + pass + + def unregister_hooks(self) -> None: + self._call_count += 1 + if self._call_count > 1: + raise RuntimeError("Already unregistered") + + +class NotAnAdapter: + """A class that does not inherit from FrameworkAdapter.""" + + +@pytest.fixture() +def valid_adapter_cls() -> type: + return ValidAdapter + + +@pytest.fixture() +def empty_name_adapter_cls() -> type: + return EmptyNameAdapter + + +@pytest.fixture() +def empty_versions_adapter_cls() -> type: + return EmptyVersionsAdapter + + +@pytest.fixture() +def non_idempotent_adapter_cls() -> type: + return NonIdempotentAdapter + + +@pytest.fixture() +def not_an_adapter_cls() -> type: + return NotAnAdapter diff --git a/test/unit/cli/test_adapter_validator.py b/test/unit/cli/test_adapter_validator.py new file mode 100644 index 0000000..ecb596f --- /dev/null +++ b/test/unit/cli/test_adapter_validator.py @@ -0,0 +1,208 @@ +"""Unit tests for adapter validator logic.""" + +from __future__ import annotations + +from agent_assembly.adapters.base import FrameworkAdapter +from agent_assembly.cli.adapter_validator import ( + AdapterValidationResult, + _check_abstract_methods_implemented, + _check_entry_point_metadata, + _check_framework_name, + _check_inherits_framework_adapter, + _check_register_hooks_signature, + _check_supported_versions, + _check_unregister_hooks_idempotent, + validate_adapter, +) + + +class TestAdapterValidationResult: + """Tests for the AdapterValidationResult dataclass.""" + + def test_fields_stored(self) -> None: + result = AdapterValidationResult(check_name="test_check", passed=True, message="ok") + assert result.check_name == "test_check" + assert result.passed is True + assert result.message == "ok" + + def test_equality(self) -> None: + a = AdapterValidationResult(check_name="c", passed=True, message="m") + b = AdapterValidationResult(check_name="c", passed=True, message="m") + assert a == b + + def test_frozen(self) -> None: + import pytest + + result = AdapterValidationResult(check_name="c", passed=True, message="m") + with pytest.raises(AttributeError): + result.passed = False # type: ignore[misc] + + +class TestCheckInheritsFrameworkAdapter: + """Tests for _check_inherits_framework_adapter.""" + + def test_valid_subclass_passes(self, valid_adapter_cls: type) -> None: + result = _check_inherits_framework_adapter(valid_adapter_cls) + assert result.passed is True + assert result.check_name == "inherits_framework_adapter" + + def test_non_subclass_fails(self, not_an_adapter_cls: type) -> None: + result = _check_inherits_framework_adapter(not_an_adapter_cls) + assert result.passed is False + assert "does not inherit" in result.message + + +class TestCheckAbstractMethodsImplemented: + """Tests for _check_abstract_methods_implemented.""" + + def test_all_methods_concrete_passes(self, valid_adapter_cls: type) -> None: + result = _check_abstract_methods_implemented(valid_adapter_cls) + assert result.passed is True + + def test_missing_method_fails(self) -> None: + class PartialAdapter(FrameworkAdapter): + def get_framework_name(self) -> str: + return "test" + + def get_supported_versions(self) -> list[str]: + return [">=1.0.0"] + + result = _check_abstract_methods_implemented(PartialAdapter) + assert result.passed is False + assert "register_hooks" in result.message + assert "unregister_hooks" in result.message + + +class TestCheckFrameworkName: + """Tests for _check_framework_name.""" + + def test_non_empty_name_passes(self, valid_adapter_cls: type) -> None: + result = _check_framework_name(valid_adapter_cls()) + assert result.passed is True + assert "test_framework" in result.message + + def test_empty_name_fails(self, empty_name_adapter_cls: type) -> None: + result = _check_framework_name(empty_name_adapter_cls()) + assert result.passed is False + assert "non-empty string" in result.message + + def test_whitespace_name_fails(self) -> None: + from test.unit.cli.conftest import ValidAdapter + + class WhitespaceAdapter(ValidAdapter): + def get_framework_name(self) -> str: + return " " + + result = _check_framework_name(WhitespaceAdapter()) + assert result.passed is False + + +class TestCheckSupportedVersions: + """Tests for _check_supported_versions.""" + + def test_valid_list_passes(self, valid_adapter_cls: type) -> None: + result = _check_supported_versions(valid_adapter_cls()) + assert result.passed is True + + def test_empty_list_fails(self, empty_versions_adapter_cls: type) -> None: + result = _check_supported_versions(empty_versions_adapter_cls()) + assert result.passed is False + assert "non-empty list" in result.message + + def test_empty_string_in_list_fails(self) -> None: + from test.unit.cli.conftest import ValidAdapter + + class EmptyStringVersionAdapter(ValidAdapter): + def get_supported_versions(self) -> list[str]: + return [">=1.0.0", ""] + + result = _check_supported_versions(EmptyStringVersionAdapter()) + assert result.passed is False + assert "index 1" in result.message + + +class TestCheckRegisterHooksSignature: + """Tests for _check_register_hooks_signature.""" + + def test_correct_signature_passes(self, valid_adapter_cls: type) -> None: + result = _check_register_hooks_signature(valid_adapter_cls) + assert result.passed is True + + def test_missing_param_fails(self) -> None: + from test.unit.cli.conftest import ValidAdapter + + class NoParamAdapter(ValidAdapter): + def register_hooks(self) -> None: # type: ignore[override] + pass + + result = _check_register_hooks_signature(NoParamAdapter) + assert result.passed is False + assert "must accept" in result.message + + +class TestCheckUnregisterHooksIdempotent: + """Tests for _check_unregister_hooks_idempotent.""" + + def test_double_call_no_raise_passes(self, valid_adapter_cls: type) -> None: + result = _check_unregister_hooks_idempotent(valid_adapter_cls()) + assert result.passed is True + + def test_raises_on_second_call_fails(self, non_idempotent_adapter_cls: type) -> None: + result = _check_unregister_hooks_idempotent(non_idempotent_adapter_cls()) + assert result.passed is False + assert "not idempotent" in result.message + + +class TestCheckEntryPointMetadata: + """Tests for _check_entry_point_metadata.""" + + def test_valid_pyproject_passes(self, valid_adapter_cls: type, tmp_path: object) -> None: + import pathlib + + assert isinstance(tmp_path, pathlib.Path) + pyproject = tmp_path / "pyproject.toml" + qualname = f"{valid_adapter_cls.__module__}:{valid_adapter_cls.__qualname__}" + pyproject.write_text(f'[project.entry-points."agent_assembly.adapters"]\n' f'test_framework = "{qualname}"\n') + result = _check_entry_point_metadata(valid_adapter_cls, str(tmp_path)) + assert result.passed is True + + def test_missing_entry_point_fails(self, valid_adapter_cls: type, tmp_path: object) -> None: + import pathlib + + assert isinstance(tmp_path, pathlib.Path) + pyproject = tmp_path / "pyproject.toml" + pyproject.write_text("[project]\nname = 'test'\n") + result = _check_entry_point_metadata(valid_adapter_cls, str(tmp_path)) + assert result.passed is False + assert "missing" in result.message + + def test_no_pyproject_skips(self, valid_adapter_cls: type, tmp_path: object) -> None: + import pathlib + + assert isinstance(tmp_path, pathlib.Path) + result = _check_entry_point_metadata(valid_adapter_cls, str(tmp_path)) + assert result.passed is True + assert "skipping" in result.message.lower() + + +class TestValidateAdapter: + """Tests for validate_adapter orchestrator.""" + + def test_all_pass_for_valid_adapter(self, valid_adapter_cls: type) -> None: + results = validate_adapter(valid_adapter_cls, "test.module") + assert all(r.passed for r in results) + + def test_mixed_fail_for_empty_name(self, empty_name_adapter_cls: type) -> None: + results = validate_adapter(empty_name_adapter_cls, "test.module") + failed = [r for r in results if not r.passed] + assert len(failed) >= 1 + assert any(r.check_name == "framework_name" for r in failed) + + def test_short_circuits_on_inheritance_failure(self, not_an_adapter_cls: type) -> None: + results = validate_adapter(not_an_adapter_cls, "test.module") + assert len(results) == 2 + assert not results[0].passed + + def test_result_count_for_valid_adapter(self, valid_adapter_cls: type) -> None: + results = validate_adapter(valid_adapter_cls, "test.module") + assert len(results) == 7 diff --git a/test/unit/cli/test_cli_main.py b/test/unit/cli/test_cli_main.py new file mode 100644 index 0000000..bcd5df9 --- /dev/null +++ b/test/unit/cli/test_cli_main.py @@ -0,0 +1,41 @@ +"""Unit tests for CLI main entry point.""" + +from __future__ import annotations + +from unittest import mock + +import pytest + +from agent_assembly.cli.main import main + + +class TestCliMainExitCodeZero: + """Tests for CLI main returning exit code 0.""" + + def test_valid_adapter_exits_zero(self) -> None: + with mock.patch( + "sys.argv", + ["aasm", "adapter", "validate", "agent_assembly.adapters.langchain.adapter"], + ): + with pytest.raises(SystemExit) as exc_info: + main() + assert exc_info.value.code == 0 + + +class TestCliMainExitCodeOne: + """Tests for CLI main returning exit code 1.""" + + def test_nonexistent_module_exits_one(self) -> None: + with mock.patch( + "sys.argv", + ["aasm", "adapter", "validate", "nonexistent.module"], + ): + with pytest.raises(SystemExit) as exc_info: + main() + assert exc_info.value.code == 1 + + def test_no_command_exits_one(self) -> None: + with mock.patch("sys.argv", ["aasm"]): + with pytest.raises(SystemExit) as exc_info: + main() + assert exc_info.value.code == 1 diff --git a/test/unit/cli/test_loader.py b/test/unit/cli/test_loader.py new file mode 100644 index 0000000..de22781 --- /dev/null +++ b/test/unit/cli/test_loader.py @@ -0,0 +1,100 @@ +"""Unit tests for adapter class loader functions.""" + +from __future__ import annotations + +import pytest + +from agent_assembly.cli.adapter_validator import ( + load_adapter_class, + load_adapter_class_from_module, + load_adapter_class_from_path, +) + + +class TestLoadAdapterClassFromModule: + """Tests for load_adapter_class_from_module.""" + + def test_valid_module(self) -> None: + cls = load_adapter_class_from_module("agent_assembly.adapters.langchain.adapter") + from agent_assembly.adapters.base import FrameworkAdapter + + assert issubclass(cls, FrameworkAdapter) + + def test_invalid_module_raises(self) -> None: + with pytest.raises(ImportError): + load_adapter_class_from_module("nonexistent.module.path") + + def test_module_with_no_adapter_raises(self) -> None: + with pytest.raises(ValueError, match="No FrameworkAdapter subclass"): + load_adapter_class_from_module("agent_assembly.exceptions") + + +class TestLoadAdapterClassFromPath: + """Tests for load_adapter_class_from_path.""" + + def test_valid_file_path(self, tmp_path: object) -> None: + import pathlib + + assert isinstance(tmp_path, pathlib.Path) + adapter_file = tmp_path / "my_adapter.py" + adapter_file.write_text( + "from agent_assembly.adapters.base import FrameworkAdapter, GovernanceInterceptor\n" + "\n" + "class MyAdapter(FrameworkAdapter):\n" + " def get_framework_name(self) -> str:\n" + " return 'my_framework'\n" + " def get_supported_versions(self) -> list[str]:\n" + " return ['>=1.0.0']\n" + " def register_hooks(self, interceptor: GovernanceInterceptor) -> None:\n" + " pass\n" + " def unregister_hooks(self) -> None:\n" + " pass\n" + ) + cls = load_adapter_class_from_path(str(adapter_file)) + from agent_assembly.adapters.base import FrameworkAdapter + + assert issubclass(cls, FrameworkAdapter) + + def test_invalid_path_raises(self) -> None: + with pytest.raises(FileNotFoundError): + load_adapter_class_from_path("/nonexistent/path/adapter.py") + + def test_file_with_no_adapter_raises(self, tmp_path: object) -> None: + import pathlib + + assert isinstance(tmp_path, pathlib.Path) + empty_file = tmp_path / "empty.py" + empty_file.write_text("x = 1\n") + with pytest.raises(ValueError, match="No FrameworkAdapter subclass"): + load_adapter_class_from_path(str(empty_file)) + + +class TestLoadAdapterClass: + """Tests for load_adapter_class dispatcher.""" + + def test_dispatches_to_path_for_existing_file(self, tmp_path: object) -> None: + import pathlib + + assert isinstance(tmp_path, pathlib.Path) + adapter_file = tmp_path / "my_adapter.py" + adapter_file.write_text( + "from agent_assembly.adapters.base import FrameworkAdapter, GovernanceInterceptor\n" + "\n" + "class MyAdapter(FrameworkAdapter):\n" + " def get_framework_name(self) -> str:\n" + " return 'my_framework'\n" + " def get_supported_versions(self) -> list[str]:\n" + " return ['>=1.0.0']\n" + " def register_hooks(self, interceptor: GovernanceInterceptor) -> None:\n" + " pass\n" + " def unregister_hooks(self) -> None:\n" + " pass\n" + ) + cls = load_adapter_class(str(adapter_file)) + assert cls.__name__ == "MyAdapter" + + def test_dispatches_to_module_for_dotted_name(self) -> None: + cls = load_adapter_class("agent_assembly.adapters.langchain.adapter") + from agent_assembly.adapters.base import FrameworkAdapter + + assert issubclass(cls, FrameworkAdapter) diff --git a/test/unit/cli/test_output.py b/test/unit/cli/test_output.py new file mode 100644 index 0000000..8f847f8 --- /dev/null +++ b/test/unit/cli/test_output.py @@ -0,0 +1,37 @@ +"""Unit tests for output formatting.""" + +from __future__ import annotations + +from agent_assembly.cli.adapter_validator import AdapterValidationResult +from agent_assembly.cli.output import format_results + + +class TestFormatResults: + """Tests for format_results.""" + + def test_all_pass_output(self) -> None: + results = [ + AdapterValidationResult(check_name="check_a", passed=True, message="ok"), + AdapterValidationResult(check_name="check_b", passed=True, message="ok"), + ] + output = format_results(results) + assert "[PASS]" in output + assert "[FAIL]" not in output + assert "2 passed, 0 failed" in output + + def test_mixed_output(self) -> None: + results = [ + AdapterValidationResult(check_name="check_a", passed=True, message="ok"), + AdapterValidationResult(check_name="check_b", passed=False, message="bad"), + ] + output = format_results(results) + assert "[PASS]" in output + assert "[FAIL]" in output + assert "1 passed, 1 failed" in output + + def test_pass_fail_prefix_format(self) -> None: + results = [ + AdapterValidationResult(check_name="my_check", passed=True, message="m"), + ] + output = format_results(results) + assert " [PASS] my_check: m" in output