Skip to content

Commit 424ca67

Browse files
authored
Merge pull request #7 from ComputationalDesignLab/add-neural-network-model
Add neural network model
2 parents 7d50bdc + 8c2565a commit 424ca67

3 files changed

Lines changed: 375 additions & 0 deletions

File tree

scimlstudio/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .polynomial import Polynomial
22
from .rbf import RBF
33
from .single_output_gp_model import SingleOutputGP
4+
from .feed_forward_nn_model import FeedForwardNeuralNetwork
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
import torch
2+
from ..base_models import BaseModel
3+
from ..utils import Standardize, Normalize
4+
5+
class FeedForwardNeuralNetwork(BaseModel):
6+
7+
def __init__(
8+
self,
9+
x_train: torch.Tensor,
10+
y_train: torch.Tensor,
11+
network: torch.nn.Sequential,
12+
input_transform: Standardize | Normalize | None = None,
13+
output_transform: Standardize | Normalize | None = None,
14+
):
15+
"""
16+
Class definition for training and predicting using
17+
a feed-forward neural network for supervised problems
18+
19+
Parameters
20+
----------
21+
x_train: torch.Tensor
22+
Input training data for the network in a 2D tensor
23+
24+
y_train: torch.Tensor
25+
Output training data for the network in a 2D tensor
26+
27+
network: torch.nn.Sequential
28+
Sequential object defining the network
29+
30+
input_transform: Normalize or Standardize or None
31+
Data scaling class for the inputs of the network
32+
33+
output_transform: Normalize or Standardize or None
34+
Data scaling class for the outputs of the network
35+
"""
36+
37+
# Some checks
38+
assert isinstance(x_train, torch.Tensor) and x_train.ndim == 2, "xtrain must be a 2D tensor array"
39+
assert isinstance(y_train, torch.Tensor) and y_train.ndim == 2, "ytrain must be a 2D tensor array"
40+
assert x_train.shape[0] == y_train.shape[0], "number of samples in input and output training data must be the same"
41+
assert x_train.device == y_train.device, "input and output training data must be on the same device"
42+
assert isinstance(network, torch.nn.Sequential), "network should be an instance of sequential class from torch.nn module"
43+
for param in network.parameters():
44+
assert param.device == x_train.device, "network parameters should be on the same device as the training data"
45+
46+
if input_transform is not None:
47+
assert isinstance(input_transform, Normalize) or isinstance(input_transform, Standardize), "input_transform should be an instance of Normalize or Standardize class"
48+
49+
if output_transform is not None:
50+
assert isinstance(output_transform, Normalize) or isinstance(output_transform, Standardize), "output_transform should be an instance of Normalize or Standardize class"
51+
52+
try:
53+
network.eval()
54+
with torch.no_grad():
55+
network(x_train[0])
56+
except Exception as e:
57+
raise RuntimeError(f"Network architecture is not correct and/or not compatible with the provided data: {e}")
58+
59+
super().__init__()
60+
61+
network.train() # set network in train mode
62+
63+
self.x_train = x_train
64+
self.y_train = y_train
65+
self.network = network
66+
self.input_transform = input_transform
67+
self.output_transform = output_transform
68+
69+
@property
70+
def parameters(self):
71+
return self.network.parameters() # network parameters
72+
73+
def fit(
74+
self,
75+
optimizer: torch.optim.Optimizer,
76+
loss_func: torch.nn.modules.loss._Loss,
77+
batch_size: int = 1,
78+
epochs: int = 100,
79+
convert_to_eval_mode: bool = True
80+
):
81+
"""
82+
Method to fit the network to the training data
83+
84+
`NOTE`: This method supports mini-batch training
85+
86+
Parameters
87+
----------
88+
optimizer: torch.optim.Optimizer
89+
Optimizer object from torch.optim module to optimize the network parameters
90+
91+
loss_func: torch.nn.modules.loss._Loss
92+
Loss function object from torch.nn.Module.loss module to compute the loss during training
93+
94+
batch_size: int
95+
Batch size to use during training, default = 1
96+
97+
epochs: int
98+
Number of epochs to train the network, default = 100
99+
100+
convert_to_eval_mode: bool
101+
Flag to set the network to eval mode after training is done, default = True
102+
"""
103+
104+
assert isinstance(optimizer, torch.optim.Optimizer), "`optimizer` should be an instance of PyTorch optimizer class"
105+
assert isinstance(loss_func, torch.nn.modules.loss._Loss), "`loss_func` should be an instance of a PyTorch loss function class"
106+
assert isinstance(batch_size, int) and batch_size > 0, "`batch_size` should be a positive integer"
107+
assert isinstance(epochs, int) and epochs > 0, "`epochs` should be a positive integer"
108+
assert isinstance(convert_to_eval_mode, bool), "`convert_to_eval_mode` should be a boolean value"
109+
110+
self.network.train() # set network in train mode
111+
112+
# transform the training data
113+
if self.input_transform is not None:
114+
x_train = self.input_transform.transform(self.x_train)
115+
else:
116+
x_train = self.x_train
117+
118+
if self.output_transform is not None:
119+
y_train = self.output_transform.transform(self.y_train)
120+
else:
121+
y_train = self.y_train
122+
123+
# dataset and dataloader
124+
dataset = torch.utils.data.TensorDataset(x_train, y_train)
125+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
126+
127+
# training loop
128+
for epoch in range(epochs):
129+
130+
# loop over all batches
131+
for x_batch, y_batch in dataloader:
132+
133+
optimizer.zero_grad() # zero the grads
134+
135+
y_pred = self.network(x_batch) # forward pass
136+
137+
loss = loss_func(y_pred, y_batch) # compute the loss
138+
139+
loss.backward() # backward pass
140+
141+
optimizer.step() # update the parameters
142+
143+
if convert_to_eval_mode:
144+
self.network.eval()
145+
146+
def predict(self, x: torch.Tensor) -> torch.Tensor:
147+
"""
148+
Method to predict the output for the given input data
149+
150+
`NOTE`: predictions are made in no grad context
151+
152+
Parameters
153+
----------
154+
x: torch.Tensor
155+
a torch tensor representing the input data used for prediction
156+
157+
Returns
158+
-------
159+
y_pred: torch.Tensor
160+
a torch tensor representing the predicted output for the given input data
161+
"""
162+
163+
assert isinstance(x, torch.Tensor), "`x` should be a torch tensor"
164+
assert x.device == self.x_train.device, "input data should be on the same device as the training data"
165+
166+
x_ndim = x.ndim # number of dimensions in the given input data
167+
168+
# check input shape and add batch dim if necessary
169+
if x_ndim == self.x_train.ndim:
170+
assert x.shape[1:] == self.x_train.shape[1:], "input data should have the same feature size as the training data"
171+
elif x_ndim == self.x_train.ndim - 1:
172+
assert x.shape == self.x_train.shape[1:], "input data should have the same feature size as the training data"
173+
x = x.unsqueeze(0) # add batch dim as 1
174+
else:
175+
raise ValueError("input data should be of similar shape as the training data")
176+
177+
# check if network is in train mode
178+
if self.network.training:
179+
raise RuntimeError("Network is in train mode, please use the `fit` method to train the network first and then call the `predict` method")
180+
181+
# transform the input data
182+
if self.input_transform is not None:
183+
x = self.input_transform.transform(x)
184+
185+
# predict in no grad context
186+
with torch.no_grad():
187+
y_pred = self.network(x)
188+
189+
# inverse transform the predicted output
190+
if self.output_transform is not None:
191+
y_pred = self.output_transform.inverse_transform(y_pred)
192+
193+
# remove batch dim, if it was added
194+
if x_ndim == self.x_train.ndim - 1:
195+
y_pred = y_pred.squeeze(0)
196+
197+
return y_pred

tests/test_feed_forward_nn.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
import unittest, torch
2+
from scimlstudio.models import FeedForwardNeuralNetwork
3+
from scimlstudio.utils import evaluate_scalar, Standardize
4+
5+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
6+
dtype = torch.float32
7+
args = {
8+
"device": device,
9+
"dtype": dtype
10+
}
11+
12+
class TestFeedForwardNeuralNetwork(unittest.TestCase):
13+
"""
14+
Class defining the test cases for the feed forward neural network model
15+
"""
16+
17+
def test_nn_model_1d(self):
18+
19+
# training data
20+
xtrain = torch.linspace(0, 2*torch.pi, 7, **args).reshape(-1,1)
21+
ytrain = torch.sin(xtrain)
22+
23+
# testing data
24+
xtest = torch.linspace(0, 2*torch.pi, 100, **args).reshape(-1,1)
25+
ytest = torch.sin(xtest)
26+
27+
# network
28+
network = torch.nn.Sequential(
29+
torch.nn.Linear(in_features=xtrain.shape[1], out_features=32),
30+
torch.nn.GELU(),
31+
torch.nn.Linear(in_features=32, out_features=32),
32+
torch.nn.GELU(),
33+
torch.nn.Linear(in_features=32, out_features=32),
34+
torch.nn.GELU(),
35+
torch.nn.Linear(in_features=32, out_features=ytrain.shape[1]),
36+
).to(**args)
37+
38+
def init_weights(m):
39+
"""
40+
Function for initializing the weights using glorot (or xavier) initialization
41+
"""
42+
43+
if isinstance(m, torch.nn.Linear):
44+
torch.nn.init.xavier_normal_(m.weight)
45+
m.bias.data.fill_(0.0)
46+
47+
# initial weights
48+
network.apply(init_weights)
49+
50+
# data transforms
51+
input_transform = Standardize(xtrain)
52+
output_transform = Standardize(ytrain)
53+
54+
# create model instance
55+
model = FeedForwardNeuralNetwork(xtrain, ytrain, network, input_transform=input_transform, output_transform=output_transform)
56+
57+
# optimizer
58+
optimizer = torch.optim.Adam(model.parameters, lr=0.01)
59+
60+
# loss function
61+
loss_func = torch.nn.MSELoss()
62+
63+
# fit the model
64+
model.fit(optimizer, loss_func, batch_size=xtrain.shape[0], epochs=100)
65+
66+
# predict
67+
ytest_pred = model.predict(xtest)
68+
69+
# metrics
70+
r2 = evaluate_scalar(ytest.reshape(-1,), ytest_pred.reshape(-1,), "r2")
71+
nrmse = evaluate_scalar(ytest.reshape(-1,), ytest_pred.reshape(-1,), "nrmse")
72+
73+
assert nrmse < 2e-2 and r2 > 0.99
74+
75+
def test_nn_model_2d(self):
76+
77+
# train
78+
x1 = torch.linspace(0,1,5,**args)
79+
x2 = torch.linspace(0,1,5,**args)
80+
X1, X2 = torch.meshgrid(x1, x2, indexing="ij")
81+
xtrain = torch.hstack(( X1.reshape(-1,1), X2.reshape(-1,1) ))
82+
ytrain = torch.cos(torch.sum(xtrain, axis=1))*torch.exp(torch.prod(xtrain, axis=1))
83+
ytrain = ytrain.reshape(-1,1)
84+
85+
# test
86+
x1 = torch.linspace(0,1,15,**args)
87+
x2 = torch.linspace(0,1,15,**args)
88+
X1, X2 = torch.meshgrid(x1, x2, indexing="ij")
89+
xtest = torch.hstack(( X1.reshape(-1,1), X2.reshape(-1,1) ))
90+
ytest = torch.cos(xtest[:,0]+xtest[:,1])*torch.exp(xtest[:,0]*xtest[:,1])
91+
ytest = ytest.reshape(-1,1)
92+
93+
# network
94+
network = torch.nn.Sequential(
95+
torch.nn.Linear(in_features=xtrain.shape[1], out_features=32),
96+
torch.nn.GELU(),
97+
torch.nn.Linear(in_features=32, out_features=32),
98+
torch.nn.GELU(),
99+
torch.nn.Linear(in_features=32, out_features=32),
100+
torch.nn.GELU(),
101+
torch.nn.Linear(in_features=32, out_features=ytrain.shape[1]),
102+
).to(**args)
103+
104+
def init_weights(m):
105+
"""
106+
Function for initializing the weights using glorot (or xavier) initialization
107+
"""
108+
109+
if isinstance(m, torch.nn.Linear):
110+
torch.nn.init.xavier_normal_(m.weight)
111+
m.bias.data.fill_(0.0)
112+
113+
# initial weights
114+
network.apply(init_weights)
115+
116+
# data transforms
117+
input_transform = Standardize(xtrain)
118+
output_transform = Standardize(ytrain)
119+
120+
# create model instance
121+
model = FeedForwardNeuralNetwork(xtrain, ytrain, network, input_transform=input_transform, output_transform=output_transform)
122+
123+
# optimizer
124+
optimizer = torch.optim.Adam(model.parameters, lr=0.01)
125+
126+
# loss function
127+
loss_func = torch.nn.MSELoss()
128+
129+
# fit the model
130+
model.fit(optimizer, loss_func, batch_size=xtrain.shape[0], epochs=100)
131+
132+
# predict
133+
ytest_pred = model.predict(xtest)
134+
135+
# metrics
136+
r2 = evaluate_scalar(ytest.reshape(-1,), ytest_pred.reshape(-1,), "r2")
137+
nrmse = evaluate_scalar(ytest.reshape(-1,), ytest_pred.reshape(-1,), "nrmse")
138+
139+
assert nrmse < 1e-2 and r2 > 0.99
140+
141+
def test_input_output_shapes(self):
142+
143+
# dummy training data
144+
xtrain = torch.rand(10, 5, **args)
145+
ytrain = torch.rand(10, 1, **args)
146+
147+
# network
148+
network = torch.nn.Sequential(
149+
torch.nn.Linear(in_features=xtrain.shape[1], out_features=16),
150+
torch.nn.Tanh(),
151+
torch.nn.Linear(in_features=16, out_features=16),
152+
torch.nn.Tanh(),
153+
torch.nn.Linear(in_features=16, out_features=ytrain.shape[1]),
154+
).to(**args)
155+
156+
# create model instance
157+
model = FeedForwardNeuralNetwork(xtrain, ytrain, network)
158+
159+
# optimizer
160+
optimizer = torch.optim.Adam(model.parameters, lr=0.01)
161+
162+
# loss function
163+
loss_func = torch.nn.MSELoss()
164+
165+
# fit the model
166+
model.fit(optimizer, loss_func, batch_size=xtrain.shape[0], epochs=100)
167+
168+
# predict - 1 samples
169+
ypred = model.predict(xtrain[0])
170+
assert ypred.ndim == 1 and ypred.shape[0] == 1
171+
172+
# predict - 5 samples
173+
ypred = model.predict(xtrain[:5])
174+
assert ypred.ndim == 2 and ypred.shape[0] == 5
175+
176+
if __name__ == '__main__':
177+
unittest.main()

0 commit comments

Comments
 (0)