Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions scripts/generate_docs_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ class ModuleDict(TypedDict):

def resolve_annotation(annotation: str | Expr | None) -> str:
"""Converts a griffe annotation to its string representation."""
if isinstance(annotation, (Expr, str)):
if isinstance(annotation, Expr | str):
return str(annotation)
return ""

Expand Down Expand Up @@ -391,16 +391,14 @@ def parse_docstring(docstring_object: object | None) -> DocstringSections:
parse_admonition_section(section, sections_data)
elif isinstance(
section,
(
DocstringSectionDeprecated,
DocstringSectionWarns,
DocstringSectionYields,
DocstringSectionReceives,
DocstringSectionOtherParameters,
DocstringSectionClasses,
DocstringSectionFunctions,
DocstringSectionModules,
),
DocstringSectionDeprecated
| DocstringSectionWarns
| DocstringSectionYields
| DocstringSectionReceives
| DocstringSectionOtherParameters
| DocstringSectionClasses
| DocstringSectionFunctions
| DocstringSectionModules,
):
sections_data[kind] = parse_generic_section(section)
else:
Expand Down
18 changes: 16 additions & 2 deletions src/lean_explore/util/embedding_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@

import asyncio
import logging
import os

import torch
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer

logger = logging.getLogger(__name__)

DEFAULT_BATCH_SIZE = 8


class EmbeddingResponse(BaseModel):
"""Response from embedding generation."""
Expand All @@ -27,7 +30,11 @@ class EmbeddingClient:
"""Client for generating text embeddings."""

def __init__(
self, model_name: str, device: str | None = None, max_length: int | None = None
self,
model_name: str,
device: str | None = None,
max_length: int | None = None,
batch_size: int | None = None,
):
"""Initialize the embedding client.

Expand All @@ -36,10 +43,17 @@ def __init__(
device: Device to use ("cuda", "mps", "cpu"). Auto-detects if None.
max_length: Maximum sequence length for tokenization. If None, uses
model default. Lower values reduce memory usage.
batch_size: Batch size for encode calls. Falls back to the
LEAN_EXPLORE_EMBEDDING_BATCH_SIZE env var, then to
DEFAULT_BATCH_SIZE. Raise on large-VRAM hardware for better
throughput.
"""
self.model_name = model_name
self.device = device or self._select_device()
self.max_length = max_length
self.batch_size = batch_size or int(
os.getenv("LEAN_EXPLORE_EMBEDDING_BATCH_SIZE", DEFAULT_BATCH_SIZE)
)
logger.info("Loading embedding model %s on %s", model_name, self.device)
self.model = SentenceTransformer(model_name, device=self.device)

Expand Down Expand Up @@ -78,7 +92,7 @@ def _encode():
encode_kwargs = {
"show_progress_bar": False,
"convert_to_numpy": True,
"batch_size": 8,
"batch_size": self.batch_size,
}
if is_query:
encode_kwargs["prompt_name"] = "query"
Expand Down
21 changes: 19 additions & 2 deletions src/lean_explore/util/reranker_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import logging
import os

import torch
from pydantic import BaseModel
Expand All @@ -10,6 +11,8 @@
logger = logging.getLogger(__name__)

DEFAULT_INSTRUCTION = "Find relevant Lean 4 math declarations"
DEFAULT_CUDA_BATCH_SIZE = 16
DEFAULT_CPU_BATCH_SIZE = 32


class RerankerResponse(BaseModel):
Expand All @@ -34,6 +37,7 @@ def __init__(
device: str | None = None,
max_length: int = 512,
instruction: str = DEFAULT_INSTRUCTION,
batch_size: int | None = None,
):
"""Initialize the reranker client.

Expand All @@ -42,12 +46,26 @@ def __init__(
device: Device to use ("cuda", "mps", "cpu"). Auto-detects if None.
max_length: Maximum sequence length for tokenization.
instruction: Task instruction prepended to each query-document pair.
batch_size: Default batch size for rerank() when the caller doesn't
specify one. Falls back to the LEAN_EXPLORE_RERANKER_BATCH_SIZE
env var, then to device-specific defaults (16 on CUDA, 32 on
CPU). Raise on large-VRAM hardware for better throughput.
"""
self.model_name = model_name
self.device = device or self._select_device()
self.max_length = max_length
self.instruction = instruction

env_batch_size = os.getenv("LEAN_EXPLORE_RERANKER_BATCH_SIZE")
if batch_size is not None:
self.batch_size = batch_size
elif env_batch_size:
self.batch_size = int(env_batch_size)
elif self.device == "cuda":
self.batch_size = DEFAULT_CUDA_BATCH_SIZE
else:
self.batch_size = DEFAULT_CPU_BATCH_SIZE

logger.info("Loading reranker model %s on %s", model_name, self.device)

self.tokenizer = AutoTokenizer.from_pretrained(
Expand Down Expand Up @@ -162,9 +180,8 @@ async def rerank(
if not documents:
return RerankerResponse(query=query, scores=[], model=self.model_name)

# Default batch size: 16 on GPU (fits 8GB VRAM), 32 on CPU
if batch_size is None:
batch_size = 16 if self.device == "cuda" else 32
batch_size = self.batch_size

# For small batches, run synchronously to avoid executor overhead
if len(documents) <= batch_size:
Expand Down
Loading