Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 74 additions & 34 deletions agents/agents/agents/agent_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,28 @@

from .templates.templates import get_template
from ..__init__ import AGENT_DATA_DIR
from .llm_backend import AsyncVLLMBackend, AsyncVerlBackend, ClientBackend, TransformersBackend, VLLMBackend, VerlBackend
from .llm_backend import AsyncVLLMBackend, AsyncVerlBackend, ClientBackend, TransformersBackend, VLLMBackend
from ..utils.logging import get_logger
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from .templates.utils import is_vlm_template, tokenize_conversations
from .templates.utils import tokenize_conversations
from .templates.vision_processor import is_vision_template
from .chain.chain_base import ChainGeneration
import os
import transformers
import warnings
import logging
from .chain.streaming_observer import ConsoleStreamObserver, StreamingManager
from .utils.tokenizer import create_tokenizer
from .backend_config import BACKEND_CONFIGS
try:
from verl.protocol import DataProto
except ImportError:
print("verl can not be imported.")
pass

Logger = logging.getLogger(__name__)

class BaseAgent(ChainGeneration, ABC):
"""
Expand All @@ -34,12 +39,13 @@ class BaseAgent(ChainGeneration, ABC):
def __init__(
self,
model_name_or_path,
template: str,
template: str=None,
system_prompt: str = None,
tools: List = None,
max_length: int=8192,
debug: bool = False,
backend: str = "transformers",
backend_config: Any = None,
reward_fn: Callable = None,
log_file: str = "agent",
project_name: str = None,
Expand All @@ -65,9 +71,30 @@ def __init__(
self.tools = tools
self.system_prompt = system_prompt
self.model_name_or_path = model_name_or_path
self.llm_engine, self.tokenizer, self.processor = self._init_llm_engine(model_name_or_path, backend)

# Handle backend configuration
if backend_config is None:
# Use default configuration for the backend
config_class = BACKEND_CONFIGS.get(backend)
if config_class:
self.backend_config = config_class()
else:
self.backend_config = None
else:
self.backend_config = backend_config

self.llm_engine = self._init_llm_engine(model_name_or_path, backend)

# Create appropriate tokenizer for trajectory processing
self.tokenizer = create_tokenizer(model_name_or_path)

self._reward_fn = reward_fn
self.jinja_template = get_template(self.template).jinja_template()

if self.template is None:
self.jinja_template = None
else:
self.jinja_template = get_template(self.template).jinja_template()

self.project_name = project_name
self.run_name = run_name
self.streaming_manager = StreamingManager()
Expand All @@ -78,38 +105,59 @@ def __init__(
raise ValueError(f"Streaming mode {streaming} is not supported.")
super().__init__()
if kwargs:
warnings.warn(f"Unused arguments for agent initialization: {kwargs}")
# warnings.warn(f"Unused arguments for agent initialization: {kwargs}")
raise ValueError(f"Unused arguments for agent initialization: {kwargs}")

def _init_llm_engine(self, model_name_or_path: str, backend: str):
if isinstance(model_name_or_path, str):
# Extract backend-specific configuration
config_kwargs = {}
if self.backend_config:
config_kwargs = {k: v for k, v in self.backend_config.__dict__.items()
if not k.startswith('_')}

if backend == "transformers":
llm_engine = TransformersBackend(model_name_or_path, self.template, max_length=self.max_length)
elif backend == "vllm":
llm_engine = VLLMBackend(model_name_or_path, self.template, max_length=self.max_length)
llm_engine = TransformersBackend(
model_name_or_path,
self.template,
max_length=self.max_length,
**config_kwargs
)
elif backend == "async_vllm":
llm_engine = AsyncVLLMBackend(model_name_or_path, self.template, max_length=self.max_length)
elif backend == "verl":
llm_engine = VerlBackend(llm_engine=None, model_name_or_path=model_name_or_path, template=self.template, max_length=self.max_length)
llm_engine = AsyncVLLMBackend(
model_name_or_path,
self.template,
max_length=self.max_length,
**config_kwargs
)
elif backend == "async_verl":
llm_engine = AsyncVerlBackend(llm_engine=None, model_name_or_path=model_name_or_path, template=self.template, max_length=self.max_length)
llm_engine = AsyncVerlBackend(
llm_engine=None,
model_name_or_path=model_name_or_path,
template=self.template,
max_length=self.max_length,
**config_kwargs
)
elif backend == "client":
llm_engine = ClientBackend(model_name_or_path, self.template, max_length=self.max_length)
print(f"config_kwargs: {config_kwargs}")
llm_engine = ClientBackend(
model_name_or_path,
self.template,
max_length=self.max_length,
**config_kwargs
)
else:
raise ValueError(f"Backend {backend} is not supported.")
else:
raise ValueError("model_name_or_path must be a string.")

tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path)
if is_vlm_template(self.template):
processor = transformers.AutoProcessor.from_pretrained(model_name_or_path)
else:
processor = None
return llm_engine, tokenizer, processor
return llm_engine

def set_llm_engine(self, llm_engine: Any, tokenizer: Any):
def set_llm_engine(self, llm_engine: Any, tokenizer: Any, processor: Any):
assert self.backend == "async_verl", "Only async verl backend is supported for now"
self.llm_engine.llm_engine = llm_engine
self.tokenizer = tokenizer
self.processor = processor

def generate(self, messages_list_or_inputs: List[List[Dict]], **args):
return self.llm_engine.generate(messages_list_or_inputs, **args)
Expand Down Expand Up @@ -151,25 +199,17 @@ async def generate_streaming(self, messages_list_or_inputs: List[List[Dict]], st
@property
def timing_data(self):
return self.timer.timing_data

def forward(self, messages_list_or_inputs: List[List[Dict]], **args):
if isinstance(messages_list_or_inputs, List):
inputs = tokenize_conversations(messages_list_or_inputs, tokenizer=self.tokenizer, conv_template=self.template, max_length=self.max_length, processor=self.processor)
else:
raise ValueError("messages_list_or_inputs must be a list of messages or a dictionary of padded inputs.")

if isinstance(self.llm_engine, transformers.PreTrainedModel):
return self.llm_engine.forward(**inputs, **args) # Only support transformers models for now.
else:
raise ValueError("llm_engine must be a transformers.PretrainedModel.")

@property
def trajectories(self):
trajectories = self.get_messages()

return trajectories

def tokenize_trajectories(self, return_action_mask: bool = False, return_reward_mask: bool = False):
def tokenize_trajectories(self, tokenizer, return_action_mask: bool = False, return_reward_mask: bool = False):
if tokenizer is None:
tokenizer = self.tokenizer

trajectories = self.trajectories
self.logger.info("================ Trajectory ================")
self.logger.info(trajectories[0])
Expand All @@ -196,7 +236,7 @@ def tokenize_trajectories(self, return_action_mask: bool = False, return_reward_
info['last_response'] = last_response
other_info_list.append(info)

inputs = tokenize_conversations(messages_list, tokenizer=self.tokenizer, conv_template=self.template, processor=self.processor, max_length=self.max_length, return_reward_mask=return_reward_mask)
inputs = tokenize_conversations(messages_list, tokenizer=tokenizer, conv_template=self.template, processor=self.processor, max_length=self.max_length, return_reward_mask=return_reward_mask)
position_ids = torch.clip(torch.cumsum(inputs['attention_mask'], dim=-1) - 1, min=0, max=None)
inputs['position_ids'] = position_ids

Expand Down
66 changes: 66 additions & 0 deletions agents/agents/agents/backend_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from dataclasses import dataclass
from typing import Optional, Dict, Any, List
import asyncio


@dataclass
class TransformersConfig:
"""Configuration for Transformers backend"""
temperature: float = 1.0
max_new_tokens: int = 1024
trust_remote_code: bool = True
device_map: str = "auto"


@dataclass
class VLLMConfig:
"""Configuration for VLLM backend"""
temperature: float = 1.0
max_new_tokens: int = 1024
# Add other vLLM specific parameters as needed


@dataclass
class AsyncVLLMConfig:
"""Configuration for Async VLLM backend"""
temperature: float = 1.0
max_new_tokens: int = 1024
# Add other async vLLM specific parameters as needed


@dataclass
class VerlConfig:
"""Configuration for Verl backend"""
temperature: float = 1.0
max_new_tokens: int = 1024
# Add other Verl specific parameters as needed


@dataclass
class AsyncVerlConfig:
"""Configuration for Async Verl backend"""
temperature: float = 1.0
max_new_tokens: int = 1024
# Add other async Verl specific parameters as needed


@dataclass
class ClientConfig:
"""Configuration for Client backend (OpenAI-compatible)"""
base_url: str = "http://localhost:8000/v1"
max_requests_per_minute: int = 100
timeout: int = 600
api_key: str = "EMPTY"
max_new_tokens: int = 1024
temperature: float = 1.0


# Backend configuration mapping
BACKEND_CONFIGS = {
"transformers": TransformersConfig,
"vllm": VLLMConfig,
"async_vllm": AsyncVLLMConfig,
"verl": VerlConfig,
"async_verl": AsyncVerlConfig,
"client": ClientConfig,
}
56 changes: 29 additions & 27 deletions agents/agents/agents/chain/chain_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,14 +316,40 @@ async def _run_single_chain(self,
# Handle tool calls
if current_node.messages[-1].get("tool_calls"):
for tool_call in current_node.messages[-1]["tool_calls"]:
current_node = await self._execute_tool_call(
result = await self._execute_tool_call(
tool_call, newest_messages, chain, chain_id, depth,
have_set_tools, enable_streaming
)
have_set_tools = True

# Create action input node
action_input_node = chain.add_node(
type="Action Input",
messages=deepcopy(newest_messages),
description=result.get("arguments", "")
)

# Process observation
observation = result["observation"]
observation_json = json.dumps({
"name": result["name"],
"content": observation,
}, indent=4)

action_input_node.observation = observation_json
action_input_node.observation_code = result["status"]
newest_messages.append({
"role": "tool",
"tool_call_id": tool_call["id"],
"content": [{"type": "text", "text": observation_json}],
})
action_input_node.messages = deepcopy(newest_messages)
action_input_node.is_terminal = result["status"] in self.terminal_status
else:
# No tool calls, chain is finished
break

current_node = action_input_node

depth += 1

Expand Down Expand Up @@ -463,33 +489,9 @@ async def _execute_tool_call(self, tool_call, newest_messages, chain, chain_id,
step=depth,
depth=depth
))

return result


# Create action input node
action_input_node = chain.add_node(
type="Action Input",
messages=deepcopy(newest_messages),
description=result.get("arguments", "")
)

# Process observation
observation = result["observation"]
observation_json = json.dumps({
"name": result["name"],
"content": observation,
}, indent=4)

action_input_node.observation = observation_json
action_input_node.observation_code = result["status"]
newest_messages.append({
"role": "tool",
"tool_call_id": tool_call["id"],
"content": [{"type": "text", "text": observation_json}],
})
action_input_node.messages = deepcopy(newest_messages)
action_input_node.is_terminal = result["status"] in self.terminal_status

return action_input_node

async def _finalize_chain(self, chain_id, chain, current_node, depth):
"""Finalize the chain with reward calculation and cleanup."""
Expand Down
Loading
Loading