-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhello_caffe2.py
More file actions
55 lines (42 loc) · 1.61 KB
/
hello_caffe2.py
File metadata and controls
55 lines (42 loc) · 1.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import numpy as np
from caffe2.python import cnn, workspace
from caffe2.python import net_drawer
def gen_graph(net, output):
graph = net_drawer.GetPydotGraph(net, rankdir = 'LR')
graph.write_png(output)
# Create input data
data = np.random.rand(16, 100).astype(np.float32)
# Create Labels for the data as integers [0, 9]
label = (np.random.rand(16) * 10).astype(np.int32)
# Feed data and label
workspace.FeedBlob('data', data)
workspace.FeedBlob('label', label)
# Create a model using model helper
m = cnn.CNNModelHelper(name = 'hello caffe2')
fc_1 = m.FC('data', 'fc1', dim_in = 100, dim_out = 10)
pred = m.Sigmoid(fc_1, 'pred') # or: m.Sigmoid('fc_1', 'pred')
[softmax, loss] = m.SoftmaxWithLoss([pred, 'label'], ['softmax', 'loss'])
m.AddGradientOperators([loss])
# Debug nets via printing
print(str(m.net.Proto()))
print(str(m.param_init_net.Proto()))
# Debug nets via visulizing
gen_graph(m.net, 'net.png')
gen_graph(m.param_init_net, 'param_init_net.png')
# Parameters initialization
workspace.RunNetOnce(m.param_init_net)
# You can print out if you are interested
# print(workspace.FetchBlob('fc1_b'))
# print(workspace.FetchBlob('fc1_w'))
# Create the actual training net
workspace.CreateNet(m.net)
# Run 100 iterations
for j in range(100):
workspace.RunNet(m.name, 10) # run for 10 times
data = np.random.rand(16, 100).astype(np.float32)
label = (np.random.rand(16) * 10).astype(np.int32)
workspace.FeedBlob('data', data)
workspace.FeedBlob('label', label)
# You can print them out if you are interested in
# print(workspace.FetchBlob('softmax'))
# print(workspace.FetchBlob('loss'))