1010import pandas as pd
1111from datasets import load_dataset
1212
13- TOXIC_COLUMNS = ["toxic" , "severe_toxic" , "obscene" , "threat" , "insult" , "identity_hate" ]
13+ TOXIC_COLUMNS = [
14+ "toxic" ,
15+ "severe_toxic" ,
16+ "obscene" ,
17+ "threat" ,
18+ "insult" ,
19+ "identity_hate" ,
20+ ]
1421TEXT_COLUMN = "comment_text"
1522OUTPUT_DIR = Path (__file__ ).parent / "data" / "processed"
1623
@@ -46,29 +53,51 @@ def load_single(
4653 if label_source == "paradetox" :
4754 # toxic = 1, neutral/detox = 0
4855 input_col = next (
49- (c for c in [
50- "input" , "source" , "toxic" ,
51- "en_toxic_comment" , "ru_toxic_comment" , "toxic_sentence" ,
52- ] if c in df .columns ),
56+ (
57+ c
58+ for c in [
59+ "input" ,
60+ "source" ,
61+ "toxic" ,
62+ "en_toxic_comment" ,
63+ "ru_toxic_comment" ,
64+ "toxic_sentence" ,
65+ ]
66+ if c in df .columns
67+ ),
5368 None ,
5469 )
5570 output_col = next (
56- (c for c in [
57- "output" , "target" , "detox" ,
58- "en_neutral_comment" , "ru_neutral_comment" , "neutral_sentence" ,
59- ] if c in df .columns ),
71+ (
72+ c
73+ for c in [
74+ "output" ,
75+ "target" ,
76+ "detox" ,
77+ "en_neutral_comment" ,
78+ "ru_neutral_comment" ,
79+ "neutral_sentence" ,
80+ ]
81+ if c in df .columns
82+ ),
6083 None ,
6184 )
6285 if not input_col or not output_col :
63- raise ValueError (f"ParaDetox format needs toxic/neutral columns. Columns: { list (df .columns )} " )
86+ raise ValueError (
87+ f"ParaDetox format needs toxic/neutral columns. Columns: { list (df .columns )} "
88+ )
6489 toxic_df = df [[input_col ]].rename (columns = {input_col : TEXT_COLUMN })
6590 toxic_df ["label" ] = 1
6691 clean_df = df [[output_col ]].rename (columns = {output_col : TEXT_COLUMN })
6792 clean_df ["label" ] = 0
6893 df = pd .concat ([toxic_df , clean_df ], ignore_index = True )
6994 else :
7095 text_col = text_col or next (
71- (c for c in ["comment_text" , "text" , "comment" , "sentence" , "content" ] if c in df .columns ),
96+ (
97+ c
98+ for c in ["comment_text" , "text" , "comment" , "sentence" , "content" ]
99+ if c in df .columns
100+ ),
72101 None ,
73102 )
74103 if not text_col :
@@ -81,7 +110,9 @@ def load_single(
81110 # civil_comments: toxicity 0-1, threshold 0.5
82111 tox_col = next ((c for c in ["toxicity" , "toxic" ] if c in df .columns ), None )
83112 if not tox_col :
84- raise ValueError (f"Toxicity column not found. Columns: { list (df .columns )} " )
113+ raise ValueError (
114+ f"Toxicity column not found. Columns: { list (df .columns )} "
115+ )
85116 df ["label" ] = (df [tox_col ].fillna (0 ) >= 0.5 ).astype (int )
86117 elif label_source .startswith ("toxic" ):
87118 toxic_cols = [c for c in TOXIC_COLUMNS if c in df .columns ]
@@ -132,7 +163,10 @@ def load_multilingual(max_samples_per_dataset: int | None = None) -> pd.DataFram
132163
133164 # English + Russian + multilingual paradetox
134165 for name , (ds , _ , src ) in DATASET_PRESETS .items ():
135- if name in ("paradetox" , "ru_paradetox" , "multilingual_paradetox" ) and src == "paradetox" :
166+ if (
167+ name in ("paradetox" , "ru_paradetox" , "multilingual_paradetox" )
168+ and src == "paradetox"
169+ ):
136170 try :
137171 df = load_single (ds , src , None , max_samples_per_dataset , 3 , 512 )
138172 dfs .append (df )
@@ -144,7 +178,9 @@ def load_multilingual(max_samples_per_dataset: int | None = None) -> pd.DataFram
144178 return pd .concat (dfs , ignore_index = True ).drop_duplicates (subset = [TEXT_COLUMN ])
145179
146180
147- def balance (df : pd .DataFrame , ratio : float = 0.3 , max_total : int | None = None ) -> pd .DataFrame :
181+ def balance (
182+ df : pd .DataFrame , ratio : float = 0.3 , max_total : int | None = None
183+ ) -> pd .DataFrame :
148184 """Balance classes. ratio = fraction of positive samples. max_total caps result size."""
149185 pos = df [df ["label" ] == 1 ]
150186 neg = df [df ["label" ] == 0 ]
@@ -213,7 +249,9 @@ def main() -> None:
213249 ds_name , text_col , label_src = DATASET_PRESETS [args .preset ]
214250 df = load_single (ds_name , label_src , text_col , args .max_samples , 3 , 512 )
215251
216- print (f"Total: { len (df )} samples, { df ['label' ].sum ()} positive ({ df ['label' ].mean ():.2%} )" )
252+ print (
253+ f"Total: { len (df )} samples, { df ['label' ].sum ()} positive ({ df ['label' ].mean ():.2%} )"
254+ )
217255
218256 df_balanced = balance (df , ratio = args .positive_ratio , max_total = args .max_total )
219257 print (f"Balanced: { len (df_balanced )} samples" )
0 commit comments