diff --git a/.env.example b/.env.example index 77ca0f3a3..7bc642714 100644 --- a/.env.example +++ b/.env.example @@ -76,6 +76,7 @@ # Reuses OPENAI_API_KEY / OPENAI_BASE_URL above when EMBEDDING_PROVIDER=openai. # OPENAI_EMBEDDING_MODEL=text-embedding-3-small # Embedding model when EMBEDDING_PROVIDER=openai # OPENAI_EMBEDDING_DIMENSIONS=1536 # Required when the model is not in the known-models table +# OPENAI_EMBEDDING_MAX_BATCH=256 # Max inputs per /embeddings request before splitting into sequential sub-batches # OPENROUTER_EMBEDDING_MODEL=openai/text-embedding-3-small # When EMBEDDING_PROVIDER=openrouter diff --git a/README.md b/README.md index c4ec2c1e0..668472497 100644 --- a/README.md +++ b/README.md @@ -1419,6 +1419,7 @@ Create `~/.agentmemory/.env`: # OPENAI_BASE_URL=https://api.openai.com # Override for Azure / vLLM / LM Studio / proxies # OPENAI_EMBEDDING_MODEL=text-embedding-3-small # OPENAI_EMBEDDING_DIMENSIONS=1536 # Required when the model is not in the known-models table +# OPENAI_EMBEDDING_MAX_BATCH=256 # Max inputs per /embeddings request before splitting into sequential sub-batches. Lower this when self-hosted runners (e.g. Ollama) crash on large batches # Outbound LLM / embedding timeout # AGENTMEMORY_LLM_TIMEOUT_MS=60000 # Default: 60 000 ms (60 s). Applies to every diff --git a/src/providers/embedding/openai.ts b/src/providers/embedding/openai.ts index 7384d5137..e84fa21ce 100644 --- a/src/providers/embedding/openai.ts +++ b/src/providers/embedding/openai.ts @@ -37,6 +37,25 @@ function resolveDimensions(model: string, override: string | undefined): number return MODEL_DIMENSIONS[model] ?? DEFAULT_DIMENSIONS; } +const DEFAULT_MAX_BATCH = 256; + +// Some self-hosted OpenAI-compatible runners (e.g. Ollama) crash the model +// subprocess when one /embeddings request carries too many inputs (observed: +// ~600 inputs -> HTTP 400 "tokenize: EOF"), regardless of total token volume. +// embedBatch splits into sequential sub-batches of this size to stay under it. +function resolveMaxBatch(override: string | undefined): number { + if (override !== undefined && override.trim().length > 0) { + const raw = override.trim(); + if (!/^[1-9]\d*$/.test(raw)) { + throw new Error( + `OPENAI_EMBEDDING_MAX_BATCH must be a positive integer, got: ${override}`, + ); + } + return Number(raw); + } + return DEFAULT_MAX_BATCH; +} + /** * OpenAI-compatible embedding provider. * @@ -72,6 +91,9 @@ function resolveDimensions(model: string, override: string | undefined): number * OPENAI_EMBEDDING_DIMENSIONS — override reported dimensions (required for * custom / self-hosted models not in the * MODEL_DIMENSIONS table above) + * OPENAI_EMBEDDING_MAX_BATCH - max inputs per /embeddings request before + * splitting into sequential sub-batches + * (default: 256) */ export class OpenAIEmbeddingProvider implements EmbeddingProvider { readonly name = "openai"; @@ -81,6 +103,7 @@ export class OpenAIEmbeddingProvider implements EmbeddingProvider { private model: string; private isAzure: boolean; private azureApiVersion: string; + private maxBatch: number; constructor(apiKey?: string) { // Separate API key path: caller-passed wins, then OPENAI_EMBEDDING_API_KEY, @@ -111,6 +134,7 @@ export class OpenAIEmbeddingProvider implements EmbeddingProvider { this.isAzure = detectAzure(this.baseUrl); this.azureApiVersion = getEnvVar("OPENAI_API_VERSION") || DEFAULT_AZURE_API_VERSION; + this.maxBatch = resolveMaxBatch(getEnvVar("OPENAI_EMBEDDING_MAX_BATCH")); } async embed(text: string): Promise { @@ -119,6 +143,18 @@ export class OpenAIEmbeddingProvider implements EmbeddingProvider { } async embedBatch(texts: string[]): Promise { + if (texts.length <= this.maxBatch) { + return this.embedChunk(texts); + } + const out: Float32Array[] = []; + for (let i = 0; i < texts.length; i += this.maxBatch) { + const chunk = await this.embedChunk(texts.slice(i, i + this.maxBatch)); + for (const e of chunk) out.push(e); + } + return out; + } + + private async embedChunk(texts: string[]): Promise { const url = buildEmbeddingUrl( this.baseUrl, this.isAzure, diff --git a/test/embedding-provider.test.ts b/test/embedding-provider.test.ts index 6c2d263ec..d8707b2ce 100644 --- a/test/embedding-provider.test.ts +++ b/test/embedding-provider.test.ts @@ -62,6 +62,7 @@ describe("OpenAIEmbeddingProvider", () => { delete process.env["OPENAI_EMBEDDING_API_KEY"]; delete process.env["OPENAI_EMBEDDING_MODEL"]; delete process.env["OPENAI_EMBEDDING_DIMENSIONS"]; + delete process.env["OPENAI_EMBEDDING_MAX_BATCH"]; }); afterEach(() => { @@ -155,6 +156,182 @@ describe("OpenAIEmbeddingProvider", () => { /OPENAI_EMBEDDING_DIMENSIONS must be a positive integer/, ); }); + + it("performs a single POST when input count is at or below the default max (256)", async () => { + const provider = new OpenAIEmbeddingProvider("test-key"); + const fetchSpy = vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response( + JSON.stringify({ + data: Array.from({ length: 256 }, () => ({ embedding: [0.1, 0.2, 0.3] })), + }), + { status: 200 }, + ), + ); + + const texts = Array.from({ length: 256 }, (_, i) => `text-${i}`); + await provider.embedBatch(texts); + + expect(fetchSpy).toHaveBeenCalledTimes(1); + const body = JSON.parse((fetchSpy.mock.calls[0][1] as RequestInit).body as string); + expect(body.input).toHaveLength(256); + + fetchSpy.mockRestore(); + }); + + it("splits into sequential sub-batches when input count exceeds the max", async () => { + const provider = new OpenAIEmbeddingProvider("test-key"); + const fetchSpy = vi.spyOn(globalThis, "fetch").mockImplementation( + async (_url, init) => { + const body = JSON.parse((init as RequestInit).body as string); + const len = (body.input as string[]).length; + return new Response( + JSON.stringify({ + data: Array.from({ length: len }, () => ({ embedding: [0.1, 0.2, 0.3] })), + }), + { status: 200 }, + ); + }, + ); + + const texts = Array.from({ length: 600 }, (_, i) => `text-${i}`); + const result = await provider.embedBatch(texts); + + expect(result).toHaveLength(600); + expect(fetchSpy).toHaveBeenCalledTimes(3); + const sizes = fetchSpy.mock.calls.map( + (c) => (JSON.parse((c[1] as RequestInit).body as string).input as string[]).length, + ); + expect(sizes).toEqual([256, 256, 88]); + + fetchSpy.mockRestore(); + }); + + it("respects OPENAI_EMBEDDING_MAX_BATCH override", async () => { + process.env["OPENAI_EMBEDDING_MAX_BATCH"] = "100"; + const provider = new OpenAIEmbeddingProvider("test-key"); + const fetchSpy = vi.spyOn(globalThis, "fetch").mockImplementation( + async (_url, init) => { + const body = JSON.parse((init as RequestInit).body as string); + const len = (body.input as string[]).length; + return new Response( + JSON.stringify({ + data: Array.from({ length: len }, () => ({ embedding: [0.1, 0.2, 0.3] })), + }), + { status: 200 }, + ); + }, + ); + + const texts = Array.from({ length: 250 }, (_, i) => `text-${i}`); + await provider.embedBatch(texts); + + expect(fetchSpy).toHaveBeenCalledTimes(3); + const sizes = fetchSpy.mock.calls.map( + (c) => (JSON.parse((c[1] as RequestInit).body as string).input as string[]).length, + ); + expect(sizes).toEqual([100, 100, 50]); + + fetchSpy.mockRestore(); + }); + + it("preserves input order across chunked sub-batches", async () => { + process.env["OPENAI_EMBEDDING_MAX_BATCH"] = "2"; + const provider = new OpenAIEmbeddingProvider("test-key"); + const fetchSpy = vi.spyOn(globalThis, "fetch").mockImplementation( + async (_url, init) => { + const body = JSON.parse((init as RequestInit).body as string); + const inputs = body.input as string[]; + return new Response( + JSON.stringify({ + data: inputs.map((t) => ({ embedding: [parseFloat(t)] })), + }), + { status: 200 }, + ); + }, + ); + + const texts = ["0", "1", "2", "3", "4"]; + const result = await provider.embedBatch(texts); + + expect(result.map((e) => e[0])).toEqual([0, 1, 2, 3, 4]); + + fetchSpy.mockRestore(); + }); + + it("issues chunked sub-batches sequentially, not in parallel", async () => { + process.env["OPENAI_EMBEDDING_MAX_BATCH"] = "2"; + const provider = new OpenAIEmbeddingProvider("test-key"); + let inFlight = 0; + let maxConcurrent = 0; + const fetchSpy = vi.spyOn(globalThis, "fetch").mockImplementation( + async (_url, init) => { + inFlight++; + maxConcurrent = Math.max(maxConcurrent, inFlight); + await new Promise((r) => setTimeout(r, 5)); + inFlight--; + const body = JSON.parse((init as RequestInit).body as string); + const len = (body.input as string[]).length; + return new Response( + JSON.stringify({ + data: Array.from({ length: len }, () => ({ embedding: [0.1] })), + }), + { status: 200 }, + ); + }, + ); + + await provider.embedBatch(["a", "b", "c", "d", "e", "f"]); + + expect(maxConcurrent).toBe(1); + expect(fetchSpy).toHaveBeenCalledTimes(3); + + fetchSpy.mockRestore(); + }); + + it("propagates errors from a failing sub-batch", async () => { + process.env["OPENAI_EMBEDDING_MAX_BATCH"] = "2"; + const provider = new OpenAIEmbeddingProvider("test-key"); + let call = 0; + const fetchSpy = vi.spyOn(globalThis, "fetch").mockImplementation( + async (_url, init) => { + call++; + if (call === 2) { + return new Response("upstream tokenize EOF", { status: 400 }); + } + const body = JSON.parse((init as RequestInit).body as string); + const len = (body.input as string[]).length; + return new Response( + JSON.stringify({ + data: Array.from({ length: len }, () => ({ embedding: [0.1] })), + }), + { status: 200 }, + ); + }, + ); + + await expect(provider.embedBatch(["a", "b", "c", "d"])).rejects.toThrow( + /OpenAI embedding failed \(400\)/, + ); + + fetchSpy.mockRestore(); + }); + + it("rejects invalid OPENAI_EMBEDDING_MAX_BATCH values", () => { + process.env["OPENAI_EMBEDDING_MAX_BATCH"] = "not-a-number"; + expect(() => new OpenAIEmbeddingProvider("test-key")).toThrow( + /OPENAI_EMBEDDING_MAX_BATCH must be a positive integer/, + ); + + process.env["OPENAI_EMBEDDING_MAX_BATCH"] = "-5"; + expect(() => new OpenAIEmbeddingProvider("test-key")).toThrow( + /OPENAI_EMBEDDING_MAX_BATCH must be a positive integer/, + ); + + process.env["OPENAI_EMBEDDING_MAX_BATCH"] = "0"; + expect(() => new OpenAIEmbeddingProvider("test-key")).toThrow( + /OPENAI_EMBEDDING_MAX_BATCH must be a positive integer/, + ); + }); }); describe("withDimensionGuard", () => {