diff --git a/packages/llm/src/llm.ts b/packages/llm/src/llm.ts index 820ff8356..fee4ebd31 100644 --- a/packages/llm/src/llm.ts +++ b/packages/llm/src/llm.ts @@ -23,15 +23,20 @@ export class LLM { constructor(private openAI: OpenAI) {} async call(params: { model: string; input: Array; tools?: Array }) { - const flattenInput = params.input - .map((m) => { - if (m instanceof ResponseMessage) { - return m.output // - .map((o) => o.toPlain()); - } - return m.toPlain(); - }) - .flat(); + const flattenInput = params.input.flatMap((m) => { + if (m instanceof ResponseMessage) { + const outputMessages = m.output; + const containsToolCall = outputMessages.some((entry) => entry instanceof ToolCallMessage); + return outputMessages + .filter((entry) => { + if (!containsToolCall) return true; + if (!(entry instanceof AIMessage)) return true; + return entry.text.trim().length > 0; + }) + .map((entry) => entry.toPlain()); + } + return [m.toPlain()]; + }); const toolDefinitions = params.tools?.map((tool) => tool.definition()); diff --git a/packages/platform-server/__tests__/llm.full_flow.duplication.test.ts b/packages/platform-server/__tests__/llm.full_flow.duplication.test.ts new file mode 100644 index 000000000..5346008fe --- /dev/null +++ b/packages/platform-server/__tests__/llm.full_flow.duplication.test.ts @@ -0,0 +1,613 @@ +import 'reflect-metadata'; +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; +import { Test } from '@nestjs/testing'; +import z from 'zod'; + +import { AgentNode } from '../src/nodes/agent/agent.node'; +import { ConfigService } from '../src/core/services/config.service'; +import { registerTestConfig, clearTestConfig, runnerConfigDefaults } from './helpers/config'; +import { LLMProvisioner } from '../src/llm/provisioners/llm.provisioner'; +import { RunSignalsRegistry } from '../src/agents/run-signals.service'; +import { AgentsPersistenceService } from '../src/agents/agents.persistence.service'; +import { PrismaService } from '../src/core/services/prisma.service'; +import { RunEventsService } from '../src/events/run-events.service'; +import { EventsBusService } from '../src/events/events-bus.service'; +import { createRunEventsStub, createEventsBusStub } from './helpers/runEvents.stub'; +import { BaseToolNode } from '../src/nodes/tools/baseToolNode'; + +import type { LLM } from '@agyn/llm'; +import { AIMessage, FunctionTool, HumanMessage, ResponseMessage, SystemMessage, ToolCallMessage, ToolCallOutputMessage } from '@agyn/llm'; + +vi.mock('@agyn/docker-runner', () => ({})); + +type MixedOutput = ReturnType['output']; + +type ScriptStep = + | { kind: 'tool_call'; callId: string; name: string; args?: string } + | { kind: 'text'; text: string } + | { kind: 'response'; output: MixedOutput }; + +class ScriptableLLM implements Pick { + readonly inputs: Array<{ raw: Parameters[0]['input']; flat: unknown[] }> = []; + private script: ScriptStep[] = []; + private pointer = 0; + + setScript(steps: ScriptStep[]): void { + this.script = [...steps]; + this.pointer = 0; + this.inputs.length = 0; + } + + async call(params: Parameters[0]): Promise { + const flat = params.input.flatMap((msg) => { + if (msg instanceof ResponseMessage) { + const outputMessages = msg.output; + const containsToolCall = outputMessages.some((entry) => entry instanceof ToolCallMessage); + return outputMessages + .filter((entry) => { + if (!containsToolCall) return true; + if (!(entry instanceof AIMessage)) return true; + return entry.text.trim().length > 0; + }) + .map((entry) => entry.toPlain()); + } + return [msg.toPlain()]; + }); + + this.inputs.push({ raw: params.input, flat }); + + const step = this.script[this.pointer]; + this.pointer += 1; + if (!step) { + throw new Error('ScriptableLLM received more calls than scripted'); + } + + if (step.kind === 'tool_call') { + const toolCall = new ToolCallMessage({ + type: 'function_call', + call_id: step.callId, + name: step.name, + arguments: step.args ?? '{}', + } as any); + return new ResponseMessage({ output: [toolCall.toPlain()] as any }); + } + + if (step.kind === 'response') { + return new ResponseMessage({ output: step.output as any }); + } + + return ResponseMessage.fromText(step.text); + } +} + +class SilentLLM implements Pick { + async call(): Promise { + throw new Error('Summarization LLM should not be invoked in these tests'); + } +} + +class FakeProvisioner extends LLMProvisioner { + private pendingCallModelLLM: ScriptableLLM | null = null; + + constructor(private readonly summarizationLLM: SilentLLM) { + super(); + } + + setNextCallModelLLM(llm: ScriptableLLM): void { + this.pendingCallModelLLM = llm; + } + + async init(): Promise {} + + async getLLM(): Promise { + if (this.pendingCallModelLLM) { + const llm = this.pendingCallModelLLM; + this.pendingCallModelLLM = null; + return llm as unknown as LLM; + } + return this.summarizationLLM as unknown as LLM; + } + + async teardown(): Promise {} +} + +const TOOL_SCHEMA = z.object({}); + +class DemoFunctionTool extends FunctionTool { + constructor(private readonly toolName: string) { + super(); + } + + get name(): string { + return this.toolName; + } + + get description(): string { + return `${this.toolName} integration tool`; + } + + get schema(): typeof TOOL_SCHEMA { + return TOOL_SCHEMA; + } + + async execute(): Promise { + return 'ok'; + } +} + +class DemoToolNode extends BaseToolNode { + constructor(private readonly tool: FunctionTool) { + super(); + this.init({ nodeId: 'tool-demo' }); + } + + getTool(): FunctionTool { + return this.tool; + } + + getPortConfig() { + return { sourcePorts: {}, targetPorts: {} }; + } +} + +const createToolCallPlain = (callId: string, name = 'demo', args = '{}') => + new ToolCallMessage({ + type: 'function_call', + call_id: callId, + name, + arguments: args, + } as any).toPlain(); + +type AgentFixture = { + agent: AgentNode; + moduleRef: Awaited>; + provisioner: FakeProvisioner; + conversationState: Map; + registerCallModelLLM: (llm: ScriptableLLM) => void; +}; + +const createAgentFixture = async (): Promise => { + const config = registerTestConfig({ + llmProvider: 'litellm', + litellmBaseUrl: 'http://127.0.0.1:4000', + litellmMasterKey: 'sk-test-master', + ...runnerConfigDefaults, + }); + + const runEvents = createRunEventsStub(); + const eventsBus = createEventsBusStub(); + const summarizationLLM = new SilentLLM(); + const provisioner = new FakeProvisioner(summarizationLLM); + + const conversationState = new Map(); + + const prismaClient = { + conversationState: { + findUnique: async ({ where }: { where: { threadId_nodeId: { threadId: string; nodeId: string } } }) => { + const { threadId, nodeId } = where.threadId_nodeId; + const key = `${threadId}::${nodeId}`; + if (!conversationState.has(key)) return null; + return { threadId, nodeId, state: conversationState.get(key) }; + }, + upsert: async ({ + where, + create, + update, + }: { + where: { threadId_nodeId: { threadId: string; nodeId: string } }; + create: { threadId: string; nodeId: string; state: unknown }; + update: { state: unknown }; + }) => { + const { threadId, nodeId } = where.threadId_nodeId; + const key = `${threadId}::${nodeId}`; + const payload = conversationState.has(key) ? update.state : create.state; + conversationState.set(key, payload); + return { threadId, nodeId, state: payload }; + }, + }, + }; + + let runCounter = 0; + const threadModels = new Map(); + + const moduleRef = await Test.createTestingModule({ + providers: [ + { provide: ConfigService, useValue: config }, + AgentNode, + RunSignalsRegistry, + { provide: LLMProvisioner, useValue: provisioner }, + { + provide: PrismaService, + useValue: { + getClient: () => prismaClient, + }, + }, + { provide: RunEventsService, useValue: runEvents }, + { provide: EventsBusService, useValue: eventsBus }, + { + provide: AgentsPersistenceService, + useValue: { + beginRunThread: vi.fn(async () => ({ runId: `run-${++runCounter}` })), + completeRun: vi.fn(async () => {}), + recordInjected: vi.fn(async () => ({ messageIds: [] })), + ensureThreadModel: vi.fn(async (threadId: string, model: string) => { + if (threadModels.has(threadId)) { + return threadModels.get(threadId) ?? model; + } + threadModels.set(threadId, model); + return model; + }), + }, + }, + ], + }).compile(); + + const agent = await moduleRef.resolve(AgentNode); + agent.init({ nodeId: 'agent-node' }); + await agent.setConfig({ + debounceMs: 0, + sendFinalResponseToThread: false, + summarizationKeepTokens: 0, + summarizationMaxTokens: 8192, + }); + + const tool = new DemoFunctionTool('demo'); + agent.addTool(new DemoToolNode(tool)); + + return { + agent, + moduleRef, + provisioner, + conversationState, + registerCallModelLLM: (llm: ScriptableLLM) => provisioner.setNextCallModelLLM(llm), + } satisfies AgentFixture; +}; + +const summarizeInput = (input: Parameters[0]['input']) => { + const order = input.map((msg) => msg.constructor.name); + const counts = { + system: input.filter((msg) => msg instanceof SystemMessage).length, + human: input.filter((msg) => msg instanceof HumanMessage).length, + response: input.filter((msg) => msg instanceof ResponseMessage).length, + toolCallOutput: input.filter((msg) => msg instanceof ToolCallOutputMessage).length, + }; + return { order, counts }; +}; + +describe('LLM full-flow duplication integration', () => { + beforeEach(() => { + vi.restoreAllMocks(); + }); + + afterEach(() => { + clearTestConfig(); + }); + + it('captures second model call input within a single run', async () => { + const fixture = await createAgentFixture(); + const { agent, moduleRef, registerCallModelLLM } = fixture; + + try { + const scriptedLLM = new ScriptableLLM(); + scriptedLLM.setScript([ + { kind: 'tool_call', callId: 'call-1', name: 'demo' }, + { kind: 'text', text: 'final' }, + ]); + registerCallModelLLM(scriptedLLM); + + const result = await agent.invoke('thread-alpha', [HumanMessage.fromText('start')]); + expect(result).toBeInstanceOf(ResponseMessage); + expect(result.text).toBe('final'); + + expect(scriptedLLM.inputs.length).toBe(2); + const secondCall = scriptedLLM.inputs[1]; + expect(secondCall).toBeDefined(); + + const summary = summarizeInput(secondCall?.raw ?? []); + console.info('Second call input (single run):', JSON.stringify(summary, null, 2)); + + expect(secondCall?.flat ?? []).toMatchObject([ + { + role: 'system', + content: [{ type: 'input_text', text: 'You are a helpful AI assistant.' }], + }, + { + type: 'message', + role: 'user', + content: [{ type: 'input_text', text: 'start' }], + }, + { + type: 'function_call', + call_id: 'call-1', + name: 'demo', + arguments: '{}', + }, + { + type: 'function_call_output', + call_id: 'call-1', + output: 'ok', + }, + ]); + } finally { + await moduleRef.close(); + } + }); + + it('filters empty assistant outputs when paired with tool calls', async () => { + const fixture = await createAgentFixture(); + const { agent, moduleRef, registerCallModelLLM } = fixture; + + try { + const scriptedLLM = new ScriptableLLM(); + const toolCallPlain = createToolCallPlain('call-mixed'); + const emptyAssistantPlain = AIMessage.fromText('').toPlain(); + + scriptedLLM.setScript([ + { kind: 'response', output: [toolCallPlain, emptyAssistantPlain] }, + { kind: 'text', text: 'final' }, + ]); + registerCallModelLLM(scriptedLLM); + + const result = await agent.invoke('thread-mixed', [HumanMessage.fromText('start')] ); + expect(result).toBeInstanceOf(ResponseMessage); + expect(result.text).toBe('final'); + + expect(scriptedLLM.inputs.length).toBe(2); + const secondCallInput = scriptedLLM.inputs[1]; + const rawMessages = secondCallInput?.raw ?? []; + const flattenedMessages = secondCallInput?.flat ?? []; + + const summary = summarizeInput(rawMessages); + console.info('Second call input (mixed response):', JSON.stringify(summary, null, 2)); + console.debug('Second call flattened input (mixed response):', JSON.stringify(flattenedMessages, null, 2)); + + const responseMessages = rawMessages.filter((msg): msg is ResponseMessage => msg instanceof ResponseMessage); + expect(responseMessages.length).toBe(1); + const responsePayloads = responseMessages.map((msg) => msg.toPlain()); + console.debug('Second call response payloads (mixed response):', JSON.stringify(responsePayloads, null, 2)); + + const [firstResponse] = responseMessages; + const responseOutputs = firstResponse.output; + expect(responseOutputs.length).toBe(2); + const toolCallOutputs = responseOutputs.filter((output) => output instanceof ToolCallMessage); + const assistantOutputs = responseOutputs.filter((output) => output instanceof AIMessage); + + expect(toolCallOutputs.length).toBe(1); + expect(assistantOutputs.length).toBe(1); + expect(assistantOutputs[0]?.text).toBe(''); + + expect(flattenedMessages).toMatchObject([ + { + role: 'system', + content: [{ type: 'input_text', text: 'You are a helpful AI assistant.' }], + }, + { + type: 'message', + role: 'user', + content: [{ type: 'input_text', text: 'start' }], + }, + { + type: 'function_call', + call_id: 'call-mixed', + name: 'demo', + arguments: '{}', + }, + { + type: 'function_call_output', + call_id: 'call-mixed', + output: 'ok', + }, + ]); + } finally { + await moduleRef.close(); + } + }); + + it('captures duplicate tool_call assistant outputs within a single run', async () => { + const fixture = await createAgentFixture(); + const { agent, moduleRef, registerCallModelLLM } = fixture; + + try { + const scriptedLLM = new ScriptableLLM(); + const duplicateCallId = 'call-duplicate'; + scriptedLLM.setScript([ + { + kind: 'response', + output: [createToolCallPlain(duplicateCallId), createToolCallPlain(duplicateCallId)], + }, + { kind: 'text', text: 'final' }, + ]); + registerCallModelLLM(scriptedLLM); + + const result = await agent.invoke('thread-duplicate-single', [HumanMessage.fromText('start')]); + expect(result).toBeInstanceOf(ResponseMessage); + expect(result.text).toBe('final'); + + expect(scriptedLLM.inputs.length).toBe(2); + const secondCallInput = scriptedLLM.inputs[1]; + const rawMessages = secondCallInput?.raw ?? []; + const flattenedMessages = secondCallInput?.flat ?? []; + + const summary = summarizeInput(rawMessages); + console.info('Second call input (duplicate tool calls, single run):', JSON.stringify(summary, null, 2)); + console.debug( + 'Second call flattened input (duplicate tool calls, single run):', + JSON.stringify(flattenedMessages, null, 2), + ); + + const responseMessages = rawMessages.filter((msg): msg is ResponseMessage => msg instanceof ResponseMessage); + expect(responseMessages.length).toBe(1); + const [response] = responseMessages; + const toolCallOutputs = response.output.filter((entry) => entry instanceof ToolCallMessage) as ToolCallMessage[]; + const assistantOutputs = response.output.filter((entry) => entry instanceof AIMessage); + + expect(toolCallOutputs.length).toBe(2); + expect(toolCallOutputs[0].toPlain()).toEqual(toolCallOutputs[1].toPlain()); + expect(assistantOutputs.length).toBe(0); + + const flattenedFunctionCalls = flattenedMessages.filter((entry) => entry?.type === 'function_call'); + expect(flattenedFunctionCalls.length).toBe(2); + console.debug( + 'Second call flattened function calls (duplicate tool calls, single run):', + JSON.stringify(flattenedFunctionCalls, null, 2), + ); + expect(flattenedFunctionCalls[0]).toEqual(flattenedFunctionCalls[1]); + } finally { + await moduleRef.close(); + } + }); + + it('captures first model call input after loading persisted state in a new run', async () => { + const fixture = await createAgentFixture(); + const { agent, moduleRef, registerCallModelLLM } = fixture; + + try { + const firstRunLLM = new ScriptableLLM(); + firstRunLLM.setScript([ + { kind: 'tool_call', callId: 'call-1', name: 'demo' }, + { kind: 'text', text: 'final' }, + ]); + registerCallModelLLM(firstRunLLM); + + const firstResult = await agent.invoke('thread-beta', [HumanMessage.fromText('initial')]); + expect(firstResult).toBeInstanceOf(ResponseMessage); + + const secondRunLLM = new ScriptableLLM(); + secondRunLLM.setScript([{ kind: 'text', text: 'follow-up' }]); + registerCallModelLLM(secondRunLLM); + + const followUp = await agent.invoke('thread-beta', [HumanMessage.fromText('next')]); + expect(followUp).toBeInstanceOf(ResponseMessage); + + expect(secondRunLLM.inputs.length).toBe(1); + const secondRunCall = secondRunLLM.inputs[0]; + const freshRunInput = secondRunCall?.raw ?? []; + expect(freshRunInput.length).toBeGreaterThan(0); + + const summary = summarizeInput(freshRunInput); + console.info('First call input after load (new run):', JSON.stringify(summary, null, 2)); + + const responseMessages = freshRunInput.filter( + (msg): msg is ResponseMessage => msg instanceof ResponseMessage, + ); + const responsePayloads = responseMessages.map((msg) => msg.toPlain()); + + if (responsePayloads.length > 0) { + console.debug( + 'First call response payloads:', + JSON.stringify(responsePayloads, null, 2), + ); + } + + console.debug( + 'First call flattened input after load (new run):', + JSON.stringify(secondRunCall?.flat ?? [], null, 2), + ); + + expect(secondRunCall?.flat ?? []).toMatchObject([ + { + role: 'system', + content: [{ type: 'input_text', text: 'You are a helpful AI assistant.' }], + }, + { + type: 'message', + role: 'user', + content: [{ type: 'input_text', text: 'initial' }], + }, + { + type: 'function_call', + call_id: 'call-1', + name: 'demo', + arguments: '{}', + }, + { + type: 'function_call_output', + call_id: 'call-1', + output: 'ok', + }, + { + type: 'message', + role: 'assistant', + content: [{ type: 'output_text', text: 'final' }], + }, + { + type: 'message', + role: 'user', + content: [{ type: 'input_text', text: 'next' }], + }, + ]); + } finally { + await moduleRef.close(); + } + }); + + it('persists duplicate tool_call outputs across runs', async () => { + const fixture = await createAgentFixture(); + const { agent, moduleRef, registerCallModelLLM } = fixture; + + try { + const callId = 'call-duplicate-persist'; + const firstRunLLM = new ScriptableLLM(); + firstRunLLM.setScript([ + { + kind: 'response', + output: [createToolCallPlain(callId), createToolCallPlain(callId)], + }, + { kind: 'text', text: 'final' }, + ]); + registerCallModelLLM(firstRunLLM); + + const firstResult = await agent.invoke('thread-duplicate-persist', [HumanMessage.fromText('initial')]); + expect(firstResult).toBeInstanceOf(ResponseMessage); + + const secondRunLLM = new ScriptableLLM(); + secondRunLLM.setScript([{ kind: 'text', text: 'follow-up' }]); + registerCallModelLLM(secondRunLLM); + + const followUp = await agent.invoke('thread-duplicate-persist', [HumanMessage.fromText('next')]); + expect(followUp).toBeInstanceOf(ResponseMessage); + expect(followUp.text).toBe('follow-up'); + + expect(secondRunLLM.inputs.length).toBe(1); + const firstCallInput = secondRunLLM.inputs[0]; + const rawMessages = firstCallInput?.raw ?? []; + const flattenedMessages = firstCallInput?.flat ?? []; + + const summary = summarizeInput(rawMessages); + console.info( + 'First call input after load (duplicate tool calls, new run):', + JSON.stringify(summary, null, 2), + ); + console.debug( + 'First call flattened input after load (duplicate tool calls, new run):', + JSON.stringify(flattenedMessages, null, 2), + ); + + const responseMessages = rawMessages.filter((msg): msg is ResponseMessage => msg instanceof ResponseMessage); + expect(responseMessages.length).toBeGreaterThan(0); + const duplicateResponse = responseMessages.find((msg) => + msg.output.some((entry) => entry instanceof ToolCallMessage), + ); + expect(duplicateResponse).toBeDefined(); + + const toolCallOutputs = duplicateResponse!.output.filter( + (entry): entry is ToolCallMessage => entry instanceof ToolCallMessage, + ); + const assistantOutputs = duplicateResponse!.output.filter((entry) => entry instanceof AIMessage); + + expect(toolCallOutputs.length).toBe(2); + expect(toolCallOutputs[0].toPlain()).toEqual(toolCallOutputs[1].toPlain()); + expect(assistantOutputs.length).toBe(0); + console.debug( + 'Persisted duplicate tool call payloads:', + JSON.stringify(toolCallOutputs.map((entry) => entry.toPlain()), null, 2), + ); + + const flattenedFunctionCalls = flattenedMessages.filter((entry) => entry?.type === 'function_call'); + expect(flattenedFunctionCalls.length).toBeGreaterThanOrEqual(2); + expect(flattenedFunctionCalls[0]).toEqual(flattenedFunctionCalls[1]); + } finally { + await moduleRef.close(); + } + }); +}); diff --git a/packages/platform-server/__tests__/llm.second_call.input.duplication.test.ts b/packages/platform-server/__tests__/llm.second_call.input.duplication.test.ts new file mode 100644 index 000000000..91d24a5dc --- /dev/null +++ b/packages/platform-server/__tests__/llm.second_call.input.duplication.test.ts @@ -0,0 +1,244 @@ +import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest'; +import { Test } from '@nestjs/testing'; + +import { AgentNode } from '../src/nodes/agent/agent.node'; +import { ConfigService } from '../src/core/services/config.service'; +import { LLMProvisioner } from '../src/llm/provisioners/llm.provisioner'; +import type { LLM } from '@agyn/llm'; +import { AIMessage, HumanMessage, ResponseMessage, ToolCallMessage } from '@agyn/llm'; +import { PrismaService } from '../src/core/services/prisma.service'; +import { RunEventsService } from '../src/events/run-events.service'; +import { EventsBusService } from '../src/events/events-bus.service'; +import { RunSignalsRegistry } from '../src/agents/run-signals.service'; +import { AgentsPersistenceService } from '../src/agents/agents.persistence.service'; +import { createEventsBusStub, createRunEventsStub } from './helpers/runEvents.stub'; + +vi.mock('@agyn/docker-runner', () => ({})); + +class FakeLLM implements Pick { + public readonly calls: Array<{ + model: string; + input: Parameters[0]['input']; + tools?: unknown[]; + flat: unknown[]; + }> = []; + + async call(params: Parameters[0]): Promise { + const flat = params.input.flatMap((msg) => { + if (msg instanceof ResponseMessage) { + const outputMessages = msg.output; + const containsToolCall = outputMessages.some((entry) => entry instanceof ToolCallMessage); + return outputMessages + .filter((entry) => { + if (!containsToolCall) return true; + if (!(entry instanceof AIMessage)) return true; + return entry.text.trim().length > 0; + }) + .map((entry) => entry.toPlain()); + } + return [msg.toPlain()]; + }); + + this.calls.push({ model: params.model, input: params.input, tools: params.tools, flat }); + const order = this.calls.length; + if (order === 1) { + return this.toolCallResponse(); + } + if (order === 2) { + return ResponseMessage.fromText('final'); + } + return ResponseMessage.fromText(`extra-${order}`); + } + + private toolCallResponse(): ResponseMessage { + const toolCallPlain = { + id: 'call-1', + type: 'function_call', + call_id: 'call-1', + name: 'demo', + arguments: '{}', + status: 'completed', + } as ReturnType & { status: string }; + + const emptyAssistantPlain = { + id: 'msg-1', + type: 'message', + role: 'assistant', + status: 'completed', + content: [ + { + type: 'output_text', + text: '', + annotations: [], + }, + ], + } as ReturnType & { status: string }; + + return new ResponseMessage({ output: [emptyAssistantPlain, toolCallPlain] as any }); + } +} + +class SilentLLM implements Pick { + async call(): Promise { + throw new Error('Summarization LLM should not be invoked in this test'); + } +} + +class FakeProvisioner extends LLMProvisioner { + private callIndex = 0; + + constructor(private readonly callModelLLM: FakeLLM, private readonly summarizationLLM: SilentLLM) { + super(); + } + + async init(): Promise {} + + async getLLM(): Promise { + this.callIndex += 1; + if (this.callIndex === 1) { + return this.callModelLLM as unknown as LLM; + } + return this.summarizationLLM as unknown as LLM; + } + + async teardown(): Promise {} +} + +describe('AgentNode second LLM call input', () => { + const baseConfig: Partial = { + llmProvider: 'fake', + }; + + let moduleRef: Awaited>; + let agent: AgentNode; + let fakeLLM: FakeLLM; + + const conversationState = new Map(); + + beforeEach(async () => { + fakeLLM = new FakeLLM(); + const summaryLLM = new SilentLLM(); + const provisioner = new FakeProvisioner(fakeLLM, summaryLLM); + const runEvents = createRunEventsStub(); + const eventsBus = createEventsBusStub(); + + const prismaClient = { + conversationState: { + findUnique: async ({ where }: { where: { threadId_nodeId: { threadId: string; nodeId: string } } }) => { + const { threadId, nodeId } = where.threadId_nodeId; + const key = `${threadId}::${nodeId}`; + if (!conversationState.has(key)) return null; + return { threadId, nodeId, state: conversationState.get(key) }; + }, + upsert: async ({ + where, + create, + update, + }: { + where: { threadId_nodeId: { threadId: string; nodeId: string } }; + create: { threadId: string; nodeId: string; state: unknown }; + update: { state: unknown }; + }) => { + const { threadId, nodeId } = where.threadId_nodeId; + const key = `${threadId}::${nodeId}`; + const payload = conversationState.has(key) ? update.state : create.state; + conversationState.set(key, payload); + return { threadId, nodeId, state: payload }; + }, + }, + }; + + let runCounter = 0; + const threadModels = new Map(); + + moduleRef = await Test.createTestingModule({ + providers: [ + AgentNode, + RunSignalsRegistry, + { provide: ConfigService, useValue: baseConfig }, + { provide: LLMProvisioner, useValue: provisioner }, + { + provide: PrismaService, + useValue: { + getClient: () => prismaClient, + }, + }, + { provide: RunEventsService, useValue: runEvents }, + { provide: EventsBusService, useValue: eventsBus }, + { + provide: AgentsPersistenceService, + useValue: { + beginRunThread: async () => ({ runId: `run-${++runCounter}` }), + completeRun: async () => {}, + recordInjected: async () => ({ messageIds: [] }), + ensureThreadModel: async (threadId: string, model: string) => { + if (!threadModels.has(threadId)) { + threadModels.set(threadId, model); + return model; + } + return threadModels.get(threadId) ?? model; + }, + }, + }, + ], + }).compile(); + + agent = await moduleRef.resolve(AgentNode); + agent.init({ nodeId: 'agent-node' }); + await agent.setConfig({ + debounceMs: 0, + sendFinalResponseToThread: false, + summarizationKeepTokens: 0, + summarizationMaxTokens: 1024, + }); + }); + + afterEach(async () => { + await moduleRef?.close(); + conversationState.clear(); + }); + + it('emits a single tool_call entry in the second model invocation', async () => { + const result = await agent.invoke('thread-dup', [HumanMessage.fromText('start')]); + expect(result).toBeInstanceOf(ResponseMessage); + + expect(fakeLLM.calls.length).toBeGreaterThanOrEqual(2); + const secondCall = fakeLLM.calls[1]; + expect(secondCall).toBeDefined(); + const flattened = secondCall?.flat ?? []; + + const assistantMessages = flattened.filter( + (entry: any) => entry?.type === 'message' && entry?.role === 'assistant', + ); + const functionCalls = flattened.filter((entry: any) => entry?.type === 'function_call'); + + expect(assistantMessages).toHaveLength(0); + expect(functionCalls).toHaveLength(1); + expect(functionCalls[0]).toMatchObject({ call_id: 'call-1', name: 'demo', arguments: '{}' }); + + const functionCallOutputs = flattened.filter((entry: any) => entry?.type === 'function_call_output'); + expect(functionCallOutputs).toHaveLength(1); + expect(functionCallOutputs[0]).toMatchObject({ call_id: 'call-1' }); + + expect(flattened.length).toBe(4); + expect(flattened[0]).toMatchObject({ + role: 'system', + content: [{ type: 'input_text', text: 'You are a helpful AI assistant.' }], + }); + expect(flattened[1]).toMatchObject({ + type: 'message', + role: 'user', + content: [{ type: 'input_text', text: 'start' }], + }); + expect(flattened[2]).toMatchObject({ + type: 'function_call', + call_id: 'call-1', + name: 'demo', + arguments: '{}', + }); + expect(flattened[3]).toMatchObject({ + type: 'function_call_output', + call_id: 'call-1', + }); + }); +}); diff --git a/packages/platform-server/__tests__/load.reducer.merge.behavior.test.ts b/packages/platform-server/__tests__/load.reducer.merge.behavior.test.ts new file mode 100644 index 000000000..370bf3c96 --- /dev/null +++ b/packages/platform-server/__tests__/load.reducer.merge.behavior.test.ts @@ -0,0 +1,151 @@ +import { afterEach, describe, expect, it, vi } from 'vitest'; +import { LoadLLMReducer } from '../src/llm/reducers/load.llm.reducer'; +import { CallModelLLMReducer } from '../src/llm/reducers/callModel.llm.reducer'; +import { ConversationStateRepository } from '../src/llm/repositories/conversationState.repository'; +import { Signal } from '../src/signal'; +import type { LLMContext, LLMState } from '../src/llm/types'; +import type { PrismaService } from '../src/core/services/prisma.service'; +import type { RunEventsService } from '../src/events/run-events.service'; +import type { EventsBusService } from '../src/events/events-bus.service'; +import { AIMessage, ResponseMessage, ToolCallMessage, type LLM } from '@agyn/llm'; + +const prismaServiceStub = { + getClient: () => ({}), +} as unknown as PrismaService; + +const THREAD_ID = 'thread-merge'; +const NODE_ID = 'agent'; + +function baseContext(): LLMContext { + const response = ResponseMessage.fromText('noop'); + return { + threadId: THREAD_ID, + runId: 'run-1', + finishSignal: new Signal(), + terminateSignal: new Signal(), + callerAgent: { + getAgentNodeId: () => NODE_ID, + invoke: async () => response, + }, + }; +} + +function deepClone(value: T): T { + return JSON.parse(JSON.stringify(value)); +} + +function setupPersistedState( + reducer: LoadLLMReducer, + persisted: LLMState, + ctx: LLMContext, +): void { + vi.spyOn(ConversationStateRepository.prototype, 'get').mockResolvedValue({ + threadId: ctx.threadId, + nodeId: NODE_ID, + state: reducer['serializeState'](persisted), + }); +} + +function callReducerWithMocks(llmCallMock: ReturnType): CallModelLLMReducer { + const runEventsStub = { + startLLMCall: vi.fn(async () => ({ id: 'llm-event-id' })), + completeLLMCall: vi.fn(async () => {}), + createContextItems: vi.fn(async () => ['ctx-assistant']), + publishEvent: vi.fn(async () => null), + } as unknown as RunEventsService; + + const eventsBusStub = { + publishEvent: vi.fn(async () => null), + } as unknown as EventsBusService; + + const reducer = new CallModelLLMReducer(runEventsStub, eventsBusStub); + reducer.init({ + llm: { call: llmCallMock } as unknown as LLM, + model: 'gpt-test', + systemPrompt: 'system prompt', + tools: [], + }); + return reducer; +} + +afterEach(() => { + vi.restoreAllMocks(); +}); + +describe('LoadLLMReducer merge behavior', () => { + it('concatenates persisted and incoming response messages without deduplication', async () => { + const reducer = new LoadLLMReducer(prismaServiceStub); + const ctx = baseContext(); + + const persistedMessage = ResponseMessage.fromText('persisted'); + const incomingMessage = ResponseMessage.fromText('incoming'); + + const persistedState: LLMState = { messages: [persistedMessage], context: { messageIds: [], memory: [] } }; + const incomingState: LLMState = { messages: [incomingMessage], context: { messageIds: [], memory: [] } }; + + setupPersistedState(reducer, persistedState, ctx); + + const merged = await reducer.invoke(incomingState, ctx); + + expect(merged.messages).toHaveLength(2); + const responseMessages = merged.messages.filter((msg): msg is ResponseMessage => msg instanceof ResponseMessage); + expect(responseMessages).toHaveLength(2); + expect(responseMessages[0].text).toBe('persisted'); + expect(responseMessages[1].text).toBe('incoming'); + }); + + it('keeps tool calls while filtering empty assistant text during LLM input assembly', async () => { + const reducer = new LoadLLMReducer(prismaServiceStub); + const ctx = baseContext(); + + const toolCallPlain = { + type: 'function_call', + call_id: 'call-1', + name: 'lookup_user', + arguments: '{"id":42}', + } as ReturnType; + + const emptyAssistantPlain = AIMessage.fromText('').toPlain(); + + const persistedMessage = new ResponseMessage({ output: [deepClone(emptyAssistantPlain), deepClone(toolCallPlain)] }); + + const persistedState: LLMState = { messages: [persistedMessage], context: { messageIds: [], memory: [] } }; + const incomingState: LLMState = { messages: [], context: { messageIds: [], memory: [] } }; + + setupPersistedState(reducer, persistedState, ctx); + + const merged = await reducer.invoke(incomingState, ctx); + + const llmCallMock = vi.fn(async ({ input }: Parameters[0]) => { + const flatten = input.flatMap((msg) => { + if (msg instanceof ResponseMessage) { + const output = msg.output; + const includesToolCall = output.some((entry) => entry instanceof ToolCallMessage); + return output + .filter((entry) => { + if (!includesToolCall) return true; + if (!(entry instanceof AIMessage)) return true; + return entry.text.trim().length > 0; + }) + .map((entry) => entry.toPlain()); + } + return [msg.toPlain()]; + }); + + const assistantMessages = flatten.filter((entry: any) => entry?.role === 'assistant'); + const toolCalls = flatten.filter((entry: any) => entry?.type === 'function_call'); + + expect(assistantMessages).toHaveLength(0); + expect(toolCalls).toHaveLength(1); + expect(toolCalls[0]).toMatchObject({ call_id: 'call-1', name: 'lookup_user', arguments: '{"id":42}' }); + + return ResponseMessage.fromText('ok'); + }); + + const callReducer = callReducerWithMocks(llmCallMock); + + await callReducer.invoke(merged, ctx); + + expect(llmCallMock).toHaveBeenCalledTimes(1); + }); +}); diff --git a/packages/platform-server/src/llm/reducers/load.llm.reducer.ts b/packages/platform-server/src/llm/reducers/load.llm.reducer.ts index 298f727b5..3f2d14624 100644 --- a/packages/platform-server/src/llm/reducers/load.llm.reducer.ts +++ b/packages/platform-server/src/llm/reducers/load.llm.reducer.ts @@ -39,9 +39,12 @@ export class LoadLLMReducer extends PersistenceBaseLLMReducer { system: persistedContext.system ?? incomingContext.system, }; + // Preserve the persisted transcript order and append the newly received messages. + const mergedMessages = [...persisted.messages, ...state.messages]; + const merged: LLMState = { summary: persisted.summary ?? state.summary, - messages: [...persisted.messages, ...state.messages], + messages: mergedMessages, context: mergedContext, meta: state.meta, }; @@ -55,7 +58,6 @@ export class LoadLLMReducer extends PersistenceBaseLLMReducer { return { ...state, context: this.ensureContext(state.context) }; } } - private ensureContext(context: LLMContextState | undefined): LLMContextState { if (!context) return { messageIds: [], memory: [], pendingNewContextItemIds: [] }; return {