-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathmodelfactory.py
More file actions
59 lines (43 loc) · 1.7 KB
/
modelfactory.py
File metadata and controls
59 lines (43 loc) · 1.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import models
import torch.nn as nn
def ModelFactory(args):
if args.task == 'mnist':
args.depth = 1
args.num_classes = 10
if args.task == 'cifar10':
args.depth = 3
args.num_classes = 10
if args.task == 'cifar100':
args.depth = 3
args.num_classes = 100
args.size = args.img_width
if args.task == 'tiny':
args.depth = 4
args.num_classes = 200
args.size = 64
if args.task == 'imagenet':
args.num_classes = 1000
if args.task == 'imagenet100':
args.num_classes = 100
if args.model == 'resnet20':
train_net = models.resnet20(num_classes = args.num_classes)
if args.model == 'resnet18':
train_net = models.resnet18(num_classes = args.num_classes)
if args.model == 'efficient':
train_net = models.efficientnet_b0(num_classes = args.num_classes)
if args.model == 'resnet50':
train_net = models.resnet50(num_classes = args.num_classes)
if args.model == 'resnet56':
train_net = models.resnet56(num_classes = args.num_classes)
if args.model == 'densenet121':
train_net = models.densenet121(num_classes = args.num_classes)
if args.model == 'lenet':
if args.task == 'mnist':
train_net = models.LeNetMNIST(num_classes = args.num_classes)
else:
train_net = models.LeNet(num_classes = args.num_classes)
if args.model == 'convnet':
train_net = models.ConvNet(3, args.num_classes, 128, args.depth, 'relu', 'instancenorm', 'avgpooling', im_size=(args.size, args.size))
if args.model == 'alexnet':
train_net = models.AlexNet(num_classes=args.num_classes)
return train_net