diff --git a/MaxKernel/auto_agent/agent.py b/MaxKernel/auto_agent/agent.py index e57004c..cc2e2e1 100644 --- a/MaxKernel/auto_agent/agent.py +++ b/MaxKernel/auto_agent/agent.py @@ -4,6 +4,7 @@ for the human-in-the-loop kernel generation process. """ +from auto_agent.subagents.autotuning.agent import autotune_agent from auto_agent.subagents.kernel_writing import ( implement_kernel_agent, plan_kernel_agent, @@ -23,6 +24,7 @@ validate_agent=validate_kernel_compilation_agent, test_gen_agent=validated_test_generation_agent, test_run_agent=unified_test_agent, + autotune_agent=autotune_agent, profile_agent=profile_agent, max_iterations=5, ) diff --git a/MaxKernel/auto_agent/client_utils/eval_client.py b/MaxKernel/auto_agent/client_utils/eval_client.py new file mode 100644 index 0000000..b10e92d --- /dev/null +++ b/MaxKernel/auto_agent/client_utils/eval_client.py @@ -0,0 +1,76 @@ +import asyncio +import logging +import time + +import aiohttp + + +async def call_eval_server_async( + session: aiohttp.ClientSession, + eval_server_url: str, + payload: dict, + poll_interval: int = 10, + client_wait_timeout: int = 3600 * 3, # Default to 3 hours +) -> dict: + """Calls the evaluation server asynchronously and polls for status. + + Args: + session: aiohttp.ClientSession to use. + eval_server_url: Base URL of the eval server (e.g., + "http://localhost:1245"). + payload: The request payload. + poll_interval: Seconds to wait between polls. + client_wait_timeout: Max seconds to wait for task completion. + + Returns: + The result from the evaluation server. + """ + # 1. Submit the task + submit_url = f"{eval_server_url}/evaluate" + logging.info(f"Submitting async task to {submit_url}") + + payload = payload.copy() + payload["client_wait_timeout"] = client_wait_timeout + + async with session.post(submit_url, json=payload) as response: + if response.status != 202: + error_text = await response.text() + raise Exception( + f"Failed to submit task. Status: {response.status}, Error: {error_text}" + ) + resp_data = await response.json() + task_id = resp_data["task_id"] + logging.info(f"Task submitted successfully. ID: {task_id}") + + # 2. Poll for status + start_time = time.time() + status_url = f"{eval_server_url}/status/{task_id}" + + while True: + if time.time() - start_time > client_wait_timeout: + raise Exception( + f"Client timed out waiting for task {task_id} after {client_wait_timeout} seconds" + ) + + async with session.get(status_url) as response: + if response.status != 200: + error_text = await response.text() + raise Exception( + f"Failed to get task status. Status: {response.status}, Error: {error_text}" + ) + + status_data = await response.json() + status = status_data["status"] + + if status == "success": + logging.info(f"Task {task_id} completed successfully.") + return status_data["result"] + elif status in ["failed", "timeout"]: + raise Exception( + f"Task {task_id} ended with status {status}: {status_data.get('error')}" + ) + + logging.info( + f"Task {task_id} status: {status}. Waiting {poll_interval}s..." + ) + await asyncio.sleep(poll_interval) diff --git a/MaxKernel/auto_agent/constants.py b/MaxKernel/auto_agent/constants.py index ad1aec0..24ff4c1 100644 --- a/MaxKernel/auto_agent/constants.py +++ b/MaxKernel/auto_agent/constants.py @@ -4,9 +4,7 @@ TEMPERATURE = 0.1 TOP_P = 0.9 TOP_K = 5 -TPU_TIMEOUT = 120 -REQUEST_TIMEOUT = 1800 +REQUEST_TIMEOUT = 3600 * 3 TPU_SERVER_PORT = 5463 CPU_SERVER_PORT = 5464 EVAL_SERVER_PORT = 1245 -PERF_THRESHOLD = 1.1 diff --git a/MaxKernel/auto_agent/server_utils/eval_server.py b/MaxKernel/auto_agent/server_utils/eval_server.py index c3418bf..0ba22a5 100644 --- a/MaxKernel/auto_agent/server_utils/eval_server.py +++ b/MaxKernel/auto_agent/server_utils/eval_server.py @@ -1,20 +1,17 @@ import asyncio import logging +import time +import uuid from enum import Enum from typing import Optional import aiohttp import yaml -from fastapi import FastAPI, HTTPException +from fastapi import BackgroundTasks, FastAPI, HTTPException from pydantic import BaseModel -from auto_agent.constants import ( - EVAL_SERVER_PORT, -) -from auto_agent.server_utils.tpu_server import ( - CodeResponse, - get_tpu_version, -) +from auto_agent.constants import EVAL_SERVER_PORT +from auto_agent.server_utils.tpu_server import get_tpu_version logging.basicConfig( level=logging.INFO, @@ -62,14 +59,38 @@ class EvalTypes(Enum): PERFORMANCE_TEST = "performance_test" UNIFIED_TEST = "unified_test" PROFILE = "profile" + AUTOTUNE = "autotune" class EvalRequest(BaseModel): eval_type: EvalTypes - code: str - timeout: Optional[int] = 30 + timeout: int + code: Optional[str] = None + code_template: Optional[str] = None # For autotune + search_space: Optional[dict] = None # For autotune backend_type: Optional[str] = None # "tpu", "cpu", or None for any available dependencies: Optional[dict] = None + total_timeout: Optional[int] = None # For autotune + client_wait_timeout: Optional[int] = None + + +class TaskStatus(str, Enum): + QUEUED = "queued" + RUNNING = "running" + SUCCESS = "success" + FAILED = "failed" + TIMEOUT = "timeout" + + +class TaskResponse(BaseModel): + task_id: str + status: TaskStatus + result: Optional[dict] = None + error: Optional[str] = None + + +# For evaluation task tracking. task_id -> {status, request, result_or_error} +_tasks = {} class Evaluator: @@ -115,13 +136,64 @@ async def health_check(): return {"status": "healthy"} -@app.post("/evaluate", response_model=CodeResponse) -async def evaluate(request: EvalRequest): +# Asynchronous task polling architecture for evaluation. +@app.post("/evaluate", status_code=202) +async def evaluate(request: EvalRequest, background_tasks: BackgroundTasks): + task_id = str(uuid.uuid4()) + _tasks[task_id] = {"status": TaskStatus.QUEUED, "request": request} + background_tasks.add_task(run_evaluation_task, task_id, request) + return {"task_id": task_id, "status": TaskStatus.QUEUED} + + +@app.get("/status/{task_id}", response_model=TaskResponse) +async def get_task_status(task_id: str): + if task_id not in _tasks: + raise HTTPException(status_code=404, detail="Task not found") + task_info = _tasks[task_id] + return { + "task_id": task_id, + "status": task_info["status"], + "result": task_info.get("result"), + "error": task_info.get("error"), + } + + +async def run_evaluation_task(task_id: str, request: EvalRequest): + _tasks[task_id]["status"] = TaskStatus.RUNNING + try: + result = await _perform_evaluation(request) + _tasks[task_id]["status"] = TaskStatus.SUCCESS + _tasks[task_id]["result"] = result + except HTTPException as e: + if e.status_code == 408: + _tasks[task_id]["status"] = TaskStatus.TIMEOUT + else: + _tasks[task_id]["status"] = TaskStatus.FAILED + _tasks[task_id]["error"] = e.detail + except Exception as e: + _tasks[task_id]["status"] = TaskStatus.FAILED + _tasks[task_id]["error"] = str(e) + + +async def _perform_evaluation(request: EvalRequest): if request.eval_type not in EvalTypes: raise HTTPException(status_code=400, detail="Invalid evaluation type") # Acquire backend, retry if busy + start_queue_time = time.time() + max_queue_time = ( + request.client_wait_timeout + if request.client_wait_timeout is not None + else 3600 * 3 + ) + while True: + if time.time() - start_queue_time > max_queue_time: + raise HTTPException( + status_code=408, + detail="Timed out waiting for an available backend in queue", + ) + async with backend_semaphore: backend = evaluator.get_available_backend( backend_type=request.backend_type @@ -146,18 +218,26 @@ async def evaluate(request: EvalRequest): f"{requested_type_msg}" ) - # Send request to backend server - backend_timeout = request.timeout if request.timeout is not None else 30 + # Construct payload based on eval type + payload = { + "eval_type": request.eval_type.value, + "timeout": request.timeout, + } + if request.eval_type == EvalTypes.AUTOTUNE: + payload["code_template"] = request.code_template + payload["search_space"] = request.search_space + backend_timeout = request.total_timeout + payload["total_timeout"] = request.total_timeout + else: + payload["code"] = request.code + payload["dependencies"] = request.dependencies + backend_timeout = request.timeout + client_timeout = aiohttp.ClientTimeout(total=backend_timeout + 10) async with aiohttp.ClientSession(timeout=client_timeout) as session: async with session.post( f"http://{backend_ip}:{backend_port}/{request.eval_type.value}", - json={ - "eval_type": request.eval_type.value, - "code": request.code, - "timeout": request.timeout, - "dependencies": request.dependencies, - }, + json=payload, ) as response: result = await response.json() logging.info( @@ -186,7 +266,9 @@ async def evaluate(request: EvalRequest): raise e except Exception as e: logging.error(f"Error occurred while evaluating on {backend_name}: {e}") - raise HTTPException(status_code=500, detail="Backend evaluation failed") + raise HTTPException( + status_code=500, detail=f"Backend evaluation failed: {e}" + ) finally: backend.set_status("available") diff --git a/MaxKernel/auto_agent/server_utils/tpu_server.py b/MaxKernel/auto_agent/server_utils/tpu_server.py index 3377677..f23ab31 100644 --- a/MaxKernel/auto_agent/server_utils/tpu_server.py +++ b/MaxKernel/auto_agent/server_utils/tpu_server.py @@ -1,4 +1,5 @@ import asyncio +import itertools import json import logging import os @@ -6,6 +7,7 @@ import subprocess import sys import tempfile +import time from typing import Optional from fastapi import FastAPI, HTTPException @@ -27,6 +29,7 @@ correctness_semaphore = asyncio.Semaphore(1) performance_semaphore = asyncio.Semaphore(1) profile_semaphore = asyncio.Semaphore(1) +autotune_semaphore = asyncio.Semaphore(1) class CodeRequest(BaseModel): @@ -41,6 +44,13 @@ class CodeResponse(BaseModel): exit_code: int +class AutotuneRequest(BaseModel): + code_template: str + search_space: dict[str, list] + timeout: Optional[int] = 300 + total_timeout: Optional[int] = None + + class GetTpuVersionResponse(BaseModel): tpu_version: str @@ -498,6 +508,132 @@ async def profile(request: CodeRequest): logging.info("Profile analysis finished") +@app.post("/autotune", response_model=CodeResponse) +async def autotune(request: AutotuneRequest): + logging.info("Starting autotune") + async with performance_semaphore: + try: + # Generate all combinations + keys = list(request.search_space.keys()) + values = list(request.search_space.values()) + combinations = list(itertools.product(*values)) + + best_time = float("inf") + best_cfg = None + best_output = "" + all_results = [] + + start_time = time.time() + for combo in combinations: + if ( + request.total_timeout + and (time.time() - start_time) > request.total_timeout + ): + logging.warning( + f"Total timeout of {request.total_timeout}s reached in autotune" + ) + break + cfg = dict(zip(keys, combo)) + try: + code_content = request.code_template + for k, v in cfg.items(): + code_content = code_content.replace(f"{{{k}}}", str(v)) + except Exception as e: + logging.error(f"Error during template formatting: {e}. Config: {cfg}") + continue + + # Execute the code + with tempfile.NamedTemporaryFile( + mode="w", suffix=".py", prefix="hitl_eval_", delete=False + ) as temp_file: + temp_file.write(code_content) + temp_file_path = temp_file.name + + process = None + try: + process = await asyncio.create_subprocess_exec( + sys.executable, + temp_file_path, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=tempfile.gettempdir(), + ) + + stdout, stderr = await asyncio.wait_for( + process.communicate(), timeout=request.timeout + ) + + output = stdout.decode("utf-8") if stdout else "" + error = stderr.decode("utf-8") if stderr else "" + exit_code = process.returncode + + if exit_code == 0: + # Parse RESULT_TIME + match = re.search(r"RESULT_TIME:\s*([0-9.]+)\s*ms", output) + if match: + time_taken = float(match.group(1)) + all_results.append( + {"cfg": cfg, "time": time_taken, "status": "success"} + ) + if time_taken < best_time: + best_time = time_taken + best_cfg = cfg + best_output = output + else: + logging.warning( + f"No RESULT_TIME found in output for config {cfg}" + ) + all_results.append( + {"cfg": cfg, "status": "no_result_time", "output": output} + ) + else: + logging.warning( + f"Config {cfg} failed with exit code {exit_code}. Stderr: {error}" + ) + all_results.append( + { + "cfg": cfg, + "status": "failed", + "error": error, + "exit_code": exit_code, + } + ) + + except asyncio.TimeoutError: + logging.warning(f"Config {cfg} timed out") + if process: + process.kill() + await process.wait() + all_results.append({"cfg": cfg, "status": "timeout"}) + except Exception as e: + logging.error(f"Error running config {cfg}: {e}") + all_results.append( + {"cfg": cfg, "status": "exception", "error": str(e)} + ) + finally: + if "temp_file_path" in locals(): + try: + os.unlink(temp_file_path) + except OSError: + pass + + return CodeResponse( + output=json.dumps( + { + "best_cfg": best_cfg, + "best_time": best_time, + "best_output": best_output, + "all_results": all_results, + } + ), + exit_code=0 if best_cfg is not None else 1, + ) + + except Exception as e: + logging.error(f"Autotune failed with error: {str(e)}") + raise HTTPException(status_code=500, detail=f"Autotune error: {str(e)}") + + @app.post("/get_tpu_version", response_model=GetTpuVersionResponse) def get_tpu_version() -> str: """Attempts to determine the TPU version by trying three methods. diff --git a/MaxKernel/auto_agent/subagents/autotuning/agent.py b/MaxKernel/auto_agent/subagents/autotuning/agent.py new file mode 100644 index 0000000..80fff69 --- /dev/null +++ b/MaxKernel/auto_agent/subagents/autotuning/agent.py @@ -0,0 +1,213 @@ +"""Autotuning agent following the split pattern (Planner + Runner).""" + +import json +import logging +import os +from typing import AsyncGenerator, Optional + +from google.adk.agents import BaseAgent +from google.adk.agents.invocation_context import InvocationContext +from google.adk.events import Event, EventActions + +from auto_agent.config import model_config, thinking_planner +from auto_agent.constants import MODEL_NAME +from auto_agent.custom_types import CustomLlmAgent +from auto_agent.subagents.autotuning.autotune_tool import autotune_kernel +from auto_agent.subagents.autotuning.prompts import ( + apply_best_config_prompt, + autotune_prompt, + summary_prompt, +) +from auto_agent.tools.file_tools import ( + filesystem_tool_r, + write_autotune_specs_tool, + write_optimized_kernel_tool, +) +from auto_agent.tools.search_api_tool import search_api_tool + +# 1. Planner Agent (LLM) +# This agent identifies parameters, creates the template, and defines the search space. +# It saves them to session state instead of calling the tool directly. +autotune_planner_agent = CustomLlmAgent( + name="AutotunePlannerAgent", + model=MODEL_NAME, + generate_content_config=model_config, + planner=thinking_planner, + instruction=autotune_prompt.PROMPT, + description="Prepares code template and search space for auto-tuning Pallas kernels.", + tools=[filesystem_tool_r, write_autotune_specs_tool, search_api_tool], +) + + +# 2. Runner Agent +class AutotuneRunner(BaseAgent): + """Executes autotuning via HTTP endpoint.""" + + name: Optional[str] = None + output_key: Optional[str] = None + + def __init__( + self, + name: str, + output_key: str, + ): + BaseAgent.__init__(self, name=name) + self.output_key = output_key + + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + + autotune_specs_path = ctx.session.state.get("autotune_specs_path", "") + autotune_results_path = ctx.session.state.get("autotune_results_path", "") + + if not os.path.exists(autotune_specs_path): + error_msg = f"Autotune specs file not found at {autotune_specs_path}" + logging.error(error_msg) + yield Event( + author=self.name, + actions=EventActions( + state_delta={ + self.output_key: {"status": "error", "message": error_msg} + } + ), + ) + return + + try: + with open(autotune_specs_path, "r") as f: + specs = json.load(f) + kernel_name = specs.get("kernel_name", "") + code_template = specs.get("code_template", "") + search_space = specs.get("search_space", {}) + except Exception as e: + error_msg = f"Failed to parse autotune specs JSON: {e}" + logging.error(error_msg) + yield Event( + author=self.name, + actions=EventActions( + state_delta={ + self.output_key: {"status": "error", "message": error_msg} + } + ), + ) + return + + if not kernel_name or not code_template or not search_space: + error_msg = "Missing required inputs in autotune specs file." + logging.error(error_msg) + yield Event( + author=self.name, + actions=EventActions( + state_delta={ + self.output_key: {"status": "error", "message": error_msg} + } + ), + ) + return + + logging.info(f"[{self.name}] Running autotune for {kernel_name}") + + try: + results = await autotune_kernel( + kernel_name=kernel_name, + code_template=code_template, + search_space=search_space, + backend="tpu", + ) + + try: + with open(autotune_results_path, "w") as f: + json.dump(results, f) + logging.info(f"[{self.name}] Saved results to {autotune_results_path}") + except Exception as e: + logging.error(f"[{self.name}] Failed to save results to file: {e}") + + yield Event( + author=self.name, + actions=EventActions( + state_delta={ + self.output_key: results, + } + ), + ) + + except Exception as e: + logging.error(f"Exception during autotuning: {e}") + yield Event( + author=self.name, + actions=EventActions( + state_delta={self.output_key: {"status": "error", "message": str(e)}} + ), + ) + + +autotune_runner = AutotuneRunner( + name="AutotuneRunner", + output_key="autotune_results", +) + +# 3. Apply Best Config Agent +apply_best_config_agent = CustomLlmAgent( + name="ApplyBestConfigAgent", + model=MODEL_NAME, + generate_content_config=model_config, + planner=thinking_planner, + instruction=apply_best_config_prompt.PROMPT, + description="Applies autotuning results to the optimized kernel file.", + tools=[filesystem_tool_r, write_optimized_kernel_tool], +) + +# 4. Summarizer Agent +# This agent reads results from state and talks to the user. +autotune_summary_agent = CustomLlmAgent( + name="AutotuneSummaryAgent", + model=MODEL_NAME, + generate_content_config=model_config, + planner=thinking_planner, + instruction=summary_prompt.PROMPT, + description="Summarizes autotuning results for the user.", + tools=[filesystem_tool_r], + output_key="autotuning_summary", +) + + +class CombinedAutotuneAgent(BaseAgent): + """Chains autotuning steps and conditionally applies best config.""" + + def __init__(self, name: str): + super().__init__(name=name) + + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + logging.info(f"[{self.name}] Running AutotunePlannerAgent...") + async for event in autotune_planner_agent.run_async(ctx): + yield event + + logging.info(f"[{self.name}] Running AutotuneRunner...") + async for event in autotune_runner.run_async(ctx): + yield event + + autotune_results = ctx.session.state.get("autotune_results", {}) + if ( + autotune_results.get("status") == "success" + and autotune_results.get("best_config") is not None + ): + logging.info(f"[{self.name}] Running ApplyBestConfigAgent...") + async for event in apply_best_config_agent.run_async(ctx): + yield event + else: + logging.warning( + f"[{self.name}] Autotune was not successful or no best configuration" + " found. Skipping ApplyBestConfigAgent." + ) + + logging.info(f"[{self.name}] Running AutotuneSummaryAgent...") + async for event in autotune_summary_agent.run_async(ctx): + yield event + + +autotune_agent = CombinedAutotuneAgent(name="AutotuneAgent") + +__all__ = ["autotune_agent"] diff --git a/MaxKernel/auto_agent/subagents/autotuning/autotune_tool.py b/MaxKernel/auto_agent/subagents/autotuning/autotune_tool.py new file mode 100644 index 0000000..6cf3151 --- /dev/null +++ b/MaxKernel/auto_agent/subagents/autotuning/autotune_tool.py @@ -0,0 +1,110 @@ +"""Standalone tool for auto-tuning Pallas kernels using grid search on remote servers.""" + +import json +import logging +from typing import Any + +import aiohttp + +from auto_agent.client_utils.eval_client import call_eval_server_async +from auto_agent.constants import EVAL_SERVER_PORT, REQUEST_TIMEOUT + +AUTOTUNE_INDIVIDUAL_TIMEOUT = 300 +AUTOTUNE_TOTAL_TIMEOUT = 5400 + + +async def autotune_kernel( + kernel_name: str, + code_template: str, + search_space: dict[str, list[Any]], + backend: str = None, + server_addr: str = "http://localhost", +) -> dict: + """Runs a grid search to auto-tune a Pallas kernel on a remote server. + + Args: + kernel_name: Name of the kernel. + code_template: Python code containing placeholders for parameters to be + tuned. It should produce a line like "RESULT_TIME: " in its + output to indicate performance. + search_space: A dictionary mapping placeholder names to lists of feasible + values. + backend: 'tpu' or 'cpu'. + server_addr: Address of the server (default: http://localhost). + + Returns: + A dictionary containing the status, optimal parameters, and a summary of + results. + """ + logging.info( + f"Starting remote autotuning for kernel: {kernel_name} on {backend}" + ) + + url = f"{server_addr}:{EVAL_SERVER_PORT}/evaluate" + + try: + client_timeout = aiohttp.ClientTimeout(total=REQUEST_TIMEOUT + 10) + async with aiohttp.ClientSession(timeout=client_timeout) as session: + payload = { + "eval_type": "autotune", + "code_template": code_template, + "search_space": search_space, + "timeout": AUTOTUNE_INDIVIDUAL_TIMEOUT, + "backend_type": backend, + "total_timeout": AUTOTUNE_TOTAL_TIMEOUT, + } + result = await call_eval_server_async( + session, + f"{server_addr}:{EVAL_SERVER_PORT}", + payload, + poll_interval=10, + client_wait_timeout=REQUEST_TIMEOUT, + ) + + if result["exit_code"] == 0: + try: + output_data = json.loads(result["output"]) + logging.info( + f"Autotuning completed. Best config: {output_data['best_cfg']}" + f" with time {output_data['best_time']} ms" + ) + return { + "status": "success", + "message": "Autotuning completed", + "best_config": output_data["best_cfg"], + "best_time_ms": output_data["best_time"], + "best_output": output_data["best_output"], + "all_results": output_data.get("all_results", []), + } + except json.JSONDecodeError: + logging.warning("Failed to decode JSON from server output.") + return { + "status": "success", + "message": "Autotuning completed (raw output)", + "raw_output": result["output"], + } + else: + try: + output_data = json.loads(result["output"]) + return { + "status": "failed", + "message": result["error"] or "Autotune failed on server", + "all_results": output_data.get("all_results", []), + } + except Exception: + return { + "status": "failed", + "message": result["error"] or "Autotune failed on server", + "server_output": result["output"], + } + + except aiohttp.ClientConnectorError: + return { + "status": "error", + "message": ( + f"Could not connect to server at {url}. Make sure it is running." + ), + } + + except Exception as e: + return {"status": "error", "message": str(e)} diff --git a/MaxKernel/auto_agent/subagents/autotuning/prompts/apply_best_config_prompt.py b/MaxKernel/auto_agent/subagents/autotuning/prompts/apply_best_config_prompt.py new file mode 100644 index 0000000..363b123 --- /dev/null +++ b/MaxKernel/auto_agent/subagents/autotuning/prompts/apply_best_config_prompt.py @@ -0,0 +1,13 @@ +"""Prompt for ApplyBestConfigAgent.""" + +PROMPT = """You are a specialized agent for applying autotuning results to a Pallas kernel file. +Your goal is to read the best configuration from autotuning results and update the `optimized_kernel.py` file with these values. + +You must: +1. use the `read_file` tool to read the file at {autotune_specs_path?} to get the context of autotuning experiment. +2. Use the `read_file` tool to read the file at {autotune_results_path?} and parse the JSON content of the autotune results to find the best configuration `best_config`. +3. Use the `read_file` tool to read the current kernel file at {optimized_kernel_path?}. +4. Use the `restricted_write_file` tool to save the updated kernel file, replacing the old parameter values with the values from `best_config`. + +Be precise and ensure you only change the specific parameter values identified in the best configuration. +""" diff --git a/MaxKernel/auto_agent/subagents/autotuning/prompts/autotune_prompt.py b/MaxKernel/auto_agent/subagents/autotuning/prompts/autotune_prompt.py new file mode 100644 index 0000000..23a5e2a --- /dev/null +++ b/MaxKernel/auto_agent/subagents/autotuning/prompts/autotune_prompt.py @@ -0,0 +1,31 @@ +"""Prompt for AutotuneAgent.""" + +PROMPT = """You are a specialized agent for preparing autotuning specifications for Pallas kernels. +Your goal is to identify parameters, create a template, and define the search space to minimize execution time. + +CRITICAL: Do NOT attempt to optimize the kernel code, improve its logic, or fix any bugs. Your task is strictly to prepare the template for autotuning by replacing hardcoded parameters with placeholders and adding timing code. + +To prepare for autotuning, you must: +1. Identify the parameters that can be tuned in the kernel (e.g., BLOCK_M, BLOCK_N). +2. Create a code template from the kernel code, replacing the specific parameter values with placeholders enclosed in curly braces (for example, if the parameter is BLOCK_M, use it enclosed in curly braces as the placeholder). +3. Ensure the template code prints "RESULT_TIME: ms" to indicate the average execution time in microseconds. To get accurate and quick timing, wrap the kernel call in a loop of exactly 10 iterations (preceded by 1 warm-up execution) and use `jax.block_until_ready()`. Limit iterations strictly to 10 to keep profiling runs fast. WARNING: If you wrap the kernel call in a loop, check if the kernel donates its input buffers (look for `donate_argnames` in the kernel decorator). If it does, calling it repeatedly with the same inputs will fail. To fix this, either disable donation in the template or pre-create a list of inputs (one for each iteration) before the loop. +4. Define a highly optimized, high-probability search space as a dictionary mapping placeholder names to lists of suggested values. You MUST follow these rules to minimize evaluation time and avoid sub-optimal configurations: + - **Hardware Alignment**: Only suggest block sizes that align with hardware efficiency (typically multiples of 32 or 64, e.g., `[32, 64, 128]`). Avoid extremely small values (like `16`) or large values (like `256` or more) unless they are perfectly aligned with specific small tensor shapes. + - **Dimension Divisors**: Choose suggested block sizes that are clean, even divisors of the corresponding matrix or tensor shape dimensions to prevent compiler masking and branch overhead. + - **Total Combinations Limit**: Proactively limit the size of individual parameter lists so that the total Cartesian product (all possible combinations) stays small—ideally between **10 to 100 total combinations max**. Keep each parameter list to 2 or 3 high-probability values (e.g., `[64, 128]`). Do not generate massive combinatorial sweeps. +5. Write the `kernel_name`, `code_template`, and `search_space` to a JSON and save it using the `restricted_write_file` tool. +The JSON file must have exactly this structure: +{ + "kernel_name": "...", + "code_template": "...", + "search_space": { ... } +} + +## Tools Available +1. **`search_api`**: Search for API definitions +2. **`read_file`**: Read the kernel code file. + - Required Argument: `path` +3. **`restricted_write_file`**: Write the json file + - Required Argument: `content` (The complete file content) + - Example: `restricted_write_file(content=...)` +""" diff --git a/MaxKernel/auto_agent/subagents/autotuning/prompts/summary_prompt.py b/MaxKernel/auto_agent/subagents/autotuning/prompts/summary_prompt.py new file mode 100644 index 0000000..70adb3e --- /dev/null +++ b/MaxKernel/auto_agent/subagents/autotuning/prompts/summary_prompt.py @@ -0,0 +1,25 @@ +"""Prompt for AutotuneSummarizerAgent.""" + +PROMPT = """ +You are an AI assistant summarizing autotuning results for a Pallas kernel optimization task. + +Your goal is to summarize the autotuning results provided below, report the best configuration and latency to the user, and state whether this configuration was applied. + +Autotuning Results: +{autotune_results} + +Instructions: +1. **Extract Metrics**: Find the `"best_cfg"` and `"best_time_ms"` in the results above. +2. **Summarize**: Provide a clear summary in your response. Do NOT list all tested configurations from `all_results`. +3. **Verify Application**: To determine if the best configuration was applied, read the file at {optimized_kernel_path?} and verify that the configuration parameters in the file match the values listed in `"best_config"` from the autotuning results. State whether it was applied. +4. **Handle Errors**: If the status is `"failed"` or `"error"`, report the error message provided in the file. + +Please use the following format for your summary: +### Autotuning Results +- **Status**: [Success / Failed] +- **Best Configuration**: `[JSON or description of best config]` +- **Latency**: `[Time]` ms +- **Applied to File**: [Yes / No] + +[Any additional brief notes or error messages] +""" diff --git a/MaxKernel/auto_agent/subagents/kernel_writing/kernel_compilation.py b/MaxKernel/auto_agent/subagents/kernel_writing/kernel_compilation.py index 9fb505a..c2adbca 100644 --- a/MaxKernel/auto_agent/subagents/kernel_writing/kernel_compilation.py +++ b/MaxKernel/auto_agent/subagents/kernel_writing/kernel_compilation.py @@ -6,11 +6,10 @@ from google.adk.agents.invocation_context import InvocationContext from google.adk.events import Event, EventActions -from auto_agent.constants import ( - EVAL_SERVER_PORT, - REQUEST_TIMEOUT, - TPU_TIMEOUT, -) +from auto_agent.client_utils.eval_client import call_eval_server_async +from auto_agent.constants import EVAL_SERVER_PORT, REQUEST_TIMEOUT + +COMPILATION_TIMEOUT = 120 class KernelCompilationChecker(BaseAgent): @@ -47,69 +46,63 @@ async def _run_async_impl( # Call the TPU server to execute the code logging.info(f"[{self.name}] Running code") async with aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(total=REQUEST_TIMEOUT) + timeout=aiohttp.ClientTimeout(total=REQUEST_TIMEOUT + 10) ) as session: - async with session.post( - f"http://localhost:{EVAL_SERVER_PORT}/evaluate", - json={ - "eval_type": "compilation_test", - "code": code, - "timeout": TPU_TIMEOUT, - "backend_type": "tpu", - }, - ) as response: - if response.status == 200: - result = await response.json() - logging.info(f"[{self.name}] Compilation test result: {result}") - if result["exit_code"] == 0: - logging.info(f"[{self.name}] Code execution successful.") - yield Event( - author=self.name, - actions=EventActions(state_delta={self.output_key: "Success"}), - ) - elif ( - result["error"] is None - and result["output"] == "" - and result["exit_code"] == 1 - ): - logging.info( - f"[{self.name}] Code execution had exit code 1, but no error, indicating success." - ) - yield Event( - author=self.name, - actions=EventActions(state_delta={self.output_key: "Success"}), - ) - else: - logging.info( - f"[{self.name}] Code execution failed. Loop will continue." - ) - # Use 'or' to handle None case - result.get() returns None if key exists with None value - error_msg = ( - result.get("error") - or result.get("output") - or "Unknown error: No error message or output available" - ) - # Add diagnostic logging when error field is None or empty - if result.get("error") is None: - logging.warning( - f"[{self.name}] Error field is None. " - f"exit_code: {result.get('exit_code')}, " - f"output length: {len(result.get('output', ''))}, " - f"Using output as fallback: {result.get('output', '')[:200]}" - ) - yield Event( - author=self.name, - actions=EventActions(state_delta={self.output_key: error_msg}), - ) - else: - error_detail = await response.text() - logging.error( - f"[{self.name}] HTTP error {response.status}: {error_detail}" - ) - ctx.session.state[self.output_key] = ( - f"HTTP error {response.status}: {error_detail}" + payload = { + "eval_type": "compilation_test", + "code": code, + "timeout": COMPILATION_TIMEOUT, + "backend_type": "tpu", + } + result = await call_eval_server_async( + session, + f"http://localhost:{EVAL_SERVER_PORT}", + payload, + poll_interval=10, + client_wait_timeout=REQUEST_TIMEOUT, + ) + + logging.info(f"[{self.name}] Compilation test result: {result}") + if result["exit_code"] == 0: + logging.info(f"[{self.name}] Code execution successful.") + yield Event( + author=self.name, + actions=EventActions(state_delta={self.output_key: "Success"}), + ) + elif ( + result["error"] is None + and result["output"] == "" + and result["exit_code"] == 1 + ): + logging.info( + f"[{self.name}] Code execution had exit code 1, but no error, indicating success." + ) + yield Event( + author=self.name, + actions=EventActions(state_delta={self.output_key: "Success"}), + ) + else: + logging.info( + f"[{self.name}] Code execution failed. Loop will continue." + ) + # Use 'or' to handle None case - result.get() returns None if key exists with None value + error_msg = ( + result.get("error") + or result.get("output") + or "Unknown error: No error message or output available" + ) + # Add diagnostic logging when error field is None or empty + if result.get("error") is None: + logging.warning( + f"[{self.name}] Error field is None. " + f"exit_code: {result.get('exit_code')}, " + f"output length: {len(result.get('output', ''))}, " + f"Using output as fallback: {result.get('output', '')[:200]}" ) - yield Event(author=self.name) + yield Event( + author=self.name, + actions=EventActions(state_delta={self.output_key: error_msg}), + ) except Exception as e: logging.error(f"[{self.name}] Exception during code execution: {str(e)}") ctx.session.state[self.output_key] = ( diff --git a/MaxKernel/auto_agent/subagents/kernel_writing/prompts/kernel_planning_prompt.py b/MaxKernel/auto_agent/subagents/kernel_writing/prompts/kernel_planning_prompt.py index 1a5b0fd..a6ce62c 100644 --- a/MaxKernel/auto_agent/subagents/kernel_writing/prompts/kernel_planning_prompt.py +++ b/MaxKernel/auto_agent/subagents/kernel_writing/prompts/kernel_planning_prompt.py @@ -35,6 +35,7 @@ 2. **Review execution results:** Analyze the following to identify what needs improvement: * Compilation Status: `{kernel_compilation_status?}` * Test Results: `{test_results?}` + *. Autotune Summary: `{autotuning_summary?}` * Profiling Summary: `{profiling_summary?}` 3. **Follow Guidelines:** * Preserve good ideas from the original plan that are not causing issues. diff --git a/MaxKernel/auto_agent/subagents/pipeline_agent.py b/MaxKernel/auto_agent/subagents/pipeline_agent.py index f351da7..ad06020 100644 --- a/MaxKernel/auto_agent/subagents/pipeline_agent.py +++ b/MaxKernel/auto_agent/subagents/pipeline_agent.py @@ -2,6 +2,7 @@ import logging import os +import re from typing import AsyncGenerator from google.adk.agents import BaseAgent @@ -21,6 +22,7 @@ class AutonomousPipelineAgent(BaseAgent): validate_agent: BaseAgent test_gen_agent: BaseAgent test_run_agent: BaseAgent + autotune_agent: BaseAgent profile_agent: BaseAgent max_iterations: int = 2 @@ -32,6 +34,7 @@ def __init__( validate_agent: BaseAgent, test_gen_agent: BaseAgent, test_run_agent: BaseAgent, + autotune_agent: BaseAgent, profile_agent: BaseAgent, max_iterations: int = 2, ): @@ -42,6 +45,7 @@ def __init__( validate_agent=validate_agent, test_gen_agent=test_gen_agent, test_run_agent=test_run_agent, + autotune_agent=autotune_agent, profile_agent=profile_agent, max_iterations=max_iterations, ) @@ -110,7 +114,12 @@ async def _run_async_impl( iteration += 1 continue - # Step 6: Profile + # Step 6: Autotune + logging.info(f"[{self.name}] Running AutotuneAgent...") + async for event in self.autotune_agent.run_async(ctx): + yield event + + # Step 7: Profile logging.info(f"[{self.name}] Running ProfileAgentOrchestrator...") async for event in self.profile_agent.run_async(ctx): yield event @@ -128,8 +137,7 @@ async def _run_async_impl( ) # Extract latency - test_output = ctx.session.state.get("test_results", {}).get("output", "") - latency = self._extract_latency(test_output) + latency = self._extract_latency(ctx) snapshot = { "iteration": iteration, @@ -239,6 +247,22 @@ def _initialize_state(self, ctx: InvocationContext) -> Event: f"[{self.name}] Set profiling_script_path: {ctx.session.state['profiling_script_path']}" ) + if "autotune_specs_path" not in ctx.session.state: + ctx.session.state["autotune_specs_path"] = os.path.join( + session_dir, "autotune_specs.json" + ) + logging.info( + f"[{self.name}] Set autotune_specs_path: {ctx.session.state['autotune_specs_path']}" + ) + + if "autotune_results_path" not in ctx.session.state: + ctx.session.state["autotune_results_path"] = os.path.join( + session_dir, "autotune_results.json" + ) + logging.info( + f"[{self.name}] Set autotune_results_path: {ctx.session.state['autotune_results_path']}" + ) + logging.info(f"[{self.name}] Published explicit path state update Event.") return Event( author=self.name, @@ -254,18 +278,26 @@ def _initialize_state(self, ctx: InvocationContext) -> Event: ), ) - def _extract_latency(self, test_output: str): - """Extracts execution time from test results output.""" + def _extract_latency(self, ctx: InvocationContext): + """Extracts execution time from autotune results or test results output.""" + autotune_results = ctx.session.state.get("autotune_results", {}) + if autotune_results.get("status") == "success": + latency = autotune_results.get("best_time_ms") + if latency is not None: + logging.info( + f"[{self.name}] Extracted latency from autotune results: {latency} ms" + ) + return latency + + test_output = ctx.session.state.get("test_results", {}).get("output", "") if not test_output: return None try: - import re - match = re.search(r"PERF_METRICS:\s*([\d.]+)", test_output) if match: latency = float(match.group(1)) logging.info( - f"[{self.name}] Extracted execution time from test results: {latency} ms" + f"[{self.name}] Extracted execution time from test results: {latency} ms (fallback)" ) return latency except Exception as e: diff --git a/MaxKernel/auto_agent/subagents/profiling/kernel_profile.py b/MaxKernel/auto_agent/subagents/profiling/kernel_profile.py index 72138f2..a4299e9 100644 --- a/MaxKernel/auto_agent/subagents/profiling/kernel_profile.py +++ b/MaxKernel/auto_agent/subagents/profiling/kernel_profile.py @@ -6,11 +6,10 @@ from google.adk.agents.invocation_context import InvocationContext from google.adk.events import Event, EventActions -from auto_agent.constants import ( - EVAL_SERVER_PORT, - REQUEST_TIMEOUT, - TPU_TIMEOUT, -) +from auto_agent.client_utils.eval_client import call_eval_server_async +from auto_agent.constants import EVAL_SERVER_PORT, REQUEST_TIMEOUT + +PROFILE_TIMEOUT = 120 class KernelProfiler(BaseAgent): @@ -50,106 +49,96 @@ async def _run_async_impl( # Call the TPU server to execute the code logging.info(f"[{self.name}] Running code") async with aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(total=REQUEST_TIMEOUT) + timeout=aiohttp.ClientTimeout(total=REQUEST_TIMEOUT + 10) ) as session: - async with session.post( - f"http://localhost:{EVAL_SERVER_PORT}/evaluate", - json={ - "eval_type": "profile", - "code": profile_code, - "timeout": TPU_TIMEOUT, - "backend_type": "tpu", - }, - ) as response: - if response.status == 200: - result = await response.json() - logging.info(f"[{self.name}] Profiling result: {result}") + payload = { + "eval_type": "profile", + "code": profile_code, + "timeout": PROFILE_TIMEOUT, + "backend_type": "tpu", + } + result = await call_eval_server_async( + session, + f"http://localhost:{EVAL_SERVER_PORT}", + payload, + poll_interval=10, + client_wait_timeout=REQUEST_TIMEOUT, + ) - # Check if profiling was successful based on exit code and output - exit_code = result.get("exit_code", 0) - output = result.get("output", "") - error_msg = result.get("error", "") + logging.info(f"[{self.name}] Profiling result: {result}") - # Profiling succeeds if exit_code is 0 and we have output - # Stderr may contain warnings (like TensorFlow import warnings) which are not failures - if exit_code != 0: - full_error = f"Profiling script failed with exit code {exit_code}" - if error_msg: - full_error += f": {error_msg}" - logging.error(f"[{self.name}] {full_error}") - yield Event( - author=self.name, - actions=EventActions(state_delta={self.output_key: full_error}), - ) - elif not output or output.strip() == "": - full_error = "Profiling script produced no output" - if error_msg: - full_error += f". Stderr: {error_msg}" - logging.error(f"[{self.name}] {full_error}") - yield Event( - author=self.name, - actions=EventActions(state_delta={self.output_key: full_error}), - ) - else: - # Successful profiling - parse the ratio and xplane path - try: - try: - import json + # Check if profiling was successful based on exit code and output + exit_code = result.get("exit_code", 0) + output = result.get("output", "") + error_msg = result.get("error", "") - res = json.loads(output.strip()) - ratio = float(res.get("ratio", 0)) - xplane_path = res.get("xplane_path", "") - except (ValueError, json.JSONDecodeError): - # Fallback for old servers returning raw ratio - ratio = float(output.strip()) - xplane_path = "" + # Profiling succeeds if exit_code is 0 and we have output + # Stderr may contain warnings (like TensorFlow import warnings) which are not failures + if exit_code != 0: + full_error = f"Profiling script failed with exit code {exit_code}" + if error_msg: + full_error += f": {error_msg}" + logging.error(f"[{self.name}] {full_error}") + yield Event( + author=self.name, + actions=EventActions(state_delta={self.output_key: full_error}), + ) + elif not output or output.strip() == "": + full_error = "Profiling script produced no output" + if error_msg: + full_error += f". Stderr: {error_msg}" + logging.error(f"[{self.name}] {full_error}") + yield Event( + author=self.name, + actions=EventActions(state_delta={self.output_key: full_error}), + ) + else: + # Successful profiling - parse the ratio and xplane path + try: + try: + import json - # Log warnings if present, but don't fail - if error_msg: - logging.warning( - f"[{self.name}] Profiling succeeded but had warnings in" - f" stderr: {error_msg[:200]}" - ) + res = json.loads(output.strip()) + ratio = float(res.get("ratio", 0)) + xplane_path = res.get("xplane_path", "") + except (ValueError, json.JSONDecodeError): + # Fallback for old servers returning raw ratio + ratio = float(output.strip()) + xplane_path = "" - logging.info( - f"[{self.name}] Profiling succeeded with ratio: {ratio}," - f" xplane_path: {xplane_path}" - ) - yield Event( - author=self.name, - actions=EventActions( - escalate=False, - state_delta={ - self.output_key: { - "DMAs_and_memory_transfers_ratio": ratio, - "compute_ratio": 1 - ratio, - "xplane_path": xplane_path, - } - }, - ), - ) - except (ValueError, KeyError) as e: - error_msg_full = ( - f"Failed to parse profiling output: '{output}'. Error: {e}" - ) - logging.error(f"[{self.name}] {error_msg_full}") - yield Event( - author=self.name, - actions=EventActions( - state_delta={self.output_key: error_msg_full} - ), - ) - else: - error_detail = await response.text() - logging.error( - f"[{self.name}] HTTP error {response.status}: {error_detail}" + # Log warnings if present, but don't fail + if error_msg: + logging.warning( + f"[{self.name}] Profiling succeeded but had warnings in" + f" stderr: {error_msg[:200]}" + ) + + logging.info( + f"[{self.name}] Profiling succeeded with ratio: {ratio}," + f" xplane_path: {xplane_path}" ) yield Event( author=self.name, actions=EventActions( + escalate=False, state_delta={ - self.output_key: f"HTTP error {response.status}: {error_detail}" - } + self.output_key: { + "DMAs_and_memory_transfers_ratio": ratio, + "compute_ratio": 1 - ratio, + "xplane_path": xplane_path, + } + }, + ), + ) + except (ValueError, KeyError) as e: + error_msg_full = ( + f"Failed to parse profiling output: '{output}'. Error: {e}" + ) + logging.error(f"[{self.name}] {error_msg_full}") + yield Event( + author=self.name, + actions=EventActions( + state_delta={self.output_key: error_msg_full} ), ) except Exception as e: diff --git a/MaxKernel/auto_agent/subagents/testing/agent.py b/MaxKernel/auto_agent/subagents/testing/agent.py index a57e0bc..c120de1 100644 --- a/MaxKernel/auto_agent/subagents/testing/agent.py +++ b/MaxKernel/auto_agent/subagents/testing/agent.py @@ -15,8 +15,9 @@ from google.adk.agents.invocation_context import InvocationContext from google.adk.events import Event, EventActions +from auto_agent.client_utils.eval_client import call_eval_server_async from auto_agent.config import model_config, thinking_planner -from auto_agent.constants import EVAL_SERVER_PORT, MODEL_NAME +from auto_agent.constants import EVAL_SERVER_PORT, MODEL_NAME, REQUEST_TIMEOUT from auto_agent.custom_types import CustomLlmAgent from auto_agent.subagents.testing.prompts import ( fix_test_script, @@ -149,17 +150,15 @@ async def _run_async_impl( "dependencies": dependencies, } - async with aiohttp.ClientSession() as session: - async with session.post( - f"http://localhost:{EVAL_SERVER_PORT}/evaluate", - json=payload, - ) as response: - if response.status != 200: - error_text = await response.text() - raise Exception( - f"Eval server returned status {response.status}: {error_text}" - ) - result_json = await response.json() + client_timeout = aiohttp.ClientTimeout(total=REQUEST_TIMEOUT + 10) + async with aiohttp.ClientSession(timeout=client_timeout) as session: + result_json = await call_eval_server_async( + session, + f"http://localhost:{EVAL_SERVER_PORT}", + payload, + poll_interval=10, + client_wait_timeout=REQUEST_TIMEOUT, + ) full_output = f"STDOUT:\n{result_json.get('output', '')}\n\nSTDERR:\n{result_json.get('error', '') or ''}" @@ -177,21 +176,6 @@ async def _run_async_impl( actions=EventActions(state_delta={self.output_key: test_results}), ) - except subprocess.TimeoutExpired: - error_msg = "Test execution timed out after 5 minutes" - logging.error(f"[{self.name}] {error_msg}") - yield Event( - author=self.name, - actions=EventActions( - state_delta={ - self.output_key: { - "exit_code": -1, - "output": error_msg, - "success": False, - } - } - ), - ) except Exception as e: error_msg = f"Exception during test execution: {str(e)}" logging.error(f"[{self.name}] {error_msg}") @@ -639,17 +623,15 @@ async def _run_async_impl( "dependencies": dependencies, } - async with aiohttp.ClientSession() as session: - async with session.post( - f"http://localhost:{EVAL_SERVER_PORT}/evaluate", - json=payload, - ) as response: - if response.status != 200: - error_text = await response.text() - raise Exception( - f"Eval server returned status {response.status}: {error_text}" - ) - result_json = await response.json() + client_timeout = aiohttp.ClientTimeout(total=REQUEST_TIMEOUT + 10) + async with aiohttp.ClientSession(timeout=client_timeout) as session: + result_json = await call_eval_server_async( + session, + f"http://localhost:{EVAL_SERVER_PORT}", + payload, + poll_interval=10, + client_wait_timeout=REQUEST_TIMEOUT, + ) exit_code = result_json.get("exit_code", -1) stdout = result_json.get("output", "") diff --git a/MaxKernel/auto_agent/tools/file_tools.py b/MaxKernel/auto_agent/tools/file_tools.py index 700eb55..bc24a5f 100644 --- a/MaxKernel/auto_agent/tools/file_tools.py +++ b/MaxKernel/auto_agent/tools/file_tools.py @@ -86,6 +86,9 @@ def _write_file(content: str, tool_context: ToolContext) -> str: write_profiling_script_tool = restricted_write_file( "profiling_script_path", "Writes the profiling script." ) +write_autotune_specs_tool = restricted_write_file( + "autotune_specs_path", "Writes the autotuning specifications." +) __all__ = [ "filesystem_tool_r", @@ -94,4 +97,5 @@ def _write_file(content: str, tool_context: ToolContext) -> str: "write_optimized_kernel_tool", "write_optimization_plan_tool", "write_profiling_script_tool", + "write_autotune_specs_tool", ]