Skip to content

Commit

Permalink
Improve optimizer serialization
Browse files Browse the repository at this point in the history
Also, add optimizer.load_state_dict
  • Loading branch information
apaszke authored and soumith committed Jan 24, 2017
1 parent 3975a26 commit ecfcf39
Show file tree
Hide file tree
Showing 11 changed files with 105 additions and 12 deletions.
46 changes: 44 additions & 2 deletions test/test_optim.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import unittest
import functools
from copy import deepcopy
import torch
import torch.optim as optim
import torch.legacy.optim as old_optim
Expand Down Expand Up @@ -59,7 +61,7 @@ def eval():
def _test_basic_cases_template(self, weight, bias, input, constructor):
weight = Variable(weight, requires_grad=True)
bias = Variable(bias, requires_grad=True)
input = Variable(input, requires_grad=False)
input = Variable(input)
optimizer = constructor(weight, bias)

def fn():
Expand All @@ -74,10 +76,50 @@ def fn():
initial_value = fn().data[0]
for i in range(200):
optimizer.step(fn)

self.assertLess(fn().data[0], initial_value)

def _test_state_dict(self, weight, bias, input, constructor):
weight = Variable(weight, requires_grad=True)
bias = Variable(bias, requires_grad=True)
input = Variable(input)

def fn_base(optimizer, weight, bias):
optimizer.zero_grad()
loss = (weight.mv(input) + bias).pow(2).sum()
loss.backward()
return loss

optimizer = constructor(weight, bias)
fn = functools.partial(fn_base, optimizer, weight, bias)

# Prime the optimizer
for i in range(20):
optimizer.step(fn)
# Clone the weights and construct new optimizer for them
weight_c = Variable(weight.data.clone(), requires_grad=True)
bias_c = Variable(bias.data.clone(), requires_grad=True)
optimizer_c = constructor(weight_c, bias_c)
fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c)
# Load state dict
state_dict = deepcopy(optimizer.state_dict())
state_dict_c = deepcopy(optimizer.state_dict())
optimizer_c.load_state_dict(state_dict_c)
# Run both optimizations in parallel
for i in range(20):
optimizer.step(fn)
optimizer_c.step(fn_c)
self.assertEqual(weight, weight_c)
self.assertEqual(bias, bias_c)
# Make sure state dict wasn't modified
self.assertEqual(state_dict, state_dict_c)

def _test_basic_cases(self, constructor, ignore_multidevice=False):
self._test_state_dict(
torch.randn(10, 5),
torch.randn(10),
torch.randn(5),
constructor
)
self._test_basic_cases_template(
torch.randn(10, 5),
torch.randn(10),
Expand Down
2 changes: 1 addition & 1 deletion torch/optim/adadelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def step(self, closure=None):
for group in self.param_groups:
for p in group['params']:
grad = p.grad.data
state = self.state[id(p)]
state = self.state[p]

# State initialization
if len(state) == 0:
Expand Down
2 changes: 1 addition & 1 deletion torch/optim/adagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def step(self, closure=None):
for group in self.param_groups:
for p in group['params']:
grad = p.grad.data
state = self.state[id(p)]
state = self.state[p]

# State initialization
if len(state) == 0:
Expand Down
2 changes: 1 addition & 1 deletion torch/optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def step(self, closure=None):
for group in self.param_groups:
for p in group['params']:
grad = p.grad.data
state = self.state[id(p)]
state = self.state[p]

# State initialization
if len(state) == 0:
Expand Down
2 changes: 1 addition & 1 deletion torch/optim/adamax.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def step(self, closure=None):
for group in self.param_groups:
for p in group['params']:
grad = p.grad.data
state = self.state[id(p)]
state = self.state[p]

# State initialization
if len(state) == 0:
Expand Down
2 changes: 1 addition & 1 deletion torch/optim/asgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def step(self, closure=None):
for group in self.param_groups:
for p in group['params']:
grad = p.grad.data
state = self.state[id(p)]
state = self.state[p]

# State initialization
if len(state) == 0:
Expand Down
1 change: 0 additions & 1 deletion torch/optim/lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def step(self, closure):
state.setdefault('func_evals', 0)
state.setdefault('n_iter', 0)


# evaluate initial f(x) and df/dx
orig_loss = closure()
loss = orig_loss.data[0]
Expand Down
54 changes: 53 additions & 1 deletion torch/optim/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from collections import defaultdict

import torch
from copy import deepcopy
from itertools import chain
from torch.autograd import Variable

required = object()
Expand Down Expand Up @@ -49,6 +51,9 @@ def __init__(self, params, defaults):

for group in self.param_groups:
for param in group['params']:
if not isinstance(param, Variable):
raise TypeError("optimizer can only optimize Variables, "
"but one of the params is " + torch.typename(param))
if not param.requires_grad:
raise ValueError("optimizing a parameter that doesn't "
"require gradients")
Expand All @@ -70,7 +75,54 @@ def state_dict(self):
differs between optimizer classes.
* param_groups - a dict containig all parameter groups
"""
return self.__getstate__()
# Save ids instead of Variables
def pack_group(group):
packed = {k: v for k, v in group.items() if k != 'params'}
packed['params'] = [id(p) for p in group['params']]
return packed
param_groups = [pack_group(g) for g in self.param_groups]
# Remap state to use ids as keys
packed_state = {(id(k) if isinstance(k, Variable) else k): v
for k, v in self.state.items()}
return {
'state': packed_state,
'param_groups': param_groups,
}

def load_state_dict(self, state_dict):
"""Loads the optimizer state.
Arguments:
state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`.
"""
# deepcopy, to be consistent with module API
state_dict = deepcopy(state_dict)
# Validate the state_dict
groups = self.param_groups
saved_groups = state_dict['param_groups']

if len(groups) != len(saved_groups):
raise ValueError("loaded state dict has a different number of "
"parameter groups")
param_lens = (len(g['params']) for g in groups)
saved_lens = (len(g['params']) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
raise ValueError("loaded state dict contains a parameter group "
"that doesn't match the size of optimizer's group")

# Update the state
id_map = {old_id: p for old_id, p in
zip(chain(*(g['params'] for g in saved_groups)),
chain(*(g['params'] for g in groups)))}
self.state = {id_map.get(k, k): v for k, v in state_dict['state'].items()}

# Update parameter groups, setting their 'params' value
def update_group(group, new_group):
new_group['params'] = group['params']
return new_group
self.param_groups = [
update_group(g, ng) for g, ng in zip(groups, saved_groups)]

def zero_grad(self):
"""Clears the gradients of all optimized :class:`Variable` s."""
Expand Down
2 changes: 1 addition & 1 deletion torch/optim/rmsprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def step(self, closure=None):
for group in self.param_groups:
for p in group['params']:
grad = p.grad.data
state = self.state[id(p)]
state = self.state[p]

# State initialization
if len(state) == 0:
Expand Down
2 changes: 1 addition & 1 deletion torch/optim/rprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def step(self, closure=None):
for group in self.param_groups:
for p in group['params']:
grad = p.grad.data
state = self.state[id(p)]
state = self.state[p]

# State initialization
if len(state) == 0:
Expand Down
2 changes: 1 addition & 1 deletion torch/optim/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def step(self, closure=None):
if weight_decay != 0:
d_p.add_(weight_decay, p.data)
if momentum != 0:
param_state = self.state[id(p)]
param_state = self.state[p]
if 'momentum_buffer' not in param_state:
param_state['momentum_buffer'] = d_p.clone()
else:
Expand Down

0 comments on commit ecfcf39

Please sign in to comment.