diff --git a/README.md b/README.md index 553347c7..11593a71 100644 --- a/README.md +++ b/README.md @@ -62,6 +62,7 @@ Unlike prior approaches that simply concatenate full interaction histories, `ver - [3. Sokoban](#3-sokoban) - [4. Gym Cards](#4-gym-cards) - [5. AppWorld (Experimental)](#5-appworld-experimental) + - [6. Search](#6-search) - [Run Examples](#run-examples) - [RL Training](#rl-training) - [1. GiGPO](#1-gigpo) @@ -139,7 +140,7 @@ We have released our models on [HuggingFace](https://huggingface.co/collections/ # Installation ## Install veRL ```bash -conda create -n verl-agent python==3.12 -y +conda create -n verl-agent python==3.10 -y conda activate verl-agent pip3 install torch==2.6.0 --index-url https://download.pytorch.org/whl/cu124 @@ -177,12 +178,6 @@ alfworld-play-tw --- ### 2. WebShop -WebShop requires Python <=3.10, so begin by creating a new `verl-agent-webshop` environment -```bash -conda create -n verl-agent-webshop python==3.10 -y -conda activate verl-agent-webshop -``` - Install WebShop ```bash cd ./agent_system/environments/env_package/webshop/webshop @@ -241,9 +236,59 @@ appworld install appworld download data ``` +### 6. Search +```bash +conda activate verl-agent +cd ./agent_system/environments/env_package/search/third_party +pip install -e . +pip install gym==0.26.2 +``` + +Prepare dataset (data will be saved at `~/data/searchR1_processed_direct`): +```bash +cd repo_root/ +python examples/data_preprocess/preprocess_search_r1_dataset.py +``` + + +Since faiss-gpu is not available via pip, we setup a separate conda environment for the local retrieval server. Running this server will use around 6GB of GPU memory per GPU, so make sure to account for this in your training run configuration. Build Retriever environments: +```bash +# Create and activate the retriever environment with Python 3.10 +conda create -n retriever python=3.10 -y +conda activate retriever + +# Install PyTorch (with GPU support) and related libraries +conda install numpy==1.26.4 # needed to stop incompatible version of numpy from being installed via pip +pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124 - +# Install other Python packages +pip install transformers datasets pyserini huggingface_hub + +# Install the GPU version of faiss +conda install faiss-gpu==1.8.0 -c pytorch -c nvidia -y + +# Install the API service framework +pip install uvicorn fastapi +``` + +Download the index: +```bash +conda activate retriever + +local_dir=~/data/searchR1 +python examples/search/searchr1_download.py --local_dir $local_dir +cat $local_dir/part_* > $local_dir/e5_Flat.index +gzip -d $local_dir/wiki-18.jsonl.gz +``` + +Start the local flat e5 retrieval server: +```bash +conda activate retriever + +# redirect the output to a file to avoid cluttering the terminal +# we have observed outputting to the terminal causing spikes in server response times +bash examples/search/retriever/retrieval_launch.sh > retrieval_server.log +``` # Run Examples ## RL Training @@ -266,6 +311,9 @@ bash examples/gigpo_trainer/run_webshop.sh # WebShop ```bash bash examples/gigpo_trainer/run_sokoban.sh # Sokoban ``` +```bash +bash examples/gigpo_trainer/run_search.sh # Search +``` ### 2. GRPO GRPO is a critic-free algorithm that estimates relative advantages based on a group of full episode trajectories. ```bash @@ -274,6 +322,9 @@ bash examples/grpo_trainer/run_alfworld.sh # ALFWorld ```bash bash examples/grpo_trainer/run_webshop.sh # WebShop ``` +```bash +bash examples/grpo_trainer/run_search.sh # Search +``` ### 3. PPO PPO is a classic actor-critic algorithm that updates the policy using a clipped objective to ensure stable learning. It requires a separate value network (critic) to estimate state values. ```bash diff --git a/agent_system/agent/agent.py b/agent_system/agent/agent.py deleted file mode 100644 index 58e11724..00000000 --- a/agent_system/agent/agent.py +++ /dev/null @@ -1,247 +0,0 @@ -from __future__ import annotations -"""Agent definitions & registry. - -Each agent calls the LLM policy (``actor_rollout_wg``) in its -:py:meth:`act` method. Prompts for each role live in *prompt.py*. -""" -from typing import Dict, Any, Callable, Optional, List, Tuple -import copy -from verl import DataProto -from transformers import PreTrainedTokenizer -from agent_system.multi_turn_rollout.utils import preprocess_batch -from agent_system.agent.utils import tag_projection - -from agent_system.agent.prompts import AGENT_PROMPTS - -# Registry -class AgentRegistry: - _REGISTRY: Dict[str, Callable[..., "BaseAgent"]] = {} - - @classmethod - def register(cls, name: str): - def decorator(agent_cls: Callable[..., "BaseAgent"]): - if name in cls._REGISTRY: - raise ValueError(f"Agent '{name}' already registered.") - cls._REGISTRY[name] = agent_cls - return agent_cls - return decorator - - @classmethod - def create(cls, name: str, **kwargs): - if name not in cls._REGISTRY: - raise KeyError(f"Unknown agent '{name}'. Registered: {list(cls._REGISTRY)}") - return cls._REGISTRY[name](**kwargs) - - @classmethod - def names(cls) -> List[str]: - return list(cls._REGISTRY) - - -class BaseAgent: - """Abstract agent. All subclasses *must* implement :py:meth:`act`.""" - - def __init__(self, name: str, tokenizer: PreTrainedTokenizer, processor, config: Any): - self.name = name - self.tokenizer = tokenizer - self.processor = processor - self.config = config - - self.start_tag = None - self.end_tag = None - - # Check if prompt is defined for this agent via calling the property - if not hasattr(self, 'prompt') or not isinstance(self.prompt, str): - raise ValueError(f"Agent '{self.name}' must define a 'prompt' property.") - - def reset(self): - pass - - @property - def prompt(self) -> str: - """Return the prompt template""" - return AGENT_PROMPTS[self.name] - - def build_prompt(self, env_obs: Dict[str, Any], team_context: List[str], step: int) -> str: - """Build the prompt for the agent based on the observation.""" - # Naive Implementation - obs = copy.deepcopy(env_obs) - bs = len(obs['text']) - for i in range(bs): - if self.start_tag is not None and self.end_tag is not None: - obs['text'][i] = self.prompt.format(env_prompt=obs['text'][i], - team_context=team_context[i], - step=step, - start_tag=self.start_tag, - end_tag=self.end_tag) - else: - obs['text'][i] = self.prompt.format(env_prompt=obs['text'][i], - team_context=team_context[i], - step=step) - return obs - - def postprocess_batch(self, team_context: List[str], text_response: str) -> List[str]: - """Update the observation dictionary with the text response.""" - # Naive append of the latest responses to observations - for i in range(len(team_context)): - if team_context[i] == "": - team_context[i] = "Some of your teammates have already shared their thoughts for the current step. Their outputs are as follows:\n" - team_context[i] = team_context[i] + f"\n{self.name}:\n{text_response[i]}\n" - return team_context - - def _generate_with_llm(self, batch: DataProto, actor_rollout_wg, meta_info) -> Tuple[DataProto, List[str]]: - """Helper: prompt → input_ids → actor_rollout_wg → decoded str.""" - batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] - non_tensor_batch_keys_to_pop = ["raw_prompt_ids"] - if "multi_modal_data" in batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("multi_modal_data") - if "raw_prompt" in batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("raw_prompt") - if "tools_kwargs" in batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("tools_kwargs") - batch_input = batch.pop( - batch_keys=batch_keys_to_pop, - non_tensor_batch_keys=non_tensor_batch_keys_to_pop, - ) - - batch_input.meta_info = meta_info - batch_output = actor_rollout_wg.generate_sequences(batch_input) - - batch = batch.union(batch_output) - - text_repsonses = self.tokenizer.batch_decode(batch.batch['responses'], skip_special_tokens=True) - - text_repsonses, valids = tag_projection(text_repsonses, start_tag=self.start_tag, end_tag=self.end_tag) - batch.non_tensor_batch['is_action_valid'] = valids - - return batch, text_repsonses - - def call( - self, - gen_batch: DataProto, - env_obs: Dict[str, Any], - team_context: List[str], - actor_rollout_wg, - step: int, - ) -> Tuple[DataProto, List[str], List[str]]: - """Generate a response based on the observation and the batch. - Args: - gen_batch (DataProto): The input batch for generation. - env_obs (Dict[str, Any]): Observations from the environment. - - 'text' (List[str]): Text observation data - - 'image' (np.ndarray or torch.Tensor): Image observation data - - 'anchor' (None or Any): Anchor observation without any histories or additional info. (for GiGPO only). - team_context (List[str]): Contextual information from the team. - actor_rollout_wg: The LLM policy for acting. - step: environment step - Returns: - Tuple[DataProto, List[str], List[str]]: - - batch (DataProto): The processed batch after generation. - - text_repsonses (List[str]): The generated text responses. - - team_context (List[str]): Updated team context after processing. - """ - raise NotImplementedError - - -# ============================================================================= -# Reference agents (Memory / Planner / Action) using LLM -# ============================================================================= -@AgentRegistry.register("Reflexion Agent") -class ReflexionAgent(BaseAgent): - def __init__(self, tokenizer: PreTrainedTokenizer, processor,config: Any): - super().__init__("Reflexion Agent", tokenizer=tokenizer, processor=processor,config=config) - self.start_tag = "" - self.end_tag = "" - - def call(self, gen_batch: DataProto, env_obs: Dict[str, Any], team_context: List[str], actor_rollout_wg, step: int) -> Tuple[DataProto, List[str], List[str]]: - """Generate a summary of the conversation history.""" - if step == 0: - return None, None, team_context - - obs = self.build_prompt(env_obs, team_context, step) - batch = preprocess_batch(gen_batch=gen_batch, - obs=obs, - config=self.config, - tokenizer=self.tokenizer, - processor=self.processor, - ) - batch, text_repsonses = self._generate_with_llm(batch, actor_rollout_wg, gen_batch.meta_info) - - team_context = self.postprocess_batch(team_context, text_repsonses) - return batch, text_repsonses, team_context - -@AgentRegistry.register("Planning Agent") -class PlanningAgent(BaseAgent): - def __init__(self, tokenizer: PreTrainedTokenizer, processor, config: Any): - super().__init__("Planning Agent", tokenizer=tokenizer, processor=processor,config=config) - self.start_tag = "" - self.end_tag = "" - - def call(self, gen_batch: DataProto, env_obs: Dict[str, Any], team_context: List[str], actor_rollout_wg, step: int) -> Tuple[DataProto, List[str], List[str]]: - """Generate a summary of the conversation history.""" - obs = self.build_prompt(env_obs, team_context, step) - batch = preprocess_batch(gen_batch=gen_batch, - obs=obs, - config=self.config, - tokenizer=self.tokenizer, - processor=self.processor, - ) - batch, text_repsonses = self._generate_with_llm(batch, actor_rollout_wg, gen_batch.meta_info) - - team_context = self.postprocess_batch(team_context, text_repsonses) - return batch, text_repsonses, team_context - - -@AgentRegistry.register("Action Agent") -class ActionAgent(BaseAgent): - def __init__(self, tokenizer: PreTrainedTokenizer, processor, config: Any): - super().__init__("Action Agent", tokenizer=tokenizer, processor=processor,config=config) - self.start_tag = "" - self.end_tag = "" - - def projection(self, text_repsonses: List[str]) -> List[str]: - return [response.strip() for response in text_repsonses] - - def call(self, gen_batch: DataProto, env_obs: Dict[str, Any], team_context: List[str], actor_rollout_wg, step: int) -> Tuple[DataProto, List[str], List[str]]: - """Generate a summary of the conversation history.""" - obs = self.build_prompt(env_obs, team_context, step) - batch = preprocess_batch(gen_batch=gen_batch, - obs=obs, - config=self.config, - tokenizer=self.tokenizer, - processor=self.processor, - ) - batch, text_repsonses = self._generate_with_llm(batch, actor_rollout_wg, gen_batch.meta_info) - - team_context = self.postprocess_batch(team_context, text_repsonses) - return batch, text_repsonses, team_context - - -@AgentRegistry.register("Memory Agent") -class MemoryAgent(BaseAgent): - def __init__(self, tokenizer: PreTrainedTokenizer, processor, config: Any): - super().__init__("Memory Agent", tokenizer=tokenizer, processor=processor,config=config) - self.start_tag = "" - self.end_tag = "" - - def call(self, gen_batch: DataProto, env_obs: Dict[str, Any], team_context: List[str], actor_rollout_wg, step: int) -> Tuple[DataProto, List[str], List[str]]: - """Generate a summary of the conversation history.""" - obs = self.build_prompt(env_obs, team_context, step) - batch = preprocess_batch(gen_batch=gen_batch, - obs=obs, - config=self.config, - tokenizer=self.tokenizer, - processor=self.processor, - ) - batch, text_repsonses = self._generate_with_llm(batch, actor_rollout_wg, gen_batch.meta_info) - - team_context = self.postprocess_batch(team_context, text_repsonses) - return batch, text_repsonses, team_context - -# ============================================================================= -__all__ = [ - "AgentRegistry", - "BaseAgent", - "ReflexionAgent", - "PlanningAgent", - "ActionAgent", -] diff --git a/agent_system/agent/agents/__init__.py b/agent_system/agent/agents/__init__.py new file mode 100644 index 00000000..68194559 --- /dev/null +++ b/agent_system/agent/agents/__init__.py @@ -0,0 +1,14 @@ +# agent_system/agent/agents/__init__.py +from .action_agent import ActionAgent +from .memory_agent import MemoryAgent +from .planning_agent import PlanningAgent +from .reflexion_agent import ReflexionAgent +from .search_agent import SearchAgent + +__all__ = [ + "ActionAgent", + "MemoryAgent", + "PlanningAgent", + "ReflexionAgent", + "SearchAgent", +] diff --git a/agent_system/agent/agents/action_agent.py b/agent_system/agent/agents/action_agent.py new file mode 100644 index 00000000..e4258416 --- /dev/null +++ b/agent_system/agent/agents/action_agent.py @@ -0,0 +1,44 @@ +from typing import Dict, Any, Callable, Optional, List, Tuple +from verl import DataProto +from transformers import PreTrainedTokenizer +from agent_system.multi_turn_rollout.utils import preprocess_batch +from agent_system.agent.registry import AgentRegistry +from agent_system.agent.base import BaseAgent + +PROMPT = """ +{env_prompt} + +{team_context} + +------- + +You are an "Action Agent", and your role within your team is to determine the final action for the current step. + +You are now at step {step}. Based on all information above, please decide on the most appropriate admissible action. +You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within tags. +Once you've finished your reasoning, select one admissible action and MUST present it enclosed within {start_tag} {end_tag} tags. +""" + +@AgentRegistry.register("Action Agent") +class ActionAgent(BaseAgent): + def __init__(self, tokenizer: PreTrainedTokenizer, processor, config: Any): + super().__init__("Action Agent", PROMPT, tokenizer=tokenizer, processor=processor,config=config) + self.start_tag = "" + self.end_tag = "" + + def projection(self, text_repsonses: List[str]) -> List[str]: + return [response.strip() for response in text_repsonses] + + def call(self, gen_batch: DataProto, env_obs: Dict[str, Any], team_context: List[str], actor_rollout_wg, step: int) -> Tuple[DataProto, List[str], List[str]]: + """Generate a summary of the conversation history.""" + obs = self.build_prompt(env_obs, team_context, step) + batch = preprocess_batch(gen_batch=gen_batch, + obs=obs, + config=self.config, + tokenizer=self.tokenizer, + processor=self.processor, + ) + batch, text_repsonses = self._generate_with_llm(batch, actor_rollout_wg, gen_batch.meta_info) + + team_context = self.postprocess_batch(team_context, text_repsonses) + return batch, text_repsonses, team_context diff --git a/agent_system/agent/agents/memory_agent.py b/agent_system/agent/agents/memory_agent.py new file mode 100644 index 00000000..a9fefe65 --- /dev/null +++ b/agent_system/agent/agents/memory_agent.py @@ -0,0 +1,49 @@ +from typing import Dict, Any, Callable, Optional, List, Tuple +from verl import DataProto +from transformers import PreTrainedTokenizer +from agent_system.multi_turn_rollout.utils import preprocess_batch +from agent_system.agent.registry import AgentRegistry +from agent_system.agent.base import BaseAgent + +PROMPT = """ +{env_prompt} + +{team_context} + +------- + +You are a "Memory Agent", and your role within your team is to maintain a complete memory for all important history details. + +Your responsibilities: +- Maintain an objective and accurate log of **important observation details** and the **team's actions**. +- Do not include internal team reasoning, planning, or discussions. +- Record one entry for each environment step using the format: "Step N: your memory for this step" +- The environment observation must be a high-level summary in your own words — do NOT copy raw observation text. +- Be sure to record meaningful and high-impact details (e.g., number, price, names, and identifiers) from the observation that could inform future decisions, or help recover from incorrect or suboptimal decisions. +- In this update, append one new entry for the current step to the existing memory buffer. +- You MUST output the full memory buffer, from step 1 to the current step, including all previous entries. + +You are now at step {step}. Based on all the information above, provide a complete memory buffer enclosed within {start_tag} {end_tag} tags. +""" + + +@AgentRegistry.register("Memory Agent") +class MemoryAgent(BaseAgent): + def __init__(self, tokenizer: PreTrainedTokenizer, processor, config: Any): + super().__init__("Memory Agent", PROMPT, tokenizer=tokenizer, processor=processor,config=config) + self.start_tag = "" + self.end_tag = "" + + def call(self, gen_batch: DataProto, env_obs: Dict[str, Any], team_context: List[str], actor_rollout_wg, step: int) -> Tuple[DataProto, List[str], List[str]]: + """Generate a summary of the conversation history.""" + obs = self.build_prompt(env_obs, team_context, step) + batch = preprocess_batch(gen_batch=gen_batch, + obs=obs, + config=self.config, + tokenizer=self.tokenizer, + processor=self.processor, + ) + batch, text_repsonses = self._generate_with_llm(batch, actor_rollout_wg, gen_batch.meta_info) + + team_context = self.postprocess_batch(team_context, text_repsonses) + return batch, text_repsonses, team_context \ No newline at end of file diff --git a/agent_system/agent/agents/planning_agent.py b/agent_system/agent/agents/planning_agent.py new file mode 100644 index 00000000..cf3cabcf --- /dev/null +++ b/agent_system/agent/agents/planning_agent.py @@ -0,0 +1,46 @@ +from typing import Dict, Any, Callable, Optional, List, Tuple +from verl import DataProto +from transformers import PreTrainedTokenizer +from agent_system.multi_turn_rollout.utils import preprocess_batch +from agent_system.agent.registry import AgentRegistry +from agent_system.agent.base import BaseAgent + +PROMPT = """ +{env_prompt} + +{team_context} + +------- + +You are a "Planning Agent", and your role within your team is to formulate a high-level plan and identify the most appropriate strategic objective. + +Your responsibilities are strictly limited to: +- Formulating a high-level plan that addresses the current situation +- Ensuring the plan aligns with long-term task success + +You are now at step {step}. Based on all information above, you should first reason step-by-step about the planning process. This reasoning process MUST be enclosed within tags. +Once you've finished your reasoning, present your final plan enclosed within {start_tag} {end_tag} tags. +""" + + + +@AgentRegistry.register("Planning Agent") +class PlanningAgent(BaseAgent): + def __init__(self, tokenizer: PreTrainedTokenizer, processor, config: Any): + super().__init__("Planning Agent", PROMPT, tokenizer=tokenizer, processor=processor, config=config) + self.start_tag = "" + self.end_tag = "" + + def call(self, gen_batch: DataProto, env_obs: Dict[str, Any], team_context: List[str], actor_rollout_wg, step: int) -> Tuple[DataProto, List[str], List[str]]: + """Generate a summary of the conversation history.""" + obs = self.build_prompt(env_obs, team_context, step) + batch = preprocess_batch(gen_batch=gen_batch, + obs=obs, + config=self.config, + tokenizer=self.tokenizer, + processor=self.processor, + ) + batch, text_repsonses = self._generate_with_llm(batch, actor_rollout_wg, gen_batch.meta_info) + + team_context = self.postprocess_batch(team_context, text_repsonses) + return batch, text_repsonses, team_context \ No newline at end of file diff --git a/agent_system/agent/agents/reflexion_agent.py b/agent_system/agent/agents/reflexion_agent.py new file mode 100644 index 00000000..97375f81 --- /dev/null +++ b/agent_system/agent/agents/reflexion_agent.py @@ -0,0 +1,49 @@ +from typing import Dict, Any, Callable, Optional, List, Tuple +from verl import DataProto +from transformers import PreTrainedTokenizer +from agent_system.multi_turn_rollout.utils import preprocess_batch +from agent_system.agent.registry import AgentRegistry +from agent_system.agent.base import BaseAgent + +PROMPT = """ +{env_prompt} + +{team_context} + +------- + +You are a "Reflexion Agent", and your role within your team is to analyze the team's past actions and identify any mistakes, inefficiencies, missed opportunities, or incorrect assumptions that may have occurred. +Your reflection will help the your team understand what could have been done better and how to improve in future steps. + +Your responsibilities are strictly limited to: +- Review past actions, decisions, and outcomes. +- Identify mistakes, missed opportunities, inefficiencies, or false assumptions. +- Suggest improvements that could guide better decisions in the future. + +You are now at step {step}. Based on all information above, you should first reason step-by-step about the past events. This reasoning process MUST be enclosed within tags. +Once you've finished your reasoning, provide a clear and insightful reflection enclosed within {start_tag} {end_tag} tags. +""" + +@AgentRegistry.register("Reflexion Agent") +class ReflexionAgent(BaseAgent): + def __init__(self, tokenizer: PreTrainedTokenizer, processor,config: Any): + super().__init__("Reflexion Agent", PROMPT, tokenizer=tokenizer, processor=processor,config=config) + self.start_tag = "" + self.end_tag = "" + + def call(self, gen_batch: DataProto, env_obs: Dict[str, Any], team_context: List[str], actor_rollout_wg, step: int) -> Tuple[DataProto, List[str], List[str]]: + """Generate a summary of the conversation history.""" + if step == 0: + return None, None, team_context + + obs = self.build_prompt(env_obs, team_context, step) + batch = preprocess_batch(gen_batch=gen_batch, + obs=obs, + config=self.config, + tokenizer=self.tokenizer, + processor=self.processor, + ) + batch, text_repsonses = self._generate_with_llm(batch, actor_rollout_wg, gen_batch.meta_info) + + team_context = self.postprocess_batch(team_context, text_repsonses) + return batch, text_repsonses, team_context \ No newline at end of file diff --git a/agent_system/agent/agents/search_agent.py b/agent_system/agent/agents/search_agent.py new file mode 100644 index 00000000..dd753ff1 --- /dev/null +++ b/agent_system/agent/agents/search_agent.py @@ -0,0 +1,43 @@ + +from typing import Dict, Any, Callable, Optional, List, Tuple +from verl import DataProto +from transformers import PreTrainedTokenizer +from agent_system.multi_turn_rollout.utils import preprocess_batch +from agent_system.agent.registry import AgentRegistry +from agent_system.agent.base import BaseAgent + +PROMPT = """ +{env_prompt} + +{team_context} + +------- +You are a "Search Agent", and your primary responsibility is to call a search engine to obtain valuable external information that supports the team's goals. + +You are now at step {step}. +You should first reason step-by-step about the current situation and historical context to identify the core objective of the task, the information that is already known, and what is still missing. Consider how external information could provide value, and develop a search direction that includes specific key entities, relevant actions or focus areas. This reasoning process MUST be enclosed within tags. + +Once you've finished your reasoning, write the final search query inside {start_tag} {end_tag} tags. Ensure your query is precise, information-rich, avoids vagueness, and maximizes the likelihood of retrieving valuable, directly relevant information rather than repeating what is already known. +""" + +@AgentRegistry.register("Search Agent") +class SearchAgent(BaseAgent): + def __init__(self, tokenizer: PreTrainedTokenizer, processor, config: Any): + super().__init__("Search Agent", PROMPT, tokenizer=tokenizer, processor=processor, config=config) + self.start_tag = "" + self.end_tag = "" + + def call(self, gen_batch: DataProto, env_obs: Dict[str, Any], team_context: List[str], actor_rollout_wg, step: int) -> Tuple[DataProto, List[str], List[str]]: + """Generate a summary of the conversation history.""" + obs = self.build_prompt(env_obs, team_context, step) + batch = preprocess_batch(gen_batch=gen_batch, + obs=obs, + config=self.config, + tokenizer=self.tokenizer, + processor=self.processor, + ) + batch, text_repsonses = self._generate_with_llm(batch, actor_rollout_wg, gen_batch.meta_info) + + team_context = self.postprocess_batch(team_context, text_repsonses) + return batch, text_repsonses, team_context + \ No newline at end of file diff --git a/agent_system/agent/base.py b/agent_system/agent/base.py new file mode 100644 index 00000000..3babbfdb --- /dev/null +++ b/agent_system/agent/base.py @@ -0,0 +1,119 @@ +from __future__ import annotations +""" +Agent definitions. +""" +from typing import Dict, Any, List, Tuple +import copy +from verl import DataProto +from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto +from transformers import PreTrainedTokenizer +from agent_system.agent.utils import tag_projection + +class BaseAgent: + """Abstract agent. All subclasses *must* implement :py:meth:`act`.""" + + def __init__(self, name: str, prompt: str, tokenizer: PreTrainedTokenizer, processor, config: Any): + self.name = name + self.prompt = prompt + self.tokenizer = tokenizer + self.processor = processor + self.config = config + + self.start_tag = None + self.end_tag = None + + # Check if prompt is defined for this agent via calling the property + if not hasattr(self, 'prompt') or not isinstance(self.prompt, str): + raise ValueError(f"Agent '{self.name}' must define a 'prompt' property.") + + def reset(self): + pass + + def build_prompt(self, env_obs: Dict[str, Any], team_context: List[str], step: int) -> str: + """Build the prompt for the agent based on the observation.""" + # Naive Implementation + obs = copy.deepcopy(env_obs) + bs = len(obs['text']) + for i in range(bs): + if self.start_tag is not None and self.end_tag is not None: + obs['text'][i] = self.prompt.format(env_prompt=obs['text'][i], + team_context=team_context[i], + step=step, + start_tag=self.start_tag, + end_tag=self.end_tag) + else: + obs['text'][i] = self.prompt.format(env_prompt=obs['text'][i], + team_context=team_context[i], + step=step) + return obs + + def postprocess_batch(self, team_context: List[str], text_response: str) -> List[str]: + """Update the observation dictionary with the text response.""" + # Naive append of the latest responses to observations + for i in range(len(team_context)): + if team_context[i] == "": + team_context[i] = "Some of your teammates have already shared their thoughts for the current step. Their outputs are as follows:\n" + team_context[i] = team_context[i] + f"\n{self.name}:\n{text_response[i]}\n" + return team_context + + def _generate_with_llm(self, batch: DataProto, actor_rollout_wg, meta_info) -> Tuple[DataProto, List[str]]: + """Helper: prompt → input_ids → actor_rollout_wg → decoded str.""" + batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] + non_tensor_batch_keys_to_pop = ["raw_prompt_ids"] + if "multi_modal_data" in batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("multi_modal_data") + if "raw_prompt" in batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("raw_prompt") + if "tools_kwargs" in batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("tools_kwargs") + batch_input = batch.pop( + batch_keys=batch_keys_to_pop, + non_tensor_batch_keys=non_tensor_batch_keys_to_pop, + ) + + batch_input.meta_info = meta_info + + # pad to be divisible by dp_size + batch_input_padded, pad_size = pad_dataproto_to_divisor(batch_input, actor_rollout_wg.world_size) + batch_output_padded = actor_rollout_wg.generate_sequences(batch_input_padded) + # # unpad + batch_output = unpad_dataproto(batch_output_padded, pad_size=pad_size) + + batch = batch.union(batch_output) + + text_repsonses = self.tokenizer.batch_decode(batch.batch['responses'], skip_special_tokens=True) + + text_repsonses, valids = tag_projection(text_repsonses, start_tag=self.start_tag, end_tag=self.end_tag) + batch.non_tensor_batch['is_action_valid'] = valids + + return batch, text_repsonses + + def call( + self, + gen_batch: DataProto, + env_obs: Dict[str, Any], + team_context: List[str], + actor_rollout_wg, + step: int, + ) -> Tuple[DataProto, List[str], List[str]]: + """Generate a response based on the observation and the batch. + Args: + gen_batch (DataProto): The input batch for generation. + env_obs (Dict[str, Any]): Observations from the environment. + - 'text' (List[str]): Text observation data + - 'image' (np.ndarray or torch.Tensor): Image observation data + - 'anchor' (None or Any): Anchor observation without any histories or additional info. (for GiGPO only). + team_context (List[str]): Contextual information from the team. + actor_rollout_wg: The LLM policy for acting. + step: environment step + Returns: + Tuple[DataProto, List[str], List[str]]: + - batch (DataProto): The processed batch after generation. + - text_repsonses (List[str]): The generated text responses. + - team_context (List[str]): Updated team context after processing. + """ + raise NotImplementedError + +__all__ = [ + "BaseAgent", +] diff --git a/agent_system/agent/executor.py b/agent_system/agent/executor.py index 1a418e9f..335c2762 100644 --- a/agent_system/agent/executor.py +++ b/agent_system/agent/executor.py @@ -3,7 +3,9 @@ """ from typing import List, Dict, Any, Optional, Tuple from transformers import PreTrainedTokenizer -from agent_system.agent.agent import AgentRegistry, BaseAgent +from agent_system.agent.base import BaseAgent +from agent_system.agent.registry import AgentRegistry +from agent_system.agent.agents import * from verl import DataProto @@ -20,8 +22,8 @@ def __init__( self.config = config self.tokenizer = tokenizer self.processor = processor - if agent_list is None: - agent_list = ["Reflexion Agent", "Action Agent", "Memory Agent"] + + assert agent_list is not None, "agent_list must be provided." self.agents: Dict[str, BaseAgent] = { name: AgentRegistry.create(name=name, @@ -38,6 +40,7 @@ def __init__( def reset(self): """Reset the executor, all agents, and buffer.""" self.reset_buffer() + self.memory = None for ag in self.agents.values(): ag.reset() @@ -64,7 +67,6 @@ def initialize_context(self, env_obs): def update_memory(self, text_repsonses: List[str]): """Update the memory of the agents with the latest text responses.""" - assert "Memory Agent" in self.agent_list, "Memory Agent is required to update memory. Please add it to the agent_list list." if self.memory is None: self.memory = text_repsonses else: @@ -111,7 +113,7 @@ class MultiAgentChainExecutor(BaseExecutor): """ def __init__( self, - agent_list: Optional[List[str]] = ["Reflexion Agent", "Planning Agent", "Action Agent"], + agent_list: Optional[List[str]] = ["Reflexion Agent", "Action Agent", "Memory Agent"], tokenizer: PreTrainedTokenizer = None, processor=None, config: Any = None, @@ -124,6 +126,9 @@ def __init__( ) if not self.agents: raise ValueError("ChainExecutor requires at least one agent.") + + if self.config.agent.use_agent_memory and "Memory Agent" not in self.agent_list: + raise ValueError("Memory Agent is required to use agent memory. Please add it to the agent_list.") # The order of agents is the execution order. self.agent_order = self.agent_list @@ -146,7 +151,7 @@ def run(self, gen_batch: DataProto, env_obs: Dict[str, Any], actor_rollout_wg, s if name == "Action Agent": text_actions = text_repsonses - if name == "Memory Agent": + if self.config.agent.use_agent_memory and name == "Memory Agent": self.update_memory(text_repsonses) # if len(self.multiagent_batch_buffer) != len(self.agent_order): @@ -193,4 +198,63 @@ def run(self, gen_batch: DataProto, env_obs: Dict[str, Any], actor_rollout_wg, s pass +class SearchMultiAgentExecutor(BaseExecutor): + """Sequentially run agents, passing observation and batch through each agent. + This executor runs agents in a chain, where each agent processes the output + of the previous agent and passes its output to the next agent. + It is useful for scenarios where agents need to work in a sequence, such as + in a pipeline or a multi‑step process. + Args: + agent_list (List[str]): List of agent names to be executed in sequence. + tokenizer (PreTrainedTokenizer): Tokenizer for processing text. + processor: Processor for handling data. + config (Any): Configuration object containing settings for the executor. + """ + def __init__( + self, + agent_list: Optional[List[str]] = ["Planning Agent", "Action Agent", "Memory Agent"], + tokenizer: PreTrainedTokenizer = None, + processor=None, + config: Any = None, + ): + super().__init__( + agent_list=agent_list, + tokenizer=tokenizer, + processor=processor, + config=config, + ) + if not self.agents: + raise ValueError("ChainExecutor requires at least one agent.") + + if self.config.agent.use_agent_memory and "Memory Agent" not in self.agent_list: + raise ValueError("Memory Agent is required to use agent memory. Please add it to the agent_list.") + + # The order of agents is the execution order. + self.agent_order = self.agent_list + # if self.agent_order[-1] != "ActionAgent": + # raise ValueError("The last agent must be ActionAgent.") + + def run(self, gen_batch: DataProto, env_obs: Dict[str, Any], actor_rollout_wg, step: int) -> Tuple[List[str], Dict[str, DataProto]]: + # clear and reset multiagent batch buffer + self.reset_buffer() + team_context, env_obs = self.initialize_context(env_obs) + + # run agents sequentially, passing observation and batch + for name in self.agent_order: + batch, text_repsonses, team_context = self.agents[name].call(gen_batch=gen_batch, env_obs=env_obs, team_context=team_context, actor_rollout_wg=actor_rollout_wg, step=step) + if batch is None: + continue # skip if the agent did not produce a batch + + # save the batch to the multiagent buffer + self.save_to_buffer(name, batch) + + if name == "Action Agent": + text_actions = text_repsonses + if self.config.agent.use_agent_memory and name == "Memory Agent": + self.update_memory(text_repsonses) + + # if len(self.multiagent_batch_buffer) != len(self.agent_order): + # raise Warning("Multiagent output batch buffer length does not match number of agents. This may lead to unexpected behavior.") + return text_actions, self.multiagent_batch_buffer + __all__ = ["ChainExecutor", "HierarchicalExecutor"] diff --git a/agent_system/agent/prompts.py b/agent_system/agent/prompts.py deleted file mode 100644 index d760382a..00000000 --- a/agent_system/agent/prompts.py +++ /dev/null @@ -1,129 +0,0 @@ -AGENT_PROMPTS = { -"Reflexion Agent": -""" -{env_prompt} - -{team_context} - -------- - -You are a "Reflexion Agent", and your role within your team is to analyze the team's past actions and identify any mistakes, inefficiencies, missed opportunities, or incorrect assumptions that may have occurred. -Your reflection will help the your team understand what could have been done better and how to improve in future steps. - -Your responsibilities are strictly limited to: -- Review past actions, decisions, and outcomes. -- Identify mistakes, missed opportunities, inefficiencies, or false assumptions. -- Suggest improvements that could guide better decisions in the future. - -You are now at step {step}. Based on all information above, you should first reason step-by-step about the past events. This reasoning process MUST be enclosed within tags. -Once you've finished your reasoning, provide a clear and insightful reflection enclosed within {start_tag} {end_tag} tags. -""" -, - -"Planning Agent": -""" -{env_prompt} - -{team_context} - -------- - -You are a "Planning Agent", and your role within your team is to formulate a high-level plan and identify the most appropriate strategic objective. - -Your responsibilities are strictly limited to: -- Formulating a high-level plan that addresses the current situation -- Ensuring the plan aligns with long-term task success - -You are now at step {step}. Based on all information above, you should first reason step-by-step about the planning process. This reasoning process MUST be enclosed within tags. -Once you've finished your reasoning, present your final plan enclosed within {start_tag} {end_tag} tags. -""" -, - -"Action Agent": -""" -{env_prompt} - -{team_context} - -------- - -You are an "Action Agent", and your role within your team is to determine the final action for the current step. - -You are now at step {step}. Based on all information above, please decide on the most appropriate admissible action. -You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within tags. -Once you've finished your reasoning, select one admissible action and MUST present it enclosed within {start_tag} {end_tag} tags. -""", - - -"Memory Agent": -""" -{env_prompt} - -{team_context} - -------- - -You are a "Memory Agent", and your role within your team is to maintain a complete memory for all important history details. - -Your responsibilities: -- Maintain an objective and accurate log of important environmental details and the team's actions. -- Do not include internal team reasoning, planning, or discussions. -- Record one entry for each environment step using the format: "Step N: ..." -- The environment observation must be a high-level summary in your own words — do NOT copy raw observation text. -- Be sure to record meaningful and high-impact details (e.g., number, price, names, and identifiers) from the environment observation that could inform future decisions, or help recover from incorrect or suboptimal decisions. -- In this update, append one new entry for the current step to the existing memory buffer. -- You MUST output the full memory buffer, from step 1 to the current step, including all previous entries. - -You are now at step {step}. Based on all the information above, provide an updated, concise memory buffer enclosed within {start_tag} {end_tag} tags. -The memory buffer should look like: -{start_tag} -step 1: ... -step 2: ... -... -{end_tag} -""", - -# "Memory Agent": -# """ -# {env_prompt} - -# {team_context} - -# ------- - -# You are a "Memory Agent", and your role within your team is to maintain a complete summary of the interaction history between your team and the environment. - -# Responsibilities: -# - Ensure the summary is objective, factual, and captures only meaningful events. -# - Do not include task descriptions, internal reasoning, your team discussions, or strategic planning. -# - You MUST maintain a complete and continuous history from step 1 to the current step (do not skip or omit any steps). -# - Record one entry for each environment step using the format: "Step N: summary of environment observation and team action" (summarize clearly and concisely in your own words, keeping each entry under 100 characters). -# - For each step, only summarize the environment observation, the team's action, and any feedback or result from the environment. -# - In this update, you only need to append a new entry for the current step to the existing history memory. - -# Now, based on all the information above, you should first reason step-by-step about what should be recorded. This reasoning process MUST be enclosed within tags. -# Once you've finished your reasoning, provide an updated, concise summary enclosed within {start_tag} {end_tag} tags. -# """, - -# "Memory Agent": -# """ -# {env_prompt} - -# The following is the output provided by your teammates: -# {team_context} - -# ------- - -# You are a **Memory Agent**, responsible for recording a clear, step-by-step history of the task based solely on the final action chosen by the team and the environment's feedback. - -# Your output must: -# - Record only the final action executed by the team at each environment step and the corresponding environment feedback -# - Do not include any teammate suggestions, discussion, or reasoning -# - Avoid copying raw output; summarize events in clear, natural language -# - Keep the memory short, factual, informative, and easy to follow -# - Write one entry per environment step using the format: "Step N: what happened" (where N corresponds to the environment's step number) -# - If the memory context becomes too long, try to compress earlier steps into a compact summary. - -# Now, based on the information above, please update the memory to reflect the full history. Be clear and brief in your response. -# """, -} \ No newline at end of file diff --git a/agent_system/agent/registry.py b/agent_system/agent/registry.py new file mode 100644 index 00000000..36f42684 --- /dev/null +++ b/agent_system/agent/registry.py @@ -0,0 +1,28 @@ +from __future__ import annotations +""" +Agent registry. +""" +from typing import Dict, Callable, List + + +class AgentRegistry: + _REGISTRY: Dict[str, Callable[..., "BaseAgent"]] = {} + + @classmethod + def register(cls, name: str): + def decorator(agent_cls: Callable[..., "BaseAgent"]): + if name in cls._REGISTRY: + raise ValueError(f"Agent '{name}' already registered.") + cls._REGISTRY[name] = agent_cls + return agent_cls + return decorator + + @classmethod + def create(cls, name: str, **kwargs): + if name not in cls._REGISTRY: + raise KeyError(f"Unknown agent '{name}'. Registered: {list(cls._REGISTRY)}") + return cls._REGISTRY[name](**kwargs) + + @classmethod + def names(cls) -> List[str]: + return list(cls._REGISTRY) \ No newline at end of file diff --git a/agent_system/environments/base.py b/agent_system/environments/base.py index 825fd975..724d62cb 100644 --- a/agent_system/environments/base.py +++ b/agent_system/environments/base.py @@ -30,9 +30,11 @@ def __init__(self, envs, projection_f, config): self.projection_f = projection_f self.config = config - def reset(self) -> Dict[str, Any]: + def reset(self, kwargs) -> Dict[str, Any]: """ Reset all environments and return the initial observations. + Parameters: + - kwargs (Dict): Additional keyword arguments for resetting the environment, such as 'tools_kwargs'. Returns: - next_observations (Dict): diff --git a/agent_system/environments/env_manager.py b/agent_system/environments/env_manager.py index 069eae15..16771a7d 100644 --- a/agent_system/environments/env_manager.py +++ b/agent_system/environments/env_manager.py @@ -6,7 +6,7 @@ import os from agent_system.environments.prompts import * from agent_system.environments.base import EnvironmentManagerBase, to_numpy -from agent_system.memory import SimpleMemory +from agent_system.memory import SimpleMemory, SearchMemory def parse_gamefile(infos): gamefile = [] @@ -31,7 +31,7 @@ def __init__(self, envs, projection_f, config): self.memory = SimpleMemory() super().__init__(envs, projection_f, config) - def reset(self): + def reset(self, kwargs): text_obs, image_obs, infos = self.envs.reset() self.gamefile = parse_gamefile(infos) # initialize the history buffer @@ -97,6 +97,7 @@ def build_text_obs(self, text_obs: List[str], admissible_actions: List[List[str] if self.config.agent.multi_agent: obs = ALFWORLD_MULTIAGENT_TEMPLATE_NO_HIS.format( task_description=self.tasks[i], + current_step=len(self.memory[i]) + 1, current_observation=text_obs[i], admissible_actions=reformatted_admissible_actions ) @@ -111,7 +112,7 @@ def build_text_obs(self, text_obs: List[str], admissible_actions: List[List[str] obs = ALFWORLD_MULTIAGENT_TEMPLATE.format( task_description=self.tasks[i], step_count=len(self.memory[i]), - memory="{memory}", + memory="{memory}" if self.config.agent.use_agent_memory else memory_contexts[i], current_step=len(self.memory[i]) + 1, current_observation=text_obs[i], admissible_actions=reformatted_admissible_actions @@ -174,7 +175,7 @@ def __init__(self, envs, projection_f, config): self.memory = SimpleMemory() super().__init__(envs, projection_f, config) - def reset(self): + def reset(self, kwargs): obs, infos = self.envs.reset() if self.is_multi_modal: obs = np.array(obs, obs[0].dtype) @@ -266,7 +267,7 @@ class GymCardEnvironmentManager(EnvironmentManagerBase): def __init__(self, envs, projection_f, config): super().__init__(envs, projection_f, config) - def reset(self) -> Dict[str, Any]: + def reset(self, kwargs) -> Dict[str, Any]: obs, infos = self.envs.reset() # infos = [None] * self.envs.num_envs observations = {'text': self.build_text_obs(infos), 'image': obs, 'anchor': obs.copy()} @@ -310,8 +311,9 @@ def __init__(self, envs, projection_f, config): self.memory = SimpleMemory() super().__init__(envs, projection_f, config) - def reset(self) -> Dict[str, Any]: + def reset(self, kwargs) -> Dict[str, Any]: obs, infos = self.envs.reset() + self.memory.reset(batch_size = len(infos)) self.tasks = self.extract_task(obs) obs = self.format_obs(obs) # infos = [None] * self.envs.num_envs @@ -320,7 +322,6 @@ def reset(self) -> Dict[str, Any]: 'anchor': obs.copy() } self.pre_text_obs = obs - self.memory.reset(batch_size = len(infos)) return observations, infos def step(self, text_actions: List[str]): @@ -409,6 +410,7 @@ def build_text_obs(self, text_obs: List[str], infos: List[List[str]], init: bool if self.config.agent.multi_agent: obs = WEBSHOP_MULTIAGENT_TEMPLATE_NO_HIS.format( task_description=self.tasks[i], + current_step=len(self.memory[i]) + 1, current_observation=text_obs[i], available_actions=reformatted_available_actions ) @@ -426,7 +428,7 @@ def build_text_obs(self, text_obs: List[str], infos: List[List[str]], init: bool obs = WEBSHOP_MULTIAGENT_TEMPLATE.format( task_description=self.tasks[i], step_count=len(self.memory[i]), - memory="{memory}", + memory="{memory}" if self.config.agent.use_agent_memory else memory_contexts[i], current_step=len(self.memory[i]) + 1, current_observation=text_obs[i], available_actions=reformatted_available_actions @@ -462,7 +464,7 @@ def __init__(self, envs, projection_f, config): self.memory = SimpleMemory() super().__init__(envs, projection_f, config) - def reset(self): + def reset(self, kwargs): text_obs, infos = self.envs.reset() self.supervisors = [info['supervisor'] for info in infos] @@ -543,6 +545,104 @@ def build_text_obs(self, text_obs: List[str], init: bool = False) -> List[str]: postprocess_text_obs.append(obs) return postprocess_text_obs +import time + +class SearchEnvironmentManager(EnvironmentManagerBase): + """ + EnvironmentManager for SearchEnv. + """ + def __init__(self, envs, projection_f, config): + self.memory = SearchMemory() + super().__init__(envs, projection_f, config) + + def reset(self, kwargs) -> Tuple[Dict[str, Any], List[Dict]]: + obs, infos = self.envs.reset(kwargs=kwargs) + self.tasks = obs + + self.memory.reset(batch_size=len(obs)) + + observations = { + "text": self.build_text_obs(obs, init=True), + "image": None, + "anchor": obs.copy() + } + + return observations, infos + + def step(self, text_actions: List[str]): + if not self.config.agent.multi_agent: + actions, valids = self.projection_f(text_actions) + else: + actions = text_actions + + time1 = time.time() + next_obs, rewards, dones, infos = self.envs.step(actions) + time2 = time.time() + print(f"SearchEnv step time: {time2 - time1:.4f} seconds") + + self.memory.store({ + "search": actions, + "information": next_obs, + }) + + next_observations = { + "text": self.build_text_obs(next_obs), + "image": None, + "anchor": next_obs.copy() + } + + if not self.config.agent.multi_agent: + for i, info in enumerate(infos): + info["is_action_valid"] = to_numpy(valids[i]) + + rewards = to_numpy(rewards) + dones = to_numpy(dones) + + return next_observations, rewards, dones, infos + + def build_text_obs( + self, + text_obs: List[str], + init: bool = False + ) -> List[str]: + postprocess_text_obs: List[str] = [] + + if not init and self.config.env.history_length > 0: + memory_ctx, _ = self.memory.fetch( + self.config.env.history_length, + obs_key="information", + action_key="search" + ) + + for i in range(len(text_obs)): + if init or self.config.env.history_length <= 0: + obs_i = SEARCH_TEMPLATE_NO_HIS.format( + task_description=self.tasks[i] + ) + else: + obs_i = SEARCH_TEMPLATE.format( + task_description=self.tasks[i], + memory_context=memory_ctx[i], + step_count=len(self.memory[i]), + ) + postprocess_text_obs.append(obs_i) + + return postprocess_text_obs + + def _process_batch(self, batch_idx, total_batch_list, total_infos, success): + # Find the last entry with active masks + for i in reversed(range(len(total_batch_list[batch_idx]))): + batch_item = total_batch_list[batch_idx][i] + if batch_item['active_masks']: + info = total_infos[batch_idx][i] + won_value = float(info['won']) + success['success_rate'].append(won_value) + + # Process game file if it exists + data_source = info.get("data_source") + success[f"{data_source}_success_rate"].append(won_value) + return # Exit after finding the first active mask + def make_envs(config): """ Create enviroments @@ -627,6 +727,15 @@ def make_envs(config): envs = AppWorldEnvironmentManager(_envs, projection_f, config) val_envs = AppWorldEnvironmentManager(_val_envs, projection_f, config) return envs, val_envs + elif "search" in config.env.env_name.lower(): + from agent_system.environments.env_package.search import build_search_envs, search_projection + _envs = build_search_envs(seed=config.env.seed, env_num=config.data.train_batch_size, group_n=group_n, is_train=True, env_config=config.env) + _val_envs = build_search_envs(seed=config.env.seed + 1000, env_num=config.data.val_batch_size, group_n=1, is_train=False, env_config=config.env) + + projection_f = partial(search_projection) + envs = SearchEnvironmentManager(_envs, projection_f, config) + val_envs = SearchEnvironmentManager(_val_envs, projection_f, config) + return envs, val_envs else: print("Environment not supported") exit(1) \ No newline at end of file diff --git a/agent_system/environments/env_package/search/__init__.py b/agent_system/environments/env_package/search/__init__.py new file mode 100644 index 00000000..8e79a934 --- /dev/null +++ b/agent_system/environments/env_package/search/__init__.py @@ -0,0 +1,2 @@ +from .projection import search_projection +from .envs import build_search_envs \ No newline at end of file diff --git a/agent_system/environments/env_package/search/envs.py b/agent_system/environments/env_package/search/envs.py new file mode 100644 index 00000000..bee332d8 --- /dev/null +++ b/agent_system/environments/env_package/search/envs.py @@ -0,0 +1,165 @@ +import asyncio +import concurrent.futures +from typing import Any, Dict, List, Tuple + +import gym +import numpy as np +from omegaconf import DictConfig, ListConfig +from copy import deepcopy + + +class SearchMultiProcessEnv(gym.Env): + """ + - env_num : 分组数(逻辑分片,保留形参以兼容外部调用) + - group_n : 每组环境数 + - total_envs = env_num * group_n + 外部仍按 “一个元素 = 一环境” 的形式传 action / kwargs。 + """ + + def __init__( + self, + seed: int = 0, + env_num: int = 1, + group_n: int = 1, + is_train: bool = True, + env_config: DictConfig | None = None, + ) -> None: + super().__init__() + + from agent_system.environments.env_package.search.third_party.skyrl_gym.envs.search.env import SearchEnv + + self.env_num = env_num + self.group_n = group_n + self.batch_size = env_num * group_n + self.is_train = is_train + self.max_steps = env_config.max_steps + + self._rng = np.random.RandomState(seed) + + # ---------- 关键改动开始 ---------- + # 1) 把 search_url 统一转成 list + search_cfg = env_config.search + search_urls = search_cfg.search_url + if not isinstance(search_urls, ListConfig): + search_urls = [search_urls] + + n_clients = len(search_urls) + + # 2) round‑robin 为每个 env 选一个 url + self.envs = [] + for idx in range(self.batch_size): + cfg_i = deepcopy(search_cfg) # 避免修改原始 config + cfg_i.search_url = search_urls[idx % n_clients] + self.envs.append(SearchEnv(cfg_i)) + + max_workers = min(self.batch_size, 256) + self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) + + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + + def _sync_reset(self, env, kwargs): + extras = { + "ground_truth": kwargs["search"]["create_kwargs"]["ground_truth"], + "max_turns": self.max_steps, + "data_source": kwargs["search"]["create_kwargs"].get("data_source", "unknown") + } + env.reset(extras) + obs = kwargs["search"]["create_kwargs"]["question"] + info = {'data_source': kwargs["search"]["create_kwargs"].get("data_source", "unknown")} + return obs, info + + def _sync_step(self, env, action: str): + out = env.step(action) + + obs = out["observations"] + obs = "" if len(obs) == 0 else obs[0]["content"].strip() + reward = out["reward"] + done = out["done"] + + info = dict(out.get("metadata", {})) + info["postprocessed_action"] = out.get("postprocessed_action") + info["won"] = bool(done and reward >= 1.0) + return obs, reward, done, info + + def reset(self, kwargs: List[Dict]): + if len(kwargs) > self.batch_size: + raise ValueError(f"Got {len(kwargs)} kwarg dicts, but the env was initialised with total_envs={self.batch_size}") + + pad_n = self.batch_size - len(kwargs) + dummy_kw = { + "search": { + "create_kwargs": { + "ground_truth": "", + "question": "", + "data_source": "unkown", + } + } + } + + padded_kwargs = list(kwargs) + [dummy_kw] * pad_n + valid_mask = [True] * len(kwargs) + [False] * pad_n + + tasks = [ + self._loop.run_in_executor(self._executor, self._sync_reset, env, kw) + for env, kw in zip(self.envs, padded_kwargs) + ] + results = self._loop.run_until_complete(asyncio.gather(*tasks)) + + obs_list, info_list = map(list, zip(*results)) + + obs_list = [o for o, keep in zip(obs_list, valid_mask) if keep] + info_list = [i for i, keep in zip(info_list, valid_mask) if keep] + + return obs_list, info_list + + def step(self, actions: List[str]): + if len(actions) > self.batch_size: + raise ValueError(f"Got {len(actions)} actions, but the env was initialized with total_envs={self.batch_size}") + + pad_n = self.batch_size - len(actions) + padded_actions = list(actions) + [""] * pad_n + valid_mask = [True] * len(actions) + [False] * pad_n + + tasks = [ + self._loop.run_in_executor(self._executor, self._sync_step, env, act) + for env, act in zip(self.envs, padded_actions) + ] + results = self._loop.run_until_complete(asyncio.gather(*tasks)) + + obs_list, reward_list, done_list, info_list = map(list, zip(*results)) + + obs_list = [o for o, keep in zip(obs_list, valid_mask) if keep] + reward_list = [r for r, keep in zip(reward_list, valid_mask) if keep] + done_list = [d for d, keep in zip(done_list, valid_mask) if keep] + info_list = [i for i, keep in zip(info_list, valid_mask) if keep] + + return obs_list, reward_list, done_list, info_list + + def close(self): + if getattr(self, "_closed", False): + return + for env in self.envs: + env.close() + self._executor.shutdown(wait=True) + self._loop.close() + self._closed = True + + def __del__(self): + self.close() + + +def build_search_envs( + seed: int = 0, + env_num: int = 1, + group_n: int = 1, + is_train: bool = True, + env_config=None, +): + return SearchMultiProcessEnv( + seed=seed, + env_num=env_num, + group_n=group_n, + is_train=is_train, + env_config=env_config, + ) \ No newline at end of file diff --git a/agent_system/environments/env_package/search/projection.py b/agent_system/environments/env_package/search/projection.py new file mode 100644 index 00000000..9238914a --- /dev/null +++ b/agent_system/environments/env_package/search/projection.py @@ -0,0 +1,73 @@ +from typing import List, Tuple +import re + + +def _postprocess_action(action: str) -> str: + """Trim everything *after* the first closing `` or `` tag. + + This guards against a common LLM hallucination where an action contains + several concatenated XML‑like snippets. By hard‑cutting at the first + relevant close tag we can safely apply non‑greedy regex below. + """ + if "" in action: + return action.split("", 1)[0] + "" + if "" in action: + return action.split("", 1)[0] + "" + return action + + +def search_projection(actions: List[str]) -> Tuple[List[str], List[int]]: + """Project a list of LLM *actions* into (`results`, `valids`). + + Extraction logic (order matters): + 1. Grab the **first** complete ```` block (case‑insensitive). + 2. If absent, grab the **first** complete ```` block. + 3. If still absent, store an empty string. + + Validity logic (independent of extraction): ``valids[i]`` flips to **0** when + the *original* action text satisfies any of: + 1. Contains **both** ```` and ```` tags. + 2. Contains more than one ```` tag or more than one ```` tag. + + The extracted block (if any) is **not** cleared when a validity rule fails – + downstream callers can still inspect the fragment while trusting the flag. + """ + + results: List[str] = [] + valids: List[int] = [1] * len(actions) + + # --- Pre‑compiled patterns ------------------------------------------------ + re_search_block = re.compile(r"(.*?)", re.IGNORECASE | re.DOTALL) + re_answer_block = re.compile(r"(.*?)", re.IGNORECASE | re.DOTALL) + re_search_tag = re.compile(r"", re.IGNORECASE) + re_answer_tag = re.compile(r"", re.IGNORECASE) + + for i, action in enumerate(actions): + original_action = action # Keep untouched for validity checks + trimmed_action = _postprocess_action(action) + + # --- Extraction ----------------------------------------------------- + m = re_search_block.search(trimmed_action) + if m: + results.append(f"{m.group(1).strip()}") + else: + m = re_answer_block.search(trimmed_action) + if m: + results.append(f"{m.group(1).strip()}") + else: + results.append("") + valids[i] = 0 + + # --- Validity checks ------------------------------------------------- + n_search = len(re_search_tag.findall(original_action)) + n_answer = len(re_answer_tag.findall(original_action)) + + # Both search and answer present + if n_search and n_answer: + valids[i] = 0 + continue + # Multiple identical tags + if n_search > 1 or n_answer > 1: + valids[i] = 0 + + return results, valids diff --git a/agent_system/environments/env_package/search/third_party/README.md b/agent_system/environments/env_package/search/third_party/README.md new file mode 100644 index 00000000..ebdd23fd --- /dev/null +++ b/agent_system/environments/env_package/search/third_party/README.md @@ -0,0 +1,32 @@ +# SkyRL-Gym + +A library of RL environments for LLMs implemented with the Gymnasium API. + +## Key Features + +- Simple `Environment` interface following the Gynasium API. +- Library of ready-built environments for math, code, search, and text-to-SQL. +- A reusable `tool` interface. Developers can implement a tool once, and use it across any environment. +- Supports multi-tool environments + +## Installation + +You can install the latest release from PyPI: + +```bash +pip install skyrl-gym +``` + +or install from source: + +```bash +git clone https://github.com/NovaSky-AI/SkyRL.git +cd SkyRL/skyrl-gym +pip install -e . +``` + +## Documentation + +To build your first environment, see our [Walkthrough Docs](https://skyrl.readthedocs.io/en/latest/tutorials/new_env.html). + +All docs are available at [https://skyrl.readthedocs.io/en/latest/](https://skyrl.readthedocs.io/en/latest/). diff --git a/agent_system/environments/env_package/search/third_party/pyproject.toml b/agent_system/environments/env_package/search/third_party/pyproject.toml new file mode 100644 index 00000000..ab70c22b --- /dev/null +++ b/agent_system/environments/env_package/search/third_party/pyproject.toml @@ -0,0 +1,81 @@ +# basic pyproject.toml for skyrl-gym +[build-system] +requires = ["setuptools>=45", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "skyrl-gym" +version = "0.1.1" +description = "RL environments for LLMs implemented with the Gymnasium API." +authors = [ + { name = "NovaSkyAI AI Team"} +] +license = { text = "MIT" } +readme = "README.md" +requires-python = ">=3.10" + +dependencies = [ + "func_timeout", + "pandas", + "requests", + "omegaconf", +] + +[project.urls] +Repository = "https://github.com/NovaSky-AI/SkyRL" + +[tool.setuptools.packages.find] +include = ["skyrl_gym*"] + +[project.optional-dependencies] +dev = [ + "pytest" +] + +[tool.black] +line-length = 120 +include = '\.pyi?$' +extend-exclude = ''' +/( + # directories + \.eggs + | \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | build + | dist +)/ +''' + +[tool.flake8] +max-line-length = 120 +max-doc-length = 120 +extend-ignore = [ + # Default ignored errors by flake8 + "E121", "E123", "E126", "E226", "E24", "E704", + # F401 module imported but unused + "F401", + # E203 whitespace before ':' (conflict with black) + "E203", + # E231 missing whitespace after ',' (conflict with black) + "E231", + # E501 line too long (conflict with black) + "E501", + # E741 do not use variables named 'l', 'O', or 'I' + "E741", + # W503 line break before binary operator (conflict with black) + "W503", + # W504 line break after binary operator (conflict with black) + "W504", + # W505 doc line too long (conflict with black) + "W505", + # W605 invalid escape sequence 'x' (conflict with latex within docs) + "W605", +] + +[tool.ruff.lint] +ignore = [ + "F722" # Syntax error in annotation - ignored because this doesn't play well with jaxtyping +] \ No newline at end of file diff --git a/agent_system/environments/env_package/search/third_party/skyrl_gym/__init__.py b/agent_system/environments/env_package/search/third_party/skyrl_gym/__init__.py new file mode 100644 index 00000000..89182b41 --- /dev/null +++ b/agent_system/environments/env_package/search/third_party/skyrl_gym/__init__.py @@ -0,0 +1,30 @@ +"""Root `__init__` of the gym module setting the `__all__` of skyrl modules.""" + +from agent_system.environments.env_package.search.third_party.skyrl_gym.core import Env +from agent_system.environments.env_package.search.third_party.skyrl_gym import error + +from agent_system.environments.env_package.search.third_party.skyrl_gym.envs.registration import ( + make, + spec, + register, + registry, + pprint_registry, +) +from agent_system.environments.env_package.search.third_party.skyrl_gym import tools, envs + +# Define __all__ to control what's exposed +__all__ = [ + # core classes + "Env", + # registration + "make", + "spec", + "register", + "registry", + "pprint_registry", + # module folders + "envs", + "tools", + "error", +] +__version__ = "0.0.0" diff --git a/agent_system/environments/env_package/search/third_party/skyrl_gym/core.py b/agent_system/environments/env_package/search/third_party/skyrl_gym/core.py new file mode 100644 index 00000000..28bbb50e --- /dev/null +++ b/agent_system/environments/env_package/search/third_party/skyrl_gym/core.py @@ -0,0 +1,97 @@ +"""Core API for Environment""" + +from __future__ import annotations + +from typing import Generic, Any, SupportsFloat, TypeVar, TypedDict, Optional +from typing import Tuple, Dict + +ObsType = TypeVar("ObsType") +ActType = TypeVar("ActType") + + +class EnvStepOutput(TypedDict): + observations: ObsType + reward: SupportsFloat + done: bool + metadata: Optional[Dict[str, Any]] = None + + +class Env(Generic[ObsType, ActType]): + """ + The main SkyRL Gym class for implementing Reinforcement Learning Agents environments. + + The main API methods that users of this class need to know are: + + - `step` - Perform actions (e.g. tool calls) in the environment. + Return the observation, the reward for taking that actions, and a boolean value `done`. + + - `init` - Initializes the environment to an initial state, required before calling step. + Returns the first observation for a turn and information, i.e. metrics, debug info. + + - `close` - Closes the environment. + Important when external software is used, i.e. pygame for rendering, databases + """ + + def step(self, action: ActType) -> EnvStepOutput: + """ + Parse and run one step of action in the environment. + + Args: + action (ActType): An action provided to the environment. + For example, in our case, the action can be a [str] response generated by an LLM, + which must be parsed and executed accordingly. + + Returns: + observation (ObsType): The resulting observation after executing the action. + For example, this could involve executing a SQL query derived from the LLM response + and observing {'role': 'user', 'content': 'str(observation)'} output or any error messages from database. + + reward (SupportsFloat): The reward obtained by taking the action. + + done (bool): A boolean value for if the episode has ended, in which case further `step` calls will + return undefined results. + + info (Dict): Contains auxiliary diagnostic information (helpful for debugging, learning, and logging). + This might, for instance, contain: metrics that describe the performance state, variables that are + hidden from observations, or individual reward terms that are combined to produce the total reward. + """ + raise NotImplementedError + + def init(self, *kwargs) -> Tuple[ObsType, Dict[str, Any]]: + """ + Initialize the environment, returning initial observation and optional metadata. + + Returns: + observation (ObsType): Observation of the initial state. This is analogous to the observation returned by `step`. + info (Dict: This dictionary contains auxiliary information complementing ``observation``. It should be analogous to + the ``info`` returned by `step`. + """ + raise NotImplementedError + + def close(self): + """ + After the user has finished using the environment, close contains the code necessary to "clean up" the environment. + + This is critical for closing rendering windows, database or HTTP connections. + Calling ``close`` on an already closed environment has no effect and won't raise an error. + """ + pass + + def __str__(self): + """ + Returns a string of the environment. + + Returns: + A string identifying the environment + """ + return f"Env({type(self).__name__})" + + def __enter__(self): + """Support with-statement for the environment.""" + return self + + def __exit__(self, *args: Any): + """Support with-statement for the environment and closes the environment.""" + self.close() + # propagate exception + return False diff --git a/agent_system/environments/env_package/search/third_party/skyrl_gym/envs/README.md b/agent_system/environments/env_package/search/third_party/skyrl_gym/envs/README.md new file mode 100644 index 00000000..5f12a95b --- /dev/null +++ b/agent_system/environments/env_package/search/third_party/skyrl_gym/envs/README.md @@ -0,0 +1,16 @@ +# Supported Environments + +## Basic Environments +Search, SQL, Math, + Simple Code (LCB)? + +## Multi-Tool Environment +The model decides which tool to call in each step, showcase Search + Python Code example scripts +- New environment, tool parsing logic different + +## Mix Dataset Environment +Datasets consist of multi-domain, switch between environment initialization +- Reuse environment, agent loop different + +## Replicate results +- Sky-SQL-RL-7B +- Search-r1 3B \ No newline at end of file diff --git a/agent_system/environments/env_package/search/third_party/skyrl_gym/envs/__init__.py b/agent_system/environments/env_package/search/third_party/skyrl_gym/envs/__init__.py new file mode 100644 index 00000000..e4d32353 --- /dev/null +++ b/agent_system/environments/env_package/search/third_party/skyrl_gym/envs/__init__.py @@ -0,0 +1,8 @@ +"""Registers the internal gym envs.""" + +from agent_system.environments.env_package.search.third_party.skyrl_gym.envs.registration import register + +register( + id="search", + entry_point="skyrl_gym.envs.search.env:SearchEnv", +) diff --git a/agent_system/environments/env_package/search/third_party/skyrl_gym/envs/base_text_env.py b/agent_system/environments/env_package/search/third_party/skyrl_gym/envs/base_text_env.py new file mode 100644 index 00000000..2ac11f99 --- /dev/null +++ b/agent_system/environments/env_package/search/third_party/skyrl_gym/envs/base_text_env.py @@ -0,0 +1,83 @@ +from typing import Any, Dict, List, Optional, TypedDict +from agent_system.environments.env_package.search.third_party.skyrl_gym import Env +from typing import Tuple + +MessageType = Dict[str, str] +ConversationType = List[MessageType] + + +class BaseTextEnvStepOutput(TypedDict): + observations: List[Dict[str, str]] # OpenAI API Messages Format + reward: float + done: bool + metadata: Dict[str, Any] + postprocessed_action: Optional[str] = None + + +class BaseTextEnv(Env[str, str]): + """ + Base environment class for all text-in / text-out environments. + Supports tool-calling and multi-turn trajectories. + + Exposes only `step`, `init` and `close`. + + Input Types: + - ObsType: str (tool output, LLM input) + - ActType: str (LLM output) + """ + + def __init__(self): + super().__init__() + + # Metadata + self.turns = 0 + self.max_turns = 1 + + # Tool groups + self.tool_groups = [] + self.tool_to_toolgroup = {} + + def init_tool_groups(self, tool_groups: List = []) -> None: + """ + Initialize the tool groups for the environment. + """ + # Find ToolGroup for a given tool + self.tool_groups = tool_groups + self.tool_to_toolgroup = {} + for tool_group in self.tool_groups: + self.tool_to_toolgroup.update(tool_group.get_tool_to_group_mapping()) + + def _execute_tool(self, tool_group_name: str, tool_name: str, tool_input: Any) -> str: + """ + Find the right ToolGroup and Tool and execute it. + """ + for group in self.tool_groups: + if group.name == tool_group_name: + return group.execute_tool(tool_name, *tool_input) # tool_input must be tuple or list + + raise ValueError(f"ToolGroup '{tool_group_name}' not found.") + + def step(self, action: str) -> BaseTextEnvStepOutput: + """ + Runs one environment step. + + Return: + - new_obs: [{"role": "user", "content": observation}] + - reward: float + - done: bool + - postprocessed_action: Optional[str] + - Dict[str, Any]: any metadata + """ + pass + + def init(self, prompt: ConversationType) -> Tuple[ConversationType, Dict[str, Any]]: + """ + Return the first prompt to be given to the model and optional metadata. + """ + return prompt, {} + + def close(self): + """ + Closes the environment, override if needed by subclasses. + """ + pass diff --git a/agent_system/environments/env_package/search/third_party/skyrl_gym/envs/registration.py b/agent_system/environments/env_package/search/third_party/skyrl_gym/envs/registration.py new file mode 100644 index 00000000..4a14abc6 --- /dev/null +++ b/agent_system/environments/env_package/search/third_party/skyrl_gym/envs/registration.py @@ -0,0 +1,345 @@ +"""Functions for registering environments within skyrl_gym using public functions ``make``, ``register`` and ``spec``.""" + +from __future__ import annotations + +import copy +import dataclasses +import importlib +import json +from dataclasses import dataclass, field +from agent_system.environments.env_package.search.third_party.skyrl_gym import Env, error + +from typing import Protocol, Dict, Any + +__all__ = [ + "registry", + "EnvSpec", + # Functions + "make", + "spec", + "register", + "pprint_registry", +] + + +class EnvCreator(Protocol): + """Function type expected for an environment.""" + + def __call__(self, **kwargs: Any) -> Env: ... + + +@dataclass +class EnvSpec: + """A specification for creating environments with `skyrl_gym.make`. + + * **id**: The string used to create the environment with `skyrl_gym.make` + + * **entry_point**: A string for the environment location, ``(import path):(environment name)`` or a function that creates the environment. + NOTE[shu]: example is like a database path? or local RAG environment? + + * **kwargs**: Additional keyword arguments passed to the environment during initialisation + + """ + + id: str + entry_point: EnvCreator | str | None = field(default=None) + + # Environment arguments + kwargs: Dict[str, Any] = field(default_factory=dict) + + # post-init attributes + name: str = field(init=False) + + def __post_init__(self): + """Calls after the spec is created to extract name from the environment id.""" + self.name = self.id + + def make(self, **kwargs: Any) -> Env: + """Calls ``make`` using the environment spec and any keyword arguments.""" + return make(self, **kwargs) + + def to_json(self) -> str: + """Converts the environment spec into a json compatible string. + + Returns: + A jsonifyied string for the environment spec + """ + env_spec_dict = dataclasses.asdict(self) + env_spec_dict.pop("name") + + # To check that the environment spec can be transformed to a json compatible type + self._check_can_jsonify(env_spec_dict) + + return json.dumps(env_spec_dict) + + @staticmethod + def _check_can_jsonify(env_spec: Dict[str, Any]): + """Warns the user about serialisation failing if the spec contains a callable. + + Args: + env_spec: An environment. + + Returns: The specification with lambda functions converted to strings. + + """ + spec_name = env_spec["name"] if "name" in env_spec else env_spec["id"] + + for key, value in env_spec.items(): + if callable(value): + raise ValueError( + f"Callable found in {spec_name} for {key} attribute with value={value}. Currently, skyrl_gym does not support serialising callables." + ) + + @staticmethod + def from_json(json_env_spec: str) -> EnvSpec: + """Converts a JSON string into a specification stack. + + Args: + json_env_spec: A JSON string representing the env specification. + + Returns: + An environment spec + """ + parsed_env_spec = json.loads(json_env_spec) + + try: + env_spec = EnvSpec(**parsed_env_spec) + except Exception as e: + raise ValueError(f"An issue occurred when trying to make {parsed_env_spec} an EnvSpec") from e + + return env_spec + + def pprint( + self, + disable_print: bool = False, + include_entry_points: bool = False, + ) -> str | None: + """Pretty prints the environment spec. + + Args: + disable_print: If to disable print and return the output + include_entry_points: If to include the entry_points in the output + print_all: If to print all information, including variables with default values + + Returns: + If ``disable_print is True`` a string otherwise ``None`` + """ + output = f"id={self.id}" + if include_entry_points: + output += f"\nentry_point={self.entry_point}" + + if disable_print: + return output + else: + print(output) + + +# Global registry of environments. Meant to be accessed through `register` and `make` +registry: Dict[str, EnvSpec] = {} + + +def _find_spec(env_id: str) -> EnvSpec: + # For string id's, load the environment spec from the registry then make the environment spec + assert isinstance(env_id, str) + + # load the env spec from the registry + env_name = env_id + env_spec = registry.get(env_name) + + if env_spec is None: + raise error.Error( + f"No registered env with id: {env_name}. Did you register it, or import the package that registers it? Use `skyrl_gym.pprint_registry()` to see all of the registered environments." + ) + + return env_spec + + +def load_env_creator(name: str) -> EnvCreator: + """Loads an environment with name of style ``"(import path):(environment name)"`` and returns the environment creation function, normally the environment class type. + + Args: + name: The environment name + + Returns: + The environment constructor for the given environment name. + """ + mod_name, attr_name = name.split(":") + mod = importlib.import_module(mod_name) + fn = getattr(mod, attr_name) + return fn + + +def _check_spec_register(testing_spec: EnvSpec): + """NOTE[shu]: Checks that no environment with the same name already exists in the registry.""" + for env_spec in registry.values(): + if env_spec.name == testing_spec.name: + raise error.RegistrationError( + f"An environment with name `{testing_spec.name}` is already registered (id=`{env_spec.id}`). " + "Environment names must be unique." + ) + + +def register( + id: str, + entry_point: EnvCreator | str | None = None, + kwargs: Dict[str, Any] | None = None, +): + """ + Registers an environment in skyrl_gym with an ``id`` to use with `skyrl_gym.make` with the ``entry_point`` + being a string or callable for creating the environment. + + The ``id`` parameter corresponds to the name of the environment. + + It takes arbitrary keyword arguments, which are passed to the :class:`EnvSpec` ``kwargs`` parameter. + + Args: + id: The environment id + entry_point: The entry point for creating the environment + kwargs: arbitrary keyword arguments which are passed to the environment constructor on initialisation. + """ + assert entry_point is not None, "`entry_point` must be provided" + global registry + + if kwargs is None: + kwargs = dict() + + new_spec = EnvSpec( + id=id, + entry_point=entry_point, + kwargs=kwargs, + ) + _check_spec_register(new_spec) + registry[new_spec.id] = new_spec + + +def make( + id: str | EnvSpec, + **kwargs: Any, +) -> Env: + """Creates an environment previously registered with `skyrl_gym.register` or a :class:`EnvSpec`. + + To find all available environments use ``skyrl_gym.envs.registry.keys()`` for all valid ids. + + Args: + id: A string for the environment id or a :class:`EnvSpec`. + kwargs: Additional arguments to pass to the environment constructor. + + Returns: + An instance of the environment. + + Raises: + Error: If the ``id`` doesn't exist in the `registry` + """ + if isinstance(id, EnvSpec): + env_spec = id + else: + # For string id's, load the environment spec from the registry then make the environment spec + assert isinstance(id, str) + + # The environment name can include an unloaded module in "module:env_name" style + env_spec = _find_spec(id) + + assert isinstance(env_spec, EnvSpec) + + # Update the env spec kwargs with the `make` kwargs + env_spec_kwargs = copy.deepcopy(env_spec.kwargs) + env_spec_kwargs.update(kwargs) + + # Load the environment creator + if env_spec.entry_point is None: + raise error.Error(f"{env_spec.id} registered but entry_point is not specified") + elif callable(env_spec.entry_point): + env_creator = env_spec.entry_point + else: + # Assume it's a string + env_creator = load_env_creator(env_spec.entry_point) + + try: + env = env_creator(**env_spec_kwargs) + except TypeError as e: + raise type(e)(f"{e} was raised from the environment creator for {env_spec.id} with kwargs ({env_spec_kwargs})") + + if not isinstance(env, Env): + if str(env.__class__.__base__) == "": + raise TypeError( + "Gym is incompatible with skyrl_gym, please update the environment class to `skyrl_gym.Env`. " + ) + else: + raise TypeError(f"The environment must inherit from the skyrl_gym.Env class, actual class: {type(env)}. ") + + # Set the minimal env spec for the environment. + env.spec = EnvSpec( + id=env_spec.id, + entry_point=env_spec.entry_point, + kwargs=env_spec_kwargs, + ) + + return env + + +def spec(env_id: str) -> EnvSpec: + """Retrieve the :class:`EnvSpec` for the environment id from the `registry`. + + Args: + env_id: The environment id with the expected format of ``[(namespace)/]id[-v(version)]`` + + Returns: + The environment spec if it exists + + Raises: + Error: If the environment id doesn't exist + """ + env_spec = registry.get(env_id) + if env_spec is None: + raise error.Error(f"No registered env with id: {env_id}") + else: + assert isinstance( + env_spec, EnvSpec + ), f"Expected the registry for {env_id} to be an `EnvSpec`, actual type is {type(env_spec)}" + return env_spec + + +def pprint_registry( + print_registry: Dict[str, EnvSpec] = registry, + *, + num_cols: int = 3, + disable_print: bool = False, +) -> str | None: + """Pretty prints all environments in the registry without grouping by namespace. + + Args: + print_registry: Environment registry to be printed. By default, uses the global `registry`. + num_cols: Number of columns to arrange environments in. + disable_print: Whether to return a string instead of printing it. + """ + # Get all environment ids + env_ids = sorted(print_registry.keys()) + + if not env_ids: + output = "No environments registered." + if disable_print: + return output + else: + print(output) + return + + # Find the max width for nice column alignment + max_justify = max(len(env_id) for env_id in env_ids) + + # Build the output + output_lines = [] + current_line = "" + + for count, env_id in enumerate(env_ids, 1): + current_line += env_id.ljust(max_justify + 2) + + if count % num_cols == 0 or count == len(env_ids): + output_lines.append(current_line.rstrip()) + current_line = "" + + final_output = "\n".join(output_lines) + + if disable_print: + return final_output + else: + print(final_output) diff --git a/agent_system/environments/env_package/search/third_party/skyrl_gym/envs/search/env.py b/agent_system/environments/env_package/search/third_party/skyrl_gym/envs/search/env.py new file mode 100644 index 00000000..57d1df27 --- /dev/null +++ b/agent_system/environments/env_package/search/third_party/skyrl_gym/envs/search/env.py @@ -0,0 +1,131 @@ +from agent_system.environments.env_package.search.third_party.skyrl_gym.envs.base_text_env import BaseTextEnv, BaseTextEnvStepOutput, ConversationType +from typing import Any +from agent_system.environments.env_package.search.third_party.skyrl_gym.envs.search.utils import compute_score +from agent_system.environments.env_package.search.third_party.skyrl_gym.tools import SearchToolGroup +import re +from typing import Dict, Optional, List +from omegaconf import DictConfig + + +class SearchEnv(BaseTextEnv): + """ + Environment for Search execution tasks. + + Based on Verl + Search-R1 integration + """ + + def __init__(self, env_config: DictConfig): + super().__init__() + # Initialize the tools + # name is hardcoded to "SearchToolGroup", with tool name "search" + self.tool_group = SearchToolGroup( + search_url=env_config.search_url, + topk=env_config.topk, + timeout=env_config.timeout, + log_requests=env_config.log_requests, + ) + self.init_tool_groups([self.tool_group]) + + def reset(self, extras: Dict[str, Any] = {}) -> None: + assert "ground_truth" in extras, "ground_truth is required in extras field" + self.ground_truth = extras["ground_truth"] + self.max_turns = extras["max_turns"] if "max_turns" in extras else 3 + + self.data_source = extras.get("data_source", "unknown") + + # Chat history + # role (user, assistant), content (tool observation or LLM response) + self.chat_history: ConversationType = [] + self.done = False + self.turns = 0 + + + def _parse_action(self, action: str) -> List[Optional[str]]: + match = None + if "" in action and "" in action: + match = re.search(r"(.*?)", action, re.DOTALL) + return [match.group(1)] if match else [None] + + def _get_reward(self, action: str, done: bool) -> float: + if done: + # Concat all chat history into a single string and compute reward + chat_history_str = "".join([item["content"] for item in self.chat_history]) + return compute_score(chat_history_str, self.ground_truth) + else: + # No reward for intermediate steps for Search tasks + return 0 + + def _is_done(self, action: str) -> bool: + if self.turns >= self.max_turns: + return True + return "" in action and "" in action + + def _postprocess_action(self, action: str) -> str: + if "" in action: + return action.split("")[0] + "" + elif "" in action: + return action.split("")[0] + "" + else: + return action + + def _execute_tool(self, tool_group_name: str, tool_name: str, tool_input: Any) -> str: + tool_output = super()._execute_tool(tool_group_name, tool_name, tool_input) + if len(tool_output) > 0: + return "\n" + tool_output + "\n" + else: + return None + + def step(self, action: str) -> BaseTextEnvStepOutput: + self.turns += 1 + # action = self._postprocess_action(action) + self.chat_history.append({"role": "assistant", "content": action}) + + error = None + if not self.done: + done = self._is_done(action) + self.done = done + else: + done = True + + reward = self._get_reward(action, done) + + if done: + return BaseTextEnvStepOutput( + observations=[], reward=reward, done=done, metadata={"data_source": self.data_source}, postprocessed_action=action + ) + + try: + query = self._parse_action(action) + observation = self._execute_tool("SearchToolGroup", "search", query) + except Exception as e: + error = str(e) + observation = None + + # Wrap the observation properly as a message + if observation: + new_obs = {"role": "user", "content": observation} + elif error: + # Give error as observation if any + print(f"!!(Warning) an error when calling tools: {error}") + new_obs = {"role": "user", "content": error} + else: + new_obs = None + + info = { + "tool_group": "SearchToolGroup", + "tool_name": "search", + "tool_input": query, + "data_source": self.data_source, + } + + # Update chat history + if new_obs: + self.chat_history.append(new_obs) + + return BaseTextEnvStepOutput( + observations=[new_obs] if new_obs else [], + reward=reward, + done=done, + metadata=info, + postprocessed_action=action, + ) \ No newline at end of file diff --git a/agent_system/environments/env_package/search/third_party/skyrl_gym/envs/search/utils.py b/agent_system/environments/env_package/search/third_party/skyrl_gym/envs/search/utils.py new file mode 100644 index 00000000..916e4ccb --- /dev/null +++ b/agent_system/environments/env_package/search/third_party/skyrl_gym/envs/search/utils.py @@ -0,0 +1,118 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 Search-R1 Contributors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/PeterGriffinJin/Search-R1/blob/main/verl/utils/reward_score/qa_em.py + +import re +import string + + +def normalize_answer(s): + def remove_articles(text): + return re.sub(r"\b(a|an|the)\b", " ", text) + + def white_space_fix(text): + return " ".join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def em_check(prediction, golden_answers): + if isinstance(golden_answers, str): + golden_answers = [golden_answers] + normalized_prediction = normalize_answer(prediction) + score = 0 + for golden_answer in golden_answers: + golden_answer = normalize_answer(golden_answer) + if golden_answer == normalized_prediction: + score = 1 + break + return score + + +def subem_check(prediction, golden_answers): + if isinstance(golden_answers, str): + golden_answers = [golden_answers] + normalized_prediction = normalize_answer(prediction) + score = 0 + for golden_answer in golden_answers: + golden_answer = normalize_answer(golden_answer) + if golden_answer in normalized_prediction: + score = 1 + break + return score + + +def extract_solution(solution_str): + """Extract the equation from the solution string.""" + answer_pattern = r"(.*?)" + match = re.finditer(answer_pattern, solution_str, re.DOTALL) + matches = list(match) + + # If there are 0 matches, return None + if len(matches) < 1: + return None + + # If there are 2 or more matches, return the last one + return matches[-1].group(1).strip() + + +def compute_score(solution_str, ground_truth, method="strict", format_score=0.0, score=1.0): + """The scoring function for exact match (EM). + + Args: + solution_str: the solution text + ground_truth: the ground truth + method: the method to extract the solution, choices are 'strict' and 'flexible' + format_score: the score for the format + score: the score for the correct answer + """ + answer = extract_solution(solution_str=solution_str) + + if answer is None: + return 0 + else: + if em_check(answer, ground_truth["target"]): + return score + else: + return format_score + + +def compute_score_subem(solution_str, ground_truth, method="strict", format_score=0.0, score=1.0): + """The scoring function for substring exact match (EM). + + Args: + solution_str: the solution text + ground_truth: the ground truth + method: the method to extract the solution, choices are 'strict' and 'flexible' + format_score: the score for the format + score: the score for the correct answer + """ + answer = extract_solution(solution_str=solution_str) + + if answer is None: + return 0 + else: + if subem_check(answer, ground_truth["target"]): + return score + else: + return format_score diff --git a/agent_system/environments/env_package/search/third_party/skyrl_gym/error.py b/agent_system/environments/env_package/search/third_party/skyrl_gym/error.py new file mode 100644 index 00000000..6fdd7fe7 --- /dev/null +++ b/agent_system/environments/env_package/search/third_party/skyrl_gym/error.py @@ -0,0 +1,9 @@ +"""Set of Error classes for skyrl_gym.""" + + +class Error(Exception): + """Error superclass.""" + + +class RegistrationError(Error): + """Raised when the user attempts to register an invalid env.""" diff --git a/agent_system/environments/env_package/search/third_party/skyrl_gym/tools/__init__.py b/agent_system/environments/env_package/search/third_party/skyrl_gym/tools/__init__.py new file mode 100644 index 00000000..05bd9e2f --- /dev/null +++ b/agent_system/environments/env_package/search/third_party/skyrl_gym/tools/__init__.py @@ -0,0 +1,3 @@ +from .search import SearchToolGroup + +__all__ = ["SearchToolGroup"] diff --git a/agent_system/environments/env_package/search/third_party/skyrl_gym/tools/core.py b/agent_system/environments/env_package/search/third_party/skyrl_gym/tools/core.py new file mode 100644 index 00000000..12d9a98c --- /dev/null +++ b/agent_system/environments/env_package/search/third_party/skyrl_gym/tools/core.py @@ -0,0 +1,60 @@ +from typing import Dict, Callable, Any, Optional, List + + +class tool: + """ + A tool that can be used to execute a function. + """ + + def __init__(self, func: Callable): + self.func = func + self.name = func.__name__ + + def __get__(self, instance, owner): + if instance is None: + return self # Return the descriptor itself when accessed from the class + return lambda *args, **kwargs: self.func(instance, *args, **kwargs) + + +class ToolGroup: + """ + A group of tools that can be used together. + """ + + def __init__(self, name: str): + self.name = name + self._tool_registry: Dict[str, Callable] = {} + self._register_tools() + + def get_name(self): + return self.name + + def _register_tools(self): + # Register all methods decorated with @tool + + # Tool names must be unique across tool groups. + # TODO: Support duplicate tool names across tool groups via namespacing + for attr_name in dir(self): + # Look for the descriptor on the class, not the instance + raw = getattr(type(self), attr_name, None) + if isinstance(raw, tool): + self._tool_registry[raw.name] = getattr(self, attr_name) + + def get_tool(self, name: str) -> Optional[Callable]: + # Get a tool by name, returns None if not found + return self._tool_registry.get(name) + + def get_tool_names(self) -> List[str]: + # Get all available tool names + return list(self._tool_registry.keys()) + + def execute_tool(self, name: str, *args, **kwargs) -> Any: + # Execute a tool by name with given arguments + tool_func = self.get_tool(name) + if tool_func: + return tool_func(*args, **kwargs) + raise ValueError(f"Tool '{name}' not found in group '{self.name}'.") + + def get_tool_to_group_mapping(self) -> Dict[str, str]: + # Get mapping of tool names to group name + return {name: self.name for name in self._tool_registry} diff --git a/agent_system/environments/env_package/search/third_party/skyrl_gym/tools/search.py b/agent_system/environments/env_package/search/third_party/skyrl_gym/tools/search.py new file mode 100644 index 00000000..e061e163 --- /dev/null +++ b/agent_system/environments/env_package/search/third_party/skyrl_gym/tools/search.py @@ -0,0 +1,254 @@ +import json +import logging +import requests +import uuid +import time +import threading +from typing import Tuple, Optional, Any, Dict +from urllib.parse import urlparse + +from agent_system.environments.env_package.search.third_party.skyrl_gym.tools.core import tool, ToolGroup + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +DEFAULT_TIMEOUT = 30 +MAX_RETRIES = 10 +INITIAL_RETRY_DELAY = 1 + + +def call_search_api( + retrieval_service_url: str, + query: str, + topk: int = 3, + return_scores: bool = True, + timeout: int = DEFAULT_TIMEOUT, + log_requests: bool = True, + session: Optional[requests.Session] = None, +) -> Tuple[Optional[Dict[str, Any]], Optional[str]]: + """ + Calls the search API with a single query. + + Args: + retrieval_service_url: The URL of the search API. + query: The query to search for. + topk: The number of results to return. + return_scores: Whether to return scores for the results. + timeout: The timeout for the request. + log_requests: Whether to log requests. + session: The session to use for the request. If none is provided, a new session will be created. + + Returns: + response: The response from the search API (json if successful, None otherwise) + error_msg: The error message if the request failed. + """ + request_id = str(uuid.uuid4()) + log_prefix = f"[Search Request ID: {request_id}] " + + payload = {"query": query, "topk": topk, "return_scores": return_scores} + headers = {"Content-Type": "application/json", "Accept": "application/json"} + + # Use provided session or create a new one for this request + if session is None: + session = requests.Session() + should_close_session = True + else: + should_close_session = False + + last_error = None + for attempt in range(MAX_RETRIES): + try: + if log_requests: + logger.info( + f"{log_prefix}Attempt {attempt + 1}/{MAX_RETRIES}: Calling search API at {retrieval_service_url}" + ) + response = session.post( + retrieval_service_url, + headers=headers, + json=payload, + timeout=timeout, + ) + + # Check for Gateway Timeout (504) and other server errors for retrying + if response.status_code in [500, 502, 503, 504]: + last_error = f"{log_prefix}API Request Error: Server Error ({response.status_code}) on attempt {attempt + 1}/{MAX_RETRIES}" + logger.warning(last_error) + if attempt < MAX_RETRIES - 1: + delay = INITIAL_RETRY_DELAY * (attempt + 1) + logger.info(f"{log_prefix}Retrying after {delay} seconds...") + time.sleep(delay) + continue + + # Check for other HTTP errors (e.g., 4xx) + response.raise_for_status() + + # If successful (status code 2xx) + if log_requests: + logger.info(f"{log_prefix}Search API call successful on attempt {attempt + 1}") + + # Close session if we created it + if should_close_session: + session.close() + + return response.json(), None + + except requests.exceptions.ConnectionError as e: + last_error = f"{log_prefix}Connection Error: {e}" + logger.warning(last_error) + if attempt < MAX_RETRIES - 1: + delay = INITIAL_RETRY_DELAY * (attempt + 1) + logger.info(f"{log_prefix}Retrying after {delay} seconds...") + time.sleep(delay) + continue + except requests.exceptions.Timeout as e: + last_error = f"{log_prefix}Timeout Error: {e}" + logger.warning(last_error) + if attempt < MAX_RETRIES - 1: + delay = INITIAL_RETRY_DELAY * (attempt + 1) + logger.info(f"{log_prefix}Retrying after {delay} seconds...") + time.sleep(delay) + continue + except requests.exceptions.RequestException as e: + last_error = f"{log_prefix}API Request Error: {e}" + break # Exit retry loop on other request errors + except json.JSONDecodeError as e: + raw_response_text = response.text if "response" in locals() else "N/A" + last_error = f"{log_prefix}API Response JSON Decode Error: {e}, Response: {raw_response_text[:200]}" + break # Exit retry loop on JSON decode errors + except Exception as e: + last_error = f"{log_prefix}Unexpected Error: {e}" + break # Exit retry loop on other unexpected errors + + # If we reach here, all attempts failed + logger.error(f"{log_prefix}API Request Failed after {MAX_RETRIES} attempts: {last_error}") + + # Close session if we created it + if should_close_session: + session.close() + + return None, last_error + + +def _passages2string(retrieval_result): + format_reference = "" + for idx, doc_item in enumerate(retrieval_result): + content = doc_item["document"]["contents"].strip() + format_reference += f"Doc {idx+1}: {content}\n" + return format_reference + + +class SearchToolGroup(ToolGroup): + # Class-level session pool shared across all instances + _session_pool = {} + _session_lock = threading.Lock() + + @classmethod + def _get_shared_session(cls, base_url: str) -> requests.Session: + """Get or create a shared session for the given base URL""" + with cls._session_lock: + if base_url not in cls._session_pool: + session = requests.Session() + # Configure connection pooling + adapter = requests.adapters.HTTPAdapter( + pool_connections=512, # Number of connection pools + pool_maxsize=512, # Max connections per pool + max_retries=0, # We handle retries ourselves + pool_block=False, # Don't block if pool is full + ) + session.mount("http://", adapter) + session.mount("https://", adapter) + cls._session_pool[base_url] = session + logger.info(f"Created shared session pool for {base_url}") + return cls._session_pool[base_url] + + def __init__(self, search_url="http://127.0.0.1:8000/retrieve", topk=3, timeout=DEFAULT_TIMEOUT, log_requests=True): + self.search_url = search_url + self.topk = topk + self.timeout = timeout + self.log_requests = log_requests + + # Extract base URL for session sharing + parsed_url = urlparse(self.search_url) + self.base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + + # Get shared session for this base URL + self.session = self._get_shared_session(self.base_url) + if self.log_requests: + logger.info(f"SearchToolGroup initialized using shared session pool for {self.base_url}") + + super().__init__(name="SearchToolGroup") + + @tool + def search(self, query: str) -> str: + # NOTE(shu): add warning messages here? + if query is None: + return "" + + query = query.strip() + + try: + api_response, error_msg = call_search_api( + retrieval_service_url=self.search_url, + query=query, + topk=self.topk, + timeout=self.timeout, + log_requests=self.log_requests, + session=self.session, # Pass our shared session for connection reuse + ) + except Exception as e: + error_msg = f"API Request Exception during batch search: {e}" + logger.error(f"Batch search: {error_msg}") + + metadata = { + "query": query, + "api_request_error": error_msg, + "api_response": None, + "status": "unknown", + "total_results": 0, + "formatted_result": None, + } + + result_text = json.dumps({"result": "Search request failed or timed out after retries."}) + + if error_msg: + metadata["status"] = "api_error" + result_text = json.dumps({"result": f"Search error: {error_msg}"}) + logger.error(f"Batch search: API error occurred: {error_msg}") + elif api_response: + logger.debug(f"Batch search: API Response: {api_response}") + metadata["api_response"] = api_response + + try: + raw_results = api_response.get("result", []) + if raw_results: + pretty_results = [] + total_results = 0 + for retrieval in raw_results: + formatted = _passages2string(retrieval) + pretty_results.append(formatted) + total_results += len(retrieval) if isinstance(retrieval, list) else 1 + + final_result = "\n---\n".join(pretty_results) + result_text = json.dumps({"result": final_result}) + metadata["status"] = "success" + metadata["total_results"] = total_results + metadata["formatted_result"] = final_result + if self.log_requests: + logger.info(f"Batch search: Successful, got {total_results} total results") + else: + result_text = json.dumps({"result": "No search results found."}) + metadata["status"] = "no_results" + metadata["total_results"] = 0 + if self.log_requests: + logger.info("Batch search: No results found") + except Exception as e: + error_msg = f"Error processing search results: {e}" + result_text = json.dumps({"result": error_msg}) + metadata["status"] = "processing_error" + logger.error(f"Batch search: {error_msg}") + else: + metadata["status"] = "unknown_api_state" + result_text = json.dumps({"result": "Unknown API state (no response and no error message)."}) + logger.error("Batch search: Unknown API state.") + + return result_text diff --git a/agent_system/environments/prompts/__init__.py b/agent_system/environments/prompts/__init__.py index 41fde5a9..9783e0fb 100644 --- a/agent_system/environments/prompts/__init__.py +++ b/agent_system/environments/prompts/__init__.py @@ -2,4 +2,5 @@ from .webshop import * from .sokoban import * from .gym_cards import * -from .appworld import * \ No newline at end of file +from .appworld import * +from .search import * \ No newline at end of file diff --git a/agent_system/environments/prompts/alfworld.py b/agent_system/environments/prompts/alfworld.py index 9379af97..3604df7c 100644 --- a/agent_system/environments/prompts/alfworld.py +++ b/agent_system/environments/prompts/alfworld.py @@ -27,7 +27,7 @@ The team's overall task is to: {task_description} At each step, you and your teammates collaborate to decide the best next action. -Your team is now at "step 1" and the current observation is: {current_observation} +Your team is now at "step {current_step}" and the current observation is: {current_observation} The admissible actions in the current situation are: [{admissible_actions}]. Now, you and your teammates must work together to determine the action. diff --git a/agent_system/environments/prompts/search.py b/agent_system/environments/prompts/search.py new file mode 100644 index 00000000..131419e0 --- /dev/null +++ b/agent_system/environments/prompts/search.py @@ -0,0 +1,41 @@ +# SEARCH_TEMPLATE_NO_HIS = """ +# Answer the given question. You must conduct reasoning inside and first every time you get new information. After reasoning, if you find you lack some knowledge, you can call a search engine by query and it will return the top searched results between and . You can search as many times as your want. If you find no further external knowledge needed, you can directly provide the answer inside and , without detailed illustrations. For example, Beijing . Question: {task_description} +# """ + +# SEARCH_TEMPLATE = """ +# Answer the given question. You must conduct reasoning inside and first every time you get new information. After reasoning, if you find you lack some knowledge, you can call a search engine by query and it will return the top searched results between and . You can search as many times as your want. If you find no further external knowledge needed, you can directly provide the answer inside and , without detailed illustrations. For example, Beijing . Question: {task_description} +# {memory_context} +# """ + + +SEARCH_TEMPLATE_NO_HIS = """ +You are an expert agent tasked with answering the given question step-by-step. +Your question: {task_description} + +Now it's your turn to respond for the current step. +You should first conduct reasoning process. This process MUST be enclosed within tags. +After completing your reasoning, choose only one of the following actions (do not perform both): +(1) If you find you lack some knowledge, you can call a search engine to get more external information using format: your query . +(2) If you have enough knowledge to answer the question confidently, provide your final answer within tags, without detailed illustrations. For example, Beijing. + +Your response should be in one of the following forms: "... ..." or "... ...". +""" + +SEARCH_TEMPLATE = """ +You are an expert agent tasked with answering the given question step-by-step. +Your question: {task_description} + +Prior to this step, you have already taken {step_count} step(s). Below is the interaction history where wrapped your past search queries and wrapped the corresponding search results returned by the external search engine. History: +{memory_context} + +Now it's your turn to respond for the current step. +You should first conduct reasoning process. This process MUST be enclosed within tags. +After completing your reasoning, choose only one of the following actions (do not perform both): +(1) If you find you lack some knowledge, you can call a search engine to get more external information using format: your query . +(2) If you have enough knowledge to answer the question confidently, provide your final answer within tags, without detailed illustrations. For example, Beijing. + +Your response should be in one of the following forms: "... ..." or "... ...". +""" + + + diff --git a/agent_system/environments/prompts/webshop.py b/agent_system/environments/prompts/webshop.py index 8251744b..b8258bf4 100644 --- a/agent_system/environments/prompts/webshop.py +++ b/agent_system/environments/prompts/webshop.py @@ -35,7 +35,7 @@ The team's overall task is: {task_description}. At each step, you and your teammates collaborate to decide the best next action. -Your team is now at "step 1" and the current observation is: {current_observation}. +Your team is now at "step {current_step}" and the current observation is: {current_observation}. The admissible actions in the current situation are: [ {available_actions} @@ -50,7 +50,7 @@ The team's overall task is: {task_description} At each step, you and your teammates collaborate to decide the best next action. -Prior to this step, your team has taken {step_count} environment step(s). Below is the history memory from "step 1" to "step {step_count}": +Prior to this step, your team has taken {step_count} environment step(s). Below is the history memory: [ {memory} ] diff --git a/agent_system/memory/__init__.py b/agent_system/memory/__init__.py index e8fef65f..22e1a787 100644 --- a/agent_system/memory/__init__.py +++ b/agent_system/memory/__init__.py @@ -1 +1 @@ -from .memory import SimpleMemory \ No newline at end of file +from .memory import SimpleMemory, SearchMemory \ No newline at end of file diff --git a/agent_system/memory/memory.py b/agent_system/memory/memory.py index 5c6896c4..eaef476c 100644 --- a/agent_system/memory/memory.py +++ b/agent_system/memory/memory.py @@ -82,4 +82,89 @@ def fetch( memory_contexts.append("\n".join(lines)) valid_lengths.append(valid_len) + return memory_contexts, valid_lengths + + +class SearchMemory(BaseMemory): + """ + Memory manager: responsible for storing & fetching per‑environment history records. + """ + def __init__(self): + self._data = None + self.keys = None + self.batch_size = 0 + + def __len__(self): + return len(self._data) + + def __getitem__(self, idx): + return self._data[idx] + + def reset(self, batch_size: int): + if self._data is not None: + self._data.clear() + self._data = [[] for _ in range(batch_size)] + self.batch_size = batch_size + self.keys = None + + def store(self, record: Dict[str, List[Any]]): + """ + Store a new record (one step of history) for each environment instance. + + Args: + record (Dict[str, List[Any]]): + A dictionary where each key corresponds to a type of data + (e.g., 'text_obs', 'action'), and each value is a list of + length `batch_size`, containing the data for each environment. + """ + if self.keys is None: + self.keys = list(record.keys()) + assert self.keys == list(record.keys()) + + for env_idx in range(self.batch_size): + self._data[env_idx].append({k: record[k][env_idx] for k in self.keys}) + + def fetch( + self, + history_length: int, + obs_key: str, + action_key: str, + ) -> Tuple[List[str], List[int]]: + """ + Fetch and format recent interaction history for each environment instance. + Args: + history_length (int): + Maximum number of past steps to retrieve per environment. + obs_key (str): + The key name used to access the observation in stored records. + For example: "text_obs" or "Observation", depending on the environment. + action_key (str): + The key name used to access the action in stored records. + For example: "action" or "Action". + Returns: + memory_contexts : List[str] + A list of formatted action history strings for each environment. + valid_lengths : List[int] + A list of the actual number of valid history steps per environment. + """ + memory_contexts, valid_lengths = [], [] + + for env_idx in range(self.batch_size): + recent = self._data[env_idx][-history_length:] + valid_len = len(recent) + start_idx = len(self._data[env_idx]) - valid_len + + lines = [] + for j, rec in enumerate(recent): + step_num = start_idx + j + 1 + act = rec[action_key] + obs = rec[obs_key] + lines.append( + f"Step {step_num}:{act} {obs}\n" + # f"{act}\n{obs}\n" + ) + + memory_contexts.append("\n".join(lines)) + valid_lengths.append(valid_len) + return memory_contexts, valid_lengths \ No newline at end of file diff --git a/agent_system/multi_turn_rollout/rollout_loop.py b/agent_system/multi_turn_rollout/rollout_loop.py index 3c721830..d4c3cb6a 100644 --- a/agent_system/multi_turn_rollout/rollout_loop.py +++ b/agent_system/multi_turn_rollout/rollout_loop.py @@ -7,6 +7,7 @@ from agent_system.multi_turn_rollout.utils import to_list_of_dict, torch_to_numpy, filter_group_data, preprocess_batch from agent_system.environments import EnvironmentManagerBase from typing import List, Dict, Any, Optional +from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto class TrajectoryCollector: def __init__(self, config, tokenizer: PreTrainedTokenizer, processor=None): @@ -105,18 +106,12 @@ def vanilla_multi_turn_loop( success (Dict[str, np.ndarray]): Success samples for each environment traj_uid (np.ndarray): Trajectory unique identifiers """ - # Initial observations from the environment - obs, infos = envs.reset() - - # Initialize trajectory collection - lenght_obs = len(obs['text']) if obs['text'] is not None else len(obs['image']) - if len(gen_batch.batch) != lenght_obs and self.config.env.rollout.n > 0: - gen_batch = gen_batch.repeat(repeat_times=self.config.env.rollout.n, interleave=True) - assert len(gen_batch.batch) == lenght_obs, f"gen_batch size {len(gen_batch.batch)} does not match obs size {lenght_obs}" - batch_size = len(gen_batch.batch['input_ids']) - batch_output = None + batch_size = len(gen_batch.batch) + # Initial observations from the environment + obs, infos = envs.reset(kwargs=gen_batch.non_tensor_batch.get('tools_kwargs', None)) + if self.config.env.rollout.n > 0: # env grouping uid_batch = [] for i in range(batch_size): @@ -159,7 +154,11 @@ def vanilla_multi_turn_loop( batch_input.meta_info = gen_batch.meta_info - batch_output = actor_rollout_wg.generate_sequences(batch_input) + # pad to be divisible by dp_size + batch_input_padded, pad_size = pad_dataproto_to_divisor(batch_input, actor_rollout_wg.world_size) + batch_output_padded = actor_rollout_wg.generate_sequences(batch_input_padded) + # # unpad + batch_output = unpad_dataproto(batch_output_padded, pad_size=pad_size) batch.non_tensor_batch['uid'] = uid_batch batch.non_tensor_batch['traj_uid'] = traj_uid @@ -299,7 +298,9 @@ def multi_turn_loop( Returns: DataProto: Final collected trajectory data with metadata. """ - # Initial observations from the environment + if is_train: + gen_batch = gen_batch.repeat(repeat_times=self.config.env.rollout.n, interleave=True) + if self.config.algorithm.filter_groups.enable and is_train: # Dynamic Sampling (for DAPO and Dynamic GiGPO) total_batch_list, total_episode_rewards, total_episode_lengths, total_success, total_traj_uid = \ @@ -373,22 +374,17 @@ def __init__( raise ValueError(f"Unknown executor_type '{executor_type}'.") # ------------------------------------------------------------------ - def vanilla_multi_turn_loop( # noqa: D401 – doc in base class + def vanilla_multi_turn_loop( self, gen_batch: DataProto, actor_rollout_wg, envs: EnvironmentManagerBase, ): - obs, infos = envs.reset() - # Initialize trajectory collection - lenght_obs = len(obs['text']) if obs['text'] is not None else len(obs['image']) - if len(gen_batch.batch) != lenght_obs and self.config.env.rollout.n > 0: - gen_batch = gen_batch.repeat(repeat_times=self.config.env.rollout.n, interleave=True) - assert len(gen_batch.batch) == lenght_obs, f"gen_batch size {len(gen_batch.batch)} does not match obs size {lenght_obs}" + batch_size = len(gen_batch.batch) - batch_size = len(gen_batch.batch['input_ids']) - batch_output = None + obs, infos = envs.reset(kwargs=gen_batch.non_tensor_batch.get('tools_kwargs', None)) + self.multiagent_executor.reset() if self.config.env.rollout.n > 0: # env grouping uid_batch = [] diff --git a/agent_system/multi_turn_rollout/utils.py b/agent_system/multi_turn_rollout/utils.py index 3e154508..b836f93d 100644 --- a/agent_system/multi_turn_rollout/utils.py +++ b/agent_system/multi_turn_rollout/utils.py @@ -68,9 +68,9 @@ def preprocess_fn( dict: Contains processed input data such as input_ids, attention_mask, etc. """ - # raw_prompt = gen_batch.non_tensor_batch['raw_prompt'][item] + raw_prompt = gen_batch.non_tensor_batch['raw_prompt'][item] data_source = gen_batch.non_tensor_batch['data_source'][item] - + # Get observation components obs_texts = obs.get('text', None) obs_images = obs.get('image', None) @@ -88,18 +88,23 @@ def preprocess_fn( # obs_content = obs_content.replace('', '') # Build chat structure - obs_content = '' - if obs_text is not None: - obs_content += obs_text - else: - print(f"Warning: No text observation found!") - - - chat = np.array([{ - "content": obs_content, - "role": "user", - }]) - + system_prompt = "You are a helpful and harmless assistant." + for message in raw_prompt: + if message['role'] == 'system': + system_prompt = message['content'] + # if message['role'] == 'user': + # context_from_dataset = message['content'] + + # if len(context_from_dataset) > 0 and obs_text is not None: + # obs_content = obs_text.replace('{placeholder_of_dataset_context}', context_from_dataset) + # else: + # print(f"Warning: No text observation found!") + + obs_content = obs_text + chat = np.array([ + {"content": system_prompt, "role": "system"}, + {"content": obs_content, "role": "user",} + ]) # Apply chat template prompt_with_chat_template = tokenizer.apply_chat_template( chat, @@ -141,7 +146,7 @@ def preprocess_fn( max_length=config.data.max_prompt_length, pad_token_id=tokenizer.pad_token_id, left_pad=True, - truncation='error') + truncation=config.data.truncation,) @@ -156,12 +161,26 @@ def preprocess_fn( else: position_ids = compute_position_id_with_mask(attention_mask) + + raw_prompt_ids = tokenizer.encode(raw_prompt, add_special_tokens=False) + if len(raw_prompt_ids) > config.data.max_prompt_length: + if config.data.truncation == "left": + raw_prompt_ids = raw_prompt_ids[-config.data.max_prompt_length :] + elif config.data.truncation == "right": + raw_prompt_ids = raw_prompt_ids[: config.data.max_prompt_length] + elif config.data.truncation == "middle": + left_half = config.data.max_prompt_length // 2 + right_half = config.data.max_prompt_length - left_half + raw_prompt_ids = raw_prompt_ids[:left_half] + raw_prompt_ids[-right_half:] + elif config.data.truncation == "error": + raise RuntimeError(f"Prompt length {len(raw_prompt_ids)} is longer than {config.data.max_prompt_length}.") + # Build final output dict row_dict.update({ 'input_ids': input_ids[0], 'attention_mask': attention_mask[0], 'position_ids': position_ids[0], - 'raw_prompt_ids': tokenizer.encode(raw_prompt, add_special_tokens=False), + 'raw_prompt_ids': raw_prompt_ids, 'anchor_obs': _obs_anchor, 'index': item, 'data_source': data_source @@ -247,41 +266,54 @@ def process_image(image, max_pixels: int = 2048 * 2048, min_pixels: int = 256 * def adjust_batch(config, data: DataProto, mode="copy") -> DataProto: - size_divisor_ref = config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu * config.trainer.n_gpus_per_node + use_adaptive_bs = config.actor_rollout_ref.actor.use_adaptive_ppo_mini_batch_size + ppo_mini_update_num = config.actor_rollout_ref.actor.ppo_mini_update_num + + size_divisor_ref = config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu * config.trainer.n_gpus_per_node size_divisor_rollout = config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu * config.trainer.n_gpus_per_node - size_divisor_actor = config.actor_rollout_ref.actor.ppo_mini_batch_size + if use_adaptive_bs: + size_divisor_actor = config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu * config.trainer.n_gpus_per_node * ppo_mini_update_num + else: + size_divisor_actor = config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu * config.trainer.n_gpus_per_node + size_divisor = np.lcm.reduce(np.array([size_divisor_ref, size_divisor_rollout, size_divisor_actor])).item() # check if the batch size is divisible by the dp size, if not, delete the last few samples to make it divisible bs = len(data) remainder = bs % size_divisor if remainder == 0: - return data - - if mode == "delete": - # Generate indices to remove, rather than indices to keep - remove_indices = np.random.choice(bs, remainder, replace=False) - # Sort remove_indices to maintain stability when deleting - remove_indices = np.sort(remove_indices) - - # Create a boolean mask for elements to keep - keep_mask = np.ones(bs, dtype=bool) - keep_mask[remove_indices] = False - - keep_mask_tensor = torch.tensor(keep_mask, dtype=torch.bool, device=data.batch['input_ids'].device) - # Apply the mask to keep elements in their original order - tensor_data = data.batch[keep_mask_tensor] - non_tensor_data = {key: val[keep_mask] for key, val in data.non_tensor_batch.items()} - adjusted_batch = DataProto(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=data.meta_info) - del data - elif mode == "copy": - to_add = size_divisor - remainder - dup_indices = np.random.choice(bs, to_add, replace=False) - dup_proto = data.select_idxs(dup_indices) - - adjusted_batch = DataProto.concat([data, dup_proto]) + adjusted_batch = data else: - raise ValueError(f"Unsupported mode: {mode}") + if mode == "delete": + # Generate indices to remove, rather than indices to keep + remove_indices = np.random.choice(bs, remainder, replace=False) + # Sort remove_indices to maintain stability when deleting + remove_indices = np.sort(remove_indices) + + # Create a boolean mask for elements to keep + keep_mask = np.ones(bs, dtype=bool) + keep_mask[remove_indices] = False + + keep_mask_tensor = torch.tensor(keep_mask, dtype=torch.bool, device=data.batch['input_ids'].device) + # Apply the mask to keep elements in their original order + tensor_data = data.batch[keep_mask_tensor] + non_tensor_data = {key: val[keep_mask] for key, val in data.non_tensor_batch.items()} + adjusted_batch = DataProto(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=data.meta_info) + del data + elif mode == "copy": + to_add = size_divisor - remainder + dup_indices = np.random.choice(bs, to_add, replace=False) + dup_proto = data.select_idxs(dup_indices) + + adjusted_batch = DataProto.concat([data, dup_proto]) + else: + raise ValueError(f"Unsupported mode: {mode}") + + if use_adaptive_bs: + adjusted_bs = len(adjusted_batch) + assert adjusted_bs % ppo_mini_update_num == 0, f"Adjusted batch size {adjusted_bs} is not divisible by update_num {ppo_mini_update_num}." + adjusted_batch.meta_info["ppo_mini_batch_size"] = (adjusted_bs // ppo_mini_update_num) + assert adjusted_batch.meta_info["ppo_mini_batch_size"] > 0, "ppo_mini_batch_size must be greater than 0." return adjusted_batch diff --git a/examples/data_preprocess/preprocess_search_r1_dataset.py b/examples/data_preprocess/preprocess_search_r1_dataset.py index a602d020..269b53d7 100644 --- a/examples/data_preprocess/preprocess_search_r1_dataset.py +++ b/examples/data_preprocess/preprocess_search_r1_dataset.py @@ -30,15 +30,18 @@ # Configuration constants DEFAULT_SYSTEM_CONTENT = "You are a helpful and harmless assistant." +# DEFAULT_USER_CONTENT_PREFIX = ( +# "Answer the given question. You must conduct reasoning inside and " +# "first every time you get new information. After reasoning, if you find you lack " +# "some knowledge, you can call a search engine by query " +# "and it will return the top searched results between and " +# ". You can search as many times as your want. If you find no " +# "further external knowledge needed, you can directly provide the answer inside " +# " and , without detailed illustrations. For example, " +# " Beijing . Question: " +# ) DEFAULT_USER_CONTENT_PREFIX = ( - "Answer the given question. You must conduct reasoning inside and " - "first every time you get new information. After reasoning, if you find you lack " - "some knowledge, you can call a search engine by query " - "and it will return the top searched results between and " - ". You can search as many times as your want. If you find no " - "further external knowledge needed, you can directly provide the answer inside " - " and , without detailed illustrations. For example, " - " Beijing . Question: " + "" ) @@ -68,10 +71,14 @@ def process_single_row(row, current_split_name, row_index): ground_truth = row.get("golden_answers", []) # Process data source - data_source_tagged = "searchR1_" + str(row.get("data_source", "")) + data_source_tagged = str(row.get("data_source", "")) # Build tools kwargs structure - tools_kwargs = {"search": {"create_kwargs": {"ground_truth": ground_truth, "question": question, "data_source": data_source_tagged}}} + tools_kwargs = { + "search": { + "create_kwargs": {"ground_truth": ground_truth, "question": question, "data_source": data_source_tagged} + } + } # Build complete extra_info structure extra_info = { @@ -155,8 +162,14 @@ def apply_process_row(row, split_name=split): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Download Search-R1 from HuggingFace, process, and save to Parquet.") - parser.add_argument("--hf_repo_id", default="PeterJinGo/nq_hotpotqa_train", help="HuggingFace dataset repository ID.") - parser.add_argument("--local_dir", default="~/data/searchR1_processed_direct", help="Local directory to save the processed Parquet files.") + parser.add_argument( + "--hf_repo_id", default="PeterJinGo/nq_hotpotqa_train", help="HuggingFace dataset repository ID." + ) + parser.add_argument( + "--local_dir", + default="~/data/searchR1_processed_direct", + help="Local directory to save the processed Parquet files.", + ) parser.add_argument("--hdfs_dir", default=None, help="Optional HDFS directory to copy the Parquet files to.") args = parser.parse_args() @@ -165,4 +178,4 @@ def apply_process_row(row, split_name=split): system_content = DEFAULT_SYSTEM_CONTENT user_content_prefix = DEFAULT_USER_CONTENT_PREFIX - main() + main() \ No newline at end of file diff --git a/examples/gigpo_trainer/run_search.sh b/examples/gigpo_trainer/run_search.sh new file mode 100644 index 00000000..9c294b1d --- /dev/null +++ b/examples/gigpo_trainer/run_search.sh @@ -0,0 +1,69 @@ +set -x + +ENGINE=${1:-vllm} +export CUDA_VISIBLE_DEVICES=4,5 + +train_data_size=256 +val_data_size=512 +group_size=8 +mode="mean_std_norm" # "mean_norm" or "mean_std_norm" + +TRAIN_DATA="$HOME/data/searchR1_processed_direct/train.parquet" +VAL_DATA="$HOME/data/searchR1_processed_direct/test.parquet" + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=gigpo \ + data.train_files=$TRAIN_DATA \ + data.val_files=$VAL_DATA \ + data.train_batch_size=$train_data_size \ + data.val_batch_size=$val_data_size \ + data.max_prompt_length=4096 \ + data.max_response_length=500 \ + data.filter_overlong_prompts=True \ + data.truncation='left' \ + data.return_raw_chat=True \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-3B \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=512 \ + actor_rollout_ref.actor.use_adaptive_ppo_mini_batch_size=True \ + actor_rollout_ref.actor.ppo_mini_update_num=8 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=$ENGINE \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.free_cache_engine=False \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.use_invalid_action_penalty=True \ + actor_rollout_ref.actor.invalid_action_penalty_coef=0.01 \ + algorithm.use_kl_in_reward=False \ + algorithm.gamma=0.95 \ + algorithm.gigpo.step_advantage_w=1.0 \ + algorithm.gigpo.sim_thresh=0.8 \ + algorithm.gigpo.mode=$mode \ + env.env_name=search \ + env.seed=0 \ + env.max_steps=4 \ + env.rollout.n=$group_size \ + env.history_length=4 \ + trainer.critic_warmup=0 \ + trainer.logger=['console','wandb'] \ + trainer.project_name='verl_agent_search' \ + trainer.experiment_name='gigpo_qwen2.5_3b_step4_sim0.8_warm0.1_upnum8' \ + trainer.n_gpus_per_node=2 \ + trainer.nnodes=1 \ + trainer.save_freq=50 \ + trainer.test_freq=150 \ + trainer.total_epochs=1 \ + trainer.val_before_train=False $@ \ No newline at end of file diff --git a/examples/grpo_trainer/run_search.sh b/examples/grpo_trainer/run_search.sh new file mode 100644 index 00000000..90672aa7 --- /dev/null +++ b/examples/grpo_trainer/run_search.sh @@ -0,0 +1,64 @@ +set -x +ENGINE=${1:-vllm} + +export CUDA_VISIBLE_DEVICES=2,3 + +train_data_size=256 +val_data_size=512 +group_size=8 + +TRAIN_DATA="$HOME/data/searchR1_processed_direct/train.parquet" +VAL_DATA="$HOME/data/searchR1_processed_direct/test.parquet" + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$TRAIN_DATA \ + data.val_files=$VAL_DATA \ + data.train_batch_size=$train_data_size \ + data.val_batch_size=$val_data_size \ + data.max_prompt_length=4096 \ + data.max_response_length=500 \ + data.filter_overlong_prompts=True \ + data.truncation='left' \ + data.return_raw_chat=True \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-3B \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=512 \ + actor_rollout_ref.actor.use_adaptive_ppo_mini_batch_size=True \ + actor_rollout_ref.actor.ppo_mini_update_num=8 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=$ENGINE \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.free_cache_engine=False \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.use_invalid_action_penalty=True \ + actor_rollout_ref.actor.invalid_action_penalty_coef=0.01 \ + algorithm.use_kl_in_reward=False \ + env.env_name=search \ + env.seed=0 \ + env.max_steps=4 \ + env.rollout.n=$group_size \ + env.history_length=4 \ + trainer.critic_warmup=0 \ + trainer.logger=['console','wandb'] \ + trainer.project_name='verl_agent_search' \ + trainer.experiment_name='grpo_qwen2.5_3b_step4_warm0.1_upnum8' \ + trainer.n_gpus_per_node=2 \ + trainer.nnodes=1 \ + trainer.save_freq=50 \ + trainer.test_freq=150 \ + trainer.total_epochs=1 \ + trainer.val_before_train=False $@ \ No newline at end of file diff --git a/examples/multi_agent_trainer/run_webshop.sh b/examples/multi_agent_trainer/run_webshop.sh index bd87be13..1e7b7c75 100644 --- a/examples/multi_agent_trainer/run_webshop.sh +++ b/examples/multi_agent_trainer/run_webshop.sh @@ -2,18 +2,25 @@ set -x ENGINE=${1:-vllm} export CUDA_VISIBLE_DEVICES=4,5 +multi_agent=True +agent_list='["Reflexion Agent","Action Agent","Memory Agent"]' + train_data_size=16 val_data_size=128 group_size=8 -multi_agent=True -agent_list='["Action Agent","Memory Agent"]' +algorithm=grpo +gigpo_mode=mean_std_norm # "mean_norm" or "mean_std_norm" +model=Qwen/Qwen3-4B-Instruct-2507 + +if [ "$multi_agent" = "True" ]; then + agent_name_tag=$(echo "$agent_list" | jq -r '.[]' | sed 's/ Agent//g' | paste -sd+ -) +else + agent_name_tag="Single" +fi -algorithm=gigpo -mode=mean_std_norm # "mean_norm" or "mean_std_norm" -model=Qwen/Qwen2.5-3B +experiment_name="${algorithm}_$(basename $model)_${group_size}group_${agent_name_tag}" -experiment_name="${algorithm}_$(basename $model)_${group_size}group_${mode}_ma_${multi_agent}_AM" # We only use data preparation to indicate the modality and the data size. python3 -m examples.data_preprocess.prepare \ @@ -30,15 +37,16 @@ python3 -m verl.trainer.main_ppo \ data.max_prompt_length=5120 \ data.max_response_length=512 \ data.filter_overlong_prompts=True \ - data.truncation='error' \ + data.truncation='left' \ data.return_raw_chat=True \ actor_rollout_ref.model.path=$model \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=192 \ + actor_rollout_ref.actor.use_adaptive_ppo_mini_batch_size=True \ + actor_rollout_ref.actor.ppo_mini_update_num=10 \ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.01 \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ actor_rollout_ref.actor.kl_loss_type=low_var_kl \ actor_rollout_ref.model.enable_gradient_checkpointing=True \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ @@ -59,13 +67,14 @@ python3 -m verl.trainer.main_ppo \ algorithm.use_kl_in_reward=False \ algorithm.gamma=0.985 \ algorithm.gigpo.step_advantage_w=1.0 \ - algorithm.gigpo.mode=$mode \ + algorithm.gigpo.mode=$gigpo_mode \ env.env_name=Webshop \ env.seed=0 \ env.max_steps=15 \ env.rollout.n=$group_size \ agent.agent_list="$agent_list" \ agent.multi_agent=$multi_agent \ + agent.use_agent_memory=True \ trainer.critic_warmup=0 \ trainer.logger=['console','wandb'] \ trainer.project_name='verl_multiagent_webshop' \ diff --git a/examples/search/README.md b/examples/search/README.md new file mode 100644 index 00000000..17cd5f68 --- /dev/null +++ b/examples/search/README.md @@ -0,0 +1,61 @@ +# SearchR1 Replication Setup Instructions + +We provide scripts to reproduce our results for training a multi-turn search agent using the dataset and recipe from [SearchR1](https://raw.githubusercontent.com/PeterGriffinJin/Search-R1/refs/heads/main/docs/retriever.md). + +Additional Reference: [Verl+Sglang Instructions](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/multi-turn/tool_examples/verl-multiturn-searchR1-like.md). + +## Prepare Datasets +```bash +local_dir=~/data/searchR1 +uv run --isolated examples/search/searchr1_dataset.py --local_dir $local_dir +``` + +# Start the Search Engine +Since faiss-gpu is not available via pip, we setup a separate conda environment for the local retrieval server. Running this server will use around 6GB of GPU memory per GPU, so make sure to account for this in your training run configuration. + +## Retriever environments +```bash +# Create and activate the retriever environment with Python 3.10 +conda create -n retriever python=3.10 -y +conda activate retriever + +# Install PyTorch (with GPU support) and related libraries +conda install numpy==1.26.4 # needed to stop incompatible version of numpy from being installed via pip +pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124 + +# Install other Python packages +pip install transformers datasets pyserini huggingface_hub + +# Install the GPU version of faiss +conda install faiss-gpu==1.8.0 -c pytorch -c nvidia -y + +# Install the API service framework +pip install uvicorn fastapi +``` + +## Download the Index +```bash +conda activate retriever + +local_dir=~/data/searchR1 +python examples/search/searchr1_download.py --local_dir $local_dir +cat $local_dir/part_* > $local_dir/e5_Flat.index +gzip -d $local_dir/wiki-18.jsonl.gz +``` + +## Start the Local Flat e5 Retrieval Server +```bash +conda activate retriever + +# redirect the output to a file to avoid cluttering the terminal +# we have observed outputting to the terminal causing spikes in server response times +bash examples/search/retriever/retrieval_launch.sh > retrieval_server.log +``` + +## Launch your Training Job +Now from your base environment, you can launch your training run (which will use uv to package dependencies, separately from the retriever environment). + +```bash + export WANDB_API_KEY=your_wandb_api_key + bash examples/search/run_search.sh +``` diff --git a/examples/search/retriever/retrieval_launch.sh b/examples/search/retriever/retrieval_launch.sh new file mode 100644 index 00000000..d556845a --- /dev/null +++ b/examples/search/retriever/retrieval_launch.sh @@ -0,0 +1,15 @@ +save_path=$HOME/data/searchR1 + +index_file=$save_path/e5_Flat.index +corpus_file=$save_path/wiki-18.jsonl +retriever_name=e5 +retriever_path=intfloat/e5-base-v2 + +python examples/search/retriever/retrieval_server.py \ + --index_path $index_file \ + --corpus_path $corpus_file \ + --topk 3 \ + --retriever_name $retriever_name \ + --retriever_model $retriever_path \ + --faiss_gpu \ + --port 8000 \ \ No newline at end of file diff --git a/examples/search/retriever/retrieval_server.py b/examples/search/retriever/retrieval_server.py new file mode 100644 index 00000000..e280f45f --- /dev/null +++ b/examples/search/retriever/retrieval_server.py @@ -0,0 +1,389 @@ +import json +import warnings +from typing import List, Optional +import argparse + +import faiss +import torch +import numpy as np +from transformers import AutoTokenizer, AutoModel +import datasets + +import uvicorn +from fastapi import FastAPI +from pydantic import BaseModel + + +def load_corpus(corpus_path: str): + corpus = datasets.load_dataset("json", data_files=corpus_path, split="train", num_proc=4) + return corpus + + +def read_jsonl(file_path): + data = [] + with open(file_path, "r") as f: + for line in f: + data.append(json.loads(line)) + return data + + +def load_docs(corpus, doc_idxs): + results = [corpus[int(idx)] for idx in doc_idxs] + return results + + +def load_model(model_path: str, use_fp16: bool = False): + # model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + model = AutoModel.from_pretrained(model_path, trust_remote_code=True) + model.eval() + model.cuda() + if use_fp16: + model = model.half() + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=True) + return model, tokenizer + + +def pooling(pooler_output, last_hidden_state, attention_mask=None, pooling_method="mean"): + if pooling_method == "mean": + last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0) + return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] + elif pooling_method == "cls": + return last_hidden_state[:, 0] + elif pooling_method == "pooler": + return pooler_output + else: + raise NotImplementedError("Pooling method not implemented!") + + +class Encoder: + def __init__(self, model_name, model_path, pooling_method, max_length, use_fp16): + self.model_name = model_name + self.model_path = model_path + self.pooling_method = pooling_method + self.max_length = max_length + self.use_fp16 = use_fp16 + + self.model, self.tokenizer = load_model(model_path=model_path, use_fp16=use_fp16) + self.model.eval() + + @torch.no_grad() + def encode(self, query_list: List[str], is_query=True) -> np.ndarray: + # processing query for different encoders + if isinstance(query_list, str): + query_list = [query_list] + + if "e5" in self.model_name.lower(): + if is_query: + query_list = [f"query: {query}" for query in query_list] + else: + query_list = [f"passage: {query}" for query in query_list] + + if "bge" in self.model_name.lower(): + if is_query: + query_list = [ + f"Represent this sentence for searching relevant passages: {query}" for query in query_list + ] + + inputs = self.tokenizer( + query_list, max_length=self.max_length, padding=True, truncation=True, return_tensors="pt" + ) + inputs = {k: v.cuda() for k, v in inputs.items()} + + if "T5" in type(self.model).__name__: + # T5-based retrieval model + decoder_input_ids = torch.zeros((inputs["input_ids"].shape[0], 1), dtype=torch.long).to( + inputs["input_ids"].device + ) + output = self.model(**inputs, decoder_input_ids=decoder_input_ids, return_dict=True) + query_emb = output.last_hidden_state[:, 0, :] + else: + output = self.model(**inputs, return_dict=True) + query_emb = pooling( + output.pooler_output, output.last_hidden_state, inputs["attention_mask"], self.pooling_method + ) + if "dpr" not in self.model_name.lower(): + query_emb = torch.nn.functional.normalize(query_emb, dim=-1) + + query_emb = query_emb.detach().cpu().numpy() + query_emb = query_emb.astype(np.float32, order="C") + + return query_emb + + +class BaseRetriever: + def __init__(self, config): + self.config = config + self.retrieval_method = config.retrieval_method + self.topk = config.retrieval_topk + + self.index_path = config.index_path + self.corpus_path = config.corpus_path + + def _search(self, query: str, num: int, return_score: bool): + raise NotImplementedError + + def _batch_search(self, query_list: List[str], num: int, return_score: bool): + raise NotImplementedError + + def search(self, query: str, num: int = None, return_score: bool = False): + return self._search(query, num, return_score) + + def batch_search(self, query_list: List[str], num: int = None, return_score: bool = False): + return self._batch_search(query_list, num, return_score) + + +class BM25Retriever(BaseRetriever): + def __init__(self, config): + super().__init__(config) + from pyserini.search.lucene import LuceneSearcher + + self.searcher = LuceneSearcher(self.index_path) + self.contain_doc = self._check_contain_doc() + if not self.contain_doc: + self.corpus = load_corpus(self.corpus_path) + self.max_process_num = 8 + + def _check_contain_doc(self): + return self.searcher.doc(0).raw() is not None + + def _search(self, query: str, num: int = None, return_score: bool = False): + if num is None: + num = self.topk + hits = self.searcher.search(query, num) + if len(hits) < 1: + if return_score: + return [], [] + else: + return [] + scores = [hit.score for hit in hits] + if len(hits) < num: + warnings.warn("Not enough documents retrieved!") + else: + hits = hits[:num] + + if self.contain_doc: + all_contents = [json.loads(self.searcher.doc(hit.docid).raw())["contents"] for hit in hits] + results = [ + { + "title": content.split("\n")[0].strip('"'), + "text": "\n".join(content.split("\n")[1:]), + "contents": content, + } + for content in all_contents + ] + else: + results = load_docs(self.corpus, [hit.docid for hit in hits]) + + if return_score: + return results, scores + else: + return results + + def _batch_search(self, query_list: List[str], num: int = None, return_score: bool = False): + results = [] + scores = [] + for query in query_list: + item_result, item_score = self._search(query, num, True) + results.append(item_result) + scores.append(item_score) + if return_score: + return results, scores + else: + return results + + +class DenseRetriever(BaseRetriever): + def __init__(self, config): + super().__init__(config) + self.index = faiss.read_index(self.index_path) + if config.faiss_gpu: + co = faiss.GpuMultipleClonerOptions() + co.useFloat16 = True + co.shard = True + self.index = faiss.index_cpu_to_all_gpus(self.index, co=co) + + self.corpus = load_corpus(self.corpus_path) + self.encoder = Encoder( + model_name=self.retrieval_method, + model_path=config.retrieval_model_path, + pooling_method=config.retrieval_pooling_method, + max_length=config.retrieval_query_max_length, + use_fp16=config.retrieval_use_fp16, + ) + self.topk = config.retrieval_topk + self.batch_size = config.retrieval_batch_size + + def _search(self, query: str, num: int = None, return_score: bool = False): + if num is None: + num = self.topk + query_emb = self.encoder.encode(query) + scores, idxs = self.index.search(query_emb, k=num) + idxs = idxs[0] + scores = scores[0] + results = load_docs(self.corpus, idxs) + if return_score: + return results, scores + else: + return results + + def _batch_search(self, query_list: List[str], num: int = None, return_score: bool = False): + if isinstance(query_list, str): + query_list = [query_list] + if num is None: + num = self.topk + + results = [] + scores = [] + for start_idx in range(0, len(query_list), self.batch_size): + query_batch = query_list[start_idx : start_idx + self.batch_size] + batch_emb = self.encoder.encode(query_batch) + batch_scores, batch_idxs = self.index.search(batch_emb, k=num) + + batch_scores = batch_scores.tolist() + batch_idxs = batch_idxs.tolist() + # load_docs is not vectorized, but is a python list approach + flat_idxs = sum(batch_idxs, []) + batch_results = load_docs(self.corpus, flat_idxs) + # chunk them back + batch_results = [batch_results[i * num : (i + 1) * num] for i in range(len(batch_idxs))] + results.extend(batch_results) + scores.extend(batch_scores) + if return_score: + return results, scores + else: + return results + + +def get_retriever(config): + if config.retrieval_method == "bm25": + return BM25Retriever(config) + else: + return DenseRetriever(config) + + +##################################### +# FastAPI server below +##################################### + + +class Config: + """ + Minimal config class (simulating your argparse) + Replace this with your real arguments or load them dynamically. + """ + + def __init__( + self, + retrieval_method: str = "bm25", + retrieval_topk: int = 10, + index_path: str = "./index/bm25", + corpus_path: str = "./data/corpus.jsonl", + dataset_path: str = "./data", + data_split: str = "train", + faiss_gpu: bool = True, + retrieval_model_path: str = "./model", + retrieval_pooling_method: str = "mean", + retrieval_query_max_length: int = 256, + retrieval_use_fp16: bool = False, + retrieval_batch_size: int = 128, + ): + self.retrieval_method = retrieval_method + self.retrieval_topk = retrieval_topk + self.index_path = index_path + self.corpus_path = corpus_path + self.dataset_path = dataset_path + self.data_split = data_split + self.faiss_gpu = faiss_gpu + self.retrieval_model_path = retrieval_model_path + self.retrieval_pooling_method = retrieval_pooling_method + self.retrieval_query_max_length = retrieval_query_max_length + self.retrieval_use_fp16 = retrieval_use_fp16 + self.retrieval_batch_size = retrieval_batch_size + + +class QueryRequest(BaseModel): + query: str + topk: Optional[int] = None + return_scores: bool = False + + +app = FastAPI() + + +@app.post("/retrieve") +def retrieve_endpoint(request: QueryRequest): + """ + Endpoint that accepts a single query and performs retrieval. + Input format: + { + "query": "What is Python?", + "topk": 3, + "return_scores": true + } + """ + if not request.topk: + request.topk = config.retrieval_topk # fallback to default + + # Perform retrieval + if request.return_scores: + results, scores = retriever.search(query=request.query, num=request.topk, return_score=True) + else: + results = retriever.search(query=request.query, num=request.topk, return_score=False) + scores = None + + # Format response + resp = [] + if request.return_scores and scores is not None: + # If scores are returned, combine them with results + combined = [] + for doc, score in zip(results, scores): + # Convert numpy float32 to regular Python float for JSON serialization + combined.append({"document": doc, "score": float(score)}) + resp.append(combined) + else: + resp.append(results) + return {"result": resp} + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description="Launch the local faiss retriever.") + parser.add_argument( + "--index_path", type=str, default="/home/peterjin/mnt/index/wiki-18/e5_Flat.index", help="Corpus indexing file." + ) + parser.add_argument( + "--corpus_path", + type=str, + default="/home/peterjin/mnt/data/retrieval-corpus/wiki-18.jsonl", + help="Local corpus file.", + ) + parser.add_argument("--topk", type=int, default=3, help="Number of retrieved passages for one query.") + parser.add_argument("--retriever_name", type=str, default="e5", help="Name of the retriever model.") + parser.add_argument( + "--retriever_model", type=str, default="intfloat/e5-base-v2", help="Path of the retriever model." + ) + parser.add_argument("--faiss_gpu", action="store_true", help="Use GPU for computation") + parser.add_argument("--port", type=int, default=8000, help="Port to run the FastAPI server on.") + + args = parser.parse_args() + + # 1) Build a config (could also parse from arguments). + # In real usage, you'd parse your CLI arguments or environment variables. + config = Config( + retrieval_method=args.retriever_name, # or "dense" + index_path=args.index_path, + corpus_path=args.corpus_path, + retrieval_topk=args.topk, + faiss_gpu=args.faiss_gpu, + retrieval_model_path=args.retriever_model, + retrieval_pooling_method="mean", + retrieval_query_max_length=256, + retrieval_use_fp16=True, + retrieval_batch_size=512, # this is unused in the current retrieval implementation, which only supports single query + ) + + # 2) Instantiate a global retriever so it is loaded once and reused. + retriever = get_retriever(config) + + # 3) Launch the server. By default, it listens on http://127.0.0.1:8000 + uvicorn.run(app, host="0.0.0.0", port=args.port) diff --git a/examples/search/searchr1_download.py b/examples/search/searchr1_download.py new file mode 100644 index 00000000..79dc4650 --- /dev/null +++ b/examples/search/searchr1_download.py @@ -0,0 +1,25 @@ +import argparse +from huggingface_hub import hf_hub_download + +parser = argparse.ArgumentParser(description="Download files from a Hugging Face dataset repository.") +parser.add_argument("--repo_id", type=str, default="PeterJinGo/wiki-18-e5-index", help="Hugging Face repository ID") +parser.add_argument("--local_dir", type=str, required=True, help="Local directory to save files") + +args = parser.parse_args() + +repo_id = "PeterJinGo/wiki-18-e5-index" +for file in ["part_aa", "part_ab"]: + hf_hub_download( + repo_id=repo_id, + filename=file, # e.g., "e5_Flat.index" + repo_type="dataset", + local_dir=args.local_dir, + ) + +repo_id = "PeterJinGo/wiki-18-corpus" +hf_hub_download( + repo_id=repo_id, + filename="wiki-18.jsonl.gz", + repo_type="dataset", + local_dir=args.local_dir, +) diff --git a/gigpo/core_gigpo.py b/gigpo/core_gigpo.py index 7c5bfd34..a6c9df2a 100644 --- a/gigpo/core_gigpo.py +++ b/gigpo/core_gigpo.py @@ -9,6 +9,9 @@ from verl import DataProto import uuid +from difflib import SequenceMatcher +from typing import Sequence, List, Dict, Any + # ---------------------------------------------------------- # # --------------- General Functions of GiGPO --------------- # # ---------------------------------------------------------- # @@ -49,6 +52,9 @@ def summarize_group_size(group_size: list): if prop: print(f"{size:>4} | {cnt:>5} | {prop:>9.2%}") +def are_similar(a: str, b: str, threshold: float = 0.95) -> bool: + """Return True if similarity ratio ≥ threshold.""" + return SequenceMatcher(None, a, b).ratio() >= threshold def compute_step_discounted_returns(batch: DataProto, gamma: float): rewards = batch.non_tensor_batch['rewards'].astype(np.float32) @@ -99,7 +105,8 @@ def compute_gigpo_outcome_advantage(token_level_rewards: torch.Tensor, traj_index: np.array, epsilon: float = 1e-6, step_advantage_w: float = 1.0, - mode: str = "mean_norm" + mode: str = "mean_norm", + sim_thresh: float = 1.0, ): if mode == "mean_std_norm": @@ -113,7 +120,7 @@ def compute_gigpo_outcome_advantage(token_level_rewards: torch.Tensor, episode_advantages = episode_norm_reward(token_level_rewards, response_mask, index, traj_index, epsilon, remove_std) # Compute step_group_uids - step_group_uids = build_step_group(anchor_obs, index) + step_group_uids = build_step_group(anchor_obs, index, sim_thresh) # Compute step-level group reward step_advantages = step_norm_reward(step_rewards, response_mask, step_group_uids, epsilon, remove_std) @@ -191,7 +198,7 @@ def episode_norm_reward(token_level_rewards: torch.Tensor, return episode_advantages -def build_step_group(anchor_obs: np.array, index: np.array, summarize: bool = False): +def build_step_group(anchor_obs: np.array, index: np.array, sim_thresh: float = 1.0, summarize: bool = False): """ Group observations by index and then cluster identical observations within each index group. Assigns a unique step_group_uid (UUID) to each cluster. @@ -204,40 +211,71 @@ def build_step_group(anchor_obs: np.array, index: np.array, summarize: bool = Fa Array of episode_group_uid summarize : bool Whether to summarize the group sizes (default: True) + sim_thresh : float + Threshold for similarity to consider two observations as identical (default: 1.0, meaning exact match) Returns: -------- np.array Array of step_group_uid values corresponding to the original anchor_obs array """ + assert sim_thresh >= 0.0 and sim_thresh <= 1.0, "sim_thresh should be in [0, 1]" + # Initialize the result array with placeholder values step_group_uids = np.empty(len(anchor_obs), dtype=object) # Get unique indices unique_indices = np.unique(index) - group_size = [] + group_size: List[int] = [] # Process each unique index for idx in unique_indices: - # Get all observations for this index using np.where - indices = np.where(index == idx)[0] - obs_group = anchor_obs[indices] - - # Create clusters for identical observations - clusters = defaultdict(list) - for i, obs in enumerate(obs_group): - clusters[to_hashable(obs)].append(indices[i]) # Store the original index position - - # Assign unique step_group_uid to each cluster - for obs, original_indices in clusters.items(): - # Generate a UUID for this cluster - uid = str(uuid.uuid4()) + if sim_thresh == 1.0: + # Get all observations for this index using np.where + indices = np.where(index == idx)[0] + obs_group = anchor_obs[indices] - # Assign the same step_group_uid to all elements in this cluster - group_size.append(len(original_indices)) - for original_idx in original_indices: - step_group_uids[original_idx] = uid + # Create clusters for identical observations + clusters = defaultdict(list) + for i, obs in enumerate(obs_group): + clusters[to_hashable(obs)].append(indices[i]) # Store the original index position + + # Assign unique step_group_uid to each cluster + for obs, original_indices in clusters.items(): + # Generate a UUID for this cluster + uid = str(uuid.uuid4()) + + # Assign the same step_group_uid to all elements in this cluster + group_size.append(len(original_indices)) + for original_idx in original_indices: + step_group_uids[original_idx] = uid + elif sim_thresh < 1.0 and sim_thresh >= 0.0: + locs = np.where(index == idx)[0] + obs_group = anchor_obs[locs] + + # 动态维护簇:[{rep: str, locs: List[int]} ...] + clusters: List[Dict[str, Any]] = [] + + for obs, loc in zip(obs_group, locs): + # 尝试放入已有簇 + placed = False + for cluster in clusters: + if are_similar(obs, cluster["rep"], sim_thresh): + cluster["locs"].append(loc) + placed = True + break + # 若没有匹配簇,则建立新簇 + if not placed: + clusters.append({"rep": obs, "locs": [loc]}) + # 为每个簇分配 UUID + for cluster in clusters: + uid = str(uuid.uuid4()) + group_size.append(len(cluster["locs"])) + for loc in cluster["locs"]: + step_group_uids[loc] = uid + else: + raise ValueError(f"sim_thresh should be in [0, 1], but got {sim_thresh}") # Validate that all elements have been assigned a uid if None in step_group_uids or np.any(step_group_uids == None): missing_indices = np.where(step_group_uids == None)[0] diff --git a/requirements-npu.txt b/requirements-npu.txt index 601e8f9f..c978f8fc 100644 --- a/requirements-npu.txt +++ b/requirements-npu.txt @@ -11,7 +11,7 @@ pyarrow>=15.0.0 pybind11 pylatexenc ray -tensordict<=0.6.2 +tensordict<=0.10.0 transformers>=4.52.0 wandb mathruler diff --git a/requirements.txt b/requirements.txt index 3d2d80e2..e7045963 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,7 @@ pybind11 pylatexenc pre-commit ray[default] -tensordict<=0.6.2 +tensordict<=0.10.0 torchdata transformers==4.51.1 # vllm==0.8.4 diff --git a/requirements_sglang.txt b/requirements_sglang.txt index 57d5e0be..9f0cd109 100644 --- a/requirements_sglang.txt +++ b/requirements_sglang.txt @@ -12,7 +12,7 @@ pyarrow>=19.0.0 pybind11 pylatexenc ray[default]>=2.10 -tensordict<=0.6.2 +tensordict<=0.10.0 torchdata torchvision transformers diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 23e36d03..e5014442 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -41,6 +41,8 @@ actor_rollout_ref: trust_remote_code: False actor: strategy: fsdp # [fsdp, fsdp2], This is for backward-compatibility + use_adaptive_ppo_mini_batch_size: False + ppo_mini_update_num: 10 ppo_mini_batch_size: 256 ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu ppo_micro_batch_size_per_gpu: null @@ -247,6 +249,7 @@ algorithm: gigpo: step_advantage_w: 1.0 mode: "mean_norm" # "mean_norm" or "mean_std_norm" + sim_thresh: 1.0 # similarity threshold for gigpo [0.0, 1.0] filter_groups: # DAPO from https://arxiv.org/abs/2503.14476 enable: False max_num_gen_batches: 10 @@ -283,10 +286,11 @@ ray_init: num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then. -agent: +agent: # only for multi-agent training multi_agent: False # True for multi-agent training, False for single-agent training agent_list: ["Reflexion Agent", "Action Agent", "Memory Agent"] # ["Reflexion Agent", "Planning Agent", "Action Agent", "Memory Agent"] executor_type: chain + use_agent_memory: False env: @@ -305,3 +309,8 @@ env: webshop: use_small: True human_goals: False + search: + log_requests: false + search_url: "http://127.0.0.1:8000/retrieve" # also support multiple urls: ["http://127.0.0.1:8000/retrieve", "http://127.0.0.1:8001/retrieve"] + topk: 3 + timeout: 60 diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index de04edcb..ab06436f 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -241,7 +241,7 @@ def compute_response_mask(data: DataProto): return attention_mask[:, -response_length:] -def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_repeat=1, multi_turn=False, norm_adv_by_std_in_grpo=True, step_advantage_w=1.0, gigpo_mode="mean_std_norm", **kwargs): +def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_repeat=1, multi_turn=False, norm_adv_by_std_in_grpo=True, step_advantage_w=1.0, gigpo_mode="mean_std_norm", sim_thresh=1.0, **kwargs): """Compute advantage estimates for policy optimization. This function computes advantage estimates using various estimators like GAE, GRPO, REINFORCE++, etc. @@ -352,6 +352,7 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re traj_index=data.non_tensor_batch['traj_uid'], step_advantage_w=step_advantage_w, mode=gigpo_mode, + sim_thresh=sim_thresh, ) data.batch['advantages'] = advantages data.batch['returns'] = returns @@ -1202,6 +1203,7 @@ def fit(self): pf_ppo_weight_pow=self.config.algorithm.pf_ppo.weight_pow, step_advantage_w=self.config.algorithm.gigpo.step_advantage_w, gigpo_mode=self.config.algorithm.gigpo.mode, + sim_thresh=self.config.algorithm.gigpo.sim_thresh, ) # update critic diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index 81d701eb..3be6713e 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -328,6 +328,11 @@ def update_policy(self, data: DataProto): batch = data.select(batch_keys=select_keys).batch has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() + if self.config.use_adaptive_ppo_mini_batch_size: + print(f"[update_policy] pre self.config.ppo_mini_batch_size: {self.config.ppo_mini_batch_size}") + self.config.ppo_mini_batch_size = data.meta_info.get("ppo_mini_batch_size", self.config.ppo_mini_batch_size) + print(f"[update_policy] post self.config.ppo_mini_batch_size: {self.config.ppo_mini_batch_size}") + # Split to make minibatch iterator for updating the actor # See PPO paper for details. https://arxiv.org/abs/1707.06347 if has_multi_modal_inputs: