Skip to content
2 changes: 2 additions & 0 deletions MaxKernel/auto_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
for the human-in-the-loop kernel generation process.
"""

from auto_agent.subagents.autotuning.agent import autotune_agent
from auto_agent.subagents.kernel_writing import (
implement_kernel_agent,
plan_kernel_agent,
Expand All @@ -23,6 +24,7 @@
validate_agent=validate_kernel_compilation_agent,
test_gen_agent=validated_test_generation_agent,
test_run_agent=unified_test_agent,
autotune_agent=autotune_agent,
profile_agent=profile_agent,
max_iterations=5,
)
Expand Down
76 changes: 76 additions & 0 deletions MaxKernel/auto_agent/client_utils/eval_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import asyncio
import logging
import time

import aiohttp


async def call_eval_server_async(
session: aiohttp.ClientSession,
eval_server_url: str,
payload: dict,
poll_interval: int = 10,
client_wait_timeout: int = 3600 * 3, # Default to 3 hours
) -> dict:
"""Calls the evaluation server asynchronously and polls for status.

Args:
session: aiohttp.ClientSession to use.
eval_server_url: Base URL of the eval server (e.g.,
"http://localhost:1245").
payload: The request payload.
poll_interval: Seconds to wait between polls.
client_wait_timeout: Max seconds to wait for task completion.

Returns:
The result from the evaluation server.
"""
# 1. Submit the task
submit_url = f"{eval_server_url}/evaluate"
logging.info(f"Submitting async task to {submit_url}")

payload = payload.copy()
payload["client_wait_timeout"] = client_wait_timeout

async with session.post(submit_url, json=payload) as response:
if response.status != 202:
error_text = await response.text()
raise Exception(
f"Failed to submit task. Status: {response.status}, Error: {error_text}"
)
resp_data = await response.json()
task_id = resp_data["task_id"]
logging.info(f"Task submitted successfully. ID: {task_id}")

# 2. Poll for status
start_time = time.time()
status_url = f"{eval_server_url}/status/{task_id}"

while True:
if time.time() - start_time > client_wait_timeout:
raise Exception(
f"Client timed out waiting for task {task_id} after {client_wait_timeout} seconds"
)

async with session.get(status_url) as response:
if response.status != 200:
error_text = await response.text()
raise Exception(
f"Failed to get task status. Status: {response.status}, Error: {error_text}"
)

status_data = await response.json()
status = status_data["status"]

if status == "success":
logging.info(f"Task {task_id} completed successfully.")
return status_data["result"]
elif status in ["failed", "timeout"]:
raise Exception(
f"Task {task_id} ended with status {status}: {status_data.get('error')}"
)

logging.info(
f"Task {task_id} status: {status}. Waiting {poll_interval}s..."
)
await asyncio.sleep(poll_interval)
4 changes: 1 addition & 3 deletions MaxKernel/auto_agent/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
TEMPERATURE = 0.1
TOP_P = 0.9
TOP_K = 5
TPU_TIMEOUT = 120
REQUEST_TIMEOUT = 1800
REQUEST_TIMEOUT = 3600 * 3
Comment thread
NinaCai marked this conversation as resolved.
TPU_SERVER_PORT = 5463
CPU_SERVER_PORT = 5464
EVAL_SERVER_PORT = 1245
PERF_THRESHOLD = 1.1
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it not needed anymore?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't see it is used by any code.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is for evaluation? When perf improves by more than 10%, we consider it as an improvement?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The evaluation code is independent of any agent code and the threshold was set as 1.05.

124 changes: 103 additions & 21 deletions MaxKernel/auto_agent/server_utils/eval_server.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
import asyncio
import logging
import time
import uuid
from enum import Enum
from typing import Optional

import aiohttp
import yaml
from fastapi import FastAPI, HTTPException
from fastapi import BackgroundTasks, FastAPI, HTTPException
from pydantic import BaseModel

from auto_agent.constants import (
EVAL_SERVER_PORT,
)
from auto_agent.server_utils.tpu_server import (
CodeResponse,
get_tpu_version,
)
from auto_agent.constants import EVAL_SERVER_PORT
from auto_agent.server_utils.tpu_server import get_tpu_version

logging.basicConfig(
level=logging.INFO,
Expand Down Expand Up @@ -62,14 +59,38 @@ class EvalTypes(Enum):
PERFORMANCE_TEST = "performance_test"
UNIFIED_TEST = "unified_test"
PROFILE = "profile"
AUTOTUNE = "autotune"


class EvalRequest(BaseModel):
eval_type: EvalTypes
code: str
timeout: Optional[int] = 30
timeout: int
code: Optional[str] = None
code_template: Optional[str] = None # For autotune
search_space: Optional[dict] = None # For autotune
backend_type: Optional[str] = None # "tpu", "cpu", or None for any available
dependencies: Optional[dict] = None
total_timeout: Optional[int] = None # For autotune
client_wait_timeout: Optional[int] = None


class TaskStatus(str, Enum):
QUEUED = "queued"
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
TIMEOUT = "timeout"


class TaskResponse(BaseModel):
task_id: str
status: TaskStatus
result: Optional[dict] = None
error: Optional[str] = None


# For evaluation task tracking. task_id -> {status, request, result_or_error}
_tasks = {}


class Evaluator:
Expand Down Expand Up @@ -115,13 +136,64 @@ async def health_check():
return {"status": "healthy"}


@app.post("/evaluate", response_model=CodeResponse)
async def evaluate(request: EvalRequest):
# Asynchronous task polling architecture for evaluation.
@app.post("/evaluate", status_code=202)
async def evaluate(request: EvalRequest, background_tasks: BackgroundTasks):
task_id = str(uuid.uuid4())
_tasks[task_id] = {"status": TaskStatus.QUEUED, "request": request}
background_tasks.add_task(run_evaluation_task, task_id, request)
return {"task_id": task_id, "status": TaskStatus.QUEUED}


@app.get("/status/{task_id}", response_model=TaskResponse)
async def get_task_status(task_id: str):
if task_id not in _tasks:
raise HTTPException(status_code=404, detail="Task not found")
task_info = _tasks[task_id]
return {
"task_id": task_id,
"status": task_info["status"],
"result": task_info.get("result"),
"error": task_info.get("error"),
}


async def run_evaluation_task(task_id: str, request: EvalRequest):
_tasks[task_id]["status"] = TaskStatus.RUNNING
try:
result = await _perform_evaluation(request)
_tasks[task_id]["status"] = TaskStatus.SUCCESS
_tasks[task_id]["result"] = result
except HTTPException as e:
if e.status_code == 408:
_tasks[task_id]["status"] = TaskStatus.TIMEOUT
else:
_tasks[task_id]["status"] = TaskStatus.FAILED
_tasks[task_id]["error"] = e.detail
except Exception as e:
_tasks[task_id]["status"] = TaskStatus.FAILED
_tasks[task_id]["error"] = str(e)


async def _perform_evaluation(request: EvalRequest):
if request.eval_type not in EvalTypes:
raise HTTPException(status_code=400, detail="Invalid evaluation type")

# Acquire backend, retry if busy
start_queue_time = time.time()
max_queue_time = (
request.client_wait_timeout
if request.client_wait_timeout is not None
else 3600 * 3
)

while True:
if time.time() - start_queue_time > max_queue_time:
raise HTTPException(
status_code=408,
detail="Timed out waiting for an available backend in queue",
)

async with backend_semaphore:
backend = evaluator.get_available_backend(
backend_type=request.backend_type
Expand All @@ -146,18 +218,26 @@ async def evaluate(request: EvalRequest):
f"{requested_type_msg}"
)

# Send request to backend server
backend_timeout = request.timeout if request.timeout is not None else 30
# Construct payload based on eval type
payload = {
"eval_type": request.eval_type.value,
"timeout": request.timeout,
}
if request.eval_type == EvalTypes.AUTOTUNE:
payload["code_template"] = request.code_template
payload["search_space"] = request.search_space
backend_timeout = request.total_timeout
payload["total_timeout"] = request.total_timeout
else:
payload["code"] = request.code
payload["dependencies"] = request.dependencies
backend_timeout = request.timeout

client_timeout = aiohttp.ClientTimeout(total=backend_timeout + 10)
async with aiohttp.ClientSession(timeout=client_timeout) as session:
async with session.post(
f"http://{backend_ip}:{backend_port}/{request.eval_type.value}",
json={
"eval_type": request.eval_type.value,
"code": request.code,
"timeout": request.timeout,
"dependencies": request.dependencies,
},
json=payload,
) as response:
result = await response.json()
logging.info(
Expand Down Expand Up @@ -186,7 +266,9 @@ async def evaluate(request: EvalRequest):
raise e
except Exception as e:
logging.error(f"Error occurred while evaluating on {backend_name}: {e}")
raise HTTPException(status_code=500, detail="Backend evaluation failed")
raise HTTPException(
status_code=500, detail=f"Backend evaluation failed: {e}"
)
finally:
backend.set_status("available")

Expand Down
Loading