forked from qibin0506/Cortex
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfile_dataset.py
More file actions
108 lines (82 loc) · 3.16 KB
/
file_dataset.py
File metadata and controls
108 lines (82 loc) · 3.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
import threading
from llm_trainer import FileDataset, TrainerTools
from constant import data_root_dir
from modelscope import dataset_snapshot_download
class FileDatasetBase(FileDataset):
def __init__(self, file_names: list):
self.file_names = file_names
def __len__(self) -> int:
return len(self.file_names)
def __getitem__(self, idx) -> str:
file_path = f"{data_root_dir}{self.file_names[idx]}"
# 下载当前文件
if not os.path.exists(file_path):
if TrainerTools().parallel.is_main_process:
print(f"正在下载{file_path}")
dataset_snapshot_download(
'qibin0506/cortex-train-data-v2',
allow_file_pattern=[self.file_names[idx]],
local_dir=data_root_dir
)
TrainerTools().parallel.wait()
# 删除并下载后一个文件
if idx < len(self.file_names) - 1 and TrainerTools().parallel.is_main_process:
next_file = self.file_names[idx + 1]
dst_file = f'{data_root_dir}{next_file}'
if os.path.exists(dst_file):
os.remove(dst_file)
threading.Thread(
target=dataset_snapshot_download,
kwargs={
'dataset_id': 'qibin0506/cortex-train-data-v2',
'allow_file_pattern': [next_file],
'local_dir': data_root_dir
}
).start()
# 删除前一个文件
if idx > 0 and TrainerTools().parallel.is_main_process:
prev_file = self.file_names[idx - 1]
if os.path.exists(f'{data_root_dir}{prev_file}'):
os.remove(f'{data_root_dir}{prev_file}')
# with open('./data/pretrained.txt', 'a') as f:
# f.write(f'{prev_file},')
return file_path
class PretrainStage0FileDataset(FileDatasetBase):
def __init__(self):
super().__init__([
'mobvoi_seq_monkey_short_0.pkl',
'mobvoi_seq_monkey_short_1.pkl',
'mobvoi_seq_monkey_short_2.pkl',
'mobvoi_seq_monkey_short_3.pkl',
'mobvoi_seq_monkey_short_4.pkl',
'mobvoi_seq_monkey_short_5.pkl',
'mobvoi_seq_monkey_short_6.pkl',
'mobvoi_seq_monkey_short_7.pkl',
'mobvoi_seq_monkey_short_8.pkl',
'wikipedia.pkl',
])
class PretrainStage1FileDataset(FileDatasetBase):
def __init__(self):
super().__init__([
'mobvoi_seq_monkey_long_0.pkl',
'mobvoi_seq_monkey_long_1.pkl'
])
class COTFileDataset(FileDatasetBase):
def __init__(self):
super().__init__(['cot_sft.pkl'])
class GRPOFileDataset(FileDatasetBase):
def __init__(self):
super().__init__(['grpo.pkl'])
class MixFileDataset(FileDatasetBase):
def __init__(self):
super().__init__(['mix_sft.pkl'])
class DPOFileDataset(FileDatasetBase):
def __init__(self):
super().__init__(['dpo.pkl'])
class DistillDataset(FileDatasetBase):
def __init__(self):
super().__init__([
'cot_sft.pkl',
'mix_sft.pkl'
])