diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index 5a3433ef7e60e..3336af3ea1b63 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -253,6 +253,35 @@ async def wrapped(*args: Any, **kwargs: Any) -> Any: return cast("Func", wrapped) +async def _achat_model_start_fallback( + coro: Coroutine[Any, Any, Any], + handler: BaseCallbackHandler, + *args: Any, + **kwargs: Any, +) -> None: + """Wrap an async `on_chat_model_start` coroutine with fallback. + + Catches `NotImplementedError` and triggers the `on_llm_start` fallback. + This covers async handlers invoked from a **sync** `handle_event` call, + where the coroutine is collected into `coros` and executed later by + `_run_coros`. Without this wrapper the `NotImplementedError` would be + caught generically by `_run_coros` and the trace would be lost. + """ + try: + await coro + except NotImplementedError: + message_strings = [get_buffer_string(m) for m in args[1]] + await _ahandle_event_for_handler( + handler, + "on_llm_start", + "ignore_llm", + args[0], + message_strings, + *args[2:], + **kwargs, + ) + + def handle_event( handlers: list[BaseCallbackHandler], event_name: str, @@ -282,6 +311,10 @@ def handle_event( ): event = getattr(handler, event_name)(*args, **kwargs) if asyncio.iscoroutine(event): + if event_name == "on_chat_model_start": + event = _achat_model_start_fallback( + event, handler, *args, **kwargs + ) coros.append(event) except NotImplementedError as e: if event_name == "on_chat_model_start": diff --git a/libs/core/tests/unit_tests/callbacks/test_handle_event.py b/libs/core/tests/unit_tests/callbacks/test_handle_event.py new file mode 100644 index 0000000000000..3bf3e9cec40b1 --- /dev/null +++ b/libs/core/tests/unit_tests/callbacks/test_handle_event.py @@ -0,0 +1,155 @@ +"""Tests for the `on_chat_model_start` fallback in `handle_event`.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any +from uuid import UUID, uuid4 + +from langchain_core.callbacks.manager import handle_event +from langchain_core.messages import HumanMessage, SystemMessage +from langchain_core.tracers.base import AsyncBaseTracer + +if TYPE_CHECKING: + import pytest + + from langchain_core.tracers.schemas import Run + +SERIALIZED = {"id": ["chat_model"]} + + +class _NoOpAsyncTracer(AsyncBaseTracer): + """Async tracer that does NOT override `on_chat_model_start`. + + Records `on_llm_start` calls so the test can verify the fallback fired. + """ + + def __init__(self) -> None: + super().__init__() + self.runs: list[Run] = [] + self.llm_start_calls: list[dict[str, Any]] = [] + + async def _persist_run(self, run: Run) -> None: + self.runs.append(run) + + async def on_llm_start( + self, + serialized: dict[str, Any], + prompts: list[str], + *, + run_id: UUID, + **_kwargs: Any, + ) -> None: + self.llm_start_calls.append( + { + "serialized": serialized, + "prompts": prompts, + "run_id": run_id, + } + ) + + +class _WorkingAsyncTracer(AsyncBaseTracer): + """Async tracer that DOES override `on_chat_model_start`. + + Used to verify the normal (non-fallback) path still works. + """ + + def __init__(self) -> None: + super().__init__() + self.runs: list[Run] = [] + self.chat_model_start_calls: list[dict[str, Any]] = [] + + async def _persist_run(self, run: Run) -> None: + self.runs.append(run) + + async def on_chat_model_start( + self, + serialized: dict[str, Any], + messages: list[list[Any]], + *, + run_id: UUID, + **_kwargs: Any, + ) -> None: + self.chat_model_start_calls.append( + { + "serialized": serialized, + "messages": messages, + "run_id": run_id, + } + ) + + +def test_async_tracer_falls_back_to_on_llm_start_in_sync_context() -> None: + """Async tracer without `on_chat_model_start` falls back. + + When `handle_event` is called synchronously with an + `AsyncBaseTracer` that doesn't implement `on_chat_model_start`, + the `on_llm_start` callback should fire as a fallback. + """ + tracer = _NoOpAsyncTracer() + run_id = uuid4() + messages = [[SystemMessage(content="sys"), HumanMessage(content="hi")]] + + handle_event( + [tracer], + "on_chat_model_start", + "ignore_chat_model", + SERIALIZED, + messages, + run_id=run_id, + ) + + assert len(tracer.llm_start_calls) == 1 + call = tracer.llm_start_calls[0] + assert call["serialized"] == SERIALIZED + # The fallback converts messages to strings via get_buffer_string + assert isinstance(call["prompts"], list) + assert len(call["prompts"]) == 1 + assert isinstance(call["prompts"][0], str) + + +def test_async_tracer_no_fallback_when_implemented() -> None: + """Async tracer WITH `on_chat_model_start` does not fall back. + + When the handler implements `on_chat_model_start`, no fallback + should be triggered. + """ + tracer = _WorkingAsyncTracer() + run_id = uuid4() + messages = [[HumanMessage(content="hello")]] + + handle_event( + [tracer], + "on_chat_model_start", + "ignore_chat_model", + SERIALIZED, + messages, + run_id=run_id, + ) + + assert len(tracer.chat_model_start_calls) == 1 + call = tracer.chat_model_start_calls[0] + assert call["serialized"] == SERIALIZED + assert call["messages"] == messages + + +def test_async_tracer_fallback_no_error_logged( + caplog: pytest.LogCaptureFixture, +) -> None: + """The fallback path should not produce any warning/error logs.""" + tracer = _NoOpAsyncTracer() + messages = [[HumanMessage(content="test")]] + + with caplog.at_level("WARNING", logger="langchain_core.callbacks.manager"): + handle_event( + [tracer], + "on_chat_model_start", + "ignore_chat_model", + SERIALIZED, + messages, + run_id=uuid4(), + ) + + assert not caplog.records, ( + f"Expected no warnings but got: {[r.message for r in caplog.records]}" + )