diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm index d6e510a06e322..44802e7d7a18e 100644 --- a/aten/src/ATen/native/mps/operations/ReduceOps.mm +++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm @@ -138,7 +138,7 @@ void set_axes_and_shapes(const Tensor& input_t, } void reduction_out_mps - (const Tensor& input_t, + (const Tensor& input_tensor, OptionalIntArrayRef opt_dim, bool keepdim, c10::optional dtype, @@ -146,6 +146,8 @@ void set_axes_and_shapes(const Tensor& input_t, MPSReductionType reduction_type, const std::string& func_name) { + auto input_t = (input_tensor.sizes().size() == 0) ? input_tensor.view({1}) : input_tensor; + IntArrayRef input_shape = input_t.sizes(); if (opt_dim.has_value()) { @@ -391,14 +393,17 @@ Tensor count_nonzero_mps(const Tensor& self, IntArrayRef dims){ } TORCH_IMPL_FUNC(norm_out_mps) -(const Tensor& input_t, +(const Tensor& input_tensor, const OptionalScalarRef opt_p, IntArrayRef dim, bool keepdim, const Tensor& output_t) { - if (input_t.numel() == 0) + if (input_tensor.numel() == 0) return; + + auto input_t = (input_tensor.sizes().size() == 0) ? input_tensor.view({1}) : input_tensor; + IntArrayRef input_shape = input_t.sizes(); for(int i = 0; i < dim.size(); i++) { diff --git a/test/test_mps.py b/test/test_mps.py index ccb3a294539b9..2a5c2d03c6046 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -6510,6 +6510,11 @@ class TestConsistency(TestCase): 'arange': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'argmax': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'argmin': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'amax': ['f32'], + 'amix': ['f32'], + 'logsumexp': ['f32'], + 'mean': ['f32'], + 'sum': ['f32'], 'asin': ['f32', 'i16', 'i32', 'u8'], 'asinh': ['f32', 'i16', 'i32', 'u8'], 'atan': ['f32', 'i16', 'i32', 'u8'],