11from abc import ABC , abstractmethod
2- from ast import Not
2+ import numpy as np
33from pydantic import BaseModel , Field
44from pathlib import Path
55from typing import Any , Optional
66from datasets import Dataset
77from huggingface_hub import HfApi
88from datafast .llms import LLMProvider
9- from datafast .prompts import classification_prompts , text_prompts
10- from datafast .schema .config import ClassificationConfig , TextDatasetConfig
11- from datafast .schema .data_rows import TextClassificationRow , LabelSource , TextRow , TextSource
9+ from datafast .prompts import (
10+ classification_prompts ,
11+ question_generation_prompts ,
12+ text_prompts ,
13+ )
14+ from datafast .schema .config import (
15+ ClassificationConfig ,
16+ TextDatasetConfig ,
17+ UltraChatDatasetConfig ,
18+ )
19+ from datafast .schema .data_rows import (
20+ ChatRow ,
21+ TextClassificationRow ,
22+ LabelSource ,
23+ TextRow ,
24+ TextSource ,
25+ )
1226from datafast .expanders import expand_prompts
1327import os
1428
1529
30+ class TextEntries (BaseModel ):
31+ entries : list [str ] = Field (..., description = "List of generated texts" )
32+
33+
34+ class UserQuestions (BaseModel ):
35+ questions : list [str ] = Field (..., description = "List of user questions" )
36+
37+
38+ class ReformulatedUserQuestion (BaseModel ):
39+ question : str = Field (..., description = "Reformulated user question" )
40+
41+
42+ class Answer (BaseModel ):
43+ answer : str = Field (..., description = "Answer to the user question" )
44+
45+
46+ class FollowupQuestion (BaseModel ):
47+ question : str = Field (
48+ ..., description = "Followup question of a user to an AI assistant response."
49+ )
50+
51+
1652class DatasetBase (ABC ):
1753 """Abstract base class for all dataset generators."""
1854
@@ -144,12 +180,6 @@ def push_to_hub(
144180 return f"https://huggingface.co/datasets/{ repo_id } "
145181
146182
147- class TextEntries (BaseModel ):
148- entries : list [str ] = Field (
149- ..., description = "List of generated texts"
150- )
151-
152-
153183class TextClassificationDataset (DatasetBase ):
154184 def __init__ (self , config : ClassificationConfig ):
155185 super ().__init__ (config )
@@ -233,11 +263,9 @@ def _get_default_prompts(self) -> list[str]:
233263
234264
235265class TextDataset (DatasetBase ):
236-
237266 def __init__ (self , config : TextDatasetConfig ):
238267 super ().__init__ (config )
239268 self .config = config
240-
241269
242270 def generate (self , llms : list [LLMProvider ]) -> "TextDataset" :
243271 """Generate text data by calling multiple providers.
@@ -275,10 +303,10 @@ def generate(self, llms: list[LLMProvider]) -> "TextDataset":
275303 for prompt in base_prompts
276304 ]
277305
278-
279306 # 2. Expand prompts with configured variations
280307 expansions = expand_prompts (
281- prompt_templates = base_prompts , ** self .config .expansion .model_dump ()
308+ prompt_templates = base_prompts ,
309+ ** self .config .expansion .model_dump (),
282310 )
283311
284312 # 3. For each expanded prompt, call each provider
@@ -287,8 +315,7 @@ def generate(self, llms: list[LLMProvider]) -> "TextDataset":
287315 try :
288316 # Generate multiple examples using the LLM
289317 response = llm .generate (
290- expanded_prompt ,
291- response_format = TextEntries
318+ expanded_prompt , response_format = TextEntries
292319 )
293320
294321 # Create a row for each generated example
@@ -301,25 +328,201 @@ def generate(self, llms: list[LLMProvider]) -> "TextDataset":
301328 "language" : lang_code ,
302329 "document_type" : document_type ,
303330 "topic" : topic ,
304- }
331+ },
305332 )
306333 self .data_rows .append (row )
307- print (f" Generated total of { len (self .data_rows )} examples" )
334+ print (
335+ f" Generated total of { len (self .data_rows )} examples"
336+ )
308337
309338 except Exception as e :
310339 print (f"Error with llm provider { llm .name } : { e } " )
311340
312-
313341 # Final save at the end
314342 self .to_jsonl (self .config .output_file )
315343 return self
316344
317-
318345 def _get_default_prompts (self ) -> list [str ]:
319346 """Return the default prompt templates for text generation."""
320347 return text_prompts .DEFAULT_TEMPLATES
321348
322-
323-
324349
325-
350+ class UltraChatDataset (DatasetBase ):
351+ def __init__ (self , config : UltraChatDatasetConfig ):
352+ super ().__init__ (config )
353+ self .config = config
354+
355+ def generate (self , llms : list [LLMProvider ]) -> "TextDataset" :
356+ if not llms :
357+ raise ValueError ("At least one LLM provider must be supplied" )
358+
359+ # Get languages from config, default to English if not specified
360+ languages = self .config .languages or {"en" : "English" }
361+
362+ # For each language, generate examples using all providers
363+ for lang_code , language_name in languages .items ():
364+ for topic , subtopics in self .config .topics_and_subtopics .items ():
365+ for subtopic in subtopics :
366+ # 1. Create base prompts for this language
367+ base_prompts = (
368+ self .config .question_generation_prompts
369+ or self ._get_default_question_generation_prompts ()
370+ )
371+
372+ base_prompts = [
373+ prompt .format (
374+ num_samples = self .config .num_samples ,
375+ language_name = language_name ,
376+ domain = self .config .domain ,
377+ topic = topic ,
378+ subtopic = subtopic ,
379+ )
380+ for prompt in base_prompts
381+ ]
382+
383+ # 2. Expand prompts with configured variations
384+ expansions = expand_prompts (
385+ prompt_templates = base_prompts ,
386+ ** self .config .expansion .model_dump (),
387+ )
388+
389+ # 3. For each expanded prompt, call each provider in UltraChat iteration
390+ for i , (expanded_prompt , meta ) in enumerate (expansions , 1 ):
391+ for llm in llms :
392+ try :
393+ # Generate multiple examples using the LLM
394+ # --- Here goes the ultraChat loop ---
395+ opening_questions = llm .generate (
396+ expanded_prompt , response_format = UserQuestions
397+ )
398+
399+ for opening_question in opening_questions .questions :
400+ random_persona = np .random .choice (
401+ self .config .personas
402+ )
403+ reformulation_prompt = self ._get_default_persona_question_reformulation_prompt ()
404+ reformulated_question = llm .generate (
405+ prompt = reformulation_prompt .format (
406+ question = opening_question ,
407+ persona = random_persona ,
408+ topic = topic ,
409+ subtopic = subtopic ,
410+ ),
411+ response_format = ReformulatedUserQuestion ,
412+ )
413+
414+ # simulate the assistant response to the opening question
415+ assistant_prompt = (
416+ self ._get_default_simulated_assistant_prompt ()
417+ )
418+ assistant_response = llm .generate (
419+ prompt = assistant_prompt .format (
420+ domain = self .config .domain ,
421+ topic = topic ,
422+ subtopic = subtopic ,
423+ question = reformulated_question .question ,
424+ ),
425+ response_format = Answer ,
426+ )
427+
428+ # choose to continue the conversation or not (proba 0.5)
429+ count = 1
430+ messages = [
431+ {
432+ "role" : "user" ,
433+ "content" : reformulated_question .question ,
434+ },
435+ {
436+ "role" : "assistant" ,
437+ "content" : assistant_response .answer ,
438+ },
439+ ]
440+
441+ # assemble the dialog to prompt the user
442+ dialog_summary = f"{ reformulated_question .question } \n { assistant_response .answer } "
443+
444+ while (count < self .config .max_turns ) and (
445+ np .random .random ()
446+ < self .config .conversation_continuation_prob
447+ ):
448+ # simulate the user follow-up question
449+ followup_prompt = (
450+ self ._get_default_user_followup_prompt ()
451+ )
452+ followup_question = llm .generate (
453+ prompt = followup_prompt .format (
454+ dialog_summary = dialog_summary ,
455+ persona = random_persona ,
456+ subtopic = subtopic ,
457+ domain = self .config .domain ,
458+ ),
459+ response_format = ReformulatedUserQuestion ,
460+ )
461+ # simulate the assistant response
462+ messages .append (
463+ {
464+ "role" : "user" ,
465+ "content" : followup_question .question ,
466+ }
467+ )
468+ ai_response = llm .generate (
469+ messages = messages , response_format = Answer
470+ )
471+
472+ dialog_summary += f"\n { followup_question .question } \n { ai_response .answer } "
473+ messages .append (
474+ {
475+ "role" : "assistant" ,
476+ "content" : ai_response .answer ,
477+ }
478+ )
479+
480+ count += 1
481+ if count >= self .config .max_turns :
482+ break
483+
484+ # Create a row for each generated example
485+ row = ChatRow (
486+ opening_question = messages [0 ]["content" ],
487+ messages = messages ,
488+ model_id = llm .model_id ,
489+ metadata = {
490+ "language" : lang_code ,
491+ "domain" : self .config .domain ,
492+ "topic" : topic ,
493+ "subtopic" : subtopic ,
494+ },
495+ persona = random_persona ,
496+ )
497+ self .data_rows .append (row )
498+ print (
499+ f" Generated total of { len (self .data_rows )} examples"
500+ )
501+
502+ except Exception as e :
503+ import traceback
504+ error_trace = traceback .format_exc ()
505+ print (f"\n Error with llm provider { llm .name } :\n { error_trace } " )
506+ print (f"Error occurred at response type: { response_format .__name__ if 'response_format' in locals () else 'unknown' } " )
507+ if 'reformulated_question' in locals ():
508+ print (f"Last reformulated_question: { reformulated_question } " )
509+
510+ self .to_jsonl (self .config .output_file )
511+ return self
512+
513+ def _get_default_question_generation_prompts (self ) -> list [str ]:
514+ return question_generation_prompts .DOMAIN_TOPIC_SUBTOPIC_N_QUESTION_GENERATION_DEFAULT_TEMPLATES
515+
516+ def _get_default_persona_question_reformulation_prompt (self ) -> str :
517+ return (
518+ question_generation_prompts .PERSONA_QUESTION_REFORMULATION_DEFAULT_TEMPLATE
519+ )
520+
521+ def _get_default_simulated_assistant_prompt (self ) -> str :
522+ return question_generation_prompts .SIMULATED_ASSISTANT_DEFAULT_TEMPLATE
523+
524+ # def _get_default_user_system_prompt(self) -> str:
525+ # return question_generation_prompts.USER_SYSTEM_PROMPT_TEMPLATE
526+
527+ def _get_default_user_followup_prompt (self ) -> str :
528+ return question_generation_prompts .USER_FOLLOWUP_PROMPT_TEMPLATE
0 commit comments