Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 68 additions & 42 deletions src/askui/models/shared/truncation_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@
logger = logging.getLogger(__name__)

# needs to be below limits imposed by endpoint
MAX_INPUT_TOKENS = 100_000
MAX_INPUT_TOKENS = 200_000

# we will truncate as soon as we reach this threshold
TRUNCATION_THRESHOLD = 0.7
TRUNCATION_THRESHOLD = 0.56
Comment thread
mlikasam-askui marked this conversation as resolved.

# see https://docs.anthropic.com/en/api/messages#body-messages
MAX_MESSAGES = 100_000
Expand Down Expand Up @@ -222,6 +222,7 @@ def __init__(
) -> None:
self._full_message_history: list[MessageParam] = []
self._truncated_message_history: list[MessageParam] = []
self._first_user_message: MessageParam | None = None
self._max_messages = max_messages
self._absolute_truncation_threshold = int(
max_input_tokens * truncation_threshold
Expand Down Expand Up @@ -261,6 +262,60 @@ def truncate(self) -> None:
"""Force-truncate the message history."""
...

def _capture_first_user_message(self, message: MessageParam) -> None:
"""Store the first user message if not already captured.

Args:
message: The message to check and potentially store.
"""
if self._first_user_message is None and message.role == "user":
self._first_user_message = message

def _build_truncated_messages(
self,
summary: str,
recent: list[MessageParam],
) -> list[MessageParam]:
"""Build the new truncated history with the first user message preserved.

Prepends the original first user message (if captured),
adds the summary, ensures valid role alternation, and
appends the recent messages.

Args:
summary: The LLM-generated conversation summary.
recent: The most-recent messages to keep verbatim.

Returns:
The assembled message list.
"""
new_messages: list[MessageParam] = []

if self._first_user_message is not None:
new_messages.append(self._first_user_message)
new_messages.append(
MessageParam(
role="assistant",
content="Understood. I'll keep your original instructions in mind.",
)
)

new_messages.append(MessageParam(role="user", content=summary))

if recent and recent[0].role == "user":
new_messages.append(
MessageParam(
role="assistant",
content=(
"Understood. I'll continue based on "
"the conversation summary above."
),
)
)

new_messages.extend(recent)
return new_messages

def reset(self, messages: list[MessageParam] | None = None) -> None:
"""Reset message histories with optional initial messages.

Expand All @@ -271,9 +326,12 @@ def reset(self, messages: list[MessageParam] | None = None) -> None:
messages: Initial messages to populate both histories.
If ``None``, both histories are cleared.
"""
self._first_user_message = None
if messages is not None:
self._full_message_history = list(messages)
self._truncated_message_history = list(messages)
for m in messages:
self._capture_first_user_message(m)
else:
self._full_message_history = []
self._truncated_message_history = []
Expand Down Expand Up @@ -360,6 +418,7 @@ def append_message(self, message: MessageParam) -> None:
Args:
message: The message to append.
"""
self._capture_first_user_message(message)
self._full_message_history.append(message)
self._truncated_message_history.append(message)

Expand Down Expand Up @@ -443,27 +502,9 @@ def truncate(self) -> None:
return

recent = self._truncated_message_history[cut:]

# Build new history with the summary as a user message
new_messages: list[MessageParam] = [
MessageParam(role="user", content=summary),
]

# Ensure valid role alternation: if first recent message
# is also "user", insert a synthetic assistant ack.
if recent and recent[0].role == "user":
new_messages.append(
MessageParam(
role="assistant",
content=(
"Understood. I'll continue based on "
"the conversation summary above."
),
)
)

new_messages.extend(recent)
self._truncated_message_history = new_messages
self._truncated_message_history = self._build_truncated_messages(
summary, recent
)
self._image_removal_boundary_index = None

# ------------------------------------------------------------------
Expand Down Expand Up @@ -687,6 +728,7 @@ def append_message(self, message: MessageParam) -> None:
Args:
message: The message to append.
"""
self._capture_first_user_message(message)
self._full_message_history.append(message)
self._truncated_message_history.append(message)

Expand Down Expand Up @@ -768,22 +810,6 @@ def truncate(self) -> None:
return

recent = self._truncated_message_history[cut:]

new_messages: list[MessageParam] = [
MessageParam(role="user", content=summary),
]

# Ensure valid role alternation
if recent and recent[0].role == "user":
new_messages.append(
MessageParam(
role="assistant",
content=(
"Understood. I'll continue based on "
"the conversation summary above."
),
)
)

new_messages.extend(recent)
self._truncated_message_history = new_messages
self._truncated_message_history = self._build_truncated_messages(
summary, recent
)
142 changes: 123 additions & 19 deletions tests/unit/models/test_truncation_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,13 @@ def test_truncate_replaces_history_with_summary(self) -> None:
# Force truncation
strategy.truncate()
msgs = strategy.truncated_messages
# First message should be the summary (user role)
# First message is the preserved original first user message
assert msgs[0].role == "user"
assert msgs[0].content == "Summary of the conversation."
assert msgs[0].content == "msg 0"
# Then assistant ack, then summary
assert msgs[1].role == "assistant"
assert msgs[2].role == "user"
assert msgs[2].content == "Summary of the conversation."
# Last 2 messages preserved
assert msgs[-1].content == "msg 5"
assert msgs[-2].content == "msg 4"
Expand All @@ -335,12 +339,14 @@ def test_truncate_inserts_synthetic_assistant_for_alternation(self) -> None:
strategy.append_message(MessageParam(role=role, content=f"msg {i}"))
strategy.truncate()
msgs = strategy.truncated_messages
# Summary (user) -> msgs[-2] is "msg 4" (user)
# So a synthetic assistant should be inserted between
assert msgs[0].role == "user" # summary
assert msgs[1].role == "assistant" # synthetic
assert "Understood" in str(msgs[1].content)
assert msgs[2].role == "user" # msg 4
# First user message preserved, then ack, then summary
assert msgs[0].role == "user" # original first user message
assert msgs[0].content == "msg 0"
assert msgs[1].role == "assistant" # ack for first user message
assert msgs[2].role == "user" # summary
assert msgs[3].role == "assistant" # synthetic ack for alternation
assert "Understood" in str(msgs[3].content)
assert msgs[4].role == "user" # msg 4

def test_truncate_skips_when_too_few_messages(self) -> None:
strategy = _make_strategy(n_messages_to_keep=10)
Expand Down Expand Up @@ -382,14 +388,14 @@ def test_truncate_resets_image_boundary(self) -> None:
def test_full_messages_preserved_after_truncation(self) -> None:
vlm = _make_vlm_provider()
strategy = _make_strategy(vlm_provider=vlm, n_messages_to_keep=2)
for i in range(6):
for i in range(10):
role = "user" if i % 2 == 0 else "assistant"
strategy.append_message(MessageParam(role=role, content=f"msg {i}"))
strategy.truncate()
# Full messages should still have all 6
assert len(strategy.full_messages) == 6
# Full messages should still have all 10
assert len(strategy.full_messages) == 10
# Truncated messages should be shorter
assert len(strategy.truncated_messages) < 6
assert len(strategy.truncated_messages) < 10

def test_truncate_preserves_tool_use_tool_result_pairs(self) -> None:
vlm = _make_vlm_provider()
Expand Down Expand Up @@ -507,6 +513,91 @@ def test_truncate_deferred_when_last_message_has_tool_use(self) -> None:
vlm.create_message.assert_called_once()


# ---------------------------------------------------------------------------
# First user message preservation
# ---------------------------------------------------------------------------


class TestFirstUserMessagePreservation:
"""Both strategies must always keep the original first user message."""

def test_sliding_preserves_first_user_message(self) -> None:
vlm = _make_vlm_provider()
strategy = _make_strategy(vlm_provider=vlm, n_messages_to_keep=2)
for i in range(6):
role = "user" if i % 2 == 0 else "assistant"
strategy.append_message(MessageParam(role=role, content=f"msg {i}"))
strategy.truncate()
msgs = strategy.truncated_messages
assert msgs[0].role == "user"
assert msgs[0].content == "msg 0"

def test_summarizing_preserves_first_user_message(self) -> None:
vlm = _make_vlm_provider()
strategy = _make_summarizing_strategy(vlm_provider=vlm, n_messages_to_keep=2)
for i in range(6):
role = "user" if i % 2 == 0 else "assistant"
strategy.append_message(MessageParam(role=role, content=f"msg {i}"))
strategy.truncate()
msgs = strategy.truncated_messages
assert msgs[0].role == "user"
assert msgs[0].content == "msg 0"

def test_first_user_message_survives_multiple_truncations(self) -> None:
vlm = _make_vlm_provider()
strategy = _make_summarizing_strategy(vlm_provider=vlm, n_messages_to_keep=2)
for i in range(6):
role = "user" if i % 2 == 0 else "assistant"
strategy.append_message(MessageParam(role=role, content=f"msg {i}"))
strategy.truncate()
# Add more messages and truncate again
for i in range(6, 12):
role = "user" if i % 2 == 0 else "assistant"
strategy.append_message(MessageParam(role=role, content=f"msg {i}"))
strategy.truncate()
msgs = strategy.truncated_messages
# Original first user message must still be at position 0
assert msgs[0].role == "user"
assert msgs[0].content == "msg 0"

def test_first_user_message_captured_from_reset(self) -> None:
vlm = _make_vlm_provider()
strategy = _make_summarizing_strategy(vlm_provider=vlm, n_messages_to_keep=2)
initial_msgs = [
MessageParam(role="user", content="initial task"),
MessageParam(role="assistant", content="ok"),
]
strategy.reset(initial_msgs)
for i in range(6):
role = "user" if i % 2 == 0 else "assistant"
strategy.append_message(MessageParam(role=role, content=f"msg {i}"))
strategy.truncate()
msgs = strategy.truncated_messages
assert msgs[0].role == "user"
assert msgs[0].content == "initial task"

def test_first_user_message_cleared_on_none_reset(self) -> None:
strategy = _make_summarizing_strategy()
strategy.append_message(MessageParam(role="user", content="first"))
strategy.reset()
assert strategy._first_user_message is None # noqa: SLF001

def test_role_alternation_valid_after_truncation(self) -> None:
"""Verify user/assistant roles alternate correctly after truncation."""
vlm = _make_vlm_provider()
strategy = _make_summarizing_strategy(vlm_provider=vlm, n_messages_to_keep=2)
for i in range(6):
role = "user" if i % 2 == 0 else "assistant"
strategy.append_message(MessageParam(role=role, content=f"msg {i}"))
strategy.truncate()
msgs = strategy.truncated_messages
for i in range(len(msgs) - 1):
assert msgs[i].role != msgs[i + 1].role, (
f"Adjacent messages at {i} and {i + 1} have the same role: "
f"{msgs[i].role}"
)


# ---------------------------------------------------------------------------
# Edge cases
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -670,8 +761,13 @@ def test_truncate_replaces_history_with_summary(self) -> None:
strategy.append_message(MessageParam(role=role, content=f"msg {i}"))
strategy.truncate()
msgs = strategy.truncated_messages
# First message is the preserved original first user message
assert msgs[0].role == "user"
assert msgs[0].content == "Summary of the conversation."
assert msgs[0].content == "msg 0"
# Then assistant ack, then summary
assert msgs[1].role == "assistant"
assert msgs[2].role == "user"
assert msgs[2].content == "Summary of the conversation."
assert msgs[-1].content == "msg 5"
assert msgs[-2].content == "msg 4"

Expand All @@ -683,9 +779,13 @@ def test_truncate_inserts_synthetic_assistant(self) -> None:
strategy.append_message(MessageParam(role=role, content=f"msg {i}"))
strategy.truncate()
msgs = strategy.truncated_messages
# First user message preserved, then ack, then summary
assert msgs[0].role == "user"
assert msgs[0].content == "msg 0"
assert msgs[1].role == "assistant"
assert "Understood" in str(msgs[1].content)
assert msgs[2].role == "user" # summary
assert msgs[3].role == "assistant" # synthetic ack
assert "Understood" in str(msgs[3].content)

def test_truncate_skips_when_too_few_messages(self) -> None:
strategy = _make_summarizing_strategy(n_messages_to_keep=10)
Expand All @@ -698,12 +798,12 @@ def test_truncate_skips_when_too_few_messages(self) -> None:
def test_full_messages_preserved_after_truncation(self) -> None:
vlm = _make_vlm_provider()
strategy = _make_summarizing_strategy(vlm_provider=vlm, n_messages_to_keep=2)
for i in range(6):
for i in range(10):
role = "user" if i % 2 == 0 else "assistant"
strategy.append_message(MessageParam(role=role, content=f"msg {i}"))
strategy.truncate()
assert len(strategy.full_messages) == 6
assert len(strategy.truncated_messages) < 6
assert len(strategy.full_messages) == 10
assert len(strategy.truncated_messages) < 10

def test_preserves_tool_use_tool_result_pairs(self) -> None:
vlm = _make_vlm_provider()
Expand Down Expand Up @@ -858,7 +958,9 @@ def test_strategy_does_not_report_when_no_reporter(self) -> None:
strategy.append_message(MessageParam(role=role, content=f"msg {i}"))
# Should not crash even though no reporter is set
strategy.truncate()
assert strategy.truncated_messages[0].content == "Summary of the conversation."
# First message is preserved original, summary is at index 2
assert strategy.truncated_messages[0].content == "msg 0"
assert strategy.truncated_messages[2].content == "Summary of the conversation."


class TestCallbackIntegration:
Expand Down Expand Up @@ -926,7 +1028,9 @@ def test_strategy_no_callbacks_no_crash(self) -> None:
role = "user" if i % 2 == 0 else "assistant"
strategy.append_message(MessageParam(role=role, content=f"msg {i}"))
strategy.truncate()
assert strategy.truncated_messages[0].content == "Summary of the conversation."
# First message is preserved original, summary is at index 2
assert strategy.truncated_messages[0].content == "msg 0"
assert strategy.truncated_messages[2].content == "Summary of the conversation."


class TestSummarizationRequestContext:
Expand Down
Loading