Skip to content

Commit 173403f

Browse files
Merge pull request #110 from LamQam/feature/batch-inference
Implemented batch inference support
2 parents 85cfa26 + a3e2a9c commit 173403f

2 files changed

Lines changed: 722 additions & 130 deletions

File tree

datafast/llms.py

Lines changed: 123 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# LiteLLM
1717
import litellm
1818
from litellm.utils import ModelResponse
19+
from litellm import batch_completion
1920

2021
# Internal imports
2122
from .llm_utils import get_messages
@@ -25,9 +26,10 @@
2526
Messages = list[Message]
2627
T = TypeVar('T', bound=BaseModel)
2728

29+
2830
class LLMProvider(ABC):
2931
"""Abstract base class for LLM providers."""
30-
32+
3133
def __init__(
3234
self,
3335
model_id: str,
@@ -39,7 +41,7 @@ def __init__(
3941
rpm_limit: int | None = None,
4042
):
4143
"""Initialize the LLM provider with common parameters.
42-
44+
4345
Args:
4446
model_id: The model identifier
4547
api_key: API key (if None, will get from environment)
@@ -50,7 +52,7 @@ def __init__(
5052
"""
5153
self.model_id = model_id
5254
self.api_key = api_key or self._get_api_key()
53-
55+
5456
# Set generation parameters
5557
self.temperature = temperature
5658
self.max_completion_tokens = max_completion_tokens
@@ -60,22 +62,22 @@ def __init__(
6062
# Rate limiting
6163
self.rpm_limit = rpm_limit
6264
self._request_timestamps: list[float] = []
63-
65+
6466
# Configure environment with API key if needed
6567
self._configure_env()
66-
68+
6769
@property
6870
@abstractmethod
6971
def provider_name(self) -> str:
7072
"""Return the provider name used by LiteLLM."""
7173
pass
72-
74+
7375
@property
7476
@abstractmethod
7577
def env_key_name(self) -> str:
7678
"""Return the environment variable name for API key."""
7779
pass
78-
80+
7981
def _get_api_key(self) -> str:
8082
"""Get API key from environment variables."""
8183
api_key = os.getenv(self.env_key_name)
@@ -85,12 +87,12 @@ def _get_api_key(self) -> str:
8587
f"Please set it or provide an API key when initializing the provider."
8688
)
8789
return api_key
88-
90+
8991
def _configure_env(self) -> None:
9092
"""Configure environment variables for API key."""
9193
if self.api_key:
9294
os.environ[self.env_key_name] = self.api_key
93-
95+
9496
def _get_model_string(self) -> str:
9597
"""Get the full model string for LiteLLM."""
9698
return f"{self.provider_name}/{self.model_id}"
@@ -101,7 +103,8 @@ def _respect_rate_limit(self) -> None:
101103
return
102104
current = time.monotonic()
103105
# Keep only timestamps within the last minute
104-
self._request_timestamps = [ts for ts in self._request_timestamps if current - ts < 60]
106+
self._request_timestamps = [
107+
ts for ts in self._request_timestamps if current - ts < 60]
105108
if len(self._request_timestamps) < self.rpm_limit:
106109
return
107110
# Need to wait until the earliest request is outside the 60-second window
@@ -111,87 +114,132 @@ def _respect_rate_limit(self) -> None:
111114
if sleep_time > 0:
112115
print("Waiting for rate limit...")
113116
time.sleep(sleep_time)
114-
117+
115118
def generate(
116-
self,
117-
prompt: str | None = None,
118-
messages: Messages | None = None,
119-
response_format: Type[T] | None = None
120-
) -> str | T:
121-
"""Generate a response from the LLM.
119+
self,
120+
prompt: str | list[str] | None = None,
121+
messages: list[Messages] | Messages | None = None,
122+
response_format: Type[T] | None = None,
123+
) -> str | list[str] | T | list[T]:
124+
"""
125+
Generate responses from the LLM using single or batch inference.
122126
123127
Args:
124-
prompt: Text prompt (use either prompt or messages, not both)
125-
messages: List of message dictionaries with role and content (use either prompt or messages, not both)
128+
prompt: Single text prompt (str) or list of text prompts for batch processing
129+
messages: Single message list or list of message lists for batch processing
126130
response_format: Optional Pydantic model class for structured output
127131
128132
Returns:
129-
Either a string response or a Pydantic model instance if response_format is provided
130-
133+
Single string/model or list of strings/models depending on input type.
134+
131135
Raises:
132-
ValueError: If neither prompt nor messages is provided, or if both are provided
133-
RuntimeError: If there's an error during generation
136+
ValueError: If neither prompt nor messages is provided, or if both are provided.
137+
RuntimeError: If there's an error during generation.
134138
"""
135139
# Validate inputs
136140
if prompt is None and messages is None:
137-
raise ValueError("Either prompt or messages must be provided")
141+
raise ValueError("Either prompts or messages must be provided")
138142
if prompt is not None and messages is not None:
139-
raise ValueError("Provide either prompt or messages, not both")
140-
143+
raise ValueError("Provide either prompts or messages, not both")
144+
145+
# Determine if this is a single input or batch input
146+
single_input = False
147+
batch_prompts = None
148+
batch_messages = None
149+
150+
if prompt is not None:
151+
if isinstance(prompt, str):
152+
# Single prompt - convert to batch
153+
batch_prompts = [prompt]
154+
single_input = True
155+
elif isinstance(prompt, list):
156+
# Already a list of prompts
157+
batch_prompts = prompt
158+
single_input = False
159+
else:
160+
raise ValueError("prompt must be a string or list of strings")
161+
162+
if messages is not None:
163+
if isinstance(messages, list) and len(messages) > 0:
164+
# Check if it's a single message list or batch
165+
if isinstance(messages[0], dict):
166+
# Single message list - convert to batch
167+
batch_messages = [messages]
168+
single_input = True
169+
elif isinstance(messages[0], list):
170+
# Already a batch of message lists
171+
batch_messages = messages
172+
single_input = False
173+
else:
174+
raise ValueError("Invalid messages format")
175+
else:
176+
raise ValueError("messages cannot be empty")
177+
141178
try:
142-
# Convert string prompt to messages if needed
143-
if prompt is not None:
144-
messages_to_send = get_messages(prompt)
179+
# Convert batch prompts to messages if needed
180+
batch_to_send = []
181+
if batch_prompts is not None:
182+
for one_prompt in batch_prompts:
183+
batch_to_send.append(get_messages(one_prompt))
145184
else:
146-
messages_to_send = messages
147-
148-
# Enforce rate limit if set
185+
batch_to_send = batch_messages
186+
187+
# Enforce rate limit per batch
149188
self._respect_rate_limit()
150-
# Prepare completion parameters
189+
190+
# Prepare completion parameters for batch
151191
completion_params = {
152192
"model": self._get_model_string(),
153-
"messages": messages_to_send,
193+
"messages": batch_to_send,
154194
"temperature": self.temperature,
155195
"max_tokens": self.max_completion_tokens,
156196
"top_p": self.top_p,
157197
"frequency_penalty": self.frequency_penalty,
158198
}
159-
160-
# Add response format if provided
161199
if response_format is not None:
162200
completion_params["response_format"] = response_format
163-
164-
# Call LiteLLM completion
165-
response: ModelResponse = litellm.completion(**completion_params)
201+
202+
# Call LiteLLM completion with batch messages
203+
response: list[ModelResponse] = litellm.batch_completion(
204+
**completion_params)
205+
166206
# Record timestamp for rate limiting
167207
if self.rpm_limit is not None:
168208
self._request_timestamps.append(time.monotonic())
169-
170-
# Extract content from response
171-
content = response.choices[0].message.content
172-
173-
# Parse and validate if response_format is provided
174-
if response_format is not None:
175-
return response_format.model_validate_json(content)
176-
else:
177-
return content
178-
209+
210+
# Extract content from each response
211+
results = []
212+
for one_response in response:
213+
content = one_response.choices[0].message.content
214+
if response_format is not None:
215+
results.append(
216+
response_format.model_validate_json(content))
217+
else:
218+
results.append(content)
219+
220+
# Return single result for backward compatibility
221+
if single_input and len(results) == 1:
222+
return results[0]
223+
return results
224+
179225
except Exception as e:
180226
error_trace = traceback.format_exc()
181-
raise RuntimeError(f"Error generating response with {self.provider_name}:\n{error_trace}")
227+
raise RuntimeError(
228+
f"Error generating batch response with {self.provider_name}:\n{error_trace}"
229+
)
182230

183231

184232
class OpenAIProvider(LLMProvider):
185233
"""OpenAI provider using litellm."""
186-
234+
187235
@property
188236
def provider_name(self) -> str:
189237
return "openai"
190-
238+
191239
@property
192240
def env_key_name(self) -> str:
193241
return "OPENAI_API_KEY"
194-
242+
195243
def __init__(
196244
self,
197245
model_id: str = "gpt-4.1-mini-2025-04-14",
@@ -200,9 +248,9 @@ def __init__(
200248
max_completion_tokens: int | None = None,
201249
top_p: float | None = None,
202250
frequency_penalty: float | None = None,
203-
):
251+
):
204252
"""Initialize the OpenAI provider.
205-
253+
206254
Args:
207255
model_id: The model ID (defaults to gpt-4.1-mini-2025-04-14)
208256
api_key: API key (if None, will get from environment)
@@ -223,15 +271,15 @@ def __init__(
223271

224272
class AnthropicProvider(LLMProvider):
225273
"""Anthropic provider using litellm."""
226-
274+
227275
@property
228276
def provider_name(self) -> str:
229277
return "anthropic"
230-
278+
231279
@property
232280
def env_key_name(self) -> str:
233281
return "ANTHROPIC_API_KEY"
234-
282+
235283
def __init__(
236284
self,
237285
model_id: str = "claude-3-5-haiku-latest",
@@ -240,9 +288,9 @@ def __init__(
240288
max_completion_tokens: int | None = None,
241289
top_p: float | None = None,
242290
# frequency_penalty: float | None = None, # Not supported by anthropic
243-
):
291+
):
244292
"""Initialize the Anthropic provider.
245-
293+
246294
Args:
247295
model_id: The model ID (defaults to claude-3-5-haiku-latest)
248296
api_key: API key (if None, will get from environment)
@@ -261,15 +309,15 @@ def __init__(
261309

262310
class GeminiProvider(LLMProvider):
263311
"""Google Gemini provider using litellm."""
264-
312+
265313
@property
266314
def provider_name(self) -> str:
267315
return "gemini"
268-
316+
269317
@property
270318
def env_key_name(self) -> str:
271319
return "GEMINI_API_KEY"
272-
320+
273321
def __init__(
274322
self,
275323
model_id: str = "gemini-2.0-flash",
@@ -278,10 +326,10 @@ def __init__(
278326
max_completion_tokens: int | None = None,
279327
top_p: float | None = None,
280328
frequency_penalty: float | None = None,
281-
rpm_limit: int | None = None,
282-
):
329+
rpm_limit: int | None = None,
330+
):
283331
"""Initialize the Gemini provider.
284-
332+
285333
Args:
286334
model_id: The model ID (defaults to gemini-2.0-flash)
287335
api_key: API key (if None, will get from environment)
@@ -303,26 +351,26 @@ def __init__(
303351

304352
class OllamaProvider(LLMProvider):
305353
"""Ollama provider using litellm.
306-
354+
307355
Note: Ollama typically doesn't require an API key as it's usually run locally.
308356
"""
309-
357+
310358
@property
311359
def provider_name(self) -> str:
312360
return "ollama_chat"
313-
361+
314362
@property
315363
def env_key_name(self) -> str:
316364
return "OLLAMA_API_BASE"
317-
365+
318366
def _get_api_key(self) -> str:
319367
"""Override to handle Ollama not requiring an API key.
320-
368+
321369
Returns an empty string since Ollama typically doesn't need an API key.
322370
OLLAMA_API_BASE can be used to set a custom base URL.
323371
"""
324372
return ""
325-
373+
326374
def __init__(
327375
self,
328376
model_id: str = "gemma3:4b",
@@ -332,9 +380,9 @@ def __init__(
332380
frequency_penalty: float | None = None,
333381
api_base: str | None = None,
334382
rpm_limit: int | None = None,
335-
):
383+
):
336384
"""Initialize the Ollama provider.
337-
385+
338386
Args:
339387
model_id: The model ID (defaults to llama3)
340388
temperature: Temperature for generation (0.0 to 1.0)
@@ -346,7 +394,7 @@ def __init__(
346394
# Set API base URL if provided
347395
if api_base:
348396
os.environ["OLLAMA_API_BASE"] = api_base
349-
397+
350398
super().__init__(
351399
model_id=model_id,
352400
api_key="", # Pass empty string since parent class requires this parameter

0 commit comments

Comments
 (0)