-
Notifications
You must be signed in to change notification settings - Fork 1
/
momo.py
194 lines (153 loc) · 7.04 KB
/
momo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
"""
Implements the MoMo algorithm.
Authors: Fabian Schaipp, Ruben Ohana, Michael Eickenberg, Aaron Defazio, Robert Gower
"""
import torch
import warnings
from math import sqrt
from .types import Params, LossClosure, OptFloat
class Momo(torch.optim.Optimizer):
def __init__(self,
params: Params,
lr: float=1.0,
weight_decay: float=0,
beta: float=0.9,
lb: float=0,
bias_correction: bool=False,
use_fstar: bool=False) -> None:
"""
MoMo optimizer
Parameters
----------
params : Params
Model parameters.
lr : float, optional
Learning rate, by default 1.
weight_decay : float, optional
Weight decay parameter, by default 0.
beta : float, optional
Momentum parameter, should be in [0,1), by default 0.9.
lb : float, optional
Lower bound for loss. Zero is often a good guess.
If no good estimate for the minimal loss value is available, you can set use_fstar=True.
By default 0.
bias_correction : bool, optional
Which averaging scheme is used, see details in the paper. By default False.
use_fstar : bool, optional
Whether to use online estimation of loss lower bound.
Can be used if no good estimate is available, by default False.
"""
if lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if weight_decay < 0.0:
raise ValueError("Invalid weight decay: {}".format(weight_decay))
if (beta < 0.0) or (beta > 1.0):
raise ValueError("Invalid beta parameter: {}".format(beta))
defaults = dict(lr=lr, weight_decay=weight_decay)
super(Momo, self).__init__(params, defaults)
self.beta = beta
self.lb = lb
self._initial_lb = lb
self.bias_correction = bias_correction
self.use_fstar = use_fstar
# Initialization
self._number_steps = 0
self.state['step_size_list'] = list() # for storing the adaptive step size term
return
def step(self, closure: LossClosure=None, loss=None) -> OptFloat:
"""
Performs a single optimization step.
Parameters
----------
closure : LossClosure, optional
A callable that evaluates the model (possibly with backprop) and returns the loss, by default None.
loss : torch.tensor, optional
The loss tensor. Use this when the backward step has already been performed. By default None.
Returns
-------
(Stochastic) Loss function value.
"""
assert (closure is not None) or (loss is not None), "Either loss tensor or closure must be passed."
assert (closure is None) or (loss is None), "Pass either the loss tensor or the closure, not both."
if closure is not None:
with torch.enable_grad():
loss = closure()
if len(self.param_groups) > 1:
warnings.warn("More than one param group. step_size_list contains adaptive term of last group.")
warnings.warn("More than one param group. This might cause issues for the step method.")
self._number_steps += 1
beta = self.beta
###### Preliminaries
if self._number_steps == 1:
if self.bias_correction:
self.loss_avg = 0.
else:
self.loss_avg = loss.detach().clone()
self.loss_avg = beta*self.loss_avg + (1-beta)*loss.detach()
if self.bias_correction:
rho = 1-beta**self._number_steps # must be after incrementing k
else:
rho = 1
_dot = 0.
_gamma = 0.
_norm = 0.
############################################################
# Notation
# d_k: p.grad_avg, gamma_k: _gamma, \bar f_k: self.loss_avg
for group in self.param_groups:
for p in group['params']:
grad = p.grad.data.detach()
state = self.state[p]
# Initialize EMA
if self._number_steps == 1:
if self.bias_correction:
state['grad_avg'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach()
state['grad_dot_w'] = torch.zeros(1).to(p.device)
else:
# Exponential moving average of gradients
state['grad_avg'] = grad.clone()
# Exponential moving average of inner product <grad, weight>
state['grad_dot_w'] = torch.sum(torch.mul(p.data, grad))
grad_avg, grad_dot_w = state['grad_avg'], state['grad_dot_w']
grad_avg.mul_(beta).add_(grad, alpha=1-beta)
grad_dot_w.mul_(beta).add_(torch.sum(torch.mul(p.data, grad)), alpha=1-beta)
_dot += torch.sum(torch.mul(p.data, grad_avg))
_gamma += grad_dot_w
_norm += torch.sum(torch.mul(grad_avg, grad_avg))
#################
# Update
for group in self.param_groups:
lr = group['lr']
lmbda = group['weight_decay']
if self.use_fstar:
cap = ((1+lr*lmbda)*self.loss_avg + _dot - (1+lr*lmbda)*_gamma).item()
# Reset
if cap < (1+lr*lmbda)*rho*self.lb:
self.lb = cap/(2*(1+lr*lmbda)*rho)
self.lb = max(self.lb, self._initial_lb) # safeguard
### Compute adaptive step size
if lmbda > 0:
nom = (1+lr*lmbda)*(self.loss_avg - rho*self.lb) + _dot - (1+lr*lmbda)*_gamma
t1 = max(nom, 0.)/_norm
else:
t1 = max(self.loss_avg - rho*self.lb + _dot - _gamma, 0.)/_norm
t1 = t1.item() # make scalar
tau = min(lr/rho, t1) # step size
### Update lb estimator
if self.use_fstar:
h = (self.loss_avg + _dot - _gamma).item()
self.lb = ((h - (1/2)*tau*_norm)/rho).item()
self.lb = max(self.lb, self._initial_lb) # safeguard
### Update params
for p in group['params']:
state = self.state[p]
grad_avg = state['grad_avg']
p.data.add_(other=grad_avg, alpha=-tau)
if lmbda > 0:
p.data.div_(1+lr*lmbda)
############################################################
if self.use_fstar:
self.state['fstar'] = self.lb
# If you want to track the adaptive step size term, activate the following line.
# self.state['step_size_list'].append(t1)
return loss