Skip to content

Commit

Permalink
[MPS] Fix the crash in bitwise ops on x86 platforms. (pytorch#85285)
Browse files Browse the repository at this point in the history
Fixes #ISSUE_NUMBER

Pull Request resolved: pytorch#85285
Approved by: https://github.com/razarmehr, https://github.com/malfet
  • Loading branch information
kulinseth authored and pytorchmergebot committed Sep 20, 2022
1 parent 6c48a01 commit bcdef58
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion aten/src/ATen/native/mps/operations/BitwiseOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,10 @@ kernel void bitwise_not(constant uint& length [[buffer(0)]],
return it->second;
}
NSError *error = nil;
MTLCompileOptions *options = [[MTLCompileOptions new] autorelease];
[options setLanguageVersion: MTLLanguageVersion2_3];
auto rc = [device newLibraryWithSource:[NSString stringWithUTF8String:fmt::format(BITWISE_OPS_TEMPLATE, t1, t2, t3).c_str()]
options:nil
options:options
error:&error];
TORCH_CHECK(rc != nil && error == nil, "Failed to compile library: ", [[error localizedDescription] UTF8String]);
libMap[key] = rc;
Expand Down Expand Up @@ -170,6 +172,9 @@ void handle_tensor_tensor_binary_op(const at::Tensor& self, const at::Tensor& ot
getMetalType(other),
kernel_name);
uint32_t length = output.numel();
if (length == 0) {
return;
}
dispatch_sync(stream->queue(), ^(){
id<MTLCommandBuffer> buffer = stream->commandBuffer();
id<MTLComputeCommandEncoder> commandEncoder = [buffer computeCommandEncoder];
Expand Down Expand Up @@ -200,6 +205,9 @@ void handle_tensor_scalar_binary_op(const at::Tensor& self, const at::Scalar& ot
kernel_name);
uint64_t sval = other.to<int64_t>();
uint32_t length = output.numel();
if (length == 0) {
return;
}
dispatch_sync(stream->queue(), ^(){
id<MTLCommandBuffer> buffer = stream->commandBuffer();
id<MTLComputeCommandEncoder> commandEncoder = [buffer computeCommandEncoder];
Expand Down

0 comments on commit bcdef58

Please sign in to comment.