Skip to content

Commit

Permalink
Add meta implementation for aten.max.dim (pytorch#88005)
Browse files Browse the repository at this point in the history
  • Loading branch information
tugsbayasgalan authored and pytorchmergebot committed Nov 1, 2022
1 parent 97b3eea commit 2c7de4a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
1 change: 0 additions & 1 deletion test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,7 +1243,6 @@ def f(a, b, c, d, e):
xfail('masked_scatter', ''), # aten.masked_scatter.default - couldn't find symbolic meta function/decomposition
xfail('masked_select', ''), # aten.masked_select.default - couldn't find symbolic meta function/decomposition
xfail('matrix_exp', ''), # aten.linalg_matrix_exp.default - couldn't find symbolic meta function/decomposition
xfail('max', 'reduction_with_dim'), # aten.max.dim - couldn't find symbolic meta function/decomposition
xfail('median', ''), # Could not run 'aten::median' with arguments from the 'Meta' backend. This could be becau...
xfail('meshgrid', 'list_of_tensors'), # Tensors of type TensorImpl do not have numel
xfail('meshgrid', 'variadic_tensors'), # Tensors of type TensorImpl do not have numel
Expand Down
10 changes: 10 additions & 0 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,16 @@ def meta_max(self):
return self.new_empty(())


@register_meta(aten.max.dim)
def meta_max_dim(self, dim, keepdim=False):
dim = utils.reduction_dims(self.shape, (dim,))
output_shape = _compute_reduction_shape(self, dim, keepdim)
return (
self.new_empty(output_shape),
self.new_empty(output_shape, dtype=torch.long),
)


@register_meta([aten.min.default])
def meta_min(self):
return self.new_empty(())
Expand Down

0 comments on commit 2c7de4a

Please sign in to comment.