-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtrain_unstructured.py
More file actions
104 lines (86 loc) · 4.17 KB
/
train_unstructured.py
File metadata and controls
104 lines (86 loc) · 4.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import argparse
import os
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from datasets import get_dataset, DATASETS
from architectures_unstructured import ARCHITECTURES, get_architecture
from torch.optim import SGD, Optimizer
from torch.optim.lr_scheduler import StepLR, MultiStepLR
import time
import datetime
from train_utils import AverageMeter, accuracy, init_logfile, log
from utils import train, test
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('dataset', type=str, choices=DATASETS)
parser.add_argument('arch', type=str, choices=ARCHITECTURES)
parser.add_argument('outdir', type=str, help='folder to save model and training log)')
parser.add_argument('--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=160, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--batch', default=256, type=int, metavar='N',
help='batchsize (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
help='initial learning rate', dest='lr')
parser.add_argument('--lr_step_size', type=int, default=30,
help='How often to decrease learning by gamma.')
parser.add_argument('--lr_milestones', default = [80, 120])
parser.add_argument('--gamma', type=float, default=0.1,
help='LR is multiplied by gamma on schedule.')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--noise_sd', default=0.0, type=float,
help="standard deviation of Gaussian noise for weight augmentation")
parser.add_argument('--gpu', default=0, type=int,
help='id(s) for CUDA_VISIBLE_DEVICES')
parser.add_argument('--print-freq', default=200, type=int,
metavar='N', help='print frequency (default: 10)')
parser.add_argument('--stride', type=int, default=1, help='conv1 stride')
args = parser.parse_args()
def main():
if not os.path.exists(args.outdir):
os.makedirs(args.outdir)
device = torch.device("cuda")
torch.cuda.set_device(args.gpu)
train_dataset = get_dataset(args.dataset, 'train')
test_dataset = get_dataset(args.dataset, 'test')
pin_memory = (args.dataset == "imagenet")
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch,
num_workers=args.workers, pin_memory=pin_memory)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch,
num_workers=args.workers, pin_memory=pin_memory)
model = get_architecture(args.arch, args.dataset, device, args)
for name, param in model.named_parameters():
if 'alpha' in name:
param.requires_grad = False
criterion = nn.CrossEntropyLoss().to(device)
optimizer = SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
scheduler = MultiStepLR(optimizer, milestones=[args.epochs // 2, 3*args.epochs // 4], last_epoch=-1)
best_top1 = 0.
for epoch in range(args.epochs):
# training
train(train_loader, model, criterion, optimizer, epoch, device)
# validation
cur_step = (epoch+1) * len(train_loader)
_, top1, _ = test(test_loader, model, criterion, device, cur_step)
scheduler.step()
# save
if best_top1 < top1:
best_top1 = top1
is_best = True
else:
is_best = False
if is_best:
torch.save({
'arch': args.arch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
}, os.path.join(args.outdir, 'best_checkpoint.pth.tar'))
print("")
print("Best model's validation acc: {:.4%}".format(best_top1 / 100))
if __name__ == "__main__":
main()