Skip to content

Commit

Permalink
[MPS] Fix conv1d backwards crash for channels last case (pytorch#85283)
Browse files Browse the repository at this point in the history
Fixes pytorch#84511

Use the same logic as in the forward pass for the backward pass (in case of channels last, manually set the shape to NHWC)

Pull Request resolved: pytorch#85283
Approved by: https://github.com/malfet, https://github.com/razarmehr
  • Loading branch information
kulinseth authored and pytorchmergebot committed Sep 20, 2022
1 parent bcdef58 commit 077db3d
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 17 deletions.
38 changes: 21 additions & 17 deletions aten/src/ATen/native/mps/operations/Convolution.mm
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,19 @@ void fill_conv_desc(MPSGraphConvolution2DOpDescriptor* descriptor_,
descriptor_.groups = groups;
}

static
MPSShape* get_mps_conv_shape(const Tensor& tensor, bool is_channels_last) {
if (is_channels_last) {
const auto tensorSizes = tensor.sizes();
const NSUInteger N = tensorSizes[0];
const NSUInteger C = tensorSizes[1];
const NSUInteger H = tensorSizes[2];
const NSUInteger W = tensorSizes[3];
return @[@(N), @(H), @(W), @(C)];
}
return at::native::mps::getMPSShape(tensor);
}

Tensor _mps_convolution(
const Tensor& input_t,
const Tensor& weight_t,
Expand Down Expand Up @@ -126,19 +139,7 @@ Tensor _mps_convolution(
+ mps::getTensorsStringKey({input_t, weight_t}) + ":"
+ to_string(bias_defined) + ":" + bias_shape_key;
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
MPSShape* inputShape = nil;

if (is_channels_last) {
const auto inputSizes = input_t.sizes();
const NSUInteger N = inputSizes[0];
const NSUInteger C = inputSizes[1];
const NSUInteger H = inputSizes[2];
const NSUInteger W = inputSizes[3];
inputShape = @[@(N), @(H), @(W), @(C)];
} else {
inputShape = native_mps::getMPSShape(input_t);
}

MPSShape* inputShape = get_mps_conv_shape(input_t, is_channels_last);
if(!cachedGraph) {
native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () {

Expand Down Expand Up @@ -333,6 +334,9 @@ Tensor mps_convolution_backward_weights(
using namespace mps;
CheckedFrom c = "mps_convolution_backward_weights";
auto memory_format = input_t.suggest_memory_format();
bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast);
MPSShape* inputShape = get_mps_conv_shape(input_t, is_channels_last);
MPSShape* gradOutputShape = get_mps_conv_shape(grad_output_t, is_channels_last);

// For uniformity with everything else, although it seems grad_weight
// would be unambiguous too.
Expand Down Expand Up @@ -399,8 +403,8 @@ Tensor mps_convolution_backward_weights(
padding[1], padding[0],
memory_format, groups);

MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, grad_output_t);
MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_t);
MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(grad_output_t.scalar_type()), gradOutputShape);
MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(input_t.scalar_type()), inputShape);

MPSGraphTensor* gradWeightTensor = [mpsGraph convolution2DWeightsGradientWithIncomingGradientTensor:gradOutputTensor
sourceTensor:inputTensor
Expand All @@ -417,8 +421,8 @@ Tensor mps_convolution_backward_weights(
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
}

auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output_t);
auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t);
auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output_t, gradOutputShape);
auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t, inputShape);
auto outputPlaceholder = Placeholder(cachedGraph->gradWeightTensor_, grad_weight_t);

NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
Expand Down
13 changes: 13 additions & 0 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -6005,6 +6005,19 @@ def test_conv_transpose_1d_nn_functional(self):

self.assertEqual(tcpu, tgpu.cpu(), rtol=2.6e-05, atol=2e-04)

def test_conv_backward_1d_channels_last(self):
# https://github.com/pytorch/pytorch/issues/84511
conv_cpu = torch.nn.Conv1d(in_channels=1, out_channels=1, kernel_size=3)
conv_mps = copy.deepcopy(conv_cpu).to(device='mps')

data = torch.rand(1, 176, 1, dtype=torch.float32)
x_cpu = data.permute(0, 2, 1).contiguous()
x_mps = data.permute(0, 2, 1).contiguous().to("mps")
res_cpu = conv_cpu(x_cpu).sum().backward()
res_mps = conv_mps(x_mps).sum().backward()

self.assertEqual(res_cpu, res_mps)

def test_conv1d_contiguous(self):
model_cpu = torch.nn.Conv1d(1, 128, 3)
a_cpu = torch.ones(128, 1, 176)
Expand Down

0 comments on commit 077db3d

Please sign in to comment.