Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions MaxKernel/auto_agent/agent_client/run_batch_agent_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import logging
import os
import random
import time
from concurrent.futures import ThreadPoolExecutor, as_completed

Expand Down Expand Up @@ -89,6 +90,11 @@ def process_problem(
user_id = "user_0"
session_id = f"session_{problem_id}_attempt_{attempt}_{int(time.time())}"

# Add random jitter to avoid SQLite database lock contention
jitter = random.uniform(0.1, 2.0)
logger.info(f"Sleeping for {jitter:.2f}s (jitter) to avoid DB lock.")
time.sleep(jitter)

client = AutoAgentClient(
user_id=user_id,
session_id=session_id,
Expand Down
51 changes: 49 additions & 2 deletions MaxKernel/auto_agent/subagents/pipeline_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import os
import re
import shutil
from typing import AsyncGenerator

from google.adk.agents import BaseAgent
Expand Down Expand Up @@ -85,6 +86,9 @@ async def _run_async_impl(
logging.error(
f"[{self.name}] Compilation failed. Looping back to planning."
)
self._save_iteration_files(
ctx, iteration, keys_to_save=["optimized_kernel_path"]
)
iteration += 1
continue

Expand All @@ -99,6 +103,11 @@ async def _run_async_impl(
logging.error(
f"[{self.name}] Test generation/validation failed. Looping back to planning."
)
self._save_iteration_files(
ctx,
iteration,
keys_to_save=["optimized_kernel_path", "test_file_path"],
)
iteration += 1
continue

Expand All @@ -111,6 +120,11 @@ async def _run_async_impl(
test_results = ctx.session.state.get("test_results", {})
if not test_results.get("success", False):
logging.error(f"[{self.name}] Tests failed. Looping back to planning.")
self._save_iteration_files(
ctx,
iteration,
keys_to_save=["optimized_kernel_path", "test_file_path"],
)
iteration += 1
continue

Expand Down Expand Up @@ -161,9 +175,11 @@ async def _run_async_impl(
)
logging.info(f"[{self.name}] Saved snapshot for iteration {iteration}")

# Step 7: Check if improvement is needed
needs_improvement = ctx.session.state.get("needs_improvement", False)
self._save_iteration_files(ctx, iteration)

# Step 7: Check if improvement is needed
# needs_improvement = ctx.session.state.get("needs_improvement", False)
needs_improvement = True
if not needs_improvement:
logging.info(
f"[{self.name}] No further improvement needed or agent decided to stop. Stopping pipeline."
Expand Down Expand Up @@ -193,6 +209,35 @@ async def _run_async_impl(
),
)

def _save_iteration_files(
self,
ctx: InvocationContext,
iteration: int,
keys_to_save: list[str] | None = None,
):
"""Saves artifacts with an iteration suffix."""
if keys_to_save is None:
keys_to_save = [
"optimized_kernel_path",
"test_file_path",
"autotune_specs_path",
"autotune_results_path",
]
for path_key in keys_to_save:
path = ctx.session.state.get(path_key)
if path and os.path.exists(path):
directory, filename = os.path.split(path)
name, ext = os.path.splitext(filename)
new_filename = f"{name}_{iteration}{ext}"
new_path = os.path.join(directory, new_filename)
try:
shutil.copy2(path, new_path)
logging.info(f"[{self.name}] Copied {path_key} to {new_path}")
except Exception as e:
logging.error(
f"[{self.name}] Failed to copy {path_key} to {new_path}: {e}"
)

def _initialize_state(self, ctx: InvocationContext) -> Event:
"""Initializes session state with standard paths and returns the event."""
# Initialize history
Expand Down Expand Up @@ -274,6 +319,8 @@ def _initialize_state(self, ctx: InvocationContext) -> Event:
"kernel_plan_path": ctx.session.state["kernel_plan_path"],
"test_file_path": ctx.session.state["test_file_path"],
"profiling_script_path": ctx.session.state["profiling_script_path"],
"autotune_specs_path": ctx.session.state["autotune_specs_path"],
"autotune_results_path": ctx.session.state["autotune_results_path"],
}
),
)
Expand Down
1 change: 1 addition & 0 deletions MaxKernel/auto_agent/subagents/testing/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
COMPILE_VALIDATION_TIMEOUT = 60 * 1
MOCK_EXECUTION_TIMEOUT = 60 * 3
TEST_EXECUTION_TIMEOUT = 60 * 5
TEST_EXECUTION_POLL_INTERVAL = 20
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is TEST_EXECUTION_POLL_INTERVAL used?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is for the polling interval. The client will poll for the status of the test work at TPU backend every 20s. The default is 10s and since the test usually takes longer I increase the interval so that it does not poll to often.



class TestRunner(BaseAgent):
Expand Down