Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
# Reuses OPENAI_API_KEY / OPENAI_BASE_URL above when EMBEDDING_PROVIDER=openai.
# OPENAI_EMBEDDING_MODEL=text-embedding-3-small # Embedding model when EMBEDDING_PROVIDER=openai
# OPENAI_EMBEDDING_DIMENSIONS=1536 # Required when the model is not in the known-models table
# OPENAI_EMBEDDING_MAX_BATCH=256 # Max inputs per /embeddings request before splitting into sequential sub-batches

# OPENROUTER_EMBEDDING_MODEL=openai/text-embedding-3-small # When EMBEDDING_PROVIDER=openrouter

Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1419,6 +1419,7 @@ Create `~/.agentmemory/.env`:
# OPENAI_BASE_URL=https://api.openai.com # Override for Azure / vLLM / LM Studio / proxies
# OPENAI_EMBEDDING_MODEL=text-embedding-3-small
# OPENAI_EMBEDDING_DIMENSIONS=1536 # Required when the model is not in the known-models table
# OPENAI_EMBEDDING_MAX_BATCH=256 # Max inputs per /embeddings request before splitting into sequential sub-batches. Lower this when self-hosted runners (e.g. Ollama) crash on large batches
# Outbound LLM / embedding timeout
# AGENTMEMORY_LLM_TIMEOUT_MS=60000 # Default: 60 000 ms (60 s). Applies to every
Expand Down
36 changes: 36 additions & 0 deletions src/providers/embedding/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,25 @@ function resolveDimensions(model: string, override: string | undefined): number
return MODEL_DIMENSIONS[model] ?? DEFAULT_DIMENSIONS;
}

const DEFAULT_MAX_BATCH = 256;

// Some self-hosted OpenAI-compatible runners (e.g. Ollama) crash the model
// subprocess when one /embeddings request carries too many inputs (observed:
// ~600 inputs -> HTTP 400 "tokenize: EOF"), regardless of total token volume.
// embedBatch splits into sequential sub-batches of this size to stay under it.
function resolveMaxBatch(override: string | undefined): number {
if (override !== undefined && override.trim().length > 0) {
const raw = override.trim();
if (!/^[1-9]\d*$/.test(raw)) {
throw new Error(
`OPENAI_EMBEDDING_MAX_BATCH must be a positive integer, got: ${override}`,
);
}
return Number(raw);
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
return DEFAULT_MAX_BATCH;
}

/**
* OpenAI-compatible embedding provider.
*
Expand Down Expand Up @@ -72,6 +91,9 @@ function resolveDimensions(model: string, override: string | undefined): number
* OPENAI_EMBEDDING_DIMENSIONS — override reported dimensions (required for
* custom / self-hosted models not in the
* MODEL_DIMENSIONS table above)
* OPENAI_EMBEDDING_MAX_BATCH - max inputs per /embeddings request before
* splitting into sequential sub-batches
* (default: 256)
*/
export class OpenAIEmbeddingProvider implements EmbeddingProvider {
readonly name = "openai";
Expand All @@ -81,6 +103,7 @@ export class OpenAIEmbeddingProvider implements EmbeddingProvider {
private model: string;
private isAzure: boolean;
private azureApiVersion: string;
private maxBatch: number;

constructor(apiKey?: string) {
// Separate API key path: caller-passed wins, then OPENAI_EMBEDDING_API_KEY,
Expand Down Expand Up @@ -111,6 +134,7 @@ export class OpenAIEmbeddingProvider implements EmbeddingProvider {
this.isAzure = detectAzure(this.baseUrl);
this.azureApiVersion =
getEnvVar("OPENAI_API_VERSION") || DEFAULT_AZURE_API_VERSION;
this.maxBatch = resolveMaxBatch(getEnvVar("OPENAI_EMBEDDING_MAX_BATCH"));
}

async embed(text: string): Promise<Float32Array> {
Expand All @@ -119,6 +143,18 @@ export class OpenAIEmbeddingProvider implements EmbeddingProvider {
}

async embedBatch(texts: string[]): Promise<Float32Array[]> {
if (texts.length <= this.maxBatch) {
return this.embedChunk(texts);
}
const out: Float32Array[] = [];
for (let i = 0; i < texts.length; i += this.maxBatch) {
const chunk = await this.embedChunk(texts.slice(i, i + this.maxBatch));
for (const e of chunk) out.push(e);
}
return out;
}

private async embedChunk(texts: string[]): Promise<Float32Array[]> {
const url = buildEmbeddingUrl(
this.baseUrl,
this.isAzure,
Expand Down
177 changes: 177 additions & 0 deletions test/embedding-provider.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ describe("OpenAIEmbeddingProvider", () => {
delete process.env["OPENAI_EMBEDDING_API_KEY"];
delete process.env["OPENAI_EMBEDDING_MODEL"];
delete process.env["OPENAI_EMBEDDING_DIMENSIONS"];
delete process.env["OPENAI_EMBEDDING_MAX_BATCH"];
});

afterEach(() => {
Expand Down Expand Up @@ -155,6 +156,182 @@ describe("OpenAIEmbeddingProvider", () => {
/OPENAI_EMBEDDING_DIMENSIONS must be a positive integer/,
);
});

it("performs a single POST when input count is at or below the default max (256)", async () => {
const provider = new OpenAIEmbeddingProvider("test-key");
const fetchSpy = vi.spyOn(globalThis, "fetch").mockResolvedValue(
new Response(
JSON.stringify({
data: Array.from({ length: 256 }, () => ({ embedding: [0.1, 0.2, 0.3] })),
}),
{ status: 200 },
),
);

const texts = Array.from({ length: 256 }, (_, i) => `text-${i}`);
await provider.embedBatch(texts);

expect(fetchSpy).toHaveBeenCalledTimes(1);
const body = JSON.parse((fetchSpy.mock.calls[0][1] as RequestInit).body as string);
expect(body.input).toHaveLength(256);

fetchSpy.mockRestore();
});

it("splits into sequential sub-batches when input count exceeds the max", async () => {
const provider = new OpenAIEmbeddingProvider("test-key");
const fetchSpy = vi.spyOn(globalThis, "fetch").mockImplementation(
async (_url, init) => {
const body = JSON.parse((init as RequestInit).body as string);
const len = (body.input as string[]).length;
return new Response(
JSON.stringify({
data: Array.from({ length: len }, () => ({ embedding: [0.1, 0.2, 0.3] })),
}),
{ status: 200 },
);
},
);

const texts = Array.from({ length: 600 }, (_, i) => `text-${i}`);
const result = await provider.embedBatch(texts);

expect(result).toHaveLength(600);
expect(fetchSpy).toHaveBeenCalledTimes(3);
const sizes = fetchSpy.mock.calls.map(
(c) => (JSON.parse((c[1] as RequestInit).body as string).input as string[]).length,
);
expect(sizes).toEqual([256, 256, 88]);

fetchSpy.mockRestore();
});

it("respects OPENAI_EMBEDDING_MAX_BATCH override", async () => {
process.env["OPENAI_EMBEDDING_MAX_BATCH"] = "100";
const provider = new OpenAIEmbeddingProvider("test-key");
const fetchSpy = vi.spyOn(globalThis, "fetch").mockImplementation(
async (_url, init) => {
const body = JSON.parse((init as RequestInit).body as string);
const len = (body.input as string[]).length;
return new Response(
JSON.stringify({
data: Array.from({ length: len }, () => ({ embedding: [0.1, 0.2, 0.3] })),
}),
{ status: 200 },
);
},
);

const texts = Array.from({ length: 250 }, (_, i) => `text-${i}`);
await provider.embedBatch(texts);

expect(fetchSpy).toHaveBeenCalledTimes(3);
const sizes = fetchSpy.mock.calls.map(
(c) => (JSON.parse((c[1] as RequestInit).body as string).input as string[]).length,
);
expect(sizes).toEqual([100, 100, 50]);

fetchSpy.mockRestore();
});

it("preserves input order across chunked sub-batches", async () => {
process.env["OPENAI_EMBEDDING_MAX_BATCH"] = "2";
const provider = new OpenAIEmbeddingProvider("test-key");
const fetchSpy = vi.spyOn(globalThis, "fetch").mockImplementation(
async (_url, init) => {
const body = JSON.parse((init as RequestInit).body as string);
const inputs = body.input as string[];
return new Response(
JSON.stringify({
data: inputs.map((t) => ({ embedding: [parseFloat(t)] })),
}),
{ status: 200 },
);
},
);

const texts = ["0", "1", "2", "3", "4"];
const result = await provider.embedBatch(texts);

expect(result.map((e) => e[0])).toEqual([0, 1, 2, 3, 4]);

fetchSpy.mockRestore();
});

it("issues chunked sub-batches sequentially, not in parallel", async () => {
process.env["OPENAI_EMBEDDING_MAX_BATCH"] = "2";
const provider = new OpenAIEmbeddingProvider("test-key");
let inFlight = 0;
let maxConcurrent = 0;
const fetchSpy = vi.spyOn(globalThis, "fetch").mockImplementation(
async (_url, init) => {
inFlight++;
maxConcurrent = Math.max(maxConcurrent, inFlight);
await new Promise((r) => setTimeout(r, 5));
inFlight--;
const body = JSON.parse((init as RequestInit).body as string);
const len = (body.input as string[]).length;
return new Response(
JSON.stringify({
data: Array.from({ length: len }, () => ({ embedding: [0.1] })),
}),
{ status: 200 },
);
},
);

await provider.embedBatch(["a", "b", "c", "d", "e", "f"]);

expect(maxConcurrent).toBe(1);
expect(fetchSpy).toHaveBeenCalledTimes(3);

fetchSpy.mockRestore();
});

it("propagates errors from a failing sub-batch", async () => {
process.env["OPENAI_EMBEDDING_MAX_BATCH"] = "2";
const provider = new OpenAIEmbeddingProvider("test-key");
let call = 0;
const fetchSpy = vi.spyOn(globalThis, "fetch").mockImplementation(
async (_url, init) => {
call++;
if (call === 2) {
return new Response("upstream tokenize EOF", { status: 400 });
}
const body = JSON.parse((init as RequestInit).body as string);
const len = (body.input as string[]).length;
return new Response(
JSON.stringify({
data: Array.from({ length: len }, () => ({ embedding: [0.1] })),
}),
{ status: 200 },
);
},
);

await expect(provider.embedBatch(["a", "b", "c", "d"])).rejects.toThrow(
/OpenAI embedding failed \(400\)/,
);

fetchSpy.mockRestore();
});

it("rejects invalid OPENAI_EMBEDDING_MAX_BATCH values", () => {
process.env["OPENAI_EMBEDDING_MAX_BATCH"] = "not-a-number";
expect(() => new OpenAIEmbeddingProvider("test-key")).toThrow(
/OPENAI_EMBEDDING_MAX_BATCH must be a positive integer/,
);

process.env["OPENAI_EMBEDDING_MAX_BATCH"] = "-5";
expect(() => new OpenAIEmbeddingProvider("test-key")).toThrow(
/OPENAI_EMBEDDING_MAX_BATCH must be a positive integer/,
);

process.env["OPENAI_EMBEDDING_MAX_BATCH"] = "0";
expect(() => new OpenAIEmbeddingProvider("test-key")).toThrow(
/OPENAI_EMBEDDING_MAX_BATCH must be a positive integer/,
);
});
});

describe("withDimensionGuard", () => {
Expand Down