diff --git a/src/askui/model_providers/anthropic_vlm_provider.py b/src/askui/model_providers/anthropic_vlm_provider.py index 9edd42b9..aef5b48e 100644 --- a/src/askui/model_providers/anthropic_vlm_provider.py +++ b/src/askui/model_providers/anthropic_vlm_provider.py @@ -15,6 +15,7 @@ ToolChoiceParam, ) from askui.models.shared.prompts import SystemPrompt +from askui.models.shared.request_size import ANTHROPIC_MAX_REQUEST_BYTES from askui.models.shared.tools import ToolCollection from askui.utils.model_pricing import ModelPricing @@ -104,6 +105,11 @@ def model_id(self) -> str: def pricing(self) -> ModelPricing | None: return self._pricing + @property + @override + def max_request_bytes(self) -> int | None: + return ANTHROPIC_MAX_REQUEST_BYTES + @cached_property def _messages_api(self) -> AnthropicMessagesApi: """Lazily initialise the AnthropicMessagesApi on first use.""" diff --git a/src/askui/model_providers/askui_vlm_provider.py b/src/askui/model_providers/askui_vlm_provider.py index d149deff..cd67856c 100644 --- a/src/askui/model_providers/askui_vlm_provider.py +++ b/src/askui/model_providers/askui_vlm_provider.py @@ -16,6 +16,7 @@ ToolChoiceParam, ) from askui.models.shared.prompts import SystemPrompt +from askui.models.shared.request_size import ANTHROPIC_MAX_REQUEST_BYTES from askui.models.shared.tools import ToolCollection _DEFAULT_MODEL_ID = "claude-sonnet-4-6" @@ -69,6 +70,11 @@ def __init__( def model_id(self) -> str: return self._model_id_value + @property + @override + def max_request_bytes(self) -> int | None: + return ANTHROPIC_MAX_REQUEST_BYTES + @cached_property def _messages_api(self) -> AnthropicMessagesApi: """Lazily initialise the AnthropicMessagesApi on first use.""" diff --git a/src/askui/model_providers/vlm_provider.py b/src/askui/model_providers/vlm_provider.py index 1e98b972..afc3de33 100644 --- a/src/askui/model_providers/vlm_provider.py +++ b/src/askui/model_providers/vlm_provider.py @@ -53,6 +53,15 @@ def pricing(self) -> ModelPricing | None: """ return None + @property + def max_request_bytes(self) -> int | None: + """Hard cap on the serialized request body size, in bytes. + Returns ``None`` when the provider has no known/specific limit, in + which case callers fall back to a conservative default. Override in + subclasses backed by an endpoint with a known cap. + """ + return None + @abstractmethod def create_message( self, diff --git a/src/askui/models/anthropic/messages_api.py b/src/askui/models/anthropic/messages_api.py index 47ea2681..3b7621a5 100644 --- a/src/askui/models/anthropic/messages_api.py +++ b/src/askui/models/anthropic/messages_api.py @@ -1,3 +1,4 @@ +import logging from typing import Any, Tuple, cast from anthropic import ( @@ -39,9 +40,15 @@ ) from askui.models.shared.messages_api import MessagesApi from askui.models.shared.prompts import SystemPrompt +from askui.models.shared.request_size import ( + ANTHROPIC_MAX_REQUEST_BYTES, + estimate_messages_bytes, +) from askui.models.shared.tools import ToolCollection from askui.utils.image_utils import image_to_base64 +logger = logging.getLogger(__name__) + def _is_retryable_error(exception: BaseException) -> bool: """Check if the exception is a retryable error.""" @@ -193,6 +200,16 @@ def create_message( temperature: float | None = None, provider_options: dict[str, Any] | None = None, ) -> MessageParam: + estimated_bytes = estimate_messages_bytes(messages) + if estimated_bytes > ANTHROPIC_MAX_REQUEST_BYTES: + logger.warning( + "Estimated request size ~%d bytes exceeds the Anthropic " + "limit of %d bytes; the request may be rejected with a 400. " + "Configure a truncation strategy with byte enforcement.", + estimated_bytes, + ANTHROPIC_MAX_REQUEST_BYTES, + ) + # convert each message to anthropic BetaMessageParam type _messages = [from_message_param(message) for message in messages] diff --git a/src/askui/models/shared/request_size.py b/src/askui/models/shared/request_size.py new file mode 100644 index 00000000..f3404ba4 --- /dev/null +++ b/src/askui/models/shared/request_size.py @@ -0,0 +1,67 @@ +"""Request size estimation shared across truncation and providers. + +The Anthropic Messages API rejects requests whose serialized body exceeds +~32 MB with a 400 ``BadRequestError``. Base64-encoded screenshots dominate +that payload. These helpers estimate the serialized byte size cheaply so +truncation strategies (and a provider-side safety net) can keep requests +under the limit. + +The estimate reads cached string lengths (``len`` is O(1) on Python +strings, and base64 ``data`` is ASCII so its length equals its serialized +byte count), making a full pass O(number of blocks) rather than +O(payload size). Structural JSON overhead (field names, braces, quotes) is +not counted; it is sub-percent of image-heavy payloads and absorbed by the +threshold headroom callers apply on top of the hard limit. +""" + +from askui.models.shared.agent_message_param import ( + Base64ImageSourceParam, + BetaThinkingBlock, + ContentBlockParam, + ImageBlockParam, + MessageParam, + TextBlockParam, + ToolResultBlockParam, + ToolUseBlockParam, +) + +# Hard cap on serialized request size for the Anthropic Messages API. +ANTHROPIC_MAX_REQUEST_BYTES = 30 * 1024 * 1024 + + +def estimate_block_bytes(block: ContentBlockParam) -> int: + """Cheaply estimate the serialized byte size of one content block. + + Base64 image ``data`` is ASCII, so ``len`` equals its byte count and + is O(1) on Python strings. Walking blocks is therefore O(number of + blocks) rather than O(payload size), keeping the byte check cheap even + with many multi-megabyte screenshots. + """ + if isinstance(block, ImageBlockParam): + if isinstance(block.source, Base64ImageSourceParam): + return len(block.source.data) + return len(block.source.url) + if isinstance(block, TextBlockParam): + return len(block.text) + if isinstance(block, ToolResultBlockParam): + if isinstance(block.content, str): + return len(block.content) + return sum(estimate_block_bytes(nested) for nested in block.content) + if isinstance(block, ToolUseBlockParam): + return len(str(block.input)) + len(block.name) + if isinstance(block, BetaThinkingBlock): + return len(block.thinking) + len(block.signature) + # BetaRedactedThinkingBlock + return len(block.data) + + +def estimate_message_bytes(message: MessageParam) -> int: + """Estimate the serialized byte size of a single message.""" + if isinstance(message.content, str): + return len(message.content) + return sum(estimate_block_bytes(block) for block in message.content) + + +def estimate_messages_bytes(messages: list[MessageParam]) -> int: + """Estimate the serialized byte size of a message history.""" + return sum(estimate_message_bytes(message) for message in messages) diff --git a/src/askui/models/shared/truncation_strategies.py b/src/askui/models/shared/truncation_strategies.py index baf2d023..009c5fff 100644 --- a/src/askui/models/shared/truncation_strategies.py +++ b/src/askui/models/shared/truncation_strategies.py @@ -18,6 +18,9 @@ ToolUseBlockParam, ) from askui.models.shared.prompts import SystemPrompt +from askui.models.shared.request_size import ( + estimate_messages_bytes, +) from askui.models.shared.token_counter import SimpleTokenCounter from askui.models.shared.tools import ToolCollection from askui.prompts.truncation import SUMMARIZE_INSTRUCTION_PROMPT @@ -38,9 +41,178 @@ # see https://docs.anthropic.com/en/api/messages#body-messages MAX_MESSAGES = 100_000 + +REQUEST_SIZE_THRESHOLD = 0.8 + IMAGE_REMOVED_PLACEHOLDER = "[Screenshot removed to reduce message history length]" """Text used to replace stripped base64 images.""" +_REMOVED_IMAGE_PLACEHOLDER_BYTES = len(IMAGE_REMOVED_PLACEHOLDER) + + +def _base64_image_byte_sizes(messages: list[MessageParam]) -> list[int]: + """Byte sizes of every base64 image, in chronological order. + + Recurses into ``ToolResultBlockParam`` content. URL-based images + are excluded as they carry no payload bytes. + """ + sizes: list[int] = [] + for msg in messages: + if isinstance(msg.content, str): + continue + for block in msg.content: + if isinstance(block, ImageBlockParam) and isinstance( + block.source, Base64ImageSourceParam + ): + sizes.append(len(block.source.data)) + elif isinstance(block, ToolResultBlockParam) and isinstance( + block.content, list + ): + sizes.extend( + len(nested.source.data) + for nested in block.content + if isinstance(nested, ImageBlockParam) + and isinstance(nested.source, Base64ImageSourceParam) + ) + return sizes + + +def _count_base64_images(messages: list[MessageParam]) -> int: + """Count total base64 image blocks across messages.""" + return len(_base64_image_byte_sizes(messages)) + + +def _strip_base64_images( + content: list[ContentBlockParam], + max_to_strip: int, +) -> tuple[list[ContentBlockParam], int]: + """Strip up to ``max_to_strip`` base64 images from a content list. + + Replaces base64 ``ImageBlockParam`` blocks (top-level and nested + inside ``ToolResultBlockParam``) with text placeholders. URL-based + images are never stripped. + + Args: + content: The content blocks to process. + max_to_strip: Maximum number of images to strip. + + Returns: + Tuple of (new content list, count stripped). + """ + stripped = 0 + new_content: list[ContentBlockParam] = [] + + for block in content: + if stripped >= max_to_strip: + new_content.append(block) + continue + + if isinstance(block, ImageBlockParam) and isinstance( + block.source, Base64ImageSourceParam + ): + new_content.append(TextBlockParam(text=IMAGE_REMOVED_PLACEHOLDER)) + stripped += 1 + elif isinstance(block, ToolResultBlockParam) and isinstance( + block.content, list + ): + new_nested: list[TextBlockParam | ImageBlockParam] = [] + for nested in block.content: + if ( + stripped < max_to_strip + and isinstance(nested, ImageBlockParam) + and isinstance(nested.source, Base64ImageSourceParam) + ): + new_nested.append(TextBlockParam(text=IMAGE_REMOVED_PLACEHOLDER)) + stripped += 1 + else: + new_nested.append(nested) + new_content.append( + ToolResultBlockParam( + tool_use_id=block.tool_use_id, + content=new_nested, + is_error=block.is_error, + cache_control=block.cache_control, + ) + ) + else: + new_content.append(block) + + return new_content, stripped + + +def _strip_oldest_base64_images( + messages: list[MessageParam], + n_images_to_keep: int, +) -> int | None: + """Replace oldest base64 images with placeholders in place. + + Keeps the newest ``n_images_to_keep`` base64 images and walks from + the start of the history, replacing excess images (oldest first). + + Args: + messages: The message history to mutate. + n_images_to_keep: Number of most-recent base64 images to retain. + + Returns: + The index of the last message modified, or ``None`` if nothing + was stripped. + """ + to_remove = _count_base64_images(messages) - n_images_to_keep + if to_remove <= 0: + return None + + removed = 0 + boundary: int | None = None + for i, msg in enumerate(messages): + if removed >= to_remove: + break + if isinstance(msg.content, str): + continue + new_content, removed_in_msg = _strip_base64_images( + msg.content, to_remove - removed + ) + if removed_in_msg > 0: + messages[i] = MessageParam( + role=msg.role, + content=new_content, + stop_reason=msg.stop_reason, + usage=msg.usage, + ) + boundary = i + removed += removed_in_msg + return boundary + + +def _image_keep_count_for_byte_budget( + messages: list[MessageParam], + byte_budget: int, + current_bytes: int, +) -> int: + """Smallest number of newest base64 images to keep within budget. + + Oldest images are dropped first (their payload replaced by the small + placeholder) until the estimated request size falls to or below + ``byte_budget``. If dropping every image is still not enough, returns + ``0`` (strip all) and the caller falls back to summarization. + + Args: + messages: The message history to inspect. + byte_budget: Maximum allowed estimated request size in bytes. + current_bytes: Current estimated request size in bytes. + + Returns: + The target number of newest base64 images to keep (``>= 0``). + """ + sizes = _base64_image_byte_sizes(messages) + freed = 0 + keep = len(sizes) + for size in sizes: # oldest first + if current_bytes - freed <= byte_budget: + break + freed += size - _REMOVED_IMAGE_PLACEHOLDER_BYTES + keep -= 1 + return max(keep, 0) + def _has_orphaned_tool_results(msg: MessageParam) -> bool: """Check if a message contains tool_result blocks. @@ -212,6 +384,15 @@ class TruncationStrategy(ABC): max_input_tokens: Maximum input tokens for the endpoint. truncation_threshold: Fraction of `max_input_tokens` at which to truncate. + max_request_bytes: Hard cap on the serialized request size in + bytes. When ``None`` (default), it is read from the + conversation's `VlmProvider` (e.g. ~32 MB for Anthropic). If + no provider advertises a limit either, byte-budget enforcement + is skipped entirely (no images are stripped to meet a byte + budget). Set it to override the provider's value. + request_size_threshold: Fraction of `max_request_bytes` at + which to start stripping the oldest base64 images (newest + kept) so the request stays under the endpoint's byte limit. """ def __init__( @@ -219,11 +400,15 @@ def __init__( max_messages: int = MAX_MESSAGES, max_input_tokens: int = MAX_INPUT_TOKENS, truncation_threshold: float = TRUNCATION_THRESHOLD, + max_request_bytes: int | None = None, + request_size_threshold: float = REQUEST_SIZE_THRESHOLD, ) -> None: self._full_message_history: list[MessageParam] = [] self._truncated_message_history: list[MessageParam] = [] self._first_user_message: MessageParam | None = None self._max_messages = max_messages + self._max_request_bytes = max_request_bytes + self._request_size_threshold = request_size_threshold self._absolute_truncation_threshold = int( max_input_tokens * truncation_threshold ) @@ -235,6 +420,45 @@ def __init__( self.callbacks: list[ConversationCallback] = [] self.conversation: "Conversation | None" = None + def _resolve_max_request_bytes(self) -> int | None: + """Resolve the endpoint's hard request byte limit. + + Precedence: an explicit ``max_request_bytes`` override, then the + provider used for the outgoing request (the conversation's + ``vlm_provider``, falling back to this strategy's summarization + ``vlm_provider``). + + Returns ``None`` when no limit is defined anywhere (no override and + no provider advertises one), in which case byte-budget enforcement + is skipped entirely (no images are stripped to meet a byte budget). + """ + if self._max_request_bytes is not None: + return self._max_request_bytes + providers = ( + self.conversation.vlm_provider if self.conversation else None, + self.vlm_provider, + ) + for provider in providers: + if provider is not None and provider.max_request_bytes is not None: + return provider.max_request_bytes + return None + + def _byte_budget(self) -> int | None: + """Effective byte budget: hard limit scaled by the threshold. + + Returns ``None`` when no hard limit is defined, signalling that + byte-budget enforcement should be skipped. + """ + max_request_bytes = self._resolve_max_request_bytes() + if max_request_bytes is None: + return None + return int(max_request_bytes * self._request_size_threshold) + + def _report(self, content: str) -> None: + """Send a debug message to the reporter if one is attached.""" + if self.reporter: + self.reporter.add_message("TruncationStrategy", content) + def _summarization_request_context( self, ) -> tuple[SystemPrompt | None, ToolCollection | None, dict[str, Any] | None]: @@ -262,6 +486,73 @@ def truncate(self) -> None: """Force-truncate the message history.""" ... + def _enforce_byte_budget(self) -> int | None: + """Strip oldest base64 images until within the byte budget. + + When no hard limit is defined (no `max_request_bytes` override and + no provider advertises one), `_byte_budget` returns ``None`` and + enforcement is skipped entirely: no images are stripped. + Returns: + The index of the last message modified, or ``None`` if the + history was already within budget, or no budget is defined + (no change made). + """ + budget = self._byte_budget() + if budget is None: + self._report( + "[byte budget] no request byte limit defined, " + "skipping byte-budget enforcement" + ) + return None + + current = estimate_messages_bytes(self._truncated_message_history) + n_images = _count_base64_images(self._truncated_message_history) + self._report( + f"[byte budget] start: current={current} bytes, " + f"budget={budget} bytes, images={n_images}, " + f"messages={len(self._truncated_message_history)}" + ) + + if current <= budget: + self._report( + f"[byte budget] within budget ({current} <= {budget}), " + "no images stripped" + ) + return None + + keep = _image_keep_count_for_byte_budget( + self._truncated_message_history, + budget, + current, + ) + self._report( + f"[byte budget] over budget ({current} > {budget}), " + f"keeping {keep}/{n_images} newest images, " + f"stripping {n_images - keep} oldest" + ) + + boundary = _strip_oldest_base64_images(self._truncated_message_history, keep) + new_size = estimate_messages_bytes(self._truncated_message_history) + if new_size > budget: + warn_msg = ( + f"Request still ~{new_size} bytes after stripping images " + f"(budget {budget}); non-image content exceeds the byte " + "budget. Relying on summarization to reduce it further." + ) + logger.warning(warn_msg) + self._report(warn_msg) + elif boundary is not None: + info_msg = ( + f"Stripped old images to meet byte budget: " + f"{current} -> {new_size} bytes (budget {budget})" + ) + logger.info(info_msg) + self._report(info_msg) + self._report( + f"[byte budget] done: new size={new_size} bytes, boundary index={boundary}" + ) + return boundary + def _capture_first_user_message(self, message: MessageParam) -> None: """Store the first user message if not already captured. @@ -514,119 +805,17 @@ def truncate(self) -> None: def _remove_images(self) -> None: """Strip old base64 images from truncated history. - Walks from the beginning and replaces excess base64 - `ImageBlockParam` blocks with text placeholders. Also - recurses into `ToolResultBlockParam.content` lists. - URL-based images are never stripped. + Keeps the newest `n_images_to_keep` base64 images and replaces + older ones with text placeholders, recursing into + `ToolResultBlockParam.content` lists. URL-based images are + never stripped. Updates `_image_removal_boundary_index` to the + last message modified. """ - total = self._count_base64_images(self._truncated_message_history) - to_remove = total - self._n_images_to_keep - if to_remove <= 0: - return - - removed = 0 - for i, msg in enumerate(self._truncated_message_history): - if removed >= to_remove: - break - if isinstance(msg.content, str): - continue - - new_content, removed_in_msg = self._strip_base64_images( - msg.content, to_remove - removed - ) - if removed_in_msg > 0: - self._truncated_message_history[i] = MessageParam( - role=msg.role, - content=new_content, - stop_reason=msg.stop_reason, - usage=msg.usage, - ) - self._image_removal_boundary_index = i - removed += removed_in_msg - - @staticmethod - def _count_base64_images( - messages: list[MessageParam], - ) -> int: - """Count total base64 image blocks across messages.""" - count = 0 - for msg in messages: - if isinstance(msg.content, str): - continue - for block in msg.content: - if isinstance(block, ImageBlockParam) and isinstance( - block.source, Base64ImageSourceParam - ): - count += 1 - elif isinstance(block, ToolResultBlockParam) and isinstance( - block.content, list - ): - for nested in block.content: - if isinstance(nested, ImageBlockParam) and isinstance( - nested.source, - Base64ImageSourceParam, - ): - count += 1 - return count - - @staticmethod - def _strip_base64_images( - content: list[ContentBlockParam], - max_to_strip: int, - ) -> tuple[list[ContentBlockParam], int]: - """Strip up to `max_to_strip` base64 images. - - Args: - content: The content blocks to process. - max_to_strip: Maximum number of images to strip. - - Returns: - Tuple of (new content list, count stripped). - """ - stripped = 0 - new_content: list[ContentBlockParam] = [] - - for block in content: - if stripped >= max_to_strip: - new_content.append(block) - continue - - if isinstance(block, ImageBlockParam) and isinstance( - block.source, Base64ImageSourceParam - ): - new_content.append(TextBlockParam(text=IMAGE_REMOVED_PLACEHOLDER)) - stripped += 1 - elif isinstance(block, ToolResultBlockParam) and isinstance( - block.content, list - ): - new_nested: list[TextBlockParam | ImageBlockParam] = [] - for nested in block.content: - if ( - stripped < max_to_strip - and isinstance(nested, ImageBlockParam) - and isinstance( - nested.source, - Base64ImageSourceParam, - ) - ): - new_nested.append( - TextBlockParam(text=IMAGE_REMOVED_PLACEHOLDER) - ) - stripped += 1 - else: - new_nested.append(nested) - new_content.append( - ToolResultBlockParam( - tool_use_id=block.tool_use_id, - content=new_nested, - is_error=block.is_error, - cache_control=block.cache_control, - ) - ) - else: - new_content.append(block) - - return new_content, stripped + boundary = _strip_oldest_base64_images( + self._truncated_message_history, self._n_images_to_keep + ) + if boundary is not None: + self._image_removal_boundary_index = boundary # ------------------------------------------------------------------ # Cache breakpoints @@ -673,11 +862,17 @@ class SummarizingTruncationStrategy(TruncationStrategy): """Truncation strategy that summarizes when limits are hit. Unlike `SlidingImageWindowSummarizingTruncationStrategy`, - this strategy does **not** strip images. It places a - single cache breakpoint on the last user message (moving - it forward on each append) and summarizes the conversation - history via the VLM when the token or message count - exceeds the configured threshold. + this strategy does **not** strip images on a sliding window. + It places a single cache breakpoint on the last user message + (moving it forward on each append) and summarizes the + conversation history via the VLM when the token or message + count exceeds the configured threshold. + + Images are only stripped as a last-resort safeguard: if the + estimated serialized request size exceeds `max_request_bytes` + (the endpoint rejects requests above ~32 MB), the oldest base64 + images are dropped — newest kept — to bring the request back + under the limit before it is sent. Conversation-owned dependencies (``vlm_provider``, ``reporter``, ``callbacks``, ``conversation``) are auto-injected by @@ -694,6 +889,14 @@ class SummarizingTruncationStrategy(TruncationStrategy): endpoint. truncation_threshold: Fraction of `max_input_tokens` at which to truncate. + max_request_bytes: Hard cap on the serialized request size in + bytes. When ``None`` (default), it is read from the + conversation's `VlmProvider` (e.g. ~32 MB for Anthropic). If + no provider advertises a limit either, byte-budget enforcement + is skipped entirely (no images are stripped to meet a budget). + request_size_threshold: Fraction of `max_request_bytes` at + which the oldest base64 images are stripped to keep the + request under the endpoint's byte limit. vlm_provider: Optional override for the summarization VLM. When ``None`` (default), the conversation's ``vlm_provider`` is used. @@ -705,12 +908,16 @@ def __init__( max_messages: int = MAX_MESSAGES, max_input_tokens: int = MAX_INPUT_TOKENS, truncation_threshold: float = TRUNCATION_THRESHOLD, + max_request_bytes: int | None = None, + request_size_threshold: float = REQUEST_SIZE_THRESHOLD, vlm_provider: VlmProvider | None = None, ) -> None: super().__init__( max_messages, max_input_tokens, truncation_threshold, + max_request_bytes, + request_size_threshold, ) self.vlm_provider = vlm_provider self._n_messages_to_keep = n_messages_to_keep @@ -732,6 +939,11 @@ def append_message(self, message: MessageParam) -> None: self._full_message_history.append(message) self._truncated_message_history.append(message) + # Enforce the request byte budget first so the summarization + # request itself stays under the endpoint's hard byte limit. + # Strips oldest images only if the history exceeds the budget. + self._enforce_byte_budget() + # Move cache breakpoint to last user message self._move_cache_breakpoint() @@ -743,6 +955,12 @@ def append_message(self, message: MessageParam) -> None: or token_counts.total > self._absolute_truncation_threshold ): self.truncate() + # Summarization re-inserts the original first user message, + # which may carry a large image; re-enforce so the final + # history stays under budget, then re-place the breakpoint + # on the rebuilt history. + self._enforce_byte_budget() + self._move_cache_breakpoint() def _move_cache_breakpoint(self) -> None: """Place a cache breakpoint on the last user message. diff --git a/tests/unit/model_providers/test_openai_vlm_provider.py b/tests/unit/model_providers/test_openai_vlm_provider.py index d51ff74b..cd8b2b08 100644 --- a/tests/unit/model_providers/test_openai_vlm_provider.py +++ b/tests/unit/model_providers/test_openai_vlm_provider.py @@ -17,6 +17,12 @@ def test_pricing_returns_none(self) -> None: provider = OpenAIVlmProvider(model_id="gpt-4o", api_key="sk-test") assert provider.pricing is None + def test_max_request_bytes_none_by_default(self) -> None: + # OpenAI-compatible providers do not advertise a byte limit, so + # callers fall back to a conservative default. + provider = OpenAIVlmProvider(model_id="gpt-4o", api_key="sk-test") + assert provider.max_request_bytes is None + def test_injected_client_used(self) -> None: mock_client = MagicMock(spec=OpenAI) provider = OpenAIVlmProvider(model_id="gpt-4o", client=mock_client) diff --git a/tests/unit/model_providers/test_vlm_provider_request_bytes.py b/tests/unit/model_providers/test_vlm_provider_request_bytes.py new file mode 100644 index 00000000..d61c4893 --- /dev/null +++ b/tests/unit/model_providers/test_vlm_provider_request_bytes.py @@ -0,0 +1,18 @@ +"""Tests for the ``max_request_bytes`` limit advertised by VLM providers.""" + +from unittest.mock import MagicMock + +from askui.model_providers.anthropic_vlm_provider import AnthropicVlmProvider +from askui.model_providers.askui_vlm_provider import AskUIVlmProvider +from askui.models.shared.request_size import ANTHROPIC_MAX_REQUEST_BYTES + + +class TestProviderMaxRequestBytes: + def test_anthropic_provider_reports_anthropic_limit(self) -> None: + provider = AnthropicVlmProvider(client=MagicMock()) + assert provider.max_request_bytes == ANTHROPIC_MAX_REQUEST_BYTES + + def test_askui_provider_reports_anthropic_limit(self) -> None: + # Pass mock settings so construction does not validate env credentials. + provider = AskUIVlmProvider(askui_settings=MagicMock(), client=MagicMock()) + assert provider.max_request_bytes == ANTHROPIC_MAX_REQUEST_BYTES diff --git a/tests/unit/models/test_truncation_strategies.py b/tests/unit/models/test_truncation_strategies.py index 37adee7f..4fed6f33 100644 --- a/tests/unit/models/test_truncation_strategies.py +++ b/tests/unit/models/test_truncation_strategies.py @@ -1,7 +1,10 @@ """Unit tests for truncation strategies.""" +import logging from unittest.mock import MagicMock +import pytest + from askui.callbacks.conversation_callback import ConversationCallback from askui.models.shared.agent_message_param import ( Base64ImageSourceParam, @@ -14,9 +17,13 @@ UrlImageSourceParam, UsageParam, ) +from askui.models.shared.request_size import ( + estimate_messages_bytes, +) from askui.models.shared.truncation_strategies import ( SlidingImageWindowSummarizingTruncationStrategy, SummarizingTruncationStrategy, + _image_keep_count_for_byte_budget, ) IMAGE_REMOVED_PLACEHOLDER = "[Screenshot removed to reduce message history length]" @@ -27,12 +34,17 @@ # --------------------------------------------------------------------------- -def _make_base64_image_block() -> ImageBlockParam: +def _make_base64_image_block(data: str = "abc123") -> ImageBlockParam: return ImageBlockParam( - source=Base64ImageSourceParam(data="abc123", media_type="image/png"), + source=Base64ImageSourceParam(data=data, media_type="image/png"), ) +def _make_sized_image_block(n_bytes: int) -> ImageBlockParam: + """Base64 image whose estimated byte size is ``n_bytes``.""" + return _make_base64_image_block(data="x" * n_bytes) + + def _make_url_image_block() -> ImageBlockParam: return ImageBlockParam( source=UrlImageSourceParam(url="https://example.com/img.png"), @@ -56,6 +68,10 @@ def _make_vlm_provider(usage: UsageParam | None = None) -> MagicMock: content="Summary of the conversation.", usage=usage, ) + # A bare MagicMock attribute is truthy; set None so byte-budget + # enforcement is skipped (no images stripped to meet a budget) unless + # a test configures an explicit limit. + provider.max_request_bytes = None return provider @@ -923,8 +939,15 @@ def test_summarizing_strategy_reports_summary_response(self) -> None: role = "user" if i % 2 == 0 else "assistant" strategy.append_message(MessageParam(role=role, content=f"msg {i}")) strategy.truncate() - reporter.add_message.assert_called_once() - call_args = reporter.add_message.call_args + # Byte-budget debug messages (plain strings) may also be reported; + # the summary response is the one dict payload. + summary_calls = [ + c + for c in reporter.add_message.call_args_list + if isinstance(c.args[1], dict) + ] + assert len(summary_calls) == 1 + call_args = summary_calls[0] assert call_args.args[0] == "TruncationStrategy" # Logged content is the raw VLM response dump assert call_args.args[1]["role"] == "assistant" @@ -942,8 +965,15 @@ def test_sliding_strategy_reports_summary_response(self) -> None: role = "user" if i % 2 == 0 else "assistant" strategy.append_message(MessageParam(role=role, content=f"msg {i}")) strategy.truncate() - reporter.add_message.assert_called_once() - call_args = reporter.add_message.call_args + # Byte-budget debug messages (plain strings) may also be reported; + # the summary response is the one dict payload. + summary_calls = [ + c + for c in reporter.add_message.call_args_list + if isinstance(c.args[1], dict) + ] + assert len(summary_calls) == 1 + call_args = summary_calls[0] assert call_args.args[0] == "TruncationStrategy" assert call_args.args[1]["content"] == "Summary of the conversation." @@ -1112,3 +1142,205 @@ def test_strategy_without_conversation_passes_none_for_context(self) -> None: assert call_kwargs["system"] is None assert call_kwargs["tools"] is None assert call_kwargs["provider_options"] is None + + +# --------------------------------------------------------------------------- +# Byte-budget enforcement +# --------------------------------------------------------------------------- + + +def _is_placeholder(block: ContentBlockParam) -> bool: + return isinstance(block, TextBlockParam) and block.text == IMAGE_REMOVED_PLACEHOLDER + + +def _first_block(msg: MessageParam) -> ContentBlockParam: + assert isinstance(msg.content, list) + return msg.content[0] + + +class TestByteBudgetHelpers: + """Direct tests of the keep-count math driving byte enforcement.""" + + def test_keep_all_when_under_budget(self) -> None: + msgs = [ + MessageParam(role="user", content=[_make_sized_image_block(500)]), + MessageParam(role="user", content=[_make_sized_image_block(500)]), + ] + current = estimate_messages_bytes(msgs) + # Budget exactly equal to current => keep everything. + assert _image_keep_count_for_byte_budget(msgs, current, current) == 2 + + def test_drops_oldest_until_under_budget(self) -> None: + msgs = [ + MessageParam(role="user", content=[_make_sized_image_block(500)]), + MessageParam(role="assistant", content=[_make_sized_image_block(500)]), + MessageParam(role="user", content=[_make_sized_image_block(500)]), + ] + current = estimate_messages_bytes(msgs) # ~1500 + # Only the newest image (~500) fits under 1000. + assert _image_keep_count_for_byte_budget(msgs, 1000, current) == 1 + + def test_drops_all_when_even_one_image_exceeds_budget(self) -> None: + msgs = [MessageParam(role="user", content=[_make_sized_image_block(500)])] + current = estimate_messages_bytes(msgs) + assert _image_keep_count_for_byte_budget(msgs, 100, current) == 0 + + def test_no_images_keeps_zero(self) -> None: + msgs = [MessageParam(role="user", content="x" * 500)] + assert _image_keep_count_for_byte_budget(msgs, 100, 500) == 0 + + +class TestSummarizingByteBudget: + """`SummarizingTruncationStrategy` strips oldest images on byte overflow.""" + + def _make(self, max_request_bytes: int) -> SummarizingTruncationStrategy: + return SummarizingTruncationStrategy( + vlm_provider=_make_vlm_provider(), + n_messages_to_keep=100, + # Keep token-based truncation out of the way. + max_input_tokens=10_000_000, + max_request_bytes=max_request_bytes, + # Treat max_request_bytes as the exact budget for these tests. + request_size_threshold=1.0, + ) + + def test_strips_oldest_images_when_over_budget(self) -> None: + strategy = self._make(max_request_bytes=1000) + for i in range(3): + role = "user" if i % 2 == 0 else "assistant" + strategy.append_message( + MessageParam(role=role, content=[_make_sized_image_block(500)]) + ) + msgs = strategy.truncated_messages + # Oldest two images replaced by placeholders, newest kept. + assert _is_placeholder(_first_block(msgs[0])) + assert _is_placeholder(_first_block(msgs[1])) + assert isinstance(_first_block(msgs[2]), ImageBlockParam) + + def test_final_history_within_budget(self) -> None: + strategy = self._make(max_request_bytes=1000) + for i in range(5): + role = "user" if i % 2 == 0 else "assistant" + strategy.append_message( + MessageParam(role=role, content=[_make_sized_image_block(500)]) + ) + assert estimate_messages_bytes(strategy.truncated_messages) <= 1000 + + def test_no_stripping_when_under_budget(self) -> None: + strategy = self._make(max_request_bytes=10_000) + for i in range(3): + role = "user" if i % 2 == 0 else "assistant" + strategy.append_message( + MessageParam(role=role, content=[_make_sized_image_block(500)]) + ) + for msg in strategy.truncated_messages: + assert isinstance(_first_block(msg), ImageBlockParam) + + def test_full_messages_keep_original_images(self) -> None: + strategy = self._make(max_request_bytes=1000) + for i in range(3): + role = "user" if i % 2 == 0 else "assistant" + strategy.append_message( + MessageParam(role=role, content=[_make_sized_image_block(500)]) + ) + # Full (append-only) history must retain every original image. + for msg in strategy.full_messages: + assert isinstance(_first_block(msg), ImageBlockParam) + + def test_url_images_not_stripped(self) -> None: + strategy = self._make(max_request_bytes=1) + strategy.append_message( + MessageParam(role="user", content=[_make_url_image_block()]) + ) + # URL images carry no payload bytes and must never be stripped. + assert isinstance(_first_block(strategy.truncated_messages[0]), ImageBlockParam) + + def test_warns_when_non_image_content_exceeds_budget( + self, caplog: pytest.LogCaptureFixture + ) -> None: + strategy = self._make(max_request_bytes=100) + with caplog.at_level(logging.WARNING): + strategy.append_message(MessageParam(role="user", content="x" * 500)) + # Cannot strip text; logs a warning but preserves the message. + assert any("byte budget" in record.message for record in caplog.records) + assert strategy.truncated_messages[0].content == "x" * 500 + + +class TestRequestSizeThreshold: + """Stripping kicks in at ``max_request_bytes * request_size_threshold``.""" + + def test_strips_at_threshold_not_hard_limit(self) -> None: + # Hard limit 1000, threshold 0.8 => effective budget 800. + strategy = SummarizingTruncationStrategy( + vlm_provider=_make_vlm_provider(), + n_messages_to_keep=100, + max_input_tokens=10_000_000, + max_request_bytes=1000, + request_size_threshold=0.8, + ) + # Two 500-byte images = ~1000 bytes: under the hard limit but over + # the 800-byte threshold, so the oldest image must be stripped. + for i in range(2): + role = "user" if i % 2 == 0 else "assistant" + strategy.append_message( + MessageParam(role=role, content=[_make_sized_image_block(500)]) + ) + msgs = strategy.truncated_messages + assert _is_placeholder(_first_block(msgs[0])) + assert isinstance(_first_block(msgs[1]), ImageBlockParam) + assert estimate_messages_bytes(msgs) <= 800 + + def test_default_threshold_is_80_percent(self) -> None: + strategy = SummarizingTruncationStrategy(max_request_bytes=1000) + assert strategy._request_size_threshold == 0.8 # noqa: SLF001 + assert strategy._byte_budget() == 800 # noqa: SLF001 + + +class TestByteBudgetResolution: + """The byte limit is sourced from the provider unless overridden.""" + + def test_explicit_override_wins(self) -> None: + strategy = SummarizingTruncationStrategy( + max_request_bytes=1234, + request_size_threshold=1.0, + ) + assert strategy._resolve_max_request_bytes() == 1234 # noqa: SLF001 + + def test_reads_limit_from_conversation_provider(self) -> None: + provider = MagicMock() + provider.max_request_bytes = 5_000_000 + conversation = MagicMock() + conversation.vlm_provider = provider + + strategy = SummarizingTruncationStrategy(request_size_threshold=1.0) + strategy.conversation = conversation + assert strategy._resolve_max_request_bytes() == 5_000_000 # noqa: SLF001 + + def test_resolves_to_none_when_provider_has_no_limit(self) -> None: + provider = MagicMock() + provider.max_request_bytes = None + conversation = MagicMock() + conversation.vlm_provider = provider + + strategy = SummarizingTruncationStrategy(request_size_threshold=1.0) + strategy.conversation = conversation + assert strategy._resolve_max_request_bytes() is None # noqa: SLF001 + assert strategy._byte_budget() is None # noqa: SLF001 + + def test_no_images_stripped_when_no_limit_defined(self) -> None: + # No explicit override and provider advertises no limit: byte-budget + # enforcement is skipped, so even large images are kept verbatim. + provider = _make_vlm_provider() # max_request_bytes is None + strategy = SummarizingTruncationStrategy( + vlm_provider=provider, + n_messages_to_keep=100, + max_input_tokens=10_000_000, + ) + for i in range(2): + role = "user" if i % 2 == 0 else "assistant" + strategy.append_message( + MessageParam(role=role, content=[_make_sized_image_block(5000)]) + ) + msgs = strategy.truncated_messages + assert isinstance(_first_block(msgs[0]), ImageBlockParam) + assert isinstance(_first_block(msgs[1]), ImageBlockParam)