Skip to content

Commit

Permalink
gpu memory manager with cudamalloc
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischoy committed Dec 19, 2019
1 parent 0edbdaa commit 8195f63
Show file tree
Hide file tree
Showing 17 changed files with 133 additions and 191 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
- CoordsMap size initialization updates
- Added MinkowskiUnion
- Updated MinkowskiUnion, MinkowskiPruning docs
- Use cudaMalloc instead of `at::Tensor` for GPU memory management for illegal memory access, invalid arg.


## [0.3.1] - 2019-12-15
Expand Down
6 changes: 0 additions & 6 deletions pybind/minkowski.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,12 +204,6 @@ void instantiate(py::module &m) {
}

void bind_native(py::module &m) {
#ifndef CPU_ONLY
py::class_<GPUMemoryManager>(m, "MemoryManager")
.def(py::init<>())
.def("resize", &GPUMemoryManager::resize);
#endif

std::string name = std::string("CoordsKey");
py::class_<CoordsKey>(m, name.c_str())
.def(py::init<>())
Expand Down
16 changes: 1 addition & 15 deletions src/broadcast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,20 +122,6 @@ void BroadcastBackwardGPU(at::Tensor in_feat, at::Tensor grad_in_feat,
grad_in_feat_glob.resize_as_(in_feat_glob);
grad_in_feat_glob.zero_();

const int nnz = getInOutMapsSize(p_coords_manager->d_in_maps[map_key]);

int *d_scr = (int *)p_coords_manager->getScratchGPUMemory(
2 * nnz * sizeof(int) + // in, out maps to sort
(in_feat_glob.size(0) + 1) * sizeof(int) // d_csr_row
);

Dtype *d_dscr = (Dtype *)p_coords_manager->getScratchGPUMemory2(
in_feat.size(0) * sizeof(Dtype) + // d_csr_val
in_feat.size(0) * in_feat.size(1) * sizeof(Dtype) + // tmp_grad_infeat
in_feat_glob.size(0) * in_feat.size(1) *
sizeof(Dtype) // tmp_grad_infeat_global
);

cusparseHandle_t handle = at::cuda::getCurrentCUDASparseHandle();
cusparseSetStream(handle, at::cuda::getCurrentCUDAStream());

Expand All @@ -144,7 +130,7 @@ void BroadcastBackwardGPU(at::Tensor in_feat, at::Tensor grad_in_feat,
in_feat_glob.data<Dtype>(), grad_in_feat_glob.data<Dtype>(),
in_feat_glob.size(0), grad_out_feat.data<Dtype>(), in_feat.size(1), op,
p_coords_manager->d_in_maps[map_key],
p_coords_manager->d_out_maps[map_key], d_scr, d_dscr, handle,
p_coords_manager->d_out_maps[map_key], handle,
at::cuda::getCurrentCUDAStream());
}
#endif
Expand Down
74 changes: 36 additions & 38 deletions src/broadcast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ void BroadcastForwardKernelGPU(
}

CUDA_CHECK(cudaGetLastError());
// cudaFree(d_out_map);
CUDA_CHECK(cudaDeviceSynchronize());
}

template void BroadcastForwardKernelGPU<float, int32_t>(
Expand All @@ -139,16 +139,13 @@ template void BroadcastForwardKernelGPU<double, int32_t>(
cusparseHandle_t cuhandle, cudaStream_t stream);

template <typename Dtype, typename Itype>
void BroadcastBackwardKernelGPU(const Dtype *d_in_feat, Dtype *d_grad_in_feat,
int in_nrows, const Dtype *d_in_feat_global,
Dtype *d_grad_in_feat_global,
int in_nrows_global,
const Dtype *d_grad_out_feat, int nchannel,
int op, const pInOutMaps<Itype> &in_maps,
const pInOutMaps<Itype> &out_maps, Itype *d_scr,
Dtype *d_dscr, cusparseHandle_t cushandle,
cudaStream_t stream) {
Itype *d_in_map, *d_out_map, *d_csr_row;
void BroadcastBackwardKernelGPU(
const Dtype *d_in_feat, Dtype *d_grad_in_feat, int in_nrows,
const Dtype *d_in_feat_global, Dtype *d_grad_in_feat_global,
int in_nrows_global, const Dtype *d_grad_out_feat, int nchannel, int op,
const pInOutMaps<Itype> &in_maps, const pInOutMaps<Itype> &out_maps,
cusparseHandle_t cushandle, cudaStream_t stream) {
Itype *d_scr, *d_in_map, *d_out_map, *d_csr_row;
Dtype *d_dtype, *d_csr_val, *d_tmp_grad_in_feat_global, *d_tmp_grad_in_feat;
cusparseMatDescr_t descr = 0;
const Dtype alpha = 1;
Expand All @@ -164,32 +161,20 @@ void BroadcastBackwardKernelGPU(const Dtype *d_in_feat, Dtype *d_grad_in_feat,
if (in_maps[0].size() != in_nrows)
throw std::invalid_argument("Invalid in_map");

/* In Out Map prep */
// Malloc d_in_map, d_out_map, d_csr_row
// CSR returns n_row + 1
// CUDA_CHECK(cudaMalloc((void **)&d_in_map,
// (in_maps[0].size() + out_maps[0].size()
// + in_nrows_global + 1) * sizeof(Itype)));

// GPUMemoryManager<Dtype> dmem((nnz + (in_nrows + in_nrows_global) *
// nchannel)); CUDA_CHECK(cudaMalloc((void **)&d_dtype,
// (nnz + (in_nrows + in_nrows_global) * nchannel) *
// sizeof(Dtype)));
// d_dtype =
// (Dtype *)(d_scr + in_maps[0].size() + out_maps[0].size()
// + in_nrows_global + 1);

// Divide the memory space into multiple chunks
d_dtype = d_dscr;
d_tmp_grad_in_feat_global = d_dtype;
d_tmp_grad_in_feat = d_tmp_grad_in_feat_global + in_nrows_global * nchannel;
d_csr_val = d_tmp_grad_in_feat + in_nrows * nchannel;
CUDA_CHECK(cudaMalloc((void **)&d_scr,
2 * nnz * sizeof(Itype) + // in out maps
(in_nrows_global + 1) * sizeof(Itype) // d_csr_row
));

// COO cols
d_in_map = d_scr;
d_in_map = d_scr; // nnz
// COO rows
d_out_map = d_scr + nnz;
d_out_map = d_scr + nnz; // nnz
// CSR row indices
d_csr_row = d_scr + 2 * nnz;
d_csr_row = d_scr + 2 * nnz; // in_nrows_global + 1

CUDA_CHECK(cudaMemcpy(d_in_map,
in_maps[0].data(), // in_maps are contiguous of size nnz
Expand All @@ -200,6 +185,21 @@ void BroadcastBackwardKernelGPU(const Dtype *d_in_feat, Dtype *d_grad_in_feat,
out_maps[0].data(), // out_maps are contiguous of size nnz
nnz * sizeof(int), cudaMemcpyDeviceToDevice));

/* tmp in out feat */
// sparse gemm output
CUDA_CHECK(cudaMalloc(
(void **)&d_dtype,
nnz * sizeof(Dtype) + // d_csr_val
in_nrows * nchannel * sizeof(Dtype) + // tmp_grad_infeat
in_nrows_global * nchannel * sizeof(Dtype) // tmp_grad_infeat_global
));

// Divide the memory space into multiple chunks
d_tmp_grad_in_feat_global = d_dtype; // in_nrows_global * nchannel
d_tmp_grad_in_feat = d_tmp_grad_in_feat_global +
in_nrows_global * nchannel; // in_nrows * nchannel
d_csr_val = d_tmp_grad_in_feat + in_nrows * nchannel;

// thrust::fill(d_csr_val.begin(), d_csr_val.end(), 1);
fill<Dtype><<<GET_BLOCKS(in_nrows), CUDA_NUM_THREADS, 0, stream>>>(
nnz, d_csr_val, (Dtype)1.);
Expand Down Expand Up @@ -301,27 +301,25 @@ void BroadcastBackwardKernelGPU(const Dtype *d_in_feat, Dtype *d_grad_in_feat,
<< std::to_string(op));
}

cudaFree(d_scr);
cudaFree(d_dtype);
CUSPARSE_CHECK(cusparseDestroyMatDescr(descr));

CUDA_CHECK(cudaGetLastError());

// cudaFree(d_in_map);
// cudaFree(d_dtype);
CUDA_CHECK(cudaDeviceSynchronize());
}

template void BroadcastBackwardKernelGPU<float, int32_t>(
const float *d_in_feat, float *d_grad_in_feat, int in_nrows,
const float *d_in_feat_global, float *d_grad_in_feat_global,
int in_nrows_global, const float *d_grad_out_feat, int nchannel, int op,
const pInOutMaps<int32_t> &in_map, const pInOutMaps<int32_t> &out_map,
int32_t *d_scr, float *d_dscr, cusparseHandle_t cushandle,
cudaStream_t stream);
cusparseHandle_t cushandle, cudaStream_t stream);

template void BroadcastBackwardKernelGPU<double, int32_t>(
const double *d_in_feat, double *d_grad_in_feat, int in_nrows,
const double *d_in_feat_global, double *d_grad_in_feat_global,
int in_nrows_global, const double *d_grad_out_feat, int nchannel, int op,
const pInOutMaps<int32_t> &in_map, const pInOutMaps<int32_t> &out_map,
int32_t *d_scr, double *d_dscr, cusparseHandle_t cushandle,
cudaStream_t stream);
cusparseHandle_t cushandle, cudaStream_t stream);
#endif
3 changes: 1 addition & 2 deletions src/broadcast.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ void BroadcastBackwardKernelGPU(
const Dtype *d_in_feat_global, Dtype *d_grad_in_feat_global,
int in_nrows_global, const Dtype *d_grad_out_feat, int nchannel, int op,
const pInOutMaps<Itype> &d_in_map, const pInOutMaps<Itype> &d_out_map,
Itype *d_scr, Dtype *d_dscr, cusparseHandle_t cushandle,
cudaStream_t stream);
cusparseHandle_t cushandle, cudaStream_t stream);

#endif
2 changes: 1 addition & 1 deletion src/convolution.cu
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ __global__ void matmul2(const Dtype *A, const int wA, const int hA,
const int tx = threadIdx.x;
const int ty = threadIdx.y;

// Coordinate. x is for rows, y is for columns.
// Coordinate. y is for rows, x is for columns.
const int x = BLOCK_SIZE * bx + tx;
const int y = BLOCK_SIZE * by + ty;

Expand Down
7 changes: 4 additions & 3 deletions src/coords_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,12 @@ class CoordsManager {
getUnionInOutMapsGPU(vector<py::object> py_in_coords_keys,
py::object py_out_coords_key);

void *getScratchGPUMemory(int size) { return gpu_memory_manager.data(size); }
void *getScratchGPUMemory2(int size) {
return gpu_memory_manager.data2(size);
void *getScratchGPUMemory(size_t size) {
return gpu_memory_manager.tmp_data(size);
}

void clearScratchGPUMemory() { gpu_memory_manager.clear_tmp(); }

#endif // CPU_ONLY
};

Expand Down
60 changes: 28 additions & 32 deletions src/gpu_memory_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,57 +8,53 @@
#include "gpu.cuh"
#include "types.hpp"

#include <torch/extension.h>

using namespace std;

class GPUMemoryManager {
int initial_size = 256;
int device_id;
torch::TensorOptions options;

public:
// Scratch space, the user should keep track of the validity of the pointer
torch::Tensor scratch_data;
torch::Tensor scratch_data2; // cusparse_csrmm requires 2 memory spaces

// A set of data that will be not be freed untill the class is destroyed.
vector<torch::Tensor> vec_data;
vector<void *> persist_vec_ptr;
vector<void *> tmp_vec_ptr;

// Memory manager simply allocates and free memory when done.
GPUMemoryManager() {
CUDA_CHECK(cudaGetDevice(&device_id));
options = torch::TensorOptions()
.dtype(torch::kByte)
.device(torch::kCUDA, device_id)
.requires_grad(false);
scratch_data = torch::zeros({initial_size}, options).contiguous();
scratch_data2 = torch::zeros({initial_size}, options).contiguous();
}
GPUMemoryManager() { CUDA_CHECK(cudaGetDevice(&device_id)); }
GPUMemoryManager(int size) : initial_size(size) { GPUMemoryManager(); }
~GPUMemoryManager() {
for (auto p_buffer : persist_vec_ptr) {
cudaFree(p_buffer);
}
}

pInOutMaps<int> copyInOutMapToGPU(const InOutMaps<int> &map);

void resize(int size) { scratch_data.resize_({size}).contiguous(); }
void resize2(int size) { scratch_data2.resize_({size}).contiguous(); }

void *data(int size) {
if (scratch_data.numel() < size)
resize(size);
return (void *)scratch_data.data<unsigned char>();
void clear_tmp() {
for (auto p_buffer : tmp_vec_ptr) {
cudaFree(p_buffer);
}
tmp_vec_ptr.clear();
}

void *data2(int size) {
if (scratch_data2.numel() < size)
resize2(size);
return (void *)scratch_data2.data<unsigned char>();
void set_device() {
CUDA_CHECK(cudaSetDevice(device_id));
}

void *gpuMalloc(int size) {
torch::Tensor data = torch::zeros({size}, options).contiguous();
vec_data.push_back(data);
void *tmp_data(size_t size) {
void *p_buffer = NULL;
CUDA_CHECK(cudaSetDevice(device_id));
CUDA_CHECK(cudaMalloc(&p_buffer, size));
tmp_vec_ptr.push_back(p_buffer);
return p_buffer;
}

return (void *)data.data<unsigned char>();
void *gpuMalloc(size_t size) {
void *p_buffer = NULL;
CUDA_CHECK(cudaSetDevice(device_id));
CUDA_CHECK(cudaMalloc(&p_buffer, size));
persist_vec_ptr.push_back(p_buffer);
return p_buffer;
}
};

Expand Down
16 changes: 1 addition & 15 deletions src/pooling_avg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,27 +112,13 @@ void AvgPoolingForwardGPU(at::Tensor in_feat, at::Tensor out_feat,
num_nonzero_data = num_nonzero.data<Dtype>();
}

// int dtype_mult = dtypeMultiplier<Dtype, int>(), nnz = 0;
int nnz = getInOutMapsSize(in_out.first);

int *d_scr = (int *)p_coords_manager->getScratchGPUMemory(
2 * nnz * sizeof(int) + // in, out maps to sort
(out_nrows + 1) * sizeof(int)); // csr_row

Dtype *d_dscr = (Dtype *)p_coords_manager->getScratchGPUMemory2(
((use_avg ? in_feat.size(0) : 0) + // d_ones
nnz + // d_csr_val
in_feat.size(1) * out_nrows // d_tmp_out_feat
) *
sizeof(Dtype));

cusparseHandle_t handle = at::cuda::getCurrentCUDASparseHandle();
cusparseSetStream(handle, at::cuda::getCurrentCUDAStream());

NonzeroAvgPoolingForwardKernelGPU<Dtype, int>(
in_feat.data<Dtype>(), in_feat.size(0), out_feat.data<Dtype>(), out_nrows,
num_nonzero_data, in_feat.size(1), in_out.first, in_out.second, use_avg,
d_scr, d_dscr, handle, at::cuda::getCurrentCUDAStream());
handle, at::cuda::getCurrentCUDAStream());
}

template <typename Dtype>
Expand Down
Loading

0 comments on commit 8195f63

Please sign in to comment.