Skip to content

Commit

Permalink
[MPS] Register index.Tensor_out (pytorch#82507)
Browse files Browse the repository at this point in the history
* Add more tests from test_indexing into test_mps
* Cache the indexing library on the MPSDevice
Pull Request resolved: pytorch#82507
Approved by: https://github.com/malfet
  • Loading branch information
kulinseth authored and pytorchmergebot committed Aug 18, 2022
1 parent 6dc8673 commit ce7177f
Show file tree
Hide file tree
Showing 8 changed files with 580 additions and 9 deletions.
132 changes: 132 additions & 0 deletions aten/src/ATen/mps/IndexKernels.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
#pragma once

namespace at {
namespace mps {

static const char * indexing_metal_shaders = R"INDEX_METAL(
#include <metal_stdlib>
using namespace metal;
constant uint32_t num_indices [[function_constant(0)]];
struct IndexAB {
// Allow up to 16 indices
metal::array<constant const void *, 16> indexArray [[ id(0) ]];
};
template<typename T>
kernel void index_select(
constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]) {
constant const int64_t * index_sizes = (constant const int64_t *)indexSizes;
constant const int64_t * index_strides = (constant const int64_t *)indexStrides;
int64_t offset = 0;
for (uint32_t i = 0; i < num_indices; i++) {
int64_t index = ((constant const int64_t*)(indexAB.indexArray[i]))[offsets[thread_index].z / sizeof(int64_t)];
if (index < 0) {
index += index_sizes[i];
}
offset += index * index_strides[i];
}
device T * out = (device T*)((device char*)outputData + offsets[thread_index].x);
constant const T * in = (constant const T*)((constant const char*)inputData + offsets[thread_index].y + offset);
*out = *in;
}
template
[[host_name("index_select_float")]]
kernel void index_select<float>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_select_half")]]
kernel void index_select<half>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_select_long")]]
kernel void index_select<long>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_select_int")]]
kernel void index_select<int>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_select_short")]]
kernel void index_select<short>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_select_char")]]
kernel void index_select<char>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_select_uchar")]]
kernel void index_select<uchar>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_select_bool")]]
kernel void index_select<bool>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
kernel void kernel_index_offsets(constant const packed_uint3 * strides [[buffer(0)]],
device uint3 * data_offsets [[buffer(1)]],
constant const uint * iter_shape [[buffer(2)]],
constant const uint & num_dimensions [[buffer(3)]],
constant const uint & num_offsets [[buffer(4)]],
uint thread_index [[thread_position_in_grid]]) {
uint32_t idx = thread_index;
for (uint32_t dim = 0; dim < num_dimensions; dim++) {
uint32_t remainder = idx % iter_shape[dim];
idx /= iter_shape[dim];
for (uint32_t offset = 0; offset < num_offsets; offset++)
data_offsets[thread_index][offset] += remainder * strides[dim][offset];
}
}
)INDEX_METAL";
}
}
9 changes: 9 additions & 0 deletions aten/src/ATen/mps/MPSDevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,15 @@
#include <Metal/Metal.h>
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
typedef id<MTLDevice> MTLDevice_t;
typedef id<MTLLibrary> MTLLibrary_t;
typedef id<MTLFunction> MTLFunction_t;
typedef MTLFunctionConstantValues* MTLFunctionConstantValues_t;
#else
typedef void* MTLDevice;
typedef void* MTLDevice_t;
typedef void* MTLLibrary_t;
typedef void* MTLFunction_t;
typedef void* MTLFunctionConstantValues_t;
#endif

using namespace std;
Expand Down Expand Up @@ -48,11 +54,14 @@ class TORCH_API MPSDevice {
return _mtl_device;
}

MTLFunction_t metalIndexingFunction(const std::string &kernel, MTLFunctionConstantValues_t constantValues);

~MPSDevice();

private:
static MPSDevice* _device;
MTLDevice_t _mtl_device;
MTLLibrary_t _mtl_indexing_library;
MPSDevice();
};

Expand Down
43 changes: 41 additions & 2 deletions aten/src/ATen/mps/MPSDevice.mm
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,65 @@
#include <c10/util/CallOnce.h>

#include <ATen/mps/MPSDevice.h>
#include <ATen/mps/IndexKernels.h>

namespace at {
namespace mps {

static std::unique_ptr<MPSDevice> mps_device;
static c10::once_flag mpsdev_init;

static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& device) {
// MPS Advanced Indexing needs at least Metal 2.0 (support for Argument Buffers and function constants)
// host_name attribute needs at least Metal 2.2
MTLLanguageVersion languageVersion = MTLLanguageVersion2_2;

TORCH_CHECK([device supportsFamily:MTLGPUFamilyMac2], "Missing Metal support for MTLGPUFamilyMac2");
return languageVersion;
}

MPSDevice* MPSDevice::getInstance() {
c10::call_once(mpsdev_init, [] {
mps_device = std::unique_ptr<MPSDevice>(new MPSDevice());
});
return mps_device.get();
}

id<MTLFunction> MPSDevice::metalIndexingFunction(const std::string& kernel, MTLFunctionConstantValues* constantValues) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_mtl_device);
NSError* error = nil;
if (!_mtl_indexing_library) {
MTLCompileOptions *options = [MTLCompileOptions new];
[options setLanguageVersion: getMetalLanguageVersion(_mtl_device)];
[options setFastMathEnabled: YES];
_mtl_indexing_library = [_mtl_device newLibraryWithSource: [NSString stringWithCString: mps::indexing_metal_shaders encoding:NSASCIIStringEncoding]
options: options
error: &error];
TORCH_CHECK(_mtl_indexing_library, "Failed to create indexing library, error: ", [[error description] UTF8String]);
}

id<MTLFunction> indexFunction = nil;
if (constantValues) {
indexFunction = [[_mtl_indexing_library newFunctionWithName: [NSString stringWithUTF8String: kernel.c_str()]
constantValues: constantValues
error: &error] autorelease];
} else {
indexFunction = [[_mtl_indexing_library newFunctionWithName: [NSString stringWithUTF8String: kernel.c_str()]] autorelease];
}

TORCH_CHECK(indexFunction, "Failed to create specialized function state object: ", kernel, ", error: ", [[error description] UTF8String]);

return indexFunction;
}

MPSDevice::~MPSDevice() {
[_mtl_device release];
[_mtl_indexing_library release];
_mtl_device = nil;
_mtl_indexing_library = nil;
}

MPSDevice::MPSDevice(): _mtl_device(nil) {
MPSDevice::MPSDevice(): _mtl_device(nil), _mtl_indexing_library(nil) {
// Check that MacOS 12.3+ version of MPS framework is available
// Create the MPSGraph and check method introduced in 12.3+
// which is used by MPS backend.
Expand All @@ -45,7 +84,7 @@
break;
}
}
assert(_mtl_device);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_mtl_device);
}

at::Allocator* getMPSSharedAllocator();
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/native/TensorAdvancedIndexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -521,9 +521,9 @@ AdvancedIndex::AdvancedIndex(const Tensor& src, TensorList indices_list)
}
}

// For CUDA tensors, force all index tensors to have the same striding to
// simplify the CUDA kernel.
if (indices.size() >= 2 && this->src.device().type() == kCUDA) {
// For CUDA/MPS tensors, force all index tensors to have the same striding to
// simplify the CUDA/MPS kernel.
if (indices.size() >= 2 && (this->src.device().type() == kCUDA || this->src.device().type() == kMPS)) {
if (!all_strides_match(indices)) {
for (auto & indice : indices) {
indice = indice.contiguous();
Expand Down
51 changes: 51 additions & 0 deletions aten/src/ATen/native/mps/operations/Indexing.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright © 2022 Apple Inc.

#include <ATen/ATen.h>
#include <ATen/Tensor.h>
#include <ATen/Utils.h>
#include <ATen/mps/MPSStream.h>
#include <ATen/native/mps/TensorFactory.h>
#include <c10/core/ScalarType.h>
#include <torch/library.h>
#include <unordered_map>

using namespace at::mps;

namespace at {
namespace native {
namespace mps {

std::string getMetalScalarType(ScalarType scalar_type) {
std::string res = "";
switch (scalar_type) {
case ScalarType::Float:
res = "float"; break;
case ScalarType::Half:
res = "half"; break;
case ScalarType::Long:
res = "long"; break;
case ScalarType::Int:
res = "int"; break;
case ScalarType::Short:
res = "short"; break;
case ScalarType::Char:
res = "char"; break;
case ScalarType::Byte:
res = "uchar"; break;
case ScalarType::Bool:
res = "bool"; break;
default:
break;
}
return res;
}

std::string getIndexFunctionName(ScalarType scalar_type, bool index_select, bool accumulate) {
std::string indexFunction = index_select ? "index_select_" :
(accumulate && (scalar_type != kBool)) ? "index_put_accumulate_" : "index_put_";

return indexFunction + getMetalScalarType(scalar_type);
}
}
}
}
Loading

0 comments on commit ce7177f

Please sign in to comment.