diff --git a/stapler-scripts/claude-proxy/providers/bedrock.py b/stapler-scripts/claude-proxy/providers/bedrock.py index 4792965..92cb57c 100644 --- a/stapler-scripts/claude-proxy/providers/bedrock.py +++ b/stapler-scripts/claude-proxy/providers/bedrock.py @@ -1,5 +1,6 @@ """AWS Bedrock provider implementation.""" import json +import asyncio import boto3 from typing import Dict, Any, AsyncIterator, Optional from . import Provider, RateLimitError, ValidationError @@ -86,16 +87,18 @@ async def send_message( bedrock_body.pop("model", None) try: - # Synchronous call wrapped in async - response = self.client.invoke_model( + # Synchronous call wrapped in async using thread pool + response = await asyncio.to_thread( + self.client.invoke_model, modelId=bedrock_model, contentType="application/json", accept="application/json", body=json.dumps(bedrock_body) ) - # Parse response - result = json.loads(response["body"].read()) + # Parse response - reading from the body is also blocking I/O + body_content = await asyncio.to_thread(response["body"].read) + result = json.loads(body_content) return self._convert_response(result, original_model) except self.client.exceptions.ThrottlingException: @@ -126,16 +129,22 @@ async def stream_message( bedrock_body.pop("model", None) try: - # Invoke with streaming - response = self.client.invoke_model_with_response_stream( + # Invoke with streaming wrapped in async using thread pool + response = await asyncio.to_thread( + self.client.invoke_model_with_response_stream, modelId=bedrock_model, contentType="application/json", accept="application/json", body=json.dumps(bedrock_body) ) - # Stream events - for event in response["body"]: + # Stream events - the EventStream is a synchronous iterator, so we wrap next() in a thread + iterator = iter(response["body"]) + while True: + event = await asyncio.to_thread(next, iterator, None) + if event is None: + break + chunk = json.loads(event["chunk"]["bytes"]) # Convert to SSE format matching Anthropic