diff --git a/agents/agents/agents/agent_base.py b/agents/agents/agents/agent_base.py index c2a8639..0d712ce 100644 --- a/agents/agents/agents/agent_base.py +++ b/agents/agents/agents/agent_base.py @@ -14,6 +14,7 @@ import os import transformers import warnings +from .chain.streaming_observer import ConsoleStreamObserver, StreamingManager try: from verl.protocol import DataProto except ImportError: @@ -43,6 +44,7 @@ def __init__( log_file: str = "agent", project_name: str = None, run_name: str = None, + streaming: str = "console", **kwargs # To pass other unused arguments ): """ @@ -68,6 +70,12 @@ def __init__( self.jinja_template = get_template(self.template).jinja_template() self.project_name = project_name self.run_name = run_name + self.streaming_manager = StreamingManager() + if streaming == "console": + self.streaming_manager.add_observer(ConsoleStreamObserver()) + else: + # TODO: Support other streaming modes + raise ValueError(f"Streaming mode {streaming} is not supported.") super().__init__() if kwargs: warnings.warn(f"Unused arguments for agent initialization: {kwargs}") @@ -118,6 +126,27 @@ async def generate_async(self, messages_list_or_inputs: List[List[Dict]], **args List of responses. """ return await self.llm_engine.generate_async(messages_list_or_inputs, **args) + + async def generate_streaming(self, messages_list_or_inputs: List[List[Dict]], streaming_callback=None, **args): + """ + Generate responses with streaming support. This method yields response chunks as they are generated. + + Args: + messages_list_or_inputs: List of messages to generate responses for. + streaming_callback: Optional callback function for streaming chunks. + **args: Additional arguments for generation. + + Yields: + str: Response chunks as they are generated. + """ + if hasattr(self.llm_engine, 'generate_streaming'): + async for chunk in self.llm_engine.generate_streaming(messages_list_or_inputs, streaming_callback=streaming_callback, **args): + yield chunk + else: + # Fallback to non-streaming generation + responses = await self.generate_async(messages_list_or_inputs, **args) + for response in responses: + yield response @property def timing_data(self): diff --git a/agents/agents/agents/chain/chain_base.py b/agents/agents/agents/chain/chain_base.py index 0ee77be..61cd9d7 100644 --- a/agents/agents/agents/chain/chain_base.py +++ b/agents/agents/agents/chain/chain_base.py @@ -2,8 +2,9 @@ from collections import defaultdict from dataclasses import dataclass, field import json +import time from ...utils.timing import Timer -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, Callable import uuid from termcolor import colored import numpy as np @@ -13,6 +14,8 @@ from ...utils.monitor import JsonlSink, MetricEvent, Monitor, WandbSink, emit, serialize_for_json from ... import AGENT_DATA_DIR import wandb +from .streaming_observer import ConsoleStreamObserver, StreamingManager, StreamEvent, StreamEventType + @dataclass class Node: is_terminal: bool = False @@ -210,40 +213,35 @@ def prepare_chain_messages(self, start_messages: Union[List[dict], np.ndarray]): other_info_list.append(info) return messages_list, other_info_list + + def validate_run_args(self, max_steps: int, num_chains: int): + assert max_steps >= 1, "max_steps must be at least 1." + assert num_chains >= 1, "num_chains must be at least 1." + for observer in self.streaming_manager.observers: + if isinstance(observer, ConsoleStreamObserver): + assert num_chains == 1, "num_chains must be 1 when ConsoleStreamObserver is used." + async def run_async(self, max_steps: int, start_messages: Union[List[dict], np.ndarray], num_chains: int, - generation_config: Optional[Dict[str, Any]] = None + generation_config: Optional[Dict[str, Any]] = None, + enable_streaming: bool = False, + streaming_callback: Optional[Callable] = None, ): """ - Run the chain-based rollout. + Run the chain-based rollout with optional streaming support. Args: max_steps: Maximum number of steps for each chain. - start_messages: List of messages to start the chains. Each message should be a dict - with "messages" key containing a list of message dictionaries. + start_messages: List of messages to start the chains. num_chains: Number of chains to run for each message. generation_config: Generation configuration dictionary. - - Example: - .. code-block:: python - - start_messages = [ - { - "messages": [{"role": "user", "content": "..."}], - "info": {"question": "..."}, - "answer": "..." - }, - { - "messages": [{"role": "user", "content": "..."}], - "info": {"question": "..."}, - "answer": "..." - } - ] + enable_streaming: Whether to enable streaming mode. + streaming_callback: Optional callback for streaming events. """ - assert max_steps >= 1, "max_steps must be at least 1." + self.validate_run_args(max_steps, num_chains) Monitor.ensure_started() self.reset() messages_list, other_info_list = self.prepare_chain_messages(start_messages) @@ -257,27 +255,19 @@ async def run_async(self, done_q = asyncio.Queue() tasks = [ asyncio.create_task( - self._run_chain_async( - cid, - node, - chains[cid], - tool_schemas, - max_steps=max_steps, - done_queue=done_q) + self._run_single_chain( + cid, + node, + chains[cid], + tool_schemas, + max_steps=max_steps, + done_queue=done_q, + enable_streaming=enable_streaming + ) ) for cid, node in first_nodes.items() ] - # Throttle the number of concurrent chains - # print([tool.parallel_size for tool in self.tools]) - - # minimal_tool_parallel_size = 1 - # sem = asyncio.Semaphore(minimal_tool_parallel_size) - # async def guarded_run(cid, *args): - # async with sem: - # return await self._run_chain_async(cid, *args) - # tasks = [guarded_run(cid, node, chains[cid], max_steps, done_q) for cid, node in first_nodes.items()] - # await asyncio.gather(*tasks) await tqdm_asyncio.gather(*tasks) self.chains = {} @@ -289,29 +279,31 @@ async def run_async(self, self.global_step += 1 self.monitor_step() - async def _run_chain_async(self, + async def _run_single_chain(self, chain_id: str, first_node: Node, chain: Chain, tools: List[Dict], max_steps: int, - done_queue: asyncio.Queue + done_queue: asyncio.Queue, + enable_streaming: bool = False ): """ - Drives *one* trajectory until it terminates or max_steps is reached. - Writes (chain_id, chain) to done_queue when finished. + Run a single chain with optional streaming support. """ current_node = first_node depth = 0 - final_result = None have_set_tools = False while not current_node.is_terminal and depth < max_steps: - newest_messages = current_node.messages + newest_messages = deepcopy(current_node.messages) + if not current_node.is_terminal: - responses = await self.generate_async([current_node.messages], tools=tools, num_return_sequences=1) - new_msg = self.parse(responses, self.tools) - new_msg = new_msg[0] + # Generate response + new_msg = await self._generate_response( + current_node, tools, depth, chain_id, enable_streaming + ) + newest_messages.append(new_msg) thought_node = chain.add_node( type="Thought", @@ -321,47 +313,186 @@ async def _run_chain_async(self, thought_node.is_terminal = new_msg.get("status", "continue") in self.terminal_status current_node = thought_node + # Handle tool calls if current_node.messages[-1].get("tool_calls"): for tool_call in current_node.messages[-1]["tool_calls"]: - tool_name = tool_call["function"]["name"] - tool_input = tool_call["function"]["arguments"] - action_node = chain.add_node( - type="Action", - messages=deepcopy(newest_messages), - description=tool_name + current_node = await self._execute_tool_call( + tool_call, newest_messages, chain, chain_id, depth, + have_set_tools, enable_streaming ) - if not have_set_tools: - await self.set_tools(chain_id, chain.info) - have_set_tools = True - - result = await submit_tool_call(tool_name, tool_input, id=chain_id) - final_result = result - action_input_node = chain.add_node( - type="Action Input", - messages=deepcopy(newest_messages), - description=result.get("arguments", "") - ) - observation = result["observation"] - observation_json = json.dumps({ - "name": result["name"], - "content": observation, - }, indent=4) - action_input_node.observation = observation_json - action_input_node.observation_code = result["status"] - newest_messages.append({ - "role": "tool", - "tool_call_id": tool_call["id"], - "content": [{"type": "text", "text": observation_json}], - }) - action_input_node.messages = deepcopy(newest_messages) - action_input_node.is_terminal = result["status"] in self.terminal_status - current_node = action_input_node + have_set_tools = True else: - # If there is no tool call, we assume the chain is finished + # No tool calls, chain is finished break depth += 1 + # Finalize chain + await self._finalize_chain(chain_id, chain, current_node, depth) + await done_queue.put((chain_id, chain, current_node)) + + self.finished_chains_count += 1 + self.monitor_chain() + + async def _generate_response(self, current_node, tools, depth, chain_id, enable_streaming): + """Generate response with optional streaming support.""" + if enable_streaming: + # Emit generation start event + await self.streaming_manager.emit_event(StreamEvent( + event_type=StreamEventType.LLM_GENERATION_START, + chain_id=chain_id, + timestamp=time.time(), + data={"depth": depth}, + step=depth, + depth=depth + )) + + # Check if we have streaming capabilities + has_streaming = False + if hasattr(self, 'generate_streaming'): + has_streaming = True + elif hasattr(self, 'llm_engine') and hasattr(self.llm_engine, 'generate_streaming'): + has_streaming = True + # Create a wrapper to use the LLM engine's streaming + async def generate_streaming_wrapper(messages_list, **kwargs): + async for chunk in self.llm_engine.generate_streaming(messages_list, **kwargs): + yield chunk + self.generate_streaming = generate_streaming_wrapper + + if has_streaming: + # Collect full response from streaming + full_response = "" + async for chunk in self.generate_streaming([current_node.messages], tools=tools): + await self.streaming_manager.emit_event(StreamEvent( + event_type=StreamEventType.LLM_GENERATION_CHUNK, + chain_id=chain_id, + timestamp=time.time(), + data={"content": chunk}, + step=depth, + depth=depth + )) + # chunk is the whole generated text + full_response = chunk + + # Emit generation end event + await self.streaming_manager.emit_event(StreamEvent( + event_type=StreamEventType.LLM_GENERATION_END, + chain_id=chain_id, + timestamp=time.time(), + data={"full_response": full_response}, + step=depth, + depth=depth + )) + + # Parse response + new_msg = self.parse([full_response], self.tools) + return new_msg[0] + else: + # Fallback to non-streaming generation + responses = await self.generate_async([current_node.messages], tools=tools, num_return_sequences=1) + new_msg = self.parse(responses, self.tools) + + # Emit a single chunk event for the full response + full_response = new_msg[0].get("content", "") + if isinstance(full_response, list) and len(full_response) > 0: + # Handle case where content is a list of content blocks + if isinstance(full_response[0], dict) and "text" in full_response[0]: + full_response = full_response[0]["text"] + else: + full_response = str(full_response) + elif not isinstance(full_response, str): + full_response = str(full_response) + + await self.streaming_manager.emit_event(StreamEvent( + event_type=StreamEventType.LLM_GENERATION_CHUNK, + chain_id=chain_id, + timestamp=time.time(), + data={"content": full_response}, + step=depth, + depth=depth + )) + + # Emit generation end event + await self.streaming_manager.emit_event(StreamEvent( + event_type=StreamEventType.LLM_GENERATION_END, + chain_id=chain_id, + timestamp=time.time(), + data={"full_response": full_response}, + step=depth, + depth=depth + )) + + return new_msg[0] + else: + # Non-streaming generation + responses = await self.generate_async([current_node.messages], tools=tools, num_return_sequences=1) + new_msg = self.parse(responses, self.tools) + return new_msg[0] + + async def _execute_tool_call(self, tool_call, newest_messages, chain, chain_id, depth, have_set_tools, enable_streaming): + """Execute a tool call with optional streaming support.""" + tool_name = tool_call["function"]["name"] + tool_input = tool_call["function"]["arguments"] + + # Create action node + action_node = chain.add_node( + type="Action", + messages=deepcopy(newest_messages), + description=tool_name + ) + + # Set up tools if needed + if not have_set_tools: + await self.set_tools(chain_id, chain.info) + have_set_tools = True + + # Execute tool call + result = await submit_tool_call(tool_name, tool_input, id=chain_id) + + if enable_streaming: + # Emit tool observation event + await self.streaming_manager.emit_event(StreamEvent( + event_type=StreamEventType.TOOL_OBSERVATION, + chain_id=chain_id, + timestamp=time.time(), + data={ + "tool_name": tool_name, + "observation": result["observation"], + "status": result["status"] + }, + step=depth, + depth=depth + )) + + + # Create action input node + action_input_node = chain.add_node( + type="Action Input", + messages=deepcopy(newest_messages), + description=result.get("arguments", "") + ) + + # Process observation + observation = result["observation"] + observation_json = json.dumps({ + "name": result["name"], + "content": observation, + }, indent=4) + + action_input_node.observation = observation_json + action_input_node.observation_code = result["status"] + newest_messages.append({ + "role": "tool", + "tool_call_id": tool_call["id"], + "content": [{"type": "text", "text": observation_json}], + }) + action_input_node.messages = deepcopy(newest_messages) + action_input_node.is_terminal = result["status"] in self.terminal_status + + return action_input_node + + async def _finalize_chain(self, chain_id, chain, current_node, depth): + """Finalize the chain with reward calculation and cleanup.""" if self._reward_fn is not None: trajectory = current_node.messages final_response = self.extract_final_response(trajectory) @@ -369,12 +500,9 @@ async def _run_chain_async(self, chain.info["reward"] = await self._reward_fn(prediction=final_response, **other_args, trajectory=trajectory, id=chain_id) else: chain.info["reward"] = None + await self.release_resources(chain_id) - await done_queue.put((chain_id, chain, current_node)) - - self.finished_chains_count += 1 - self.monitor_chain() async def release_resources(self, id: str) -> None: for tool in self.tools: diff --git a/agents/agents/agents/chain/streaming_observer.py b/agents/agents/agents/chain/streaming_observer.py new file mode 100644 index 0000000..f2ab24e --- /dev/null +++ b/agents/agents/agents/chain/streaming_observer.py @@ -0,0 +1,254 @@ +import asyncio +from abc import ABC, abstractmethod +from dataclasses import dataclass +import os +from typing import Any, Dict, List, Optional, Callable, AsyncGenerator, Set +from enum import Enum +import json +import time +from termcolor import colored + + +class StreamEventType(Enum): + """Types of streaming events""" + LLM_GENERATION_START = "llm_generation_start" + LLM_GENERATION_CHUNK = "llm_generation_chunk" + LLM_GENERATION_END = "llm_generation_end" + TOOL_CALL_START = "tool_call_start" + TOOL_CALL_END = "tool_call_end" + TOOL_OBSERVATION = "tool_observation" + CHAIN_START = "chain_start" + CHAIN_END = "chain_end" + ERROR = "error" + + +@dataclass +class StreamEvent: + """A streaming event with metadata""" + event_type: StreamEventType + chain_id: str + timestamp: float + data: Dict[str, Any] + step: Optional[int] = None + depth: Optional[int] = None + + def __post_init__(self): + # Add a unique identifier for this event + if not hasattr(self, 'event_id'): + self.event_id = f"{self.chain_id}_{self.timestamp}_{self.event_type.value}" + + +class StreamObserver(ABC): + """Abstract base class for stream observers""" + + @abstractmethod + async def on_event(self, event: StreamEvent) -> None: + """Handle a streaming event""" + pass + + async def on_error(self, error: Exception, chain_id: str) -> None: + """Handle an error event""" + event = StreamEvent( + event_type=StreamEventType.ERROR, + chain_id=chain_id, + timestamp=time.time(), + data={"error": str(error), "error_type": type(error).__name__} + ) + await self.on_event(event) + + +class StreamingManager: + """Manages streaming observers and event distribution""" + + def __init__(self): + self.observers: List[StreamObserver] = [] + self.enabled = False + self.active_chains: Set[str] = set() + self.chain_events: Dict[str, List[StreamEvent]] = {} + + def add_observer(self, observer: StreamObserver) -> None: + """Add a streaming observer""" + self.observers.append(observer) + self.enabled = True + + def remove_observer(self, observer: StreamObserver) -> None: + """Remove a streaming observer""" + if observer in self.observers: + self.observers.remove(observer) + if not self.observers: + self.enabled = False + + async def emit_event(self, event: StreamEvent) -> None: + """Emit an event to all observers""" + if not self.enabled: + return + + # Track active chains + if event.event_type == StreamEventType.CHAIN_START: + self.active_chains.add(event.chain_id) + self.chain_events[event.chain_id] = [] + elif event.event_type == StreamEventType.CHAIN_END: + self.active_chains.discard(event.chain_id) + + # Store event for this chain + if event.chain_id in self.chain_events: + self.chain_events[event.chain_id].append(event) + + tasks = [observer.on_event(event) for observer in self.observers] + await asyncio.gather(*tasks, return_exceptions=True) + + async def emit_error(self, error: Exception, chain_id: str) -> None: + """Emit an error event to all observers""" + if not self.enabled: + return + + tasks = [observer.on_error(error, chain_id) for observer in self.observers] + await asyncio.gather(*tasks, return_exceptions=True) + + def get_chain_events(self, chain_id: str) -> List[StreamEvent]: + """Get all events for a specific chain""" + return self.chain_events.get(chain_id, []) + + def get_active_chains(self) -> Set[str]: + """Get all currently active chain IDs""" + return self.active_chains.copy() + + +class ConsoleStreamObserver(StreamObserver): + """Simple console-based stream observer for debugging""" + + def __init__(self, show_timestamps: bool = True, chain_filter: Optional[str] = None): + self.show_timestamps = show_timestamps + self.chain_filter = chain_filter # Only show events for this chain_id + self.chain_colors = ["red", "green", "blue", "yellow", "magenta", "cyan"] + self.chain_id_data = {} + + + async def on_event(self, event: StreamEvent) -> None: + # Filter by chain if specified + if self.chain_filter and event.chain_id != self.chain_filter: + return + + turn_info = f" (turn {event.step})" if event.step is not None else "" + + # Use different colors for different chains + chain_index = hash(event.chain_id) % len(self.chain_colors) + chain_color = self.chain_colors[chain_index] + if event.chain_id not in self.chain_id_data: + self.chain_id_data[event.chain_id] = { + "color": chain_color, + "timestamp": event.timestamp, + "step": event.step, + "depth": event.depth, + "event_type": event.event_type.value, + "content_buffer": "" + } + + if event.event_type == StreamEventType.LLM_GENERATION_START: + print(f"{event.timestamp - self.chain_id_data[event.chain_id]['timestamp']:.2f}s {turn_info}".center(80, "="), flush=True) + elif event.event_type == StreamEventType.LLM_GENERATION_CHUNK: + content = event.data.get("content", "") + if content: + # clear the terminal + if self.chain_id_data[event.chain_id]["event_type"] == StreamEventType.LLM_GENERATION_CHUNK: + print(colored(f"""{content[len(self.chain_id_data[event.chain_id]["content_buffer"]):]}""", color=chain_color), end="", flush=True) + self.chain_id_data[event.chain_id]["content_buffer"] = content + else: + self.chain_id_data[event.chain_id]["content_buffer"] = content + print(colored(f"{content}", color=chain_color), end="", flush=True) + self.chain_id_data[event.chain_id]["event_type"] = StreamEventType.LLM_GENERATION_CHUNK + elif event.event_type == StreamEventType.LLM_GENERATION_END: + print(colored(f"\n{event.timestamp - self.chain_id_data[event.chain_id]['timestamp']:.2f}s", color=chain_color), flush=True) + self.chain_id_data[event.chain_id]["event_type"] = StreamEventType.LLM_GENERATION_END + elif event.event_type == StreamEventType.TOOL_OBSERVATION: + observation = event.data.get("observation", "") + tool_name = event.data.get("tool_name", "") + print(colored(f"Tool: [{tool_name}] {observation[:1024]}{'...' if len(observation) > 200 else ''}", color=chain_color)) + print(f"".center(80, "="), flush=True) + self.chain_id_data[event.chain_id]["event_type"] = StreamEventType.TOOL_OBSERVATION + elif event.event_type == StreamEventType.ERROR: + error_msg = event.data.get("error", "") + print(colored(f" ❌ Error: {error_msg}", color=chain_color)) + self.chain_id_data[event.chain_id]["event_type"] = StreamEventType.ERROR + + +class AsyncGeneratorStreamObserver(StreamObserver): + """Stream observer that yields events as an async generator""" + + def __init__(self, chain_filter: Optional[str] = None): + self.queue = asyncio.Queue() + self.chain_filter = chain_filter + + async def on_event(self, event: StreamEvent) -> None: + # Filter by chain if specified + if self.chain_filter and event.chain_id != self.chain_filter: + return + + await self.queue.put(event) + + async def events(self) -> AsyncGenerator[StreamEvent, None]: + """Yield events as they arrive""" + while True: + try: + event = await self.queue.get() + if event.event_type == StreamEventType.CHAIN_END: + # Send the final event and stop + yield event + break + yield event + except asyncio.CancelledError: + break + + +class ChainSpecificStreamObserver(StreamObserver): + """Stream observer that only handles events for a specific chain""" + + def __init__(self, target_chain_id: str, base_observer: StreamObserver): + self.target_chain_id = target_chain_id + self.base_observer = base_observer + + async def on_event(self, event: StreamEvent) -> None: + if event.chain_id == self.target_chain_id: + await self.base_observer.on_event(event) + + async def on_error(self, error: Exception, chain_id: str) -> None: + if chain_id == self.target_chain_id: + await self.base_observer.on_error(error, chain_id) + + +class MultiChainStreamObserver(StreamObserver): + """Stream observer that organizes events by chain""" + + def __init__(self): + self.chain_observers: Dict[str, List[StreamObserver]] = {} + self.global_observers: List[StreamObserver] = [] + + def add_chain_observer(self, chain_id: str, observer: StreamObserver) -> None: + """Add an observer for a specific chain""" + if chain_id not in self.chain_observers: + self.chain_observers[chain_id] = [] + self.chain_observers[chain_id].append(observer) + + def add_global_observer(self, observer: StreamObserver) -> None: + """Add an observer for all chains""" + self.global_observers.append(observer) + + async def on_event(self, event: StreamEvent) -> None: + # Send to chain-specific observers + if event.chain_id in self.chain_observers: + tasks = [obs.on_event(event) for obs in self.chain_observers[event.chain_id]] + await asyncio.gather(*tasks, return_exceptions=True) + + # Send to global observers + tasks = [obs.on_event(event) for obs in self.global_observers] + await asyncio.gather(*tasks, return_exceptions=True) + + async def on_error(self, error: Exception, chain_id: str) -> None: + # Send to chain-specific observers + if chain_id in self.chain_observers: + tasks = [obs.on_error(error, chain_id) for obs in self.chain_observers[chain_id]] + await asyncio.gather(*tasks, return_exceptions=True) + + # Send to global observers + tasks = [obs.on_error(error, chain_id) for obs in self.global_observers] + await asyncio.gather(*tasks, return_exceptions=True) \ No newline at end of file diff --git a/agents/agents/agents/chain/websocket_streaming.py b/agents/agents/agents/chain/websocket_streaming.py new file mode 100644 index 0000000..4ce4fc6 --- /dev/null +++ b/agents/agents/agents/chain/websocket_streaming.py @@ -0,0 +1,227 @@ +""" +WebSocket-based streaming interface for real-time agent interactions. +This module provides a WebSocket server that can stream agent events to web clients. +""" + +import asyncio +import json +import websockets +from typing import Dict, Set, Optional, Callable +from .streaming_observer import StreamObserver, StreamEvent, StreamEventType +import logging + +logger = logging.getLogger(__name__) + + +class WebSocketStreamObserver(StreamObserver): + """Stream observer that broadcasts events to WebSocket clients""" + + def __init__(self): + self.clients: Set[websockets.WebSocketServerProtocol] = set() + self.lock = asyncio.Lock() + + async def on_event(self, event: StreamEvent) -> None: + """Broadcast event to all connected WebSocket clients""" + if not self.clients: + return + + # Convert event to JSON + event_data = { + "event_type": event.event_type.value, + "chain_id": event.chain_id, + "timestamp": event.timestamp, + "step": event.step, + "depth": event.depth, + "data": event.data + } + + message = json.dumps(event_data) + + # Broadcast to all clients + disconnected_clients = set() + async with self.lock: + for client in self.clients: + try: + await client.send(message) + except websockets.exceptions.ConnectionClosed: + disconnected_clients.add(client) + except Exception as e: + logger.error(f"Error sending to WebSocket client: {e}") + disconnected_clients.add(client) + + # Remove disconnected clients + self.clients -= disconnected_clients + + async def add_client(self, websocket: websockets.WebSocketServerProtocol) -> None: + """Add a new WebSocket client""" + async with self.lock: + self.clients.add(websocket) + logger.info(f"WebSocket client connected. Total clients: {len(self.clients)}") + + async def remove_client(self, websocket: websockets.WebSocketServerProtocol) -> None: + """Remove a WebSocket client""" + async with self.lock: + self.clients.discard(websocket) + logger.info(f"WebSocket client disconnected. Total clients: {len(self.clients)}") + + +class WebSocketStreamingServer: + """WebSocket server for streaming agent events""" + + def __init__(self, host: str = "localhost", port: int = 8765): + self.host = host + self.port = port + self.observer = WebSocketStreamObserver() + self.server = None + + async def handle_client(self, websocket, path): + """Handle individual WebSocket client connection""" + await self.observer.add_client(websocket) + try: + # Keep connection alive and handle incoming messages + async for message in websocket: + try: + data = json.loads(message) + # Handle client messages if needed + logger.info(f"Received message from client: {data}") + except json.JSONDecodeError: + logger.warning(f"Invalid JSON from client: {message}") + except Exception as e: + logger.error(f"Error handling client message: {e}") + except websockets.exceptions.ConnectionClosed: + pass + finally: + await self.observer.remove_client(websocket) + + async def start(self): + """Start the WebSocket server""" + self.server = await websockets.serve( + self.handle_client, + self.host, + self.port + ) + logger.info(f"WebSocket server started on ws://{self.host}:{self.port}") + return self.server + + async def stop(self): + """Stop the WebSocket server""" + if self.server: + self.server.close() + await self.server.wait_closed() + logger.info("WebSocket server stopped") + + def get_observer(self) -> WebSocketStreamObserver: + """Get the WebSocket stream observer""" + return self.observer + + +class WebSocketStreamingClient: + """WebSocket client for receiving streaming events""" + + def __init__(self, uri: str = "ws://localhost:8765"): + self.uri = uri + self.websocket = None + + async def connect(self): + """Connect to the WebSocket server""" + self.websocket = await websockets.connect(self.uri) + logger.info(f"Connected to WebSocket server at {self.uri}") + + async def disconnect(self): + """Disconnect from the WebSocket server""" + if self.websocket: + await self.websocket.close() + logger.info("Disconnected from WebSocket server") + + async def receive_events(self, event_handler: Optional[Callable] = None): + """Receive and handle streaming events""" + if not self.websocket: + await self.connect() + + try: + async for message in self.websocket: + try: + event_data = json.loads(message) + if event_handler: + await event_handler(event_data) + else: + # Default event handling + event_type = event_data.get("event_type") + chain_id = event_data.get("chain_id") + data = event_data.get("data", {}) + + if event_type == "llm_generation_chunk": + content = data.get("content", "") + print(f"🤖 Chain {chain_id}: {content}", end="", flush=True) + elif event_type == "tool_observation": + tool_name = data.get("tool_name", "") + observation = data.get("observation", "") + print(f"\n🔧 {tool_name}: {observation[:100]}...") + elif event_type == "chain_end": + print(f"\n✅ Chain {chain_id} completed!") + + except json.JSONDecodeError: + logger.warning(f"Invalid JSON received: {message}") + except Exception as e: + logger.error(f"Error handling event: {e}") + except websockets.exceptions.ConnectionClosed: + logger.info("WebSocket connection closed") + except Exception as e: + logger.error(f"WebSocket error: {e}") + + +# Example usage functions +async def start_websocket_server(): + """Start the WebSocket streaming server""" + server = WebSocketStreamingServer() + await server.start() + return server + + +async def run_agent_with_websocket_streaming(agent, start_messages, max_steps=5, num_chains=1): + """Run an agent with WebSocket streaming""" + + # Start WebSocket server + server = await start_websocket_server() + + # Add WebSocket observer to agent + agent.streaming_manager.add_observer(server.get_observer()) + + try: + # Run the agent + await agent.run_async( + max_steps=max_steps, + start_messages=start_messages, + num_chains=num_chains, + enable_streaming=True + ) + finally: + # Stop the server + await server.stop() + + +async def connect_and_monitor(): + """Connect to WebSocket server and monitor events""" + client = WebSocketStreamingClient() + + async def event_handler(event_data): + """Custom event handler""" + event_type = event_data.get("event_type") + chain_id = event_data.get("chain_id") + + if event_type == "llm_generation_start": + print(f"🚀 Chain {chain_id}: Starting LLM generation...") + elif event_type == "tool_call_start": + tool_name = event_data.get("data", {}).get("tool_name", "") + print(f"🔧 Chain {chain_id}: Calling tool {tool_name}...") + elif event_type == "chain_end": + final_depth = event_data.get("data", {}).get("final_depth", 0) + reward = event_data.get("data", {}).get("reward") + print(f"✅ Chain {chain_id}: Completed in {final_depth} steps (reward: {reward})") + + await client.receive_events(event_handler) + + +if __name__ == "__main__": + # Example: Start WebSocket server + asyncio.run(start_websocket_server()) \ No newline at end of file diff --git a/agents/agents/agents/llm_backend.py b/agents/agents/agents/llm_backend.py index bc527f7..2401137 100644 --- a/agents/agents/agents/llm_backend.py +++ b/agents/agents/agents/llm_backend.py @@ -7,7 +7,7 @@ from collections import deque from functools import partial import time -from typing import Dict, Any, List, Optional +from typing import Dict, Any, List, Optional, Callable, AsyncGenerator import uuid from .templates.utils import convert_messages_to_openai_format import numpy as np @@ -54,6 +54,10 @@ def apply_chat_template(self, messages_list: List[List[Dict]], template: str, ad def generate(self, messages_list: str, **kwargs) -> str: """Generate text from prompt""" raise NotImplementedError("Subclasses must implement generate()") + + async def generate_streaming(self, messages_list: List[List[Dict]], streaming_callback: Optional[Callable] = None, **kwargs) -> AsyncGenerator[str, None]: + """Generate text with streaming support""" + raise NotImplementedError("Subclasses must implement generate_streaming()") class TransformersBackend(LLMBackend): """HuggingFace Transformers implementation""" @@ -100,8 +104,50 @@ def generate(self, messages_list: str, **kwargs) -> str: response_texts = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) return response_texts - def generate_async(self, messages_list: str, **kwargs) -> str: - raise NotImplementedError("Transformers backend does not support async generation") + async def generate_async(self, messages_list: str, **kwargs) -> str: + """Async wrapper for generate""" + return self.generate(messages_list, **kwargs) + + async def generate_streaming(self, messages_list: List[List[Dict]], streaming_callback: Optional[Callable] = None, **kwargs) -> AsyncGenerator[str, None]: + """Generate text with streaming support using Transformers""" + max_new_tokens = kwargs.get("max_new_tokens", self.max_new_tokens) + temperature = kwargs.get("temperature", self.temperature) + + prompts, _ = self.apply_chat_template(messages_list, self.template) + + inputs = self.tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left").to(self.llm_engine.device) + input_length = inputs['input_ids'].shape[1] + + # Use streaming generation + generated_tokens = [] + for i in range(max_new_tokens): + outputs = self.llm_engine.generate( + **inputs, + max_new_tokens=1, + temperature=temperature, + do_sample=temperature > 0, + pad_token_id=self.tokenizer.eos_token_id, + use_cache=True + ) + + new_token = outputs[0][-1].unsqueeze(0) + generated_tokens.append(new_token) + + # Decode the new token + new_text = self.tokenizer.decode(new_token, skip_special_tokens=True) + + if streaming_callback: + await streaming_callback(new_text) + + yield new_text + + # Check for EOS + if new_token.item() == self.tokenizer.eos_token_id: + break + + # Update input for next iteration + inputs['input_ids'] = torch.cat([inputs['input_ids'], new_token.unsqueeze(0)], dim=1) + inputs['attention_mask'] = torch.cat([inputs['attention_mask'], torch.ones(1, 1, device=inputs['attention_mask'].device)], dim=1) class VLLMBackend(LLMBackend): """vLLM implementation""" @@ -156,6 +202,35 @@ def generate(self, messages_list: str, **kwargs) -> str: def generate_async(self, messages_list: str, **kwargs) -> str: raise NotImplementedError("VLLM backend does not support async generation") + async def generate_streaming(self, messages_list: List[List[Dict]], streaming_callback: Optional[Callable] = None, **kwargs) -> AsyncGenerator[str, None]: + """Generate text with streaming support using vLLM""" + max_new_tokens = kwargs.get("max_new_tokens", self.max_new_tokens) + temperature = kwargs.get("temperature", self.temperature) + sampling_params = SamplingParams( + n=1, + max_tokens=max_new_tokens, + temperature=temperature, + ) + + tools = kwargs.get("tools", None) + prompts, vision_inputs = self.apply_chat_template(messages_list, self.template, tools=tools) + inputs = self._process_inputs(prompts, vision_inputs) + + # For streaming, we process one input at a time + for input_data in inputs: + outputs_gen = self.llm_engine.generate( + input_data, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + ) + + async for output in outputs_gen: + for sequence in output.outputs: + # Stream each token + if hasattr(sequence, 'text'): + if streaming_callback: + await streaming_callback(sequence.text) + yield sequence.text class AsyncVLLMBackend(LLMBackend): """Async vLLM implementation""" @@ -221,6 +296,34 @@ async def generate_async(self, messages_list: str, **kwargs) -> str: LOGGER.debug(f"[AsyncVLLMBackend] response_texts: {response_texts}") return response_texts + + async def generate_streaming(self, messages_list: List[List[Dict]], **kwargs) -> AsyncGenerator[str, None]: + """Generate text with streaming support using Async vLLM""" + max_new_tokens = kwargs.get("max_new_tokens", self.max_new_tokens) + temperature = kwargs.get("temperature", self.temperature) + sampling_params = SamplingParams( + n=1, + max_tokens=max_new_tokens, + temperature=temperature, + ) + + tools = kwargs.get("tools", None) + prompts, vision_inputs = self.apply_chat_template(messages_list, self.template, tools=tools) + inputs = self._process_inputs(prompts, vision_inputs) + + # For streaming, we process one input at a time + for input_data in inputs: + outputs_gen = self.llm_engine.generate( + input_data, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + ) + + async for output in outputs_gen: + for sequence in output.outputs: + # Stream each token + if hasattr(sequence, 'text'): + yield sequence.text class VerlBackend(LLMBackend): """Verl implementation""" diff --git a/agents/agents/agents/react/react_agent.py b/agents/agents/agents/react/react_agent.py index 765528f..e176149 100644 --- a/agents/agents/agents/react/react_agent.py +++ b/agents/agents/agents/react/react_agent.py @@ -15,29 +15,36 @@ def parse_react_step(text: str) -> Dict[str, Optional[str]]: """ - Parse a single ReAct-style step (one Thought→Action→Input) into its components. + Parse a single ReAct-style step into its components. Args: - text: A string containing exactly one Thought:, one Action:, and one Input:. + text: A string that may contain Thought:, Action:, and/or Input: components. Returns: - A dict with keys 'thought', 'action', and 'input', or None if not found. + A dict with keys 'thought', 'action', and 'input', with None for missing components. """ - pattern = re.compile( - r"Thought:\s*(?P.*?)\s*" - r"Action:\s*(?P.*?)\s*" - r"Input:\s*(?P.*)", - re.IGNORECASE | re.DOTALL - ) - m = pattern.search(text) - if not m: - return {"thought": None, "action": None, "input": None} - - return { - "thought": m.group("thought").strip(), - "action": m.group("action").strip(), - "input": m.group("input").strip(), - } + # Initialize result with None values + result = {"thought": None, "action": None, "input": None} + + # Pattern for Thought: + thought_pattern = re.compile(r"Thought:\s*(.*?)(?=\s*(?:Action:|Input:|$))", re.IGNORECASE | re.DOTALL) + thought_match = thought_pattern.search(text) + if thought_match: + result["thought"] = thought_match.group(1).strip() + + # Pattern for Action: + action_pattern = re.compile(r"Action:\s*(.*?)(?=\s*(?:Thought:|Input:|$))", re.IGNORECASE | re.DOTALL) + action_match = action_pattern.search(text) + if action_match: + result["action"] = action_match.group(1).strip() + + # Pattern for Input: + input_pattern = re.compile(r"Input:\s*(.*?)(?=\s*(?:Thought:|Action:|$))", re.IGNORECASE | re.DOTALL) + input_match = input_pattern.search(text) + if input_match: + result["input"] = input_match.group(1).strip() + + return result def extract_tool_calls(action_input: str) -> List[Dict]: if action_input is None: diff --git a/agents/agents/envs/manager/env_manager.py b/agents/agents/envs/manager/env_manager.py index c74dc7d..495d6d9 100644 --- a/agents/agents/envs/manager/env_manager.py +++ b/agents/agents/envs/manager/env_manager.py @@ -19,6 +19,7 @@ async def start(cls, env_cls: type[BaseEnv], size: int = 1, env_kwargs: dict | N or add more envs to the existing pool if the size is larger. If the size is smaller, do nothing. """ + # TODO: Currently, WarmPool will start all the envs at once. This should be fine for training, but might be wasteful for showing the demo, we may need to support feature to start a new env when acquiring, or make it a configurable option. key = env_cls if key not in cls._pools: cls._pools[key] = WarmPool( diff --git a/agents/agents/examples/streaming_example.py b/agents/agents/examples/streaming_example.py new file mode 100644 index 0000000..c4fd08b --- /dev/null +++ b/agents/agents/examples/streaming_example.py @@ -0,0 +1,59 @@ +import asyncio +from agents.agents.react.react_agent import ReactAgent +from agents.tools import code_interpreter +from agents.rewards import math_reward_tool +import json +from agents.agents.chain.streaming_observer import ConsoleStreamObserver + + +async def main(): + tools = [code_interpreter] + + agent = ReactAgent( + "Qwen/Qwen2.5-7B-Instruct", + tools=tools, + template="qwen2.5-no-tool", + backend="async_vllm", + reward_fn=math_reward_tool, + debug=True + ) + + + console_stream_observer = ConsoleStreamObserver() + agent.streaming_manager.add_observer(console_stream_observer) + + + question1 = "Every morning Aya goes for a $9$-kilometer-long walk and stops at a coffee shop afterwards. When she walks at a constant speed of $s$ kilometers per hour, the walk takes her 4 hours, including $t$ minutes spent in the coffee shop. When she walks $s+2$ kilometers per hour, the walk takes her 2 hours and 24 minutes, including $t$ minutes spent in the coffee shop. Suppose Aya walks at $s+\frac{1}{2}$ kilometers per hour. Find the number of minutes the walk takes her, including the $t$ minutes spent in the coffee shop." + answer1 = "204" + question2 = "$P(x)$ is a polynomial of degree $3n$ such that\n\\begin{eqnarray*} P(0) = P(3) = \\cdots &=& P(3n) = 2, \\\\ P(1) = P(4) = \\cdots &=& P(3n-2) = 1, \\\\ P(2) = P(5) = \\cdots &=& P(3n-1) = 0, \\quad\\text{ and }\\\\ && P(3n+1) = 730.\\end{eqnarray*}\nDetermine $n$." + answer2 = "3" + + + messages = [ + { + "messages": [ + {"role": "user", "content": f"{question1}"} + ], + "question": f"{question1}", + "answer": f"{answer1}" + }, + # { + # "messages": [ + # {"role": "user", "content": f"{question2}"} + # ], + # "question": f"{question2}", + # "answer": f"{answer2}" + # } + ] + + + await agent.run_async( + max_steps=4, + start_messages=messages, + num_chains=1, + enable_streaming=True + ) + +if __name__ == "__main__": + asyncio.run(main()) + diff --git a/agents/agents/rewards/math_reward.py b/agents/agents/rewards/math_reward.py index 3536666..0b63515 100644 --- a/agents/agents/rewards/math_reward.py +++ b/agents/agents/rewards/math_reward.py @@ -489,8 +489,8 @@ def math_reward_tool(prediction: str, answer: str, trajectory: List[Dict]) -> fl "acc": 1.0 if answer_correct else 0.0, } -@reward(name="math_reward_thought") -def math_reward_thought(prediction: str, answer: str, trajectory: List[Dict]) -> float: +@reward(name="math_reward_thought_with_tool") +def math_reward_thought_with_tool(prediction: str, answer: str, trajectory: List[Dict]) -> float: has_called_tool = False for msg in trajectory: if msg["role"] == "tool": @@ -519,7 +519,7 @@ def math_reward_thought(prediction: str, answer: str, trajectory: List[Dict]) -> elif has_called_tool and all_have_thought and not answer_correct: reward = 0.1 elif has_called_tool and not all_have_thought and answer_correct: - reward = 0.1 + reward = 0.0 elif has_called_tool and all_have_thought and answer_correct: reward = 1.0 else: diff --git a/agents/agents/rewards/qa_reward.py b/agents/agents/rewards/qa_reward.py index 83b596f..583519f 100644 --- a/agents/agents/rewards/qa_reward.py +++ b/agents/agents/rewards/qa_reward.py @@ -52,11 +52,11 @@ def em_score(prediction, ground_truth): @reward(name="qa_f1_reward") -def qa_f1_reward(prediction: str, golden_answer: str, trajectory: List[str]) -> float: +def qa_f1_reward(prediction: str, answer: str, trajectory: List[str]) -> float: # Extract answer from agent's response response = prediction - f1, precision, recall = f1_score(response, golden_answer) - em = em_score(response, golden_answer) + f1, precision, recall = f1_score(response, answer) + em = em_score(response, answer) return { "reward": f1, diff --git a/agents/agents/tools/src/search/dense_retriever.py b/agents/agents/tools/src/search/dense_retriever.py index 9ea0288..d8b77d2 100644 --- a/agents/agents/tools/src/search/dense_retriever.py +++ b/agents/agents/tools/src/search/dense_retriever.py @@ -4,7 +4,7 @@ from transformers import AutoTokenizer, AutoModel from torch import Tensor from ...tool_base import tool -from ....__init__ import AGENT_DATA_DIR +from ....__init__ import AGENT_CACHE_DIR def load_corpus(corpus_path: str): corpus = datasets.load_dataset( @@ -49,12 +49,12 @@ async def search(self, queries: list[str], top_k: int): @tool(name="dense_retrieve", description="Use a dense retriever to retrieve documents from a corpus.", max_length=4096) async def dense_retrieve(query: str): - global AGENT_DATA_DIR + global AGENT_CACHE_DIR if not query.startswith("query:"): query = "query: " + query global GLOBAL_RETRIEVER if GLOBAL_RETRIEVER is None: - GLOBAL_RETRIEVER = DenseRetriever(corpus_file=os.path.join(AGENT_DATA_DIR, "search", "wiki-18.jsonl"), index_file=os.path.join(AGENT_DATA_DIR, "search", "e5_Flat.index")) + GLOBAL_RETRIEVER = DenseRetriever(corpus_file=os.path.join(AGENT_CACHE_DIR, "data", "search", "wiki-18.jsonl"), index_file=os.path.join(AGENT_CACHE_DIR, "data", "search", "e5_Flat.index")) doc_list = await GLOBAL_RETRIEVER.search(query, 3) doc_list = doc_list[0] content = "" diff --git a/agents/tests/unit/agents/test_react_agent.py b/agents/tests/unit/agents/test_react_agent.py index 1e84eda..dcb2043 100644 --- a/agents/tests/unit/agents/test_react_agent.py +++ b/agents/tests/unit/agents/test_react_agent.py @@ -10,7 +10,7 @@ def test_react_agent_initialization(): agent = ReactAgent( "Qwen/Qwen2.5-3B-Instruct", tools=tools, - template="qwen-7b-chat", + template="qwen2.5", task_info=task_info, backend="client" ) @@ -35,7 +35,7 @@ def test_parse_react_step(): # Test with missing components text_missing = "Thought: I'm thinking about something." result_missing = parse_react_step(text_missing) - assert result_missing["thought"] is None + assert result_missing["thought"] == "I'm thinking about something." assert result_missing["action"] is None assert result_missing["input"] is None @@ -45,25 +45,19 @@ def test_react_agent_parse(): agent = ReactAgent( "Qwen/Qwen2.5-3B-Instruct", tools=tools, - template="qwen-7b-chat", + template="qwen2.5", backend="client" ) - # Mock the generate method to return a predefined response - def mock_generate(*args, **kwargs): - return ["""Thought: I need to search for information. + responses = ["""Thought: I need to search for information. Action: google_search Input: {"query": "test query"}"""] - agent.generate = mock_generate - - # Test the parse method - messages_list = [[{"role": "user", "content": "Find information about Python"}]] - result = agent.parse(messages_list, tools) - + result = agent.parse(responses, tools) + print(result) assert len(result) == 1 assert result[0]["role"] == "assistant" - assert "Thought: I need to search for information." in result[0]["content"] + assert "Thought: I need to search for information." in result[0]["content"][0]["text"] assert len(result[0]["tool_calls"]) == 1 assert result[0]["tool_calls"][0]["function"]["name"] == "google_search" - assert result[0]["tool_calls"][0]["function"]["arguments"] == '{"query": "test query"}' \ No newline at end of file + assert result[0]["tool_calls"][0]["function"]["arguments"] == {"query": "test query"} \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index 80dd89c..3278b99 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -11,6 +11,7 @@ AgentFly is a scalable and extensible Agent-RL framework designed to empower LM start/installation start/training_example + start/agent_examples .. toctree:: :maxdepth: 2 diff --git a/docs/start/agent_examples.md b/docs/start/agent_examples.md new file mode 100644 index 0000000..2f5f3b0 --- /dev/null +++ b/docs/start/agent_examples.md @@ -0,0 +1,215 @@ +## Build an Agent + +### Use a Predefined Agent +We can specify the following arguments to use a predefined agent: + +- model_name: the path or name or the model, used to load weights +- tools: tools that will be used by the agent +- template: chat template +- backend: what type of backend + +The following shows an example to use Qwen2.5-7B-Instruct as a react agent: + +```python +from agents.agents.react.react_agent import ReactAgent +from agents.tools.src.code.tools import code_interpreter +from agents.tools.src.search.google_search import google_search_serper +from agents.tools.src.react.tools import answer + +tools = [google_search_serper, answer] + +task_info = "Use code to get answers. Result must be printed." + +react_agent = ReactAgent( + "Qwen/Qwen2.5-7B-Instruct", + tools=tools, + template="qwen2.5-no-tool", + task_info=task_info, + backend="async_vllm" +) + +question = "Solve the equation 2x + 5y = 4 such that sum of x and y is 7." +messages = [ + { + "messages": [ + {"role": "user", "content": f"{question}"} + ], + "question": f"{question}", + }, +] + +await react_agent.run_async( + max_steps=4, + start_messages=messages, + num_chains=5 # for the question, the agent will generate 5 trajectories +) + +``` + +After the rollout, we can obtain the trajectories: + +```python +react_agent.trajectories +``` + +Obtaining the rewards (if you specified reward function and give necessary parameters in input messages) +``` +react_agent.rewards) +``` + +### Customize Agent + +You can customize your own agent by defining how the agent do generation and handle tool calls. + +```python +class CustomizedAgent(BaseAgent): + def __init__(self, + **kwargs + ) + super().__init__(**kwargs) + + async def generate_async(self, messages_list: List[List[Dict]], **args): + return await self.llm_engine.generate_async(messages_list, **args) + + def parse(self, responses: List(str), tools): + # parse responses into tool calls + ... +``` + +### Use Trained Agent +We provide the following agent that we can try: + +- WebShop Agent: +```python +import asyncio +from agents.agents import ReactAgent +from agents.tools import webshop_browser +from agents.rewards import webshop_reward +from agents.agents.chain.streaming_observer import ConsoleStreamObserver + +tools = [webshop_browser] + +agent = ReactAgent( + "Agent-One/Qwen2.5-3B-Instruct-WebShop", + tools=tools, + template="qwen2.5", + backend="async_vllm", + streaming="console" +) + +question = "I am looking for a gluten free, 100% vegan plant based protein shake that is soy-free, and price lower than 40.00 dollars" + +messages = [ + { + "messages": [ + {"role": "user", "content": f"{question}"} + ], + "question": f"{question}" + }, +] + +await agent.run_async( + max_steps=10, + start_messages=messages, + num_chains=1, + enable_streaming=True +) +``` +- Science World Agent +```python +from agents.tools import scienceworld_explorer +from agents.rewards import scienceworld_reward + + +tools = [scienceworld_explorer] + +agent = ReactAgent( + "Agent-One/Qwen2.5-7B-Instruct-ScienceWorld-REINFORCEPP", + tools=tools, + template="qwen2.5-no-system-tool", + backend="async_vllm", + reward_fn=scienceworld_reward, + streaming="console" +) + +# question = "Your task is to boil ice cream. For compounds without a boiling point, combusting the substance is also acceptable. First, focus on the substance. Then, take actions that will cause it to change its state of matter." +# task_name = "boil" +# variation_idx = 14 + + +# question = "Your task is to find a(n) animal. First, focus on the thing. Then, move it to the green box in the kitchen." +# task_name = "find-animal" +# variation_idx = 150 + +# question = "Your task is to find a(n) plant. First, focus on the thing. Then, move it to the purple box in the bathroom." +# task_name = "find-plant" +# variation_idx = 191 + + +# question = "Your task is to grow a orange. This will require growing several plants, and them being crosspollinated to produce fruit. Seeds can be found in the kitchen. To complete the task, focus on the grown orange." +# task_name = "grow-fruit" +# variation_idx = 92 + +# question = "Your task is to determine which of the two inclined planes (bronze, copper) has the most friction. After completing your experiment, focus on the inclined plane with the most friction." +# task_name = "inclined-plane-friction-named-surfaces" +# variation_idx = 776 + +question = "Your task is to measure the melting point of orange juice, which is located around the kitchen. First, focus on the thermometer. Next, focus on the orange juice. If the melting point of orange juice is above 200.0 degrees celsius, focus on the yellow box. If the melting point of orange juice is below 200.0 degrees celsius, focus on the purple box. The boxes are located around the kitchen." +task_name = "measure-melting-point-known-substance" +variation_idx = 247 + +messages = [ + { + "messages": [ + {"role": "user", "content": f"{question}"} + ], + "question": f"{question}", + "task_name": task_name, + "variation_idx": variation_idx + }, +] + +await agent.run_async( + max_steps=20, + start_messages=messages, + num_chains=1, + enable_streaming=True +) + +print(agent.rewards) +``` + +- Retrieval Agent + +```python +from agents.tools import dense_retrieve, asyncdense_retrieve + +tools = [dense_retrieve] + +agent = ReactAgent( + "Agent-One/Qwen2.5-3B-Instruct-Retrieval-GRPO", + tools=tools, + template="qwen2.5-no-system-tool", + backend="async_vllm", + streaming="console" +) + +question = "Who is Geoffrey Hinton" + + +messages = [ + { + "messages": [ + {"role": "user", "content": f"{question}"} + ], + "question": f"{question}", + }, +] + +await agent.run_async( + max_steps=6, + start_messages=messages, + num_chains=1, + enable_streaming=True +) +``` \ No newline at end of file diff --git a/docs/start/use_agent.md b/docs/start/use_agent.md deleted file mode 100644 index 0ed10ea..0000000 --- a/docs/start/use_agent.md +++ /dev/null @@ -1,74 +0,0 @@ -## Build an Agent - -### Use a Predefined Agent -We can specify the following arguments to use a predefined agent: - -- model_name: the path or name or the model, used to load weights -- tools: tools that will be used by the agent -- template: chat template -- backend: what type of backend - -The following shows an example to use Qwen2.5-7B-Instruct as a react agent: - -```python -from agents.agents.react.react_agent import ReactAgent -from agents.tools.src.code.tools import code_interpreter -from agents.tools.src.search.google_search import google_search_serper -from agents.tools.src.react.tools import answer - -tools = [google_search_serper, answer] - -task_info = "Use code to get answers. Result must be printed." - -react_agent = ReactAgent( - "Qwen/Qwen2.5-7B-Instruct", - tools=tools, - template="qwen2.5-no-tool", - task_info=task_info, - backend="async_vllm" -) - -question = "Solve the equation 2x + 5y = 4 such that sum of x and y is 7." -messages = [ - { - "messages": [ - {"role": "user", "content": f"{question}"} - ], - "question": f"{question}", - }, -] - -await react_agent.run_async( - max_steps=4, - start_messages=messages, - num_chains=5 # for the question, the agent will generate 5 trajectories -) - -``` - -After the rollout, we can obtain the trajectories: - -```python -react_agent.trajectories -``` - -### Customize Agent - -You can customize your own agent by defining how the agent do generation and handle tool calls. - -```python -class CustomizedAgent(BaseAgent): - def __init__(self, - **kwargs - ) - super().__init__(**kwargs) - - async def generate_async(self, messages_list: List[List[Dict]], **args): - return await self.llm_engine.generate_async(messages_list, **args) - - def parse(self, responses: List(str), tools): - # parse responses into tool calls - ... -``` - - diff --git a/verl b/verl index de234c9..1b70157 160000 --- a/verl +++ b/verl @@ -1 +1 @@ -Subproject commit de234c9e7d0fa26e261c61ec2e0ee0307acd7376 +Subproject commit 1b70157b71ab01d7e09c96d81b5f7cfa70d5ee5a