-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.cpp
More file actions
70 lines (58 loc) · 2.27 KB
/
main.cpp
File metadata and controls
70 lines (58 loc) · 2.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#include "mnist_loader.h"
#include <iostream>
#include "Matrix.h"
#include "NeuralNetwork.h"
#include <algorithm>
#include <random>
#include <thread>
Matrix one_hot(int lbl) {
Matrix result(10, 1);
result.zeroes();
result.vals[lbl] = 1.0;
return result;
}
int main() {
// Load training and test data
auto train_data = load_images("data/train-images.idx3-ubyte", "data/train-labels.idx1-ubyte");
auto test_data = load_images("data/t10k-images.idx3-ubyte", "data/t10k-labels.idx1-ubyte");
NeuralNetwork model;
int batch_size = 100;
float lr = 0.01;
int num_threads = 4;
int chunk_size = batch_size / num_threads;
int seed = 32;
for (int epoch = 0; epoch < 10; ++epoch) {
std::shuffle(train_data.begin(), train_data.end(), std::default_random_engine(seed + epoch));
for (int i = 0; i + batch_size <= train_data.size(); i += batch_size) {
std::vector<NeuralNetwork> thread_models(num_threads, model);
std::vector<std::thread> threads;
for (int t = 0; t < num_threads; ++t) {
threads.emplace_back([&, t]() {
int start = i + t * chunk_size;
int end = start + chunk_size;
for (int j = start; j < end; ++j) {
Matrix pred = thread_models[t].forward(train_data[j].img);
Matrix dz = pred - one_hot(train_data[j].label);
thread_models[t].backward(dz);
}
});
}
for (auto& thread : threads) thread.join();
model.zero_gradients();
for (int t = 0; t < num_threads; ++t) {
model.accumulate_gradients(thread_models[t]);
}
model.apply(batch_size, lr);
}
// Evaluate accuracy after each epoch
int correct = 0;
for (int j = 0; j < test_data.size(); ++j) {
Matrix pred = model.forward(test_data[j].img);
if (pred.argmax() == test_data[j].label) correct++;
}
float acc = 100.0f * correct / test_data.size();
std::cout << "Epoch " << epoch + 1 << " Accuracy: " << acc << "%" << std::endl;
}
std::cout << "Training Complete & Testing Complete!" << std::endl;
return 0;
}