diff --git a/agents/agents/agents/__init__.py b/agents/agents/agents/__init__.py index 75c339d..6152c72 100644 --- a/agents/agents/agents/__init__.py +++ b/agents/agents/agents/__init__.py @@ -1,5 +1,6 @@ from .react.react_agent import ReactAgent from .specialized.code_agent import CodeAgent from .specialized.think_agent import ThinkAgent +from .specialized.gui_agent import GUIAgent -__all__ = ["ReactAgent", "CodeAgent", "ThinkAgent"] \ No newline at end of file +__all__ = ["ReactAgent", "CodeAgent", "ThinkAgent", "GUIAgent"] \ No newline at end of file diff --git a/agents/agents/agents/agent_base.py b/agents/agents/agents/agent_base.py index 43b01f2..2a57ecd 100644 --- a/agents/agents/agents/agent_base.py +++ b/agents/agents/agents/agent_base.py @@ -157,6 +157,7 @@ def _init_llm_engine(self, model_name_or_path: str, backend: str): def set_llm_engine(self, llm_engine: Any, tokenizer: Any, processor: Any): assert self.backend == "async_verl", "Only async verl backend is supported for now" + self.llm_engine.llm_engine = llm_engine self.tokenizer = tokenizer self.processor = processor diff --git a/agents/agents/agents/auto.py b/agents/agents/agents/auto.py index 84050f4..c62e8f7 100644 --- a/agents/agents/agents/auto.py +++ b/agents/agents/agents/auto.py @@ -1,11 +1,12 @@ from typing import Any, Callable, Dict, List, Optional, Type, Union from .specialized.think_agent import ThinkAgent -from agents.agents.specialized.openai_agent import OpenAIAgent +from .specialized.openai_agent import OpenAIAgent from ..tools import get_tools_from_names from .agent_base import BaseAgent from .react.react_agent import ReactAgent from .specialized.code_agent import CodeAgent +from .specialized.gui_agent import GUIAgent from ..rewards.reward_base import get_reward_from_name @@ -165,4 +166,5 @@ def from_pretrained( AutoAgent.register_agent("react", ReactAgent) AutoAgent.register_agent("code", CodeAgent) AutoAgent.register_agent("openai", OpenAIAgent) -AutoAgent.register_agent("think", ThinkAgent) \ No newline at end of file +AutoAgent.register_agent("think", ThinkAgent) +AutoAgent.register_agent("gui", GUIAgent) \ No newline at end of file diff --git a/agents/agents/agents/chain/chain_base.py b/agents/agents/agents/chain/chain_base.py index 02fa27d..94e9e8a 100644 --- a/agents/agents/agents/chain/chain_base.py +++ b/agents/agents/agents/chain/chain_base.py @@ -312,6 +312,10 @@ async def _run_single_chain(self, ) thought_node.is_terminal = new_msg.get("status", "continue") in self.terminal_status current_node = thought_node + + # Check if the thought node is terminal - if so, break the loop + if current_node.is_terminal: + break # Handle tool calls if current_node.messages[-1].get("tool_calls"): diff --git a/agents/agents/agents/specialized/gui_agent.py b/agents/agents/agents/specialized/gui_agent.py new file mode 100644 index 0000000..e9dd812 --- /dev/null +++ b/agents/agents/agents/specialized/gui_agent.py @@ -0,0 +1,226 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: Apache-2.0 + +import json +from typing import List, Any, Tuple, Dict, Optional +from ..agent_base import BaseAgent +from agents.utils.ui_action_parser import parse_action_to_structure_output, IMAGE_FACTOR + +# Default image dimensions +TEST_IMAGE_HEIGHT = 1080 +TEST_IMAGE_WIDTH = 1920 + +# GUI Agent system prompt +GUI_AGENT_SYSTEM_PROMPT = """You are a GUI automation agent. Given a task and screenshot, you must analyze the screen and perform the next action. + +## Response Format (REQUIRED) +You MUST always respond with exactly two lines: +Thought: [Describe what you see and what action to take] +Action: [Choose ONE action from the list below] + +## Action Space + +click(start_box='<|box_start|>(x1,y1)<|box_end|>') +type(content='xxx') # Use escape characters \\', \\", and \\n in content part +scroll(direction='down or up or right or left') + +## Examples +Example 1: +Thought: I need to click on the search button at coordinates (100, 200). +Action: click(start_box='<|box_start|>(100,200)<|box_end|>') + +Example 2: +Thought: I need to type "hello world" in the text field. +Action: type(content='hello world') + +## Note +- Use English in `Thought` and `Action` part +- Always provide both Thought and Action lines +- Coordinates should be in pixel values""" + + +class GUIAgent(BaseAgent): + """GUI Agent for interacting with graphical user interfaces.""" + + def __init__( + self, + model_name_or_path: str, + template: str, + tools: List = None, + **kwargs + ): + """ + Initialize GUI Agent. + + Args: + model_name_or_path: Path to the vision-language model + template: Template name for formatting prompts + tools: List of tools available to the agent + **kwargs: Additional arguments + """ + super().__init__( + model_name_or_path=model_name_or_path, + template=template, + system_prompt=GUI_AGENT_SYSTEM_PROMPT, + tools=tools, + max_length=kwargs.get("max_length", 8192), + **kwargs + ) + self.action_counter = 0 # Track number of actions taken + self.max_retries = 3 # Maximum retries for empty responses + + # def _init_llm_engine(self, model_name_or_path: str, backend: str = "vllm"): + # """ + # Override to handle vision-language models properly. + + # For GUI agents using vision-language models like Qwen2.5-VL, + # we need special handling since they're not standard causal LM models. + # """ + # # For unit tests or when model loading should be skipped + # # if model_name_or_path == "ByteDance-Seed/UI-TARS-1.5-7B": + # # # Return mock objects for testing + # # print(f"[GUIAgent] Skipping actual model load for testing: {model_name_or_path}") + # # return None, None, None + + # # Otherwise use parent's initialization + # return super()._init_llm_engine(model_name_or_path, backend) + + def parse(self, responses: List[str], tools: List[Any]) -> List[Dict[str, Any]]: + """ + Parse model responses into structured messages. + + Args: + responses: List of model response strings + tools: List of available tools + + Returns: + List of structured messages with tool calls + """ + print(f"[GUIAgent.parse] Number of responses: {len(responses)}") + print(f"[GUIAgent.parse] Raw responses type: {type(responses)}") + + new_messages_list = [] + + # Process each response + processed_responses = [] + for resp in responses: + if resp and "Thought:" in resp and "Action:" in resp: + processed_responses.append(resp) + elif resp and resp.strip(): + # Try to reformat responses that don't have the expected format + resp_lower = resp.lower() + print(f"[GUIAgent.parse] Response missing format, reformatting: {resp[:100]}") + + # Check if it contains action-like content + if any(action in resp_lower for action in ['click', 'type', 'scroll']): + formatted_resp = f"Thought: Executing action based on response.\nAction: {resp.strip()}" + else: + # Default to click at center if no clear action + formatted_resp = f"Thought: {resp.strip()}\nAction: click(start_box='<|box_start|>(960,540)<|box_end|>')" + processed_responses.append(formatted_resp) + else: + # Handle empty responses with default click at center + self.action_counter += 1 + processed_responses.append(f"Thought: Processing the screen (attempt {self.action_counter}).\nAction: click(start_box='<|box_start|>(960,540)<|box_end|>')") + + responses = processed_responses + + # Log responses for debugging + for idx, resp in enumerate(responses[:3]): # Log first 3 responses + if resp: + print(f"[GUIAgent.parse] Response {idx} length: {len(resp)}, preview: {resp[:200]}") + else: + print(f"[GUIAgent.parse] Response {idx} is None or empty") + + # Parse actions from responses + action_list = [] + for response in responses: + parsed = parse_action_to_structure_output( + response, + IMAGE_FACTOR, + TEST_IMAGE_HEIGHT, + TEST_IMAGE_WIDTH + ) + action_list.append(parsed) + + # Create messages with tool calls + for i, (response, actions) in enumerate(zip(responses, action_list)): + print(f"[GUIAgent.parse] Processing response {i+1}: response_length={len(response) if response else 0}, actions={actions}") + + tool_calls = [] + + if actions is not None and len(actions) > 0: + if len(actions) > 1: + print(f"[GUIAgent.parse] Warning: Multiple actions found ({len(actions)}), using first one") + action = actions[0] + tool_calls = [{ + "id": str(i), + "type": "function", + "function": { + "name": "pyautogui_code_generator", + "arguments": json.dumps({"action": action}) + } + }] + else: + # If no action was parsed, create a default click action at center + print(f"[GUIAgent.parse] No action parsed from response, creating default click action") + default_action = { + "action_type": "click", + "action_inputs": {"start_box": "(960, 540)"}, + "thought": "Clicking at screen center", + "reflection": None + } + tool_calls = [{ + "id": str(i), + "type": "function", + "function": { + "name": "pyautogui_code_generator", + "arguments": json.dumps({"action": default_action}) + } + }] + + # Always terminate after one turn since we only have 3 action types + # and no explicit termination action + status = "terminal" + if actions and isinstance(actions[0], dict): + action_type = actions[0].get("action_type", "") + print(f"[GUIAgent.parse] Action type: {action_type}, terminating after one turn") + + message = { + "role": "assistant", + "content": [{"type": "text", "text": response}] if response else [{"type": "text", "text": ""}], + "tool_calls": tool_calls, + "loss": True, + "status": status + } + print(f"[GUIAgent.parse] Created message with status={status}, tool_calls={len(tool_calls)}, content_length={len(response)}") + new_messages_list.append(message) + + return new_messages_list + + def format_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Format messages for the vision-language model. + + Args: + messages: List of messages to format + + Returns: + Formatted messages suitable for VLM input + """ + formatted_messages = [] + + for msg in messages: + formatted_msg = { + "role": msg.get("role"), + "content": msg.get("content", "") + } + + # Handle image content if present + if "images" in msg: + # Convert images to appropriate format for the model + formatted_msg["images"] = msg["images"] + + formatted_messages.append(formatted_msg) + + return formatted_messages \ No newline at end of file diff --git a/agents/agents/rewards/__init__.py b/agents/agents/rewards/__init__.py index 1d0b712..a9a6185 100644 --- a/agents/agents/rewards/__init__.py +++ b/agents/agents/rewards/__init__.py @@ -4,6 +4,7 @@ from .webshop_reward import webshop_reward from .alfworld_reward import alfworld_episode_reward from .scienceworld_reward import scienceworld_reward +from .gui_reward import gui_reward -__all__ = ["alfworld_episode_reward","qa_f1_reward", "math_reward", "math_reward_tool", "math_reward_think", "RewardFunction", "get_reward_from_name", "get_rewards_from_names", "list_available_rewards", "register_reward", "llm_as_judge_client_math_reward", "webshop_reward", "alfworld_episode_reward"] +__all__ = ["alfworld_episode_reward","qa_f1_reward", "math_reward", "math_reward_tool", "math_reward_think", "RewardFunction", "get_reward_from_name", "get_rewards_from_names", "list_available_rewards", "register_reward", "llm_as_judge_client_math_reward", "webshop_reward", "alfworld_episode_reward", "gui_reward"] \ No newline at end of file diff --git a/agents/agents/rewards/gui_reward.py b/agents/agents/rewards/gui_reward.py new file mode 100644 index 0000000..ae24d7a --- /dev/null +++ b/agents/agents/rewards/gui_reward.py @@ -0,0 +1,357 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: Apache-2.0 + +import re +import json +import ast +from typing import Dict, Any, List, Tuple, Optional + +from .reward_base import reward +from agents.utils.ui_action_parser import parse_action_to_structure_output, IMAGE_FACTOR + +# Image dimensions for testing +TEST_IMAGE_HEIGHT = 1080 +TEST_IMAGE_WIDTH = 1920 + + +def normalize_answer(s: str) -> set: + """Normalize answer string for comparison.""" + def remove_punctuation(text): + return re.sub(r"[^\w\s]", "", text) + + def lower(text): + return text.lower() + + return set(lower(remove_punctuation(s)).split()) + + +def f1_score(prediction: str, ground_truth: str) -> Tuple[float, float, float]: + """Calculate F1 score between prediction and ground truth.""" + normalized_prediction = normalize_answer(prediction) + normalized_ground_truth = normalize_answer(ground_truth) + + if not normalized_prediction and not normalized_ground_truth: + return 1.0, 1.0, 1.0 + + common_tokens = normalized_prediction.intersection(normalized_ground_truth) + + precision = len(common_tokens) / len(normalized_prediction) if normalized_prediction else 0.0 + recall = len(common_tokens) / len(normalized_ground_truth) if normalized_ground_truth else 0.0 + + f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0 + + return f1, precision, recall + + +def extract_action(content: str) -> str: + """Extract action type from response content.""" + try: + parsed = parse_action_to_structure_output(content, IMAGE_FACTOR, TEST_IMAGE_HEIGHT, TEST_IMAGE_WIDTH, model_type="default") + if parsed and len(parsed) > 0: + action_dict = parsed[0] + action_type = action_dict.get("action_type", "no action") + + # Check for specific action types in raw text + if action_type == "click" and "Action:" in content: + action_text = content.split("Action:")[1].strip() + if action_text.startswith("left_double"): + return "left_double" + elif action_text.startswith("right_single"): + return "right_single" + elif action_text.startswith("click"): + return "click" + else: + return "click" + + # Map normalized action types + if action_type == "hotkey": + return "hotkey" + elif action_type == "drag": + return "drag" + + return action_type + except Exception as e: + print(f"[extract_action] Error: {e}") + return "no action" + + +def extract_input_text(content: str) -> str: + """Extract input text from action content.""" + try: + parsed = parse_action_to_structure_output(content, IMAGE_FACTOR, TEST_IMAGE_HEIGHT, TEST_IMAGE_WIDTH, model_type="default") + if parsed and len(parsed) > 0: + action_dict = parsed[0] + action_type = action_dict.get("action_type") + action_inputs = action_dict.get("action_inputs", {}) + + # Extract text based on action type + if action_type == 'type': + return action_inputs.get('content', '') + elif action_type == 'scroll': + return action_inputs.get('direction', 'down') + elif action_type == 'hotkey': + return action_inputs.get('key', action_inputs.get('hotkey', '')) + elif action_type == 'finished': + return action_inputs.get('content', '') + + return "" + except Exception as e: + print(f"[extract_input_text] Error: {e}") + return "" + + +def extract_coord(content: str) -> Tuple[list, bool]: + """Extract coordinates from action content and normalize them to 0-1 range.""" + try: + parsed = parse_action_to_structure_output(content, IMAGE_FACTOR, TEST_IMAGE_HEIGHT, TEST_IMAGE_WIDTH, model_type="default") + if parsed and len(parsed) > 0: + action_dict = parsed[0] + action_inputs = action_dict.get("action_inputs", {}) + + # Try to get coordinates from start_box + if "start_box" in action_inputs: + try: + coords = ast.literal_eval(action_inputs["start_box"]) + # Ensure coords is a list + if isinstance(coords, (list, tuple)): + if len(coords) == 2: + # Point format [x, y] in pixels - normalize to 0-1 + normalized_coords = [ + coords[0] / TEST_IMAGE_WIDTH, + coords[1] / TEST_IMAGE_HEIGHT + ] + print(f"[extract_coord] Normalized point from {coords} to {normalized_coords}") + return normalized_coords, True + elif len(coords) == 4: + # Box format [x1, y1, x2, y2] in pixels - normalize to 0-1 + normalized_coords = [ + coords[0] / TEST_IMAGE_WIDTH, + coords[1] / TEST_IMAGE_HEIGHT, + coords[2] / TEST_IMAGE_WIDTH, + coords[3] / TEST_IMAGE_HEIGHT + ] + print(f"[extract_coord] Normalized box from {coords} to {normalized_coords}") + return normalized_coords, True + else: + print(f"[extract_coord] Unexpected coord format: {coords}") + except Exception as e: + print(f"[extract_coord] Error parsing coordinates: {e}") + + except Exception as e: + print(f"[extract_coord] Error: {e}") + return [], False + + +def gui_format_score(predict_str: str) -> float: + """Calculate format score for GUI prediction.""" + try: + parsed_actions = parse_action_to_structure_output(predict_str, IMAGE_FACTOR, TEST_IMAGE_HEIGHT, TEST_IMAGE_WIDTH, model_type="default") + return 1.0 if parsed_actions else 0.0 + except Exception: + return 0.0 + + +def gui_accuracy_score(predict_str: str, gt_action: str, gt_bbox: list, gt_input_text: str) -> float: + """Calculate accuracy score for GUI prediction (0.5 for action type, 0.5 for parameters).""" + try: + gt_action = gt_action.lower() if gt_action else '' + + pred_action = extract_action(predict_str).lower() + pred_coord, has_coord = extract_coord(predict_str) + pred_input_text = extract_input_text(predict_str) + + print(f"[gui_accuracy_score] gt_action: {gt_action}, pred_action: {pred_action}") + print(f"[gui_accuracy_score] gt_bbox: {gt_bbox}, pred_coord: {pred_coord}, has_coord: {has_coord}") + print(f"[gui_accuracy_score] gt_input_text: {gt_input_text}, pred_input_text: {pred_input_text}") + + # Map all click variants to 'click' for the 3-action space + action_mapping = { + 'left_single': 'click', + 'left_double': 'click', + 'right_single': 'click', + 'click': 'click', + 'type': 'type', + 'scroll': 'scroll' + } + + # Normalize actions + pred_action_normalized = action_mapping.get(pred_action, pred_action) + gt_action_normalized = action_mapping.get(gt_action, gt_action) + + score = 0.0 + + # Score calculation: 0.5 for action type, 0.5 for parameters (bbox/text) + + # 1. Action type matching (0.5 points) + if pred_action_normalized == gt_action_normalized: + score += 0.5 + print(f"[gui_accuracy_score] Action matched: +0.5 points") + else: + print(f"[gui_accuracy_score] Action mismatch: {pred_action_normalized} vs {gt_action_normalized}") + + # 2. Parameter matching (0.5 points) - depends on action type + if gt_action_normalized == 'click': + # For click: check bbox coordinates + if gt_bbox and len(gt_bbox) > 0: + if has_coord: + # Both have coordinates, calculate distance + # Handle different gt_bbox formats (already normalized 0-1) + if len(gt_bbox) == 2: + gt_x, gt_y = gt_bbox + elif len(gt_bbox) == 4: + gt_x = (gt_bbox[0] + gt_bbox[2]) / 2 + gt_y = (gt_bbox[1] + gt_bbox[3]) / 2 + else: + print(f"[gui_accuracy_score] Invalid gt_bbox format: {gt_bbox}") + return score + + # Get predicted center (already normalized 0-1) + if len(pred_coord) == 2: + pred_x, pred_y = pred_coord + elif len(pred_coord) == 4: + pred_x = (pred_coord[0] + pred_coord[2]) / 2 + pred_y = (pred_coord[1] + pred_coord[3]) / 2 + else: + print(f"[gui_accuracy_score] Invalid pred_coord format: {pred_coord}") + return score + + # Calculate distance in normalized space + distance = ((pred_x - gt_x) ** 2 + (pred_y - gt_y) ** 2) ** 0.5 + + # Threshold in normalized space (5% of diagonal) + threshold = 0.05 * (2 ** 0.5) # sqrt(1^2 + 1^2) ≈ 0.07 + + if distance < threshold: + score += 0.5 + print(f"[gui_accuracy_score] Bbox matched (distance={distance:.4f}): +0.5 points") + else: + print(f"[gui_accuracy_score] Bbox too far (distance={distance:.4f}, threshold={threshold:.4f})") + else: + print(f"[gui_accuracy_score] No predicted coordinates for click action") + else: + # No gt_bbox required, any click gets parameter points + score += 0.5 + print(f"[gui_accuracy_score] No gt_bbox required: +0.5 points") + + elif gt_action_normalized == 'type': + # For type: check text content + if gt_input_text and gt_input_text != "no input text": + f1, _, _ = f1_score(pred_input_text, gt_input_text) + if f1 >= 0.5: + score += 0.5 + print(f"[gui_accuracy_score] Type text matched (f1={f1:.2f}): +0.5 points") + else: + print(f"[gui_accuracy_score] Type text mismatch (f1={f1:.2f})") + else: + # No text required, any type action gets parameter points + score += 0.5 + print(f"[gui_accuracy_score] No text required: +0.5 points") + + elif gt_action_normalized == 'scroll': + # For scroll: only check direction (no bbox needed) + if gt_input_text and gt_input_text != "no input text": + if pred_input_text.lower() == gt_input_text.lower(): + score += 0.5 + print(f"[gui_accuracy_score] Scroll direction matched: +0.5 points") + else: + print(f"[gui_accuracy_score] Scroll direction mismatch: {pred_input_text} vs {gt_input_text}") + else: + # No direction specified, any scroll gets parameter points + score += 0.5 + print(f"[gui_accuracy_score] No scroll direction required: +0.5 points") + + print(f"[gui_accuracy_score] Final score: {score}") + return score + + except Exception as e: + print(f"Error in gui_accuracy_score: {e}") + print(f"predict_str: {predict_str}") + print(f"gt_action: {gt_action}, gt_bbox: {gt_bbox}, gt_input_text: {gt_input_text}") + return 0.0 + + +@reward(name="gui_reward") +def gui_reward(prediction: str, trajectory: List[Dict] = None, gt_action: str = "", gt_bbox: list = None, gt_input_text: str = "", **kwargs) -> Dict[str, float]: + """ + Calculate GUI reward based on prediction accuracy. + + Args: + prediction: Model prediction string + trajectory: Conversation trajectory (optional) + **kwargs: Additional parameters including ground truth + + Returns: + Dictionary with reward scores + """ + print(f"[gui_reward] Called with prediction: {prediction[:200] if prediction else 'None'}") + print(f"[gui_reward] kwargs keys: {list(kwargs.keys())}") + + # Handle empty predictions + if not prediction or prediction.strip() == "": + print(f"[gui_reward] Warning: Empty prediction received") + # Check if there's a default action in trajectory + if trajectory and len(trajectory) > 0: + for msg in reversed(trajectory): + if msg.get('role') == 'assistant' and msg.get('content'): + prediction = msg['content'] + print(f"[gui_reward] Using trajectory content as prediction: {prediction[:100]}") + break + + # if not prediction or prediction.strip() == "": + # prediction = "Thought: No response generated.\nAction: wait()" + # print(f"[gui_reward] Using default prediction") + + # Handle None values for parameters + if gt_bbox is None: + gt_bbox = [] + + # Convert numpy array to list if needed + if hasattr(gt_bbox, 'tolist'): + gt_bbox = gt_bbox.tolist() + + print(f"[gui_reward] gt_action: {gt_action}, gt_bbox: {gt_bbox}, gt_input_text: {gt_input_text}") + + # Handle "no input text" as empty + if gt_input_text == "no input text": + gt_input_text = "" + + # Keep bbox in normalized coordinates (0-1 range) + # Both prediction and ground truth use normalized coordinates + + if not gt_action and not gt_bbox and not gt_input_text: + print(f"[gui_reward] Warning: No ground truth data provided - returning 0 reward") + return { + "reward": 0.0, + "format": gui_format_score(prediction), + "accuracy": 0.0, + "f1": 0.0, + "precision": 0.0, + "recall": 0.0, + } + + # Calculate scores + format_score = gui_format_score(prediction) + accuracy_score = gui_accuracy_score(prediction, gt_action, gt_bbox, gt_input_text) + + print(f"[gui_reward] format_score: {format_score}, accuracy_score: {accuracy_score}") + + # For f1_score, create answer string for backward compatibility + answer_dict = { + "action": gt_action, + "gt_bbox": gt_bbox, + "input_text": gt_input_text + } + answer = json.dumps(answer_dict) + f1, precision, recall = f1_score(prediction, answer) + + # Calculate final reward (weighted combination) + final_reward = 0.8 * accuracy_score + 0.2 * format_score + + return { + "reward": final_reward, + "format": format_score, + "accuracy": accuracy_score, + "f1": f1, + "precision": precision, + "recall": recall, + } \ No newline at end of file diff --git a/agents/agents/rewards/qa_reward.py b/agents/agents/rewards/qa_reward.py index bd335b5..a6dd139 100644 --- a/agents/agents/rewards/qa_reward.py +++ b/agents/agents/rewards/qa_reward.py @@ -1,7 +1,7 @@ import re import string from collections import Counter -from typing import List +from typing import List, Dict, Union from .reward_base import reward @@ -118,4 +118,24 @@ def ok_vqa_reward(prediction: str, answers: List[str], trajectory: List[str]) -> f1, precision, recall = f1_score(prediction, answer) f1_scores.append(f1) # All answers are the correct answer, take the max f1 score - return max(f1_scores) \ No newline at end of file + return max(f1_scores) + + +@reward(name="infoseek_reward") +def infoseek_reward(prediction: str, answer: Union[str, List[str]], answer_eval: List[str | Dict], trajectory: List[str]) -> float: + f1_scores = [] + answers = [] + if isinstance(answer, str): + answers.append(answer) + elif isinstance(answer, list): + answers.extend(answer) + + if isinstance(answer_eval[0], str): + answers.extend(answer_eval) + + for _answer in answers: + f1, precision, recall = f1_score(prediction, _answer) + f1_scores.append(f1) + + # All answers are the correct answer, take the max f1 score + return max(f1_scores) diff --git a/agents/agents/rewards/scienceworld_reward.py b/agents/agents/rewards/scienceworld_reward.py index bdef743..03c7922 100644 --- a/agents/agents/rewards/scienceworld_reward.py +++ b/agents/agents/rewards/scienceworld_reward.py @@ -1,4 +1,4 @@ -from agents.envs.scienceworld_env import ScienceWorldEnv +from ..envs.scienceworld_env import ScienceWorldEnv from .reward_base import reward @reward(name="scienceworld_reward", env_cls=ScienceWorldEnv, pool_size=8) diff --git a/agents/agents/rewards/webshop_reward.py b/agents/agents/rewards/webshop_reward.py index 1bf0b1a..52ba44b 100644 --- a/agents/agents/rewards/webshop_reward.py +++ b/agents/agents/rewards/webshop_reward.py @@ -1,4 +1,4 @@ -from agents.envs.webshop_text_env import WebAgentTextEnv +from ..envs.webshop_text_env import WebAgentTextEnv from .reward_base import reward @reward(name="webshop_reward", env_cls=WebAgentTextEnv, pool_size=8) diff --git a/agents/agents/tools/__init__.py b/agents/agents/tools/__init__.py index bd0f768..98d33ea 100644 --- a/agents/agents/tools/__init__.py +++ b/agents/agents/tools/__init__.py @@ -14,6 +14,7 @@ from .src.react.tools import answer_qa, answer_math from .src.search.async_dense_retriever import asyncdense_retrieve from .src.scienceworld.tools import scienceworld_explorer +from .src.ui.tools import pyautogui_code_generator # Export the tools __all__ = [ @@ -36,6 +37,7 @@ "alfworld_get_task_objective" "alfworld_reset" "asyncdense_retrieve" + "pyautogui_code_generator" # "current_env" ] @@ -54,7 +56,8 @@ "answer_math": answer_math, "hallucination_tool": hallucination_tool, "invalid_input_tool": invalid_input_tool, - "dense_retrieve": dense_retrieve + "dense_retrieve": dense_retrieve, + "pyautogui_code_generator": pyautogui_code_generator } # Update the registry with explicit tools diff --git a/agents/agents/tools/src/ui/__init__.py b/agents/agents/tools/src/ui/__init__.py new file mode 100644 index 0000000..f847595 --- /dev/null +++ b/agents/agents/tools/src/ui/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: Apache-2.0 + +from .tools import pyautogui_code_generator + +__all__ = ["pyautogui_code_generator"] \ No newline at end of file diff --git a/agents/agents/tools/src/ui/tools.py b/agents/agents/tools/src/ui/tools.py new file mode 100644 index 0000000..02769ab --- /dev/null +++ b/agents/agents/tools/src/ui/tools.py @@ -0,0 +1,94 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: Apache-2.0 + +import json +from typing import Any +from ...tool_base import tool +from agents.utils.ui_action_parser import parsing_response_to_pyautogui_code + +# Default image dimensions for UI interactions +DEFAULT_IMAGE_HEIGHT = 1080 +DEFAULT_IMAGE_WIDTH = 1920 + + +@tool(name="pyautogui_code_generator") +def pyautogui_code_generator(action: dict, **kwargs) -> str: + """ + Generate PyAutoGUI code from a structured action dictionary. + + Args: + action (dict): Dictionary containing action_type, action_inputs, thought, etc. + **kwargs: Additional parameters like image dimensions + + Returns: + PyAutoGUI code string or execution result + """ + print(f"[pyautogui_code_generator] Received action: {action}") + + # Extract image dimensions from kwargs or use defaults + image_height = kwargs.get("image_height", DEFAULT_IMAGE_HEIGHT) + image_width = kwargs.get("image_width", DEFAULT_IMAGE_WIDTH) + + # Handle the action + if isinstance(action, str): + # Try to parse if it's a JSON string + try: + action = json.loads(action) + except json.JSONDecodeError: + return f"Error: Invalid action format - {action}" + + if not isinstance(action, dict): + return f"Error: Action must be a dictionary, got {type(action)}" + + # Check if this is a terminal action + action_type = action.get("action_type", "") + if action_type in ["finished", "call_user"]: + content = action.get("action_inputs", {}).get("content", "Task completed") + return f"Task completed: {content}" + + # Generate PyAutoGUI code for the action + try: + pyautogui_code = parsing_response_to_pyautogui_code( + [action], + image_height=image_height, + image_width=image_width, + input_swap=True + ) + + # For non-terminal actions, return the code + if pyautogui_code == "DONE": + return "Task completed successfully" + + return f"Generated PyAutoGUI code:\n{pyautogui_code}" + + except Exception as e: + return f"Error generating PyAutoGUI code: {str(e)}" + + +@tool(name="capture_screenshot") +def capture_screenshot(**kwargs) -> str: + """ + Capture a screenshot of the current screen. + + Returns: + Base64 encoded screenshot or error message + """ + try: + import pyautogui + import base64 + from io import BytesIO + + # Take screenshot + screenshot = pyautogui.screenshot() + + # Convert to base64 + buffered = BytesIO() + screenshot.save(buffered, format="PNG") + img_str = base64.b64encode(buffered.getvalue()).decode() + + return f"Screenshot captured successfully (base64): {img_str[:100]}..." + + except ImportError: + return "Error: PyAutoGUI not installed. Please install it to capture screenshots." + except Exception as e: + return f"Error capturing screenshot: {str(e)}" \ No newline at end of file diff --git a/agents/agents/utils/monitor.py b/agents/agents/utils/monitor.py index fdcab41..504f1f5 100644 --- a/agents/agents/utils/monitor.py +++ b/agents/agents/utils/monitor.py @@ -24,6 +24,7 @@ from PIL import Image import io import base64 +import numpy as np @dataclass(slots=True) class MetricEvent: @@ -74,7 +75,13 @@ def __repr__(self) -> str: # noqa: D401 def serialize_for_json(obj): - if isinstance(obj, Image.Image): + if isinstance(obj, np.ndarray): + # Convert numpy array to list + return obj.tolist() + elif isinstance(obj, (np.integer, np.floating)): + # Convert numpy scalars to Python types + return obj.item() + elif isinstance(obj, Image.Image): # Convert image to base64 string buffer = io.BytesIO() obj.save(buffer, format="PNG") diff --git a/agents/agents/utils/ui_action_parser.py b/agents/agents/utils/ui_action_parser.py new file mode 100644 index 0000000..e992fce --- /dev/null +++ b/agents/agents/utils/ui_action_parser.py @@ -0,0 +1,423 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: Apache-2.0 +import re +import ast +import math +from typing import List, Dict, Any, Optional, Tuple + +IMAGE_FACTOR = 1 # Changed to match gui_reward.py +MIN_PIXELS = 100 * 28 * 28 +MAX_PIXELS = 16384 * 28 * 28 +MAX_RATIO = 200 + + +def convert_point_to_coordinates(text: str, is_answer: bool = False) -> str: + """Convert point format to coordinates.""" + pattern = r"(\d+)\s+(\d+)" + + def replace_match(match): + x1, y1 = map(int, match.groups()) + x = (x1 + x1) // 2 # Truncate to integer + y = (y1 + y1) // 2 # Truncate to integer + if is_answer: + return f"({x},{y})" # Only return (x, y) format + return f"({x},{y})" # Return the format with tags + + # Remove [EOS] and replace coordinates + text = re.sub(r"\[EOS\]", "", text) + return re.sub(pattern, replace_match, text).strip() + + +def parse_action(action_str: str) -> Optional[Dict[str, Any]]: + """Parse an action string into function name and arguments.""" + try: + # Parse the string to AST node + node = ast.parse(action_str, mode='eval') + + # Ensure the node is an expression + if not isinstance(node, ast.Expression): + raise ValueError("Not an expression") + + # Get the body of the expression + call = node.body + + # Ensure the body is a function call + if not isinstance(call, ast.Call): + raise ValueError("Not a function call") + + # Get the function name + if isinstance(call.func, ast.Name): + func_name = call.func.id + elif isinstance(call.func, ast.Attribute): + func_name = call.func.attr + else: + func_name = None + + # Get the keyword arguments + kwargs = {} + for kw in call.keywords: + key = kw.arg + # Handle different types of values + if isinstance(kw.value, ast.Constant): + value = kw.value.value + elif isinstance(kw.value, ast.Str): # Compatible with old version Python + value = kw.value.s + else: + value = None + kwargs[key] = value + + return {'function': func_name, 'args': kwargs} + + except Exception as e: + print(f"Failed to parse action '{action_str}': {e}") + return None + + +def escape_single_quotes(text: str) -> str: + """Escape unescaped single quotes.""" + pattern = r"(? int: + """Returns the closest integer to 'number' that is divisible by 'factor'.""" + return round(number / factor) * factor + + +def ceil_by_factor(number: int, factor: int) -> int: + """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" + return math.ceil(number / factor) * factor + + +def floor_by_factor(number: int, factor: int) -> int: + """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" + return math.floor(number / factor) * factor + + +def linear_resize(height: int, + width: int, + factor: int = IMAGE_FACTOR, + min_pixels: int = MIN_PIXELS, + max_pixels: int = MAX_PIXELS) -> Tuple[int, int]: + """Resize image to fit within pixel limits while maintaining aspect ratio.""" + if width * height > max_pixels: + resize_factor = math.sqrt(max_pixels / (width * height)) + width, height = int(width * resize_factor), int(height * resize_factor) + if width * height < min_pixels: + resize_factor = math.sqrt(min_pixels / (width * height)) + width, height = math.ceil(width * resize_factor), math.ceil(height * resize_factor) + return height, width + + +def smart_resize(height: int, + width: int, + factor: int = IMAGE_FACTOR, + min_pixels: int = MIN_PIXELS, + max_pixels: int = MAX_PIXELS) -> Tuple[int, int]: + """ + Rescales the image so that: + 1. Both dimensions are divisible by 'factor'. + 2. Total pixels is within [min_pixels, max_pixels]. + 3. Aspect ratio is maintained as closely as possible. + """ + if max(height, width) / min(height, width) > MAX_RATIO: + raise ValueError( + f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}" + ) + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = floor_by_factor(height / beta, factor) + w_bar = floor_by_factor(width / beta, factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + return h_bar, w_bar + + +def parse_action_to_structure_output(text: str, + factor: int, + origin_resized_height: int, + origin_resized_width: int, + model_type: str = "qwen25vl", + max_pixels: int = 16384 * 28 * 28, + min_pixels: int = 100 * 28 * 28) -> Optional[List[Dict[str, Any]]]: + """Parse action text to structured output.""" + print(f"[parse_action_to_structure_output] Input text: {text[:500] if text else 'Empty text'}") + + # Handle empty or None responses + if not text: + print(f"[parse_action_to_structure_output] Empty text, returning None") + return None + + text = text.strip() + + # Handle various point/box formats + if "" in text: + text = convert_point_to_coordinates(text) + if "start_point=" in text: + text = text.replace("start_point=", "start_box=") + if "end_point=" in text: + text = text.replace("end_point=", "end_box=") + if "point=" in text: + text = text.replace("point=", "start_box=") + + smart_resize_height, smart_resize_width = origin_resized_height, origin_resized_width + if model_type == "qwen25vl": + smart_resize_height, smart_resize_width = smart_resize( + origin_resized_height, + origin_resized_width, + factor=IMAGE_FACTOR, + min_pixels=min_pixels, + max_pixels=max_pixels) + + # Extract thought and reflection + if text.startswith("Thought:"): + thought_pattern = r"Thought: (.+?)(?=\s*Action: |$)" + thought_hint = "Thought: " + elif text.startswith("Reflection:"): + thought_pattern = r"Reflection: (.+?)Action_Summary: (.+?)(?=\s*Action: |$)" + thought_hint = "Reflection: " + elif text.startswith("Action_Summary:"): + thought_pattern = r"Action_Summary: (.+?)(?=\s*Action: |$)" + thought_hint = "Action_Summary: " + else: + thought_pattern = r"Thought: (.+?)(?=\s*Action: |$)" + thought_hint = "Thought: " + + reflection, thought = None, None + thought_match = re.search(thought_pattern, text, re.DOTALL) + if thought_match: + if len(thought_match.groups()) == 1: + thought = thought_match.group(1).strip() + elif len(thought_match.groups()) == 2: + thought = thought_match.group(2).strip() + reflection = thought_match.group(1).strip() + + if "Action:" not in text: + print(f"[parse_action_to_structure_output] No 'Action:' found in text, returning None") + return None + + action_str = text.split("Action: ")[-1] + print(f"[parse_action_to_structure_output] Extracted action string: {action_str[:200]}") + + # Parse multiple actions + tmp_all_action = action_str.split(")\n\n") + all_action = [] + for action_str in tmp_all_action: + if "type(content" in action_str: + if not action_str.strip().endswith(")"): + action_str = action_str.strip() + ")" + # Handle type content escaping + pattern = r"type\(content='(.*?)'\)" + if re.search(pattern, action_str): + content = re.sub(pattern, lambda m: m.group(1), action_str) + action_str = escape_single_quotes(content) + action_str = "type(content='" + action_str + "')" + if not action_str.strip().endswith(")"): + action_str = action_str.strip() + ")" + all_action.append(action_str) + + parsed_actions = [ + parse_action(action.replace("\n", "\\n").lstrip()) + for action in all_action + ] + + actions = [] + for action_instance, raw_str in zip(parsed_actions, all_action): + if action_instance is None: + print(f"Action can't parse: {raw_str}") + continue + + action_type = action_instance["function"] + params = action_instance["args"] + + action_inputs = {} + for param_name, param in params.items(): + if param is None or param == "": + continue + + # Only apply lstrip to string parameters + if isinstance(param, str): + param = param.lstrip() + + action_inputs[param_name.strip()] = param + + # Handle box coordinates (only for string parameters) + if isinstance(param, str) and ("start_box" in param_name or "end_box" in param_name): + ori_box = param + # Remove box tags if present + ori_box = ori_box.replace("<|box_start|>", "").replace("<|box_end|>", "") + numbers = ori_box.replace("(", "").replace(")", "").split(",") + + try: + for num in numbers: + float(num.strip()) + except ValueError: + print(f"Warning: Invalid coordinate format in '{param_name}': '{ori_box}'") + return None + + # Convert coordinates based on model type + if model_type == "qwen25vl": + float_numbers = [] + for num_idx, num in enumerate(numbers): + num = float(num) + if (num_idx + 1) % 2 == 0: + float_numbers.append(float(num / smart_resize_height)) + else: + float_numbers.append(float(num / smart_resize_width)) + else: + # For IMAGE_FACTOR = 1, keep coordinates as pixel values + float_numbers = [float(num.strip()) for num in numbers] + + if len(float_numbers) == 2: + float_numbers = [ + float_numbers[0], float_numbers[1], + float_numbers[0], float_numbers[1] + ] + action_inputs[param_name.strip()] = str(float_numbers) + + # Normalize action types for consistency + normalized_action_type = action_type + if action_type in ["left_single", "left_double", "right_single"]: + normalized_action_type = "click" + elif action_type in ["press", "keydown", "release", "keyup"]: + normalized_action_type = "hotkey" + elif action_type in ["select"]: + normalized_action_type = "drag" + + actions.append({ + "reflection": reflection, + "thought": thought, + "action_type": normalized_action_type, + "action_inputs": action_inputs, + "text": text, + }) + + return actions + + +def parsing_response_to_pyautogui_code(responses: List[Dict[str, Any]], + image_height: int, + image_width: int, + input_swap: bool = True) -> str: + """Convert parsed responses to PyAutoGUI code.""" + pyautogui_code = f"import pyautogui\nimport time\n" + if isinstance(responses, dict): + responses = [responses] + + for response_id, response in enumerate(responses): + observation = response.get("observation", "") + thought = response.get("thought", "") + + if response_id == 0: + pyautogui_code += f"'''\nObservation:\n{observation}\n\nThought:\n{thought}\n'''\n" + else: + pyautogui_code += f"\ntime.sleep(1)\n" + + action_type = response.get("action_type") + action_inputs = response.get("action_inputs", {}) + + if action_type == "hotkey": + hotkey = action_inputs.get("key", action_inputs.get("hotkey", "")) + # Convert arrow keys + hotkey = hotkey.replace("arrowleft", "left").replace("arrowright", "right") + hotkey = hotkey.replace("arrowup", "up").replace("arrowdown", "down") + + if hotkey: + keys = hotkey.split() + convert_keys = [] + for key in keys: + if key == "space": + key = ' ' + convert_keys.append(key) + pyautogui_code += f"\npyautogui.hotkey({', '.join([repr(k) for k in convert_keys])})" + + elif action_type == "type": + content = action_inputs.get("content", "") + content = escape_single_quotes(content) + stripped_content = content.rstrip("\\n").rstrip("\n") + + if content: + if input_swap: + pyautogui_code += f"\nimport pyperclip" + pyautogui_code += f"\npyperclip.copy('{stripped_content}')" + pyautogui_code += f"\npyautogui.hotkey('ctrl', 'v')" + pyautogui_code += f"\ntime.sleep(0.5)\n" + if content.endswith("\n") or content.endswith("\\n"): + pyautogui_code += f"\npyautogui.press('enter')" + else: + pyautogui_code += f"\npyautogui.write('{stripped_content}', interval=0.1)" + pyautogui_code += f"\ntime.sleep(0.5)\n" + if content.endswith("\n") or content.endswith("\\n"): + pyautogui_code += f"\npyautogui.press('enter')" + + elif action_type in ["drag", "select"]: + start_box = action_inputs.get("start_box") + end_box = action_inputs.get("end_box") + if start_box and end_box: + x1, y1, x2, y2 = eval(start_box) + sx = round(float((x1 + x2) / 2) * image_width, 3) + sy = round(float((y1 + y2) / 2) * image_height, 3) + x1, y1, x2, y2 = eval(end_box) + ex = round(float((x1 + x2) / 2) * image_width, 3) + ey = round(float((y1 + y2) / 2) * image_height, 3) + pyautogui_code += ( + f"\npyautogui.moveTo({sx}, {sy})\n" + f"\npyautogui.dragTo({ex}, {ey}, duration=1.0)\n") + + elif action_type == "scroll": + start_box = action_inputs.get("start_box") + if start_box: + x1, y1, x2, y2 = eval(start_box) + x = round(float((x1 + x2) / 2) * image_width, 3) + y = round(float((y1 + y2) / 2) * image_height, 3) + else: + x = None + y = None + + direction = action_inputs.get("direction", "") + + if x is None: + if "up" in direction.lower(): + pyautogui_code += f"\npyautogui.scroll(5)" + elif "down" in direction.lower(): + pyautogui_code += f"\npyautogui.scroll(-5)" + else: + if "up" in direction.lower(): + pyautogui_code += f"\npyautogui.scroll(5, x={x}, y={y})" + elif "down" in direction.lower(): + pyautogui_code += f"\npyautogui.scroll(-5, x={x}, y={y})" + + elif action_type in ["click", "left_single", "left_double", "right_single", "hover"]: + start_box = action_inputs.get("start_box") + start_box = str(start_box) + if start_box: + start_box = eval(start_box) + if len(start_box) == 4: + x1, y1, x2, y2 = start_box + elif len(start_box) == 2: + x1, y1 = start_box + x2 = x1 + y2 = y1 + x = round(float((x1 + x2) / 2) * image_width, 3) + y = round(float((y1 + y2) / 2) * image_height, 3) + + if action_type == "left_single" or action_type == "click": + pyautogui_code += f"\npyautogui.click({x}, {y}, button='left')" + elif action_type == "left_double": + pyautogui_code += f"\npyautogui.doubleClick({x}, {y}, button='left')" + elif action_type == "right_single": + pyautogui_code += f"\npyautogui.click({x}, {y}, button='right')" + elif action_type == "hover": + pyautogui_code += f"\npyautogui.moveTo({x}, {y})" + + elif action_type in ["finished"]: + pyautogui_code = f"DONE" + + else: + pyautogui_code += f"\n# Unrecognized action type: {action_type}" + + return pyautogui_code \ No newline at end of file diff --git a/agents/pytest.ini b/agents/pytest.ini new file mode 100644 index 0000000..f476b91 --- /dev/null +++ b/agents/pytest.ini @@ -0,0 +1,6 @@ +[pytest] +asyncio_mode = auto +testpaths = tests +python_files = test_*.py +python_classes = Test* +python_functions = test_* \ No newline at end of file diff --git a/agents/tests/unit/agents/test_gui_agent.py b/agents/tests/unit/agents/test_gui_agent.py new file mode 100644 index 0000000..bd1eca7 --- /dev/null +++ b/agents/tests/unit/agents/test_gui_agent.py @@ -0,0 +1,201 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import sys +from pathlib import Path + +# Add the agents module to path +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from agents.agents.specialized.gui_agent import GUIAgent +from agents.rewards.gui_reward import gui_reward +from agents.utils.ui_action_parser import parse_action_to_structure_output, IMAGE_FACTOR + + +class TestGUIAgent: + """Test suite for GUI Agent implementation.""" + + def test_gui_agent_initialization(self): + """Test GUI agent can be initialized.""" + # Skip loading actual model for unit test + agent = GUIAgent( + model_name_or_path="ByteDance-Seed/UI-TARS-1.5-7B", + template="qwen2.5-vl", + tools=["pyautogui_code_generator"] + ) + assert agent is not None + assert agent.system_prompt is not None + assert "GUI automation agent" in agent.system_prompt + + def test_gui_agent_parse_valid_response(self): + """Test GUI agent can parse valid responses.""" + # Skip loading actual model for unit test + agent = GUIAgent( + model_name_or_path="ByteDance-Seed/UI-TARS-1.5-7B", + template="qwen2.5-vl", + tools=[] + ) + + responses = [ + "Thought: I need to click on the button.\nAction: click(start_box='<|box_start|>(100,200)<|box_end|>')" + ] + + messages = agent.parse(responses, tools=[]) + + assert len(messages) == 1 + assert messages[0]["role"] == "assistant" + assert messages[0]["content"][0]["text"] == responses[0] + assert len(messages[0]["tool_calls"]) == 1 + assert messages[0]["status"] == "continue" + + def test_gui_agent_parse_terminal_action(self): + """Test GUI agent recognizes terminal actions.""" + # Skip loading actual model for unit test + agent = GUIAgent( + model_name_or_path="ByteDance-Seed/UI-TARS-1.5-7B", + template="qwen2.5-vl", + tools=[] + ) + + responses = [ + "Thought: Task is complete.\nAction: finished(content='Task completed successfully')" + ] + + messages = agent.parse(responses, tools=[]) + + assert len(messages) == 1 + assert messages[0]["status"] == "terminal" + + def test_gui_agent_parse_empty_response(self): + """Test GUI agent handles empty responses gracefully.""" + # Skip loading actual model for unit test + agent = GUIAgent( + model_name_or_path="ByteDance-Seed/UI-TARS-1.5-7B", + template="qwen2.5-vl", + tools=[] + ) + + responses = [""] + + messages = agent.parse(responses, tools=[]) + + assert len(messages) == 1 + assert "wait" in messages[0]["content"][0]["text"].lower() + assert messages[0]["status"] == "continue" + + +class TestGUIReward: + """Test suite for GUI reward function.""" + + @pytest.mark.asyncio + async def test_gui_reward_with_ground_truth(self): + """Test GUI reward with ground truth data.""" + prediction = "Thought: I need to click on the button.\nAction: click(start_box='<|box_start|>(100,200)<|box_end|>')" + + result = await gui_reward( + prediction=prediction, + gt_action="click", + gt_bbox=[100, 200], + gt_input_text="" + ) + + assert isinstance(result, dict) + assert "reward" in result + assert "format" in result + assert "accuracy" in result + assert result["format"] == 1.0 # Valid format + assert result["accuracy"] == 1.0 # Exact match + assert result["reward"] > 0.9 # High reward for perfect match + + @pytest.mark.asyncio + async def test_gui_reward_without_ground_truth(self): + """Test GUI reward without ground truth data.""" + prediction = "Thought: I need to click.\nAction: click(start_box='<|box_start|>(100,200)<|box_end|>')" + + result = await gui_reward(prediction=prediction) + + assert isinstance(result, dict) + assert result["reward"] == 0.0 # No ground truth + + @pytest.mark.asyncio + async def test_gui_reward_empty_prediction(self): + """Test GUI reward with empty prediction.""" + result = await gui_reward( + prediction="", + gt_action="click", + gt_bbox=[100, 200], + gt_input_text="" + ) + + assert isinstance(result, dict) + assert result["format"] == 0.0 # Invalid format + + @pytest.mark.asyncio + async def test_gui_reward_type_action(self): + """Test GUI reward for typing action.""" + prediction = "Thought: I need to type text.\nAction: type(content='hello world')" + + result = await gui_reward( + prediction=prediction, + gt_action="type", + gt_bbox=[], + gt_input_text="hello world" + ) + + assert isinstance(result, dict) + assert result["format"] == 1.0 + assert result["accuracy"] == 1.0 # Text matches + + +class TestUIActionParser: + """Test suite for UI action parser.""" + + def test_parse_click_action(self): + """Test parsing click action.""" + text = "Thought: Click button.\nAction: click(start_box='<|box_start|>(100,200)<|box_end|>')" + + result = parse_action_to_structure_output( + text, IMAGE_FACTOR, 1080, 1920 + ) + + assert result is not None + assert len(result) == 1 + assert result[0]["action_type"] == "click" + assert "start_box" in result[0]["action_inputs"] + + def test_parse_type_action(self): + """Test parsing type action.""" + text = "Thought: Type text.\nAction: type(content='hello world')" + + result = parse_action_to_structure_output( + text, IMAGE_FACTOR, 1080, 1920 + ) + + assert result is not None + assert len(result) == 1 + assert result[0]["action_type"] == "type" + assert result[0]["action_inputs"]["content"] == "hello world" + + def test_parse_invalid_action(self): + """Test parsing invalid action returns None.""" + text = "This is not a valid action format" + + result = parse_action_to_structure_output( + text, IMAGE_FACTOR, 1080, 1920 + ) + + assert result is None + + def test_parse_empty_text(self): + """Test parsing empty text returns None.""" + result = parse_action_to_structure_output( + "", IMAGE_FACTOR, 1080, 1920 + ) + + assert result is None + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) \ No newline at end of file diff --git a/verl b/verl index 237d9ca..0ba7136 160000 --- a/verl +++ b/verl @@ -1 +1 @@ -Subproject commit 237d9cacd2ede001c21f1a1daa44e8e8598993e1 +Subproject commit 0ba71360604c85ca7e83168520169fa858681633