From cd5409d9fefcda6b385955fb1b33cd24fa3a5d05 Mon Sep 17 00:00:00 2001 From: shangkunwang Date: Tue, 19 May 2026 21:38:20 +0000 Subject: [PATCH 1/8] feat: implement AutotuneAgent with planner, runner, and summarizer subagents to automate Pallas kernel optimization. --- .../auto_agent/server_utils/eval_server.py | 35 ++-- .../auto_agent/server_utils/tpu_server.py | 125 ++++++++++++++ .../auto_agent/subagents/autotuning/agent.py | 160 ++++++++++++++++++ .../subagents/autotuning/autotune_tool.py | 128 ++++++++++++++ .../autotuning/prompts/autotune_prompt.py | 26 +++ .../autotuning/prompts/summary_prompt.py | 17 ++ .../auto_agent/subagents/pipeline_agent.py | 16 ++ MaxKernel/auto_agent/tools/file_tools.py | 4 + 8 files changed, 500 insertions(+), 11 deletions(-) create mode 100644 MaxKernel/auto_agent/subagents/autotuning/agent.py create mode 100644 MaxKernel/auto_agent/subagents/autotuning/autotune_tool.py create mode 100644 MaxKernel/auto_agent/subagents/autotuning/prompts/autotune_prompt.py create mode 100644 MaxKernel/auto_agent/subagents/autotuning/prompts/summary_prompt.py diff --git a/MaxKernel/auto_agent/server_utils/eval_server.py b/MaxKernel/auto_agent/server_utils/eval_server.py index c3418bf..72d50b8 100644 --- a/MaxKernel/auto_agent/server_utils/eval_server.py +++ b/MaxKernel/auto_agent/server_utils/eval_server.py @@ -62,12 +62,15 @@ class EvalTypes(Enum): PERFORMANCE_TEST = "performance_test" UNIFIED_TEST = "unified_test" PROFILE = "profile" + AUTOTUNE = "autotune" class EvalRequest(BaseModel): eval_type: EvalTypes - code: str - timeout: Optional[int] = 30 + timeout: int + code: Optional[str] = None + code_template: Optional[str] = None # For autotune + search_space: Optional[dict] = None # For autotune backend_type: Optional[str] = None # "tpu", "cpu", or None for any available dependencies: Optional[dict] = None @@ -146,18 +149,28 @@ async def evaluate(request: EvalRequest): f"{requested_type_msg}" ) - # Send request to backend server - backend_timeout = request.timeout if request.timeout is not None else 30 - client_timeout = aiohttp.ClientTimeout(total=backend_timeout + 10) + # Construct payload based on eval type + payload = { + "eval_type": request.eval_type.value, + "timeout": request.timeout, + } + if request.eval_type == EvalTypes.AUTOTUNE: + payload["code_template"] = request.code_template + payload["search_space"] = request.search_space + else: + payload["code"] = request.code + payload["dependencies"] = request.dependencies + + backend_timeout = ( + request.total_timeout + if request.total_timeout is not None + else request.timeout + ) + client_timeout = aiohttp.ClientTimeout(total=backend_timeout) async with aiohttp.ClientSession(timeout=client_timeout) as session: async with session.post( f"http://{backend_ip}:{backend_port}/{request.eval_type.value}", - json={ - "eval_type": request.eval_type.value, - "code": request.code, - "timeout": request.timeout, - "dependencies": request.dependencies, - }, + json=payload, ) as response: result = await response.json() logging.info( diff --git a/MaxKernel/auto_agent/server_utils/tpu_server.py b/MaxKernel/auto_agent/server_utils/tpu_server.py index 3377677..fa1a6a4 100644 --- a/MaxKernel/auto_agent/server_utils/tpu_server.py +++ b/MaxKernel/auto_agent/server_utils/tpu_server.py @@ -1,4 +1,5 @@ import asyncio +import itertools import json import logging import os @@ -27,6 +28,7 @@ correctness_semaphore = asyncio.Semaphore(1) performance_semaphore = asyncio.Semaphore(1) profile_semaphore = asyncio.Semaphore(1) +autotune_semaphore = asyncio.Semaphore(1) class CodeRequest(BaseModel): @@ -41,6 +43,12 @@ class CodeResponse(BaseModel): exit_code: int +class AutotuneRequest(BaseModel): + code_template: str + search_space: dict[str, list] + timeout: Optional[int] = 300 + + class GetTpuVersionResponse(BaseModel): tpu_version: str @@ -498,6 +506,123 @@ async def profile(request: CodeRequest): logging.info("Profile analysis finished") +@app.post("/autotune", response_model=CodeResponse) +async def autotune(request: AutotuneRequest): + logging.info("Starting autotune") + async with performance_semaphore: + try: + # Generate all combinations + keys = list(request.search_space.keys()) + values = list(request.search_space.values()) + combinations = list(itertools.product(*values)) + + best_time = float("inf") + best_cfg = None + best_output = "" + all_results = [] + + for combo in combinations: + cfg = dict(zip(keys, combo)) + try: + code_content = request.code_template + for k, v in cfg.items(): + code_content = code_content.replace(f"{{{k}}}", str(v)) + except Exception as e: + logging.error(f"Error during template formatting: {e}. Config: {cfg}") + continue + + # Execute the code + with tempfile.NamedTemporaryFile( + mode="w", suffix=".py", prefix="hitl_eval_", delete=False + ) as temp_file: + temp_file.write(code_content) + temp_file_path = temp_file.name + + process = None + try: + process = await asyncio.create_subprocess_exec( + sys.executable, + temp_file_path, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=tempfile.gettempdir(), + ) + + stdout, stderr = await asyncio.wait_for( + process.communicate(), timeout=request.timeout + ) + + output = stdout.decode("utf-8") if stdout else "" + error = stderr.decode("utf-8") if stderr else "" + exit_code = process.returncode + + if exit_code == 0: + # Parse RESULT_TIME + match = re.search(r"RESULT_TIME:\s*([0-9.]+)", output) + if match: + time_taken = float(match.group(1)) + all_results.append( + {"cfg": cfg, "time": time_taken, "status": "success"} + ) + if time_taken < best_time: + best_time = time_taken + best_cfg = cfg + best_output = output + else: + logging.warning( + f"No RESULT_TIME found in output for config {cfg}" + ) + all_results.append( + {"cfg": cfg, "status": "no_result_time", "output": output} + ) + else: + logging.warning( + f"Config {cfg} failed with exit code {exit_code}. Stderr: {error}" + ) + all_results.append( + { + "cfg": cfg, + "status": "failed", + "error": error, + "exit_code": exit_code, + } + ) + + except asyncio.TimeoutError: + logging.warning(f"Config {cfg} timed out") + if process: + process.kill() + await process.wait() + all_results.append({"cfg": cfg, "status": "timeout"}) + except Exception as e: + logging.error(f"Error running config {cfg}: {e}") + all_results.append( + {"cfg": cfg, "status": "exception", "error": str(e)} + ) + finally: + if "temp_file_path" in locals(): + try: + os.unlink(temp_file_path) + except OSError: + pass + + return CodeResponse( + output=json.dumps( + { + "best_cfg": best_cfg, + "best_time": best_time, + "best_output": best_output, + "all_results": all_results, + } + ), + exit_code=0 if best_cfg is not None else 1, + ) + + except Exception as e: + logging.error(f"Autotune failed with error: {str(e)}") + raise HTTPException(status_code=500, detail=f"Autotune error: {str(e)}") + + @app.post("/get_tpu_version", response_model=GetTpuVersionResponse) def get_tpu_version() -> str: """Attempts to determine the TPU version by trying three methods. diff --git a/MaxKernel/auto_agent/subagents/autotuning/agent.py b/MaxKernel/auto_agent/subagents/autotuning/agent.py new file mode 100644 index 0000000..b1f46a6 --- /dev/null +++ b/MaxKernel/auto_agent/subagents/autotuning/agent.py @@ -0,0 +1,160 @@ +"""Autotuning agent following the split pattern (Planner + Runner).""" + +import json +import logging +import os +from typing import AsyncGenerator + +from google.adk.agents import BaseAgent, SequentialAgent +from google.adk.agents.invocation_context import InvocationContext +from google.adk.events import Event, EventActions + +from auto_agent.config import model_config, thinking_planner +from auto_agent.constants import MODEL_NAME +from auto_agent.custom_types import CustomLlmAgent +from auto_agent.subagents.autotuning.autotune_tool import autotune_kernel +from auto_agent.subagents.autotuning.prompts import ( + autotune_prompt, + summary_prompt, +) +from auto_agent.tools.search_api_tool import search_api_tool +from auto_agent.tools.tools import filesystem_tool_r, write_autotune_specs_tool + +# 1. Planner Agent (LLM) +# This agent identifies parameters, creates the template, and defines the search space. +# It saves them to session state instead of calling the tool directly. +autotune_planner_agent = CustomLlmAgent( + name="AutotunePlannerAgent", + model=MODEL_NAME, + generate_content_config=model_config, + planner=thinking_planner, + instruction=autotune_prompt.PROMPT, + description="Prepares code template and search space for auto-tuning Pallas kernels.", + tools=[filesystem_tool_r, write_autotune_specs_tool, search_api_tool], +) + + +# 2. Runner Agent +class AutotuneRunner(BaseAgent): + """Executes autotuning via HTTP endpoint.""" + + def __init__( + self, + name: str, + output_key: str, + ): + BaseAgent.__init__(self, name=name) + self.output_key = output_key + + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + + autotune_specs_path = ctx.session.state.get("autotune_specs_path", "") + autotune_results_path = ctx.session.state.get("autotune_results_path", "") + + if not os.path.exists(autotune_specs_path): + error_msg = f"Autotune specs file not found at {autotune_specs_path}" + logging.error(error_msg) + yield Event( + author=self.name, + actions=EventActions( + state_delta={ + self.output_key: {"status": "error", "message": error_msg} + } + ), + ) + return + + try: + with open(autotune_specs_path, "r") as f: + specs = json.load(f) + kernel_name = specs.get("kernel_name", "") + code_template = specs.get("code_template", "") + search_space = specs.get("search_space", {}) + except Exception as e: + error_msg = f"Failed to parse autotune specs JSON: {e}" + logging.error(error_msg) + yield Event( + author=self.name, + actions=EventActions( + state_delta={ + self.output_key: {"status": "error", "message": error_msg} + } + ), + ) + return + + if not kernel_name or not code_template or not search_space: + error_msg = "Missing required inputs in autotune specs file." + logging.error(error_msg) + yield Event( + author=self.name, + actions=EventActions( + state_delta={ + self.output_key: {"status": "error", "message": error_msg} + } + ), + ) + return + + logging.info(f"[{self.name}] Running autotune for {kernel_name}") + + try: + results = autotune_kernel( + kernel_name=kernel_name, + code_template=code_template, + search_space=search_space, + backend="tpu", + ) + + try: + with open(autotune_results_path, "w") as f: + json.dump(results, f) + logging.info(f"[{self.name}] Saved results to {autotune_results_path}") + except Exception as e: + logging.error(f"[{self.name}] Failed to save results to file: {e}") + + yield Event( + author=self.name, + actions=EventActions( + state_delta={ + self.output_key: results, + } + ), + ) + + except Exception as e: + logging.error(f"Exception during autotuning: {e}") + yield Event( + author=self.name, + actions=EventActions( + state_delta={self.output_key: {"status": "error", "message": str(e)}} + ), + ) + + +autotune_runner = AutotuneRunner( + name="AutotuneRunner", + output_key="autotune_results", +) + +# 3. Summarizer Agent (LLM) +# This agent reads results from state and talks to the user. +autotune_summary_agent = CustomLlmAgent( + name="AutotuneSummaryAgent", + model=MODEL_NAME, + generate_content_config=model_config, + planner=thinking_planner, + instruction=summary_prompt.PROMPT, + description="Summarizes autotuning results for the user.", + tools=[filesystem_tool_r], +) + +# 4. Combined Sequential Agent +autotune_agent = SequentialAgent( + name="AutotuneAgent", + sub_agents=[autotune_planner_agent, autotune_runner, autotune_summary_agent], +) + +__all__ = ["autotune_agent"] diff --git a/MaxKernel/auto_agent/subagents/autotuning/autotune_tool.py b/MaxKernel/auto_agent/subagents/autotuning/autotune_tool.py new file mode 100644 index 0000000..7ad8ca7 --- /dev/null +++ b/MaxKernel/auto_agent/subagents/autotuning/autotune_tool.py @@ -0,0 +1,128 @@ +"""Standalone tool for auto-tuning Pallas kernels using grid search on remote servers.""" + +import json +import logging +import subprocess +from typing import Any + +import requests + +from auto_agent.constants import EVAL_SERVER_PORT + +AUTOTUNE_TIMEOUT = 5400 + + +def autotune_kernel( + kernel_name: str, + code_template: str, + search_space: dict[str, list[Any]], + backend: str = None, + server_addr: str = "http://localhost", +) -> dict: + """Runs a grid search to auto-tune a Pallas kernel on a remote server. + + Args: + kernel_name: Name of the kernel. + code_template: Python code containing placeholders for parameters to be + tuned. It should produce a line like "RESULT_TIME: " in its + output to indicate performance. + search_space: A dictionary mapping placeholder names to lists of feasible + values. + backend: 'tpu' or 'cpu'. + server_addr: Address of the server (default: http://localhost). + + Returns: + A dictionary containing the status, optimal parameters, and a summary of + results. + """ + logging.info( + f"Starting remote autotuning for kernel: {kernel_name} on {backend}" + ) + + url = f"{server_addr}:{EVAL_SERVER_PORT}/evaluate" + + try: + response = requests.post( + url, + json={ + "eval_type": "autotune", + "code_template": code_template, + "search_space": search_space, + "timeout": 300, # timeout for each individual evaluation + "backend_type": backend, + "total_timeout": AUTOTUNE_TIMEOUT, + }, + timeout=AUTOTUNE_TIMEOUT + 10, + ) + + if response.status_code == 200: + result = response.json() + if result["exit_code"] == 0: + try: + output_data = json.loads(result["output"]) + logging.info( + f"Autotuning completed. Best config: {output_data['best_cfg']}" + f" with time {output_data['best_time']} ms" + ) + return { + "status": "success", + "message": "Autotuning completed", + "best_config": output_data["best_cfg"], + "best_time_ms": output_data["best_time"], + "best_output": output_data["best_output"], + "all_results": output_data.get("all_results", []), + } + except json.JSONDecodeError: + logging.warning("Failed to decode JSON from server output.") + return { + "status": "success", + "message": "Autotuning completed (raw output)", + "raw_output": result["output"], + } + else: + try: + output_data = json.loads(result["output"]) + return { + "status": "failed", + "message": result["error"] or "Autotune failed on server", + "all_results": output_data.get("all_results", []), + } + except Exception: + return { + "status": "failed", + "message": result["error"] or "Autotune failed on server", + "server_output": result["output"], + } + else: + return { + "status": "error", + "message": ( + f"Server returned status code {response.status_code}: {response.text}" + ), + } + + except requests.exceptions.ConnectionError: + return { + "status": "error", + "message": ( + f"Could not connect to server at {url}. Make sure it is running." + ), + } + except requests.exceptions.Timeout: + logging.warning( + "Autotune timed out on client side. Cleaning up dangling subprocesses on server..." + ) + try: + subprocess.run(["pkill", "-f", "/tmp/hitl_eval_.*\\.py"], check=False) + except Exception as cleanup_error: + logging.error(f"Failed to run cleanup commands: {cleanup_error}") + + return { + "status": "error", + "message": ( + f"Autotune request timed out after {AUTOTUNE_TIMEOUT} seconds. " + "Dangling processes were killed." + ), + } + except Exception as e: + return {"status": "error", "message": str(e)} diff --git a/MaxKernel/auto_agent/subagents/autotuning/prompts/autotune_prompt.py b/MaxKernel/auto_agent/subagents/autotuning/prompts/autotune_prompt.py new file mode 100644 index 0000000..16fef95 --- /dev/null +++ b/MaxKernel/auto_agent/subagents/autotuning/prompts/autotune_prompt.py @@ -0,0 +1,26 @@ +"""Prompt for AutotuneAgent.""" + +PROMPT = """You are a specialized agent for preparing autotuning specifications for Pallas kernels. +Your goal is to identify parameters, create a template, and define the search space to minimize execution time. + +To prepare for autotuning, you must: +1. Identify the parameters that can be tuned in the kernel (e.g., BLOCK_M, BLOCK_N). +2. Create a code template from the kernel code, replacing the specific parameter values with placeholders enclosed in curly braces (for example, if the parameter is BLOCK_M, use it enclosed in curly braces as the placeholder). +3. Ensure the template code prints "RESULT_TIME: " to indicate the execution time. You may need to wrap the kernel call in a loop or use `jax.block_until_ready()` to get accurate timing. WARNING: If you wrap the kernel call in a loop, check if the kernel donates its input buffers (look for `donate_argnames` in the kernel decorator). If it does, calling it repeatedly with the same inputs will fail. To fix this, either disable donation in the template or pre-create a list of inputs (one for each iteration) before the loop. +4. Define a search space as a dictionary mapping placeholder names to lists of suggested values. +5. Write the `kernel_name`, `code_template`, and `search_space` to a JSON file named `autotune_specs.json`. You MUST save this file (and any helper scripts you create like `create_specs.py`) in the directory specified by `{workdir?}`. Use that full path with `filesystem_tool_rw`. +The JSON file must have exactly this structure: +{ + "kernel_name": "...", + "code_template": "...", + "search_space": { ... } +} + +## Tools Available +1. **`search_api`**: Search for API definitions +2. **`read_file`**: Read the kernel code file. + - Required Argument: `path` +3. **`restricted_write_file`**: Write the json file + - Required Argument: `content` (The complete file content) + - Example: `restricted_write_file(content=...)` +""" diff --git a/MaxKernel/auto_agent/subagents/autotuning/prompts/summary_prompt.py b/MaxKernel/auto_agent/subagents/autotuning/prompts/summary_prompt.py new file mode 100644 index 0000000..b99110b --- /dev/null +++ b/MaxKernel/auto_agent/subagents/autotuning/prompts/summary_prompt.py @@ -0,0 +1,17 @@ +"""Prompt for AutotuneSummarizerAgent.""" + +PROMPT = """ +You are providing a summary of autotuning results. + +Your goal is to read the results from a file and report ONLY the best configuration and latency to the user. + +You must: +1. Use the `read_file` tool to read the file at {autotune_results_path?}. +2. Parse the JSON content of the file. +3. The file contains `all_results` (a list of all tested configurations). You should IGNORE the full list in your conversation response. +4. Find the `best_cfg` and `best_time` (or `best_time_ms`) in the JSON. +5. Report the best configuration and its execution time to the user in a clear, readable format. +6. If the status is "failed" or "error", report the error message. + +Be concise and friendly. Do NOT output the full list of results in the conversation. +""" diff --git a/MaxKernel/auto_agent/subagents/pipeline_agent.py b/MaxKernel/auto_agent/subagents/pipeline_agent.py index f351da7..df4a30c 100644 --- a/MaxKernel/auto_agent/subagents/pipeline_agent.py +++ b/MaxKernel/auto_agent/subagents/pipeline_agent.py @@ -239,6 +239,22 @@ def _initialize_state(self, ctx: InvocationContext) -> Event: f"[{self.name}] Set profiling_script_path: {ctx.session.state['profiling_script_path']}" ) + if "autotune_specs_path" not in ctx.session.state: + ctx.session.state["autotune_specs_path"] = os.path.join( + session_dir, "autotune_specs.json" + ) + logging.info( + f"[{self.name}] Set autotune_specs_path: {ctx.session.state['autotune_specs_path']}" + ) + + if "autotune_results_path" not in ctx.session.state: + ctx.session.state["autotune_results_path"] = os.path.join( + session_dir, "autotune_results.json" + ) + logging.info( + f"[{self.name}] Set autotune_results_path: {ctx.session.state['autotune_results_path']}" + ) + logging.info(f"[{self.name}] Published explicit path state update Event.") return Event( author=self.name, diff --git a/MaxKernel/auto_agent/tools/file_tools.py b/MaxKernel/auto_agent/tools/file_tools.py index 700eb55..bc24a5f 100644 --- a/MaxKernel/auto_agent/tools/file_tools.py +++ b/MaxKernel/auto_agent/tools/file_tools.py @@ -86,6 +86,9 @@ def _write_file(content: str, tool_context: ToolContext) -> str: write_profiling_script_tool = restricted_write_file( "profiling_script_path", "Writes the profiling script." ) +write_autotune_specs_tool = restricted_write_file( + "autotune_specs_path", "Writes the autotuning specifications." +) __all__ = [ "filesystem_tool_r", @@ -94,4 +97,5 @@ def _write_file(content: str, tool_context: ToolContext) -> str: "write_optimized_kernel_tool", "write_optimization_plan_tool", "write_profiling_script_tool", + "write_autotune_specs_tool", ] From 1ca5e0793faf2eab9ff1b0fbcbf467a1afcf92b6 Mon Sep 17 00:00:00 2001 From: shangkunwang Date: Tue, 19 May 2026 22:04:54 +0000 Subject: [PATCH 2/8] refactor: migrate autotune_kernel to use asynchronous aiohttp requests --- .../auto_agent/subagents/autotuning/agent.py | 2 +- .../subagents/autotuning/autotune_tool.py | 122 +++++++++--------- 2 files changed, 63 insertions(+), 61 deletions(-) diff --git a/MaxKernel/auto_agent/subagents/autotuning/agent.py b/MaxKernel/auto_agent/subagents/autotuning/agent.py index b1f46a6..393c063 100644 --- a/MaxKernel/auto_agent/subagents/autotuning/agent.py +++ b/MaxKernel/auto_agent/subagents/autotuning/agent.py @@ -101,7 +101,7 @@ async def _run_async_impl( logging.info(f"[{self.name}] Running autotune for {kernel_name}") try: - results = autotune_kernel( + results = await autotune_kernel( kernel_name=kernel_name, code_template=code_template, search_space=search_space, diff --git a/MaxKernel/auto_agent/subagents/autotuning/autotune_tool.py b/MaxKernel/auto_agent/subagents/autotuning/autotune_tool.py index 7ad8ca7..76ea1e8 100644 --- a/MaxKernel/auto_agent/subagents/autotuning/autotune_tool.py +++ b/MaxKernel/auto_agent/subagents/autotuning/autotune_tool.py @@ -1,18 +1,19 @@ """Standalone tool for auto-tuning Pallas kernels using grid search on remote servers.""" +import asyncio import json import logging import subprocess from typing import Any -import requests +import aiohttp from auto_agent.constants import EVAL_SERVER_PORT AUTOTUNE_TIMEOUT = 5400 -def autotune_kernel( +async def autotune_kernel( kernel_name: str, code_template: str, search_space: dict[str, list[Any]], @@ -42,73 +43,74 @@ def autotune_kernel( url = f"{server_addr}:{EVAL_SERVER_PORT}/evaluate" try: - response = requests.post( - url, - json={ - "eval_type": "autotune", - "code_template": code_template, - "search_space": search_space, - "timeout": 300, # timeout for each individual evaluation - "backend_type": backend, - "total_timeout": AUTOTUNE_TIMEOUT, - }, - timeout=AUTOTUNE_TIMEOUT + 10, - ) - - if response.status_code == 200: - result = response.json() - if result["exit_code"] == 0: - try: - output_data = json.loads(result["output"]) - logging.info( - f"Autotuning completed. Best config: {output_data['best_cfg']}" - f" with time {output_data['best_time']} ms" - ) - return { - "status": "success", - "message": "Autotuning completed", - "best_config": output_data["best_cfg"], - "best_time_ms": output_data["best_time"], - "best_output": output_data["best_output"], - "all_results": output_data.get("all_results", []), - } - except json.JSONDecodeError: - logging.warning("Failed to decode JSON from server output.") - return { - "status": "success", - "message": "Autotuning completed (raw output)", - "raw_output": result["output"], - } - else: - try: - output_data = json.loads(result["output"]) - return { - "status": "failed", - "message": result["error"] or "Autotune failed on server", - "all_results": output_data.get("all_results", []), - } - except Exception: + client_timeout = aiohttp.ClientTimeout(total=AUTOTUNE_TIMEOUT + 10) + async with aiohttp.ClientSession(timeout=client_timeout) as session: + async with session.post( + url, + json={ + "eval_type": "autotune", + "code_template": code_template, + "search_space": search_space, + "timeout": 300, # timeout for each individual evaluation + "backend_type": backend, + "total_timeout": AUTOTUNE_TIMEOUT, + }, + ) as response: + if response.status == 200: + result = await response.json() + if result["exit_code"] == 0: + try: + output_data = json.loads(result["output"]) + logging.info( + f"Autotuning completed. Best config: {output_data['best_cfg']}" + f" with time {output_data['best_time']} ms" + ) + return { + "status": "success", + "message": "Autotuning completed", + "best_config": output_data["best_cfg"], + "best_time_ms": output_data["best_time"], + "best_output": output_data["best_output"], + "all_results": output_data.get("all_results", []), + } + except json.JSONDecodeError: + logging.warning("Failed to decode JSON from server output.") + return { + "status": "success", + "message": "Autotuning completed (raw output)", + "raw_output": result["output"], + } + else: + try: + output_data = json.loads(result["output"]) + return { + "status": "failed", + "message": result["error"] or "Autotune failed on server", + "all_results": output_data.get("all_results", []), + } + except Exception: + return { + "status": "failed", + "message": result["error"] or "Autotune failed on server", + "server_output": result["output"], + } + else: + response_text = await response.text() return { - "status": "failed", - "message": result["error"] or "Autotune failed on server", - "server_output": result["output"], + "status": "error", + "message": ( + f"Server returned status code {response.status}: {response_text}" + ), } - else: - return { - "status": "error", - "message": ( - f"Server returned status code {response.status_code}: {response.text}" - ), - } - except requests.exceptions.ConnectionError: + except aiohttp.ClientConnectorError: return { "status": "error", "message": ( f"Could not connect to server at {url}. Make sure it is running." ), } - except requests.exceptions.Timeout: + except asyncio.TimeoutError: logging.warning( "Autotune timed out on client side. Cleaning up dangling subprocesses on server..." ) From 6e52d68ec42549cf8c6feecc0c0f71e743dc9f3a Mon Sep 17 00:00:00 2001 From: shangkunwang Date: Wed, 20 May 2026 16:32:25 +0000 Subject: [PATCH 3/8] feat: introduce asynchronous task polling for eval server client and implement total timeout logic for TPU autotuning --- .../auto_agent/client_utils/eval_client.py | 76 ++++++++ MaxKernel/auto_agent/constants.py | 4 +- .../auto_agent/server_utils/eval_server.py | 102 ++++++++-- .../auto_agent/server_utils/tpu_server.py | 11 ++ .../subagents/autotuning/autotune_tool.py | 130 ++++++------- .../kernel_writing/kernel_compilation.py | 125 ++++++------- .../subagents/profiling/kernel_profile.py | 177 ++++++++---------- .../auto_agent/subagents/testing/agent.py | 58 ++---- 8 files changed, 390 insertions(+), 293 deletions(-) create mode 100644 MaxKernel/auto_agent/client_utils/eval_client.py diff --git a/MaxKernel/auto_agent/client_utils/eval_client.py b/MaxKernel/auto_agent/client_utils/eval_client.py new file mode 100644 index 0000000..b10e92d --- /dev/null +++ b/MaxKernel/auto_agent/client_utils/eval_client.py @@ -0,0 +1,76 @@ +import asyncio +import logging +import time + +import aiohttp + + +async def call_eval_server_async( + session: aiohttp.ClientSession, + eval_server_url: str, + payload: dict, + poll_interval: int = 10, + client_wait_timeout: int = 3600 * 3, # Default to 3 hours +) -> dict: + """Calls the evaluation server asynchronously and polls for status. + + Args: + session: aiohttp.ClientSession to use. + eval_server_url: Base URL of the eval server (e.g., + "http://localhost:1245"). + payload: The request payload. + poll_interval: Seconds to wait between polls. + client_wait_timeout: Max seconds to wait for task completion. + + Returns: + The result from the evaluation server. + """ + # 1. Submit the task + submit_url = f"{eval_server_url}/evaluate" + logging.info(f"Submitting async task to {submit_url}") + + payload = payload.copy() + payload["client_wait_timeout"] = client_wait_timeout + + async with session.post(submit_url, json=payload) as response: + if response.status != 202: + error_text = await response.text() + raise Exception( + f"Failed to submit task. Status: {response.status}, Error: {error_text}" + ) + resp_data = await response.json() + task_id = resp_data["task_id"] + logging.info(f"Task submitted successfully. ID: {task_id}") + + # 2. Poll for status + start_time = time.time() + status_url = f"{eval_server_url}/status/{task_id}" + + while True: + if time.time() - start_time > client_wait_timeout: + raise Exception( + f"Client timed out waiting for task {task_id} after {client_wait_timeout} seconds" + ) + + async with session.get(status_url) as response: + if response.status != 200: + error_text = await response.text() + raise Exception( + f"Failed to get task status. Status: {response.status}, Error: {error_text}" + ) + + status_data = await response.json() + status = status_data["status"] + + if status == "success": + logging.info(f"Task {task_id} completed successfully.") + return status_data["result"] + elif status in ["failed", "timeout"]: + raise Exception( + f"Task {task_id} ended with status {status}: {status_data.get('error')}" + ) + + logging.info( + f"Task {task_id} status: {status}. Waiting {poll_interval}s..." + ) + await asyncio.sleep(poll_interval) diff --git a/MaxKernel/auto_agent/constants.py b/MaxKernel/auto_agent/constants.py index ad1aec0..24ff4c1 100644 --- a/MaxKernel/auto_agent/constants.py +++ b/MaxKernel/auto_agent/constants.py @@ -4,9 +4,7 @@ TEMPERATURE = 0.1 TOP_P = 0.9 TOP_K = 5 -TPU_TIMEOUT = 120 -REQUEST_TIMEOUT = 1800 +REQUEST_TIMEOUT = 3600 * 3 TPU_SERVER_PORT = 5463 CPU_SERVER_PORT = 5464 EVAL_SERVER_PORT = 1245 -PERF_THRESHOLD = 1.1 diff --git a/MaxKernel/auto_agent/server_utils/eval_server.py b/MaxKernel/auto_agent/server_utils/eval_server.py index 72d50b8..591567a 100644 --- a/MaxKernel/auto_agent/server_utils/eval_server.py +++ b/MaxKernel/auto_agent/server_utils/eval_server.py @@ -1,20 +1,17 @@ import asyncio import logging +import time +import uuid from enum import Enum from typing import Optional import aiohttp import yaml -from fastapi import FastAPI, HTTPException +from fastapi import BackgroundTasks, FastAPI, HTTPException from pydantic import BaseModel -from auto_agent.constants import ( - EVAL_SERVER_PORT, -) -from auto_agent.server_utils.tpu_server import ( - CodeResponse, - get_tpu_version, -) +from auto_agent.constants import EVAL_SERVER_PORT +from auto_agent.server_utils.tpu_server import get_tpu_version logging.basicConfig( level=logging.INFO, @@ -73,6 +70,27 @@ class EvalRequest(BaseModel): search_space: Optional[dict] = None # For autotune backend_type: Optional[str] = None # "tpu", "cpu", or None for any available dependencies: Optional[dict] = None + total_timeout: Optional[int] = None # For autotune + client_wait_timeout: Optional[int] = None + + +class TaskStatus(str, Enum): + QUEUED = "queued" + RUNNING = "running" + SUCCESS = "success" + FAILED = "failed" + TIMEOUT = "timeout" + + +class TaskResponse(BaseModel): + task_id: str + status: TaskStatus + result: Optional[dict] = None + error: Optional[str] = None + + +# For evaluation task tracking. task_id -> {status, request, result_or_error} +_tasks = {} class Evaluator: @@ -118,13 +136,64 @@ async def health_check(): return {"status": "healthy"} -@app.post("/evaluate", response_model=CodeResponse) -async def evaluate(request: EvalRequest): +# Asynchronous task polling architecture for evaluation. +@app.post("/evaluate", status_code=202) +async def evaluate(request: EvalRequest, background_tasks: BackgroundTasks): + task_id = str(uuid.uuid4()) + _tasks[task_id] = {"status": TaskStatus.QUEUED, "request": request} + background_tasks.add_task(run_evaluation_task, task_id, request) + return {"task_id": task_id, "status": TaskStatus.QUEUED} + + +@app.get("/status/{task_id}", response_model=TaskResponse) +async def get_task_status(task_id: str): + if task_id not in _tasks: + raise HTTPException(status_code=404, detail="Task not found") + task_info = _tasks[task_id] + return { + "task_id": task_id, + "status": task_info["status"], + "result": task_info.get("result"), + "error": task_info.get("error"), + } + + +async def run_evaluation_task(task_id: str, request: EvalRequest): + _tasks[task_id]["status"] = TaskStatus.RUNNING + try: + result = await _perform_evaluation(request) + _tasks[task_id]["status"] = TaskStatus.SUCCESS + _tasks[task_id]["result"] = result + except HTTPException as e: + if e.status_code == 408: + _tasks[task_id]["status"] = TaskStatus.TIMEOUT + else: + _tasks[task_id]["status"] = TaskStatus.FAILED + _tasks[task_id]["error"] = e.detail + except Exception as e: + _tasks[task_id]["status"] = TaskStatus.FAILED + _tasks[task_id]["error"] = str(e) + + +async def _perform_evaluation(request: EvalRequest): if request.eval_type not in EvalTypes: raise HTTPException(status_code=400, detail="Invalid evaluation type") # Acquire backend, retry if busy + start_queue_time = time.time() + max_queue_time = ( + request.client_wait_timeout + if request.client_wait_timeout is not None + else 3600 * 3 + ) + while True: + if time.time() - start_queue_time > max_queue_time: + raise HTTPException( + status_code=408, + detail="Timed out waiting for an available backend in queue", + ) + async with backend_semaphore: backend = evaluator.get_available_backend( backend_type=request.backend_type @@ -157,16 +226,13 @@ async def evaluate(request: EvalRequest): if request.eval_type == EvalTypes.AUTOTUNE: payload["code_template"] = request.code_template payload["search_space"] = request.search_space + backend_timeout = request.total_timeout else: payload["code"] = request.code payload["dependencies"] = request.dependencies + backend_timeout = request.timeout - backend_timeout = ( - request.total_timeout - if request.total_timeout is not None - else request.timeout - ) - client_timeout = aiohttp.ClientTimeout(total=backend_timeout) + client_timeout = aiohttp.ClientTimeout(total=backend_timeout + 10) async with aiohttp.ClientSession(timeout=client_timeout) as session: async with session.post( f"http://{backend_ip}:{backend_port}/{request.eval_type.value}", @@ -199,7 +265,9 @@ async def evaluate(request: EvalRequest): raise e except Exception as e: logging.error(f"Error occurred while evaluating on {backend_name}: {e}") - raise HTTPException(status_code=500, detail="Backend evaluation failed") + raise HTTPException( + status_code=500, detail=f"Backend evaluation failed: {e}" + ) finally: backend.set_status("available") diff --git a/MaxKernel/auto_agent/server_utils/tpu_server.py b/MaxKernel/auto_agent/server_utils/tpu_server.py index fa1a6a4..495f472 100644 --- a/MaxKernel/auto_agent/server_utils/tpu_server.py +++ b/MaxKernel/auto_agent/server_utils/tpu_server.py @@ -7,6 +7,7 @@ import subprocess import sys import tempfile +import time from typing import Optional from fastapi import FastAPI, HTTPException @@ -47,6 +48,7 @@ class AutotuneRequest(BaseModel): code_template: str search_space: dict[str, list] timeout: Optional[int] = 300 + total_timeout: Optional[int] = None class GetTpuVersionResponse(BaseModel): @@ -521,7 +523,16 @@ async def autotune(request: AutotuneRequest): best_output = "" all_results = [] + start_time = time.time() for combo in combinations: + if ( + request.total_timeout + and (time.time() - start_time) > request.total_timeout + ): + logging.warning( + f"Total timeout of {request.total_timeout}s reached in autotune" + ) + break cfg = dict(zip(keys, combo)) try: code_content = request.code_template diff --git a/MaxKernel/auto_agent/subagents/autotuning/autotune_tool.py b/MaxKernel/auto_agent/subagents/autotuning/autotune_tool.py index 76ea1e8..6cf3151 100644 --- a/MaxKernel/auto_agent/subagents/autotuning/autotune_tool.py +++ b/MaxKernel/auto_agent/subagents/autotuning/autotune_tool.py @@ -1,16 +1,16 @@ """Standalone tool for auto-tuning Pallas kernels using grid search on remote servers.""" -import asyncio import json import logging -import subprocess from typing import Any import aiohttp -from auto_agent.constants import EVAL_SERVER_PORT +from auto_agent.client_utils.eval_client import call_eval_server_async +from auto_agent.constants import EVAL_SERVER_PORT, REQUEST_TIMEOUT -AUTOTUNE_TIMEOUT = 5400 +AUTOTUNE_INDIVIDUAL_TIMEOUT = 300 +AUTOTUNE_TOTAL_TIMEOUT = 5400 async def autotune_kernel( @@ -43,64 +43,59 @@ async def autotune_kernel( url = f"{server_addr}:{EVAL_SERVER_PORT}/evaluate" try: - client_timeout = aiohttp.ClientTimeout(total=AUTOTUNE_TIMEOUT + 10) + client_timeout = aiohttp.ClientTimeout(total=REQUEST_TIMEOUT + 10) async with aiohttp.ClientSession(timeout=client_timeout) as session: - async with session.post( - url, - json={ - "eval_type": "autotune", - "code_template": code_template, - "search_space": search_space, - "timeout": 300, # timeout for each individual evaluation - "backend_type": backend, - "total_timeout": AUTOTUNE_TIMEOUT, - }, - ) as response: - if response.status == 200: - result = await response.json() - if result["exit_code"] == 0: - try: - output_data = json.loads(result["output"]) - logging.info( - f"Autotuning completed. Best config: {output_data['best_cfg']}" - f" with time {output_data['best_time']} ms" - ) - return { - "status": "success", - "message": "Autotuning completed", - "best_config": output_data["best_cfg"], - "best_time_ms": output_data["best_time"], - "best_output": output_data["best_output"], - "all_results": output_data.get("all_results", []), - } - except json.JSONDecodeError: - logging.warning("Failed to decode JSON from server output.") - return { - "status": "success", - "message": "Autotuning completed (raw output)", - "raw_output": result["output"], - } - else: - try: - output_data = json.loads(result["output"]) - return { - "status": "failed", - "message": result["error"] or "Autotune failed on server", - "all_results": output_data.get("all_results", []), - } - except Exception: - return { - "status": "failed", - "message": result["error"] or "Autotune failed on server", - "server_output": result["output"], - } - else: - response_text = await response.text() + payload = { + "eval_type": "autotune", + "code_template": code_template, + "search_space": search_space, + "timeout": AUTOTUNE_INDIVIDUAL_TIMEOUT, + "backend_type": backend, + "total_timeout": AUTOTUNE_TOTAL_TIMEOUT, + } + result = await call_eval_server_async( + session, + f"{server_addr}:{EVAL_SERVER_PORT}", + payload, + poll_interval=10, + client_wait_timeout=REQUEST_TIMEOUT, + ) + + if result["exit_code"] == 0: + try: + output_data = json.loads(result["output"]) + logging.info( + f"Autotuning completed. Best config: {output_data['best_cfg']}" + f" with time {output_data['best_time']} ms" + ) + return { + "status": "success", + "message": "Autotuning completed", + "best_config": output_data["best_cfg"], + "best_time_ms": output_data["best_time"], + "best_output": output_data["best_output"], + "all_results": output_data.get("all_results", []), + } + except json.JSONDecodeError: + logging.warning("Failed to decode JSON from server output.") + return { + "status": "success", + "message": "Autotuning completed (raw output)", + "raw_output": result["output"], + } + else: + try: + output_data = json.loads(result["output"]) + return { + "status": "failed", + "message": result["error"] or "Autotune failed on server", + "all_results": output_data.get("all_results", []), + } + except Exception: return { - "status": "error", - "message": ( - f"Server returned status code {response.status}: {response_text}" - ), + "status": "failed", + "message": result["error"] or "Autotune failed on server", + "server_output": result["output"], } except aiohttp.ClientConnectorError: @@ -110,21 +105,6 @@ async def autotune_kernel( f"Could not connect to server at {url}. Make sure it is running." ), } - except asyncio.TimeoutError: - logging.warning( - "Autotune timed out on client side. Cleaning up dangling subprocesses on server..." - ) - try: - subprocess.run(["pkill", "-f", "/tmp/hitl_eval_.*\\.py"], check=False) - except Exception as cleanup_error: - logging.error(f"Failed to run cleanup commands: {cleanup_error}") - return { - "status": "error", - "message": ( - f"Autotune request timed out after {AUTOTUNE_TIMEOUT} seconds. " - "Dangling processes were killed." - ), - } except Exception as e: return {"status": "error", "message": str(e)} diff --git a/MaxKernel/auto_agent/subagents/kernel_writing/kernel_compilation.py b/MaxKernel/auto_agent/subagents/kernel_writing/kernel_compilation.py index 9fb505a..c2adbca 100644 --- a/MaxKernel/auto_agent/subagents/kernel_writing/kernel_compilation.py +++ b/MaxKernel/auto_agent/subagents/kernel_writing/kernel_compilation.py @@ -6,11 +6,10 @@ from google.adk.agents.invocation_context import InvocationContext from google.adk.events import Event, EventActions -from auto_agent.constants import ( - EVAL_SERVER_PORT, - REQUEST_TIMEOUT, - TPU_TIMEOUT, -) +from auto_agent.client_utils.eval_client import call_eval_server_async +from auto_agent.constants import EVAL_SERVER_PORT, REQUEST_TIMEOUT + +COMPILATION_TIMEOUT = 120 class KernelCompilationChecker(BaseAgent): @@ -47,69 +46,63 @@ async def _run_async_impl( # Call the TPU server to execute the code logging.info(f"[{self.name}] Running code") async with aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(total=REQUEST_TIMEOUT) + timeout=aiohttp.ClientTimeout(total=REQUEST_TIMEOUT + 10) ) as session: - async with session.post( - f"http://localhost:{EVAL_SERVER_PORT}/evaluate", - json={ - "eval_type": "compilation_test", - "code": code, - "timeout": TPU_TIMEOUT, - "backend_type": "tpu", - }, - ) as response: - if response.status == 200: - result = await response.json() - logging.info(f"[{self.name}] Compilation test result: {result}") - if result["exit_code"] == 0: - logging.info(f"[{self.name}] Code execution successful.") - yield Event( - author=self.name, - actions=EventActions(state_delta={self.output_key: "Success"}), - ) - elif ( - result["error"] is None - and result["output"] == "" - and result["exit_code"] == 1 - ): - logging.info( - f"[{self.name}] Code execution had exit code 1, but no error, indicating success." - ) - yield Event( - author=self.name, - actions=EventActions(state_delta={self.output_key: "Success"}), - ) - else: - logging.info( - f"[{self.name}] Code execution failed. Loop will continue." - ) - # Use 'or' to handle None case - result.get() returns None if key exists with None value - error_msg = ( - result.get("error") - or result.get("output") - or "Unknown error: No error message or output available" - ) - # Add diagnostic logging when error field is None or empty - if result.get("error") is None: - logging.warning( - f"[{self.name}] Error field is None. " - f"exit_code: {result.get('exit_code')}, " - f"output length: {len(result.get('output', ''))}, " - f"Using output as fallback: {result.get('output', '')[:200]}" - ) - yield Event( - author=self.name, - actions=EventActions(state_delta={self.output_key: error_msg}), - ) - else: - error_detail = await response.text() - logging.error( - f"[{self.name}] HTTP error {response.status}: {error_detail}" - ) - ctx.session.state[self.output_key] = ( - f"HTTP error {response.status}: {error_detail}" + payload = { + "eval_type": "compilation_test", + "code": code, + "timeout": COMPILATION_TIMEOUT, + "backend_type": "tpu", + } + result = await call_eval_server_async( + session, + f"http://localhost:{EVAL_SERVER_PORT}", + payload, + poll_interval=10, + client_wait_timeout=REQUEST_TIMEOUT, + ) + + logging.info(f"[{self.name}] Compilation test result: {result}") + if result["exit_code"] == 0: + logging.info(f"[{self.name}] Code execution successful.") + yield Event( + author=self.name, + actions=EventActions(state_delta={self.output_key: "Success"}), + ) + elif ( + result["error"] is None + and result["output"] == "" + and result["exit_code"] == 1 + ): + logging.info( + f"[{self.name}] Code execution had exit code 1, but no error, indicating success." + ) + yield Event( + author=self.name, + actions=EventActions(state_delta={self.output_key: "Success"}), + ) + else: + logging.info( + f"[{self.name}] Code execution failed. Loop will continue." + ) + # Use 'or' to handle None case - result.get() returns None if key exists with None value + error_msg = ( + result.get("error") + or result.get("output") + or "Unknown error: No error message or output available" + ) + # Add diagnostic logging when error field is None or empty + if result.get("error") is None: + logging.warning( + f"[{self.name}] Error field is None. " + f"exit_code: {result.get('exit_code')}, " + f"output length: {len(result.get('output', ''))}, " + f"Using output as fallback: {result.get('output', '')[:200]}" ) - yield Event(author=self.name) + yield Event( + author=self.name, + actions=EventActions(state_delta={self.output_key: error_msg}), + ) except Exception as e: logging.error(f"[{self.name}] Exception during code execution: {str(e)}") ctx.session.state[self.output_key] = ( diff --git a/MaxKernel/auto_agent/subagents/profiling/kernel_profile.py b/MaxKernel/auto_agent/subagents/profiling/kernel_profile.py index 72138f2..4a5f642 100644 --- a/MaxKernel/auto_agent/subagents/profiling/kernel_profile.py +++ b/MaxKernel/auto_agent/subagents/profiling/kernel_profile.py @@ -6,11 +6,10 @@ from google.adk.agents.invocation_context import InvocationContext from google.adk.events import Event, EventActions -from auto_agent.constants import ( - EVAL_SERVER_PORT, - REQUEST_TIMEOUT, - TPU_TIMEOUT, -) +from auto_agent.constants import EVAL_SERVER_PORT, REQUEST_TIMEOUT +from auto_agent.tools.eval_client import call_eval_server_async + +PROFILE_TIMEOUT = 120 class KernelProfiler(BaseAgent): @@ -50,106 +49,96 @@ async def _run_async_impl( # Call the TPU server to execute the code logging.info(f"[{self.name}] Running code") async with aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(total=REQUEST_TIMEOUT) + timeout=aiohttp.ClientTimeout(total=REQUEST_TIMEOUT + 10) ) as session: - async with session.post( - f"http://localhost:{EVAL_SERVER_PORT}/evaluate", - json={ - "eval_type": "profile", - "code": profile_code, - "timeout": TPU_TIMEOUT, - "backend_type": "tpu", - }, - ) as response: - if response.status == 200: - result = await response.json() - logging.info(f"[{self.name}] Profiling result: {result}") + payload = { + "eval_type": "profile", + "code": profile_code, + "timeout": PROFILE_TIMEOUT, + "backend_type": "tpu", + } + result = await call_eval_server_async( + session, + f"http://localhost:{EVAL_SERVER_PORT}", + payload, + poll_interval=10, + client_wait_timeout=REQUEST_TIMEOUT, + ) - # Check if profiling was successful based on exit code and output - exit_code = result.get("exit_code", 0) - output = result.get("output", "") - error_msg = result.get("error", "") + logging.info(f"[{self.name}] Profiling result: {result}") - # Profiling succeeds if exit_code is 0 and we have output - # Stderr may contain warnings (like TensorFlow import warnings) which are not failures - if exit_code != 0: - full_error = f"Profiling script failed with exit code {exit_code}" - if error_msg: - full_error += f": {error_msg}" - logging.error(f"[{self.name}] {full_error}") - yield Event( - author=self.name, - actions=EventActions(state_delta={self.output_key: full_error}), - ) - elif not output or output.strip() == "": - full_error = "Profiling script produced no output" - if error_msg: - full_error += f". Stderr: {error_msg}" - logging.error(f"[{self.name}] {full_error}") - yield Event( - author=self.name, - actions=EventActions(state_delta={self.output_key: full_error}), - ) - else: - # Successful profiling - parse the ratio and xplane path - try: - try: - import json + # Check if profiling was successful based on exit code and output + exit_code = result.get("exit_code", 0) + output = result.get("output", "") + error_msg = result.get("error", "") - res = json.loads(output.strip()) - ratio = float(res.get("ratio", 0)) - xplane_path = res.get("xplane_path", "") - except (ValueError, json.JSONDecodeError): - # Fallback for old servers returning raw ratio - ratio = float(output.strip()) - xplane_path = "" + # Profiling succeeds if exit_code is 0 and we have output + # Stderr may contain warnings (like TensorFlow import warnings) which are not failures + if exit_code != 0: + full_error = f"Profiling script failed with exit code {exit_code}" + if error_msg: + full_error += f": {error_msg}" + logging.error(f"[{self.name}] {full_error}") + yield Event( + author=self.name, + actions=EventActions(state_delta={self.output_key: full_error}), + ) + elif not output or output.strip() == "": + full_error = "Profiling script produced no output" + if error_msg: + full_error += f". Stderr: {error_msg}" + logging.error(f"[{self.name}] {full_error}") + yield Event( + author=self.name, + actions=EventActions(state_delta={self.output_key: full_error}), + ) + else: + # Successful profiling - parse the ratio and xplane path + try: + try: + import json - # Log warnings if present, but don't fail - if error_msg: - logging.warning( - f"[{self.name}] Profiling succeeded but had warnings in" - f" stderr: {error_msg[:200]}" - ) + res = json.loads(output.strip()) + ratio = float(res.get("ratio", 0)) + xplane_path = res.get("xplane_path", "") + except (ValueError, json.JSONDecodeError): + # Fallback for old servers returning raw ratio + ratio = float(output.strip()) + xplane_path = "" - logging.info( - f"[{self.name}] Profiling succeeded with ratio: {ratio}," - f" xplane_path: {xplane_path}" - ) - yield Event( - author=self.name, - actions=EventActions( - escalate=False, - state_delta={ - self.output_key: { - "DMAs_and_memory_transfers_ratio": ratio, - "compute_ratio": 1 - ratio, - "xplane_path": xplane_path, - } - }, - ), - ) - except (ValueError, KeyError) as e: - error_msg_full = ( - f"Failed to parse profiling output: '{output}'. Error: {e}" - ) - logging.error(f"[{self.name}] {error_msg_full}") - yield Event( - author=self.name, - actions=EventActions( - state_delta={self.output_key: error_msg_full} - ), - ) - else: - error_detail = await response.text() - logging.error( - f"[{self.name}] HTTP error {response.status}: {error_detail}" + # Log warnings if present, but don't fail + if error_msg: + logging.warning( + f"[{self.name}] Profiling succeeded but had warnings in" + f" stderr: {error_msg[:200]}" + ) + + logging.info( + f"[{self.name}] Profiling succeeded with ratio: {ratio}," + f" xplane_path: {xplane_path}" ) yield Event( author=self.name, actions=EventActions( + escalate=False, state_delta={ - self.output_key: f"HTTP error {response.status}: {error_detail}" - } + self.output_key: { + "DMAs_and_memory_transfers_ratio": ratio, + "compute_ratio": 1 - ratio, + "xplane_path": xplane_path, + } + }, + ), + ) + except (ValueError, KeyError) as e: + error_msg_full = ( + f"Failed to parse profiling output: '{output}'. Error: {e}" + ) + logging.error(f"[{self.name}] {error_msg_full}") + yield Event( + author=self.name, + actions=EventActions( + state_delta={self.output_key: error_msg_full} ), ) except Exception as e: diff --git a/MaxKernel/auto_agent/subagents/testing/agent.py b/MaxKernel/auto_agent/subagents/testing/agent.py index a57e0bc..04225eb 100644 --- a/MaxKernel/auto_agent/subagents/testing/agent.py +++ b/MaxKernel/auto_agent/subagents/testing/agent.py @@ -16,7 +16,7 @@ from google.adk.events import Event, EventActions from auto_agent.config import model_config, thinking_planner -from auto_agent.constants import EVAL_SERVER_PORT, MODEL_NAME +from auto_agent.constants import EVAL_SERVER_PORT, MODEL_NAME, REQUEST_TIMEOUT from auto_agent.custom_types import CustomLlmAgent from auto_agent.subagents.testing.prompts import ( fix_test_script, @@ -24,6 +24,7 @@ summarize_test_results_prompt, validation_summary, ) +from auto_agent.tools.eval_client import call_eval_server_async from auto_agent.tools.file_tools import filesystem_tool_r, write_test_file_tool from auto_agent.tools.search_api_tool import search_api_tool from auto_agent.tools.tools import vertex_ai_rag_tool @@ -149,17 +150,15 @@ async def _run_async_impl( "dependencies": dependencies, } - async with aiohttp.ClientSession() as session: - async with session.post( - f"http://localhost:{EVAL_SERVER_PORT}/evaluate", - json=payload, - ) as response: - if response.status != 200: - error_text = await response.text() - raise Exception( - f"Eval server returned status {response.status}: {error_text}" - ) - result_json = await response.json() + client_timeout = aiohttp.ClientTimeout(total=REQUEST_TIMEOUT + 10) + async with aiohttp.ClientSession(timeout=client_timeout) as session: + result_json = await call_eval_server_async( + session, + f"http://localhost:{EVAL_SERVER_PORT}", + payload, + poll_interval=10, + client_wait_timeout=REQUEST_TIMEOUT, + ) full_output = f"STDOUT:\n{result_json.get('output', '')}\n\nSTDERR:\n{result_json.get('error', '') or ''}" @@ -177,21 +176,6 @@ async def _run_async_impl( actions=EventActions(state_delta={self.output_key: test_results}), ) - except subprocess.TimeoutExpired: - error_msg = "Test execution timed out after 5 minutes" - logging.error(f"[{self.name}] {error_msg}") - yield Event( - author=self.name, - actions=EventActions( - state_delta={ - self.output_key: { - "exit_code": -1, - "output": error_msg, - "success": False, - } - } - ), - ) except Exception as e: error_msg = f"Exception during test execution: {str(e)}" logging.error(f"[{self.name}] {error_msg}") @@ -639,17 +623,15 @@ async def _run_async_impl( "dependencies": dependencies, } - async with aiohttp.ClientSession() as session: - async with session.post( - f"http://localhost:{EVAL_SERVER_PORT}/evaluate", - json=payload, - ) as response: - if response.status != 200: - error_text = await response.text() - raise Exception( - f"Eval server returned status {response.status}: {error_text}" - ) - result_json = await response.json() + client_timeout = aiohttp.ClientTimeout(total=REQUEST_TIMEOUT + 10) + async with aiohttp.ClientSession(timeout=client_timeout) as session: + result_json = await call_eval_server_async( + session, + f"http://localhost:{EVAL_SERVER_PORT}", + payload, + poll_interval=10, + client_wait_timeout=REQUEST_TIMEOUT, + ) exit_code = result_json.get("exit_code", -1) stdout = result_json.get("output", "") From cff421f245eb5533e49922e78060d2ae60865a32 Mon Sep 17 00:00:00 2001 From: shangkunwang Date: Wed, 20 May 2026 18:07:46 +0000 Subject: [PATCH 4/8] fix: fix import path for eval client --- MaxKernel/auto_agent/subagents/profiling/kernel_profile.py | 2 +- MaxKernel/auto_agent/subagents/testing/agent.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/MaxKernel/auto_agent/subagents/profiling/kernel_profile.py b/MaxKernel/auto_agent/subagents/profiling/kernel_profile.py index 4a5f642..a4299e9 100644 --- a/MaxKernel/auto_agent/subagents/profiling/kernel_profile.py +++ b/MaxKernel/auto_agent/subagents/profiling/kernel_profile.py @@ -6,8 +6,8 @@ from google.adk.agents.invocation_context import InvocationContext from google.adk.events import Event, EventActions +from auto_agent.client_utils.eval_client import call_eval_server_async from auto_agent.constants import EVAL_SERVER_PORT, REQUEST_TIMEOUT -from auto_agent.tools.eval_client import call_eval_server_async PROFILE_TIMEOUT = 120 diff --git a/MaxKernel/auto_agent/subagents/testing/agent.py b/MaxKernel/auto_agent/subagents/testing/agent.py index 04225eb..c120de1 100644 --- a/MaxKernel/auto_agent/subagents/testing/agent.py +++ b/MaxKernel/auto_agent/subagents/testing/agent.py @@ -15,6 +15,7 @@ from google.adk.agents.invocation_context import InvocationContext from google.adk.events import Event, EventActions +from auto_agent.client_utils.eval_client import call_eval_server_async from auto_agent.config import model_config, thinking_planner from auto_agent.constants import EVAL_SERVER_PORT, MODEL_NAME, REQUEST_TIMEOUT from auto_agent.custom_types import CustomLlmAgent @@ -24,7 +25,6 @@ summarize_test_results_prompt, validation_summary, ) -from auto_agent.tools.eval_client import call_eval_server_async from auto_agent.tools.file_tools import filesystem_tool_r, write_test_file_tool from auto_agent.tools.search_api_tool import search_api_tool from auto_agent.tools.tools import vertex_ai_rag_tool From eeb35845f49cd248d4d5728246ad1a7ef84ee7c9 Mon Sep 17 00:00:00 2001 From: shangkunwang Date: Wed, 20 May 2026 20:22:52 +0000 Subject: [PATCH 5/8] feat: add a apply_best_config subagent into autotune agent and integrate with pipeline agent --- MaxKernel/auto_agent/agent.py | 2 + .../auto_agent/subagents/autotuning/agent.py | 60 ++++++++++++++++--- .../prompts/apply_best_config_prompt.py | 13 ++++ .../autotuning/prompts/autotune_prompt.py | 7 ++- .../autotuning/prompts/summary_prompt.py | 26 ++++---- .../auto_agent/subagents/pipeline_agent.py | 32 +++++++--- 6 files changed, 113 insertions(+), 27 deletions(-) create mode 100644 MaxKernel/auto_agent/subagents/autotuning/prompts/apply_best_config_prompt.py diff --git a/MaxKernel/auto_agent/agent.py b/MaxKernel/auto_agent/agent.py index e57004c..cc2e2e1 100644 --- a/MaxKernel/auto_agent/agent.py +++ b/MaxKernel/auto_agent/agent.py @@ -4,6 +4,7 @@ for the human-in-the-loop kernel generation process. """ +from auto_agent.subagents.autotuning.agent import autotune_agent from auto_agent.subagents.kernel_writing import ( implement_kernel_agent, plan_kernel_agent, @@ -23,6 +24,7 @@ validate_agent=validate_kernel_compilation_agent, test_gen_agent=validated_test_generation_agent, test_run_agent=unified_test_agent, + autotune_agent=autotune_agent, profile_agent=profile_agent, max_iterations=5, ) diff --git a/MaxKernel/auto_agent/subagents/autotuning/agent.py b/MaxKernel/auto_agent/subagents/autotuning/agent.py index 393c063..d933801 100644 --- a/MaxKernel/auto_agent/subagents/autotuning/agent.py +++ b/MaxKernel/auto_agent/subagents/autotuning/agent.py @@ -5,7 +5,7 @@ import os from typing import AsyncGenerator -from google.adk.agents import BaseAgent, SequentialAgent +from google.adk.agents import BaseAgent from google.adk.agents.invocation_context import InvocationContext from google.adk.events import Event, EventActions @@ -14,9 +14,11 @@ from auto_agent.custom_types import CustomLlmAgent from auto_agent.subagents.autotuning.autotune_tool import autotune_kernel from auto_agent.subagents.autotuning.prompts import ( + apply_best_config_prompt, autotune_prompt, summary_prompt, ) +from auto_agent.tools.file_tools import write_optimized_kernel_tool from auto_agent.tools.search_api_tool import search_api_tool from auto_agent.tools.tools import filesystem_tool_r, write_autotune_specs_tool @@ -139,7 +141,18 @@ async def _run_async_impl( output_key="autotune_results", ) -# 3. Summarizer Agent (LLM) +# 3. Apply Best Config Agent +apply_best_config_agent = CustomLlmAgent( + name="ApplyBestConfigAgent", + model=MODEL_NAME, + generate_content_config=model_config, + planner=thinking_planner, + instruction=apply_best_config_prompt.PROMPT, + description="Applies autotuning results to the optimized kernel file.", + tools=[filesystem_tool_r, write_optimized_kernel_tool], +) + +# 4. Summarizer Agent # This agent reads results from state and talks to the user. autotune_summary_agent = CustomLlmAgent( name="AutotuneSummaryAgent", @@ -151,10 +164,43 @@ async def _run_async_impl( tools=[filesystem_tool_r], ) -# 4. Combined Sequential Agent -autotune_agent = SequentialAgent( - name="AutotuneAgent", - sub_agents=[autotune_planner_agent, autotune_runner, autotune_summary_agent], -) + +class CombinedAutotuneAgent(BaseAgent): + """Chains autotuning steps and conditionally applies best config.""" + + def __init__(self, name: str): + super().__init__(name=name) + + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + logging.info(f"[{self.name}] Running AutotunePlannerAgent...") + async for event in autotune_planner_agent.run_async(ctx): + yield event + + logging.info(f"[{self.name}] Running AutotuneRunner...") + async for event in autotune_runner.run_async(ctx): + yield event + + autotune_results = ctx.session.state.get("autotune_results", {}) + if ( + autotune_results.get("status") == "success" + and autotune_results.get("best_cfg") is not None + ): + logging.info(f"[{self.name}] Running ApplyBestConfigAgent...") + async for event in apply_best_config_agent.run_async(ctx): + yield event + else: + logging.warning( + f"[{self.name}] Autotune was not successful or no best configuration" + " found. Skipping ApplyBestConfigAgent." + ) + + logging.info(f"[{self.name}] Running AutotuneSummaryAgent...") + async for event in autotune_summary_agent.run_async(ctx): + yield event + + +autotune_agent = CombinedAutotuneAgent(name="AutotuneAgent") __all__ = ["autotune_agent"] diff --git a/MaxKernel/auto_agent/subagents/autotuning/prompts/apply_best_config_prompt.py b/MaxKernel/auto_agent/subagents/autotuning/prompts/apply_best_config_prompt.py new file mode 100644 index 0000000..b408697 --- /dev/null +++ b/MaxKernel/auto_agent/subagents/autotuning/prompts/apply_best_config_prompt.py @@ -0,0 +1,13 @@ +"""Prompt for ApplyBestConfigAgent.""" + +PROMPT = """You are a specialized agent for applying autotuning results to a Pallas kernel file. +Your goal is to read the best configuration from autotuning results and update the `optimized_kernel.py` file with these values. + +You must: +1. use the `read_file` tool to read the file at {autotune_specs_path?} to get the context of autotuning experiment. +2. Use the `read_file` tool to read the file at {autotune_results_path?} and parse the JSON content of the autotune results to find the best configuration `best_cfg`. +3. Use the `read_file` tool to read the current kernel file at {optimized_kernel_path?}. +4. Use the `restricted_write_file` tool to save the updated kernel file, replacing the old parameter values with the values from `best_cfg`. + +Be precise and ensure you only change the specific parameter values identified in the best configuration. +""" diff --git a/MaxKernel/auto_agent/subagents/autotuning/prompts/autotune_prompt.py b/MaxKernel/auto_agent/subagents/autotuning/prompts/autotune_prompt.py index 16fef95..aa5c484 100644 --- a/MaxKernel/auto_agent/subagents/autotuning/prompts/autotune_prompt.py +++ b/MaxKernel/auto_agent/subagents/autotuning/prompts/autotune_prompt.py @@ -6,8 +6,11 @@ To prepare for autotuning, you must: 1. Identify the parameters that can be tuned in the kernel (e.g., BLOCK_M, BLOCK_N). 2. Create a code template from the kernel code, replacing the specific parameter values with placeholders enclosed in curly braces (for example, if the parameter is BLOCK_M, use it enclosed in curly braces as the placeholder). -3. Ensure the template code prints "RESULT_TIME: " to indicate the execution time. You may need to wrap the kernel call in a loop or use `jax.block_until_ready()` to get accurate timing. WARNING: If you wrap the kernel call in a loop, check if the kernel donates its input buffers (look for `donate_argnames` in the kernel decorator). If it does, calling it repeatedly with the same inputs will fail. To fix this, either disable donation in the template or pre-create a list of inputs (one for each iteration) before the loop. -4. Define a search space as a dictionary mapping placeholder names to lists of suggested values. +3. Ensure the template code prints "RESULT_TIME: " to indicate the average execution time. To get accurate and quick timing, wrap the kernel call in a loop of exactly 10 iterations (preceded by 1 warm-up execution) and use `jax.block_until_ready()`. Limit iterations strictly to 10 to keep profiling runs fast. WARNING: If you wrap the kernel call in a loop, check if the kernel donates its input buffers (look for `donate_argnames` in the kernel decorator). If it does, calling it repeatedly with the same inputs will fail. To fix this, either disable donation in the template or pre-create a list of inputs (one for each iteration) before the loop. +4. Define a highly optimized, high-probability search space as a dictionary mapping placeholder names to lists of suggested values. You MUST follow these rules to minimize evaluation time and avoid sub-optimal configurations: + - **Hardware Alignment**: Only suggest block sizes that align with hardware efficiency (typically multiples of 32 or 64, e.g., `[32, 64, 128]`). Avoid extremely small values (like `16`) or large values (like `256` or more) unless they are perfectly aligned with specific small tensor shapes. + - **Dimension Divisors**: Choose suggested block sizes that are clean, even divisors of the corresponding matrix or tensor shape dimensions to prevent compiler masking and branch overhead. + - **Total Combinations Limit**: Proactively limit the size of individual parameter lists so that the total Cartesian product (all possible combinations) stays small—ideally between **10 to 100 total combinations max**. Keep each parameter list to 2 or 3 high-probability values (e.g., `[64, 128]`). Do not generate massive combinatorial sweeps. 5. Write the `kernel_name`, `code_template`, and `search_space` to a JSON file named `autotune_specs.json`. You MUST save this file (and any helper scripts you create like `create_specs.py`) in the directory specified by `{workdir?}`. Use that full path with `filesystem_tool_rw`. The JSON file must have exactly this structure: { diff --git a/MaxKernel/auto_agent/subagents/autotuning/prompts/summary_prompt.py b/MaxKernel/auto_agent/subagents/autotuning/prompts/summary_prompt.py index b99110b..ba05f4d 100644 --- a/MaxKernel/auto_agent/subagents/autotuning/prompts/summary_prompt.py +++ b/MaxKernel/auto_agent/subagents/autotuning/prompts/summary_prompt.py @@ -1,17 +1,23 @@ """Prompt for AutotuneSummarizerAgent.""" PROMPT = """ -You are providing a summary of autotuning results. +You are an AI assistant summarizing autotuning results for a Pallas kernel optimization task. -Your goal is to read the results from a file and report ONLY the best configuration and latency to the user. +Your goal is to read the results from a file, report the best configuration and latency to the user, and state whether this configuration was applied. -You must: -1. Use the `read_file` tool to read the file at {autotune_results_path?}. -2. Parse the JSON content of the file. -3. The file contains `all_results` (a list of all tested configurations). You should IGNORE the full list in your conversation response. -4. Find the `best_cfg` and `best_time` (or `best_time_ms`) in the JSON. -5. Report the best configuration and its execution time to the user in a clear, readable format. -6. If the status is "failed" or "error", report the error message. +Instructions: +1. **Read Results**: Use the `read_file` tool to read the file at {autotune_results_path?} and parse its JSON content. +2. **Extract Metrics**: Find the `"best_cfg"` and `"best_time_ms"` in the JSON. +3. **Summarize**: Provide a clear summary in your response. Do NOT list all tested configurations from `all_results`. +4. **Verify Application**: State whether the best configuration was applied to the file at {optimized_kernel_path?}. (Note: In the current pipeline, if the autotune status is "success", the best configuration is automatically applied before you run). +5. **Handle Errors**: If the status is `"failed"` or `"error"`, report the error message provided in the file. -Be concise and friendly. Do NOT output the full list of results in the conversation. +Please use the following format for your summary: +### Autotuning Results +- **Status**: [Success / Failed] +- **Best Configuration**: `[JSON or description of best config]` +- **Latency**: `[Time]` ms +- **Applied to File**: [Yes / No] + +[Any additional brief notes or error messages] """ diff --git a/MaxKernel/auto_agent/subagents/pipeline_agent.py b/MaxKernel/auto_agent/subagents/pipeline_agent.py index df4a30c..ad06020 100644 --- a/MaxKernel/auto_agent/subagents/pipeline_agent.py +++ b/MaxKernel/auto_agent/subagents/pipeline_agent.py @@ -2,6 +2,7 @@ import logging import os +import re from typing import AsyncGenerator from google.adk.agents import BaseAgent @@ -21,6 +22,7 @@ class AutonomousPipelineAgent(BaseAgent): validate_agent: BaseAgent test_gen_agent: BaseAgent test_run_agent: BaseAgent + autotune_agent: BaseAgent profile_agent: BaseAgent max_iterations: int = 2 @@ -32,6 +34,7 @@ def __init__( validate_agent: BaseAgent, test_gen_agent: BaseAgent, test_run_agent: BaseAgent, + autotune_agent: BaseAgent, profile_agent: BaseAgent, max_iterations: int = 2, ): @@ -42,6 +45,7 @@ def __init__( validate_agent=validate_agent, test_gen_agent=test_gen_agent, test_run_agent=test_run_agent, + autotune_agent=autotune_agent, profile_agent=profile_agent, max_iterations=max_iterations, ) @@ -110,7 +114,12 @@ async def _run_async_impl( iteration += 1 continue - # Step 6: Profile + # Step 6: Autotune + logging.info(f"[{self.name}] Running AutotuneAgent...") + async for event in self.autotune_agent.run_async(ctx): + yield event + + # Step 7: Profile logging.info(f"[{self.name}] Running ProfileAgentOrchestrator...") async for event in self.profile_agent.run_async(ctx): yield event @@ -128,8 +137,7 @@ async def _run_async_impl( ) # Extract latency - test_output = ctx.session.state.get("test_results", {}).get("output", "") - latency = self._extract_latency(test_output) + latency = self._extract_latency(ctx) snapshot = { "iteration": iteration, @@ -270,18 +278,26 @@ def _initialize_state(self, ctx: InvocationContext) -> Event: ), ) - def _extract_latency(self, test_output: str): - """Extracts execution time from test results output.""" + def _extract_latency(self, ctx: InvocationContext): + """Extracts execution time from autotune results or test results output.""" + autotune_results = ctx.session.state.get("autotune_results", {}) + if autotune_results.get("status") == "success": + latency = autotune_results.get("best_time_ms") + if latency is not None: + logging.info( + f"[{self.name}] Extracted latency from autotune results: {latency} ms" + ) + return latency + + test_output = ctx.session.state.get("test_results", {}).get("output", "") if not test_output: return None try: - import re - match = re.search(r"PERF_METRICS:\s*([\d.]+)", test_output) if match: latency = float(match.group(1)) logging.info( - f"[{self.name}] Extracted execution time from test results: {latency} ms" + f"[{self.name}] Extracted execution time from test results: {latency} ms (fallback)" ) return latency except Exception as e: From 980244de25fcf579559b7ceefc80468732909b9d Mon Sep 17 00:00:00 2001 From: shangkunwang Date: Thu, 21 May 2026 22:42:55 +0000 Subject: [PATCH 6/8] fix: update regex to include 'ms' in RESULT_TIME, standardize best configuration naming in prompts and improve summary and autotune prompt. --- MaxKernel/auto_agent/server_utils/tpu_server.py | 2 +- MaxKernel/auto_agent/subagents/autotuning/agent.py | 14 ++++++++++---- .../autotuning/prompts/apply_best_config_prompt.py | 4 ++-- .../autotuning/prompts/autotune_prompt.py | 6 ++++-- .../subagents/autotuning/prompts/summary_prompt.py | 14 ++++++++------ 5 files changed, 25 insertions(+), 15 deletions(-) diff --git a/MaxKernel/auto_agent/server_utils/tpu_server.py b/MaxKernel/auto_agent/server_utils/tpu_server.py index 495f472..f23ab31 100644 --- a/MaxKernel/auto_agent/server_utils/tpu_server.py +++ b/MaxKernel/auto_agent/server_utils/tpu_server.py @@ -569,7 +569,7 @@ async def autotune(request: AutotuneRequest): if exit_code == 0: # Parse RESULT_TIME - match = re.search(r"RESULT_TIME:\s*([0-9.]+)", output) + match = re.search(r"RESULT_TIME:\s*([0-9.]+)\s*ms", output) if match: time_taken = float(match.group(1)) all_results.append( diff --git a/MaxKernel/auto_agent/subagents/autotuning/agent.py b/MaxKernel/auto_agent/subagents/autotuning/agent.py index d933801..817978c 100644 --- a/MaxKernel/auto_agent/subagents/autotuning/agent.py +++ b/MaxKernel/auto_agent/subagents/autotuning/agent.py @@ -3,7 +3,7 @@ import json import logging import os -from typing import AsyncGenerator +from typing import AsyncGenerator, Optional from google.adk.agents import BaseAgent from google.adk.agents.invocation_context import InvocationContext @@ -18,9 +18,12 @@ autotune_prompt, summary_prompt, ) -from auto_agent.tools.file_tools import write_optimized_kernel_tool +from auto_agent.tools.file_tools import ( + filesystem_tool_r, + write_autotune_specs_tool, + write_optimized_kernel_tool, +) from auto_agent.tools.search_api_tool import search_api_tool -from auto_agent.tools.tools import filesystem_tool_r, write_autotune_specs_tool # 1. Planner Agent (LLM) # This agent identifies parameters, creates the template, and defines the search space. @@ -40,6 +43,9 @@ class AutotuneRunner(BaseAgent): """Executes autotuning via HTTP endpoint.""" + name: Optional[str] = None + output_key: Optional[str] = None + def __init__( self, name: str, @@ -185,7 +191,7 @@ async def _run_async_impl( autotune_results = ctx.session.state.get("autotune_results", {}) if ( autotune_results.get("status") == "success" - and autotune_results.get("best_cfg") is not None + and autotune_results.get("best_config") is not None ): logging.info(f"[{self.name}] Running ApplyBestConfigAgent...") async for event in apply_best_config_agent.run_async(ctx): diff --git a/MaxKernel/auto_agent/subagents/autotuning/prompts/apply_best_config_prompt.py b/MaxKernel/auto_agent/subagents/autotuning/prompts/apply_best_config_prompt.py index b408697..363b123 100644 --- a/MaxKernel/auto_agent/subagents/autotuning/prompts/apply_best_config_prompt.py +++ b/MaxKernel/auto_agent/subagents/autotuning/prompts/apply_best_config_prompt.py @@ -5,9 +5,9 @@ You must: 1. use the `read_file` tool to read the file at {autotune_specs_path?} to get the context of autotuning experiment. -2. Use the `read_file` tool to read the file at {autotune_results_path?} and parse the JSON content of the autotune results to find the best configuration `best_cfg`. +2. Use the `read_file` tool to read the file at {autotune_results_path?} and parse the JSON content of the autotune results to find the best configuration `best_config`. 3. Use the `read_file` tool to read the current kernel file at {optimized_kernel_path?}. -4. Use the `restricted_write_file` tool to save the updated kernel file, replacing the old parameter values with the values from `best_cfg`. +4. Use the `restricted_write_file` tool to save the updated kernel file, replacing the old parameter values with the values from `best_config`. Be precise and ensure you only change the specific parameter values identified in the best configuration. """ diff --git a/MaxKernel/auto_agent/subagents/autotuning/prompts/autotune_prompt.py b/MaxKernel/auto_agent/subagents/autotuning/prompts/autotune_prompt.py index aa5c484..23a5e2a 100644 --- a/MaxKernel/auto_agent/subagents/autotuning/prompts/autotune_prompt.py +++ b/MaxKernel/auto_agent/subagents/autotuning/prompts/autotune_prompt.py @@ -3,15 +3,17 @@ PROMPT = """You are a specialized agent for preparing autotuning specifications for Pallas kernels. Your goal is to identify parameters, create a template, and define the search space to minimize execution time. +CRITICAL: Do NOT attempt to optimize the kernel code, improve its logic, or fix any bugs. Your task is strictly to prepare the template for autotuning by replacing hardcoded parameters with placeholders and adding timing code. + To prepare for autotuning, you must: 1. Identify the parameters that can be tuned in the kernel (e.g., BLOCK_M, BLOCK_N). 2. Create a code template from the kernel code, replacing the specific parameter values with placeholders enclosed in curly braces (for example, if the parameter is BLOCK_M, use it enclosed in curly braces as the placeholder). -3. Ensure the template code prints "RESULT_TIME: " to indicate the average execution time. To get accurate and quick timing, wrap the kernel call in a loop of exactly 10 iterations (preceded by 1 warm-up execution) and use `jax.block_until_ready()`. Limit iterations strictly to 10 to keep profiling runs fast. WARNING: If you wrap the kernel call in a loop, check if the kernel donates its input buffers (look for `donate_argnames` in the kernel decorator). If it does, calling it repeatedly with the same inputs will fail. To fix this, either disable donation in the template or pre-create a list of inputs (one for each iteration) before the loop. +3. Ensure the template code prints "RESULT_TIME: ms" to indicate the average execution time in microseconds. To get accurate and quick timing, wrap the kernel call in a loop of exactly 10 iterations (preceded by 1 warm-up execution) and use `jax.block_until_ready()`. Limit iterations strictly to 10 to keep profiling runs fast. WARNING: If you wrap the kernel call in a loop, check if the kernel donates its input buffers (look for `donate_argnames` in the kernel decorator). If it does, calling it repeatedly with the same inputs will fail. To fix this, either disable donation in the template or pre-create a list of inputs (one for each iteration) before the loop. 4. Define a highly optimized, high-probability search space as a dictionary mapping placeholder names to lists of suggested values. You MUST follow these rules to minimize evaluation time and avoid sub-optimal configurations: - **Hardware Alignment**: Only suggest block sizes that align with hardware efficiency (typically multiples of 32 or 64, e.g., `[32, 64, 128]`). Avoid extremely small values (like `16`) or large values (like `256` or more) unless they are perfectly aligned with specific small tensor shapes. - **Dimension Divisors**: Choose suggested block sizes that are clean, even divisors of the corresponding matrix or tensor shape dimensions to prevent compiler masking and branch overhead. - **Total Combinations Limit**: Proactively limit the size of individual parameter lists so that the total Cartesian product (all possible combinations) stays small—ideally between **10 to 100 total combinations max**. Keep each parameter list to 2 or 3 high-probability values (e.g., `[64, 128]`). Do not generate massive combinatorial sweeps. -5. Write the `kernel_name`, `code_template`, and `search_space` to a JSON file named `autotune_specs.json`. You MUST save this file (and any helper scripts you create like `create_specs.py`) in the directory specified by `{workdir?}`. Use that full path with `filesystem_tool_rw`. +5. Write the `kernel_name`, `code_template`, and `search_space` to a JSON and save it using the `restricted_write_file` tool. The JSON file must have exactly this structure: { "kernel_name": "...", diff --git a/MaxKernel/auto_agent/subagents/autotuning/prompts/summary_prompt.py b/MaxKernel/auto_agent/subagents/autotuning/prompts/summary_prompt.py index ba05f4d..70adb3e 100644 --- a/MaxKernel/auto_agent/subagents/autotuning/prompts/summary_prompt.py +++ b/MaxKernel/auto_agent/subagents/autotuning/prompts/summary_prompt.py @@ -3,14 +3,16 @@ PROMPT = """ You are an AI assistant summarizing autotuning results for a Pallas kernel optimization task. -Your goal is to read the results from a file, report the best configuration and latency to the user, and state whether this configuration was applied. +Your goal is to summarize the autotuning results provided below, report the best configuration and latency to the user, and state whether this configuration was applied. + +Autotuning Results: +{autotune_results} Instructions: -1. **Read Results**: Use the `read_file` tool to read the file at {autotune_results_path?} and parse its JSON content. -2. **Extract Metrics**: Find the `"best_cfg"` and `"best_time_ms"` in the JSON. -3. **Summarize**: Provide a clear summary in your response. Do NOT list all tested configurations from `all_results`. -4. **Verify Application**: State whether the best configuration was applied to the file at {optimized_kernel_path?}. (Note: In the current pipeline, if the autotune status is "success", the best configuration is automatically applied before you run). -5. **Handle Errors**: If the status is `"failed"` or `"error"`, report the error message provided in the file. +1. **Extract Metrics**: Find the `"best_cfg"` and `"best_time_ms"` in the results above. +2. **Summarize**: Provide a clear summary in your response. Do NOT list all tested configurations from `all_results`. +3. **Verify Application**: To determine if the best configuration was applied, read the file at {optimized_kernel_path?} and verify that the configuration parameters in the file match the values listed in `"best_config"` from the autotuning results. State whether it was applied. +4. **Handle Errors**: If the status is `"failed"` or `"error"`, report the error message provided in the file. Please use the following format for your summary: ### Autotuning Results From cb365a479cba468e9826d2cf9197f541e20c06f1 Mon Sep 17 00:00:00 2001 From: shangkunwang Date: Fri, 22 May 2026 00:09:15 +0000 Subject: [PATCH 7/8] feat: incorporate autotuning summary into kernel planning prompts and agent output --- MaxKernel/auto_agent/subagents/autotuning/agent.py | 1 + .../subagents/kernel_writing/prompts/kernel_planning_prompt.py | 1 + 2 files changed, 2 insertions(+) diff --git a/MaxKernel/auto_agent/subagents/autotuning/agent.py b/MaxKernel/auto_agent/subagents/autotuning/agent.py index 817978c..80fff69 100644 --- a/MaxKernel/auto_agent/subagents/autotuning/agent.py +++ b/MaxKernel/auto_agent/subagents/autotuning/agent.py @@ -168,6 +168,7 @@ async def _run_async_impl( instruction=summary_prompt.PROMPT, description="Summarizes autotuning results for the user.", tools=[filesystem_tool_r], + output_key="autotuning_summary", ) diff --git a/MaxKernel/auto_agent/subagents/kernel_writing/prompts/kernel_planning_prompt.py b/MaxKernel/auto_agent/subagents/kernel_writing/prompts/kernel_planning_prompt.py index 1a5b0fd..a6ce62c 100644 --- a/MaxKernel/auto_agent/subagents/kernel_writing/prompts/kernel_planning_prompt.py +++ b/MaxKernel/auto_agent/subagents/kernel_writing/prompts/kernel_planning_prompt.py @@ -35,6 +35,7 @@ 2. **Review execution results:** Analyze the following to identify what needs improvement: * Compilation Status: `{kernel_compilation_status?}` * Test Results: `{test_results?}` + *. Autotune Summary: `{autotuning_summary?}` * Profiling Summary: `{profiling_summary?}` 3. **Follow Guidelines:** * Preserve good ideas from the original plan that are not causing issues. From 98defd6f85e3d1f267968d4b247f40345fd622d6 Mon Sep 17 00:00:00 2001 From: shangkunwang Date: Sat, 30 May 2026 17:30:30 +0000 Subject: [PATCH 8/8] fix: propagate total_timeout to eval server payload --- MaxKernel/auto_agent/server_utils/eval_server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/MaxKernel/auto_agent/server_utils/eval_server.py b/MaxKernel/auto_agent/server_utils/eval_server.py index 591567a..0ba22a5 100644 --- a/MaxKernel/auto_agent/server_utils/eval_server.py +++ b/MaxKernel/auto_agent/server_utils/eval_server.py @@ -227,6 +227,7 @@ async def _perform_evaluation(request: EvalRequest): payload["code_template"] = request.code_template payload["search_space"] = request.search_space backend_timeout = request.total_timeout + payload["total_timeout"] = request.total_timeout else: payload["code"] = request.code payload["dependencies"] = request.dependencies