Skip to content

Commit 6601a29

Browse files
Merge pull request #21 from patrickfleith/20-ultrachat-human-model-interaction-dataset-on-topics-and-subtopics
20 ultrachat human model interaction dataset on topics and subtopics
2 parents 8f9c70b + 4835da0 commit 6601a29

7 files changed

Lines changed: 407 additions & 46 deletions

File tree

datafast/datasets.py

Lines changed: 226 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,54 @@
11
from abc import ABC, abstractmethod
2-
from ast import Not
2+
import numpy as np
33
from pydantic import BaseModel, Field
44
from pathlib import Path
55
from typing import Any, Optional
66
from datasets import Dataset
77
from huggingface_hub import HfApi
88
from datafast.llms import LLMProvider
9-
from datafast.prompts import classification_prompts, text_prompts
10-
from datafast.schema.config import ClassificationConfig, TextDatasetConfig
11-
from datafast.schema.data_rows import TextClassificationRow, LabelSource, TextRow, TextSource
9+
from datafast.prompts import (
10+
classification_prompts,
11+
question_generation_prompts,
12+
text_prompts,
13+
)
14+
from datafast.schema.config import (
15+
ClassificationConfig,
16+
TextDatasetConfig,
17+
UltraChatDatasetConfig,
18+
)
19+
from datafast.schema.data_rows import (
20+
ChatRow,
21+
TextClassificationRow,
22+
LabelSource,
23+
TextRow,
24+
TextSource,
25+
)
1226
from datafast.expanders import expand_prompts
1327
import os
1428

1529

30+
class TextEntries(BaseModel):
31+
entries: list[str] = Field(..., description="List of generated texts")
32+
33+
34+
class UserQuestions(BaseModel):
35+
questions: list[str] = Field(..., description="List of user questions")
36+
37+
38+
class ReformulatedUserQuestion(BaseModel):
39+
question: str = Field(..., description="Reformulated user question")
40+
41+
42+
class Answer(BaseModel):
43+
answer: str = Field(..., description="Answer to the user question")
44+
45+
46+
class FollowupQuestion(BaseModel):
47+
question: str = Field(
48+
..., description="Followup question of a user to an AI assistant response."
49+
)
50+
51+
1652
class DatasetBase(ABC):
1753
"""Abstract base class for all dataset generators."""
1854

@@ -144,12 +180,6 @@ def push_to_hub(
144180
return f"https://huggingface.co/datasets/{repo_id}"
145181

146182

147-
class TextEntries(BaseModel):
148-
entries: list[str] = Field(
149-
..., description="List of generated texts"
150-
)
151-
152-
153183
class TextClassificationDataset(DatasetBase):
154184
def __init__(self, config: ClassificationConfig):
155185
super().__init__(config)
@@ -233,11 +263,9 @@ def _get_default_prompts(self) -> list[str]:
233263

234264

235265
class TextDataset(DatasetBase):
236-
237266
def __init__(self, config: TextDatasetConfig):
238267
super().__init__(config)
239268
self.config = config
240-
241269

242270
def generate(self, llms: list[LLMProvider]) -> "TextDataset":
243271
"""Generate text data by calling multiple providers.
@@ -275,10 +303,10 @@ def generate(self, llms: list[LLMProvider]) -> "TextDataset":
275303
for prompt in base_prompts
276304
]
277305

278-
279306
# 2. Expand prompts with configured variations
280307
expansions = expand_prompts(
281-
prompt_templates=base_prompts, **self.config.expansion.model_dump()
308+
prompt_templates=base_prompts,
309+
**self.config.expansion.model_dump(),
282310
)
283311

284312
# 3. For each expanded prompt, call each provider
@@ -287,8 +315,7 @@ def generate(self, llms: list[LLMProvider]) -> "TextDataset":
287315
try:
288316
# Generate multiple examples using the LLM
289317
response = llm.generate(
290-
expanded_prompt,
291-
response_format=TextEntries
318+
expanded_prompt, response_format=TextEntries
292319
)
293320

294321
# Create a row for each generated example
@@ -301,25 +328,201 @@ def generate(self, llms: list[LLMProvider]) -> "TextDataset":
301328
"language": lang_code,
302329
"document_type": document_type,
303330
"topic": topic,
304-
}
331+
},
305332
)
306333
self.data_rows.append(row)
307-
print(f" Generated total of {len(self.data_rows)} examples")
334+
print(
335+
f" Generated total of {len(self.data_rows)} examples"
336+
)
308337

309338
except Exception as e:
310339
print(f"Error with llm provider {llm.name}: {e}")
311340

312-
313341
# Final save at the end
314342
self.to_jsonl(self.config.output_file)
315343
return self
316344

317-
318345
def _get_default_prompts(self) -> list[str]:
319346
"""Return the default prompt templates for text generation."""
320347
return text_prompts.DEFAULT_TEMPLATES
321348

322-
323-
324349

325-
350+
class UltraChatDataset(DatasetBase):
351+
def __init__(self, config: UltraChatDatasetConfig):
352+
super().__init__(config)
353+
self.config = config
354+
355+
def generate(self, llms: list[LLMProvider]) -> "TextDataset":
356+
if not llms:
357+
raise ValueError("At least one LLM provider must be supplied")
358+
359+
# Get languages from config, default to English if not specified
360+
languages = self.config.languages or {"en": "English"}
361+
362+
# For each language, generate examples using all providers
363+
for lang_code, language_name in languages.items():
364+
for topic, subtopics in self.config.topics_and_subtopics.items():
365+
for subtopic in subtopics:
366+
# 1. Create base prompts for this language
367+
base_prompts = (
368+
self.config.question_generation_prompts
369+
or self._get_default_question_generation_prompts()
370+
)
371+
372+
base_prompts = [
373+
prompt.format(
374+
num_samples=self.config.num_samples,
375+
language_name=language_name,
376+
domain=self.config.domain,
377+
topic=topic,
378+
subtopic=subtopic,
379+
)
380+
for prompt in base_prompts
381+
]
382+
383+
# 2. Expand prompts with configured variations
384+
expansions = expand_prompts(
385+
prompt_templates=base_prompts,
386+
**self.config.expansion.model_dump(),
387+
)
388+
389+
# 3. For each expanded prompt, call each provider in UltraChat iteration
390+
for i, (expanded_prompt, meta) in enumerate(expansions, 1):
391+
for llm in llms:
392+
try:
393+
# Generate multiple examples using the LLM
394+
# --- Here goes the ultraChat loop ---
395+
opening_questions = llm.generate(
396+
expanded_prompt, response_format=UserQuestions
397+
)
398+
399+
for opening_question in opening_questions.questions:
400+
random_persona = np.random.choice(
401+
self.config.personas
402+
)
403+
reformulation_prompt = self._get_default_persona_question_reformulation_prompt()
404+
reformulated_question = llm.generate(
405+
prompt=reformulation_prompt.format(
406+
question=opening_question,
407+
persona=random_persona,
408+
topic=topic,
409+
subtopic=subtopic,
410+
),
411+
response_format=ReformulatedUserQuestion,
412+
)
413+
414+
# simulate the assistant response to the opening question
415+
assistant_prompt = (
416+
self._get_default_simulated_assistant_prompt()
417+
)
418+
assistant_response = llm.generate(
419+
prompt=assistant_prompt.format(
420+
domain=self.config.domain,
421+
topic=topic,
422+
subtopic=subtopic,
423+
question=reformulated_question.question,
424+
),
425+
response_format=Answer,
426+
)
427+
428+
# choose to continue the conversation or not (proba 0.5)
429+
count = 1
430+
messages = [
431+
{
432+
"role": "user",
433+
"content": reformulated_question.question,
434+
},
435+
{
436+
"role": "assistant",
437+
"content": assistant_response.answer,
438+
},
439+
]
440+
441+
# assemble the dialog to prompt the user
442+
dialog_summary = f"{reformulated_question.question}\n{assistant_response.answer}"
443+
444+
while (count < self.config.max_turns) and (
445+
np.random.random()
446+
< self.config.conversation_continuation_prob
447+
):
448+
# simulate the user follow-up question
449+
followup_prompt = (
450+
self._get_default_user_followup_prompt()
451+
)
452+
followup_question = llm.generate(
453+
prompt=followup_prompt.format(
454+
dialog_summary=dialog_summary,
455+
persona=random_persona,
456+
subtopic=subtopic,
457+
domain=self.config.domain,
458+
),
459+
response_format=ReformulatedUserQuestion,
460+
)
461+
# simulate the assistant response
462+
messages.append(
463+
{
464+
"role": "user",
465+
"content": followup_question.question,
466+
}
467+
)
468+
ai_response = llm.generate(
469+
messages=messages, response_format=Answer
470+
)
471+
472+
dialog_summary += f"\n{followup_question.question}\n{ai_response.answer}"
473+
messages.append(
474+
{
475+
"role": "assistant",
476+
"content": ai_response.answer,
477+
}
478+
)
479+
480+
count += 1
481+
if count >= self.config.max_turns:
482+
break
483+
484+
# Create a row for each generated example
485+
row = ChatRow(
486+
opening_question=messages[0]["content"],
487+
messages=messages,
488+
model_id=llm.model_id,
489+
metadata={
490+
"language": lang_code,
491+
"domain": self.config.domain,
492+
"topic": topic,
493+
"subtopic": subtopic,
494+
},
495+
persona=random_persona,
496+
)
497+
self.data_rows.append(row)
498+
print(
499+
f" Generated total of {len(self.data_rows)} examples"
500+
)
501+
502+
except Exception as e:
503+
import traceback
504+
error_trace = traceback.format_exc()
505+
print(f"\nError with llm provider {llm.name}:\n{error_trace}")
506+
print(f"Error occurred at response type: {response_format.__name__ if 'response_format' in locals() else 'unknown'}")
507+
if 'reformulated_question' in locals():
508+
print(f"Last reformulated_question: {reformulated_question}")
509+
510+
self.to_jsonl(self.config.output_file)
511+
return self
512+
513+
def _get_default_question_generation_prompts(self) -> list[str]:
514+
return question_generation_prompts.DOMAIN_TOPIC_SUBTOPIC_N_QUESTION_GENERATION_DEFAULT_TEMPLATES
515+
516+
def _get_default_persona_question_reformulation_prompt(self) -> str:
517+
return (
518+
question_generation_prompts.PERSONA_QUESTION_REFORMULATION_DEFAULT_TEMPLATE
519+
)
520+
521+
def _get_default_simulated_assistant_prompt(self) -> str:
522+
return question_generation_prompts.SIMULATED_ASSISTANT_DEFAULT_TEMPLATE
523+
524+
# def _get_default_user_system_prompt(self) -> str:
525+
# return question_generation_prompts.USER_SYSTEM_PROMPT_TEMPLATE
526+
527+
def _get_default_user_followup_prompt(self) -> str:
528+
return question_generation_prompts.USER_FOLLOWUP_PROMPT_TEMPLATE
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from datafast.datasets import UltraChatDataset, ChatRow
2+
from datafast.schema.config import UltraChatDatasetConfig
3+
from datafast.llms import OpenAIProvider
4+
from dotenv import load_dotenv
5+
6+
def main():
7+
config = UltraChatDatasetConfig(
8+
domain="Materials Science",
9+
topics_and_subtopics={
10+
"Polymers" : ["Design", "Testing"],
11+
"Alloys": ["Nickel", "Steel"]
12+
},
13+
personas=[
14+
"Master Student at Paris Saclay University, aspiring to pursue a PhD",
15+
"Head of Materials Science and Technologies Department at Airbus Defense and Space"
16+
],
17+
num_samples=4,
18+
output_file="materials_science_example_instruction.jsonl",
19+
)
20+
21+
# 2. Create LLM providers with specific models
22+
providers = [
23+
OpenAIProvider(model_id="gpt-4o-mini"),
24+
]
25+
26+
# 3. Generate the dataset
27+
dataset = UltraChatDataset(config)
28+
dataset.generate(providers)
29+
30+
31+
if __name__ == "__main__":
32+
from dotenv import load_dotenv
33+
load_dotenv("secrets.env")
34+
main()

0 commit comments

Comments
 (0)