Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions src/agent/__tests__/investigation-orchestrator.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@
// Create mock LLM that returns appropriate responses based on call order
// The order is: triage -> hypothesis generation -> evidence eval (repeat) -> conclusion -> remediation
mockLLM = {
complete: vi.fn().mockImplementation(async (prompt: string) => {

Check warning on line 129 in src/agent/__tests__/investigation-orchestrator.test.ts

View workflow job for this annotation

GitHub Actions / Test

'prompt' is defined but never used. Allowed unused args must match /^_/u
llmCallIndex++;

// First call is triage
Expand Down Expand Up @@ -161,7 +161,7 @@

// Create mock tool executor
mockToolExecutor = {
execute: vi.fn().mockImplementation(async (tool: string, params: Record<string, unknown>) => {

Check warning on line 164 in src/agent/__tests__/investigation-orchestrator.test.ts

View workflow job for this annotation

GitHub Actions / Test

'params' is defined but never used. Allowed unused args must match /^_/u
if (tool === 'cloudwatch_alarms') {
return [{ alarmName: 'HighLatency', state: 'ALARM' }];
}
Expand Down Expand Up @@ -326,6 +326,67 @@
expect(mockToolExecutor.execute).toHaveBeenCalled();
});

it('should execute hypothesis queries with parallelism', async () => {
let callIndex = 0;
const complete = vi.fn().mockImplementation(async () => {
callIndex++;
if (callIndex === 1) return mockTriageResponse;
if (callIndex === 2) {
return JSON.stringify({
hypotheses: [
{
statement: 'Unexpected behavior in custom subsystem',
category: 'application',
priority: 1,
confirmingEvidence: 'Correlated anomalies in telemetry',
refutingEvidence: 'No telemetry anomalies',
queries: [],
},
],
reasoning: 'Start with broad telemetry correlation.',
});
}
if (callIndex === 3) return mockEvidenceEvaluationConfirm;
if (callIndex === 4) return mockConclusionResponse;
if (callIndex === 5) return mockRemediationResponse;
return mockEvidenceEvaluationPrune;
});
const llm: LLMClient = { complete };

let active = 0;
let maxActive = 0;
const execute = vi.fn().mockImplementation(async (tool: string) => {
if (tool === 'cloudwatch_alarms' || tool === 'cloudwatch_logs' || tool === 'datadog') {
active++;
maxActive = Math.max(maxActive, active);
await new Promise((resolve) => setTimeout(resolve, 25));
active--;

if (tool === 'cloudwatch_alarms') {
return [];
}
if (tool === 'cloudwatch_logs') {
return { events: [], count: 0 };
}
return { triggeredMonitors: [], count: 0 };
}

if (tool === 'aws_query') {
return { totalResources: 0, results: {} };
}

return { success: true };
});
const toolExecutor: ToolExecutor = { execute };

const orchestrator = createOrchestrator(llm, toolExecutor, {
availableTools: ['cloudwatch_alarms', 'cloudwatch_logs', 'datadog', 'aws_query'],
});
await orchestrator.investigate('Investigate incident behavior');

expect(maxActive).toBeGreaterThan(1);
});

it('should handle tool errors gracefully', async () => {
// Reset call index for this test
llmCallIndex = 0;
Expand Down
56 changes: 52 additions & 4 deletions src/agent/investigation-orchestrator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ export interface RemediationContext {
export interface InvestigationOptions {
incidentId?: string;
maxIterations?: number;
queryExecutionConcurrency?: number;
autoApproveRemediation?: boolean;
approveRemediationStep?: (step: RemediationStep) => Promise<boolean>;
knownServices?: string[];
Expand Down Expand Up @@ -627,6 +628,37 @@ export class InvestigationOrchestrator {
return formatted;
}

private getQueryExecutionConcurrency(totalQueries: number): number {
const configured = this.options.queryExecutionConcurrency ?? 3;
if (!Number.isFinite(configured)) {
return 1;
}

return Math.max(1, Math.min(totalQueries, Math.floor(configured)));
}

private async runWithConcurrency<T>(
items: T[],
concurrency: number,
worker: (item: T) => Promise<void>
): Promise<void> {
if (items.length === 0) {
return;
}

let index = 0;
const workerCount = Math.max(1, Math.min(concurrency, items.length));

const workers = Array.from({ length: workerCount }, async () => {
while (index < items.length) {
const currentIndex = index++;
await worker(items[currentIndex]);
}
});

await Promise.all(workers);
}

/**
* Run a full investigation
*/
Expand Down Expand Up @@ -971,14 +1003,20 @@ export class InvestigationOrchestrator {
return q;
});

// Execute each query
const runnableQueries: CausalQuery[] = [];
for (const query of refinedQueries) {
const runnableQuery = this.adaptQueryToEnvironment(query);
if (!runnableQuery) {
results.set(query.id, { error: `No compatible tool available for ${query.tool}` });
continue;
}
runnableQueries.push(runnableQuery);
}

const concurrency = this.getQueryExecutionConcurrency(runnableQueries.length);
const queryResults = new Map<string, unknown>();

await this.runWithConcurrency(runnableQueries, concurrency, async (runnableQuery) => {
this.emit({ type: 'query_executing', query: runnableQuery });

try {
Expand All @@ -987,13 +1025,23 @@ export class InvestigationOrchestrator {
runnableQuery.parameters
);
this.updateCloudWatchHints(runnableQuery.tool, result, runnableQuery.parameters);
results.set(runnableQuery.id, result);
queryResults.set(runnableQuery.id, result);
machine.recordQueryResult(hypothesis.id, runnableQuery.id, result);

this.emit({ type: 'query_complete', query: runnableQuery, result });
} catch (error) {
results.set(runnableQuery.id, { error: String(error) });
const failure = { error: error instanceof Error ? error.message : String(error) };
queryResults.set(runnableQuery.id, failure);
machine.recordQueryResult(hypothesis.id, runnableQuery.id, failure);
this.emit({ type: 'query_complete', query: runnableQuery, result: failure });
}
});

// Preserve deterministic result ordering based on generated query order.
for (const runnableQuery of runnableQueries) {
results.set(
runnableQuery.id,
queryResults.get(runnableQuery.id) || { error: 'No result returned' }
);
}

return results;
Expand Down
Loading