Skip to content

Commit

Permalink
[MPS] Fix and unblock TestConsistency for median (pytorch#94489)
Browse files Browse the repository at this point in the history
- fix num_output_dims calculation
- fix median_out_mps key
- cast tensor sent to sortWithTensor and argSortWithTensor
- note down same issue for unique
- unblock median from blocklist
- adding test_median_int16 test

Fixes #ISSUE_NUMBER

Pull Request resolved: pytorch#94489
Approved by: https://github.com/razarmehr
  • Loading branch information
kulinseth authored and pytorchmergebot committed Feb 9, 2023
1 parent 69e0bda commit 105f720
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 22 deletions.
67 changes: 46 additions & 21 deletions aten/src/ATen/native/mps/operations/ReduceOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -1751,11 +1751,21 @@ Tensor median_mps(const Tensor& input_t) {
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);

auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
auto reshapedTensor = [mpsGraph reshapeTensor: inputTensor
withShape: @[@-1]
name: nil];
MPSDataType dataType = [inputTensor dataType];
// #issue 104398441 sortWithTensor only supports following types, cast if necessary
if (dataType != MPSDataTypeInt32 &&
dataType != MPSDataTypeFloat32 &&
dataType != MPSDataTypeFloat16) {
dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
reshapedTensor = [mpsGraph castTensor:reshapedTensor
toType:dataType
name:@"castReshapedTensor"];
}

auto sortedTensor = [mpsGraph sortWithTensor: reshapedTensor
axis: ((NSUInteger) (int)0)
name: nil];
Expand Down Expand Up @@ -1835,7 +1845,7 @@ void median_out_mps(
auto stream = at::mps::getCurrentMPSStream();

@autoreleasepool {
string key = func_name + ":" + to_string(dim_) + ":" + getTensorsStringKey(input_t);
string key = func_name + ":" + to_string(dim_) + ":" + getTensorsStringKey(input_t) + ":" + getTensorsStringKey(indices_t);
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);

if (!cachedGraph) {
Expand All @@ -1847,24 +1857,39 @@ void median_out_mps(
auto mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);

MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
auto sortedTensor = [mpsGraph sortWithTensor: inputTensor
axis: (NSUInteger)dim_
name: nil];
const NSUInteger midpoint = (dim_total_elements + 1) / 2 - 1;
auto outputTensor = [mpsGraph sliceTensor:sortedTensor
dimension:dim_
start:midpoint
length:1
name:nil];
auto argreduceOutTensor = [mpsGraph argSortWithTensor:inputTensor
axis:(NSInteger)dim_
name:@"argmax_out"];
auto argOutputTensor = [mpsGraph sliceTensor:argreduceOutTensor
dimension:dim_
start:midpoint
length:1
name:nil];
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(input_t.scalar_type()));
MPSGraphTensor* outputTensor = nil;
MPSGraphTensor* castInputTensor = inputTensor;
MPSDataType dataType = getMPSDataType(input_t.scalar_type());
// #issue 104398441 sortWithTensor only supports following types, cast if necessary
if (dataType != MPSDataTypeInt32 &&
dataType != MPSDataTypeFloat32 &&
dataType != MPSDataTypeFloat16) {
dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
castInputTensor = [mpsGraph castTensor:inputTensor
toType:dataType
name:@"castInputTensor"];
}

MPSGraphTensor * sortedTensor = [mpsGraph
sortWithTensor:castInputTensor
axis:((NSUInteger) (int)dim_)
name:nil];

outputTensor = [mpsGraph sliceTensor:sortedTensor
dimension:dim_
start:((NSUInteger) (int)((dim_total_elements+1)/2 ) - 1)
length:1
name:nil];
MPSGraphTensor* argreduceOutTensor = nil;
argreduceOutTensor = [mpsGraph argSortWithTensor:castInputTensor
axis:(NSInteger)dim_
name:@"argmax_out"];
MPSGraphTensor* argOutputTensor = [mpsGraph sliceTensor:argreduceOutTensor
dimension:dim_
start:((NSUInteger) (int)((dim_total_elements+1)/2 ) - 1)
length:1
name:nil];

newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->outputTensor_ = outputTensor;
Expand Down Expand Up @@ -1934,7 +1959,7 @@ void median_out_mps(
int64_t num_input_dims = input_shape.size();
NSMutableArray<NSNumber*> *apparent_out_shape = nil;
// Use this if keepdim is false
int64_t num_output_dims = num_input_dims - 1;
int64_t num_output_dims = num_input_dims - 1 < 0 ? 0 : num_input_dims - 1;

std::vector<int64_t> vec_apparent_out_shape(num_input_dims);
std::vector<int64_t> vec_out_shape(num_output_dims);
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/mps/operations/Unique.mm
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
return {resultTensor, inverseIndicesTensor, countTensor, lengthTensor};
}

// Sort only supports following types, cast if necessary
// #issue 104398441 sortWithTensor only supports following types, cast if necessary
if (dataType != MPSDataTypeInt32 &&
dataType != MPSDataTypeFloat32 &&
dataType != MPSDataTypeFloat16) {
Expand Down
11 changes: 11 additions & 0 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2325,6 +2325,17 @@ def helper(dtype, noncontiguous, dim):
helper(dtype, noncontiguous, dim)


def test_median_int16(self):
def helper(shape, dtype):
cpu_x = torch.randint(-9999, 9999, shape, device='cpu', dtype=dtype)
x = cpu_x.detach().clone().to('mps')

median_result = torch.median(x)
median_result_cpu = torch.median(cpu_x)
self.assertEqual(median_result, median_result_cpu)

helper((2, 8, 4, 5), torch.int16)

class TestLogical(TestCase):
def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):
return torch.tensor(x, device=device, dtype=dtype, requires_grad=requires_grad)
Expand Down

0 comments on commit 105f720

Please sign in to comment.