Skip to content

Commit 40f75bc

Browse files
authored
Merge pull request #4 from ComputationalDesignLab/abhi-add-rbf-implementation
Add RBF model implementation
2 parents 0a26f28 + e2b7621 commit 40f75bc

3 files changed

Lines changed: 307 additions & 1 deletion

File tree

scimlstudio/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
from .polynomial import Polynomial
1+
from .polynomial import Polynomial
2+
from .rbf import RBF

scimlstudio/models/rbf.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
import torch
2+
from ..base_model import BaseModel
3+
from ..utils.transformations import Standardize, Normalize
4+
5+
class RBF(BaseModel):
6+
7+
def __init__(self, x_train: torch.Tensor, y_train: torch.Tensor, sigma: float,
8+
input_transform: Normalize | Standardize | None = None, output_transform: Normalize | Standardize | None = None,
9+
basis: str = "gaussian"):
10+
11+
"""
12+
Class definition for parametric radial basis function models
13+
14+
f(x) = w^T psi(x), psi_i(x) = psi(r), where r = ||x - c_i||
15+
16+
NOTE: Only the Gaussian and multiquadric basis functions are implemented
17+
18+
Parameters
19+
----------
20+
xtrain: torch.Tensor
21+
Input training data for the RBF model with size (N, d)
22+
23+
ytrain: torch.Tensor
24+
Output training data for the RBF model with size (N, 1)
25+
26+
sigma: float
27+
Value of sigma for calculating the basis function
28+
29+
input_transform: Normalize | Standardize | None
30+
Object of a transformation class for applying a transform to the inputs of the model
31+
The default is None which means the application of no transform to inputs
32+
33+
output_transform: Normalize | Standardize | None
34+
Object of a transformation for applying a transform to the outputs of the model
35+
The default is None which means the application of no transform to outputs
36+
37+
basis: str
38+
String that specifies which parametric basis function should be used
39+
User has a choice between "gaussian", "multiquadric" and "inverse"
40+
41+
gaussian --> exp(-r^2/(2 * sigma^2))
42+
multiquadric --> (r^2 + sigma^2)^0.5
43+
inverse --> (r^2 + sigma^2)^(-0.5)
44+
"""
45+
46+
super().__init__()
47+
48+
# Checking inputs
49+
assert isinstance(x_train, torch.Tensor) and x_train.ndim == 2, "xtrain must be a 2D tensor array"
50+
assert isinstance(y_train, torch.Tensor) and y_train.ndim == 2, "ytrain must be a 2D tensor array"
51+
assert x_train.shape[0] == y_train.shape[0], "number of samples in input and output training data must be the same"
52+
assert x_train.device == y_train.device, "input and output training data must be on the same device"
53+
assert isinstance(sigma, float) and sigma > 0, "sigma value must be a positive floating point value"
54+
assert isinstance(basis, str), "basis choice must be a string"
55+
assert basis in ["gaussian", "multiquadric", "inverse"], "basis choice is not a valid choice. choice must be one of [gaussian, multiquadric, inverse]"
56+
57+
self.sigma = sigma
58+
self.basis = basis
59+
self.input_transform = input_transform
60+
self.output_transform = output_transform
61+
self.basis_function = self.set_basis_function(self.basis)
62+
63+
# Set training data for the model
64+
self.xtrain = self.transform_values(x_train, input_transform)
65+
self.ytrain = self.transform_values(y_train, output_transform)
66+
67+
def fit(self) -> torch.Tensor:
68+
69+
"""
70+
Method for fitting the model to the provided training data
71+
72+
Returns
73+
-------
74+
rbf_weights: torch.Tensor
75+
Tensor array with the fitted weights of the model
76+
77+
basis_matrix: torch.Tensor
78+
Tensor array containing the basis matrix for the model
79+
"""
80+
81+
# Basis matrix definition
82+
ns = self.xtrain.shape[0]
83+
basis_matrix = torch.zeros((ns, ns)).to(self.xtrain)
84+
85+
# Assigning values to the basis matrix
86+
for i in range(ns):
87+
for j in range(ns):
88+
# Calculating the value of r
89+
r = torch.linalg.norm(self.xtrain[i] - self.xtrain[j], ord=2)
90+
91+
# Putting value in basis matrix
92+
basis_matrix[i, j] = self.basis_function(r)
93+
94+
basis_matrix_inverse = torch.linalg.pinv(basis_matrix)
95+
self.rbf_weights = torch.matmul(basis_matrix_inverse, self.ytrain)
96+
97+
return self.rbf_weights.clone(), basis_matrix
98+
99+
def predict(self, x: torch.Tensor) -> torch.Tensor:
100+
101+
"""
102+
Method for to predict values from the model for specified points
103+
104+
Parameters
105+
----------
106+
x: torch.Tensor
107+
Data for which predictions must be made. Must be a 2D tensor array
108+
109+
Returns
110+
-------
111+
ypred: torch.Tensor
112+
Predictions of the model on the specified points
113+
"""
114+
# Checks for training and inputs
115+
assert hasattr(self, "rbf_weights"), "model has not been trained. call fit method before predict"
116+
assert isinstance(x, torch.Tensor) and x.ndim == 2, "input data provided to the model must be a 2D tensor array"
117+
assert x.device == self.xtrain.device, "provided input data must be on the same device as the training data"
118+
119+
x = self.transform_values(x, self.input_transform)
120+
121+
# Basis matrix definition
122+
basis_matrix = torch.zeros((x.shape[0], self.xtrain.shape[0])).to(self.xtrain)
123+
124+
# Assigning values to the basis matrix
125+
for i in range(x.shape[0]):
126+
for j in range(self.xtrain.shape[0]):
127+
r = torch.linalg.norm(x[i] - self.xtrain[j], ord=2)
128+
basis_matrix[i, j] = self.basis_function(r)
129+
130+
ypred = torch.matmul(basis_matrix, self.rbf_weights)
131+
if self.output_transform is not None:
132+
ypred = self.output_transform.inverse_transform(ypred)
133+
134+
return ypred
135+
136+
def set_basis_function(self, basis: str):
137+
138+
"""
139+
Method for setting the basis function of the RBF model based on user input
140+
141+
Parameters
142+
----------
143+
basis: str
144+
String that specifies the user choice for the basis function
145+
146+
Returns
147+
-------
148+
basis_function:
149+
Python function that calculates the values of the basis function based on the value of r
150+
"""
151+
if basis == "gaussian":
152+
basis_function = lambda r: torch.exp(-r**2/(2 * self.sigma ** 2))
153+
elif basis == "multiquadric":
154+
basis_function = lambda r: torch.sqrt(r**2 + self.sigma ** 2)
155+
elif basis == "inverse":
156+
basis_function = lambda r: 1 / torch.sqrt(r**2 + self.sigma ** 2)
157+
158+
return basis_function
159+
160+
def transform_values(self, data: torch.Tensor, transform: Normalize | Standardize | None) -> torch.Tensor:
161+
162+
"""
163+
Method for transforming values based on given transform
164+
165+
Parameters
166+
----------
167+
data: torch.Tensor
168+
Data that must be transformed
169+
170+
transform: Normalize | Standardize | None
171+
Object of a transformation class that will be used to transform the data
172+
173+
Returns
174+
-------
175+
transformed_data: torch.Tensor
176+
Data after the application of the transform
177+
If transform is None, then the original data is returned
178+
"""
179+
180+
# Checking if transform is None and applying the transform accordingly
181+
if transform is None:
182+
transformed_data = data
183+
else:
184+
assert isinstance(transform, Normalize) or isinstance(transform, Standardize), "transform must be an instance of Normalize or Standardize class"
185+
transformed_data = transform.transform(data)
186+
187+
return transformed_data

tests/test_rbf.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import unittest, torch, math
2+
from scimlstudio.utils import evaluate_scalar
3+
from scimlstudio.models import RBF
4+
from scimlstudio.utils import Normalize, Standardize
5+
from pyDOE3 import lhs, halton_sequence
6+
7+
# Defining the device and data type
8+
tkwargs = {"device": torch.device("cuda" if torch.cuda.is_available() else "cpu"), "dtype": torch.float64}
9+
10+
class TestRBF(unittest.TestCase):
11+
12+
"""
13+
Class defining the test cases for the RBF model
14+
"""
15+
16+
def test_rbf(self):
17+
18+
# Generating some data for the test cases
19+
# Function used for the test cases is a sinusoidal function (https://www.sfu.ca/~ssurjano/curretal88sin.html)
20+
x_train = torch.linspace(0, 1, 10, **tkwargs)
21+
y_train = torch.sin(2*math.pi*(x_train - 0.1))
22+
23+
x_test = torch.linspace(0, 1, 100, **tkwargs)
24+
y_test = torch.sin(2*math.pi*(x_test - 0.1))
25+
26+
# Creating and fitting RBF models
27+
for basis in ["gaussian", "multiquadric", "inverse"]:
28+
rbf = RBF(x_train.reshape(-1,1), y_train.reshape(-1,1), sigma = 0.1, basis = basis)
29+
_, basis_matrix = rbf.fit()
30+
train_pred = rbf.predict(x_train.reshape(-1,1))
31+
test_pred = rbf.predict(x_test.reshape(-1,1))
32+
r2_value = evaluate_scalar(test_pred.reshape(-1,), y_test, "r2")
33+
34+
torch.testing.assert_close(train_pred.reshape(-1,), y_train, rtol=0, atol=1e-6, check_device=True, check_dtype=True) # interpolation check
35+
assert basis_matrix.ndim == 2 # basis matrix must be 2D
36+
if basis == "gaussian":
37+
assert (torch.diag(basis_matrix) == 1).all() # diagonal of basis matrix must be all 1s if basis is gaussian
38+
assert round(r2_value, 6) < 1
39+
40+
# Creating and fitting RBF models with normalize transformations
41+
input_transform = Normalize(x_train.reshape(-1,1))
42+
output_transform = Normalize(y_train.reshape(-1,1))
43+
for basis in ["gaussian", "multiquadric", "inverse"]:
44+
rbf = RBF(x_train.reshape(-1,1), y_train.reshape(-1,1), sigma = 0.1, basis = basis, input_transform=input_transform, output_transform=output_transform)
45+
_, basis_matrix = rbf.fit()
46+
train_pred = rbf.predict(x_train.reshape(-1,1))
47+
test_pred = rbf.predict(x_test.reshape(-1,1))
48+
r2_value = evaluate_scalar(test_pred.reshape(-1,), y_test, "r2")
49+
50+
torch.testing.assert_close(train_pred.reshape(-1,), y_train, rtol=0, atol=1e-6, check_device=True, check_dtype=True) # interpolation check
51+
assert basis_matrix.ndim == 2 # basis matrix must be 2D
52+
if basis == "gaussian":
53+
assert (torch.diag(basis_matrix) == 1).all() # diagonal of basis matrix must be all 1s if basis is gaussian
54+
assert round(r2_value, 6) < 1
55+
56+
def test_rbf_5D(self):
57+
58+
# Generating some data for the test cases
59+
# Function used for the test cases is the Friedman function (https://www.sfu.ca/~ssurjano/fried.html)
60+
x_train = torch.tensor(halton_sequence(num_points=15, dimension=5), **tkwargs)
61+
y_train = 10*torch.sin(math.pi*x_train[:,0]*x_train[:,1]) + 20 * (x_train[:,2] - 0.5) ** 2 + 10 * x_train[:,3] + 5 * x_train[:,4]
62+
63+
x_test = torch.tensor(lhs(n=5, samples=100, criterion='cm', iterations=100, seed=10), **tkwargs)
64+
y_test = 10*torch.sin(math.pi*x_test[:,0]*x_test[:,1]) + 20 * (x_test[:,2] - 0.5) ** 2 + 10 * x_test[:,3] + 5 * x_test[:,4]
65+
66+
# Creating and fitting RBF models
67+
for basis in ["gaussian", "multiquadric", "inverse"]:
68+
rbf = RBF(x_train, y_train.reshape(-1,1), sigma = 0.1, basis = basis)
69+
_, basis_matrix = rbf.fit()
70+
train_pred = rbf.predict(x_train)
71+
test_pred = rbf.predict(x_test)
72+
r2_value = evaluate_scalar(test_pred.reshape(-1,), y_test, "r2")
73+
74+
torch.testing.assert_close(train_pred.reshape(-1,), y_train, rtol=0, atol=1e-6, check_device=True, check_dtype=True) # interpolation check
75+
assert basis_matrix.ndim == 2 # basis matrix must be 2D
76+
if basis == "gaussian":
77+
assert (torch.diag(basis_matrix) == 1).all() # diagonal of basis matrix must be all 1s if basis is gaussian
78+
assert round(r2_value, 6) < 1
79+
80+
# Creating and fitting RBF models with standardize transformations
81+
input_transform = Standardize(x_train)
82+
output_transform = Standardize(y_train.reshape(-1,1))
83+
for basis in ["gaussian", "multiquadric", "inverse"]:
84+
rbf = RBF(x_train, y_train.reshape(-1,1), sigma = 0.1, basis = basis, input_transform=input_transform, output_transform=output_transform)
85+
_, basis_matrix = rbf.fit()
86+
train_pred = rbf.predict(x_train)
87+
test_pred = rbf.predict(x_test)
88+
r2_value = evaluate_scalar(test_pred.reshape(-1,), y_test, "r2")
89+
90+
torch.testing.assert_close(train_pred.reshape(-1,), y_train, rtol=0, atol=1e-6, check_device=True, check_dtype=True) # interpolation check
91+
assert basis_matrix.ndim == 2 # basis matrix must be 2D
92+
if basis == "gaussian":
93+
assert (torch.diag(basis_matrix) == 1).all() # diagonal of basis matrix must be all 1s if basis is gaussian
94+
assert round(r2_value, 6) < 1
95+
96+
def test_inputs(self):
97+
98+
# Generate dummy data
99+
x_random = torch.rand(15)
100+
y_random = torch.rand(15)
101+
102+
with self.assertRaises(Exception):
103+
_ = RBF(x_random.reshape(-1,1), y_random, sigma = 0.1)
104+
105+
with self.assertRaises(Exception):
106+
_ = RBF(x_random, y_random.reshape(-1,1), sigma = 0.1)
107+
108+
with self.assertRaises(Exception):
109+
_ = RBF(x_random.reshape(-1,1), y_random.reshape(-1,1), sigma = -0.1)
110+
111+
with self.assertRaises(Exception):
112+
_ = RBF(x_random.reshape(-1,1), y_random.reshape(-1,1), sigma = 0.1, basis="linear")
113+
114+
_ = RBF(x_random.reshape(-1,1), y_random.reshape(-1,1), sigma = 0.1, basis="inverse")
115+
116+
if __name__ == '__main__':
117+
unittest.main()
118+

0 commit comments

Comments
 (0)