1616# LiteLLM
1717import litellm
1818from litellm .utils import ModelResponse
19+ from litellm import batch_completion
1920
2021# Internal imports
2122from .llm_utils import get_messages
2526Messages = list [Message ]
2627T = TypeVar ('T' , bound = BaseModel )
2728
29+
2830class LLMProvider (ABC ):
2931 """Abstract base class for LLM providers."""
30-
32+
3133 def __init__ (
3234 self ,
3335 model_id : str ,
@@ -39,7 +41,7 @@ def __init__(
3941 rpm_limit : int | None = None ,
4042 ):
4143 """Initialize the LLM provider with common parameters.
42-
44+
4345 Args:
4446 model_id: The model identifier
4547 api_key: API key (if None, will get from environment)
@@ -50,7 +52,7 @@ def __init__(
5052 """
5153 self .model_id = model_id
5254 self .api_key = api_key or self ._get_api_key ()
53-
55+
5456 # Set generation parameters
5557 self .temperature = temperature
5658 self .max_completion_tokens = max_completion_tokens
@@ -60,22 +62,22 @@ def __init__(
6062 # Rate limiting
6163 self .rpm_limit = rpm_limit
6264 self ._request_timestamps : list [float ] = []
63-
65+
6466 # Configure environment with API key if needed
6567 self ._configure_env ()
66-
68+
6769 @property
6870 @abstractmethod
6971 def provider_name (self ) -> str :
7072 """Return the provider name used by LiteLLM."""
7173 pass
72-
74+
7375 @property
7476 @abstractmethod
7577 def env_key_name (self ) -> str :
7678 """Return the environment variable name for API key."""
7779 pass
78-
80+
7981 def _get_api_key (self ) -> str :
8082 """Get API key from environment variables."""
8183 api_key = os .getenv (self .env_key_name )
@@ -85,12 +87,12 @@ def _get_api_key(self) -> str:
8587 f"Please set it or provide an API key when initializing the provider."
8688 )
8789 return api_key
88-
90+
8991 def _configure_env (self ) -> None :
9092 """Configure environment variables for API key."""
9193 if self .api_key :
9294 os .environ [self .env_key_name ] = self .api_key
93-
95+
9496 def _get_model_string (self ) -> str :
9597 """Get the full model string for LiteLLM."""
9698 return f"{ self .provider_name } /{ self .model_id } "
@@ -101,7 +103,8 @@ def _respect_rate_limit(self) -> None:
101103 return
102104 current = time .monotonic ()
103105 # Keep only timestamps within the last minute
104- self ._request_timestamps = [ts for ts in self ._request_timestamps if current - ts < 60 ]
106+ self ._request_timestamps = [
107+ ts for ts in self ._request_timestamps if current - ts < 60 ]
105108 if len (self ._request_timestamps ) < self .rpm_limit :
106109 return
107110 # Need to wait until the earliest request is outside the 60-second window
@@ -111,87 +114,132 @@ def _respect_rate_limit(self) -> None:
111114 if sleep_time > 0 :
112115 print ("Waiting for rate limit..." )
113116 time .sleep (sleep_time )
114-
117+
115118 def generate (
116- self ,
117- prompt : str | None = None ,
118- messages : Messages | None = None ,
119- response_format : Type [T ] | None = None
120- ) -> str | T :
121- """Generate a response from the LLM.
119+ self ,
120+ prompt : str | list [str ] | None = None ,
121+ messages : list [Messages ] | Messages | None = None ,
122+ response_format : Type [T ] | None = None ,
123+ ) -> str | list [str ] | T | list [T ]:
124+ """
125+ Generate responses from the LLM using single or batch inference.
122126
123127 Args:
124- prompt: Text prompt (use either prompt or messages, not both)
125- messages: List of message dictionaries with role and content (use either prompt or messages, not both)
128+ prompt: Single text prompt (str) or list of text prompts for batch processing
129+ messages: Single message list or list of message lists for batch processing
126130 response_format: Optional Pydantic model class for structured output
127131
128132 Returns:
129- Either a string response or a Pydantic model instance if response_format is provided
130-
133+ Single string/model or list of strings/models depending on input type.
134+
131135 Raises:
132- ValueError: If neither prompt nor messages is provided, or if both are provided
133- RuntimeError: If there's an error during generation
136+ ValueError: If neither prompt nor messages is provided, or if both are provided.
137+ RuntimeError: If there's an error during generation.
134138 """
135139 # Validate inputs
136140 if prompt is None and messages is None :
137- raise ValueError ("Either prompt or messages must be provided" )
141+ raise ValueError ("Either prompts or messages must be provided" )
138142 if prompt is not None and messages is not None :
139- raise ValueError ("Provide either prompt or messages, not both" )
140-
143+ raise ValueError ("Provide either prompts or messages, not both" )
144+
145+ # Determine if this is a single input or batch input
146+ single_input = False
147+ batch_prompts = None
148+ batch_messages = None
149+
150+ if prompt is not None :
151+ if isinstance (prompt , str ):
152+ # Single prompt - convert to batch
153+ batch_prompts = [prompt ]
154+ single_input = True
155+ elif isinstance (prompt , list ):
156+ # Already a list of prompts
157+ batch_prompts = prompt
158+ single_input = False
159+ else :
160+ raise ValueError ("prompt must be a string or list of strings" )
161+
162+ if messages is not None :
163+ if isinstance (messages , list ) and len (messages ) > 0 :
164+ # Check if it's a single message list or batch
165+ if isinstance (messages [0 ], dict ):
166+ # Single message list - convert to batch
167+ batch_messages = [messages ]
168+ single_input = True
169+ elif isinstance (messages [0 ], list ):
170+ # Already a batch of message lists
171+ batch_messages = messages
172+ single_input = False
173+ else :
174+ raise ValueError ("Invalid messages format" )
175+ else :
176+ raise ValueError ("messages cannot be empty" )
177+
141178 try :
142- # Convert string prompt to messages if needed
143- if prompt is not None :
144- messages_to_send = get_messages (prompt )
179+ # Convert batch prompts to messages if needed
180+ batch_to_send = []
181+ if batch_prompts is not None :
182+ for one_prompt in batch_prompts :
183+ batch_to_send .append (get_messages (one_prompt ))
145184 else :
146- messages_to_send = messages
147-
148- # Enforce rate limit if set
185+ batch_to_send = batch_messages
186+
187+ # Enforce rate limit per batch
149188 self ._respect_rate_limit ()
150- # Prepare completion parameters
189+
190+ # Prepare completion parameters for batch
151191 completion_params = {
152192 "model" : self ._get_model_string (),
153- "messages" : messages_to_send ,
193+ "messages" : batch_to_send ,
154194 "temperature" : self .temperature ,
155195 "max_tokens" : self .max_completion_tokens ,
156196 "top_p" : self .top_p ,
157197 "frequency_penalty" : self .frequency_penalty ,
158198 }
159-
160- # Add response format if provided
161199 if response_format is not None :
162200 completion_params ["response_format" ] = response_format
163-
164- # Call LiteLLM completion
165- response : ModelResponse = litellm .completion (** completion_params )
201+
202+ # Call LiteLLM completion with batch messages
203+ response : list [ModelResponse ] = litellm .batch_completion (
204+ ** completion_params )
205+
166206 # Record timestamp for rate limiting
167207 if self .rpm_limit is not None :
168208 self ._request_timestamps .append (time .monotonic ())
169-
170- # Extract content from response
171- content = response .choices [0 ].message .content
172-
173- # Parse and validate if response_format is provided
174- if response_format is not None :
175- return response_format .model_validate_json (content )
176- else :
177- return content
178-
209+
210+ # Extract content from each response
211+ results = []
212+ for one_response in response :
213+ content = one_response .choices [0 ].message .content
214+ if response_format is not None :
215+ results .append (
216+ response_format .model_validate_json (content ))
217+ else :
218+ results .append (content )
219+
220+ # Return single result for backward compatibility
221+ if single_input and len (results ) == 1 :
222+ return results [0 ]
223+ return results
224+
179225 except Exception as e :
180226 error_trace = traceback .format_exc ()
181- raise RuntimeError (f"Error generating response with { self .provider_name } :\n { error_trace } " )
227+ raise RuntimeError (
228+ f"Error generating batch response with { self .provider_name } :\n { error_trace } "
229+ )
182230
183231
184232class OpenAIProvider (LLMProvider ):
185233 """OpenAI provider using litellm."""
186-
234+
187235 @property
188236 def provider_name (self ) -> str :
189237 return "openai"
190-
238+
191239 @property
192240 def env_key_name (self ) -> str :
193241 return "OPENAI_API_KEY"
194-
242+
195243 def __init__ (
196244 self ,
197245 model_id : str = "gpt-4.1-mini-2025-04-14" ,
@@ -200,9 +248,9 @@ def __init__(
200248 max_completion_tokens : int | None = None ,
201249 top_p : float | None = None ,
202250 frequency_penalty : float | None = None ,
203- ):
251+ ):
204252 """Initialize the OpenAI provider.
205-
253+
206254 Args:
207255 model_id: The model ID (defaults to gpt-4.1-mini-2025-04-14)
208256 api_key: API key (if None, will get from environment)
@@ -223,15 +271,15 @@ def __init__(
223271
224272class AnthropicProvider (LLMProvider ):
225273 """Anthropic provider using litellm."""
226-
274+
227275 @property
228276 def provider_name (self ) -> str :
229277 return "anthropic"
230-
278+
231279 @property
232280 def env_key_name (self ) -> str :
233281 return "ANTHROPIC_API_KEY"
234-
282+
235283 def __init__ (
236284 self ,
237285 model_id : str = "claude-3-5-haiku-latest" ,
@@ -240,9 +288,9 @@ def __init__(
240288 max_completion_tokens : int | None = None ,
241289 top_p : float | None = None ,
242290 # frequency_penalty: float | None = None, # Not supported by anthropic
243- ):
291+ ):
244292 """Initialize the Anthropic provider.
245-
293+
246294 Args:
247295 model_id: The model ID (defaults to claude-3-5-haiku-latest)
248296 api_key: API key (if None, will get from environment)
@@ -261,15 +309,15 @@ def __init__(
261309
262310class GeminiProvider (LLMProvider ):
263311 """Google Gemini provider using litellm."""
264-
312+
265313 @property
266314 def provider_name (self ) -> str :
267315 return "gemini"
268-
316+
269317 @property
270318 def env_key_name (self ) -> str :
271319 return "GEMINI_API_KEY"
272-
320+
273321 def __init__ (
274322 self ,
275323 model_id : str = "gemini-2.0-flash" ,
@@ -278,10 +326,10 @@ def __init__(
278326 max_completion_tokens : int | None = None ,
279327 top_p : float | None = None ,
280328 frequency_penalty : float | None = None ,
281- rpm_limit : int | None = None ,
282- ):
329+ rpm_limit : int | None = None ,
330+ ):
283331 """Initialize the Gemini provider.
284-
332+
285333 Args:
286334 model_id: The model ID (defaults to gemini-2.0-flash)
287335 api_key: API key (if None, will get from environment)
@@ -303,26 +351,26 @@ def __init__(
303351
304352class OllamaProvider (LLMProvider ):
305353 """Ollama provider using litellm.
306-
354+
307355 Note: Ollama typically doesn't require an API key as it's usually run locally.
308356 """
309-
357+
310358 @property
311359 def provider_name (self ) -> str :
312360 return "ollama_chat"
313-
361+
314362 @property
315363 def env_key_name (self ) -> str :
316364 return "OLLAMA_API_BASE"
317-
365+
318366 def _get_api_key (self ) -> str :
319367 """Override to handle Ollama not requiring an API key.
320-
368+
321369 Returns an empty string since Ollama typically doesn't need an API key.
322370 OLLAMA_API_BASE can be used to set a custom base URL.
323371 """
324372 return ""
325-
373+
326374 def __init__ (
327375 self ,
328376 model_id : str = "gemma3:4b" ,
@@ -332,9 +380,9 @@ def __init__(
332380 frequency_penalty : float | None = None ,
333381 api_base : str | None = None ,
334382 rpm_limit : int | None = None ,
335- ):
383+ ):
336384 """Initialize the Ollama provider.
337-
385+
338386 Args:
339387 model_id: The model ID (defaults to llama3)
340388 temperature: Temperature for generation (0.0 to 1.0)
@@ -346,7 +394,7 @@ def __init__(
346394 # Set API base URL if provided
347395 if api_base :
348396 os .environ ["OLLAMA_API_BASE" ] = api_base
349-
397+
350398 super ().__init__ (
351399 model_id = model_id ,
352400 api_key = "" , # Pass empty string since parent class requires this parameter
0 commit comments