diff --git a/agents/agents/agents/agent_base.py b/agents/agents/agents/agent_base.py index 0d712ce..e878a79 100644 --- a/agents/agents/agents/agent_base.py +++ b/agents/agents/agents/agent_base.py @@ -4,23 +4,28 @@ from .templates.templates import get_template from ..__init__ import AGENT_DATA_DIR -from .llm_backend import AsyncVLLMBackend, AsyncVerlBackend, ClientBackend, TransformersBackend, VLLMBackend, VerlBackend +from .llm_backend import AsyncVLLMBackend, AsyncVerlBackend, ClientBackend, TransformersBackend, VLLMBackend from ..utils.logging import get_logger from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch -from .templates.utils import is_vlm_template, tokenize_conversations +from .templates.utils import tokenize_conversations +from .templates.vision_processor import is_vision_template from .chain.chain_base import ChainGeneration import os import transformers import warnings +import logging from .chain.streaming_observer import ConsoleStreamObserver, StreamingManager +from .utils.tokenizer import create_tokenizer +from .backend_config import BACKEND_CONFIGS try: from verl.protocol import DataProto except ImportError: print("verl can not be imported.") pass +Logger = logging.getLogger(__name__) class BaseAgent(ChainGeneration, ABC): """ @@ -34,12 +39,13 @@ class BaseAgent(ChainGeneration, ABC): def __init__( self, model_name_or_path, - template: str, + template: str=None, system_prompt: str = None, tools: List = None, max_length: int=8192, debug: bool = False, backend: str = "transformers", + backend_config: Any = None, reward_fn: Callable = None, log_file: str = "agent", project_name: str = None, @@ -65,9 +71,30 @@ def __init__( self.tools = tools self.system_prompt = system_prompt self.model_name_or_path = model_name_or_path - self.llm_engine, self.tokenizer, self.processor = self._init_llm_engine(model_name_or_path, backend) + + # Handle backend configuration + if backend_config is None: + # Use default configuration for the backend + config_class = BACKEND_CONFIGS.get(backend) + if config_class: + self.backend_config = config_class() + else: + self.backend_config = None + else: + self.backend_config = backend_config + + self.llm_engine = self._init_llm_engine(model_name_or_path, backend) + + # Create appropriate tokenizer for trajectory processing + self.tokenizer = create_tokenizer(model_name_or_path) + self._reward_fn = reward_fn - self.jinja_template = get_template(self.template).jinja_template() + + if self.template is None: + self.jinja_template = None + else: + self.jinja_template = get_template(self.template).jinja_template() + self.project_name = project_name self.run_name = run_name self.streaming_manager = StreamingManager() @@ -78,38 +105,59 @@ def __init__( raise ValueError(f"Streaming mode {streaming} is not supported.") super().__init__() if kwargs: - warnings.warn(f"Unused arguments for agent initialization: {kwargs}") + # warnings.warn(f"Unused arguments for agent initialization: {kwargs}") + raise ValueError(f"Unused arguments for agent initialization: {kwargs}") def _init_llm_engine(self, model_name_or_path: str, backend: str): if isinstance(model_name_or_path, str): + # Extract backend-specific configuration + config_kwargs = {} + if self.backend_config: + config_kwargs = {k: v for k, v in self.backend_config.__dict__.items() + if not k.startswith('_')} + if backend == "transformers": - llm_engine = TransformersBackend(model_name_or_path, self.template, max_length=self.max_length) - elif backend == "vllm": - llm_engine = VLLMBackend(model_name_or_path, self.template, max_length=self.max_length) + llm_engine = TransformersBackend( + model_name_or_path, + self.template, + max_length=self.max_length, + **config_kwargs + ) elif backend == "async_vllm": - llm_engine = AsyncVLLMBackend(model_name_or_path, self.template, max_length=self.max_length) - elif backend == "verl": - llm_engine = VerlBackend(llm_engine=None, model_name_or_path=model_name_or_path, template=self.template, max_length=self.max_length) + llm_engine = AsyncVLLMBackend( + model_name_or_path, + self.template, + max_length=self.max_length, + **config_kwargs + ) elif backend == "async_verl": - llm_engine = AsyncVerlBackend(llm_engine=None, model_name_or_path=model_name_or_path, template=self.template, max_length=self.max_length) + llm_engine = AsyncVerlBackend( + llm_engine=None, + model_name_or_path=model_name_or_path, + template=self.template, + max_length=self.max_length, + **config_kwargs + ) elif backend == "client": - llm_engine = ClientBackend(model_name_or_path, self.template, max_length=self.max_length) + print(f"config_kwargs: {config_kwargs}") + llm_engine = ClientBackend( + model_name_or_path, + self.template, + max_length=self.max_length, + **config_kwargs + ) else: raise ValueError(f"Backend {backend} is not supported.") else: raise ValueError("model_name_or_path must be a string.") - tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path) - if is_vlm_template(self.template): - processor = transformers.AutoProcessor.from_pretrained(model_name_or_path) - else: - processor = None - return llm_engine, tokenizer, processor + return llm_engine - def set_llm_engine(self, llm_engine: Any, tokenizer: Any): + def set_llm_engine(self, llm_engine: Any, tokenizer: Any, processor: Any): assert self.backend == "async_verl", "Only async verl backend is supported for now" self.llm_engine.llm_engine = llm_engine self.tokenizer = tokenizer + self.processor = processor def generate(self, messages_list_or_inputs: List[List[Dict]], **args): return self.llm_engine.generate(messages_list_or_inputs, **args) @@ -151,17 +199,6 @@ async def generate_streaming(self, messages_list_or_inputs: List[List[Dict]], st @property def timing_data(self): return self.timer.timing_data - - def forward(self, messages_list_or_inputs: List[List[Dict]], **args): - if isinstance(messages_list_or_inputs, List): - inputs = tokenize_conversations(messages_list_or_inputs, tokenizer=self.tokenizer, conv_template=self.template, max_length=self.max_length, processor=self.processor) - else: - raise ValueError("messages_list_or_inputs must be a list of messages or a dictionary of padded inputs.") - - if isinstance(self.llm_engine, transformers.PreTrainedModel): - return self.llm_engine.forward(**inputs, **args) # Only support transformers models for now. - else: - raise ValueError("llm_engine must be a transformers.PretrainedModel.") @property def trajectories(self): @@ -169,7 +206,10 @@ def trajectories(self): return trajectories - def tokenize_trajectories(self, return_action_mask: bool = False, return_reward_mask: bool = False): + def tokenize_trajectories(self, tokenizer, return_action_mask: bool = False, return_reward_mask: bool = False): + if tokenizer is None: + tokenizer = self.tokenizer + trajectories = self.trajectories self.logger.info("================ Trajectory ================") self.logger.info(trajectories[0]) @@ -196,7 +236,7 @@ def tokenize_trajectories(self, return_action_mask: bool = False, return_reward_ info['last_response'] = last_response other_info_list.append(info) - inputs = tokenize_conversations(messages_list, tokenizer=self.tokenizer, conv_template=self.template, processor=self.processor, max_length=self.max_length, return_reward_mask=return_reward_mask) + inputs = tokenize_conversations(messages_list, tokenizer=tokenizer, conv_template=self.template, processor=self.processor, max_length=self.max_length, return_reward_mask=return_reward_mask) position_ids = torch.clip(torch.cumsum(inputs['attention_mask'], dim=-1) - 1, min=0, max=None) inputs['position_ids'] = position_ids diff --git a/agents/agents/agents/backend_config.py b/agents/agents/agents/backend_config.py new file mode 100644 index 0000000..54a5a62 --- /dev/null +++ b/agents/agents/agents/backend_config.py @@ -0,0 +1,66 @@ +from dataclasses import dataclass +from typing import Optional, Dict, Any, List +import asyncio + + +@dataclass +class TransformersConfig: + """Configuration for Transformers backend""" + temperature: float = 1.0 + max_new_tokens: int = 1024 + trust_remote_code: bool = True + device_map: str = "auto" + + +@dataclass +class VLLMConfig: + """Configuration for VLLM backend""" + temperature: float = 1.0 + max_new_tokens: int = 1024 + # Add other vLLM specific parameters as needed + + +@dataclass +class AsyncVLLMConfig: + """Configuration for Async VLLM backend""" + temperature: float = 1.0 + max_new_tokens: int = 1024 + # Add other async vLLM specific parameters as needed + + +@dataclass +class VerlConfig: + """Configuration for Verl backend""" + temperature: float = 1.0 + max_new_tokens: int = 1024 + # Add other Verl specific parameters as needed + + +@dataclass +class AsyncVerlConfig: + """Configuration for Async Verl backend""" + temperature: float = 1.0 + max_new_tokens: int = 1024 + # Add other async Verl specific parameters as needed + + +@dataclass +class ClientConfig: + """Configuration for Client backend (OpenAI-compatible)""" + base_url: str = "http://localhost:8000/v1" + max_requests_per_minute: int = 100 + timeout: int = 600 + api_key: str = "EMPTY" + max_new_tokens: int = 1024 + temperature: float = 1.0 + + +# Backend configuration mapping +BACKEND_CONFIGS = { + "transformers": TransformersConfig, + "vllm": VLLMConfig, + "async_vllm": AsyncVLLMConfig, + "verl": VerlConfig, + "async_verl": AsyncVerlConfig, + "client": ClientConfig, +} \ No newline at end of file diff --git a/agents/agents/agents/chain/chain_base.py b/agents/agents/agents/chain/chain_base.py index 843b681..02fa27d 100644 --- a/agents/agents/agents/chain/chain_base.py +++ b/agents/agents/agents/chain/chain_base.py @@ -316,14 +316,40 @@ async def _run_single_chain(self, # Handle tool calls if current_node.messages[-1].get("tool_calls"): for tool_call in current_node.messages[-1]["tool_calls"]: - current_node = await self._execute_tool_call( + result = await self._execute_tool_call( tool_call, newest_messages, chain, chain_id, depth, have_set_tools, enable_streaming ) have_set_tools = True + + # Create action input node + action_input_node = chain.add_node( + type="Action Input", + messages=deepcopy(newest_messages), + description=result.get("arguments", "") + ) + + # Process observation + observation = result["observation"] + observation_json = json.dumps({ + "name": result["name"], + "content": observation, + }, indent=4) + + action_input_node.observation = observation_json + action_input_node.observation_code = result["status"] + newest_messages.append({ + "role": "tool", + "tool_call_id": tool_call["id"], + "content": [{"type": "text", "text": observation_json}], + }) + action_input_node.messages = deepcopy(newest_messages) + action_input_node.is_terminal = result["status"] in self.terminal_status else: # No tool calls, chain is finished break + + current_node = action_input_node depth += 1 @@ -463,33 +489,9 @@ async def _execute_tool_call(self, tool_call, newest_messages, chain, chain_id, step=depth, depth=depth )) + + return result - - # Create action input node - action_input_node = chain.add_node( - type="Action Input", - messages=deepcopy(newest_messages), - description=result.get("arguments", "") - ) - - # Process observation - observation = result["observation"] - observation_json = json.dumps({ - "name": result["name"], - "content": observation, - }, indent=4) - - action_input_node.observation = observation_json - action_input_node.observation_code = result["status"] - newest_messages.append({ - "role": "tool", - "tool_call_id": tool_call["id"], - "content": [{"type": "text", "text": observation_json}], - }) - action_input_node.messages = deepcopy(newest_messages) - action_input_node.is_terminal = result["status"] in self.terminal_status - - return action_input_node async def _finalize_chain(self, chain_id, chain, current_node, depth): """Finalize the chain with reward calculation and cleanup.""" diff --git a/agents/agents/agents/llm_backend.py b/agents/agents/agents/llm_backend.py index 2401137..726620a 100644 --- a/agents/agents/agents/llm_backend.py +++ b/agents/agents/agents/llm_backend.py @@ -20,6 +20,7 @@ from vllm import LLM, AsyncLLMEngine, SamplingParams, AsyncEngineArgs import openai from .templates.templates import Chat +from .templates.vision_processor import get_processor import logging import PIL @@ -46,8 +47,8 @@ def apply_chat_template(self, messages_list: List[List[Dict]], template: str, ad for messages in messages_list: chat = Chat(template, messages) prompts.append(chat.prompt(add_generation_prompt=add_generation_prompt, tools=tools)) - # We don't support vision inputs for now - vision_inputs.append(None) + # We only support image inputs for now + vision_inputs.append(chat.vision_inputs()) return prompts, vision_inputs @@ -325,49 +326,6 @@ async def generate_streaming(self, messages_list: List[List[Dict]], **kwargs) -> if hasattr(sequence, 'text'): yield sequence.text -class VerlBackend(LLMBackend): - """Verl implementation""" - - def __init__(self, llm_engine, model_name_or_path: str, template: str, max_length: int=8192, **kwargs): - super().__init__(**kwargs) - self.model_name = model_name_or_path - self.max_length = max_length - self.template = template - self.tokenizer = AutoTokenizer.from_pretrained( - self.model_name, - trust_remote_code=True, - ) - self.llm_engine = llm_engine - - def generate(self, messages_list: str, **kwargs) -> str: - """Generate text from prompt using Verl""" - # We need to build a DataProto from the prompts - prompts = self.apply_chat_template(messages_list, self.template) - inputs = self.tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left") - # We need to do padding for compatibility with the verl DataProto, which - # assumes that the batch size must be divisible by the dp size - world_size = self.llm_engine.world_size - inputs['input_ids'] = pad_tensor_to_rank_size(inputs['input_ids'], world_size) - inputs['attention_mask'] = pad_tensor_to_rank_size(inputs['attention_mask'], world_size) - - position_ids = torch.clip(torch.cumsum(inputs.attention_mask, dim=-1) - 1, min=0, max=None) - inputs['position_ids'] = position_ids - - n = kwargs.get("num_return_sequences", 1) - temperature = kwargs.get("temperature", 1.0) - use_agent = True - batch = DataProto.from_single_dict(inputs, meta_info={"n": n, "use_agent": use_agent, "temperature": temperature}) - - gen_batch_output = self.llm_engine.generate_sequences(batch) - responses = gen_batch_output.batch['responses'] # BS x L - response_texts = self.tokenizer.batch_decode(responses, skip_special_tokens=True) # List of string with length BS - response_texts = response_texts[:len(prompts)*n] - return response_texts - - def generate_async(self, messages_list: str, **kwargs) -> str: - raise NotImplementedError("Verl backend does not support async generation") - - class AsyncVerlBackend(LLMBackend): """Verl implementation""" @@ -417,12 +375,6 @@ async def generate_async(self, messages_list: str, **kwargs) -> str: gen_batch_output = await self.llm_engine.generate_sequences_async(batch, **generation_config) response_texts = gen_batch_output.batch['responses'].tolist() # np.array of strings with length BS - # print(f"[AsyncVerlbackend] response_texts: {response_texts.shape} {type(response_texts)}") - # print(f"[AsyncVerlBackend] response_texts: {response_texts[0]}") - # raise NotImplementedError("Async Verl backend does not support sync generation") - # response_texts = [text for text in response_texts] - # response_texts = self.tokenizer.batch_decode(responses, skip_special_tokens=True) # List of string with length BS - # response_texts = responses[:len(prompts)*n] return response_texts @@ -465,24 +417,39 @@ def __init__( # --------------------------------------------------------------------- # # Low‑level single request (runs in threadpool so it doesn't block loop) # --------------------------------------------------------------------- # - @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=15)) - def _blocking_call(self, messages: List[List[Dict]], **kw) -> str: - if "num_return_sequences" in kw: - n = kw.pop("num_return_sequences") + @retry(stop=stop_after_attempt(1), wait=wait_exponential(multiplier=1, min=4, max=15)) + def _blocking_call(self, messages: List[List[Dict]], **kwargs) -> str: + if "num_return_sequences" in kwargs: + n = kwargs.pop("num_return_sequences") else: n = 1 + + if "tool_choice" in kwargs: + tool_choice = kwargs.pop("tool_choice") + else: + tool_choice = "none" + + print(f"[ClientBackend] messages: {messages}") resp = self.client.chat.completions.create( model=self.model_name, messages=messages, timeout=self.timeout, max_tokens=self.max_new_tokens, n=n, - tool_choice="none", - **kw, + tool_choice=tool_choice, + **kwargs, ) - response_texts = [choice.message.content for choice in resp.choices] + resp_json = resp.dict() + response_texts = [choice["message"]["content"] for choice in resp_json["choices"]] + tool_calls = [choice["message"]["tool_calls"] for choice in resp_json["choices"]] - return response_texts + if tool_choice == "none": + return response_texts + else: + return { + "response_texts": response_texts, + "tool_calls": tool_calls, + } async def _call(self, messages: List[List[Dict]], **kw) -> str: # acquire a rate‑limit token @@ -495,7 +462,7 @@ async def _call(self, messages: List[List[Dict]], **kw) -> str: def async_generate( self, messages: List[List[Dict]] | List[Dict], - **kw, + **kwargs, ) -> List[str] | asyncio.Task: """ • Pass a *list of messages* → single completion. @@ -510,14 +477,24 @@ def async_generate( messages_list = [messages] # single else: messages_list = messages # batch - + print(f"[ClientBackend] messages_list: {messages_list}") messages_list = [convert_messages_to_openai_format(messages) for messages in messages_list] async def _runner(): - tasks = [asyncio.create_task(self._call(_input, **kw)) for _input in messages_list] - texts_list = await asyncio.gather(*tasks) - response_texts = [text for texts in texts_list for text in texts] - return response_texts + tasks = [asyncio.create_task(self._call(_input, **kwargs)) for _input in messages_list] + # Flatten the response list + response_texts_list_or_dict = await asyncio.gather(*tasks) + # return is a dict if tool_choice is not none, otherwise a list of strings + if isinstance(response_texts_list_or_dict[0], dict): + response_texts = [text for response in response_texts_list_or_dict for text in response["response_texts"]] + tool_calls = [tool_call for response in response_texts_list_or_dict for tool_call in response["tool_calls"]] + return { + "response_texts": response_texts, + "tool_calls": tool_calls, + } + else: + response_texts = [text for response in response_texts_list_or_dict for text in response] + return response_texts try: loop = asyncio.get_running_loop() # ➊ already inside a loop? @@ -532,8 +509,8 @@ async def _runner(): async def generate_async(self, messages: List[List[Dict]] | List[Dict], - **kw) -> List[str]: - return await self.async_generate(messages, **kw) + **kwargs) -> List[str]: + return await self.async_generate(messages, **kwargs) # Background token‑bucket refill (one token each 60/max_rpm seconds) async def _refill_tokens(self): diff --git a/agents/agents/agents/specialized/openai_agent.py b/agents/agents/agents/specialized/openai_agent.py index f64a056..ae0f37a 100644 --- a/agents/agents/agents/specialized/openai_agent.py +++ b/agents/agents/agents/specialized/openai_agent.py @@ -1,10 +1,13 @@ import os -from typing import Any, Dict, List +from typing import Any, Dict, List, Tuple, Union from openai import OpenAI, AzureOpenAI import httpx - +import asyncio +from ...tools import answer_qa from ...tools.tool_base import tool from ..agent_base import BaseAgent +from ..llm_backend import ClientBackend +from ..backend_config import ClientConfig from tenacity import retry, wait_random_exponential, stop_after_attempt from termcolor import colored import json @@ -14,102 +17,113 @@ class OpenAIAgent(BaseAgent): def __init__( self, api_key="", + base_url="https://api.openai.com/v1", **kwargs ): - wrapper = kwargs.get("wrapper", False) - if not wrapper: - kwargs["wrapper"] = True - template = kwargs.get("template", "openai") - kwargs["template"] = template + assert api_key is not None and api_key != "", "API key is required" + backend = kwargs.get("backend", "client") + assert backend == "client", "OpenAI agent only supports client backend" + kwargs["backend"] = backend + + # Create client-specific configuration + client_config = ClientConfig( + api_key=api_key, + base_url=base_url, + max_requests_per_minute=kwargs.get("max_requests_per_minute", 100), + timeout=kwargs.get("timeout", 600), + max_new_tokens=kwargs.get("max_new_tokens", 1024), + temperature=kwargs.get("temperature", 1.0) + ) + kwargs["backend_config"] = client_config + + # Initialize the base class super(OpenAIAgent, self).__init__(**kwargs) + model_name_or_path = kwargs.get("model_name_or_path", "gpt-3.5-turbo") self.client = OpenAI(api_key=api_key) self.api_key = api_key self.model = model_name_or_path + + # For OpenAI models, we don't need a tokenizer for the LLM engine, but we still need one for trajectory processing + if self.backend == "client": + self.tokenizer = None + self.processor = None - def parse(self, messages_list: List[List[Dict]], tools: List[Any], **args): - # OpenAI use 'n' to specify the number of return sequences - num_return_sequences = args.get("num_return_sequences", 1) - tool_schemas = [tool.schema for tool in tools] - del args["num_return_sequences"] - args["n"] = num_return_sequences - - def process_message(message_data): - message, api_key, model, tool_schemas, args = message_data - client = OpenAI(api_key=api_key) - try: - json_data = { - "model": model, - "messages": message, - **args - } - if tool_schemas is not None: - json_data.update({"tools": tool_schemas}) - openai_response = client.chat.completions.create(**json_data) - result = openai_response.dict() - new_message = result['choices'][0]['message'] - new_message["loss"] = True - return new_message - except Exception as e: - print(f"Parsing Exception: {repr(e)}. Try again.") - return { - "role": "assistant", - "content": result, - "tool_calls": [], - "loss": True - } + async def generate_async(self, messages_list_or_inputs: List[List[Dict]], **args): + responses = await super().generate_async(messages_list_or_inputs, tool_choice="auto", **args) + return responses - # Prepare data for each message - message_data_list = [ - (message, self.api_key, self.model, tool_schemas, args) - for message in messages_list - ] - pool = Pool() - results = pool.map(process_message, message_data_list) - pool.close() - pool.join() - - return results - + def parse(self, responses: Union[Dict[str, List], List[str]], tools: List[Any], **args) -> List[Dict]: + """ + Parse responses into the correct message format. + + Args: + responses: List of response strings from the LLM. + tools: List of tools available to the agent. + **args: Additional arguments for parsing. + + Returns: + List of assistant messages in the correct format. + """ + + new_messages = [] + if isinstance(responses, dict): + for response, tool_calls in zip(responses["response_texts"], responses["tool_calls"]): + new_message = {"role": "assistant"} + + new_message["content"] = response + if len(tool_calls) > 0: + tool_calls = tool_calls[:1] + new_message["tool_calls"] = tool_calls + + new_messages.append(new_message) + elif isinstance(responses, list): + for content in responses: + new_message = {"role": "assistant", "content": content} + new_messages.append(new_message) + else: + raise ValueError(f"Invalid responses type: {type(responses)}") + + return new_messages - @retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3)) - def chat_completion_request( - self, - messages, - tools=None, - tool_choice=None, - model=None, - stop=None, - client=None, - **args - ): - if model is None: - model = self.model - if client is None: - client = self.client + # @retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(1)) + # def chat_completion_request( + # self, + # messages, + # tools=None, + # tool_choice=None, + # model=None, + # stop=None, + # client=None, + # **args + # ): + # if model is None: + # model = self.model + # if client is None: + # client = self.client - json_data = { - "model": model, - "messages": messages, - **args - } - if stop is not None: - json_data.update({"stop": stop}) - if tools is not None: - json_data.update({"tools": tools}) - if tool_choice is not None: - json_data.update({"tool_choice": tool_choice}) + # json_data = { + # "model": model, + # "messages": messages, + # **args + # } + # if stop is not None: + # json_data.update({"stop": stop}) + # if tools is not None: + # json_data.update({"tools": tools}) + # if tool_choice is not None: + # json_data.update({"tool_choice": tool_choice}) - try: - # We use chat completion API - openai_response = client.chat.completions.create(**json_data) - json_data = openai_response.dict() - return json_data - except Exception as e: - print(f"Unable to generate ChatCompletion response: {e}") - raise e + # try: + # # We use chat completion API + # openai_response = client.chat.completions.create(**json_data) + # json_data = openai_response.dict() + # return json_data + # except Exception as e: + # print(f"Unable to generate ChatCompletion response: {e}") + # raise e @tool() @@ -129,17 +143,23 @@ def get_current_weather(location: str, unit: str="fahrenheit"): if __name__ == "__main__": agent = OpenAIAgent( - model_name_or_path="gpt-3.5-turbo", - api_key="", - tools=[get_current_weather] + model_name_or_path="gpt-4o-mini", + api_key="OpenAI API Key", + tools=[get_current_weather,answer_qa] ) - messages = [{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}] - agent.run( - max_steps=3, + messages = [ + { + "messages": [ + {"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"} + ] + } + ] + asyncio.run(agent.run_async( + max_steps=5, start_messages=messages, num_chains=1 - ) + )) trajectories = agent.trajectories - print(trajectories[0]["messages"]) + print(f"""Trajectory: {trajectories[0]["messages"]}""") diff --git a/agents/agents/agents/templates/constants.py b/agents/agents/agents/templates/constants.py new file mode 100644 index 0000000..796ba97 --- /dev/null +++ b/agents/agents/agents/templates/constants.py @@ -0,0 +1,18 @@ +from enum import Enum, auto + +class ToolPlacement(Enum): + """ + Where to inject the tool catalogue in the rendered prompt. + """ + SYSTEM = auto() # inside the system message + FIRST_USER = auto() # as an extra first-user turn + LAST_USER = auto() # appended to the last user turn + SEPARATE = auto() # its own dedicated turn / role + + +class Role(Enum): + SYSTEM = "system" + USER = "user" + ASSISTANT = "assistant" + TOOL = "tool" + ASSISTANT_PREFIX = "assistant_prefix" \ No newline at end of file diff --git a/agents/agents/agents/templates/system_policy.py b/agents/agents/agents/templates/system_policy.py new file mode 100644 index 0000000..8ede734 --- /dev/null +++ b/agents/agents/agents/templates/system_policy.py @@ -0,0 +1,46 @@ +from abc import ABC, abstractmethod +import dataclasses +import datetime +from typing import Callable + +@dataclasses.dataclass +class SystemPolicy: + use_system: bool = True + use_system_without_system_message: bool = True + content_processor: Callable[[str], str] = None + + +class SystemContentProcessor(ABC): + @abstractmethod + def __call__(self, system_message: str) -> str: + raise NotImplementedError + + @abstractmethod + def jinja(self) -> str: + raise NotImplementedError + +class Llama32DateProcessor(SystemContentProcessor): + """ + A system content processor that adds date information to system messages. + + In Python mode, it dynamically computes the current date. + In Jinja mode, it provides a template with placeholders that can be processed. + + Usage in Jinja templates: + - The template includes '__CURRENT_DATE__' placeholder + - Replace '__CURRENT_DATE__' with the actual formatted date during processing + - Format should be 'dd MMM yyyy' (e.g., '15 Dec 2024') + - No external context variables required + """ + def __call__(self, system_message: str) -> str: + return f"Cutting Knowledge Date: December 2023\nToday Date: {datetime.datetime.now().strftime('%d %b %Y')}\n\n{system_message}" + + def jinja(self) -> str: + # For Jinja templates used by external systems (like vLLM), we need a self-contained approach + # Since external systems can't provide context variables, we use a placeholder approach + # The external system should replace __CURRENT_DATE__ with the actual date + return """Cutting Knowledge Date: December 2023 +Today Date: __CURRENT_DATE__ + +{{ system_message }}""" + \ No newline at end of file diff --git a/agents/agents/agents/templates/templates.py b/agents/agents/agents/templates/templates.py index 1b71cfe..7697f87 100644 --- a/agents/agents/agents/templates/templates.py +++ b/agents/agents/agents/templates/templates.py @@ -1,43 +1,44 @@ -""" -This file is a modified version of the original fastchat/conversation.py file. - -The original file can be found at: -https://github.com/LM-SYS/fastchat/blob/main/fastchat/conversation.py - -""" from collections import defaultdict -from copy import copy +from copy import copy, deepcopy import dataclasses -from enum import Enum, auto, IntEnum import json -from typing import List, Any, Dict, Union, Tuple +from typing import Callable, List, Any, Dict, Union, Tuple import warnings - +import logging import torch -from .preprocess import open_image_from_any from transformers import PreTrainedTokenizer +from .preprocess import open_image_from_any +from .vision_processor import is_vision_template import re +from typing import Protocol +from .tool_policy import ( + ToolFormatter, + JsonMinifiedFormatter, + JsonCompactFormatter, + JsonIndentedFormatter, + ToolMainContentProcessor, + JsonQwenFormatter, +) +from datetime import datetime +from .constants import Role +from .system_policy import Llama32DateProcessor, SystemPolicy +from .tool_policy import ToolPolicy +from .constants import ToolPlacement, Role + +Logger = logging.getLogger(__name__) + +# Add console handler if no handlers exist +if not Logger.handlers: + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.DEBUG) + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + console_handler.setFormatter(formatter) + Logger.addHandler(console_handler) -class Role(Enum): - SYSTEM = "system" - USER = "user" - ASSISTANT = "assistant" - TOOL = "tool" - ASSISTANT_PREFIX = "assistant_prefix" - - -class SeparatorStyle(IntEnum): - """Separator styles.""" - NO_COLON_SINGLE = auto() - LLAMA2 = auto() - LLAMA3 = auto() - CHATGLM = auto() - CHATML = auto() - ADD_SPACE_TWO = auto() - GENERAL = auto() # No sep, we take all sep as part of the role - GENERAL_STOP_ALL = auto() # No sep, and add stop_str at each turn - +@dataclasses.dataclass +class GlobalPolicy: + prefix: str = None @dataclasses.dataclass @@ -59,8 +60,15 @@ class Template: tool_template: str = None # The user template user_template: str = None + user_template_with_tools: str = None # The assistant template assistant_template: str = None + # Global policy + global_policy: "GlobalPolicy" = None + # System message policy + system_policy: "SystemPolicy" = None + # Tool policy for this template + tool_policy: "ToolPolicy" = None ## vision part vision_start: str = None @@ -68,63 +76,199 @@ class Template: image_token: str = None video_token: str = None + chat_template: str = None + + def __post_init__(self): + """Post-initialization to automatically register vision processor if vision tokens are defined""" + if self.image_token or self.video_token: + self._register_vision_processor() + # Initialise default tool policy if none was provided + if self.tool_policy is None: + self.tool_policy = ToolPolicy() + if self.system_policy is None: + self.system_policy = SystemPolicy() + + def _register_vision_processor(self): + """Automatically register a vision processor for this template""" + from .vision_processor import VisionProcessorConfig, register_processor + + # Determine model type based on template name + model_type = self._infer_model_type() + + # Create vision config + config = VisionProcessorConfig( + model_type=model_type, + image_token=self.image_token or "", + video_token=self.video_token or "", + vision_start=self.vision_start or "", + vision_end=self.vision_end or "", + processor_class="AutoProcessor", + expansion_strategy="patch_based" + ) + + # Register the processor + register_processor(self.name, config) + + def _infer_model_type(self) -> str: + """Infer model type from template name""" + name_lower = self.name.lower() + + if "qwen" in name_lower: + return "qwen_vl" + elif "llava" in name_lower: + return "llava" + elif "gemma" in name_lower: + return "gemma3" + elif "paligemma" in name_lower: + return "paligemma" + elif "internvl" in name_lower: + return "internvl" + elif "minicpm" in name_lower: + return "minicpm" + elif "mllama" in name_lower: + return "mllama" + elif "pixtral" in name_lower: + return "pixtral" + elif "video" in name_lower: + return "video_llava" + else: + # Default to patch-based for unknown models + return "patch_based" + + def _supports_tool_call(self) -> bool: + if (self.system_template_with_tools or self.user_template_with_tools) and self.tool_template: + return True + else: + return False + def render(self, messages: List[Dict], tools=None, add_generation_prompt: bool = False) -> str: - """Render the template with the given messages and kwargs. - messages: [ - { - "role": "user", - "content": [ - { - "type": "text", - "text": "Hello, how are you?" - } - ] - }, - { - "role": "assistant", - ] + """Render the template and return + 1. the final prompt string, + 2. the list of string *elements* that compose the prompt, and + 3. the corresponding list of *roles* (used by downstream post-processing). + + The heavy lifting is delegated to small, single-purpose helpers so the + high-level flow is immediately apparent: + + 1. _insert_tools – decide where the tool catalogue lives + 2. _encode_turns – encode every conversation turn + 3. _maybe_add_generation_prompt – append the generation prefix if requested """ - elements = [] - roles = [] - if tools: - tools = self._encode_system_tools(tools) - for i, message in enumerate(messages): + # Step 1 – decide tool placement & clone messages + work_messages, tools_str, insert_tools_idx = self._insert_tools(messages, tools) + + # Step 2 – encode each conversation turn to text tokens + elements, roles = self._encode_turns(work_messages, tools_str, insert_tools_idx) + + # Step 3 – append generation prefix if needed + if add_generation_prompt: + self._maybe_add_generation_prompt(elements, roles) + + # Concatenate the prompt + prompt = "".join(elements) + return prompt, elements, roles - if i == 0 and self._detect_role(message["role"]) == Role.SYSTEM: - system_message = self._encode_system_message(message["content"], tools=tools) - elements.append(system_message) - roles.append(Role.SYSTEM) - # This message is done + def _insert_tools(self, messages: List[Dict], tools): + """Clone *messages* and compute where (and how) the tool catalogue + should be injected. + + Returns + ------- + work_messages : List[Dict] + A deepcopy of the original *messages* so we never mutate caller data. + tools_str : Optional[str] + The formatted tool catalogue or *None* if `tools` is falsy. + insert_tools_idx : int + Index of the *user* message that receives the catalogue, or -1 when + no injection is required. + """ + + work_messages = deepcopy(messages) + if tools: + tools_str = self.tool_policy.format_tools(tools) + placement = self.tool_policy.placement + insert_tools_idx = self._find_insert_tools_index(work_messages, placement) + else: + tools_str = None + insert_tools_idx = -1 + return work_messages, tools_str, insert_tools_idx + + def _encode_turns( + self, + work_messages: List[Dict], + tools_str: str, + insert_tools_idx: int, + ) -> Tuple[List[str], List[Role]]: + """Convert every message dict into its textual representation while + tracking roles for later masking logic.""" + + elements: List[str] = [] + roles: List[Role] = [] + + # Global prefix comes first (rarely used but must respect ordering) + if self.global_policy and self.global_policy.prefix: + elements.append(self.global_policy.prefix) + roles.append(Role.SYSTEM) + + for i, message in enumerate(work_messages): + current_role = self._detect_role(message["role"]) + + # -------------------------------------------------------------- + # Handle system message insertion on the very first turn + # -------------------------------------------------------------- + if i == 0 and current_role == Role.SYSTEM: + if self.system_policy.use_system: + system_message = self._encode_system_message( + message["content"], tools=tools_str + ) + elements.append(system_message) + roles.append(Role.SYSTEM) + # Whether inserted or not, we skip further handling of this + # message because it's the (optional) system turn itself. continue - elif i == 0 and self._detect_role(message["role"]) != Role.SYSTEM: - system_message = self._encode_system_message_default(tools=tools) - elements.append(system_message) - roles.append(Role.SYSTEM) - # This message is not done, we need to handle other roles - - if self._detect_role(message["role"]) == Role.USER: - user_message = self._encode_user_message(message["content"]) + elif i == 0 and current_role != Role.SYSTEM: + if self.system_policy.use_system: + system_message = self._encode_system_message_default(tools=tools_str) + elements.append(system_message) + roles.append(Role.SYSTEM) + # Do *not* `continue` – we still need to encode this first message. + + # -------------------------------------------------------------- + # Encode regular conversation turns + # -------------------------------------------------------------- + if current_role == Role.USER: + if i == insert_tools_idx: + user_message = self._encode_user_message_with_tools( + message["content"], tools=tools_str + ) + else: + user_message = self._encode_user_message(message["content"]) elements.append(user_message) roles.append(Role.USER) - elif self._detect_role(message["role"]) == Role.ASSISTANT: + + elif current_role == Role.ASSISTANT: assistant_message = self._encode_assistant_message(message["content"]) elements.append(assistant_message) roles.append(Role.ASSISTANT) - elif self._detect_role(message["role"]) == Role.TOOL: + + elif current_role == Role.TOOL: tool_message = self._encode_tool_message(message["content"]) elements.append(tool_message) roles.append(Role.TOOL) + else: raise ValueError(f"Invalid role: {message['role']}") - - if add_generation_prompt: - generation_prefix = self._encode_generation_prompt() - elements.append(generation_prefix) - roles.append(Role.ASSISTANT_PREFIX) - - prompt = "".join(elements) - return prompt, elements, roles + + return elements, roles + + def _maybe_add_generation_prompt(self, elements: List[str], roles: List[Role]): + """Append the generation prefix so the model knows to continue + generating an assistant response.""" + + generation_prefix = self._encode_generation_prompt() + elements.append(generation_prefix) + roles.append(Role.ASSISTANT_PREFIX) def _detect_role(self, role: str) -> Role: if role == "system": @@ -137,53 +281,114 @@ def _detect_role(self, role: str) -> Role: return Role.TOOL else: raise ValueError(f"Invalid role: {role}") + + def _find_insert_tools_index(self, work_messages: List[Dict], placement: ToolPlacement) -> int: + insert_tools_idx = 0 # Default to insert tools at system message + for i, message in enumerate(work_messages): + if placement == ToolPlacement.SYSTEM: + insert_tools_idx = 0 + elif placement == ToolPlacement.FIRST_USER: + if message.get("role") == "user": + insert_tools_idx = i + break + elif placement == ToolPlacement.LAST_USER: + if message.get("role") == "user": + insert_tools_idx = i + else: + raise ValueError(f"Unhandled ToolPlacement: {placement}") + return insert_tools_idx def _encode_system_tools(self, tools: List[Dict]) -> str: return "\n".join([json.dumps(tool) for tool in tools]) def _encode_system_message_default(self, tools=None) -> str: + if not self.system_policy.use_system_without_system_message: + return "" + + if self.system_policy.content_processor is not None: + system_message = self.system_policy.content_processor(self.system_message) + else: + system_message = self.system_message + if tools is None: - return self.system_template.format(system_message=self.system_message) + return self.system_template.format(system_message=system_message) else: if self.system_template_with_tools: - return self.system_template_with_tools.format(system_message=self.system_message, tools=tools) + return self.system_template_with_tools.format(system_message=system_message, tools=tools) else: - return self.system_template.format(system_message=self.system_message) + return self.system_template.format(system_message=system_message) def _encode_system_message(self, content, tools=None) -> str: - if tools is None: + # Handle both string content and list content formats + if isinstance(content, str): + system_message = content + else: system_message = content[0]['text'] + + if self.system_policy.content_processor is not None: + system_message = self.system_policy.content_processor(system_message) + + if tools is None: return self.system_template.format(system_message=system_message) else: - system_message = content[0]['text'] if self.system_template_with_tools is None: return self.system_template.format(system_message=system_message) else: return self.system_template_with_tools.format(system_message=system_message, tools=tools) - - def _encode_user_message(self, content: List[Dict]) -> str: - text = "" - for item in content: - if item["type"] == "text": - text += item["text"] - elif item["type"] == "image": - text += self.vision_start + self.image_token + self.vision_end - elif item["type"] == "video": - text += self.vision_start + self.video_token + self.vision_end - else: - raise ValueError(f"Invalid message type: {item['type']}") + + def _encode_user_message_with_tools(self, content, tools: str) -> str: + # Handle both string content and list content formats + if isinstance(content, str): + text = content + else: + text = "" + for item in content: + if item["type"] == "text": + text += item["text"] + else: + raise ValueError(f"Invalid message type: {item['type']}") + + if self.user_template_with_tools: + user_message = self.user_template_with_tools.format(content=text, tools=tools) + else: + user_message = self.user_template.format(content=text) + return user_message + + def _encode_user_message(self, content) -> str: + # Handle both string content and list content formats + if isinstance(content, str): + text = content + else: + text = "" + for item in content: + if item["type"] == "text": + text += item["text"] + elif item["type"] in ["image", "image_url"]: + text += self.vision_start + self.image_token + self.vision_end + elif item["type"] == "video": + text += self.vision_start + self.video_token + self.vision_end + else: + raise ValueError(f"Invalid message type: {item['type']}") user_message = self.user_template.format(content=text) return user_message - def _encode_assistant_message(self, content: List[Dict]) -> str: - assert len(content) == 1, "Assistant message must be a single message" - text = content[0]["text"] + def _encode_assistant_message(self, content) -> str: + # Handle both string content and list content formats + if isinstance(content, str): + text = content + else: + assert len(content) == 1, "Assistant message must be a single message" + text = content[0]["text"] assistant_message = self.assistant_template.format(content=text) return assistant_message - def _encode_tool_message(self, content: List[Dict]) -> str: - assert len(content) == 1, "Tool message must be a single message" - text = content[0]["text"] + def _encode_tool_message(self, content) -> str: + # Handle both string content and list content formats + if isinstance(content, str): + text = content + else: + assert len(content) == 1, "Tool message must be a single message" + text = content[0]["text"] tool_message = self.tool_template.format(observation=text) return tool_message @@ -208,19 +413,36 @@ def _split_assistant_message(self, assistant_message: str) -> List[str]: return generation_prefix, content, suffix - def encode(self, messages: List[Dict], tokenizer: PreTrainedTokenizer, return_tensors: str = None, tools=None) -> str: - prompt, elements, roles = self.render(messages, tools=tools) + def encode(self, messages: List[Dict], tokenizer: PreTrainedTokenizer, return_tensors: str = None, tools=None, add_generation_prompt=False, processor=None) -> str: + if processor is None and self.supports_vision(): + raise ValueError(f"Processor is required for vision templates: {self.name}") + + if self.supports_vision(): + # Use vision-aware encoding with proper alignment + return self._encode_with_vision_processor(messages, tokenizer, return_tensors, tools, add_generation_prompt=add_generation_prompt, processor=processor) + else: + # Use standard encoding + return self._encode_standard(messages, tokenizer, return_tensors, tools, add_generation_prompt=add_generation_prompt) + + def _encode_standard(self, messages: List[Dict], tokenizer: PreTrainedTokenizer, return_tensors: str = None, tools=None, add_generation_prompt=False) -> str: + Logger.debug(f"[Template] Encoding standard for template: {self.name}") + """Standard encoding without vision support""" + prompt, elements, roles = self.render(messages, tools=tools, add_generation_prompt=add_generation_prompt) elements, mask_flags = self._postprocess_elements(elements, roles) input_ids = [] attention_mask = [] labels = [] action_mask = [] - if tokenizer.bos_token and tokenizer.add_bos_token: - input_ids.append(tokenizer.bos_token_id) - attention_mask.append(1) - labels.append(-100) - action_mask.append(0) + + if tokenizer.bos_token: + # If add_bos_token is not set, we assume to add bos token + # There is potential issue if the tokenizer has bos_token but do not add it by default + if getattr(tokenizer, "add_bos_token", True): + input_ids.append(tokenizer.bos_token_id) + attention_mask.append(1) + labels.append(-100) + action_mask.append(0) for element, mask_flag in zip(elements, mask_flags): cur_input_ids = tokenizer.encode(element, add_special_tokens=False) @@ -241,6 +463,41 @@ def encode(self, messages: List[Dict], tokenizer: PreTrainedTokenizer, return_te if return_tensors == "pt": inputs = {k: torch.tensor([v]) for k, v in inputs.items()} return inputs + + def _encode_with_vision_processor(self, messages: List[Dict], tokenizer: PreTrainedTokenizer, return_tensors: str = None, tools=None, add_generation_prompt=False, processor=None) -> str: + Logger.debug(f"[Template] Encoding with vision processor for template: {self.name}") + """Encode with vision processor handling proper alignment""" + from .vision_processor import get_processor + from .utils import extract_vision_inputs_from_messages + + # Get vision processor + vision_processor = get_processor(self.name) + if vision_processor is None: + raise ValueError(f"No vision processor registered for template: {self.name}") + + # Get base prompt and mask information + prompt, elements, roles = self.render(messages, tools=tools, add_generation_prompt=add_generation_prompt) + elements, mask_flags = self._postprocess_elements(elements, roles) + + # Extract vision inputs + images, videos = extract_vision_inputs_from_messages(messages) + + Logger.debug(f"[Template] images: {len(images)}") + Logger.debug(f"[Template] videos: {len(videos)}") + + Logger.debug(f"[Template] messages: {messages}") + + # Use vision processor with alignment support + return vision_processor.process_for_llm( + prompt=prompt, + elements=elements, + mask_flags=mask_flags, + images=images, + videos=videos, + processor=processor, + tokenizer=tokenizer, + return_tensors=return_tensors + ) def _postprocess_elements(self, elements: List[str], roles) -> List[str]: @@ -302,67 +559,112 @@ def _postprocess_elements(self, elements: List[str], roles) -> List[str]: merged_mask_flags.append(prev_mask_flag) return merged_elements, merged_mask_flags + def supports_vision(self) -> bool: + """Check if this template supports vision processing""" + return is_vision_template(self.name) - def get_vision_inputs(self): + def get_vision_inputs(self, messages: List[Dict]): vision_inputs = defaultdict(list) - for role, message, _ in self.messages: - if isinstance(message, list): - for item in message: + Logger.debug(f"[Template] get_vision_inputs: messages: {messages}") + for message in messages: + content = message["content"] + if isinstance(content, list): + for item in content: if item['type'] == 'text': continue - elif item['type'] == 'image': - vision_inputs['image'].append(open_image_from_any(item['image'])) + elif item['type'] in ['image', 'image_url', 'image_base64']: + vision_inputs["image"].append(open_image_from_any(item[item['type']])) elif item['type'] == 'video': raise NotImplementedError("Video is not supported for chat template.") else: raise ValueError(f"Invalid message type: {item['type']}") + else: + raise ValueError(f"Invalid message content: {content}, the content should be a list of dicts") return vision_inputs - # if self.name == "qwen2.5-vl": - # jinja_template = """{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['role'] == 'tool'%}\n{% endif %}{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% if message['role'] == 'tool' %}\n\n{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}""" - def jinja_template(self) -> str: - """ - Build a Hugging-Face-style chat-template (Jinja-mini dialect) that mimics - `self.render`. The template expects three variables in its context: + if self.chat_template: + return self.chat_template + else: + return self.render_jinja_template() - • messages – list[dict] (same format you pass to .render) - • add_generation_prompt – bool (default False) - • tools – list[dict] (optional, for tool-enabled templates) + def render_jinja_template(self) -> str: + """Return a Hugging-Face style chat-template (Jinja-mini dialect). - No other Python state is referenced, so the string can be cached in the - tokenizer and shipped to a different process. + The implementation now mirrors the three-step structure of + `render()` for easier maintenance: + + 1. _jinja_header_constants – immutable `set` statements + 2. _jinja_system_block – first turn / system handling + 3. _jinja_loop_messages – remaining turns & per-role logic + 4. _jinja_generation_block – optional generation prefix """ - # ------------------------------------------------------------------ - # 1. Pre-compute constant strings so the inner template stays tiny - # ------------------------------------------------------------------ - default_system = self.system_template.format(system_message=self.system_message) + + parts: List[str] = [] + + # 1. Constant header (always first) + parts.extend(self._jinja_header_constants()) + + # 2. System-message handling (depends on presence of tools etc.) + parts.extend(self._jinja_system_block()) + + # 2.5 Pre-compute insert index for user placement + parts.extend(self._jinja_compute_insert_idx()) + + # 3. Loop over remaining messages + parts.extend(self._jinja_loop_messages()) + + # 4. Generation prefix block + parts.extend(self._jinja_generation_block()) + + template_str = "".join(parts) - # Don't pre-format system_template_with_tools - handle it in Jinja - system_template_with_tools_raw = self.system_template_with_tools if self.system_template_with_tools else None + # Post-process: Replace __CURRENT_DATE__ placeholder with actual date + if "__CURRENT_DATE__" in template_str: + from datetime import datetime + current_date = datetime.now().strftime('%d %b %Y') + template_str = template_str.replace("__CURRENT_DATE__", current_date) + + return template_str + + # ------------------------------------------------------------------ + # Private helpers – keep them together for readability + # ------------------------------------------------------------------ + + def _jinja_header_constants(self) -> List[str]: + """Return Jinja `set` statements for all constant strings.""" + # Compute default system message considering content processor + if self.system_policy.content_processor is not None: + # Apply content processor to system message + processed_system_message = self.system_policy.content_processor(self.system_message) + default_system = self.system_template.format(system_message=processed_system_message) + else: + default_system = self.system_template.format(system_message=self.system_message) + + system_template_with_tools_raw = ( + self.system_template_with_tools if self.system_template_with_tools else None + ) + + # Split templates try: u_pref, u_suff = self.user_template.split("{content}") a_pref, a_suff = self.assistant_template.split("{content}") - except ValueError as e: # missing {content} - raise ValueError("`user_template` / `assistant_template` must contain " - "`{content}` placeholder") from e + except ValueError as exc: + raise ValueError( + "`user_template` / `assistant_template` must contain `{content}` placeholder" + ) from exc if self.tool_template: t_pref, t_suff = self.tool_template.split("{observation}") - else: # tools optional + else: t_pref, t_suff = "", "" - # tokens for images / videos + # Tokens for images / videos img_tok = (self.vision_start or "") + (self.image_token or "") + (self.vision_end or "") vid_tok = (self.vision_start or "") + (self.video_token or "") + (self.vision_end or "") - # ------------------------------------------------------------------ - # 2. Assemble the Jinja text; everything in triple-quotes is copied - # verbatim into the tokenizer. We splice in the constants that - # never change for this Template instance. - # ------------------------------------------------------------------ - template_parts = [ + header = [ f"{{% set _u_pref = {u_pref!r} %}}", f"{{% set _u_suff = {u_suff!r} %}}", f"{{% set _a_pref = {a_pref!r} %}}", @@ -374,39 +676,146 @@ def jinja_template(self) -> str: f"{{% set _default_system = {default_system!r} %}}", f"{{% set _system_message = {self.system_message!r} %}}", f"{{% set _system_template = {self.system_template!r} %}}", + f"{{% set _tool_placement = {self.tool_policy.placement.name!r} %}}", ] - - # Add system template with tools if available + if system_template_with_tools_raw: - template_parts.append(f"{{% set _system_template_with_tools = {system_template_with_tools_raw!r} %}}") - - template_parts.extend([ + header.append( + f"{{% set _system_template_with_tools = {system_template_with_tools_raw!r} %}}" + ) + + # Add user template with tools if it exists + if self.user_template_with_tools: + # Convert double braces to single braces for Jinja compatibility + processed_template = self.user_template_with_tools.replace('{{', '{').replace('}}', '}') + header.append( + f"{{% set _u_template_with_tools = {processed_template!r} %}}" + ) + + # ------------------------------------------------------------------ + # Formatter macro for tools (only if the template supports tool calls) + # ------------------------------------------------------------------ + + if self._supports_tool_call(): + # Build a Jinja macro that reproduces ToolPolicy.format_tools behaviour + formatter_snippet = self.tool_policy.formatter.jinja() + + # The snippet usually comes wrapped in "{{ ... }}". We drop the + # outer braces because macro bodies are already an output context. + formatter_body = formatter_snippet.strip() + + header.extend( + [ + "{% macro _fmt_tools(tools) %}", + f"{formatter_body}", + "{% endmacro %}", + ] + ) + + # ------------------------------------------------------------------ + # System processor macro (if system policy has a content processor) + # ------------------------------------------------------------------ + + if self.system_policy.content_processor is not None: + # Build a Jinja macro that reproduces the system content processor behaviour + processor_snippet = self.system_policy.content_processor.jinja() + + # The snippet should be a template that expects 'system_message' variable + # We create a macro that can be called with the system message + header.extend( + [ + "{% macro _process_system_message(system_message) %}", + f"{processor_snippet}", + "{% endmacro %}", + ] + ) + + return header + + def _jinja_compute_insert_idx(self) -> List[str]: + """Return Jinja code that pre-computes the index where tools should + be injected for FIRST_USER and LAST_USER placements.""" + + return [ + "{% set _insert_ns = namespace(idx=-1) %}", + "{% if _tool_placement in ['FIRST_USER', 'LAST_USER'] %}", + "{%- for _m in messages -%}", + "{%- if _m['role'] == 'user' -%}", + "{%- if _tool_placement == 'FIRST_USER' and _insert_ns.idx == -1 -%}", + "{% set _insert_ns.idx = loop.index0 %}", + "{%- elif _tool_placement == 'LAST_USER' -%}", + "{% set _insert_ns.idx = loop.index0 %}", + "{%- endif -%}", + "{%- endif -%}", + "{%- endfor -%}", + "{% endif %}", + ] + + def _jinja_system_block(self) -> List[str]: + """Return Jinja code that handles the system message logic.""" + + return [ # Handle system message first (matching render logic) "{% if messages and messages[0]['role'] == 'system' %}", "{% if tools and _system_template_with_tools %}", "{% if messages[0]['content'] is string %}", - "{{ _system_template_with_tools.format(system_message=messages[0]['content'], tools=tools | map('tojson') | join('\\n')) }}", + "{% if _process_system_message is defined %}", + "{{ _system_template_with_tools.format(system_message=_process_system_message(messages[0]['content']), tools=_fmt_tools(tools)) }}", + "{% else %}", + "{{ _system_template_with_tools.format(system_message=messages[0]['content'], tools=_fmt_tools(tools)) }}", + "{% endif %}", "{% else %}", - "{{ _system_template_with_tools.format(system_message=messages[0]['content'][0]['text'], tools=tools | map('tojson') | join('\\n')) }}", + "{% if _process_system_message is defined %}", + "{{ _system_template_with_tools.format(system_message=_process_system_message(messages[0]['content'][0]['text']), tools=_fmt_tools(tools)) }}", + "{% else %}", + "{{ _system_template_with_tools.format(system_message=messages[0]['content'][0]['text'], tools=_fmt_tools(tools)) }}", + "{% endif %}", "{% endif %}", "{% else %}", "{% if messages[0]['content'] is string %}", + "{% if _process_system_message is defined %}", + "{% set processed_message = _process_system_message(messages[0]['content']) %}", + "{% set formatted_system = _system_template | replace('{system_message}', processed_message) %}{{ formatted_system }}", + "{% else %}", "{% set formatted_system = _system_template | replace('{system_message}', messages[0]['content']) %}{{ formatted_system }}", + "{% endif %}", + "{% else %}", + "{% if _process_system_message is defined %}", + "{% set processed_message = _process_system_message(messages[0]['content'][0]['text']) %}", + "{% set formatted_system = _system_template | replace('{system_message}', processed_message) %}{{ formatted_system }}", "{% else %}", "{% set formatted_system = _system_template | replace('{system_message}', messages[0]['content'][0]['text']) %}{{ formatted_system }}", "{% endif %}", "{% endif %}", + "{% endif %}", "{% else %}", "{% if tools and _system_template_with_tools %}", - "{{ _system_template_with_tools.format(system_message=_system_message, tools=tools | map('tojson') | join('\\n')) }}", + "{% if _process_system_message is defined %}", + "{{ _system_template_with_tools.format(system_message=_process_system_message(_system_message), tools=_fmt_tools(tools)) }}", + "{% else %}", + "{{ _system_template_with_tools.format(system_message=_system_message, tools=_fmt_tools(tools)) }}", + "{% endif %}", + "{% else %}", + "{% if _process_system_message is defined %}", + "{% set processed_message = _process_system_message(_system_message) %}", + "{% set formatted_system = _system_template | replace('{system_message}', processed_message) %}{{ formatted_system }}", "{% else %}", "{{ _default_system }}", "{% endif %}", "{% endif %}", + "{% endif %}", + ] + + def _jinja_loop_messages(self) -> List[str]: + """Return Jinja loop that encodes all messages except the first system.""" + + return [ + "{% set _tool_ns = namespace(inserted=False, user_count=0) %}", # Process remaining messages (skip first if it was system) "{% for m in messages %}", "{% if not (loop.first and m['role'] == 'system') %}", "{% if m['role'] == 'user' %}", + "{% set _tool_ns.user_count = _tool_ns.user_count + 1 %}", "{% set ns = namespace(txt='') %}", "{% if m['content'] is string %}", "{% set ns.txt = m['content'] %}", @@ -421,7 +830,17 @@ def jinja_template(self) -> str: "{% endif %}", "{% endfor %}", "{% endif %}", + "{% if tools and ((_tool_placement == 'FIRST_USER' and _tool_ns.user_count == 1) or (_tool_placement == 'LAST_USER' and loop.index0 == _insert_ns.idx)) and not _tool_ns.inserted %}", + "{% if _u_template_with_tools is defined %}", + "{% set formatted_tools = _fmt_tools(tools) %}", + "{{ _u_template_with_tools | replace('{content}', ns.txt) | replace('{tools}', formatted_tools) }}", + "{% else %}", + "{{ _u_pref }}{{ ns.txt }}{{ _u_suff }}\\n{{ _fmt_tools(tools) }}", + "{% endif %}", + "{% set _tool_ns.inserted = True %}", + "{% else %}", "{{ _u_pref }}{{ ns.txt }}{{ _u_suff }}", + "{% endif %}", "{% elif m['role'] == 'assistant' %}", "{% if m['content'] is string %}", "{{ _a_pref }}{{ m['content'] }}{{ _a_suff }}", @@ -437,12 +856,16 @@ def jinja_template(self) -> str: "{% endif %}", "{% endif %}", "{% endfor %}", + ] + + def _jinja_generation_block(self) -> List[str]: + """Return Jinja code that appends the generation prefix when requested.""" + + return [ "{% if add_generation_prompt %}", "{{ _a_pref }}", - "{% endif %}" - ]) - - return "".join(template_parts) + "{% endif %}", + ] def render_with_mask(self, messages: List[Dict], add_generation_prompt: bool = False, tools=None): @@ -463,14 +886,6 @@ def set_system_message(self, system_message: str): """Set the system message.""" self.system_message = system_message - def set_tools(self, tools: List[Dict]): - """Set the tools.""" - if self.tool_aggregator == "DEFAULT": - self.tools = json.dumps(tools) - elif self.tool_aggregator == "STACKED": - self.tools = "\n".join([json.dumps(tool) for tool in tools]) - else: - raise ValueError(f"Invalid tool aggregator: {self.tool_aggregator}") def copy(self): return Template( @@ -479,6 +894,7 @@ def copy(self): system_template_with_tools=self.system_template_with_tools, system_message=self.system_message, user_template=self.user_template, + user_template_with_tools=self.user_template_with_tools, assistant_template=self.assistant_template, tool_template=self.tool_template, stop_words=self.stop_words, @@ -486,6 +902,10 @@ def copy(self): vision_end=self.vision_end, image_token=self.image_token, video_token=self.video_token, + global_policy=deepcopy(self.global_policy), + system_policy=deepcopy(self.system_policy), + tool_policy=deepcopy(self.tool_policy), + chat_template=self.chat_template, ) def dict(self): @@ -554,12 +974,15 @@ def prompt_with_mask(self, add_generation_prompt=False, tools=None) -> str: prompt_with_mask, _, _ = self.template.render_with_mask(messages=self.messages, add_generation_prompt=add_generation_prompt, tools=tools) return prompt_with_mask - def tokenize(self, tokenizer: PreTrainedTokenizer = None, add_generation_prompt=False, tools=None) -> List[int]: + def vision_inputs(self) -> List[Any]: + return self.template.get_vision_inputs(self.messages) + + def tokenize(self, tokenizer: PreTrainedTokenizer = None, add_generation_prompt=False, tools=None, processor=None) -> List[int]: if tokenizer is None: tokenizer = self.tokenizer if tools is None: tools = self.tools - return self.template.encode(messages=self.messages, tokenizer=tokenizer, return_tensors="pt", tools=tools) + return self.template.encode(messages=self.messages, tokenizer=tokenizer, return_tensors="pt", tools=tools, add_generation_prompt=add_generation_prompt, processor=processor) def append(self, message: Union[Dict, List[Dict]]): self._convert_single_message_to_hf_format(message) @@ -594,13 +1017,26 @@ def get_template(name: str) -> Template: assistant_template="<|im_start|>assistant\n{content}<|im_end|>\n", tool_template="<|im_start|>user\n\n{observation}\n<|im_end|>\n", stop_words=["<|im_end|>"], + ) +) + +register_template( + Template( + name="qwen2.5-vl", + system_template="<|im_start|>system\n{system_message}<|im_end|>\n", + system_message="You are a helpful assistant.", + user_template="<|im_start|>user\n{content}<|im_end|>\n", + assistant_template="<|im_start|>assistant\n{content}<|im_end|>\n", + tool_template="<|im_start|>tool\n{observation}<|im_end|>\n", vision_start="<|vision_start|>", vision_end="<|vision_end|>", image_token="<|image_pad|>", video_token="<|video_pad|>", + stop_words=["<|im_end|>"], ) ) + register_template( Template( name="qwen2.5", @@ -611,10 +1047,6 @@ def get_template(name: str) -> Template: assistant_template="<|im_start|>assistant\n{content}<|im_end|>\n", tool_template="<|im_start|>user\n\n{observation}\n<|im_end|>\n", stop_words=["<|im_end|>"], - vision_start="<|vision_start|>", - vision_end="<|vision_end|>", - image_token="<|image_pad|>", - video_token="<|video_pad|>", ) ) @@ -648,48 +1080,98 @@ def get_template(name: str) -> Template: ) ) -# register_conv_template( -# Template( -# name="qwen3", -# system_template="<|im_start|>system\n{system_message}<|im_end|>\n", -# # system_message="", -# system_template_tool="""<|im_start|>system\n{system_message}# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n\n{tools}\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{{"name": , "arguments": }}\n<|im_end|>\n""", -# roles=("<|im_start|>user\n", "<|im_start|>assistant\n", "<|im_start|>user\n\n{observation}\n"), -# tool_aggregator="STACKED", -# sep_style=SeparatorStyle.GENERAL_STOP_ALL, -# sep="", -# stop_token_ids=[ -# 151643, -# 151644, -# 151645, -# ], # "<|endoftext|>", "<|im_start|>", "<|im_end|>" -# stop_str="<|im_end|>\n", -# ) -# ) -# register_conv_template( + +# TODO: mistral template has many cornor cases, leave it for now +# register_template( # Template( -# name="qwen2.5-vl", -# system_template="<|im_start|>system\n{system_message}<|im_end|>\n", -# system_message="You are a helpful assistant.", -# roles=("<|im_start|>user\n", "<|im_start|>assistant\n", "<|im_start|>tool\n\n{observation}\n"), -# tool_aggregator="STACKED", -# sep_style=SeparatorStyle.GENERAL_STOP_ALL, -# sep="", -# vision_start="<|vision_start|>", -# vision_end="<|vision_end|>", -# image_token="<|image_pad|>", -# video_token="<|video_pad|>", -# stop_token_ids=[ -# 151643, -# 151644, -# 151645, -# ], # "<|endoftext|>", "<|im_start|>", "<|im_end|>" -# stop_str="<|im_end|>\n", +# name="mistral", +# system_template="{system_message}", +# user_template="[INST] {content}[/INST] ", +# user_template_with_tools="[AVAILABLE TOOLS] {tools} [/AVAILABLE TOOLS] [INST] {content}[/INST] ", +# assistant_template="{content}", +# tool_template="{observation}", +# stop_words=[""], +# system_policy=SystemPolicy( +# use_system=False, +# ), +# tool_policy=ToolPolicy( +# placement=ToolPlacement.LAST_USER, +# formatter=JsonCompactFormatter() +# ) # ) # ) +# TODO: system template includes current date +register_template( + Template( + name="llama-3.2", + system_template="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>", + system_template_with_tools="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\n{system_message}<|eot_id|>", + user_template="<|start_header_id|>user<|end_header_id|>\n\n{content}<|eot_id|>", + user_template_with_tools="""<|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {{"name": function name, "parameters": dictionary of argument name and its value}}.Do not use variables.\n\n{tools}\n\n{content}<|eot_id|>""", + assistant_template="<|start_header_id|>assistant<|end_header_id|>\n\n{content}<|eot_id|>", + tool_template="""<|start_header_id|>ipython<|end_header_id|>\n\n"{observation}"<|eot_id|>""", + stop_words=["<|eot_id|>"], + system_policy=SystemPolicy( + use_system=True, + content_processor=Llama32DateProcessor(), + ), + tool_policy=ToolPolicy( + placement=ToolPlacement.FIRST_USER, + formatter=JsonIndentedFormatter() + ) + ) +) + +register_template( + Template( + name="glm-4", + system_template="<|system|>\n{system_message}", + user_template="<|user|>\n{content}", + assistant_template="<|assistant|>\n{content}", + stop_words=[""], + global_policy=GlobalPolicy( + prefix="[gMASK]" + ), + system_policy=SystemPolicy( + use_system=True, + use_system_without_system_message=False, + ), + ) +) + +register_template( + Template( + name="phi-4", + system_template="<|im_start|>system<|im_sep|>{system_message}<|im_end|>", + user_template="<|im_start|>user<|im_sep|>{content}<|im_end|>", + assistant_template="<|im_start|>assistant<|im_sep|>{content}<|im_end|>", + stop_words=["<|im_end|>"], + ) +) +# Note: Partial align, some minor new-line problems. +register_template( + Template( + name="nemotron", + system_template="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{system_message}<|eot_id|>", + system_template_with_tools="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{system_message}{tools}<|eot_id|>""", + user_template="<|start_header_id|>user<|end_header_id|>\n\n{content}<|eot_id|>", + assistant_template="<|start_header_id|>assistant<|end_header_id|>\n\n{content}<|eot_id|>", + tool_template="<|start_header_id|>user<|end_header_id|>\n\n[{observation}]<|eot_id|>", + stop_words=["<|eot_id|>"], + system_policy=SystemPolicy( + use_system=True, + content_processor=lambda x: f"\n{x}", + ), + tool_policy=ToolPolicy( + placement=ToolPlacement.SYSTEM, + content_processor=ToolMainContentProcessor(), + formatter=JsonCompactFormatter(), + ) + ) +) if __name__ == "__main__": - pass + pass \ No newline at end of file diff --git a/agents/agents/agents/templates/tool_policy.py b/agents/agents/agents/templates/tool_policy.py new file mode 100644 index 0000000..377c098 --- /dev/null +++ b/agents/agents/agents/templates/tool_policy.py @@ -0,0 +1,237 @@ +from typing import List, Dict, Tuple +import json +from typing import Callable, Any # Added for content processor typing +from abc import ABC, abstractmethod +import dataclasses + +from .constants import ToolPlacement + +# Convert ToolFormatter into an abstract base class +class ToolFormatter(ABC): + """ + Strategy that converts an in-memory list[dict] describing tools + into the textual representation expected by the target model. + """ + @abstractmethod + def format(self, tools: List[Dict]) -> str: + """Format a list of tool dictionaries into a string representation.""" + raise NotImplementedError + + @abstractmethod + def jinja(self) -> str: + """Return a Jinja template that can be used to format the tools.""" + raise NotImplementedError + + +class ToolContentProcessor(ABC): + """ + Strategy that processes the content of a tool before it is serialized. + """ + @abstractmethod + def __call__(self, tool: Dict) -> Dict: + raise NotImplementedError + + @abstractmethod + def jinja(self) -> str: + """Return a Jinja template that can be used to process the content of a tool.""" + raise NotImplementedError + + +class ToolMainContentProcessor(ToolContentProcessor): + """ + Strategy that processes the main content of a tool before it is serialized. + """ + def __call__(self, tool: Dict) -> Dict: + assert isinstance(tool, dict), "Tool must be a dictionary" + if "function" in tool: + content = tool["function"] + assert "name" in content, "Tool function must have a name" + assert "parameters" in content, "Tool function must have parameters" + return content + elif "name" in tool and "parameters" in tool: + return tool + else: + raise ValueError(f"Tool must have a function or name and parameters: {tool}") + + # The main-content extraction cannot be replicated in pure Jinja, so we + # fall back to the identity behaviour at template-generation time. This + # means the processor is *ignored* in frozen chat-templates; users who + # require it must rely on the Python render path. + + def jinja(self) -> str: + # We deliberately document the limitation by returning a simple pass- + # through expression. + return "{{ tool }}" + +# Make JsonFormatter inherit the ToolFormatter base class +class JsonFormatter(ToolFormatter): + """General JSON formatter with configurable indent, separators, and joiner.""" + + def __init__( + self, + *, + indent: int | None = None, + separators: Tuple[str, str] | None = None, + joiner: str = "\n", + format_as_list: bool = False, + content_processor: ToolContentProcessor = None, + ): + """Create a new JsonFormatter. + + Args: + indent: Indentation level passed to ``json.dumps``. ``None`` means no pretty-print. + separators: Custom separators passed to ``json.dumps``; useful for minification. + joiner: String used to join per-tool JSON strings when ``format_as_list`` is *False*. + format_as_list: If *True*, the entire ``tools`` list is serialised in a single + ``json.dumps`` call, ignoring ``joiner``. This is handy when the target + model expects a single JSON array instead of multiple individual objects. + content_processor: Optional callable applied to each individual tool dictionary + before serialisation. Defaults to the identity function. + """ + self.indent = indent + self.separators = separators + self.joiner = joiner + self.format_as_list = format_as_list + + def format(self, tools: List[Dict]) -> str: # noqa: D401 + """Return a single string obtained by dumping every tool to JSON then joining them. + + Args: + tools: A list of tool dictionaries to be stringified. + + Returns: + A string representation of the tools, formatted according to the + given ``indent``/``separators`` and concatenated with ``joiner``. + """ + # Apply the per-tool content processor first + + if self.format_as_list: + # Serialize the whole list in one go – joiner is irrelevant in this mode. + return json.dumps(tools, indent=self.indent, separators=self.separators) + + # Default behaviour: dump each tool individually then concatenate. + return self.joiner.join( + json.dumps(t, indent=self.indent, separators=self.separators) for t in tools + ) + + # ------------------------------------------------------------------ + # Jinja support + # ------------------------------------------------------------------ + + def _escape_joiner(self, joiner: str) -> str: # local helper + """Return *joiner* escaped so it is safe inside a single‐quoted Jinja + string literal (the HF chat-template parser understands the Python + backslash escapes).""" + + return joiner.replace("\\", "\\\\").replace("'", "\\'") + + def jinja(self) -> str: # noqa: D401 + """Return a **Jinja-mini** snippet that serialises the *tools* variable + with the same settings as :py:meth:`format`. + + The template assumes that a ``tools`` list is present in the Jinja + context. Because the Hugging-Face chat-template dialect only supports + a limited subset of Jinja, we restrict ourselves to `map`, `tojson`, + `join`, and optional indent on a *single* tojson call when + ``format_as_list`` is *True*. + + When ``format_as_list`` is *False* and ``indent`` is specified, we use + a Jinja loop to apply indentation to each individual tool. + """ + + # Serialise whole list -> one tojson call (supports indent argument) + if self.format_as_list: + if self.indent is None: + return "{{ tools | tojson }}" + else: + return f"{{{{ tools | tojson(indent={self.indent}) }}}}" + + # Individual objects: use loop if indent is needed, otherwise use map + if self.indent is not None: + # Use loop to apply indentation to each individual tool + # For joiners containing newlines, we need to avoid whitespace control to preserve them + # For other joiners, we can use whitespace control for cleaner output + + if '\n' in self.joiner: + # Joiner contains newlines - use Jinja's string replacement to convert \n to actual newlines + # We'll create a Jinja variable with the proper newlines + joiner_var = '{% set joiner = "' + self.joiner.replace('\n', '\\n') + '" | replace("\\\\n", "\n") %}' + return joiner_var + f"{{% for tool in tools %}}{{{{ tool | tojson(indent={self.indent}) }}}}{{% if not loop.last %}}{{{{ joiner }}}}{{% endif %}}{{% endfor %}}" + else: + # Joiner doesn't contain newlines - safe to use whitespace control and escaping + joiner_escaped = self._escape_joiner(self.joiner) + return f"{{%- for tool in tools -%}}{{{{ tool | tojson(indent={self.indent}) }}}}{{%- if not loop.last -%}}{joiner_escaped}{{%- endif -%}}{{%- endfor -%}}" + else: + # No indentation needed, use the simpler map approach + joiner_escaped = self._escape_joiner(self.joiner) + return ( + "{{ tools | map('tojson') | join('" + joiner_escaped + "') }}" + ) + +class JsonMinifiedFormatter(JsonFormatter): + """Single-line JSON objects without extra whitespace (legacy alias).""" + + def __init__(self, joiner: str = "\n", *, content_processor: Callable[[Dict], Any] | None = None): + super().__init__(indent=None, separators=(",", ":"), joiner=joiner, content_processor=content_processor) + + +class JsonIndentedFormatter(JsonFormatter): + """ + Pretty printed JSON with configurable indent (default 4). + Frequently required by models like Mistral-v0.3. + (legacy alias) + """ + + def __init__(self, indent: int = 4, *, joiner: str = "\n\n", format_as_list: bool = False): + super().__init__(indent=indent, separators=None, joiner=joiner, format_as_list=format_as_list) + + +class JsonCompactFormatter(JsonFormatter): + """Single-line JSON objects without extra whitespace.""" + def __init__(self, *, format_as_list: bool = True, content_processor: Callable[[Dict], Any] | None = None): + super().__init__(indent=None, separators=None, format_as_list=format_as_list, content_processor=content_processor) + +class JsonQwenFormatter(JsonFormatter): + """ + JSON formatter for Qwen models. + """ + def __init__(self): + super().__init__(indent=None, separators=None, format_as_list=False, content_processor=None) + + # No special behaviour – inherits .jinja from JsonFormatter + + +# --------------------------------------------------------------------------- +# Content processors – only implement jinja where feasible +# --------------------------------------------------------------------------- + + +try: + import yaml as _yaml # optional dependency + + class YamlFormatter(ToolFormatter): # type: ignore + def format(self, tools: List[Dict]) -> str: # noqa: D401 + return _yaml.safe_dump(tools, sort_keys=False) +except ModuleNotFoundError: # pragma: no cover + YamlFormatter = None # type: ignore + + +@dataclasses.dataclass +class ToolPolicy: + """ + Encapsulates every configuration decision about how *tools* + appear in the prompt for a given template. + """ + placement: "ToolPlacement" = ToolPlacement.SYSTEM + content_processor: Callable[[Dict], Any] = None + formatter: ToolFormatter = dataclasses.field(default_factory=lambda: JsonQwenFormatter()) + + def format_tools(self, tools: List[Dict]) -> str: + """ + Convert `tools` into ready-to-inject text according to the chosen formatter. + """ + if self.content_processor is not None: + processed_tools = [self.content_processor(t) for t in tools] + else: + processed_tools = tools + return self.formatter.format(processed_tools) diff --git a/agents/agents/agents/templates/utils.py b/agents/agents/agents/templates/utils.py index ccb2258..9f62182 100644 --- a/agents/agents/agents/templates/utils.py +++ b/agents/agents/agents/templates/utils.py @@ -10,7 +10,8 @@ import logging from .templates import Chat, get_template from ... import AGENT_DATA_DIR - +from typing import Any +from .vision_processor import get_processor # Set up logging that won't be overridden by other modules LOGGER = logging.getLogger(__name__) @@ -20,9 +21,6 @@ def strip_ansi(s: str) -> str: """Remove ANSI escape sequences from a string.""" return ANSI_RE.sub('', s) -def is_vlm_template(template: str) -> bool: - return template in ["qwen2.5-vl"] - def convert_messages_to_openai_format(messages: list) -> list: """ @@ -31,10 +29,10 @@ def convert_messages_to_openai_format(messages: list) -> list: """ messages = copy.deepcopy(messages) for message in messages: - if "tool_calls" in message: - del message["tool_calls"] - if "tool_call_id" in message: - del message["tool_call_id"] + # if "tool_calls" in message: + # del message["tool_calls"] + # if "tool_call_id" in message: + # del message["tool_call_id"] if "tool_choice" in message: del message["tool_choice"] return messages @@ -139,12 +137,15 @@ def tokenize_conversation( :return: input_ids, attention_mask, labels, action_mask """ chat = Chat(template=template, messages=messages, tokenizer=tokenizer) - inputs = chat.tokenize(tokenizer, add_generation_prompt=add_generation_prompt, tools=tools) + inputs = chat.tokenize(tokenizer, add_generation_prompt=add_generation_prompt, tools=tools, processor=processor) + if max_length is not None: inputs['input_ids'] = inputs['input_ids'][:, :max_length] inputs['attention_mask'] = inputs['attention_mask'][:, :max_length] - inputs['labels'] = inputs['labels'][:, :max_length] - inputs['action_mask'] = inputs['action_mask'][:, :max_length] + if 'labels' in inputs: + inputs['labels'] = inputs['labels'][:, :max_length] + if 'action_mask' in inputs: + inputs['action_mask'] = inputs['action_mask'][:, :max_length] return inputs @@ -153,82 +154,85 @@ def convert_inputs_to_vision_inputs(template: str, processor, # AutoProcessor (not bare tokenizer) messages: list): """ - Expand `inputs` (built from a chat template that contains ONE - <|image_pad|> / <|video_pad|> placeholder per asset) into the real - processor outputs and stretch `action_mask` / `labels` so they stay - aligned after the pad tokens are repeated. - - Returns - ------- - dict -- processor(...) result + expanded masks + NEW PIPELINE: Template processes messages → Human-readable prompt → Vision processor → LLM-ready inputs + + The correct pipeline is: + 1. Template processes messages to get human-readable prompt with single multi-modal tokens + 2. Vision processor handles image/video processing and token expansion + 3. Final result is directly usable by LLMs with model(**inputs) """ - assert template == "qwen2.5-vl", "Only qwen2.5-vl is supported" - - - # ------------------------------------------------------------------ - # 1. special-token ids - # ------------------------------------------------------------------ - tk = processor.tokenizer - conv = get_conv_template(template) - image_pad_id = tk.encode(conv.image_token, add_special_tokens=False)[0] - video_pad_id = tk.encode(conv.video_token, add_special_tokens=False)[0] - vis_start_id = tk.encode(conv.vision_start, add_special_tokens=False)[0] - vis_end_id = tk.encode(conv.vision_end, add_special_tokens=False)[0] - - repeat_ids = torch.tensor([image_pad_id, video_pad_id], dtype=torch.long) - vision_ids = torch.tensor([image_pad_id, video_pad_id, - vis_start_id, vis_end_id], dtype=torch.long) - - # ------------------------------------------------------------------ - # 2. run the processor (adds patch-level vision tokens) - # ------------------------------------------------------------------ - LOGGER.debug(f"[Template::convert_inputs_to_vision_inputs] messages: {messages}") - imgs, vids = process_vision_info(messages) - text = format_conversation(messages, template).get_prompt() - proc_out = processor( - text=text, - images=imgs, - videos=vids, - return_tensors="pt", - padding=False, - truncation=False, + # Get the vision processor for this template + vision_processor = get_processor(template) + if vision_processor is None: + raise ValueError(f"No vision processor registered for template: {template}") + + # Step 1: Template processes messages to get human-readable prompt + from .templates import Chat + chat = Chat(template=template, messages=messages, tokenizer=processor.tokenizer) + prompt = chat.prompt() # This gives us human-readable prompt with single multi-modal tokens + + # Step 2: Extract vision inputs from messages + images, videos = extract_vision_inputs_from_messages(messages) + + # Step 3: Vision processor handles the complete pipeline + # This expands tokens and generates LLM-ready inputs + final_inputs = vision_processor.process_for_llm( + prompt=prompt, + images=images, + videos=videos, + processor=processor, + tokenizer=processor.tokenizer ) - new_ids = proc_out["input_ids"][0] # (new_len,) - new_attention_mask = proc_out["attention_mask"][0] - device = new_ids.device - - # ------------------------------------------------------------------ - # 3. original (pre-processor) tensors - # ------------------------------------------------------------------ - old_ids = inputs["input_ids"][0].to(device) - old_action = inputs.get("action_mask", None) - old_labels = inputs.get("labels", None) - old_attention_mask = inputs.get("attention_mask", None) - - # ------------------------------------------------------------------ - # 4. build boolean masks that mark NON-repeated positions - # ------------------------------------------------------------------ - new_keep = ~torch.isin(new_ids, repeat_ids) # True where we copy from old - old_keep = ~torch.isin(old_ids, repeat_ids) - - # sanity check - assert new_keep.sum() == old_keep.sum(), f"Mismatch after dropping repeated vision tokens: {new_ids} {old_ids}" - - # ------------------------------------------------------------------ - # 5. allocate expanded tensors and copy en masse - # ------------------------------------------------------------------ - if old_action is not None: - exp_action = torch.zeros_like(new_ids, dtype=old_action.dtype) - exp_action[new_keep] = old_action[0][old_keep].to(device) - proc_out["action_mask"] = exp_action.unsqueeze(0) + return final_inputs - if old_labels is not None: - exp_labels = torch.full_like(new_ids, -100, dtype=old_labels.dtype) - exp_labels[new_keep] = old_labels[0][old_keep].to(device) - proc_out["labels"] = exp_labels.unsqueeze(0) - - return proc_out +def extract_vision_inputs_from_messages(messages: list) -> tuple[list, list]: + """Extract images and videos from messages""" + images, videos = [], [] + + for message in messages: + if isinstance(message.get('content'), list): + for item in message['content']: + if item.get('type') in ['image', 'image_url']: + if 'image' in item: + images.append(item['image']) + elif 'image_url' in item: + images.append(item['image_url']['url']) + elif item.get('type') in ['video', 'video_url']: + if 'video' in item: + videos.append(item['video']) + elif 'video_url' in item: + videos.append(item['video_url']['url']) + + return images, videos + +def process_prompt_with_vision( + prompt: str, + template: str, + processor: Any, + images: list = None, + videos: list = None, +) -> dict: + """Process a prompt with vision support""" + vision_processor = get_processor(template) + if vision_processor is None: + # If no vision processor, just return tokenized prompt + return processor.tokenizer( + prompt, + return_tensors="pt", + add_special_tokens=True, + padding=True, + truncation=True + ) + + # Use vision processor to handle the complete pipeline + return vision_processor.process_for_llm( + prompt=prompt, + images=images or [], + videos=videos or [], + processor=processor, + tokenizer=processor.tokenizer + ) def tokenize_conversations(messages_list, tokenizer, conv_template, max_length, processor=None, return_tensors="pt", return_reward_mask=False): @@ -245,7 +249,9 @@ def tokenize_conversations(messages_list, tokenizer, conv_template, max_length, batch_action_masks.append(inputs['action_mask'].squeeze(0)) if return_tensors == "pt": - batch_input_ids = torch.nn.utils.rnn.pad_sequence(batch_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id) + # Use pad_token_id from the tokenizer interface + pad_token_id = getattr(tokenizer, 'pad_token_id', 0) + batch_input_ids = torch.nn.utils.rnn.pad_sequence(batch_input_ids, batch_first=True, padding_value=pad_token_id) batch_attention_masks = torch.nn.utils.rnn.pad_sequence(batch_attention_masks, batch_first=True, padding_value=0) batch_labels = torch.nn.utils.rnn.pad_sequence(batch_labels, batch_first=True, padding_value=-100) batch_action_masks = torch.nn.utils.rnn.pad_sequence(batch_action_masks, batch_first=True, padding_value=0) @@ -299,6 +305,9 @@ def compare_hf_template(tokenizer, template_name, messages=None, tools=None, add plain_highlighted_prompt = strip_ansi(highlighted_prompt) is_equal_between_implemented_prompts = implemented_prompt == plain_highlighted_prompt jinja_template = chat.template.jinja_template() + # Save jinja template to file + with open("jinja_template.jinja", "w") as f: + f.write(jinja_template) tokenizer.chat_template = jinja_template implemented_jinja_prompt = tokenizer.apply_chat_template(messages, tokenize=False, tools=tools, add_generation_prompt=add_generation_prompt) is_equal_between_jinja_prompts = implemented_jinja_prompt == implemented_prompt @@ -312,7 +321,9 @@ def vllm_serve(model_name_or_path, template, tp, pp, dp): os.makedirs(f"{AGENT_DATA_DIR}/cache") with open(f"{AGENT_DATA_DIR}/cache/jinja_template.jinja", "w") as f: f.write(jinja_template) - command = f"vllm serve {model_name_or_path} --chat-template {AGENT_DATA_DIR}/cache/jinja_template.jinja --tensor-parallel-size {tp} --pipeline-parallel-size {pp} --data-parallel-size {dp} --port {port} --enable-auto-tool-choice --tool-call-parser hermes --expand-tools-even-if-tool-choice-none" + # command = f"vllm serve {model_name_or_path} --chat-template {AGENT_DATA_DIR}/cache/jinja_template.jinja --tensor-parallel-size {tp} --pipeline-parallel-size {pp} --data-parallel-size {dp} --port {port} --enable-auto-tool-choice --tool-call-parser hermes --expand-tools-even-if-tool-choice-none" + command = f"vllm serve {model_name_or_path} --tensor-parallel-size {tp} --pipeline-parallel-size {pp} --data-parallel-size {dp} --port {port} --enable-auto-tool-choice --tool-call-parser hermes --expand-tools-even-if-tool-choice-none" + print(command) os.system(command) @@ -321,5 +332,6 @@ def vllm_serve(model_name_or_path, template, tp, pp, dp): "python -m agents.agents.templates.utils" # model = "/mnt/sharefs/users/haonan.li/models/Qwen2.5-7B-instruct-am_think_v1_distilled" model = "Qwen/Qwen2.5-7B-Instruct" - vllm_serve(model, "qwen2.5-think", 2, 1, 4) + # vllm_serve(model, "qwen2.5-think", 2, 1, 4) + vllm_serve(model, "qwen2.5", 1, 1, 1) diff --git a/agents/agents/agents/templates/vision_processor.py b/agents/agents/agents/templates/vision_processor.py new file mode 100644 index 0000000..ca14aea --- /dev/null +++ b/agents/agents/agents/templates/vision_processor.py @@ -0,0 +1,668 @@ +""" +Comprehensive multi-modal vision processor that handles vision processing separately from template processing. +The pipeline is: Template → Human-readable prompt → Vision processor → LLM-ready inputs. +""" + +import base64 +import inspect +import math +import os +import re +import urllib.parse +import urllib.request +from copy import deepcopy +from dataclasses import dataclass +from io import BytesIO +from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union, List, Dict, Any +from abc import ABC, abstractmethod + +import numpy as np +import torch +from PIL import Image +from PIL.Image import Image as ImageObject +from transformers.image_utils import get_image_size, is_valid_image, to_numpy_array +from transformers.models.mllama.processing_mllama import ( + convert_sparse_cross_attention_mask_to_dense, + get_cross_attention_token_mask, +) +from typing_extensions import override + +if TYPE_CHECKING: + from av.stream import Stream + from numpy.typing import NDArray + from transformers import PreTrainedTokenizer, ProcessorMixin + from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor + from transformers.image_processing_utils import BaseImageProcessor + + class EncodedImage(TypedDict): + path: Optional[str] + bytes: Optional[bytes] + + ImageInput = Union[str, bytes, EncodedImage, BinaryIO, "ImageObject"] + VideoInput = Union[str, BinaryIO, list[list[ImageInput]]] + + class MMProcessor(ProcessorMixin): + patch_size: int + image_seq_length: int + num_additional_image_tokens: int + vision_feature_select_strategy: Literal["default", "full"] + + def _get_number_of_features(self, orig_height: int, orig_width: int, height: int, width: int) -> int: + pass + + +@dataclass +class VisionProcessorConfig: + """Configuration for vision processing""" + model_type: str + image_token: str + video_token: str + vision_start: str = "" + vision_end: str = "" + processor_class: str = "AutoProcessor" + expansion_strategy: str = "patch_based" + image_max_pixels: int = 16384 * 28 * 28 + image_min_pixels: int = 4 * 28 * 28 + video_max_pixels: int = 16384 * 28 * 28 + video_min_pixels: int = 4 * 28 * 28 + video_fps: float = 2.0 + video_maxlen: int = 128 + +class VisionProcessor(ABC): + """Abstract base class for vision processing strategies""" + + def __init__(self, config: VisionProcessorConfig): + self.config = config + self._validate_config() + + def _validate_config(self): + """Validate the vision configuration""" + required_fields = ['image_token', 'video_token'] + for field in required_fields: + if not hasattr(self.config, field) or getattr(self.config, field) is None: + raise ValueError(f"Missing required field: {field}") + + @abstractmethod + def preprocess_images(self, images: List["ImageInput"], processor: Any) -> Dict[str, Any]: + """Preprocess images for the model""" + pass + + @abstractmethod + def preprocess_videos(self, videos: List["VideoInput"], processor: Any) -> Dict[str, Any]: + """Preprocess videos for the model""" + pass + + @abstractmethod + def calculate_image_tokens(self, image_data: Dict[str, Any], processor: Any) -> int: + """Calculate the number of tokens needed for an image""" + pass + + @abstractmethod + def calculate_video_tokens(self, video_data: Dict[str, Any], processor: Any) -> int: + """Calculate the number of tokens needed for a video""" + pass + + @abstractmethod + def expand_vision_tokens( + self, + prompt: str, + images: List["ImageInput"], + videos: List["VideoInput"], + processor: Optional[Any], + ) -> str: + """Expand vision tokens in the prompt to their actual token representations""" + pass + + @abstractmethod + def get_mm_inputs( + self, + images: List["ImageInput"], + videos: List["VideoInput"], + processor: Optional[Any], + ) -> Dict[str, torch.Tensor]: + """Generate multi-modal inputs for the model""" + pass + + def process_vision_info(self, messages: List[Dict]) -> Dict[str, torch.Tensor]: + """Process vision information from messages""" + pass + + # def process_for_llm( + # self, + # prompt: str, + # images: List["ImageInput"], + # videos: List["VideoInput"], + # processor: Optional[Any], + # tokenizer: Any, + # ) -> Dict[str, torch.Tensor]: + # """ + # Complete pipeline: expand tokens and generate LLM-ready inputs. + # Returns inputs that can be used directly with model(**inputs). + # """ + # # Step 1: Expand vision tokens in the prompt + # expanded_prompt = self.expand_vision_tokens(prompt, images, videos, processor) + + # # Step 2: Tokenize the expanded prompt + # tokenized_inputs = tokenizer( + # expanded_prompt, + # return_tensors="pt", + # add_special_tokens=True, + # padding=True, + # truncation=True + # ) + + # # Step 3: Generate multi-modal inputs + # mm_inputs = self.get_mm_inputs(images, videos, processor) + + # # Step 4: Combine tokenized inputs with multi-modal inputs + # final_inputs = {**tokenized_inputs, **mm_inputs} + + # return final_inputs + + def process_for_llm( + self, + prompt: str, + elements: List[str], + mask_flags: List[bool], + images: List["ImageInput"], + videos: List["VideoInput"], + processor: Any, + tokenizer: Any, + return_tensors: str = None, + ) -> Dict[str, torch.Tensor]: + """ + Process with proper alignment of all tensors (input_ids, attention_mask, labels, action_mask). + This ensures that when vision tokens are expanded, all corresponding tensors are expanded + at the same positions, maintaining proper alignment for training and inference. + """ + import torch + + # Step 1: Tokenize elements to get base tensors with proper alignment + input_ids = [] + attention_mask = [] + labels = [] + action_mask = [] + + # Add BOS token if needed + if tokenizer.bos_token and tokenizer.add_bos_token: + input_ids.append(tokenizer.bos_token_id) + attention_mask.append(1) + labels.append(-100) + action_mask.append(0) + + # Step 2: Process each element with vision token expansion + for element, mask_flag in zip(elements, mask_flags): + # Check if element contains vision tokens + if self._contains_vision_tokens(element): + # Expand vision tokens in this element + expanded_element = self.expand_vision_tokens(element, images, videos, processor) + cur_input_ids = tokenizer.encode(expanded_element, add_special_tokens=False) + else: + cur_input_ids = tokenizer.encode(element, add_special_tokens=False) + + # Add tokens with proper alignment + input_ids.extend(cur_input_ids) + attention_mask.extend([1] * len(cur_input_ids)) + + if mask_flag: + labels.extend([-100] * len(cur_input_ids)) + action_mask.extend([0] * len(cur_input_ids)) + else: + labels.extend(cur_input_ids) + action_mask.extend([1] * len(cur_input_ids)) + + # Step 3: Create base inputs + inputs = { + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'labels': labels, + 'action_mask': action_mask + } + + # Convert to tensors if requested + if return_tensors == "pt": + inputs = {k: torch.tensor([v]) for k, v in inputs.items()} + + # Step 4: Add vision inputs + mm_inputs = self.get_mm_inputs(images, videos, processor) + inputs.update(mm_inputs) + + return inputs + + def _contains_vision_tokens(self, text: str) -> bool: + """Check if text contains vision tokens""" + return self.config.image_token in text or self.config.video_token in text + +class PatchBasedProcessor(VisionProcessor): + """Patch-based vision processor (used by Qwen-VL, LLaVA, etc.) + + Supports multiple image input formats: + - File paths (str): "/path/to/image.jpg" + - URLs (str): "https://example.com/image.jpg" + - Base64 strings (str): "data:image/jpeg;base64,/9j/4AAQ..." or raw base64 + - PIL Image objects + - Bytes objects + - File-like objects + - Dict format: {"path": "/path/to/image.jpg"} or {"bytes": b"image_data"} + """ + + def _load_image_from_input(self, image_input) -> "ImageObject": + """Load image from various input formats including URL and base64""" + from PIL import Image + + # Handle PIL Image objects directly + if hasattr(image_input, 'width') and hasattr(image_input, 'height'): + return image_input + + # Handle string inputs (file path, URL, or base64) + if isinstance(image_input, str): + # Check if it's a URL + if image_input.startswith(('http://', 'https://')): + try: + with urllib.request.urlopen(image_input) as response: + image_data = response.read() + return Image.open(BytesIO(image_data)) + except Exception as e: + raise ValueError(f"Failed to load image from URL {image_input}: {e}") + + # Check if it's a base64 string + elif image_input.startswith('data:image/') or image_input.startswith('data:application/octet-stream'): + # Handle data URL format: data:image/jpeg;base64,/9j/4AAQ... + try: + # Extract the base64 part after the comma + base64_data = image_input.split(',', 1)[1] + image_data = base64.b64decode(base64_data) + return Image.open(BytesIO(image_data)) + except Exception as e: + raise ValueError(f"Failed to decode base64 image: {e}") + + elif image_input.startswith('iVBORw0KGgo') or len(image_input) > 100: + # Likely a raw base64 string (common for PNG images starting with iVBORw0KGgo) + try: + image_data = base64.b64decode(image_input) + return Image.open(BytesIO(image_data)) + except Exception as e: + raise ValueError(f"Failed to decode base64 image: {e}") + + # Assume it's a file path + else: + return Image.open(image_input) + + # Handle bytes + elif isinstance(image_input, bytes): + return Image.open(BytesIO(image_input)) + + # Handle file-like objects + elif hasattr(image_input, 'read'): + return Image.open(image_input) + + # Handle dict format + elif isinstance(image_input, dict): + if image_input.get("bytes") is not None: + return Image.open(BytesIO(image_input["bytes"])) + elif image_input.get("path") is not None: + return Image.open(image_input["path"]) + else: + raise ValueError("Invalid image dict format") + + else: + raise ValueError(f"Unsupported image input type: {type(image_input)}") + + def _preprocess_single_image(self, image: "ImageObject", **kwargs) -> "ImageObject": + """Preprocess a single image""" + if (image.width * image.height) > self.config.image_max_pixels: + resize_factor = math.sqrt(self.config.image_max_pixels / (image.width * image.height)) + width, height = int(image.width * resize_factor), int(image.height * resize_factor) + image = image.resize((width, height)) + + if (image.width * image.height) < self.config.image_min_pixels: + resize_factor = math.sqrt(self.config.image_min_pixels / (image.width * image.height)) + width, height = int(image.width * resize_factor), int(image.height * resize_factor) + image = image.resize((width, height)) + + if image.mode != "RGB": + image = image.convert("RGB") + + return image + + def _regularize_images(self, images: List["ImageInput"]) -> List["ImageObject"]: + """Regularize images to avoid errors""" + results = [] + for image in images: + # Use the new helper method to handle all input formats + pil_image = self._load_image_from_input(image) + results.append(self._preprocess_single_image(pil_image)) + + return results + + def _regularize_videos(self, videos: List["VideoInput"]) -> List[List["ImageObject"]]: + """Regularize videos to avoid errors""" + results = [] + for video in videos: + frames: List["ImageObject"] = [] + + # Check if video is nested images + if isinstance(video, list) and all(isinstance(frame, (str, BinaryIO, dict)) for frame in video): + # Use the new image loading method for each frame + for frame in video: + try: + pil_image = self._load_image_from_input(frame) + frames.append(pil_image) + except Exception as e: + raise ValueError(f"Invalid image found in video frames: {e}") + else: + # Process actual video file + import av + container = av.open(video, "r") + video_stream = next(stream for stream in container.streams if stream.type == "video") + + # Calculate sample indices + total_frames = video_stream.frames + if total_frames == 0: # infinite video + sample_indices = np.linspace(0, self.config.video_maxlen - 1, self.config.video_maxlen).astype(np.int32) + else: + sample_frames = max(1, math.floor(float(video_stream.duration * video_stream.time_base) * self.config.video_fps)) + sample_frames = min(total_frames, self.config.video_maxlen, sample_frames) + sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32) + + container.seek(0) + for frame_idx, frame in enumerate(container.decode(video_stream)): + if frame_idx in sample_indices: + frames.append(frame.to_image()) + + frames = self._regularize_images(frames) + results.append(frames) + + return results + + def preprocess_images(self, images: List["ImageInput"], processor: Any) -> Dict[str, Any]: + """Preprocess images for the model""" + if not images: + return {} + + image_processor = getattr(processor, "image_processor", None) + if image_processor is None: + raise ValueError("Image processor not found") + + images = self._regularize_images(images) + return image_processor(images, return_tensors="pt") + + def preprocess_videos(self, videos: List["VideoInput"], processor: Any) -> Dict[str, Any]: + """Preprocess videos for the model""" + if not videos: + return {} + + video_processor = getattr(processor, "video_processor", getattr(processor, "image_processor", None)) + if video_processor is None: + raise ValueError("Video processor not found") + + videos = self._regularize_videos(videos) + + # Handle different video processor interfaces + if "videos" in inspect.signature(video_processor.preprocess).parameters: + return video_processor(images=None, videos=videos, return_tensors="pt") + else: + return video_processor(videos, return_tensors="pt") + + def calculate_image_tokens(self, image_data: Dict[str, Any], processor: Any) -> int: + """Calculate the number of tokens needed for an image + + Uses two approaches: + 1. Grid-based (HuggingFace method): Uses image_grid_thw and merge_size + - More accurate for models like Qwen-VL + - Accounts for hierarchical token merging + 2. Patch-based (fallback): Uses image dimensions and patch_size + - Standard approach for most ViT-based models + - Assumes each patch corresponds to one token + """ + if "pixel_values" in image_data: + # Try grid-based calculation first (HuggingFace method) + if "image_grid_thw" in image_data: + grid_info = image_data["image_grid_thw"] + if isinstance(grid_info, torch.Tensor): + grid_prod = grid_info.prod().item() + elif isinstance(grid_info, list): + grid_prod = math.prod(grid_info) + else: + grid_prod = grid_info + + # Get merge_size from processor + merge_size = getattr(processor, "merge_size", 1) + merge_length = merge_size ** 2 + + num_image_tokens = grid_prod // merge_length + return max(1, num_image_tokens) + + # Fallback to patch-based calculation + height, width = get_image_size(to_numpy_array(image_data["pixel_values"][0])) + image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + if hasattr(processor, 'num_additional_image_tokens'): + image_seqlen += processor.num_additional_image_tokens + if hasattr(processor, 'vision_feature_select_strategy') and processor.vision_feature_select_strategy == "default": + image_seqlen -= 1 + return image_seqlen + return 1 + + def calculate_video_tokens(self, video_data: Dict[str, Any], processor: Any) -> int: + """Calculate the number of tokens needed for a video""" + if "pixel_values" in video_data: + # For videos, we need to calculate based on frames + video_tensor = video_data["pixel_values"][0] + if len(video_tensor.shape) > 3: # Has frame dimension + num_frames = video_tensor.shape[0] + height, width = get_image_size(to_numpy_array(video_tensor[0])) + frame_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + if hasattr(processor, 'num_additional_image_tokens'): + frame_seqlen += processor.num_additional_image_tokens + if hasattr(processor, 'vision_feature_select_strategy') and processor.vision_feature_select_strategy == "default": + frame_seqlen -= 1 + return frame_seqlen * num_frames + else: + # Single frame video + return self.calculate_image_tokens(video_data, processor) + return 1 + + def expand_vision_tokens( + self, + prompt: str, + images: List["ImageInput"], + videos: List["VideoInput"], + processor: Optional[Any], + ) -> str: + """Expand vision tokens in the prompt to their actual token representations""" + if processor is None: + raise ValueError("Processor is required for vision processing") + + # Validate that number of placeholders matches number of inputs + num_image_placeholders = prompt.count(self.config.image_token) + num_video_placeholders = prompt.count(self.config.video_token) + + if len(images) != num_image_placeholders: + raise ValueError(f"Number of images ({len(images)}) doesn't match placeholders ({num_image_placeholders})") + if len(videos) != num_video_placeholders: + raise ValueError(f"Number of videos ({len(videos)}) doesn't match placeholders ({num_video_placeholders})") + + # Preprocess images and videos to get individual token counts + processed_images = self.preprocess_images(images, processor) if images else {} + processed_videos = self.preprocess_videos(videos, processor) if videos else {} + + # Expand image tokens using regex to avoid infinite loops + expanded_prompt = prompt + if self.config.image_token in expanded_prompt: + if processed_images and "pixel_values" in processed_images: + # Calculate tokens for this specific image + image_tokens = self.calculate_image_tokens(processed_images, processor) + replacement = self.config.image_token * image_tokens + else: + replacement = self.config.image_token + + # Use regex to replace all occurrences at once + import re + expanded_prompt = re.sub(re.escape(self.config.image_token), replacement, expanded_prompt) + + # Expand video tokens using regex to avoid infinite loops + if self.config.video_token in expanded_prompt: + if processed_videos and "pixel_values" in processed_videos: + # Calculate tokens for this specific video + video_tokens = self.calculate_video_tokens(processed_videos, processor) + replacement = self.config.video_token * video_tokens + else: + replacement = self.config.video_token + + # Use regex to replace all occurrences at once + expanded_prompt = re.sub(re.escape(self.config.video_token), replacement, expanded_prompt) + + return expanded_prompt + + def get_mm_inputs( + self, + images: List["ImageInput"], + videos: List["VideoInput"], + processor: Optional[Any], + ) -> Dict[str, torch.Tensor]: + """Generate multi-modal inputs for the model""" + mm_inputs = {} + + # Process images + if images: + mm_inputs.update(self.preprocess_images(images, processor)) + + # Process videos + if videos: + mm_inputs.update(self.preprocess_videos(videos, processor)) + + return mm_inputs + + def process_vision_info(self, messages: List[Dict], processor: Any): + """Process vision information from messages""" + image_message_types = ["image", "image_url", "image_base64"] + images = [] + for message in messages: + for content in message["content"]: + if content["type"] in image_message_types: + content_type = content["type"] + images.append(content[content_type]) + mm_inputs = self.get_mm_inputs(images, [], processor) + return mm_inputs + + +class QwenVLProcessor(PatchBasedProcessor): + """Qwen-VL specific processor with custom image preprocessing""" + + def _preprocess_single_image(self, image: "ImageObject", **kwargs) -> "ImageObject": + """Qwen-VL specific image preprocessing""" + image = super()._preprocess_single_image(image, **kwargs) + + # Qwen-VL specific adjustments + if min(image.width, image.height) < 28: + width, height = max(image.width, 28), max(image.height, 28) + image = image.resize((width, height)) + + if image.width / image.height > 200: + width, height = image.height * 180, image.height + image = image.resize((width, height)) + + if image.height / image.width > 200: + width, height = image.width, image.width * 180 + image = image.resize((width, height)) + + return image + + def calculate_image_tokens(self, image_data: Dict[str, Any], processor: Any) -> int: + """Qwen-VL specific token calculation using grid-based approach""" + if "image_grid_thw" in image_data: + # Use grid information for more accurate token calculation + grid_info = image_data["image_grid_thw"] + if isinstance(grid_info, torch.Tensor): + grid_prod = grid_info.prod().item() + elif isinstance(grid_info, list): + grid_prod = math.prod(grid_info) + else: + grid_prod = grid_info + + # Get merge_size from processor (Qwen-VL typically uses merge_size=2) + merge_size = getattr(processor, "merge_size", 2) + merge_length = merge_size ** 2 + + num_image_tokens = grid_prod // merge_length + return max(1, num_image_tokens) + + # Fallback to standard calculation + return super().calculate_image_tokens(image_data, processor) + + def expand_vision_tokens( + self, + prompt: str, + images: List["ImageInput"], + videos: List["VideoInput"], + processor: Optional[Any], + ) -> str: + """Qwen-VL specific token expansion with vision tags""" + expanded_prompt = super().expand_vision_tokens(prompt, images, videos, processor) + + return expanded_prompt + +class LlavaProcessor(PatchBasedProcessor): + """LLaVA specific processor""" + + def calculate_image_tokens(self, image_data: Dict[str, Any], processor: Any) -> int: + """LLaVA specific token calculation""" + if "pixel_values" in image_data: + height, width = get_image_size(to_numpy_array(image_data["pixel_values"][0])) + image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + if hasattr(processor, 'num_additional_image_tokens'): + image_seqlen += processor.num_additional_image_tokens + if hasattr(processor, 'vision_feature_select_strategy') and processor.vision_feature_select_strategy == "default": + image_seqlen -= 1 + return image_seqlen + return 1 + + +VISION_PROCESSORS: Dict[str, VisionProcessor] = {} + +model_type_to_processor_class = { + "qwen_vl": QwenVLProcessor, + "llava": LlavaProcessor, + "gemma3": PatchBasedProcessor, + "paligemma": PatchBasedProcessor, + "internvl": PatchBasedProcessor, + "minicpm": PatchBasedProcessor, + "mllama": PatchBasedProcessor, + "pixtral": PatchBasedProcessor, + "video_llava": PatchBasedProcessor, + "patch_based": PatchBasedProcessor, +} + +def register_processor(template_name: str, config: VisionProcessorConfig): + """Register a vision processor for a template""" + processor_class = model_type_to_processor_class.get(config.model_type) + if processor_class is None: + raise ValueError(f"No processor class found for model type: {config.model_type}") + VISION_PROCESSORS[template_name] = processor_class(config) + + +def register(cls, template_name: str, config: VisionProcessorConfig, processor_class: type = None): + """Register a vision processor for a template""" + if processor_class is not None: + # If processor_class is provided, use it directly + VISION_PROCESSORS[template_name] = processor_class(config) + else: + # Use the global register_processor function + register_processor(template_name, config) + +def get_processor(template_name: str) -> Optional[VisionProcessor]: + """Get vision processor for a template""" + return VISION_PROCESSORS.get(template_name) + +def get_processor_config(template_name: str) -> Optional[VisionProcessorConfig]: + """Get vision config for a template""" + processor = get_processor(template_name) + return processor.config if processor else None + +def is_vision_template(template_name: str) -> bool: + """Check if template supports vision""" + return template_name in VISION_PROCESSORS + +def list_vision_templates() -> List[str]: + """List all vision-enabled templates""" + return list(VISION_PROCESSORS.keys()) diff --git a/agents/agents/agents/utils/tokenizer.py b/agents/agents/agents/utils/tokenizer.py new file mode 100644 index 0000000..00ab8fe --- /dev/null +++ b/agents/agents/agents/utils/tokenizer.py @@ -0,0 +1,10 @@ +from transformers import AutoTokenizer + +def create_tokenizer(model_name_or_path: str): + try: + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + # Can not find the tokenizer in local directory or huggingface hub + except OSError: + tokenizer = None + + return tokenizer diff --git a/agents/agents/examples/streaming_example.py b/agents/agents/examples/streaming_example.py deleted file mode 100644 index c4fd08b..0000000 --- a/agents/agents/examples/streaming_example.py +++ /dev/null @@ -1,59 +0,0 @@ -import asyncio -from agents.agents.react.react_agent import ReactAgent -from agents.tools import code_interpreter -from agents.rewards import math_reward_tool -import json -from agents.agents.chain.streaming_observer import ConsoleStreamObserver - - -async def main(): - tools = [code_interpreter] - - agent = ReactAgent( - "Qwen/Qwen2.5-7B-Instruct", - tools=tools, - template="qwen2.5-no-tool", - backend="async_vllm", - reward_fn=math_reward_tool, - debug=True - ) - - - console_stream_observer = ConsoleStreamObserver() - agent.streaming_manager.add_observer(console_stream_observer) - - - question1 = "Every morning Aya goes for a $9$-kilometer-long walk and stops at a coffee shop afterwards. When she walks at a constant speed of $s$ kilometers per hour, the walk takes her 4 hours, including $t$ minutes spent in the coffee shop. When she walks $s+2$ kilometers per hour, the walk takes her 2 hours and 24 minutes, including $t$ minutes spent in the coffee shop. Suppose Aya walks at $s+\frac{1}{2}$ kilometers per hour. Find the number of minutes the walk takes her, including the $t$ minutes spent in the coffee shop." - answer1 = "204" - question2 = "$P(x)$ is a polynomial of degree $3n$ such that\n\\begin{eqnarray*} P(0) = P(3) = \\cdots &=& P(3n) = 2, \\\\ P(1) = P(4) = \\cdots &=& P(3n-2) = 1, \\\\ P(2) = P(5) = \\cdots &=& P(3n-1) = 0, \\quad\\text{ and }\\\\ && P(3n+1) = 730.\\end{eqnarray*}\nDetermine $n$." - answer2 = "3" - - - messages = [ - { - "messages": [ - {"role": "user", "content": f"{question1}"} - ], - "question": f"{question1}", - "answer": f"{answer1}" - }, - # { - # "messages": [ - # {"role": "user", "content": f"{question2}"} - # ], - # "question": f"{question2}", - # "answer": f"{answer2}" - # } - ] - - - await agent.run_async( - max_steps=4, - start_messages=messages, - num_chains=1, - enable_streaming=True - ) - -if __name__ == "__main__": - asyncio.run(main()) - diff --git a/agents/agents/rewards/qa_reward.py b/agents/agents/rewards/qa_reward.py index 583519f..bd335b5 100644 --- a/agents/agents/rewards/qa_reward.py +++ b/agents/agents/rewards/qa_reward.py @@ -104,3 +104,18 @@ def qa_f1_reward_format(prediction: str, answer: str, trajectory: List[str]) -> raise ValueError(f"Invalid prediction or trajectory for qa reward with format: Trajectory: {trajectory}") return rewards_dict + + +@reward(name="ok_vqa_reward") +def ok_vqa_reward(prediction: str, answers: List[str], trajectory: List[str]) -> float: + """ + Calculate the reward for the agent's response based on the F1 score and EM score. + The reward is 0.0 if the agent has not called any tool. + The reward is the F1 score if the agent has called a tool. + """ + f1_scores = [] + for answer in answers: + f1, precision, recall = f1_score(prediction, answer) + f1_scores.append(f1) + # All answers are the correct answer, take the max f1 score + return max(f1_scores) \ No newline at end of file diff --git a/agents/tests/unit/agents/prompts/test_qwen25_vl_prompt.py b/agents/tests/unit/agents/prompts/test_qwen25_vl_prompt.py deleted file mode 100644 index 0376b49..0000000 --- a/agents/tests/unit/agents/prompts/test_qwen25_vl_prompt.py +++ /dev/null @@ -1,138 +0,0 @@ -from agents.agents.agents.templates.templates import get_conv_template -from agents.agents.templates.utils import compare_hf_template, format_conversation -from transformers import AutoProcessor, AutoTokenizer - -def test_simple_prompts(): - messages = [ - { - "role": "user", - "content": [ - { - "type": "image", - "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", - }, - {"type": "text", "text": "Describe this image."}, - ], - }, - { - "role": "assistant", - "content": [ - { - "type": "text", - "text": "The image is a cat.", - }, - ], - }, - { - "role": "user", - "content": [ - { - "type": "text", - "text": "What is in the image?", - }, - ], - } - ] - processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") - is_equal, official_prompt, implemented_prompt, highlighted_prompt = compare_hf_template(processor, "qwen2.5-vl", messages=messages, add_generation_prompt=True) - assert is_equal, f"Official prompt:\n\n{official_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}" - print(f"Highlighted prompt:\n\n{highlighted_prompt}") - - -def test_simple_prompts_with_system(): - messages = [ - { - "role": "system", - "content": "You are a multi-modal assistant that can answer questions about images.", - }, - { - "role": "user", - "content": [ - { - "type": "image", - "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", - }, - {"type": "text", "text": "Describe this image."}, - ], - }, - { - "role": "assistant", - "content": [ - { - "type": "text", - "text": "The image is a cat.", - }, - ], - }, - { - "role": "user", - "content": [ - { - "type": "text", - "text": "What is in the image?", - }, - ], - } - ] - processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") - is_equal, official_prompt, implemented_prompt, highlighted_prompt = compare_hf_template(processor, "qwen2.5-vl", messages=messages, add_generation_prompt=True) - assert is_equal, f"Official prompt:\n\n{official_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}" - print(f"Highlighted prompt:\n\n{highlighted_prompt}") - - -def test_jinja_template(): - conv = get_conv_template("qwen2.5-vl") - jinja_template = conv.get_jinja_template() - tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") - messages = [ - { - "role": "system", - "content": "You are a multi-modal assistant that can answer questions about images.", - }, - { - "role": "user", - "content": [ - { - "type": "image", - "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", - }, - {"type": "text", "text": "Describe this image."}, - ], - }, - { - "role": "assistant", - "content": [ - { - "type": "text", - "text": "The image is a cat.", - }, - ], - }, - { - "role": "tool", - "content": [ - { - "type": "text", - "text": "Example tool response.", - }, - ], - }, - { - "role": "user", - "content": [ - { - "type": "text", - "text": "What is in the image?", - }, - ], - } - ] - official_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) - tokenizer.chat_template = jinja_template - prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) - print(official_prompt) - print(prompt) - - conv = format_conversation(messages, "qwen2.5-vl", add_generation_prompt=True) - implemented_prompt = conv.get_prompt() - print(implemented_prompt) \ No newline at end of file diff --git a/agents/tests/unit/agents/prompts/test_qwen3_prompt.py b/agents/tests/unit/agents/prompts/test_qwen3_prompt.py deleted file mode 100644 index f11785e..0000000 --- a/agents/tests/unit/agents/prompts/test_qwen3_prompt.py +++ /dev/null @@ -1,81 +0,0 @@ -from agents.agents.agents.templates.templates import get_conv_template -from agents.agents.templates.utils import compare_hf_template - -def test_simple_prompts(): - messages = [ - {"role": "user", "content": "Hello, how are you?"}, - {"role": "assistant", "content": "I am fine, thank you."}, - {"role": "user", "content": "Want to play a game?"}, - {"role": "assistant", "content": "Sure, what game?"}, - {"role": "user", "content": "What is 3 times 5?"}, - {"role": "assistant", "content": '''{"name": "multiply", "arguments": {"x": 3, "y": 5}}'''}, - {"role": "tool", "content": "15"}, - ] - qwen3_model = "Qwen/Qwen3-30B-A3B" - is_equal, official_prompt, implemented_prompt, highlighted_prompt = compare_hf_template(qwen3_model, "qwen3", messages=messages, add_generation_prompt=True) - assert is_equal, f"Official prompt:\n\n{official_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}" - print(f"Highlighted prompt:\n\n{highlighted_prompt}") - - -def test_simple_prompts_with_system(): - messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hello, how are you?"}, - {"role": "assistant", "content": "I am fine, thank you."}, - {"role": "user", "content": "Want to play a game?"}, - {"role": "assistant", "content": "Sure, what game?"}, - {"role": "user", "content": "What is 3 times 5?"}, - {"role": "assistant", "content": '''{"name": "multiply", "arguments": {"x": 3, "y": 5}}'''}, - {"role": "tool", "content": "15"}, - ] - qwen3_model = "Qwen/Qwen3-30B-A3B" - is_equal, official_prompt, implemented_prompt, highlighted_prompt = compare_hf_template(qwen3_model, "qwen3", messages=messages, add_generation_prompt=True) - assert is_equal, f"Official prompt:\n\n{official_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}" - print(f"Highlighted prompt:\n\n{highlighted_prompt}") - - -def test_with_tools(): - multiply_schema = {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}} - add_schema = {"type": "function", "function": {"name": "add", "description": "A function that add two numbers", "parameters": {"type": "object", "properties": {"a": {"type": "number", "description": "The first number to add"}, "b": {"type": "number", "description": "The second number to add"}}, "required": ["a", "b"]}}} - tools = [multiply_schema, add_schema] - messages = [ - {"role": "user", "content": "Hello, how are you?"}, - {"role": "assistant", "content": "I am fine, thank you."}, - {"role": "user", "content": "Want to play a game?"}, - {"role": "assistant", "content": "Sure, what game?"}, - {"role": "user", "content": "What is 3 times 5?"}, - {"role": "assistant", "content": '''{"name": "multiply", "arguments": {"x": 3, "y": 5}}'''}, - {"role": "tool", "content": "15"}, - ] - qwen3_model = "Qwen/Qwen3-30B-A3B" - is_equal, official_prompt, implemented_prompt, highlighted_prompt = compare_hf_template(qwen3_model, "qwen3", messages=messages, tools=tools, add_generation_prompt=True) - assert is_equal, f"Official prompt:\n\n{official_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}" - print(f"Highlighted prompt:\n\n{highlighted_prompt}") - - -def test_with_system_message(): - messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hello, how are you?"}, - {"role": "assistant", "content": "I am fine, thank you."}, - {"role": "user", "content": "What is 3 times 5?"}, - ] - qwen3_model = "Qwen/Qwen3-30B-A3B" - is_equal, official_prompt, implemented_prompt, highlighted_prompt = compare_hf_template(qwen3_model, "qwen3", messages=messages, enable_thinking=False, add_generation_prompt=True) - assert is_equal, f"Official prompt:\n\n{official_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}" - print(f"Highlighted prompt:\n\n{highlighted_prompt}") - -def test_with_system_message_and_tools(): - multiply_schema = {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}} - add_schema = {"type": "function", "function": {"name": "add", "description": "A function that add two numbers", "parameters": {"type": "object", "properties": {"a": {"type": "number", "description": "The first number to add"}, "b": {"type": "number", "description": "The second number to add"}}, "required": ["a", "b"]}}} - tools = [multiply_schema, add_schema] - messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hello, how are you?"}, - {"role": "assistant", "content": "I am fine, thank you."}, - {"role": "user", "content": "What is 3 times 5?"}, - ] - qwen3_model = "Qwen/Qwen3-30B-A3B" - is_equal, official_prompt, implemented_prompt, highlighted_prompt = compare_hf_template(qwen3_model, "qwen3", messages=messages, tools=tools, enable_thinking=False,add_generation_prompt=True) - assert is_equal, f"Official prompt:\n\n{official_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}" - print(f"Highlighted prompt:\n\n{highlighted_prompt}") \ No newline at end of file diff --git a/agents/tests/unit/agents/prompts/test_template_tokenize.py b/agents/tests/unit/agents/prompts/test_template_tokenize.py deleted file mode 100644 index ff22ba1..0000000 --- a/agents/tests/unit/agents/prompts/test_template_tokenize.py +++ /dev/null @@ -1,62 +0,0 @@ -from agents.agents.templates.utils import is_vlm_template, tokenize_conversation -import pytest -from transformers import AutoTokenizer, AutoProcessor -import torch -from agents.agents.templates.templates import Chat - -@pytest.mark.parametrize("template", ["qwen2.5"]) -@pytest.mark.parametrize("messages", [ - [ - {"role": "user", "content": "Hello, how are you?"}, - {"role": "assistant", "content": "I am fine, thank you."}, - {"role": "user", "content": "Want to play a game?"}, - {"role": "assistant", "content": "Sure, what game?"}, - ], - [ - {"role": "user", "content": "Help me to calculate 3 times 5."}, - {"role": "assistant", "content": '''{"name": "multiply", "arguments": {"x": 3, "y": 5}}'''}, - {"role": "tool", "content": "15"}, - ], - [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hello, how are you?"}, - {"role": "assistant", "content": "I am fine, thank you."}, - {"role": "user", "content": "What is 3 times 5?"}, - ], -]) -@pytest.mark.parametrize("tools", [ - None, - # [ - # {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}}, - # {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}}, - # ] -]) -@pytest.mark.parametrize("add_generation_prompt", [False]) -def test_template_tokenize(template, messages, tools, add_generation_prompt): - template_tokenizer_mapping = { - "qwen2.5": "Qwen/Qwen2.5-3B-Instruct", - "qwen2.5-vl": "Qwen/Qwen2.5-VL-3B-Instruct", - "qwen3": "Qwen/Qwen3-8B", - } - tokenizer = AutoTokenizer.from_pretrained(template_tokenizer_mapping[template]) - if is_vlm_template(template): - processor = AutoProcessor.from_pretrained(template_tokenizer_mapping[template]) - else: - processor = None - try: - official_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=add_generation_prompt) - official_inputs = tokenizer(official_prompt, return_tensors="pt") - # chat = Chat(template, messages, tokenizer) - # implemented_inputs = chat.tokenize() - implemented_inputs = tokenize_conversation(messages, tokenizer, template, max_length=2048, processor=processor, tools=tools, add_generation_prompt=add_generation_prompt, return_tensors="pt") - - assert torch.equal(official_inputs["input_ids"], implemented_inputs["input_ids"]), f"template: {template}\n\nmessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nofficial_prompt: {official_prompt}\n\nimplemented_prompt: {tokenizer.decode(implemented_inputs['input_ids'][0])}\n\nofficial_inputs: {official_inputs}\n\nimplemented_inputs: {implemented_inputs}" - assert torch.equal(official_inputs["attention_mask"], implemented_inputs["attention_mask"]) - except Exception as e: - if isinstance(e, ValueError) and "does not support tool calling." in str(e) and tools is not None and template in ["qwen2.5-vl"]: - pass - else: - raise e - # assert torch.equal(official_inputs["labels"], implemented_inputs["labels"]) - - diff --git a/agents/tests/unit/agents/prompts/test_think_prompt.py b/agents/tests/unit/agents/prompts/test_think_prompt.py deleted file mode 100644 index 94a3ec2..0000000 --- a/agents/tests/unit/agents/prompts/test_think_prompt.py +++ /dev/null @@ -1,29 +0,0 @@ -import openai - -def test_think_prompt(): - client = openai.OpenAI(api_key="token-123", base_url="http://0.0.0.0:8000/v1") - tools = [ - { - "type": "function", - "function": { - "name": "code_interpreter", - "description": "A python code interpreter that can execute code and return the result", - "parameters": { - "type": "object", - "properties": { - "code": {"type": "string", "description": "The python code to execute"}, - }, - }, - } - } - ] - response = client.chat.completions.create( - # model="/mnt/sharefs/users/haonan.li/models/Qwen2.5-7B-instruct-am_think_v1_distilled", - model="Qwen/Qwen2.5-7B-Instruct", - tools=tools, - messages=[ - {"role": "user", "content": "Every morning Aya goes for a $9$-kilometer-long walk and stops at a coffee shop afterwards. When she walks at a constant speed of $s$ kilometers per hour, the walk takes her 4 hours, including $t$ minutes spent in the coffee shop. When she walks $s+2$ kilometers per hour, the walk takes her 2 hours and 24 minutes, including $t$ minutes spent in the coffee shop. Suppose Aya walks at $s+\frac{1}{2}$ kilometers per hour. Find the number of minutes the walk takes her, including the $t$ minutes spent in the coffee shop."}, - ], - max_tokens=8192 - ) - print(response.choices[0].message.content) \ No newline at end of file diff --git a/agents/tests/unit/agents/templates/test_qwen3_prompt.py b/agents/tests/unit/agents/templates/test_qwen3_prompt.py new file mode 100644 index 0000000..e69de29 diff --git a/agents/tests/unit/agents/templates/test_template_utilities.py b/agents/tests/unit/agents/templates/test_template_utilities.py new file mode 100644 index 0000000..47cba54 --- /dev/null +++ b/agents/tests/unit/agents/templates/test_template_utilities.py @@ -0,0 +1,33 @@ +from agents.agents.templates.templates import get_template, register_template, Template +from agents.agents.templates.vision_processor import get_processor + +def test_template_registration(): + register_template( + Template( + name="test", + system_template="", + user_template="", + assistant_template="", + stop_words=[] + ) + ) + assert get_template("test") is not None + assert get_template("test").name == "test" + +def test_template_registration_with_vision(): + register_template( + Template( + name="test-vl", + system_template="", + user_template="", + assistant_template="", + stop_words=[], + image_token="<|image_pad|>", + ) + ) + assert get_processor("test-vl") is not None + assert get_processor("test-vl").config.image_token == "<|image_pad|>" + + + + \ No newline at end of file diff --git a/agents/tests/unit/agents/templates/test_text_templates_full_align.py b/agents/tests/unit/agents/templates/test_text_templates_full_align.py new file mode 100644 index 0000000..006ee4e --- /dev/null +++ b/agents/tests/unit/agents/templates/test_text_templates_full_align.py @@ -0,0 +1,111 @@ +""" This file is for testing the text templates that align seamlessly with HF templates. The templates should align on following aspects: + - The obtained textual prompt should be the same as the one obtained from HF template with all the following options: + - add_generation_prompt + - tools + - The obtained textual prompt should be the same as the one obtained from Jinja template with all the following options: + - add_generation_prompt + - tools +""" + + +from agents.agents.templates.utils import compare_hf_template +from transformers import AutoTokenizer +import pytest + +@pytest.mark.parametrize("model_name_or_path", [ + # "Qwen/Qwen2.5-3B-Instruct", + "mistralai/Mistral-7B-Instruct-v0.3", +]) +@pytest.mark.parametrize("messages", [ + [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I am fine, thank you."}, + {"role": "user", "content": "Want to play a game?"}, + {"role": "assistant", "content": "Sure, what game?"}, + ], + [ + {"role": "user", "content": "Help me to calculate 3 times 5."}, + {"role": "assistant", "content": '''{"name": "multiply", "arguments": {"x": 3, "y": 5}}'''}, + {"role": "tool", "content": "15"}, + ], + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I am fine, thank you."}, + {"role": "user", "content": "What is 3 times 5?"}, + ], +]) +@pytest.mark.parametrize("tools", [ + None, + [ + {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}}, + {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}}, + ] +]) +@pytest.mark.parametrize("add_generation_prompt", [True, False]) +def test_hf_template_print(model_name_or_path, messages, tools, add_generation_prompt): + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True) + prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=add_generation_prompt, tools=tools) + print(f"========================================\nModel: {model_name_or_path}\nMessages: {messages}\nTools: {tools}\nAdd generation prompt: {add_generation_prompt}\n") + print(prompt) + print("========================================\n") + + +# "qwen2.5-think", "qwen2.5", "qwen2.5-no-tool", +# "llama-3.2", "mistral", "glm-4", "internlm2.5", "phi-3.5", "phi-4" +@pytest.mark.parametrize("template", ["llama-3.2", "qwen2.5"]) +@pytest.mark.parametrize("messages", [ + [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I am fine, thank you."}, + {"role": "user", "content": "Want to play a game?"}, + {"role": "assistant", "content": "Sure, what game?"}, + ], + [ + {"role": "user", "content": "Help me to calculate 3 times 5."}, + {"role": "assistant", "content": '''{"name": "multiply", "arguments": {"x": 3, "y": 5}}'''}, + {"role": "tool", "content": "15"}, + ], + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I am fine, thank you."}, + {"role": "user", "content": "What is 3 times 5?"}, + ], +]) +@pytest.mark.parametrize("tools", [ + None, + [ + {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}}, + {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}}, + ] +]) +@pytest.mark.parametrize("add_generation_prompt", [True, False]) +def test_chat_template_equal(template, messages, tools, add_generation_prompt): + # Filter invalid combinations + if add_generation_prompt and messages[-1]['role'] == 'assistant': + return + + + template_tokenizer_mapping = { + "qwen2.5": "Qwen/Qwen2.5-3B-Instruct", + "qwen2.5-think": "Qwen/Qwen2.5-3B-Instruct", + "qwen2.5-no-system-tool": "Qwen/Qwen2.5-3B-Instruct", + "deepseek-prover-v2": "deepseek-ai/DeepSeek-Prover-V2-7B", + "llama-3.2": "meta-llama/Llama-3.2-3B-Instruct", + "mistral": "mistralai/Mistral-7B-Instruct-v0.3", + "glm-4": "THUDM/glm-4-9b-chat", + "internlm2.5": "internlm/internlm2_5-7b-chat", + "phi-3.5": "microsoft/Phi-3.5-mini-instruct", + "phi-4": "microsoft/Phi-4", + "nemotron": "nvidia/Llama-3.1-Nemotron-Nano-8B-v1", + } + tokenizer = AutoTokenizer.from_pretrained(template_tokenizer_mapping[template], trust_remote_code=True) + + is_equal, is_equal_between_implemented_prompts, is_equal_between_jinja_prompts, official_prompt, implemented_prompt, implemented_jinja_prompt, highlighted_prompt = compare_hf_template(tokenizer, template, messages=messages, tools=tools,add_generation_prompt=add_generation_prompt) + + assert is_equal, f"Template: {template}\n\nMessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nOfficial prompt:\n\n{official_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}" + assert is_equal_between_jinja_prompts, f"Template: {template}\n\nMessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}\n\nJinja prompt:\n\n{implemented_jinja_prompt}" + # print(f"Official prompt:\n\n{official_prompt}") + # print(f"Highlighted prompt:\n\n{highlighted_prompt}") + diff --git a/agents/tests/unit/agents/prompts/test_templates.py b/agents/tests/unit/agents/templates/test_text_templates_partial_align.py similarity index 54% rename from agents/tests/unit/agents/prompts/test_templates.py rename to agents/tests/unit/agents/templates/test_text_templates_partial_align.py index 068f832..78b7f87 100644 --- a/agents/tests/unit/agents/prompts/test_templates.py +++ b/agents/tests/unit/agents/templates/test_text_templates_partial_align.py @@ -1,8 +1,10 @@ -from agents.agents.templates.utils import compare_hf_template -from transformers import AutoTokenizer import pytest +from transformers import AutoTokenizer +from agents.agents.templates.templates import get_template +from agents.agents.templates.utils import compare_hf_template -@pytest.mark.parametrize("template", ["qwen2.5-think", "qwen2.5", "qwen2.5-no-tool"]) +# nemotron, phi-4, glm-4 +@pytest.mark.parametrize("template_name", ["qwen2.5-think", "qwen2.5-no-system-tool",]) @pytest.mark.parametrize("messages", [ [ {"role": "user", "content": "Hello, how are you?"}, @@ -13,7 +15,7 @@ [ {"role": "user", "content": "Help me to calculate 3 times 5."}, {"role": "assistant", "content": '''{"name": "multiply", "arguments": {"x": 3, "y": 5}}'''}, - {"role": "tool", "content": "15"}, + {"role": "tool", "content": "15", "tool_call_id": "123456789"}, ], [ {"role": "system", "content": "You are a helpful assistant."}, @@ -30,20 +32,37 @@ ] ]) @pytest.mark.parametrize("add_generation_prompt", [True, False]) -def test_chat_template_equal(template, messages, tools, add_generation_prompt): +def test_chat_template_equal(template_name, messages, tools, add_generation_prompt): # Filter invalid combinations if add_generation_prompt and messages[-1]['role'] == 'assistant': return + + template = get_template(template_name) + if tools and not template._supports_tool_call(): + return + + contain_tool_role = any(message['role'] == 'tool' for message in messages) + if contain_tool_role and not template._supports_tool_call(): + return template_tokenizer_mapping = { "qwen2.5": "Qwen/Qwen2.5-3B-Instruct", "qwen2.5-think": "Qwen/Qwen2.5-3B-Instruct", - "qwen2.5-no-tool": "Qwen/Qwen2.5-3B-Instruct", + "qwen2.5-no-system-tool": "Qwen/Qwen2.5-3B-Instruct", + "deepseek-prover-v2": "deepseek-ai/DeepSeek-Prover-V2-7B", + "llama-3.2": "meta-llama/Llama-3.2-3B-Instruct", + "mistral": "mistralai/Mistral-7B-Instruct-v0.3", + "glm-4": "THUDM/glm-4-9b-chat", + "internlm2.5": "internlm/internlm2_5-7b-chat", + "phi-3.5": "microsoft/Phi-3.5-mini-instruct", + "phi-4": "microsoft/Phi-4", + "nemotron": "nvidia/Llama-3.1-Nemotron-Nano-8B-v1", } - tokenizer = AutoTokenizer.from_pretrained(template_tokenizer_mapping[template]) - - is_equal, is_equal_between_implemented_prompts, is_equal_between_jinja_prompts, official_prompt, implemented_prompt, implemented_jinja_prompt, highlighted_prompt = compare_hf_template(tokenizer, template, messages=messages, tools=tools,add_generation_prompt=add_generation_prompt) - # assert is_equal, print(f"Template: {template}\n\nMessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nOfficial prompt:\n\n{official_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}") - assert is_equal_between_jinja_prompts, print(f"Template: {template}\n\nMessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}\n\nJinja prompt:\n\n{implemented_jinja_prompt}") - print(f"Highlighted prompt:\n\n{highlighted_prompt}") + tokenizer = AutoTokenizer.from_pretrained(template_tokenizer_mapping[template_name], trust_remote_code=True) + is_equal, is_equal_between_implemented_prompts, is_equal_between_jinja_prompts, official_prompt, implemented_prompt, implemented_jinja_prompt, highlighted_prompt = compare_hf_template(tokenizer, template_name, messages=messages, tools=tools,add_generation_prompt=add_generation_prompt) + + print(f"Official prompt:\n\n\"{official_prompt}\"\n\n") + print(f"Implemented prompt:\n\n\"{implemented_prompt}\"\n\n") + assert is_equal, f"Template: {template}\n\nMessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nOfficial prompt:\n\n{official_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}" + assert is_equal_between_jinja_prompts, f"Template: {template}\n\nMessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}\n\nJinja prompt:\n\n{implemented_jinja_prompt}" \ No newline at end of file diff --git a/agents/tests/unit/agents/templates/test_text_templates_tokenize.py b/agents/tests/unit/agents/templates/test_text_templates_tokenize.py new file mode 100644 index 0000000..a89e0b8 --- /dev/null +++ b/agents/tests/unit/agents/templates/test_text_templates_tokenize.py @@ -0,0 +1,61 @@ +""" This file is for testing the tokenization of the templates. The templates should align on following aspects: + - The tokenized prompt should be the same as the one obtained from HF template with all the following options: + - add_generation_prompt + - tools + - We need to observe the labels and action_mask to make sure the the they are correct. + +Since the align for textual prompt is already tested in other files, we only need to test the tokenization of the templates. +""" + +from agents.agents.templates.utils import tokenize_conversation +import pytest +from transformers import AutoTokenizer +import torch +from agents.agents.templates.templates import Chat + +@pytest.mark.parametrize("template", ["llama-3.2", "qwen2.5"]) +@pytest.mark.parametrize("messages", [ + [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I am fine, thank you."}, + {"role": "user", "content": "Want to play a game?"}, + {"role": "assistant", "content": "Sure, what game?"}, + ], + [ + {"role": "user", "content": "Help me to calculate 3 times 5."}, + {"role": "assistant", "content": '''{"name": "multiply", "arguments": {"x": 3, "y": 5}}'''}, + {"role": "tool", "content": "15"}, + ], + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I am fine, thank you."}, + {"role": "user", "content": "What is 3 times 5?"}, + {"role": "assistant", "content": "15"}, + {"role": "user", "content": "OK, what is 3 times 6?"}, + {"role": "assistant", "content": "18"}, + ], +]) +@pytest.mark.parametrize("tools", [ + None, + [ + {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}}, + {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}}, + ] +]) +@pytest.mark.parametrize("add_generation_prompt", [False, True]) +def test_template_tokenize(template, messages, tools, add_generation_prompt): + template_tokenizer_mapping = { + "qwen2.5": "Qwen/Qwen2.5-3B-Instruct", + "llama-3.2": "meta-llama/Llama-3.2-3B-Instruct", + } + tokenizer = AutoTokenizer.from_pretrained(template_tokenizer_mapping[template], trust_remote_code=True) + + chat = Chat(template, messages, tools=tools) + prompt = chat.prompt(add_generation_prompt=add_generation_prompt, tools=tools) + + hf_inputs = tokenizer(prompt, return_tensors="pt") + + implemented_inputs = tokenize_conversation(messages, tokenizer, template, max_length=2048, tools=tools, add_generation_prompt=add_generation_prompt, return_tensors="pt") + + assert torch.equal(hf_inputs["input_ids"], implemented_inputs["input_ids"]), f"template: {template}\n\nmessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nprompt: {prompt}\n\nimplemented_prompt: {tokenizer.decode(implemented_inputs['input_ids'][0])}\n\nhf_inputs: {hf_inputs}\n\nimplemented_inputs: {implemented_inputs}" diff --git a/agents/tests/unit/agents/templates/test_vision_templates_full_align.py b/agents/tests/unit/agents/templates/test_vision_templates_full_align.py new file mode 100644 index 0000000..167307d --- /dev/null +++ b/agents/tests/unit/agents/templates/test_vision_templates_full_align.py @@ -0,0 +1,87 @@ +""" This file is for testing the vision templates that align seamlessly with HF templates. The templates should align on following aspects: + - The obtained textual prompt should be the same as the one obtained from HF template with all the following options: + - add_generation_prompt + - tools + - The obtained textual prompt should be the same as the one obtained from Jinja template with all the following options: + - add_generation_prompt + - tools +To test vision part, the messages should contain at least one image. +""" + + +from agents.agents.templates.utils import compare_hf_template +from transformers import AutoTokenizer +import pytest +# "qwen2.5-think", "qwen2.5", "qwen2.5-no-tool", +@pytest.mark.parametrize("template", ["qwen2.5-vl"]) +@pytest.mark.parametrize("messages", [ + [ + { + "role": "system", + "content": "You are a multi-modal assistant that can answer questions about images.", + }, + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", + }, + {"type": "text", "text": "Describe this image."}, + ], + }, + { + "role": "assistant", + "content": [ + { + "type": "text", + "text": "The image is a cat.", + }, + ], + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What is in the image?", + }, + ], + } + ], + [ + {"role": "user", "content": "Help me to calculate 3 times 5."}, + {"role": "assistant", "content": '''{"name": "multiply", "arguments": {"x": 3, "y": 5}}'''}, + {"role": "tool", "content": "15"}, + ], + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I am fine, thank you."}, + {"role": "user", "content": "What is 3 times 5?"}, + ], +]) +@pytest.mark.parametrize("tools", [ + None, + [ + {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}}, + {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}}, + ] +]) +@pytest.mark.parametrize("add_generation_prompt", [True, False]) +def test_chat_template_equal(template, messages, tools, add_generation_prompt): + # Filter invalid combinations + if add_generation_prompt and messages[-1]['role'] == 'assistant': + return + + template_tokenizer_mapping = { + "qwen2.5-vl": "Qwen/Qwen2.5-VL-3B-Instruct", + } + tokenizer = AutoTokenizer.from_pretrained(template_tokenizer_mapping[template]) + + is_equal, is_equal_between_implemented_prompts, is_equal_between_jinja_prompts, official_prompt, implemented_prompt, implemented_jinja_prompt, highlighted_prompt = compare_hf_template(tokenizer, template, messages=messages, tools=tools,add_generation_prompt=add_generation_prompt) + assert is_equal, f"Template: {template}\n\nMessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nOfficial prompt:\n\n{official_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}" + assert is_equal_between_jinja_prompts, f"Template: {template}\n\nMessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}\n\nJinja prompt:\n\n{implemented_jinja_prompt}" + print(f"Official prompt:\n\n{official_prompt}") + print(f"Highlighted prompt:\n\n{highlighted_prompt}") + diff --git a/agents/tests/unit/agents/templates/test_vision_templates_tokenize.py b/agents/tests/unit/agents/templates/test_vision_templates_tokenize.py new file mode 100644 index 0000000..31b5ecd --- /dev/null +++ b/agents/tests/unit/agents/templates/test_vision_templates_tokenize.py @@ -0,0 +1,110 @@ +""" This file is for testing the text templates that align seamlessly with HF templates. The templates should align on following aspects: + - The obtained textual prompt should be the same as the one obtained from HF template with all the following options: + - add_generation_prompt + - tools + - The obtained textual prompt should be the same as the one obtained from Jinja template with all the following options: + - add_generation_prompt + - tools +""" + + +from agents.agents.templates.templates import Chat +from agents.agents.templates.utils import compare_hf_template, tokenize_conversation +from transformers import AutoTokenizer +import pytest +import torch +from transformers import AutoProcessor +from qwen_vl_utils import process_vision_info + +@pytest.mark.parametrize("template", ["qwen2.5-vl"]) +@pytest.mark.parametrize("messages", [ + [ + { + "role": "system", + "content": "You are a multi-modal assistant that can answer questions about images.", + }, + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", + }, + {"type": "text", "text": "Describe this image."}, + ], + }, + { + "role": "assistant", + "content": [ + { + "type": "text", + "text": "The image is a cat.", + }, + ], + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What is in the image?", + }, + ], + } + ], + # [ + # {"role": "user", "content": "Help me to calculate 3 times 5."}, + # {"role": "assistant", "content": '''{"name": "multiply", "arguments": {"x": 3, "y": 5}}'''}, + # {"role": "tool", "content": "15"}, + # ], + # [ + # {"role": "system", "content": "You are a helpful assistant."}, + # {"role": "user", "content": "Hello, how are you?"}, + # {"role": "assistant", "content": "I am fine, thank you."}, + # {"role": "user", "content": "What is 3 times 5?"}, + # ], +]) +@pytest.mark.parametrize("tools", [ + None, + # [ + # {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}}, + # {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}}, + # ] +]) +@pytest.mark.parametrize("add_generation_prompt", [True, False]) +def test_chat_template_equal(template, messages, tools, add_generation_prompt): + template_tokenizer_mapping = { + "qwen2.5-vl": "Qwen/Qwen2.5-VL-3B-Instruct", + } + tokenizer = AutoTokenizer.from_pretrained(template_tokenizer_mapping[template]) + processor = AutoProcessor.from_pretrained(template_tokenizer_mapping[template]) + official_prompt = tokenizer.apply_chat_template(messages + , tokenize=False, add_generation_prompt=add_generation_prompt) + image_inputs, video_inputs = process_vision_info(messages) + official_inputs = processor( + text=[official_prompt], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ) + + implemented_inputs = tokenize_conversation(messages, tokenizer, template, max_length=8192, tools=tools, add_generation_prompt=add_generation_prompt, return_tensors="pt", processor=processor) + + official_prompt = tokenizer.decode(official_inputs['input_ids'][0]) + implemented_prompt = tokenizer.decode(implemented_inputs['input_ids'][0]) + print(f"Official prompt image tokens: {official_prompt.count('<|image_pad|>')}\nImplemented prompt image tokens: {implemented_prompt.count('<|image_pad|>')}") + + print(f"Official images: {official_inputs['pixel_values'].shape}\nImplemented images: {implemented_inputs['pixel_values'].shape}") + + assert torch.equal(official_inputs["input_ids"], implemented_inputs["input_ids"]), f"""Offical + prompt:\n{official_prompt}\nImplemented prompt:\n{implemented_prompt}""" + + assert torch.equal(official_inputs["pixel_values"], implemented_inputs["pixel_values"]) + + assert torch.equal(official_inputs["image_grid_thw"], implemented_inputs["image_grid_thw"]) + + assert implemented_inputs["input_ids"].shape == implemented_inputs["action_mask"].shape, f"""Official action mask shape: {official_inputs["action_mask"].shape}\nImplemented action mask shape: {implemented_inputs["action_mask"].shape}""" + + + print(f"official_prompt: {official_prompt}\nimplemented_prompt: {tokenizer.decode(implemented_inputs['input_ids'][0])}\nofficial_inputs: {official_inputs.keys()}\nimplemented_inputs: {implemented_inputs.keys()}\n") \ No newline at end of file diff --git a/verl b/verl index 0186ec8..861f63b 160000 --- a/verl +++ b/verl @@ -1 +1 @@ -Subproject commit 0186ec81e9273953d2f34b0c4d48741cc2f9aabc +Subproject commit 861f63ba8097a43ababe27116842512783080586