Skip to content

Commit

Permalink
[MPS] Handle broadcasting by expanding src tensor in Copy.mm (pytorch…
Browse files Browse the repository at this point in the history
…#95272)

Fixes #ISSUE_NUMBER

Pull Request resolved: pytorch#95272
Approved by: https://github.com/DenisVieriu97
  • Loading branch information
kulinseth authored and pytorchmergebot committed Feb 22, 2023
1 parent 5a8092f commit 02a6d43
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
11 changes: 8 additions & 3 deletions aten/src/ATen/native/mps/operations/Copy.mm
Original file line number Diff line number Diff line change
Expand Up @@ -300,22 +300,27 @@ void copy_blit_mps(void* dst, const void* src, size_t size) {
TORCH_CHECK(dst.defined(), "dst is undefined");
TORCH_CHECK(src.defined(), "src is undefined");

bool needs_broadcasting = false;

if (src.numel() == 0 || dst.is_same(src)) {
return dst;
}
if (dst.numel() == 0) {
dst.resize_as_(src);
}
if (dst.dim() > src.dim()) {
needs_broadcasting = true;
}

if (src.device().type() == at::kMPS && dst.device().type() == at::kCPU) {
return copy_from_mps_(dst, src, non_blocking);
return copy_from_mps_(dst, needs_broadcasting ? src.expand_as(dst) : src, non_blocking);
}
if (src.device().type() == at::kCPU && dst.device().type() == at::kMPS) {
return copy_to_mps_(dst, src, non_blocking);
return copy_to_mps_(dst, needs_broadcasting ? src.expand_as(dst) : src, non_blocking);
}

if (src.device().type() == at::kMPS && dst.device().type() == at::kMPS) {
return copy_kernel_mps(dst, src, non_blocking);
return copy_kernel_mps(dst, needs_broadcasting ? src.expand_as(dst) : src, non_blocking);
}
TORCH_INTERNAL_ASSERT(
src.device().type() == DeviceType::MPS,
Expand Down
1 change: 1 addition & 0 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -9264,6 +9264,7 @@ class TestConsistency(TestCaseMPS):
'isreal': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'kron': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'linalg.matrix_norm': ['f16'],
'linalg.matrix_power': ['f32'],
'linalg.svd': ['f32'],
'linalg.vector_norm': ['f16', 'f32'],
'linspace': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
Expand Down

0 comments on commit 02a6d43

Please sign in to comment.