diff --git a/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py b/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py index ac1a01fa50fa9..91f801570cee1 100644 --- a/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py +++ b/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py @@ -14,6 +14,7 @@ ResponseT, StateT, ) +from langchain.tools import ToolRuntime class Action(TypedDict): @@ -130,6 +131,18 @@ def __call__( ... +class _InterruptWhen(Protocol): + """Predicate that decides whether a tool call should trigger an interrupt.""" + + def __call__( + self, + tool_call: ToolCall, + runtime: ToolRuntime[ContextT, Any], + ) -> bool: + """Return `True` to interrupt this tool call, `False` to auto-approve it.""" + ... + + class InterruptOnConfig(TypedDict): """Configuration for an action requiring human in the loop. @@ -178,6 +191,24 @@ def format_tool_description( args_schema: NotRequired[dict[str, Any]] """JSON schema for the args associated with the action, if edits are allowed.""" + interrupt_when: NotRequired[_InterruptWhen] + """Optional predicate gating whether this tool call triggers an interrupt. + + Called with the proposed `ToolCall` and a `ToolRuntime` constructed by the + middleware. Return `True` to interrupt (with this config's `allowed_decisions`) + or `False` to auto-approve the call as if the tool were not listed in + `interrupt_on`. + + Predicates must be synchronous and deterministic — LangGraph interrupt replay + on resume requires the same interrupts to fire on each evaluation. Exceptions + propagate. + + Note: the `ToolRuntime` constructed for predicate evaluation has three + deviations from a tool-time `ToolRuntime`: `tools=[]`, `config={}`, and + `execution_info`/`server_info=None`. Predicates should rely on `state`, + `tool_call_id`, `context`, and `store` rather than these fields. + """ + class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT, ResponseT]): """Human in the loop middleware.""" @@ -201,8 +232,15 @@ def __init__( * `InterruptOnConfig` indicates the specific decisions allowed for this tool. - The `InterruptOnConfig` can include a `description` field (`str` or - `Callable`) for custom formatting of the interrupt description. + The `InterruptOnConfig` can include: + - a `description` field (`str` or `Callable`) for custom formatting + of the interrupt description, and + - an `interrupt_when` predicate `(ToolCall, ToolRuntime) -> bool`. + When set, only calls for which the predicate returns `True` + interrupt. Calls returning `False` are auto-approved exactly as + if the tool were not listed in `interrupt_on`. Predicates must be + synchronous and deterministic (LangGraph interrupt replay re-runs + them on resume). description_prefix: The prefix to use when constructing action requests. This is used to provide context about the tool call and the action being @@ -223,6 +261,34 @@ def __init__( self.interrupt_on = resolved_configs self.description_prefix = description_prefix + def _should_interrupt( + self, + tool_call: ToolCall, + config: InterruptOnConfig, + state: AgentState[Any], + runtime: Runtime[ContextT], + ) -> bool: + """Evaluate the per-call predicate, if any. + + Returns `True` when the call should interrupt, `False` to auto-approve. + When no `interrupt_when` is configured, returns `True` (unchanged behavior). + """ + interrupt_when = config.get("interrupt_when") + if interrupt_when is None: + return True + tool_runtime: ToolRuntime[ContextT, Any] = ToolRuntime( + state=state, + context=runtime.context, + config={}, + stream_writer=runtime.stream_writer, + tool_call_id=tool_call["id"], + store=runtime.store, + tools=[], + execution_info=None, + server_info=None, + ) + return interrupt_when(tool_call, tool_runtime) + def _create_action_and_config( self, tool_call: ToolCall, @@ -340,13 +406,17 @@ def after_model( interrupt_indices: list[int] = [] for idx, tool_call in enumerate(last_ai_msg.tool_calls): - if (config := self.interrupt_on.get(tool_call["name"])) is not None: - action_request, review_config = self._create_action_and_config( - tool_call, config, state, runtime - ) - action_requests.append(action_request) - review_configs.append(review_config) - interrupt_indices.append(idx) + config = self.interrupt_on.get(tool_call["name"]) + if config is None: + continue + if not self._should_interrupt(tool_call, config, state, runtime): + continue + action_request, review_config = self._create_action_and_config( + tool_call, config, state, runtime + ) + action_requests.append(action_request) + review_configs.append(review_config) + interrupt_indices.append(idx) # If no interrupts needed, return early if not action_requests: diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_human_in_the_loop.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_human_in_the_loop.py index 5e03b0bed55ca..43fb4aa726ee3 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_human_in_the_loop.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_human_in_the_loop.py @@ -12,6 +12,7 @@ HumanInTheLoopMiddleware, ) from langchain.agents.middleware.types import AgentState +from langchain.tools import ToolRuntime def test_human_in_the_loop_middleware_initialization() -> None: @@ -883,3 +884,231 @@ def mock_mixed_responses(_: Any) -> dict[str, Any]: assert isinstance(tool_message, ToolMessage) assert tool_message.content == "Rejected tool B" assert tool_message.tool_call_id == "id_b" + + +def test_interrupt_on_config_accepts_interrupt_when() -> None: + """`InterruptOnConfig` accepts an optional `interrupt_when` predicate.""" + config: InterruptOnConfig = { + "allowed_decisions": ["approve", "reject"], + "interrupt_when": lambda _tc, _rt: True, + } + assert config["interrupt_when"](None, None) is True # type: ignore[arg-type] + + +def test_interrupt_when_false_auto_approves() -> None: + """`interrupt_when` returning `False` skips the interrupt entirely.""" + middleware = HumanInTheLoopMiddleware( + interrupt_on={ + "edit_file": { + "allowed_decisions": ["approve", "reject"], + "interrupt_when": lambda _tc, _rt: False, + } + } + ) + ai_message = AIMessage( + content="Writing safe file", + tool_calls=[{"name": "edit_file", "args": {"path": "/safe/path"}, "id": "1"}], + ) + state = AgentState[Any](messages=[HumanMessage(content="Hi"), ai_message]) + + def fail_if_called(_: Any) -> dict[str, Any]: + msg = "interrupt() should not be called when predicate returns False" + raise AssertionError(msg) + + with patch( + "langchain.agents.middleware.human_in_the_loop.interrupt", + side_effect=fail_if_called, + ): + result = middleware.after_model(state, Runtime()) + + assert result is None + + +def test_interrupt_when_true_interrupts() -> None: + """`interrupt_when` returning `True` interrupts with the configured decisions.""" + middleware = HumanInTheLoopMiddleware( + interrupt_on={ + "edit_file": { + "allowed_decisions": ["approve", "reject"], + "interrupt_when": lambda _tc, _rt: True, + } + } + ) + ai_message = AIMessage( + content="Writing protected file", + tool_calls=[{"name": "edit_file", "args": {"path": "/etc/secret"}, "id": "1"}], + ) + state = AgentState[Any](messages=[HumanMessage(content="Hi"), ai_message]) + + captured: dict[str, Any] = {} + + def mock_approve(request: Any) -> dict[str, Any]: + captured["request"] = request + return {"decisions": [{"type": "approve"}]} + + with patch( + "langchain.agents.middleware.human_in_the_loop.interrupt", + side_effect=mock_approve, + ): + result = middleware.after_model(state, Runtime()) + + assert result is not None + assert len(captured["request"]["action_requests"]) == 1 + assert captured["request"]["action_requests"][0]["name"] == "edit_file" + assert captured["request"]["review_configs"][0]["allowed_decisions"] == ["approve", "reject"] + assert result["messages"][0].tool_calls[0]["id"] == "1" + + +def test_interrupt_when_mixed_calls_same_tool() -> None: + """One protected-path call interrupts, one wiki call auto-approves, order preserved.""" + middleware = HumanInTheLoopMiddleware( + interrupt_on={ + "edit_file": { + "allowed_decisions": ["approve", "reject"], + "interrupt_when": lambda tc, _rt: bool( + re.match(r"^/etc/", tc["args"].get("path", "")) + ), + } + } + ) + ai_message = AIMessage( + content="Two writes", + tool_calls=[ + {"name": "edit_file", "args": {"path": "/etc/secret"}, "id": "1"}, + {"name": "edit_file", "args": {"path": "/wiki/page"}, "id": "2"}, + ], + ) + state = AgentState[Any](messages=[HumanMessage(content="Hi"), ai_message]) + + captured: dict[str, Any] = {} + + def mock_approve(request: Any) -> dict[str, Any]: + captured["request"] = request + return {"decisions": [{"type": "approve"}]} + + with patch( + "langchain.agents.middleware.human_in_the_loop.interrupt", + side_effect=mock_approve, + ): + result = middleware.after_model(state, Runtime()) + + assert len(captured["request"]["action_requests"]) == 1 + assert captured["request"]["action_requests"][0]["args"] == {"path": "/etc/secret"} + + assert result is not None + revised = result["messages"][0].tool_calls + assert [tc["id"] for tc in revised] == ["1", "2"] + assert [tc["args"]["path"] for tc in revised] == ["/etc/secret", "/wiki/page"] + + +def test_interrupt_when_mixed_configured_tools() -> None: + """Unlisted, listed-with-false, and listed-with-true tools all behave correctly.""" + middleware = HumanInTheLoopMiddleware( + interrupt_on={ + "edit_file": { + "allowed_decisions": ["approve"], + "interrupt_when": lambda _tc, _rt: False, + }, + "delete_file": { + "allowed_decisions": ["approve", "reject"], + "interrupt_when": lambda _tc, _rt: True, + }, + } + ) + ai_message = AIMessage( + content="Three calls", + tool_calls=[ + {"name": "search", "args": {"q": "x"}, "id": "1"}, + {"name": "edit_file", "args": {"path": "/wiki/page"}, "id": "2"}, + {"name": "delete_file", "args": {"path": "/wiki/page"}, "id": "3"}, + ], + ) + state = AgentState[Any](messages=[HumanMessage(content="Hi"), ai_message]) + + captured: dict[str, Any] = {} + + def mock_approve(request: Any) -> dict[str, Any]: + captured["request"] = request + return {"decisions": [{"type": "approve"}]} + + with patch( + "langchain.agents.middleware.human_in_the_loop.interrupt", + side_effect=mock_approve, + ): + result = middleware.after_model(state, Runtime()) + + assert [a["name"] for a in captured["request"]["action_requests"]] == ["delete_file"] + assert result is not None + revised = result["messages"][0].tool_calls + assert [tc["id"] for tc in revised] == ["1", "2", "3"] + + +def test_interrupt_when_tool_runtime_fields() -> None: + """Predicate receives a `ToolRuntime` with correct fields and documented deviations.""" + captured: dict[str, Any] = {} + + def capture_predicate(tc: ToolCall, rt: ToolRuntime[Any, Any]) -> bool: + captured["tool_call"] = tc + captured["runtime"] = rt + return False + + middleware = HumanInTheLoopMiddleware( + interrupt_on={ + "edit_file": { + "allowed_decisions": ["approve"], + "interrupt_when": capture_predicate, + } + } + ) + ai_message = AIMessage( + content="One call", + tool_calls=[{"name": "edit_file", "args": {"path": "/x"}, "id": "abc"}], + ) + state = AgentState[Any](messages=[HumanMessage(content="Hi"), ai_message]) + + sentinel_context = object() + sentinel_store = object() + runtime = Runtime(context=sentinel_context, store=sentinel_store) # type: ignore[arg-type] + + middleware.after_model(state, runtime) + + rt = captured["runtime"] + assert rt.tool_call_id == "abc" + assert rt.state is state + assert rt.context is sentinel_context + assert rt.store is sentinel_store + assert rt.stream_writer is runtime.stream_writer + assert rt.tools == [] + assert rt.config == {} + assert rt.execution_info is None + assert rt.server_info is None + assert captured["tool_call"]["id"] == "abc" + assert captured["tool_call"]["args"] == {"path": "/x"} + + +def test_interrupt_when_exceptions_propagate() -> None: + """A predicate that raises propagates the exception (no silent auto-approval).""" + + class PredicateBugError(RuntimeError): + pass + + def bad_predicate(_tc: Any, _rt: Any) -> bool: + msg = "kaboom" + raise PredicateBugError(msg) + + middleware = HumanInTheLoopMiddleware( + interrupt_on={ + "edit_file": { + "allowed_decisions": ["approve"], + "interrupt_when": bad_predicate, + } + } + ) + ai_message = AIMessage( + content="One call", + tool_calls=[{"name": "edit_file", "args": {"path": "/x"}, "id": "1"}], + ) + state = AgentState[Any](messages=[HumanMessage(content="Hi"), ai_message]) + + with pytest.raises(PredicateBugError, match="kaboom"): + middleware.after_model(state, Runtime())