Skip to content

Commit

Permalink
Enable instance norm running mean test (pytorch#89793)
Browse files Browse the repository at this point in the history
Followup action to pytorch#88697
Pull Request resolved: pytorch#89793
Approved by: https://github.com/bdhirsh
  • Loading branch information
janeyx99 authored and pytorchmergebot committed Nov 29, 2022
1 parent c599cf2 commit fcb5d6e
Showing 1 changed file with 32 additions and 28 deletions.
60 changes: 32 additions & 28 deletions test/test_functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1305,13 +1305,15 @@ def forward(self, a_1):


def test_instance_norm_running_mean_is_x(self):
size = 100

def f(x):
with enable_python_dispatcher():
return torch.instance_norm(torch.randn(20, 100, 35, 45), None, None, running_mean=x, running_var=torch.ones(100),
use_input_stats=True, momentum=0.1, eps=1e-5, cudnn_enabled=False)
# TODO: uncomment following line after functionalization can handle input mutations
# self.assert_functionalization(f, torch.zeros(100))
logs = self.get_logs(f, torch.zeros(100))
return torch.instance_norm(
torch.arange(20 * size * 35 * 45, dtype=torch.float32).reshape(20, size, 35, 45), None, None,
x, torch.ones(size), use_input_stats=True, momentum=0.1, eps=1e-5, cudnn_enabled=False)
self.assert_functionalization(f, torch.zeros(size))
logs = self.get_logs(f, torch.zeros(size))
# On Windows, for instance_norm, the alias_copy's are reordered to come right before they need to be used
# whereas on other platforms, the alias_copy's are before the view_copy's.
# e.g., the alias_copy after the getitem_4 assignment would be moved to be right before the copy assignment.
Expand All @@ -1321,66 +1323,68 @@ def f(x):
def forward(self, a_1):
randn = torch.ops.aten.randn.default([20, 100, 35, 45], device = device(type='cpu'), pin_memory = False)
arange = torch.ops.aten.arange.default(3150000, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
view_copy = torch.ops.aten.view_copy.default(arange, [20, 100, 35, 45]); arange = None
ones = torch.ops.aten.ones.default([100], device = device(type='cpu'), pin_memory = False)
repeat = torch.ops.aten.repeat.default(a_1, [20])
repeat_1 = torch.ops.aten.repeat.default(ones, [20])
view_copy = torch.ops.aten.view_copy.default(randn, [1, 2000, 35, 45]); randn = None
view_copy_1 = torch.ops.aten.view_copy.default(view_copy, [1, 2000, 35, 45]); view_copy = None
empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu'))
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(view_copy, None, None, repeat, repeat_1, True, 0.1, 1e-05); view_copy = repeat = repeat_1 = None
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(view_copy_1, None, None, repeat, repeat_1, True, 0.1, 1e-05); view_copy_1 = repeat = repeat_1 = None
getitem = _native_batch_norm_legit_functional[0]
getitem_1 = _native_batch_norm_legit_functional[1]
getitem_2 = _native_batch_norm_legit_functional[2]
getitem_3 = _native_batch_norm_legit_functional[3]
getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
alias_copy = torch.ops.aten.alias_copy.default(a_1)
view_copy_1 = torch.ops.aten.view_copy.default(getitem_3, [20, 100])
view_copy_2 = torch.ops.aten.view_copy.default(getitem_3, [20, 100]); getitem_3 = None
mean = torch.ops.aten.mean.dim(view_copy_2, [0]); view_copy_2 = None
view_copy_2 = torch.ops.aten.view_copy.default(getitem_3, [20, 100])
view_copy_3 = torch.ops.aten.view_copy.default(getitem_3, [20, 100]); getitem_3 = None
mean = torch.ops.aten.mean.dim(view_copy_3, [0]); view_copy_3 = None
copy = torch.ops.aten.copy.default(alias_copy, mean); alias_copy = mean = None
alias_copy_1 = torch.ops.aten.alias_copy.default(ones); ones = None
view_copy_3 = torch.ops.aten.view_copy.default(getitem_4, [20, 100])
view_copy_4 = torch.ops.aten.view_copy.default(getitem_4, [20, 100]); getitem_4 = None
mean_1 = torch.ops.aten.mean.dim(view_copy_4, [0]); view_copy_4 = None
view_copy_4 = torch.ops.aten.view_copy.default(getitem_4, [20, 100])
view_copy_5 = torch.ops.aten.view_copy.default(getitem_4, [20, 100]); getitem_4 = None
mean_1 = torch.ops.aten.mean.dim(view_copy_5, [0]); view_copy_5 = None
copy_1 = torch.ops.aten.copy.default(alias_copy_1, mean_1); alias_copy_1 = mean_1 = None
view_copy_5 = torch.ops.aten.view_copy.default(getitem, [20, 100, 35, 45]); getitem = None
view_copy_6 = torch.ops.aten.view_copy.default(getitem, [20, 100, 35, 45]); getitem = None
alias_copy_2 = torch.ops.aten.alias_copy.default(copy); copy = None
copy_ = torch.ops.aten.copy_.default(a_1, alias_copy_2); a_1 = alias_copy_2 = None
return view_copy_5
return view_copy_6
""") # noqa: B950

reinplaced_logs = self.get_logs(f, torch.zeros(100), reapply_views=True, run_reinplace=True)
reinplaced_logs = self.get_logs(f, torch.zeros(size), reapply_views=True, run_reinplace=True)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, a_1):
randn = torch.ops.aten.randn.default([20, 100, 35, 45], device = device(type='cpu'), pin_memory = False)
arange = torch.ops.aten.arange.default(3150000, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
view = torch.ops.aten.view.default(arange, [20, 100, 35, 45]); arange = None
ones = torch.ops.aten.ones.default([100], device = device(type='cpu'), pin_memory = False)
repeat = torch.ops.aten.repeat.default(a_1, [20])
repeat_1 = torch.ops.aten.repeat.default(ones, [20])
view = torch.ops.aten.view.default(randn, [1, 2000, 35, 45]); randn = None
view_1 = torch.ops.aten.view.default(view, [1, 2000, 35, 45]); view = None
empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu'))
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(view, None, None, repeat, repeat_1, True, 0.1, 1e-05); view = repeat = repeat_1 = None
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(view_1, None, None, repeat, repeat_1, True, 0.1, 1e-05); view_1 = repeat = repeat_1 = None
getitem = _native_batch_norm_legit_functional[0]
getitem_1 = _native_batch_norm_legit_functional[1]
getitem_2 = _native_batch_norm_legit_functional[2]
getitem_3 = _native_batch_norm_legit_functional[3]
getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
alias = torch.ops.aten.alias.default(a_1)
view_1 = torch.ops.aten.view.default(getitem_3, [20, 100])
view_2 = torch.ops.aten.view.default(getitem_3, [20, 100]); getitem_3 = None
mean = torch.ops.aten.mean.dim(view_2, [0]); view_2 = None
view_2 = torch.ops.aten.view.default(getitem_3, [20, 100])
view_3 = torch.ops.aten.view.default(getitem_3, [20, 100]); getitem_3 = None
mean = torch.ops.aten.mean.dim(view_3, [0]); view_3 = None
copy = torch.ops.aten.copy.default(alias, mean); alias = mean = None
alias_1 = torch.ops.aten.alias.default(ones); ones = None
view_3 = torch.ops.aten.view.default(getitem_4, [20, 100])
view_4 = torch.ops.aten.view.default(getitem_4, [20, 100]); getitem_4 = None
mean_1 = torch.ops.aten.mean.dim(view_4, [0]); view_4 = None
view_4 = torch.ops.aten.view.default(getitem_4, [20, 100])
view_5 = torch.ops.aten.view.default(getitem_4, [20, 100]); getitem_4 = None
mean_1 = torch.ops.aten.mean.dim(view_5, [0]); view_5 = None
copy_1 = torch.ops.aten.copy_.default(alias_1, mean_1); alias_1 = mean_1 = None
view_5 = torch.ops.aten.view.default(getitem, [20, 100, 35, 45]); getitem = None
view_6 = torch.ops.aten.view.default(getitem, [20, 100, 35, 45]); getitem = None
alias_2 = torch.ops.aten.alias.default(copy); copy = None
copy_ = torch.ops.aten.copy_.default(a_1, alias_2); a_1 = alias_2 = None
return view_5
return view_6
""") # noqa: B950


Expand Down

0 comments on commit fcb5d6e

Please sign in to comment.