Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion litellm/proxy/anthropic_endpoints/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,36 @@
router = APIRouter()


def _get_blocked_response_usage(
model: str,
messages: Optional[list],
block_message: str,
) -> dict:
"""
Compute token usage for a synthetic guardrail-blocked response.

The original request still consumed input tokens and the synthetic block
message has real content, so both counts are computed rather than reported
as zero. Token counting is best-effort: any failure falls back to 0 so a
blocked response is always returned to the caller.
"""
import litellm

input_tokens = 0
output_tokens = 0
try:
if messages:
input_tokens = litellm.token_counter(model=model, messages=messages)
if block_message:
output_tokens = litellm.token_counter(model=model, text=block_message)
except Exception as token_count_error:
verbose_proxy_logger.debug(
"Failed to count tokens for blocked response: %s", token_count_error
)

return {"input_tokens": input_tokens, "output_tokens": output_tokens}


@router.post(
"/v1/messages",
tags=["[beta] Anthropic `/v1/messages`"],
Expand Down Expand Up @@ -87,14 +117,24 @@ async def anthropic_response( # noqa: PLR0915

from litellm.types.utils import AnthropicMessagesResponse

# The original request was processed upstream and the block message is
# synthesized here, so both token counts are reachable: input_tokens
# from the original request messages (carried on the exception) and
# output_tokens from the block message text.
_usage = _get_blocked_response_usage(
model=e.model,
messages=_data.get("messages"),
block_message=e.message,
)

_anthropic_response = AnthropicMessagesResponse(
id=f"msg_{str(uuid.uuid4())}",
type="message",
role="assistant",
content=[{"type": "text", "text": e.message}],
model=e.model,
stop_reason="end_turn",
usage={"input_tokens": 0, "output_tokens": 0},
usage=_usage,
)

if data.get("stream", None) is not None and data["stream"] is True:
Expand Down
54 changes: 54 additions & 0 deletions tests/test_litellm/proxy/anthropic_endpoints/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,60 @@ async def test_async_data_generator_anthropic_dict_handling(self, mock_safe_dump
) # Called twice, once for each dict object


class TestBlockedResponseUsage:
"""Token usage on synthetic guardrail-blocked responses."""

def test_blocked_response_reports_nonzero_token_counts(self):
from litellm.proxy.anthropic_endpoints.endpoints import (
_get_blocked_response_usage,
)

usage = _get_blocked_response_usage(
model="claude-3-5-sonnet-20240620",
messages=[{"role": "user", "content": "Tell me something harmful"}],
block_message="This request was blocked by a content guardrail.",
)

assert usage["input_tokens"] > 0
assert usage["output_tokens"] > 0

def test_blocked_response_output_tokens_match_block_message(self):
import litellm
from litellm.proxy.anthropic_endpoints.endpoints import (
_get_blocked_response_usage,
)

model = "claude-3-5-sonnet-20240620"
block_message = "Blocked: this content violates policy."

usage = _get_blocked_response_usage(
model=model,
messages=[{"role": "user", "content": "hi"}],
block_message=block_message,
)

expected_output = litellm.token_counter(model=model, text=block_message)
assert usage["output_tokens"] == expected_output

def test_blocked_response_usage_falls_back_to_zero_on_error(self):
from unittest.mock import patch

from litellm.proxy.anthropic_endpoints.endpoints import (
_get_blocked_response_usage,
)

with patch(
"litellm.token_counter", side_effect=RuntimeError("boom")
):
usage = _get_blocked_response_usage(
model="claude-3-5-sonnet-20240620",
messages=[{"role": "user", "content": "hi"}],
block_message="blocked",
)

assert usage == {"input_tokens": 0, "output_tokens": 0}


class TestEventLoggingBatchEndpoint:
"""Test the stubbed event logging batch endpoint"""

Expand Down
Loading