Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions libs/core/langchain_core/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,10 @@ def on_stream_event(
tags: list[str] | None = None,
**kwargs: Any,
) -> Any:
"""Run on each protocol event produced by `stream_v2` / `astream_v2`.
"""Run on each protocol event from `stream_events(version="v3")`.

Also fires for the async equivalent
(`astream_events(version="v3")`).

Fires once per `MessagesData` event — `message-start`, per-block
`content-block-start` / `content-block-delta` /
Expand Down Expand Up @@ -699,7 +702,7 @@ async def on_stream_event(
tags: list[str] | None = None,
**kwargs: Any,
) -> None:
"""Run on each protocol event produced by `astream_v2`.
"""Run on each protocol event produced by `astream_events(version="v3")`.

See :meth:`LLMManagerMixin.on_stream_event` for the full contract.
Fires once per `MessagesData` event at event granularity, uniformly
Expand Down
4 changes: 2 additions & 2 deletions libs/core/langchain_core/callbacks/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ def on_llm_error(
)

def on_stream_event(self, event: MessagesData, **kwargs: Any) -> None:
"""Run on each protocol event from `stream_v2`.
"""Run on each protocol event from `stream_events(version="v3")`.

Args:
event: The protocol event.
Expand Down Expand Up @@ -871,7 +871,7 @@ async def on_llm_error(
)

async def on_stream_event(self, event: MessagesData, **kwargs: Any) -> None:
"""Run on each protocol event from `astream_v2`.
"""Run on each protocol event from `astream_events(version="v3")`.

Args:
event: The protocol event.
Expand Down
152 changes: 104 additions & 48 deletions libs/core/langchain_core/language_models/_compat_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@
if TYPE_CHECKING:
from collections.abc import AsyncIterator, Iterator

from langchain_protocol.protocol import (
BlockDelta,
BlockDeltaFields,
ContentBlockDelta,
DataDelta,
ReasoningDelta,
TextDelta,
)

from langchain_core.outputs import ChatGenerationChunk


Expand Down Expand Up @@ -101,6 +110,38 @@ def _to_finalized_block(block: CompatBlock) -> FinalizedContentBlock:
return cast("FinalizedContentBlock", block)


def _to_block_delta_fields(block: CompatBlock) -> BlockDeltaFields:
"""Narrow an internal working dict to protocol block-delta fields."""
return cast("BlockDeltaFields", block)


def _to_content_delta(block: CompatBlock) -> ContentBlockDelta:
"""Convert a content-block slice/snapshot to an explicit protocol delta."""
btype = block.get("type")
if btype == "text":
return cast("TextDelta", {"type": "text-delta", "text": block.get("text", "")})
if btype == "reasoning":
return cast(
"ReasoningDelta",
{
"type": "reasoning-delta",
"reasoning": block.get("reasoning", ""),
},
)
if "data" in block:
delta = cast("DataDelta", {"type": "data-delta", "data": block.get("data", "")})
if block.get("encoding") == "base64":
delta["encoding"] = "base64"
return delta
return cast(
"BlockDelta",
{
"type": "block-delta",
"fields": _to_block_delta_fields(block),
},
)


# ---------------------------------------------------------------------------
# Block iteration
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -236,6 +277,8 @@ def _should_emit_delta(block: CompatBlock) -> bool:
return bool(
block.get("args") or block.get("id") or block.get("name"),
)
if "data" in block:
return bool(block.get("data"))
return False


Expand Down Expand Up @@ -281,6 +324,15 @@ def _accumulate(state: CompatBlock | None, delta: CompatBlock) -> CompatBlock:
state["id"] = delta["id"]
if delta.get("name") is not None:
state["name"] = delta["name"]
elif btype == dtype and "data" in delta:
state["data"] = (state.get("data", "") or "") + (delta.get("data") or "")
for key, value in delta.items():
if key in ("type", "data") or value is None:
continue
if key == "extras" and isinstance(value, dict):
state["extras"] = {**(state.get("extras") or {}), **value}
Comment on lines +329 to +333

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Functional High

Using dict.update() for input_token_details and output_token_details overwrites per-key counts across chunks instead of summing them, so nested usage totals become incorrect; sum numeric subkeys when merging.

Suggested fix
            if detail_key not in current:
                current[detail_key] = {}
            for subkey, subvalue in delta[detail_key].items():
                current[detail_key][subkey] = current[detail_key].get(subkey, 0) + subvalue
Prompt for AI assistance

Copy the prompt below and paste it into ChatGPT, Claude, or any LLM:

You are an expert python developer with deep knowledge of security, performance, and best practices.

### Context

File: libs/core/langchain_core/language_models/_compat_bridge.py
Lines: 329-333
Issue Type: functional-high
Severity: high

Issue Description:
Using `dict.update()` for `input_token_details` and `output_token_details` overwrites per-key counts across chunks instead of summing them, so nested usage totals become incorrect; sum numeric subkeys when merging.

Current Code:
            if detail_key not in current:
                current[detail_key] = {}
            current[detail_key].update(delta[detail_key])

---

### Instructions

1. Fix the issue described above
2. Maintain the exact indentation and code style from the original
3. Follow python best practices and language-specific idioms
4. Ensure the fix addresses the root cause, not just the symptoms
5. Add brief inline comments explaining the fix if needed

### Constraints

- Do not change functionality beyond fixing the identified issue
- Preserve existing variable names and function signatures unless they are part of the problem
- Ensure the fix is production-ready

---


Like Dislike Create Issue Jira

else:
state[key] = value
else:
# Self-contained or already-finalized types: replace wholesale.
state.clear()
Expand Down Expand Up @@ -429,11 +481,11 @@ def _to_protocol_usage(usage: dict[str, Any] | None) -> UsageInfo | None:
"""Convert accumulated usage to the protocol's `UsageInfo` shape."""
if usage is None:
return None
result: UsageInfo = {}
result: dict[str, Any] = {}
for key in ("input_tokens", "output_tokens", "total_tokens", "cached_tokens"):
if key in usage:
result[key] = usage[key]
return result or None
return cast("UsageInfo", result) if result else None


# ---------------------------------------------------------------------------
Expand All @@ -445,10 +497,10 @@ def _build_message_start(
msg: BaseMessage,
message_id: str | None,
) -> MessageStartData:
start_data = MessageStartData(event="message-start", role="ai")
start_data = MessageStartData(event="message-start", role="ai", id="")
resolved_id = message_id if message_id is not None else getattr(msg, "id", None)
if resolved_id:
start_data["message_id"] = resolved_id
start_data["id"] = resolved_id
Comment on lines +500 to +503

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 message-start always emits id: "" when no message ID is available

Initialising MessageStartData with id="" means the message-start event always contains the id key, even when no provider message ID was resolved. Previously the field was absent. Consumers that distinguish "has an ID" from "no ID" via "id" in event (rather than if event.get("id")) will now always see the key, treating the empty string as a valid ID. The _push_message_start guard (if message_id:) correctly ignores the empty string internally, but external consumers of the raw event may not apply the same guard. Omitting the field when no ID is available would match the previous protocol contract.

Suggested change
start_data = MessageStartData(event="message-start", role="ai", id="")
resolved_id = message_id if message_id is not None else getattr(msg, "id", None)
if resolved_id:
start_data["message_id"] = resolved_id
start_data["id"] = resolved_id
start_data = MessageStartData(event="message-start", role="ai")
resolved_id = message_id if message_id is not None else getattr(msg, "id", None)
if resolved_id:
start_data["id"] = resolved_id

start_metadata = _extract_start_metadata(msg.response_metadata or {})
if start_metadata:
start_data["metadata"] = start_metadata
Expand All @@ -464,13 +516,13 @@ def _build_message_finish(
# `MessageFinishData`; the provider's raw `finish_reason` /
# `stop_reason` now rides inside `metadata` alongside other
# response metadata. Pass it through unchanged.
finish_data = MessageFinishData(event="message-finish")
finish_data: dict[str, Any] = {"event": "message-finish"}
usage_info = _to_protocol_usage(usage)
if usage_info is not None:
finish_data["usage"] = usage_info
if response_metadata:
finish_data["metadata"] = dict(response_metadata)
return finish_data
return cast("MessageFinishData", finish_data)


def _finalize_and_build_finish(
Expand All @@ -481,7 +533,7 @@ def _finalize_and_build_finish(
return ContentBlockFinishData(
event="content-block-finish",
index=wire_idx,
content_block=_finalize_block(block),
content=_finalize_block(block),
)


Expand All @@ -497,12 +549,12 @@ def chunks_to_events(
) -> Iterator[MessagesData]:
"""Convert a stream of `ChatGenerationChunk` to protocol events.

Blocks stream one at a time: when a chunk carries a different block
identifier than the currently-open one, the open block is finished
before the new block starts, matching the protocol's no-interleave
rule. Source-side identifiers (from the block's `index` field, which
may be int or string) are translated to sequential `uint` wire
indices.
Blocks are tracked independently by source-side identifier. Providers
such as Anthropic can interleave parallel tool-call chunks by index, so
each first-seen block gets a `content-block-start`, deltas keep their
stable wire index, and all open blocks are finalized at message end.
Source-side identifiers (from the block's `index` field, which may be
int or string) are translated to sequential `uint` wire indices.

Args:
chunks: Iterator of `ChatGenerationChunk` from `_stream()`.
Expand All @@ -512,9 +564,7 @@ def chunks_to_events(
`MessagesData` lifecycle events.
"""
started = False
open_key: Any = None
open_block: CompatBlock | None = None
open_wire_idx: int = 0
blocks: dict[Any, tuple[int, CompatBlock]] = {}
next_wire_idx = 0
usage: dict[str, Any] | None = None
response_metadata: dict[str, Any] = {}
Expand Down Expand Up @@ -545,25 +595,29 @@ def chunks_to_events(
yield _build_message_start(msg, message_id)

for key, block in _iter_protocol_blocks(msg):
if key != open_key:
if open_block is not None:
yield _finalize_and_build_finish(open_wire_idx, open_block)
open_key = key
open_wire_idx = next_wire_idx
if key not in blocks:
wire_idx = next_wire_idx
next_wire_idx += 1
open_block = dict(block)
blocks[key] = (wire_idx, dict(block))
yield ContentBlockStartData(
event="content-block-start",
index=open_wire_idx,
content_block=_start_skeleton(block),
index=wire_idx,
content=_start_skeleton(block),
)
else:
open_block = _accumulate(open_block, block)
wire_idx, existing = blocks[key]
blocks[key] = (wire_idx, _accumulate(existing, block))
if _should_emit_delta(block):
wire_idx, current = blocks[key]
is_block_delta = block.get("type") in (
"tool_call_chunk",
"server_tool_call_chunk",
)
delta_source = current if is_block_delta else block
yield ContentBlockDeltaData(
event="content-block-delta",
index=open_wire_idx,
content_block=_to_protocol_block(block),
index=wire_idx,
delta=_to_content_delta(delta_source or block),
)
Comment on lines 610 to 621

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Tool-call delta events carry cumulative state, not incremental chunks

For tool_call_chunk and server_tool_call_chunk blocks, delta_source = current (the fully accumulated block state), so every content-block-delta emits fields.args equal to the total concatenated args up to that point rather than just the new bytes from the incoming block. A consumer that appends each delta's args field to build the final tool call will produce doubled content. _push_content_block_delta's block-delta handler calls _merge_block_delta_into_store which overwrites fields on each delta, so it converges to the correct state, but external SDK consumers following the naive "concatenate each delta" pattern will produce incorrect args. Additionally, delta_source or block is always equivalent to delta_source because delta_source is always assigned a non-None dict (current or block) — the or block fallback is dead code.


if msg.usage_metadata:
Expand All @@ -572,8 +626,8 @@ def chunks_to_events(
if not started:
return

if open_block is not None:
yield _finalize_and_build_finish(open_wire_idx, open_block)
for wire_idx, block in blocks.values():
yield _finalize_and_build_finish(wire_idx, block)

yield _build_message_finish(
usage=usage,
Expand All @@ -588,9 +642,7 @@ async def achunks_to_events(
) -> AsyncIterator[MessagesData]:
"""Async variant of `chunks_to_events`."""
started = False
open_key: Any = None
open_block: CompatBlock | None = None
open_wire_idx: int = 0
blocks: dict[Any, tuple[int, CompatBlock]] = {}
next_wire_idx = 0
usage: dict[str, Any] | None = None
response_metadata: dict[str, Any] = {}
Expand All @@ -615,25 +667,29 @@ async def achunks_to_events(
yield _build_message_start(msg, message_id)

for key, block in _iter_protocol_blocks(msg):
if key != open_key:
if open_block is not None:
yield _finalize_and_build_finish(open_wire_idx, open_block)
open_key = key
open_wire_idx = next_wire_idx
if key not in blocks:
wire_idx = next_wire_idx
next_wire_idx += 1
open_block = dict(block)
blocks[key] = (wire_idx, dict(block))
yield ContentBlockStartData(
event="content-block-start",
index=open_wire_idx,
content_block=_start_skeleton(block),
index=wire_idx,
content=_start_skeleton(block),
)
else:
open_block = _accumulate(open_block, block)
wire_idx, existing = blocks[key]
blocks[key] = (wire_idx, _accumulate(existing, block))
if _should_emit_delta(block):
wire_idx, current = blocks[key]
is_block_delta = block.get("type") in (
"tool_call_chunk",
"server_tool_call_chunk",
)
delta_source = current if is_block_delta else block
yield ContentBlockDeltaData(
event="content-block-delta",
index=open_wire_idx,
content_block=_to_protocol_block(block),
index=wire_idx,
delta=_to_content_delta(delta_source or block),
)

if msg.usage_metadata:
Expand All @@ -642,8 +698,8 @@ async def achunks_to_events(
if not started:
return

if open_block is not None:
yield _finalize_and_build_finish(open_wire_idx, open_block)
for wire_idx, block in blocks.values():
yield _finalize_and_build_finish(wire_idx, block)

yield _build_message_finish(
usage=usage,
Expand Down Expand Up @@ -682,18 +738,18 @@ def message_to_events(
yield ContentBlockStartData(
event="content-block-start",
index=wire_idx,
content_block=_start_skeleton(block),
content=_start_skeleton(block),
)
if _should_emit_delta(block):
yield ContentBlockDeltaData(
event="content-block-delta",
index=wire_idx,
content_block=_to_protocol_block(block),
delta=_to_content_delta(block),
)
yield ContentBlockFinishData(
event="content-block-finish",
index=wire_idx,
content_block=_finalize_block(block),
content=_finalize_block(block),
)

yield _build_message_finish(
Expand Down
Loading