Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 39 additions & 32 deletions openai_finetuner.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,35 @@
import json
import os
import time
from typing import List, Dict, Any, Tuple, Optional
from typing import List, Dict, Any, Optional, NamedTuple

import yaml
import logging
from openai import OpenAI
import pandas as pd
from dotenv import load_dotenv


def load_config(config_path="config.yml") -> Optional[Dict[str, Any]]:
"""Loads configuration from a YAML file."""
try:
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
print(f"Configuration loaded successfully from {config_path}")
return config
except FileNotFoundError:
print(f"ERROR: Configuration file not found at {config_path}")
return None
except yaml.YAMLError as e:
print(
f"ERROR: Could not parse YAML configuration from {config_path}: {e}")
return None
except Exception as e:
print(f"An unexpected error occurred while loading configuration: {e}")
return None
logger = logging.getLogger(__name__)


def load_config(config_path="config.yml") -> Dict[str, Any]:
"""Loads configuration from a YAML file and applies defaults."""
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f) or {}

if not isinstance(config, dict):
raise ValueError(f"Configuration in {config_path} must be a mapping.")

openai_cfg = config.setdefault('openai_settings', {})
openai_cfg.setdefault('classification_temperature', 0.0)
openai_cfg.setdefault('classification_max_tokens', 200)

prompt_cfg = config.setdefault('prompt_customization', {})
prompt_cfg.setdefault('output_wrapper_tag', 'category')

logger.info("Configuration loaded successfully from %s", config_path)
return config


CONFIG = load_config()
Expand All @@ -47,13 +51,23 @@ def initialize_openai_client() -> Optional[OpenAI]:
return None



class ClassificationData(NamedTuple):
train_data: pd.DataFrame
test_data: pd.DataFrame
train_texts: List[str]
train_labels: List[str]
test_texts: List[str]
test_labels: List[str]
categories: List[str]

def load_classification_data(
train_file_path: str,
test_file_path: str,
sep: str = "\t",
text_column: str = "text",
label_column: str = "label"
) -> Optional[Tuple[pd.DataFrame, pd.DataFrame, List[str], List[str], List[str], List[str], List[str]]]:
) -> Optional[ClassificationData]:
"""
Loads training and testing data from specified TSV/CSV files.
Uses text_column and label_column from global CONFIG.
Expand Down Expand Up @@ -93,7 +107,7 @@ def load_classification_data(
test_texts = test_data[text_col].astype(str).tolist()
test_labels = test_data[label_col].astype(str).tolist()

return train_data, test_data, train_texts, train_labels, test_texts, test_labels, categories
return ClassificationData(train_data, test_data, train_texts, train_labels, test_texts, test_labels, categories)

except FileNotFoundError as e:
print(f"ERROR: Data file not found: {e.filename}")
Expand Down Expand Up @@ -378,18 +392,11 @@ def classify_items_with_model(
Classifies a list of items using a specified OpenAI model.
Reads prompt, temperature, and max_tokens from global CONFIG.
"""
if not CONFIG:
print("ERROR: Configuration not loaded. Using default classification settings.")
temperature = 0.0
max_tokens = 200
output_wrapper_tag = "category"
else:
# Assuming these might be added to config.yml under openai_settings or a new section
openai_cfg = CONFIG.get('openai_settings', {})
temperature = openai_cfg.get('classification_temperature', 0.0)
max_tokens = openai_cfg.get('classification_max_tokens', 200)
prompt_cfg = CONFIG.get('prompt_customization', {})
output_wrapper_tag = prompt_cfg.get('output_wrapper_tag', "category")
openai_cfg = CONFIG.get('openai_settings', {})
temperature = openai_cfg.get('classification_temperature', 0.0)
max_tokens = openai_cfg.get('classification_max_tokens', 200)
prompt_cfg = CONFIG.get('prompt_customization', {})
output_wrapper_tag = prompt_cfg.get('output_wrapper_tag', "category")

responses = []
stop_sequence = [f"</{output_wrapper_tag}>"]
Expand Down