-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathPredictionUI_MNIST_Conv.py
More file actions
108 lines (80 loc) · 3.24 KB
/
PredictionUI_MNIST_Conv.py
File metadata and controls
108 lines (80 loc) · 3.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import gradio as gr
# DEFINE MODEL (required for loading weights)
class ConvNet(nn.Module):
def __init__(self, input_channels, num_classes):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=input_channels, out_channels=32, kernel_size=3)
self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3)
self.max_pool = nn.MaxPool2d(kernel_size=2)
self.relu = nn.ReLU()
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, num_classes)
def forward(self, x):
#######################
# Convolutional Part
#######################
#print(f'Input dims: {x.shape}')
x = self.conv1(x) # (N, 1, 28, 28) -> (N, 32, 26, 26)
#print(f'After conv1 {x.shape}')
x = self.relu(x) # no dim change
x = self.conv2(x) # (N, 32, 26, 26) -> (N, 64, 24, 24)
#print(f'After conv2 {x.shape}')
x = self.relu(x) # no dim change
x = self.max_pool(x) # (N, 64, 24, 24) -> (N, 64, 12, 12)
#print(f'After maxpool {x.shape}')
#######################
#######################
#######################
## Fully Connected Part
#######################
x = torch.flatten(x, 1) # (N, 64, 12, 12) -> (N, 64*12*12) -> (N, 9216)
x = self.fc1(x) # (N, 9216) -> (N, 128)
x = self.relu(x) # no dim change
logits = self.fc2(x) # (N, 128) - (N, 10)
#######################
#######################
return logits
# LOAD PRE-TRAINED MODEL
model = ConvNet(
input_channels=1, # 1 for grayscale images
num_classes=10 # 10 for MNIST
)
checkpoint = torch.load('convnet_mnist_checkpoint.pt')
model.load_state_dict(checkpoint['model_state_dict'], strict=True)
# SWITCH MODEL TO PREDICTION ONLY MODE
model.eval()
# Same transforms that used in training
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Function for processing input image
# Since we're only interested in prediction, we disable the gradient computations
@torch.no_grad()
def recognize_digit(image):
#print(type(image))
#print(image.shape)
image_tensor = transform(image) # 1, 28, 28
image_tensor = image_tensor.unsqueeze(0) # add dummy batch dimension 1, 1, 28, 28
#print(image_tensor.shape)
logits = model(image_tensor)
preds = F.softmax(logits, dim=1) # convert to probabilities
preds_list = preds.tolist()[0] # take the first batch (there is only one)
#print(preds_list)
return {str(i): preds_list[i] for i in range(10)}
# UI for displaying output class probabilities
output_labels = gr.outputs.Label(num_top_classes=3)
# Main UI that contains everything
interface = gr.Interface(
fn=recognize_digit,
inputs='sketchpad',
outputs=output_labels,
title='MNIST Drawing Application (ConvNet)',
description='Draw a number 0 through 9 on the sketchpad, and click submit to see the model predictions.',
)
if __name__ == '__main__':
interface.launch()