Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 48 additions & 15 deletions libs/core/langchain_core/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,18 +969,36 @@ def __init__(
self.inheritable_tags = inheritable_tags or []
self.metadata = metadata or {}
self.inheritable_metadata = inheritable_metadata or {}
self._cow = False

def _cow_copy(self) -> None:
"""Materialize copy-on-write shared state before mutation."""
if self._cow:
self.handlers = self.handlers.copy()
self.inheritable_handlers = self.inheritable_handlers.copy()
self.tags = self.tags.copy()
self.inheritable_tags = self.inheritable_tags.copy()
self.metadata = self.metadata.copy()
self.inheritable_metadata = self.inheritable_metadata.copy()
self._cow = False

def copy(self) -> Self:
"""Return a copy of the callback manager."""
return self.__class__(
handlers=self.handlers.copy(),
inheritable_handlers=self.inheritable_handlers.copy(),
parent_run_id=self.parent_run_id,
tags=self.tags.copy(),
inheritable_tags=self.inheritable_tags.copy(),
metadata=self.metadata.copy(),
inheritable_metadata=self.inheritable_metadata.copy(),
)
"""Return a copy of the callback manager.

Uses copy-on-write: the copy shares underlying lists/dicts until
either the original or the copy is mutated.
"""
self._cow = True
clone = self.__class__.__new__(self.__class__)
clone.handlers = self.handlers
clone.inheritable_handlers = self.inheritable_handlers
clone.parent_run_id = self.parent_run_id
clone.tags = self.tags
clone.inheritable_tags = self.inheritable_tags
clone.metadata = self.metadata
clone.inheritable_metadata = self.inheritable_metadata
clone._cow = True # noqa: SLF001
Comment on lines +992 to +1000

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Functional High

The copy-on-write clone bypasses __init__, so subclass-specific state is silently dropped; use the normal constructor or a subclass hook to preserve required attributes.

Suggested fix
        return self.__class__(
            handlers=self.handlers.copy(),
            inheritable_handlers=self.inheritable_handlers.copy(),
            parent_run_id=self.parent_run_id,
            tags=self.tags.copy(),
            inheritable_tags=self.inheritable_tags.copy(),
            metadata=self.metadata.copy(),
            inheritable_metadata=self.inheritable_metadata.copy(),
        )
Prompt for AI assistance

Copy the prompt below and paste it into ChatGPT, Claude, or any LLM:

You are an expert python developer with deep knowledge of security, performance, and best practices.

### Context

File: libs/core/langchain_core/callbacks/base.py
Lines: 992-1000
Issue Type: functional-high
Severity: high

Issue Description:
The copy-on-write clone bypasses `__init__`, so subclass-specific state is silently dropped; use the normal constructor or a subclass hook to preserve required attributes.

Current Code:
        clone = self.__class__.__new__(self.__class__)
        clone.handlers = self.handlers
        clone.inheritable_handlers = self.inheritable_handlers
        clone.parent_run_id = self.parent_run_id
        clone.tags = self.tags
        clone.inheritable_tags = self.inheritable_tags
        clone.metadata = self.metadata
        clone.inheritable_metadata = self.inheritable_metadata
        clone._cow = True  # noqa: SLF001

---

### Instructions

1. Fix the issue described above
2. Maintain the exact indentation and code style from the original
3. Follow python best practices and language-specific idioms
4. Ensure the fix addresses the root cause, not just the symptoms
5. Add brief inline comments explaining the fix if needed

### Constraints

- Do not change functionality beyond fixing the identified issue
- Preserve existing variable names and function signatures unless they are part of the problem
- Ensure the fix is production-ready

---


Like Dislike Create Issue Jira

return clone
Comment on lines 985 to +1001

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 __new__-based clone skips subclass __init__

copy() uses self.__class__.__new__(self.__class__) and manually copies only the seven attributes declared in BaseCallbackManager.__init__. Any subclass that extends __init__ with additional instance attributes (e.g., self.ended, self.parent_run_manager in CallbackManagerForChainGroup) will produce a clone that raises AttributeError on first access.

The two current subclasses that add extra attributes (CallbackManagerForChainGroup, AsyncCallbackManagerForChainGroup) both override copy() and are safe today, but this is an invisible footgun for any future subclass that forgets to do the same. A docstring note or a __init_subclass__ guard would prevent silent breakage.


def merge(self, other: BaseCallbackManager) -> Self:
"""Merge the callback manager with another callback manager.
Expand Down Expand Up @@ -1053,6 +1071,7 @@ def add_handler(
handler: The handler to add.
inherit: Whether to inherit the handler.
"""
self._cow_copy()
if handler not in self.handlers:
self.handlers.append(handler)
if inherit and handler not in self.inheritable_handlers:
Expand All @@ -1064,6 +1083,7 @@ def remove_handler(self, handler: BaseCallbackHandler) -> None:
Args:
handler: The handler to remove.
"""
self._cow_copy()
if handler in self.handlers:
self.handlers.remove(handler)
if handler in self.inheritable_handlers:
Expand All @@ -1080,6 +1100,7 @@ def set_handlers(
handlers: The handlers to set.
inherit: Whether to inherit the handlers.
"""
self._cow = False
self.handlers = []
self.inheritable_handlers = []
for handler in handlers:
Expand Down Expand Up @@ -1109,19 +1130,29 @@ def add_tags(
tags: The tags to add.
inherit: Whether to inherit the tags.
"""
for tag in tags:
if tag in self.tags:
self.remove_tags([tag])
self.tags.extend(tags)
self._cow_copy()
if not self.tags:
self.tags.extend(tags)
if inherit:
self.inheritable_tags.extend(tags)
return
# Deduplicate: tag order is not meaningful across the codebase
# (merge_configs sorts, tracers deduplicate via sets).
existing = set(self.tags)
new_tags = [t for t in tags if t not in existing]
self.tags.extend(new_tags)
if inherit:
self.inheritable_tags.extend(tags)
existing_inh = set(self.inheritable_tags)
new_inh = [t for t in tags if t not in existing_inh]
self.inheritable_tags.extend(new_inh)

def remove_tags(self, tags: list[str]) -> None:
"""Remove tags from the callback manager.

Args:
tags: The tags to remove.
"""
self._cow_copy()
for tag in tags:
if tag in self.tags:
self.tags.remove(tag)
Expand All @@ -1139,6 +1170,7 @@ def add_metadata(
metadata: The metadata to add.
inherit: Whether to inherit the metadata.
"""
self._cow_copy()
self.metadata.update(metadata)
if inherit:
self.inheritable_metadata.update(metadata)
Expand All @@ -1149,6 +1181,7 @@ def remove_metadata(self, keys: list[str]) -> None:
Args:
keys: The keys to remove.
"""
self._cow_copy()
for key in keys:
self.metadata.pop(key, None)
self.inheritable_metadata.pop(key, None)
Expand Down
49 changes: 33 additions & 16 deletions libs/core/langchain_core/callbacks/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,9 @@ def handle_event(
**kwargs: The keyword arguments to pass to the event handler

"""
if not handlers:
return

coros: list[Coroutine[Any, Any, Any]] = []

try:
Expand Down Expand Up @@ -433,6 +436,9 @@ async def ahandle_event(
**kwargs: The keyword arguments to pass to the event handler.

"""
if not handlers:
return

for handler in [h for h in handlers if h.run_inline]:
await _ahandle_event_for_handler(
handler, event_name, ignore_condition_name, *args, **kwargs
Expand Down Expand Up @@ -574,13 +580,18 @@ def get_child(self, tag: str | None = None) -> CallbackManager:
The child callback manager.

"""
manager = CallbackManager(handlers=[], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers)
manager.add_tags(self.inheritable_tags)
manager.add_metadata(self.inheritable_metadata)
tags = list(self.inheritable_tags)
if tag is not None:
manager.add_tags([tag], inherit=False)
return manager
tags.append(tag)
return CallbackManager(
handlers=list(self.inheritable_handlers),
inheritable_handlers=list(self.inheritable_handlers),
parent_run_id=self.run_id,
tags=tags,
inheritable_tags=list(self.inheritable_tags),
metadata=dict(self.inheritable_metadata),
inheritable_metadata=dict(self.inheritable_metadata),
)


class AsyncRunManager(BaseRunManager, ABC):
Expand Down Expand Up @@ -658,13 +669,18 @@ def get_child(self, tag: str | None = None) -> AsyncCallbackManager:
The child callback manager.

"""
manager = AsyncCallbackManager(handlers=[], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers)
manager.add_tags(self.inheritable_tags)
manager.add_metadata(self.inheritable_metadata)
tags = list(self.inheritable_tags)
if tag is not None:
manager.add_tags([tag], inherit=False)
return manager
tags.append(tag)
return AsyncCallbackManager(
handlers=list(self.inheritable_handlers),
inheritable_handlers=list(self.inheritable_handlers),
parent_run_id=self.run_id,
tags=tags,
inheritable_tags=list(self.inheritable_tags),
metadata=dict(self.inheritable_metadata),
inheritable_metadata=dict(self.inheritable_metadata),
)


class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
Expand Down Expand Up @@ -2340,10 +2356,6 @@ def _configure(
tracing_tags = tracing_context["tags"]
run_tree: Run | None = tracing_context["parent"]
parent_run_id = None if run_tree is None else run_tree.id
callback_manager = callback_manager_cls(
handlers=[],
parent_run_id=parent_run_id,
)
if inheritable_callbacks or local_callbacks:
if isinstance(inheritable_callbacks, list) or inheritable_callbacks is None:
inheritable_callbacks_ = inheritable_callbacks or []
Expand Down Expand Up @@ -2381,6 +2393,11 @@ def _configure(
)
for handler in local_handlers_:
callback_manager.add_handler(handler, inherit=False)
else:
callback_manager = callback_manager_cls(
handlers=[],
parent_run_id=parent_run_id,
)
if inheritable_tags or local_tags:
callback_manager.add_tags(inheritable_tags or [])
callback_manager.add_tags(local_tags or [], inherit=False)
Expand Down
31 changes: 29 additions & 2 deletions libs/core/langchain_core/language_models/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ def _format_for_tracing(messages: list[BaseMessage]) -> list[BaseMessage]:
List of messages formatted for tracing.

"""
# Fast path: if no messages have list content, no formatting is needed.
if not any(isinstance(m.content, list) for m in messages):
return messages
messages_to_trace = []
for message in messages:
message_to_trace = message
Expand Down Expand Up @@ -243,6 +246,30 @@ def _format_ls_structured_output(ls_structured_output_format: dict | None) -> di
return ls_structured_output_format_dict


_generate_accepts_run_manager: dict[type, bool] = {}
_agenerate_accepts_run_manager: dict[type, bool] = {}


def _check_generates_accept_run_manager(self: BaseChatModel) -> bool:
cls = type(self)
try:
return _generate_accepts_run_manager[cls]
except KeyError:
result = bool(inspect.signature(self._generate).parameters.get("run_manager"))
_generate_accepts_run_manager[cls] = result
return result


def _check_agenerates_accept_run_manager(self: BaseChatModel) -> bool:
cls = type(self)
try:
return _agenerate_accepts_run_manager[cls]
except KeyError:
result = bool(inspect.signature(self._agenerate).parameters.get("run_manager"))
_agenerate_accepts_run_manager[cls] = result
return result


class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
r"""Base class for chat models.

Expand Down Expand Up @@ -1231,7 +1258,7 @@ def _generate_with_cache(
run_manager.on_llm_new_token("", chunk=chunk)
chunks.append(chunk)
result = generate_from_stream(iter(chunks))
elif inspect.signature(self._generate).parameters.get("run_manager"):
elif _check_generates_accept_run_manager(self):
result = self._generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
Expand Down Expand Up @@ -1357,7 +1384,7 @@ async def _agenerate_with_cache(
await run_manager.on_llm_new_token("", chunk=chunk)
chunks.append(chunk)
result = generate_from_stream(iter(chunks))
elif inspect.signature(self._agenerate).parameters.get("run_manager"):
elif _check_agenerates_accept_run_manager(self):
result = await self._agenerate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
Expand Down
38 changes: 21 additions & 17 deletions libs/core/langchain_core/runnables/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,23 +120,27 @@ class RunnableConfig(TypedDict, total=False):
"""


CONFIG_KEYS = [
"tags",
"metadata",
"callbacks",
"run_name",
"max_concurrency",
"recursion_limit",
"configurable",
"run_id",
]

COPIABLE_KEYS = [
"tags",
"metadata",
"callbacks",
"configurable",
]
CONFIG_KEYS = frozenset(
{
"tags",
"metadata",
"callbacks",
"run_name",
"max_concurrency",
"recursion_limit",
"configurable",
"run_id",
}
)

COPIABLE_KEYS = frozenset(
{
"tags",
"metadata",
"callbacks",
"configurable",
}
)

DEFAULT_RECURSION_LIMIT = 25

Expand Down
23 changes: 17 additions & 6 deletions libs/core/langchain_core/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,19 @@ def __init__(self, **kwargs: Any) -> None:
)
raise TypeError(msg)
super().__init__(**kwargs)
# Cache per-invocation introspection results
try:
self._has_run_manager_param: bool = bool(
signature(self._run).parameters.get("run_manager")
)
except (ValueError, TypeError):
self._has_run_manager_param = False
try:
self._runnable_config_param: str | None = _get_runnable_config_param(
self._run
)
except (ValueError, TypeError):
self._runnable_config_param = None

model_config = ConfigDict(
arbitrary_types_allowed=True,
Expand Down Expand Up @@ -794,9 +807,7 @@ async def _arun(self, *args: Any, **kwargs: Any) -> Any:
Returns:
The result of the tool execution.
"""
if kwargs.get("run_manager") and signature(self._run).parameters.get(
"run_manager"
):
if kwargs.get("run_manager") and self._has_run_manager_param:
kwargs["run_manager"] = kwargs["run_manager"].get_sync()
return await run_in_executor(None, self._run, *args, **kwargs)

Expand Down Expand Up @@ -960,10 +971,10 @@ def run(
tool_args, tool_kwargs = self._to_args_and_kwargs(
tool_input, tool_call_id
)
if signature(self._run).parameters.get("run_manager"):
if self._has_run_manager_param:
tool_kwargs |= {"run_manager": run_manager}
if config_param := _get_runnable_config_param(self._run):
tool_kwargs |= {config_param: config}
if self._runnable_config_param:
tool_kwargs |= {self._runnable_config_param: config}
response = context.run(self._run, *tool_args, **tool_kwargs)
if self.response_format == "content_and_artifact":
msg = (
Expand Down
4 changes: 3 additions & 1 deletion libs/partners/deepseek/langchain_deepseek/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections.abc import Callable, Iterator, Sequence
from json import JSONDecodeError
from typing import Any, Literal, TypeAlias, cast
from urllib.parse import urlparse

import openai
from langchain_core.callbacks import (
Expand Down Expand Up @@ -197,7 +198,8 @@ class Joke(BaseModel):
@property
def _is_azure_endpoint(self) -> bool:
"""Check if the configured endpoint is an Azure deployment."""
return "azure.com" in (self.api_base or "").lower()
hostname = urlparse(self.api_base or "").hostname or ""
return hostname == "azure.com" or hostname.endswith(".azure.com")

@property
def _llm_type(self) -> str:
Expand Down
3 changes: 3 additions & 0 deletions libs/partners/deepseek/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,9 @@ def test_is_azure_endpoint_detection(self) -> None:
DEFAULT_API_BASE,
"https://api.openai.com/v1",
"https://custom-endpoint.com/api",
"https://evil-azure.com/v1", # hostname bypass attempt
"https://notazure.com.evil.com/", # subdomain bypass attempt
"https://example.com/azure.com", # path bypass attempt
]
for endpoint in non_azure_endpoints:
llm = ChatDeepSeek(
Expand Down