diff --git a/agents/agents/agents/agent_base.py b/agents/agents/agents/agent_base.py
index 0d712ce..e878a79 100644
--- a/agents/agents/agents/agent_base.py
+++ b/agents/agents/agents/agent_base.py
@@ -4,23 +4,28 @@
from .templates.templates import get_template
from ..__init__ import AGENT_DATA_DIR
-from .llm_backend import AsyncVLLMBackend, AsyncVerlBackend, ClientBackend, TransformersBackend, VLLMBackend, VerlBackend
+from .llm_backend import AsyncVLLMBackend, AsyncVerlBackend, ClientBackend, TransformersBackend, VLLMBackend
from ..utils.logging import get_logger
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
-from .templates.utils import is_vlm_template, tokenize_conversations
+from .templates.utils import tokenize_conversations
+from .templates.vision_processor import is_vision_template
from .chain.chain_base import ChainGeneration
import os
import transformers
import warnings
+import logging
from .chain.streaming_observer import ConsoleStreamObserver, StreamingManager
+from .utils.tokenizer import create_tokenizer
+from .backend_config import BACKEND_CONFIGS
try:
from verl.protocol import DataProto
except ImportError:
print("verl can not be imported.")
pass
+Logger = logging.getLogger(__name__)
class BaseAgent(ChainGeneration, ABC):
"""
@@ -34,12 +39,13 @@ class BaseAgent(ChainGeneration, ABC):
def __init__(
self,
model_name_or_path,
- template: str,
+ template: str=None,
system_prompt: str = None,
tools: List = None,
max_length: int=8192,
debug: bool = False,
backend: str = "transformers",
+ backend_config: Any = None,
reward_fn: Callable = None,
log_file: str = "agent",
project_name: str = None,
@@ -65,9 +71,30 @@ def __init__(
self.tools = tools
self.system_prompt = system_prompt
self.model_name_or_path = model_name_or_path
- self.llm_engine, self.tokenizer, self.processor = self._init_llm_engine(model_name_or_path, backend)
+
+ # Handle backend configuration
+ if backend_config is None:
+ # Use default configuration for the backend
+ config_class = BACKEND_CONFIGS.get(backend)
+ if config_class:
+ self.backend_config = config_class()
+ else:
+ self.backend_config = None
+ else:
+ self.backend_config = backend_config
+
+ self.llm_engine = self._init_llm_engine(model_name_or_path, backend)
+
+ # Create appropriate tokenizer for trajectory processing
+ self.tokenizer = create_tokenizer(model_name_or_path)
+
self._reward_fn = reward_fn
- self.jinja_template = get_template(self.template).jinja_template()
+
+ if self.template is None:
+ self.jinja_template = None
+ else:
+ self.jinja_template = get_template(self.template).jinja_template()
+
self.project_name = project_name
self.run_name = run_name
self.streaming_manager = StreamingManager()
@@ -78,38 +105,59 @@ def __init__(
raise ValueError(f"Streaming mode {streaming} is not supported.")
super().__init__()
if kwargs:
- warnings.warn(f"Unused arguments for agent initialization: {kwargs}")
+ # warnings.warn(f"Unused arguments for agent initialization: {kwargs}")
+ raise ValueError(f"Unused arguments for agent initialization: {kwargs}")
def _init_llm_engine(self, model_name_or_path: str, backend: str):
if isinstance(model_name_or_path, str):
+ # Extract backend-specific configuration
+ config_kwargs = {}
+ if self.backend_config:
+ config_kwargs = {k: v for k, v in self.backend_config.__dict__.items()
+ if not k.startswith('_')}
+
if backend == "transformers":
- llm_engine = TransformersBackend(model_name_or_path, self.template, max_length=self.max_length)
- elif backend == "vllm":
- llm_engine = VLLMBackend(model_name_or_path, self.template, max_length=self.max_length)
+ llm_engine = TransformersBackend(
+ model_name_or_path,
+ self.template,
+ max_length=self.max_length,
+ **config_kwargs
+ )
elif backend == "async_vllm":
- llm_engine = AsyncVLLMBackend(model_name_or_path, self.template, max_length=self.max_length)
- elif backend == "verl":
- llm_engine = VerlBackend(llm_engine=None, model_name_or_path=model_name_or_path, template=self.template, max_length=self.max_length)
+ llm_engine = AsyncVLLMBackend(
+ model_name_or_path,
+ self.template,
+ max_length=self.max_length,
+ **config_kwargs
+ )
elif backend == "async_verl":
- llm_engine = AsyncVerlBackend(llm_engine=None, model_name_or_path=model_name_or_path, template=self.template, max_length=self.max_length)
+ llm_engine = AsyncVerlBackend(
+ llm_engine=None,
+ model_name_or_path=model_name_or_path,
+ template=self.template,
+ max_length=self.max_length,
+ **config_kwargs
+ )
elif backend == "client":
- llm_engine = ClientBackend(model_name_or_path, self.template, max_length=self.max_length)
+ print(f"config_kwargs: {config_kwargs}")
+ llm_engine = ClientBackend(
+ model_name_or_path,
+ self.template,
+ max_length=self.max_length,
+ **config_kwargs
+ )
else:
raise ValueError(f"Backend {backend} is not supported.")
else:
raise ValueError("model_name_or_path must be a string.")
- tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path)
- if is_vlm_template(self.template):
- processor = transformers.AutoProcessor.from_pretrained(model_name_or_path)
- else:
- processor = None
- return llm_engine, tokenizer, processor
+ return llm_engine
- def set_llm_engine(self, llm_engine: Any, tokenizer: Any):
+ def set_llm_engine(self, llm_engine: Any, tokenizer: Any, processor: Any):
assert self.backend == "async_verl", "Only async verl backend is supported for now"
self.llm_engine.llm_engine = llm_engine
self.tokenizer = tokenizer
+ self.processor = processor
def generate(self, messages_list_or_inputs: List[List[Dict]], **args):
return self.llm_engine.generate(messages_list_or_inputs, **args)
@@ -151,17 +199,6 @@ async def generate_streaming(self, messages_list_or_inputs: List[List[Dict]], st
@property
def timing_data(self):
return self.timer.timing_data
-
- def forward(self, messages_list_or_inputs: List[List[Dict]], **args):
- if isinstance(messages_list_or_inputs, List):
- inputs = tokenize_conversations(messages_list_or_inputs, tokenizer=self.tokenizer, conv_template=self.template, max_length=self.max_length, processor=self.processor)
- else:
- raise ValueError("messages_list_or_inputs must be a list of messages or a dictionary of padded inputs.")
-
- if isinstance(self.llm_engine, transformers.PreTrainedModel):
- return self.llm_engine.forward(**inputs, **args) # Only support transformers models for now.
- else:
- raise ValueError("llm_engine must be a transformers.PretrainedModel.")
@property
def trajectories(self):
@@ -169,7 +206,10 @@ def trajectories(self):
return trajectories
- def tokenize_trajectories(self, return_action_mask: bool = False, return_reward_mask: bool = False):
+ def tokenize_trajectories(self, tokenizer, return_action_mask: bool = False, return_reward_mask: bool = False):
+ if tokenizer is None:
+ tokenizer = self.tokenizer
+
trajectories = self.trajectories
self.logger.info("================ Trajectory ================")
self.logger.info(trajectories[0])
@@ -196,7 +236,7 @@ def tokenize_trajectories(self, return_action_mask: bool = False, return_reward_
info['last_response'] = last_response
other_info_list.append(info)
- inputs = tokenize_conversations(messages_list, tokenizer=self.tokenizer, conv_template=self.template, processor=self.processor, max_length=self.max_length, return_reward_mask=return_reward_mask)
+ inputs = tokenize_conversations(messages_list, tokenizer=tokenizer, conv_template=self.template, processor=self.processor, max_length=self.max_length, return_reward_mask=return_reward_mask)
position_ids = torch.clip(torch.cumsum(inputs['attention_mask'], dim=-1) - 1, min=0, max=None)
inputs['position_ids'] = position_ids
diff --git a/agents/agents/agents/backend_config.py b/agents/agents/agents/backend_config.py
new file mode 100644
index 0000000..54a5a62
--- /dev/null
+++ b/agents/agents/agents/backend_config.py
@@ -0,0 +1,66 @@
+from dataclasses import dataclass
+from typing import Optional, Dict, Any, List
+import asyncio
+
+
+@dataclass
+class TransformersConfig:
+ """Configuration for Transformers backend"""
+ temperature: float = 1.0
+ max_new_tokens: int = 1024
+ trust_remote_code: bool = True
+ device_map: str = "auto"
+
+
+@dataclass
+class VLLMConfig:
+ """Configuration for VLLM backend"""
+ temperature: float = 1.0
+ max_new_tokens: int = 1024
+ # Add other vLLM specific parameters as needed
+
+
+@dataclass
+class AsyncVLLMConfig:
+ """Configuration for Async VLLM backend"""
+ temperature: float = 1.0
+ max_new_tokens: int = 1024
+ # Add other async vLLM specific parameters as needed
+
+
+@dataclass
+class VerlConfig:
+ """Configuration for Verl backend"""
+ temperature: float = 1.0
+ max_new_tokens: int = 1024
+ # Add other Verl specific parameters as needed
+
+
+@dataclass
+class AsyncVerlConfig:
+ """Configuration for Async Verl backend"""
+ temperature: float = 1.0
+ max_new_tokens: int = 1024
+ # Add other async Verl specific parameters as needed
+
+
+@dataclass
+class ClientConfig:
+ """Configuration for Client backend (OpenAI-compatible)"""
+ base_url: str = "http://localhost:8000/v1"
+ max_requests_per_minute: int = 100
+ timeout: int = 600
+ api_key: str = "EMPTY"
+ max_new_tokens: int = 1024
+ temperature: float = 1.0
+
+
+# Backend configuration mapping
+BACKEND_CONFIGS = {
+ "transformers": TransformersConfig,
+ "vllm": VLLMConfig,
+ "async_vllm": AsyncVLLMConfig,
+ "verl": VerlConfig,
+ "async_verl": AsyncVerlConfig,
+ "client": ClientConfig,
+}
\ No newline at end of file
diff --git a/agents/agents/agents/chain/chain_base.py b/agents/agents/agents/chain/chain_base.py
index 843b681..02fa27d 100644
--- a/agents/agents/agents/chain/chain_base.py
+++ b/agents/agents/agents/chain/chain_base.py
@@ -316,14 +316,40 @@ async def _run_single_chain(self,
# Handle tool calls
if current_node.messages[-1].get("tool_calls"):
for tool_call in current_node.messages[-1]["tool_calls"]:
- current_node = await self._execute_tool_call(
+ result = await self._execute_tool_call(
tool_call, newest_messages, chain, chain_id, depth,
have_set_tools, enable_streaming
)
have_set_tools = True
+
+ # Create action input node
+ action_input_node = chain.add_node(
+ type="Action Input",
+ messages=deepcopy(newest_messages),
+ description=result.get("arguments", "")
+ )
+
+ # Process observation
+ observation = result["observation"]
+ observation_json = json.dumps({
+ "name": result["name"],
+ "content": observation,
+ }, indent=4)
+
+ action_input_node.observation = observation_json
+ action_input_node.observation_code = result["status"]
+ newest_messages.append({
+ "role": "tool",
+ "tool_call_id": tool_call["id"],
+ "content": [{"type": "text", "text": observation_json}],
+ })
+ action_input_node.messages = deepcopy(newest_messages)
+ action_input_node.is_terminal = result["status"] in self.terminal_status
else:
# No tool calls, chain is finished
break
+
+ current_node = action_input_node
depth += 1
@@ -463,33 +489,9 @@ async def _execute_tool_call(self, tool_call, newest_messages, chain, chain_id,
step=depth,
depth=depth
))
+
+ return result
-
- # Create action input node
- action_input_node = chain.add_node(
- type="Action Input",
- messages=deepcopy(newest_messages),
- description=result.get("arguments", "")
- )
-
- # Process observation
- observation = result["observation"]
- observation_json = json.dumps({
- "name": result["name"],
- "content": observation,
- }, indent=4)
-
- action_input_node.observation = observation_json
- action_input_node.observation_code = result["status"]
- newest_messages.append({
- "role": "tool",
- "tool_call_id": tool_call["id"],
- "content": [{"type": "text", "text": observation_json}],
- })
- action_input_node.messages = deepcopy(newest_messages)
- action_input_node.is_terminal = result["status"] in self.terminal_status
-
- return action_input_node
async def _finalize_chain(self, chain_id, chain, current_node, depth):
"""Finalize the chain with reward calculation and cleanup."""
diff --git a/agents/agents/agents/llm_backend.py b/agents/agents/agents/llm_backend.py
index 2401137..726620a 100644
--- a/agents/agents/agents/llm_backend.py
+++ b/agents/agents/agents/llm_backend.py
@@ -20,6 +20,7 @@
from vllm import LLM, AsyncLLMEngine, SamplingParams, AsyncEngineArgs
import openai
from .templates.templates import Chat
+from .templates.vision_processor import get_processor
import logging
import PIL
@@ -46,8 +47,8 @@ def apply_chat_template(self, messages_list: List[List[Dict]], template: str, ad
for messages in messages_list:
chat = Chat(template, messages)
prompts.append(chat.prompt(add_generation_prompt=add_generation_prompt, tools=tools))
- # We don't support vision inputs for now
- vision_inputs.append(None)
+ # We only support image inputs for now
+ vision_inputs.append(chat.vision_inputs())
return prompts, vision_inputs
@@ -325,49 +326,6 @@ async def generate_streaming(self, messages_list: List[List[Dict]], **kwargs) ->
if hasattr(sequence, 'text'):
yield sequence.text
-class VerlBackend(LLMBackend):
- """Verl implementation"""
-
- def __init__(self, llm_engine, model_name_or_path: str, template: str, max_length: int=8192, **kwargs):
- super().__init__(**kwargs)
- self.model_name = model_name_or_path
- self.max_length = max_length
- self.template = template
- self.tokenizer = AutoTokenizer.from_pretrained(
- self.model_name,
- trust_remote_code=True,
- )
- self.llm_engine = llm_engine
-
- def generate(self, messages_list: str, **kwargs) -> str:
- """Generate text from prompt using Verl"""
- # We need to build a DataProto from the prompts
- prompts = self.apply_chat_template(messages_list, self.template)
- inputs = self.tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left")
- # We need to do padding for compatibility with the verl DataProto, which
- # assumes that the batch size must be divisible by the dp size
- world_size = self.llm_engine.world_size
- inputs['input_ids'] = pad_tensor_to_rank_size(inputs['input_ids'], world_size)
- inputs['attention_mask'] = pad_tensor_to_rank_size(inputs['attention_mask'], world_size)
-
- position_ids = torch.clip(torch.cumsum(inputs.attention_mask, dim=-1) - 1, min=0, max=None)
- inputs['position_ids'] = position_ids
-
- n = kwargs.get("num_return_sequences", 1)
- temperature = kwargs.get("temperature", 1.0)
- use_agent = True
- batch = DataProto.from_single_dict(inputs, meta_info={"n": n, "use_agent": use_agent, "temperature": temperature})
-
- gen_batch_output = self.llm_engine.generate_sequences(batch)
- responses = gen_batch_output.batch['responses'] # BS x L
- response_texts = self.tokenizer.batch_decode(responses, skip_special_tokens=True) # List of string with length BS
- response_texts = response_texts[:len(prompts)*n]
- return response_texts
-
- def generate_async(self, messages_list: str, **kwargs) -> str:
- raise NotImplementedError("Verl backend does not support async generation")
-
-
class AsyncVerlBackend(LLMBackend):
"""Verl implementation"""
@@ -417,12 +375,6 @@ async def generate_async(self, messages_list: str, **kwargs) -> str:
gen_batch_output = await self.llm_engine.generate_sequences_async(batch, **generation_config)
response_texts = gen_batch_output.batch['responses'].tolist() # np.array of strings with length BS
- # print(f"[AsyncVerlbackend] response_texts: {response_texts.shape} {type(response_texts)}")
- # print(f"[AsyncVerlBackend] response_texts: {response_texts[0]}")
- # raise NotImplementedError("Async Verl backend does not support sync generation")
- # response_texts = [text for text in response_texts]
- # response_texts = self.tokenizer.batch_decode(responses, skip_special_tokens=True) # List of string with length BS
- # response_texts = responses[:len(prompts)*n]
return response_texts
@@ -465,24 +417,39 @@ def __init__(
# --------------------------------------------------------------------- #
# Low‑level single request (runs in threadpool so it doesn't block loop)
# --------------------------------------------------------------------- #
- @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=15))
- def _blocking_call(self, messages: List[List[Dict]], **kw) -> str:
- if "num_return_sequences" in kw:
- n = kw.pop("num_return_sequences")
+ @retry(stop=stop_after_attempt(1), wait=wait_exponential(multiplier=1, min=4, max=15))
+ def _blocking_call(self, messages: List[List[Dict]], **kwargs) -> str:
+ if "num_return_sequences" in kwargs:
+ n = kwargs.pop("num_return_sequences")
else:
n = 1
+
+ if "tool_choice" in kwargs:
+ tool_choice = kwargs.pop("tool_choice")
+ else:
+ tool_choice = "none"
+
+ print(f"[ClientBackend] messages: {messages}")
resp = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
timeout=self.timeout,
max_tokens=self.max_new_tokens,
n=n,
- tool_choice="none",
- **kw,
+ tool_choice=tool_choice,
+ **kwargs,
)
- response_texts = [choice.message.content for choice in resp.choices]
+ resp_json = resp.dict()
+ response_texts = [choice["message"]["content"] for choice in resp_json["choices"]]
+ tool_calls = [choice["message"]["tool_calls"] for choice in resp_json["choices"]]
- return response_texts
+ if tool_choice == "none":
+ return response_texts
+ else:
+ return {
+ "response_texts": response_texts,
+ "tool_calls": tool_calls,
+ }
async def _call(self, messages: List[List[Dict]], **kw) -> str:
# acquire a rate‑limit token
@@ -495,7 +462,7 @@ async def _call(self, messages: List[List[Dict]], **kw) -> str:
def async_generate(
self,
messages: List[List[Dict]] | List[Dict],
- **kw,
+ **kwargs,
) -> List[str] | asyncio.Task:
"""
• Pass a *list of messages* → single completion.
@@ -510,14 +477,24 @@ def async_generate(
messages_list = [messages] # single
else:
messages_list = messages # batch
-
+ print(f"[ClientBackend] messages_list: {messages_list}")
messages_list = [convert_messages_to_openai_format(messages) for messages in messages_list]
async def _runner():
- tasks = [asyncio.create_task(self._call(_input, **kw)) for _input in messages_list]
- texts_list = await asyncio.gather(*tasks)
- response_texts = [text for texts in texts_list for text in texts]
- return response_texts
+ tasks = [asyncio.create_task(self._call(_input, **kwargs)) for _input in messages_list]
+ # Flatten the response list
+ response_texts_list_or_dict = await asyncio.gather(*tasks)
+ # return is a dict if tool_choice is not none, otherwise a list of strings
+ if isinstance(response_texts_list_or_dict[0], dict):
+ response_texts = [text for response in response_texts_list_or_dict for text in response["response_texts"]]
+ tool_calls = [tool_call for response in response_texts_list_or_dict for tool_call in response["tool_calls"]]
+ return {
+ "response_texts": response_texts,
+ "tool_calls": tool_calls,
+ }
+ else:
+ response_texts = [text for response in response_texts_list_or_dict for text in response]
+ return response_texts
try:
loop = asyncio.get_running_loop() # ➊ already inside a loop?
@@ -532,8 +509,8 @@ async def _runner():
async def generate_async(self,
messages: List[List[Dict]] | List[Dict],
- **kw) -> List[str]:
- return await self.async_generate(messages, **kw)
+ **kwargs) -> List[str]:
+ return await self.async_generate(messages, **kwargs)
# Background token‑bucket refill (one token each 60/max_rpm seconds)
async def _refill_tokens(self):
diff --git a/agents/agents/agents/specialized/openai_agent.py b/agents/agents/agents/specialized/openai_agent.py
index f64a056..ae0f37a 100644
--- a/agents/agents/agents/specialized/openai_agent.py
+++ b/agents/agents/agents/specialized/openai_agent.py
@@ -1,10 +1,13 @@
import os
-from typing import Any, Dict, List
+from typing import Any, Dict, List, Tuple, Union
from openai import OpenAI, AzureOpenAI
import httpx
-
+import asyncio
+from ...tools import answer_qa
from ...tools.tool_base import tool
from ..agent_base import BaseAgent
+from ..llm_backend import ClientBackend
+from ..backend_config import ClientConfig
from tenacity import retry, wait_random_exponential, stop_after_attempt
from termcolor import colored
import json
@@ -14,102 +17,113 @@ class OpenAIAgent(BaseAgent):
def __init__(
self,
api_key="",
+ base_url="https://api.openai.com/v1",
**kwargs
):
- wrapper = kwargs.get("wrapper", False)
- if not wrapper:
- kwargs["wrapper"] = True
- template = kwargs.get("template", "openai")
- kwargs["template"] = template
+ assert api_key is not None and api_key != "", "API key is required"
+ backend = kwargs.get("backend", "client")
+ assert backend == "client", "OpenAI agent only supports client backend"
+ kwargs["backend"] = backend
+
+ # Create client-specific configuration
+ client_config = ClientConfig(
+ api_key=api_key,
+ base_url=base_url,
+ max_requests_per_minute=kwargs.get("max_requests_per_minute", 100),
+ timeout=kwargs.get("timeout", 600),
+ max_new_tokens=kwargs.get("max_new_tokens", 1024),
+ temperature=kwargs.get("temperature", 1.0)
+ )
+ kwargs["backend_config"] = client_config
+
+ # Initialize the base class
super(OpenAIAgent, self).__init__(**kwargs)
+
model_name_or_path = kwargs.get("model_name_or_path", "gpt-3.5-turbo")
self.client = OpenAI(api_key=api_key)
self.api_key = api_key
self.model = model_name_or_path
+
+ # For OpenAI models, we don't need a tokenizer for the LLM engine, but we still need one for trajectory processing
+ if self.backend == "client":
+ self.tokenizer = None
+ self.processor = None
- def parse(self, messages_list: List[List[Dict]], tools: List[Any], **args):
- # OpenAI use 'n' to specify the number of return sequences
- num_return_sequences = args.get("num_return_sequences", 1)
- tool_schemas = [tool.schema for tool in tools]
- del args["num_return_sequences"]
- args["n"] = num_return_sequences
-
- def process_message(message_data):
- message, api_key, model, tool_schemas, args = message_data
- client = OpenAI(api_key=api_key)
- try:
- json_data = {
- "model": model,
- "messages": message,
- **args
- }
- if tool_schemas is not None:
- json_data.update({"tools": tool_schemas})
- openai_response = client.chat.completions.create(**json_data)
- result = openai_response.dict()
- new_message = result['choices'][0]['message']
- new_message["loss"] = True
- return new_message
- except Exception as e:
- print(f"Parsing Exception: {repr(e)}. Try again.")
- return {
- "role": "assistant",
- "content": result,
- "tool_calls": [],
- "loss": True
- }
+ async def generate_async(self, messages_list_or_inputs: List[List[Dict]], **args):
+ responses = await super().generate_async(messages_list_or_inputs, tool_choice="auto", **args)
+ return responses
- # Prepare data for each message
- message_data_list = [
- (message, self.api_key, self.model, tool_schemas, args)
- for message in messages_list
- ]
- pool = Pool()
- results = pool.map(process_message, message_data_list)
- pool.close()
- pool.join()
-
- return results
-
+ def parse(self, responses: Union[Dict[str, List], List[str]], tools: List[Any], **args) -> List[Dict]:
+ """
+ Parse responses into the correct message format.
+
+ Args:
+ responses: List of response strings from the LLM.
+ tools: List of tools available to the agent.
+ **args: Additional arguments for parsing.
+
+ Returns:
+ List of assistant messages in the correct format.
+ """
+
+ new_messages = []
+ if isinstance(responses, dict):
+ for response, tool_calls in zip(responses["response_texts"], responses["tool_calls"]):
+ new_message = {"role": "assistant"}
+
+ new_message["content"] = response
+ if len(tool_calls) > 0:
+ tool_calls = tool_calls[:1]
+ new_message["tool_calls"] = tool_calls
+
+ new_messages.append(new_message)
+ elif isinstance(responses, list):
+ for content in responses:
+ new_message = {"role": "assistant", "content": content}
+ new_messages.append(new_message)
+ else:
+ raise ValueError(f"Invalid responses type: {type(responses)}")
+
+ return new_messages
- @retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3))
- def chat_completion_request(
- self,
- messages,
- tools=None,
- tool_choice=None,
- model=None,
- stop=None,
- client=None,
- **args
- ):
- if model is None:
- model = self.model
- if client is None:
- client = self.client
+ # @retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(1))
+ # def chat_completion_request(
+ # self,
+ # messages,
+ # tools=None,
+ # tool_choice=None,
+ # model=None,
+ # stop=None,
+ # client=None,
+ # **args
+ # ):
+ # if model is None:
+ # model = self.model
+ # if client is None:
+ # client = self.client
- json_data = {
- "model": model,
- "messages": messages,
- **args
- }
- if stop is not None:
- json_data.update({"stop": stop})
- if tools is not None:
- json_data.update({"tools": tools})
- if tool_choice is not None:
- json_data.update({"tool_choice": tool_choice})
+ # json_data = {
+ # "model": model,
+ # "messages": messages,
+ # **args
+ # }
+ # if stop is not None:
+ # json_data.update({"stop": stop})
+ # if tools is not None:
+ # json_data.update({"tools": tools})
+ # if tool_choice is not None:
+ # json_data.update({"tool_choice": tool_choice})
- try:
- # We use chat completion API
- openai_response = client.chat.completions.create(**json_data)
- json_data = openai_response.dict()
- return json_data
- except Exception as e:
- print(f"Unable to generate ChatCompletion response: {e}")
- raise e
+ # try:
+ # # We use chat completion API
+ # openai_response = client.chat.completions.create(**json_data)
+ # json_data = openai_response.dict()
+ # return json_data
+ # except Exception as e:
+ # print(f"Unable to generate ChatCompletion response: {e}")
+ # raise e
@tool()
@@ -129,17 +143,23 @@ def get_current_weather(location: str, unit: str="fahrenheit"):
if __name__ == "__main__":
agent = OpenAIAgent(
- model_name_or_path="gpt-3.5-turbo",
- api_key="",
- tools=[get_current_weather]
+ model_name_or_path="gpt-4o-mini",
+ api_key="OpenAI API Key",
+ tools=[get_current_weather,answer_qa]
)
- messages = [{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}]
- agent.run(
- max_steps=3,
+ messages = [
+ {
+ "messages": [
+ {"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}
+ ]
+ }
+ ]
+ asyncio.run(agent.run_async(
+ max_steps=5,
start_messages=messages,
num_chains=1
- )
+ ))
trajectories = agent.trajectories
- print(trajectories[0]["messages"])
+ print(f"""Trajectory: {trajectories[0]["messages"]}""")
diff --git a/agents/agents/agents/templates/constants.py b/agents/agents/agents/templates/constants.py
new file mode 100644
index 0000000..796ba97
--- /dev/null
+++ b/agents/agents/agents/templates/constants.py
@@ -0,0 +1,18 @@
+from enum import Enum, auto
+
+class ToolPlacement(Enum):
+ """
+ Where to inject the tool catalogue in the rendered prompt.
+ """
+ SYSTEM = auto() # inside the system message
+ FIRST_USER = auto() # as an extra first-user turn
+ LAST_USER = auto() # appended to the last user turn
+ SEPARATE = auto() # its own dedicated turn / role
+
+
+class Role(Enum):
+ SYSTEM = "system"
+ USER = "user"
+ ASSISTANT = "assistant"
+ TOOL = "tool"
+ ASSISTANT_PREFIX = "assistant_prefix"
\ No newline at end of file
diff --git a/agents/agents/agents/templates/system_policy.py b/agents/agents/agents/templates/system_policy.py
new file mode 100644
index 0000000..8ede734
--- /dev/null
+++ b/agents/agents/agents/templates/system_policy.py
@@ -0,0 +1,46 @@
+from abc import ABC, abstractmethod
+import dataclasses
+import datetime
+from typing import Callable
+
+@dataclasses.dataclass
+class SystemPolicy:
+ use_system: bool = True
+ use_system_without_system_message: bool = True
+ content_processor: Callable[[str], str] = None
+
+
+class SystemContentProcessor(ABC):
+ @abstractmethod
+ def __call__(self, system_message: str) -> str:
+ raise NotImplementedError
+
+ @abstractmethod
+ def jinja(self) -> str:
+ raise NotImplementedError
+
+class Llama32DateProcessor(SystemContentProcessor):
+ """
+ A system content processor that adds date information to system messages.
+
+ In Python mode, it dynamically computes the current date.
+ In Jinja mode, it provides a template with placeholders that can be processed.
+
+ Usage in Jinja templates:
+ - The template includes '__CURRENT_DATE__' placeholder
+ - Replace '__CURRENT_DATE__' with the actual formatted date during processing
+ - Format should be 'dd MMM yyyy' (e.g., '15 Dec 2024')
+ - No external context variables required
+ """
+ def __call__(self, system_message: str) -> str:
+ return f"Cutting Knowledge Date: December 2023\nToday Date: {datetime.datetime.now().strftime('%d %b %Y')}\n\n{system_message}"
+
+ def jinja(self) -> str:
+ # For Jinja templates used by external systems (like vLLM), we need a self-contained approach
+ # Since external systems can't provide context variables, we use a placeholder approach
+ # The external system should replace __CURRENT_DATE__ with the actual date
+ return """Cutting Knowledge Date: December 2023
+Today Date: __CURRENT_DATE__
+
+{{ system_message }}"""
+
\ No newline at end of file
diff --git a/agents/agents/agents/templates/templates.py b/agents/agents/agents/templates/templates.py
index 1b71cfe..7697f87 100644
--- a/agents/agents/agents/templates/templates.py
+++ b/agents/agents/agents/templates/templates.py
@@ -1,43 +1,44 @@
-"""
-This file is a modified version of the original fastchat/conversation.py file.
-
-The original file can be found at:
-https://github.com/LM-SYS/fastchat/blob/main/fastchat/conversation.py
-
-"""
from collections import defaultdict
-from copy import copy
+from copy import copy, deepcopy
import dataclasses
-from enum import Enum, auto, IntEnum
import json
-from typing import List, Any, Dict, Union, Tuple
+from typing import Callable, List, Any, Dict, Union, Tuple
import warnings
-
+import logging
import torch
-from .preprocess import open_image_from_any
from transformers import PreTrainedTokenizer
+from .preprocess import open_image_from_any
+from .vision_processor import is_vision_template
import re
+from typing import Protocol
+from .tool_policy import (
+ ToolFormatter,
+ JsonMinifiedFormatter,
+ JsonCompactFormatter,
+ JsonIndentedFormatter,
+ ToolMainContentProcessor,
+ JsonQwenFormatter,
+)
+from datetime import datetime
+from .constants import Role
+from .system_policy import Llama32DateProcessor, SystemPolicy
+from .tool_policy import ToolPolicy
+from .constants import ToolPlacement, Role
+
+Logger = logging.getLogger(__name__)
+
+# Add console handler if no handlers exist
+if not Logger.handlers:
+ console_handler = logging.StreamHandler()
+ console_handler.setLevel(logging.DEBUG)
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+ console_handler.setFormatter(formatter)
+ Logger.addHandler(console_handler)
-class Role(Enum):
- SYSTEM = "system"
- USER = "user"
- ASSISTANT = "assistant"
- TOOL = "tool"
- ASSISTANT_PREFIX = "assistant_prefix"
-
-
-class SeparatorStyle(IntEnum):
- """Separator styles."""
- NO_COLON_SINGLE = auto()
- LLAMA2 = auto()
- LLAMA3 = auto()
- CHATGLM = auto()
- CHATML = auto()
- ADD_SPACE_TWO = auto()
- GENERAL = auto() # No sep, we take all sep as part of the role
- GENERAL_STOP_ALL = auto() # No sep, and add stop_str at each turn
-
+@dataclasses.dataclass
+class GlobalPolicy:
+ prefix: str = None
@dataclasses.dataclass
@@ -59,8 +60,15 @@ class Template:
tool_template: str = None
# The user template
user_template: str = None
+ user_template_with_tools: str = None
# The assistant template
assistant_template: str = None
+ # Global policy
+ global_policy: "GlobalPolicy" = None
+ # System message policy
+ system_policy: "SystemPolicy" = None
+ # Tool policy for this template
+ tool_policy: "ToolPolicy" = None
## vision part
vision_start: str = None
@@ -68,63 +76,199 @@ class Template:
image_token: str = None
video_token: str = None
+ chat_template: str = None
+
+ def __post_init__(self):
+ """Post-initialization to automatically register vision processor if vision tokens are defined"""
+ if self.image_token or self.video_token:
+ self._register_vision_processor()
+ # Initialise default tool policy if none was provided
+ if self.tool_policy is None:
+ self.tool_policy = ToolPolicy()
+ if self.system_policy is None:
+ self.system_policy = SystemPolicy()
+
+ def _register_vision_processor(self):
+ """Automatically register a vision processor for this template"""
+ from .vision_processor import VisionProcessorConfig, register_processor
+
+ # Determine model type based on template name
+ model_type = self._infer_model_type()
+
+ # Create vision config
+ config = VisionProcessorConfig(
+ model_type=model_type,
+ image_token=self.image_token or "",
+ video_token=self.video_token or "",
+ vision_start=self.vision_start or "",
+ vision_end=self.vision_end or "",
+ processor_class="AutoProcessor",
+ expansion_strategy="patch_based"
+ )
+
+ # Register the processor
+ register_processor(self.name, config)
+
+ def _infer_model_type(self) -> str:
+ """Infer model type from template name"""
+ name_lower = self.name.lower()
+
+ if "qwen" in name_lower:
+ return "qwen_vl"
+ elif "llava" in name_lower:
+ return "llava"
+ elif "gemma" in name_lower:
+ return "gemma3"
+ elif "paligemma" in name_lower:
+ return "paligemma"
+ elif "internvl" in name_lower:
+ return "internvl"
+ elif "minicpm" in name_lower:
+ return "minicpm"
+ elif "mllama" in name_lower:
+ return "mllama"
+ elif "pixtral" in name_lower:
+ return "pixtral"
+ elif "video" in name_lower:
+ return "video_llava"
+ else:
+ # Default to patch-based for unknown models
+ return "patch_based"
+
+ def _supports_tool_call(self) -> bool:
+ if (self.system_template_with_tools or self.user_template_with_tools) and self.tool_template:
+ return True
+ else:
+ return False
+
def render(self, messages: List[Dict], tools=None, add_generation_prompt: bool = False) -> str:
- """Render the template with the given messages and kwargs.
- messages: [
- {
- "role": "user",
- "content": [
- {
- "type": "text",
- "text": "Hello, how are you?"
- }
- ]
- },
- {
- "role": "assistant",
- ]
+ """Render the template and return
+ 1. the final prompt string,
+ 2. the list of string *elements* that compose the prompt, and
+ 3. the corresponding list of *roles* (used by downstream post-processing).
+
+ The heavy lifting is delegated to small, single-purpose helpers so the
+ high-level flow is immediately apparent:
+
+ 1. _insert_tools – decide where the tool catalogue lives
+ 2. _encode_turns – encode every conversation turn
+ 3. _maybe_add_generation_prompt – append the generation prefix if requested
"""
- elements = []
- roles = []
- if tools:
- tools = self._encode_system_tools(tools)
- for i, message in enumerate(messages):
+ # Step 1 – decide tool placement & clone messages
+ work_messages, tools_str, insert_tools_idx = self._insert_tools(messages, tools)
+
+ # Step 2 – encode each conversation turn to text tokens
+ elements, roles = self._encode_turns(work_messages, tools_str, insert_tools_idx)
+
+ # Step 3 – append generation prefix if needed
+ if add_generation_prompt:
+ self._maybe_add_generation_prompt(elements, roles)
+
+ # Concatenate the prompt
+ prompt = "".join(elements)
+ return prompt, elements, roles
- if i == 0 and self._detect_role(message["role"]) == Role.SYSTEM:
- system_message = self._encode_system_message(message["content"], tools=tools)
- elements.append(system_message)
- roles.append(Role.SYSTEM)
- # This message is done
+ def _insert_tools(self, messages: List[Dict], tools):
+ """Clone *messages* and compute where (and how) the tool catalogue
+ should be injected.
+
+ Returns
+ -------
+ work_messages : List[Dict]
+ A deepcopy of the original *messages* so we never mutate caller data.
+ tools_str : Optional[str]
+ The formatted tool catalogue or *None* if `tools` is falsy.
+ insert_tools_idx : int
+ Index of the *user* message that receives the catalogue, or -1 when
+ no injection is required.
+ """
+
+ work_messages = deepcopy(messages)
+ if tools:
+ tools_str = self.tool_policy.format_tools(tools)
+ placement = self.tool_policy.placement
+ insert_tools_idx = self._find_insert_tools_index(work_messages, placement)
+ else:
+ tools_str = None
+ insert_tools_idx = -1
+ return work_messages, tools_str, insert_tools_idx
+
+ def _encode_turns(
+ self,
+ work_messages: List[Dict],
+ tools_str: str,
+ insert_tools_idx: int,
+ ) -> Tuple[List[str], List[Role]]:
+ """Convert every message dict into its textual representation while
+ tracking roles for later masking logic."""
+
+ elements: List[str] = []
+ roles: List[Role] = []
+
+ # Global prefix comes first (rarely used but must respect ordering)
+ if self.global_policy and self.global_policy.prefix:
+ elements.append(self.global_policy.prefix)
+ roles.append(Role.SYSTEM)
+
+ for i, message in enumerate(work_messages):
+ current_role = self._detect_role(message["role"])
+
+ # --------------------------------------------------------------
+ # Handle system message insertion on the very first turn
+ # --------------------------------------------------------------
+ if i == 0 and current_role == Role.SYSTEM:
+ if self.system_policy.use_system:
+ system_message = self._encode_system_message(
+ message["content"], tools=tools_str
+ )
+ elements.append(system_message)
+ roles.append(Role.SYSTEM)
+ # Whether inserted or not, we skip further handling of this
+ # message because it's the (optional) system turn itself.
continue
- elif i == 0 and self._detect_role(message["role"]) != Role.SYSTEM:
- system_message = self._encode_system_message_default(tools=tools)
- elements.append(system_message)
- roles.append(Role.SYSTEM)
- # This message is not done, we need to handle other roles
-
- if self._detect_role(message["role"]) == Role.USER:
- user_message = self._encode_user_message(message["content"])
+ elif i == 0 and current_role != Role.SYSTEM:
+ if self.system_policy.use_system:
+ system_message = self._encode_system_message_default(tools=tools_str)
+ elements.append(system_message)
+ roles.append(Role.SYSTEM)
+ # Do *not* `continue` – we still need to encode this first message.
+
+ # --------------------------------------------------------------
+ # Encode regular conversation turns
+ # --------------------------------------------------------------
+ if current_role == Role.USER:
+ if i == insert_tools_idx:
+ user_message = self._encode_user_message_with_tools(
+ message["content"], tools=tools_str
+ )
+ else:
+ user_message = self._encode_user_message(message["content"])
elements.append(user_message)
roles.append(Role.USER)
- elif self._detect_role(message["role"]) == Role.ASSISTANT:
+
+ elif current_role == Role.ASSISTANT:
assistant_message = self._encode_assistant_message(message["content"])
elements.append(assistant_message)
roles.append(Role.ASSISTANT)
- elif self._detect_role(message["role"]) == Role.TOOL:
+
+ elif current_role == Role.TOOL:
tool_message = self._encode_tool_message(message["content"])
elements.append(tool_message)
roles.append(Role.TOOL)
+
else:
raise ValueError(f"Invalid role: {message['role']}")
-
- if add_generation_prompt:
- generation_prefix = self._encode_generation_prompt()
- elements.append(generation_prefix)
- roles.append(Role.ASSISTANT_PREFIX)
-
- prompt = "".join(elements)
- return prompt, elements, roles
+
+ return elements, roles
+
+ def _maybe_add_generation_prompt(self, elements: List[str], roles: List[Role]):
+ """Append the generation prefix so the model knows to continue
+ generating an assistant response."""
+
+ generation_prefix = self._encode_generation_prompt()
+ elements.append(generation_prefix)
+ roles.append(Role.ASSISTANT_PREFIX)
def _detect_role(self, role: str) -> Role:
if role == "system":
@@ -137,53 +281,114 @@ def _detect_role(self, role: str) -> Role:
return Role.TOOL
else:
raise ValueError(f"Invalid role: {role}")
+
+ def _find_insert_tools_index(self, work_messages: List[Dict], placement: ToolPlacement) -> int:
+ insert_tools_idx = 0 # Default to insert tools at system message
+ for i, message in enumerate(work_messages):
+ if placement == ToolPlacement.SYSTEM:
+ insert_tools_idx = 0
+ elif placement == ToolPlacement.FIRST_USER:
+ if message.get("role") == "user":
+ insert_tools_idx = i
+ break
+ elif placement == ToolPlacement.LAST_USER:
+ if message.get("role") == "user":
+ insert_tools_idx = i
+ else:
+ raise ValueError(f"Unhandled ToolPlacement: {placement}")
+ return insert_tools_idx
def _encode_system_tools(self, tools: List[Dict]) -> str:
return "\n".join([json.dumps(tool) for tool in tools])
def _encode_system_message_default(self, tools=None) -> str:
+ if not self.system_policy.use_system_without_system_message:
+ return ""
+
+ if self.system_policy.content_processor is not None:
+ system_message = self.system_policy.content_processor(self.system_message)
+ else:
+ system_message = self.system_message
+
if tools is None:
- return self.system_template.format(system_message=self.system_message)
+ return self.system_template.format(system_message=system_message)
else:
if self.system_template_with_tools:
- return self.system_template_with_tools.format(system_message=self.system_message, tools=tools)
+ return self.system_template_with_tools.format(system_message=system_message, tools=tools)
else:
- return self.system_template.format(system_message=self.system_message)
+ return self.system_template.format(system_message=system_message)
def _encode_system_message(self, content, tools=None) -> str:
- if tools is None:
+ # Handle both string content and list content formats
+ if isinstance(content, str):
+ system_message = content
+ else:
system_message = content[0]['text']
+
+ if self.system_policy.content_processor is not None:
+ system_message = self.system_policy.content_processor(system_message)
+
+ if tools is None:
return self.system_template.format(system_message=system_message)
else:
- system_message = content[0]['text']
if self.system_template_with_tools is None:
return self.system_template.format(system_message=system_message)
else:
return self.system_template_with_tools.format(system_message=system_message, tools=tools)
-
- def _encode_user_message(self, content: List[Dict]) -> str:
- text = ""
- for item in content:
- if item["type"] == "text":
- text += item["text"]
- elif item["type"] == "image":
- text += self.vision_start + self.image_token + self.vision_end
- elif item["type"] == "video":
- text += self.vision_start + self.video_token + self.vision_end
- else:
- raise ValueError(f"Invalid message type: {item['type']}")
+
+ def _encode_user_message_with_tools(self, content, tools: str) -> str:
+ # Handle both string content and list content formats
+ if isinstance(content, str):
+ text = content
+ else:
+ text = ""
+ for item in content:
+ if item["type"] == "text":
+ text += item["text"]
+ else:
+ raise ValueError(f"Invalid message type: {item['type']}")
+
+ if self.user_template_with_tools:
+ user_message = self.user_template_with_tools.format(content=text, tools=tools)
+ else:
+ user_message = self.user_template.format(content=text)
+ return user_message
+
+ def _encode_user_message(self, content) -> str:
+ # Handle both string content and list content formats
+ if isinstance(content, str):
+ text = content
+ else:
+ text = ""
+ for item in content:
+ if item["type"] == "text":
+ text += item["text"]
+ elif item["type"] in ["image", "image_url"]:
+ text += self.vision_start + self.image_token + self.vision_end
+ elif item["type"] == "video":
+ text += self.vision_start + self.video_token + self.vision_end
+ else:
+ raise ValueError(f"Invalid message type: {item['type']}")
user_message = self.user_template.format(content=text)
return user_message
- def _encode_assistant_message(self, content: List[Dict]) -> str:
- assert len(content) == 1, "Assistant message must be a single message"
- text = content[0]["text"]
+ def _encode_assistant_message(self, content) -> str:
+ # Handle both string content and list content formats
+ if isinstance(content, str):
+ text = content
+ else:
+ assert len(content) == 1, "Assistant message must be a single message"
+ text = content[0]["text"]
assistant_message = self.assistant_template.format(content=text)
return assistant_message
- def _encode_tool_message(self, content: List[Dict]) -> str:
- assert len(content) == 1, "Tool message must be a single message"
- text = content[0]["text"]
+ def _encode_tool_message(self, content) -> str:
+ # Handle both string content and list content formats
+ if isinstance(content, str):
+ text = content
+ else:
+ assert len(content) == 1, "Tool message must be a single message"
+ text = content[0]["text"]
tool_message = self.tool_template.format(observation=text)
return tool_message
@@ -208,19 +413,36 @@ def _split_assistant_message(self, assistant_message: str) -> List[str]:
return generation_prefix, content, suffix
- def encode(self, messages: List[Dict], tokenizer: PreTrainedTokenizer, return_tensors: str = None, tools=None) -> str:
- prompt, elements, roles = self.render(messages, tools=tools)
+ def encode(self, messages: List[Dict], tokenizer: PreTrainedTokenizer, return_tensors: str = None, tools=None, add_generation_prompt=False, processor=None) -> str:
+ if processor is None and self.supports_vision():
+ raise ValueError(f"Processor is required for vision templates: {self.name}")
+
+ if self.supports_vision():
+ # Use vision-aware encoding with proper alignment
+ return self._encode_with_vision_processor(messages, tokenizer, return_tensors, tools, add_generation_prompt=add_generation_prompt, processor=processor)
+ else:
+ # Use standard encoding
+ return self._encode_standard(messages, tokenizer, return_tensors, tools, add_generation_prompt=add_generation_prompt)
+
+ def _encode_standard(self, messages: List[Dict], tokenizer: PreTrainedTokenizer, return_tensors: str = None, tools=None, add_generation_prompt=False) -> str:
+ Logger.debug(f"[Template] Encoding standard for template: {self.name}")
+ """Standard encoding without vision support"""
+ prompt, elements, roles = self.render(messages, tools=tools, add_generation_prompt=add_generation_prompt)
elements, mask_flags = self._postprocess_elements(elements, roles)
input_ids = []
attention_mask = []
labels = []
action_mask = []
- if tokenizer.bos_token and tokenizer.add_bos_token:
- input_ids.append(tokenizer.bos_token_id)
- attention_mask.append(1)
- labels.append(-100)
- action_mask.append(0)
+
+ if tokenizer.bos_token:
+ # If add_bos_token is not set, we assume to add bos token
+ # There is potential issue if the tokenizer has bos_token but do not add it by default
+ if getattr(tokenizer, "add_bos_token", True):
+ input_ids.append(tokenizer.bos_token_id)
+ attention_mask.append(1)
+ labels.append(-100)
+ action_mask.append(0)
for element, mask_flag in zip(elements, mask_flags):
cur_input_ids = tokenizer.encode(element, add_special_tokens=False)
@@ -241,6 +463,41 @@ def encode(self, messages: List[Dict], tokenizer: PreTrainedTokenizer, return_te
if return_tensors == "pt":
inputs = {k: torch.tensor([v]) for k, v in inputs.items()}
return inputs
+
+ def _encode_with_vision_processor(self, messages: List[Dict], tokenizer: PreTrainedTokenizer, return_tensors: str = None, tools=None, add_generation_prompt=False, processor=None) -> str:
+ Logger.debug(f"[Template] Encoding with vision processor for template: {self.name}")
+ """Encode with vision processor handling proper alignment"""
+ from .vision_processor import get_processor
+ from .utils import extract_vision_inputs_from_messages
+
+ # Get vision processor
+ vision_processor = get_processor(self.name)
+ if vision_processor is None:
+ raise ValueError(f"No vision processor registered for template: {self.name}")
+
+ # Get base prompt and mask information
+ prompt, elements, roles = self.render(messages, tools=tools, add_generation_prompt=add_generation_prompt)
+ elements, mask_flags = self._postprocess_elements(elements, roles)
+
+ # Extract vision inputs
+ images, videos = extract_vision_inputs_from_messages(messages)
+
+ Logger.debug(f"[Template] images: {len(images)}")
+ Logger.debug(f"[Template] videos: {len(videos)}")
+
+ Logger.debug(f"[Template] messages: {messages}")
+
+ # Use vision processor with alignment support
+ return vision_processor.process_for_llm(
+ prompt=prompt,
+ elements=elements,
+ mask_flags=mask_flags,
+ images=images,
+ videos=videos,
+ processor=processor,
+ tokenizer=tokenizer,
+ return_tensors=return_tensors
+ )
def _postprocess_elements(self, elements: List[str], roles) -> List[str]:
@@ -302,67 +559,112 @@ def _postprocess_elements(self, elements: List[str], roles) -> List[str]:
merged_mask_flags.append(prev_mask_flag)
return merged_elements, merged_mask_flags
+ def supports_vision(self) -> bool:
+ """Check if this template supports vision processing"""
+ return is_vision_template(self.name)
- def get_vision_inputs(self):
+ def get_vision_inputs(self, messages: List[Dict]):
vision_inputs = defaultdict(list)
- for role, message, _ in self.messages:
- if isinstance(message, list):
- for item in message:
+ Logger.debug(f"[Template] get_vision_inputs: messages: {messages}")
+ for message in messages:
+ content = message["content"]
+ if isinstance(content, list):
+ for item in content:
if item['type'] == 'text':
continue
- elif item['type'] == 'image':
- vision_inputs['image'].append(open_image_from_any(item['image']))
+ elif item['type'] in ['image', 'image_url', 'image_base64']:
+ vision_inputs["image"].append(open_image_from_any(item[item['type']]))
elif item['type'] == 'video':
raise NotImplementedError("Video is not supported for chat template.")
else:
raise ValueError(f"Invalid message type: {item['type']}")
+ else:
+ raise ValueError(f"Invalid message content: {content}, the content should be a list of dicts")
return vision_inputs
- # if self.name == "qwen2.5-vl":
- # jinja_template = """{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['role'] == 'tool'%}\n{% endif %}{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% if message['role'] == 'tool' %}\n\n{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"""
-
def jinja_template(self) -> str:
- """
- Build a Hugging-Face-style chat-template (Jinja-mini dialect) that mimics
- `self.render`. The template expects three variables in its context:
+ if self.chat_template:
+ return self.chat_template
+ else:
+ return self.render_jinja_template()
- • messages – list[dict] (same format you pass to .render)
- • add_generation_prompt – bool (default False)
- • tools – list[dict] (optional, for tool-enabled templates)
+ def render_jinja_template(self) -> str:
+ """Return a Hugging-Face style chat-template (Jinja-mini dialect).
- No other Python state is referenced, so the string can be cached in the
- tokenizer and shipped to a different process.
+ The implementation now mirrors the three-step structure of
+ `render()` for easier maintenance:
+
+ 1. _jinja_header_constants – immutable `set` statements
+ 2. _jinja_system_block – first turn / system handling
+ 3. _jinja_loop_messages – remaining turns & per-role logic
+ 4. _jinja_generation_block – optional generation prefix
"""
- # ------------------------------------------------------------------
- # 1. Pre-compute constant strings so the inner template stays tiny
- # ------------------------------------------------------------------
- default_system = self.system_template.format(system_message=self.system_message)
+
+ parts: List[str] = []
+
+ # 1. Constant header (always first)
+ parts.extend(self._jinja_header_constants())
+
+ # 2. System-message handling (depends on presence of tools etc.)
+ parts.extend(self._jinja_system_block())
+
+ # 2.5 Pre-compute insert index for user placement
+ parts.extend(self._jinja_compute_insert_idx())
+
+ # 3. Loop over remaining messages
+ parts.extend(self._jinja_loop_messages())
+
+ # 4. Generation prefix block
+ parts.extend(self._jinja_generation_block())
+
+ template_str = "".join(parts)
- # Don't pre-format system_template_with_tools - handle it in Jinja
- system_template_with_tools_raw = self.system_template_with_tools if self.system_template_with_tools else None
+ # Post-process: Replace __CURRENT_DATE__ placeholder with actual date
+ if "__CURRENT_DATE__" in template_str:
+ from datetime import datetime
+ current_date = datetime.now().strftime('%d %b %Y')
+ template_str = template_str.replace("__CURRENT_DATE__", current_date)
+
+ return template_str
+
+ # ------------------------------------------------------------------
+ # Private helpers – keep them together for readability
+ # ------------------------------------------------------------------
+
+ def _jinja_header_constants(self) -> List[str]:
+ """Return Jinja `set` statements for all constant strings."""
+ # Compute default system message considering content processor
+ if self.system_policy.content_processor is not None:
+ # Apply content processor to system message
+ processed_system_message = self.system_policy.content_processor(self.system_message)
+ default_system = self.system_template.format(system_message=processed_system_message)
+ else:
+ default_system = self.system_template.format(system_message=self.system_message)
+
+ system_template_with_tools_raw = (
+ self.system_template_with_tools if self.system_template_with_tools else None
+ )
+
+ # Split templates
try:
u_pref, u_suff = self.user_template.split("{content}")
a_pref, a_suff = self.assistant_template.split("{content}")
- except ValueError as e: # missing {content}
- raise ValueError("`user_template` / `assistant_template` must contain "
- "`{content}` placeholder") from e
+ except ValueError as exc:
+ raise ValueError(
+ "`user_template` / `assistant_template` must contain `{content}` placeholder"
+ ) from exc
if self.tool_template:
t_pref, t_suff = self.tool_template.split("{observation}")
- else: # tools optional
+ else:
t_pref, t_suff = "", ""
- # tokens for images / videos
+ # Tokens for images / videos
img_tok = (self.vision_start or "") + (self.image_token or "") + (self.vision_end or "")
vid_tok = (self.vision_start or "") + (self.video_token or "") + (self.vision_end or "")
- # ------------------------------------------------------------------
- # 2. Assemble the Jinja text; everything in triple-quotes is copied
- # verbatim into the tokenizer. We splice in the constants that
- # never change for this Template instance.
- # ------------------------------------------------------------------
- template_parts = [
+ header = [
f"{{% set _u_pref = {u_pref!r} %}}",
f"{{% set _u_suff = {u_suff!r} %}}",
f"{{% set _a_pref = {a_pref!r} %}}",
@@ -374,39 +676,146 @@ def jinja_template(self) -> str:
f"{{% set _default_system = {default_system!r} %}}",
f"{{% set _system_message = {self.system_message!r} %}}",
f"{{% set _system_template = {self.system_template!r} %}}",
+ f"{{% set _tool_placement = {self.tool_policy.placement.name!r} %}}",
]
-
- # Add system template with tools if available
+
if system_template_with_tools_raw:
- template_parts.append(f"{{% set _system_template_with_tools = {system_template_with_tools_raw!r} %}}")
-
- template_parts.extend([
+ header.append(
+ f"{{% set _system_template_with_tools = {system_template_with_tools_raw!r} %}}"
+ )
+
+ # Add user template with tools if it exists
+ if self.user_template_with_tools:
+ # Convert double braces to single braces for Jinja compatibility
+ processed_template = self.user_template_with_tools.replace('{{', '{').replace('}}', '}')
+ header.append(
+ f"{{% set _u_template_with_tools = {processed_template!r} %}}"
+ )
+
+ # ------------------------------------------------------------------
+ # Formatter macro for tools (only if the template supports tool calls)
+ # ------------------------------------------------------------------
+
+ if self._supports_tool_call():
+ # Build a Jinja macro that reproduces ToolPolicy.format_tools behaviour
+ formatter_snippet = self.tool_policy.formatter.jinja()
+
+ # The snippet usually comes wrapped in "{{ ... }}". We drop the
+ # outer braces because macro bodies are already an output context.
+ formatter_body = formatter_snippet.strip()
+
+ header.extend(
+ [
+ "{% macro _fmt_tools(tools) %}",
+ f"{formatter_body}",
+ "{% endmacro %}",
+ ]
+ )
+
+ # ------------------------------------------------------------------
+ # System processor macro (if system policy has a content processor)
+ # ------------------------------------------------------------------
+
+ if self.system_policy.content_processor is not None:
+ # Build a Jinja macro that reproduces the system content processor behaviour
+ processor_snippet = self.system_policy.content_processor.jinja()
+
+ # The snippet should be a template that expects 'system_message' variable
+ # We create a macro that can be called with the system message
+ header.extend(
+ [
+ "{% macro _process_system_message(system_message) %}",
+ f"{processor_snippet}",
+ "{% endmacro %}",
+ ]
+ )
+
+ return header
+
+ def _jinja_compute_insert_idx(self) -> List[str]:
+ """Return Jinja code that pre-computes the index where tools should
+ be injected for FIRST_USER and LAST_USER placements."""
+
+ return [
+ "{% set _insert_ns = namespace(idx=-1) %}",
+ "{% if _tool_placement in ['FIRST_USER', 'LAST_USER'] %}",
+ "{%- for _m in messages -%}",
+ "{%- if _m['role'] == 'user' -%}",
+ "{%- if _tool_placement == 'FIRST_USER' and _insert_ns.idx == -1 -%}",
+ "{% set _insert_ns.idx = loop.index0 %}",
+ "{%- elif _tool_placement == 'LAST_USER' -%}",
+ "{% set _insert_ns.idx = loop.index0 %}",
+ "{%- endif -%}",
+ "{%- endif -%}",
+ "{%- endfor -%}",
+ "{% endif %}",
+ ]
+
+ def _jinja_system_block(self) -> List[str]:
+ """Return Jinja code that handles the system message logic."""
+
+ return [
# Handle system message first (matching render logic)
"{% if messages and messages[0]['role'] == 'system' %}",
"{% if tools and _system_template_with_tools %}",
"{% if messages[0]['content'] is string %}",
- "{{ _system_template_with_tools.format(system_message=messages[0]['content'], tools=tools | map('tojson') | join('\\n')) }}",
+ "{% if _process_system_message is defined %}",
+ "{{ _system_template_with_tools.format(system_message=_process_system_message(messages[0]['content']), tools=_fmt_tools(tools)) }}",
+ "{% else %}",
+ "{{ _system_template_with_tools.format(system_message=messages[0]['content'], tools=_fmt_tools(tools)) }}",
+ "{% endif %}",
"{% else %}",
- "{{ _system_template_with_tools.format(system_message=messages[0]['content'][0]['text'], tools=tools | map('tojson') | join('\\n')) }}",
+ "{% if _process_system_message is defined %}",
+ "{{ _system_template_with_tools.format(system_message=_process_system_message(messages[0]['content'][0]['text']), tools=_fmt_tools(tools)) }}",
+ "{% else %}",
+ "{{ _system_template_with_tools.format(system_message=messages[0]['content'][0]['text'], tools=_fmt_tools(tools)) }}",
+ "{% endif %}",
"{% endif %}",
"{% else %}",
"{% if messages[0]['content'] is string %}",
+ "{% if _process_system_message is defined %}",
+ "{% set processed_message = _process_system_message(messages[0]['content']) %}",
+ "{% set formatted_system = _system_template | replace('{system_message}', processed_message) %}{{ formatted_system }}",
+ "{% else %}",
"{% set formatted_system = _system_template | replace('{system_message}', messages[0]['content']) %}{{ formatted_system }}",
+ "{% endif %}",
+ "{% else %}",
+ "{% if _process_system_message is defined %}",
+ "{% set processed_message = _process_system_message(messages[0]['content'][0]['text']) %}",
+ "{% set formatted_system = _system_template | replace('{system_message}', processed_message) %}{{ formatted_system }}",
"{% else %}",
"{% set formatted_system = _system_template | replace('{system_message}', messages[0]['content'][0]['text']) %}{{ formatted_system }}",
"{% endif %}",
"{% endif %}",
+ "{% endif %}",
"{% else %}",
"{% if tools and _system_template_with_tools %}",
- "{{ _system_template_with_tools.format(system_message=_system_message, tools=tools | map('tojson') | join('\\n')) }}",
+ "{% if _process_system_message is defined %}",
+ "{{ _system_template_with_tools.format(system_message=_process_system_message(_system_message), tools=_fmt_tools(tools)) }}",
+ "{% else %}",
+ "{{ _system_template_with_tools.format(system_message=_system_message, tools=_fmt_tools(tools)) }}",
+ "{% endif %}",
+ "{% else %}",
+ "{% if _process_system_message is defined %}",
+ "{% set processed_message = _process_system_message(_system_message) %}",
+ "{% set formatted_system = _system_template | replace('{system_message}', processed_message) %}{{ formatted_system }}",
"{% else %}",
"{{ _default_system }}",
"{% endif %}",
"{% endif %}",
+ "{% endif %}",
+ ]
+
+ def _jinja_loop_messages(self) -> List[str]:
+ """Return Jinja loop that encodes all messages except the first system."""
+
+ return [
+ "{% set _tool_ns = namespace(inserted=False, user_count=0) %}",
# Process remaining messages (skip first if it was system)
"{% for m in messages %}",
"{% if not (loop.first and m['role'] == 'system') %}",
"{% if m['role'] == 'user' %}",
+ "{% set _tool_ns.user_count = _tool_ns.user_count + 1 %}",
"{% set ns = namespace(txt='') %}",
"{% if m['content'] is string %}",
"{% set ns.txt = m['content'] %}",
@@ -421,7 +830,17 @@ def jinja_template(self) -> str:
"{% endif %}",
"{% endfor %}",
"{% endif %}",
+ "{% if tools and ((_tool_placement == 'FIRST_USER' and _tool_ns.user_count == 1) or (_tool_placement == 'LAST_USER' and loop.index0 == _insert_ns.idx)) and not _tool_ns.inserted %}",
+ "{% if _u_template_with_tools is defined %}",
+ "{% set formatted_tools = _fmt_tools(tools) %}",
+ "{{ _u_template_with_tools | replace('{content}', ns.txt) | replace('{tools}', formatted_tools) }}",
+ "{% else %}",
+ "{{ _u_pref }}{{ ns.txt }}{{ _u_suff }}\\n{{ _fmt_tools(tools) }}",
+ "{% endif %}",
+ "{% set _tool_ns.inserted = True %}",
+ "{% else %}",
"{{ _u_pref }}{{ ns.txt }}{{ _u_suff }}",
+ "{% endif %}",
"{% elif m['role'] == 'assistant' %}",
"{% if m['content'] is string %}",
"{{ _a_pref }}{{ m['content'] }}{{ _a_suff }}",
@@ -437,12 +856,16 @@ def jinja_template(self) -> str:
"{% endif %}",
"{% endif %}",
"{% endfor %}",
+ ]
+
+ def _jinja_generation_block(self) -> List[str]:
+ """Return Jinja code that appends the generation prefix when requested."""
+
+ return [
"{% if add_generation_prompt %}",
"{{ _a_pref }}",
- "{% endif %}"
- ])
-
- return "".join(template_parts)
+ "{% endif %}",
+ ]
def render_with_mask(self, messages: List[Dict], add_generation_prompt: bool = False, tools=None):
@@ -463,14 +886,6 @@ def set_system_message(self, system_message: str):
"""Set the system message."""
self.system_message = system_message
- def set_tools(self, tools: List[Dict]):
- """Set the tools."""
- if self.tool_aggregator == "DEFAULT":
- self.tools = json.dumps(tools)
- elif self.tool_aggregator == "STACKED":
- self.tools = "\n".join([json.dumps(tool) for tool in tools])
- else:
- raise ValueError(f"Invalid tool aggregator: {self.tool_aggregator}")
def copy(self):
return Template(
@@ -479,6 +894,7 @@ def copy(self):
system_template_with_tools=self.system_template_with_tools,
system_message=self.system_message,
user_template=self.user_template,
+ user_template_with_tools=self.user_template_with_tools,
assistant_template=self.assistant_template,
tool_template=self.tool_template,
stop_words=self.stop_words,
@@ -486,6 +902,10 @@ def copy(self):
vision_end=self.vision_end,
image_token=self.image_token,
video_token=self.video_token,
+ global_policy=deepcopy(self.global_policy),
+ system_policy=deepcopy(self.system_policy),
+ tool_policy=deepcopy(self.tool_policy),
+ chat_template=self.chat_template,
)
def dict(self):
@@ -554,12 +974,15 @@ def prompt_with_mask(self, add_generation_prompt=False, tools=None) -> str:
prompt_with_mask, _, _ = self.template.render_with_mask(messages=self.messages, add_generation_prompt=add_generation_prompt, tools=tools)
return prompt_with_mask
- def tokenize(self, tokenizer: PreTrainedTokenizer = None, add_generation_prompt=False, tools=None) -> List[int]:
+ def vision_inputs(self) -> List[Any]:
+ return self.template.get_vision_inputs(self.messages)
+
+ def tokenize(self, tokenizer: PreTrainedTokenizer = None, add_generation_prompt=False, tools=None, processor=None) -> List[int]:
if tokenizer is None:
tokenizer = self.tokenizer
if tools is None:
tools = self.tools
- return self.template.encode(messages=self.messages, tokenizer=tokenizer, return_tensors="pt", tools=tools)
+ return self.template.encode(messages=self.messages, tokenizer=tokenizer, return_tensors="pt", tools=tools, add_generation_prompt=add_generation_prompt, processor=processor)
def append(self, message: Union[Dict, List[Dict]]):
self._convert_single_message_to_hf_format(message)
@@ -594,13 +1017,26 @@ def get_template(name: str) -> Template:
assistant_template="<|im_start|>assistant\n{content}<|im_end|>\n",
tool_template="<|im_start|>user\n\n{observation}\n<|im_end|>\n",
stop_words=["<|im_end|>"],
+ )
+)
+
+register_template(
+ Template(
+ name="qwen2.5-vl",
+ system_template="<|im_start|>system\n{system_message}<|im_end|>\n",
+ system_message="You are a helpful assistant.",
+ user_template="<|im_start|>user\n{content}<|im_end|>\n",
+ assistant_template="<|im_start|>assistant\n{content}<|im_end|>\n",
+ tool_template="<|im_start|>tool\n{observation}<|im_end|>\n",
vision_start="<|vision_start|>",
vision_end="<|vision_end|>",
image_token="<|image_pad|>",
video_token="<|video_pad|>",
+ stop_words=["<|im_end|>"],
)
)
+
register_template(
Template(
name="qwen2.5",
@@ -611,10 +1047,6 @@ def get_template(name: str) -> Template:
assistant_template="<|im_start|>assistant\n{content}<|im_end|>\n",
tool_template="<|im_start|>user\n\n{observation}\n<|im_end|>\n",
stop_words=["<|im_end|>"],
- vision_start="<|vision_start|>",
- vision_end="<|vision_end|>",
- image_token="<|image_pad|>",
- video_token="<|video_pad|>",
)
)
@@ -648,48 +1080,98 @@ def get_template(name: str) -> Template:
)
)
-# register_conv_template(
-# Template(
-# name="qwen3",
-# system_template="<|im_start|>system\n{system_message}<|im_end|>\n",
-# # system_message="",
-# system_template_tool="""<|im_start|>system\n{system_message}# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n\n{tools}\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{{"name": , "arguments": }}\n<|im_end|>\n""",
-# roles=("<|im_start|>user\n", "<|im_start|>assistant\n", "<|im_start|>user\n\n{observation}\n"),
-# tool_aggregator="STACKED",
-# sep_style=SeparatorStyle.GENERAL_STOP_ALL,
-# sep="",
-# stop_token_ids=[
-# 151643,
-# 151644,
-# 151645,
-# ], # "<|endoftext|>", "<|im_start|>", "<|im_end|>"
-# stop_str="<|im_end|>\n",
-# )
-# )
-# register_conv_template(
+
+# TODO: mistral template has many cornor cases, leave it for now
+# register_template(
# Template(
-# name="qwen2.5-vl",
-# system_template="<|im_start|>system\n{system_message}<|im_end|>\n",
-# system_message="You are a helpful assistant.",
-# roles=("<|im_start|>user\n", "<|im_start|>assistant\n", "<|im_start|>tool\n\n{observation}\n"),
-# tool_aggregator="STACKED",
-# sep_style=SeparatorStyle.GENERAL_STOP_ALL,
-# sep="",
-# vision_start="<|vision_start|>",
-# vision_end="<|vision_end|>",
-# image_token="<|image_pad|>",
-# video_token="<|video_pad|>",
-# stop_token_ids=[
-# 151643,
-# 151644,
-# 151645,
-# ], # "<|endoftext|>", "<|im_start|>", "<|im_end|>"
-# stop_str="<|im_end|>\n",
+# name="mistral",
+# system_template="{system_message}",
+# user_template="[INST] {content}[/INST] ",
+# user_template_with_tools="[AVAILABLE TOOLS] {tools} [/AVAILABLE TOOLS] [INST] {content}[/INST] ",
+# assistant_template="{content}",
+# tool_template="{observation}",
+# stop_words=[""],
+# system_policy=SystemPolicy(
+# use_system=False,
+# ),
+# tool_policy=ToolPolicy(
+# placement=ToolPlacement.LAST_USER,
+# formatter=JsonCompactFormatter()
+# )
# )
# )
+# TODO: system template includes current date
+register_template(
+ Template(
+ name="llama-3.2",
+ system_template="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>",
+ system_template_with_tools="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\n{system_message}<|eot_id|>",
+ user_template="<|start_header_id|>user<|end_header_id|>\n\n{content}<|eot_id|>",
+ user_template_with_tools="""<|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {{"name": function name, "parameters": dictionary of argument name and its value}}.Do not use variables.\n\n{tools}\n\n{content}<|eot_id|>""",
+ assistant_template="<|start_header_id|>assistant<|end_header_id|>\n\n{content}<|eot_id|>",
+ tool_template="""<|start_header_id|>ipython<|end_header_id|>\n\n"{observation}"<|eot_id|>""",
+ stop_words=["<|eot_id|>"],
+ system_policy=SystemPolicy(
+ use_system=True,
+ content_processor=Llama32DateProcessor(),
+ ),
+ tool_policy=ToolPolicy(
+ placement=ToolPlacement.FIRST_USER,
+ formatter=JsonIndentedFormatter()
+ )
+ )
+)
+
+register_template(
+ Template(
+ name="glm-4",
+ system_template="<|system|>\n{system_message}",
+ user_template="<|user|>\n{content}",
+ assistant_template="<|assistant|>\n{content}",
+ stop_words=[""],
+ global_policy=GlobalPolicy(
+ prefix="[gMASK]"
+ ),
+ system_policy=SystemPolicy(
+ use_system=True,
+ use_system_without_system_message=False,
+ ),
+ )
+)
+
+register_template(
+ Template(
+ name="phi-4",
+ system_template="<|im_start|>system<|im_sep|>{system_message}<|im_end|>",
+ user_template="<|im_start|>user<|im_sep|>{content}<|im_end|>",
+ assistant_template="<|im_start|>assistant<|im_sep|>{content}<|im_end|>",
+ stop_words=["<|im_end|>"],
+ )
+)
+# Note: Partial align, some minor new-line problems.
+register_template(
+ Template(
+ name="nemotron",
+ system_template="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{system_message}<|eot_id|>",
+ system_template_with_tools="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{system_message}{tools}<|eot_id|>""",
+ user_template="<|start_header_id|>user<|end_header_id|>\n\n{content}<|eot_id|>",
+ assistant_template="<|start_header_id|>assistant<|end_header_id|>\n\n{content}<|eot_id|>",
+ tool_template="<|start_header_id|>user<|end_header_id|>\n\n[{observation}]<|eot_id|>",
+ stop_words=["<|eot_id|>"],
+ system_policy=SystemPolicy(
+ use_system=True,
+ content_processor=lambda x: f"\n{x}",
+ ),
+ tool_policy=ToolPolicy(
+ placement=ToolPlacement.SYSTEM,
+ content_processor=ToolMainContentProcessor(),
+ formatter=JsonCompactFormatter(),
+ )
+ )
+)
if __name__ == "__main__":
- pass
+ pass
\ No newline at end of file
diff --git a/agents/agents/agents/templates/tool_policy.py b/agents/agents/agents/templates/tool_policy.py
new file mode 100644
index 0000000..377c098
--- /dev/null
+++ b/agents/agents/agents/templates/tool_policy.py
@@ -0,0 +1,237 @@
+from typing import List, Dict, Tuple
+import json
+from typing import Callable, Any # Added for content processor typing
+from abc import ABC, abstractmethod
+import dataclasses
+
+from .constants import ToolPlacement
+
+# Convert ToolFormatter into an abstract base class
+class ToolFormatter(ABC):
+ """
+ Strategy that converts an in-memory list[dict] describing tools
+ into the textual representation expected by the target model.
+ """
+ @abstractmethod
+ def format(self, tools: List[Dict]) -> str:
+ """Format a list of tool dictionaries into a string representation."""
+ raise NotImplementedError
+
+ @abstractmethod
+ def jinja(self) -> str:
+ """Return a Jinja template that can be used to format the tools."""
+ raise NotImplementedError
+
+
+class ToolContentProcessor(ABC):
+ """
+ Strategy that processes the content of a tool before it is serialized.
+ """
+ @abstractmethod
+ def __call__(self, tool: Dict) -> Dict:
+ raise NotImplementedError
+
+ @abstractmethod
+ def jinja(self) -> str:
+ """Return a Jinja template that can be used to process the content of a tool."""
+ raise NotImplementedError
+
+
+class ToolMainContentProcessor(ToolContentProcessor):
+ """
+ Strategy that processes the main content of a tool before it is serialized.
+ """
+ def __call__(self, tool: Dict) -> Dict:
+ assert isinstance(tool, dict), "Tool must be a dictionary"
+ if "function" in tool:
+ content = tool["function"]
+ assert "name" in content, "Tool function must have a name"
+ assert "parameters" in content, "Tool function must have parameters"
+ return content
+ elif "name" in tool and "parameters" in tool:
+ return tool
+ else:
+ raise ValueError(f"Tool must have a function or name and parameters: {tool}")
+
+ # The main-content extraction cannot be replicated in pure Jinja, so we
+ # fall back to the identity behaviour at template-generation time. This
+ # means the processor is *ignored* in frozen chat-templates; users who
+ # require it must rely on the Python render path.
+
+ def jinja(self) -> str:
+ # We deliberately document the limitation by returning a simple pass-
+ # through expression.
+ return "{{ tool }}"
+
+# Make JsonFormatter inherit the ToolFormatter base class
+class JsonFormatter(ToolFormatter):
+ """General JSON formatter with configurable indent, separators, and joiner."""
+
+ def __init__(
+ self,
+ *,
+ indent: int | None = None,
+ separators: Tuple[str, str] | None = None,
+ joiner: str = "\n",
+ format_as_list: bool = False,
+ content_processor: ToolContentProcessor = None,
+ ):
+ """Create a new JsonFormatter.
+
+ Args:
+ indent: Indentation level passed to ``json.dumps``. ``None`` means no pretty-print.
+ separators: Custom separators passed to ``json.dumps``; useful for minification.
+ joiner: String used to join per-tool JSON strings when ``format_as_list`` is *False*.
+ format_as_list: If *True*, the entire ``tools`` list is serialised in a single
+ ``json.dumps`` call, ignoring ``joiner``. This is handy when the target
+ model expects a single JSON array instead of multiple individual objects.
+ content_processor: Optional callable applied to each individual tool dictionary
+ before serialisation. Defaults to the identity function.
+ """
+ self.indent = indent
+ self.separators = separators
+ self.joiner = joiner
+ self.format_as_list = format_as_list
+
+ def format(self, tools: List[Dict]) -> str: # noqa: D401
+ """Return a single string obtained by dumping every tool to JSON then joining them.
+
+ Args:
+ tools: A list of tool dictionaries to be stringified.
+
+ Returns:
+ A string representation of the tools, formatted according to the
+ given ``indent``/``separators`` and concatenated with ``joiner``.
+ """
+ # Apply the per-tool content processor first
+
+ if self.format_as_list:
+ # Serialize the whole list in one go – joiner is irrelevant in this mode.
+ return json.dumps(tools, indent=self.indent, separators=self.separators)
+
+ # Default behaviour: dump each tool individually then concatenate.
+ return self.joiner.join(
+ json.dumps(t, indent=self.indent, separators=self.separators) for t in tools
+ )
+
+ # ------------------------------------------------------------------
+ # Jinja support
+ # ------------------------------------------------------------------
+
+ def _escape_joiner(self, joiner: str) -> str: # local helper
+ """Return *joiner* escaped so it is safe inside a single‐quoted Jinja
+ string literal (the HF chat-template parser understands the Python
+ backslash escapes)."""
+
+ return joiner.replace("\\", "\\\\").replace("'", "\\'")
+
+ def jinja(self) -> str: # noqa: D401
+ """Return a **Jinja-mini** snippet that serialises the *tools* variable
+ with the same settings as :py:meth:`format`.
+
+ The template assumes that a ``tools`` list is present in the Jinja
+ context. Because the Hugging-Face chat-template dialect only supports
+ a limited subset of Jinja, we restrict ourselves to `map`, `tojson`,
+ `join`, and optional indent on a *single* tojson call when
+ ``format_as_list`` is *True*.
+
+ When ``format_as_list`` is *False* and ``indent`` is specified, we use
+ a Jinja loop to apply indentation to each individual tool.
+ """
+
+ # Serialise whole list -> one tojson call (supports indent argument)
+ if self.format_as_list:
+ if self.indent is None:
+ return "{{ tools | tojson }}"
+ else:
+ return f"{{{{ tools | tojson(indent={self.indent}) }}}}"
+
+ # Individual objects: use loop if indent is needed, otherwise use map
+ if self.indent is not None:
+ # Use loop to apply indentation to each individual tool
+ # For joiners containing newlines, we need to avoid whitespace control to preserve them
+ # For other joiners, we can use whitespace control for cleaner output
+
+ if '\n' in self.joiner:
+ # Joiner contains newlines - use Jinja's string replacement to convert \n to actual newlines
+ # We'll create a Jinja variable with the proper newlines
+ joiner_var = '{% set joiner = "' + self.joiner.replace('\n', '\\n') + '" | replace("\\\\n", "\n") %}'
+ return joiner_var + f"{{% for tool in tools %}}{{{{ tool | tojson(indent={self.indent}) }}}}{{% if not loop.last %}}{{{{ joiner }}}}{{% endif %}}{{% endfor %}}"
+ else:
+ # Joiner doesn't contain newlines - safe to use whitespace control and escaping
+ joiner_escaped = self._escape_joiner(self.joiner)
+ return f"{{%- for tool in tools -%}}{{{{ tool | tojson(indent={self.indent}) }}}}{{%- if not loop.last -%}}{joiner_escaped}{{%- endif -%}}{{%- endfor -%}}"
+ else:
+ # No indentation needed, use the simpler map approach
+ joiner_escaped = self._escape_joiner(self.joiner)
+ return (
+ "{{ tools | map('tojson') | join('" + joiner_escaped + "') }}"
+ )
+
+class JsonMinifiedFormatter(JsonFormatter):
+ """Single-line JSON objects without extra whitespace (legacy alias)."""
+
+ def __init__(self, joiner: str = "\n", *, content_processor: Callable[[Dict], Any] | None = None):
+ super().__init__(indent=None, separators=(",", ":"), joiner=joiner, content_processor=content_processor)
+
+
+class JsonIndentedFormatter(JsonFormatter):
+ """
+ Pretty printed JSON with configurable indent (default 4).
+ Frequently required by models like Mistral-v0.3.
+ (legacy alias)
+ """
+
+ def __init__(self, indent: int = 4, *, joiner: str = "\n\n", format_as_list: bool = False):
+ super().__init__(indent=indent, separators=None, joiner=joiner, format_as_list=format_as_list)
+
+
+class JsonCompactFormatter(JsonFormatter):
+ """Single-line JSON objects without extra whitespace."""
+ def __init__(self, *, format_as_list: bool = True, content_processor: Callable[[Dict], Any] | None = None):
+ super().__init__(indent=None, separators=None, format_as_list=format_as_list, content_processor=content_processor)
+
+class JsonQwenFormatter(JsonFormatter):
+ """
+ JSON formatter for Qwen models.
+ """
+ def __init__(self):
+ super().__init__(indent=None, separators=None, format_as_list=False, content_processor=None)
+
+ # No special behaviour – inherits .jinja from JsonFormatter
+
+
+# ---------------------------------------------------------------------------
+# Content processors – only implement jinja where feasible
+# ---------------------------------------------------------------------------
+
+
+try:
+ import yaml as _yaml # optional dependency
+
+ class YamlFormatter(ToolFormatter): # type: ignore
+ def format(self, tools: List[Dict]) -> str: # noqa: D401
+ return _yaml.safe_dump(tools, sort_keys=False)
+except ModuleNotFoundError: # pragma: no cover
+ YamlFormatter = None # type: ignore
+
+
+@dataclasses.dataclass
+class ToolPolicy:
+ """
+ Encapsulates every configuration decision about how *tools*
+ appear in the prompt for a given template.
+ """
+ placement: "ToolPlacement" = ToolPlacement.SYSTEM
+ content_processor: Callable[[Dict], Any] = None
+ formatter: ToolFormatter = dataclasses.field(default_factory=lambda: JsonQwenFormatter())
+
+ def format_tools(self, tools: List[Dict]) -> str:
+ """
+ Convert `tools` into ready-to-inject text according to the chosen formatter.
+ """
+ if self.content_processor is not None:
+ processed_tools = [self.content_processor(t) for t in tools]
+ else:
+ processed_tools = tools
+ return self.formatter.format(processed_tools)
diff --git a/agents/agents/agents/templates/utils.py b/agents/agents/agents/templates/utils.py
index ccb2258..9f62182 100644
--- a/agents/agents/agents/templates/utils.py
+++ b/agents/agents/agents/templates/utils.py
@@ -10,7 +10,8 @@
import logging
from .templates import Chat, get_template
from ... import AGENT_DATA_DIR
-
+from typing import Any
+from .vision_processor import get_processor
# Set up logging that won't be overridden by other modules
LOGGER = logging.getLogger(__name__)
@@ -20,9 +21,6 @@ def strip_ansi(s: str) -> str:
"""Remove ANSI escape sequences from a string."""
return ANSI_RE.sub('', s)
-def is_vlm_template(template: str) -> bool:
- return template in ["qwen2.5-vl"]
-
def convert_messages_to_openai_format(messages: list) -> list:
"""
@@ -31,10 +29,10 @@ def convert_messages_to_openai_format(messages: list) -> list:
"""
messages = copy.deepcopy(messages)
for message in messages:
- if "tool_calls" in message:
- del message["tool_calls"]
- if "tool_call_id" in message:
- del message["tool_call_id"]
+ # if "tool_calls" in message:
+ # del message["tool_calls"]
+ # if "tool_call_id" in message:
+ # del message["tool_call_id"]
if "tool_choice" in message:
del message["tool_choice"]
return messages
@@ -139,12 +137,15 @@ def tokenize_conversation(
:return: input_ids, attention_mask, labels, action_mask
"""
chat = Chat(template=template, messages=messages, tokenizer=tokenizer)
- inputs = chat.tokenize(tokenizer, add_generation_prompt=add_generation_prompt, tools=tools)
+ inputs = chat.tokenize(tokenizer, add_generation_prompt=add_generation_prompt, tools=tools, processor=processor)
+
if max_length is not None:
inputs['input_ids'] = inputs['input_ids'][:, :max_length]
inputs['attention_mask'] = inputs['attention_mask'][:, :max_length]
- inputs['labels'] = inputs['labels'][:, :max_length]
- inputs['action_mask'] = inputs['action_mask'][:, :max_length]
+ if 'labels' in inputs:
+ inputs['labels'] = inputs['labels'][:, :max_length]
+ if 'action_mask' in inputs:
+ inputs['action_mask'] = inputs['action_mask'][:, :max_length]
return inputs
@@ -153,82 +154,85 @@ def convert_inputs_to_vision_inputs(template: str,
processor, # AutoProcessor (not bare tokenizer)
messages: list):
"""
- Expand `inputs` (built from a chat template that contains ONE
- <|image_pad|> / <|video_pad|> placeholder per asset) into the real
- processor outputs and stretch `action_mask` / `labels` so they stay
- aligned after the pad tokens are repeated.
-
- Returns
- -------
- dict -- processor(...) result + expanded masks
+ NEW PIPELINE: Template processes messages → Human-readable prompt → Vision processor → LLM-ready inputs
+
+ The correct pipeline is:
+ 1. Template processes messages to get human-readable prompt with single multi-modal tokens
+ 2. Vision processor handles image/video processing and token expansion
+ 3. Final result is directly usable by LLMs with model(**inputs)
"""
- assert template == "qwen2.5-vl", "Only qwen2.5-vl is supported"
-
-
- # ------------------------------------------------------------------
- # 1. special-token ids
- # ------------------------------------------------------------------
- tk = processor.tokenizer
- conv = get_conv_template(template)
- image_pad_id = tk.encode(conv.image_token, add_special_tokens=False)[0]
- video_pad_id = tk.encode(conv.video_token, add_special_tokens=False)[0]
- vis_start_id = tk.encode(conv.vision_start, add_special_tokens=False)[0]
- vis_end_id = tk.encode(conv.vision_end, add_special_tokens=False)[0]
-
- repeat_ids = torch.tensor([image_pad_id, video_pad_id], dtype=torch.long)
- vision_ids = torch.tensor([image_pad_id, video_pad_id,
- vis_start_id, vis_end_id], dtype=torch.long)
-
- # ------------------------------------------------------------------
- # 2. run the processor (adds patch-level vision tokens)
- # ------------------------------------------------------------------
- LOGGER.debug(f"[Template::convert_inputs_to_vision_inputs] messages: {messages}")
- imgs, vids = process_vision_info(messages)
- text = format_conversation(messages, template).get_prompt()
- proc_out = processor(
- text=text,
- images=imgs,
- videos=vids,
- return_tensors="pt",
- padding=False,
- truncation=False,
+ # Get the vision processor for this template
+ vision_processor = get_processor(template)
+ if vision_processor is None:
+ raise ValueError(f"No vision processor registered for template: {template}")
+
+ # Step 1: Template processes messages to get human-readable prompt
+ from .templates import Chat
+ chat = Chat(template=template, messages=messages, tokenizer=processor.tokenizer)
+ prompt = chat.prompt() # This gives us human-readable prompt with single multi-modal tokens
+
+ # Step 2: Extract vision inputs from messages
+ images, videos = extract_vision_inputs_from_messages(messages)
+
+ # Step 3: Vision processor handles the complete pipeline
+ # This expands tokens and generates LLM-ready inputs
+ final_inputs = vision_processor.process_for_llm(
+ prompt=prompt,
+ images=images,
+ videos=videos,
+ processor=processor,
+ tokenizer=processor.tokenizer
)
- new_ids = proc_out["input_ids"][0] # (new_len,)
- new_attention_mask = proc_out["attention_mask"][0]
- device = new_ids.device
-
- # ------------------------------------------------------------------
- # 3. original (pre-processor) tensors
- # ------------------------------------------------------------------
- old_ids = inputs["input_ids"][0].to(device)
- old_action = inputs.get("action_mask", None)
- old_labels = inputs.get("labels", None)
- old_attention_mask = inputs.get("attention_mask", None)
-
- # ------------------------------------------------------------------
- # 4. build boolean masks that mark NON-repeated positions
- # ------------------------------------------------------------------
- new_keep = ~torch.isin(new_ids, repeat_ids) # True where we copy from old
- old_keep = ~torch.isin(old_ids, repeat_ids)
-
- # sanity check
- assert new_keep.sum() == old_keep.sum(), f"Mismatch after dropping repeated vision tokens: {new_ids} {old_ids}"
-
- # ------------------------------------------------------------------
- # 5. allocate expanded tensors and copy en masse
- # ------------------------------------------------------------------
- if old_action is not None:
- exp_action = torch.zeros_like(new_ids, dtype=old_action.dtype)
- exp_action[new_keep] = old_action[0][old_keep].to(device)
- proc_out["action_mask"] = exp_action.unsqueeze(0)
+ return final_inputs
- if old_labels is not None:
- exp_labels = torch.full_like(new_ids, -100, dtype=old_labels.dtype)
- exp_labels[new_keep] = old_labels[0][old_keep].to(device)
- proc_out["labels"] = exp_labels.unsqueeze(0)
-
- return proc_out
+def extract_vision_inputs_from_messages(messages: list) -> tuple[list, list]:
+ """Extract images and videos from messages"""
+ images, videos = [], []
+
+ for message in messages:
+ if isinstance(message.get('content'), list):
+ for item in message['content']:
+ if item.get('type') in ['image', 'image_url']:
+ if 'image' in item:
+ images.append(item['image'])
+ elif 'image_url' in item:
+ images.append(item['image_url']['url'])
+ elif item.get('type') in ['video', 'video_url']:
+ if 'video' in item:
+ videos.append(item['video'])
+ elif 'video_url' in item:
+ videos.append(item['video_url']['url'])
+
+ return images, videos
+
+def process_prompt_with_vision(
+ prompt: str,
+ template: str,
+ processor: Any,
+ images: list = None,
+ videos: list = None,
+) -> dict:
+ """Process a prompt with vision support"""
+ vision_processor = get_processor(template)
+ if vision_processor is None:
+ # If no vision processor, just return tokenized prompt
+ return processor.tokenizer(
+ prompt,
+ return_tensors="pt",
+ add_special_tokens=True,
+ padding=True,
+ truncation=True
+ )
+
+ # Use vision processor to handle the complete pipeline
+ return vision_processor.process_for_llm(
+ prompt=prompt,
+ images=images or [],
+ videos=videos or [],
+ processor=processor,
+ tokenizer=processor.tokenizer
+ )
def tokenize_conversations(messages_list, tokenizer, conv_template, max_length, processor=None, return_tensors="pt", return_reward_mask=False):
@@ -245,7 +249,9 @@ def tokenize_conversations(messages_list, tokenizer, conv_template, max_length,
batch_action_masks.append(inputs['action_mask'].squeeze(0))
if return_tensors == "pt":
- batch_input_ids = torch.nn.utils.rnn.pad_sequence(batch_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
+ # Use pad_token_id from the tokenizer interface
+ pad_token_id = getattr(tokenizer, 'pad_token_id', 0)
+ batch_input_ids = torch.nn.utils.rnn.pad_sequence(batch_input_ids, batch_first=True, padding_value=pad_token_id)
batch_attention_masks = torch.nn.utils.rnn.pad_sequence(batch_attention_masks, batch_first=True, padding_value=0)
batch_labels = torch.nn.utils.rnn.pad_sequence(batch_labels, batch_first=True, padding_value=-100)
batch_action_masks = torch.nn.utils.rnn.pad_sequence(batch_action_masks, batch_first=True, padding_value=0)
@@ -299,6 +305,9 @@ def compare_hf_template(tokenizer, template_name, messages=None, tools=None, add
plain_highlighted_prompt = strip_ansi(highlighted_prompt)
is_equal_between_implemented_prompts = implemented_prompt == plain_highlighted_prompt
jinja_template = chat.template.jinja_template()
+ # Save jinja template to file
+ with open("jinja_template.jinja", "w") as f:
+ f.write(jinja_template)
tokenizer.chat_template = jinja_template
implemented_jinja_prompt = tokenizer.apply_chat_template(messages, tokenize=False, tools=tools, add_generation_prompt=add_generation_prompt)
is_equal_between_jinja_prompts = implemented_jinja_prompt == implemented_prompt
@@ -312,7 +321,9 @@ def vllm_serve(model_name_or_path, template, tp, pp, dp):
os.makedirs(f"{AGENT_DATA_DIR}/cache")
with open(f"{AGENT_DATA_DIR}/cache/jinja_template.jinja", "w") as f:
f.write(jinja_template)
- command = f"vllm serve {model_name_or_path} --chat-template {AGENT_DATA_DIR}/cache/jinja_template.jinja --tensor-parallel-size {tp} --pipeline-parallel-size {pp} --data-parallel-size {dp} --port {port} --enable-auto-tool-choice --tool-call-parser hermes --expand-tools-even-if-tool-choice-none"
+ # command = f"vllm serve {model_name_or_path} --chat-template {AGENT_DATA_DIR}/cache/jinja_template.jinja --tensor-parallel-size {tp} --pipeline-parallel-size {pp} --data-parallel-size {dp} --port {port} --enable-auto-tool-choice --tool-call-parser hermes --expand-tools-even-if-tool-choice-none"
+ command = f"vllm serve {model_name_or_path} --tensor-parallel-size {tp} --pipeline-parallel-size {pp} --data-parallel-size {dp} --port {port} --enable-auto-tool-choice --tool-call-parser hermes --expand-tools-even-if-tool-choice-none"
+
print(command)
os.system(command)
@@ -321,5 +332,6 @@ def vllm_serve(model_name_or_path, template, tp, pp, dp):
"python -m agents.agents.templates.utils"
# model = "/mnt/sharefs/users/haonan.li/models/Qwen2.5-7B-instruct-am_think_v1_distilled"
model = "Qwen/Qwen2.5-7B-Instruct"
- vllm_serve(model, "qwen2.5-think", 2, 1, 4)
+ # vllm_serve(model, "qwen2.5-think", 2, 1, 4)
+ vllm_serve(model, "qwen2.5", 1, 1, 1)
diff --git a/agents/agents/agents/templates/vision_processor.py b/agents/agents/agents/templates/vision_processor.py
new file mode 100644
index 0000000..ca14aea
--- /dev/null
+++ b/agents/agents/agents/templates/vision_processor.py
@@ -0,0 +1,668 @@
+"""
+Comprehensive multi-modal vision processor that handles vision processing separately from template processing.
+The pipeline is: Template → Human-readable prompt → Vision processor → LLM-ready inputs.
+"""
+
+import base64
+import inspect
+import math
+import os
+import re
+import urllib.parse
+import urllib.request
+from copy import deepcopy
+from dataclasses import dataclass
+from io import BytesIO
+from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union, List, Dict, Any
+from abc import ABC, abstractmethod
+
+import numpy as np
+import torch
+from PIL import Image
+from PIL.Image import Image as ImageObject
+from transformers.image_utils import get_image_size, is_valid_image, to_numpy_array
+from transformers.models.mllama.processing_mllama import (
+ convert_sparse_cross_attention_mask_to_dense,
+ get_cross_attention_token_mask,
+)
+from typing_extensions import override
+
+if TYPE_CHECKING:
+ from av.stream import Stream
+ from numpy.typing import NDArray
+ from transformers import PreTrainedTokenizer, ProcessorMixin
+ from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
+ from transformers.image_processing_utils import BaseImageProcessor
+
+ class EncodedImage(TypedDict):
+ path: Optional[str]
+ bytes: Optional[bytes]
+
+ ImageInput = Union[str, bytes, EncodedImage, BinaryIO, "ImageObject"]
+ VideoInput = Union[str, BinaryIO, list[list[ImageInput]]]
+
+ class MMProcessor(ProcessorMixin):
+ patch_size: int
+ image_seq_length: int
+ num_additional_image_tokens: int
+ vision_feature_select_strategy: Literal["default", "full"]
+
+ def _get_number_of_features(self, orig_height: int, orig_width: int, height: int, width: int) -> int:
+ pass
+
+
+@dataclass
+class VisionProcessorConfig:
+ """Configuration for vision processing"""
+ model_type: str
+ image_token: str
+ video_token: str
+ vision_start: str = ""
+ vision_end: str = ""
+ processor_class: str = "AutoProcessor"
+ expansion_strategy: str = "patch_based"
+ image_max_pixels: int = 16384 * 28 * 28
+ image_min_pixels: int = 4 * 28 * 28
+ video_max_pixels: int = 16384 * 28 * 28
+ video_min_pixels: int = 4 * 28 * 28
+ video_fps: float = 2.0
+ video_maxlen: int = 128
+
+class VisionProcessor(ABC):
+ """Abstract base class for vision processing strategies"""
+
+ def __init__(self, config: VisionProcessorConfig):
+ self.config = config
+ self._validate_config()
+
+ def _validate_config(self):
+ """Validate the vision configuration"""
+ required_fields = ['image_token', 'video_token']
+ for field in required_fields:
+ if not hasattr(self.config, field) or getattr(self.config, field) is None:
+ raise ValueError(f"Missing required field: {field}")
+
+ @abstractmethod
+ def preprocess_images(self, images: List["ImageInput"], processor: Any) -> Dict[str, Any]:
+ """Preprocess images for the model"""
+ pass
+
+ @abstractmethod
+ def preprocess_videos(self, videos: List["VideoInput"], processor: Any) -> Dict[str, Any]:
+ """Preprocess videos for the model"""
+ pass
+
+ @abstractmethod
+ def calculate_image_tokens(self, image_data: Dict[str, Any], processor: Any) -> int:
+ """Calculate the number of tokens needed for an image"""
+ pass
+
+ @abstractmethod
+ def calculate_video_tokens(self, video_data: Dict[str, Any], processor: Any) -> int:
+ """Calculate the number of tokens needed for a video"""
+ pass
+
+ @abstractmethod
+ def expand_vision_tokens(
+ self,
+ prompt: str,
+ images: List["ImageInput"],
+ videos: List["VideoInput"],
+ processor: Optional[Any],
+ ) -> str:
+ """Expand vision tokens in the prompt to their actual token representations"""
+ pass
+
+ @abstractmethod
+ def get_mm_inputs(
+ self,
+ images: List["ImageInput"],
+ videos: List["VideoInput"],
+ processor: Optional[Any],
+ ) -> Dict[str, torch.Tensor]:
+ """Generate multi-modal inputs for the model"""
+ pass
+
+ def process_vision_info(self, messages: List[Dict]) -> Dict[str, torch.Tensor]:
+ """Process vision information from messages"""
+ pass
+
+ # def process_for_llm(
+ # self,
+ # prompt: str,
+ # images: List["ImageInput"],
+ # videos: List["VideoInput"],
+ # processor: Optional[Any],
+ # tokenizer: Any,
+ # ) -> Dict[str, torch.Tensor]:
+ # """
+ # Complete pipeline: expand tokens and generate LLM-ready inputs.
+ # Returns inputs that can be used directly with model(**inputs).
+ # """
+ # # Step 1: Expand vision tokens in the prompt
+ # expanded_prompt = self.expand_vision_tokens(prompt, images, videos, processor)
+
+ # # Step 2: Tokenize the expanded prompt
+ # tokenized_inputs = tokenizer(
+ # expanded_prompt,
+ # return_tensors="pt",
+ # add_special_tokens=True,
+ # padding=True,
+ # truncation=True
+ # )
+
+ # # Step 3: Generate multi-modal inputs
+ # mm_inputs = self.get_mm_inputs(images, videos, processor)
+
+ # # Step 4: Combine tokenized inputs with multi-modal inputs
+ # final_inputs = {**tokenized_inputs, **mm_inputs}
+
+ # return final_inputs
+
+ def process_for_llm(
+ self,
+ prompt: str,
+ elements: List[str],
+ mask_flags: List[bool],
+ images: List["ImageInput"],
+ videos: List["VideoInput"],
+ processor: Any,
+ tokenizer: Any,
+ return_tensors: str = None,
+ ) -> Dict[str, torch.Tensor]:
+ """
+ Process with proper alignment of all tensors (input_ids, attention_mask, labels, action_mask).
+ This ensures that when vision tokens are expanded, all corresponding tensors are expanded
+ at the same positions, maintaining proper alignment for training and inference.
+ """
+ import torch
+
+ # Step 1: Tokenize elements to get base tensors with proper alignment
+ input_ids = []
+ attention_mask = []
+ labels = []
+ action_mask = []
+
+ # Add BOS token if needed
+ if tokenizer.bos_token and tokenizer.add_bos_token:
+ input_ids.append(tokenizer.bos_token_id)
+ attention_mask.append(1)
+ labels.append(-100)
+ action_mask.append(0)
+
+ # Step 2: Process each element with vision token expansion
+ for element, mask_flag in zip(elements, mask_flags):
+ # Check if element contains vision tokens
+ if self._contains_vision_tokens(element):
+ # Expand vision tokens in this element
+ expanded_element = self.expand_vision_tokens(element, images, videos, processor)
+ cur_input_ids = tokenizer.encode(expanded_element, add_special_tokens=False)
+ else:
+ cur_input_ids = tokenizer.encode(element, add_special_tokens=False)
+
+ # Add tokens with proper alignment
+ input_ids.extend(cur_input_ids)
+ attention_mask.extend([1] * len(cur_input_ids))
+
+ if mask_flag:
+ labels.extend([-100] * len(cur_input_ids))
+ action_mask.extend([0] * len(cur_input_ids))
+ else:
+ labels.extend(cur_input_ids)
+ action_mask.extend([1] * len(cur_input_ids))
+
+ # Step 3: Create base inputs
+ inputs = {
+ 'input_ids': input_ids,
+ 'attention_mask': attention_mask,
+ 'labels': labels,
+ 'action_mask': action_mask
+ }
+
+ # Convert to tensors if requested
+ if return_tensors == "pt":
+ inputs = {k: torch.tensor([v]) for k, v in inputs.items()}
+
+ # Step 4: Add vision inputs
+ mm_inputs = self.get_mm_inputs(images, videos, processor)
+ inputs.update(mm_inputs)
+
+ return inputs
+
+ def _contains_vision_tokens(self, text: str) -> bool:
+ """Check if text contains vision tokens"""
+ return self.config.image_token in text or self.config.video_token in text
+
+class PatchBasedProcessor(VisionProcessor):
+ """Patch-based vision processor (used by Qwen-VL, LLaVA, etc.)
+
+ Supports multiple image input formats:
+ - File paths (str): "/path/to/image.jpg"
+ - URLs (str): "https://example.com/image.jpg"
+ - Base64 strings (str): "data:image/jpeg;base64,/9j/4AAQ..." or raw base64
+ - PIL Image objects
+ - Bytes objects
+ - File-like objects
+ - Dict format: {"path": "/path/to/image.jpg"} or {"bytes": b"image_data"}
+ """
+
+ def _load_image_from_input(self, image_input) -> "ImageObject":
+ """Load image from various input formats including URL and base64"""
+ from PIL import Image
+
+ # Handle PIL Image objects directly
+ if hasattr(image_input, 'width') and hasattr(image_input, 'height'):
+ return image_input
+
+ # Handle string inputs (file path, URL, or base64)
+ if isinstance(image_input, str):
+ # Check if it's a URL
+ if image_input.startswith(('http://', 'https://')):
+ try:
+ with urllib.request.urlopen(image_input) as response:
+ image_data = response.read()
+ return Image.open(BytesIO(image_data))
+ except Exception as e:
+ raise ValueError(f"Failed to load image from URL {image_input}: {e}")
+
+ # Check if it's a base64 string
+ elif image_input.startswith('data:image/') or image_input.startswith('data:application/octet-stream'):
+ # Handle data URL format: data:image/jpeg;base64,/9j/4AAQ...
+ try:
+ # Extract the base64 part after the comma
+ base64_data = image_input.split(',', 1)[1]
+ image_data = base64.b64decode(base64_data)
+ return Image.open(BytesIO(image_data))
+ except Exception as e:
+ raise ValueError(f"Failed to decode base64 image: {e}")
+
+ elif image_input.startswith('iVBORw0KGgo') or len(image_input) > 100:
+ # Likely a raw base64 string (common for PNG images starting with iVBORw0KGgo)
+ try:
+ image_data = base64.b64decode(image_input)
+ return Image.open(BytesIO(image_data))
+ except Exception as e:
+ raise ValueError(f"Failed to decode base64 image: {e}")
+
+ # Assume it's a file path
+ else:
+ return Image.open(image_input)
+
+ # Handle bytes
+ elif isinstance(image_input, bytes):
+ return Image.open(BytesIO(image_input))
+
+ # Handle file-like objects
+ elif hasattr(image_input, 'read'):
+ return Image.open(image_input)
+
+ # Handle dict format
+ elif isinstance(image_input, dict):
+ if image_input.get("bytes") is not None:
+ return Image.open(BytesIO(image_input["bytes"]))
+ elif image_input.get("path") is not None:
+ return Image.open(image_input["path"])
+ else:
+ raise ValueError("Invalid image dict format")
+
+ else:
+ raise ValueError(f"Unsupported image input type: {type(image_input)}")
+
+ def _preprocess_single_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
+ """Preprocess a single image"""
+ if (image.width * image.height) > self.config.image_max_pixels:
+ resize_factor = math.sqrt(self.config.image_max_pixels / (image.width * image.height))
+ width, height = int(image.width * resize_factor), int(image.height * resize_factor)
+ image = image.resize((width, height))
+
+ if (image.width * image.height) < self.config.image_min_pixels:
+ resize_factor = math.sqrt(self.config.image_min_pixels / (image.width * image.height))
+ width, height = int(image.width * resize_factor), int(image.height * resize_factor)
+ image = image.resize((width, height))
+
+ if image.mode != "RGB":
+ image = image.convert("RGB")
+
+ return image
+
+ def _regularize_images(self, images: List["ImageInput"]) -> List["ImageObject"]:
+ """Regularize images to avoid errors"""
+ results = []
+ for image in images:
+ # Use the new helper method to handle all input formats
+ pil_image = self._load_image_from_input(image)
+ results.append(self._preprocess_single_image(pil_image))
+
+ return results
+
+ def _regularize_videos(self, videos: List["VideoInput"]) -> List[List["ImageObject"]]:
+ """Regularize videos to avoid errors"""
+ results = []
+ for video in videos:
+ frames: List["ImageObject"] = []
+
+ # Check if video is nested images
+ if isinstance(video, list) and all(isinstance(frame, (str, BinaryIO, dict)) for frame in video):
+ # Use the new image loading method for each frame
+ for frame in video:
+ try:
+ pil_image = self._load_image_from_input(frame)
+ frames.append(pil_image)
+ except Exception as e:
+ raise ValueError(f"Invalid image found in video frames: {e}")
+ else:
+ # Process actual video file
+ import av
+ container = av.open(video, "r")
+ video_stream = next(stream for stream in container.streams if stream.type == "video")
+
+ # Calculate sample indices
+ total_frames = video_stream.frames
+ if total_frames == 0: # infinite video
+ sample_indices = np.linspace(0, self.config.video_maxlen - 1, self.config.video_maxlen).astype(np.int32)
+ else:
+ sample_frames = max(1, math.floor(float(video_stream.duration * video_stream.time_base) * self.config.video_fps))
+ sample_frames = min(total_frames, self.config.video_maxlen, sample_frames)
+ sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
+
+ container.seek(0)
+ for frame_idx, frame in enumerate(container.decode(video_stream)):
+ if frame_idx in sample_indices:
+ frames.append(frame.to_image())
+
+ frames = self._regularize_images(frames)
+ results.append(frames)
+
+ return results
+
+ def preprocess_images(self, images: List["ImageInput"], processor: Any) -> Dict[str, Any]:
+ """Preprocess images for the model"""
+ if not images:
+ return {}
+
+ image_processor = getattr(processor, "image_processor", None)
+ if image_processor is None:
+ raise ValueError("Image processor not found")
+
+ images = self._regularize_images(images)
+ return image_processor(images, return_tensors="pt")
+
+ def preprocess_videos(self, videos: List["VideoInput"], processor: Any) -> Dict[str, Any]:
+ """Preprocess videos for the model"""
+ if not videos:
+ return {}
+
+ video_processor = getattr(processor, "video_processor", getattr(processor, "image_processor", None))
+ if video_processor is None:
+ raise ValueError("Video processor not found")
+
+ videos = self._regularize_videos(videos)
+
+ # Handle different video processor interfaces
+ if "videos" in inspect.signature(video_processor.preprocess).parameters:
+ return video_processor(images=None, videos=videos, return_tensors="pt")
+ else:
+ return video_processor(videos, return_tensors="pt")
+
+ def calculate_image_tokens(self, image_data: Dict[str, Any], processor: Any) -> int:
+ """Calculate the number of tokens needed for an image
+
+ Uses two approaches:
+ 1. Grid-based (HuggingFace method): Uses image_grid_thw and merge_size
+ - More accurate for models like Qwen-VL
+ - Accounts for hierarchical token merging
+ 2. Patch-based (fallback): Uses image dimensions and patch_size
+ - Standard approach for most ViT-based models
+ - Assumes each patch corresponds to one token
+ """
+ if "pixel_values" in image_data:
+ # Try grid-based calculation first (HuggingFace method)
+ if "image_grid_thw" in image_data:
+ grid_info = image_data["image_grid_thw"]
+ if isinstance(grid_info, torch.Tensor):
+ grid_prod = grid_info.prod().item()
+ elif isinstance(grid_info, list):
+ grid_prod = math.prod(grid_info)
+ else:
+ grid_prod = grid_info
+
+ # Get merge_size from processor
+ merge_size = getattr(processor, "merge_size", 1)
+ merge_length = merge_size ** 2
+
+ num_image_tokens = grid_prod // merge_length
+ return max(1, num_image_tokens)
+
+ # Fallback to patch-based calculation
+ height, width = get_image_size(to_numpy_array(image_data["pixel_values"][0]))
+ image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
+ if hasattr(processor, 'num_additional_image_tokens'):
+ image_seqlen += processor.num_additional_image_tokens
+ if hasattr(processor, 'vision_feature_select_strategy') and processor.vision_feature_select_strategy == "default":
+ image_seqlen -= 1
+ return image_seqlen
+ return 1
+
+ def calculate_video_tokens(self, video_data: Dict[str, Any], processor: Any) -> int:
+ """Calculate the number of tokens needed for a video"""
+ if "pixel_values" in video_data:
+ # For videos, we need to calculate based on frames
+ video_tensor = video_data["pixel_values"][0]
+ if len(video_tensor.shape) > 3: # Has frame dimension
+ num_frames = video_tensor.shape[0]
+ height, width = get_image_size(to_numpy_array(video_tensor[0]))
+ frame_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
+ if hasattr(processor, 'num_additional_image_tokens'):
+ frame_seqlen += processor.num_additional_image_tokens
+ if hasattr(processor, 'vision_feature_select_strategy') and processor.vision_feature_select_strategy == "default":
+ frame_seqlen -= 1
+ return frame_seqlen * num_frames
+ else:
+ # Single frame video
+ return self.calculate_image_tokens(video_data, processor)
+ return 1
+
+ def expand_vision_tokens(
+ self,
+ prompt: str,
+ images: List["ImageInput"],
+ videos: List["VideoInput"],
+ processor: Optional[Any],
+ ) -> str:
+ """Expand vision tokens in the prompt to their actual token representations"""
+ if processor is None:
+ raise ValueError("Processor is required for vision processing")
+
+ # Validate that number of placeholders matches number of inputs
+ num_image_placeholders = prompt.count(self.config.image_token)
+ num_video_placeholders = prompt.count(self.config.video_token)
+
+ if len(images) != num_image_placeholders:
+ raise ValueError(f"Number of images ({len(images)}) doesn't match placeholders ({num_image_placeholders})")
+ if len(videos) != num_video_placeholders:
+ raise ValueError(f"Number of videos ({len(videos)}) doesn't match placeholders ({num_video_placeholders})")
+
+ # Preprocess images and videos to get individual token counts
+ processed_images = self.preprocess_images(images, processor) if images else {}
+ processed_videos = self.preprocess_videos(videos, processor) if videos else {}
+
+ # Expand image tokens using regex to avoid infinite loops
+ expanded_prompt = prompt
+ if self.config.image_token in expanded_prompt:
+ if processed_images and "pixel_values" in processed_images:
+ # Calculate tokens for this specific image
+ image_tokens = self.calculate_image_tokens(processed_images, processor)
+ replacement = self.config.image_token * image_tokens
+ else:
+ replacement = self.config.image_token
+
+ # Use regex to replace all occurrences at once
+ import re
+ expanded_prompt = re.sub(re.escape(self.config.image_token), replacement, expanded_prompt)
+
+ # Expand video tokens using regex to avoid infinite loops
+ if self.config.video_token in expanded_prompt:
+ if processed_videos and "pixel_values" in processed_videos:
+ # Calculate tokens for this specific video
+ video_tokens = self.calculate_video_tokens(processed_videos, processor)
+ replacement = self.config.video_token * video_tokens
+ else:
+ replacement = self.config.video_token
+
+ # Use regex to replace all occurrences at once
+ expanded_prompt = re.sub(re.escape(self.config.video_token), replacement, expanded_prompt)
+
+ return expanded_prompt
+
+ def get_mm_inputs(
+ self,
+ images: List["ImageInput"],
+ videos: List["VideoInput"],
+ processor: Optional[Any],
+ ) -> Dict[str, torch.Tensor]:
+ """Generate multi-modal inputs for the model"""
+ mm_inputs = {}
+
+ # Process images
+ if images:
+ mm_inputs.update(self.preprocess_images(images, processor))
+
+ # Process videos
+ if videos:
+ mm_inputs.update(self.preprocess_videos(videos, processor))
+
+ return mm_inputs
+
+ def process_vision_info(self, messages: List[Dict], processor: Any):
+ """Process vision information from messages"""
+ image_message_types = ["image", "image_url", "image_base64"]
+ images = []
+ for message in messages:
+ for content in message["content"]:
+ if content["type"] in image_message_types:
+ content_type = content["type"]
+ images.append(content[content_type])
+ mm_inputs = self.get_mm_inputs(images, [], processor)
+ return mm_inputs
+
+
+class QwenVLProcessor(PatchBasedProcessor):
+ """Qwen-VL specific processor with custom image preprocessing"""
+
+ def _preprocess_single_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
+ """Qwen-VL specific image preprocessing"""
+ image = super()._preprocess_single_image(image, **kwargs)
+
+ # Qwen-VL specific adjustments
+ if min(image.width, image.height) < 28:
+ width, height = max(image.width, 28), max(image.height, 28)
+ image = image.resize((width, height))
+
+ if image.width / image.height > 200:
+ width, height = image.height * 180, image.height
+ image = image.resize((width, height))
+
+ if image.height / image.width > 200:
+ width, height = image.width, image.width * 180
+ image = image.resize((width, height))
+
+ return image
+
+ def calculate_image_tokens(self, image_data: Dict[str, Any], processor: Any) -> int:
+ """Qwen-VL specific token calculation using grid-based approach"""
+ if "image_grid_thw" in image_data:
+ # Use grid information for more accurate token calculation
+ grid_info = image_data["image_grid_thw"]
+ if isinstance(grid_info, torch.Tensor):
+ grid_prod = grid_info.prod().item()
+ elif isinstance(grid_info, list):
+ grid_prod = math.prod(grid_info)
+ else:
+ grid_prod = grid_info
+
+ # Get merge_size from processor (Qwen-VL typically uses merge_size=2)
+ merge_size = getattr(processor, "merge_size", 2)
+ merge_length = merge_size ** 2
+
+ num_image_tokens = grid_prod // merge_length
+ return max(1, num_image_tokens)
+
+ # Fallback to standard calculation
+ return super().calculate_image_tokens(image_data, processor)
+
+ def expand_vision_tokens(
+ self,
+ prompt: str,
+ images: List["ImageInput"],
+ videos: List["VideoInput"],
+ processor: Optional[Any],
+ ) -> str:
+ """Qwen-VL specific token expansion with vision tags"""
+ expanded_prompt = super().expand_vision_tokens(prompt, images, videos, processor)
+
+ return expanded_prompt
+
+class LlavaProcessor(PatchBasedProcessor):
+ """LLaVA specific processor"""
+
+ def calculate_image_tokens(self, image_data: Dict[str, Any], processor: Any) -> int:
+ """LLaVA specific token calculation"""
+ if "pixel_values" in image_data:
+ height, width = get_image_size(to_numpy_array(image_data["pixel_values"][0]))
+ image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
+ if hasattr(processor, 'num_additional_image_tokens'):
+ image_seqlen += processor.num_additional_image_tokens
+ if hasattr(processor, 'vision_feature_select_strategy') and processor.vision_feature_select_strategy == "default":
+ image_seqlen -= 1
+ return image_seqlen
+ return 1
+
+
+VISION_PROCESSORS: Dict[str, VisionProcessor] = {}
+
+model_type_to_processor_class = {
+ "qwen_vl": QwenVLProcessor,
+ "llava": LlavaProcessor,
+ "gemma3": PatchBasedProcessor,
+ "paligemma": PatchBasedProcessor,
+ "internvl": PatchBasedProcessor,
+ "minicpm": PatchBasedProcessor,
+ "mllama": PatchBasedProcessor,
+ "pixtral": PatchBasedProcessor,
+ "video_llava": PatchBasedProcessor,
+ "patch_based": PatchBasedProcessor,
+}
+
+def register_processor(template_name: str, config: VisionProcessorConfig):
+ """Register a vision processor for a template"""
+ processor_class = model_type_to_processor_class.get(config.model_type)
+ if processor_class is None:
+ raise ValueError(f"No processor class found for model type: {config.model_type}")
+ VISION_PROCESSORS[template_name] = processor_class(config)
+
+
+def register(cls, template_name: str, config: VisionProcessorConfig, processor_class: type = None):
+ """Register a vision processor for a template"""
+ if processor_class is not None:
+ # If processor_class is provided, use it directly
+ VISION_PROCESSORS[template_name] = processor_class(config)
+ else:
+ # Use the global register_processor function
+ register_processor(template_name, config)
+
+def get_processor(template_name: str) -> Optional[VisionProcessor]:
+ """Get vision processor for a template"""
+ return VISION_PROCESSORS.get(template_name)
+
+def get_processor_config(template_name: str) -> Optional[VisionProcessorConfig]:
+ """Get vision config for a template"""
+ processor = get_processor(template_name)
+ return processor.config if processor else None
+
+def is_vision_template(template_name: str) -> bool:
+ """Check if template supports vision"""
+ return template_name in VISION_PROCESSORS
+
+def list_vision_templates() -> List[str]:
+ """List all vision-enabled templates"""
+ return list(VISION_PROCESSORS.keys())
diff --git a/agents/agents/agents/utils/tokenizer.py b/agents/agents/agents/utils/tokenizer.py
new file mode 100644
index 0000000..00ab8fe
--- /dev/null
+++ b/agents/agents/agents/utils/tokenizer.py
@@ -0,0 +1,10 @@
+from transformers import AutoTokenizer
+
+def create_tokenizer(model_name_or_path: str):
+ try:
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
+ # Can not find the tokenizer in local directory or huggingface hub
+ except OSError:
+ tokenizer = None
+
+ return tokenizer
diff --git a/agents/agents/examples/streaming_example.py b/agents/agents/examples/streaming_example.py
deleted file mode 100644
index c4fd08b..0000000
--- a/agents/agents/examples/streaming_example.py
+++ /dev/null
@@ -1,59 +0,0 @@
-import asyncio
-from agents.agents.react.react_agent import ReactAgent
-from agents.tools import code_interpreter
-from agents.rewards import math_reward_tool
-import json
-from agents.agents.chain.streaming_observer import ConsoleStreamObserver
-
-
-async def main():
- tools = [code_interpreter]
-
- agent = ReactAgent(
- "Qwen/Qwen2.5-7B-Instruct",
- tools=tools,
- template="qwen2.5-no-tool",
- backend="async_vllm",
- reward_fn=math_reward_tool,
- debug=True
- )
-
-
- console_stream_observer = ConsoleStreamObserver()
- agent.streaming_manager.add_observer(console_stream_observer)
-
-
- question1 = "Every morning Aya goes for a $9$-kilometer-long walk and stops at a coffee shop afterwards. When she walks at a constant speed of $s$ kilometers per hour, the walk takes her 4 hours, including $t$ minutes spent in the coffee shop. When she walks $s+2$ kilometers per hour, the walk takes her 2 hours and 24 minutes, including $t$ minutes spent in the coffee shop. Suppose Aya walks at $s+\frac{1}{2}$ kilometers per hour. Find the number of minutes the walk takes her, including the $t$ minutes spent in the coffee shop."
- answer1 = "204"
- question2 = "$P(x)$ is a polynomial of degree $3n$ such that\n\\begin{eqnarray*} P(0) = P(3) = \\cdots &=& P(3n) = 2, \\\\ P(1) = P(4) = \\cdots &=& P(3n-2) = 1, \\\\ P(2) = P(5) = \\cdots &=& P(3n-1) = 0, \\quad\\text{ and }\\\\ && P(3n+1) = 730.\\end{eqnarray*}\nDetermine $n$."
- answer2 = "3"
-
-
- messages = [
- {
- "messages": [
- {"role": "user", "content": f"{question1}"}
- ],
- "question": f"{question1}",
- "answer": f"{answer1}"
- },
- # {
- # "messages": [
- # {"role": "user", "content": f"{question2}"}
- # ],
- # "question": f"{question2}",
- # "answer": f"{answer2}"
- # }
- ]
-
-
- await agent.run_async(
- max_steps=4,
- start_messages=messages,
- num_chains=1,
- enable_streaming=True
- )
-
-if __name__ == "__main__":
- asyncio.run(main())
-
diff --git a/agents/agents/rewards/qa_reward.py b/agents/agents/rewards/qa_reward.py
index 583519f..bd335b5 100644
--- a/agents/agents/rewards/qa_reward.py
+++ b/agents/agents/rewards/qa_reward.py
@@ -104,3 +104,18 @@ def qa_f1_reward_format(prediction: str, answer: str, trajectory: List[str]) ->
raise ValueError(f"Invalid prediction or trajectory for qa reward with format: Trajectory: {trajectory}")
return rewards_dict
+
+
+@reward(name="ok_vqa_reward")
+def ok_vqa_reward(prediction: str, answers: List[str], trajectory: List[str]) -> float:
+ """
+ Calculate the reward for the agent's response based on the F1 score and EM score.
+ The reward is 0.0 if the agent has not called any tool.
+ The reward is the F1 score if the agent has called a tool.
+ """
+ f1_scores = []
+ for answer in answers:
+ f1, precision, recall = f1_score(prediction, answer)
+ f1_scores.append(f1)
+ # All answers are the correct answer, take the max f1 score
+ return max(f1_scores)
\ No newline at end of file
diff --git a/agents/tests/unit/agents/prompts/test_qwen25_vl_prompt.py b/agents/tests/unit/agents/prompts/test_qwen25_vl_prompt.py
deleted file mode 100644
index 0376b49..0000000
--- a/agents/tests/unit/agents/prompts/test_qwen25_vl_prompt.py
+++ /dev/null
@@ -1,138 +0,0 @@
-from agents.agents.agents.templates.templates import get_conv_template
-from agents.agents.templates.utils import compare_hf_template, format_conversation
-from transformers import AutoProcessor, AutoTokenizer
-
-def test_simple_prompts():
- messages = [
- {
- "role": "user",
- "content": [
- {
- "type": "image",
- "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
- },
- {"type": "text", "text": "Describe this image."},
- ],
- },
- {
- "role": "assistant",
- "content": [
- {
- "type": "text",
- "text": "The image is a cat.",
- },
- ],
- },
- {
- "role": "user",
- "content": [
- {
- "type": "text",
- "text": "What is in the image?",
- },
- ],
- }
- ]
- processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
- is_equal, official_prompt, implemented_prompt, highlighted_prompt = compare_hf_template(processor, "qwen2.5-vl", messages=messages, add_generation_prompt=True)
- assert is_equal, f"Official prompt:\n\n{official_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}"
- print(f"Highlighted prompt:\n\n{highlighted_prompt}")
-
-
-def test_simple_prompts_with_system():
- messages = [
- {
- "role": "system",
- "content": "You are a multi-modal assistant that can answer questions about images.",
- },
- {
- "role": "user",
- "content": [
- {
- "type": "image",
- "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
- },
- {"type": "text", "text": "Describe this image."},
- ],
- },
- {
- "role": "assistant",
- "content": [
- {
- "type": "text",
- "text": "The image is a cat.",
- },
- ],
- },
- {
- "role": "user",
- "content": [
- {
- "type": "text",
- "text": "What is in the image?",
- },
- ],
- }
- ]
- processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
- is_equal, official_prompt, implemented_prompt, highlighted_prompt = compare_hf_template(processor, "qwen2.5-vl", messages=messages, add_generation_prompt=True)
- assert is_equal, f"Official prompt:\n\n{official_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}"
- print(f"Highlighted prompt:\n\n{highlighted_prompt}")
-
-
-def test_jinja_template():
- conv = get_conv_template("qwen2.5-vl")
- jinja_template = conv.get_jinja_template()
- tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
- messages = [
- {
- "role": "system",
- "content": "You are a multi-modal assistant that can answer questions about images.",
- },
- {
- "role": "user",
- "content": [
- {
- "type": "image",
- "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
- },
- {"type": "text", "text": "Describe this image."},
- ],
- },
- {
- "role": "assistant",
- "content": [
- {
- "type": "text",
- "text": "The image is a cat.",
- },
- ],
- },
- {
- "role": "tool",
- "content": [
- {
- "type": "text",
- "text": "Example tool response.",
- },
- ],
- },
- {
- "role": "user",
- "content": [
- {
- "type": "text",
- "text": "What is in the image?",
- },
- ],
- }
- ]
- official_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
- tokenizer.chat_template = jinja_template
- prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
- print(official_prompt)
- print(prompt)
-
- conv = format_conversation(messages, "qwen2.5-vl", add_generation_prompt=True)
- implemented_prompt = conv.get_prompt()
- print(implemented_prompt)
\ No newline at end of file
diff --git a/agents/tests/unit/agents/prompts/test_qwen3_prompt.py b/agents/tests/unit/agents/prompts/test_qwen3_prompt.py
deleted file mode 100644
index f11785e..0000000
--- a/agents/tests/unit/agents/prompts/test_qwen3_prompt.py
+++ /dev/null
@@ -1,81 +0,0 @@
-from agents.agents.agents.templates.templates import get_conv_template
-from agents.agents.templates.utils import compare_hf_template
-
-def test_simple_prompts():
- messages = [
- {"role": "user", "content": "Hello, how are you?"},
- {"role": "assistant", "content": "I am fine, thank you."},
- {"role": "user", "content": "Want to play a game?"},
- {"role": "assistant", "content": "Sure, what game?"},
- {"role": "user", "content": "What is 3 times 5?"},
- {"role": "assistant", "content": '''{"name": "multiply", "arguments": {"x": 3, "y": 5}}'''},
- {"role": "tool", "content": "15"},
- ]
- qwen3_model = "Qwen/Qwen3-30B-A3B"
- is_equal, official_prompt, implemented_prompt, highlighted_prompt = compare_hf_template(qwen3_model, "qwen3", messages=messages, add_generation_prompt=True)
- assert is_equal, f"Official prompt:\n\n{official_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}"
- print(f"Highlighted prompt:\n\n{highlighted_prompt}")
-
-
-def test_simple_prompts_with_system():
- messages = [
- {"role": "system", "content": "You are a helpful assistant."},
- {"role": "user", "content": "Hello, how are you?"},
- {"role": "assistant", "content": "I am fine, thank you."},
- {"role": "user", "content": "Want to play a game?"},
- {"role": "assistant", "content": "Sure, what game?"},
- {"role": "user", "content": "What is 3 times 5?"},
- {"role": "assistant", "content": '''{"name": "multiply", "arguments": {"x": 3, "y": 5}}'''},
- {"role": "tool", "content": "15"},
- ]
- qwen3_model = "Qwen/Qwen3-30B-A3B"
- is_equal, official_prompt, implemented_prompt, highlighted_prompt = compare_hf_template(qwen3_model, "qwen3", messages=messages, add_generation_prompt=True)
- assert is_equal, f"Official prompt:\n\n{official_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}"
- print(f"Highlighted prompt:\n\n{highlighted_prompt}")
-
-
-def test_with_tools():
- multiply_schema = {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}}
- add_schema = {"type": "function", "function": {"name": "add", "description": "A function that add two numbers", "parameters": {"type": "object", "properties": {"a": {"type": "number", "description": "The first number to add"}, "b": {"type": "number", "description": "The second number to add"}}, "required": ["a", "b"]}}}
- tools = [multiply_schema, add_schema]
- messages = [
- {"role": "user", "content": "Hello, how are you?"},
- {"role": "assistant", "content": "I am fine, thank you."},
- {"role": "user", "content": "Want to play a game?"},
- {"role": "assistant", "content": "Sure, what game?"},
- {"role": "user", "content": "What is 3 times 5?"},
- {"role": "assistant", "content": '''{"name": "multiply", "arguments": {"x": 3, "y": 5}}'''},
- {"role": "tool", "content": "15"},
- ]
- qwen3_model = "Qwen/Qwen3-30B-A3B"
- is_equal, official_prompt, implemented_prompt, highlighted_prompt = compare_hf_template(qwen3_model, "qwen3", messages=messages, tools=tools, add_generation_prompt=True)
- assert is_equal, f"Official prompt:\n\n{official_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}"
- print(f"Highlighted prompt:\n\n{highlighted_prompt}")
-
-
-def test_with_system_message():
- messages = [
- {"role": "system", "content": "You are a helpful assistant."},
- {"role": "user", "content": "Hello, how are you?"},
- {"role": "assistant", "content": "I am fine, thank you."},
- {"role": "user", "content": "What is 3 times 5?"},
- ]
- qwen3_model = "Qwen/Qwen3-30B-A3B"
- is_equal, official_prompt, implemented_prompt, highlighted_prompt = compare_hf_template(qwen3_model, "qwen3", messages=messages, enable_thinking=False, add_generation_prompt=True)
- assert is_equal, f"Official prompt:\n\n{official_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}"
- print(f"Highlighted prompt:\n\n{highlighted_prompt}")
-
-def test_with_system_message_and_tools():
- multiply_schema = {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}}
- add_schema = {"type": "function", "function": {"name": "add", "description": "A function that add two numbers", "parameters": {"type": "object", "properties": {"a": {"type": "number", "description": "The first number to add"}, "b": {"type": "number", "description": "The second number to add"}}, "required": ["a", "b"]}}}
- tools = [multiply_schema, add_schema]
- messages = [
- {"role": "system", "content": "You are a helpful assistant."},
- {"role": "user", "content": "Hello, how are you?"},
- {"role": "assistant", "content": "I am fine, thank you."},
- {"role": "user", "content": "What is 3 times 5?"},
- ]
- qwen3_model = "Qwen/Qwen3-30B-A3B"
- is_equal, official_prompt, implemented_prompt, highlighted_prompt = compare_hf_template(qwen3_model, "qwen3", messages=messages, tools=tools, enable_thinking=False,add_generation_prompt=True)
- assert is_equal, f"Official prompt:\n\n{official_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}"
- print(f"Highlighted prompt:\n\n{highlighted_prompt}")
\ No newline at end of file
diff --git a/agents/tests/unit/agents/prompts/test_template_tokenize.py b/agents/tests/unit/agents/prompts/test_template_tokenize.py
deleted file mode 100644
index ff22ba1..0000000
--- a/agents/tests/unit/agents/prompts/test_template_tokenize.py
+++ /dev/null
@@ -1,62 +0,0 @@
-from agents.agents.templates.utils import is_vlm_template, tokenize_conversation
-import pytest
-from transformers import AutoTokenizer, AutoProcessor
-import torch
-from agents.agents.templates.templates import Chat
-
-@pytest.mark.parametrize("template", ["qwen2.5"])
-@pytest.mark.parametrize("messages", [
- [
- {"role": "user", "content": "Hello, how are you?"},
- {"role": "assistant", "content": "I am fine, thank you."},
- {"role": "user", "content": "Want to play a game?"},
- {"role": "assistant", "content": "Sure, what game?"},
- ],
- [
- {"role": "user", "content": "Help me to calculate 3 times 5."},
- {"role": "assistant", "content": '''{"name": "multiply", "arguments": {"x": 3, "y": 5}}'''},
- {"role": "tool", "content": "15"},
- ],
- [
- {"role": "system", "content": "You are a helpful assistant."},
- {"role": "user", "content": "Hello, how are you?"},
- {"role": "assistant", "content": "I am fine, thank you."},
- {"role": "user", "content": "What is 3 times 5?"},
- ],
-])
-@pytest.mark.parametrize("tools", [
- None,
- # [
- # {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}},
- # {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}},
- # ]
-])
-@pytest.mark.parametrize("add_generation_prompt", [False])
-def test_template_tokenize(template, messages, tools, add_generation_prompt):
- template_tokenizer_mapping = {
- "qwen2.5": "Qwen/Qwen2.5-3B-Instruct",
- "qwen2.5-vl": "Qwen/Qwen2.5-VL-3B-Instruct",
- "qwen3": "Qwen/Qwen3-8B",
- }
- tokenizer = AutoTokenizer.from_pretrained(template_tokenizer_mapping[template])
- if is_vlm_template(template):
- processor = AutoProcessor.from_pretrained(template_tokenizer_mapping[template])
- else:
- processor = None
- try:
- official_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=add_generation_prompt)
- official_inputs = tokenizer(official_prompt, return_tensors="pt")
- # chat = Chat(template, messages, tokenizer)
- # implemented_inputs = chat.tokenize()
- implemented_inputs = tokenize_conversation(messages, tokenizer, template, max_length=2048, processor=processor, tools=tools, add_generation_prompt=add_generation_prompt, return_tensors="pt")
-
- assert torch.equal(official_inputs["input_ids"], implemented_inputs["input_ids"]), f"template: {template}\n\nmessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nofficial_prompt: {official_prompt}\n\nimplemented_prompt: {tokenizer.decode(implemented_inputs['input_ids'][0])}\n\nofficial_inputs: {official_inputs}\n\nimplemented_inputs: {implemented_inputs}"
- assert torch.equal(official_inputs["attention_mask"], implemented_inputs["attention_mask"])
- except Exception as e:
- if isinstance(e, ValueError) and "does not support tool calling." in str(e) and tools is not None and template in ["qwen2.5-vl"]:
- pass
- else:
- raise e
- # assert torch.equal(official_inputs["labels"], implemented_inputs["labels"])
-
-
diff --git a/agents/tests/unit/agents/prompts/test_think_prompt.py b/agents/tests/unit/agents/prompts/test_think_prompt.py
deleted file mode 100644
index 94a3ec2..0000000
--- a/agents/tests/unit/agents/prompts/test_think_prompt.py
+++ /dev/null
@@ -1,29 +0,0 @@
-import openai
-
-def test_think_prompt():
- client = openai.OpenAI(api_key="token-123", base_url="http://0.0.0.0:8000/v1")
- tools = [
- {
- "type": "function",
- "function": {
- "name": "code_interpreter",
- "description": "A python code interpreter that can execute code and return the result",
- "parameters": {
- "type": "object",
- "properties": {
- "code": {"type": "string", "description": "The python code to execute"},
- },
- },
- }
- }
- ]
- response = client.chat.completions.create(
- # model="/mnt/sharefs/users/haonan.li/models/Qwen2.5-7B-instruct-am_think_v1_distilled",
- model="Qwen/Qwen2.5-7B-Instruct",
- tools=tools,
- messages=[
- {"role": "user", "content": "Every morning Aya goes for a $9$-kilometer-long walk and stops at a coffee shop afterwards. When she walks at a constant speed of $s$ kilometers per hour, the walk takes her 4 hours, including $t$ minutes spent in the coffee shop. When she walks $s+2$ kilometers per hour, the walk takes her 2 hours and 24 minutes, including $t$ minutes spent in the coffee shop. Suppose Aya walks at $s+\frac{1}{2}$ kilometers per hour. Find the number of minutes the walk takes her, including the $t$ minutes spent in the coffee shop."},
- ],
- max_tokens=8192
- )
- print(response.choices[0].message.content)
\ No newline at end of file
diff --git a/agents/tests/unit/agents/templates/test_qwen3_prompt.py b/agents/tests/unit/agents/templates/test_qwen3_prompt.py
new file mode 100644
index 0000000..e69de29
diff --git a/agents/tests/unit/agents/templates/test_template_utilities.py b/agents/tests/unit/agents/templates/test_template_utilities.py
new file mode 100644
index 0000000..47cba54
--- /dev/null
+++ b/agents/tests/unit/agents/templates/test_template_utilities.py
@@ -0,0 +1,33 @@
+from agents.agents.templates.templates import get_template, register_template, Template
+from agents.agents.templates.vision_processor import get_processor
+
+def test_template_registration():
+ register_template(
+ Template(
+ name="test",
+ system_template="",
+ user_template="",
+ assistant_template="",
+ stop_words=[]
+ )
+ )
+ assert get_template("test") is not None
+ assert get_template("test").name == "test"
+
+def test_template_registration_with_vision():
+ register_template(
+ Template(
+ name="test-vl",
+ system_template="",
+ user_template="",
+ assistant_template="",
+ stop_words=[],
+ image_token="<|image_pad|>",
+ )
+ )
+ assert get_processor("test-vl") is not None
+ assert get_processor("test-vl").config.image_token == "<|image_pad|>"
+
+
+
+
\ No newline at end of file
diff --git a/agents/tests/unit/agents/templates/test_text_templates_full_align.py b/agents/tests/unit/agents/templates/test_text_templates_full_align.py
new file mode 100644
index 0000000..006ee4e
--- /dev/null
+++ b/agents/tests/unit/agents/templates/test_text_templates_full_align.py
@@ -0,0 +1,111 @@
+""" This file is for testing the text templates that align seamlessly with HF templates. The templates should align on following aspects:
+ - The obtained textual prompt should be the same as the one obtained from HF template with all the following options:
+ - add_generation_prompt
+ - tools
+ - The obtained textual prompt should be the same as the one obtained from Jinja template with all the following options:
+ - add_generation_prompt
+ - tools
+"""
+
+
+from agents.agents.templates.utils import compare_hf_template
+from transformers import AutoTokenizer
+import pytest
+
+@pytest.mark.parametrize("model_name_or_path", [
+ # "Qwen/Qwen2.5-3B-Instruct",
+ "mistralai/Mistral-7B-Instruct-v0.3",
+])
+@pytest.mark.parametrize("messages", [
+ [
+ {"role": "user", "content": "Hello, how are you?"},
+ {"role": "assistant", "content": "I am fine, thank you."},
+ {"role": "user", "content": "Want to play a game?"},
+ {"role": "assistant", "content": "Sure, what game?"},
+ ],
+ [
+ {"role": "user", "content": "Help me to calculate 3 times 5."},
+ {"role": "assistant", "content": '''{"name": "multiply", "arguments": {"x": 3, "y": 5}}'''},
+ {"role": "tool", "content": "15"},
+ ],
+ [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "Hello, how are you?"},
+ {"role": "assistant", "content": "I am fine, thank you."},
+ {"role": "user", "content": "What is 3 times 5?"},
+ ],
+])
+@pytest.mark.parametrize("tools", [
+ None,
+ [
+ {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}},
+ {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}},
+ ]
+])
+@pytest.mark.parametrize("add_generation_prompt", [True, False])
+def test_hf_template_print(model_name_or_path, messages, tools, add_generation_prompt):
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=add_generation_prompt, tools=tools)
+ print(f"========================================\nModel: {model_name_or_path}\nMessages: {messages}\nTools: {tools}\nAdd generation prompt: {add_generation_prompt}\n")
+ print(prompt)
+ print("========================================\n")
+
+
+# "qwen2.5-think", "qwen2.5", "qwen2.5-no-tool",
+# "llama-3.2", "mistral", "glm-4", "internlm2.5", "phi-3.5", "phi-4"
+@pytest.mark.parametrize("template", ["llama-3.2", "qwen2.5"])
+@pytest.mark.parametrize("messages", [
+ [
+ {"role": "user", "content": "Hello, how are you?"},
+ {"role": "assistant", "content": "I am fine, thank you."},
+ {"role": "user", "content": "Want to play a game?"},
+ {"role": "assistant", "content": "Sure, what game?"},
+ ],
+ [
+ {"role": "user", "content": "Help me to calculate 3 times 5."},
+ {"role": "assistant", "content": '''{"name": "multiply", "arguments": {"x": 3, "y": 5}}'''},
+ {"role": "tool", "content": "15"},
+ ],
+ [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "Hello, how are you?"},
+ {"role": "assistant", "content": "I am fine, thank you."},
+ {"role": "user", "content": "What is 3 times 5?"},
+ ],
+])
+@pytest.mark.parametrize("tools", [
+ None,
+ [
+ {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}},
+ {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}},
+ ]
+])
+@pytest.mark.parametrize("add_generation_prompt", [True, False])
+def test_chat_template_equal(template, messages, tools, add_generation_prompt):
+ # Filter invalid combinations
+ if add_generation_prompt and messages[-1]['role'] == 'assistant':
+ return
+
+
+ template_tokenizer_mapping = {
+ "qwen2.5": "Qwen/Qwen2.5-3B-Instruct",
+ "qwen2.5-think": "Qwen/Qwen2.5-3B-Instruct",
+ "qwen2.5-no-system-tool": "Qwen/Qwen2.5-3B-Instruct",
+ "deepseek-prover-v2": "deepseek-ai/DeepSeek-Prover-V2-7B",
+ "llama-3.2": "meta-llama/Llama-3.2-3B-Instruct",
+ "mistral": "mistralai/Mistral-7B-Instruct-v0.3",
+ "glm-4": "THUDM/glm-4-9b-chat",
+ "internlm2.5": "internlm/internlm2_5-7b-chat",
+ "phi-3.5": "microsoft/Phi-3.5-mini-instruct",
+ "phi-4": "microsoft/Phi-4",
+ "nemotron": "nvidia/Llama-3.1-Nemotron-Nano-8B-v1",
+ }
+ tokenizer = AutoTokenizer.from_pretrained(template_tokenizer_mapping[template], trust_remote_code=True)
+
+ is_equal, is_equal_between_implemented_prompts, is_equal_between_jinja_prompts, official_prompt, implemented_prompt, implemented_jinja_prompt, highlighted_prompt = compare_hf_template(tokenizer, template, messages=messages, tools=tools,add_generation_prompt=add_generation_prompt)
+
+ assert is_equal, f"Template: {template}\n\nMessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nOfficial prompt:\n\n{official_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}"
+ assert is_equal_between_jinja_prompts, f"Template: {template}\n\nMessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}\n\nJinja prompt:\n\n{implemented_jinja_prompt}"
+ # print(f"Official prompt:\n\n{official_prompt}")
+ # print(f"Highlighted prompt:\n\n{highlighted_prompt}")
+
diff --git a/agents/tests/unit/agents/prompts/test_templates.py b/agents/tests/unit/agents/templates/test_text_templates_partial_align.py
similarity index 54%
rename from agents/tests/unit/agents/prompts/test_templates.py
rename to agents/tests/unit/agents/templates/test_text_templates_partial_align.py
index 068f832..78b7f87 100644
--- a/agents/tests/unit/agents/prompts/test_templates.py
+++ b/agents/tests/unit/agents/templates/test_text_templates_partial_align.py
@@ -1,8 +1,10 @@
-from agents.agents.templates.utils import compare_hf_template
-from transformers import AutoTokenizer
import pytest
+from transformers import AutoTokenizer
+from agents.agents.templates.templates import get_template
+from agents.agents.templates.utils import compare_hf_template
-@pytest.mark.parametrize("template", ["qwen2.5-think", "qwen2.5", "qwen2.5-no-tool"])
+# nemotron, phi-4, glm-4
+@pytest.mark.parametrize("template_name", ["qwen2.5-think", "qwen2.5-no-system-tool",])
@pytest.mark.parametrize("messages", [
[
{"role": "user", "content": "Hello, how are you?"},
@@ -13,7 +15,7 @@
[
{"role": "user", "content": "Help me to calculate 3 times 5."},
{"role": "assistant", "content": '''{"name": "multiply", "arguments": {"x": 3, "y": 5}}'''},
- {"role": "tool", "content": "15"},
+ {"role": "tool", "content": "15", "tool_call_id": "123456789"},
],
[
{"role": "system", "content": "You are a helpful assistant."},
@@ -30,20 +32,37 @@
]
])
@pytest.mark.parametrize("add_generation_prompt", [True, False])
-def test_chat_template_equal(template, messages, tools, add_generation_prompt):
+def test_chat_template_equal(template_name, messages, tools, add_generation_prompt):
# Filter invalid combinations
if add_generation_prompt and messages[-1]['role'] == 'assistant':
return
+
+ template = get_template(template_name)
+ if tools and not template._supports_tool_call():
+ return
+
+ contain_tool_role = any(message['role'] == 'tool' for message in messages)
+ if contain_tool_role and not template._supports_tool_call():
+ return
template_tokenizer_mapping = {
"qwen2.5": "Qwen/Qwen2.5-3B-Instruct",
"qwen2.5-think": "Qwen/Qwen2.5-3B-Instruct",
- "qwen2.5-no-tool": "Qwen/Qwen2.5-3B-Instruct",
+ "qwen2.5-no-system-tool": "Qwen/Qwen2.5-3B-Instruct",
+ "deepseek-prover-v2": "deepseek-ai/DeepSeek-Prover-V2-7B",
+ "llama-3.2": "meta-llama/Llama-3.2-3B-Instruct",
+ "mistral": "mistralai/Mistral-7B-Instruct-v0.3",
+ "glm-4": "THUDM/glm-4-9b-chat",
+ "internlm2.5": "internlm/internlm2_5-7b-chat",
+ "phi-3.5": "microsoft/Phi-3.5-mini-instruct",
+ "phi-4": "microsoft/Phi-4",
+ "nemotron": "nvidia/Llama-3.1-Nemotron-Nano-8B-v1",
}
- tokenizer = AutoTokenizer.from_pretrained(template_tokenizer_mapping[template])
-
- is_equal, is_equal_between_implemented_prompts, is_equal_between_jinja_prompts, official_prompt, implemented_prompt, implemented_jinja_prompt, highlighted_prompt = compare_hf_template(tokenizer, template, messages=messages, tools=tools,add_generation_prompt=add_generation_prompt)
- # assert is_equal, print(f"Template: {template}\n\nMessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nOfficial prompt:\n\n{official_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}")
- assert is_equal_between_jinja_prompts, print(f"Template: {template}\n\nMessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}\n\nJinja prompt:\n\n{implemented_jinja_prompt}")
- print(f"Highlighted prompt:\n\n{highlighted_prompt}")
+ tokenizer = AutoTokenizer.from_pretrained(template_tokenizer_mapping[template_name], trust_remote_code=True)
+ is_equal, is_equal_between_implemented_prompts, is_equal_between_jinja_prompts, official_prompt, implemented_prompt, implemented_jinja_prompt, highlighted_prompt = compare_hf_template(tokenizer, template_name, messages=messages, tools=tools,add_generation_prompt=add_generation_prompt)
+
+ print(f"Official prompt:\n\n\"{official_prompt}\"\n\n")
+ print(f"Implemented prompt:\n\n\"{implemented_prompt}\"\n\n")
+ assert is_equal, f"Template: {template}\n\nMessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nOfficial prompt:\n\n{official_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}"
+ assert is_equal_between_jinja_prompts, f"Template: {template}\n\nMessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}\n\nJinja prompt:\n\n{implemented_jinja_prompt}"
\ No newline at end of file
diff --git a/agents/tests/unit/agents/templates/test_text_templates_tokenize.py b/agents/tests/unit/agents/templates/test_text_templates_tokenize.py
new file mode 100644
index 0000000..a89e0b8
--- /dev/null
+++ b/agents/tests/unit/agents/templates/test_text_templates_tokenize.py
@@ -0,0 +1,61 @@
+""" This file is for testing the tokenization of the templates. The templates should align on following aspects:
+ - The tokenized prompt should be the same as the one obtained from HF template with all the following options:
+ - add_generation_prompt
+ - tools
+ - We need to observe the labels and action_mask to make sure the the they are correct.
+
+Since the align for textual prompt is already tested in other files, we only need to test the tokenization of the templates.
+"""
+
+from agents.agents.templates.utils import tokenize_conversation
+import pytest
+from transformers import AutoTokenizer
+import torch
+from agents.agents.templates.templates import Chat
+
+@pytest.mark.parametrize("template", ["llama-3.2", "qwen2.5"])
+@pytest.mark.parametrize("messages", [
+ [
+ {"role": "user", "content": "Hello, how are you?"},
+ {"role": "assistant", "content": "I am fine, thank you."},
+ {"role": "user", "content": "Want to play a game?"},
+ {"role": "assistant", "content": "Sure, what game?"},
+ ],
+ [
+ {"role": "user", "content": "Help me to calculate 3 times 5."},
+ {"role": "assistant", "content": '''{"name": "multiply", "arguments": {"x": 3, "y": 5}}'''},
+ {"role": "tool", "content": "15"},
+ ],
+ [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "Hello, how are you?"},
+ {"role": "assistant", "content": "I am fine, thank you."},
+ {"role": "user", "content": "What is 3 times 5?"},
+ {"role": "assistant", "content": "15"},
+ {"role": "user", "content": "OK, what is 3 times 6?"},
+ {"role": "assistant", "content": "18"},
+ ],
+])
+@pytest.mark.parametrize("tools", [
+ None,
+ [
+ {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}},
+ {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}},
+ ]
+])
+@pytest.mark.parametrize("add_generation_prompt", [False, True])
+def test_template_tokenize(template, messages, tools, add_generation_prompt):
+ template_tokenizer_mapping = {
+ "qwen2.5": "Qwen/Qwen2.5-3B-Instruct",
+ "llama-3.2": "meta-llama/Llama-3.2-3B-Instruct",
+ }
+ tokenizer = AutoTokenizer.from_pretrained(template_tokenizer_mapping[template], trust_remote_code=True)
+
+ chat = Chat(template, messages, tools=tools)
+ prompt = chat.prompt(add_generation_prompt=add_generation_prompt, tools=tools)
+
+ hf_inputs = tokenizer(prompt, return_tensors="pt")
+
+ implemented_inputs = tokenize_conversation(messages, tokenizer, template, max_length=2048, tools=tools, add_generation_prompt=add_generation_prompt, return_tensors="pt")
+
+ assert torch.equal(hf_inputs["input_ids"], implemented_inputs["input_ids"]), f"template: {template}\n\nmessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nprompt: {prompt}\n\nimplemented_prompt: {tokenizer.decode(implemented_inputs['input_ids'][0])}\n\nhf_inputs: {hf_inputs}\n\nimplemented_inputs: {implemented_inputs}"
diff --git a/agents/tests/unit/agents/templates/test_vision_templates_full_align.py b/agents/tests/unit/agents/templates/test_vision_templates_full_align.py
new file mode 100644
index 0000000..167307d
--- /dev/null
+++ b/agents/tests/unit/agents/templates/test_vision_templates_full_align.py
@@ -0,0 +1,87 @@
+""" This file is for testing the vision templates that align seamlessly with HF templates. The templates should align on following aspects:
+ - The obtained textual prompt should be the same as the one obtained from HF template with all the following options:
+ - add_generation_prompt
+ - tools
+ - The obtained textual prompt should be the same as the one obtained from Jinja template with all the following options:
+ - add_generation_prompt
+ - tools
+To test vision part, the messages should contain at least one image.
+"""
+
+
+from agents.agents.templates.utils import compare_hf_template
+from transformers import AutoTokenizer
+import pytest
+# "qwen2.5-think", "qwen2.5", "qwen2.5-no-tool",
+@pytest.mark.parametrize("template", ["qwen2.5-vl"])
+@pytest.mark.parametrize("messages", [
+ [
+ {
+ "role": "system",
+ "content": "You are a multi-modal assistant that can answer questions about images.",
+ },
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image",
+ "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
+ },
+ {"type": "text", "text": "Describe this image."},
+ ],
+ },
+ {
+ "role": "assistant",
+ "content": [
+ {
+ "type": "text",
+ "text": "The image is a cat.",
+ },
+ ],
+ },
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "What is in the image?",
+ },
+ ],
+ }
+ ],
+ [
+ {"role": "user", "content": "Help me to calculate 3 times 5."},
+ {"role": "assistant", "content": '''{"name": "multiply", "arguments": {"x": 3, "y": 5}}'''},
+ {"role": "tool", "content": "15"},
+ ],
+ [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "Hello, how are you?"},
+ {"role": "assistant", "content": "I am fine, thank you."},
+ {"role": "user", "content": "What is 3 times 5?"},
+ ],
+])
+@pytest.mark.parametrize("tools", [
+ None,
+ [
+ {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}},
+ {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}},
+ ]
+])
+@pytest.mark.parametrize("add_generation_prompt", [True, False])
+def test_chat_template_equal(template, messages, tools, add_generation_prompt):
+ # Filter invalid combinations
+ if add_generation_prompt and messages[-1]['role'] == 'assistant':
+ return
+
+ template_tokenizer_mapping = {
+ "qwen2.5-vl": "Qwen/Qwen2.5-VL-3B-Instruct",
+ }
+ tokenizer = AutoTokenizer.from_pretrained(template_tokenizer_mapping[template])
+
+ is_equal, is_equal_between_implemented_prompts, is_equal_between_jinja_prompts, official_prompt, implemented_prompt, implemented_jinja_prompt, highlighted_prompt = compare_hf_template(tokenizer, template, messages=messages, tools=tools,add_generation_prompt=add_generation_prompt)
+ assert is_equal, f"Template: {template}\n\nMessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nOfficial prompt:\n\n{official_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}"
+ assert is_equal_between_jinja_prompts, f"Template: {template}\n\nMessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}\n\nJinja prompt:\n\n{implemented_jinja_prompt}"
+ print(f"Official prompt:\n\n{official_prompt}")
+ print(f"Highlighted prompt:\n\n{highlighted_prompt}")
+
diff --git a/agents/tests/unit/agents/templates/test_vision_templates_tokenize.py b/agents/tests/unit/agents/templates/test_vision_templates_tokenize.py
new file mode 100644
index 0000000..31b5ecd
--- /dev/null
+++ b/agents/tests/unit/agents/templates/test_vision_templates_tokenize.py
@@ -0,0 +1,110 @@
+""" This file is for testing the text templates that align seamlessly with HF templates. The templates should align on following aspects:
+ - The obtained textual prompt should be the same as the one obtained from HF template with all the following options:
+ - add_generation_prompt
+ - tools
+ - The obtained textual prompt should be the same as the one obtained from Jinja template with all the following options:
+ - add_generation_prompt
+ - tools
+"""
+
+
+from agents.agents.templates.templates import Chat
+from agents.agents.templates.utils import compare_hf_template, tokenize_conversation
+from transformers import AutoTokenizer
+import pytest
+import torch
+from transformers import AutoProcessor
+from qwen_vl_utils import process_vision_info
+
+@pytest.mark.parametrize("template", ["qwen2.5-vl"])
+@pytest.mark.parametrize("messages", [
+ [
+ {
+ "role": "system",
+ "content": "You are a multi-modal assistant that can answer questions about images.",
+ },
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image",
+ "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
+ },
+ {"type": "text", "text": "Describe this image."},
+ ],
+ },
+ {
+ "role": "assistant",
+ "content": [
+ {
+ "type": "text",
+ "text": "The image is a cat.",
+ },
+ ],
+ },
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "What is in the image?",
+ },
+ ],
+ }
+ ],
+ # [
+ # {"role": "user", "content": "Help me to calculate 3 times 5."},
+ # {"role": "assistant", "content": '''{"name": "multiply", "arguments": {"x": 3, "y": 5}}'''},
+ # {"role": "tool", "content": "15"},
+ # ],
+ # [
+ # {"role": "system", "content": "You are a helpful assistant."},
+ # {"role": "user", "content": "Hello, how are you?"},
+ # {"role": "assistant", "content": "I am fine, thank you."},
+ # {"role": "user", "content": "What is 3 times 5?"},
+ # ],
+])
+@pytest.mark.parametrize("tools", [
+ None,
+ # [
+ # {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}},
+ # {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}},
+ # ]
+])
+@pytest.mark.parametrize("add_generation_prompt", [True, False])
+def test_chat_template_equal(template, messages, tools, add_generation_prompt):
+ template_tokenizer_mapping = {
+ "qwen2.5-vl": "Qwen/Qwen2.5-VL-3B-Instruct",
+ }
+ tokenizer = AutoTokenizer.from_pretrained(template_tokenizer_mapping[template])
+ processor = AutoProcessor.from_pretrained(template_tokenizer_mapping[template])
+ official_prompt = tokenizer.apply_chat_template(messages
+ , tokenize=False, add_generation_prompt=add_generation_prompt)
+ image_inputs, video_inputs = process_vision_info(messages)
+ official_inputs = processor(
+ text=[official_prompt],
+ images=image_inputs,
+ videos=video_inputs,
+ padding=True,
+ return_tensors="pt",
+ )
+
+ implemented_inputs = tokenize_conversation(messages, tokenizer, template, max_length=8192, tools=tools, add_generation_prompt=add_generation_prompt, return_tensors="pt", processor=processor)
+
+ official_prompt = tokenizer.decode(official_inputs['input_ids'][0])
+ implemented_prompt = tokenizer.decode(implemented_inputs['input_ids'][0])
+ print(f"Official prompt image tokens: {official_prompt.count('<|image_pad|>')}\nImplemented prompt image tokens: {implemented_prompt.count('<|image_pad|>')}")
+
+ print(f"Official images: {official_inputs['pixel_values'].shape}\nImplemented images: {implemented_inputs['pixel_values'].shape}")
+
+ assert torch.equal(official_inputs["input_ids"], implemented_inputs["input_ids"]), f"""Offical
+ prompt:\n{official_prompt}\nImplemented prompt:\n{implemented_prompt}"""
+
+ assert torch.equal(official_inputs["pixel_values"], implemented_inputs["pixel_values"])
+
+ assert torch.equal(official_inputs["image_grid_thw"], implemented_inputs["image_grid_thw"])
+
+ assert implemented_inputs["input_ids"].shape == implemented_inputs["action_mask"].shape, f"""Official action mask shape: {official_inputs["action_mask"].shape}\nImplemented action mask shape: {implemented_inputs["action_mask"].shape}"""
+
+
+ print(f"official_prompt: {official_prompt}\nimplemented_prompt: {tokenizer.decode(implemented_inputs['input_ids'][0])}\nofficial_inputs: {official_inputs.keys()}\nimplemented_inputs: {implemented_inputs.keys()}\n")
\ No newline at end of file
diff --git a/verl b/verl
index 0186ec8..861f63b 160000
--- a/verl
+++ b/verl
@@ -1 +1 @@
-Subproject commit 0186ec81e9273953d2f34b0c4d48741cc2f9aabc
+Subproject commit 861f63ba8097a43ababe27116842512783080586