-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathextractor.py
More file actions
64 lines (52 loc) · 1.91 KB
/
extractor.py
File metadata and controls
64 lines (52 loc) · 1.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os, sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
os.chdir(os.getcwd())
models_dir = 'Bacteria_TL'
sys.path.append(models_dir)
from resnet import ResidualBlock
class Extractor(nn.Module):
def __init__(self, hidden_sizes, num_blocks, input_dim=1174, in_channels=64):
super(Extractor, self).__init__()
assert len(num_blocks) == len(hidden_sizes)
self.input_dim = input_dim
self.in_channels = in_channels
self.conv1 = nn.Conv1d(1, self.in_channels, kernel_size=5, stride=1,
padding=2, bias=False)
self.bn1 = nn.BatchNorm1d(self.in_channels)
# Flexible number of residual encoding layers
layers = []
strides = [1] + [2] * (len(hidden_sizes) - 1)
for idx, hidden_size in enumerate(hidden_sizes):
layers.append(self._make_layer(hidden_size, num_blocks[idx],
stride=strides[idx]))
self.encoder = nn.Sequential(*layers)
self.z_dim = self._get_encoding_size()
#NO LINEAR LAYER
def encode(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = self.encoder(x)
#z = x.view(x.size(0), -1)
#return z
return x
def forward(self, x):
z = self.encode(x)
return z
def _make_layer(self, out_channels, num_blocks, stride=1):
strides = [stride] + [1] * (num_blocks - 1)
blocks = []
for stride in strides:
blocks.append(ResidualBlock(self.in_channels, out_channels,
stride=stride))
self.in_channels = out_channels
return nn.Sequential(*blocks)
def _get_encoding_size(self):
"""
Returns the dimension of the encoded input.
"""
temp = Variable(torch.rand(1, 1, self.input_dim))
z = self.encode(temp)
z_dim = z.data.size(1)
return z_dim