diff --git a/Makefile b/Makefile index acc46ee..d4fabf8 100644 --- a/Makefile +++ b/Makefile @@ -31,7 +31,7 @@ docs: # Testing # ============================================================================ -test: utils_test client_test langchain_adapter_test opg_token_test +test: utils_test client_test langchain_adapter_test opg_token_test tee_registry_test utils_test: pytest tests/utils_test.py -v @@ -45,6 +45,9 @@ langchain_adapter_test: opg_token_test: pytest tests/opg_token_test.py -v +tee_registry_test: + pytest tests/tee_registry_test.py -v + integrationtest: python integrationtest/agent/test_agent.py python integrationtest/workflow_models/test_workflow_models.py @@ -87,16 +90,16 @@ chat-stream: chat-tool: python -m opengradient.cli chat \ --model $(MODEL) \ - --messages '[{"role":"user","content":"What is the weather in Tokyo?"}]' \ - --tools '[{"type":"function","function":{"name":"get_weather","description":"Get weather for a location","parameters":{"type":"object","properties":{"location":{"type":"string"},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}}]' \ - --max-tokens 100 + --messages '[{"role":"system","content":"You are a helpful assistant. Use tools when needed."},{"role":"user","content":"What'\''s the weather like in Dallas, Texas? Give me the temperature in fahrenheit."}]' \ + --tools '[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather in a given location","parameters":{"type":"object","properties":{"city":{"type":"string"},"state":{"type":"string"},"unit":{"type":"string","enum":["fahrenheit","celsius"]}},"required":["city","state","unit"]}}}]' \ + --max-tokens 200 chat-stream-tool: python -m opengradient.cli chat \ --model $(MODEL) \ - --messages '[{"role":"user","content":"What is the weather in Tokyo?"}]' \ - --tools '[{"type":"function","function":{"name":"get_weather","description":"Get weather for a location","parameters":{"type":"object","properties":{"location":{"type":"string"},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}}]' \ - --max-tokens 100 \ + --messages '[{"role":"system","content":"You are a helpful assistant. Use tools when needed."},{"role":"user","content":"What'\''s the weather like in Dallas, Texas? Give me the temperature in fahrenheit."}]' \ + --tools '[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather in a given location","parameters":{"type":"object","properties":{"city":{"type":"string"},"state":{"type":"string"},"unit":{"type":"string","enum":["fahrenheit","celsius"]}},"required":["city","state","unit"]}}}]' \ + --max-tokens 200 \ --stream .PHONY: install build publish check docs test utils_test client_test langchain_adapter_test opg_token_test integrationtest examples \ diff --git a/pyproject.toml b/pyproject.toml index 262637b..fd0428d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "opengradient" -version = "0.7.5" +version = "0.7.6" description = "Python SDK for OpenGradient decentralized model management & inference services" authors = [{name = "OpenGradient", email = "adam@vannalabs.ai"}] readme = "README.md" diff --git a/src/opengradient/abi/TEERegistry.abi b/src/opengradient/abi/TEERegistry.abi new file mode 100644 index 0000000..51fe5b3 --- /dev/null +++ b/src/opengradient/abi/TEERegistry.abi @@ -0,0 +1,80 @@ +[ + { + "inputs": [{"internalType": "uint8", "name": "teeType", "type": "uint8"}], + "name": "getActiveTEEs", + "outputs": [ + { + "components": [ + {"internalType": "address", "name": "owner", "type": "address"}, + {"internalType": "address", "name": "paymentAddress", "type": "address"}, + {"internalType": "string", "name": "endpoint", "type": "string"}, + {"internalType": "bytes", "name": "publicKey", "type": "bytes"}, + {"internalType": "bytes", "name": "tlsCertificate", "type": "bytes"}, + {"internalType": "bytes32", "name": "pcrHash", "type": "bytes32"}, + {"internalType": "uint8", "name": "teeType", "type": "uint8"}, + {"internalType": "bool", "name": "enabled", "type": "bool"}, + {"internalType": "uint256", "name": "registeredAt", "type": "uint256"}, + {"internalType": "uint256", "name": "lastHeartbeatAt", "type": "uint256"} + ], + "internalType": "struct TEERegistry.TEEInfo[]", + "name": "", + "type": "tuple[]" + } + ], + "stateMutability": "view", + "type": "function" + }, + { + "inputs": [{"internalType": "uint8", "name": "teeType", "type": "uint8"}], + "name": "getEnabledTEEs", + "outputs": [{"internalType": "bytes32[]", "name": "", "type": "bytes32[]"}], + "stateMutability": "view", + "type": "function" + }, + { + "inputs": [{"internalType": "uint8", "name": "teeType", "type": "uint8"}], + "name": "getTEEsByType", + "outputs": [{"internalType": "bytes32[]", "name": "", "type": "bytes32[]"}], + "stateMutability": "view", + "type": "function" + }, + { + "inputs": [{"internalType": "bytes32", "name": "teeId", "type": "bytes32"}], + "name": "getTEE", + "outputs": [ + { + "components": [ + {"internalType": "address", "name": "owner", "type": "address"}, + {"internalType": "address", "name": "paymentAddress", "type": "address"}, + {"internalType": "string", "name": "endpoint", "type": "string"}, + {"internalType": "bytes", "name": "publicKey", "type": "bytes"}, + {"internalType": "bytes", "name": "tlsCertificate", "type": "bytes"}, + {"internalType": "bytes32", "name": "pcrHash", "type": "bytes32"}, + {"internalType": "uint8", "name": "teeType", "type": "uint8"}, + {"internalType": "bool", "name": "enabled", "type": "bool"}, + {"internalType": "uint256", "name": "registeredAt", "type": "uint256"}, + {"internalType": "uint256", "name": "lastHeartbeatAt", "type": "uint256"} + ], + "internalType": "struct TEERegistry.TEEInfo", + "name": "", + "type": "tuple" + } + ], + "stateMutability": "view", + "type": "function" + }, + { + "inputs": [{"internalType": "bytes32", "name": "teeId", "type": "bytes32"}], + "name": "isTEEActive", + "outputs": [{"internalType": "bool", "name": "", "type": "bool"}], + "stateMutability": "view", + "type": "function" + }, + { + "inputs": [{"internalType": "bytes32", "name": "teeId", "type": "bytes32"}], + "name": "isTEEEnabled", + "outputs": [{"internalType": "bool", "name": "", "type": "bool"}], + "stateMutability": "view", + "type": "function" + } +] diff --git a/src/opengradient/cli.py b/src/opengradient/cli.py index ef46d33..cfca61b 100644 --- a/src/opengradient/cli.py +++ b/src/opengradient/cli.py @@ -413,13 +413,31 @@ def completion( x402_settlement_mode=x402SettlementModes[x402_settlement_mode], ) - print_llm_completion_result(model_cid, completion_output.transaction_hash, completion_output.completion_output, is_vanilla=False) + print_llm_completion_result( + model_cid, completion_output.transaction_hash, completion_output.completion_output, is_vanilla=False, result=completion_output + ) except Exception as e: click.echo(f"Error running LLM completion: {str(e)}") -def print_llm_completion_result(model_cid, tx_hash, llm_output, is_vanilla=True): +def _print_tee_info(tee_id, tee_endpoint, tee_payment_address): + """Print TEE node info if available.""" + if not any([tee_id, tee_endpoint, tee_payment_address]): + return + click.secho("TEE Node:", fg="magenta", bold=True) + if tee_endpoint: + click.echo(" Endpoint: ", nl=False) + click.secho(tee_endpoint, fg="magenta") + if tee_id: + click.echo(" TEE ID: ", nl=False) + click.secho(tee_id, fg="magenta") + if tee_payment_address: + click.echo(" Payment address: ", nl=False) + click.secho(tee_payment_address, fg="magenta") + + +def print_llm_completion_result(model_cid, tx_hash, llm_output, is_vanilla=True, result=None): click.secho("✅ LLM completion Successful", fg="green", bold=True) click.echo("──────────────────────────────────────") click.echo("Model: ", nl=False) @@ -435,6 +453,9 @@ def print_llm_completion_result(model_cid, tx_hash, llm_output, is_vanilla=True) click.echo("Source: ", nl=False) click.secho("OpenGradient TEE", fg="cyan", bold=True) + if result is not None: + _print_tee_info(result.tee_id, result.tee_endpoint, result.tee_payment_address) + click.echo("──────────────────────────────────────") click.secho("LLM Output:", fg="yellow", bold=True) click.echo() @@ -578,13 +599,15 @@ def chat( if stream: print_streaming_chat_result(model_cid, result, is_tee=True) else: - print_llm_chat_result(model_cid, result.transaction_hash, result.finish_reason, result.chat_output, is_vanilla=False) + print_llm_chat_result( + model_cid, result.transaction_hash, result.finish_reason, result.chat_output, is_vanilla=False, result=result + ) except Exception as e: click.echo(f"Error running LLM chat inference: {str(e)}") -def print_llm_chat_result(model_cid, tx_hash, finish_reason, chat_output, is_vanilla=True): +def print_llm_chat_result(model_cid, tx_hash, finish_reason, chat_output, is_vanilla=True, result=None): click.secho("✅ LLM Chat Successful", fg="green", bold=True) click.echo("──────────────────────────────────────") click.echo("Model: ", nl=False) @@ -600,6 +623,9 @@ def print_llm_chat_result(model_cid, tx_hash, finish_reason, chat_output, is_van click.echo("Source: ", nl=False) click.secho("OpenGradient TEE", fg="cyan", bold=True) + if result is not None: + _print_tee_info(result.tee_id, result.tee_endpoint, result.tee_payment_address) + click.echo("──────────────────────────────────────") click.secho("Finish Reason: ", fg="yellow", bold=True) click.echo() @@ -608,7 +634,16 @@ def print_llm_chat_result(model_cid, tx_hash, finish_reason, chat_output, is_van click.secho("Chat Output:", fg="yellow", bold=True) click.echo() for key, value in chat_output.items(): - if value is not None and value not in ("", "[]", []): + if value is None or value in ("", "[]", []): + continue + if key == "tool_calls": + # Format tool calls the same way as the streaming path + click.secho("Tool Calls:", fg="yellow", bold=True) + for tool_call in value: + fn = tool_call.get("function", {}) + click.echo(f" Function: {fn.get('name', '')}") + click.echo(f" Arguments: {fn.get('arguments', '')}") + elif key == "content" and isinstance(value, list): # Normalize list-of-blocks content (e.g. Gemini 3 thought signatures) if key == "content" and isinstance(value, list): text = " ".join(block.get("text", "") for block in value if isinstance(block, dict) and block.get("type") == "text").strip() @@ -638,20 +673,21 @@ def print_streaming_chat_result(model_cid, stream, is_tee=True): for chunk in stream: chunk_count += 1 - if chunk.choices[0].delta.content: - content = chunk.choices[0].delta.content - sys.stdout.write(content) - sys.stdout.flush() - content_parts.append(content) - - # Handle tool calls - if chunk.choices[0].delta.tool_calls: - sys.stdout.write("\n") - sys.stdout.flush() - click.secho("Tool Calls:", fg="yellow", bold=True) - for tool_call in chunk.choices[0].delta.tool_calls: - click.echo(f" Function: {tool_call['function']['name']}") - click.echo(f" Arguments: {tool_call['function']['arguments']}") + if chunk.choices: + if chunk.choices[0].delta.content: + content = chunk.choices[0].delta.content + sys.stdout.write(content) + sys.stdout.flush() + content_parts.append(content) + + # Handle tool calls + if chunk.choices[0].delta.tool_calls: + sys.stdout.write("\n") + sys.stdout.flush() + click.secho("Tool Calls:", fg="yellow", bold=True) + for tool_call in chunk.choices[0].delta.tool_calls: + click.echo(f" Function: {tool_call['function']['name']}") + click.echo(f" Arguments: {tool_call['function']['arguments']}") # Print final info when stream completes if chunk.is_final: @@ -666,10 +702,12 @@ def print_streaming_chat_result(model_cid, stream, is_tee=True): click.echo(f" Total tokens: {chunk.usage.total_tokens}") click.echo() - if chunk.choices[0].finish_reason: + if chunk.choices and chunk.choices[0].finish_reason: click.echo("Finish reason: ", nl=False) click.secho(chunk.choices[0].finish_reason, fg="green") + _print_tee_info(chunk.tee_id, chunk.tee_endpoint, chunk.tee_payment_address) + click.echo("──────────────────────────────────────") click.echo(f"Chunks received: {chunk_count}") click.echo(f"Content length: {len(''.join(content_parts))} characters") diff --git a/src/opengradient/client/client.py b/src/opengradient/client/client.py index 2caef8b..08dc9ac 100644 --- a/src/opengradient/client/client.py +++ b/src/opengradient/client/client.py @@ -1,5 +1,6 @@ """Main Client class that unifies all OpenGradient service namespaces.""" +import logging from typing import Optional from web3 import Web3 @@ -7,15 +8,17 @@ from ..defaults import ( DEFAULT_API_URL, DEFAULT_INFERENCE_CONTRACT_ADDRESS, - DEFAULT_OPENGRADIENT_LLM_SERVER_URL, - DEFAULT_OPENGRADIENT_LLM_STREAMING_SERVER_URL, DEFAULT_RPC_URL, + DEFAULT_TEE_REGISTRY_ADDRESS, ) from .alpha import Alpha from .llm import LLM from .model_hub import ModelHub +from .tee_registry import TEERegistry from .twins import Twins +logger = logging.getLogger(__name__) + class Client: """ @@ -62,8 +65,8 @@ def __init__( rpc_url: str = DEFAULT_RPC_URL, api_url: str = DEFAULT_API_URL, contract_address: str = DEFAULT_INFERENCE_CONTRACT_ADDRESS, - og_llm_server_url: Optional[str] = DEFAULT_OPENGRADIENT_LLM_SERVER_URL, - og_llm_streaming_server_url: Optional[str] = DEFAULT_OPENGRADIENT_LLM_STREAMING_SERVER_URL, + og_llm_server_url: Optional[str] = None, + tee_registry_address: str = DEFAULT_TEE_REGISTRY_ADDRESS, ): """ Initialize the OpenGradient client. @@ -74,6 +77,11 @@ def __init__( You can supply a separate ``alpha_private_key`` so each chain uses its own funded wallet. When omitted, ``private_key`` is used for both. + By default the LLM server endpoint and its TLS certificate are fetched from + the on-chain TEE Registry, which stores certificates that were verified during + enclave attestation. You can override the endpoint by passing + ``og_llm_server_url`` explicitly (the system CA bundle is used for that URL). + Args: private_key: Private key whose wallet holds **Base Sepolia OPG tokens** for x402 LLM payments. @@ -86,8 +94,11 @@ def __init__( rpc_url: RPC URL for the OpenGradient Alpha Testnet. api_url: API URL for the OpenGradient API. contract_address: Inference contract address. - og_llm_server_url: OpenGradient LLM server URL. - og_llm_streaming_server_url: OpenGradient LLM streaming server URL. + og_llm_server_url: Override the LLM server URL instead of using the + registry-discovered endpoint. When set, the TLS certificate is + validated against the system CA bundle rather than the registry. + tee_registry_address: Address of the TEERegistry contract used to + discover active LLM proxy endpoints and their verified TLS certs. """ blockchain = Web3(Web3.HTTPProvider(rpc_url)) wallet_account = blockchain.eth.account.from_key(private_key) @@ -102,6 +113,32 @@ def __init__( if email is not None: hub_user = ModelHub._login_to_hub(email, password) + # Resolve LLM server URL and TLS certificate. + # If the caller provided explicit URLs, use those with standard CA verification. + # Otherwise, discover the endpoint and registry-verified cert from the TEE Registry. + llm_tls_cert_der: Optional[bytes] = None + tee = None + if og_llm_server_url is None: + try: + registry = TEERegistry( + rpc_url=rpc_url, + registry_address=tee_registry_address, + ) + tee = registry.get_llm_tee() + if tee is not None: + og_llm_server_url = tee.endpoint + llm_tls_cert_der = tee.tls_cert_der + logger.info("Using TEE endpoint from registry: %s (teeId=%s)", tee.endpoint, tee.tee_id) + else: + raise ValueError("No active LLM proxy TEE found in the registry. Pass og_llm_server_url explicitly to override.") + except ValueError: + raise + except Exception as e: + raise RuntimeError( + f"Failed to fetch LLM TEE endpoint from registry ({tee_registry_address} on {rpc_url}): {e}. " + "Pass og_llm_server_url explicitly to override." + ) from e + # Create namespaces self.model_hub = ModelHub(hub_user=hub_user) self.wallet_address = wallet_account.address @@ -109,7 +146,9 @@ def __init__( self.llm = LLM( wallet_account=wallet_account, og_llm_server_url=og_llm_server_url, - og_llm_streaming_server_url=og_llm_streaming_server_url, + tls_cert_der=llm_tls_cert_der, + tee_id=tee.tee_id if tee is not None else None, + tee_payment_address=tee.payment_address if tee is not None else None, ) self.alpha = Alpha( diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index 5e1176a..84868cd 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -2,13 +2,10 @@ import asyncio import json +import ssl import threading from queue import Queue from typing import AsyncGenerator, Dict, List, Optional, Union -import ssl -import socket -import tempfile -from urllib.parse import urlparse import httpx from eth_account.account import LocalAccount @@ -18,9 +15,10 @@ from x402v2.mechanisms.evm.exact.register import register_exact_evm_client as register_exact_evm_clientv2 from x402v2.mechanisms.evm.upto.register import register_upto_evm_client as register_upto_evm_clientv2 -from ..types import TEE_LLM, StreamChunk, TextGenerationOutput, TextGenerationStream, x402SettlementMode +from ..types import TEE_LLM, StreamChunk, StreamChoice, StreamDelta, TextGenerationOutput, TextGenerationStream, x402SettlementMode from .exceptions import OpenGradientError from .opg_token import Permit2ApprovalResult, ensure_opg_approval +from .tee_registry import build_ssl_context_from_der X402_PROCESSING_HASH_HEADER = "x-processing-hash" X402_PLACEHOLDER_API_KEY = "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef" @@ -40,51 +38,6 @@ ) -def _fetch_tls_cert_as_ssl_context(server_url: str) -> Optional[ssl.SSLContext]: - """ - Connect to a server, retrieve its TLS certificate (TOFU), - and return an ssl.SSLContext that trusts ONLY that certificate. - - Hostname verification is disabled because the TEE server's cert - is typically issued for a hostname but we may connect via IP address. - The pinned certificate itself provides the trust anchor. - - Returns None if the server is not HTTPS or unreachable. - """ - parsed = urlparse(server_url) - if parsed.scheme != "https": - return None - - hostname = parsed.hostname - port = parsed.port or 443 - - # Connect without verification to retrieve the server's certificate - fetch_ctx = ssl.create_default_context() - fetch_ctx.check_hostname = False - fetch_ctx.verify_mode = ssl.CERT_NONE - - try: - with socket.create_connection((hostname, port), timeout=10) as sock: - with fetch_ctx.wrap_socket(sock, server_hostname=hostname) as ssock: - der_cert = ssock.getpeercert(binary_form=True) - pem_cert = ssl.DER_cert_to_PEM_cert(der_cert) - except Exception: - return None - - # Write PEM to a temp file so we can load it into the SSLContext - cert_file = tempfile.NamedTemporaryFile(prefix="og_tee_tls_", suffix=".pem", delete=False, mode="w") - cert_file.write(pem_cert) - cert_file.flush() - cert_file.close() - - # Build an SSLContext that trusts ONLY this cert, with hostname check disabled - ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - ctx.load_verify_locations(cert_file.name) - ctx.check_hostname = False # Cert is for a hostname, but we connect via IP - ctx.verify_mode = ssl.CERT_REQUIRED # Still verify the cert itself - return ctx - - class LLM: """ LLM inference namespace. @@ -108,13 +61,32 @@ class LLM: result = client.llm.completion(model=TEE_LLM.CLAUDE_HAIKU_4_5, prompt="Hello") """ - def __init__(self, wallet_account: LocalAccount, og_llm_server_url: str, og_llm_streaming_server_url: str): + def __init__( + self, + wallet_account: LocalAccount, + og_llm_server_url: str, + tls_cert_der: Optional[bytes] = None, + tee_id: Optional[str] = None, + tee_payment_address: Optional[str] = None, + ): self._wallet_account = wallet_account self._og_llm_server_url = og_llm_server_url - self._og_llm_streaming_server_url = og_llm_streaming_server_url - self._tls_verify: Union[ssl.SSLContext, bool] = _fetch_tls_cert_as_ssl_context(self._og_llm_server_url) or True - self._streaming_tls_verify: Union[ssl.SSLContext, bool] = _fetch_tls_cert_as_ssl_context(self._og_llm_streaming_server_url) or True + # TEE metadata surfaced on every response so callers can verify/audit which + # enclave served the request. + self._tee_id = tee_id + self._tee_endpoint = og_llm_server_url + self._tee_payment_address = tee_payment_address + + if tls_cert_der: + # Use the registry-verified certificate as the sole trust anchor. + ssl_ctx = build_ssl_context_from_der(tls_cert_der) + self._tls_verify: Union[ssl.SSLContext, bool] = ssl_ctx + self._streaming_tls_verify: Union[ssl.SSLContext, bool] = ssl_ctx + else: + # No cert from registry — fall back to default system CA verification. + self._tls_verify = True + self._streaming_tls_verify = True signer = EthAccountSignerv2(self._wallet_account) self._x402_client = x402Clientv2() @@ -199,7 +171,7 @@ def completion( max_tokens: int = 100, stop_sequence: Optional[List[str]] = None, temperature: float = 0.0, - x402_settlement_mode: Optional[x402SettlementMode] = x402SettlementMode.BATCH_HASHED, + x402_settlement_mode: x402SettlementMode = x402SettlementMode.BATCH_HASHED, ) -> TextGenerationOutput: """ Perform inference on an LLM model using completions via TEE. @@ -241,7 +213,7 @@ def _tee_llm_completion( max_tokens: int = 100, stop_sequence: Optional[List[str]] = None, temperature: float = 0.0, - x402_settlement_mode: Optional[x402SettlementMode] = x402SettlementMode.BATCH_HASHED, + x402_settlement_mode: x402SettlementMode = x402SettlementMode.BATCH_HASHED, ) -> TextGenerationOutput: """ Route completion request to OpenGradient TEE LLM server with x402 payments. @@ -277,6 +249,9 @@ async def make_request_v2(): completion_output=result.get("completion"), tee_signature=result.get("tee_signature"), tee_timestamp=result.get("tee_timestamp"), + tee_id=self._tee_id, + tee_endpoint=self._tee_endpoint, + tee_payment_address=self._tee_payment_address, ) except Exception as e: @@ -298,7 +273,7 @@ def chat( temperature: float = 0.0, tools: Optional[List[Dict]] = None, tool_choice: Optional[str] = None, - x402_settlement_mode: Optional[x402SettlementMode] = x402SettlementMode.BATCH_HASHED, + x402_settlement_mode: x402SettlementMode = x402SettlementMode.BATCH_HASHED, stream: bool = False, ) -> Union[TextGenerationOutput, TextGenerationStream]: """ @@ -328,6 +303,20 @@ def chat( OpenGradientError: If the inference fails. """ if stream: + if tools: + # The TEE streaming endpoint omits tool call content from SSE events. + # Fall back transparently to the non-streaming endpoint and emit a + # single final StreamChunk so callers get the complete tool call data. + return self._tee_llm_chat_tools_as_stream( + model=model.split("/")[1], + messages=messages, + max_tokens=max_tokens, + stop_sequence=stop_sequence, + temperature=temperature, + tools=tools, + tool_choice=tool_choice, + x402_settlement_mode=x402_settlement_mode, + ) # Use threading bridge for true sync streaming return self._tee_llm_chat_stream_sync( model=model.split("/")[1], @@ -413,6 +402,9 @@ async def make_request_v2(): chat_output=message, tee_signature=result.get("tee_signature"), tee_timestamp=result.get("tee_timestamp"), + tee_id=self._tee_id, + tee_endpoint=self._tee_endpoint, + tee_payment_address=self._tee_payment_address, ) except Exception as e: @@ -425,6 +417,59 @@ async def make_request_v2(): except Exception as e: raise OpenGradientError(f"TEE LLM chat failed: {str(e)}") + def _tee_llm_chat_tools_as_stream( + self, + model: str, + messages: List[Dict], + max_tokens: int = 100, + stop_sequence: Optional[List[str]] = None, + temperature: float = 0.0, + tools: Optional[List[Dict]] = None, + tool_choice: Optional[str] = None, + x402_settlement_mode: x402SettlementMode = x402SettlementMode.BATCH_HASHED, + ) -> TextGenerationStream: + """ + Transparent non-streaming fallback for tool-call requests with stream=True. + + The TEE streaming endpoint returns an empty delta when tools are present — + tool call content is not emitted as SSE events. This method calls the + non-streaming endpoint instead and emits a single final StreamChunk that + carries the complete tool call response, preserving the streaming interface + for callers (including the CLI). + """ + result = self._tee_llm_chat( + model=model, + messages=messages, + max_tokens=max_tokens, + stop_sequence=stop_sequence, + temperature=temperature, + tools=tools, + tool_choice=tool_choice, + x402_settlement_mode=x402_settlement_mode, + ) + + chat_output = result.chat_output or {} + delta = StreamDelta( + role=chat_output.get("role"), + content=chat_output.get("content"), + tool_calls=chat_output.get("tool_calls"), + ) + choice = StreamChoice( + delta=delta, + index=0, + finish_reason=result.finish_reason, + ) + yield StreamChunk( + choices=[choice], + model=model, + is_final=True, + tee_signature=result.tee_signature, + tee_timestamp=result.tee_timestamp, + tee_id=result.tee_id, + tee_endpoint=result.tee_endpoint, + tee_payment_address=result.tee_payment_address, + ) + def _tee_llm_chat_stream_sync( self, model: str, @@ -551,14 +596,19 @@ async def _parse_sse_response(response) -> AsyncGenerator[StreamChunk, None]: try: data = json.loads(data_str) - yield StreamChunk.from_sse_data(data) + chunk = StreamChunk.from_sse_data(data) + if chunk.is_final: + chunk.tee_id = self._tee_id + chunk.tee_endpoint = self._tee_endpoint + chunk.tee_payment_address = self._tee_payment_address + yield chunk except json.JSONDecodeError: continue endpoint = "/v1/chat/completions" async with self._stream_client.stream( "POST", - self._og_llm_streaming_server_url + endpoint, + self._og_llm_server_url + endpoint, json=payload, headers=headers, timeout=60, diff --git a/src/opengradient/client/tee_registry.py b/src/opengradient/client/tee_registry.py new file mode 100644 index 0000000..99294c0 --- /dev/null +++ b/src/opengradient/client/tee_registry.py @@ -0,0 +1,153 @@ +"""TEE Registry client for fetching verified TEE endpoints and TLS certificates.""" + +import logging +import ssl +from dataclasses import dataclass +from typing import List, NamedTuple, Optional + +from web3 import Web3 + +from ._utils import get_abi + +logger = logging.getLogger(__name__) + +# TEE types as defined in the registry contract +TEE_TYPE_LLM_PROXY = 0 +TEE_TYPE_VALIDATOR = 1 + + +class TEEInfo(NamedTuple): + """Mirrors the on-chain TEERegistry.TEEInfo struct.""" + + owner: str + payment_address: str + endpoint: str + public_key: bytes + tls_certificate: bytes + pcr_hash: bytes + tee_type: int + enabled: bool + registered_at: int + last_heartbeat_at: int + + +@dataclass +class TEEEndpoint: + """A verified TEE with its endpoint URL and TLS certificate from the registry.""" + + tee_id: str + endpoint: str + tls_cert_der: bytes + payment_address: str + + +class TEERegistry: + """ + Queries the on-chain TEE Registry contract to retrieve verified TEE endpoints + and their TLS certificates. + + Instead of blindly trusting the TLS certificate presented by a TEE server + (TOFU), this class fetches the certificate that was submitted and verified + during TEE registration. Any certificate that does not match the one stored + in the registry should be rejected. + + Args: + rpc_url: RPC endpoint for the chain where the registry is deployed. + registry_address: Address of the deployed TEERegistry contract. + """ + + def __init__(self, rpc_url: str, registry_address: str): + self._web3 = Web3(Web3.HTTPProvider(rpc_url)) + abi = get_abi("TEERegistry.abi") + self._contract = self._web3.eth.contract( + address=Web3.to_checksum_address(registry_address), + abi=abi, + ) + + def get_active_tees_by_type(self, tee_type: int) -> List[TEEEndpoint]: + """ + Return all active TEEs of the given type with their endpoints and TLS certs. + + Uses the contract's ``getActiveTEEs(teeType)`` which returns only TEEs that + are enabled, have a valid (non-revoked) PCR, and a fresh heartbeat — all in + a single on-chain call. + + Args: + tee_type: Integer TEE type (0=LLMProxy, 1=Validator). + + Returns: + List of TEEEndpoint objects for active TEEs of that type. + """ + type_label = {TEE_TYPE_LLM_PROXY: "LLMProxy", TEE_TYPE_VALIDATOR: "Validator"}.get(tee_type, str(tee_type)) + + try: + tee_infos = self._contract.functions.getActiveTEEs(tee_type).call() + except Exception as e: + logger.warning("Failed to fetch active TEEs from registry (type=%s): %s", type_label, e) + return [] + + logger.debug("Registry returned %d active TEE(s) for type=%s", len(tee_infos), type_label) + + endpoints: List[TEEEndpoint] = [] + for raw in tee_infos: + tee = TEEInfo(*raw) + tee_id_hex = Web3.keccak(tee.public_key).hex() + if not tee.endpoint or not tee.tls_certificate: + logger.warning(" teeId=%s missing endpoint or TLS cert (skipped)", tee_id_hex) + continue + logger.info( + " teeId=%s endpoint=%s paymentAddress=%s certBytes=%d", + tee_id_hex, + tee.endpoint, + tee.payment_address, + len(tee.tls_certificate), + ) + endpoints.append( + TEEEndpoint( + tee_id=tee_id_hex, + endpoint=tee.endpoint, + tls_cert_der=bytes(tee.tls_certificate), + payment_address=tee.payment_address, + ) + ) + + logger.info("Discovered %d active %s TEE(s) from registry", len(endpoints), type_label) + return endpoints + + def get_llm_tee(self) -> Optional[TEEEndpoint]: + """ + Return the first active LLM proxy TEE from the registry. + + Returns: + TEEEndpoint for an active LLM proxy TEE, or None if none are available. + """ + logger.debug("Querying TEE registry for active LLM proxy TEEs...") + tees = self.get_active_tees_by_type(TEE_TYPE_LLM_PROXY) + if tees: + logger.info("Selected LLM TEE: endpoint=%s teeId=%s", tees[0].endpoint, tees[0].tee_id) + else: + logger.warning("No active LLM proxy TEEs found in registry") + return tees[0] if tees else None + + +def build_ssl_context_from_der(der_cert: bytes) -> ssl.SSLContext: + """ + Build an ssl.SSLContext that trusts *only* the given DER-encoded certificate. + + Hostname verification is disabled because TEE servers are typically addressed + by IP while the cert may be issued for a different hostname. The pinned + certificate itself is the trust anchor — only that cert is accepted. + + Args: + der_cert: DER-encoded X.509 certificate bytes as stored in the registry. + + Returns: + ssl.SSLContext configured to accept only the pinned certificate. + """ + pem = ssl.DER_cert_to_PEM_cert(der_cert) + + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.load_verify_locations(cadata=pem) + ctx.check_hostname = False # TEE cert may be issued for a hostname; we connect via IP + ctx.verify_mode = ssl.CERT_REQUIRED + return ctx diff --git a/src/opengradient/defaults.py b/src/opengradient/defaults.py index c053225..ba2bfc5 100644 --- a/src/opengradient/defaults.py +++ b/src/opengradient/defaults.py @@ -6,6 +6,6 @@ DEFAULT_INFERENCE_CONTRACT_ADDRESS = "0x8383C9bD7462F12Eb996DD02F78234C0421A6FaE" DEFAULT_SCHEDULER_ADDRESS = "0x7179724De4e7FF9271FA40C0337c7f90C0508eF6" DEFAULT_BLOCKCHAIN_EXPLORER = "https://explorer.opengradient.ai/tx/" -# TODO (Kyle): Add a process to fetch these IPs from the TEE registry -DEFAULT_OPENGRADIENT_LLM_SERVER_URL = "https://3.15.214.21:443" -DEFAULT_OPENGRADIENT_LLM_STREAMING_SERVER_URL = "https://3.15.214.21:443" +# TEE Registry contract on the OG EVM chain — used to discover LLM proxy endpoints +# and fetch their registry-verified TLS certificates instead of blindly trusting TOFU. +DEFAULT_TEE_REGISTRY_ADDRESS = "0x4e72238852f3c918f4E4e57AeC9280dDB0c80248" diff --git a/src/opengradient/types.py b/src/opengradient/types.py index 43a5060..492bed4 100644 --- a/src/opengradient/types.py +++ b/src/opengradient/types.py @@ -234,6 +234,9 @@ class StreamChunk: is_final: Whether this is the final chunk (before [DONE]) tee_signature: RSA-PSS signature over the response, present on the final chunk tee_timestamp: ISO timestamp from the TEE at signing time, present on the final chunk + tee_id: On-chain TEE registry ID of the enclave that served this request (final chunk only) + tee_endpoint: Endpoint URL of the TEE that served this request (final chunk only) + tee_payment_address: Payment address registered for the TEE (final chunk only) """ choices: List[StreamChoice] @@ -242,6 +245,9 @@ class StreamChunk: is_final: bool = False tee_signature: Optional[str] = None tee_timestamp: Optional[str] = None + tee_id: Optional[str] = None + tee_endpoint: Optional[str] = None + tee_payment_address: Optional[str] = None @classmethod def from_sse_data(cls, data: Dict) -> "StreamChunk": @@ -256,7 +262,9 @@ def from_sse_data(cls, data: Dict) -> "StreamChunk": """ choices = [] for choice_data in data.get("choices", []): - delta_data = choice_data.get("delta", {}) + # The TEE proxy sometimes sends SSE events using the non-streaming "message" + # key instead of the standard streaming "delta" key. Fall back gracefully. + delta_data = choice_data.get("delta") or choice_data.get("message") or {} delta = StreamDelta(content=delta_data.get("content"), role=delta_data.get("role"), tool_calls=delta_data.get("tool_calls")) choice = StreamChoice(delta=delta, index=choice_data.get("index", 0), finish_reason=choice_data.get("finish_reason")) choices.append(choice) @@ -423,6 +431,15 @@ class TextGenerationOutput: tee_timestamp: Optional[str] = None """ISO-8601 timestamp from the TEE at signing time.""" + tee_id: Optional[str] = None + """On-chain TEE registry ID (keccak256 of the enclave's public key) of the TEE that served this request.""" + + tee_endpoint: Optional[str] = None + """Endpoint URL of the TEE that served this request, as registered on-chain.""" + + tee_payment_address: Optional[str] = None + """Payment address registered for the TEE that served this request.""" + @dataclass class AbiFunction: diff --git a/tests/client_test.py b/tests/client_test.py index 5077be5..dadb5a0 100644 --- a/tests/client_test.py +++ b/tests/client_test.py @@ -21,7 +21,10 @@ @pytest.fixture def mock_web3(): """Create a mock Web3 instance.""" - with patch("src.opengradient.client.client.Web3") as mock: + with ( + patch("src.opengradient.client.client.Web3") as mock, + patch("src.opengradient.client.client.TEERegistry") as mock_tee_registry, + ): mock_instance = MagicMock() mock.return_value = mock_instance mock.HTTPProvider.return_value = MagicMock() @@ -31,6 +34,14 @@ def mock_web3(): mock_instance.eth.gas_price = 1000000000 mock_instance.eth.contract.return_value = MagicMock() + # Return a fake active TEE endpoint so Client.__init__ doesn't need a live registry + mock_tee = MagicMock() + mock_tee.endpoint = "https://test.tee.server" + mock_tee.tls_cert_der = None + mock_tee.tee_id = "test-tee-id" + mock_tee.payment_address = "0xTestPaymentAddress" + mock_tee_registry.return_value.get_llm_tee.return_value = mock_tee + yield mock_instance @@ -103,9 +114,8 @@ def test_client_initialization_with_auth(self, mock_web3, mock_abi_files): assert client.model_hub._hub_user["idToken"] == "test_token" def test_client_initialization_custom_llm_urls(self, mock_web3, mock_abi_files): - """Test client initialization with custom LLM server URLs.""" + """Test client initialization with custom LLM server URL.""" custom_llm_url = "https://custom.llm.server" - custom_streaming_url = "https://custom.streaming.server" client = Client( private_key="0x" + "a" * 64, @@ -113,11 +123,9 @@ def test_client_initialization_custom_llm_urls(self, mock_web3, mock_abi_files): api_url="https://test.api.url", contract_address="0x" + "b" * 40, og_llm_server_url=custom_llm_url, - og_llm_streaming_server_url=custom_streaming_url, ) assert client.llm._og_llm_server_url == custom_llm_url - assert client.llm._og_llm_streaming_server_url == custom_streaming_url class TestAlphaProperty: diff --git a/tests/tee_registry_test.py b/tests/tee_registry_test.py new file mode 100644 index 0000000..15b15d2 --- /dev/null +++ b/tests/tee_registry_test.py @@ -0,0 +1,206 @@ +import os +import ssl +import sys +from unittest.mock import MagicMock, patch + +import pytest + +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) + +from src.opengradient.client.tee_registry import ( + TEE_TYPE_LLM_PROXY, + TEE_TYPE_VALIDATOR, + TEEEndpoint, + TEERegistry, + build_ssl_context_from_der, +) + + +# --- Helpers --- + + +def _make_tee_info( + endpoint="https://tee.example.com", + payment_address="0xPayment", + pub_key=b"pubkey", + tls_cert_der=b"\x01\x02\x03", +): + """Build a tuple matching the TEEInfo struct order from the new contract.""" + return ( + "0xOwner", # owner + payment_address, # paymentAddress + endpoint, # endpoint + pub_key, # publicKey + tls_cert_der, # tlsCertificate + b"\x00" * 32, # pcrHash + 0, # teeType + True, # enabled (always True from getActiveTEEs) + 1000, # registeredAt + 2000, # lastHeartbeatAt + ) + + +def _make_self_signed_der() -> bytes: + """Generate a minimal self-signed DER certificate for testing.""" + from cryptography import x509 + from cryptography.hazmat.primitives import hashes, serialization + from cryptography.hazmat.primitives.asymmetric import rsa + from cryptography.x509.oid import NameOID + import datetime + + key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + subject = issuer = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "test")]) + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.datetime.now(datetime.UTC)) + .not_valid_after(datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=1)) + .sign(key, hashes.SHA256()) + ) + return cert.public_bytes(serialization.Encoding.DER) + + +# --- Fixtures --- + + +@pytest.fixture +def mock_contract(): + """Create a TEERegistry with a mocked Web3 contract.""" + with ( + patch("src.opengradient.client.tee_registry.Web3") as mock_web3_cls, + patch("src.opengradient.client.tee_registry.get_abi") as mock_get_abi, + ): + mock_get_abi.return_value = [] + mock_web3 = MagicMock() + mock_web3_cls.return_value = mock_web3 + mock_web3_cls.HTTPProvider.return_value = MagicMock() + mock_web3_cls.to_checksum_address.side_effect = lambda x: x + mock_web3_cls.keccak.side_effect = lambda data: b"\xaa" * 32 if data == b"pubkey" else b"\xbb" * 32 + + contract = MagicMock() + mock_web3.eth.contract.return_value = contract + + registry = TEERegistry(rpc_url="http://localhost:8545", registry_address="0xRegistry") + yield registry, contract + + +# --- TEERegistry Tests --- + + +class TestGetActiveTeesByType: + def test_returns_active_tees(self, mock_contract): + registry, contract = mock_contract + + contract.functions.getActiveTEEs.return_value.call.return_value = [_make_tee_info()] + + result = registry.get_active_tees_by_type(TEE_TYPE_LLM_PROXY) + + assert len(result) == 1 + assert result[0].endpoint == "https://tee.example.com" + assert result[0].payment_address == "0xPayment" + assert result[0].tls_cert_der == b"\x01\x02\x03" + contract.functions.getActiveTEEs.assert_called_once_with(TEE_TYPE_LLM_PROXY) + + def test_skips_tee_with_empty_endpoint(self, mock_contract): + registry, contract = mock_contract + + contract.functions.getActiveTEEs.return_value.call.return_value = [_make_tee_info(endpoint="")] + + result = registry.get_active_tees_by_type(TEE_TYPE_LLM_PROXY) + assert len(result) == 0 + + def test_skips_tee_with_empty_cert(self, mock_contract): + registry, contract = mock_contract + + contract.functions.getActiveTEEs.return_value.call.return_value = [_make_tee_info(tls_cert_der=b"")] + + result = registry.get_active_tees_by_type(TEE_TYPE_LLM_PROXY) + assert len(result) == 0 + + def test_returns_empty_on_rpc_failure(self, mock_contract): + registry, contract = mock_contract + + contract.functions.getActiveTEEs.return_value.call.side_effect = Exception("RPC error") + + result = registry.get_active_tees_by_type(TEE_TYPE_LLM_PROXY) + assert result == [] + + def test_multiple_active_tees(self, mock_contract): + registry, contract = mock_contract + + infos = [_make_tee_info(endpoint=f"https://tee-{i}.example.com", pub_key=f"pubkey{i}".encode()) for i in range(3)] + contract.functions.getActiveTEEs.return_value.call.return_value = infos + + result = registry.get_active_tees_by_type(TEE_TYPE_LLM_PROXY) + assert len(result) == 3 + + def test_validator_type_label(self, mock_contract): + """Ensure validator type queries work the same way.""" + registry, contract = mock_contract + + contract.functions.getActiveTEEs.return_value.call.return_value = [] + + result = registry.get_active_tees_by_type(TEE_TYPE_VALIDATOR) + assert result == [] + contract.functions.getActiveTEEs.assert_called_once_with(TEE_TYPE_VALIDATOR) + + +class TestGetLlmTee: + def test_returns_first_active_tee(self, mock_contract): + registry, contract = mock_contract + + contract.functions.getActiveTEEs.return_value.call.return_value = [ + _make_tee_info(endpoint="https://tee-1.example.com"), + _make_tee_info(endpoint="https://tee-2.example.com"), + ] + + result = registry.get_llm_tee() + + assert result is not None + assert result.endpoint == "https://tee-1.example.com" + + def test_returns_none_when_no_tees(self, mock_contract): + registry, contract = mock_contract + + contract.functions.getActiveTEEs.return_value.call.return_value = [] + + result = registry.get_llm_tee() + assert result is None + + def test_queries_llm_proxy_type(self, mock_contract): + registry, contract = mock_contract + + contract.functions.getActiveTEEs.return_value.call.return_value = [] + registry.get_llm_tee() + + contract.functions.getActiveTEEs.assert_called_once_with(TEE_TYPE_LLM_PROXY) + + +# --- build_ssl_context_from_der Tests --- + + +class TestBuildSslContextFromDer: + def test_returns_ssl_context(self): + der_cert = _make_self_signed_der() + ctx = build_ssl_context_from_der(der_cert) + + assert isinstance(ctx, ssl.SSLContext) + + def test_hostname_check_disabled(self): + der_cert = _make_self_signed_der() + ctx = build_ssl_context_from_der(der_cert) + + assert ctx.check_hostname is False + + def test_cert_required(self): + der_cert = _make_self_signed_der() + ctx = build_ssl_context_from_der(der_cert) + + assert ctx.verify_mode == ssl.CERT_REQUIRED + + def test_rejects_invalid_der(self): + with pytest.raises(Exception): + build_ssl_context_from_der(b"not-a-valid-cert")