Skip to content

Commit

Permalink
functionalization: remove some unnecessary view_copies in inplace views
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#77713

Approved by: https://github.com/ezyang
  • Loading branch information
bdhirsh authored and pytorchmergebot committed May 26, 2022
1 parent 7ff091f commit e9c54ae
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 32 deletions.
4 changes: 1 addition & 3 deletions aten/src/ATen/FunctionalTensorWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,7 @@ void FunctionalTensorWrapper::mutate_view_meta(at::functionalization::ViewMeta m
// So, these ops are special - they're mutation AND view ops. They get special codegen.
// An example is transpose_, e.g. `a.transpose_()`
// Calling transpose_() should ensure that a gets an alias, and append the new ViewMeta to a's current list of ViewMetas.
// We also need to force a sync (even if a is already up to date), because a's underlying tensor hasn't actually
// been updated to reflect the new view yet.
regenerate_from_base();
value_ = meta.forward_fn(value_, meta.out_index);
}

// Note [Functionalization: Mutation Removal]
Expand Down
51 changes: 22 additions & 29 deletions test/test_functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,36 +376,29 @@ def f(x):
$2 = torch._ops.aten.view_copy.default($1, [8])
$3 = torch._ops.aten._reshape_alias_copy.default($2, [2, 4], [4, 1])
$4 = torch._ops.aten.transpose_copy.int($3, 1, 0)
$5 = torch._ops.aten.view_copy.default($1, [8])
$6 = torch._ops.aten._reshape_alias_copy.default($5, [2, 4], [4, 1])
$7 = torch._ops.aten.transpose_copy.int($6, 1, 0)
$8 = torch._ops.aten.unsqueeze_copy.default($7, 0)
$9 = torch._ops.aten.view_copy.default($1, [8])
$10 = torch._ops.aten._reshape_alias_copy.default($9, [2, 4], [4, 1])
$11 = torch._ops.aten.transpose_copy.int($10, 1, 0)
$12 = torch._ops.aten.unsqueeze_copy.default($11, 0)
$13 = torch._ops.aten.squeeze_copy.default($12)
$14, $15 = torch._ops.aten.split_copy.Tensor($13, 2)
$16 = torch._ops.aten.add.Tensor($14, tensor([[1., 1.],
$5 = torch._ops.aten.unsqueeze_copy.default($4, 0)
$6 = torch._ops.aten.squeeze_copy.default($5)
$7, $8 = torch._ops.aten.split_copy.Tensor($6, 2)
$9 = torch._ops.aten.add.Tensor($7, tensor([[1., 1.],
[1., 1.]]))
$17 = torch._ops.aten.select_copy.int($3, 0, 0)
$18 = torch._ops.aten.clone.default($16, memory_format=torch.contiguous_format)
$19 = torch._ops.aten._unsafe_view.default($18, [4])
$20 = torch._ops.aten.view_copy.default($1, [8])
$21 = torch._ops.aten._reshape_alias_copy.default($20, [2, 4], [4, 1])
$22 = torch._ops.aten.transpose_copy.int($21, 1, 0)
$23 = torch._ops.aten.unsqueeze_copy.default($22, 0)
$24 = torch._ops.aten.squeeze_copy.default($23)
$25 = torch._ops.aten.slice_scatter.default($24, $16, 0, 0, 2)
$26 = torch._ops.aten.unsqueeze_copy.default($25, 0)
$27 = torch._ops.aten.squeeze_copy.dim($26, 0)
$28 = torch._ops.aten.transpose_copy.int($27, 1, 0)
$29 = torch._ops.aten._reshape_alias_copy.default($28, [8], [1])
$30 = torch._ops.aten.view_copy.default($29, [4, 2])
$31 = torch._ops.aten.view_copy.default($30, [8])
$32 = torch._ops.aten._reshape_alias_copy.default($31, [2, 4], [4, 1])
$33 = torch._ops.aten.select_copy.int($32, 0, 0)
$34 = torch._ops.aten.add.Tensor($33, $19)""")
$10 = torch._ops.aten.select_copy.int($3, 0, 0)
$11 = torch._ops.aten.clone.default($9, memory_format=torch.contiguous_format)
$12 = torch._ops.aten._unsafe_view.default($11, [4])
$13 = torch._ops.aten.view_copy.default($1, [8])
$14 = torch._ops.aten._reshape_alias_copy.default($13, [2, 4], [4, 1])
$15 = torch._ops.aten.transpose_copy.int($14, 1, 0)
$16 = torch._ops.aten.unsqueeze_copy.default($15, 0)
$17 = torch._ops.aten.squeeze_copy.default($16)
$18 = torch._ops.aten.slice_scatter.default($17, $9, 0, 0, 2)
$19 = torch._ops.aten.unsqueeze_copy.default($18, 0)
$20 = torch._ops.aten.squeeze_copy.dim($19, 0)
$21 = torch._ops.aten.transpose_copy.int($20, 1, 0)
$22 = torch._ops.aten._reshape_alias_copy.default($21, [8], [1])
$23 = torch._ops.aten.view_copy.default($22, [4, 2])
$24 = torch._ops.aten.view_copy.default($23, [8])
$25 = torch._ops.aten._reshape_alias_copy.default($24, [2, 4], [4, 1])
$26 = torch._ops.aten.select_copy.int($25, 0, 0)
$27 = torch._ops.aten.add.Tensor($26, $12)""")

def test_reapply_views_simple(self):
def f(x):
Expand Down

0 comments on commit e9c54ae

Please sign in to comment.