Skip to content

Commit

Permalink
Replace Variable.volatile with torch.no_grad() (pytorch#3970)
Browse files Browse the repository at this point in the history
This removes volatile from Variable. The functionality is mostly
replaced by a global (thread-local) flag, which is controlled by
torch.set_grad_enabled() and the context manager torch.no_grad().

In C++, the flag is exposed through GradMode::is_enabled() and GradMode::set_enabled()

Fixes pytorch#3627
  • Loading branch information
colesbury committed Dec 18, 2017
1 parent 0876bab commit d605058
Show file tree
Hide file tree
Showing 52 changed files with 552 additions and 528 deletions.
2 changes: 2 additions & 0 deletions aten/src/ATen/Declarations.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -2276,6 +2276,7 @@
name: zero_
cname: zero
return: self
aten_sparse: True
arguments:
- THTensor* self
]]
Expand Down Expand Up @@ -3931,6 +3932,7 @@
name: assign_
cname: copy
cpu_half: True
aten_sparse: True
return: self
arguments:
- THTensor* self
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,7 @@ def check_file(f):
"torch/csrc/jit/passes/onnx/peephole.cpp",
"torch/csrc/jit/generated/aten_dispatch.cpp",
"torch/csrc/autograd/init.cpp",
"torch/csrc/autograd/grad_mode.cpp",
"torch/csrc/autograd/engine.cpp",
"torch/csrc/autograd/function.cpp",
"torch/csrc/autograd/variable.cpp",
Expand Down
183 changes: 98 additions & 85 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
import unittest
import random
import warnings
from copy import deepcopy
from collections import OrderedDict
from itertools import product
Expand Down Expand Up @@ -56,13 +57,10 @@ def _function_test(self, cls):
y = Variable(torch.randn(5, 5), requires_grad=True)
result = cls.apply(x, 2, y)
go = Variable(torch.ones(1), requires_grad=True)
result.sum().backward(go)
result.sum().backward(go, create_graph=True)

self.assertEqual(x.grad.data, y.data + torch.ones(5, 5))
self.assertEqual(y.grad.data, x.data + torch.ones(5, 5) * 2)

self.assertFalse(x.grad.volatile)
self.assertFalse(y.grad.volatile)
self.assertIsNotNone(x.grad.grad_fn)
self.assertIsNotNone(y.grad.grad_fn)

Expand Down Expand Up @@ -93,11 +91,11 @@ def backward(ctx, grad_output):
y_grad_desc = graph_desc(y.grad.grad_fn)
self.assertEqual(
x_grad_desc,
'Identity(AddBackward1(ExpandBackward(AccumulateGrad()), '
'CloneBackward(AddBackward1(ExpandBackward(AccumulateGrad()), '
'MulBackward1(ExpandBackward(AccumulateGrad()), AccumulateGrad())))')
self.assertEqual(
y_grad_desc,
'Identity(AddBackward1(MulBackward0(ExpandBackward(AccumulateGrad())), '
'CloneBackward(AddBackward1(MulBackward0(ExpandBackward(AccumulateGrad())), '
'MulBackward1(ExpandBackward(AccumulateGrad()), AccumulateGrad())))')

def test_once_differentiable(self):
Expand All @@ -122,9 +120,9 @@ def backward(ctx, grad_output):

x, y = self._function_test(MyFunction)
self.assertEqual(graph_desc(x.grad.grad_fn),
'Identity(Error(AccumulateGrad(), None, AccumulateGrad()))')
'CloneBackward(Error(AccumulateGrad(), None, AccumulateGrad()))')
self.assertEqual(graph_desc(y.grad.grad_fn),
'Identity(Error(AccumulateGrad(), None, AccumulateGrad()))')
'CloneBackward(Error(AccumulateGrad(), None, AccumulateGrad()))')

def test_function_returns_input(self):
class MyFunction(Function):
Expand All @@ -146,31 +144,30 @@ def backward(ctx, grad):

def test_accumulate_grad(self):
grad_output = Variable(torch.ones(5, 5))
for start_volatile, end_volatile in product((True, False), repeat=2):
go1 = grad_output.data if start_volatile else grad_output
go2 = grad_output.data if end_volatile else grad_output

def compute_grad(create_graph):
x = Variable(torch.randn(5, 5), requires_grad=True)
y = x + 2
y.backward(go1, retain_graph=True)
y.backward(grad_output, retain_graph=True)
x_grad = x.grad
x_grad_clone = x.grad.data.clone()
y.backward(go2)
x_grad_clone = x.grad.clone()
y.backward(grad_output, create_graph=create_graph)
return x_grad, x_grad_clone

# That's the only case when we can accumulate in-place
# TODO: reconsider this logic (see accumulate_grad.cpp)
if start_volatile:
expected_grad = x_grad_clone * 2
else:
expected_grad = x_grad_clone
self.assertEqual(x_grad.data, expected_grad)
# Accumulate in-place when create_graph is False
x_grad, x_grad_clone = compute_grad(create_graph=False)
self.assertEqual(x_grad, x_grad_clone * 2)

# Accumulate out-of-place when create_graph is False
x_grad, x_grad_clone = compute_grad(create_graph=True)
self.assertEqual(x_grad, x_grad_clone)

def test_hessian_vector(self):
x = Variable(torch.randn(2, 2), requires_grad=True)
y = Variable(torch.randn(2, 2), requires_grad=True)

z = x ** 2 + y * x + y ** 2
z.backward(Variable(torch.ones(2, 2), requires_grad=True), retain_graph=True)
z.backward(torch.ones(2, 2), create_graph=True)

x_grad = 2 * x.data + y.data
y_grad = x.data + 2 * y.data
Expand All @@ -188,7 +185,7 @@ def test_grad(self):
x = Variable(torch.randn(2, 2), requires_grad=True)
y = Variable(torch.randn(2, 2), requires_grad=True)
z = x ** 2 + y * x + y ** 2
z.backward(Variable(torch.ones(2, 2)), retain_graph=True)
z.backward(torch.ones(2, 2), create_graph=True)

x_grad = 2 * x.data + y.data
y_grad = x.data + 2 * y.data
Expand Down Expand Up @@ -605,19 +602,11 @@ def backward(ctx, grad_b):

TestFn.apply(b).sum().backward()

def test_volatile(self):
def test_no_grad(self):
x = Variable(torch.ones(5, 5), requires_grad=True)
y = Variable(torch.ones(5, 5) * 4, volatile=True)

z = x ** 2
self.assertFalse(z.volatile)
self.assertTrue(z.requires_grad)
self.assertIsNotNone(z.grad_fn)
z.backward(torch.ones(5, 5))
self.assertEqual(x.grad.data, torch.ones(5, 5) * 2)

w = z + y
self.assertTrue(w.volatile)
y = Variable(torch.ones(5, 5) * 4)
with torch.no_grad():
w = x + y
self.assertFalse(w.requires_grad)
self.assertRaises(RuntimeError, lambda: w.backward(torch.ones(5, 5)))
self.assertIsNone(w.grad_fn)
Expand Down Expand Up @@ -744,6 +733,12 @@ def test_indexing_duplicates(self):
expected_grad[1].fill_(3)
self.assertEqual(y.grad.data, expected_grad)

def test_volatile_deprecated(self):
v = torch.autograd.Variable(torch.randn(3, 3))
with warnings.catch_warnings(record=True) as w:
self.assertFalse(v.volatile)
self.assertIn('volatile', str(w[0].message))

def test_requires_grad(self):
x = Variable(torch.randn(5, 5))
y = Variable(torch.randn(5, 5))
Expand Down Expand Up @@ -902,6 +897,28 @@ def backward(ctx, grad_output):
y = x.masked_fill(mask, 0)
y.sum().backward()

def test_mark_non_differentiable_mixed(self):
class MyFunction(Function):
@staticmethod
def forward(ctx, input):
a = input + 1
b = input + 2
ctx.mark_non_differentiable(a)
return a, b

@staticmethod
def backward(ctx, grad_a, grad_b):
self.assertTrue((grad_a == 0).all())
self.assertTrue((grad_b == 1).all())
return grad_b

x = Variable(torch.randn(5, 5), requires_grad=True)
a, b = MyFunction.apply(x)
self.assertFalse(a.requires_grad)
self.assertTrue(b.requires_grad)
b.sum().backward()
self.assertEqual(x.grad.data, torch.ones(5, 5))

def test_mark_non_differentiable_none(self):
# This used to segfault because MyFunction would send back null
# gradients to MulBackward, which is implemented in C++. C++
Expand Down Expand Up @@ -1140,12 +1157,6 @@ def test_detach(self):
# wanted. detach() is an advanced option.
self.assertEqual(x.grad.data, torch.ones(10, 10))

# detach() should preserve volatile flag
x = Variable(torch.randn(10, 10), volatile=True)
y = x * 2
y = y.detach()
self.assertTrue(y.volatile)

# in-place detach
x = Variable(torch.randn(10, 10), requires_grad=True)
y = Variable(torch.randn(10, 10), requires_grad=True)
Expand All @@ -1161,6 +1172,16 @@ def test_detach(self):
view = x.narrow(0, 1, 4)
self.assertRaisesRegex(RuntimeError, 'view', lambda: view.detach_())

def test_detach_base(self):
"detaching base does not detach view"
x = Variable(torch.randn(10, 10), requires_grad=True)
view = x.narrow(0, 1, 4)
x.detach_()
self.assertFalse(x.requires_grad)
self.assertTrue(view.requires_grad)
self.assertIsNotNone(view.grad_fn)
self.assertIs(view._base, x)

def _test_type_conversion_backward(self, t, ):
fvar = Variable(t(torch.randn(5, 5).float()), requires_grad=True)
fvar.double().sum().backward()
Expand Down Expand Up @@ -1369,12 +1390,23 @@ def test_leaf_assignment(self):
self.assertEqual(y.grad.data, torch.ones(5))
self.assertEqual(z.grad.data, torch.ones(5) * 2)

def test_volatile_assignment(self):
x = Variable(torch.randn(5, 5))
y = Variable(torch.randn(5), volatile=True)
def test_no_grad_assignment(self):
x = Variable(torch.randn(5, 5), requires_grad=True)
y = Variable(torch.randn(5))
with torch.no_grad():
x[0] = y

x[0] = y
self.assertTrue(x.volatile)
self.assertTrue(x.requires_grad)
self.assertIsNone(x.grad_fn)

def test_no_grad_modifies_version(self):
x = Variable(torch.randn(5), requires_grad=True)
y = Variable(torch.randn(5), requires_grad=True)
z = (x * y).sum()
with torch.no_grad():
x *= 2
self.assertRaisesRegex(RuntimeError, 'modified by an inplace operation',
lambda: z.backward())

def test_backward_copy(self):
# This tests checks backward engine for a very subtle bug that appreared
Expand Down Expand Up @@ -1503,20 +1535,17 @@ def backward(self, grad_output):

def test_pickle(self):
x = Variable(torch.randn(10, 10), requires_grad=True)
y = Variable(torch.randn(10, 10), volatile=True)
z = Variable(torch.randn(10, 10), requires_grad=False)
y = Variable(torch.randn(10, 10), requires_grad=False)

def assert_strict_equal(var1, var2):
self.assertEqual(var1.data, var2.data)
self.assertEqual(var1.requires_grad, var2.requires_grad)
self.assertEqual(var1.volatile, var2.volatile)

serialized = [pickle.dumps([x, y, z], protocol=p) for p in range(3)]
serialized = [pickle.dumps([x, y], protocol=p) for p in range(3)]
for dump in serialized:
xc, yc, zc = pickle.loads(dump)
xc, yc = pickle.loads(dump)
assert_strict_equal(xc, x)
assert_strict_equal(yc, y)
assert_strict_equal(zc, z)

def test_dep_nograd(self):
class F1(Function):
Expand Down Expand Up @@ -1564,7 +1593,7 @@ def backward(ctx, grad_output):

x = Variable(torch.randn(2, 2), requires_grad=True)
out = Reenter.apply(x)
out.sum().backward()
out.sum().backward(create_graph=True)
self.assertEqual(x.grad.data, y_data)

def test_cat(self):
Expand Down Expand Up @@ -1720,25 +1749,6 @@ def test_inplace_view_of_view(self):
x.sum().backward()
self.assertEqual(root.grad.data.tolist(), [[1, 2], [1, 1]])

def test_inplace_view_volatile(self):
# an in-place operation on a view that makes the view volatile should
# make the base volatile too
base = Variable(torch.randn(2, 2))
view = base.narrow(0, 0, 1)
view.add_(Variable(torch.randn(1, 2), volatile=True))
self.assertTrue(view.volatile)
self.assertTrue(base.volatile)

def test_inplace_base_volatile(self):
# an in-place operation on a base that makes the base volatile should
# trigger a consistency exception if the view is used in a differentiable
# op
x = Variable(torch.randn(2, 2), requires_grad=True)
base = Variable(torch.randn(2, 2))
view = base.narrow(0, 0, 1)
base.add_(Variable(torch.randn(1, 2), volatile=True))
self.assertRaisesRegex(RuntimeError, 'is_volatile', lambda: view + x)

def test_inplace_view_gradcheck(self):
# gradcheck modifications to views
a = Variable(torch.randn(4, 4), requires_grad=True)
Expand Down Expand Up @@ -1779,20 +1789,23 @@ def test_inplace_view_backprop_view(self):
self.assertEqual(b.grad.data.tolist(), [5])
self.assertIsNone(a.grad)

def test_inplace_view_flags(self):
# check that an exception is thrown if the flags on the base do not
# match the flags on the view
x = Variable(torch.ones(5))
def test_inplace_view_modify_base(self):
# Test that an in-place operation on a base that forced it to require
# grad also forces any previous views to require grad and backprop
# correctly
r = Variable(torch.ones(1), requires_grad=True)
r2 = Variable(torch.ones(1), requires_grad=True)
v = x.select(0, 1)
x.add_(r)
self.assertFalse(v.requires_grad)
self.assertTrue(x.requires_grad)
# v is dependent on r due to the addition above, but v still doesn't
# requires_grad. The addition to r2 should raise an error until we
# share requires_grad between base and views.
self.assertRaisesRegex(RuntimeError, 'requires_grad', lambda: v + r2)

def fn(r):
x = Variable(torch.ones(5))
v = x.select(0, 1)
self.assertFalse(v.requires_grad)
self.assertIsNone(v.grad_fn)
x.add_(r) # v is now dependent on r due to the in-place op on x
self.assertTrue(v.requires_grad)
return v

gradcheck(fn, [r])
gradgradcheck(fn, [r])

def test_inplace_view_python(self):
# in-place modifications of Python-autograd created view
Expand Down
20 changes: 8 additions & 12 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,14 +228,12 @@ def test_arg_configurations(self):
"""Different arg configurations should trigger different traces"""
x = Variable(torch.FloatTensor(4, 4).uniform_())
x_double = Variable(x.data.double())
x_volatile = Variable(x.data.clone(), volatile=True)
x_grad = Variable(x.data.clone(), requires_grad=True)
y = Variable(torch.randn(4))

configurations = [
(x,),
(x_double,),
(x_volatile,),
(x_grad,),
(y,),
([x, x],),
Expand Down Expand Up @@ -703,8 +701,8 @@ def fn(x, y):
self.assertEqual(grad_v, expected_grad)
self.assertEqual(fn.has_trace_for(x, y), rx or ry)

def test_volatile_fallback(self):
"""Check that Traceable falls back to num_backwards=0 if given volatile inputs"""
def test_no_grad_fallback(self):
"""Check that Traceable falls back to num_backwards=0 if in no-backprop mode"""
x = Variable(torch.randn(2, 2))
y = Variable(torch.randn(2, 2), requires_grad=True)

Expand All @@ -714,14 +712,12 @@ def fn(x, y):

out = fn(x, y)
self.assertFalse(fn.has_trace_for(x, y))

x.volatile = True
self.assertFalse(fn.has_trace_for(x, y))
out = fn(x, y)
self.assertTrue(fn.has_trace_for(x, y))
with self.assertCompiled(fn):
out2 = fn(x, y)
self.assertEqual(out, out2)
with torch.no_grad():
out = fn(x, y)
self.assertTrue(fn.has_trace_for(x, y))
with self.assertCompiled(fn):
out2 = fn(x, y)
self.assertEqual(out, out2)

def test_backward_flag_checks(self):
x = Variable(torch.randn(1), requires_grad=True)
Expand Down
Loading

0 comments on commit d605058

Please sign in to comment.