diff --git a/CONTEXT.md b/CONTEXT.md index 692792b..dc17aab 100644 --- a/CONTEXT.md +++ b/CONTEXT.md @@ -68,7 +68,9 @@ User's Python app (this library) ``` src/ - __init__.py # Public API exports: KnowledgeBase, KnowledgeBaseBuilder, KnowledgeBaseSettings + __init__.py # Public API exports: KnowledgeBase, KnowledgeBaseBuilder, KnowledgeBaseSettings, Depends + depends.py # Depends — DI marker for handler parameter annotations + di.py # resolve_dependencies — resolves Depends-annotated params at call time knowledge_base.py # KnowledgeBase class — the main user-facing class knowledge_base_builder.py # KnowledgeBaseBuilder — settings-aware builder that wraps KnowledgeBase knowledge_interaction.py # KnowledgeInteractionContext, Handler type, status enum @@ -174,6 +176,36 @@ kb.start_handling_loop(loops=10) # Runs exactly 10 poll iterations (useful for --- +### `Depends` — Dependency injection + +Handlers can declare dependencies (database connections, HTTP clients, config, etc.) using +`Depends()` in `Annotated` type hints. The framework resolves them at call time. + +```python +from typing import Annotated +from src import Depends + +def get_db() -> MyDatabase: + return MyDatabase(url="...") + +@kb.answer_ki(name="...", graph_pattern="...") +def handler( + binding_set: list[PersonBinding], + info: KnowledgeInteractionInfo, + db: Annotated[MyDatabase, Depends(get_db)], +) -> list[PersonBinding]: + return db.query(binding_set) +``` + +**Behaviour:** +- The framework inspects handler signatures at registration time and resolves `Depends` params at call time. +- Dependency factories are **sync-only** (async support is out of scope). +- Factories can themselves declare `Depends` parameters — nested/transitive resolution is supported. +- `cache=True` (default): factory called once per KI-call invocation; result shared across all uses. +- `cache=False`: factory called fresh every time it is needed. + +--- + ### `KnowledgeBaseBuilder` Returned by `KnowledgeBase.from_settings()`. Wraps a `KnowledgeBase` internally and exposes @@ -385,3 +417,4 @@ These are excluded from linting (`ruff`) and are kept for historical reference o - **KI registry indexed by ID after registration**: `KnowledgeBase` maintains a secondary index (`_ki_registry_by_id`) populated once a KI is registered with the SC and assigned an ID. The handling loop dispatches by ID using this index. - **Handler introspection**: `KnowledgeInteractionContext.__post_init__` inspects handler signatures to auto-detect binding models, enabling transparent (de)serialization without manual type dispatch. Dispatch logic (validate → call → serialize for ANSWER/REACT; prepare_outgoing + parse_result for ASK/POST) lives in `KnowledgeInteractionContext`, not in `KnowledgeBase`. - **`KnowledgeBaseBuilder` wraps `KnowledgeBase`**: Settings-based KI registration belongs to `KnowledgeBaseBuilder`, not to `KnowledgeBase`. `KnowledgeBase.from_settings()` returns a builder; `builder.build()` returns the finished `KnowledgeBase`. `KnowledgeBase` itself has no knowledge of settings. ASK/POST KIs are auto-registered at `build()` time; ANSWER/REACT KIs require a handler attached via `builder.handler(name, func)` before `build()` is called. +- **Dependency injection via `Depends`**: `KnowledgeBase.call()` calls `resolve_dependencies(handler)` before invoking the handler, passing resolved values as kwargs. The resolver (`src/di.py`) uses `get_type_hints(include_extras=True)` to find `Annotated[T, Depends(factory)]` params, recursively resolves factory deps (transitive), and caches results per invocation when `cache=True`. `@wraps` on the decorator wrapper preserves `__annotations__`, so the resolver sees the original handler's hints. diff --git a/examples/dependency_injection.py b/examples/dependency_injection.py new file mode 100644 index 0000000..a4d6add --- /dev/null +++ b/examples/dependency_injection.py @@ -0,0 +1,174 @@ +"""Dependency injection example. + +Demonstrates how to declare dependencies on handler functions using +``Depends()`` so that external resources (config, database connections, +HTTP clients, …) are injected by the framework at call time rather than +hard-coded as globals or closures. + +Scenario: a KB that answers queries about sensor readings. Two +dependencies are wired together: + + 1. ``AppConfig`` — loaded once via ``get_config`` (cache=True, the + default). + 2. ``SensorRepository`` — constructed from the config via + ``get_sensor_repository``, which itself declares a ``Depends`` on + ``get_config`` (transitive resolution). + +Because both the handler and the repository factory depend on the same +``get_config`` factory, the config object is built only once per KI call. +""" + +import sys +from pathlib import Path +from typing import Annotated + +sys.path.insert(0, str(Path(__file__).parent)) + +from shared import get_example_logger + +from src import Depends, KnowledgeBase +from src.ke.models import BindingModel, KnowledgeInteractionInfo, Literal, Uri + +EXAMPLE_NAME = "dependency-injection" +logger = get_example_logger(EXAMPLE_NAME) + +EX = "http://example.org/knowledge-mapper/dependency-injection#" + +# --------------------------------------------------------------------------- +# Application-level resources +# --------------------------------------------------------------------------- + + +class AppConfig: + """Holds application configuration.""" + + def __init__(self, db_url: str = "sqlite:///sensors.db"): + self.db_url = db_url + + +class SensorRepository: + """A (fake) repository backed by a database connection.""" + + # Static in-memory data for this example + _READINGS: dict[str, float] = { + f"{EX}sensor1": 21.3, + f"{EX}sensor2": 19.8, + f"{EX}sensor3": 22.7, + } + + def __init__(self, db_url: str): + self.db_url = db_url + logger.debug("SensorRepository initialised (db_url=%s)", db_url) + + def get_reading(self, sensor_uri: str) -> float | None: + return self._READINGS.get(sensor_uri) + + def all_readings(self) -> dict[str, float]: + return dict(self._READINGS) + + +# --------------------------------------------------------------------------- +# Dependency factories +# --------------------------------------------------------------------------- + + +def get_config() -> AppConfig: + """Return the shared application config. + + In a real application this might load from environment variables or a + config file. With ``cache=True`` (the default) the framework calls this + factory only once per KI call and reuses the result for every parameter + that depends on it. + """ + return AppConfig(db_url="sqlite:///sensors.db") + + +def get_sensor_repository( + config: Annotated[AppConfig, Depends(get_config)], +) -> SensorRepository: + """Return a SensorRepository wired to the injected config. + + This factory itself declares a dependency on ``get_config``, demonstrating + *transitive* (nested) resolution. Because ``get_config`` uses + ``cache=True``, the same ``AppConfig`` instance is reused here and in any + other parameter of the same KI call that also depends on ``get_config``. + """ + return SensorRepository(db_url=config.db_url) + + +# --------------------------------------------------------------------------- +# Binding model +# --------------------------------------------------------------------------- + + +class SensorReadingBinding(BindingModel): + sensor: Uri + temperature: Literal[float] + + +# --------------------------------------------------------------------------- +# Knowledge Base +# --------------------------------------------------------------------------- + +kb = KnowledgeBase( + id=f"{EX}kb", + name="dependency-injection-kb", + description="A KB that demonstrates dependency injection via Depends().", + ke_url="http://localhost:8280/rest", +) + + +@kb.answer_ki( + name="sensor-readings-answer-ki", + graph_pattern=""" + ?sensor a ex:Sensor ; + ex:hasTemperature ?temperature . + """, + prefixes={"ex": EX}, +) +def answer_sensor_readings( + binding_set: list[SensorReadingBinding], + info: KnowledgeInteractionInfo, + repo: Annotated[SensorRepository, Depends(get_sensor_repository)], + config: Annotated[AppConfig, Depends(get_config)], +) -> list[SensorReadingBinding]: + """Answer queries about sensor temperatures. + + ``repo`` and ``config`` are injected by the framework. Because + ``get_config`` is ``cache=True``, the *same* ``AppConfig`` instance is + passed to both ``get_sensor_repository`` and directly to this handler — it + is constructed only once. + + ``binding_set`` may contain partial bindings (sensor URI provided, + temperature unknown) or be empty (return all sensors). + """ + logger.info( + "Handling sensor-readings query (db=%s, incoming=%d bindings)", + config.db_url, + len(binding_set), + ) + + if binding_set: + # Filtered query: only return the requested sensors + results = [] + for b in binding_set: + sensor_uri = str(b.sensor) + temperature = repo.get_reading(sensor_uri) + if temperature is not None: + results.append( + SensorReadingBinding(sensor=b.sensor, temperature=temperature) + ) + return results + + # Open query: return all known sensors + return [ + SensorReadingBinding(sensor=uri, temperature=temp) # type: ignore[arg-type] + for uri, temp in repo.all_readings().items() + ] + + +if __name__ == "__main__": + kb.connect() + kb.register() + logger.info("Registered the dependency-injection example KB!") + kb.unregister() diff --git a/src/__init__.py b/src/__init__.py index 98df6c3..ac00237 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1,5 +1,6 @@ import logging +from .dependency_injection import Depends from .kb.builder import KnowledgeBaseBuilder from .kb.knowledge_base import KnowledgeBase from .settings import KnowledgeBaseSettings diff --git a/src/dependency_injection.py b/src/dependency_injection.py new file mode 100644 index 0000000..3ad63f4 --- /dev/null +++ b/src/dependency_injection.py @@ -0,0 +1,94 @@ +"""Dependency injection resolver for KI handler parameters.""" + +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any, get_args, get_type_hints + + +@dataclass +class Depends: + """Mark a handler parameter as a resolved dependency. + + Usage:: + + def get_db() -> MyDatabase: + return MyDatabase(url="...") + + @kb.answer_ki(name="...", graph_pattern="...") + def handler( + binding_set: list[PersonBinding], + info: KnowledgeInteractionInfo, + db: Annotated[MyDatabase, Depends(get_db)], + ) -> list[PersonBinding]: + return db.query(binding_set) + + Args: + factory: A callable (sync) that returns the dependency value. The + factory may itself declare ``Annotated[T, Depends(...)]`` parameters + for nested/transitive resolution. + cache: When ``True`` (the default) the factory is called at most once + per KI-call invocation and the result is shared across all + parameters that reference the same factory. When ``False`` the + factory is called fresh every time it is needed. + """ + + factory: Callable[..., Any] + cache: bool = field(default=True) + + +def _get_dep_params(func: Callable[..., Any]) -> dict[str, Depends]: + """Return a mapping of parameter-name → Depends for all Annotated Depends params.""" + try: + hints = get_type_hints(func, include_extras=True) + except Exception: + return {} + + dep_params: dict[str, Depends] = {} + for name, hint in hints.items(): + if name == "return": + continue + if not hasattr(hint, "__metadata__"): + continue + for meta in get_args(hint)[1:]: + if isinstance(meta, Depends): + dep_params[name] = meta + break + return dep_params + + +def resolve_dependencies( + func: Callable[..., Any], + cache: dict[Callable[..., Any], Any] | None = None, +) -> dict[str, Any]: + """Resolve all ``Annotated[T, Depends(...)]`` parameters of *func*. + + Args: + func: The callable whose parameters should be inspected. + cache: A per-call cache mapping factory → resolved value. Pass the + same dict for all calls within a single KI invocation so that + ``cache=True`` factories are called at most once. Pass ``None`` + to start fresh (a new empty dict will be created). + + Returns: + A dict mapping parameter name → resolved value for every + ``Depends``-annotated parameter found in *func*'s signature. + """ + if cache is None: + cache = {} + + dep_params = _get_dep_params(func) + resolved: dict[str, Any] = {} + for param_name, dep in dep_params.items(): + factory = dep.factory + if dep.cache and factory in cache: + resolved[param_name] = cache[factory] + else: + # Recursively resolve factory's own dependencies first + factory_kwargs = resolve_dependencies(factory, cache) + value = factory(**factory_kwargs) + if dep.cache: + cache[factory] = value + resolved[param_name] = value + return resolved diff --git a/src/kb/knowledge_base.py b/src/kb/knowledge_base.py index 692e0af..2e0cbfd 100644 --- a/src/kb/knowledge_base.py +++ b/src/kb/knowledge_base.py @@ -6,6 +6,7 @@ from functools import wraps from typing import TYPE_CHECKING, Any +from ..dependency_injection import resolve_dependencies from ..ke import Client from ..ke.client import ClientProtocol, PollResult from ..ke.errors import KnowledgeEngineNotAvailableError @@ -388,13 +389,15 @@ def call(self, binding_set: BindingSet, ki_name: str) -> BindingSet: ki_ctx = self.ki_registry[ki_name] assert ki_ctx.handler is not None # Should always be set for ANSWER/REACT KI's + dep_kwargs = resolve_dependencies(ki_ctx.handler) + if ki_ctx.validation_model: binding_models = [ ki_ctx.validation_model.model_validate(b) for b in binding_set ] - result_bindings = ki_ctx.handler(binding_models, ki_ctx.info) + result_bindings = ki_ctx.handler(binding_models, ki_ctx.info, **dep_kwargs) else: - result_bindings = ki_ctx.handler(binding_set, ki_ctx.info) + result_bindings = ki_ctx.handler(binding_set, ki_ctx.info, **dep_kwargs) if ki_ctx.serialization_model and result_bindings: # We can assume the result bindings are BindingModels, so we can model_dump diff --git a/tests/test_dependency_injection.py b/tests/test_dependency_injection.py new file mode 100644 index 0000000..32c5274 --- /dev/null +++ b/tests/test_dependency_injection.py @@ -0,0 +1,137 @@ +from typing import Annotated + +import pytest + +from src import Depends +from src.kb.knowledge_base import KnowledgeBase +from src.ke.models import BindingSet + + +@pytest.fixture +def kb(): + return KnowledgeBase( + id="http://example.org/test#kb", + name="test-kb", + description="A KB for testing.", + ke_url="http://fake-ke", + ) + + +# --------------------------------------------------------------------------- +# Tracer bullet: basic injection +# --------------------------------------------------------------------------- + + +def test_handler_receives_injected_dependency(kb: KnowledgeBase): + """Handler with a Depends-annotated param receives the factory's return value.""" + + class FakeDb: + def query(self): + return "db-result" + + def get_db() -> FakeDb: + return FakeDb() + + @kb.answer_ki(name="test-ki", graph_pattern="?s ?p ?o .") + def handler( + binding_set: BindingSet, + info, + db: Annotated[FakeDb, Depends(get_db)], + ) -> BindingSet: + return [{"result": db.query()}] + + result = kb.call([], "test-ki") + assert result == [{"result": "db-result"}] + + +# --------------------------------------------------------------------------- +# cache=True: factory called once per KI call even when used by multiple deps +# --------------------------------------------------------------------------- + + +def test_cached_dependency_factory_called_once(kb: KnowledgeBase): + """With cache=True (default), a shared factory is called only once per KI call.""" + call_count = 0 + + def get_db(): + nonlocal call_count + call_count += 1 + return object() + + def get_service(db: Annotated[object, Depends(get_db)]): + return db # just passes it through + + @kb.answer_ki(name="cache-ki", graph_pattern="?s ?p ?o .") + def handler( + binding_set: BindingSet, + info, + db: Annotated[object, Depends(get_db)], + svc: Annotated[object, Depends(get_service)], + ) -> BindingSet: + # both db and svc.db should be the SAME object + assert db is svc + return [] + + kb.call([], "cache-ki") + assert call_count == 1 + + +# --------------------------------------------------------------------------- +# cache=False: factory called fresh every time +# --------------------------------------------------------------------------- + + +def test_uncached_dependency_factory_called_each_time(kb: KnowledgeBase): + """With cache=False, the factory is called fresh for every dependent param.""" + call_count = 0 + + def get_value(): + nonlocal call_count + call_count += 1 + return call_count # returns a unique value each call + + @kb.answer_ki(name="nocache-ki", graph_pattern="?s ?p ?o .") + def handler( + binding_set: BindingSet, + info, + a: Annotated[int, Depends(get_value, cache=False)], + b: Annotated[int, Depends(get_value, cache=False)], + ) -> BindingSet: + assert a != b # different values: factory called twice + return [] + + kb.call([], "nocache-ki") + assert call_count == 2 + + +# --------------------------------------------------------------------------- +# Transitive: dep factory itself has Depends params +# --------------------------------------------------------------------------- + + +def test_transitive_dependency_resolution(kb: KnowledgeBase): + """A factory that declares its own Depends params is resolved transitively.""" + + class Config: + url = "sqlite://:memory:" + + class Db: + def __init__(self, config: Config): + self.url = config.url + + def get_config() -> Config: + return Config() + + def get_db(config: Annotated[Config, Depends(get_config)]) -> Db: + return Db(config) + + @kb.answer_ki(name="transitive-ki", graph_pattern="?s ?p ?o .") + def handler( + binding_set: BindingSet, + info, + db: Annotated[Db, Depends(get_db)], + ) -> BindingSet: + return [{"url": db.url}] + + result = kb.call([], "transitive-ki") + assert result == [{"url": "sqlite://:memory:"}]