-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathMainSR.py
More file actions
58 lines (53 loc) · 2.46 KB
/
MainSR.py
File metadata and controls
58 lines (53 loc) · 2.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from EquationLearning.SymbolicRegressor.SetGAP import SetGAP
from EquationLearning.SymbolicRegressor.MSSP import *
if __name__ == '__main__':
###########################################
# Import data
###########################################
import torch
from EquationLearning.models.NNModel import NNModel
datasetNames = ['E6']
seed = 8
print("Seed ", seed)
noise = 0 # 0.01
print(datasetNames)
print(noise)
noise_name = ''
if noise > 0:
noise_name = '_noise-' + str(noise)
dimensions = ''
for datasetName in datasetNames:
if datasetName == 'E10':
datasetName = 'CS1'
elif datasetName == 'E11':
datasetName = 'CS2'
elif datasetName == 'E12':
datasetName = 'CS3'
elif datasetName == 'E13':
datasetName = 'CS4'
data_loader = DataLoader(name=datasetName, noise=noise, seed=seed)
data = data_loader.dataset
###########################################
# Define NN and load weights
###########################################
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
folder = os.path.join(get_project_root(), "EquationLearning//saved_models//saved_NNs//" + datasetName)
filepath = folder + "//weights-NN-" + datasetName + noise_name + dimensions
nn_model = None
if os.path.exists(filepath.replace("weights", "NNModel") + '.pth'):
# If this file exists, it means we saved the whole model
network = torch.load(filepath.replace("weights", "NNModel") + '.pth', map_location=device)
nn_model = NNModel(device=device, n_features=data.n_features, loaded_NN=network)
elif os.path.exists(filepath):
# If this file exists, initiate a model and load the weights
nn_model = NNModel(device=device, n_features=data.n_features, NNtype=data_loader.modelType)
nn_model.loadModel(filepath)
else:
# If neither files exist, we haven't trained a NN for this problem yet
if data.n_features > 1:
sys.exit("We haven't trained a NN for this problem yet. Use the TrainNNModel.py file first.")
###########################################
# Get Estimated Multivariate Expressions
###########################################
regressor = SetGAP(dataset=data, bb_model=nn_model, n_candidates=3)
results = regressor.run()