diff --git a/src/opengradient/agents/__init__.py b/src/opengradient/agents/__init__.py index a2dd27c..5a15424 100644 --- a/src/opengradient/agents/__init__.py +++ b/src/opengradient/agents/__init__.py @@ -6,13 +6,18 @@ into existing applications and agent frameworks. """ +from typing import Optional, Union + +from ..client import Client from ..types import TEE_LLM, x402SettlementMode from .og_langchain import * def langchain_adapter( - private_key: str, - model_cid: TEE_LLM, + model_cid: Union[TEE_LLM, str], + private_key: Optional[str] = None, + client: Optional[Client] = None, + temperature: float = 0.0, max_tokens: int = 300, x402_settlement_mode: x402SettlementMode = x402SettlementMode.SETTLE_BATCH, ) -> OpenGradientChatModel: @@ -21,8 +26,10 @@ def langchain_adapter( and can be plugged into LangChain agents. """ return OpenGradientChatModel( - private_key=private_key, model_cid=model_cid, + private_key=private_key, + client=client, + temperature=temperature, max_tokens=max_tokens, x402_settlement_mode=x402_settlement_mode, ) diff --git a/src/opengradient/agents/og_langchain.py b/src/opengradient/agents/og_langchain.py index b62e443..cacd156 100644 --- a/src/opengradient/agents/og_langchain.py +++ b/src/opengradient/agents/og_langchain.py @@ -1,12 +1,14 @@ # mypy: ignore-errors +import asyncio import json -from typing import Any, Callable, Dict, List, Optional, Sequence, Union +from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Sequence, Union -from langchain_core.callbacks.manager import CallbackManagerForLLMRun +from langchain_core.callbacks.manager import AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun from langchain_core.language_models.base import LanguageModelInput from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import ( AIMessage, + AIMessageChunk, BaseMessage, HumanMessage, SystemMessage, @@ -15,6 +17,7 @@ from langchain_core.messages.tool import ToolMessage from langchain_core.outputs import ( ChatGeneration, + ChatGenerationChunk, ChatResult, ) from langchain_core.runnables import Runnable @@ -25,6 +28,7 @@ from ..types import TEE_LLM, x402SettlementMode __all__ = ["OpenGradientChatModel"] +_STREAM_END = object() def _extract_content(content: Any) -> str: @@ -69,28 +73,38 @@ def _parse_tool_call(tool_call: Dict) -> ToolCall: class OpenGradientChatModel(BaseChatModel): """OpenGradient adapter class for LangChain chat model""" - model_cid: str + model_cid: Union[TEE_LLM, str] + temperature: float = 0.0 max_tokens: int = 300 - x402_settlement_mode: Optional[str] = x402SettlementMode.SETTLE_BATCH + x402_settlement_mode: x402SettlementMode = x402SettlementMode.SETTLE_BATCH _client: Client = PrivateAttr() _tools: List[Dict] = PrivateAttr(default_factory=list) def __init__( self, - private_key: str, - model_cid: TEE_LLM, + model_cid: Union[TEE_LLM, str], + private_key: Optional[str] = None, + client: Optional[Client] = None, + temperature: float = 0.0, max_tokens: int = 300, - x402_settlement_mode: Optional[x402SettlementMode] = x402SettlementMode.SETTLE_BATCH, + x402_settlement_mode: x402SettlementMode = x402SettlementMode.SETTLE_BATCH, **kwargs, ): super().__init__( model_cid=model_cid, + temperature=temperature, max_tokens=max_tokens, x402_settlement_mode=x402_settlement_mode, **kwargs, ) - self._client = Client(private_key=private_key) + if client is not None and private_key is not None: + raise ValueError("Pass either client or private_key, not both.") + if client is None: + if private_key is None: + raise ValueError("Either client or private_key must be provided.") + client = Client(private_key=private_key) + self._client = client @property def _llm_type(self) -> str: @@ -136,7 +150,103 @@ def _generate( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - sdk_messages = [] + sdk_messages = self._to_sdk_messages(messages) + + chat_output = self._client.llm.chat( + model=self.model_cid, + messages=sdk_messages, + stop_sequence=stop, + temperature=self.temperature, + max_tokens=self.max_tokens, + tools=self._tools, + x402_settlement_mode=self.x402_settlement_mode, + ) + + finish_reason = chat_output.finish_reason or "" + chat_response = chat_output.chat_output or {} + + if chat_response.get("tool_calls"): + tool_calls = [_parse_tool_call(tc) for tc in chat_response["tool_calls"]] + ai_message = AIMessage(content="", tool_calls=tool_calls) + else: + ai_message = AIMessage(content=_extract_content(chat_response.get("content", ""))) + + return ChatResult(generations=[ChatGeneration(message=ai_message, generation_info={"finish_reason": finish_reason})]) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + sdk_messages = self._to_sdk_messages(messages) + stream = self._client.llm.chat( + model=self.model_cid, + messages=sdk_messages, + stop_sequence=stop, + temperature=self.temperature, + max_tokens=self.max_tokens, + tools=self._tools, + x402_settlement_mode=self.x402_settlement_mode, + stream=True, + ) + + for chunk in stream: + if not chunk.choices: + continue + + choice = chunk.choices[0] + delta = choice.delta + content = _extract_content(delta.content) + + additional_kwargs: Dict[str, Any] = {} + if delta.tool_calls: + additional_kwargs["tool_calls"] = delta.tool_calls + + chunk_kwargs: Dict[str, Any] = { + "content": content, + "additional_kwargs": additional_kwargs, + } + if chunk.usage: + chunk_kwargs["usage_metadata"] = { + "input_tokens": chunk.usage.prompt_tokens, + "output_tokens": chunk.usage.completion_tokens, + "total_tokens": chunk.usage.total_tokens, + } + + generation_info = {"finish_reason": choice.finish_reason} if choice.finish_reason else None + yield ChatGenerationChunk( + message=AIMessageChunk(**chunk_kwargs), + generation_info=generation_info, + ) + + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + # Bridge the sync iterator from the SDK to LangChain's async streaming API. + iterator = self._stream(messages=messages, stop=stop, **kwargs) + while True: + # Use next(..., default) so StopIteration does not cross Future boundaries. + chunk = await asyncio.to_thread(next, iterator, _STREAM_END) + if chunk is _STREAM_END: + break + yield chunk + + @property + def _identifying_params(self) -> Dict[str, Any]: + return { + "model_name": self.model_cid, + "temperature": self.temperature, + } + + @staticmethod + def _to_sdk_messages(messages: List[Any]) -> List[Dict[str, Any]]: + sdk_messages: List[Dict[str, Any]] = [] for message in messages: if isinstance(message, SystemMessage): sdk_messages.append({"role": "system", "content": _extract_content(message.content)}) @@ -162,31 +272,21 @@ def _generate( "tool_call_id": message.tool_call_id, } ) - else: - raise ValueError(f"Unexpected message type: {message}") + elif isinstance(message, dict): + role = message.get("role") + if role not in {"system", "user", "assistant", "tool"}: + raise ValueError(f"Unexpected message role in dict message: {role}") - chat_output = self._client.llm.chat( - model=self.model_cid, - messages=sdk_messages, - stop_sequence=stop, - max_tokens=self.max_tokens, - tools=self._tools, - x402_settlement_mode=self.x402_settlement_mode, - ) - - finish_reason = chat_output.finish_reason or "" - chat_response = chat_output.chat_output or {} - - if chat_response.get("tool_calls"): - tool_calls = [_parse_tool_call(tc) for tc in chat_response["tool_calls"]] - ai_message = AIMessage(content="", tool_calls=tool_calls) - else: - ai_message = AIMessage(content=_extract_content(chat_response.get("content", ""))) + sdk_message: Dict[str, Any] = { + "role": role, + "content": _extract_content(message.get("content", "")), + } + if role == "assistant" and message.get("tool_calls"): + sdk_message["tool_calls"] = message["tool_calls"] + if role == "tool" and message.get("tool_call_id"): + sdk_message["tool_call_id"] = message["tool_call_id"] - return ChatResult(generations=[ChatGeneration(message=ai_message, generation_info={"finish_reason": finish_reason})]) - - @property - def _identifying_params(self) -> Dict[str, Any]: - return { - "model_name": self.model_cid, - } + sdk_messages.append(sdk_message) + else: + raise ValueError(f"Unexpected message type: {message}") + return sdk_messages diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index 83f8eb8..e628a4b 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -25,6 +25,10 @@ X402_PROCESSING_HASH_HEADER = "x-processing-hash" X402_PLACEHOLDER_API_KEY = "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef" BASE_TESTNET_NETWORK = "eip155:84532" +RETRYABLE_PAYMENT_ERROR_SNIPPETS = ( + "invalid payment required response", + "failed to handle payment", +) TIMEOUT = httpx.Timeout( timeout=90.0, @@ -122,10 +126,7 @@ def __init__(self, wallet_account: LocalAccount, og_llm_server_url: str, og_llm_ _fetch_tls_cert_as_ssl_context(self._og_llm_streaming_server_url) or True ) - signer = EthAccountSignerv2(self._wallet_account) - self._x402_client = x402Clientv2() - register_exact_evm_clientv2(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK]) - register_upto_evm_clientv2(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK]) + self._initialize_x402_client() self._request_client_ctx = None self._request_client = None @@ -156,6 +157,24 @@ async def _initialize_http_clients(self) -> None: self._stream_client_ctx = x402HttpxClientv2(self._x402_client, verify=self._streaming_tls_verify) self._stream_client = await self._stream_client_ctx.__aenter__() + def _initialize_x402_client(self) -> None: + signer = EthAccountSignerv2(self._wallet_account) + self._x402_client = x402Clientv2() + register_exact_evm_clientv2(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK]) + register_upto_evm_clientv2(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK]) + + @staticmethod + def _is_retryable_payment_error(error: Exception) -> bool: + message = str(error).lower() + if any(snippet in message for snippet in RETRYABLE_PAYMENT_ERROR_SNIPPETS): + return True + return "paymenterror" in error.__class__.__name__.lower() + + async def _refresh_payment_clients(self) -> None: + await self._close_http_clients() + self._initialize_x402_client() + await self._initialize_http_clients() + async def _close_http_clients(self) -> None: if self._request_client_ctx is not None: await self._request_client_ctx.__aexit__(None, None, None) @@ -205,7 +224,7 @@ def completion( max_tokens: int = 100, stop_sequence: Optional[List[str]] = None, temperature: float = 0.0, - x402_settlement_mode: Optional[x402SettlementMode] = x402SettlementMode.SETTLE_BATCH, + x402_settlement_mode: x402SettlementMode = x402SettlementMode.SETTLE_BATCH, ) -> TextGenerationOutput: """ Perform inference on an LLM model using completions via TEE. @@ -247,7 +266,7 @@ def _tee_llm_completion( max_tokens: int = 100, stop_sequence: Optional[List[str]] = None, temperature: float = 0.0, - x402_settlement_mode: Optional[x402SettlementMode] = x402SettlementMode.SETTLE_BATCH, + x402_settlement_mode: x402SettlementMode = x402SettlementMode.SETTLE_BATCH, ) -> TextGenerationOutput: """ Route completion request to OpenGradient TEE LLM server with x402 payments. @@ -290,9 +309,12 @@ async def make_request_v2(): try: return self._run_coroutine(make_request_v2()) - except OpenGradientError: - raise except Exception as e: + if self._is_retryable_payment_error(e): + self._run_coroutine(self._refresh_payment_clients()) + return self._run_coroutine(make_request_v2()) + if isinstance(e, OpenGradientError): + raise raise OpenGradientError(f"TEE LLM completion failed: {str(e)}") def chat( @@ -304,7 +326,7 @@ def chat( temperature: float = 0.0, tools: Optional[List[Dict]] = None, tool_choice: Optional[str] = None, - x402_settlement_mode: Optional[x402SettlementMode] = x402SettlementMode.SETTLE_BATCH, + x402_settlement_mode: x402SettlementMode = x402SettlementMode.SETTLE_BATCH, stream: bool = False, ) -> Union[TextGenerationOutput, TextGenerationStream]: """ @@ -429,9 +451,12 @@ async def make_request_v2(): try: return self._run_coroutine(make_request_v2()) - except OpenGradientError: - raise except Exception as e: + if self._is_retryable_payment_error(e): + self._run_coroutine(self._refresh_payment_clients()) + return self._run_coroutine(make_request_v2()) + if isinstance(e, OpenGradientError): + raise raise OpenGradientError(f"TEE LLM chat failed: {str(e)}") def _tee_llm_chat_stream_sync( @@ -456,17 +481,33 @@ def _tee_llm_chat_stream_sync( async def _stream(): try: - async for chunk in self._tee_llm_chat_stream_async( - model=model, - messages=messages, - max_tokens=max_tokens, - stop_sequence=stop_sequence, - temperature=temperature, - tools=tools, - tool_choice=tool_choice, - x402_settlement_mode=x402_settlement_mode, - ): - queue.put(chunk) + try: + async for chunk in self._tee_llm_chat_stream_async( + model=model, + messages=messages, + max_tokens=max_tokens, + stop_sequence=stop_sequence, + temperature=temperature, + tools=tools, + tool_choice=tool_choice, + x402_settlement_mode=x402_settlement_mode, + ): + queue.put(chunk) + except Exception as e: + if not self._is_retryable_payment_error(e): + raise + await self._refresh_payment_clients() + async for chunk in self._tee_llm_chat_stream_async( + model=model, + messages=messages, + max_tokens=max_tokens, + stop_sequence=stop_sequence, + temperature=temperature, + tools=tools, + tool_choice=tool_choice, + x402_settlement_mode=x402_settlement_mode, + ): + queue.put(chunk) except Exception as e: queue.put(e) finally: diff --git a/tests/langchain_adapter_test.py b/tests/langchain_adapter_test.py index 1671c7f..d8ec348 100644 --- a/tests/langchain_adapter_test.py +++ b/tests/langchain_adapter_test.py @@ -1,6 +1,7 @@ import json import os import sys +import asyncio from unittest.mock import MagicMock, patch import pytest @@ -11,7 +12,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..")) from src.opengradient.agents.og_langchain import OpenGradientChatModel, _extract_content, _parse_tool_call -from src.opengradient.types import TEE_LLM, TextGenerationOutput, x402SettlementMode +from src.opengradient.types import StreamChoice, StreamChunk, StreamDelta, StreamUsage, TEE_LLM, TextGenerationOutput, x402SettlementMode @pytest.fixture @@ -26,7 +27,7 @@ def mock_client(): @pytest.fixture def model(mock_client): """Create an OpenGradientChatModel with a mocked client.""" - return OpenGradientChatModel(private_key="0x" + "a" * 64, model_cid=TEE_LLM.GPT_5) + return OpenGradientChatModel(model_cid=TEE_LLM.GPT_5, private_key="0x" + "a" * 64) class TestOpenGradientChatModel: @@ -39,21 +40,41 @@ def test_initialization(self, model): def test_initialization_custom_max_tokens(self, mock_client): """Test model initializes with custom max_tokens.""" - model = OpenGradientChatModel(private_key="0x" + "a" * 64, model_cid=TEE_LLM.CLAUDE_HAIKU_4_5, max_tokens=1000) + model = OpenGradientChatModel(model_cid=TEE_LLM.CLAUDE_HAIKU_4_5, private_key="0x" + "a" * 64, max_tokens=1000) assert model.max_tokens == 1000 def test_initialization_custom_settlement_mode(self, mock_client): """Test model initializes with custom settlement mode.""" model = OpenGradientChatModel( - private_key="0x" + "a" * 64, model_cid=TEE_LLM.GPT_5, + private_key="0x" + "a" * 64, x402_settlement_mode=x402SettlementMode.SETTLE, ) assert model.x402_settlement_mode == x402SettlementMode.SETTLE + def test_initialization_with_injected_client(self, mock_client): + """Test model can reuse an injected SDK client.""" + model = OpenGradientChatModel(model_cid="openai/gpt-4.1", client=mock_client) + assert model.model_cid == "openai/gpt-4.1" + assert model._client is mock_client + + def test_initialization_requires_client_or_private_key(self): + """Test model requires either a private key or SDK client.""" + with pytest.raises(ValueError, match="Either client or private_key must be provided."): + OpenGradientChatModel(model_cid=TEE_LLM.GPT_5) + + def test_initialization_rejects_client_and_private_key(self, mock_client): + """Test model rejects duplicate client configuration.""" + with pytest.raises(ValueError, match="Pass either client or private_key, not both."): + OpenGradientChatModel( + model_cid=TEE_LLM.GPT_5, + private_key="0x" + "a" * 64, + client=mock_client, + ) + def test_identifying_params(self, model): """Test _identifying_params returns model name.""" - assert model._identifying_params == {"model_name": TEE_LLM.GPT_5} + assert model._identifying_params == {"model_name": TEE_LLM.GPT_5, "temperature": 0.0} class TestGenerate: @@ -155,6 +176,91 @@ def test_empty_chat_output(self, model, mock_client): assert result.generations[0].message.content == "" + def test_stream_response(self, model, mock_client): + """Test _stream yields incremental chunks and usage metadata.""" + stream_chunks = iter( + [ + StreamChunk( + choices=[ + StreamChoice( + delta=StreamDelta(content="Hello"), + index=0, + finish_reason=None, + ) + ], + model="openai/gpt-5", + ), + StreamChunk( + choices=[ + StreamChoice( + delta=StreamDelta(content=" world"), + index=0, + finish_reason="stop", + ) + ], + model="openai/gpt-5", + usage=StreamUsage(prompt_tokens=10, completion_tokens=2, total_tokens=12), + is_final=True, + ), + ] + ) + mock_client.llm.chat.return_value = stream_chunks + + generations = list(model._stream([HumanMessage(content="Hi")])) + + assert len(generations) == 2 + assert generations[0].message.content == "Hello" + assert generations[1].message.content == " world" + assert generations[1].generation_info == {"finish_reason": "stop"} + assert generations[1].message.usage_metadata == { + "input_tokens": 10, + "output_tokens": 2, + "total_tokens": 12, + } + assert mock_client.llm.chat.call_args.kwargs["stream"] is True + + def test_astream_response(self, model, mock_client): + """Test _astream yields incremental chunks via async interface.""" + stream_chunks = iter( + [ + StreamChunk( + choices=[ + StreamChoice( + delta=StreamDelta(content="Hello"), + index=0, + finish_reason=None, + ) + ], + model="openai/gpt-5", + ), + StreamChunk( + choices=[ + StreamChoice( + delta=StreamDelta(content=" world"), + index=0, + finish_reason="stop", + ) + ], + model="openai/gpt-5", + is_final=True, + ), + ] + ) + mock_client.llm.chat.return_value = stream_chunks + + async def collect_generations(): + generations = [] + async for generation in model._astream([HumanMessage(content="Hi")]): + generations.append(generation) + return generations + + generations = asyncio.run(collect_generations()) + + assert len(generations) == 2 + assert generations[0].message.content == "Hello" + assert generations[1].message.content == " world" + assert generations[1].generation_info == {"finish_reason": "stop"} + class TestMessageConversion: def test_converts_all_message_types(self, model, mock_client): @@ -199,6 +305,20 @@ def test_unsupported_message_type_raises(self, model, mock_client): with pytest.raises(ValueError, match="Unexpected message type"): model._generate([MagicMock(spec=[])]) + def test_accepts_dict_messages(self, model, mock_client): + """Test that dict messages are accepted for compatibility with existing routes.""" + mock_client.llm.chat.return_value = TextGenerationOutput( + transaction_hash="external", + finish_reason="stop", + chat_output={"role": "assistant", "content": "ok"}, + ) + + model._generate([{"role": "user", "content": "Hi"}]) + + assert mock_client.llm.chat.call_args.kwargs["messages"] == [ + {"role": "user", "content": "Hi"} + ] + def test_passes_correct_params_to_client(self, model, mock_client): """Test that _generate passes model params correctly to the SDK client.""" mock_client.llm.chat.return_value = TextGenerationOutput( @@ -213,6 +333,7 @@ def test_passes_correct_params_to_client(self, model, mock_client): model=TEE_LLM.GPT_5, messages=[{"role": "user", "content": "Hi"}], stop_sequence=["END"], + temperature=0.0, max_tokens=300, tools=[], x402_settlement_mode=x402SettlementMode.SETTLE_BATCH,