Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/temporal-direct/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ async def get_population_activity(city: str) -> int:

@dataclasses.dataclass
class LLMParams:
model_id: str
messages: list[dict[str, Any]]
tool_schemas: list[dict[str, Any]]

Expand Down Expand Up @@ -148,6 +149,7 @@ async def loop(
result = await temporalio.workflow.execute_activity(
llm_call_activity,
LLMParams(
model_id=context.model.id,
messages=[m.model_dump() for m in context.messages],
tool_schemas=tool_schemas,
),
Expand Down
35 changes: 20 additions & 15 deletions examples/temporal-direct/test_durability.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def read_activity_log(log_file: pathlib.Path) -> Counter[str]:

async def test_happy_path(
client: temporalio.client.Client, log_file: pathlib.Path
) -> str:
) -> tuple[str, int]:
print("\n── test_happy_path ────────────────────────────────")
log_file.write_text("")

Expand All @@ -96,9 +96,10 @@ async def test_happy_path(
assert (
"8,336,817" in result or "8336817" in result
), f"expected NYC population in result, got: {result!r}"
counts = read_activity_log(log_file)
print(f" ✓ workflow {wid} produced {len(result)} chars")
print(f" ✓ activity calls: {dict(read_activity_log(log_file))}")
return wid
print(f" ✓ activity calls: {dict(counts)}")
return wid, counts.total()


# ── Test 2: replay determinism ───────────────────────────────────
Expand All @@ -125,6 +126,7 @@ async def test_activity_caching(
env: temporalio.testing.WorkflowEnvironment,
client: temporalio.client.Client,
log_file: pathlib.Path,
baseline_total: int,
) -> None:
print("\n── test_activity_caching ──────────────────────────")
log_file.write_text("")
Expand Down Expand Up @@ -194,21 +196,24 @@ async def test_activity_caching(
# Sanity: we actually killed worker1 mid-workflow. If worker1 had
# finished everything before the SIGINT landed, the test would
# vacuously "pass" the cache invariant without exercising resume.
assert total_post > total_pre, (
assert total_pre < baseline_total, (
f"worker1 finished the entire workflow before shutdown landed "
f"(pre={total_pre}, post={total_post}); test isn't exercising resume"
f"(pre={total_pre}, baseline={baseline_total}); not exercising resume"
)

# If worker2 ignored history and re-ran everything, total_post would
# be roughly 2x total_pre (worker1's executions + worker2 redoing
# them all). Catch that case loudly.
expected_double_run = total_pre * 2
assert total_post < expected_double_run, (
f"suspiciously high activity count after resume: {total_post} "
f"(would expect at most ~{expected_double_run - 1} if cache replayed)"
# Cache invariant: post-restart total equals one full workflow's
# worth of completions (from the happy-path baseline). If worker2
# had ignored history and re-run cached activities, total_post
# would exceed baseline_total.
assert total_post == baseline_total, (
f"unexpected activity count after resume: {total_post} "
f"(baseline is {baseline_total}); worker2 may have re-run "
f"cached activities"
)
print(" ✓ resume completed without re-running cached activities")
print(f" (total before: {total_pre}, after: {total_post})")
print(
f" (pre={total_pre}, post={total_post}, baseline={baseline_total})"
)


# ── Entry point ──────────────────────────────────────────────────
Expand Down Expand Up @@ -246,9 +251,9 @@ async def main() -> None:
):
client = env.client

wid = await test_happy_path(client, log_file)
wid, baseline_total = await test_happy_path(client, log_file)
await test_replay_determinism(client, wid)
await test_activity_caching(env, client, log_file)
await test_activity_caching(env, client, log_file, baseline_total)

print("\nAll durability checks passed.")

Expand Down
414 changes: 4 additions & 410 deletions examples/temporal-direct/uv.lock

Large diffs are not rendered by default.

25 changes: 20 additions & 5 deletions src/ai/providers/ai_gateway/client/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,29 @@ def __init__(
self.base_url = base_url
self.api_key = api_key
self.headers = dict(headers or {})
self._http = client or httpx.AsyncClient(
timeout=httpx.Timeout(timeout=300.0, connect=10.0),
)
self._owns_http = client is None
self._http_cached: httpx.AsyncClient | None = client

@property
def _http(self) -> httpx.AsyncClient:
# Constructing httpx.AsyncClient imports httpcore and anyio, which
# executes `threading.local()` at module load -- disallowed inside
# a Temporal workflow sandbox. Defer until first request so that
# constructing a Provider (e.g. via ai.get_model) is safe in a
# workflow; the actual HTTP client only materializes in activities.
if self._http_cached is None:
self._http_cached = httpx.AsyncClient(
timeout=httpx.Timeout(timeout=300.0, connect=10.0),
)
return self._http_cached

async def aclose(self) -> None:
if self._owns_http and not self._http.is_closed:
await self._http.aclose()
if (
self._owns_http
and self._http_cached is not None
and not self._http_cached.is_closed
):
await self._http_cached.aclose()

def url(self, path: str) -> str:
return f"{self.base_url.rstrip('/')}/{path.lstrip('/')}"
Expand Down
25 changes: 20 additions & 5 deletions src/ai/providers/anthropic/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ def __init__(
self._close_client_on_aclose = (
sdk_client is None and http_client is None
)
if sdk_client is None:
sdk_client = self._make_sdk_client(http_client=http_client)
self._set_client(sdk_client)
self._pending_http_client = http_client
if sdk_client is not None:
self._set_client(sdk_client)

def _make_sdk_client(
self,
Expand All @@ -112,6 +112,21 @@ def _make_sdk_client(
},
)

@property
def client(self) -> AnthropicSDKClient:
# Constructing the Anthropic SDK client imports httpx internals
# (e.g. httpx._models defines a class inheriting from
# urllib.request.Request) -- disallowed inside a Temporal
# workflow sandbox. Defer until first use so that constructing a
# Provider (e.g. via ai.get_model) is safe in a workflow; the
# actual HTTP client only materializes in activities.
if self._client is None:
self._set_client(
self._make_sdk_client(http_client=self._pending_http_client)
)
assert self._client is not None
return self._client

@property
def sdk_client(self) -> AnthropicSDKClient:
"""Provider SDK client used for Anthropic-compatible API requests."""
Expand All @@ -126,8 +141,8 @@ def is_configured(self) -> bool:

async def aclose(self) -> None:
"""Close the provider-owned SDK client, if any."""
if self._close_client_on_aclose:
await self.client.close()
if self._close_client_on_aclose and self._client is not None:
await self._client.close()

def stream(
self,
Expand Down
24 changes: 19 additions & 5 deletions src/ai/providers/openai/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ def __init__(
self._close_client_on_aclose = (
sdk_client is None and http_client is None
)
if sdk_client is None:
sdk_client = self._make_sdk_client(http_client=http_client)
self._set_client(sdk_client)
self._pending_http_client = http_client
if sdk_client is not None:
self._set_client(sdk_client)

def _make_sdk_client(
self,
Expand All @@ -110,6 +110,20 @@ def _make_sdk_client(
http_client=http_client,
)

@property
def client(self) -> OpenAISDKClient:
# Constructing the OpenAI SDK client imports httpx internals that
# trip the Temporal workflow sandbox. Defer until first use so
# that constructing a Provider (e.g. via ai.get_model) is safe in
# a workflow; the actual HTTP client only materializes in
# activities.
if self._client is None:
self._set_client(
self._make_sdk_client(http_client=self._pending_http_client)
)
assert self._client is not None
return self._client

@property
def sdk_client(self) -> OpenAISDKClient:
"""Provider SDK client used for OpenAI-compatible API requests."""
Expand All @@ -124,8 +138,8 @@ def is_configured(self) -> bool:

async def aclose(self) -> None:
"""Close the provider-owned SDK client, if any."""
if self._close_client_on_aclose:
await self.client.close()
if self._close_client_on_aclose and self._client is not None:
await self._client.close()

def stream(
self,
Expand Down
8 changes: 7 additions & 1 deletion tests/providers/anthropic/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,14 @@ def _missing_anthropic(name: str, package: str | None = None) -> object:

monkeypatch.setattr(importlib, "import_module", _missing_anthropic)

# The SDK client is constructed lazily on first use (so that building
# a provider stays import-free, e.g. inside a Temporal workflow
# sandbox), so the missing-SDK error surfaces when the client is
# accessed rather than at construction.
provider = ai.get_provider("anthropic", api_key="sk-test")
assert isinstance(provider, AnthropicCompatibleProvider)
with pytest.raises(ai.InstallationError) as exc_info:
ai.get_provider("anthropic", api_key="sk-test")
_ = provider.sdk_client

assert "could not import `anthropic`" in str(exc_info.value)
assert "required to use the anthropic provider" in str(exc_info.value)
Expand Down
27 changes: 19 additions & 8 deletions tests/providers/openai/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,14 @@ def _missing_openai(name: str, package: str | None = None) -> object:

monkeypatch.setattr(importlib, "import_module", _missing_openai)

# The SDK client is constructed lazily on first use (so that building
# a provider stays import-free, e.g. inside a Temporal workflow
# sandbox), so the missing-SDK error surfaces when the client is
# accessed rather than at construction.
provider = ai.get_provider("openai", api_key="sk-test")
assert isinstance(provider, OpenAICompatibleProvider)
with pytest.raises(ai.InstallationError) as exc_info:
ai.get_provider("openai", api_key="sk-test")
_ = provider.sdk_client

assert "could not import `openai`" in str(exc_info.value)
assert "required to use the openai provider" in str(exc_info.value)
Expand All @@ -188,14 +194,19 @@ def _missing_openai(name: str, package: str | None = None) -> object:

monkeypatch.setattr(importlib, "import_module", _missing_openai)

# The SDK client is constructed lazily on first use, so the
# missing-SDK error surfaces when the client is accessed rather than
# at construction.
provider = ai.get_provider(
"cloudflare-workers-ai",
env={
"CLOUDFLARE_ACCOUNT_ID": "account-123",
"CLOUDFLARE_API_KEY": "sk-test",
},
)
assert isinstance(provider, OpenAICompatibleProvider)
with pytest.raises(ai.InstallationError) as exc_info:
ai.get_provider(
"cloudflare-workers-ai",
env={
"CLOUDFLARE_ACCOUNT_ID": "account-123",
"CLOUDFLARE_API_KEY": "sk-test",
},
)
_ = provider.sdk_client

assert "required to use the cloudflare-workers-ai provider" in str(
exc_info.value
Expand Down
Loading