From 105f7205bd487ae59cc14685ed330898c50f2f2c Mon Sep 17 00:00:00 2001 From: Kulin Seth Date: Thu, 9 Feb 2023 19:29:07 +0000 Subject: [PATCH] [MPS] Fix and unblock TestConsistency for median (#94489) - fix num_output_dims calculation - fix median_out_mps key - cast tensor sent to sortWithTensor and argSortWithTensor - note down same issue for unique - unblock median from blocklist - adding test_median_int16 test Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/94489 Approved by: https://github.com/razarmehr --- .../ATen/native/mps/operations/ReduceOps.mm | 67 +++++++++++++------ aten/src/ATen/native/mps/operations/Unique.mm | 2 +- test/test_mps.py | 11 +++ 3 files changed, 58 insertions(+), 22 deletions(-) diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm index 88df3af523e8b..6f3b8d79f2c5a 100644 --- a/aten/src/ATen/native/mps/operations/ReduceOps.mm +++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm @@ -1751,11 +1751,21 @@ Tensor median_mps(const Tensor& input_t) { @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t); auto reshapedTensor = [mpsGraph reshapeTensor: inputTensor withShape: @[@-1] name: nil]; + MPSDataType dataType = [inputTensor dataType]; + // #issue 104398441 sortWithTensor only supports following types, cast if necessary + if (dataType != MPSDataTypeInt32 && + dataType != MPSDataTypeFloat32 && + dataType != MPSDataTypeFloat16) { + dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32; + reshapedTensor = [mpsGraph castTensor:reshapedTensor + toType:dataType + name:@"castReshapedTensor"]; + } + auto sortedTensor = [mpsGraph sortWithTensor: reshapedTensor axis: ((NSUInteger) (int)0) name: nil]; @@ -1835,7 +1845,7 @@ void median_out_mps( auto stream = at::mps::getCurrentMPSStream(); @autoreleasepool { - string key = func_name + ":" + to_string(dim_) + ":" + getTensorsStringKey(input_t); + string key = func_name + ":" + to_string(dim_) + ":" + getTensorsStringKey(input_t) + ":" + getTensorsStringKey(indices_t); CachedGraph* cachedGraph = cache_->LookUpAs(key); if (!cachedGraph) { @@ -1847,24 +1857,39 @@ void median_out_mps( auto mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t); - auto sortedTensor = [mpsGraph sortWithTensor: inputTensor - axis: (NSUInteger)dim_ - name: nil]; - const NSUInteger midpoint = (dim_total_elements + 1) / 2 - 1; - auto outputTensor = [mpsGraph sliceTensor:sortedTensor - dimension:dim_ - start:midpoint - length:1 - name:nil]; - auto argreduceOutTensor = [mpsGraph argSortWithTensor:inputTensor - axis:(NSInteger)dim_ - name:@"argmax_out"]; - auto argOutputTensor = [mpsGraph sliceTensor:argreduceOutTensor - dimension:dim_ - start:midpoint - length:1 - name:nil]; + MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(input_t.scalar_type())); + MPSGraphTensor* outputTensor = nil; + MPSGraphTensor* castInputTensor = inputTensor; + MPSDataType dataType = getMPSDataType(input_t.scalar_type()); + // #issue 104398441 sortWithTensor only supports following types, cast if necessary + if (dataType != MPSDataTypeInt32 && + dataType != MPSDataTypeFloat32 && + dataType != MPSDataTypeFloat16) { + dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32; + castInputTensor = [mpsGraph castTensor:inputTensor + toType:dataType + name:@"castInputTensor"]; + } + + MPSGraphTensor * sortedTensor = [mpsGraph + sortWithTensor:castInputTensor + axis:((NSUInteger) (int)dim_) + name:nil]; + + outputTensor = [mpsGraph sliceTensor:sortedTensor + dimension:dim_ + start:((NSUInteger) (int)((dim_total_elements+1)/2 ) - 1) + length:1 + name:nil]; + MPSGraphTensor* argreduceOutTensor = nil; + argreduceOutTensor = [mpsGraph argSortWithTensor:castInputTensor + axis:(NSInteger)dim_ + name:@"argmax_out"]; + MPSGraphTensor* argOutputTensor = [mpsGraph sliceTensor:argreduceOutTensor + dimension:dim_ + start:((NSUInteger) (int)((dim_total_elements+1)/2 ) - 1) + length:1 + name:nil]; newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->outputTensor_ = outputTensor; @@ -1934,7 +1959,7 @@ void median_out_mps( int64_t num_input_dims = input_shape.size(); NSMutableArray *apparent_out_shape = nil; // Use this if keepdim is false - int64_t num_output_dims = num_input_dims - 1; + int64_t num_output_dims = num_input_dims - 1 < 0 ? 0 : num_input_dims - 1; std::vector vec_apparent_out_shape(num_input_dims); std::vector vec_out_shape(num_output_dims); diff --git a/aten/src/ATen/native/mps/operations/Unique.mm b/aten/src/ATen/native/mps/operations/Unique.mm index 109244b73c036..eac16a74564ee 100644 --- a/aten/src/ATen/native/mps/operations/Unique.mm +++ b/aten/src/ATen/native/mps/operations/Unique.mm @@ -57,7 +57,7 @@ return {resultTensor, inverseIndicesTensor, countTensor, lengthTensor}; } - // Sort only supports following types, cast if necessary + // #issue 104398441 sortWithTensor only supports following types, cast if necessary if (dataType != MPSDataTypeInt32 && dataType != MPSDataTypeFloat32 && dataType != MPSDataTypeFloat16) { diff --git a/test/test_mps.py b/test/test_mps.py index 9002a0a879b25..3cd98df54cf58 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -2325,6 +2325,17 @@ def helper(dtype, noncontiguous, dim): helper(dtype, noncontiguous, dim) + def test_median_int16(self): + def helper(shape, dtype): + cpu_x = torch.randint(-9999, 9999, shape, device='cpu', dtype=dtype) + x = cpu_x.detach().clone().to('mps') + + median_result = torch.median(x) + median_result_cpu = torch.median(cpu_x) + self.assertEqual(median_result, median_result_cpu) + + helper((2, 8, 4, 5), torch.int16) + class TestLogical(TestCase): def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False): return torch.tensor(x, device=device, dtype=dtype, requires_grad=requires_grad)