Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 72 additions & 19 deletions src/agents/function_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

# griffelib exposes the `griffe` package at runtime but currently does not ship typing markers.
from griffe import Docstring, DocstringSectionKind # type: ignore[import-untyped]
from pydantic import BaseModel, Field, create_model
from pydantic import BaseModel, ConfigDict, Field, create_model
from pydantic.fields import FieldInfo

from .exceptions import UserError
Expand Down Expand Up @@ -40,6 +40,8 @@ class FuncSchema:
strict_json_schema: bool = True
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
as it increases the likelihood of correct JSON input."""
pydantic_field_name_map: dict[str, str] | None = None
"""Maps function parameter names to the internal Pydantic field names used for validation."""

def to_call_args(self, data: BaseModel) -> tuple[list[Any], dict[str, Any]]:
"""
Expand All @@ -56,7 +58,8 @@ def to_call_args(self, data: BaseModel) -> tuple[list[Any], dict[str, Any]]:
if self.takes_context and idx == 0:
continue

value = getattr(data, name, None)
pydantic_field_name = (self.pydantic_field_name_map or {}).get(name, name)
value = getattr(data, pydantic_field_name, None)
if param.kind == param.VAR_POSITIONAL:
# e.g. *args: extend positional args and mark that *args is now seen
positional_args.extend(value or [])
Expand Down Expand Up @@ -221,6 +224,34 @@ def _extract_field_info_from_metadata(metadata: tuple[Any, ...]) -> FieldInfo |
return None


_PYDANTIC_PROTECTED_FIELD_PREFIXES = ("model_dump", "model_validate")


def _requires_pydantic_alias(name: str) -> bool:
"""Returns True when a parameter name cannot safely be used as a Pydantic field name."""

return name == "model_config" or any(
name.startswith(prefix) for prefix in _PYDANTIC_PROTECTED_FIELD_PREFIXES
)


def _make_safe_pydantic_field_name(name: str, used_names: set[str]) -> str:
"""Generates a unique internal Pydantic field name for aliased parameters."""

candidate = f"func_arg_{name}"
suffix = 1
while candidate in used_names:
suffix += 1
candidate = f"func_arg_{name}_{suffix}"
return candidate


def _with_field_alias(field_info: FieldInfo, alias: str) -> FieldInfo:
"""Returns a copy of a field definition with an explicit validation/serialization alias."""

return FieldInfo.merge_field_infos(field_info, Field(alias=alias))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve aliases when aliasing reserved params

When a reserved parameter also has a user-specified Field(alias=...) (for example model_dump: str = Field(alias="payload")), this helper is called after merging that FieldInfo and replaces the public alias with the Python parameter name. The generated schema then exposes/validates model_dump instead of payload, making Field alias behavior depend on whether the parameter name happens to be Pydantic-protected; preserve an existing validation/serialization alias while still mapping the internal field back for to_call_args().

Useful? React with 👍 / 👎.



def function_schema(
func: Callable[..., Any],
docstring_style: DocstringStyle | None = None,
Expand Down Expand Up @@ -317,8 +348,18 @@ def function_schema(
# We will collect field definitions for create_model as a dict:
# field_name -> (type_annotation, default_value_or_Field(...))
fields: dict[str, Any] = {}
pydantic_field_name_map: dict[str, str] = {}
used_pydantic_field_names = {name for name, _ in filtered_params}

for name, param in filtered_params:
pydantic_field_name = (
_make_safe_pydantic_field_name(name, used_pydantic_field_names)
if _requires_pydantic_alias(name)
else name
Comment thread
Epochex marked this conversation as resolved.
)
pydantic_field_name_map[name] = pydantic_field_name
used_pydantic_field_names.add(pydantic_field_name)

ann = type_hints.get(name, param.annotation)
default = param.default

Expand All @@ -344,9 +385,9 @@ def function_schema(
ann = list[ann] # type: ignore

# Default factory to empty list
fields[name] = (
fields[pydantic_field_name] = (
ann,
Field(default_factory=list, description=field_description),
Field(default_factory=list, description=field_description, alias=name),
)

elif param.kind == param.VAR_KEYWORD:
Expand All @@ -362,9 +403,9 @@ def function_schema(
# e.g. def foo(**kwargs: int) -> Dict[str, int]
ann = dict[str, ann] # type: ignore

fields[name] = (
fields[pydantic_field_name] = (
ann,
Field(default_factory=dict, description=field_description),
Field(default_factory=dict, description=field_description, alias=name),
)

else:
Expand All @@ -381,30 +422,41 @@ def function_schema(
merged = FieldInfo.merge_field_infos(merged, default=default)
elif isinstance(default, FieldInfo):
merged = FieldInfo.merge_field_infos(merged, default)
fields[name] = (ann, merged)
if pydantic_field_name != name:
merged = _with_field_alias(merged, name)
fields[pydantic_field_name] = (ann, merged)
elif default == inspect._empty:
# Required field
fields[name] = (
ann,
Field(..., description=field_description),
field = (
Field(..., description=field_description, alias=name)
if pydantic_field_name != name
else Field(..., description=field_description)
)
fields[pydantic_field_name] = (ann, field)
elif isinstance(default, FieldInfo):
# Parameter with a default value that is a Field(...)
fields[name] = (
ann,
FieldInfo.merge_field_infos(
default, description=field_description or default.description
),
field = FieldInfo.merge_field_infos(
default,
description=field_description or default.description,
)
if pydantic_field_name != name:
field = _with_field_alias(field, name)
fields[pydantic_field_name] = (ann, field)
else:
# Parameter with a default value
fields[name] = (
ann,
Field(default=default, description=field_description),
field = (
Field(default=default, description=field_description, alias=name)
if pydantic_field_name != name
else Field(default=default, description=field_description)
)
fields[pydantic_field_name] = (ann, field)

# 3. Dynamically build a Pydantic model
dynamic_model = create_model(f"{func_name}_args", __base__=BaseModel, **fields)
dynamic_model = create_model(
f"{func_name}_args",
__config__=ConfigDict(populate_by_name=True),
**fields,
)

# 4. Build JSON schema from that model
json_schema = dynamic_model.model_json_schema()
Expand All @@ -421,4 +473,5 @@ def function_schema(
signature=sig,
takes_context=takes_context,
strict_json_schema=strict_json_schema,
pydantic_field_name_map=pydantic_field_name_map,
)
52 changes: 52 additions & 0 deletions tests/test_function_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,58 @@ def test_varargs_function():
assert result2 == (7, (9.9,), False, {"some_key": "some_value"})


@pytest.mark.parametrize("param_name", ["model_config", "model_dump", "model_validate"])
def test_function_schema_supports_pydantic_reserved_param_names(param_name: str) -> None:
namespace: dict[str, Any] = {}
exec(
f"def reserved_name_tool({param_name}: str) -> str:\n return {param_name}\n",
namespace,
)
func = namespace["reserved_name_tool"]

func_schema = function_schema(func)

assert func_schema.params_json_schema["properties"][param_name]["type"] == "string"
parsed = func_schema.params_pydantic_model(**{param_name: "value"})
args, kwargs_dict = func_schema.to_call_args(parsed)

assert func(*args, **kwargs_dict) == "value"


def test_function_schema_avoids_reserved_name_alias_collisions() -> None:
def collision_tool(model_dump: str, func_arg_model_dump: int) -> tuple[str, int]:
return model_dump, func_arg_model_dump

func_schema = function_schema(collision_tool)

properties = func_schema.params_json_schema["properties"]
assert set(properties) == {"model_dump", "func_arg_model_dump"}

parsed = func_schema.params_pydantic_model(
model_dump="value",
func_arg_model_dump=3,
)
args, kwargs_dict = func_schema.to_call_args(parsed)

assert collision_tool(*args, **kwargs_dict) == ("value", 3)


def test_function_schema_preserves_field_alias_defaults() -> None:
def aliased_tool(city: str = Field(alias="location")) -> str:
return city

func_schema = function_schema(aliased_tool)

properties = func_schema.params_json_schema["properties"]
assert "location" in properties
assert "city" not in properties

parsed = func_schema.params_pydantic_model(location="Paris")
args, kwargs_dict = func_schema.to_call_args(parsed)

assert aliased_tool(*args, **kwargs_dict) == "Paris"


class Foo(TypedDict):
a: int
b: str
Expand Down