Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
with:
enable-cache: true

- run: uv sync --all-extras --all-packages --group lint
- run: uv sync --all-packages --group lint

- uses: pre-commit/action@v3.0.0
with:
Expand All @@ -40,6 +40,11 @@ jobs:
test:
runs-on: ubuntu-latest
timeout-minutes: 10
strategy:
fail-fast: false
matrix:
pydantic-ai-version: [min, latest]
name: test (pydantic-ai ${{ matrix.pydantic-ai-version }})
steps:
- uses: actions/checkout@v4

Expand All @@ -48,5 +53,10 @@ jobs:
enable-cache: true

- run: mkdir .coverage
- run: uv sync --all-extras --group dev
- run: uv run coverage run -m pytest --durations=100 -n auto --dist=loadgroup
- if: matrix.pydantic-ai-version == 'min'
run: uv sync --group dev --resolution lowest-direct
- if: matrix.pydantic-ai-version == 'latest'
run: uv sync --group dev --upgrade-package pydantic-ai
env:
UV_FROZEN: "0"
- run: make test
7 changes: 0 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,6 @@ repos:

- repo: local
hooks:
- id: format
name: Format
entry: make
args: [format]
language: system
types: [python]
pass_filenames: false
- id: lint
name: Lint
entry: make
Expand Down
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
.PHONY: install
install: ## Install git hooks for local development
uv run pre-commit install

.PHONY: format
format: ## Format the code
uv run ruff format
Expand Down
19 changes: 5 additions & 14 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,15 @@ version = { attr = "cragents._version.__version__" }

[dependency-groups]
dev = [
"anyio>=4.11.0",
"coverage>=7.12.0",
"anyio>=4.12.1",
"coverage>=7.13.2",
"inline-snapshot>=0.31.1",
"pyright>=1.1.406",
"pytest>=9.0.1",
"pre-commit>=4.5.1",
"pyright>=1.1.408",
"pytest>=9.0.2",
"pytest-xdist>=3.8.0",
"trio>=0.32.0",
]
example = [
"chromadb>=1.3.7",
"datasets>=4.4.2",
"langchain-text-splitters>=1.1.0",
"markdownify>=1.2.2",
"playwright>=1.57.0",
"requests-cache>=1.2.1",
"rich>=14.2.0",
"transformers>=4.57.3",
]
lint = [
"pyright>=1.1.406",
"ruff>=0.14.1",
Expand Down
105 changes: 104 additions & 1 deletion tests/test_.py → tests/test_agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import pytest
from inline_snapshot import snapshot
from pydantic_ai import ToolOutput
from pydantic_ai.models.openai import OpenAIChatModel
from pydantic_ai.models.openai import OpenAIChatModel, OpenAIChatModelSettings
from pydantic_ai.models.test import TestModel
from pydantic_ai.providers.openai import OpenAIProvider

from cragents import Anchor, Constrain, CRAgent, Free, Think, UseTools, vllm_model_profile
Expand All @@ -24,6 +25,9 @@
]


# ── end-to-end set_guide output type tests ────────────────────────────────────


async def test_default_agent_output():
agent = CRAgent(model)
await agent.set_guide(generation_sequence)
Expand Down Expand Up @@ -134,3 +138,102 @@ async def test_mixed_output_type():
}
}
)


# ── set_guide error handling and model settings ───────────────────────────────


async def test_set_guide_requires_openai_model():
agent = CRAgent(TestModel())
with pytest.raises(RuntimeError, match="OpenAIChatModel required"):
await agent.set_guide([Anchor("hi ")])


async def test_set_guide_creates_model_settings_when_none():
agent = CRAgent(model)
assert agent.model_settings is None
await agent.set_guide([Anchor("hi ")])
assert agent.model_settings is not None
assert "extra_body" in agent.model_settings


async def test_set_guide_preserves_existing_model_settings():
agent = CRAgent(model, model_settings=OpenAIChatModelSettings(temperature=0.5))
await agent.set_guide([Anchor("hi ")])
assert agent.model_settings["temperature"] == 0.5
assert "extra_body" in agent.model_settings


async def test_set_guide_overwrites_on_second_call():
agent = CRAgent(model)
await agent.set_guide([Anchor("first ")])
first_grammar = agent.model_settings["extra_body"]["structured_outputs"]["grammar"]
await agent.set_guide([Anchor("second ")])
second_grammar = agent.model_settings["extra_body"]["structured_outputs"]["grammar"]
assert first_grammar != second_grammar
assert "second" in second_grammar


# ── set_guide UseTools schema handling ────────────────────────────────────────


async def test_set_guide_explicit_use_tools_schema_not_overwritten():
explicit_schema = {"type": "number"}
agent = CRAgent(model)
await agent.set_guide([UseTools(json_schema=explicit_schema)])
grammar = agent.model_settings["extra_body"]["structured_outputs"]["grammar"]
assert '"type": "number"' in grammar


async def test_set_guide_use_tools_with_registered_tool():
agent = CRAgent(model)

@agent.tool_plain
def my_tool(x: int) -> str:
return str(x)

await agent.set_guide([UseTools()])
grammar = agent.model_settings["extra_body"]["structured_outputs"]["grammar"]
# The tool's parameter schema (containing "x") should appear in the grammar
assert '"x"' in grammar


async def test_set_guide_use_tools_tool_names():
agent = CRAgent(model)
await agent.set_guide([UseTools(json_schema={"type": "string"}, tool_names=["alpha", "beta"])])
grammar = agent.model_settings["extra_body"]["structured_outputs"]["grammar"]
assert 'FUNCTION_NAME: ("alpha" | "beta")' in grammar


async def test_set_guide_merges_toolset_with_anyof_output():
# When the output schema already has anyOf (multiple output types) and the agent
# also has registered tools, anyOf = toolset_schemas + return_schema["anyOf"]
agent = CRAgent(model, output_type=[ToolOutput(bool), ToolOutput(int)])

@agent.tool_plain
def helper(x: str) -> str:
return x

await agent.set_guide([UseTools()])
grammar = agent.model_settings["extra_body"]["structured_outputs"]["grammar"]
assert "tool_schema" in grammar
assert "anyOf" in grammar


# ── vllm_model_profile ─────────────────────────────────────────────────────────


def test_vllm_profile_strict_tool_definition():
assert vllm_model_profile.openai_supports_strict_tool_definition is False


def test_vllm_profile_tool_choice_required():
assert vllm_model_profile.openai_supports_tool_choice_required is False


def test_vllm_profile_json_object_output():
assert vllm_model_profile.supports_json_object_output is False


def test_vllm_profile_json_schema_output():
assert vllm_model_profile.supports_json_schema_output is True
Loading