From 796d57a6d89a828fb54600d6d4284dd1b824e882 Mon Sep 17 00:00:00 2001 From: JasonZhangHub Date: Sat, 7 Mar 2026 00:06:52 +0800 Subject: [PATCH] Refine config loading and classification data typing --- openai_finetuner.py | 71 +++++++++++++++++++++++++-------------------- 1 file changed, 39 insertions(+), 32 deletions(-) diff --git a/openai_finetuner.py b/openai_finetuner.py index 6e436f4..0e1087a 100644 --- a/openai_finetuner.py +++ b/openai_finetuner.py @@ -1,31 +1,35 @@ import json import os import time -from typing import List, Dict, Any, Tuple, Optional +from typing import List, Dict, Any, Optional, NamedTuple import yaml +import logging from openai import OpenAI import pandas as pd from dotenv import load_dotenv -def load_config(config_path="config.yml") -> Optional[Dict[str, Any]]: - """Loads configuration from a YAML file.""" - try: - with open(config_path, 'r') as f: - config = yaml.safe_load(f) - print(f"Configuration loaded successfully from {config_path}") - return config - except FileNotFoundError: - print(f"ERROR: Configuration file not found at {config_path}") - return None - except yaml.YAMLError as e: - print( - f"ERROR: Could not parse YAML configuration from {config_path}: {e}") - return None - except Exception as e: - print(f"An unexpected error occurred while loading configuration: {e}") - return None +logger = logging.getLogger(__name__) + + +def load_config(config_path="config.yml") -> Dict[str, Any]: + """Loads configuration from a YAML file and applies defaults.""" + with open(config_path, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) or {} + + if not isinstance(config, dict): + raise ValueError(f"Configuration in {config_path} must be a mapping.") + + openai_cfg = config.setdefault('openai_settings', {}) + openai_cfg.setdefault('classification_temperature', 0.0) + openai_cfg.setdefault('classification_max_tokens', 200) + + prompt_cfg = config.setdefault('prompt_customization', {}) + prompt_cfg.setdefault('output_wrapper_tag', 'category') + + logger.info("Configuration loaded successfully from %s", config_path) + return config CONFIG = load_config() @@ -47,13 +51,23 @@ def initialize_openai_client() -> Optional[OpenAI]: return None + +class ClassificationData(NamedTuple): + train_data: pd.DataFrame + test_data: pd.DataFrame + train_texts: List[str] + train_labels: List[str] + test_texts: List[str] + test_labels: List[str] + categories: List[str] + def load_classification_data( train_file_path: str, test_file_path: str, sep: str = "\t", text_column: str = "text", label_column: str = "label" -) -> Optional[Tuple[pd.DataFrame, pd.DataFrame, List[str], List[str], List[str], List[str], List[str]]]: +) -> Optional[ClassificationData]: """ Loads training and testing data from specified TSV/CSV files. Uses text_column and label_column from global CONFIG. @@ -93,7 +107,7 @@ def load_classification_data( test_texts = test_data[text_col].astype(str).tolist() test_labels = test_data[label_col].astype(str).tolist() - return train_data, test_data, train_texts, train_labels, test_texts, test_labels, categories + return ClassificationData(train_data, test_data, train_texts, train_labels, test_texts, test_labels, categories) except FileNotFoundError as e: print(f"ERROR: Data file not found: {e.filename}") @@ -378,18 +392,11 @@ def classify_items_with_model( Classifies a list of items using a specified OpenAI model. Reads prompt, temperature, and max_tokens from global CONFIG. """ - if not CONFIG: - print("ERROR: Configuration not loaded. Using default classification settings.") - temperature = 0.0 - max_tokens = 200 - output_wrapper_tag = "category" - else: - # Assuming these might be added to config.yml under openai_settings or a new section - openai_cfg = CONFIG.get('openai_settings', {}) - temperature = openai_cfg.get('classification_temperature', 0.0) - max_tokens = openai_cfg.get('classification_max_tokens', 200) - prompt_cfg = CONFIG.get('prompt_customization', {}) - output_wrapper_tag = prompt_cfg.get('output_wrapper_tag', "category") + openai_cfg = CONFIG.get('openai_settings', {}) + temperature = openai_cfg.get('classification_temperature', 0.0) + max_tokens = openai_cfg.get('classification_max_tokens', 200) + prompt_cfg = CONFIG.get('prompt_customization', {}) + output_wrapper_tag = prompt_cfg.get('output_wrapper_tag', "category") responses = [] stop_sequence = [f""]