-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathnn_model.py
More file actions
124 lines (103 loc) · 4.49 KB
/
nn_model.py
File metadata and controls
124 lines (103 loc) · 4.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
from torch import nn, Tensor
from torch.nn import ModuleList
import torch.nn.functional as F
from torch_geometric.nn import Linear, APPNP, SplineConv, GATConv
from torch_geometric.typing import Adj, OptTensor
ACTIVATION_FUNCTIONS = {
'relu': nn.ReLU(),
'leakyrelu': nn.LeakyReLU(negative_slope=0.2),
'I': nn.LeakyReLU(negative_slope=1),
'elu': nn.ELU(),
'tanh': nn.Tanh(),
'prelu': nn.PReLU()
}
def get_activation(activation):
if activation not in ACTIVATION_FUNCTIONS:
raise ValueError(f'Activation {activation} is not a supported activation function. Supported activation: {", ".join(ACTIVATION_FUNCTIONS.keys())}' )
return ACTIVATION_FUNCTIONS[activation]
def build_multi_layers(model, in_channels, out_channels, num_layers, n_units,
**kwargs):
model_list = ModuleList()
if num_layers == 1:
return ModuleList([model(in_channels, out_channels, **kwargs)])
else:
model_list.append(model(in_channels, n_units[0], **kwargs))
for i in range(1, num_layers-1):
if i == num_layers-2:
model_list.append(model(n_units[i-1] * kwargs['heads'] if model == GATConv
else n_units[i-1], out_channels, **kwargs))
else:
model_list.append(model(n_units[i-1] * kwargs['heads'] if model == GATConv
else n_units[i-1], n_units[i], **kwargs))
return model_list
class APPNPModel(nn.Module):
def __init__(self, kwargs):
super().__init__()
for key, value in kwargs.items():
setattr(self, key, value)
self.model_list = build_multi_layers(model=Linear, in_channels=self.in_channels,
out_channels=self.out_channels,
num_layers=self.num_layers, n_units=self.n_units)
self.prop = APPNP(K=self.K, alpha=self.alpha)
self.activation = get_activation(self.activation)
self.reset_parameters()
def reset_parameters(self):
for layer in self.model_list:
layer.reset_parameters()
def forward(self,
x: Tensor,
edge_index: Adj,
edge_attr : OptTensor = None) -> Tensor:
for linear in self.model_list:
x = F.dropout(x, p=self.dropout, training=self.training)
x = linear(x)
x = self.activation(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.prop(x, edge_index)
return x
class SplineconvModel(nn.Module):
def __init__(self, kwargs):
super().__init__()
for key, value in kwargs.items():
setattr(self, key, value)
self.model_list = build_multi_layers(SplineConv, self.in_channels, self.out_channels,
self.num_layers, self.n_units,
dim=1, kernel_size=self.kernel_size)
self.activation = get_activation(self.activation)
self.reset_parameters()
def reset_parameters(self):
for layer in self.model_list:
layer.reset_parameters()
def forward(self,
x: Tensor,
edge_index: Adj,
edge_attr : OptTensor = None) -> Tensor:
for conv in self.model_list:
x = conv(x, edge_index, edge_attr)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.activation(x)
return x
class GATModel(nn.Module):
def __init__(self, kwargs):
super().__init__()
for key, value in kwargs.items():
setattr(self, key, value)
self.model_list = build_multi_layers(GATConv, self.in_channels, self.out_channels,
self.num_layers, self.n_units,
heads=self.heads,
dropout=self.dropout)
self.activation = get_activation(self.activation)
self.reset_parameters()
def reset_parameters(self):
for layer in self.model_list:
layer.reset_parameters()
def forward(self,
x: Tensor,
edge_index: Adj,
edge_attr: OptTensor = None) -> Tensor:
for i, conv in enumerate(self.model_list):
x = conv(x, edge_index, edge_attr)
if i != len(self.model_list)-1:
x = self.activation(x)
return x