-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathparams.py
More file actions
115 lines (102 loc) · 2.89 KB
/
params.py
File metadata and controls
115 lines (102 loc) · 2.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# -*- coding: utf-8 -*-
"""
配置参数文件 - 量化分析系统
"""
from types import SimpleNamespace
# 基础配置
params = SimpleNamespace()
# 数据路径配置
# 新增CSV数据路径配置
params.feat_path = "./test_feat.csv" # 特征数据CSV文件路径
params.label_csv_path = "./test_label.csv" # 标签数据CSV文件路径
# 训练参数
params.batch_size = 64
params.epochs = 50
params.learning_rate = 1e-3
params.patience = 10
params.num_workers = 8
params.max_grad_norm = 1.0
# 输出配置
params.output_root = "./project_haoxin/output/gru_only_baseline"
# 数据维度配置
params.time_steps = 15
params.num_nodes = 12
params.num_samples = 8
params.flat_dim = params.num_nodes * params.num_samples # 96
# 模型选择配置
params.model_type = "graph_transformer" # "gru" 或 "graph_transformer"
# GRU模型配置
params.gru_config = {
"input_dim": params.flat_dim, # 96
"hidden_dim": 256,
"num_layers": 2,
"dropout": 0.05,
"fc_dims": [64],
"use_batch_norm": False,
"bidirectional": False,
"pooling": "tail_mean",
"tail_k": 8,
}
# GraphTransformer模型配置
params.graph_transformer_config = {
# 图编码器配置
"graph_cfg": {
"num_nodes": 12,
"in_dim": 8,
"d_model": 512, # 图表示维度F(增大=更强表达)
"n_heads": 8,
"n_layers": 8, # 8层图Transformer(≈16-18M图侧参数)
"dropout": 0.10,
"ff_mult": 4,
"use_node_type_embed": True,
"prior_matrix": None, # 如有 12x12 相关性先验,可传入 torch.Tensor
"prior_strength": 0.2,
},
# 时序模块配置
"temporal_cfg": {
"input_dim": 512, # 自动接管,无需手填
"hidden_dim": 256,
"num_layers": 2,
"dropout": 0.05,
"fc_dims": [128, 64],
"use_batch_norm": False,
"bidirectional": False,
"pooling": "tail_mean",
"tail_k": 8
}
}
# 大模型配置(可选)
params.graph_transformer_large_config = {
# 更大容量(≈3千万级参数)
"graph_cfg": {
"num_nodes": 12,
"in_dim": 8,
"d_model": 640, # 或 768
"n_heads": 10, # d_model 必须能整除 n_heads
"n_layers": 10,
"dropout": 0.10,
"ff_mult": 4,
"use_node_type_embed": True,
"prior_matrix": None,
"prior_strength": 0.2,
},
"temporal_cfg": {
"input_dim": 640,
"hidden_dim": 320, # 适配更大F
"num_layers": 2,
"dropout": 0.05,
"fc_dims": [160, 64],
"use_batch_norm": False,
"bidirectional": False,
"pooling": "tail_mean",
"tail_k": 8
}
}
# 随机种子
params.random_seeds = [42, 123, 456, 789, 1000]
# 时间窗口配置
params.time_window = "day"
# 实验配置
params.n_experiments = 5
# 测试模式配置
params.test_mode = True # 是否启用测试模式