-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathmeta.py
More file actions
44 lines (39 loc) · 1.66 KB
/
meta.py
File metadata and controls
44 lines (39 loc) · 1.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from torch.optim.sgd import SGD
class MetaSGD(SGD):
def __init__(self, net, *args, **kwargs):
super(MetaSGD, self).__init__(*args, **kwargs)
self.net = net
def set_parameter(self, current_module, name, parameters):
if '.' in name:
name_split = name.split('.')
module_name = name_split[0]
rest_name = '.'.join(name_split[1:])
for children_name, children in current_module.named_children():
if module_name == children_name:
self.set_parameter(children, rest_name, parameters)
break
else:
current_module._parameters[name] = parameters
def meta_step(self, grads):
group = self.param_groups[0]
weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']
lr = group['lr']
for (name, parameter), grad in zip(self.net.named_parameters(), grads):
parameter.detach_()
if weight_decay != 0:
grad_wd = grad.add(parameter, alpha=weight_decay)
else:
grad_wd = grad
if momentum != 0 and 'momentum_buffer' in self.state[parameter]:
buffer = self.state[parameter]['momentum_buffer']
grad_b = buffer.mul(momentum).add(grad_wd, alpha=1-dampening)
else:
grad_b = grad_wd
if nesterov:
grad_n = grad_wd.add(grad_b, alpha=momentum)
else:
grad_n = grad_b
self.set_parameter(self.net, name, parameter.add(grad_n, alpha=-lr))