Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions stapler-scripts/claude-proxy/providers/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""AWS Bedrock provider implementation."""
import json
import asyncio
import boto3
from typing import Dict, Any, AsyncIterator, Optional
from . import Provider, RateLimitError, ValidationError
Expand Down Expand Up @@ -86,16 +87,18 @@ async def send_message(
bedrock_body.pop("model", None)

try:
# Synchronous call wrapped in async
response = self.client.invoke_model(
# Synchronous call wrapped in async using thread pool
response = await asyncio.to_thread(
self.client.invoke_model,
modelId=bedrock_model,
contentType="application/json",
accept="application/json",
body=json.dumps(bedrock_body)
)

# Parse response
result = json.loads(response["body"].read())
# Parse response - reading from the body is also blocking I/O
body_content = await asyncio.to_thread(response["body"].read)
result = json.loads(body_content)
return self._convert_response(result, original_model)

except self.client.exceptions.ThrottlingException:
Expand Down Expand Up @@ -126,16 +129,22 @@ async def stream_message(
bedrock_body.pop("model", None)

try:
# Invoke with streaming
response = self.client.invoke_model_with_response_stream(
# Invoke with streaming wrapped in async using thread pool
response = await asyncio.to_thread(
self.client.invoke_model_with_response_stream,
modelId=bedrock_model,
contentType="application/json",
accept="application/json",
body=json.dumps(bedrock_body)
)

# Stream events
for event in response["body"]:
# Stream events - the EventStream is a synchronous iterator, so we wrap next() in a thread
iterator = iter(response["body"])
while True:
event = await asyncio.to_thread(next, iterator, None)
if event is None:
break

chunk = json.loads(event["chunk"]["bytes"])

# Convert to SSE format matching Anthropic
Expand Down