Skip to content

Commit

Permalink
CUDA 11.0 issue (#290)
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischoy committed Jan 7, 2021
1 parent 6f2b4be commit 6db3c49
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 28 deletions.
30 changes: 22 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ We visualized a sparse tensor network operation on a sparse tensor, convolution,
## Requirements

- Ubuntu >= 14.04
- 11.1 > CUDA >= 10.1.243
- CUDA >= 11.1 or 11.0 > CUDA >= 10.1.243, [No CUDA 11.0](https://github.com/NVIDIA/MinkowskiEngine/issues/290)
- pytorch >= 1.5
- python >= 3.6
- GCC >= 7
Expand All @@ -77,23 +77,22 @@ First, install pytorch following the [instruction](https://pytorch.org). Next, i
```
sudo apt install libopenblas-dev
pip install torch
pip install -U MinkowskiEngine --install-option="--blas=openblas" -v
pip install -U MinkowskiEngine --install-option="--blas=openblas" -v --no-deps
# For pip installation from the latest source
# pip install -U git+https://github.com/NVIDIA/MinkowskiEngine
# pip install -U git+https://github.com/NVIDIA/MinkowskiEngine --no-deps
```

If you want to specify arguments for the setup script, please refer to the following command.

```
# Uncomment some options if things don't work
pip install -U git+https://github.com/NVIDIA/MinkowskiEngine \
# export CUDA_HOME=/usr/local/cuda-11.1; # or select the correct cuda version on your system.
pip install -U git+https://github.com/NVIDIA/MinkowskiEngine --no-deps \
# \ # uncomment the following line if you want to force cuda installation
# --install-option="--force_cuda" \
# \ # uncomment the following line if you want to force no cuda installation. force_cuda supercedes cpu_only
# --install-option="--cpu_only" \
# \ # uncomment the following line when torch fails to find cuda_home.
# --install-option="--cuda_home=/usr/local/cuda" \
# \ # uncomment the following line to override to openblas, atlas, mkl, blas
# --install-option="--blas=openblas" \
```
Expand All @@ -108,7 +107,7 @@ sudo apt install libopenblas-dev
conda create -n py3-mink python=3.8
conda activate py3-mink
conda install numpy mkl-include pytorch cudatoolkit=11.0 -c pytorch
pip install -U git+https://github.com/NVIDIA/MinkowskiEngine
pip install -U git+https://github.com/NVIDIA/MinkowskiEngine --no-deps
```

### System Python
Expand All @@ -131,7 +130,7 @@ cd MinkowskiEngine
python setup.py install
# To specify blas, CUDA_HOME and force CUDA installation, use the following command
# python setup.py install --blas=openblas --cuda_home=/usr/local/cuda --force_cuda
# export CUDA_HOME=/usr/local/cuda-11.1; python setup.py install --blas=openblas --force_cuda
```


Expand Down Expand Up @@ -222,13 +221,28 @@ page](https://github.com/NVIDIA/MinkowskiEngine/issues).

### Too much GPU memory usage or Frequent Out of Memory

There are a few causes for this error.

1. Out of memory during a long running training

MinkowskiEngine is a specialized library that can handle different number of points or different number of non-zero elements at every iteration during training, which is common in point cloud data.
However, pytorch is implemented assuming that the number of point, or size of the activations do not change at every iteration. Thus, the GPU memory caching used by pytorch can result in unnecessarily large memory consumption.

Specifically, pytorch caches chunks of memory spaces to speed up allocation used in every tensor creation. If it fails to find the memory space, it splits an existing cached memory or allocate new space if there's no cached memory large enough for the requested size. Thus, every time we use different number of point (number of non-zero elements) with pytorch, it either split existing cache or reserve new memory. If the cache is too fragmented and allocated all GPU space, it will raise out of memory error.

**To prevent this, you must clear the cache at regular interval with `torch.cuda.empty_cache()`.**

2. CUDA 11.0

There is a known CUDA issues that force torch to allocate exorbitant memory when used with MinkowskiEngine. For more detail please refer to [the issue #290](https://github.com/NVIDIA/MinkowskiEngine/issues/290). To fix, please install CUDA toolkit 11.1 and compile MinkowskiEngine.

```
wget https://developer.download.nvidia.com/compute/cuda/11.1.1/local_installers/cuda_11.1.1_455.32.00_linux.run
sudo sh cuda_11.1.1_455.32.00_linux.run --toolkit --silent --override
# Install MinkowskiEngine with CUDA 11.1
export CUDA_HOME=/usr/local/cuda-11.1; pip install MinkowskiEngine -v --no-deps
```

### Running the MinkowskiEngine on nodes with a large number of CPUs

Expand Down
38 changes: 24 additions & 14 deletions docs/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,9 @@ The Minkowski Engine is an auto-differentiation library for sparse tensors. It s

## News

- 2020-12-24 v0.5 is now available! Try with `pip install git+https://github.com/NVIDIA/MinkowskiEngine.git`.
- 2020-12-24 v0.5 is now available!
- 2020-08-18 v0.5 beta version is now available! [The speedup compared with v0.4.3 ranges from x2 to x5 depending on the network architectures](https://github.com/chrischoy/MinkowskiEngineBenchmark). Please install with the following command. The migration guide from v0.4 to v0.5 is available at [the wiki page](https://github.com/NVIDIA/MinkowskiEngine/wiki/Migration-Guide-from-v0.4.x-to-0.5.x). Feel free to update the wiki page to add and update any discrepancies you see.

```
pip install git+https://github.com/NVIDIA/MinkowskiEngine.git
```

## Example Networks

The Minkowski Engine supports various functions that can be built on a sparse tensor. We list a few popular network architectures and applications here. To run the examples, please install the package and run the command in the package root directory.
Expand Down Expand Up @@ -58,7 +54,7 @@ We visualized a sparse tensor network operation on a sparse tensor, convolution,
## Requirements

- Ubuntu >= 14.04
- 11.1 > CUDA >= 10.1.243
- CUDA >= 11.1 or 11.0 > CUDA >= 10.1.243, [No CUDA 11.0](https://github.com/NVIDIA/MinkowskiEngine/issues/290)
- pytorch >= 1.5
- python >= 3.6
- GCC >= 7
Expand All @@ -81,23 +77,22 @@ First, install pytorch following the [instruction](https://pytorch.org). Next, i
```
sudo apt install libopenblas-dev
pip install torch
pip install -U MinkowskiEngine --install-option="--blas=openblas" -v
pip install -U MinkowskiEngine --install-option="--blas=openblas" -v --no-deps
# For pip installation from the latest source
# pip install -U git+https://github.com/NVIDIA/MinkowskiEngine
# pip install -U git+https://github.com/NVIDIA/MinkowskiEngine --no-deps
```

If you want to specify arguments for the setup script, please refer to the following command.

```
# Uncomment some options if things don't work
pip install -U git+https://github.com/NVIDIA/MinkowskiEngine \
# export CUDA_HOME=/usr/local/cuda-11.1; # or select the correct cuda version on your system.
pip install -U git+https://github.com/NVIDIA/MinkowskiEngine --no-deps \
# \ # uncomment the following line if you want to force cuda installation
# --install-option="--force_cuda" \
# \ # uncomment the following line if you want to force no cuda installation. force_cuda supercedes cpu_only
# --install-option="--cpu_only" \
# \ # uncomment the following line when torch fails to find cuda_home.
# --install-option="--cuda_home=/usr/local/cuda" \
# \ # uncomment the following line to override to openblas, atlas, mkl, blas
# --install-option="--blas=openblas" \
```
Expand All @@ -112,7 +107,7 @@ sudo apt install libopenblas-dev
conda create -n py3-mink python=3.8
conda activate py3-mink
conda install numpy mkl-include pytorch cudatoolkit=11.0 -c pytorch
pip install -U git+https://github.com/NVIDIA/MinkowskiEngine
pip install -U git+https://github.com/NVIDIA/MinkowskiEngine --no-deps
```

### System Python
Expand All @@ -135,7 +130,7 @@ cd MinkowskiEngine
python setup.py install
# To specify blas, CUDA_HOME and force CUDA installation, use the following command
# python setup.py install --blas=openblas --cuda_home=/usr/local/cuda --force_cuda
# export CUDA_HOME=/usr/local/cuda-11.1; python setup.py install --blas=openblas --force_cuda
```


Expand Down Expand Up @@ -170,7 +165,7 @@ class ExampleNetwork(ME.MinkowskiNetwork):
kernel_size=3,
stride=2,
dilation=1,
bias=False,
has_bias=False,
dimension=D),
ME.MinkowskiBatchNorm(64),
ME.MinkowskiReLU())
Expand Down Expand Up @@ -226,13 +221,28 @@ page](https://github.com/NVIDIA/MinkowskiEngine/issues).

### Too much GPU memory usage or Frequent Out of Memory

There are a few causes for this error.

1. Out of memory during a long running training

MinkowskiEngine is a specialized library that can handle different number of points or different number of non-zero elements at every iteration during training, which is common in point cloud data.
However, pytorch is implemented assuming that the number of point, or size of the activations do not change at every iteration. Thus, the GPU memory caching used by pytorch can result in unnecessarily large memory consumption.

Specifically, pytorch caches chunks of memory spaces to speed up allocation used in every tensor creation. If it fails to find the memory space, it splits an existing cached memory or allocate new space if there's no cached memory large enough for the requested size. Thus, every time we use different number of point (number of non-zero elements) with pytorch, it either split existing cache or reserve new memory. If the cache is too fragmented and allocated all GPU space, it will raise out of memory error.

**To prevent this, you must clear the cache at regular interval with `torch.cuda.empty_cache()`.**

2. CUDA 11.0

There is a known CUDA issues that force torch to allocate exorbitant memory when used with MinkowskiEngine. For more detail please refer to [the issue #290](https://github.com/NVIDIA/MinkowskiEngine/issues/290). To fix, please install CUDA toolkit 11.1 and compile MinkowskiEngine.

```
wget https://developer.download.nvidia.com/compute/cuda/11.1.1/local_installers/cuda_11.1.1_455.32.00_linux.run
sudo sh cuda_11.1.1_455.32.00_linux.run --toolkit --silent --override
# Install MinkowskiEngine with CUDA 11.1
export CUDA_HOME=/usr/local/cuda-11.1; pip install MinkowskiEngine -v --no-deps
```

### Running the MinkowskiEngine on nodes with a large number of CPUs

Expand Down
4 changes: 2 additions & 2 deletions src/broadcast_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,9 @@ void BroadcastBackwardKernelGPU(
#if defined(CUDART_VERSION) && (CUDART_VERSION < 10010)
TORCH_CHECK(false, "spmm sparse-dense requires CUDA 10.1 or greater");
#elif defined(CUDART_VERSION) && (CUDART_VERSION >= 10010) && \
(CUDART_VERSION < 11010)
(CUDART_VERSION < 11000)
mm_alg = CUSPARSE_MM_ALG_DEFAULT;
#elif defined(CUDART_VERSION) && (CUDART_VERSION >= 11010)
#elif defined(CUDART_VERSION) && (CUDART_VERSION >= 11000)
mm_alg = CUSPARSE_SPMM_ALG_DEFAULT;
#endif

Expand Down
4 changes: 2 additions & 2 deletions src/pooling_avg_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,10 @@ void NonzeroAvgPoolingForwardKernelGPU(
#if defined(CUDART_VERSION) && (CUDART_VERSION < 10010)
ASSERT(false, "spmm sparse-dense requires CUDA 10.1 or greater");
#elif defined(CUDART_VERSION) && (CUDART_VERSION >= 10010) && \
(CUDART_VERSION < 11010)
(CUDART_VERSION < 11000)
mm_alg = CUSPARSE_COOMM_ALG1;
static_assert(is_int32, "int64 cusparseSpMM requires CUDA 11.1 or greater");
#elif defined(CUDART_VERSION) && (CUDART_VERSION >= 11010)
#elif defined(CUDART_VERSION) && (CUDART_VERSION >= 11000)
mm_alg = CUSPARSE_SPMM_COO_ALG1;
static_assert(is_int32 || is_int64, "Invalid index type");
#endif
Expand Down
4 changes: 2 additions & 2 deletions src/spmm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ torch::Tensor coo_spmm(torch::Tensor const &rows, torch::Tensor const &cols,
#if defined(CUDART_VERSION) && (CUDART_VERSION < 10010)
TORCH_CHECK(false, "spmm sparse-dense requires CUDA 10.1 or greater");
#elif defined(CUDART_VERSION) && (CUDART_VERSION >= 10010) && \
(CUDART_VERSION < 11010)
(CUDART_VERSION < 11000)
switch (spmm_algorithm_id) {
case 1:
mm_alg = CUSPARSE_COOMM_ALG1;
Expand All @@ -89,7 +89,7 @@ torch::Tensor coo_spmm(torch::Tensor const &rows, torch::Tensor const &cols,
mm_alg = CUSPARSE_MM_ALG_DEFAULT;
}
TORCH_CHECK(is_int32, "int64 cusparseSpMM requires CUDA 11.1 or greater");
#elif defined(CUDART_VERSION) && (CUDART_VERSION >= 11010)
#elif defined(CUDART_VERSION) && (CUDART_VERSION >= 11000)
switch (spmm_algorithm_id) {
case 1:
mm_alg = CUSPARSE_SPMM_COO_ALG1;
Expand Down

0 comments on commit 6db3c49

Please sign in to comment.