From 570436735bb3698f8188e0029cf47b67066e2e01 Mon Sep 17 00:00:00 2001 From: Reason-Wang Date: Tue, 5 Aug 2025 06:51:00 +0000 Subject: [PATCH 1/6] Fix openai agent --- agents/agents/agents/agent_base.py | 100 ++++++--- agents/agents/agents/backend_config.py | 66 ++++++ agents/agents/agents/chain/chain_base.py | 56 ++--- agents/agents/agents/llm_backend.py | 100 ++++----- .../agents/agents/specialized/openai_agent.py | 204 ++++++++++-------- agents/agents/agents/templates/utils.py | 31 ++- agents/agents/agents/utils/tokenizer.py | 10 + agents/agents/examples/streaming_example.py | 59 ----- 8 files changed, 348 insertions(+), 278 deletions(-) create mode 100644 agents/agents/agents/backend_config.py create mode 100644 agents/agents/agents/utils/tokenizer.py delete mode 100644 agents/agents/examples/streaming_example.py diff --git a/agents/agents/agents/agent_base.py b/agents/agents/agents/agent_base.py index 0d712ce..8987dca 100644 --- a/agents/agents/agents/agent_base.py +++ b/agents/agents/agents/agent_base.py @@ -4,7 +4,7 @@ from .templates.templates import get_template from ..__init__ import AGENT_DATA_DIR -from .llm_backend import AsyncVLLMBackend, AsyncVerlBackend, ClientBackend, TransformersBackend, VLLMBackend, VerlBackend +from .llm_backend import AsyncVLLMBackend, AsyncVerlBackend, ClientBackend, TransformersBackend, VLLMBackend from ..utils.logging import get_logger from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np @@ -15,6 +15,8 @@ import transformers import warnings from .chain.streaming_observer import ConsoleStreamObserver, StreamingManager +from .utils.tokenizer import create_tokenizer +from .backend_config import BACKEND_CONFIGS try: from verl.protocol import DataProto except ImportError: @@ -34,12 +36,13 @@ class BaseAgent(ChainGeneration, ABC): def __init__( self, model_name_or_path, - template: str, + template: str=None, system_prompt: str = None, tools: List = None, max_length: int=8192, debug: bool = False, backend: str = "transformers", + backend_config: Any = None, reward_fn: Callable = None, log_file: str = "agent", project_name: str = None, @@ -65,9 +68,30 @@ def __init__( self.tools = tools self.system_prompt = system_prompt self.model_name_or_path = model_name_or_path - self.llm_engine, self.tokenizer, self.processor = self._init_llm_engine(model_name_or_path, backend) + + # Handle backend configuration + if backend_config is None: + # Use default configuration for the backend + config_class = BACKEND_CONFIGS.get(backend) + if config_class: + self.backend_config = config_class() + else: + self.backend_config = None + else: + self.backend_config = backend_config + + self.llm_engine = self._init_llm_engine(model_name_or_path, backend) + + # Create appropriate tokenizer for trajectory processing + self.tokenizer = create_tokenizer(model_name_or_path) + self._reward_fn = reward_fn - self.jinja_template = get_template(self.template).jinja_template() + + if self.template is None: + self.jinja_template = None + else: + self.jinja_template = get_template(self.template).jinja_template() + self.project_name = project_name self.run_name = run_name self.streaming_manager = StreamingManager() @@ -78,33 +102,53 @@ def __init__( raise ValueError(f"Streaming mode {streaming} is not supported.") super().__init__() if kwargs: - warnings.warn(f"Unused arguments for agent initialization: {kwargs}") + # warnings.warn(f"Unused arguments for agent initialization: {kwargs}") + raise ValueError(f"Unused arguments for agent initialization: {kwargs}") def _init_llm_engine(self, model_name_or_path: str, backend: str): if isinstance(model_name_or_path, str): + # Extract backend-specific configuration + config_kwargs = {} + if self.backend_config: + config_kwargs = {k: v for k, v in self.backend_config.__dict__.items() + if not k.startswith('_')} + if backend == "transformers": - llm_engine = TransformersBackend(model_name_or_path, self.template, max_length=self.max_length) - elif backend == "vllm": - llm_engine = VLLMBackend(model_name_or_path, self.template, max_length=self.max_length) + llm_engine = TransformersBackend( + model_name_or_path, + self.template, + max_length=self.max_length, + **config_kwargs + ) elif backend == "async_vllm": - llm_engine = AsyncVLLMBackend(model_name_or_path, self.template, max_length=self.max_length) - elif backend == "verl": - llm_engine = VerlBackend(llm_engine=None, model_name_or_path=model_name_or_path, template=self.template, max_length=self.max_length) + llm_engine = AsyncVLLMBackend( + model_name_or_path, + self.template, + max_length=self.max_length, + **config_kwargs + ) elif backend == "async_verl": - llm_engine = AsyncVerlBackend(llm_engine=None, model_name_or_path=model_name_or_path, template=self.template, max_length=self.max_length) + llm_engine = AsyncVerlBackend( + llm_engine=None, + model_name_or_path=model_name_or_path, + template=self.template, + max_length=self.max_length, + **config_kwargs + ) elif backend == "client": - llm_engine = ClientBackend(model_name_or_path, self.template, max_length=self.max_length) + print(f"config_kwargs: {config_kwargs}") + llm_engine = ClientBackend( + model_name_or_path, + self.template, + max_length=self.max_length, + **config_kwargs + ) else: raise ValueError(f"Backend {backend} is not supported.") else: raise ValueError("model_name_or_path must be a string.") - tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path) - if is_vlm_template(self.template): - processor = transformers.AutoProcessor.from_pretrained(model_name_or_path) - else: - processor = None - return llm_engine, tokenizer, processor + return llm_engine def set_llm_engine(self, llm_engine: Any, tokenizer: Any): assert self.backend == "async_verl", "Only async verl backend is supported for now" @@ -151,17 +195,6 @@ async def generate_streaming(self, messages_list_or_inputs: List[List[Dict]], st @property def timing_data(self): return self.timer.timing_data - - def forward(self, messages_list_or_inputs: List[List[Dict]], **args): - if isinstance(messages_list_or_inputs, List): - inputs = tokenize_conversations(messages_list_or_inputs, tokenizer=self.tokenizer, conv_template=self.template, max_length=self.max_length, processor=self.processor) - else: - raise ValueError("messages_list_or_inputs must be a list of messages or a dictionary of padded inputs.") - - if isinstance(self.llm_engine, transformers.PreTrainedModel): - return self.llm_engine.forward(**inputs, **args) # Only support transformers models for now. - else: - raise ValueError("llm_engine must be a transformers.PretrainedModel.") @property def trajectories(self): @@ -169,7 +202,10 @@ def trajectories(self): return trajectories - def tokenize_trajectories(self, return_action_mask: bool = False, return_reward_mask: bool = False): + def tokenize_trajectories(self, tokenizer, return_action_mask: bool = False, return_reward_mask: bool = False): + if tokenizer is None: + tokenizer = self.tokenizer + trajectories = self.trajectories self.logger.info("================ Trajectory ================") self.logger.info(trajectories[0]) @@ -196,7 +232,7 @@ def tokenize_trajectories(self, return_action_mask: bool = False, return_reward_ info['last_response'] = last_response other_info_list.append(info) - inputs = tokenize_conversations(messages_list, tokenizer=self.tokenizer, conv_template=self.template, processor=self.processor, max_length=self.max_length, return_reward_mask=return_reward_mask) + inputs = tokenize_conversations(messages_list, tokenizer=tokenizer, conv_template=self.template, processor=self.processor, max_length=self.max_length, return_reward_mask=return_reward_mask) position_ids = torch.clip(torch.cumsum(inputs['attention_mask'], dim=-1) - 1, min=0, max=None) inputs['position_ids'] = position_ids diff --git a/agents/agents/agents/backend_config.py b/agents/agents/agents/backend_config.py new file mode 100644 index 0000000..54a5a62 --- /dev/null +++ b/agents/agents/agents/backend_config.py @@ -0,0 +1,66 @@ +from dataclasses import dataclass +from typing import Optional, Dict, Any, List +import asyncio + + +@dataclass +class TransformersConfig: + """Configuration for Transformers backend""" + temperature: float = 1.0 + max_new_tokens: int = 1024 + trust_remote_code: bool = True + device_map: str = "auto" + + +@dataclass +class VLLMConfig: + """Configuration for VLLM backend""" + temperature: float = 1.0 + max_new_tokens: int = 1024 + # Add other vLLM specific parameters as needed + + +@dataclass +class AsyncVLLMConfig: + """Configuration for Async VLLM backend""" + temperature: float = 1.0 + max_new_tokens: int = 1024 + # Add other async vLLM specific parameters as needed + + +@dataclass +class VerlConfig: + """Configuration for Verl backend""" + temperature: float = 1.0 + max_new_tokens: int = 1024 + # Add other Verl specific parameters as needed + + +@dataclass +class AsyncVerlConfig: + """Configuration for Async Verl backend""" + temperature: float = 1.0 + max_new_tokens: int = 1024 + # Add other async Verl specific parameters as needed + + +@dataclass +class ClientConfig: + """Configuration for Client backend (OpenAI-compatible)""" + base_url: str = "http://localhost:8000/v1" + max_requests_per_minute: int = 100 + timeout: int = 600 + api_key: str = "EMPTY" + max_new_tokens: int = 1024 + temperature: float = 1.0 + + +# Backend configuration mapping +BACKEND_CONFIGS = { + "transformers": TransformersConfig, + "vllm": VLLMConfig, + "async_vllm": AsyncVLLMConfig, + "verl": VerlConfig, + "async_verl": AsyncVerlConfig, + "client": ClientConfig, +} \ No newline at end of file diff --git a/agents/agents/agents/chain/chain_base.py b/agents/agents/agents/chain/chain_base.py index 843b681..02fa27d 100644 --- a/agents/agents/agents/chain/chain_base.py +++ b/agents/agents/agents/chain/chain_base.py @@ -316,14 +316,40 @@ async def _run_single_chain(self, # Handle tool calls if current_node.messages[-1].get("tool_calls"): for tool_call in current_node.messages[-1]["tool_calls"]: - current_node = await self._execute_tool_call( + result = await self._execute_tool_call( tool_call, newest_messages, chain, chain_id, depth, have_set_tools, enable_streaming ) have_set_tools = True + + # Create action input node + action_input_node = chain.add_node( + type="Action Input", + messages=deepcopy(newest_messages), + description=result.get("arguments", "") + ) + + # Process observation + observation = result["observation"] + observation_json = json.dumps({ + "name": result["name"], + "content": observation, + }, indent=4) + + action_input_node.observation = observation_json + action_input_node.observation_code = result["status"] + newest_messages.append({ + "role": "tool", + "tool_call_id": tool_call["id"], + "content": [{"type": "text", "text": observation_json}], + }) + action_input_node.messages = deepcopy(newest_messages) + action_input_node.is_terminal = result["status"] in self.terminal_status else: # No tool calls, chain is finished break + + current_node = action_input_node depth += 1 @@ -463,33 +489,9 @@ async def _execute_tool_call(self, tool_call, newest_messages, chain, chain_id, step=depth, depth=depth )) + + return result - - # Create action input node - action_input_node = chain.add_node( - type="Action Input", - messages=deepcopy(newest_messages), - description=result.get("arguments", "") - ) - - # Process observation - observation = result["observation"] - observation_json = json.dumps({ - "name": result["name"], - "content": observation, - }, indent=4) - - action_input_node.observation = observation_json - action_input_node.observation_code = result["status"] - newest_messages.append({ - "role": "tool", - "tool_call_id": tool_call["id"], - "content": [{"type": "text", "text": observation_json}], - }) - action_input_node.messages = deepcopy(newest_messages) - action_input_node.is_terminal = result["status"] in self.terminal_status - - return action_input_node async def _finalize_chain(self, chain_id, chain, current_node, depth): """Finalize the chain with reward calculation and cleanup.""" diff --git a/agents/agents/agents/llm_backend.py b/agents/agents/agents/llm_backend.py index 2401137..697c68a 100644 --- a/agents/agents/agents/llm_backend.py +++ b/agents/agents/agents/llm_backend.py @@ -325,49 +325,6 @@ async def generate_streaming(self, messages_list: List[List[Dict]], **kwargs) -> if hasattr(sequence, 'text'): yield sequence.text -class VerlBackend(LLMBackend): - """Verl implementation""" - - def __init__(self, llm_engine, model_name_or_path: str, template: str, max_length: int=8192, **kwargs): - super().__init__(**kwargs) - self.model_name = model_name_or_path - self.max_length = max_length - self.template = template - self.tokenizer = AutoTokenizer.from_pretrained( - self.model_name, - trust_remote_code=True, - ) - self.llm_engine = llm_engine - - def generate(self, messages_list: str, **kwargs) -> str: - """Generate text from prompt using Verl""" - # We need to build a DataProto from the prompts - prompts = self.apply_chat_template(messages_list, self.template) - inputs = self.tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left") - # We need to do padding for compatibility with the verl DataProto, which - # assumes that the batch size must be divisible by the dp size - world_size = self.llm_engine.world_size - inputs['input_ids'] = pad_tensor_to_rank_size(inputs['input_ids'], world_size) - inputs['attention_mask'] = pad_tensor_to_rank_size(inputs['attention_mask'], world_size) - - position_ids = torch.clip(torch.cumsum(inputs.attention_mask, dim=-1) - 1, min=0, max=None) - inputs['position_ids'] = position_ids - - n = kwargs.get("num_return_sequences", 1) - temperature = kwargs.get("temperature", 1.0) - use_agent = True - batch = DataProto.from_single_dict(inputs, meta_info={"n": n, "use_agent": use_agent, "temperature": temperature}) - - gen_batch_output = self.llm_engine.generate_sequences(batch) - responses = gen_batch_output.batch['responses'] # BS x L - response_texts = self.tokenizer.batch_decode(responses, skip_special_tokens=True) # List of string with length BS - response_texts = response_texts[:len(prompts)*n] - return response_texts - - def generate_async(self, messages_list: str, **kwargs) -> str: - raise NotImplementedError("Verl backend does not support async generation") - - class AsyncVerlBackend(LLMBackend): """Verl implementation""" @@ -465,24 +422,39 @@ def __init__( # --------------------------------------------------------------------- # # Low‑level single request (runs in threadpool so it doesn't block loop) # --------------------------------------------------------------------- # - @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=15)) - def _blocking_call(self, messages: List[List[Dict]], **kw) -> str: - if "num_return_sequences" in kw: - n = kw.pop("num_return_sequences") + @retry(stop=stop_after_attempt(1), wait=wait_exponential(multiplier=1, min=4, max=15)) + def _blocking_call(self, messages: List[List[Dict]], **kwargs) -> str: + if "num_return_sequences" in kwargs: + n = kwargs.pop("num_return_sequences") else: n = 1 + + if "tool_choice" in kwargs: + tool_choice = kwargs.pop("tool_choice") + else: + tool_choice = "none" + + print(f"[ClientBackend] messages: {messages}") resp = self.client.chat.completions.create( model=self.model_name, messages=messages, timeout=self.timeout, max_tokens=self.max_new_tokens, n=n, - tool_choice="none", - **kw, + tool_choice=tool_choice, + **kwargs, ) - response_texts = [choice.message.content for choice in resp.choices] + resp_json = resp.dict() + response_texts = [choice["message"]["content"] for choice in resp_json["choices"]] + tool_calls = [choice["message"]["tool_calls"] for choice in resp_json["choices"]] - return response_texts + if tool_choice == "none": + return response_texts + else: + return { + "response_texts": response_texts, + "tool_calls": tool_calls, + } async def _call(self, messages: List[List[Dict]], **kw) -> str: # acquire a rate‑limit token @@ -495,7 +467,7 @@ async def _call(self, messages: List[List[Dict]], **kw) -> str: def async_generate( self, messages: List[List[Dict]] | List[Dict], - **kw, + **kwargs, ) -> List[str] | asyncio.Task: """ • Pass a *list of messages* → single completion. @@ -510,14 +482,24 @@ def async_generate( messages_list = [messages] # single else: messages_list = messages # batch - + print(f"[ClientBackend] messages_list: {messages_list}") messages_list = [convert_messages_to_openai_format(messages) for messages in messages_list] async def _runner(): - tasks = [asyncio.create_task(self._call(_input, **kw)) for _input in messages_list] - texts_list = await asyncio.gather(*tasks) - response_texts = [text for texts in texts_list for text in texts] - return response_texts + tasks = [asyncio.create_task(self._call(_input, **kwargs)) for _input in messages_list] + # Flatten the response list + response_texts_list_or_dict = await asyncio.gather(*tasks) + # return is a dict if tool_choice is not none, otherwise a list of strings + if isinstance(response_texts_list_or_dict[0], dict): + response_texts = [text for response in response_texts_list_or_dict for text in response["response_texts"]] + tool_calls = [tool_call for response in response_texts_list_or_dict for tool_call in response["tool_calls"]] + return { + "response_texts": response_texts, + "tool_calls": tool_calls, + } + else: + response_texts = [text for response in response_texts_list_or_dict for text in response] + return response_texts try: loop = asyncio.get_running_loop() # ➊ already inside a loop? @@ -532,8 +514,8 @@ async def _runner(): async def generate_async(self, messages: List[List[Dict]] | List[Dict], - **kw) -> List[str]: - return await self.async_generate(messages, **kw) + **kwargs) -> List[str]: + return await self.async_generate(messages, **kwargs) # Background token‑bucket refill (one token each 60/max_rpm seconds) async def _refill_tokens(self): diff --git a/agents/agents/agents/specialized/openai_agent.py b/agents/agents/agents/specialized/openai_agent.py index f64a056..ae0f37a 100644 --- a/agents/agents/agents/specialized/openai_agent.py +++ b/agents/agents/agents/specialized/openai_agent.py @@ -1,10 +1,13 @@ import os -from typing import Any, Dict, List +from typing import Any, Dict, List, Tuple, Union from openai import OpenAI, AzureOpenAI import httpx - +import asyncio +from ...tools import answer_qa from ...tools.tool_base import tool from ..agent_base import BaseAgent +from ..llm_backend import ClientBackend +from ..backend_config import ClientConfig from tenacity import retry, wait_random_exponential, stop_after_attempt from termcolor import colored import json @@ -14,102 +17,113 @@ class OpenAIAgent(BaseAgent): def __init__( self, api_key="", + base_url="https://api.openai.com/v1", **kwargs ): - wrapper = kwargs.get("wrapper", False) - if not wrapper: - kwargs["wrapper"] = True - template = kwargs.get("template", "openai") - kwargs["template"] = template + assert api_key is not None and api_key != "", "API key is required" + backend = kwargs.get("backend", "client") + assert backend == "client", "OpenAI agent only supports client backend" + kwargs["backend"] = backend + + # Create client-specific configuration + client_config = ClientConfig( + api_key=api_key, + base_url=base_url, + max_requests_per_minute=kwargs.get("max_requests_per_minute", 100), + timeout=kwargs.get("timeout", 600), + max_new_tokens=kwargs.get("max_new_tokens", 1024), + temperature=kwargs.get("temperature", 1.0) + ) + kwargs["backend_config"] = client_config + + # Initialize the base class super(OpenAIAgent, self).__init__(**kwargs) + model_name_or_path = kwargs.get("model_name_or_path", "gpt-3.5-turbo") self.client = OpenAI(api_key=api_key) self.api_key = api_key self.model = model_name_or_path + + # For OpenAI models, we don't need a tokenizer for the LLM engine, but we still need one for trajectory processing + if self.backend == "client": + self.tokenizer = None + self.processor = None - def parse(self, messages_list: List[List[Dict]], tools: List[Any], **args): - # OpenAI use 'n' to specify the number of return sequences - num_return_sequences = args.get("num_return_sequences", 1) - tool_schemas = [tool.schema for tool in tools] - del args["num_return_sequences"] - args["n"] = num_return_sequences - - def process_message(message_data): - message, api_key, model, tool_schemas, args = message_data - client = OpenAI(api_key=api_key) - try: - json_data = { - "model": model, - "messages": message, - **args - } - if tool_schemas is not None: - json_data.update({"tools": tool_schemas}) - openai_response = client.chat.completions.create(**json_data) - result = openai_response.dict() - new_message = result['choices'][0]['message'] - new_message["loss"] = True - return new_message - except Exception as e: - print(f"Parsing Exception: {repr(e)}. Try again.") - return { - "role": "assistant", - "content": result, - "tool_calls": [], - "loss": True - } + async def generate_async(self, messages_list_or_inputs: List[List[Dict]], **args): + responses = await super().generate_async(messages_list_or_inputs, tool_choice="auto", **args) + return responses - # Prepare data for each message - message_data_list = [ - (message, self.api_key, self.model, tool_schemas, args) - for message in messages_list - ] - pool = Pool() - results = pool.map(process_message, message_data_list) - pool.close() - pool.join() - - return results - + def parse(self, responses: Union[Dict[str, List], List[str]], tools: List[Any], **args) -> List[Dict]: + """ + Parse responses into the correct message format. + + Args: + responses: List of response strings from the LLM. + tools: List of tools available to the agent. + **args: Additional arguments for parsing. + + Returns: + List of assistant messages in the correct format. + """ + + new_messages = [] + if isinstance(responses, dict): + for response, tool_calls in zip(responses["response_texts"], responses["tool_calls"]): + new_message = {"role": "assistant"} + + new_message["content"] = response + if len(tool_calls) > 0: + tool_calls = tool_calls[:1] + new_message["tool_calls"] = tool_calls + + new_messages.append(new_message) + elif isinstance(responses, list): + for content in responses: + new_message = {"role": "assistant", "content": content} + new_messages.append(new_message) + else: + raise ValueError(f"Invalid responses type: {type(responses)}") + + return new_messages - @retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3)) - def chat_completion_request( - self, - messages, - tools=None, - tool_choice=None, - model=None, - stop=None, - client=None, - **args - ): - if model is None: - model = self.model - if client is None: - client = self.client + # @retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(1)) + # def chat_completion_request( + # self, + # messages, + # tools=None, + # tool_choice=None, + # model=None, + # stop=None, + # client=None, + # **args + # ): + # if model is None: + # model = self.model + # if client is None: + # client = self.client - json_data = { - "model": model, - "messages": messages, - **args - } - if stop is not None: - json_data.update({"stop": stop}) - if tools is not None: - json_data.update({"tools": tools}) - if tool_choice is not None: - json_data.update({"tool_choice": tool_choice}) + # json_data = { + # "model": model, + # "messages": messages, + # **args + # } + # if stop is not None: + # json_data.update({"stop": stop}) + # if tools is not None: + # json_data.update({"tools": tools}) + # if tool_choice is not None: + # json_data.update({"tool_choice": tool_choice}) - try: - # We use chat completion API - openai_response = client.chat.completions.create(**json_data) - json_data = openai_response.dict() - return json_data - except Exception as e: - print(f"Unable to generate ChatCompletion response: {e}") - raise e + # try: + # # We use chat completion API + # openai_response = client.chat.completions.create(**json_data) + # json_data = openai_response.dict() + # return json_data + # except Exception as e: + # print(f"Unable to generate ChatCompletion response: {e}") + # raise e @tool() @@ -129,17 +143,23 @@ def get_current_weather(location: str, unit: str="fahrenheit"): if __name__ == "__main__": agent = OpenAIAgent( - model_name_or_path="gpt-3.5-turbo", - api_key="", - tools=[get_current_weather] + model_name_or_path="gpt-4o-mini", + api_key="OpenAI API Key", + tools=[get_current_weather,answer_qa] ) - messages = [{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}] - agent.run( - max_steps=3, + messages = [ + { + "messages": [ + {"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"} + ] + } + ] + asyncio.run(agent.run_async( + max_steps=5, start_messages=messages, num_chains=1 - ) + )) trajectories = agent.trajectories - print(trajectories[0]["messages"]) + print(f"""Trajectory: {trajectories[0]["messages"]}""") diff --git a/agents/agents/agents/templates/utils.py b/agents/agents/agents/templates/utils.py index ccb2258..3a11320 100644 --- a/agents/agents/agents/templates/utils.py +++ b/agents/agents/agents/templates/utils.py @@ -31,10 +31,10 @@ def convert_messages_to_openai_format(messages: list) -> list: """ messages = copy.deepcopy(messages) for message in messages: - if "tool_calls" in message: - del message["tool_calls"] - if "tool_call_id" in message: - del message["tool_call_id"] + # if "tool_calls" in message: + # del message["tool_calls"] + # if "tool_call_id" in message: + # del message["tool_call_id"] if "tool_choice" in message: del message["tool_choice"] return messages @@ -138,8 +138,16 @@ def tokenize_conversation( :param max_length: :return: input_ids, attention_mask, labels, action_mask """ - chat = Chat(template=template, messages=messages, tokenizer=tokenizer) - inputs = chat.tokenize(tokenizer, add_generation_prompt=add_generation_prompt, tools=tools) + # Check if tokenizer is our interface or a HuggingFace tokenizer + if hasattr(tokenizer, 'tokenizer'): # Our interface + # Use the underlying HuggingFace tokenizer for Chat template + hf_tokenizer = tokenizer.tokenizer + else: # Direct HuggingFace tokenizer + hf_tokenizer = tokenizer + + chat = Chat(template=template, messages=messages, tokenizer=hf_tokenizer) + inputs = chat.tokenize(hf_tokenizer, add_generation_prompt=add_generation_prompt, tools=tools) + if max_length is not None: inputs['input_ids'] = inputs['input_ids'][:, :max_length] inputs['attention_mask'] = inputs['attention_mask'][:, :max_length] @@ -245,7 +253,9 @@ def tokenize_conversations(messages_list, tokenizer, conv_template, max_length, batch_action_masks.append(inputs['action_mask'].squeeze(0)) if return_tensors == "pt": - batch_input_ids = torch.nn.utils.rnn.pad_sequence(batch_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id) + # Use pad_token_id from the tokenizer interface + pad_token_id = getattr(tokenizer, 'pad_token_id', 0) + batch_input_ids = torch.nn.utils.rnn.pad_sequence(batch_input_ids, batch_first=True, padding_value=pad_token_id) batch_attention_masks = torch.nn.utils.rnn.pad_sequence(batch_attention_masks, batch_first=True, padding_value=0) batch_labels = torch.nn.utils.rnn.pad_sequence(batch_labels, batch_first=True, padding_value=-100) batch_action_masks = torch.nn.utils.rnn.pad_sequence(batch_action_masks, batch_first=True, padding_value=0) @@ -312,7 +322,9 @@ def vllm_serve(model_name_or_path, template, tp, pp, dp): os.makedirs(f"{AGENT_DATA_DIR}/cache") with open(f"{AGENT_DATA_DIR}/cache/jinja_template.jinja", "w") as f: f.write(jinja_template) - command = f"vllm serve {model_name_or_path} --chat-template {AGENT_DATA_DIR}/cache/jinja_template.jinja --tensor-parallel-size {tp} --pipeline-parallel-size {pp} --data-parallel-size {dp} --port {port} --enable-auto-tool-choice --tool-call-parser hermes --expand-tools-even-if-tool-choice-none" + # command = f"vllm serve {model_name_or_path} --chat-template {AGENT_DATA_DIR}/cache/jinja_template.jinja --tensor-parallel-size {tp} --pipeline-parallel-size {pp} --data-parallel-size {dp} --port {port} --enable-auto-tool-choice --tool-call-parser hermes --expand-tools-even-if-tool-choice-none" + command = f"vllm serve {model_name_or_path} --tensor-parallel-size {tp} --pipeline-parallel-size {pp} --data-parallel-size {dp} --port {port} --enable-auto-tool-choice --tool-call-parser hermes --expand-tools-even-if-tool-choice-none" + print(command) os.system(command) @@ -321,5 +333,6 @@ def vllm_serve(model_name_or_path, template, tp, pp, dp): "python -m agents.agents.templates.utils" # model = "/mnt/sharefs/users/haonan.li/models/Qwen2.5-7B-instruct-am_think_v1_distilled" model = "Qwen/Qwen2.5-7B-Instruct" - vllm_serve(model, "qwen2.5-think", 2, 1, 4) + # vllm_serve(model, "qwen2.5-think", 2, 1, 4) + vllm_serve(model, "qwen2.5", 1, 1, 1) diff --git a/agents/agents/agents/utils/tokenizer.py b/agents/agents/agents/utils/tokenizer.py new file mode 100644 index 0000000..00ab8fe --- /dev/null +++ b/agents/agents/agents/utils/tokenizer.py @@ -0,0 +1,10 @@ +from transformers import AutoTokenizer + +def create_tokenizer(model_name_or_path: str): + try: + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + # Can not find the tokenizer in local directory or huggingface hub + except OSError: + tokenizer = None + + return tokenizer diff --git a/agents/agents/examples/streaming_example.py b/agents/agents/examples/streaming_example.py deleted file mode 100644 index c4fd08b..0000000 --- a/agents/agents/examples/streaming_example.py +++ /dev/null @@ -1,59 +0,0 @@ -import asyncio -from agents.agents.react.react_agent import ReactAgent -from agents.tools import code_interpreter -from agents.rewards import math_reward_tool -import json -from agents.agents.chain.streaming_observer import ConsoleStreamObserver - - -async def main(): - tools = [code_interpreter] - - agent = ReactAgent( - "Qwen/Qwen2.5-7B-Instruct", - tools=tools, - template="qwen2.5-no-tool", - backend="async_vllm", - reward_fn=math_reward_tool, - debug=True - ) - - - console_stream_observer = ConsoleStreamObserver() - agent.streaming_manager.add_observer(console_stream_observer) - - - question1 = "Every morning Aya goes for a $9$-kilometer-long walk and stops at a coffee shop afterwards. When she walks at a constant speed of $s$ kilometers per hour, the walk takes her 4 hours, including $t$ minutes spent in the coffee shop. When she walks $s+2$ kilometers per hour, the walk takes her 2 hours and 24 minutes, including $t$ minutes spent in the coffee shop. Suppose Aya walks at $s+\frac{1}{2}$ kilometers per hour. Find the number of minutes the walk takes her, including the $t$ minutes spent in the coffee shop." - answer1 = "204" - question2 = "$P(x)$ is a polynomial of degree $3n$ such that\n\\begin{eqnarray*} P(0) = P(3) = \\cdots &=& P(3n) = 2, \\\\ P(1) = P(4) = \\cdots &=& P(3n-2) = 1, \\\\ P(2) = P(5) = \\cdots &=& P(3n-1) = 0, \\quad\\text{ and }\\\\ && P(3n+1) = 730.\\end{eqnarray*}\nDetermine $n$." - answer2 = "3" - - - messages = [ - { - "messages": [ - {"role": "user", "content": f"{question1}"} - ], - "question": f"{question1}", - "answer": f"{answer1}" - }, - # { - # "messages": [ - # {"role": "user", "content": f"{question2}"} - # ], - # "question": f"{question2}", - # "answer": f"{answer2}" - # } - ] - - - await agent.run_async( - max_steps=4, - start_messages=messages, - num_chains=1, - enable_streaming=True - ) - -if __name__ == "__main__": - asyncio.run(main()) - From 816d7d3cb0625c2fe3905cf04fecfe1ad686f7aa Mon Sep 17 00:00:00 2001 From: Reason-Wang Date: Thu, 7 Aug 2025 15:32:52 +0000 Subject: [PATCH 2/6] Add vision templates --- agents/agents/agents/agent_base.py | 5 +- agents/agents/agents/llm_backend.py:6 | 0 agents/agents/agents/llm_backend.py:6: | 0 agents/agents/agents/templates/mm_plugin.py | 1912 +++++++++++++++++ agents/agents/agents/templates/templates.py | 214 +- .../agents/agents/templates/test_alignment.py | 170 ++ agents/agents/agents/templates/utils.py | 164 +- .../agents/templates/vision_processor.py | 651 ++++++ .../agents/prompts/test_qwen25_vl_prompt.py | 138 -- .../unit/agents/prompts/test_qwen3_prompt.py | 81 - .../agents/prompts/test_template_tokenize.py | 62 - .../unit/agents/prompts/test_think_prompt.py | 29 - .../agents/templates/test_qwen25_vl_prompt.py | 0 .../agents/templates/test_qwen3_prompt.py | 0 .../templates/test_template_utilities.py | 33 + .../test_text_templates_full_align.py | 0 .../test_text_templates_tokenize.py} | 40 +- .../agents/templates/test_think_prompt.py | 0 .../test_vision_templates_full_align.py | 87 + .../test_vision_templates_tokenize.py | 107 + 20 files changed, 3217 insertions(+), 476 deletions(-) create mode 100644 agents/agents/agents/llm_backend.py:6 create mode 100644 agents/agents/agents/llm_backend.py:6: create mode 100644 agents/agents/agents/templates/mm_plugin.py create mode 100644 agents/agents/agents/templates/test_alignment.py create mode 100644 agents/agents/agents/templates/vision_processor.py delete mode 100644 agents/tests/unit/agents/prompts/test_qwen25_vl_prompt.py delete mode 100644 agents/tests/unit/agents/prompts/test_qwen3_prompt.py delete mode 100644 agents/tests/unit/agents/prompts/test_template_tokenize.py delete mode 100644 agents/tests/unit/agents/prompts/test_think_prompt.py create mode 100644 agents/tests/unit/agents/templates/test_qwen25_vl_prompt.py create mode 100644 agents/tests/unit/agents/templates/test_qwen3_prompt.py create mode 100644 agents/tests/unit/agents/templates/test_template_utilities.py create mode 100644 agents/tests/unit/agents/templates/test_text_templates_full_align.py rename agents/tests/unit/agents/{prompts/test_templates.py => templates/test_text_templates_tokenize.py} (50%) create mode 100644 agents/tests/unit/agents/templates/test_think_prompt.py create mode 100644 agents/tests/unit/agents/templates/test_vision_templates_full_align.py create mode 100644 agents/tests/unit/agents/templates/test_vision_templates_tokenize.py diff --git a/agents/agents/agents/agent_base.py b/agents/agents/agents/agent_base.py index 0d712ce..67c9a6b 100644 --- a/agents/agents/agents/agent_base.py +++ b/agents/agents/agents/agent_base.py @@ -9,7 +9,8 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch -from .templates.utils import is_vlm_template, tokenize_conversations +from .templates.utils import tokenize_conversations +from .templates.vision_processor import is_vision_template from .chain.chain_base import ChainGeneration import os import transformers @@ -100,7 +101,7 @@ def _init_llm_engine(self, model_name_or_path: str, backend: str): raise ValueError("model_name_or_path must be a string.") tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path) - if is_vlm_template(self.template): + if is_vision_template(self.template): processor = transformers.AutoProcessor.from_pretrained(model_name_or_path) else: processor = None diff --git a/agents/agents/agents/llm_backend.py:6 b/agents/agents/agents/llm_backend.py:6 new file mode 100644 index 0000000..e69de29 diff --git a/agents/agents/agents/llm_backend.py:6: b/agents/agents/agents/llm_backend.py:6: new file mode 100644 index 0000000..e69de29 diff --git a/agents/agents/agents/templates/mm_plugin.py b/agents/agents/agents/templates/mm_plugin.py new file mode 100644 index 0000000..be8cea9 --- /dev/null +++ b/agents/agents/agents/templates/mm_plugin.py @@ -0,0 +1,1912 @@ +# Copyright 2025 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's Transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/processing_llava.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +import os +import re +from copy import deepcopy +from dataclasses import dataclass +from io import BytesIO +from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union + +import numpy as np +import torch +from transformers.image_utils import get_image_size, is_valid_image, to_numpy_array +from transformers.models.mllama.processing_mllama import ( + convert_sparse_cross_attention_mask_to_dense, + get_cross_attention_token_mask, +) +from typing_extensions import override + +from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER +from ..extras.packages import ( + is_librosa_available, + is_pillow_available, + is_pyav_available, + is_transformers_version_greater_than, +) + + +if is_librosa_available(): + import librosa + + +if is_pillow_available(): + from PIL import Image + from PIL.Image import Image as ImageObject + + +if is_pyav_available(): + import av + + +if is_transformers_version_greater_than("4.52.0"): + from transformers.image_utils import make_flat_list_of_images + from transformers.video_utils import make_batched_videos +else: + from transformers.image_utils import make_batched_videos, make_flat_list_of_images + + +if TYPE_CHECKING: + from av.stream import Stream + from numpy.typing import NDArray + from transformers import PreTrainedTokenizer, ProcessorMixin + from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor + from transformers.image_processing_utils import BaseImageProcessor + + class EncodedImage(TypedDict): + path: Optional[str] + bytes: Optional[bytes] + + ImageInput = Union[str, bytes, EncodedImage, BinaryIO, ImageObject] + VideoInput = Union[str, BinaryIO, list[list[ImageInput]]] + AudioInput = Union[str, BinaryIO, NDArray] + + class MMProcessor(ProcessorMixin): + patch_size: int + image_seq_length: int + num_additional_image_tokens: int + vision_feature_select_strategy: Literal["default", "full"] + + def _get_number_of_features(self, orig_height: int, orig_width: int, height: int, width: int) -> int: + pass + + +def _get_paligemma_token_type_ids(imglens: list[int], seqlens: list[int], processor: "MMProcessor") -> list[list[int]]: + r"""Get paligemma token type ids for computing loss. + + It is slightly different with the original token type ids where the prompt part is 0. + + Returns: + batch_token_type_ids: shape (batch_size, seq_length) + + """ + batch_token_type_ids = [] + for imglen, seqlen in zip(imglens, seqlens): + image_seqlen = imglen * processor.image_seq_length + batch_token_type_ids.append([0] * image_seqlen + [1] * (seqlen - image_seqlen)) + + return batch_token_type_ids + + +def _get_gemma3_token_type_ids(batch_ids: list[list[int]], processor: "MMProcessor"): + r"""Get gemma3 token type ids for computing loss. + + Returns: + batch_token_type_ids: shape (batch_size, seq_length) + + """ + image_token_id: int = getattr(processor, "image_token_id") + batch_token_type_ids = [] + for token_ids in batch_ids: + token_ids = np.array(token_ids) + token_type_ids = np.zeros_like(token_ids) + token_type_ids[token_ids == image_token_id] = 1 + batch_token_type_ids.append(token_type_ids.tolist()) + + return batch_token_type_ids + + +def _make_batched_images(images: list["ImageObject"], imglens: list[int]) -> list[list["ImageObject"]]: + r"""Make nested list of images.""" + batch_images = [] + for imglen in imglens: + batch_images.append(images[:imglen]) + images = images[imglen:] + + return batch_images + + +def _check_video_is_nested_images(video: "VideoInput") -> bool: + r"""Check if the video is nested images.""" + return isinstance(video, list) and all(isinstance(frame, (str, BinaryIO, dict)) for frame in video) + + +@dataclass +class MMPluginMixin: + image_token: Optional[str] + video_token: Optional[str] + audio_token: Optional[str] + expand_mm_tokens: bool = True + + def _validate_input( + self, + processor: Optional["MMProcessor"], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + ) -> None: + r"""Validate if this model accepts the input modalities.""" + image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) + video_processor: BaseImageProcessor = getattr( + processor, "video_processor", getattr(processor, "image_processor", None) + ) + feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) + if len(images) != 0 and self.image_token is None: + raise ValueError( + "This model does not support image input. Please check whether the correct `template` is used." + ) + + if len(videos) != 0 and self.video_token is None: + raise ValueError( + "This model does not support video input. Please check whether the correct `template` is used." + ) + + if len(audios) != 0 and self.audio_token is None: + raise ValueError( + "This model does not support audio input. Please check whether the correct `template` is used." + ) + + if self.image_token is not None and processor is None: + raise ValueError("Processor was not found, please check and update your model file.") + + if self.image_token is not None and image_processor is None: + raise ValueError("Image processor was not found, please check and update your model file.") + + if self.video_token is not None and video_processor is None: + raise ValueError("Video processor was not found, please check and update your model file.") + + if self.audio_token is not None and feature_extractor is None: + raise ValueError("Audio feature extractor was not found, please check and update your model file.") + + def _validate_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + ): + r"""Validate if the number of images, videos and audios match the number of placeholders in messages.""" + num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0 + for message in messages: + num_image_tokens += message["content"].count(IMAGE_PLACEHOLDER) + num_video_tokens += message["content"].count(VIDEO_PLACEHOLDER) + num_audio_tokens += message["content"].count(AUDIO_PLACEHOLDER) + + if len(images) != num_image_tokens: + raise ValueError( + f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens in {messages}." + ) + + if len(videos) != num_video_tokens: + raise ValueError( + f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens in {messages}." + ) + + if len(audios) != num_audio_tokens: + raise ValueError( + f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens in {messages}." + ) + + def _preprocess_image( + self, image: "ImageObject", image_max_pixels: int, image_min_pixels: int, **kwargs + ) -> "ImageObject": + r"""Pre-process a single image.""" + if (image.width * image.height) > image_max_pixels: + resize_factor = math.sqrt(image_max_pixels / (image.width * image.height)) + width, height = int(image.width * resize_factor), int(image.height * resize_factor) + image = image.resize((width, height)) + + if (image.width * image.height) < image_min_pixels: + resize_factor = math.sqrt(image_min_pixels / (image.width * image.height)) + width, height = int(image.width * resize_factor), int(image.height * resize_factor) + image = image.resize((width, height)) + + if image.mode != "RGB": + image = image.convert("RGB") + + return image + + def _get_video_sample_indices( + self, video_stream: "Stream", video_fps: float, video_maxlen: int, **kwargs + ) -> list[int]: + r"""Compute video sample indices according to fps.""" + total_frames = video_stream.frames + if total_frames == 0: # infinite video + return np.linspace(0, video_maxlen - 1, video_maxlen).astype(np.int32) + + sample_frames = max(1, math.floor(float(video_stream.duration * video_stream.time_base) * video_fps)) + sample_frames = min(total_frames, video_maxlen, sample_frames) + return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32) + + def _regularize_images(self, images: list["ImageInput"], **kwargs) -> dict[str, list["ImageObject"]]: + r"""Regularize images to avoid error. Including reading and pre-processing.""" + results = [] + for image in images: + if isinstance(image, (str, BinaryIO)): + image = Image.open(image) + elif isinstance(image, bytes): + image = Image.open(BytesIO(image)) + elif isinstance(image, dict): + if image["bytes"] is not None: + image = Image.open(BytesIO(image["bytes"])) + else: + image = Image.open(image["path"]) + + if not isinstance(image, ImageObject): + raise ValueError(f"Expect input is a list of images, but got {type(image)}.") + + results.append(self._preprocess_image(image, **kwargs)) + + return {"images": results} + + def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> dict[str, list[list["ImageObject"]]]: + r"""Regularizes videos to avoid error. Including reading, resizing and converting.""" + results = [] + for video in videos: + frames: list[ImageObject] = [] + if _check_video_is_nested_images(video): + for frame in video: + if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame): + raise ValueError("Invalid image found in video frames.") + frames = video + else: + container = av.open(video, "r") + video_stream = next(stream for stream in container.streams if stream.type == "video") + sample_indices = self._get_video_sample_indices(video_stream, **kwargs) + container.seek(0) + for frame_idx, frame in enumerate(container.decode(video_stream)): + if frame_idx in sample_indices: + frames.append(frame.to_image()) + + frames = self._regularize_images(frames, **kwargs)["images"] + results.append(frames) + + return {"videos": results} + + def _regularize_audios( + self, audios: list["AudioInput"], sampling_rate: float, **kwargs + ) -> dict[str, Union[list["NDArray"], list[float]]]: + r"""Regularizes audios to avoid error. Including reading and resampling.""" + results, sampling_rates = [], [] + for audio in audios: + if not isinstance(audio, np.ndarray): + audio, sampling_rate = librosa.load(audio, sr=sampling_rate) + + results.append(audio) + sampling_rates.append(sampling_rate) + + return {"audios": results, "sampling_rates": sampling_rates} + + def _get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: "MMProcessor", + imglens: Optional[list[int]] = None, + ) -> dict[str, "torch.Tensor"]: + r"""Process visual inputs. + + Returns: (llava and paligemma) + pixel_values: tensor with shape (B, C, H, W) + + Returns: (qwen2-vl) + pixel_values: tensor with shape (num_patches, patch_dim) + image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height + where num_patches == torch.prod(image_grid_thw) + + Returns: (mllama) + pixel_values: tensor with shape + (batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width) + For example, (2, 1, 4, 3, 560, 560). + aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1). + aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4). + num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1). + + """ + mm_inputs = {} + if len(images) != 0: + image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) + images = self._regularize_images( + images, + image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), + image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), + )["images"] + if imglens is not None: # if imglens are provided, make batched images + images = _make_batched_images(images, imglens) + + image_processor_kwargs = {} + if getattr(processor, "image_do_pan_and_scan", False): # gemma3 image processor + image_processor_kwargs.update( + { + "do_pan_and_scan": True, + "pan_and_scan_min_crop_size": 256, + "pan_and_scan_max_num_crops": 4, + "pan_and_scan_min_ratio_to_activate": 1.2, + } + ) + + mm_inputs.update(image_processor(images, return_tensors="pt", **image_processor_kwargs)) + + if len(videos) != 0: + video_processor: BaseImageProcessor = getattr( + processor, "video_processor", getattr(processor, "image_processor", None) + ) + videos = self._regularize_videos( + videos, + image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256), + image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16), + video_fps=getattr(processor, "video_fps", 2.0), + video_maxlen=getattr(processor, "video_maxlen", 128), + )["videos"] + if "videos" in inspect.signature(video_processor.preprocess).parameters: # for qwen2_vl and video_llava + mm_inputs.update(video_processor(images=None, videos=videos, return_tensors="pt")) + else: # for llava_next_video + mm_inputs.update(video_processor(videos, return_tensors="pt")) + + if len(audios) != 0: + feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) + audios = self._regularize_audios( + audios, + sampling_rate=getattr(processor, "audio_sampling_rate", 16000), + )["audios"] + mm_inputs.update( + feature_extractor( + audios, + sampling_rate=getattr(processor, "audio_sampling_rate", 16000), + return_attention_mask=True, + padding="max_length", + return_tensors="pt", + ) + ) + mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask", None) # prevent conflicts + + return mm_inputs + + +@dataclass +class BasePlugin(MMPluginMixin): + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + r"""Pre-process input messages before tokenization for VLMs.""" + self._validate_input(processor, images, videos, audios) + return messages + + def process_token_ids( + self, + input_ids: list[int], + labels: Optional[list[int]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + tokenizer: "PreTrainedTokenizer", + processor: Optional["MMProcessor"], + ) -> tuple[list[int], Optional[list[int]]]: + r"""Pre-process token ids after tokenization for VLMs.""" + self._validate_input(processor, images, videos, audios) + return input_ids, labels + + def get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + imglens: list[int], + vidlens: list[int], + audlens: list[int], + batch_ids: list[list[int]], + processor: Optional["MMProcessor"], + ) -> dict[str, Union[list[int], "torch.Tensor"]]: + r"""Build batched multimodal inputs for VLMs. + + Arguments: + images: a list of image inputs, shape (num_images,) + videos: a list of video inputs, shape (num_videos,) + audios: a list of audio inputs, shape (num_audios,) + imglens: number of images in each sample, shape (batch_size,) + vidlens: number of videos in each sample, shape (batch_size,) + audlens: number of audios in each sample, shape (batch_size,) + batch_ids: token ids of input samples, shape (batch_size, seq_len) + processor: a processor for pre-processing images and videos + + """ + self._validate_input(processor, images, videos, audios) + return self._get_mm_inputs(images, videos, audios, processor) + + +@dataclass +class Gemma3Plugin(BasePlugin): + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + num_image_tokens = 0 + messages = deepcopy(messages) + boi_token: str = getattr(processor, "boi_token") + full_image_sequence: str = getattr(processor, "full_image_sequence") + image_str = full_image_sequence if self.expand_mm_tokens else boi_token + + do_pan_and_scan: bool = getattr(processor, "image_do_pan_and_scan", False) + if do_pan_and_scan: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + if do_pan_and_scan: + image_placeholder_str = ( + "Here is the original image {{image}} and here are some crops to help you see better " + + " ".join(["{{image}}"] * mm_inputs["num_crops"][0][num_image_tokens]) + ) + else: + image_placeholder_str = "{{image}}" + + content = content.replace(IMAGE_PLACEHOLDER, image_placeholder_str, 1) + num_image_tokens += 1 + + message["content"] = content.replace("{{image}}", image_str) + + return messages + + @override + def get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + imglens: list[int], + vidlens: list[int], + audlens: list[int], + batch_ids: list[list[int]], + processor: Optional["MMProcessor"], + ) -> dict[str, Union[list[int], "torch.Tensor"]]: + self._validate_input(processor, images, videos, audios) + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + mm_inputs.pop("num_crops", None) + mm_inputs["token_type_ids"] = _get_gemma3_token_type_ids(batch_ids, processor) + return mm_inputs + + +class Gemma3nPlugin(Gemma3Plugin): + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + messages = deepcopy(messages) + boi_token: str = getattr(processor, "boi_token") + boa_token: str = getattr(processor, "boa_token") + full_image_sequence: str = getattr(processor, "full_image_sequence") + full_audio_sequence: str = getattr(processor, "full_audio_sequence") + image_str = full_image_sequence if self.expand_mm_tokens else boi_token + audio_str = full_audio_sequence if self.expand_mm_tokens else boa_token + + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + content = content.replace(IMAGE_PLACEHOLDER, image_str, 1) + + while AUDIO_PLACEHOLDER in content: + content = content.replace(AUDIO_PLACEHOLDER, audio_str, 1) + + message["content"] = content + + return messages + + +@dataclass +class InternVLPlugin(BasePlugin): + @override + def _get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: "ProcessorMixin", + **kwargs, + ) -> dict[str, "torch.Tensor"]: + image_processor: BaseImageProcessor = getattr(processor, "image_processor") + image_processor_kwargs = {} + if getattr(processor, "crop_to_patches", False): + image_processor_kwargs.update( + { + "crop_to_patches": True, + "max_patches": 12, + "min_patches": 1, + } + ) + + mm_inputs = {} + image_video_patches = [] + + if len(images) != 0: + images = self._regularize_images( + images, + image_max_pixels=getattr(processor, "image_max_pixels", 1024 * 1024), + image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), + )["images"] + + if len(videos) != 0: + videos = self._regularize_videos( + videos, + image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256), + image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16), + video_fps=getattr(processor, "video_fps", 2.0), + video_maxlen=getattr(processor, "video_maxlen", 128), + )["videos"] + + if len(images) != 0: + images = make_flat_list_of_images(images) + image_inputs = image_processor(images=images, return_tensors="pt", **image_processor_kwargs) + image_num_patches = image_inputs.pop("num_patches") + image_pixel_values = image_inputs.pop("pixel_values") + image_num_patches_indices = np.cumsum(image_num_patches) + + if len(videos) != 0: + videos = make_batched_videos(videos) + num_frames_per_video = [len(video) for video in videos] + patch_indices = np.cumsum(num_frames_per_video) + image_processor_kwargs["crop_to_patches"] = False + video_inputs = image_processor(images=videos, return_tensors="pt", **image_processor_kwargs) + video_num_patches = video_inputs.pop("num_patches") + video_pixel_values = video_inputs.pop("pixel_values") + video_num_patches_indices = np.cumsum(video_num_patches) + + # NOT SUPPORT IMAGE VIDEO INTERLEAVED + if len(images) != 0 and image_pixel_values is not None: + for i in range(len(images)): + start_index = image_num_patches_indices[i - 1] if i > 0 else 0 + end_index = image_num_patches_indices[i] + image_video_patches.append(image_pixel_values[start_index:end_index]) + + if len(videos) != 0 and video_pixel_values is not None: + patch_indices_with_prefix = [0] + list(patch_indices) + for i in range(len(videos)): + current_patch_index = patch_indices_with_prefix[i] + end_patch_index = patch_indices_with_prefix[i + 1] + start_index = video_num_patches_indices[current_patch_index - 1] if i > 0 else 0 + end_index = video_num_patches_indices[end_patch_index - 1] + image_video_patches.append(video_pixel_values[start_index:end_index]) + + if len(images) != 0 or len(videos) != 0: + mm_inputs["pixel_values"] = torch.cat(image_video_patches, dim=0) + + if len(images) != 0: + mm_inputs.update({"image_num_patches": image_num_patches}) + + if len(videos) != 0: + mm_inputs.update({"video_patch_indices": patch_indices}) + mm_inputs.update({"video_num_patches": video_num_patches}) + + return mm_inputs + + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["ProcessorMixin"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + num_image_tokens, num_video_tokens = 0, 0 + image_seqlen = getattr(processor, "image_seq_length") if self.expand_mm_tokens else 1 + messages = deepcopy(messages) + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + + image_pixel_patch_list = mm_inputs.get("image_num_patches") # pathes of images + video_num_patches = mm_inputs.get("video_num_patches") # all patches for frames of videos + video_patch_indices = mm_inputs.get("video_patch_indices") # num frames of per video + + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + content = content.replace( + IMAGE_PLACEHOLDER, + f"{'' * image_seqlen * image_pixel_patch_list[num_image_tokens]}", + 1, + ) + num_image_tokens += 1 + + while VIDEO_PLACEHOLDER in content: + current_patch_index = video_patch_indices[num_video_tokens - 1] if num_video_tokens > 0 else 0 + end_patch_index = video_patch_indices[num_video_tokens] + num_patches = list(video_num_patches[current_patch_index:end_patch_index]) + video_replaced_prompt = "\n".join( + f"Frame{i + 1}: {'' * image_seqlen * num_patches[i]}" + for i in range(len(num_patches)) + ) + content = content.replace(VIDEO_PLACEHOLDER, video_replaced_prompt, 1) + num_video_tokens += 1 + + message["content"] = content + + return messages + + @override + def get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + imglens: list[int], + vidlens: list[int], + audlens: list[int], + batch_ids: list[list[int]], + processor: Optional["ProcessorMixin"], + ) -> dict[str, Union[list[int], "torch.Tensor"]]: + self._validate_input(processor, images, videos, audios) + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + mm_inputs.pop("image_num_patches", None) + mm_inputs.pop("video_patch_indices", None) + mm_inputs.pop("video_num_patches", None) + return mm_inputs + + +class KimiVLPlugin(BasePlugin): + @override + def process_messages(self, messages, images, videos, audios, processor): + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + image_grid_hws = mm_inputs.get("image_grid_hws", []) + else: + image_grid_hws = [None] * len(images) + + num_image_tokens = 0 + image_processor: BaseImageProcessor = getattr(processor, "image_processor") + merge_length = math.prod(image_processor.merge_kernel_size) + messages = deepcopy(messages) + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + image_seqlen = image_grid_hws[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1 + content = content.replace( + IMAGE_PLACEHOLDER, + f"<|media_start|>image<|media_content|>{self.image_token * image_seqlen}<|media_end|>", + 1, + ) + num_image_tokens += 1 + + message["content"] = content + + return messages + + +@dataclass +class Llama4Plugin(BasePlugin): + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + if "pixel_values" in mm_inputs: + image_height, image_width = mm_inputs["pixel_values"][0].shape[-2:] + num_patches_per_chunk = int( + (image_height // processor.patch_size) + * (image_width // processor.patch_size) + // processor.downsample_ratio + ) + aspect_ratios = mm_inputs.pop("aspect_ratios") + + num_image_tokens = 0 + messages = deepcopy(messages) + for message in messages: + content = message["content"] + if self.expand_mm_tokens: + placeholder_count = content.count(IMAGE_PLACEHOLDER) + prompt_splits = content.split(IMAGE_PLACEHOLDER) + new_content = [] + for local_image_index, split_part in enumerate(prompt_splits): + new_content.append(split_part) + if local_image_index < placeholder_count: + tokens_for_this_image = processor._prompt_split_image( + aspect_ratios[num_image_tokens], num_patches_per_chunk + ) + num_image_tokens += 1 + new_content.append(tokens_for_this_image) + + content = "".join(new_content) + else: + content = content.replace(IMAGE_PLACEHOLDER, self.image_token) + + message["content"] = content + + return messages + + @override + def get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + imglens: list[int], + vidlens: list[int], + audlens: list[int], + batch_ids: list[list[int]], + processor: Optional["MMProcessor"], + ) -> dict[str, Union[list[int], "torch.Tensor"]]: + self._validate_input(processor, images, videos, audios) + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + mm_inputs.pop("aspect_ratios", None) + return mm_inputs + + +@dataclass +class LlavaPlugin(BasePlugin): + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + messages = deepcopy(messages) + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + if "pixel_values" in mm_inputs: + height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0])) + image_seqlen = (height // processor.patch_size) * ( + width // processor.patch_size + ) + processor.num_additional_image_tokens + if processor.vision_feature_select_strategy == "default": + image_seqlen -= 1 + else: + image_seqlen = 1 + + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) + + message["content"] = content.replace("{{image}}", self.image_token) + + return messages + + +@dataclass +class LlavaNextPlugin(BasePlugin): + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + num_image_tokens = 0 + messages = deepcopy(messages) + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + if "pixel_values" in mm_inputs: + image_sizes = iter(mm_inputs["image_sizes"].tolist()) + height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0])) + + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + if self.expand_mm_tokens: + orig_height, orig_width = next(image_sizes) + image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width) + if processor.vision_feature_select_strategy == "default": + image_seqlen -= 1 + else: + image_seqlen = 1 + + content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) + num_image_tokens += 1 + + message["content"] = content.replace("{{image}}", self.image_token) + + return messages + + +@dataclass +class LlavaNextVideoPlugin(BasePlugin): + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + messages = deepcopy(messages) + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + if "pixel_values" in mm_inputs: + image_sizes = iter(mm_inputs["image_sizes"].tolist()) + height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0])) + + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + if self.expand_mm_tokens: + orig_height, orig_width = next(image_sizes) + image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width) + if processor.vision_feature_select_strategy == "default": + image_seqlen -= 1 + else: + image_seqlen = 1 + + content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) + + message["content"] = content.replace("{{image}}", self.image_token) + + if self.expand_mm_tokens: + if "pixel_values_videos" in mm_inputs: + one_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0]) + height, width = get_image_size(one_video[0]) + num_frames = one_video.shape[0] # frame dim is always after batch dim + image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer + else: + video_seqlen = 1 + + for message in messages: + content = message["content"] + while VIDEO_PLACEHOLDER in content: + content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1) + + message["content"] = content.replace("{{video}}", self.video_token) + + return messages + + +@dataclass +class MiniCPMVPlugin(BasePlugin): + @override + def _get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: "MMProcessor", + **kwargs, + ) -> dict[str, "torch.Tensor"]: + image_processor: BaseImageProcessor = getattr(processor, "image_processor") + mm_inputs = {} + if len(images) != 0: + images = self._regularize_images( + images, + image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), + image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), + )["images"] + if "valid_image_nums_ls" in kwargs: + valid_image_nums_ls = kwargs["valid_image_nums_ls"] + new_images = [] + idx = 0 + for valid_image_nums in valid_image_nums_ls: + new_images.append(images[idx : idx + valid_image_nums]) + idx += valid_image_nums + + images = new_images + + image_inputs = image_processor( + images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt" + ) + mm_inputs.update(image_inputs) + + if len(videos) != 0: + videos = self._regularize_videos( + videos, + image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256), + image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16), + video_fps=getattr(processor, "video_fps", 2.0), + video_maxlen=getattr(processor, "video_maxlen", 128), + )["videos"] + video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt") + mm_inputs.update(video_inputs) + + if len(audios) != 0: + audios = self._regularize_audios( + audios, + sampling_rate=getattr(processor, "audio_sampling_rate", 16000), + )["audios"] + if "valid_audio_nums_ls" in kwargs: + valid_audio_nums_ls = kwargs["valid_audio_nums_ls"] + audios_ls = [] + idx = 0 + for valid_audio_nums in valid_audio_nums_ls: + audios_ls.append(audios[idx : idx + valid_audio_nums]) + idx += valid_audio_nums + else: + audios_ls = [audios] + + audio_features, audio_feature_lens, audio_phs = processor.audio_feature_extract( + audios_ls, + chunk_input=True, + sampling_rate=getattr(processor, "audio_sampling_rate", 16000), + ) + audio_feature_lens = [torch.tensor(audio_feature_len) for audio_feature_len in audio_feature_lens] + mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens}) + if kwargs.get("ret_phs", False): + mm_inputs.update({"audio_phs": audio_phs}) + + return mm_inputs + + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0 + messages = deepcopy(messages) + image_processor: BaseImageProcessor = getattr(processor, "image_processor") + mm_inputs, audio_inputs = {}, {} + if len(images) != 0 and len(videos) != 0: + raise ValueError("MiniCPM-V model does not support input images and videos at the same time.") + + if len(videos) != 0: + max_slice_nums = 2 + use_image_id = False + mm_inputs = self._get_mm_inputs([], videos, [], processor) + else: + max_slice_nums = image_processor.max_slice_nums + use_image_id = image_processor.use_image_id + + for i, message in enumerate(messages): + content = message["content"] + while IMAGE_PLACEHOLDER in content: + content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) + num_image_tokens += 1 + + while VIDEO_PLACEHOLDER in content: + video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1 + content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1) + num_video_tokens += 1 + + while AUDIO_PLACEHOLDER in content: + content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1) + num_audio_tokens += 1 + + message["content"] = content.replace("{{image}}", "(./)").replace( + "{{audio}}", "()" + ) + + if len(images): + mm_inputs = self._get_mm_inputs(images, [], [], processor) + + if len(audios): + audio_inputs = self._get_mm_inputs([], [], audios, processor, ret_phs=True) + + if self.expand_mm_tokens and mm_inputs: + pattern = "(./)" + image_sizes = mm_inputs["image_sizes"] + idx = 0 + for index, message in enumerate(messages): + text = message["content"] + image_tags = re.findall(pattern, text) + text_chunks = text.split(pattern) + final_text = "" + for i in range(len(image_tags)): + final_text = ( + final_text + + text_chunks[i] + + image_processor.get_slice_image_placeholder( + image_sizes[0][idx], idx, max_slice_nums, use_image_id + ) + ) + idx += 1 + + final_text += text_chunks[-1] + messages[index]["content"] = final_text + + if self.expand_mm_tokens and audio_inputs: + pattern = "()" + idx = 0 + for index, message in enumerate(messages): + text = message["content"] + audio_tags = re.findall(pattern, text) + text_chunks = text.split(pattern) + final_text = "" + for i in range(len(audio_tags)): + audio_placeholder = audio_inputs["audio_phs"][0][idx] + final_text = final_text + text_chunks[i] + audio_placeholder + idx += 1 + + final_text += text_chunks[-1] + messages[index]["content"] = final_text + + return messages + + @override + def get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + imglens: list[int], + vidlens: list[int], + audlens: list[int], + batch_ids: list[list[int]], + processor: Optional["MMProcessor"], + ) -> dict[str, Union[list[int], "torch.Tensor"]]: + self._validate_input(processor, images, videos, audios) + # image bound + image_bounds_list = [] + valid_image_nums_ls = [] + for i, input_ids in enumerate(batch_ids): + input_ids_ = torch.tensor(input_ids) + start_cond = (input_ids_ == processor.tokenizer.im_start_id) | ( + input_ids_ == processor.tokenizer.slice_start_id + ) + end_cond = (input_ids_ == processor.tokenizer.im_end_id) | (input_ids_ == processor.tokenizer.slice_end_id) + image_start_tokens = torch.where(start_cond)[0] + image_start_tokens += 1 + image_end_tokens = torch.where(end_cond)[0] + valid_image_nums_ls.append(imglens[i]) + image_bounds = torch.hstack( + [ + image_start_tokens.unsqueeze(-1), + image_end_tokens.unsqueeze(-1), + ] + ) + image_bounds_list.append(image_bounds) + + mm_inputs = self._get_mm_inputs(images, videos, [], processor, valid_image_nums_ls=valid_image_nums_ls) + if "tgt_sizes" not in mm_inputs: + dummy_data = [torch.empty(0) for _ in range(len(batch_ids))] + mm_inputs.update({"tgt_sizes": dummy_data, "pixel_values": dummy_data, "image_sizes": dummy_data}) + + mm_inputs.update({"image_bound": image_bounds_list}) + + if len(audios) > 0: + # audio bound + audio_bounds_ls = [] + spk_bounds_ls = [] + valid_audio_nums_ls = [] + + for input_ids, audiolen in zip(batch_ids, audlens): + input_ids_ = torch.tensor(input_ids) + audio_start_idx = torch.where(input_ids_ == processor.tokenizer.audio_start_id)[0] + audio_end_idx = torch.where(input_ids_ == processor.tokenizer.audio_end_id)[0] + assert len(audio_start_idx) == len(audio_end_idx) + audio_bounds = torch.hstack([(audio_start_idx + 1).unsqueeze(-1), audio_end_idx.unsqueeze(-1)]) + audio_bounds_ls.append(audio_bounds) + valid_audio_nums_ls.append(audiolen) + + spk_start_idx = torch.where(input_ids_ == processor.tokenizer.spk_start_id)[0] + spk_end_idx = torch.where(input_ids_ == processor.tokenizer.spk_end_id)[0] + assert len(spk_start_idx) == len(spk_end_idx) + spk_bounds = torch.hstack([(spk_start_idx + 1).unsqueeze(-1), spk_end_idx.unsqueeze(-1)]) + spk_bounds_ls.append(spk_bounds) + + audio_inputs = self._get_mm_inputs([], [], audios, processor, valid_audio_nums_ls=valid_audio_nums_ls) + mm_inputs.update(audio_inputs) + mm_inputs.update({"audio_bounds": audio_bounds_ls, "spk_bounds": spk_bounds_ls}) + + return mm_inputs + + +@dataclass +class MllamaPlugin(BasePlugin): + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + num_image_tokens = 0 + messages = deepcopy(messages) + for message in messages: + content = message["content"] + num_image_tokens += content.count(IMAGE_PLACEHOLDER) + message["content"] = content.replace(IMAGE_PLACEHOLDER, self.image_token) + + return messages + + @override + def get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + imglens: list[int], + vidlens: list[int], + audlens: list[int], + batch_ids: list[list[int]], + processor: Optional["MMProcessor"], + ) -> dict[str, Union[list[int], "torch.Tensor"]]: + self._validate_input(processor, images, videos, audios) + mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens) + if mm_inputs: + num_tiles = mm_inputs.pop("num_tiles") + image_token_id: int = getattr(processor, "image_token_id") + max_image_tiles: int = getattr(processor.image_processor, "max_image_tiles") + cross_attention_token_mask = [ + get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids + ] + mm_inputs["cross_attention_mask"] = torch.from_numpy( + convert_sparse_cross_attention_mask_to_dense( + cross_attention_token_mask, + num_tiles=num_tiles, + max_num_tiles=max_image_tiles, + length=max(len(input_ids) for input_ids in batch_ids), + ) + ) # shape: (batch_size, length, max_num_images, max_num_tiles) + + return mm_inputs + + +@dataclass +class PaliGemmaPlugin(BasePlugin): + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + num_image_tokens = 0 + messages = deepcopy(messages) + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + content = content.replace(IMAGE_PLACEHOLDER, "", 1) + num_image_tokens += 1 + + message["content"] = content + + return messages + + @override + def process_token_ids( + self, + input_ids: list[int], + labels: Optional[list[int]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + tokenizer: "PreTrainedTokenizer", + processor: Optional["MMProcessor"], + ) -> tuple[list[int], Optional[list[int]]]: + self._validate_input(processor, images, videos, audios) + num_images = len(images) + image_seqlen = processor.image_seq_length if self.expand_mm_tokens else 0 # skip mm token + image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) + input_ids = [image_token_id] * num_images * image_seqlen + input_ids + if labels is not None: + labels = [IGNORE_INDEX] * num_images * image_seqlen + labels + + return input_ids, labels + + @override + def get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + imglens: list[int], + vidlens: list[int], + audlens: list[int], + batch_ids: list[list[int]], + processor: Optional["MMProcessor"], + ) -> dict[str, Union[list[int], "torch.Tensor"]]: + self._validate_input(processor, images, videos, audios) + seqlens = [len(input_ids) for input_ids in batch_ids] + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor) + return mm_inputs + + +@dataclass +class PixtralPlugin(BasePlugin): + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + messages = deepcopy(messages) + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + if "pixel_values" in mm_inputs: + # BC for transformers < 4.49.0 + if isinstance(mm_inputs["image_sizes"], list): + image_sizes = iter(mm_inputs["image_sizes"][0]) + else: + image_sizes = iter(mm_inputs["image_sizes"].tolist()) + + image_break_token: str = getattr(processor, "image_break_token") + image_end_token: str = getattr(processor, "image_end_token") + + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + if self.expand_mm_tokens: + patch_size = processor.patch_size * getattr(processor, "spatial_merge_size", 1) + height, width = next(image_sizes) + num_height_tokens = height // patch_size + num_width_tokens = width // patch_size + replace_tokens = [[self.image_token] * num_width_tokens + [image_break_token]] * num_height_tokens + replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list + replace_tokens[-1] = image_end_token + replace_str = "".join(replace_tokens) + else: + replace_str = self.image_token + + content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1) + + message["content"] = content + + return messages + + @override + def get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + imglens: list[int], + vidlens: list[int], + audlens: list[int], + batch_ids: list[list[int]], + processor: Optional["MMProcessor"], + ) -> dict[str, Union[list[int], "torch.Tensor"]]: + self._validate_input(processor, images, videos, audios) + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + # ref to this commit https://github.com/huggingface/transformers/pull/35122 + # after transformers 4.49.0, the `image_sizes` is mandatory as an input parameter for Pixtral VisionEncoder forwarding. + # it can be passed into `LlavaConditionalGeneration` as a parameter. + if not is_transformers_version_greater_than("4.49.0"): + mm_inputs.pop("image_sizes", None) + return mm_inputs + + +@dataclass +class Qwen2AudioPlugin(BasePlugin): + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + bos_token: str = getattr(processor, "audio_bos_token") + eos_token: str = getattr(processor, "audio_eos_token") + messages = deepcopy(messages) + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs([], [], audios, processor) + if "feature_attention_mask" in mm_inputs: + audio_lengths = mm_inputs["feature_attention_mask"].sum(-1).tolist() + + for message in messages: + content = message["content"] + while AUDIO_PLACEHOLDER in content: + if self.expand_mm_tokens: + audio_length = audio_lengths.pop(0) + input_length = (audio_length - 1) // 2 + 1 + audio_seqlen = (input_length - 2) // 2 + 1 + else: + audio_seqlen = 1 + + content = content.replace( + AUDIO_PLACEHOLDER, f"{bos_token}{self.audio_token * audio_seqlen}{eos_token}", 1 + ) + + message["content"] = content + + return messages + + @override + def get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + imglens: list[int], + vidlens: list[int], + audlens: list[int], + batch_ids: list[list[int]], + processor: Optional["MMProcessor"], + ) -> dict[str, Union[list[int], "torch.Tensor"]]: + self._validate_input(processor, images, videos, audios) + return self._get_mm_inputs(images, videos, audios, processor) + + +@dataclass +class Qwen2VLPlugin(BasePlugin): + @override + def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject": + image = super()._preprocess_image(image, **kwargs) + if min(image.width, image.height) < 28: + width, height = max(image.width, 28), max(image.height, 28) + image = image.resize((width, height)) + + if image.width / image.height > 200: + width, height = image.height * 180, image.height + image = image.resize((width, height)) + + if image.height / image.width > 200: + width, height = image.width, image.width * 180 + image = image.resize((width, height)) + + return image + + @override + def _regularize_videos( + self, videos: list["VideoInput"], **kwargs + ) -> dict[str, Union[list[list["ImageObject"]], list[float]]]: + results, fps_per_video = [], [] + for video in videos: + frames: list[ImageObject] = [] + if _check_video_is_nested_images(video): + for frame in video: + if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame): + raise ValueError("Invalid image found in video frames.") + + frames = video + fps_per_video.append(kwargs.get("video_fps", 2.0)) + else: + container = av.open(video, "r") + video_stream = next(stream for stream in container.streams if stream.type == "video") + sample_indices = self._get_video_sample_indices(video_stream, **kwargs) + container.seek(0) + for frame_idx, frame in enumerate(container.decode(video_stream)): + if frame_idx in sample_indices: + frames.append(frame.to_image()) + + if video_stream.duration is None: + fps_per_video.append(kwargs.get("video_fps", 2.0)) + else: + fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base)) + + if len(frames) % 2 != 0: + frames.append(frames[-1]) + + frames = self._regularize_images(frames, **kwargs)["images"] + results.append(frames) + + return {"videos": results, "fps_per_video": fps_per_video} + + @override + def _get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: "MMProcessor", + ) -> dict[str, "torch.Tensor"]: + image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) + mm_inputs = {} + if len(images) != 0: + images = self._regularize_images( + images, + image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), + image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), + )["images"] + mm_inputs.update(image_processor(images, return_tensors="pt")) + + if len(videos) != 0: + video_data = self._regularize_videos( + videos, + image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256), + image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16), + video_fps=getattr(processor, "video_fps", 2.0), + video_maxlen=getattr(processor, "video_maxlen", 128), + ) + mm_inputs.update(image_processor(images=None, videos=video_data["videos"], return_tensors="pt")) + temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2) + if "second_per_grid_ts" in processor.model_input_names: + mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in video_data["fps_per_video"]] + + return mm_inputs + + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + num_image_tokens, num_video_tokens = 0, 0 + messages = deepcopy(messages) + image_processor: BaseImageProcessor = getattr(processor, "image_processor") + + merge_length: int = getattr(image_processor, "merge_size") ** 2 + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + image_grid_thw = mm_inputs.get("image_grid_thw", []) + video_grid_thw = mm_inputs.get("video_grid_thw", []) + else: + image_grid_thw = [None] * len(images) + video_grid_thw = [None] * len(videos) + + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1 + content = content.replace( + IMAGE_PLACEHOLDER, f"<|vision_start|>{self.image_token * image_seqlen}<|vision_end|>", 1 + ) + num_image_tokens += 1 + + while VIDEO_PLACEHOLDER in content: + video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1 + content = content.replace( + VIDEO_PLACEHOLDER, f"<|vision_start|>{self.video_token * video_seqlen}<|vision_end|>", 1 + ) + num_video_tokens += 1 + + message["content"] = content + + return messages + + +@dataclass +class GLM4VPlugin(Qwen2VLPlugin): + @override + def _get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: "MMProcessor", + ) -> dict[str, "torch.Tensor"]: + image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) + video_processor: BaseImageProcessor = getattr(processor, "video_processor", None) + mm_inputs = {} + if len(images) != 0: + images = self._regularize_images( + images, + image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), + image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), + )["images"] + mm_inputs.update(image_processor(images, return_tensors="pt")) + + if len(videos) != 0: + video_data = self._regularize_videos( + videos, + image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256), + image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16), + video_fps=getattr(processor, "video_fps", 2.0), + video_maxlen=getattr(processor, "video_maxlen", 128), + ) + # prepare video metadata + video_metadata = [ + {"fps": 2, "duration": len(video), "total_frames": len(video)} for video in video_data["videos"] + ] + mm_inputs.update(video_processor(images=None, videos=video_data["videos"], video_metadata=video_metadata)) + + return mm_inputs + + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + num_image_tokens, num_video_tokens = 0, 0 + messages = deepcopy(messages) + image_processor: BaseImageProcessor = getattr(processor, "image_processor") + + merge_length: int = getattr(image_processor, "merge_size") ** 2 + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + image_grid_thw = mm_inputs.get("image_grid_thw", []) + video_grid_thw = mm_inputs.get("video_grid_thw", []) + num_frames = video_grid_thw[0][0] if len(video_grid_thw) > 0 else 0 # hard code for now + timestamps = mm_inputs.get("timestamps", []) + + if hasattr(timestamps, "tolist"): + timestamps = timestamps.tolist() + + if not timestamps: + timestamps_list = [] + elif isinstance(timestamps[0], list): + timestamps_list = timestamps[0] + else: + timestamps_list = timestamps + + unique_timestamps = timestamps_list.copy() + selected_timestamps = unique_timestamps[:num_frames] + while len(selected_timestamps) < num_frames: + selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0) + + else: + image_grid_thw = [None] * len(images) + video_grid_thw = [None] * len(videos) + num_frames = 0 + selected_timestamps = [0] + + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1 + content = content.replace( + IMAGE_PLACEHOLDER, f"<|begin_of_image|>{self.image_token * image_seqlen}<|end_of_image|>", 1 + ) + num_image_tokens += 1 + + while VIDEO_PLACEHOLDER in content: + video_structure = "" + for frame_index in range(num_frames): + video_seqlen = ( + video_grid_thw[num_video_tokens][1:].prod() // merge_length if self.expand_mm_tokens else 1 + ) + timestamp_sec = selected_timestamps[frame_index] + frame_structure = ( + f"<|begin_of_image|>{self.image_token * video_seqlen}<|end_of_image|>{timestamp_sec}" + ) + video_structure += frame_structure + + content = content.replace(VIDEO_PLACEHOLDER, f"<|begin_of_video|>{video_structure}<|end_of_video|>", 1) + num_video_tokens += 1 + + message["content"] = content + + return messages + + @override + def get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + imglens: list[int], + vidlens: list[int], + audlens: list[int], + batch_ids: list[list[int]], + processor: Optional["ProcessorMixin"], + ) -> dict[str, Union[list[int], "torch.Tensor"]]: + self._validate_input(processor, images, videos, audios) + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + mm_inputs.pop("timestamps", None) + return mm_inputs + + +class Qwen2OmniPlugin(Qwen2VLPlugin): + @override + def _get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: "MMProcessor", + ) -> dict[str, "torch.Tensor"]: + image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) + feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) + mm_inputs = {} + if len(images) != 0: + images = self._regularize_images( + images, + image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), + image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), + )["images"] + mm_inputs.update(image_processor(images, return_tensors="pt")) + + if len(videos) != 0: + video_dict = self._regularize_videos( + videos, + image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256), + image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16), + video_fps=getattr(processor, "video_fps", 2.0), + video_maxlen=getattr(processor, "video_maxlen", 128), + ) + mm_inputs.update(image_processor(images=None, videos=video_dict["videos"], return_tensors="pt")) + temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2) + mm_inputs["video_second_per_grid"] = torch.tensor( + [temporal_patch_size / fps for fps in video_dict["fps_per_video"]] + ) + + if len(audios) != 0: + audios = self._regularize_audios( + audios, + sampling_rate=getattr(processor, "audio_sampling_rate", 16000), + )["audios"] + mm_inputs.update( + feature_extractor( + audios, + sampling_rate=getattr(processor, "audio_sampling_rate", 16000), + return_attention_mask=True, + padding="max_length", + return_tensors="pt", + ) + ) + mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask") # prevent conflicts + + return mm_inputs + + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0 + messages = deepcopy(messages) + image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) + + merge_length = processor.image_processor.merge_size**2 + use_audio_in_video = getattr(processor, "use_audio_in_video", False) + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + image_grid_thw = mm_inputs.get("image_grid_thw", []) + video_grid_thw = mm_inputs.get("video_grid_thw", []) + if "feature_attention_mask" in mm_inputs: + input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1 + audio_lengths = (input_lengths - 2) // 2 + 1 + else: + mm_inputs = {} + image_grid_thw = [None] * len(images) + video_grid_thw = [None] * len(videos) + audio_lengths = [None] * len(audios) + + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1 + content = content.replace( + IMAGE_PLACEHOLDER, f"<|vision_bos|>{self.image_token * image_seqlen}<|vision_eos|>", 1 + ) + num_image_tokens += 1 + + if ( + use_audio_in_video and len(audios) and len(videos) + ): # if use the audio of video # deal video token and audio token togather + if len(videos) != len(audios): + raise ValueError( + f"Number of videos ({len(videos)}) must match number of audios ({len(audios)}) when using audio in video." + ) + + while VIDEO_PLACEHOLDER in content: + video_pos = content.find(VIDEO_PLACEHOLDER) + audio_pos = content.find(AUDIO_PLACEHOLDER, video_pos) + if audio_pos == -1 or audio_pos < video_pos: + raise ValueError( + f"Each {VIDEO_PLACEHOLDER} must be followed by an {AUDIO_PLACEHOLDER} when using audio in video." + ) + + audio_t_index = torch.arange(audio_lengths[num_audio_tokens]) + video_t_index = ( + torch.arange(video_grid_thw[num_video_tokens][0]) + .view(-1, 1, 1) + .expand( + -1, + video_grid_thw[num_video_tokens][1] // image_processor.merge_size, + video_grid_thw[num_video_tokens][2] // image_processor.merge_size, + ) + .flatten() + * mm_inputs["video_second_per_grid"][num_video_tokens] + * 25 # FIXME hardcode of position_id_per_seconds=25 + ).long() + t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2] + video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk) + audio_chunk_indices = processor.get_chunked_index(audio_t_index, t_ntoken_per_chunk) + placeholder_string = "" + placeholder_string += "<|vision_bos|>" + "<|audio_bos|>" + for j in range(max(len(video_chunk_indices), len(audio_chunk_indices))): + video_chunk_index = video_chunk_indices[j] if j < len(video_chunk_indices) else None + audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None + if video_chunk_index is not None: + placeholder_string += self.video_token * (video_chunk_index[1] - video_chunk_index[0]) + + if audio_chunk_index is not None: + placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0]) + + placeholder_string += "<|audio_eos|>" + "<|vision_eos|>" + content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1) + content = content.replace(AUDIO_PLACEHOLDER, "", 1) + num_audio_tokens += 1 + num_video_tokens += 1 + else: + while AUDIO_PLACEHOLDER in content: + audio_seqlen = audio_lengths[num_audio_tokens] if self.expand_mm_tokens else 1 + content = content.replace( + AUDIO_PLACEHOLDER, f"<|audio_bos|>{self.audio_token * audio_seqlen}<|audio_eos|>", 1 + ) + num_audio_tokens += 1 + + while VIDEO_PLACEHOLDER in content: + video_seqlen = ( + video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1 + ) + content = content.replace( + VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_seqlen}<|vision_eos|>", 1 + ) + num_video_tokens += 1 + + message["content"] = content + + return messages + + +@dataclass +class VideoLlavaPlugin(BasePlugin): + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + num_image_tokens, num_video_tokens = 0, 0 + messages = deepcopy(messages) + num_frames = 0 + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + if "pixel_values_images" in mm_inputs: + height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values_images"][0])) + num_frames = 1 + + if "pixel_values_videos" in mm_inputs: + one_video = to_numpy_array(mm_inputs["pixel_values_videos"][0]) + height, width = get_image_size(one_video[0]) + num_frames = one_video.shape[0] # frame dim is always after batch dim + + if "pixel_values_images" in mm_inputs or "pixel_values_videos" in mm_inputs: + image_seqlen = (height // processor.patch_size) * ( + width // processor.patch_size + ) + processor.num_additional_image_tokens + video_seqlen = image_seqlen * num_frames + if processor.vision_feature_select_strategy == "default": + image_seqlen -= 1 + else: + image_seqlen, video_seqlen = 1, 1 + + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) + num_image_tokens += 1 + + while VIDEO_PLACEHOLDER in content: + content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1) + num_video_tokens += 1 + + content = content.replace("{{image}}", self.image_token) + message["content"] = content.replace("{{video}}", self.video_token) + + return messages + + +PLUGINS = { + "base": BasePlugin, + "gemma3": Gemma3Plugin, + "glm4v": GLM4VPlugin, + "gemma3n": Gemma3nPlugin, + "intern_vl": InternVLPlugin, + "kimi_vl": KimiVLPlugin, + "llama4": Llama4Plugin, + "llava": LlavaPlugin, + "llava_next": LlavaNextPlugin, + "llava_next_video": LlavaNextVideoPlugin, + "minicpm_v": MiniCPMVPlugin, + "mllama": MllamaPlugin, + "paligemma": PaliGemmaPlugin, + "pixtral": PixtralPlugin, + "qwen2_audio": Qwen2AudioPlugin, + "qwen2_omni": Qwen2OmniPlugin, + "qwen2_vl": Qwen2VLPlugin, + "video_llava": VideoLlavaPlugin, +} + + +def register_mm_plugin(name: str, plugin_class: type["BasePlugin"]) -> None: + r"""Register a multimodal plugin.""" + if name in PLUGINS: + raise ValueError(f"Multimodal plugin {name} already exists.") + + PLUGINS[name] = plugin_class + + +def get_mm_plugin( + name: str, + image_token: Optional[str] = None, + video_token: Optional[str] = None, + audio_token: Optional[str] = None, +) -> "BasePlugin": + r"""Get plugin for multimodal inputs.""" + if name not in PLUGINS: + raise ValueError(f"Multimodal plugin `{name}` not found.") + + return PLUGINS[name](image_token, video_token, audio_token) diff --git a/agents/agents/agents/templates/templates.py b/agents/agents/agents/templates/templates.py index 1b71cfe..fbdf63a 100644 --- a/agents/agents/agents/templates/templates.py +++ b/agents/agents/agents/templates/templates.py @@ -1,10 +1,3 @@ -""" -This file is a modified version of the original fastchat/conversation.py file. - -The original file can be found at: -https://github.com/LM-SYS/fastchat/blob/main/fastchat/conversation.py - -""" from collections import defaultdict from copy import copy @@ -13,31 +6,29 @@ import json from typing import List, Any, Dict, Union, Tuple import warnings - +import logging import torch from .preprocess import open_image_from_any from transformers import PreTrainedTokenizer +from .vision_processor import is_vision_template import re +Logger = logging.getLogger(__name__) + +# Add console handler if no handlers exist +if not Logger.handlers: + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.DEBUG) + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + console_handler.setFormatter(formatter) + Logger.addHandler(console_handler) + class Role(Enum): SYSTEM = "system" USER = "user" ASSISTANT = "assistant" TOOL = "tool" ASSISTANT_PREFIX = "assistant_prefix" - - -class SeparatorStyle(IntEnum): - """Separator styles.""" - NO_COLON_SINGLE = auto() - LLAMA2 = auto() - LLAMA3 = auto() - CHATGLM = auto() - CHATML = auto() - ADD_SPACE_TWO = auto() - GENERAL = auto() # No sep, we take all sep as part of the role - GENERAL_STOP_ALL = auto() # No sep, and add stop_str at each turn - @dataclasses.dataclass @@ -68,6 +59,58 @@ class Template: image_token: str = None video_token: str = None + def __post_init__(self): + """Post-initialization to automatically register vision processor if vision tokens are defined""" + if self.image_token or self.video_token: + self._register_vision_processor() + + def _register_vision_processor(self): + """Automatically register a vision processor for this template""" + from .vision_processor import VisionProcessorConfig, register_processor + + # Determine model type based on template name + model_type = self._infer_model_type() + + # Create vision config + config = VisionProcessorConfig( + model_type=model_type, + image_token=self.image_token or "", + video_token=self.video_token or "", + vision_start=self.vision_start or "", + vision_end=self.vision_end or "", + processor_class="AutoProcessor", + expansion_strategy="patch_based" + ) + + # Register the processor + register_processor(self.name, config) + + def _infer_model_type(self) -> str: + """Infer model type from template name""" + name_lower = self.name.lower() + + if "qwen" in name_lower: + return "qwen_vl" + elif "llava" in name_lower: + return "llava" + elif "gemma" in name_lower: + return "gemma3" + elif "paligemma" in name_lower: + return "paligemma" + elif "internvl" in name_lower: + return "internvl" + elif "minicpm" in name_lower: + return "minicpm" + elif "mllama" in name_lower: + return "mllama" + elif "pixtral" in name_lower: + return "pixtral" + elif "video" in name_lower: + return "video_llava" + else: + # Default to patch-based for unknown models + return "patch_based" + def render(self, messages: List[Dict], tools=None, add_generation_prompt: bool = False) -> str: """Render the template with the given messages and kwargs. messages: [ @@ -208,8 +251,21 @@ def _split_assistant_message(self, assistant_message: str) -> List[str]: return generation_prefix, content, suffix - def encode(self, messages: List[Dict], tokenizer: PreTrainedTokenizer, return_tensors: str = None, tools=None) -> str: - prompt, elements, roles = self.render(messages, tools=tools) + def encode(self, messages: List[Dict], tokenizer: PreTrainedTokenizer, return_tensors: str = None, tools=None, add_generation_prompt=False, processor=None) -> str: + if processor is None and self.supports_vision(): + raise ValueError(f"Processor is required for vision templates: {self.name}") + + if self.supports_vision(): + # Use vision-aware encoding with proper alignment + return self._encode_with_vision_processor(messages, tokenizer, return_tensors, tools, add_generation_prompt=add_generation_prompt, processor=processor) + else: + # Use standard encoding + return self._encode_standard(messages, tokenizer, return_tensors, tools, add_generation_prompt=add_generation_prompt) + + def _encode_standard(self, messages: List[Dict], tokenizer: PreTrainedTokenizer, return_tensors: str = None, tools=None, add_generation_prompt=False) -> str: + Logger.debug(f"[Template] Encoding standard for template: {self.name}") + """Standard encoding without vision support""" + prompt, elements, roles = self.render(messages, tools=tools, add_generation_prompt=add_generation_prompt) elements, mask_flags = self._postprocess_elements(elements, roles) input_ids = [] attention_mask = [] @@ -241,6 +297,36 @@ def encode(self, messages: List[Dict], tokenizer: PreTrainedTokenizer, return_te if return_tensors == "pt": inputs = {k: torch.tensor([v]) for k, v in inputs.items()} return inputs + + def _encode_with_vision_processor(self, messages: List[Dict], tokenizer: PreTrainedTokenizer, return_tensors: str = None, tools=None, add_generation_prompt=False, processor=None) -> str: + Logger.debug(f"[Template] Encoding with vision processor for template: {self.name}") + """Encode with vision processor handling proper alignment""" + from .vision_processor import get_processor + from .utils import extract_vision_inputs_from_messages + + # Get vision processor + vision_processor = get_processor(self.name) + if vision_processor is None: + raise ValueError(f"No vision processor registered for template: {self.name}") + + # Get base prompt and mask information + prompt, elements, roles = self.render(messages, tools=tools, add_generation_prompt=add_generation_prompt) + elements, mask_flags = self._postprocess_elements(elements, roles) + + # Extract vision inputs + images, videos = extract_vision_inputs_from_messages(messages) + + # Use vision processor with alignment support + return vision_processor.process_for_llm( + prompt=prompt, + elements=elements, + mask_flags=mask_flags, + images=images, + videos=videos, + processor=processor, + tokenizer=tokenizer, + return_tensors=return_tensors + ) def _postprocess_elements(self, elements: List[str], roles) -> List[str]: @@ -302,6 +388,14 @@ def _postprocess_elements(self, elements: List[str], roles) -> List[str]: merged_mask_flags.append(prev_mask_flag) return merged_elements, merged_mask_flags + def supports_vision(self) -> bool: + """Check if this template supports vision processing""" + return is_vision_template(self.name) + + def get_vision_config(self): + """Get vision configuration for this template""" + from .vision_processor import VisionProcessorRegistry + return VisionProcessorRegistry.get_config(self.name) def get_vision_inputs(self): vision_inputs = defaultdict(list) @@ -318,9 +412,6 @@ def get_vision_inputs(self): raise ValueError(f"Invalid message type: {item['type']}") return vision_inputs - # if self.name == "qwen2.5-vl": - # jinja_template = """{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['role'] == 'tool'%}\n{% endif %}{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% if message['role'] == 'tool' %}\n\n{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}""" - def jinja_template(self) -> str: """ Build a Hugging-Face-style chat-template (Jinja-mini dialect) that mimics @@ -554,12 +645,12 @@ def prompt_with_mask(self, add_generation_prompt=False, tools=None) -> str: prompt_with_mask, _, _ = self.template.render_with_mask(messages=self.messages, add_generation_prompt=add_generation_prompt, tools=tools) return prompt_with_mask - def tokenize(self, tokenizer: PreTrainedTokenizer = None, add_generation_prompt=False, tools=None) -> List[int]: + def tokenize(self, tokenizer: PreTrainedTokenizer = None, add_generation_prompt=False, tools=None, processor=None) -> List[int]: if tokenizer is None: tokenizer = self.tokenizer if tools is None: tools = self.tools - return self.template.encode(messages=self.messages, tokenizer=tokenizer, return_tensors="pt", tools=tools) + return self.template.encode(messages=self.messages, tokenizer=tokenizer, return_tensors="pt", tools=tools, add_generation_prompt=add_generation_prompt, processor=processor) def append(self, message: Union[Dict, List[Dict]]): self._convert_single_message_to_hf_format(message) @@ -601,6 +692,23 @@ def get_template(name: str) -> Template: ) ) +register_template( + Template( + name="qwen2.5-vl", + system_template="<|im_start|>system\n{system_message}<|im_end|>\n", + system_message="You are a helpful assistant.", + user_template="<|im_start|>user\n{content}<|im_end|>\n", + assistant_template="<|im_start|>assistant\n{content}<|im_end|>\n", + tool_template="<|im_start|>tool\n{observation}<|im_end|>\n", + vision_start="<|vision_start|>", + vision_end="<|vision_end|>", + image_token="<|image_pad|>", + video_token="<|video_pad|>", + stop_words=["<|im_end|>"], + ) +) + + register_template( Template( name="qwen2.5", @@ -648,47 +756,15 @@ def get_template(name: str) -> Template: ) ) -# register_conv_template( -# Template( -# name="qwen3", -# system_template="<|im_start|>system\n{system_message}<|im_end|>\n", -# # system_message="", -# system_template_tool="""<|im_start|>system\n{system_message}# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n\n{tools}\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{{"name": , "arguments": }}\n<|im_end|>\n""", -# roles=("<|im_start|>user\n", "<|im_start|>assistant\n", "<|im_start|>user\n\n{observation}\n"), -# tool_aggregator="STACKED", -# sep_style=SeparatorStyle.GENERAL_STOP_ALL, -# sep="", -# stop_token_ids=[ -# 151643, -# 151644, -# 151645, -# ], # "<|endoftext|>", "<|im_start|>", "<|im_end|>" -# stop_str="<|im_end|>\n", -# ) -# ) - -# register_conv_template( -# Template( -# name="qwen2.5-vl", -# system_template="<|im_start|>system\n{system_message}<|im_end|>\n", -# system_message="You are a helpful assistant.", -# roles=("<|im_start|>user\n", "<|im_start|>assistant\n", "<|im_start|>tool\n\n{observation}\n"), -# tool_aggregator="STACKED", -# sep_style=SeparatorStyle.GENERAL_STOP_ALL, -# sep="", -# vision_start="<|vision_start|>", -# vision_end="<|vision_end|>", -# image_token="<|image_pad|>", -# video_token="<|video_pad|>", -# stop_token_ids=[ -# 151643, -# 151644, -# 151645, -# ], # "<|endoftext|>", "<|im_start|>", "<|im_end|>" -# stop_str="<|im_end|>\n", -# ) -# ) - +register_template( + Template( + name="deepseek-prover-v2", + system_template="<|begin▁of▁sentence|>{system_message}", + user_template="<|User|>{content}", + assistant_template="<|Assistant|>{content}<|end▁of▁sentence|>", + stop_words=["<|end▁of▁sentence|>"], + ) +) if __name__ == "__main__": diff --git a/agents/agents/agents/templates/test_alignment.py b/agents/agents/agents/templates/test_alignment.py new file mode 100644 index 0000000..4ef7756 --- /dev/null +++ b/agents/agents/agents/templates/test_alignment.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 +""" +Test script to verify tensor alignment functionality in the vision processor system. +""" + +import torch +from typing import Dict, List, Any + +def test_tensor_alignment(): + """Test that all tensors are properly aligned after vision token expansion""" + print("=== Testing Tensor Alignment ===") + + # Mock vision processor for testing + class MockVisionProcessor: + def expand_vision_tokens(self, prompt, images, videos, processor): + # Expand single image token to multiple tokens + return prompt.replace("<|image_pad|>", "<|image_pad|><|image_pad|><|image_pad|>") + + def calculate_image_tokens(self, image_data, processor): + return 3 # Mock: each image expands to 3 tokens + + def preprocess_images(self, images, processor): + return {"pixel_values": torch.randn(1, 3, 224, 224)} + + def preprocess_videos(self, videos, processor): + return {} + + def get_mm_inputs(self, images, videos, processor): + return {"pixel_values": torch.randn(1, 3, 224, 224)} + + # Mock tokenizer + class MockTokenizer: + def encode(self, text, add_special_tokens=False): + if "<|image_pad|>" in text: + # Replace image tokens with token IDs + text = text.replace("<|image_pad|>", "999") + return [1, 2, 999, 999, 999, 4] # Expanded tokens + return [1, 2, 3, 4] # Regular tokens + + @property + def add_bos_token(self): + return True + + @property + def bos_token(self): + return "" + + @property + def bos_token_id(self): + return 0 + + # Test data + elements = [ + "What's in this image? <|image_pad|>", # User message (masked) + "I can see a cat in the image." # Assistant message (not masked) + ] + mask_flags = [True, False] # User message masked, assistant not + + processor = MockVisionProcessor() + tokenizer = MockTokenizer() + + # Simulate the alignment process + input_ids = [] + attention_mask = [] + labels = [] + action_mask = [] + + # Add BOS token + input_ids.append(tokenizer.bos_token_id) + attention_mask.append(1) + labels.append(-100) + action_mask.append(0) + + # Process each element + for element, mask_flag in zip(elements, mask_flags): + # Check if element contains vision tokens + if "<|image_pad|>" in element: + # Expand vision tokens + expanded_element = processor.expand_vision_tokens(element, ["image.jpg"], [], None) + cur_input_ids = tokenizer.encode(expanded_element, add_special_tokens=False) + else: + cur_input_ids = tokenizer.encode(element, add_special_tokens=False) + + # Add tokens with proper alignment + input_ids.extend(cur_input_ids) + attention_mask.extend([1] * len(cur_input_ids)) + + if mask_flag: + labels.extend([-100] * len(cur_input_ids)) + action_mask.extend([0] * len(cur_input_ids)) + else: + labels.extend(cur_input_ids) + action_mask.extend([1] * len(cur_input_ids)) + + # Convert to tensors + inputs = { + 'input_ids': torch.tensor([input_ids]), + 'attention_mask': torch.tensor([attention_mask]), + 'labels': torch.tensor([labels]), + 'action_mask': torch.tensor([action_mask]) + } + + # Verify alignment + print("Tensor shapes:") + for key, value in inputs.items(): + print(f" {key}: {value.shape}") + + # Check that all tensors have the same sequence length + seq_len = inputs['input_ids'].shape[1] + assert inputs['attention_mask'].shape[1] == seq_len, "attention_mask not aligned" + assert inputs['labels'].shape[1] == seq_len, "labels not aligned" + assert inputs['action_mask'].shape[1] == seq_len, "action_mask not aligned" + + print(f"✅ All tensors aligned with sequence length: {seq_len}") + + # Verify the content makes sense + print("\nTensor content verification:") + print(f"input_ids: {inputs['input_ids'][0].tolist()}") + print(f"attention_mask: {inputs['attention_mask'][0].tolist()}") + print(f"labels: {inputs['labels'][0].tolist()}") + print(f"action_mask: {inputs['action_mask'][0].tolist()}") + + # Check that vision tokens are properly handled + vision_token_positions = [i for i, token_id in enumerate(inputs['input_ids'][0]) if token_id == 999] + print(f"\nVision token positions: {vision_token_positions}") + + # Verify that all tensors have proper values at vision token positions + for pos in vision_token_positions: + assert inputs['attention_mask'][0][pos] == 1, f"attention_mask should be 1 at position {pos}" + # Labels and action_mask depend on whether it's in a masked region + # This is a simplified test - in practice, you'd check the actual mask flags + + print("✅ Vision tokens properly handled in all tensors") + print("✅ Tensor alignment test passed!") + +def test_vision_processor_integration(): + """Test integration with the actual vision processor system""" + print("\n=== Testing Vision Processor Integration ===") + + try: + from .vision_processor import VisionProcessorRegistry, VisionProcessorConfig, PatchBasedProcessor + from .templates import get_template + + # Check if qwen2.5-vl template is registered + if VisionProcessorRegistry.is_vision_template("qwen2.5-vl"): + print("✅ qwen2.5-vl template is registered") + + processor = VisionProcessorRegistry.get_processor("qwen2.5-vl") + if processor is not None: + print("✅ Vision processor retrieved successfully") + + # Test the contains_vision_tokens method + test_text = "What's in this image? <|image_pad|>" + has_vision = processor._contains_vision_tokens(test_text) + print(f"✅ Vision token detection: {has_vision}") + + else: + print("❌ Vision processor not found") + else: + print("❌ qwen2.5-vl template not registered") + + except ImportError as e: + print(f"❌ Import error: {e}") + except Exception as e: + print(f"❌ Error: {e}") + +if __name__ == "__main__": + test_tensor_alignment() + test_vision_processor_integration() + print("\n=== All Tests Completed ===") \ No newline at end of file diff --git a/agents/agents/agents/templates/utils.py b/agents/agents/agents/templates/utils.py index ccb2258..eb5a9e9 100644 --- a/agents/agents/agents/templates/utils.py +++ b/agents/agents/agents/templates/utils.py @@ -10,7 +10,8 @@ import logging from .templates import Chat, get_template from ... import AGENT_DATA_DIR - +from typing import Any +from .vision_processor import get_processor # Set up logging that won't be overridden by other modules LOGGER = logging.getLogger(__name__) @@ -20,9 +21,6 @@ def strip_ansi(s: str) -> str: """Remove ANSI escape sequences from a string.""" return ANSI_RE.sub('', s) -def is_vlm_template(template: str) -> bool: - return template in ["qwen2.5-vl"] - def convert_messages_to_openai_format(messages: list) -> list: """ @@ -139,12 +137,15 @@ def tokenize_conversation( :return: input_ids, attention_mask, labels, action_mask """ chat = Chat(template=template, messages=messages, tokenizer=tokenizer) - inputs = chat.tokenize(tokenizer, add_generation_prompt=add_generation_prompt, tools=tools) + inputs = chat.tokenize(tokenizer, add_generation_prompt=add_generation_prompt, tools=tools, processor=processor) + if max_length is not None: inputs['input_ids'] = inputs['input_ids'][:, :max_length] inputs['attention_mask'] = inputs['attention_mask'][:, :max_length] - inputs['labels'] = inputs['labels'][:, :max_length] - inputs['action_mask'] = inputs['action_mask'][:, :max_length] + if 'labels' in inputs: + inputs['labels'] = inputs['labels'][:, :max_length] + if 'action_mask' in inputs: + inputs['action_mask'] = inputs['action_mask'][:, :max_length] return inputs @@ -153,82 +154,85 @@ def convert_inputs_to_vision_inputs(template: str, processor, # AutoProcessor (not bare tokenizer) messages: list): """ - Expand `inputs` (built from a chat template that contains ONE - <|image_pad|> / <|video_pad|> placeholder per asset) into the real - processor outputs and stretch `action_mask` / `labels` so they stay - aligned after the pad tokens are repeated. - - Returns - ------- - dict -- processor(...) result + expanded masks + NEW PIPELINE: Template processes messages → Human-readable prompt → Vision processor → LLM-ready inputs + + The correct pipeline is: + 1. Template processes messages to get human-readable prompt with single multi-modal tokens + 2. Vision processor handles image/video processing and token expansion + 3. Final result is directly usable by LLMs with model(**inputs) """ - assert template == "qwen2.5-vl", "Only qwen2.5-vl is supported" - - - # ------------------------------------------------------------------ - # 1. special-token ids - # ------------------------------------------------------------------ - tk = processor.tokenizer - conv = get_conv_template(template) - image_pad_id = tk.encode(conv.image_token, add_special_tokens=False)[0] - video_pad_id = tk.encode(conv.video_token, add_special_tokens=False)[0] - vis_start_id = tk.encode(conv.vision_start, add_special_tokens=False)[0] - vis_end_id = tk.encode(conv.vision_end, add_special_tokens=False)[0] - - repeat_ids = torch.tensor([image_pad_id, video_pad_id], dtype=torch.long) - vision_ids = torch.tensor([image_pad_id, video_pad_id, - vis_start_id, vis_end_id], dtype=torch.long) - - # ------------------------------------------------------------------ - # 2. run the processor (adds patch-level vision tokens) - # ------------------------------------------------------------------ - LOGGER.debug(f"[Template::convert_inputs_to_vision_inputs] messages: {messages}") - imgs, vids = process_vision_info(messages) - text = format_conversation(messages, template).get_prompt() - proc_out = processor( - text=text, - images=imgs, - videos=vids, - return_tensors="pt", - padding=False, - truncation=False, + # Get the vision processor for this template + vision_processor = get_processor(template) + if vision_processor is None: + raise ValueError(f"No vision processor registered for template: {template}") + + # Step 1: Template processes messages to get human-readable prompt + from .templates import Chat + chat = Chat(template=template, messages=messages, tokenizer=processor.tokenizer) + prompt = chat.prompt() # This gives us human-readable prompt with single multi-modal tokens + + # Step 2: Extract vision inputs from messages + images, videos = extract_vision_inputs_from_messages(messages) + + # Step 3: Vision processor handles the complete pipeline + # This expands tokens and generates LLM-ready inputs + final_inputs = vision_processor.process_for_llm( + prompt=prompt, + images=images, + videos=videos, + processor=processor, + tokenizer=processor.tokenizer ) - new_ids = proc_out["input_ids"][0] # (new_len,) - new_attention_mask = proc_out["attention_mask"][0] - device = new_ids.device - - # ------------------------------------------------------------------ - # 3. original (pre-processor) tensors - # ------------------------------------------------------------------ - old_ids = inputs["input_ids"][0].to(device) - old_action = inputs.get("action_mask", None) - old_labels = inputs.get("labels", None) - old_attention_mask = inputs.get("attention_mask", None) - - # ------------------------------------------------------------------ - # 4. build boolean masks that mark NON-repeated positions - # ------------------------------------------------------------------ - new_keep = ~torch.isin(new_ids, repeat_ids) # True where we copy from old - old_keep = ~torch.isin(old_ids, repeat_ids) - - # sanity check - assert new_keep.sum() == old_keep.sum(), f"Mismatch after dropping repeated vision tokens: {new_ids} {old_ids}" - - # ------------------------------------------------------------------ - # 5. allocate expanded tensors and copy en masse - # ------------------------------------------------------------------ - if old_action is not None: - exp_action = torch.zeros_like(new_ids, dtype=old_action.dtype) - exp_action[new_keep] = old_action[0][old_keep].to(device) - proc_out["action_mask"] = exp_action.unsqueeze(0) - - if old_labels is not None: - exp_labels = torch.full_like(new_ids, -100, dtype=old_labels.dtype) - exp_labels[new_keep] = old_labels[0][old_keep].to(device) - proc_out["labels"] = exp_labels.unsqueeze(0) - - return proc_out + return final_inputs + +def extract_vision_inputs_from_messages(messages: list) -> tuple[list, list]: + """Extract images and videos from messages""" + images, videos = [], [] + + for message in messages: + if isinstance(message.get('content'), list): + for item in message['content']: + if item.get('type') == 'image': + if 'image' in item: + images.append(item['image']) + elif 'image_url' in item: + images.append(item['image_url']['url']) + elif item.get('type') == 'video': + if 'video' in item: + videos.append(item['video']) + elif 'video_url' in item: + videos.append(item['video_url']['url']) + + return images, videos + +def process_prompt_with_vision( + prompt: str, + template: str, + processor: Any, + images: list = None, + videos: list = None, +) -> dict: + """Process a prompt with vision support""" + vision_processor = get_processor(template) + if vision_processor is None: + # If no vision processor, just return tokenized prompt + return processor.tokenizer( + prompt, + return_tensors="pt", + add_special_tokens=True, + padding=True, + truncation=True + ) + + # Use vision processor to handle the complete pipeline + return vision_processor.process_for_llm( + prompt=prompt, + images=images or [], + videos=videos or [], + processor=processor, + tokenizer=processor.tokenizer + ) def tokenize_conversations(messages_list, tokenizer, conv_template, max_length, processor=None, return_tensors="pt", return_reward_mask=False): diff --git a/agents/agents/agents/templates/vision_processor.py b/agents/agents/agents/templates/vision_processor.py new file mode 100644 index 0000000..e3cd09c --- /dev/null +++ b/agents/agents/agents/templates/vision_processor.py @@ -0,0 +1,651 @@ +""" +Comprehensive multi-modal vision processor that handles vision processing separately from template processing. +The pipeline is: Template → Human-readable prompt → Vision processor → LLM-ready inputs. +""" + +import base64 +import inspect +import math +import os +import re +import urllib.parse +import urllib.request +from copy import deepcopy +from dataclasses import dataclass +from io import BytesIO +from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union, List, Dict, Any +from abc import ABC, abstractmethod + +import numpy as np +import torch +from PIL import Image +from PIL.Image import Image as ImageObject +from transformers.image_utils import get_image_size, is_valid_image, to_numpy_array +from transformers.models.mllama.processing_mllama import ( + convert_sparse_cross_attention_mask_to_dense, + get_cross_attention_token_mask, +) +from typing_extensions import override + +if TYPE_CHECKING: + from av.stream import Stream + from numpy.typing import NDArray + from transformers import PreTrainedTokenizer, ProcessorMixin + from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor + from transformers.image_processing_utils import BaseImageProcessor + + class EncodedImage(TypedDict): + path: Optional[str] + bytes: Optional[bytes] + + ImageInput = Union[str, bytes, EncodedImage, BinaryIO, "ImageObject"] + VideoInput = Union[str, BinaryIO, list[list[ImageInput]]] + + class MMProcessor(ProcessorMixin): + patch_size: int + image_seq_length: int + num_additional_image_tokens: int + vision_feature_select_strategy: Literal["default", "full"] + + def _get_number_of_features(self, orig_height: int, orig_width: int, height: int, width: int) -> int: + pass + + +@dataclass +class VisionProcessorConfig: + """Configuration for vision processing""" + model_type: str + image_token: str + video_token: str + vision_start: str = "" + vision_end: str = "" + processor_class: str = "AutoProcessor" + expansion_strategy: str = "patch_based" + image_max_pixels: int = 16384 * 28 * 28 + image_min_pixels: int = 4 * 28 * 28 + video_max_pixels: int = 16384 * 28 * 28 + video_min_pixels: int = 4 * 28 * 28 + video_fps: float = 2.0 + video_maxlen: int = 128 + +class VisionProcessor(ABC): + """Abstract base class for vision processing strategies""" + + def __init__(self, config: VisionProcessorConfig): + self.config = config + self._validate_config() + + def _validate_config(self): + """Validate the vision configuration""" + required_fields = ['image_token', 'video_token'] + for field in required_fields: + if not hasattr(self.config, field) or getattr(self.config, field) is None: + raise ValueError(f"Missing required field: {field}") + + @abstractmethod + def preprocess_images(self, images: List["ImageInput"], processor: Any) -> Dict[str, Any]: + """Preprocess images for the model""" + pass + + @abstractmethod + def preprocess_videos(self, videos: List["VideoInput"], processor: Any) -> Dict[str, Any]: + """Preprocess videos for the model""" + pass + + @abstractmethod + def calculate_image_tokens(self, image_data: Dict[str, Any], processor: Any) -> int: + """Calculate the number of tokens needed for an image""" + pass + + @abstractmethod + def calculate_video_tokens(self, video_data: Dict[str, Any], processor: Any) -> int: + """Calculate the number of tokens needed for a video""" + pass + + @abstractmethod + def expand_vision_tokens( + self, + prompt: str, + images: List["ImageInput"], + videos: List["VideoInput"], + processor: Optional[Any], + ) -> str: + """Expand vision tokens in the prompt to their actual token representations""" + pass + + @abstractmethod + def get_mm_inputs( + self, + images: List["ImageInput"], + videos: List["VideoInput"], + processor: Optional[Any], + ) -> Dict[str, torch.Tensor]: + """Generate multi-modal inputs for the model""" + pass + + # def process_for_llm( + # self, + # prompt: str, + # images: List["ImageInput"], + # videos: List["VideoInput"], + # processor: Optional[Any], + # tokenizer: Any, + # ) -> Dict[str, torch.Tensor]: + # """ + # Complete pipeline: expand tokens and generate LLM-ready inputs. + # Returns inputs that can be used directly with model(**inputs). + # """ + # # Step 1: Expand vision tokens in the prompt + # expanded_prompt = self.expand_vision_tokens(prompt, images, videos, processor) + + # # Step 2: Tokenize the expanded prompt + # tokenized_inputs = tokenizer( + # expanded_prompt, + # return_tensors="pt", + # add_special_tokens=True, + # padding=True, + # truncation=True + # ) + + # # Step 3: Generate multi-modal inputs + # mm_inputs = self.get_mm_inputs(images, videos, processor) + + # # Step 4: Combine tokenized inputs with multi-modal inputs + # final_inputs = {**tokenized_inputs, **mm_inputs} + + # return final_inputs + + def process_for_llm( + self, + prompt: str, + elements: List[str], + mask_flags: List[bool], + images: List["ImageInput"], + videos: List["VideoInput"], + processor: Any, + tokenizer: Any, + return_tensors: str = None, + ) -> Dict[str, torch.Tensor]: + """ + Process with proper alignment of all tensors (input_ids, attention_mask, labels, action_mask). + This ensures that when vision tokens are expanded, all corresponding tensors are expanded + at the same positions, maintaining proper alignment for training and inference. + """ + import torch + + # Step 1: Tokenize elements to get base tensors with proper alignment + input_ids = [] + attention_mask = [] + labels = [] + action_mask = [] + + # Add BOS token if needed + if tokenizer.bos_token and tokenizer.add_bos_token: + input_ids.append(tokenizer.bos_token_id) + attention_mask.append(1) + labels.append(-100) + action_mask.append(0) + + # Step 2: Process each element with vision token expansion + for element, mask_flag in zip(elements, mask_flags): + # Check if element contains vision tokens + if self._contains_vision_tokens(element): + # Expand vision tokens in this element + expanded_element = self.expand_vision_tokens(element, images, videos, processor) + cur_input_ids = tokenizer.encode(expanded_element, add_special_tokens=False) + else: + cur_input_ids = tokenizer.encode(element, add_special_tokens=False) + + # Add tokens with proper alignment + input_ids.extend(cur_input_ids) + attention_mask.extend([1] * len(cur_input_ids)) + + if mask_flag: + labels.extend([-100] * len(cur_input_ids)) + action_mask.extend([0] * len(cur_input_ids)) + else: + labels.extend(cur_input_ids) + action_mask.extend([1] * len(cur_input_ids)) + + # Step 3: Create base inputs + inputs = { + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'labels': labels, + 'action_mask': action_mask + } + + # Convert to tensors if requested + if return_tensors == "pt": + inputs = {k: torch.tensor([v]) for k, v in inputs.items()} + + # Step 4: Add vision inputs + mm_inputs = self.get_mm_inputs(images, videos, processor) + inputs.update(mm_inputs) + + return inputs + + def _contains_vision_tokens(self, text: str) -> bool: + """Check if text contains vision tokens""" + return self.config.image_token in text or self.config.video_token in text + +class PatchBasedProcessor(VisionProcessor): + """Patch-based vision processor (used by Qwen-VL, LLaVA, etc.) + + Supports multiple image input formats: + - File paths (str): "/path/to/image.jpg" + - URLs (str): "https://example.com/image.jpg" + - Base64 strings (str): "data:image/jpeg;base64,/9j/4AAQ..." or raw base64 + - PIL Image objects + - Bytes objects + - File-like objects + - Dict format: {"path": "/path/to/image.jpg"} or {"bytes": b"image_data"} + """ + + def _load_image_from_input(self, image_input) -> "ImageObject": + """Load image from various input formats including URL and base64""" + from PIL import Image + + # Handle PIL Image objects directly + if hasattr(image_input, 'width') and hasattr(image_input, 'height'): + return image_input + + # Handle string inputs (file path, URL, or base64) + if isinstance(image_input, str): + # Check if it's a URL + if image_input.startswith(('http://', 'https://')): + try: + with urllib.request.urlopen(image_input) as response: + image_data = response.read() + return Image.open(BytesIO(image_data)) + except Exception as e: + raise ValueError(f"Failed to load image from URL {image_input}: {e}") + + # Check if it's a base64 string + elif image_input.startswith('data:image/') or image_input.startswith('data:application/octet-stream'): + # Handle data URL format: data:image/jpeg;base64,/9j/4AAQ... + try: + # Extract the base64 part after the comma + base64_data = image_input.split(',', 1)[1] + image_data = base64.b64decode(base64_data) + return Image.open(BytesIO(image_data)) + except Exception as e: + raise ValueError(f"Failed to decode base64 image: {e}") + + elif image_input.startswith('iVBORw0KGgo') or len(image_input) > 100: + # Likely a raw base64 string (common for PNG images starting with iVBORw0KGgo) + try: + image_data = base64.b64decode(image_input) + return Image.open(BytesIO(image_data)) + except Exception as e: + raise ValueError(f"Failed to decode base64 image: {e}") + + # Assume it's a file path + else: + return Image.open(image_input) + + # Handle bytes + elif isinstance(image_input, bytes): + return Image.open(BytesIO(image_input)) + + # Handle file-like objects + elif hasattr(image_input, 'read'): + return Image.open(image_input) + + # Handle dict format + elif isinstance(image_input, dict): + if image_input.get("bytes") is not None: + return Image.open(BytesIO(image_input["bytes"])) + elif image_input.get("path") is not None: + return Image.open(image_input["path"]) + else: + raise ValueError("Invalid image dict format") + + else: + raise ValueError(f"Unsupported image input type: {type(image_input)}") + + def _preprocess_single_image(self, image: "ImageObject", **kwargs) -> "ImageObject": + """Preprocess a single image""" + if (image.width * image.height) > self.config.image_max_pixels: + resize_factor = math.sqrt(self.config.image_max_pixels / (image.width * image.height)) + width, height = int(image.width * resize_factor), int(image.height * resize_factor) + image = image.resize((width, height)) + + if (image.width * image.height) < self.config.image_min_pixels: + resize_factor = math.sqrt(self.config.image_min_pixels / (image.width * image.height)) + width, height = int(image.width * resize_factor), int(image.height * resize_factor) + image = image.resize((width, height)) + + if image.mode != "RGB": + image = image.convert("RGB") + + return image + + def _regularize_images(self, images: List["ImageInput"]) -> List["ImageObject"]: + """Regularize images to avoid errors""" + results = [] + for image in images: + # Use the new helper method to handle all input formats + pil_image = self._load_image_from_input(image) + results.append(self._preprocess_single_image(pil_image)) + + return results + + def _regularize_videos(self, videos: List["VideoInput"]) -> List[List["ImageObject"]]: + """Regularize videos to avoid errors""" + results = [] + for video in videos: + frames: List["ImageObject"] = [] + + # Check if video is nested images + if isinstance(video, list) and all(isinstance(frame, (str, BinaryIO, dict)) for frame in video): + # Use the new image loading method for each frame + for frame in video: + try: + pil_image = self._load_image_from_input(frame) + frames.append(pil_image) + except Exception as e: + raise ValueError(f"Invalid image found in video frames: {e}") + else: + # Process actual video file + import av + container = av.open(video, "r") + video_stream = next(stream for stream in container.streams if stream.type == "video") + + # Calculate sample indices + total_frames = video_stream.frames + if total_frames == 0: # infinite video + sample_indices = np.linspace(0, self.config.video_maxlen - 1, self.config.video_maxlen).astype(np.int32) + else: + sample_frames = max(1, math.floor(float(video_stream.duration * video_stream.time_base) * self.config.video_fps)) + sample_frames = min(total_frames, self.config.video_maxlen, sample_frames) + sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32) + + container.seek(0) + for frame_idx, frame in enumerate(container.decode(video_stream)): + if frame_idx in sample_indices: + frames.append(frame.to_image()) + + frames = self._regularize_images(frames) + results.append(frames) + + return results + + def preprocess_images(self, images: List["ImageInput"], processor: Any) -> Dict[str, Any]: + """Preprocess images for the model""" + if not images: + return {} + + image_processor = getattr(processor, "image_processor", None) + if image_processor is None: + raise ValueError("Image processor not found") + + images = self._regularize_images(images) + return image_processor(images, return_tensors="pt") + + def preprocess_videos(self, videos: List["VideoInput"], processor: Any) -> Dict[str, Any]: + """Preprocess videos for the model""" + if not videos: + return {} + + video_processor = getattr(processor, "video_processor", getattr(processor, "image_processor", None)) + if video_processor is None: + raise ValueError("Video processor not found") + + videos = self._regularize_videos(videos) + + # Handle different video processor interfaces + if "videos" in inspect.signature(video_processor.preprocess).parameters: + return video_processor(images=None, videos=videos, return_tensors="pt") + else: + return video_processor(videos, return_tensors="pt") + + def calculate_image_tokens(self, image_data: Dict[str, Any], processor: Any) -> int: + """Calculate the number of tokens needed for an image + + Uses two approaches: + 1. Grid-based (HuggingFace method): Uses image_grid_thw and merge_size + - More accurate for models like Qwen-VL + - Accounts for hierarchical token merging + 2. Patch-based (fallback): Uses image dimensions and patch_size + - Standard approach for most ViT-based models + - Assumes each patch corresponds to one token + """ + if "pixel_values" in image_data: + # Try grid-based calculation first (HuggingFace method) + if "image_grid_thw" in image_data: + grid_info = image_data["image_grid_thw"] + if isinstance(grid_info, torch.Tensor): + grid_prod = grid_info.prod().item() + elif isinstance(grid_info, list): + grid_prod = math.prod(grid_info) + else: + grid_prod = grid_info + + # Get merge_size from processor + merge_size = getattr(processor, "merge_size", 1) + merge_length = merge_size ** 2 + + num_image_tokens = grid_prod // merge_length + return max(1, num_image_tokens) + + # Fallback to patch-based calculation + height, width = get_image_size(to_numpy_array(image_data["pixel_values"][0])) + image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + if hasattr(processor, 'num_additional_image_tokens'): + image_seqlen += processor.num_additional_image_tokens + if hasattr(processor, 'vision_feature_select_strategy') and processor.vision_feature_select_strategy == "default": + image_seqlen -= 1 + return image_seqlen + return 1 + + def calculate_video_tokens(self, video_data: Dict[str, Any], processor: Any) -> int: + """Calculate the number of tokens needed for a video""" + if "pixel_values" in video_data: + # For videos, we need to calculate based on frames + video_tensor = video_data["pixel_values"][0] + if len(video_tensor.shape) > 3: # Has frame dimension + num_frames = video_tensor.shape[0] + height, width = get_image_size(to_numpy_array(video_tensor[0])) + frame_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + if hasattr(processor, 'num_additional_image_tokens'): + frame_seqlen += processor.num_additional_image_tokens + if hasattr(processor, 'vision_feature_select_strategy') and processor.vision_feature_select_strategy == "default": + frame_seqlen -= 1 + return frame_seqlen * num_frames + else: + # Single frame video + return self.calculate_image_tokens(video_data, processor) + return 1 + + def expand_vision_tokens( + self, + prompt: str, + images: List["ImageInput"], + videos: List["VideoInput"], + processor: Optional[Any], + ) -> str: + """Expand vision tokens in the prompt to their actual token representations""" + if processor is None: + raise ValueError("Processor is required for vision processing") + + # Validate that number of placeholders matches number of inputs + num_image_placeholders = prompt.count(self.config.image_token) + num_video_placeholders = prompt.count(self.config.video_token) + + if len(images) != num_image_placeholders: + raise ValueError(f"Number of images ({len(images)}) doesn't match placeholders ({num_image_placeholders})") + if len(videos) != num_video_placeholders: + raise ValueError(f"Number of videos ({len(videos)}) doesn't match placeholders ({num_video_placeholders})") + + # Preprocess images and videos to get individual token counts + processed_images = self.preprocess_images(images, processor) if images else {} + processed_videos = self.preprocess_videos(videos, processor) if videos else {} + + # Expand image tokens using regex to avoid infinite loops + expanded_prompt = prompt + if self.config.image_token in expanded_prompt: + if processed_images and "pixel_values" in processed_images: + # Calculate tokens for this specific image + image_tokens = self.calculate_image_tokens(processed_images, processor) + replacement = self.config.image_token * image_tokens + else: + replacement = self.config.image_token + + # Use regex to replace all occurrences at once + import re + expanded_prompt = re.sub(re.escape(self.config.image_token), replacement, expanded_prompt) + + # Expand video tokens using regex to avoid infinite loops + if self.config.video_token in expanded_prompt: + if processed_videos and "pixel_values" in processed_videos: + # Calculate tokens for this specific video + video_tokens = self.calculate_video_tokens(processed_videos, processor) + replacement = self.config.video_token * video_tokens + else: + replacement = self.config.video_token + + # Use regex to replace all occurrences at once + expanded_prompt = re.sub(re.escape(self.config.video_token), replacement, expanded_prompt) + + return expanded_prompt + + def get_mm_inputs( + self, + images: List["ImageInput"], + videos: List["VideoInput"], + processor: Optional[Any], + ) -> Dict[str, torch.Tensor]: + """Generate multi-modal inputs for the model""" + mm_inputs = {} + + # Process images + if images: + mm_inputs.update(self.preprocess_images(images, processor)) + + # Process videos + if videos: + mm_inputs.update(self.preprocess_videos(videos, processor)) + + return mm_inputs + +class QwenVLProcessor(PatchBasedProcessor): + """Qwen-VL specific processor with custom image preprocessing""" + + def _preprocess_single_image(self, image: "ImageObject", **kwargs) -> "ImageObject": + """Qwen-VL specific image preprocessing""" + image = super()._preprocess_single_image(image, **kwargs) + + # Qwen-VL specific adjustments + if min(image.width, image.height) < 28: + width, height = max(image.width, 28), max(image.height, 28) + image = image.resize((width, height)) + + if image.width / image.height > 200: + width, height = image.height * 180, image.height + image = image.resize((width, height)) + + if image.height / image.width > 200: + width, height = image.width, image.width * 180 + image = image.resize((width, height)) + + return image + + def calculate_image_tokens(self, image_data: Dict[str, Any], processor: Any) -> int: + """Qwen-VL specific token calculation using grid-based approach""" + if "image_grid_thw" in image_data: + # Use grid information for more accurate token calculation + grid_info = image_data["image_grid_thw"] + if isinstance(grid_info, torch.Tensor): + grid_prod = grid_info.prod().item() + elif isinstance(grid_info, list): + grid_prod = math.prod(grid_info) + else: + grid_prod = grid_info + + # Get merge_size from processor (Qwen-VL typically uses merge_size=2) + merge_size = getattr(processor, "merge_size", 2) + merge_length = merge_size ** 2 + + num_image_tokens = grid_prod // merge_length + return max(1, num_image_tokens) + + # Fallback to standard calculation + return super().calculate_image_tokens(image_data, processor) + + def expand_vision_tokens( + self, + prompt: str, + images: List["ImageInput"], + videos: List["VideoInput"], + processor: Optional[Any], + ) -> str: + """Qwen-VL specific token expansion with vision tags""" + expanded_prompt = super().expand_vision_tokens(prompt, images, videos, processor) + + return expanded_prompt + +class LlavaProcessor(PatchBasedProcessor): + """LLaVA specific processor""" + + def calculate_image_tokens(self, image_data: Dict[str, Any], processor: Any) -> int: + """LLaVA specific token calculation""" + if "pixel_values" in image_data: + height, width = get_image_size(to_numpy_array(image_data["pixel_values"][0])) + image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + if hasattr(processor, 'num_additional_image_tokens'): + image_seqlen += processor.num_additional_image_tokens + if hasattr(processor, 'vision_feature_select_strategy') and processor.vision_feature_select_strategy == "default": + image_seqlen -= 1 + return image_seqlen + return 1 + + +VISION_PROCESSORS: Dict[str, VisionProcessor] = {} + +model_type_to_processor_class = { + "qwen_vl": QwenVLProcessor, + "llava": LlavaProcessor, + "gemma3": PatchBasedProcessor, + "paligemma": PatchBasedProcessor, + "internvl": PatchBasedProcessor, + "minicpm": PatchBasedProcessor, + "mllama": PatchBasedProcessor, + "pixtral": PatchBasedProcessor, + "video_llava": PatchBasedProcessor, + "patch_based": PatchBasedProcessor, +} + +def register_processor(template_name: str, config: VisionProcessorConfig): + """Register a vision processor for a template""" + processor_class = model_type_to_processor_class.get(config.model_type) + if processor_class is None: + raise ValueError(f"No processor class found for model type: {config.model_type}") + VISION_PROCESSORS[template_name] = processor_class(config) + + +def register(cls, template_name: str, config: VisionProcessorConfig, processor_class: type = None): + """Register a vision processor for a template""" + if processor_class is not None: + # If processor_class is provided, use it directly + VISION_PROCESSORS[template_name] = processor_class(config) + else: + # Use the global register_processor function + register_processor(template_name, config) + +def get_processor(template_name: str) -> Optional[VisionProcessor]: + """Get vision processor for a template""" + return VISION_PROCESSORS.get(template_name) + +def get_processor_config(template_name: str) -> Optional[VisionProcessorConfig]: + """Get vision config for a template""" + processor = get_processor(template_name) + return processor.config if processor else None + +def is_vision_template(template_name: str) -> bool: + """Check if template supports vision""" + return template_name in VISION_PROCESSORS + +def list_vision_templates() -> List[str]: + """List all vision-enabled templates""" + return list(VISION_PROCESSORS.keys()) diff --git a/agents/tests/unit/agents/prompts/test_qwen25_vl_prompt.py b/agents/tests/unit/agents/prompts/test_qwen25_vl_prompt.py deleted file mode 100644 index 0376b49..0000000 --- a/agents/tests/unit/agents/prompts/test_qwen25_vl_prompt.py +++ /dev/null @@ -1,138 +0,0 @@ -from agents.agents.agents.templates.templates import get_conv_template -from agents.agents.templates.utils import compare_hf_template, format_conversation -from transformers import AutoProcessor, AutoTokenizer - -def test_simple_prompts(): - messages = [ - { - "role": "user", - "content": [ - { - "type": "image", - "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", - }, - {"type": "text", "text": "Describe this image."}, - ], - }, - { - "role": "assistant", - "content": [ - { - "type": "text", - "text": "The image is a cat.", - }, - ], - }, - { - "role": "user", - "content": [ - { - "type": "text", - "text": "What is in the image?", - }, - ], - } - ] - processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") - is_equal, official_prompt, implemented_prompt, highlighted_prompt = compare_hf_template(processor, "qwen2.5-vl", messages=messages, add_generation_prompt=True) - assert is_equal, f"Official prompt:\n\n{official_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}" - print(f"Highlighted prompt:\n\n{highlighted_prompt}") - - -def test_simple_prompts_with_system(): - messages = [ - { - "role": "system", - "content": "You are a multi-modal assistant that can answer questions about images.", - }, - { - "role": "user", - "content": [ - { - "type": "image", - "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", - }, - {"type": "text", "text": "Describe this image."}, - ], - }, - { - "role": "assistant", - "content": [ - { - "type": "text", - "text": "The image is a cat.", - }, - ], - }, - { - "role": "user", - "content": [ - { - "type": "text", - "text": "What is in the image?", - }, - ], - } - ] - processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") - is_equal, official_prompt, implemented_prompt, highlighted_prompt = compare_hf_template(processor, "qwen2.5-vl", messages=messages, add_generation_prompt=True) - assert is_equal, f"Official prompt:\n\n{official_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}" - print(f"Highlighted prompt:\n\n{highlighted_prompt}") - - -def test_jinja_template(): - conv = get_conv_template("qwen2.5-vl") - jinja_template = conv.get_jinja_template() - tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") - messages = [ - { - "role": "system", - "content": "You are a multi-modal assistant that can answer questions about images.", - }, - { - "role": "user", - "content": [ - { - "type": "image", - "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", - }, - {"type": "text", "text": "Describe this image."}, - ], - }, - { - "role": "assistant", - "content": [ - { - "type": "text", - "text": "The image is a cat.", - }, - ], - }, - { - "role": "tool", - "content": [ - { - "type": "text", - "text": "Example tool response.", - }, - ], - }, - { - "role": "user", - "content": [ - { - "type": "text", - "text": "What is in the image?", - }, - ], - } - ] - official_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) - tokenizer.chat_template = jinja_template - prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) - print(official_prompt) - print(prompt) - - conv = format_conversation(messages, "qwen2.5-vl", add_generation_prompt=True) - implemented_prompt = conv.get_prompt() - print(implemented_prompt) \ No newline at end of file diff --git a/agents/tests/unit/agents/prompts/test_qwen3_prompt.py b/agents/tests/unit/agents/prompts/test_qwen3_prompt.py deleted file mode 100644 index f11785e..0000000 --- a/agents/tests/unit/agents/prompts/test_qwen3_prompt.py +++ /dev/null @@ -1,81 +0,0 @@ -from agents.agents.agents.templates.templates import get_conv_template -from agents.agents.templates.utils import compare_hf_template - -def test_simple_prompts(): - messages = [ - {"role": "user", "content": "Hello, how are you?"}, - {"role": "assistant", "content": "I am fine, thank you."}, - {"role": "user", "content": "Want to play a game?"}, - {"role": "assistant", "content": "Sure, what game?"}, - {"role": "user", "content": "What is 3 times 5?"}, - {"role": "assistant", "content": '''{"name": "multiply", "arguments": {"x": 3, "y": 5}}'''}, - {"role": "tool", "content": "15"}, - ] - qwen3_model = "Qwen/Qwen3-30B-A3B" - is_equal, official_prompt, implemented_prompt, highlighted_prompt = compare_hf_template(qwen3_model, "qwen3", messages=messages, add_generation_prompt=True) - assert is_equal, f"Official prompt:\n\n{official_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}" - print(f"Highlighted prompt:\n\n{highlighted_prompt}") - - -def test_simple_prompts_with_system(): - messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hello, how are you?"}, - {"role": "assistant", "content": "I am fine, thank you."}, - {"role": "user", "content": "Want to play a game?"}, - {"role": "assistant", "content": "Sure, what game?"}, - {"role": "user", "content": "What is 3 times 5?"}, - {"role": "assistant", "content": '''{"name": "multiply", "arguments": {"x": 3, "y": 5}}'''}, - {"role": "tool", "content": "15"}, - ] - qwen3_model = "Qwen/Qwen3-30B-A3B" - is_equal, official_prompt, implemented_prompt, highlighted_prompt = compare_hf_template(qwen3_model, "qwen3", messages=messages, add_generation_prompt=True) - assert is_equal, f"Official prompt:\n\n{official_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}" - print(f"Highlighted prompt:\n\n{highlighted_prompt}") - - -def test_with_tools(): - multiply_schema = {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}} - add_schema = {"type": "function", "function": {"name": "add", "description": "A function that add two numbers", "parameters": {"type": "object", "properties": {"a": {"type": "number", "description": "The first number to add"}, "b": {"type": "number", "description": "The second number to add"}}, "required": ["a", "b"]}}} - tools = [multiply_schema, add_schema] - messages = [ - {"role": "user", "content": "Hello, how are you?"}, - {"role": "assistant", "content": "I am fine, thank you."}, - {"role": "user", "content": "Want to play a game?"}, - {"role": "assistant", "content": "Sure, what game?"}, - {"role": "user", "content": "What is 3 times 5?"}, - {"role": "assistant", "content": '''{"name": "multiply", "arguments": {"x": 3, "y": 5}}'''}, - {"role": "tool", "content": "15"}, - ] - qwen3_model = "Qwen/Qwen3-30B-A3B" - is_equal, official_prompt, implemented_prompt, highlighted_prompt = compare_hf_template(qwen3_model, "qwen3", messages=messages, tools=tools, add_generation_prompt=True) - assert is_equal, f"Official prompt:\n\n{official_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}" - print(f"Highlighted prompt:\n\n{highlighted_prompt}") - - -def test_with_system_message(): - messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hello, how are you?"}, - {"role": "assistant", "content": "I am fine, thank you."}, - {"role": "user", "content": "What is 3 times 5?"}, - ] - qwen3_model = "Qwen/Qwen3-30B-A3B" - is_equal, official_prompt, implemented_prompt, highlighted_prompt = compare_hf_template(qwen3_model, "qwen3", messages=messages, enable_thinking=False, add_generation_prompt=True) - assert is_equal, f"Official prompt:\n\n{official_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}" - print(f"Highlighted prompt:\n\n{highlighted_prompt}") - -def test_with_system_message_and_tools(): - multiply_schema = {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}} - add_schema = {"type": "function", "function": {"name": "add", "description": "A function that add two numbers", "parameters": {"type": "object", "properties": {"a": {"type": "number", "description": "The first number to add"}, "b": {"type": "number", "description": "The second number to add"}}, "required": ["a", "b"]}}} - tools = [multiply_schema, add_schema] - messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hello, how are you?"}, - {"role": "assistant", "content": "I am fine, thank you."}, - {"role": "user", "content": "What is 3 times 5?"}, - ] - qwen3_model = "Qwen/Qwen3-30B-A3B" - is_equal, official_prompt, implemented_prompt, highlighted_prompt = compare_hf_template(qwen3_model, "qwen3", messages=messages, tools=tools, enable_thinking=False,add_generation_prompt=True) - assert is_equal, f"Official prompt:\n\n{official_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}" - print(f"Highlighted prompt:\n\n{highlighted_prompt}") \ No newline at end of file diff --git a/agents/tests/unit/agents/prompts/test_template_tokenize.py b/agents/tests/unit/agents/prompts/test_template_tokenize.py deleted file mode 100644 index ff22ba1..0000000 --- a/agents/tests/unit/agents/prompts/test_template_tokenize.py +++ /dev/null @@ -1,62 +0,0 @@ -from agents.agents.templates.utils import is_vlm_template, tokenize_conversation -import pytest -from transformers import AutoTokenizer, AutoProcessor -import torch -from agents.agents.templates.templates import Chat - -@pytest.mark.parametrize("template", ["qwen2.5"]) -@pytest.mark.parametrize("messages", [ - [ - {"role": "user", "content": "Hello, how are you?"}, - {"role": "assistant", "content": "I am fine, thank you."}, - {"role": "user", "content": "Want to play a game?"}, - {"role": "assistant", "content": "Sure, what game?"}, - ], - [ - {"role": "user", "content": "Help me to calculate 3 times 5."}, - {"role": "assistant", "content": '''{"name": "multiply", "arguments": {"x": 3, "y": 5}}'''}, - {"role": "tool", "content": "15"}, - ], - [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hello, how are you?"}, - {"role": "assistant", "content": "I am fine, thank you."}, - {"role": "user", "content": "What is 3 times 5?"}, - ], -]) -@pytest.mark.parametrize("tools", [ - None, - # [ - # {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}}, - # {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}}, - # ] -]) -@pytest.mark.parametrize("add_generation_prompt", [False]) -def test_template_tokenize(template, messages, tools, add_generation_prompt): - template_tokenizer_mapping = { - "qwen2.5": "Qwen/Qwen2.5-3B-Instruct", - "qwen2.5-vl": "Qwen/Qwen2.5-VL-3B-Instruct", - "qwen3": "Qwen/Qwen3-8B", - } - tokenizer = AutoTokenizer.from_pretrained(template_tokenizer_mapping[template]) - if is_vlm_template(template): - processor = AutoProcessor.from_pretrained(template_tokenizer_mapping[template]) - else: - processor = None - try: - official_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=add_generation_prompt) - official_inputs = tokenizer(official_prompt, return_tensors="pt") - # chat = Chat(template, messages, tokenizer) - # implemented_inputs = chat.tokenize() - implemented_inputs = tokenize_conversation(messages, tokenizer, template, max_length=2048, processor=processor, tools=tools, add_generation_prompt=add_generation_prompt, return_tensors="pt") - - assert torch.equal(official_inputs["input_ids"], implemented_inputs["input_ids"]), f"template: {template}\n\nmessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nofficial_prompt: {official_prompt}\n\nimplemented_prompt: {tokenizer.decode(implemented_inputs['input_ids'][0])}\n\nofficial_inputs: {official_inputs}\n\nimplemented_inputs: {implemented_inputs}" - assert torch.equal(official_inputs["attention_mask"], implemented_inputs["attention_mask"]) - except Exception as e: - if isinstance(e, ValueError) and "does not support tool calling." in str(e) and tools is not None and template in ["qwen2.5-vl"]: - pass - else: - raise e - # assert torch.equal(official_inputs["labels"], implemented_inputs["labels"]) - - diff --git a/agents/tests/unit/agents/prompts/test_think_prompt.py b/agents/tests/unit/agents/prompts/test_think_prompt.py deleted file mode 100644 index 94a3ec2..0000000 --- a/agents/tests/unit/agents/prompts/test_think_prompt.py +++ /dev/null @@ -1,29 +0,0 @@ -import openai - -def test_think_prompt(): - client = openai.OpenAI(api_key="token-123", base_url="http://0.0.0.0:8000/v1") - tools = [ - { - "type": "function", - "function": { - "name": "code_interpreter", - "description": "A python code interpreter that can execute code and return the result", - "parameters": { - "type": "object", - "properties": { - "code": {"type": "string", "description": "The python code to execute"}, - }, - }, - } - } - ] - response = client.chat.completions.create( - # model="/mnt/sharefs/users/haonan.li/models/Qwen2.5-7B-instruct-am_think_v1_distilled", - model="Qwen/Qwen2.5-7B-Instruct", - tools=tools, - messages=[ - {"role": "user", "content": "Every morning Aya goes for a $9$-kilometer-long walk and stops at a coffee shop afterwards. When she walks at a constant speed of $s$ kilometers per hour, the walk takes her 4 hours, including $t$ minutes spent in the coffee shop. When she walks $s+2$ kilometers per hour, the walk takes her 2 hours and 24 minutes, including $t$ minutes spent in the coffee shop. Suppose Aya walks at $s+\frac{1}{2}$ kilometers per hour. Find the number of minutes the walk takes her, including the $t$ minutes spent in the coffee shop."}, - ], - max_tokens=8192 - ) - print(response.choices[0].message.content) \ No newline at end of file diff --git a/agents/tests/unit/agents/templates/test_qwen25_vl_prompt.py b/agents/tests/unit/agents/templates/test_qwen25_vl_prompt.py new file mode 100644 index 0000000..e69de29 diff --git a/agents/tests/unit/agents/templates/test_qwen3_prompt.py b/agents/tests/unit/agents/templates/test_qwen3_prompt.py new file mode 100644 index 0000000..e69de29 diff --git a/agents/tests/unit/agents/templates/test_template_utilities.py b/agents/tests/unit/agents/templates/test_template_utilities.py new file mode 100644 index 0000000..47cba54 --- /dev/null +++ b/agents/tests/unit/agents/templates/test_template_utilities.py @@ -0,0 +1,33 @@ +from agents.agents.templates.templates import get_template, register_template, Template +from agents.agents.templates.vision_processor import get_processor + +def test_template_registration(): + register_template( + Template( + name="test", + system_template="", + user_template="", + assistant_template="", + stop_words=[] + ) + ) + assert get_template("test") is not None + assert get_template("test").name == "test" + +def test_template_registration_with_vision(): + register_template( + Template( + name="test-vl", + system_template="", + user_template="", + assistant_template="", + stop_words=[], + image_token="<|image_pad|>", + ) + ) + assert get_processor("test-vl") is not None + assert get_processor("test-vl").config.image_token == "<|image_pad|>" + + + + \ No newline at end of file diff --git a/agents/tests/unit/agents/templates/test_text_templates_full_align.py b/agents/tests/unit/agents/templates/test_text_templates_full_align.py new file mode 100644 index 0000000..e69de29 diff --git a/agents/tests/unit/agents/prompts/test_templates.py b/agents/tests/unit/agents/templates/test_text_templates_tokenize.py similarity index 50% rename from agents/tests/unit/agents/prompts/test_templates.py rename to agents/tests/unit/agents/templates/test_text_templates_tokenize.py index 068f832..d17235b 100644 --- a/agents/tests/unit/agents/prompts/test_templates.py +++ b/agents/tests/unit/agents/templates/test_text_templates_tokenize.py @@ -1,8 +1,19 @@ -from agents.agents.templates.utils import compare_hf_template -from transformers import AutoTokenizer +""" This file is for testing the tokenization of the templates. The templates should align on following aspects: + - The tokenized prompt should be the same as the one obtained from HF template with all the following options: + - add_generation_prompt + - tools + - We need to observe the labels and action_mask to make sure the the they are correct. + +Since the align for textual prompt is already tested in other files, we only need to test the tokenization of the templates. +""" + +from agents.agents.templates.utils import is_vlm_template, tokenize_conversation import pytest +from transformers import AutoTokenizer, AutoProcessor +import torch +from agents.agents.templates.templates import Chat -@pytest.mark.parametrize("template", ["qwen2.5-think", "qwen2.5", "qwen2.5-no-tool"]) +@pytest.mark.parametrize("template", ["qwen2.5"]) @pytest.mark.parametrize("messages", [ [ {"role": "user", "content": "Hello, how are you?"}, @@ -29,21 +40,20 @@ {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}}, ] ]) -@pytest.mark.parametrize("add_generation_prompt", [True, False]) -def test_chat_template_equal(template, messages, tools, add_generation_prompt): - # Filter invalid combinations - if add_generation_prompt and messages[-1]['role'] == 'assistant': - return - +@pytest.mark.parametrize("add_generation_prompt", [False, True]) +def test_template_tokenize(template, messages, tools, add_generation_prompt): template_tokenizer_mapping = { "qwen2.5": "Qwen/Qwen2.5-3B-Instruct", - "qwen2.5-think": "Qwen/Qwen2.5-3B-Instruct", - "qwen2.5-no-tool": "Qwen/Qwen2.5-3B-Instruct", + "qwen2.5-vl": "Qwen/Qwen2.5-VL-3B-Instruct", + "qwen3": "Qwen/Qwen3-8B", } tokenizer = AutoTokenizer.from_pretrained(template_tokenizer_mapping[template]) - is_equal, is_equal_between_implemented_prompts, is_equal_between_jinja_prompts, official_prompt, implemented_prompt, implemented_jinja_prompt, highlighted_prompt = compare_hf_template(tokenizer, template, messages=messages, tools=tools,add_generation_prompt=add_generation_prompt) - # assert is_equal, print(f"Template: {template}\n\nMessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nOfficial prompt:\n\n{official_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}") - assert is_equal_between_jinja_prompts, print(f"Template: {template}\n\nMessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}\n\nJinja prompt:\n\n{implemented_jinja_prompt}") - print(f"Highlighted prompt:\n\n{highlighted_prompt}") + chat = Chat(template, messages, tools=tools) + prompt = chat.prompt(add_generation_prompt=add_generation_prompt, tools=tools) + + hf_inputs = tokenizer(prompt, return_tensors="pt") + + implemented_inputs = tokenize_conversation(messages, tokenizer, template, max_length=2048, tools=tools, add_generation_prompt=add_generation_prompt, return_tensors="pt") + assert torch.equal(hf_inputs["input_ids"], implemented_inputs["input_ids"]), f"template: {template}\n\nmessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nprompt: {prompt}\n\nimplemented_prompt: {tokenizer.decode(implemented_inputs['input_ids'][0])}\n\nhf_inputs: {hf_inputs}\n\nimplemented_inputs: {implemented_inputs}" diff --git a/agents/tests/unit/agents/templates/test_think_prompt.py b/agents/tests/unit/agents/templates/test_think_prompt.py new file mode 100644 index 0000000..e69de29 diff --git a/agents/tests/unit/agents/templates/test_vision_templates_full_align.py b/agents/tests/unit/agents/templates/test_vision_templates_full_align.py new file mode 100644 index 0000000..167307d --- /dev/null +++ b/agents/tests/unit/agents/templates/test_vision_templates_full_align.py @@ -0,0 +1,87 @@ +""" This file is for testing the vision templates that align seamlessly with HF templates. The templates should align on following aspects: + - The obtained textual prompt should be the same as the one obtained from HF template with all the following options: + - add_generation_prompt + - tools + - The obtained textual prompt should be the same as the one obtained from Jinja template with all the following options: + - add_generation_prompt + - tools +To test vision part, the messages should contain at least one image. +""" + + +from agents.agents.templates.utils import compare_hf_template +from transformers import AutoTokenizer +import pytest +# "qwen2.5-think", "qwen2.5", "qwen2.5-no-tool", +@pytest.mark.parametrize("template", ["qwen2.5-vl"]) +@pytest.mark.parametrize("messages", [ + [ + { + "role": "system", + "content": "You are a multi-modal assistant that can answer questions about images.", + }, + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", + }, + {"type": "text", "text": "Describe this image."}, + ], + }, + { + "role": "assistant", + "content": [ + { + "type": "text", + "text": "The image is a cat.", + }, + ], + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What is in the image?", + }, + ], + } + ], + [ + {"role": "user", "content": "Help me to calculate 3 times 5."}, + {"role": "assistant", "content": '''{"name": "multiply", "arguments": {"x": 3, "y": 5}}'''}, + {"role": "tool", "content": "15"}, + ], + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I am fine, thank you."}, + {"role": "user", "content": "What is 3 times 5?"}, + ], +]) +@pytest.mark.parametrize("tools", [ + None, + [ + {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}}, + {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}}, + ] +]) +@pytest.mark.parametrize("add_generation_prompt", [True, False]) +def test_chat_template_equal(template, messages, tools, add_generation_prompt): + # Filter invalid combinations + if add_generation_prompt and messages[-1]['role'] == 'assistant': + return + + template_tokenizer_mapping = { + "qwen2.5-vl": "Qwen/Qwen2.5-VL-3B-Instruct", + } + tokenizer = AutoTokenizer.from_pretrained(template_tokenizer_mapping[template]) + + is_equal, is_equal_between_implemented_prompts, is_equal_between_jinja_prompts, official_prompt, implemented_prompt, implemented_jinja_prompt, highlighted_prompt = compare_hf_template(tokenizer, template, messages=messages, tools=tools,add_generation_prompt=add_generation_prompt) + assert is_equal, f"Template: {template}\n\nMessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nOfficial prompt:\n\n{official_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}" + assert is_equal_between_jinja_prompts, f"Template: {template}\n\nMessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}\n\nJinja prompt:\n\n{implemented_jinja_prompt}" + print(f"Official prompt:\n\n{official_prompt}") + print(f"Highlighted prompt:\n\n{highlighted_prompt}") + diff --git a/agents/tests/unit/agents/templates/test_vision_templates_tokenize.py b/agents/tests/unit/agents/templates/test_vision_templates_tokenize.py new file mode 100644 index 0000000..54d1714 --- /dev/null +++ b/agents/tests/unit/agents/templates/test_vision_templates_tokenize.py @@ -0,0 +1,107 @@ +""" This file is for testing the text templates that align seamlessly with HF templates. The templates should align on following aspects: + - The obtained textual prompt should be the same as the one obtained from HF template with all the following options: + - add_generation_prompt + - tools + - The obtained textual prompt should be the same as the one obtained from Jinja template with all the following options: + - add_generation_prompt + - tools +""" + + +from agents.agents.templates.templates import Chat +from agents.agents.templates.utils import compare_hf_template, tokenize_conversation +from transformers import AutoTokenizer +import pytest +import torch +from transformers import AutoProcessor +from qwen_vl_utils import process_vision_info + +@pytest.mark.parametrize("template", ["qwen2.5-vl"]) +@pytest.mark.parametrize("messages", [ + [ + { + "role": "system", + "content": "You are a multi-modal assistant that can answer questions about images.", + }, + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", + }, + {"type": "text", "text": "Describe this image."}, + ], + }, + { + "role": "assistant", + "content": [ + { + "type": "text", + "text": "The image is a cat.", + }, + ], + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What is in the image?", + }, + ], + } + ], + # [ + # {"role": "user", "content": "Help me to calculate 3 times 5."}, + # {"role": "assistant", "content": '''{"name": "multiply", "arguments": {"x": 3, "y": 5}}'''}, + # {"role": "tool", "content": "15"}, + # ], + # [ + # {"role": "system", "content": "You are a helpful assistant."}, + # {"role": "user", "content": "Hello, how are you?"}, + # {"role": "assistant", "content": "I am fine, thank you."}, + # {"role": "user", "content": "What is 3 times 5?"}, + # ], +]) +@pytest.mark.parametrize("tools", [ + None, + # [ + # {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}}, + # {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}}, + # ] +]) +@pytest.mark.parametrize("add_generation_prompt", [True, False]) +def test_chat_template_equal(template, messages, tools, add_generation_prompt): + template_tokenizer_mapping = { + "qwen2.5-vl": "Qwen/Qwen2.5-VL-3B-Instruct", + } + tokenizer = AutoTokenizer.from_pretrained(template_tokenizer_mapping[template]) + processor = AutoProcessor.from_pretrained(template_tokenizer_mapping[template]) + official_prompt = tokenizer.apply_chat_template(messages + , tokenize=False, add_generation_prompt=add_generation_prompt) + image_inputs, video_inputs = process_vision_info(messages) + official_inputs = processor( + text=[official_prompt], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ) + + implemented_inputs = tokenize_conversation(messages, tokenizer, template, max_length=8192, tools=tools, add_generation_prompt=add_generation_prompt, return_tensors="pt", processor=processor) + + official_prompt = tokenizer.decode(official_inputs['input_ids'][0]) + implemented_prompt = tokenizer.decode(implemented_inputs['input_ids'][0]) + print(f"Official prompt image tokens: {official_prompt.count('<|image_pad|>')}\nImplemented prompt image tokens: {implemented_prompt.count('<|image_pad|>')}") + + print(f"Official images: {official_inputs['pixel_values'].shape}\nImplemented images: {implemented_inputs['pixel_values'].shape}") + + assert torch.equal(official_inputs["input_ids"], implemented_inputs["input_ids"]), f"""Offical + prompt:\n{official_prompt}\nImplemented prompt:\n{implemented_prompt}""" + + assert torch.equal(official_inputs["pixel_values"], implemented_inputs["pixel_values"]) + + assert torch.equal(official_inputs["image_grid_thw"], implemented_inputs["image_grid_thw"]) + + print(f"official_prompt: {official_prompt}\nimplemented_prompt: {tokenizer.decode(implemented_inputs['input_ids'][0])}\nofficial_inputs: {official_inputs.keys()}\nimplemented_inputs: {implemented_inputs.keys()}\n") \ No newline at end of file From d3afba2c2d8f62a0510ef78f5df305ddb5c54141 Mon Sep 17 00:00:00 2001 From: Reason-Wang Date: Sat, 9 Aug 2025 07:32:41 +0000 Subject: [PATCH 3/6] Test multi-modal training --- agents/agents/agents/agent_base.py | 5 +- agents/agents/agents/llm_backend.py | 11 +- agents/agents/agents/llm_backend.py:6 | 0 agents/agents/agents/llm_backend.py:6: | 0 agents/agents/agents/templates/mm_plugin.py | 1912 ----------------- agents/agents/agents/templates/templates.py | 31 +- .../agents/agents/templates/test_alignment.py | 170 -- agents/agents/agents/templates/utils.py | 4 +- .../agents/templates/vision_processor.py | 17 + agents/agents/rewards/qa_reward.py | 15 + .../test_vision_templates_tokenize.py | 3 + 11 files changed, 63 insertions(+), 2105 deletions(-) delete mode 100644 agents/agents/agents/llm_backend.py:6 delete mode 100644 agents/agents/agents/llm_backend.py:6: delete mode 100644 agents/agents/agents/templates/mm_plugin.py delete mode 100644 agents/agents/agents/templates/test_alignment.py diff --git a/agents/agents/agents/agent_base.py b/agents/agents/agents/agent_base.py index 67c9a6b..c29997e 100644 --- a/agents/agents/agents/agent_base.py +++ b/agents/agents/agents/agent_base.py @@ -15,6 +15,7 @@ import os import transformers import warnings +import logging from .chain.streaming_observer import ConsoleStreamObserver, StreamingManager try: from verl.protocol import DataProto @@ -22,6 +23,7 @@ print("verl can not be imported.") pass +Logger = logging.getLogger(__name__) class BaseAgent(ChainGeneration, ABC): """ @@ -107,10 +109,11 @@ def _init_llm_engine(self, model_name_or_path: str, backend: str): processor = None return llm_engine, tokenizer, processor - def set_llm_engine(self, llm_engine: Any, tokenizer: Any): + def set_llm_engine(self, llm_engine: Any, tokenizer: Any, processor: Any): assert self.backend == "async_verl", "Only async verl backend is supported for now" self.llm_engine.llm_engine = llm_engine self.tokenizer = tokenizer + self.processor = processor def generate(self, messages_list_or_inputs: List[List[Dict]], **args): return self.llm_engine.generate(messages_list_or_inputs, **args) diff --git a/agents/agents/agents/llm_backend.py b/agents/agents/agents/llm_backend.py index 2401137..8837bfd 100644 --- a/agents/agents/agents/llm_backend.py +++ b/agents/agents/agents/llm_backend.py @@ -20,6 +20,7 @@ from vllm import LLM, AsyncLLMEngine, SamplingParams, AsyncEngineArgs import openai from .templates.templates import Chat +from .templates.vision_processor import get_processor import logging import PIL @@ -46,8 +47,8 @@ def apply_chat_template(self, messages_list: List[List[Dict]], template: str, ad for messages in messages_list: chat = Chat(template, messages) prompts.append(chat.prompt(add_generation_prompt=add_generation_prompt, tools=tools)) - # We don't support vision inputs for now - vision_inputs.append(None) + # We only support image inputs for now + vision_inputs.append(chat.vision_inputs()) return prompts, vision_inputs @@ -417,12 +418,6 @@ async def generate_async(self, messages_list: str, **kwargs) -> str: gen_batch_output = await self.llm_engine.generate_sequences_async(batch, **generation_config) response_texts = gen_batch_output.batch['responses'].tolist() # np.array of strings with length BS - # print(f"[AsyncVerlbackend] response_texts: {response_texts.shape} {type(response_texts)}") - # print(f"[AsyncVerlBackend] response_texts: {response_texts[0]}") - # raise NotImplementedError("Async Verl backend does not support sync generation") - # response_texts = [text for text in response_texts] - # response_texts = self.tokenizer.batch_decode(responses, skip_special_tokens=True) # List of string with length BS - # response_texts = responses[:len(prompts)*n] return response_texts diff --git a/agents/agents/agents/llm_backend.py:6 b/agents/agents/agents/llm_backend.py:6 deleted file mode 100644 index e69de29..0000000 diff --git a/agents/agents/agents/llm_backend.py:6: b/agents/agents/agents/llm_backend.py:6: deleted file mode 100644 index e69de29..0000000 diff --git a/agents/agents/agents/templates/mm_plugin.py b/agents/agents/agents/templates/mm_plugin.py deleted file mode 100644 index be8cea9..0000000 --- a/agents/agents/agents/templates/mm_plugin.py +++ /dev/null @@ -1,1912 +0,0 @@ -# Copyright 2025 HuggingFace Inc. and the LlamaFactory team. -# -# This code is inspired by the HuggingFace's Transformers library. -# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/processing_llava.py -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -import math -import os -import re -from copy import deepcopy -from dataclasses import dataclass -from io import BytesIO -from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union - -import numpy as np -import torch -from transformers.image_utils import get_image_size, is_valid_image, to_numpy_array -from transformers.models.mllama.processing_mllama import ( - convert_sparse_cross_attention_mask_to_dense, - get_cross_attention_token_mask, -) -from typing_extensions import override - -from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER -from ..extras.packages import ( - is_librosa_available, - is_pillow_available, - is_pyav_available, - is_transformers_version_greater_than, -) - - -if is_librosa_available(): - import librosa - - -if is_pillow_available(): - from PIL import Image - from PIL.Image import Image as ImageObject - - -if is_pyav_available(): - import av - - -if is_transformers_version_greater_than("4.52.0"): - from transformers.image_utils import make_flat_list_of_images - from transformers.video_utils import make_batched_videos -else: - from transformers.image_utils import make_batched_videos, make_flat_list_of_images - - -if TYPE_CHECKING: - from av.stream import Stream - from numpy.typing import NDArray - from transformers import PreTrainedTokenizer, ProcessorMixin - from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor - from transformers.image_processing_utils import BaseImageProcessor - - class EncodedImage(TypedDict): - path: Optional[str] - bytes: Optional[bytes] - - ImageInput = Union[str, bytes, EncodedImage, BinaryIO, ImageObject] - VideoInput = Union[str, BinaryIO, list[list[ImageInput]]] - AudioInput = Union[str, BinaryIO, NDArray] - - class MMProcessor(ProcessorMixin): - patch_size: int - image_seq_length: int - num_additional_image_tokens: int - vision_feature_select_strategy: Literal["default", "full"] - - def _get_number_of_features(self, orig_height: int, orig_width: int, height: int, width: int) -> int: - pass - - -def _get_paligemma_token_type_ids(imglens: list[int], seqlens: list[int], processor: "MMProcessor") -> list[list[int]]: - r"""Get paligemma token type ids for computing loss. - - It is slightly different with the original token type ids where the prompt part is 0. - - Returns: - batch_token_type_ids: shape (batch_size, seq_length) - - """ - batch_token_type_ids = [] - for imglen, seqlen in zip(imglens, seqlens): - image_seqlen = imglen * processor.image_seq_length - batch_token_type_ids.append([0] * image_seqlen + [1] * (seqlen - image_seqlen)) - - return batch_token_type_ids - - -def _get_gemma3_token_type_ids(batch_ids: list[list[int]], processor: "MMProcessor"): - r"""Get gemma3 token type ids for computing loss. - - Returns: - batch_token_type_ids: shape (batch_size, seq_length) - - """ - image_token_id: int = getattr(processor, "image_token_id") - batch_token_type_ids = [] - for token_ids in batch_ids: - token_ids = np.array(token_ids) - token_type_ids = np.zeros_like(token_ids) - token_type_ids[token_ids == image_token_id] = 1 - batch_token_type_ids.append(token_type_ids.tolist()) - - return batch_token_type_ids - - -def _make_batched_images(images: list["ImageObject"], imglens: list[int]) -> list[list["ImageObject"]]: - r"""Make nested list of images.""" - batch_images = [] - for imglen in imglens: - batch_images.append(images[:imglen]) - images = images[imglen:] - - return batch_images - - -def _check_video_is_nested_images(video: "VideoInput") -> bool: - r"""Check if the video is nested images.""" - return isinstance(video, list) and all(isinstance(frame, (str, BinaryIO, dict)) for frame in video) - - -@dataclass -class MMPluginMixin: - image_token: Optional[str] - video_token: Optional[str] - audio_token: Optional[str] - expand_mm_tokens: bool = True - - def _validate_input( - self, - processor: Optional["MMProcessor"], - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - ) -> None: - r"""Validate if this model accepts the input modalities.""" - image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) - video_processor: BaseImageProcessor = getattr( - processor, "video_processor", getattr(processor, "image_processor", None) - ) - feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) - if len(images) != 0 and self.image_token is None: - raise ValueError( - "This model does not support image input. Please check whether the correct `template` is used." - ) - - if len(videos) != 0 and self.video_token is None: - raise ValueError( - "This model does not support video input. Please check whether the correct `template` is used." - ) - - if len(audios) != 0 and self.audio_token is None: - raise ValueError( - "This model does not support audio input. Please check whether the correct `template` is used." - ) - - if self.image_token is not None and processor is None: - raise ValueError("Processor was not found, please check and update your model file.") - - if self.image_token is not None and image_processor is None: - raise ValueError("Image processor was not found, please check and update your model file.") - - if self.video_token is not None and video_processor is None: - raise ValueError("Video processor was not found, please check and update your model file.") - - if self.audio_token is not None and feature_extractor is None: - raise ValueError("Audio feature extractor was not found, please check and update your model file.") - - def _validate_messages( - self, - messages: list[dict[str, str]], - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - ): - r"""Validate if the number of images, videos and audios match the number of placeholders in messages.""" - num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0 - for message in messages: - num_image_tokens += message["content"].count(IMAGE_PLACEHOLDER) - num_video_tokens += message["content"].count(VIDEO_PLACEHOLDER) - num_audio_tokens += message["content"].count(AUDIO_PLACEHOLDER) - - if len(images) != num_image_tokens: - raise ValueError( - f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens in {messages}." - ) - - if len(videos) != num_video_tokens: - raise ValueError( - f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens in {messages}." - ) - - if len(audios) != num_audio_tokens: - raise ValueError( - f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens in {messages}." - ) - - def _preprocess_image( - self, image: "ImageObject", image_max_pixels: int, image_min_pixels: int, **kwargs - ) -> "ImageObject": - r"""Pre-process a single image.""" - if (image.width * image.height) > image_max_pixels: - resize_factor = math.sqrt(image_max_pixels / (image.width * image.height)) - width, height = int(image.width * resize_factor), int(image.height * resize_factor) - image = image.resize((width, height)) - - if (image.width * image.height) < image_min_pixels: - resize_factor = math.sqrt(image_min_pixels / (image.width * image.height)) - width, height = int(image.width * resize_factor), int(image.height * resize_factor) - image = image.resize((width, height)) - - if image.mode != "RGB": - image = image.convert("RGB") - - return image - - def _get_video_sample_indices( - self, video_stream: "Stream", video_fps: float, video_maxlen: int, **kwargs - ) -> list[int]: - r"""Compute video sample indices according to fps.""" - total_frames = video_stream.frames - if total_frames == 0: # infinite video - return np.linspace(0, video_maxlen - 1, video_maxlen).astype(np.int32) - - sample_frames = max(1, math.floor(float(video_stream.duration * video_stream.time_base) * video_fps)) - sample_frames = min(total_frames, video_maxlen, sample_frames) - return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32) - - def _regularize_images(self, images: list["ImageInput"], **kwargs) -> dict[str, list["ImageObject"]]: - r"""Regularize images to avoid error. Including reading and pre-processing.""" - results = [] - for image in images: - if isinstance(image, (str, BinaryIO)): - image = Image.open(image) - elif isinstance(image, bytes): - image = Image.open(BytesIO(image)) - elif isinstance(image, dict): - if image["bytes"] is not None: - image = Image.open(BytesIO(image["bytes"])) - else: - image = Image.open(image["path"]) - - if not isinstance(image, ImageObject): - raise ValueError(f"Expect input is a list of images, but got {type(image)}.") - - results.append(self._preprocess_image(image, **kwargs)) - - return {"images": results} - - def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> dict[str, list[list["ImageObject"]]]: - r"""Regularizes videos to avoid error. Including reading, resizing and converting.""" - results = [] - for video in videos: - frames: list[ImageObject] = [] - if _check_video_is_nested_images(video): - for frame in video: - if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame): - raise ValueError("Invalid image found in video frames.") - frames = video - else: - container = av.open(video, "r") - video_stream = next(stream for stream in container.streams if stream.type == "video") - sample_indices = self._get_video_sample_indices(video_stream, **kwargs) - container.seek(0) - for frame_idx, frame in enumerate(container.decode(video_stream)): - if frame_idx in sample_indices: - frames.append(frame.to_image()) - - frames = self._regularize_images(frames, **kwargs)["images"] - results.append(frames) - - return {"videos": results} - - def _regularize_audios( - self, audios: list["AudioInput"], sampling_rate: float, **kwargs - ) -> dict[str, Union[list["NDArray"], list[float]]]: - r"""Regularizes audios to avoid error. Including reading and resampling.""" - results, sampling_rates = [], [] - for audio in audios: - if not isinstance(audio, np.ndarray): - audio, sampling_rate = librosa.load(audio, sr=sampling_rate) - - results.append(audio) - sampling_rates.append(sampling_rate) - - return {"audios": results, "sampling_rates": sampling_rates} - - def _get_mm_inputs( - self, - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - processor: "MMProcessor", - imglens: Optional[list[int]] = None, - ) -> dict[str, "torch.Tensor"]: - r"""Process visual inputs. - - Returns: (llava and paligemma) - pixel_values: tensor with shape (B, C, H, W) - - Returns: (qwen2-vl) - pixel_values: tensor with shape (num_patches, patch_dim) - image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height - where num_patches == torch.prod(image_grid_thw) - - Returns: (mllama) - pixel_values: tensor with shape - (batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width) - For example, (2, 1, 4, 3, 560, 560). - aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1). - aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4). - num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1). - - """ - mm_inputs = {} - if len(images) != 0: - image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) - images = self._regularize_images( - images, - image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), - image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), - )["images"] - if imglens is not None: # if imglens are provided, make batched images - images = _make_batched_images(images, imglens) - - image_processor_kwargs = {} - if getattr(processor, "image_do_pan_and_scan", False): # gemma3 image processor - image_processor_kwargs.update( - { - "do_pan_and_scan": True, - "pan_and_scan_min_crop_size": 256, - "pan_and_scan_max_num_crops": 4, - "pan_and_scan_min_ratio_to_activate": 1.2, - } - ) - - mm_inputs.update(image_processor(images, return_tensors="pt", **image_processor_kwargs)) - - if len(videos) != 0: - video_processor: BaseImageProcessor = getattr( - processor, "video_processor", getattr(processor, "image_processor", None) - ) - videos = self._regularize_videos( - videos, - image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256), - image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16), - video_fps=getattr(processor, "video_fps", 2.0), - video_maxlen=getattr(processor, "video_maxlen", 128), - )["videos"] - if "videos" in inspect.signature(video_processor.preprocess).parameters: # for qwen2_vl and video_llava - mm_inputs.update(video_processor(images=None, videos=videos, return_tensors="pt")) - else: # for llava_next_video - mm_inputs.update(video_processor(videos, return_tensors="pt")) - - if len(audios) != 0: - feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) - audios = self._regularize_audios( - audios, - sampling_rate=getattr(processor, "audio_sampling_rate", 16000), - )["audios"] - mm_inputs.update( - feature_extractor( - audios, - sampling_rate=getattr(processor, "audio_sampling_rate", 16000), - return_attention_mask=True, - padding="max_length", - return_tensors="pt", - ) - ) - mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask", None) # prevent conflicts - - return mm_inputs - - -@dataclass -class BasePlugin(MMPluginMixin): - def process_messages( - self, - messages: list[dict[str, str]], - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - processor: Optional["MMProcessor"], - ) -> list[dict[str, str]]: - r"""Pre-process input messages before tokenization for VLMs.""" - self._validate_input(processor, images, videos, audios) - return messages - - def process_token_ids( - self, - input_ids: list[int], - labels: Optional[list[int]], - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - tokenizer: "PreTrainedTokenizer", - processor: Optional["MMProcessor"], - ) -> tuple[list[int], Optional[list[int]]]: - r"""Pre-process token ids after tokenization for VLMs.""" - self._validate_input(processor, images, videos, audios) - return input_ids, labels - - def get_mm_inputs( - self, - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - imglens: list[int], - vidlens: list[int], - audlens: list[int], - batch_ids: list[list[int]], - processor: Optional["MMProcessor"], - ) -> dict[str, Union[list[int], "torch.Tensor"]]: - r"""Build batched multimodal inputs for VLMs. - - Arguments: - images: a list of image inputs, shape (num_images,) - videos: a list of video inputs, shape (num_videos,) - audios: a list of audio inputs, shape (num_audios,) - imglens: number of images in each sample, shape (batch_size,) - vidlens: number of videos in each sample, shape (batch_size,) - audlens: number of audios in each sample, shape (batch_size,) - batch_ids: token ids of input samples, shape (batch_size, seq_len) - processor: a processor for pre-processing images and videos - - """ - self._validate_input(processor, images, videos, audios) - return self._get_mm_inputs(images, videos, audios, processor) - - -@dataclass -class Gemma3Plugin(BasePlugin): - @override - def process_messages( - self, - messages: list[dict[str, str]], - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - processor: Optional["MMProcessor"], - ) -> list[dict[str, str]]: - self._validate_input(processor, images, videos, audios) - self._validate_messages(messages, images, videos, audios) - num_image_tokens = 0 - messages = deepcopy(messages) - boi_token: str = getattr(processor, "boi_token") - full_image_sequence: str = getattr(processor, "full_image_sequence") - image_str = full_image_sequence if self.expand_mm_tokens else boi_token - - do_pan_and_scan: bool = getattr(processor, "image_do_pan_and_scan", False) - if do_pan_and_scan: - mm_inputs = self._get_mm_inputs(images, videos, audios, processor) - - for message in messages: - content = message["content"] - while IMAGE_PLACEHOLDER in content: - if do_pan_and_scan: - image_placeholder_str = ( - "Here is the original image {{image}} and here are some crops to help you see better " - + " ".join(["{{image}}"] * mm_inputs["num_crops"][0][num_image_tokens]) - ) - else: - image_placeholder_str = "{{image}}" - - content = content.replace(IMAGE_PLACEHOLDER, image_placeholder_str, 1) - num_image_tokens += 1 - - message["content"] = content.replace("{{image}}", image_str) - - return messages - - @override - def get_mm_inputs( - self, - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - imglens: list[int], - vidlens: list[int], - audlens: list[int], - batch_ids: list[list[int]], - processor: Optional["MMProcessor"], - ) -> dict[str, Union[list[int], "torch.Tensor"]]: - self._validate_input(processor, images, videos, audios) - mm_inputs = self._get_mm_inputs(images, videos, audios, processor) - mm_inputs.pop("num_crops", None) - mm_inputs["token_type_ids"] = _get_gemma3_token_type_ids(batch_ids, processor) - return mm_inputs - - -class Gemma3nPlugin(Gemma3Plugin): - @override - def process_messages( - self, - messages: list[dict[str, str]], - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - processor: Optional["MMProcessor"], - ) -> list[dict[str, str]]: - self._validate_input(processor, images, videos, audios) - self._validate_messages(messages, images, videos, audios) - messages = deepcopy(messages) - boi_token: str = getattr(processor, "boi_token") - boa_token: str = getattr(processor, "boa_token") - full_image_sequence: str = getattr(processor, "full_image_sequence") - full_audio_sequence: str = getattr(processor, "full_audio_sequence") - image_str = full_image_sequence if self.expand_mm_tokens else boi_token - audio_str = full_audio_sequence if self.expand_mm_tokens else boa_token - - for message in messages: - content = message["content"] - while IMAGE_PLACEHOLDER in content: - content = content.replace(IMAGE_PLACEHOLDER, image_str, 1) - - while AUDIO_PLACEHOLDER in content: - content = content.replace(AUDIO_PLACEHOLDER, audio_str, 1) - - message["content"] = content - - return messages - - -@dataclass -class InternVLPlugin(BasePlugin): - @override - def _get_mm_inputs( - self, - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - processor: "ProcessorMixin", - **kwargs, - ) -> dict[str, "torch.Tensor"]: - image_processor: BaseImageProcessor = getattr(processor, "image_processor") - image_processor_kwargs = {} - if getattr(processor, "crop_to_patches", False): - image_processor_kwargs.update( - { - "crop_to_patches": True, - "max_patches": 12, - "min_patches": 1, - } - ) - - mm_inputs = {} - image_video_patches = [] - - if len(images) != 0: - images = self._regularize_images( - images, - image_max_pixels=getattr(processor, "image_max_pixels", 1024 * 1024), - image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), - )["images"] - - if len(videos) != 0: - videos = self._regularize_videos( - videos, - image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256), - image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16), - video_fps=getattr(processor, "video_fps", 2.0), - video_maxlen=getattr(processor, "video_maxlen", 128), - )["videos"] - - if len(images) != 0: - images = make_flat_list_of_images(images) - image_inputs = image_processor(images=images, return_tensors="pt", **image_processor_kwargs) - image_num_patches = image_inputs.pop("num_patches") - image_pixel_values = image_inputs.pop("pixel_values") - image_num_patches_indices = np.cumsum(image_num_patches) - - if len(videos) != 0: - videos = make_batched_videos(videos) - num_frames_per_video = [len(video) for video in videos] - patch_indices = np.cumsum(num_frames_per_video) - image_processor_kwargs["crop_to_patches"] = False - video_inputs = image_processor(images=videos, return_tensors="pt", **image_processor_kwargs) - video_num_patches = video_inputs.pop("num_patches") - video_pixel_values = video_inputs.pop("pixel_values") - video_num_patches_indices = np.cumsum(video_num_patches) - - # NOT SUPPORT IMAGE VIDEO INTERLEAVED - if len(images) != 0 and image_pixel_values is not None: - for i in range(len(images)): - start_index = image_num_patches_indices[i - 1] if i > 0 else 0 - end_index = image_num_patches_indices[i] - image_video_patches.append(image_pixel_values[start_index:end_index]) - - if len(videos) != 0 and video_pixel_values is not None: - patch_indices_with_prefix = [0] + list(patch_indices) - for i in range(len(videos)): - current_patch_index = patch_indices_with_prefix[i] - end_patch_index = patch_indices_with_prefix[i + 1] - start_index = video_num_patches_indices[current_patch_index - 1] if i > 0 else 0 - end_index = video_num_patches_indices[end_patch_index - 1] - image_video_patches.append(video_pixel_values[start_index:end_index]) - - if len(images) != 0 or len(videos) != 0: - mm_inputs["pixel_values"] = torch.cat(image_video_patches, dim=0) - - if len(images) != 0: - mm_inputs.update({"image_num_patches": image_num_patches}) - - if len(videos) != 0: - mm_inputs.update({"video_patch_indices": patch_indices}) - mm_inputs.update({"video_num_patches": video_num_patches}) - - return mm_inputs - - @override - def process_messages( - self, - messages: list[dict[str, str]], - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - processor: Optional["ProcessorMixin"], - ) -> list[dict[str, str]]: - self._validate_input(processor, images, videos, audios) - self._validate_messages(messages, images, videos, audios) - num_image_tokens, num_video_tokens = 0, 0 - image_seqlen = getattr(processor, "image_seq_length") if self.expand_mm_tokens else 1 - messages = deepcopy(messages) - mm_inputs = self._get_mm_inputs(images, videos, audios, processor) - - image_pixel_patch_list = mm_inputs.get("image_num_patches") # pathes of images - video_num_patches = mm_inputs.get("video_num_patches") # all patches for frames of videos - video_patch_indices = mm_inputs.get("video_patch_indices") # num frames of per video - - for message in messages: - content = message["content"] - while IMAGE_PLACEHOLDER in content: - content = content.replace( - IMAGE_PLACEHOLDER, - f"{'' * image_seqlen * image_pixel_patch_list[num_image_tokens]}", - 1, - ) - num_image_tokens += 1 - - while VIDEO_PLACEHOLDER in content: - current_patch_index = video_patch_indices[num_video_tokens - 1] if num_video_tokens > 0 else 0 - end_patch_index = video_patch_indices[num_video_tokens] - num_patches = list(video_num_patches[current_patch_index:end_patch_index]) - video_replaced_prompt = "\n".join( - f"Frame{i + 1}: {'' * image_seqlen * num_patches[i]}" - for i in range(len(num_patches)) - ) - content = content.replace(VIDEO_PLACEHOLDER, video_replaced_prompt, 1) - num_video_tokens += 1 - - message["content"] = content - - return messages - - @override - def get_mm_inputs( - self, - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - imglens: list[int], - vidlens: list[int], - audlens: list[int], - batch_ids: list[list[int]], - processor: Optional["ProcessorMixin"], - ) -> dict[str, Union[list[int], "torch.Tensor"]]: - self._validate_input(processor, images, videos, audios) - mm_inputs = self._get_mm_inputs(images, videos, audios, processor) - mm_inputs.pop("image_num_patches", None) - mm_inputs.pop("video_patch_indices", None) - mm_inputs.pop("video_num_patches", None) - return mm_inputs - - -class KimiVLPlugin(BasePlugin): - @override - def process_messages(self, messages, images, videos, audios, processor): - self._validate_input(processor, images, videos, audios) - self._validate_messages(messages, images, videos, audios) - if self.expand_mm_tokens: - mm_inputs = self._get_mm_inputs(images, videos, audios, processor) - image_grid_hws = mm_inputs.get("image_grid_hws", []) - else: - image_grid_hws = [None] * len(images) - - num_image_tokens = 0 - image_processor: BaseImageProcessor = getattr(processor, "image_processor") - merge_length = math.prod(image_processor.merge_kernel_size) - messages = deepcopy(messages) - for message in messages: - content = message["content"] - while IMAGE_PLACEHOLDER in content: - image_seqlen = image_grid_hws[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1 - content = content.replace( - IMAGE_PLACEHOLDER, - f"<|media_start|>image<|media_content|>{self.image_token * image_seqlen}<|media_end|>", - 1, - ) - num_image_tokens += 1 - - message["content"] = content - - return messages - - -@dataclass -class Llama4Plugin(BasePlugin): - @override - def process_messages( - self, - messages: list[dict[str, str]], - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - processor: Optional["MMProcessor"], - ) -> list[dict[str, str]]: - self._validate_input(processor, images, videos, audios) - self._validate_messages(messages, images, videos, audios) - if self.expand_mm_tokens: - mm_inputs = self._get_mm_inputs(images, videos, audios, processor) - if "pixel_values" in mm_inputs: - image_height, image_width = mm_inputs["pixel_values"][0].shape[-2:] - num_patches_per_chunk = int( - (image_height // processor.patch_size) - * (image_width // processor.patch_size) - // processor.downsample_ratio - ) - aspect_ratios = mm_inputs.pop("aspect_ratios") - - num_image_tokens = 0 - messages = deepcopy(messages) - for message in messages: - content = message["content"] - if self.expand_mm_tokens: - placeholder_count = content.count(IMAGE_PLACEHOLDER) - prompt_splits = content.split(IMAGE_PLACEHOLDER) - new_content = [] - for local_image_index, split_part in enumerate(prompt_splits): - new_content.append(split_part) - if local_image_index < placeholder_count: - tokens_for_this_image = processor._prompt_split_image( - aspect_ratios[num_image_tokens], num_patches_per_chunk - ) - num_image_tokens += 1 - new_content.append(tokens_for_this_image) - - content = "".join(new_content) - else: - content = content.replace(IMAGE_PLACEHOLDER, self.image_token) - - message["content"] = content - - return messages - - @override - def get_mm_inputs( - self, - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - imglens: list[int], - vidlens: list[int], - audlens: list[int], - batch_ids: list[list[int]], - processor: Optional["MMProcessor"], - ) -> dict[str, Union[list[int], "torch.Tensor"]]: - self._validate_input(processor, images, videos, audios) - mm_inputs = self._get_mm_inputs(images, videos, audios, processor) - mm_inputs.pop("aspect_ratios", None) - return mm_inputs - - -@dataclass -class LlavaPlugin(BasePlugin): - @override - def process_messages( - self, - messages: list[dict[str, str]], - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - processor: Optional["MMProcessor"], - ) -> list[dict[str, str]]: - self._validate_input(processor, images, videos, audios) - self._validate_messages(messages, images, videos, audios) - messages = deepcopy(messages) - if self.expand_mm_tokens: - mm_inputs = self._get_mm_inputs(images, videos, audios, processor) - if "pixel_values" in mm_inputs: - height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0])) - image_seqlen = (height // processor.patch_size) * ( - width // processor.patch_size - ) + processor.num_additional_image_tokens - if processor.vision_feature_select_strategy == "default": - image_seqlen -= 1 - else: - image_seqlen = 1 - - for message in messages: - content = message["content"] - while IMAGE_PLACEHOLDER in content: - content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) - - message["content"] = content.replace("{{image}}", self.image_token) - - return messages - - -@dataclass -class LlavaNextPlugin(BasePlugin): - @override - def process_messages( - self, - messages: list[dict[str, str]], - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - processor: Optional["MMProcessor"], - ) -> list[dict[str, str]]: - self._validate_input(processor, images, videos, audios) - self._validate_messages(messages, images, videos, audios) - num_image_tokens = 0 - messages = deepcopy(messages) - if self.expand_mm_tokens: - mm_inputs = self._get_mm_inputs(images, videos, audios, processor) - if "pixel_values" in mm_inputs: - image_sizes = iter(mm_inputs["image_sizes"].tolist()) - height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0])) - - for message in messages: - content = message["content"] - while IMAGE_PLACEHOLDER in content: - if self.expand_mm_tokens: - orig_height, orig_width = next(image_sizes) - image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width) - if processor.vision_feature_select_strategy == "default": - image_seqlen -= 1 - else: - image_seqlen = 1 - - content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) - num_image_tokens += 1 - - message["content"] = content.replace("{{image}}", self.image_token) - - return messages - - -@dataclass -class LlavaNextVideoPlugin(BasePlugin): - @override - def process_messages( - self, - messages: list[dict[str, str]], - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - processor: Optional["MMProcessor"], - ) -> list[dict[str, str]]: - self._validate_input(processor, images, videos, audios) - self._validate_messages(messages, images, videos, audios) - messages = deepcopy(messages) - if self.expand_mm_tokens: - mm_inputs = self._get_mm_inputs(images, videos, audios, processor) - if "pixel_values" in mm_inputs: - image_sizes = iter(mm_inputs["image_sizes"].tolist()) - height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0])) - - for message in messages: - content = message["content"] - while IMAGE_PLACEHOLDER in content: - if self.expand_mm_tokens: - orig_height, orig_width = next(image_sizes) - image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width) - if processor.vision_feature_select_strategy == "default": - image_seqlen -= 1 - else: - image_seqlen = 1 - - content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) - - message["content"] = content.replace("{{image}}", self.image_token) - - if self.expand_mm_tokens: - if "pixel_values_videos" in mm_inputs: - one_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0]) - height, width = get_image_size(one_video[0]) - num_frames = one_video.shape[0] # frame dim is always after batch dim - image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) - video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer - else: - video_seqlen = 1 - - for message in messages: - content = message["content"] - while VIDEO_PLACEHOLDER in content: - content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1) - - message["content"] = content.replace("{{video}}", self.video_token) - - return messages - - -@dataclass -class MiniCPMVPlugin(BasePlugin): - @override - def _get_mm_inputs( - self, - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - processor: "MMProcessor", - **kwargs, - ) -> dict[str, "torch.Tensor"]: - image_processor: BaseImageProcessor = getattr(processor, "image_processor") - mm_inputs = {} - if len(images) != 0: - images = self._regularize_images( - images, - image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), - image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), - )["images"] - if "valid_image_nums_ls" in kwargs: - valid_image_nums_ls = kwargs["valid_image_nums_ls"] - new_images = [] - idx = 0 - for valid_image_nums in valid_image_nums_ls: - new_images.append(images[idx : idx + valid_image_nums]) - idx += valid_image_nums - - images = new_images - - image_inputs = image_processor( - images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt" - ) - mm_inputs.update(image_inputs) - - if len(videos) != 0: - videos = self._regularize_videos( - videos, - image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256), - image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16), - video_fps=getattr(processor, "video_fps", 2.0), - video_maxlen=getattr(processor, "video_maxlen", 128), - )["videos"] - video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt") - mm_inputs.update(video_inputs) - - if len(audios) != 0: - audios = self._regularize_audios( - audios, - sampling_rate=getattr(processor, "audio_sampling_rate", 16000), - )["audios"] - if "valid_audio_nums_ls" in kwargs: - valid_audio_nums_ls = kwargs["valid_audio_nums_ls"] - audios_ls = [] - idx = 0 - for valid_audio_nums in valid_audio_nums_ls: - audios_ls.append(audios[idx : idx + valid_audio_nums]) - idx += valid_audio_nums - else: - audios_ls = [audios] - - audio_features, audio_feature_lens, audio_phs = processor.audio_feature_extract( - audios_ls, - chunk_input=True, - sampling_rate=getattr(processor, "audio_sampling_rate", 16000), - ) - audio_feature_lens = [torch.tensor(audio_feature_len) for audio_feature_len in audio_feature_lens] - mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens}) - if kwargs.get("ret_phs", False): - mm_inputs.update({"audio_phs": audio_phs}) - - return mm_inputs - - @override - def process_messages( - self, - messages: list[dict[str, str]], - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - processor: Optional["MMProcessor"], - ) -> list[dict[str, str]]: - self._validate_input(processor, images, videos, audios) - self._validate_messages(messages, images, videos, audios) - num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0 - messages = deepcopy(messages) - image_processor: BaseImageProcessor = getattr(processor, "image_processor") - mm_inputs, audio_inputs = {}, {} - if len(images) != 0 and len(videos) != 0: - raise ValueError("MiniCPM-V model does not support input images and videos at the same time.") - - if len(videos) != 0: - max_slice_nums = 2 - use_image_id = False - mm_inputs = self._get_mm_inputs([], videos, [], processor) - else: - max_slice_nums = image_processor.max_slice_nums - use_image_id = image_processor.use_image_id - - for i, message in enumerate(messages): - content = message["content"] - while IMAGE_PLACEHOLDER in content: - content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) - num_image_tokens += 1 - - while VIDEO_PLACEHOLDER in content: - video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1 - content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1) - num_video_tokens += 1 - - while AUDIO_PLACEHOLDER in content: - content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1) - num_audio_tokens += 1 - - message["content"] = content.replace("{{image}}", "(./)").replace( - "{{audio}}", "()" - ) - - if len(images): - mm_inputs = self._get_mm_inputs(images, [], [], processor) - - if len(audios): - audio_inputs = self._get_mm_inputs([], [], audios, processor, ret_phs=True) - - if self.expand_mm_tokens and mm_inputs: - pattern = "(./)" - image_sizes = mm_inputs["image_sizes"] - idx = 0 - for index, message in enumerate(messages): - text = message["content"] - image_tags = re.findall(pattern, text) - text_chunks = text.split(pattern) - final_text = "" - for i in range(len(image_tags)): - final_text = ( - final_text - + text_chunks[i] - + image_processor.get_slice_image_placeholder( - image_sizes[0][idx], idx, max_slice_nums, use_image_id - ) - ) - idx += 1 - - final_text += text_chunks[-1] - messages[index]["content"] = final_text - - if self.expand_mm_tokens and audio_inputs: - pattern = "()" - idx = 0 - for index, message in enumerate(messages): - text = message["content"] - audio_tags = re.findall(pattern, text) - text_chunks = text.split(pattern) - final_text = "" - for i in range(len(audio_tags)): - audio_placeholder = audio_inputs["audio_phs"][0][idx] - final_text = final_text + text_chunks[i] + audio_placeholder - idx += 1 - - final_text += text_chunks[-1] - messages[index]["content"] = final_text - - return messages - - @override - def get_mm_inputs( - self, - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - imglens: list[int], - vidlens: list[int], - audlens: list[int], - batch_ids: list[list[int]], - processor: Optional["MMProcessor"], - ) -> dict[str, Union[list[int], "torch.Tensor"]]: - self._validate_input(processor, images, videos, audios) - # image bound - image_bounds_list = [] - valid_image_nums_ls = [] - for i, input_ids in enumerate(batch_ids): - input_ids_ = torch.tensor(input_ids) - start_cond = (input_ids_ == processor.tokenizer.im_start_id) | ( - input_ids_ == processor.tokenizer.slice_start_id - ) - end_cond = (input_ids_ == processor.tokenizer.im_end_id) | (input_ids_ == processor.tokenizer.slice_end_id) - image_start_tokens = torch.where(start_cond)[0] - image_start_tokens += 1 - image_end_tokens = torch.where(end_cond)[0] - valid_image_nums_ls.append(imglens[i]) - image_bounds = torch.hstack( - [ - image_start_tokens.unsqueeze(-1), - image_end_tokens.unsqueeze(-1), - ] - ) - image_bounds_list.append(image_bounds) - - mm_inputs = self._get_mm_inputs(images, videos, [], processor, valid_image_nums_ls=valid_image_nums_ls) - if "tgt_sizes" not in mm_inputs: - dummy_data = [torch.empty(0) for _ in range(len(batch_ids))] - mm_inputs.update({"tgt_sizes": dummy_data, "pixel_values": dummy_data, "image_sizes": dummy_data}) - - mm_inputs.update({"image_bound": image_bounds_list}) - - if len(audios) > 0: - # audio bound - audio_bounds_ls = [] - spk_bounds_ls = [] - valid_audio_nums_ls = [] - - for input_ids, audiolen in zip(batch_ids, audlens): - input_ids_ = torch.tensor(input_ids) - audio_start_idx = torch.where(input_ids_ == processor.tokenizer.audio_start_id)[0] - audio_end_idx = torch.where(input_ids_ == processor.tokenizer.audio_end_id)[0] - assert len(audio_start_idx) == len(audio_end_idx) - audio_bounds = torch.hstack([(audio_start_idx + 1).unsqueeze(-1), audio_end_idx.unsqueeze(-1)]) - audio_bounds_ls.append(audio_bounds) - valid_audio_nums_ls.append(audiolen) - - spk_start_idx = torch.where(input_ids_ == processor.tokenizer.spk_start_id)[0] - spk_end_idx = torch.where(input_ids_ == processor.tokenizer.spk_end_id)[0] - assert len(spk_start_idx) == len(spk_end_idx) - spk_bounds = torch.hstack([(spk_start_idx + 1).unsqueeze(-1), spk_end_idx.unsqueeze(-1)]) - spk_bounds_ls.append(spk_bounds) - - audio_inputs = self._get_mm_inputs([], [], audios, processor, valid_audio_nums_ls=valid_audio_nums_ls) - mm_inputs.update(audio_inputs) - mm_inputs.update({"audio_bounds": audio_bounds_ls, "spk_bounds": spk_bounds_ls}) - - return mm_inputs - - -@dataclass -class MllamaPlugin(BasePlugin): - @override - def process_messages( - self, - messages: list[dict[str, str]], - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - processor: Optional["MMProcessor"], - ) -> list[dict[str, str]]: - self._validate_input(processor, images, videos, audios) - self._validate_messages(messages, images, videos, audios) - num_image_tokens = 0 - messages = deepcopy(messages) - for message in messages: - content = message["content"] - num_image_tokens += content.count(IMAGE_PLACEHOLDER) - message["content"] = content.replace(IMAGE_PLACEHOLDER, self.image_token) - - return messages - - @override - def get_mm_inputs( - self, - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - imglens: list[int], - vidlens: list[int], - audlens: list[int], - batch_ids: list[list[int]], - processor: Optional["MMProcessor"], - ) -> dict[str, Union[list[int], "torch.Tensor"]]: - self._validate_input(processor, images, videos, audios) - mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens) - if mm_inputs: - num_tiles = mm_inputs.pop("num_tiles") - image_token_id: int = getattr(processor, "image_token_id") - max_image_tiles: int = getattr(processor.image_processor, "max_image_tiles") - cross_attention_token_mask = [ - get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids - ] - mm_inputs["cross_attention_mask"] = torch.from_numpy( - convert_sparse_cross_attention_mask_to_dense( - cross_attention_token_mask, - num_tiles=num_tiles, - max_num_tiles=max_image_tiles, - length=max(len(input_ids) for input_ids in batch_ids), - ) - ) # shape: (batch_size, length, max_num_images, max_num_tiles) - - return mm_inputs - - -@dataclass -class PaliGemmaPlugin(BasePlugin): - @override - def process_messages( - self, - messages: list[dict[str, str]], - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - processor: Optional["MMProcessor"], - ) -> list[dict[str, str]]: - self._validate_input(processor, images, videos, audios) - self._validate_messages(messages, images, videos, audios) - num_image_tokens = 0 - messages = deepcopy(messages) - for message in messages: - content = message["content"] - while IMAGE_PLACEHOLDER in content: - content = content.replace(IMAGE_PLACEHOLDER, "", 1) - num_image_tokens += 1 - - message["content"] = content - - return messages - - @override - def process_token_ids( - self, - input_ids: list[int], - labels: Optional[list[int]], - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - tokenizer: "PreTrainedTokenizer", - processor: Optional["MMProcessor"], - ) -> tuple[list[int], Optional[list[int]]]: - self._validate_input(processor, images, videos, audios) - num_images = len(images) - image_seqlen = processor.image_seq_length if self.expand_mm_tokens else 0 # skip mm token - image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) - input_ids = [image_token_id] * num_images * image_seqlen + input_ids - if labels is not None: - labels = [IGNORE_INDEX] * num_images * image_seqlen + labels - - return input_ids, labels - - @override - def get_mm_inputs( - self, - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - imglens: list[int], - vidlens: list[int], - audlens: list[int], - batch_ids: list[list[int]], - processor: Optional["MMProcessor"], - ) -> dict[str, Union[list[int], "torch.Tensor"]]: - self._validate_input(processor, images, videos, audios) - seqlens = [len(input_ids) for input_ids in batch_ids] - mm_inputs = self._get_mm_inputs(images, videos, audios, processor) - mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor) - return mm_inputs - - -@dataclass -class PixtralPlugin(BasePlugin): - @override - def process_messages( - self, - messages: list[dict[str, str]], - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - processor: Optional["MMProcessor"], - ) -> list[dict[str, str]]: - self._validate_input(processor, images, videos, audios) - self._validate_messages(messages, images, videos, audios) - messages = deepcopy(messages) - if self.expand_mm_tokens: - mm_inputs = self._get_mm_inputs(images, videos, audios, processor) - if "pixel_values" in mm_inputs: - # BC for transformers < 4.49.0 - if isinstance(mm_inputs["image_sizes"], list): - image_sizes = iter(mm_inputs["image_sizes"][0]) - else: - image_sizes = iter(mm_inputs["image_sizes"].tolist()) - - image_break_token: str = getattr(processor, "image_break_token") - image_end_token: str = getattr(processor, "image_end_token") - - for message in messages: - content = message["content"] - while IMAGE_PLACEHOLDER in content: - if self.expand_mm_tokens: - patch_size = processor.patch_size * getattr(processor, "spatial_merge_size", 1) - height, width = next(image_sizes) - num_height_tokens = height // patch_size - num_width_tokens = width // patch_size - replace_tokens = [[self.image_token] * num_width_tokens + [image_break_token]] * num_height_tokens - replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list - replace_tokens[-1] = image_end_token - replace_str = "".join(replace_tokens) - else: - replace_str = self.image_token - - content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1) - - message["content"] = content - - return messages - - @override - def get_mm_inputs( - self, - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - imglens: list[int], - vidlens: list[int], - audlens: list[int], - batch_ids: list[list[int]], - processor: Optional["MMProcessor"], - ) -> dict[str, Union[list[int], "torch.Tensor"]]: - self._validate_input(processor, images, videos, audios) - mm_inputs = self._get_mm_inputs(images, videos, audios, processor) - # ref to this commit https://github.com/huggingface/transformers/pull/35122 - # after transformers 4.49.0, the `image_sizes` is mandatory as an input parameter for Pixtral VisionEncoder forwarding. - # it can be passed into `LlavaConditionalGeneration` as a parameter. - if not is_transformers_version_greater_than("4.49.0"): - mm_inputs.pop("image_sizes", None) - return mm_inputs - - -@dataclass -class Qwen2AudioPlugin(BasePlugin): - @override - def process_messages( - self, - messages: list[dict[str, str]], - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - processor: Optional["MMProcessor"], - ) -> list[dict[str, str]]: - self._validate_input(processor, images, videos, audios) - self._validate_messages(messages, images, videos, audios) - bos_token: str = getattr(processor, "audio_bos_token") - eos_token: str = getattr(processor, "audio_eos_token") - messages = deepcopy(messages) - if self.expand_mm_tokens: - mm_inputs = self._get_mm_inputs([], [], audios, processor) - if "feature_attention_mask" in mm_inputs: - audio_lengths = mm_inputs["feature_attention_mask"].sum(-1).tolist() - - for message in messages: - content = message["content"] - while AUDIO_PLACEHOLDER in content: - if self.expand_mm_tokens: - audio_length = audio_lengths.pop(0) - input_length = (audio_length - 1) // 2 + 1 - audio_seqlen = (input_length - 2) // 2 + 1 - else: - audio_seqlen = 1 - - content = content.replace( - AUDIO_PLACEHOLDER, f"{bos_token}{self.audio_token * audio_seqlen}{eos_token}", 1 - ) - - message["content"] = content - - return messages - - @override - def get_mm_inputs( - self, - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - imglens: list[int], - vidlens: list[int], - audlens: list[int], - batch_ids: list[list[int]], - processor: Optional["MMProcessor"], - ) -> dict[str, Union[list[int], "torch.Tensor"]]: - self._validate_input(processor, images, videos, audios) - return self._get_mm_inputs(images, videos, audios, processor) - - -@dataclass -class Qwen2VLPlugin(BasePlugin): - @override - def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject": - image = super()._preprocess_image(image, **kwargs) - if min(image.width, image.height) < 28: - width, height = max(image.width, 28), max(image.height, 28) - image = image.resize((width, height)) - - if image.width / image.height > 200: - width, height = image.height * 180, image.height - image = image.resize((width, height)) - - if image.height / image.width > 200: - width, height = image.width, image.width * 180 - image = image.resize((width, height)) - - return image - - @override - def _regularize_videos( - self, videos: list["VideoInput"], **kwargs - ) -> dict[str, Union[list[list["ImageObject"]], list[float]]]: - results, fps_per_video = [], [] - for video in videos: - frames: list[ImageObject] = [] - if _check_video_is_nested_images(video): - for frame in video: - if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame): - raise ValueError("Invalid image found in video frames.") - - frames = video - fps_per_video.append(kwargs.get("video_fps", 2.0)) - else: - container = av.open(video, "r") - video_stream = next(stream for stream in container.streams if stream.type == "video") - sample_indices = self._get_video_sample_indices(video_stream, **kwargs) - container.seek(0) - for frame_idx, frame in enumerate(container.decode(video_stream)): - if frame_idx in sample_indices: - frames.append(frame.to_image()) - - if video_stream.duration is None: - fps_per_video.append(kwargs.get("video_fps", 2.0)) - else: - fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base)) - - if len(frames) % 2 != 0: - frames.append(frames[-1]) - - frames = self._regularize_images(frames, **kwargs)["images"] - results.append(frames) - - return {"videos": results, "fps_per_video": fps_per_video} - - @override - def _get_mm_inputs( - self, - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - processor: "MMProcessor", - ) -> dict[str, "torch.Tensor"]: - image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) - mm_inputs = {} - if len(images) != 0: - images = self._regularize_images( - images, - image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), - image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), - )["images"] - mm_inputs.update(image_processor(images, return_tensors="pt")) - - if len(videos) != 0: - video_data = self._regularize_videos( - videos, - image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256), - image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16), - video_fps=getattr(processor, "video_fps", 2.0), - video_maxlen=getattr(processor, "video_maxlen", 128), - ) - mm_inputs.update(image_processor(images=None, videos=video_data["videos"], return_tensors="pt")) - temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2) - if "second_per_grid_ts" in processor.model_input_names: - mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in video_data["fps_per_video"]] - - return mm_inputs - - @override - def process_messages( - self, - messages: list[dict[str, str]], - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - processor: Optional["MMProcessor"], - ) -> list[dict[str, str]]: - self._validate_input(processor, images, videos, audios) - self._validate_messages(messages, images, videos, audios) - num_image_tokens, num_video_tokens = 0, 0 - messages = deepcopy(messages) - image_processor: BaseImageProcessor = getattr(processor, "image_processor") - - merge_length: int = getattr(image_processor, "merge_size") ** 2 - if self.expand_mm_tokens: - mm_inputs = self._get_mm_inputs(images, videos, audios, processor) - image_grid_thw = mm_inputs.get("image_grid_thw", []) - video_grid_thw = mm_inputs.get("video_grid_thw", []) - else: - image_grid_thw = [None] * len(images) - video_grid_thw = [None] * len(videos) - - for message in messages: - content = message["content"] - while IMAGE_PLACEHOLDER in content: - image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1 - content = content.replace( - IMAGE_PLACEHOLDER, f"<|vision_start|>{self.image_token * image_seqlen}<|vision_end|>", 1 - ) - num_image_tokens += 1 - - while VIDEO_PLACEHOLDER in content: - video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1 - content = content.replace( - VIDEO_PLACEHOLDER, f"<|vision_start|>{self.video_token * video_seqlen}<|vision_end|>", 1 - ) - num_video_tokens += 1 - - message["content"] = content - - return messages - - -@dataclass -class GLM4VPlugin(Qwen2VLPlugin): - @override - def _get_mm_inputs( - self, - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - processor: "MMProcessor", - ) -> dict[str, "torch.Tensor"]: - image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) - video_processor: BaseImageProcessor = getattr(processor, "video_processor", None) - mm_inputs = {} - if len(images) != 0: - images = self._regularize_images( - images, - image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), - image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), - )["images"] - mm_inputs.update(image_processor(images, return_tensors="pt")) - - if len(videos) != 0: - video_data = self._regularize_videos( - videos, - image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256), - image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16), - video_fps=getattr(processor, "video_fps", 2.0), - video_maxlen=getattr(processor, "video_maxlen", 128), - ) - # prepare video metadata - video_metadata = [ - {"fps": 2, "duration": len(video), "total_frames": len(video)} for video in video_data["videos"] - ] - mm_inputs.update(video_processor(images=None, videos=video_data["videos"], video_metadata=video_metadata)) - - return mm_inputs - - @override - def process_messages( - self, - messages: list[dict[str, str]], - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - processor: Optional["MMProcessor"], - ) -> list[dict[str, str]]: - self._validate_input(processor, images, videos, audios) - self._validate_messages(messages, images, videos, audios) - num_image_tokens, num_video_tokens = 0, 0 - messages = deepcopy(messages) - image_processor: BaseImageProcessor = getattr(processor, "image_processor") - - merge_length: int = getattr(image_processor, "merge_size") ** 2 - if self.expand_mm_tokens: - mm_inputs = self._get_mm_inputs(images, videos, audios, processor) - image_grid_thw = mm_inputs.get("image_grid_thw", []) - video_grid_thw = mm_inputs.get("video_grid_thw", []) - num_frames = video_grid_thw[0][0] if len(video_grid_thw) > 0 else 0 # hard code for now - timestamps = mm_inputs.get("timestamps", []) - - if hasattr(timestamps, "tolist"): - timestamps = timestamps.tolist() - - if not timestamps: - timestamps_list = [] - elif isinstance(timestamps[0], list): - timestamps_list = timestamps[0] - else: - timestamps_list = timestamps - - unique_timestamps = timestamps_list.copy() - selected_timestamps = unique_timestamps[:num_frames] - while len(selected_timestamps) < num_frames: - selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0) - - else: - image_grid_thw = [None] * len(images) - video_grid_thw = [None] * len(videos) - num_frames = 0 - selected_timestamps = [0] - - for message in messages: - content = message["content"] - while IMAGE_PLACEHOLDER in content: - image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1 - content = content.replace( - IMAGE_PLACEHOLDER, f"<|begin_of_image|>{self.image_token * image_seqlen}<|end_of_image|>", 1 - ) - num_image_tokens += 1 - - while VIDEO_PLACEHOLDER in content: - video_structure = "" - for frame_index in range(num_frames): - video_seqlen = ( - video_grid_thw[num_video_tokens][1:].prod() // merge_length if self.expand_mm_tokens else 1 - ) - timestamp_sec = selected_timestamps[frame_index] - frame_structure = ( - f"<|begin_of_image|>{self.image_token * video_seqlen}<|end_of_image|>{timestamp_sec}" - ) - video_structure += frame_structure - - content = content.replace(VIDEO_PLACEHOLDER, f"<|begin_of_video|>{video_structure}<|end_of_video|>", 1) - num_video_tokens += 1 - - message["content"] = content - - return messages - - @override - def get_mm_inputs( - self, - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - imglens: list[int], - vidlens: list[int], - audlens: list[int], - batch_ids: list[list[int]], - processor: Optional["ProcessorMixin"], - ) -> dict[str, Union[list[int], "torch.Tensor"]]: - self._validate_input(processor, images, videos, audios) - mm_inputs = self._get_mm_inputs(images, videos, audios, processor) - mm_inputs.pop("timestamps", None) - return mm_inputs - - -class Qwen2OmniPlugin(Qwen2VLPlugin): - @override - def _get_mm_inputs( - self, - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - processor: "MMProcessor", - ) -> dict[str, "torch.Tensor"]: - image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) - feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) - mm_inputs = {} - if len(images) != 0: - images = self._regularize_images( - images, - image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), - image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), - )["images"] - mm_inputs.update(image_processor(images, return_tensors="pt")) - - if len(videos) != 0: - video_dict = self._regularize_videos( - videos, - image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256), - image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16), - video_fps=getattr(processor, "video_fps", 2.0), - video_maxlen=getattr(processor, "video_maxlen", 128), - ) - mm_inputs.update(image_processor(images=None, videos=video_dict["videos"], return_tensors="pt")) - temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2) - mm_inputs["video_second_per_grid"] = torch.tensor( - [temporal_patch_size / fps for fps in video_dict["fps_per_video"]] - ) - - if len(audios) != 0: - audios = self._regularize_audios( - audios, - sampling_rate=getattr(processor, "audio_sampling_rate", 16000), - )["audios"] - mm_inputs.update( - feature_extractor( - audios, - sampling_rate=getattr(processor, "audio_sampling_rate", 16000), - return_attention_mask=True, - padding="max_length", - return_tensors="pt", - ) - ) - mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask") # prevent conflicts - - return mm_inputs - - @override - def process_messages( - self, - messages: list[dict[str, str]], - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - processor: Optional["MMProcessor"], - ) -> list[dict[str, str]]: - self._validate_input(processor, images, videos, audios) - self._validate_messages(messages, images, videos, audios) - num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0 - messages = deepcopy(messages) - image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) - - merge_length = processor.image_processor.merge_size**2 - use_audio_in_video = getattr(processor, "use_audio_in_video", False) - if self.expand_mm_tokens: - mm_inputs = self._get_mm_inputs(images, videos, audios, processor) - image_grid_thw = mm_inputs.get("image_grid_thw", []) - video_grid_thw = mm_inputs.get("video_grid_thw", []) - if "feature_attention_mask" in mm_inputs: - input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1 - audio_lengths = (input_lengths - 2) // 2 + 1 - else: - mm_inputs = {} - image_grid_thw = [None] * len(images) - video_grid_thw = [None] * len(videos) - audio_lengths = [None] * len(audios) - - for message in messages: - content = message["content"] - while IMAGE_PLACEHOLDER in content: - image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1 - content = content.replace( - IMAGE_PLACEHOLDER, f"<|vision_bos|>{self.image_token * image_seqlen}<|vision_eos|>", 1 - ) - num_image_tokens += 1 - - if ( - use_audio_in_video and len(audios) and len(videos) - ): # if use the audio of video # deal video token and audio token togather - if len(videos) != len(audios): - raise ValueError( - f"Number of videos ({len(videos)}) must match number of audios ({len(audios)}) when using audio in video." - ) - - while VIDEO_PLACEHOLDER in content: - video_pos = content.find(VIDEO_PLACEHOLDER) - audio_pos = content.find(AUDIO_PLACEHOLDER, video_pos) - if audio_pos == -1 or audio_pos < video_pos: - raise ValueError( - f"Each {VIDEO_PLACEHOLDER} must be followed by an {AUDIO_PLACEHOLDER} when using audio in video." - ) - - audio_t_index = torch.arange(audio_lengths[num_audio_tokens]) - video_t_index = ( - torch.arange(video_grid_thw[num_video_tokens][0]) - .view(-1, 1, 1) - .expand( - -1, - video_grid_thw[num_video_tokens][1] // image_processor.merge_size, - video_grid_thw[num_video_tokens][2] // image_processor.merge_size, - ) - .flatten() - * mm_inputs["video_second_per_grid"][num_video_tokens] - * 25 # FIXME hardcode of position_id_per_seconds=25 - ).long() - t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2] - video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk) - audio_chunk_indices = processor.get_chunked_index(audio_t_index, t_ntoken_per_chunk) - placeholder_string = "" - placeholder_string += "<|vision_bos|>" + "<|audio_bos|>" - for j in range(max(len(video_chunk_indices), len(audio_chunk_indices))): - video_chunk_index = video_chunk_indices[j] if j < len(video_chunk_indices) else None - audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None - if video_chunk_index is not None: - placeholder_string += self.video_token * (video_chunk_index[1] - video_chunk_index[0]) - - if audio_chunk_index is not None: - placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0]) - - placeholder_string += "<|audio_eos|>" + "<|vision_eos|>" - content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1) - content = content.replace(AUDIO_PLACEHOLDER, "", 1) - num_audio_tokens += 1 - num_video_tokens += 1 - else: - while AUDIO_PLACEHOLDER in content: - audio_seqlen = audio_lengths[num_audio_tokens] if self.expand_mm_tokens else 1 - content = content.replace( - AUDIO_PLACEHOLDER, f"<|audio_bos|>{self.audio_token * audio_seqlen}<|audio_eos|>", 1 - ) - num_audio_tokens += 1 - - while VIDEO_PLACEHOLDER in content: - video_seqlen = ( - video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1 - ) - content = content.replace( - VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_seqlen}<|vision_eos|>", 1 - ) - num_video_tokens += 1 - - message["content"] = content - - return messages - - -@dataclass -class VideoLlavaPlugin(BasePlugin): - @override - def process_messages( - self, - messages: list[dict[str, str]], - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - processor: Optional["MMProcessor"], - ) -> list[dict[str, str]]: - self._validate_input(processor, images, videos, audios) - self._validate_messages(messages, images, videos, audios) - num_image_tokens, num_video_tokens = 0, 0 - messages = deepcopy(messages) - num_frames = 0 - if self.expand_mm_tokens: - mm_inputs = self._get_mm_inputs(images, videos, audios, processor) - if "pixel_values_images" in mm_inputs: - height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values_images"][0])) - num_frames = 1 - - if "pixel_values_videos" in mm_inputs: - one_video = to_numpy_array(mm_inputs["pixel_values_videos"][0]) - height, width = get_image_size(one_video[0]) - num_frames = one_video.shape[0] # frame dim is always after batch dim - - if "pixel_values_images" in mm_inputs or "pixel_values_videos" in mm_inputs: - image_seqlen = (height // processor.patch_size) * ( - width // processor.patch_size - ) + processor.num_additional_image_tokens - video_seqlen = image_seqlen * num_frames - if processor.vision_feature_select_strategy == "default": - image_seqlen -= 1 - else: - image_seqlen, video_seqlen = 1, 1 - - for message in messages: - content = message["content"] - while IMAGE_PLACEHOLDER in content: - content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) - num_image_tokens += 1 - - while VIDEO_PLACEHOLDER in content: - content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1) - num_video_tokens += 1 - - content = content.replace("{{image}}", self.image_token) - message["content"] = content.replace("{{video}}", self.video_token) - - return messages - - -PLUGINS = { - "base": BasePlugin, - "gemma3": Gemma3Plugin, - "glm4v": GLM4VPlugin, - "gemma3n": Gemma3nPlugin, - "intern_vl": InternVLPlugin, - "kimi_vl": KimiVLPlugin, - "llama4": Llama4Plugin, - "llava": LlavaPlugin, - "llava_next": LlavaNextPlugin, - "llava_next_video": LlavaNextVideoPlugin, - "minicpm_v": MiniCPMVPlugin, - "mllama": MllamaPlugin, - "paligemma": PaliGemmaPlugin, - "pixtral": PixtralPlugin, - "qwen2_audio": Qwen2AudioPlugin, - "qwen2_omni": Qwen2OmniPlugin, - "qwen2_vl": Qwen2VLPlugin, - "video_llava": VideoLlavaPlugin, -} - - -def register_mm_plugin(name: str, plugin_class: type["BasePlugin"]) -> None: - r"""Register a multimodal plugin.""" - if name in PLUGINS: - raise ValueError(f"Multimodal plugin {name} already exists.") - - PLUGINS[name] = plugin_class - - -def get_mm_plugin( - name: str, - image_token: Optional[str] = None, - video_token: Optional[str] = None, - audio_token: Optional[str] = None, -) -> "BasePlugin": - r"""Get plugin for multimodal inputs.""" - if name not in PLUGINS: - raise ValueError(f"Multimodal plugin `{name}` not found.") - - return PLUGINS[name](image_token, video_token, audio_token) diff --git a/agents/agents/agents/templates/templates.py b/agents/agents/agents/templates/templates.py index fbdf63a..c0b8fdd 100644 --- a/agents/agents/agents/templates/templates.py +++ b/agents/agents/agents/templates/templates.py @@ -209,7 +209,7 @@ def _encode_user_message(self, content: List[Dict]) -> str: for item in content: if item["type"] == "text": text += item["text"] - elif item["type"] == "image": + elif item["type"] in ["image", "image_url"]: text += self.vision_start + self.image_token + self.vision_end elif item["type"] == "video": text += self.vision_start + self.video_token + self.vision_end @@ -315,6 +315,11 @@ def _encode_with_vision_processor(self, messages: List[Dict], tokenizer: PreTrai # Extract vision inputs images, videos = extract_vision_inputs_from_messages(messages) + + Logger.debug(f"[Template] images: {len(images)}") + Logger.debug(f"[Template] videos: {len(videos)}") + + Logger.debug(f"[Template] messages: {messages}") # Use vision processor with alignment support return vision_processor.process_for_llm( @@ -392,24 +397,23 @@ def supports_vision(self) -> bool: """Check if this template supports vision processing""" return is_vision_template(self.name) - def get_vision_config(self): - """Get vision configuration for this template""" - from .vision_processor import VisionProcessorRegistry - return VisionProcessorRegistry.get_config(self.name) - - def get_vision_inputs(self): + def get_vision_inputs(self, messages: List[Dict]): vision_inputs = defaultdict(list) - for role, message, _ in self.messages: - if isinstance(message, list): - for item in message: + Logger.debug(f"[Template] get_vision_inputs: messages: {messages}") + for message in messages: + content = message["content"] + if isinstance(content, list): + for item in content: if item['type'] == 'text': continue - elif item['type'] == 'image': - vision_inputs['image'].append(open_image_from_any(item['image'])) + elif item['type'] in ['image', 'image_url', 'image_base64']: + vision_inputs["image"].append(open_image_from_any(item[item['type']])) elif item['type'] == 'video': raise NotImplementedError("Video is not supported for chat template.") else: raise ValueError(f"Invalid message type: {item['type']}") + else: + raise ValueError(f"Invalid message content: {content}, the content should be a list of dicts") return vision_inputs def jinja_template(self) -> str: @@ -645,6 +649,9 @@ def prompt_with_mask(self, add_generation_prompt=False, tools=None) -> str: prompt_with_mask, _, _ = self.template.render_with_mask(messages=self.messages, add_generation_prompt=add_generation_prompt, tools=tools) return prompt_with_mask + def vision_inputs(self) -> List[Any]: + return self.template.get_vision_inputs(self.messages) + def tokenize(self, tokenizer: PreTrainedTokenizer = None, add_generation_prompt=False, tools=None, processor=None) -> List[int]: if tokenizer is None: tokenizer = self.tokenizer diff --git a/agents/agents/agents/templates/test_alignment.py b/agents/agents/agents/templates/test_alignment.py deleted file mode 100644 index 4ef7756..0000000 --- a/agents/agents/agents/templates/test_alignment.py +++ /dev/null @@ -1,170 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script to verify tensor alignment functionality in the vision processor system. -""" - -import torch -from typing import Dict, List, Any - -def test_tensor_alignment(): - """Test that all tensors are properly aligned after vision token expansion""" - print("=== Testing Tensor Alignment ===") - - # Mock vision processor for testing - class MockVisionProcessor: - def expand_vision_tokens(self, prompt, images, videos, processor): - # Expand single image token to multiple tokens - return prompt.replace("<|image_pad|>", "<|image_pad|><|image_pad|><|image_pad|>") - - def calculate_image_tokens(self, image_data, processor): - return 3 # Mock: each image expands to 3 tokens - - def preprocess_images(self, images, processor): - return {"pixel_values": torch.randn(1, 3, 224, 224)} - - def preprocess_videos(self, videos, processor): - return {} - - def get_mm_inputs(self, images, videos, processor): - return {"pixel_values": torch.randn(1, 3, 224, 224)} - - # Mock tokenizer - class MockTokenizer: - def encode(self, text, add_special_tokens=False): - if "<|image_pad|>" in text: - # Replace image tokens with token IDs - text = text.replace("<|image_pad|>", "999") - return [1, 2, 999, 999, 999, 4] # Expanded tokens - return [1, 2, 3, 4] # Regular tokens - - @property - def add_bos_token(self): - return True - - @property - def bos_token(self): - return "" - - @property - def bos_token_id(self): - return 0 - - # Test data - elements = [ - "What's in this image? <|image_pad|>", # User message (masked) - "I can see a cat in the image." # Assistant message (not masked) - ] - mask_flags = [True, False] # User message masked, assistant not - - processor = MockVisionProcessor() - tokenizer = MockTokenizer() - - # Simulate the alignment process - input_ids = [] - attention_mask = [] - labels = [] - action_mask = [] - - # Add BOS token - input_ids.append(tokenizer.bos_token_id) - attention_mask.append(1) - labels.append(-100) - action_mask.append(0) - - # Process each element - for element, mask_flag in zip(elements, mask_flags): - # Check if element contains vision tokens - if "<|image_pad|>" in element: - # Expand vision tokens - expanded_element = processor.expand_vision_tokens(element, ["image.jpg"], [], None) - cur_input_ids = tokenizer.encode(expanded_element, add_special_tokens=False) - else: - cur_input_ids = tokenizer.encode(element, add_special_tokens=False) - - # Add tokens with proper alignment - input_ids.extend(cur_input_ids) - attention_mask.extend([1] * len(cur_input_ids)) - - if mask_flag: - labels.extend([-100] * len(cur_input_ids)) - action_mask.extend([0] * len(cur_input_ids)) - else: - labels.extend(cur_input_ids) - action_mask.extend([1] * len(cur_input_ids)) - - # Convert to tensors - inputs = { - 'input_ids': torch.tensor([input_ids]), - 'attention_mask': torch.tensor([attention_mask]), - 'labels': torch.tensor([labels]), - 'action_mask': torch.tensor([action_mask]) - } - - # Verify alignment - print("Tensor shapes:") - for key, value in inputs.items(): - print(f" {key}: {value.shape}") - - # Check that all tensors have the same sequence length - seq_len = inputs['input_ids'].shape[1] - assert inputs['attention_mask'].shape[1] == seq_len, "attention_mask not aligned" - assert inputs['labels'].shape[1] == seq_len, "labels not aligned" - assert inputs['action_mask'].shape[1] == seq_len, "action_mask not aligned" - - print(f"✅ All tensors aligned with sequence length: {seq_len}") - - # Verify the content makes sense - print("\nTensor content verification:") - print(f"input_ids: {inputs['input_ids'][0].tolist()}") - print(f"attention_mask: {inputs['attention_mask'][0].tolist()}") - print(f"labels: {inputs['labels'][0].tolist()}") - print(f"action_mask: {inputs['action_mask'][0].tolist()}") - - # Check that vision tokens are properly handled - vision_token_positions = [i for i, token_id in enumerate(inputs['input_ids'][0]) if token_id == 999] - print(f"\nVision token positions: {vision_token_positions}") - - # Verify that all tensors have proper values at vision token positions - for pos in vision_token_positions: - assert inputs['attention_mask'][0][pos] == 1, f"attention_mask should be 1 at position {pos}" - # Labels and action_mask depend on whether it's in a masked region - # This is a simplified test - in practice, you'd check the actual mask flags - - print("✅ Vision tokens properly handled in all tensors") - print("✅ Tensor alignment test passed!") - -def test_vision_processor_integration(): - """Test integration with the actual vision processor system""" - print("\n=== Testing Vision Processor Integration ===") - - try: - from .vision_processor import VisionProcessorRegistry, VisionProcessorConfig, PatchBasedProcessor - from .templates import get_template - - # Check if qwen2.5-vl template is registered - if VisionProcessorRegistry.is_vision_template("qwen2.5-vl"): - print("✅ qwen2.5-vl template is registered") - - processor = VisionProcessorRegistry.get_processor("qwen2.5-vl") - if processor is not None: - print("✅ Vision processor retrieved successfully") - - # Test the contains_vision_tokens method - test_text = "What's in this image? <|image_pad|>" - has_vision = processor._contains_vision_tokens(test_text) - print(f"✅ Vision token detection: {has_vision}") - - else: - print("❌ Vision processor not found") - else: - print("❌ qwen2.5-vl template not registered") - - except ImportError as e: - print(f"❌ Import error: {e}") - except Exception as e: - print(f"❌ Error: {e}") - -if __name__ == "__main__": - test_tensor_alignment() - test_vision_processor_integration() - print("\n=== All Tests Completed ===") \ No newline at end of file diff --git a/agents/agents/agents/templates/utils.py b/agents/agents/agents/templates/utils.py index eb5a9e9..3feb50a 100644 --- a/agents/agents/agents/templates/utils.py +++ b/agents/agents/agents/templates/utils.py @@ -193,12 +193,12 @@ def extract_vision_inputs_from_messages(messages: list) -> tuple[list, list]: for message in messages: if isinstance(message.get('content'), list): for item in message['content']: - if item.get('type') == 'image': + if item.get('type') in ['image', 'image_url']: if 'image' in item: images.append(item['image']) elif 'image_url' in item: images.append(item['image_url']['url']) - elif item.get('type') == 'video': + elif item.get('type') in ['video', 'video_url']: if 'video' in item: videos.append(item['video']) elif 'video_url' in item: diff --git a/agents/agents/agents/templates/vision_processor.py b/agents/agents/agents/templates/vision_processor.py index e3cd09c..ca14aea 100644 --- a/agents/agents/agents/templates/vision_processor.py +++ b/agents/agents/agents/templates/vision_processor.py @@ -122,6 +122,10 @@ def get_mm_inputs( ) -> Dict[str, torch.Tensor]: """Generate multi-modal inputs for the model""" pass + + def process_vision_info(self, messages: List[Dict]) -> Dict[str, torch.Tensor]: + """Process vision information from messages""" + pass # def process_for_llm( # self, @@ -529,6 +533,19 @@ def get_mm_inputs( return mm_inputs + def process_vision_info(self, messages: List[Dict], processor: Any): + """Process vision information from messages""" + image_message_types = ["image", "image_url", "image_base64"] + images = [] + for message in messages: + for content in message["content"]: + if content["type"] in image_message_types: + content_type = content["type"] + images.append(content[content_type]) + mm_inputs = self.get_mm_inputs(images, [], processor) + return mm_inputs + + class QwenVLProcessor(PatchBasedProcessor): """Qwen-VL specific processor with custom image preprocessing""" diff --git a/agents/agents/rewards/qa_reward.py b/agents/agents/rewards/qa_reward.py index 583519f..bd335b5 100644 --- a/agents/agents/rewards/qa_reward.py +++ b/agents/agents/rewards/qa_reward.py @@ -104,3 +104,18 @@ def qa_f1_reward_format(prediction: str, answer: str, trajectory: List[str]) -> raise ValueError(f"Invalid prediction or trajectory for qa reward with format: Trajectory: {trajectory}") return rewards_dict + + +@reward(name="ok_vqa_reward") +def ok_vqa_reward(prediction: str, answers: List[str], trajectory: List[str]) -> float: + """ + Calculate the reward for the agent's response based on the F1 score and EM score. + The reward is 0.0 if the agent has not called any tool. + The reward is the F1 score if the agent has called a tool. + """ + f1_scores = [] + for answer in answers: + f1, precision, recall = f1_score(prediction, answer) + f1_scores.append(f1) + # All answers are the correct answer, take the max f1 score + return max(f1_scores) \ No newline at end of file diff --git a/agents/tests/unit/agents/templates/test_vision_templates_tokenize.py b/agents/tests/unit/agents/templates/test_vision_templates_tokenize.py index 54d1714..31b5ecd 100644 --- a/agents/tests/unit/agents/templates/test_vision_templates_tokenize.py +++ b/agents/tests/unit/agents/templates/test_vision_templates_tokenize.py @@ -103,5 +103,8 @@ def test_chat_template_equal(template, messages, tools, add_generation_prompt): assert torch.equal(official_inputs["pixel_values"], implemented_inputs["pixel_values"]) assert torch.equal(official_inputs["image_grid_thw"], implemented_inputs["image_grid_thw"]) + + assert implemented_inputs["input_ids"].shape == implemented_inputs["action_mask"].shape, f"""Official action mask shape: {official_inputs["action_mask"].shape}\nImplemented action mask shape: {implemented_inputs["action_mask"].shape}""" + print(f"official_prompt: {official_prompt}\nimplemented_prompt: {tokenizer.decode(implemented_inputs['input_ids'][0])}\nofficial_inputs: {official_inputs.keys()}\nimplemented_inputs: {implemented_inputs.keys()}\n") \ No newline at end of file From 0844f3f8819b0d478edcad6fe77528e43f972f38 Mon Sep 17 00:00:00 2001 From: Reason-Wang Date: Sat, 9 Aug 2025 09:50:46 +0000 Subject: [PATCH 4/6] Update verl --- verl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/verl b/verl index 0186ec8..861f63b 160000 --- a/verl +++ b/verl @@ -1 +1 @@ -Subproject commit 0186ec81e9273953d2f34b0c4d48741cc2f9aabc +Subproject commit 861f63ba8097a43ababe27116842512783080586 From bbe3c0dc0441fd729c8ec0ee1ba36a760d1f6899 Mon Sep 17 00:00:00 2001 From: Reason-Wang Date: Sat, 9 Aug 2025 18:49:27 +0000 Subject: [PATCH 5/6] Add template tests --- agents/agents/agents/templates/templates.py | 4 - .../agents/templates/test_qwen25_vl_prompt.py | 0 .../test_text_templates_full_align.py | 98 +++++++++++++++++++ .../templates/test_text_templates_tokenize.py | 6 +- .../agents/templates/test_think_prompt.py | 0 5 files changed, 100 insertions(+), 8 deletions(-) delete mode 100644 agents/tests/unit/agents/templates/test_qwen25_vl_prompt.py delete mode 100644 agents/tests/unit/agents/templates/test_think_prompt.py diff --git a/agents/agents/agents/templates/templates.py b/agents/agents/agents/templates/templates.py index c0b8fdd..f249747 100644 --- a/agents/agents/agents/templates/templates.py +++ b/agents/agents/agents/templates/templates.py @@ -726,10 +726,6 @@ def get_template(name: str) -> Template: assistant_template="<|im_start|>assistant\n{content}<|im_end|>\n", tool_template="<|im_start|>user\n\n{observation}\n<|im_end|>\n", stop_words=["<|im_end|>"], - vision_start="<|vision_start|>", - vision_end="<|vision_end|>", - image_token="<|image_pad|>", - video_token="<|video_pad|>", ) ) diff --git a/agents/tests/unit/agents/templates/test_qwen25_vl_prompt.py b/agents/tests/unit/agents/templates/test_qwen25_vl_prompt.py deleted file mode 100644 index e69de29..0000000 diff --git a/agents/tests/unit/agents/templates/test_text_templates_full_align.py b/agents/tests/unit/agents/templates/test_text_templates_full_align.py index e69de29..dd1f326 100644 --- a/agents/tests/unit/agents/templates/test_text_templates_full_align.py +++ b/agents/tests/unit/agents/templates/test_text_templates_full_align.py @@ -0,0 +1,98 @@ +""" This file is for testing the text templates that align seamlessly with HF templates. The templates should align on following aspects: + - The obtained textual prompt should be the same as the one obtained from HF template with all the following options: + - add_generation_prompt + - tools + - The obtained textual prompt should be the same as the one obtained from Jinja template with all the following options: + - add_generation_prompt + - tools +""" + + +from agents.agents.templates.utils import compare_hf_template +from transformers import AutoTokenizer +import pytest + +@pytest.mark.parametrize("model_name_or_path", ["Qwen/Qwen2.5-3B-Instruct"]) +@pytest.mark.parametrize("messages", [ + [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I am fine, thank you."}, + {"role": "user", "content": "Want to play a game?"}, + {"role": "assistant", "content": "Sure, what game?"}, + ], + [ + {"role": "user", "content": "Help me to calculate 3 times 5."}, + {"role": "assistant", "content": '''{"name": "multiply", "arguments": {"x": 3, "y": 5}}'''}, + {"role": "tool", "content": "15"}, + ], + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I am fine, thank you."}, + {"role": "user", "content": "What is 3 times 5?"}, + ], +]) +@pytest.mark.parametrize("tools", [ + None, + [ + {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}}, + {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}}, + ] +]) +@pytest.mark.parametrize("add_generation_prompt", [True, False]) +def test_hf_template_print(model_name_or_path, messages, tools, add_generation_prompt): + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=add_generation_prompt, tools=tools) + print(f"========================================\nModel: {model_name_or_path}\nMessages: {messages}\nTools: {tools}\nAdd generation prompt: {add_generation_prompt}\n") + print(prompt) + print("========================================\n") + + +# "qwen2.5-think", "qwen2.5", "qwen2.5-no-tool", +@pytest.mark.parametrize("template", ["qwen2.5"]) +@pytest.mark.parametrize("messages", [ + [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I am fine, thank you."}, + {"role": "user", "content": "Want to play a game?"}, + {"role": "assistant", "content": "Sure, what game?"}, + ], + [ + {"role": "user", "content": "Help me to calculate 3 times 5."}, + {"role": "assistant", "content": '''{"name": "multiply", "arguments": {"x": 3, "y": 5}}'''}, + {"role": "tool", "content": "15"}, + ], + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I am fine, thank you."}, + {"role": "user", "content": "What is 3 times 5?"}, + ], +]) +@pytest.mark.parametrize("tools", [ + None, + [ + {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}}, + {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}}, + ] +]) +@pytest.mark.parametrize("add_generation_prompt", [True, False]) +def test_chat_template_equal(template, messages, tools, add_generation_prompt): + # Filter invalid combinations + if add_generation_prompt and messages[-1]['role'] == 'assistant': + return + + template_tokenizer_mapping = { + "qwen2.5": "Qwen/Qwen2.5-3B-Instruct", + "qwen2.5-think": "Qwen/Qwen2.5-3B-Instruct", + "qwen2.5-no-system-tool": "Qwen/Qwen2.5-3B-Instruct", + "deepseek-prover-v2": "deepseek-ai/DeepSeek-Prover-V2-7B", + } + tokenizer = AutoTokenizer.from_pretrained(template_tokenizer_mapping[template]) + + is_equal, is_equal_between_implemented_prompts, is_equal_between_jinja_prompts, official_prompt, implemented_prompt, implemented_jinja_prompt, highlighted_prompt = compare_hf_template(tokenizer, template, messages=messages, tools=tools,add_generation_prompt=add_generation_prompt) + assert is_equal, f"Template: {template}\n\nMessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nOfficial prompt:\n\n{official_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}" + assert is_equal_between_jinja_prompts, f"Template: {template}\n\nMessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}\n\nJinja prompt:\n\n{implemented_jinja_prompt}" + print(f"Official prompt:\n\n{official_prompt}") + print(f"Highlighted prompt:\n\n{highlighted_prompt}") + diff --git a/agents/tests/unit/agents/templates/test_text_templates_tokenize.py b/agents/tests/unit/agents/templates/test_text_templates_tokenize.py index d17235b..566e2b5 100644 --- a/agents/tests/unit/agents/templates/test_text_templates_tokenize.py +++ b/agents/tests/unit/agents/templates/test_text_templates_tokenize.py @@ -7,9 +7,9 @@ Since the align for textual prompt is already tested in other files, we only need to test the tokenization of the templates. """ -from agents.agents.templates.utils import is_vlm_template, tokenize_conversation +from agents.agents.templates.utils import tokenize_conversation import pytest -from transformers import AutoTokenizer, AutoProcessor +from transformers import AutoTokenizer import torch from agents.agents.templates.templates import Chat @@ -44,8 +44,6 @@ def test_template_tokenize(template, messages, tools, add_generation_prompt): template_tokenizer_mapping = { "qwen2.5": "Qwen/Qwen2.5-3B-Instruct", - "qwen2.5-vl": "Qwen/Qwen2.5-VL-3B-Instruct", - "qwen3": "Qwen/Qwen3-8B", } tokenizer = AutoTokenizer.from_pretrained(template_tokenizer_mapping[template]) diff --git a/agents/tests/unit/agents/templates/test_think_prompt.py b/agents/tests/unit/agents/templates/test_think_prompt.py deleted file mode 100644 index e69de29..0000000 From c8af7ea9a0e21502a9a38424889559cda04466f4 Mon Sep 17 00:00:00 2001 From: Reason-Wang Date: Tue, 12 Aug 2025 19:43:03 +0000 Subject: [PATCH 6/6] More powerful, flexible template system --- agents/agents/agents/templates/constants.py | 18 + .../agents/agents/templates/system_policy.py | 46 ++ agents/agents/agents/templates/templates.py | 673 ++++++++++++++---- agents/agents/agents/templates/tool_policy.py | 237 ++++++ agents/agents/agents/templates/utils.py | 3 + .../test_text_templates_full_align.py | 25 +- .../test_text_templates_partial_align.py | 68 ++ .../templates/test_text_templates_tokenize.py | 8 +- 8 files changed, 935 insertions(+), 143 deletions(-) create mode 100644 agents/agents/agents/templates/constants.py create mode 100644 agents/agents/agents/templates/system_policy.py create mode 100644 agents/agents/agents/templates/tool_policy.py create mode 100644 agents/tests/unit/agents/templates/test_text_templates_partial_align.py diff --git a/agents/agents/agents/templates/constants.py b/agents/agents/agents/templates/constants.py new file mode 100644 index 0000000..796ba97 --- /dev/null +++ b/agents/agents/agents/templates/constants.py @@ -0,0 +1,18 @@ +from enum import Enum, auto + +class ToolPlacement(Enum): + """ + Where to inject the tool catalogue in the rendered prompt. + """ + SYSTEM = auto() # inside the system message + FIRST_USER = auto() # as an extra first-user turn + LAST_USER = auto() # appended to the last user turn + SEPARATE = auto() # its own dedicated turn / role + + +class Role(Enum): + SYSTEM = "system" + USER = "user" + ASSISTANT = "assistant" + TOOL = "tool" + ASSISTANT_PREFIX = "assistant_prefix" \ No newline at end of file diff --git a/agents/agents/agents/templates/system_policy.py b/agents/agents/agents/templates/system_policy.py new file mode 100644 index 0000000..8ede734 --- /dev/null +++ b/agents/agents/agents/templates/system_policy.py @@ -0,0 +1,46 @@ +from abc import ABC, abstractmethod +import dataclasses +import datetime +from typing import Callable + +@dataclasses.dataclass +class SystemPolicy: + use_system: bool = True + use_system_without_system_message: bool = True + content_processor: Callable[[str], str] = None + + +class SystemContentProcessor(ABC): + @abstractmethod + def __call__(self, system_message: str) -> str: + raise NotImplementedError + + @abstractmethod + def jinja(self) -> str: + raise NotImplementedError + +class Llama32DateProcessor(SystemContentProcessor): + """ + A system content processor that adds date information to system messages. + + In Python mode, it dynamically computes the current date. + In Jinja mode, it provides a template with placeholders that can be processed. + + Usage in Jinja templates: + - The template includes '__CURRENT_DATE__' placeholder + - Replace '__CURRENT_DATE__' with the actual formatted date during processing + - Format should be 'dd MMM yyyy' (e.g., '15 Dec 2024') + - No external context variables required + """ + def __call__(self, system_message: str) -> str: + return f"Cutting Knowledge Date: December 2023\nToday Date: {datetime.datetime.now().strftime('%d %b %Y')}\n\n{system_message}" + + def jinja(self) -> str: + # For Jinja templates used by external systems (like vLLM), we need a self-contained approach + # Since external systems can't provide context variables, we use a placeholder approach + # The external system should replace __CURRENT_DATE__ with the actual date + return """Cutting Knowledge Date: December 2023 +Today Date: __CURRENT_DATE__ + +{{ system_message }}""" + \ No newline at end of file diff --git a/agents/agents/agents/templates/templates.py b/agents/agents/agents/templates/templates.py index f249747..7697f87 100644 --- a/agents/agents/agents/templates/templates.py +++ b/agents/agents/agents/templates/templates.py @@ -1,17 +1,30 @@ from collections import defaultdict -from copy import copy +from copy import copy, deepcopy import dataclasses -from enum import Enum, auto, IntEnum import json -from typing import List, Any, Dict, Union, Tuple +from typing import Callable, List, Any, Dict, Union, Tuple import warnings import logging import torch -from .preprocess import open_image_from_any from transformers import PreTrainedTokenizer +from .preprocess import open_image_from_any from .vision_processor import is_vision_template import re +from typing import Protocol +from .tool_policy import ( + ToolFormatter, + JsonMinifiedFormatter, + JsonCompactFormatter, + JsonIndentedFormatter, + ToolMainContentProcessor, + JsonQwenFormatter, +) +from datetime import datetime +from .constants import Role +from .system_policy import Llama32DateProcessor, SystemPolicy +from .tool_policy import ToolPolicy +from .constants import ToolPlacement, Role Logger = logging.getLogger(__name__) @@ -23,12 +36,9 @@ console_handler.setFormatter(formatter) Logger.addHandler(console_handler) -class Role(Enum): - SYSTEM = "system" - USER = "user" - ASSISTANT = "assistant" - TOOL = "tool" - ASSISTANT_PREFIX = "assistant_prefix" +@dataclasses.dataclass +class GlobalPolicy: + prefix: str = None @dataclasses.dataclass @@ -50,8 +60,15 @@ class Template: tool_template: str = None # The user template user_template: str = None + user_template_with_tools: str = None # The assistant template assistant_template: str = None + # Global policy + global_policy: "GlobalPolicy" = None + # System message policy + system_policy: "SystemPolicy" = None + # Tool policy for this template + tool_policy: "ToolPolicy" = None ## vision part vision_start: str = None @@ -59,10 +76,17 @@ class Template: image_token: str = None video_token: str = None + chat_template: str = None + def __post_init__(self): """Post-initialization to automatically register vision processor if vision tokens are defined""" if self.image_token or self.video_token: self._register_vision_processor() + # Initialise default tool policy if none was provided + if self.tool_policy is None: + self.tool_policy = ToolPolicy() + if self.system_policy is None: + self.system_policy = SystemPolicy() def _register_vision_processor(self): """Automatically register a vision processor for this template""" @@ -111,63 +135,140 @@ def _infer_model_type(self) -> str: # Default to patch-based for unknown models return "patch_based" + def _supports_tool_call(self) -> bool: + if (self.system_template_with_tools or self.user_template_with_tools) and self.tool_template: + return True + else: + return False + def render(self, messages: List[Dict], tools=None, add_generation_prompt: bool = False) -> str: - """Render the template with the given messages and kwargs. - messages: [ - { - "role": "user", - "content": [ - { - "type": "text", - "text": "Hello, how are you?" - } - ] - }, - { - "role": "assistant", - ] + """Render the template and return + 1. the final prompt string, + 2. the list of string *elements* that compose the prompt, and + 3. the corresponding list of *roles* (used by downstream post-processing). + + The heavy lifting is delegated to small, single-purpose helpers so the + high-level flow is immediately apparent: + + 1. _insert_tools – decide where the tool catalogue lives + 2. _encode_turns – encode every conversation turn + 3. _maybe_add_generation_prompt – append the generation prefix if requested """ - elements = [] - roles = [] - if tools: - tools = self._encode_system_tools(tools) - for i, message in enumerate(messages): + # Step 1 – decide tool placement & clone messages + work_messages, tools_str, insert_tools_idx = self._insert_tools(messages, tools) + + # Step 2 – encode each conversation turn to text tokens + elements, roles = self._encode_turns(work_messages, tools_str, insert_tools_idx) + + # Step 3 – append generation prefix if needed + if add_generation_prompt: + self._maybe_add_generation_prompt(elements, roles) + + # Concatenate the prompt + prompt = "".join(elements) + return prompt, elements, roles + + def _insert_tools(self, messages: List[Dict], tools): + """Clone *messages* and compute where (and how) the tool catalogue + should be injected. + + Returns + ------- + work_messages : List[Dict] + A deepcopy of the original *messages* so we never mutate caller data. + tools_str : Optional[str] + The formatted tool catalogue or *None* if `tools` is falsy. + insert_tools_idx : int + Index of the *user* message that receives the catalogue, or -1 when + no injection is required. + """ - if i == 0 and self._detect_role(message["role"]) == Role.SYSTEM: - system_message = self._encode_system_message(message["content"], tools=tools) - elements.append(system_message) - roles.append(Role.SYSTEM) - # This message is done + work_messages = deepcopy(messages) + if tools: + tools_str = self.tool_policy.format_tools(tools) + placement = self.tool_policy.placement + insert_tools_idx = self._find_insert_tools_index(work_messages, placement) + else: + tools_str = None + insert_tools_idx = -1 + return work_messages, tools_str, insert_tools_idx + + def _encode_turns( + self, + work_messages: List[Dict], + tools_str: str, + insert_tools_idx: int, + ) -> Tuple[List[str], List[Role]]: + """Convert every message dict into its textual representation while + tracking roles for later masking logic.""" + + elements: List[str] = [] + roles: List[Role] = [] + + # Global prefix comes first (rarely used but must respect ordering) + if self.global_policy and self.global_policy.prefix: + elements.append(self.global_policy.prefix) + roles.append(Role.SYSTEM) + + for i, message in enumerate(work_messages): + current_role = self._detect_role(message["role"]) + + # -------------------------------------------------------------- + # Handle system message insertion on the very first turn + # -------------------------------------------------------------- + if i == 0 and current_role == Role.SYSTEM: + if self.system_policy.use_system: + system_message = self._encode_system_message( + message["content"], tools=tools_str + ) + elements.append(system_message) + roles.append(Role.SYSTEM) + # Whether inserted or not, we skip further handling of this + # message because it's the (optional) system turn itself. continue - elif i == 0 and self._detect_role(message["role"]) != Role.SYSTEM: - system_message = self._encode_system_message_default(tools=tools) - elements.append(system_message) - roles.append(Role.SYSTEM) - # This message is not done, we need to handle other roles - - if self._detect_role(message["role"]) == Role.USER: - user_message = self._encode_user_message(message["content"]) + elif i == 0 and current_role != Role.SYSTEM: + if self.system_policy.use_system: + system_message = self._encode_system_message_default(tools=tools_str) + elements.append(system_message) + roles.append(Role.SYSTEM) + # Do *not* `continue` – we still need to encode this first message. + + # -------------------------------------------------------------- + # Encode regular conversation turns + # -------------------------------------------------------------- + if current_role == Role.USER: + if i == insert_tools_idx: + user_message = self._encode_user_message_with_tools( + message["content"], tools=tools_str + ) + else: + user_message = self._encode_user_message(message["content"]) elements.append(user_message) roles.append(Role.USER) - elif self._detect_role(message["role"]) == Role.ASSISTANT: + + elif current_role == Role.ASSISTANT: assistant_message = self._encode_assistant_message(message["content"]) elements.append(assistant_message) roles.append(Role.ASSISTANT) - elif self._detect_role(message["role"]) == Role.TOOL: + + elif current_role == Role.TOOL: tool_message = self._encode_tool_message(message["content"]) elements.append(tool_message) roles.append(Role.TOOL) + else: raise ValueError(f"Invalid role: {message['role']}") - - if add_generation_prompt: - generation_prefix = self._encode_generation_prompt() - elements.append(generation_prefix) - roles.append(Role.ASSISTANT_PREFIX) - - prompt = "".join(elements) - return prompt, elements, roles + + return elements, roles + + def _maybe_add_generation_prompt(self, elements: List[str], roles: List[Role]): + """Append the generation prefix so the model knows to continue + generating an assistant response.""" + + generation_prefix = self._encode_generation_prompt() + elements.append(generation_prefix) + roles.append(Role.ASSISTANT_PREFIX) def _detect_role(self, role: str) -> Role: if role == "system": @@ -180,53 +281,114 @@ def _detect_role(self, role: str) -> Role: return Role.TOOL else: raise ValueError(f"Invalid role: {role}") + + def _find_insert_tools_index(self, work_messages: List[Dict], placement: ToolPlacement) -> int: + insert_tools_idx = 0 # Default to insert tools at system message + for i, message in enumerate(work_messages): + if placement == ToolPlacement.SYSTEM: + insert_tools_idx = 0 + elif placement == ToolPlacement.FIRST_USER: + if message.get("role") == "user": + insert_tools_idx = i + break + elif placement == ToolPlacement.LAST_USER: + if message.get("role") == "user": + insert_tools_idx = i + else: + raise ValueError(f"Unhandled ToolPlacement: {placement}") + return insert_tools_idx def _encode_system_tools(self, tools: List[Dict]) -> str: return "\n".join([json.dumps(tool) for tool in tools]) def _encode_system_message_default(self, tools=None) -> str: + if not self.system_policy.use_system_without_system_message: + return "" + + if self.system_policy.content_processor is not None: + system_message = self.system_policy.content_processor(self.system_message) + else: + system_message = self.system_message + if tools is None: - return self.system_template.format(system_message=self.system_message) + return self.system_template.format(system_message=system_message) else: if self.system_template_with_tools: - return self.system_template_with_tools.format(system_message=self.system_message, tools=tools) + return self.system_template_with_tools.format(system_message=system_message, tools=tools) else: - return self.system_template.format(system_message=self.system_message) + return self.system_template.format(system_message=system_message) def _encode_system_message(self, content, tools=None) -> str: - if tools is None: + # Handle both string content and list content formats + if isinstance(content, str): + system_message = content + else: system_message = content[0]['text'] + + if self.system_policy.content_processor is not None: + system_message = self.system_policy.content_processor(system_message) + + if tools is None: return self.system_template.format(system_message=system_message) else: - system_message = content[0]['text'] if self.system_template_with_tools is None: return self.system_template.format(system_message=system_message) else: return self.system_template_with_tools.format(system_message=system_message, tools=tools) - - def _encode_user_message(self, content: List[Dict]) -> str: - text = "" - for item in content: - if item["type"] == "text": - text += item["text"] - elif item["type"] in ["image", "image_url"]: - text += self.vision_start + self.image_token + self.vision_end - elif item["type"] == "video": - text += self.vision_start + self.video_token + self.vision_end - else: - raise ValueError(f"Invalid message type: {item['type']}") + + def _encode_user_message_with_tools(self, content, tools: str) -> str: + # Handle both string content and list content formats + if isinstance(content, str): + text = content + else: + text = "" + for item in content: + if item["type"] == "text": + text += item["text"] + else: + raise ValueError(f"Invalid message type: {item['type']}") + + if self.user_template_with_tools: + user_message = self.user_template_with_tools.format(content=text, tools=tools) + else: + user_message = self.user_template.format(content=text) + return user_message + + def _encode_user_message(self, content) -> str: + # Handle both string content and list content formats + if isinstance(content, str): + text = content + else: + text = "" + for item in content: + if item["type"] == "text": + text += item["text"] + elif item["type"] in ["image", "image_url"]: + text += self.vision_start + self.image_token + self.vision_end + elif item["type"] == "video": + text += self.vision_start + self.video_token + self.vision_end + else: + raise ValueError(f"Invalid message type: {item['type']}") user_message = self.user_template.format(content=text) return user_message - def _encode_assistant_message(self, content: List[Dict]) -> str: - assert len(content) == 1, "Assistant message must be a single message" - text = content[0]["text"] + def _encode_assistant_message(self, content) -> str: + # Handle both string content and list content formats + if isinstance(content, str): + text = content + else: + assert len(content) == 1, "Assistant message must be a single message" + text = content[0]["text"] assistant_message = self.assistant_template.format(content=text) return assistant_message - def _encode_tool_message(self, content: List[Dict]) -> str: - assert len(content) == 1, "Tool message must be a single message" - text = content[0]["text"] + def _encode_tool_message(self, content) -> str: + # Handle both string content and list content formats + if isinstance(content, str): + text = content + else: + assert len(content) == 1, "Tool message must be a single message" + text = content[0]["text"] tool_message = self.tool_template.format(observation=text) return tool_message @@ -272,11 +434,15 @@ def _encode_standard(self, messages: List[Dict], tokenizer: PreTrainedTokenizer, labels = [] action_mask = [] - if tokenizer.bos_token and tokenizer.add_bos_token: - input_ids.append(tokenizer.bos_token_id) - attention_mask.append(1) - labels.append(-100) - action_mask.append(0) + + if tokenizer.bos_token: + # If add_bos_token is not set, we assume to add bos token + # There is potential issue if the tokenizer has bos_token but do not add it by default + if getattr(tokenizer, "add_bos_token", True): + input_ids.append(tokenizer.bos_token_id) + attention_mask.append(1) + labels.append(-100) + action_mask.append(0) for element, mask_flag in zip(elements, mask_flags): cur_input_ids = tokenizer.encode(element, add_special_tokens=False) @@ -417,47 +583,88 @@ def get_vision_inputs(self, messages: List[Dict]): return vision_inputs def jinja_template(self) -> str: - """ - Build a Hugging-Face-style chat-template (Jinja-mini dialect) that mimics - `self.render`. The template expects three variables in its context: + if self.chat_template: + return self.chat_template + else: + return self.render_jinja_template() + + def render_jinja_template(self) -> str: + """Return a Hugging-Face style chat-template (Jinja-mini dialect). - • messages – list[dict] (same format you pass to .render) - • add_generation_prompt – bool (default False) - • tools – list[dict] (optional, for tool-enabled templates) + The implementation now mirrors the three-step structure of + `render()` for easier maintenance: - No other Python state is referenced, so the string can be cached in the - tokenizer and shipped to a different process. + 1. _jinja_header_constants – immutable `set` statements + 2. _jinja_system_block – first turn / system handling + 3. _jinja_loop_messages – remaining turns & per-role logic + 4. _jinja_generation_block – optional generation prefix """ - # ------------------------------------------------------------------ - # 1. Pre-compute constant strings so the inner template stays tiny - # ------------------------------------------------------------------ - default_system = self.system_template.format(system_message=self.system_message) + + parts: List[str] = [] + + # 1. Constant header (always first) + parts.extend(self._jinja_header_constants()) + + # 2. System-message handling (depends on presence of tools etc.) + parts.extend(self._jinja_system_block()) + + # 2.5 Pre-compute insert index for user placement + parts.extend(self._jinja_compute_insert_idx()) + + # 3. Loop over remaining messages + parts.extend(self._jinja_loop_messages()) + + # 4. Generation prefix block + parts.extend(self._jinja_generation_block()) + + template_str = "".join(parts) - # Don't pre-format system_template_with_tools - handle it in Jinja - system_template_with_tools_raw = self.system_template_with_tools if self.system_template_with_tools else None + # Post-process: Replace __CURRENT_DATE__ placeholder with actual date + if "__CURRENT_DATE__" in template_str: + from datetime import datetime + current_date = datetime.now().strftime('%d %b %Y') + template_str = template_str.replace("__CURRENT_DATE__", current_date) + + return template_str + + # ------------------------------------------------------------------ + # Private helpers – keep them together for readability + # ------------------------------------------------------------------ + def _jinja_header_constants(self) -> List[str]: + """Return Jinja `set` statements for all constant strings.""" + + # Compute default system message considering content processor + if self.system_policy.content_processor is not None: + # Apply content processor to system message + processed_system_message = self.system_policy.content_processor(self.system_message) + default_system = self.system_template.format(system_message=processed_system_message) + else: + default_system = self.system_template.format(system_message=self.system_message) + + system_template_with_tools_raw = ( + self.system_template_with_tools if self.system_template_with_tools else None + ) + + # Split templates try: u_pref, u_suff = self.user_template.split("{content}") a_pref, a_suff = self.assistant_template.split("{content}") - except ValueError as e: # missing {content} - raise ValueError("`user_template` / `assistant_template` must contain " - "`{content}` placeholder") from e + except ValueError as exc: + raise ValueError( + "`user_template` / `assistant_template` must contain `{content}` placeholder" + ) from exc if self.tool_template: t_pref, t_suff = self.tool_template.split("{observation}") - else: # tools optional + else: t_pref, t_suff = "", "" - # tokens for images / videos + # Tokens for images / videos img_tok = (self.vision_start or "") + (self.image_token or "") + (self.vision_end or "") vid_tok = (self.vision_start or "") + (self.video_token or "") + (self.vision_end or "") - # ------------------------------------------------------------------ - # 2. Assemble the Jinja text; everything in triple-quotes is copied - # verbatim into the tokenizer. We splice in the constants that - # never change for this Template instance. - # ------------------------------------------------------------------ - template_parts = [ + header = [ f"{{% set _u_pref = {u_pref!r} %}}", f"{{% set _u_suff = {u_suff!r} %}}", f"{{% set _a_pref = {a_pref!r} %}}", @@ -469,39 +676,146 @@ def jinja_template(self) -> str: f"{{% set _default_system = {default_system!r} %}}", f"{{% set _system_message = {self.system_message!r} %}}", f"{{% set _system_template = {self.system_template!r} %}}", + f"{{% set _tool_placement = {self.tool_policy.placement.name!r} %}}", ] - - # Add system template with tools if available + if system_template_with_tools_raw: - template_parts.append(f"{{% set _system_template_with_tools = {system_template_with_tools_raw!r} %}}") - - template_parts.extend([ + header.append( + f"{{% set _system_template_with_tools = {system_template_with_tools_raw!r} %}}" + ) + + # Add user template with tools if it exists + if self.user_template_with_tools: + # Convert double braces to single braces for Jinja compatibility + processed_template = self.user_template_with_tools.replace('{{', '{').replace('}}', '}') + header.append( + f"{{% set _u_template_with_tools = {processed_template!r} %}}" + ) + + # ------------------------------------------------------------------ + # Formatter macro for tools (only if the template supports tool calls) + # ------------------------------------------------------------------ + + if self._supports_tool_call(): + # Build a Jinja macro that reproduces ToolPolicy.format_tools behaviour + formatter_snippet = self.tool_policy.formatter.jinja() + + # The snippet usually comes wrapped in "{{ ... }}". We drop the + # outer braces because macro bodies are already an output context. + formatter_body = formatter_snippet.strip() + + header.extend( + [ + "{% macro _fmt_tools(tools) %}", + f"{formatter_body}", + "{% endmacro %}", + ] + ) + + # ------------------------------------------------------------------ + # System processor macro (if system policy has a content processor) + # ------------------------------------------------------------------ + + if self.system_policy.content_processor is not None: + # Build a Jinja macro that reproduces the system content processor behaviour + processor_snippet = self.system_policy.content_processor.jinja() + + # The snippet should be a template that expects 'system_message' variable + # We create a macro that can be called with the system message + header.extend( + [ + "{% macro _process_system_message(system_message) %}", + f"{processor_snippet}", + "{% endmacro %}", + ] + ) + + return header + + def _jinja_compute_insert_idx(self) -> List[str]: + """Return Jinja code that pre-computes the index where tools should + be injected for FIRST_USER and LAST_USER placements.""" + + return [ + "{% set _insert_ns = namespace(idx=-1) %}", + "{% if _tool_placement in ['FIRST_USER', 'LAST_USER'] %}", + "{%- for _m in messages -%}", + "{%- if _m['role'] == 'user' -%}", + "{%- if _tool_placement == 'FIRST_USER' and _insert_ns.idx == -1 -%}", + "{% set _insert_ns.idx = loop.index0 %}", + "{%- elif _tool_placement == 'LAST_USER' -%}", + "{% set _insert_ns.idx = loop.index0 %}", + "{%- endif -%}", + "{%- endif -%}", + "{%- endfor -%}", + "{% endif %}", + ] + + def _jinja_system_block(self) -> List[str]: + """Return Jinja code that handles the system message logic.""" + + return [ # Handle system message first (matching render logic) "{% if messages and messages[0]['role'] == 'system' %}", "{% if tools and _system_template_with_tools %}", "{% if messages[0]['content'] is string %}", - "{{ _system_template_with_tools.format(system_message=messages[0]['content'], tools=tools | map('tojson') | join('\\n')) }}", + "{% if _process_system_message is defined %}", + "{{ _system_template_with_tools.format(system_message=_process_system_message(messages[0]['content']), tools=_fmt_tools(tools)) }}", "{% else %}", - "{{ _system_template_with_tools.format(system_message=messages[0]['content'][0]['text'], tools=tools | map('tojson') | join('\\n')) }}", + "{{ _system_template_with_tools.format(system_message=messages[0]['content'], tools=_fmt_tools(tools)) }}", + "{% endif %}", + "{% else %}", + "{% if _process_system_message is defined %}", + "{{ _system_template_with_tools.format(system_message=_process_system_message(messages[0]['content'][0]['text']), tools=_fmt_tools(tools)) }}", + "{% else %}", + "{{ _system_template_with_tools.format(system_message=messages[0]['content'][0]['text'], tools=_fmt_tools(tools)) }}", + "{% endif %}", "{% endif %}", "{% else %}", "{% if messages[0]['content'] is string %}", + "{% if _process_system_message is defined %}", + "{% set processed_message = _process_system_message(messages[0]['content']) %}", + "{% set formatted_system = _system_template | replace('{system_message}', processed_message) %}{{ formatted_system }}", + "{% else %}", "{% set formatted_system = _system_template | replace('{system_message}', messages[0]['content']) %}{{ formatted_system }}", + "{% endif %}", + "{% else %}", + "{% if _process_system_message is defined %}", + "{% set processed_message = _process_system_message(messages[0]['content'][0]['text']) %}", + "{% set formatted_system = _system_template | replace('{system_message}', processed_message) %}{{ formatted_system }}", "{% else %}", "{% set formatted_system = _system_template | replace('{system_message}', messages[0]['content'][0]['text']) %}{{ formatted_system }}", "{% endif %}", "{% endif %}", + "{% endif %}", "{% else %}", "{% if tools and _system_template_with_tools %}", - "{{ _system_template_with_tools.format(system_message=_system_message, tools=tools | map('tojson') | join('\\n')) }}", + "{% if _process_system_message is defined %}", + "{{ _system_template_with_tools.format(system_message=_process_system_message(_system_message), tools=_fmt_tools(tools)) }}", + "{% else %}", + "{{ _system_template_with_tools.format(system_message=_system_message, tools=_fmt_tools(tools)) }}", + "{% endif %}", + "{% else %}", + "{% if _process_system_message is defined %}", + "{% set processed_message = _process_system_message(_system_message) %}", + "{% set formatted_system = _system_template | replace('{system_message}', processed_message) %}{{ formatted_system }}", "{% else %}", "{{ _default_system }}", "{% endif %}", "{% endif %}", + "{% endif %}", + ] + + def _jinja_loop_messages(self) -> List[str]: + """Return Jinja loop that encodes all messages except the first system.""" + + return [ + "{% set _tool_ns = namespace(inserted=False, user_count=0) %}", # Process remaining messages (skip first if it was system) "{% for m in messages %}", "{% if not (loop.first and m['role'] == 'system') %}", "{% if m['role'] == 'user' %}", + "{% set _tool_ns.user_count = _tool_ns.user_count + 1 %}", "{% set ns = namespace(txt='') %}", "{% if m['content'] is string %}", "{% set ns.txt = m['content'] %}", @@ -516,7 +830,17 @@ def jinja_template(self) -> str: "{% endif %}", "{% endfor %}", "{% endif %}", + "{% if tools and ((_tool_placement == 'FIRST_USER' and _tool_ns.user_count == 1) or (_tool_placement == 'LAST_USER' and loop.index0 == _insert_ns.idx)) and not _tool_ns.inserted %}", + "{% if _u_template_with_tools is defined %}", + "{% set formatted_tools = _fmt_tools(tools) %}", + "{{ _u_template_with_tools | replace('{content}', ns.txt) | replace('{tools}', formatted_tools) }}", + "{% else %}", + "{{ _u_pref }}{{ ns.txt }}{{ _u_suff }}\\n{{ _fmt_tools(tools) }}", + "{% endif %}", + "{% set _tool_ns.inserted = True %}", + "{% else %}", "{{ _u_pref }}{{ ns.txt }}{{ _u_suff }}", + "{% endif %}", "{% elif m['role'] == 'assistant' %}", "{% if m['content'] is string %}", "{{ _a_pref }}{{ m['content'] }}{{ _a_suff }}", @@ -532,12 +856,16 @@ def jinja_template(self) -> str: "{% endif %}", "{% endif %}", "{% endfor %}", + ] + + def _jinja_generation_block(self) -> List[str]: + """Return Jinja code that appends the generation prefix when requested.""" + + return [ "{% if add_generation_prompt %}", "{{ _a_pref }}", - "{% endif %}" - ]) - - return "".join(template_parts) + "{% endif %}", + ] def render_with_mask(self, messages: List[Dict], add_generation_prompt: bool = False, tools=None): @@ -558,14 +886,6 @@ def set_system_message(self, system_message: str): """Set the system message.""" self.system_message = system_message - def set_tools(self, tools: List[Dict]): - """Set the tools.""" - if self.tool_aggregator == "DEFAULT": - self.tools = json.dumps(tools) - elif self.tool_aggregator == "STACKED": - self.tools = "\n".join([json.dumps(tool) for tool in tools]) - else: - raise ValueError(f"Invalid tool aggregator: {self.tool_aggregator}") def copy(self): return Template( @@ -574,6 +894,7 @@ def copy(self): system_template_with_tools=self.system_template_with_tools, system_message=self.system_message, user_template=self.user_template, + user_template_with_tools=self.user_template_with_tools, assistant_template=self.assistant_template, tool_template=self.tool_template, stop_words=self.stop_words, @@ -581,6 +902,10 @@ def copy(self): vision_end=self.vision_end, image_token=self.image_token, video_token=self.video_token, + global_policy=deepcopy(self.global_policy), + system_policy=deepcopy(self.system_policy), + tool_policy=deepcopy(self.tool_policy), + chat_template=self.chat_template, ) def dict(self): @@ -692,10 +1017,6 @@ def get_template(name: str) -> Template: assistant_template="<|im_start|>assistant\n{content}<|im_end|>\n", tool_template="<|im_start|>user\n\n{observation}\n<|im_end|>\n", stop_words=["<|im_end|>"], - vision_start="<|vision_start|>", - vision_end="<|vision_end|>", - image_token="<|image_pad|>", - video_token="<|video_pad|>", ) ) @@ -759,16 +1080,98 @@ def get_template(name: str) -> Template: ) ) + + +# TODO: mistral template has many cornor cases, leave it for now +# register_template( +# Template( +# name="mistral", +# system_template="{system_message}", +# user_template="[INST] {content}[/INST] ", +# user_template_with_tools="[AVAILABLE TOOLS] {tools} [/AVAILABLE TOOLS] [INST] {content}[/INST] ", +# assistant_template="{content}", +# tool_template="{observation}", +# stop_words=[""], +# system_policy=SystemPolicy( +# use_system=False, +# ), +# tool_policy=ToolPolicy( +# placement=ToolPlacement.LAST_USER, +# formatter=JsonCompactFormatter() +# ) +# ) +# ) + +# TODO: system template includes current date +register_template( + Template( + name="llama-3.2", + system_template="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>", + system_template_with_tools="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\n{system_message}<|eot_id|>", + user_template="<|start_header_id|>user<|end_header_id|>\n\n{content}<|eot_id|>", + user_template_with_tools="""<|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {{"name": function name, "parameters": dictionary of argument name and its value}}.Do not use variables.\n\n{tools}\n\n{content}<|eot_id|>""", + assistant_template="<|start_header_id|>assistant<|end_header_id|>\n\n{content}<|eot_id|>", + tool_template="""<|start_header_id|>ipython<|end_header_id|>\n\n"{observation}"<|eot_id|>""", + stop_words=["<|eot_id|>"], + system_policy=SystemPolicy( + use_system=True, + content_processor=Llama32DateProcessor(), + ), + tool_policy=ToolPolicy( + placement=ToolPlacement.FIRST_USER, + formatter=JsonIndentedFormatter() + ) + ) +) + register_template( Template( - name="deepseek-prover-v2", - system_template="<|begin▁of▁sentence|>{system_message}", - user_template="<|User|>{content}", - assistant_template="<|Assistant|>{content}<|end▁of▁sentence|>", - stop_words=["<|end▁of▁sentence|>"], + name="glm-4", + system_template="<|system|>\n{system_message}", + user_template="<|user|>\n{content}", + assistant_template="<|assistant|>\n{content}", + stop_words=[""], + global_policy=GlobalPolicy( + prefix="[gMASK]" + ), + system_policy=SystemPolicy( + use_system=True, + use_system_without_system_message=False, + ), ) ) +register_template( + Template( + name="phi-4", + system_template="<|im_start|>system<|im_sep|>{system_message}<|im_end|>", + user_template="<|im_start|>user<|im_sep|>{content}<|im_end|>", + assistant_template="<|im_start|>assistant<|im_sep|>{content}<|im_end|>", + stop_words=["<|im_end|>"], + ) +) + +# Note: Partial align, some minor new-line problems. +register_template( + Template( + name="nemotron", + system_template="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{system_message}<|eot_id|>", + system_template_with_tools="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{system_message}{tools}<|eot_id|>""", + user_template="<|start_header_id|>user<|end_header_id|>\n\n{content}<|eot_id|>", + assistant_template="<|start_header_id|>assistant<|end_header_id|>\n\n{content}<|eot_id|>", + tool_template="<|start_header_id|>user<|end_header_id|>\n\n[{observation}]<|eot_id|>", + stop_words=["<|eot_id|>"], + system_policy=SystemPolicy( + use_system=True, + content_processor=lambda x: f"\n{x}", + ), + tool_policy=ToolPolicy( + placement=ToolPlacement.SYSTEM, + content_processor=ToolMainContentProcessor(), + formatter=JsonCompactFormatter(), + ) + ) +) if __name__ == "__main__": - pass + pass \ No newline at end of file diff --git a/agents/agents/agents/templates/tool_policy.py b/agents/agents/agents/templates/tool_policy.py new file mode 100644 index 0000000..377c098 --- /dev/null +++ b/agents/agents/agents/templates/tool_policy.py @@ -0,0 +1,237 @@ +from typing import List, Dict, Tuple +import json +from typing import Callable, Any # Added for content processor typing +from abc import ABC, abstractmethod +import dataclasses + +from .constants import ToolPlacement + +# Convert ToolFormatter into an abstract base class +class ToolFormatter(ABC): + """ + Strategy that converts an in-memory list[dict] describing tools + into the textual representation expected by the target model. + """ + @abstractmethod + def format(self, tools: List[Dict]) -> str: + """Format a list of tool dictionaries into a string representation.""" + raise NotImplementedError + + @abstractmethod + def jinja(self) -> str: + """Return a Jinja template that can be used to format the tools.""" + raise NotImplementedError + + +class ToolContentProcessor(ABC): + """ + Strategy that processes the content of a tool before it is serialized. + """ + @abstractmethod + def __call__(self, tool: Dict) -> Dict: + raise NotImplementedError + + @abstractmethod + def jinja(self) -> str: + """Return a Jinja template that can be used to process the content of a tool.""" + raise NotImplementedError + + +class ToolMainContentProcessor(ToolContentProcessor): + """ + Strategy that processes the main content of a tool before it is serialized. + """ + def __call__(self, tool: Dict) -> Dict: + assert isinstance(tool, dict), "Tool must be a dictionary" + if "function" in tool: + content = tool["function"] + assert "name" in content, "Tool function must have a name" + assert "parameters" in content, "Tool function must have parameters" + return content + elif "name" in tool and "parameters" in tool: + return tool + else: + raise ValueError(f"Tool must have a function or name and parameters: {tool}") + + # The main-content extraction cannot be replicated in pure Jinja, so we + # fall back to the identity behaviour at template-generation time. This + # means the processor is *ignored* in frozen chat-templates; users who + # require it must rely on the Python render path. + + def jinja(self) -> str: + # We deliberately document the limitation by returning a simple pass- + # through expression. + return "{{ tool }}" + +# Make JsonFormatter inherit the ToolFormatter base class +class JsonFormatter(ToolFormatter): + """General JSON formatter with configurable indent, separators, and joiner.""" + + def __init__( + self, + *, + indent: int | None = None, + separators: Tuple[str, str] | None = None, + joiner: str = "\n", + format_as_list: bool = False, + content_processor: ToolContentProcessor = None, + ): + """Create a new JsonFormatter. + + Args: + indent: Indentation level passed to ``json.dumps``. ``None`` means no pretty-print. + separators: Custom separators passed to ``json.dumps``; useful for minification. + joiner: String used to join per-tool JSON strings when ``format_as_list`` is *False*. + format_as_list: If *True*, the entire ``tools`` list is serialised in a single + ``json.dumps`` call, ignoring ``joiner``. This is handy when the target + model expects a single JSON array instead of multiple individual objects. + content_processor: Optional callable applied to each individual tool dictionary + before serialisation. Defaults to the identity function. + """ + self.indent = indent + self.separators = separators + self.joiner = joiner + self.format_as_list = format_as_list + + def format(self, tools: List[Dict]) -> str: # noqa: D401 + """Return a single string obtained by dumping every tool to JSON then joining them. + + Args: + tools: A list of tool dictionaries to be stringified. + + Returns: + A string representation of the tools, formatted according to the + given ``indent``/``separators`` and concatenated with ``joiner``. + """ + # Apply the per-tool content processor first + + if self.format_as_list: + # Serialize the whole list in one go – joiner is irrelevant in this mode. + return json.dumps(tools, indent=self.indent, separators=self.separators) + + # Default behaviour: dump each tool individually then concatenate. + return self.joiner.join( + json.dumps(t, indent=self.indent, separators=self.separators) for t in tools + ) + + # ------------------------------------------------------------------ + # Jinja support + # ------------------------------------------------------------------ + + def _escape_joiner(self, joiner: str) -> str: # local helper + """Return *joiner* escaped so it is safe inside a single‐quoted Jinja + string literal (the HF chat-template parser understands the Python + backslash escapes).""" + + return joiner.replace("\\", "\\\\").replace("'", "\\'") + + def jinja(self) -> str: # noqa: D401 + """Return a **Jinja-mini** snippet that serialises the *tools* variable + with the same settings as :py:meth:`format`. + + The template assumes that a ``tools`` list is present in the Jinja + context. Because the Hugging-Face chat-template dialect only supports + a limited subset of Jinja, we restrict ourselves to `map`, `tojson`, + `join`, and optional indent on a *single* tojson call when + ``format_as_list`` is *True*. + + When ``format_as_list`` is *False* and ``indent`` is specified, we use + a Jinja loop to apply indentation to each individual tool. + """ + + # Serialise whole list -> one tojson call (supports indent argument) + if self.format_as_list: + if self.indent is None: + return "{{ tools | tojson }}" + else: + return f"{{{{ tools | tojson(indent={self.indent}) }}}}" + + # Individual objects: use loop if indent is needed, otherwise use map + if self.indent is not None: + # Use loop to apply indentation to each individual tool + # For joiners containing newlines, we need to avoid whitespace control to preserve them + # For other joiners, we can use whitespace control for cleaner output + + if '\n' in self.joiner: + # Joiner contains newlines - use Jinja's string replacement to convert \n to actual newlines + # We'll create a Jinja variable with the proper newlines + joiner_var = '{% set joiner = "' + self.joiner.replace('\n', '\\n') + '" | replace("\\\\n", "\n") %}' + return joiner_var + f"{{% for tool in tools %}}{{{{ tool | tojson(indent={self.indent}) }}}}{{% if not loop.last %}}{{{{ joiner }}}}{{% endif %}}{{% endfor %}}" + else: + # Joiner doesn't contain newlines - safe to use whitespace control and escaping + joiner_escaped = self._escape_joiner(self.joiner) + return f"{{%- for tool in tools -%}}{{{{ tool | tojson(indent={self.indent}) }}}}{{%- if not loop.last -%}}{joiner_escaped}{{%- endif -%}}{{%- endfor -%}}" + else: + # No indentation needed, use the simpler map approach + joiner_escaped = self._escape_joiner(self.joiner) + return ( + "{{ tools | map('tojson') | join('" + joiner_escaped + "') }}" + ) + +class JsonMinifiedFormatter(JsonFormatter): + """Single-line JSON objects without extra whitespace (legacy alias).""" + + def __init__(self, joiner: str = "\n", *, content_processor: Callable[[Dict], Any] | None = None): + super().__init__(indent=None, separators=(",", ":"), joiner=joiner, content_processor=content_processor) + + +class JsonIndentedFormatter(JsonFormatter): + """ + Pretty printed JSON with configurable indent (default 4). + Frequently required by models like Mistral-v0.3. + (legacy alias) + """ + + def __init__(self, indent: int = 4, *, joiner: str = "\n\n", format_as_list: bool = False): + super().__init__(indent=indent, separators=None, joiner=joiner, format_as_list=format_as_list) + + +class JsonCompactFormatter(JsonFormatter): + """Single-line JSON objects without extra whitespace.""" + def __init__(self, *, format_as_list: bool = True, content_processor: Callable[[Dict], Any] | None = None): + super().__init__(indent=None, separators=None, format_as_list=format_as_list, content_processor=content_processor) + +class JsonQwenFormatter(JsonFormatter): + """ + JSON formatter for Qwen models. + """ + def __init__(self): + super().__init__(indent=None, separators=None, format_as_list=False, content_processor=None) + + # No special behaviour – inherits .jinja from JsonFormatter + + +# --------------------------------------------------------------------------- +# Content processors – only implement jinja where feasible +# --------------------------------------------------------------------------- + + +try: + import yaml as _yaml # optional dependency + + class YamlFormatter(ToolFormatter): # type: ignore + def format(self, tools: List[Dict]) -> str: # noqa: D401 + return _yaml.safe_dump(tools, sort_keys=False) +except ModuleNotFoundError: # pragma: no cover + YamlFormatter = None # type: ignore + + +@dataclasses.dataclass +class ToolPolicy: + """ + Encapsulates every configuration decision about how *tools* + appear in the prompt for a given template. + """ + placement: "ToolPlacement" = ToolPlacement.SYSTEM + content_processor: Callable[[Dict], Any] = None + formatter: ToolFormatter = dataclasses.field(default_factory=lambda: JsonQwenFormatter()) + + def format_tools(self, tools: List[Dict]) -> str: + """ + Convert `tools` into ready-to-inject text according to the chosen formatter. + """ + if self.content_processor is not None: + processed_tools = [self.content_processor(t) for t in tools] + else: + processed_tools = tools + return self.formatter.format(processed_tools) diff --git a/agents/agents/agents/templates/utils.py b/agents/agents/agents/templates/utils.py index b8c312a..9f62182 100644 --- a/agents/agents/agents/templates/utils.py +++ b/agents/agents/agents/templates/utils.py @@ -305,6 +305,9 @@ def compare_hf_template(tokenizer, template_name, messages=None, tools=None, add plain_highlighted_prompt = strip_ansi(highlighted_prompt) is_equal_between_implemented_prompts = implemented_prompt == plain_highlighted_prompt jinja_template = chat.template.jinja_template() + # Save jinja template to file + with open("jinja_template.jinja", "w") as f: + f.write(jinja_template) tokenizer.chat_template = jinja_template implemented_jinja_prompt = tokenizer.apply_chat_template(messages, tokenize=False, tools=tools, add_generation_prompt=add_generation_prompt) is_equal_between_jinja_prompts = implemented_jinja_prompt == implemented_prompt diff --git a/agents/tests/unit/agents/templates/test_text_templates_full_align.py b/agents/tests/unit/agents/templates/test_text_templates_full_align.py index dd1f326..006ee4e 100644 --- a/agents/tests/unit/agents/templates/test_text_templates_full_align.py +++ b/agents/tests/unit/agents/templates/test_text_templates_full_align.py @@ -12,7 +12,10 @@ from transformers import AutoTokenizer import pytest -@pytest.mark.parametrize("model_name_or_path", ["Qwen/Qwen2.5-3B-Instruct"]) +@pytest.mark.parametrize("model_name_or_path", [ + # "Qwen/Qwen2.5-3B-Instruct", + "mistralai/Mistral-7B-Instruct-v0.3", +]) @pytest.mark.parametrize("messages", [ [ {"role": "user", "content": "Hello, how are you?"}, @@ -41,7 +44,7 @@ ]) @pytest.mark.parametrize("add_generation_prompt", [True, False]) def test_hf_template_print(model_name_or_path, messages, tools, add_generation_prompt): - tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True) prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=add_generation_prompt, tools=tools) print(f"========================================\nModel: {model_name_or_path}\nMessages: {messages}\nTools: {tools}\nAdd generation prompt: {add_generation_prompt}\n") print(prompt) @@ -49,7 +52,8 @@ def test_hf_template_print(model_name_or_path, messages, tools, add_generation_p # "qwen2.5-think", "qwen2.5", "qwen2.5-no-tool", -@pytest.mark.parametrize("template", ["qwen2.5"]) +# "llama-3.2", "mistral", "glm-4", "internlm2.5", "phi-3.5", "phi-4" +@pytest.mark.parametrize("template", ["llama-3.2", "qwen2.5"]) @pytest.mark.parametrize("messages", [ [ {"role": "user", "content": "Hello, how are you?"}, @@ -82,17 +86,26 @@ def test_chat_template_equal(template, messages, tools, add_generation_prompt): if add_generation_prompt and messages[-1]['role'] == 'assistant': return + template_tokenizer_mapping = { "qwen2.5": "Qwen/Qwen2.5-3B-Instruct", "qwen2.5-think": "Qwen/Qwen2.5-3B-Instruct", "qwen2.5-no-system-tool": "Qwen/Qwen2.5-3B-Instruct", "deepseek-prover-v2": "deepseek-ai/DeepSeek-Prover-V2-7B", + "llama-3.2": "meta-llama/Llama-3.2-3B-Instruct", + "mistral": "mistralai/Mistral-7B-Instruct-v0.3", + "glm-4": "THUDM/glm-4-9b-chat", + "internlm2.5": "internlm/internlm2_5-7b-chat", + "phi-3.5": "microsoft/Phi-3.5-mini-instruct", + "phi-4": "microsoft/Phi-4", + "nemotron": "nvidia/Llama-3.1-Nemotron-Nano-8B-v1", } - tokenizer = AutoTokenizer.from_pretrained(template_tokenizer_mapping[template]) + tokenizer = AutoTokenizer.from_pretrained(template_tokenizer_mapping[template], trust_remote_code=True) is_equal, is_equal_between_implemented_prompts, is_equal_between_jinja_prompts, official_prompt, implemented_prompt, implemented_jinja_prompt, highlighted_prompt = compare_hf_template(tokenizer, template, messages=messages, tools=tools,add_generation_prompt=add_generation_prompt) + assert is_equal, f"Template: {template}\n\nMessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nOfficial prompt:\n\n{official_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}" assert is_equal_between_jinja_prompts, f"Template: {template}\n\nMessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}\n\nJinja prompt:\n\n{implemented_jinja_prompt}" - print(f"Official prompt:\n\n{official_prompt}") - print(f"Highlighted prompt:\n\n{highlighted_prompt}") + # print(f"Official prompt:\n\n{official_prompt}") + # print(f"Highlighted prompt:\n\n{highlighted_prompt}") diff --git a/agents/tests/unit/agents/templates/test_text_templates_partial_align.py b/agents/tests/unit/agents/templates/test_text_templates_partial_align.py new file mode 100644 index 0000000..78b7f87 --- /dev/null +++ b/agents/tests/unit/agents/templates/test_text_templates_partial_align.py @@ -0,0 +1,68 @@ +import pytest +from transformers import AutoTokenizer +from agents.agents.templates.templates import get_template +from agents.agents.templates.utils import compare_hf_template + +# nemotron, phi-4, glm-4 +@pytest.mark.parametrize("template_name", ["qwen2.5-think", "qwen2.5-no-system-tool",]) +@pytest.mark.parametrize("messages", [ + [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I am fine, thank you."}, + {"role": "user", "content": "Want to play a game?"}, + {"role": "assistant", "content": "Sure, what game?"}, + ], + [ + {"role": "user", "content": "Help me to calculate 3 times 5."}, + {"role": "assistant", "content": '''{"name": "multiply", "arguments": {"x": 3, "y": 5}}'''}, + {"role": "tool", "content": "15", "tool_call_id": "123456789"}, + ], + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I am fine, thank you."}, + {"role": "user", "content": "What is 3 times 5?"}, + ], +]) +@pytest.mark.parametrize("tools", [ + None, + [ + {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}}, + {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}}, + ] +]) +@pytest.mark.parametrize("add_generation_prompt", [True, False]) +def test_chat_template_equal(template_name, messages, tools, add_generation_prompt): + # Filter invalid combinations + if add_generation_prompt and messages[-1]['role'] == 'assistant': + return + + template = get_template(template_name) + if tools and not template._supports_tool_call(): + return + + contain_tool_role = any(message['role'] == 'tool' for message in messages) + if contain_tool_role and not template._supports_tool_call(): + return + + template_tokenizer_mapping = { + "qwen2.5": "Qwen/Qwen2.5-3B-Instruct", + "qwen2.5-think": "Qwen/Qwen2.5-3B-Instruct", + "qwen2.5-no-system-tool": "Qwen/Qwen2.5-3B-Instruct", + "deepseek-prover-v2": "deepseek-ai/DeepSeek-Prover-V2-7B", + "llama-3.2": "meta-llama/Llama-3.2-3B-Instruct", + "mistral": "mistralai/Mistral-7B-Instruct-v0.3", + "glm-4": "THUDM/glm-4-9b-chat", + "internlm2.5": "internlm/internlm2_5-7b-chat", + "phi-3.5": "microsoft/Phi-3.5-mini-instruct", + "phi-4": "microsoft/Phi-4", + "nemotron": "nvidia/Llama-3.1-Nemotron-Nano-8B-v1", + } + tokenizer = AutoTokenizer.from_pretrained(template_tokenizer_mapping[template_name], trust_remote_code=True) + + is_equal, is_equal_between_implemented_prompts, is_equal_between_jinja_prompts, official_prompt, implemented_prompt, implemented_jinja_prompt, highlighted_prompt = compare_hf_template(tokenizer, template_name, messages=messages, tools=tools,add_generation_prompt=add_generation_prompt) + + print(f"Official prompt:\n\n\"{official_prompt}\"\n\n") + print(f"Implemented prompt:\n\n\"{implemented_prompt}\"\n\n") + assert is_equal, f"Template: {template}\n\nMessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nOfficial prompt:\n\n{official_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}" + assert is_equal_between_jinja_prompts, f"Template: {template}\n\nMessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}\n\nJinja prompt:\n\n{implemented_jinja_prompt}" \ No newline at end of file diff --git a/agents/tests/unit/agents/templates/test_text_templates_tokenize.py b/agents/tests/unit/agents/templates/test_text_templates_tokenize.py index 566e2b5..a89e0b8 100644 --- a/agents/tests/unit/agents/templates/test_text_templates_tokenize.py +++ b/agents/tests/unit/agents/templates/test_text_templates_tokenize.py @@ -13,7 +13,7 @@ import torch from agents.agents.templates.templates import Chat -@pytest.mark.parametrize("template", ["qwen2.5"]) +@pytest.mark.parametrize("template", ["llama-3.2", "qwen2.5"]) @pytest.mark.parametrize("messages", [ [ {"role": "user", "content": "Hello, how are you?"}, @@ -31,6 +31,9 @@ {"role": "user", "content": "Hello, how are you?"}, {"role": "assistant", "content": "I am fine, thank you."}, {"role": "user", "content": "What is 3 times 5?"}, + {"role": "assistant", "content": "15"}, + {"role": "user", "content": "OK, what is 3 times 6?"}, + {"role": "assistant", "content": "18"}, ], ]) @pytest.mark.parametrize("tools", [ @@ -44,8 +47,9 @@ def test_template_tokenize(template, messages, tools, add_generation_prompt): template_tokenizer_mapping = { "qwen2.5": "Qwen/Qwen2.5-3B-Instruct", + "llama-3.2": "meta-llama/Llama-3.2-3B-Instruct", } - tokenizer = AutoTokenizer.from_pretrained(template_tokenizer_mapping[template]) + tokenizer = AutoTokenizer.from_pretrained(template_tokenizer_mapping[template], trust_remote_code=True) chat = Chat(template, messages, tools=tools) prompt = chat.prompt(add_generation_prompt=add_generation_prompt, tools=tools)