Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 79 additions & 9 deletions libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ResponseT,
StateT,
)
from langchain.tools import ToolRuntime


class Action(TypedDict):
Expand Down Expand Up @@ -130,6 +131,18 @@ def __call__(
...


class _InterruptWhen(Protocol):
"""Predicate that decides whether a tool call should trigger an interrupt."""

def __call__(
self,
tool_call: ToolCall,
runtime: ToolRuntime[ContextT, Any],
) -> bool:
"""Return `True` to interrupt this tool call, `False` to auto-approve it."""
...


class InterruptOnConfig(TypedDict):
"""Configuration for an action requiring human in the loop.

Expand Down Expand Up @@ -178,6 +191,24 @@ def format_tool_description(
args_schema: NotRequired[dict[str, Any]]
"""JSON schema for the args associated with the action, if edits are allowed."""

interrupt_when: NotRequired[_InterruptWhen]
"""Optional predicate gating whether this tool call triggers an interrupt.

Called with the proposed `ToolCall` and a `ToolRuntime` constructed by the
middleware. Return `True` to interrupt (with this config's `allowed_decisions`)
or `False` to auto-approve the call as if the tool were not listed in
`interrupt_on`.

Predicates must be synchronous and deterministic — LangGraph interrupt replay
on resume requires the same interrupts to fire on each evaluation. Exceptions
propagate.

Note: the `ToolRuntime` constructed for predicate evaluation has three
deviations from a tool-time `ToolRuntime`: `tools=[]`, `config={}`, and
`execution_info`/`server_info=None`. Predicates should rely on `state`,
`tool_call_id`, `context`, and `store` rather than these fields.
"""


class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT, ResponseT]):
"""Human in the loop middleware."""
Expand All @@ -201,8 +232,15 @@ def __init__(
* `InterruptOnConfig` indicates the specific decisions allowed for this
tool.

The `InterruptOnConfig` can include a `description` field (`str` or
`Callable`) for custom formatting of the interrupt description.
The `InterruptOnConfig` can include:
- a `description` field (`str` or `Callable`) for custom formatting
of the interrupt description, and
- an `interrupt_when` predicate `(ToolCall, ToolRuntime) -> bool`.
When set, only calls for which the predicate returns `True`
interrupt. Calls returning `False` are auto-approved exactly as
if the tool were not listed in `interrupt_on`. Predicates must be
synchronous and deterministic (LangGraph interrupt replay re-runs
them on resume).
description_prefix: The prefix to use when constructing action requests.

This is used to provide context about the tool call and the action being
Expand All @@ -223,6 +261,34 @@ def __init__(
self.interrupt_on = resolved_configs
self.description_prefix = description_prefix

def _should_interrupt(
self,
tool_call: ToolCall,
config: InterruptOnConfig,
state: AgentState[Any],
runtime: Runtime[ContextT],
) -> bool:
"""Evaluate the per-call predicate, if any.

Returns `True` when the call should interrupt, `False` to auto-approve.
When no `interrupt_when` is configured, returns `True` (unchanged behavior).
"""
interrupt_when = config.get("interrupt_when")
if interrupt_when is None:
return True
tool_runtime: ToolRuntime[ContextT, Any] = ToolRuntime(
state=state,
context=runtime.context,
config={},
stream_writer=runtime.stream_writer,
tool_call_id=tool_call["id"],
store=runtime.store,
tools=[],
execution_info=None,
server_info=None,
)
return interrupt_when(tool_call, tool_runtime)
Comment on lines +276 to +290

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Async predicate silently bypasses user logic

If a user passes an async def function as interrupt_when, calling interrupt_when(tool_call, tool_runtime) returns a coroutine object rather than a bool. Python coroutine objects are always truthy, so _should_interrupt returns True unconditionally — the interrupt fires regardless of what the predicate was intended to decide, and the coroutine is never awaited (producing only a background RuntimeWarning). Since aafter_model delegates synchronously to after_model, there is no path to correctly handle async predicates. A runtime guard after the call (e.g. asyncio.iscoroutine(result)) would surface this as a clear TypeError rather than silent wrong behavior.

Suggested change
interrupt_when = config.get("interrupt_when")
if interrupt_when is None:
return True
tool_runtime: ToolRuntime[ContextT, Any] = ToolRuntime(
state=state,
context=runtime.context,
config={},
stream_writer=runtime.stream_writer,
tool_call_id=tool_call["id"],
store=runtime.store,
tools=[],
execution_info=None,
server_info=None,
)
return interrupt_when(tool_call, tool_runtime)
import asyncio
result = interrupt_when(tool_call, tool_runtime)
if asyncio.iscoroutine(result):
result.close()
msg = (
"interrupt_when predicate returned a coroutine. "
"Predicates must be synchronous; pass a plain `def`, not `async def`."
)
raise TypeError(msg)
return bool(result)


def _create_action_and_config(
self,
tool_call: ToolCall,
Expand Down Expand Up @@ -340,13 +406,17 @@ def after_model(
interrupt_indices: list[int] = []

for idx, tool_call in enumerate(last_ai_msg.tool_calls):
if (config := self.interrupt_on.get(tool_call["name"])) is not None:
action_request, review_config = self._create_action_and_config(
tool_call, config, state, runtime
)
action_requests.append(action_request)
review_configs.append(review_config)
interrupt_indices.append(idx)
config = self.interrupt_on.get(tool_call["name"])
if config is None:
continue
if not self._should_interrupt(tool_call, config, state, runtime):
continue
action_request, review_config = self._create_action_and_config(
tool_call, config, state, runtime
)
action_requests.append(action_request)
review_configs.append(review_config)
interrupt_indices.append(idx)

# If no interrupts needed, return early
if not action_requests:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
HumanInTheLoopMiddleware,
)
from langchain.agents.middleware.types import AgentState
from langchain.tools import ToolRuntime


def test_human_in_the_loop_middleware_initialization() -> None:
Expand Down Expand Up @@ -883,3 +884,231 @@ def mock_mixed_responses(_: Any) -> dict[str, Any]:
assert isinstance(tool_message, ToolMessage)
assert tool_message.content == "Rejected tool B"
assert tool_message.tool_call_id == "id_b"


def test_interrupt_on_config_accepts_interrupt_when() -> None:
"""`InterruptOnConfig` accepts an optional `interrupt_when` predicate."""
config: InterruptOnConfig = {
"allowed_decisions": ["approve", "reject"],
"interrupt_when": lambda _tc, _rt: True,
}
assert config["interrupt_when"](None, None) is True # type: ignore[arg-type]


def test_interrupt_when_false_auto_approves() -> None:
"""`interrupt_when` returning `False` skips the interrupt entirely."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={
"edit_file": {
"allowed_decisions": ["approve", "reject"],
"interrupt_when": lambda _tc, _rt: False,
}
}
)
ai_message = AIMessage(
content="Writing safe file",
tool_calls=[{"name": "edit_file", "args": {"path": "/safe/path"}, "id": "1"}],
)
state = AgentState[Any](messages=[HumanMessage(content="Hi"), ai_message])

def fail_if_called(_: Any) -> dict[str, Any]:
msg = "interrupt() should not be called when predicate returns False"
raise AssertionError(msg)

with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
side_effect=fail_if_called,
):
result = middleware.after_model(state, Runtime())

assert result is None


def test_interrupt_when_true_interrupts() -> None:
"""`interrupt_when` returning `True` interrupts with the configured decisions."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={
"edit_file": {
"allowed_decisions": ["approve", "reject"],
"interrupt_when": lambda _tc, _rt: True,
}
}
)
ai_message = AIMessage(
content="Writing protected file",
tool_calls=[{"name": "edit_file", "args": {"path": "/etc/secret"}, "id": "1"}],
)
state = AgentState[Any](messages=[HumanMessage(content="Hi"), ai_message])

captured: dict[str, Any] = {}

def mock_approve(request: Any) -> dict[str, Any]:
captured["request"] = request
return {"decisions": [{"type": "approve"}]}

with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
side_effect=mock_approve,
):
result = middleware.after_model(state, Runtime())

assert result is not None
assert len(captured["request"]["action_requests"]) == 1
assert captured["request"]["action_requests"][0]["name"] == "edit_file"
assert captured["request"]["review_configs"][0]["allowed_decisions"] == ["approve", "reject"]
assert result["messages"][0].tool_calls[0]["id"] == "1"


def test_interrupt_when_mixed_calls_same_tool() -> None:
"""One protected-path call interrupts, one wiki call auto-approves, order preserved."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={
"edit_file": {
"allowed_decisions": ["approve", "reject"],
"interrupt_when": lambda tc, _rt: bool(
re.match(r"^/etc/", tc["args"].get("path", ""))
),
}
}
)
ai_message = AIMessage(
content="Two writes",
tool_calls=[
{"name": "edit_file", "args": {"path": "/etc/secret"}, "id": "1"},
{"name": "edit_file", "args": {"path": "/wiki/page"}, "id": "2"},
],
)
state = AgentState[Any](messages=[HumanMessage(content="Hi"), ai_message])

captured: dict[str, Any] = {}

def mock_approve(request: Any) -> dict[str, Any]:
captured["request"] = request
return {"decisions": [{"type": "approve"}]}

with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
side_effect=mock_approve,
):
result = middleware.after_model(state, Runtime())

assert len(captured["request"]["action_requests"]) == 1
assert captured["request"]["action_requests"][0]["args"] == {"path": "/etc/secret"}

assert result is not None
revised = result["messages"][0].tool_calls
assert [tc["id"] for tc in revised] == ["1", "2"]
assert [tc["args"]["path"] for tc in revised] == ["/etc/secret", "/wiki/page"]


def test_interrupt_when_mixed_configured_tools() -> None:
"""Unlisted, listed-with-false, and listed-with-true tools all behave correctly."""
middleware = HumanInTheLoopMiddleware(
interrupt_on={
"edit_file": {
"allowed_decisions": ["approve"],
"interrupt_when": lambda _tc, _rt: False,
},
"delete_file": {
"allowed_decisions": ["approve", "reject"],
"interrupt_when": lambda _tc, _rt: True,
},
}
)
ai_message = AIMessage(
content="Three calls",
tool_calls=[
{"name": "search", "args": {"q": "x"}, "id": "1"},
{"name": "edit_file", "args": {"path": "/wiki/page"}, "id": "2"},
{"name": "delete_file", "args": {"path": "/wiki/page"}, "id": "3"},
],
)
state = AgentState[Any](messages=[HumanMessage(content="Hi"), ai_message])

captured: dict[str, Any] = {}

def mock_approve(request: Any) -> dict[str, Any]:
captured["request"] = request
return {"decisions": [{"type": "approve"}]}

with patch(
"langchain.agents.middleware.human_in_the_loop.interrupt",
side_effect=mock_approve,
):
result = middleware.after_model(state, Runtime())

assert [a["name"] for a in captured["request"]["action_requests"]] == ["delete_file"]
assert result is not None
revised = result["messages"][0].tool_calls
assert [tc["id"] for tc in revised] == ["1", "2", "3"]


def test_interrupt_when_tool_runtime_fields() -> None:
"""Predicate receives a `ToolRuntime` with correct fields and documented deviations."""
captured: dict[str, Any] = {}

def capture_predicate(tc: ToolCall, rt: ToolRuntime[Any, Any]) -> bool:
captured["tool_call"] = tc
captured["runtime"] = rt
return False

middleware = HumanInTheLoopMiddleware(
interrupt_on={
"edit_file": {
"allowed_decisions": ["approve"],
"interrupt_when": capture_predicate,
}
}
)
ai_message = AIMessage(
content="One call",
tool_calls=[{"name": "edit_file", "args": {"path": "/x"}, "id": "abc"}],
)
state = AgentState[Any](messages=[HumanMessage(content="Hi"), ai_message])

sentinel_context = object()
sentinel_store = object()
runtime = Runtime(context=sentinel_context, store=sentinel_store) # type: ignore[arg-type]

middleware.after_model(state, runtime)

rt = captured["runtime"]
assert rt.tool_call_id == "abc"
assert rt.state is state
assert rt.context is sentinel_context
assert rt.store is sentinel_store
assert rt.stream_writer is runtime.stream_writer
assert rt.tools == []
assert rt.config == {}
assert rt.execution_info is None
assert rt.server_info is None
assert captured["tool_call"]["id"] == "abc"
assert captured["tool_call"]["args"] == {"path": "/x"}


def test_interrupt_when_exceptions_propagate() -> None:
"""A predicate that raises propagates the exception (no silent auto-approval)."""

class PredicateBugError(RuntimeError):
pass

def bad_predicate(_tc: Any, _rt: Any) -> bool:
msg = "kaboom"
raise PredicateBugError(msg)

middleware = HumanInTheLoopMiddleware(
interrupt_on={
"edit_file": {
"allowed_decisions": ["approve"],
"interrupt_when": bad_predicate,
}
}
)
ai_message = AIMessage(
content="One call",
tool_calls=[{"name": "edit_file", "args": {"path": "/x"}, "id": "1"}],
)
state = AgentState[Any](messages=[HumanMessage(content="Hi"), ai_message])

with pytest.raises(PredicateBugError, match="kaboom"):
middleware.after_model(state, Runtime())