diff --git a/agents/agents/agents/__init__.py b/agents/agents/agents/__init__.py
index 75c339d..6152c72 100644
--- a/agents/agents/agents/__init__.py
+++ b/agents/agents/agents/__init__.py
@@ -1,5 +1,6 @@
from .react.react_agent import ReactAgent
from .specialized.code_agent import CodeAgent
from .specialized.think_agent import ThinkAgent
+from .specialized.gui_agent import GUIAgent
-__all__ = ["ReactAgent", "CodeAgent", "ThinkAgent"]
\ No newline at end of file
+__all__ = ["ReactAgent", "CodeAgent", "ThinkAgent", "GUIAgent"]
\ No newline at end of file
diff --git a/agents/agents/agents/agent_base.py b/agents/agents/agents/agent_base.py
index 43b01f2..2a57ecd 100644
--- a/agents/agents/agents/agent_base.py
+++ b/agents/agents/agents/agent_base.py
@@ -157,6 +157,7 @@ def _init_llm_engine(self, model_name_or_path: str, backend: str):
def set_llm_engine(self, llm_engine: Any, tokenizer: Any, processor: Any):
assert self.backend == "async_verl", "Only async verl backend is supported for now"
+
self.llm_engine.llm_engine = llm_engine
self.tokenizer = tokenizer
self.processor = processor
diff --git a/agents/agents/agents/auto.py b/agents/agents/agents/auto.py
index 84050f4..c62e8f7 100644
--- a/agents/agents/agents/auto.py
+++ b/agents/agents/agents/auto.py
@@ -1,11 +1,12 @@
from typing import Any, Callable, Dict, List, Optional, Type, Union
from .specialized.think_agent import ThinkAgent
-from agents.agents.specialized.openai_agent import OpenAIAgent
+from .specialized.openai_agent import OpenAIAgent
from ..tools import get_tools_from_names
from .agent_base import BaseAgent
from .react.react_agent import ReactAgent
from .specialized.code_agent import CodeAgent
+from .specialized.gui_agent import GUIAgent
from ..rewards.reward_base import get_reward_from_name
@@ -165,4 +166,5 @@ def from_pretrained(
AutoAgent.register_agent("react", ReactAgent)
AutoAgent.register_agent("code", CodeAgent)
AutoAgent.register_agent("openai", OpenAIAgent)
-AutoAgent.register_agent("think", ThinkAgent)
\ No newline at end of file
+AutoAgent.register_agent("think", ThinkAgent)
+AutoAgent.register_agent("gui", GUIAgent)
\ No newline at end of file
diff --git a/agents/agents/agents/chain/chain_base.py b/agents/agents/agents/chain/chain_base.py
index 02fa27d..94e9e8a 100644
--- a/agents/agents/agents/chain/chain_base.py
+++ b/agents/agents/agents/chain/chain_base.py
@@ -312,6 +312,10 @@ async def _run_single_chain(self,
)
thought_node.is_terminal = new_msg.get("status", "continue") in self.terminal_status
current_node = thought_node
+
+ # Check if the thought node is terminal - if so, break the loop
+ if current_node.is_terminal:
+ break
# Handle tool calls
if current_node.messages[-1].get("tool_calls"):
diff --git a/agents/agents/agents/specialized/gui_agent.py b/agents/agents/agents/specialized/gui_agent.py
new file mode 100644
index 0000000..e9dd812
--- /dev/null
+++ b/agents/agents/agents/specialized/gui_agent.py
@@ -0,0 +1,226 @@
+# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
+# SPDX-License-Identifier: Apache-2.0
+
+import json
+from typing import List, Any, Tuple, Dict, Optional
+from ..agent_base import BaseAgent
+from agents.utils.ui_action_parser import parse_action_to_structure_output, IMAGE_FACTOR
+
+# Default image dimensions
+TEST_IMAGE_HEIGHT = 1080
+TEST_IMAGE_WIDTH = 1920
+
+# GUI Agent system prompt
+GUI_AGENT_SYSTEM_PROMPT = """You are a GUI automation agent. Given a task and screenshot, you must analyze the screen and perform the next action.
+
+## Response Format (REQUIRED)
+You MUST always respond with exactly two lines:
+Thought: [Describe what you see and what action to take]
+Action: [Choose ONE action from the list below]
+
+## Action Space
+
+click(start_box='<|box_start|>(x1,y1)<|box_end|>')
+type(content='xxx') # Use escape characters \\', \\", and \\n in content part
+scroll(direction='down or up or right or left')
+
+## Examples
+Example 1:
+Thought: I need to click on the search button at coordinates (100, 200).
+Action: click(start_box='<|box_start|>(100,200)<|box_end|>')
+
+Example 2:
+Thought: I need to type "hello world" in the text field.
+Action: type(content='hello world')
+
+## Note
+- Use English in `Thought` and `Action` part
+- Always provide both Thought and Action lines
+- Coordinates should be in pixel values"""
+
+
+class GUIAgent(BaseAgent):
+ """GUI Agent for interacting with graphical user interfaces."""
+
+ def __init__(
+ self,
+ model_name_or_path: str,
+ template: str,
+ tools: List = None,
+ **kwargs
+ ):
+ """
+ Initialize GUI Agent.
+
+ Args:
+ model_name_or_path: Path to the vision-language model
+ template: Template name for formatting prompts
+ tools: List of tools available to the agent
+ **kwargs: Additional arguments
+ """
+ super().__init__(
+ model_name_or_path=model_name_or_path,
+ template=template,
+ system_prompt=GUI_AGENT_SYSTEM_PROMPT,
+ tools=tools,
+ max_length=kwargs.get("max_length", 8192),
+ **kwargs
+ )
+ self.action_counter = 0 # Track number of actions taken
+ self.max_retries = 3 # Maximum retries for empty responses
+
+ # def _init_llm_engine(self, model_name_or_path: str, backend: str = "vllm"):
+ # """
+ # Override to handle vision-language models properly.
+
+ # For GUI agents using vision-language models like Qwen2.5-VL,
+ # we need special handling since they're not standard causal LM models.
+ # """
+ # # For unit tests or when model loading should be skipped
+ # # if model_name_or_path == "ByteDance-Seed/UI-TARS-1.5-7B":
+ # # # Return mock objects for testing
+ # # print(f"[GUIAgent] Skipping actual model load for testing: {model_name_or_path}")
+ # # return None, None, None
+
+ # # Otherwise use parent's initialization
+ # return super()._init_llm_engine(model_name_or_path, backend)
+
+ def parse(self, responses: List[str], tools: List[Any]) -> List[Dict[str, Any]]:
+ """
+ Parse model responses into structured messages.
+
+ Args:
+ responses: List of model response strings
+ tools: List of available tools
+
+ Returns:
+ List of structured messages with tool calls
+ """
+ print(f"[GUIAgent.parse] Number of responses: {len(responses)}")
+ print(f"[GUIAgent.parse] Raw responses type: {type(responses)}")
+
+ new_messages_list = []
+
+ # Process each response
+ processed_responses = []
+ for resp in responses:
+ if resp and "Thought:" in resp and "Action:" in resp:
+ processed_responses.append(resp)
+ elif resp and resp.strip():
+ # Try to reformat responses that don't have the expected format
+ resp_lower = resp.lower()
+ print(f"[GUIAgent.parse] Response missing format, reformatting: {resp[:100]}")
+
+ # Check if it contains action-like content
+ if any(action in resp_lower for action in ['click', 'type', 'scroll']):
+ formatted_resp = f"Thought: Executing action based on response.\nAction: {resp.strip()}"
+ else:
+ # Default to click at center if no clear action
+ formatted_resp = f"Thought: {resp.strip()}\nAction: click(start_box='<|box_start|>(960,540)<|box_end|>')"
+ processed_responses.append(formatted_resp)
+ else:
+ # Handle empty responses with default click at center
+ self.action_counter += 1
+ processed_responses.append(f"Thought: Processing the screen (attempt {self.action_counter}).\nAction: click(start_box='<|box_start|>(960,540)<|box_end|>')")
+
+ responses = processed_responses
+
+ # Log responses for debugging
+ for idx, resp in enumerate(responses[:3]): # Log first 3 responses
+ if resp:
+ print(f"[GUIAgent.parse] Response {idx} length: {len(resp)}, preview: {resp[:200]}")
+ else:
+ print(f"[GUIAgent.parse] Response {idx} is None or empty")
+
+ # Parse actions from responses
+ action_list = []
+ for response in responses:
+ parsed = parse_action_to_structure_output(
+ response,
+ IMAGE_FACTOR,
+ TEST_IMAGE_HEIGHT,
+ TEST_IMAGE_WIDTH
+ )
+ action_list.append(parsed)
+
+ # Create messages with tool calls
+ for i, (response, actions) in enumerate(zip(responses, action_list)):
+ print(f"[GUIAgent.parse] Processing response {i+1}: response_length={len(response) if response else 0}, actions={actions}")
+
+ tool_calls = []
+
+ if actions is not None and len(actions) > 0:
+ if len(actions) > 1:
+ print(f"[GUIAgent.parse] Warning: Multiple actions found ({len(actions)}), using first one")
+ action = actions[0]
+ tool_calls = [{
+ "id": str(i),
+ "type": "function",
+ "function": {
+ "name": "pyautogui_code_generator",
+ "arguments": json.dumps({"action": action})
+ }
+ }]
+ else:
+ # If no action was parsed, create a default click action at center
+ print(f"[GUIAgent.parse] No action parsed from response, creating default click action")
+ default_action = {
+ "action_type": "click",
+ "action_inputs": {"start_box": "(960, 540)"},
+ "thought": "Clicking at screen center",
+ "reflection": None
+ }
+ tool_calls = [{
+ "id": str(i),
+ "type": "function",
+ "function": {
+ "name": "pyautogui_code_generator",
+ "arguments": json.dumps({"action": default_action})
+ }
+ }]
+
+ # Always terminate after one turn since we only have 3 action types
+ # and no explicit termination action
+ status = "terminal"
+ if actions and isinstance(actions[0], dict):
+ action_type = actions[0].get("action_type", "")
+ print(f"[GUIAgent.parse] Action type: {action_type}, terminating after one turn")
+
+ message = {
+ "role": "assistant",
+ "content": [{"type": "text", "text": response}] if response else [{"type": "text", "text": ""}],
+ "tool_calls": tool_calls,
+ "loss": True,
+ "status": status
+ }
+ print(f"[GUIAgent.parse] Created message with status={status}, tool_calls={len(tool_calls)}, content_length={len(response)}")
+ new_messages_list.append(message)
+
+ return new_messages_list
+
+ def format_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """
+ Format messages for the vision-language model.
+
+ Args:
+ messages: List of messages to format
+
+ Returns:
+ Formatted messages suitable for VLM input
+ """
+ formatted_messages = []
+
+ for msg in messages:
+ formatted_msg = {
+ "role": msg.get("role"),
+ "content": msg.get("content", "")
+ }
+
+ # Handle image content if present
+ if "images" in msg:
+ # Convert images to appropriate format for the model
+ formatted_msg["images"] = msg["images"]
+
+ formatted_messages.append(formatted_msg)
+
+ return formatted_messages
\ No newline at end of file
diff --git a/agents/agents/rewards/__init__.py b/agents/agents/rewards/__init__.py
index 1d0b712..a9a6185 100644
--- a/agents/agents/rewards/__init__.py
+++ b/agents/agents/rewards/__init__.py
@@ -4,6 +4,7 @@
from .webshop_reward import webshop_reward
from .alfworld_reward import alfworld_episode_reward
from .scienceworld_reward import scienceworld_reward
+from .gui_reward import gui_reward
-__all__ = ["alfworld_episode_reward","qa_f1_reward", "math_reward", "math_reward_tool", "math_reward_think", "RewardFunction", "get_reward_from_name", "get_rewards_from_names", "list_available_rewards", "register_reward", "llm_as_judge_client_math_reward", "webshop_reward", "alfworld_episode_reward"]
+__all__ = ["alfworld_episode_reward","qa_f1_reward", "math_reward", "math_reward_tool", "math_reward_think", "RewardFunction", "get_reward_from_name", "get_rewards_from_names", "list_available_rewards", "register_reward", "llm_as_judge_client_math_reward", "webshop_reward", "alfworld_episode_reward", "gui_reward"]
\ No newline at end of file
diff --git a/agents/agents/rewards/gui_reward.py b/agents/agents/rewards/gui_reward.py
new file mode 100644
index 0000000..ae24d7a
--- /dev/null
+++ b/agents/agents/rewards/gui_reward.py
@@ -0,0 +1,357 @@
+# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
+# SPDX-License-Identifier: Apache-2.0
+
+import re
+import json
+import ast
+from typing import Dict, Any, List, Tuple, Optional
+
+from .reward_base import reward
+from agents.utils.ui_action_parser import parse_action_to_structure_output, IMAGE_FACTOR
+
+# Image dimensions for testing
+TEST_IMAGE_HEIGHT = 1080
+TEST_IMAGE_WIDTH = 1920
+
+
+def normalize_answer(s: str) -> set:
+ """Normalize answer string for comparison."""
+ def remove_punctuation(text):
+ return re.sub(r"[^\w\s]", "", text)
+
+ def lower(text):
+ return text.lower()
+
+ return set(lower(remove_punctuation(s)).split())
+
+
+def f1_score(prediction: str, ground_truth: str) -> Tuple[float, float, float]:
+ """Calculate F1 score between prediction and ground truth."""
+ normalized_prediction = normalize_answer(prediction)
+ normalized_ground_truth = normalize_answer(ground_truth)
+
+ if not normalized_prediction and not normalized_ground_truth:
+ return 1.0, 1.0, 1.0
+
+ common_tokens = normalized_prediction.intersection(normalized_ground_truth)
+
+ precision = len(common_tokens) / len(normalized_prediction) if normalized_prediction else 0.0
+ recall = len(common_tokens) / len(normalized_ground_truth) if normalized_ground_truth else 0.0
+
+ f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
+
+ return f1, precision, recall
+
+
+def extract_action(content: str) -> str:
+ """Extract action type from response content."""
+ try:
+ parsed = parse_action_to_structure_output(content, IMAGE_FACTOR, TEST_IMAGE_HEIGHT, TEST_IMAGE_WIDTH, model_type="default")
+ if parsed and len(parsed) > 0:
+ action_dict = parsed[0]
+ action_type = action_dict.get("action_type", "no action")
+
+ # Check for specific action types in raw text
+ if action_type == "click" and "Action:" in content:
+ action_text = content.split("Action:")[1].strip()
+ if action_text.startswith("left_double"):
+ return "left_double"
+ elif action_text.startswith("right_single"):
+ return "right_single"
+ elif action_text.startswith("click"):
+ return "click"
+ else:
+ return "click"
+
+ # Map normalized action types
+ if action_type == "hotkey":
+ return "hotkey"
+ elif action_type == "drag":
+ return "drag"
+
+ return action_type
+ except Exception as e:
+ print(f"[extract_action] Error: {e}")
+ return "no action"
+
+
+def extract_input_text(content: str) -> str:
+ """Extract input text from action content."""
+ try:
+ parsed = parse_action_to_structure_output(content, IMAGE_FACTOR, TEST_IMAGE_HEIGHT, TEST_IMAGE_WIDTH, model_type="default")
+ if parsed and len(parsed) > 0:
+ action_dict = parsed[0]
+ action_type = action_dict.get("action_type")
+ action_inputs = action_dict.get("action_inputs", {})
+
+ # Extract text based on action type
+ if action_type == 'type':
+ return action_inputs.get('content', '')
+ elif action_type == 'scroll':
+ return action_inputs.get('direction', 'down')
+ elif action_type == 'hotkey':
+ return action_inputs.get('key', action_inputs.get('hotkey', ''))
+ elif action_type == 'finished':
+ return action_inputs.get('content', '')
+
+ return ""
+ except Exception as e:
+ print(f"[extract_input_text] Error: {e}")
+ return ""
+
+
+def extract_coord(content: str) -> Tuple[list, bool]:
+ """Extract coordinates from action content and normalize them to 0-1 range."""
+ try:
+ parsed = parse_action_to_structure_output(content, IMAGE_FACTOR, TEST_IMAGE_HEIGHT, TEST_IMAGE_WIDTH, model_type="default")
+ if parsed and len(parsed) > 0:
+ action_dict = parsed[0]
+ action_inputs = action_dict.get("action_inputs", {})
+
+ # Try to get coordinates from start_box
+ if "start_box" in action_inputs:
+ try:
+ coords = ast.literal_eval(action_inputs["start_box"])
+ # Ensure coords is a list
+ if isinstance(coords, (list, tuple)):
+ if len(coords) == 2:
+ # Point format [x, y] in pixels - normalize to 0-1
+ normalized_coords = [
+ coords[0] / TEST_IMAGE_WIDTH,
+ coords[1] / TEST_IMAGE_HEIGHT
+ ]
+ print(f"[extract_coord] Normalized point from {coords} to {normalized_coords}")
+ return normalized_coords, True
+ elif len(coords) == 4:
+ # Box format [x1, y1, x2, y2] in pixels - normalize to 0-1
+ normalized_coords = [
+ coords[0] / TEST_IMAGE_WIDTH,
+ coords[1] / TEST_IMAGE_HEIGHT,
+ coords[2] / TEST_IMAGE_WIDTH,
+ coords[3] / TEST_IMAGE_HEIGHT
+ ]
+ print(f"[extract_coord] Normalized box from {coords} to {normalized_coords}")
+ return normalized_coords, True
+ else:
+ print(f"[extract_coord] Unexpected coord format: {coords}")
+ except Exception as e:
+ print(f"[extract_coord] Error parsing coordinates: {e}")
+
+ except Exception as e:
+ print(f"[extract_coord] Error: {e}")
+ return [], False
+
+
+def gui_format_score(predict_str: str) -> float:
+ """Calculate format score for GUI prediction."""
+ try:
+ parsed_actions = parse_action_to_structure_output(predict_str, IMAGE_FACTOR, TEST_IMAGE_HEIGHT, TEST_IMAGE_WIDTH, model_type="default")
+ return 1.0 if parsed_actions else 0.0
+ except Exception:
+ return 0.0
+
+
+def gui_accuracy_score(predict_str: str, gt_action: str, gt_bbox: list, gt_input_text: str) -> float:
+ """Calculate accuracy score for GUI prediction (0.5 for action type, 0.5 for parameters)."""
+ try:
+ gt_action = gt_action.lower() if gt_action else ''
+
+ pred_action = extract_action(predict_str).lower()
+ pred_coord, has_coord = extract_coord(predict_str)
+ pred_input_text = extract_input_text(predict_str)
+
+ print(f"[gui_accuracy_score] gt_action: {gt_action}, pred_action: {pred_action}")
+ print(f"[gui_accuracy_score] gt_bbox: {gt_bbox}, pred_coord: {pred_coord}, has_coord: {has_coord}")
+ print(f"[gui_accuracy_score] gt_input_text: {gt_input_text}, pred_input_text: {pred_input_text}")
+
+ # Map all click variants to 'click' for the 3-action space
+ action_mapping = {
+ 'left_single': 'click',
+ 'left_double': 'click',
+ 'right_single': 'click',
+ 'click': 'click',
+ 'type': 'type',
+ 'scroll': 'scroll'
+ }
+
+ # Normalize actions
+ pred_action_normalized = action_mapping.get(pred_action, pred_action)
+ gt_action_normalized = action_mapping.get(gt_action, gt_action)
+
+ score = 0.0
+
+ # Score calculation: 0.5 for action type, 0.5 for parameters (bbox/text)
+
+ # 1. Action type matching (0.5 points)
+ if pred_action_normalized == gt_action_normalized:
+ score += 0.5
+ print(f"[gui_accuracy_score] Action matched: +0.5 points")
+ else:
+ print(f"[gui_accuracy_score] Action mismatch: {pred_action_normalized} vs {gt_action_normalized}")
+
+ # 2. Parameter matching (0.5 points) - depends on action type
+ if gt_action_normalized == 'click':
+ # For click: check bbox coordinates
+ if gt_bbox and len(gt_bbox) > 0:
+ if has_coord:
+ # Both have coordinates, calculate distance
+ # Handle different gt_bbox formats (already normalized 0-1)
+ if len(gt_bbox) == 2:
+ gt_x, gt_y = gt_bbox
+ elif len(gt_bbox) == 4:
+ gt_x = (gt_bbox[0] + gt_bbox[2]) / 2
+ gt_y = (gt_bbox[1] + gt_bbox[3]) / 2
+ else:
+ print(f"[gui_accuracy_score] Invalid gt_bbox format: {gt_bbox}")
+ return score
+
+ # Get predicted center (already normalized 0-1)
+ if len(pred_coord) == 2:
+ pred_x, pred_y = pred_coord
+ elif len(pred_coord) == 4:
+ pred_x = (pred_coord[0] + pred_coord[2]) / 2
+ pred_y = (pred_coord[1] + pred_coord[3]) / 2
+ else:
+ print(f"[gui_accuracy_score] Invalid pred_coord format: {pred_coord}")
+ return score
+
+ # Calculate distance in normalized space
+ distance = ((pred_x - gt_x) ** 2 + (pred_y - gt_y) ** 2) ** 0.5
+
+ # Threshold in normalized space (5% of diagonal)
+ threshold = 0.05 * (2 ** 0.5) # sqrt(1^2 + 1^2) ≈ 0.07
+
+ if distance < threshold:
+ score += 0.5
+ print(f"[gui_accuracy_score] Bbox matched (distance={distance:.4f}): +0.5 points")
+ else:
+ print(f"[gui_accuracy_score] Bbox too far (distance={distance:.4f}, threshold={threshold:.4f})")
+ else:
+ print(f"[gui_accuracy_score] No predicted coordinates for click action")
+ else:
+ # No gt_bbox required, any click gets parameter points
+ score += 0.5
+ print(f"[gui_accuracy_score] No gt_bbox required: +0.5 points")
+
+ elif gt_action_normalized == 'type':
+ # For type: check text content
+ if gt_input_text and gt_input_text != "no input text":
+ f1, _, _ = f1_score(pred_input_text, gt_input_text)
+ if f1 >= 0.5:
+ score += 0.5
+ print(f"[gui_accuracy_score] Type text matched (f1={f1:.2f}): +0.5 points")
+ else:
+ print(f"[gui_accuracy_score] Type text mismatch (f1={f1:.2f})")
+ else:
+ # No text required, any type action gets parameter points
+ score += 0.5
+ print(f"[gui_accuracy_score] No text required: +0.5 points")
+
+ elif gt_action_normalized == 'scroll':
+ # For scroll: only check direction (no bbox needed)
+ if gt_input_text and gt_input_text != "no input text":
+ if pred_input_text.lower() == gt_input_text.lower():
+ score += 0.5
+ print(f"[gui_accuracy_score] Scroll direction matched: +0.5 points")
+ else:
+ print(f"[gui_accuracy_score] Scroll direction mismatch: {pred_input_text} vs {gt_input_text}")
+ else:
+ # No direction specified, any scroll gets parameter points
+ score += 0.5
+ print(f"[gui_accuracy_score] No scroll direction required: +0.5 points")
+
+ print(f"[gui_accuracy_score] Final score: {score}")
+ return score
+
+ except Exception as e:
+ print(f"Error in gui_accuracy_score: {e}")
+ print(f"predict_str: {predict_str}")
+ print(f"gt_action: {gt_action}, gt_bbox: {gt_bbox}, gt_input_text: {gt_input_text}")
+ return 0.0
+
+
+@reward(name="gui_reward")
+def gui_reward(prediction: str, trajectory: List[Dict] = None, gt_action: str = "", gt_bbox: list = None, gt_input_text: str = "", **kwargs) -> Dict[str, float]:
+ """
+ Calculate GUI reward based on prediction accuracy.
+
+ Args:
+ prediction: Model prediction string
+ trajectory: Conversation trajectory (optional)
+ **kwargs: Additional parameters including ground truth
+
+ Returns:
+ Dictionary with reward scores
+ """
+ print(f"[gui_reward] Called with prediction: {prediction[:200] if prediction else 'None'}")
+ print(f"[gui_reward] kwargs keys: {list(kwargs.keys())}")
+
+ # Handle empty predictions
+ if not prediction or prediction.strip() == "":
+ print(f"[gui_reward] Warning: Empty prediction received")
+ # Check if there's a default action in trajectory
+ if trajectory and len(trajectory) > 0:
+ for msg in reversed(trajectory):
+ if msg.get('role') == 'assistant' and msg.get('content'):
+ prediction = msg['content']
+ print(f"[gui_reward] Using trajectory content as prediction: {prediction[:100]}")
+ break
+
+ # if not prediction or prediction.strip() == "":
+ # prediction = "Thought: No response generated.\nAction: wait()"
+ # print(f"[gui_reward] Using default prediction")
+
+ # Handle None values for parameters
+ if gt_bbox is None:
+ gt_bbox = []
+
+ # Convert numpy array to list if needed
+ if hasattr(gt_bbox, 'tolist'):
+ gt_bbox = gt_bbox.tolist()
+
+ print(f"[gui_reward] gt_action: {gt_action}, gt_bbox: {gt_bbox}, gt_input_text: {gt_input_text}")
+
+ # Handle "no input text" as empty
+ if gt_input_text == "no input text":
+ gt_input_text = ""
+
+ # Keep bbox in normalized coordinates (0-1 range)
+ # Both prediction and ground truth use normalized coordinates
+
+ if not gt_action and not gt_bbox and not gt_input_text:
+ print(f"[gui_reward] Warning: No ground truth data provided - returning 0 reward")
+ return {
+ "reward": 0.0,
+ "format": gui_format_score(prediction),
+ "accuracy": 0.0,
+ "f1": 0.0,
+ "precision": 0.0,
+ "recall": 0.0,
+ }
+
+ # Calculate scores
+ format_score = gui_format_score(prediction)
+ accuracy_score = gui_accuracy_score(prediction, gt_action, gt_bbox, gt_input_text)
+
+ print(f"[gui_reward] format_score: {format_score}, accuracy_score: {accuracy_score}")
+
+ # For f1_score, create answer string for backward compatibility
+ answer_dict = {
+ "action": gt_action,
+ "gt_bbox": gt_bbox,
+ "input_text": gt_input_text
+ }
+ answer = json.dumps(answer_dict)
+ f1, precision, recall = f1_score(prediction, answer)
+
+ # Calculate final reward (weighted combination)
+ final_reward = 0.8 * accuracy_score + 0.2 * format_score
+
+ return {
+ "reward": final_reward,
+ "format": format_score,
+ "accuracy": accuracy_score,
+ "f1": f1,
+ "precision": precision,
+ "recall": recall,
+ }
\ No newline at end of file
diff --git a/agents/agents/rewards/qa_reward.py b/agents/agents/rewards/qa_reward.py
index bd335b5..a6dd139 100644
--- a/agents/agents/rewards/qa_reward.py
+++ b/agents/agents/rewards/qa_reward.py
@@ -1,7 +1,7 @@
import re
import string
from collections import Counter
-from typing import List
+from typing import List, Dict, Union
from .reward_base import reward
@@ -118,4 +118,24 @@ def ok_vqa_reward(prediction: str, answers: List[str], trajectory: List[str]) ->
f1, precision, recall = f1_score(prediction, answer)
f1_scores.append(f1)
# All answers are the correct answer, take the max f1 score
- return max(f1_scores)
\ No newline at end of file
+ return max(f1_scores)
+
+
+@reward(name="infoseek_reward")
+def infoseek_reward(prediction: str, answer: Union[str, List[str]], answer_eval: List[str | Dict], trajectory: List[str]) -> float:
+ f1_scores = []
+ answers = []
+ if isinstance(answer, str):
+ answers.append(answer)
+ elif isinstance(answer, list):
+ answers.extend(answer)
+
+ if isinstance(answer_eval[0], str):
+ answers.extend(answer_eval)
+
+ for _answer in answers:
+ f1, precision, recall = f1_score(prediction, _answer)
+ f1_scores.append(f1)
+
+ # All answers are the correct answer, take the max f1 score
+ return max(f1_scores)
diff --git a/agents/agents/rewards/scienceworld_reward.py b/agents/agents/rewards/scienceworld_reward.py
index bdef743..03c7922 100644
--- a/agents/agents/rewards/scienceworld_reward.py
+++ b/agents/agents/rewards/scienceworld_reward.py
@@ -1,4 +1,4 @@
-from agents.envs.scienceworld_env import ScienceWorldEnv
+from ..envs.scienceworld_env import ScienceWorldEnv
from .reward_base import reward
@reward(name="scienceworld_reward", env_cls=ScienceWorldEnv, pool_size=8)
diff --git a/agents/agents/rewards/webshop_reward.py b/agents/agents/rewards/webshop_reward.py
index 1bf0b1a..52ba44b 100644
--- a/agents/agents/rewards/webshop_reward.py
+++ b/agents/agents/rewards/webshop_reward.py
@@ -1,4 +1,4 @@
-from agents.envs.webshop_text_env import WebAgentTextEnv
+from ..envs.webshop_text_env import WebAgentTextEnv
from .reward_base import reward
@reward(name="webshop_reward", env_cls=WebAgentTextEnv, pool_size=8)
diff --git a/agents/agents/tools/__init__.py b/agents/agents/tools/__init__.py
index bd0f768..98d33ea 100644
--- a/agents/agents/tools/__init__.py
+++ b/agents/agents/tools/__init__.py
@@ -14,6 +14,7 @@
from .src.react.tools import answer_qa, answer_math
from .src.search.async_dense_retriever import asyncdense_retrieve
from .src.scienceworld.tools import scienceworld_explorer
+from .src.ui.tools import pyautogui_code_generator
# Export the tools
__all__ = [
@@ -36,6 +37,7 @@
"alfworld_get_task_objective"
"alfworld_reset"
"asyncdense_retrieve"
+ "pyautogui_code_generator"
# "current_env"
]
@@ -54,7 +56,8 @@
"answer_math": answer_math,
"hallucination_tool": hallucination_tool,
"invalid_input_tool": invalid_input_tool,
- "dense_retrieve": dense_retrieve
+ "dense_retrieve": dense_retrieve,
+ "pyautogui_code_generator": pyautogui_code_generator
}
# Update the registry with explicit tools
diff --git a/agents/agents/tools/src/ui/__init__.py b/agents/agents/tools/src/ui/__init__.py
new file mode 100644
index 0000000..f847595
--- /dev/null
+++ b/agents/agents/tools/src/ui/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
+# SPDX-License-Identifier: Apache-2.0
+
+from .tools import pyautogui_code_generator
+
+__all__ = ["pyautogui_code_generator"]
\ No newline at end of file
diff --git a/agents/agents/tools/src/ui/tools.py b/agents/agents/tools/src/ui/tools.py
new file mode 100644
index 0000000..02769ab
--- /dev/null
+++ b/agents/agents/tools/src/ui/tools.py
@@ -0,0 +1,94 @@
+# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
+# SPDX-License-Identifier: Apache-2.0
+
+import json
+from typing import Any
+from ...tool_base import tool
+from agents.utils.ui_action_parser import parsing_response_to_pyautogui_code
+
+# Default image dimensions for UI interactions
+DEFAULT_IMAGE_HEIGHT = 1080
+DEFAULT_IMAGE_WIDTH = 1920
+
+
+@tool(name="pyautogui_code_generator")
+def pyautogui_code_generator(action: dict, **kwargs) -> str:
+ """
+ Generate PyAutoGUI code from a structured action dictionary.
+
+ Args:
+ action (dict): Dictionary containing action_type, action_inputs, thought, etc.
+ **kwargs: Additional parameters like image dimensions
+
+ Returns:
+ PyAutoGUI code string or execution result
+ """
+ print(f"[pyautogui_code_generator] Received action: {action}")
+
+ # Extract image dimensions from kwargs or use defaults
+ image_height = kwargs.get("image_height", DEFAULT_IMAGE_HEIGHT)
+ image_width = kwargs.get("image_width", DEFAULT_IMAGE_WIDTH)
+
+ # Handle the action
+ if isinstance(action, str):
+ # Try to parse if it's a JSON string
+ try:
+ action = json.loads(action)
+ except json.JSONDecodeError:
+ return f"Error: Invalid action format - {action}"
+
+ if not isinstance(action, dict):
+ return f"Error: Action must be a dictionary, got {type(action)}"
+
+ # Check if this is a terminal action
+ action_type = action.get("action_type", "")
+ if action_type in ["finished", "call_user"]:
+ content = action.get("action_inputs", {}).get("content", "Task completed")
+ return f"Task completed: {content}"
+
+ # Generate PyAutoGUI code for the action
+ try:
+ pyautogui_code = parsing_response_to_pyautogui_code(
+ [action],
+ image_height=image_height,
+ image_width=image_width,
+ input_swap=True
+ )
+
+ # For non-terminal actions, return the code
+ if pyautogui_code == "DONE":
+ return "Task completed successfully"
+
+ return f"Generated PyAutoGUI code:\n{pyautogui_code}"
+
+ except Exception as e:
+ return f"Error generating PyAutoGUI code: {str(e)}"
+
+
+@tool(name="capture_screenshot")
+def capture_screenshot(**kwargs) -> str:
+ """
+ Capture a screenshot of the current screen.
+
+ Returns:
+ Base64 encoded screenshot or error message
+ """
+ try:
+ import pyautogui
+ import base64
+ from io import BytesIO
+
+ # Take screenshot
+ screenshot = pyautogui.screenshot()
+
+ # Convert to base64
+ buffered = BytesIO()
+ screenshot.save(buffered, format="PNG")
+ img_str = base64.b64encode(buffered.getvalue()).decode()
+
+ return f"Screenshot captured successfully (base64): {img_str[:100]}..."
+
+ except ImportError:
+ return "Error: PyAutoGUI not installed. Please install it to capture screenshots."
+ except Exception as e:
+ return f"Error capturing screenshot: {str(e)}"
\ No newline at end of file
diff --git a/agents/agents/utils/monitor.py b/agents/agents/utils/monitor.py
index fdcab41..504f1f5 100644
--- a/agents/agents/utils/monitor.py
+++ b/agents/agents/utils/monitor.py
@@ -24,6 +24,7 @@
from PIL import Image
import io
import base64
+import numpy as np
@dataclass(slots=True)
class MetricEvent:
@@ -74,7 +75,13 @@ def __repr__(self) -> str: # noqa: D401
def serialize_for_json(obj):
- if isinstance(obj, Image.Image):
+ if isinstance(obj, np.ndarray):
+ # Convert numpy array to list
+ return obj.tolist()
+ elif isinstance(obj, (np.integer, np.floating)):
+ # Convert numpy scalars to Python types
+ return obj.item()
+ elif isinstance(obj, Image.Image):
# Convert image to base64 string
buffer = io.BytesIO()
obj.save(buffer, format="PNG")
diff --git a/agents/agents/utils/ui_action_parser.py b/agents/agents/utils/ui_action_parser.py
new file mode 100644
index 0000000..e992fce
--- /dev/null
+++ b/agents/agents/utils/ui_action_parser.py
@@ -0,0 +1,423 @@
+# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
+# SPDX-License-Identifier: Apache-2.0
+import re
+import ast
+import math
+from typing import List, Dict, Any, Optional, Tuple
+
+IMAGE_FACTOR = 1 # Changed to match gui_reward.py
+MIN_PIXELS = 100 * 28 * 28
+MAX_PIXELS = 16384 * 28 * 28
+MAX_RATIO = 200
+
+
+def convert_point_to_coordinates(text: str, is_answer: bool = False) -> str:
+ """Convert point format to coordinates."""
+ pattern = r"(\d+)\s+(\d+)"
+
+ def replace_match(match):
+ x1, y1 = map(int, match.groups())
+ x = (x1 + x1) // 2 # Truncate to integer
+ y = (y1 + y1) // 2 # Truncate to integer
+ if is_answer:
+ return f"({x},{y})" # Only return (x, y) format
+ return f"({x},{y})" # Return the format with tags
+
+ # Remove [EOS] and replace coordinates
+ text = re.sub(r"\[EOS\]", "", text)
+ return re.sub(pattern, replace_match, text).strip()
+
+
+def parse_action(action_str: str) -> Optional[Dict[str, Any]]:
+ """Parse an action string into function name and arguments."""
+ try:
+ # Parse the string to AST node
+ node = ast.parse(action_str, mode='eval')
+
+ # Ensure the node is an expression
+ if not isinstance(node, ast.Expression):
+ raise ValueError("Not an expression")
+
+ # Get the body of the expression
+ call = node.body
+
+ # Ensure the body is a function call
+ if not isinstance(call, ast.Call):
+ raise ValueError("Not a function call")
+
+ # Get the function name
+ if isinstance(call.func, ast.Name):
+ func_name = call.func.id
+ elif isinstance(call.func, ast.Attribute):
+ func_name = call.func.attr
+ else:
+ func_name = None
+
+ # Get the keyword arguments
+ kwargs = {}
+ for kw in call.keywords:
+ key = kw.arg
+ # Handle different types of values
+ if isinstance(kw.value, ast.Constant):
+ value = kw.value.value
+ elif isinstance(kw.value, ast.Str): # Compatible with old version Python
+ value = kw.value.s
+ else:
+ value = None
+ kwargs[key] = value
+
+ return {'function': func_name, 'args': kwargs}
+
+ except Exception as e:
+ print(f"Failed to parse action '{action_str}': {e}")
+ return None
+
+
+def escape_single_quotes(text: str) -> str:
+ """Escape unescaped single quotes."""
+ pattern = r"(? int:
+ """Returns the closest integer to 'number' that is divisible by 'factor'."""
+ return round(number / factor) * factor
+
+
+def ceil_by_factor(number: int, factor: int) -> int:
+ """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
+ return math.ceil(number / factor) * factor
+
+
+def floor_by_factor(number: int, factor: int) -> int:
+ """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
+ return math.floor(number / factor) * factor
+
+
+def linear_resize(height: int,
+ width: int,
+ factor: int = IMAGE_FACTOR,
+ min_pixels: int = MIN_PIXELS,
+ max_pixels: int = MAX_PIXELS) -> Tuple[int, int]:
+ """Resize image to fit within pixel limits while maintaining aspect ratio."""
+ if width * height > max_pixels:
+ resize_factor = math.sqrt(max_pixels / (width * height))
+ width, height = int(width * resize_factor), int(height * resize_factor)
+ if width * height < min_pixels:
+ resize_factor = math.sqrt(min_pixels / (width * height))
+ width, height = math.ceil(width * resize_factor), math.ceil(height * resize_factor)
+ return height, width
+
+
+def smart_resize(height: int,
+ width: int,
+ factor: int = IMAGE_FACTOR,
+ min_pixels: int = MIN_PIXELS,
+ max_pixels: int = MAX_PIXELS) -> Tuple[int, int]:
+ """
+ Rescales the image so that:
+ 1. Both dimensions are divisible by 'factor'.
+ 2. Total pixels is within [min_pixels, max_pixels].
+ 3. Aspect ratio is maintained as closely as possible.
+ """
+ if max(height, width) / min(height, width) > MAX_RATIO:
+ raise ValueError(
+ f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
+ )
+ h_bar = max(factor, round_by_factor(height, factor))
+ w_bar = max(factor, round_by_factor(width, factor))
+ if h_bar * w_bar > max_pixels:
+ beta = math.sqrt((height * width) / max_pixels)
+ h_bar = floor_by_factor(height / beta, factor)
+ w_bar = floor_by_factor(width / beta, factor)
+ elif h_bar * w_bar < min_pixels:
+ beta = math.sqrt(min_pixels / (height * width))
+ h_bar = ceil_by_factor(height * beta, factor)
+ w_bar = ceil_by_factor(width * beta, factor)
+ return h_bar, w_bar
+
+
+def parse_action_to_structure_output(text: str,
+ factor: int,
+ origin_resized_height: int,
+ origin_resized_width: int,
+ model_type: str = "qwen25vl",
+ max_pixels: int = 16384 * 28 * 28,
+ min_pixels: int = 100 * 28 * 28) -> Optional[List[Dict[str, Any]]]:
+ """Parse action text to structured output."""
+ print(f"[parse_action_to_structure_output] Input text: {text[:500] if text else 'Empty text'}")
+
+ # Handle empty or None responses
+ if not text:
+ print(f"[parse_action_to_structure_output] Empty text, returning None")
+ return None
+
+ text = text.strip()
+
+ # Handle various point/box formats
+ if "" in text:
+ text = convert_point_to_coordinates(text)
+ if "start_point=" in text:
+ text = text.replace("start_point=", "start_box=")
+ if "end_point=" in text:
+ text = text.replace("end_point=", "end_box=")
+ if "point=" in text:
+ text = text.replace("point=", "start_box=")
+
+ smart_resize_height, smart_resize_width = origin_resized_height, origin_resized_width
+ if model_type == "qwen25vl":
+ smart_resize_height, smart_resize_width = smart_resize(
+ origin_resized_height,
+ origin_resized_width,
+ factor=IMAGE_FACTOR,
+ min_pixels=min_pixels,
+ max_pixels=max_pixels)
+
+ # Extract thought and reflection
+ if text.startswith("Thought:"):
+ thought_pattern = r"Thought: (.+?)(?=\s*Action: |$)"
+ thought_hint = "Thought: "
+ elif text.startswith("Reflection:"):
+ thought_pattern = r"Reflection: (.+?)Action_Summary: (.+?)(?=\s*Action: |$)"
+ thought_hint = "Reflection: "
+ elif text.startswith("Action_Summary:"):
+ thought_pattern = r"Action_Summary: (.+?)(?=\s*Action: |$)"
+ thought_hint = "Action_Summary: "
+ else:
+ thought_pattern = r"Thought: (.+?)(?=\s*Action: |$)"
+ thought_hint = "Thought: "
+
+ reflection, thought = None, None
+ thought_match = re.search(thought_pattern, text, re.DOTALL)
+ if thought_match:
+ if len(thought_match.groups()) == 1:
+ thought = thought_match.group(1).strip()
+ elif len(thought_match.groups()) == 2:
+ thought = thought_match.group(2).strip()
+ reflection = thought_match.group(1).strip()
+
+ if "Action:" not in text:
+ print(f"[parse_action_to_structure_output] No 'Action:' found in text, returning None")
+ return None
+
+ action_str = text.split("Action: ")[-1]
+ print(f"[parse_action_to_structure_output] Extracted action string: {action_str[:200]}")
+
+ # Parse multiple actions
+ tmp_all_action = action_str.split(")\n\n")
+ all_action = []
+ for action_str in tmp_all_action:
+ if "type(content" in action_str:
+ if not action_str.strip().endswith(")"):
+ action_str = action_str.strip() + ")"
+ # Handle type content escaping
+ pattern = r"type\(content='(.*?)'\)"
+ if re.search(pattern, action_str):
+ content = re.sub(pattern, lambda m: m.group(1), action_str)
+ action_str = escape_single_quotes(content)
+ action_str = "type(content='" + action_str + "')"
+ if not action_str.strip().endswith(")"):
+ action_str = action_str.strip() + ")"
+ all_action.append(action_str)
+
+ parsed_actions = [
+ parse_action(action.replace("\n", "\\n").lstrip())
+ for action in all_action
+ ]
+
+ actions = []
+ for action_instance, raw_str in zip(parsed_actions, all_action):
+ if action_instance is None:
+ print(f"Action can't parse: {raw_str}")
+ continue
+
+ action_type = action_instance["function"]
+ params = action_instance["args"]
+
+ action_inputs = {}
+ for param_name, param in params.items():
+ if param is None or param == "":
+ continue
+
+ # Only apply lstrip to string parameters
+ if isinstance(param, str):
+ param = param.lstrip()
+
+ action_inputs[param_name.strip()] = param
+
+ # Handle box coordinates (only for string parameters)
+ if isinstance(param, str) and ("start_box" in param_name or "end_box" in param_name):
+ ori_box = param
+ # Remove box tags if present
+ ori_box = ori_box.replace("<|box_start|>", "").replace("<|box_end|>", "")
+ numbers = ori_box.replace("(", "").replace(")", "").split(",")
+
+ try:
+ for num in numbers:
+ float(num.strip())
+ except ValueError:
+ print(f"Warning: Invalid coordinate format in '{param_name}': '{ori_box}'")
+ return None
+
+ # Convert coordinates based on model type
+ if model_type == "qwen25vl":
+ float_numbers = []
+ for num_idx, num in enumerate(numbers):
+ num = float(num)
+ if (num_idx + 1) % 2 == 0:
+ float_numbers.append(float(num / smart_resize_height))
+ else:
+ float_numbers.append(float(num / smart_resize_width))
+ else:
+ # For IMAGE_FACTOR = 1, keep coordinates as pixel values
+ float_numbers = [float(num.strip()) for num in numbers]
+
+ if len(float_numbers) == 2:
+ float_numbers = [
+ float_numbers[0], float_numbers[1],
+ float_numbers[0], float_numbers[1]
+ ]
+ action_inputs[param_name.strip()] = str(float_numbers)
+
+ # Normalize action types for consistency
+ normalized_action_type = action_type
+ if action_type in ["left_single", "left_double", "right_single"]:
+ normalized_action_type = "click"
+ elif action_type in ["press", "keydown", "release", "keyup"]:
+ normalized_action_type = "hotkey"
+ elif action_type in ["select"]:
+ normalized_action_type = "drag"
+
+ actions.append({
+ "reflection": reflection,
+ "thought": thought,
+ "action_type": normalized_action_type,
+ "action_inputs": action_inputs,
+ "text": text,
+ })
+
+ return actions
+
+
+def parsing_response_to_pyautogui_code(responses: List[Dict[str, Any]],
+ image_height: int,
+ image_width: int,
+ input_swap: bool = True) -> str:
+ """Convert parsed responses to PyAutoGUI code."""
+ pyautogui_code = f"import pyautogui\nimport time\n"
+ if isinstance(responses, dict):
+ responses = [responses]
+
+ for response_id, response in enumerate(responses):
+ observation = response.get("observation", "")
+ thought = response.get("thought", "")
+
+ if response_id == 0:
+ pyautogui_code += f"'''\nObservation:\n{observation}\n\nThought:\n{thought}\n'''\n"
+ else:
+ pyautogui_code += f"\ntime.sleep(1)\n"
+
+ action_type = response.get("action_type")
+ action_inputs = response.get("action_inputs", {})
+
+ if action_type == "hotkey":
+ hotkey = action_inputs.get("key", action_inputs.get("hotkey", ""))
+ # Convert arrow keys
+ hotkey = hotkey.replace("arrowleft", "left").replace("arrowright", "right")
+ hotkey = hotkey.replace("arrowup", "up").replace("arrowdown", "down")
+
+ if hotkey:
+ keys = hotkey.split()
+ convert_keys = []
+ for key in keys:
+ if key == "space":
+ key = ' '
+ convert_keys.append(key)
+ pyautogui_code += f"\npyautogui.hotkey({', '.join([repr(k) for k in convert_keys])})"
+
+ elif action_type == "type":
+ content = action_inputs.get("content", "")
+ content = escape_single_quotes(content)
+ stripped_content = content.rstrip("\\n").rstrip("\n")
+
+ if content:
+ if input_swap:
+ pyautogui_code += f"\nimport pyperclip"
+ pyautogui_code += f"\npyperclip.copy('{stripped_content}')"
+ pyautogui_code += f"\npyautogui.hotkey('ctrl', 'v')"
+ pyautogui_code += f"\ntime.sleep(0.5)\n"
+ if content.endswith("\n") or content.endswith("\\n"):
+ pyautogui_code += f"\npyautogui.press('enter')"
+ else:
+ pyautogui_code += f"\npyautogui.write('{stripped_content}', interval=0.1)"
+ pyautogui_code += f"\ntime.sleep(0.5)\n"
+ if content.endswith("\n") or content.endswith("\\n"):
+ pyautogui_code += f"\npyautogui.press('enter')"
+
+ elif action_type in ["drag", "select"]:
+ start_box = action_inputs.get("start_box")
+ end_box = action_inputs.get("end_box")
+ if start_box and end_box:
+ x1, y1, x2, y2 = eval(start_box)
+ sx = round(float((x1 + x2) / 2) * image_width, 3)
+ sy = round(float((y1 + y2) / 2) * image_height, 3)
+ x1, y1, x2, y2 = eval(end_box)
+ ex = round(float((x1 + x2) / 2) * image_width, 3)
+ ey = round(float((y1 + y2) / 2) * image_height, 3)
+ pyautogui_code += (
+ f"\npyautogui.moveTo({sx}, {sy})\n"
+ f"\npyautogui.dragTo({ex}, {ey}, duration=1.0)\n")
+
+ elif action_type == "scroll":
+ start_box = action_inputs.get("start_box")
+ if start_box:
+ x1, y1, x2, y2 = eval(start_box)
+ x = round(float((x1 + x2) / 2) * image_width, 3)
+ y = round(float((y1 + y2) / 2) * image_height, 3)
+ else:
+ x = None
+ y = None
+
+ direction = action_inputs.get("direction", "")
+
+ if x is None:
+ if "up" in direction.lower():
+ pyautogui_code += f"\npyautogui.scroll(5)"
+ elif "down" in direction.lower():
+ pyautogui_code += f"\npyautogui.scroll(-5)"
+ else:
+ if "up" in direction.lower():
+ pyautogui_code += f"\npyautogui.scroll(5, x={x}, y={y})"
+ elif "down" in direction.lower():
+ pyautogui_code += f"\npyautogui.scroll(-5, x={x}, y={y})"
+
+ elif action_type in ["click", "left_single", "left_double", "right_single", "hover"]:
+ start_box = action_inputs.get("start_box")
+ start_box = str(start_box)
+ if start_box:
+ start_box = eval(start_box)
+ if len(start_box) == 4:
+ x1, y1, x2, y2 = start_box
+ elif len(start_box) == 2:
+ x1, y1 = start_box
+ x2 = x1
+ y2 = y1
+ x = round(float((x1 + x2) / 2) * image_width, 3)
+ y = round(float((y1 + y2) / 2) * image_height, 3)
+
+ if action_type == "left_single" or action_type == "click":
+ pyautogui_code += f"\npyautogui.click({x}, {y}, button='left')"
+ elif action_type == "left_double":
+ pyautogui_code += f"\npyautogui.doubleClick({x}, {y}, button='left')"
+ elif action_type == "right_single":
+ pyautogui_code += f"\npyautogui.click({x}, {y}, button='right')"
+ elif action_type == "hover":
+ pyautogui_code += f"\npyautogui.moveTo({x}, {y})"
+
+ elif action_type in ["finished"]:
+ pyautogui_code = f"DONE"
+
+ else:
+ pyautogui_code += f"\n# Unrecognized action type: {action_type}"
+
+ return pyautogui_code
\ No newline at end of file
diff --git a/agents/pytest.ini b/agents/pytest.ini
new file mode 100644
index 0000000..f476b91
--- /dev/null
+++ b/agents/pytest.ini
@@ -0,0 +1,6 @@
+[pytest]
+asyncio_mode = auto
+testpaths = tests
+python_files = test_*.py
+python_classes = Test*
+python_functions = test_*
\ No newline at end of file
diff --git a/agents/tests/unit/agents/test_gui_agent.py b/agents/tests/unit/agents/test_gui_agent.py
new file mode 100644
index 0000000..bd1eca7
--- /dev/null
+++ b/agents/tests/unit/agents/test_gui_agent.py
@@ -0,0 +1,201 @@
+#!/usr/bin/env python3
+# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
+# SPDX-License-Identifier: Apache-2.0
+
+import pytest
+import sys
+from pathlib import Path
+
+# Add the agents module to path
+sys.path.insert(0, str(Path(__file__).parent.parent.parent))
+
+from agents.agents.specialized.gui_agent import GUIAgent
+from agents.rewards.gui_reward import gui_reward
+from agents.utils.ui_action_parser import parse_action_to_structure_output, IMAGE_FACTOR
+
+
+class TestGUIAgent:
+ """Test suite for GUI Agent implementation."""
+
+ def test_gui_agent_initialization(self):
+ """Test GUI agent can be initialized."""
+ # Skip loading actual model for unit test
+ agent = GUIAgent(
+ model_name_or_path="ByteDance-Seed/UI-TARS-1.5-7B",
+ template="qwen2.5-vl",
+ tools=["pyautogui_code_generator"]
+ )
+ assert agent is not None
+ assert agent.system_prompt is not None
+ assert "GUI automation agent" in agent.system_prompt
+
+ def test_gui_agent_parse_valid_response(self):
+ """Test GUI agent can parse valid responses."""
+ # Skip loading actual model for unit test
+ agent = GUIAgent(
+ model_name_or_path="ByteDance-Seed/UI-TARS-1.5-7B",
+ template="qwen2.5-vl",
+ tools=[]
+ )
+
+ responses = [
+ "Thought: I need to click on the button.\nAction: click(start_box='<|box_start|>(100,200)<|box_end|>')"
+ ]
+
+ messages = agent.parse(responses, tools=[])
+
+ assert len(messages) == 1
+ assert messages[0]["role"] == "assistant"
+ assert messages[0]["content"][0]["text"] == responses[0]
+ assert len(messages[0]["tool_calls"]) == 1
+ assert messages[0]["status"] == "continue"
+
+ def test_gui_agent_parse_terminal_action(self):
+ """Test GUI agent recognizes terminal actions."""
+ # Skip loading actual model for unit test
+ agent = GUIAgent(
+ model_name_or_path="ByteDance-Seed/UI-TARS-1.5-7B",
+ template="qwen2.5-vl",
+ tools=[]
+ )
+
+ responses = [
+ "Thought: Task is complete.\nAction: finished(content='Task completed successfully')"
+ ]
+
+ messages = agent.parse(responses, tools=[])
+
+ assert len(messages) == 1
+ assert messages[0]["status"] == "terminal"
+
+ def test_gui_agent_parse_empty_response(self):
+ """Test GUI agent handles empty responses gracefully."""
+ # Skip loading actual model for unit test
+ agent = GUIAgent(
+ model_name_or_path="ByteDance-Seed/UI-TARS-1.5-7B",
+ template="qwen2.5-vl",
+ tools=[]
+ )
+
+ responses = [""]
+
+ messages = agent.parse(responses, tools=[])
+
+ assert len(messages) == 1
+ assert "wait" in messages[0]["content"][0]["text"].lower()
+ assert messages[0]["status"] == "continue"
+
+
+class TestGUIReward:
+ """Test suite for GUI reward function."""
+
+ @pytest.mark.asyncio
+ async def test_gui_reward_with_ground_truth(self):
+ """Test GUI reward with ground truth data."""
+ prediction = "Thought: I need to click on the button.\nAction: click(start_box='<|box_start|>(100,200)<|box_end|>')"
+
+ result = await gui_reward(
+ prediction=prediction,
+ gt_action="click",
+ gt_bbox=[100, 200],
+ gt_input_text=""
+ )
+
+ assert isinstance(result, dict)
+ assert "reward" in result
+ assert "format" in result
+ assert "accuracy" in result
+ assert result["format"] == 1.0 # Valid format
+ assert result["accuracy"] == 1.0 # Exact match
+ assert result["reward"] > 0.9 # High reward for perfect match
+
+ @pytest.mark.asyncio
+ async def test_gui_reward_without_ground_truth(self):
+ """Test GUI reward without ground truth data."""
+ prediction = "Thought: I need to click.\nAction: click(start_box='<|box_start|>(100,200)<|box_end|>')"
+
+ result = await gui_reward(prediction=prediction)
+
+ assert isinstance(result, dict)
+ assert result["reward"] == 0.0 # No ground truth
+
+ @pytest.mark.asyncio
+ async def test_gui_reward_empty_prediction(self):
+ """Test GUI reward with empty prediction."""
+ result = await gui_reward(
+ prediction="",
+ gt_action="click",
+ gt_bbox=[100, 200],
+ gt_input_text=""
+ )
+
+ assert isinstance(result, dict)
+ assert result["format"] == 0.0 # Invalid format
+
+ @pytest.mark.asyncio
+ async def test_gui_reward_type_action(self):
+ """Test GUI reward for typing action."""
+ prediction = "Thought: I need to type text.\nAction: type(content='hello world')"
+
+ result = await gui_reward(
+ prediction=prediction,
+ gt_action="type",
+ gt_bbox=[],
+ gt_input_text="hello world"
+ )
+
+ assert isinstance(result, dict)
+ assert result["format"] == 1.0
+ assert result["accuracy"] == 1.0 # Text matches
+
+
+class TestUIActionParser:
+ """Test suite for UI action parser."""
+
+ def test_parse_click_action(self):
+ """Test parsing click action."""
+ text = "Thought: Click button.\nAction: click(start_box='<|box_start|>(100,200)<|box_end|>')"
+
+ result = parse_action_to_structure_output(
+ text, IMAGE_FACTOR, 1080, 1920
+ )
+
+ assert result is not None
+ assert len(result) == 1
+ assert result[0]["action_type"] == "click"
+ assert "start_box" in result[0]["action_inputs"]
+
+ def test_parse_type_action(self):
+ """Test parsing type action."""
+ text = "Thought: Type text.\nAction: type(content='hello world')"
+
+ result = parse_action_to_structure_output(
+ text, IMAGE_FACTOR, 1080, 1920
+ )
+
+ assert result is not None
+ assert len(result) == 1
+ assert result[0]["action_type"] == "type"
+ assert result[0]["action_inputs"]["content"] == "hello world"
+
+ def test_parse_invalid_action(self):
+ """Test parsing invalid action returns None."""
+ text = "This is not a valid action format"
+
+ result = parse_action_to_structure_output(
+ text, IMAGE_FACTOR, 1080, 1920
+ )
+
+ assert result is None
+
+ def test_parse_empty_text(self):
+ """Test parsing empty text returns None."""
+ result = parse_action_to_structure_output(
+ "", IMAGE_FACTOR, 1080, 1920
+ )
+
+ assert result is None
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
\ No newline at end of file
diff --git a/verl b/verl
index 237d9ca..0ba7136 160000
--- a/verl
+++ b/verl
@@ -1 +1 @@
-Subproject commit 237d9cacd2ede001c21f1a1daa44e8e8598993e1
+Subproject commit 0ba71360604c85ca7e83168520169fa858681633