From b5a95c3bbe0fef6b2b6b59e0a2ba8d46286093cf Mon Sep 17 00:00:00 2001 From: justincasher Date: Sun, 12 Apr 2026 21:38:22 -0400 Subject: [PATCH] Make embedding and reranker batch sizes tunable via env vars --- scripts/generate_docs_data.py | 20 +++++++++----------- src/lean_explore/util/embedding_client.py | 18 ++++++++++++++++-- src/lean_explore/util/reranker_client.py | 21 +++++++++++++++++++-- 3 files changed, 44 insertions(+), 15 deletions(-) diff --git a/scripts/generate_docs_data.py b/scripts/generate_docs_data.py index ed04a03..8b31194 100644 --- a/scripts/generate_docs_data.py +++ b/scripts/generate_docs_data.py @@ -193,7 +193,7 @@ class ModuleDict(TypedDict): def resolve_annotation(annotation: str | Expr | None) -> str: """Converts a griffe annotation to its string representation.""" - if isinstance(annotation, (Expr, str)): + if isinstance(annotation, Expr | str): return str(annotation) return "" @@ -391,16 +391,14 @@ def parse_docstring(docstring_object: object | None) -> DocstringSections: parse_admonition_section(section, sections_data) elif isinstance( section, - ( - DocstringSectionDeprecated, - DocstringSectionWarns, - DocstringSectionYields, - DocstringSectionReceives, - DocstringSectionOtherParameters, - DocstringSectionClasses, - DocstringSectionFunctions, - DocstringSectionModules, - ), + DocstringSectionDeprecated + | DocstringSectionWarns + | DocstringSectionYields + | DocstringSectionReceives + | DocstringSectionOtherParameters + | DocstringSectionClasses + | DocstringSectionFunctions + | DocstringSectionModules, ): sections_data[kind] = parse_generic_section(section) else: diff --git a/src/lean_explore/util/embedding_client.py b/src/lean_explore/util/embedding_client.py index cef1a71..e257214 100644 --- a/src/lean_explore/util/embedding_client.py +++ b/src/lean_explore/util/embedding_client.py @@ -2,6 +2,7 @@ import asyncio import logging +import os import torch from pydantic import BaseModel @@ -9,6 +10,8 @@ logger = logging.getLogger(__name__) +DEFAULT_BATCH_SIZE = 8 + class EmbeddingResponse(BaseModel): """Response from embedding generation.""" @@ -27,7 +30,11 @@ class EmbeddingClient: """Client for generating text embeddings.""" def __init__( - self, model_name: str, device: str | None = None, max_length: int | None = None + self, + model_name: str, + device: str | None = None, + max_length: int | None = None, + batch_size: int | None = None, ): """Initialize the embedding client. @@ -36,10 +43,17 @@ def __init__( device: Device to use ("cuda", "mps", "cpu"). Auto-detects if None. max_length: Maximum sequence length for tokenization. If None, uses model default. Lower values reduce memory usage. + batch_size: Batch size for encode calls. Falls back to the + LEAN_EXPLORE_EMBEDDING_BATCH_SIZE env var, then to + DEFAULT_BATCH_SIZE. Raise on large-VRAM hardware for better + throughput. """ self.model_name = model_name self.device = device or self._select_device() self.max_length = max_length + self.batch_size = batch_size or int( + os.getenv("LEAN_EXPLORE_EMBEDDING_BATCH_SIZE", DEFAULT_BATCH_SIZE) + ) logger.info("Loading embedding model %s on %s", model_name, self.device) self.model = SentenceTransformer(model_name, device=self.device) @@ -78,7 +92,7 @@ def _encode(): encode_kwargs = { "show_progress_bar": False, "convert_to_numpy": True, - "batch_size": 8, + "batch_size": self.batch_size, } if is_query: encode_kwargs["prompt_name"] = "query" diff --git a/src/lean_explore/util/reranker_client.py b/src/lean_explore/util/reranker_client.py index f93e17c..4da1f94 100644 --- a/src/lean_explore/util/reranker_client.py +++ b/src/lean_explore/util/reranker_client.py @@ -2,6 +2,7 @@ import asyncio import logging +import os import torch from pydantic import BaseModel @@ -10,6 +11,8 @@ logger = logging.getLogger(__name__) DEFAULT_INSTRUCTION = "Find relevant Lean 4 math declarations" +DEFAULT_CUDA_BATCH_SIZE = 16 +DEFAULT_CPU_BATCH_SIZE = 32 class RerankerResponse(BaseModel): @@ -34,6 +37,7 @@ def __init__( device: str | None = None, max_length: int = 512, instruction: str = DEFAULT_INSTRUCTION, + batch_size: int | None = None, ): """Initialize the reranker client. @@ -42,12 +46,26 @@ def __init__( device: Device to use ("cuda", "mps", "cpu"). Auto-detects if None. max_length: Maximum sequence length for tokenization. instruction: Task instruction prepended to each query-document pair. + batch_size: Default batch size for rerank() when the caller doesn't + specify one. Falls back to the LEAN_EXPLORE_RERANKER_BATCH_SIZE + env var, then to device-specific defaults (16 on CUDA, 32 on + CPU). Raise on large-VRAM hardware for better throughput. """ self.model_name = model_name self.device = device or self._select_device() self.max_length = max_length self.instruction = instruction + env_batch_size = os.getenv("LEAN_EXPLORE_RERANKER_BATCH_SIZE") + if batch_size is not None: + self.batch_size = batch_size + elif env_batch_size: + self.batch_size = int(env_batch_size) + elif self.device == "cuda": + self.batch_size = DEFAULT_CUDA_BATCH_SIZE + else: + self.batch_size = DEFAULT_CPU_BATCH_SIZE + logger.info("Loading reranker model %s on %s", model_name, self.device) self.tokenizer = AutoTokenizer.from_pretrained( @@ -162,9 +180,8 @@ async def rerank( if not documents: return RerankerResponse(query=query, scores=[], model=self.model_name) - # Default batch size: 16 on GPU (fits 8GB VRAM), 32 on CPU if batch_size is None: - batch_size = 16 if self.device == "cuda" else 32 + batch_size = self.batch_size # For small batches, run synchronously to avoid executor overhead if len(documents) <= batch_size: