-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathNeuralNetwork.cpp
More file actions
39 lines (26 loc) · 876 Bytes
/
NeuralNetwork.cpp
File metadata and controls
39 lines (26 loc) · 876 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#include "NeuralNetwork.h"
NeuralNetwork::NeuralNetwork():l1(DenseLayer(784, 128)), l2(DenseLayer(128,10)) {}
Matrix NeuralNetwork::forward(const Matrix &input) {
Matrix z1 = l1.forward(input);
Matrix a1 = ReLU(z1);
Matrix z2 = l2.forward(a1);
Matrix output = softmax(z2);
return output;
}
void NeuralNetwork::backward(const Matrix &dz) {
Matrix dz1 = l2.backward(dz);
Matrix relu_grad = elementwise(dz1, deriv_ReLU(l1.out()));
Matrix dz0 = l1.backward(relu_grad);
}
void NeuralNetwork::apply(float batch_size, float lr) {
l1.apply(batch_size, lr);
l2.apply(batch_size, lr);
}
void NeuralNetwork::zero_gradients() {
l1.zero_gradients();
l2.zero_gradients();
}
void NeuralNetwork::accumulate_gradients(const NeuralNetwork& other) {
l1.accumulate_gradients(other.l1);
l2.accumulate_gradients(other.l2);
}