Skip to content

Commit 7beea3a

Browse files
authored
feat(memory): implement core memory and agentic components (#67)
1 parent b2e97c2 commit 7beea3a

11 files changed

Lines changed: 2089 additions & 0 deletions

File tree

AGENTS.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,8 @@ If you are implementing a new feature, please implement the unit test and exampl
3535

3636
- For unit test, add in `tests/<module_name>`, and inherit the `unittest.TestCase` class.
3737
- For example, add in `examples/<module_name>`, and just demo the simple usage. (do not add too many use cases in single file)
38+
39+
## Comment Style
40+
41+
- All comments should be in English.
42+
- All comments should be in the Google style.
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
"""Minimal example showing how to build a memory timeline."""
2+
3+
from quantmind.brain.memory import Memory
4+
from quantmind.models.memory import ActionStep, TaskStep, ToolCall
5+
from quantmind.models.messages import ChatMessage, MessageRole
6+
from quantmind.utils.monitoring import Timing, TokenUsage
7+
8+
9+
def main():
10+
"""Main function."""
11+
memory = Memory("You are a quantitative research assistant.")
12+
13+
memory.steps.append(TaskStep(task="Gather the latest market sentiment."))
14+
15+
memory.steps.append(
16+
ActionStep(
17+
step_number=1,
18+
timing=Timing(start_time=0.0, end_time=0.5),
19+
model_input_messages=[
20+
ChatMessage(
21+
role=MessageRole.USER,
22+
content=[
23+
{"type": "text", "text": "Any updates on bond markets?"}
24+
],
25+
)
26+
],
27+
tool_calls=[
28+
ToolCall(
29+
name="fetch_sentiment",
30+
arguments={"asset": "treasury", "lookback": "1d"},
31+
id="call-1",
32+
)
33+
],
34+
model_output="Sentiment looks neutral across regions.",
35+
observations="Tool returned neutral scores.",
36+
token_usage=TokenUsage(input_tokens=12, output_tokens=9),
37+
)
38+
)
39+
40+
print("Succinct steps:")
41+
for step in memory.get_succinct_steps():
42+
print(step)
43+
44+
print("\nMessages replay:")
45+
messages = []
46+
messages.extend(memory.system_prompt.to_messages())
47+
for step in memory.steps:
48+
messages.extend(step.to_messages())
49+
50+
# Define colors for different message roles
51+
ROLE_COLORS = {
52+
MessageRole.SYSTEM: "\033[35m", # Magenta
53+
MessageRole.USER: "\033[32m", # Green
54+
MessageRole.ASSISTANT: "\033[36m", # Cyan
55+
MessageRole.TOOL_CALL: "\033[33m", # Yellow
56+
MessageRole.TOOL_RESPONSE: "\033[34m", # Blue
57+
}
58+
RESET = "\033[0m"
59+
60+
for message in messages:
61+
color = ROLE_COLORS.get(message.role, "")
62+
print(f"{color}[{message.role.value}]{RESET} {message.content}")
63+
64+
65+
if __name__ == "__main__":
66+
main()

quantmind/brain/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .memory import Memory
2+
3+
__all__ = ["Memory"]

quantmind/brain/memory.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
import inspect
2+
from logging import getLogger
3+
from typing import Callable, Type
4+
5+
from quantmind.models.memory import (
6+
ActionStep,
7+
MemoryStep,
8+
PlanningStep,
9+
SystemPromptStep,
10+
TaskStep,
11+
)
12+
from quantmind.utils.monitoring import AgentLogger, LogLevel
13+
14+
logger = getLogger(__name__)
15+
16+
17+
class Memory:
18+
"""Memory for the brain, containing the system prompt and all steps taken by the brain.
19+
20+
This class is used to store the agent's steps, including tasks, actions, and planning steps.
21+
It allows for resetting the memory, retrieving succinct or full step information, and replaying
22+
the agent's steps.
23+
24+
Args:
25+
system_prompt (`str`): System prompt for the agent, which sets the context and instructions
26+
for the agent's behavior.
27+
28+
**Attributes**:
29+
- **system_prompt** (`SystemPromptStep`) -- System prompt step for the agent.
30+
- **steps** (`list[TaskStep | ActionStep | PlanningStep]`) -- List of steps taken by the
31+
agent, which can include tasks, actions, and planning steps.
32+
"""
33+
34+
def __init__(self, system_prompt: str):
35+
self.system_prompt: SystemPromptStep = SystemPromptStep(
36+
system_prompt=system_prompt
37+
)
38+
self.steps: list[TaskStep | ActionStep | PlanningStep] = []
39+
40+
def reset(self):
41+
"""Reset the agent's memory, clearing all steps and keeping the system prompt."""
42+
self.steps = []
43+
44+
def get_succinct_steps(self) -> list[dict]:
45+
"""Return a succinct representation of the agent's steps, excluding model input messages."""
46+
return [
47+
{
48+
key: value
49+
for key, value in step.dict().items()
50+
if key != "model_input_messages"
51+
}
52+
for step in self.steps
53+
]
54+
55+
def get_full_steps(self) -> list[dict]:
56+
"""Return a full representation of the agent's steps, including model input messages."""
57+
if len(self.steps) == 0:
58+
return []
59+
return [step.dict() for step in self.steps]
60+
61+
def replay(self, logger: AgentLogger, detailed: bool = False):
62+
"""Prints a pretty replay of the agent's steps.
63+
64+
Args:
65+
logger (`AgentLogger`): The logger to print replay logs to.
66+
detailed (`bool`, default `False`): If True, also displays the memory at each step.
67+
Defaults to False.
68+
Careful: will increase log length exponentially. Use only for debugging.
69+
"""
70+
logger.console.log("Replaying the agent's steps:")
71+
logger.log_markdown(
72+
title="System prompt",
73+
content=self.system_prompt.system_prompt,
74+
level=LogLevel.ERROR,
75+
)
76+
for step in self.steps:
77+
if isinstance(step, TaskStep):
78+
logger.log_task(step.task, "", level=LogLevel.ERROR)
79+
elif isinstance(step, ActionStep):
80+
logger.log_rule(
81+
f"Step {step.step_number}", level=LogLevel.ERROR
82+
)
83+
if detailed and step.model_input_messages is not None:
84+
logger.log_messages(
85+
step.model_input_messages, level=LogLevel.ERROR
86+
)
87+
if step.model_output is not None:
88+
logger.log_markdown(
89+
title="Agent output:",
90+
content=step.model_output,
91+
level=LogLevel.ERROR,
92+
)
93+
elif isinstance(step, PlanningStep):
94+
logger.log_rule("Planning step", level=LogLevel.ERROR)
95+
if detailed and step.model_input_messages is not None:
96+
logger.log_messages(
97+
step.model_input_messages, level=LogLevel.ERROR
98+
)
99+
logger.log_markdown(
100+
title="Agent output:",
101+
content=step.plan,
102+
level=LogLevel.ERROR,
103+
)
104+
105+
def return_full_code(self) -> str:
106+
"""Returns all code actions from the agent's steps, concatenated as a single script."""
107+
return "\n\n".join(
108+
[
109+
step.code_action
110+
for step in self.steps
111+
if isinstance(step, ActionStep) and step.code_action is not None
112+
]
113+
)
114+
115+
116+
class CallbackRegistry:
117+
"""Registry for callbacks that are called at each step of the agent's execution.
118+
119+
Callbacks are registered by passing a step class and a callback function.
120+
"""
121+
122+
def __init__(self):
123+
self._callbacks: dict[Type[MemoryStep], list[Callable]] = {}
124+
125+
def register(self, step_cls: Type[MemoryStep], callback: Callable):
126+
"""Register a callback for a step class.
127+
128+
Args:
129+
step_cls (Type[MemoryStep]): Step class to register the callback for.
130+
callback (Callable): Callback function to register.
131+
"""
132+
if step_cls not in self._callbacks:
133+
self._callbacks[step_cls] = []
134+
self._callbacks[step_cls].append(callback)
135+
136+
def callback(self, memory_step, **kwargs):
137+
"""Call callbacks registered for a step type.
138+
139+
Args:
140+
memory_step (MemoryStep): Step to call the callbacks for.
141+
**kwargs: Additional arguments to pass to callbacks that accept them.
142+
Typically, includes the agent instance.
143+
144+
Notes:
145+
For backwards compatibility, callbacks with a single parameter signature
146+
receive only the memory_step, while callbacks with multiple parameters
147+
receive both the memory_step and any additional kwargs.
148+
"""
149+
# For compatibility with old callbacks that only take the step as an argument
150+
for cls in memory_step.__class__.__mro__:
151+
for cb in self._callbacks.get(cls, []):
152+
cb(memory_step) if len(
153+
inspect.signature(cb).parameters
154+
) == 1 else cb(memory_step, **kwargs)

0 commit comments

Comments
 (0)