-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
144 lines (112 loc) · 4.94 KB
/
utils.py
File metadata and controls
144 lines (112 loc) · 4.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import os
import argparse
import torch
import numpy as np
import json
from datetime import datetime
import pandas as pd
def get_args():
parser = argparse.ArgumentParser(description='deeplog')
parser.add_argument(
'--clf',
default='',
help='set model name:ref var clf_list'
)
parser.add_argument(
'--dataset',
default='',
help='set dataset'
)
parser.add_argument(
'--test-num',
default='',
help='set dataset'
)
args = parser.parse_args()
args.clf = args.clf.split(',')
return args
def validate_clf(clf_list, input_clfs):
if len(input_clfs):
for clf in input_clfs:
if clf not in clf_list:
raise Exception('input clf error: %s' % clf)
else:
raise Exception('empty clfs: %s' % input_clfs)
def shuffle_data(X, y):
random_indices = np.random.permutation(X.shape[0])
new_X, new_Y = X[random_indices, :], y[random_indices]
return new_X, new_Y
def get_temp_file():
# 获取当前日期和时间
current_time = datetime.now()
# 使用strftime()方法格式化时间
formatted_time = current_time.strftime("%m-%d_%H_%M_%S")
rela_path = 'data/temp/'
# file = '%s.txt' % formatted_time
if not os.path.exists(rela_path):
os.makedirs(rela_path)
# file_abs_path = os.path.join(rela_path, file)
return formatted_time, rela_path
def map2event(y, map_relation):
mapping_func_to_clf = lambda label: map_relation[label]
y = y.apply_(mapping_func_to_clf)
return y
def get_dataset(path):
# return : tensor, array
dataset = torch.load(path)
# load real event label
X = dataset['X']
y = dataset['y'].numpy()
unique_values, unique_counts = np.unique(y, return_counts=True)
print('dataset path %s: %s' % (path, X.shape))
print('class number %s' % (unique_values))
print('counts %s' % unique_counts)
print('---------------------------------')
return X, y
def save_dict_data(dir_path, create_time, dataset_name, *json_data, **json_name_data):
file_name = '%s_' % dataset_name + '%s_' % create_time + '.txt'
file_abs_path = os.path.join(dir_path, file_name)
def default(obj):
if isinstance(obj, np.ndarray):
return obj.tolist() # 将NumPy数组转换为列表
if isinstance(obj, torch.Tensor):
return obj.tolist() # 将NumPy数组转换为列表
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
# 将字符串写入同一个文本文件中
with open(file_abs_path, 'w') as file:
for d in json_data:
file.write("\n\nDictionary :\n")
file.write(json.dumps(d, indent=4, default=default))
with open(file_abs_path, 'w') as file:
for k, v in json_name_data.items():
file.write("\n\n Name: %s\n" % k)
file.write(json.dumps(v, indent=4, default=default))
def save_excel_reports(report_list, dir_path, create_time, test_index, dataset_name):
file_excel_testnum = create_time + '_%s_' % dataset_name + '_num-%s.xlsx' % test_index
file_abs_path = os.path.join(dir_path, file_excel_testnum)
with pd.ExcelWriter(file_abs_path, engine='xlsxwriter') as writer:
for name, report_str in report_list.items():
print('clf save:', name)
report_lines = report_str.strip().split('\n')
data = [line.split() for line in report_lines[2:]]
df = pd.DataFrame(data, columns=[' ', 'class', 'precision', 'recall', 'f1-score', 'support'])
# 将DataFrame保存为Excel文件
df.to_excel(writer, sheet_name=name, index=False)
def save_excel_reports_testnum(report_list_testnum, dir_path, create_time, test_num, dataset_name):
file_excel_testnum = create_time + '_%s_' % dataset_name + '_num-%s.xlsx' % test_num
file_abs_path = os.path.join(dir_path, file_excel_testnum)
start_rows_clf = {}
with pd.ExcelWriter(file_abs_path, engine='xlsxwriter') as writer:
for test_index, report_list in report_list_testnum.items():
for name, report_str in report_list.items():
print('clf save:', name)
report_lines = report_str.strip().split('\n')
data = [line.split() for line in report_lines[2:]]
df = pd.DataFrame(data, columns=[' ', 'class', 'precision', 'recall', 'f1-score', 'support'])
if name not in start_rows_clf:
start_rows_clf[name] = 0
start_rows = start_rows_clf[name]
# 将DataFrame保存为Excel文件
df.to_excel(writer, sheet_name=name, index=False, startrow=start_rows)
# +1 for title, +3 for space
start_rows_clf[name] = start_rows + len(data) + 1 + 3