Skip to content

Commit

Permalink
Add meta kernel coverage for aten.unsafe_split, aten.unsafe_chunk (py…
Browse files Browse the repository at this point in the history
…torch#92608)

Pull Request resolved: pytorch#92608
Approved by: https://github.com/ngimel
  • Loading branch information
tugsbayasgalan authored and pytorchmergebot committed Jan 20, 2023
1 parent 274958e commit 4386f31
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
1 change: 0 additions & 1 deletion test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2582,7 +2582,6 @@ def test_aot_autograd_symbolic_exhaustive(self, device, dtype, op):
})

symbolic_aot_autograd_module_failures = {
torch.nn.GRU, # Cannot call sizes() on tensor with symbolic sizes/strides
torch.nn.Transformer, # DataDependentOutputException: aten.equal compares a mask input to a mask producing a bool
torch.nn.TransformerEncoder, # DataDependentOutputException: aten.equal compares a mask input to a mask producing a bool
torch.nn.TransformerEncoderLayer, # RuntimeError: tried to get Double out of SymFloat
Expand Down
16 changes: 14 additions & 2 deletions torch/_decomp/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,7 +1085,7 @@ def prod(x: List[int]):
return r


@register_decomposition(aten.split_with_sizes)
@register_decomposition([aten.split_with_sizes, aten.unsafe_split_with_sizes])
def split_with_sizes(
self: Tensor, split_sizes: List[int], dim: int = 0
) -> List[Tensor]:
Expand All @@ -1099,7 +1099,7 @@ def split_with_sizes(
return splits


@register_decomposition(aten.split.Tensor)
@register_decomposition([aten.split.Tensor, aten.unsafe_split.Tensor])
def split(self: Tensor, split_size: int, dim: int = 0) -> List[Tensor]:
input_sizes = self.shape
dim_size = input_sizes[dim]
Expand Down Expand Up @@ -1462,6 +1462,18 @@ def native_batch_norm_decomposition(
)


@aten.unsafe_chunk.default.py_impl(DispatchKey.CompositeImplicitAutograd)
def unsafe_chunk_py_impl(tensor, chunks, dim=0) -> List[Tensor]:
dim_size = tensor.size(dim)
split_size = (dim_size + chunks - 1) // chunks

if split_size == 0 and dim_size == 0:
split_sizes = [split_size for _ in chunks]
split_sizes[chunks - 1] = split_size - (split_size * chunks - dim_size)
return torch.ops.aten.unsafe_split_with_sizes.default(tensor, split_sizes, dim)
return torch.ops.aten.unsafe_split.Tensor(tensor, split_size, dim)


@register_decomposition(aten._native_batch_norm_legit.default)
def _native_batch_norm_legit(
input: Tensor,
Expand Down

0 comments on commit 4386f31

Please sign in to comment.