Skip to content

Commit

Permalink
network speed script, minor
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischoy committed Dec 15, 2020
1 parent 537c00d commit a28ede6
Show file tree
Hide file tree
Showing 8 changed files with 179 additions and 15 deletions.
4 changes: 3 additions & 1 deletion MinkowskiEngine/MinkowskiCoordinateManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,4 +436,6 @@ def origin_map(self, key: CoordinateMapKey):
# self.CPPCoordsManager.printDiagnostics(coords_key.CPPCoordsKey)

def __repr__(self):
return "CoordinateManager(\n" + str(self._manager) + " )\n"
return (
self._CoordinateManagerClass.__name__ + "(\n" + str(self._manager) + " )\n"
)
3 changes: 2 additions & 1 deletion examples/minkunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class MinkUNetBase(ResNetBase):
PLANES = None
DILATIONS = (1, 1, 1, 1, 1, 1, 1, 1)
LAYERS = (2, 2, 2, 2, 2, 2, 2, 2)
PLANES = (32, 64, 128, 256, 256, 128, 96, 96)
INIT_DIM = 32
OUT_TENSOR_STRIDE = 1

Expand Down Expand Up @@ -112,7 +113,7 @@ def network_initialization(self, in_channels, out_channels, D):
self.LAYERS[7])

self.final = ME.MinkowskiConvolution(
self.PLANES[7],
self.PLANES[7] * self.BLOCK.expansion,
out_channels,
kernel_size=1,
bias=True,
Expand Down
13 changes: 11 additions & 2 deletions src/3rdparty/hash/hash_allocator.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ struct managed_allocator {
T* allocate(std::size_t n, cudaStream_t stream = 0) const
{
T* d_tmp;
cudaMalloc((void**) &d_tmp, n * sizeof(T));
cudaError_t error = cudaMalloc((void**) &d_tmp, n * sizeof(T));
if (error != cudaSuccess) {
cudaGetLastError(); // clear CUDA error
std::runtime_error("cudaMalloc failed in the hash_allocator.cuh:managed_allocator.");
}
return d_tmp;
// return static_cast<T*>(mr->allocate(n * sizeof(T), stream));
}
Expand Down Expand Up @@ -73,7 +77,12 @@ struct default_allocator {
T* allocate(std::size_t n, cudaStream_t stream = 0) const
{
T* d_tmp;
cudaMalloc((void**) &d_tmp, n * sizeof(T));
cudaError_t error = cudaMalloc((void**) &d_tmp, n * sizeof(T));
if (error != cudaSuccess) {
cudaGetLastError(); // clear CUDA error
std::runtime_error("cudaMalloc failed in the hash_allocator.cuh:default_allocator.");
}

return d_tmp;
// return static_cast<T*>(mr->allocate(n * sizeof(T), stream));
}
Expand Down
8 changes: 7 additions & 1 deletion src/allocators.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,13 @@ template <class T> struct default_allocator {

T *allocate(std::size_t n, cudaStream_t stream = 0) const {
T *d_tmp;
cudaMalloc((void **)&d_tmp, n * sizeof(T));
cudaError_t error = cudaMalloc((void **)&d_tmp, n * sizeof(T));
if (error != cudaSuccess) {
cudaGetLastError(); // clear error
c10::cuda::CUDACachingAllocator::emptyCache();
LOG_DEBUG("Automatically called empty cache");
CUDA_CHECK(cudaMalloc((void **)&d_tmp, n * sizeof(T)));
}
return d_tmp;
// return static_cast<T*>(mr->allocate(n * sizeof(T), stream));
}
Expand Down
14 changes: 6 additions & 8 deletions src/convolution_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,6 @@ void ConvolutionForwardKernelGPU(
default_types::size_type const out_nrows, //
cublasHandle_t cuhandle, cudaStream_t stream) {

CUDA_CHECK_ARGS(cudaDeviceSynchronize(),
". Error triggered from a previous kernel call.");

size_t n_active_in_volume, shared_mem_size = -1;

// Define the shared memory size
Expand Down Expand Up @@ -386,7 +383,7 @@ void ConvolutionForwardKernelGPU(
#endif
CUDA_CHECK(cudaGetLastError());
}
CUDA_CHECK(cudaDeviceSynchronize());
CUDA_CHECK(cudaStreamSynchronize(stream));
}

// default_allocator
Expand Down Expand Up @@ -493,8 +490,8 @@ void ConvolutionBackwardKernelGPU(
<<<grid_out, threads, threads.x * sizeof(Itype), stream>>>(
n_active_in_volume, out_nchannel, d_grad_out_feat,
d_output_buffer, d_out_map);
CUDA_CHECK(cudaStreamSynchronize(stream));
#ifdef DEBUG
CUDA_CHECK(cudaStreamSynchronize(stream));
LOG_DEBUG("copy input", t.toc());
t.tic();
#endif
Expand All @@ -508,8 +505,8 @@ void ConvolutionBackwardKernelGPU(
0, // beta
d_input_buffer // C
);
CUDA_CHECK(cudaStreamSynchronize(0));
#ifdef DEBUG
CUDA_CHECK(cudaStreamSynchronize(0));
LOG_DEBUG("input grad gemm", t.toc());
t.tic();
#endif
Expand All @@ -525,8 +522,8 @@ void ConvolutionBackwardKernelGPU(
n_active_in_volume, // In channel
d_grad_in_feat, in_nchannel, // Out
d_in_map); // Out channel
CUDA_CHECK(cudaStreamSynchronize(stream));
#ifdef DEBUG
CUDA_CHECK(cudaStreamSynchronize(stream));
LOG_DEBUG("accumulate in grad", t.toc());
t.tic();
#endif
Expand All @@ -539,12 +536,13 @@ void ConvolutionBackwardKernelGPU(
<<<grid_in, threads, threads.x * sizeof(Itype), stream>>>(
n_active_in_volume, in_nchannel, d_in_feat, d_input_buffer,
d_in_map);
CUDA_CHECK(cudaStreamSynchronize(stream));
#ifdef DEBUG
LOG_DEBUG("copy in feat to buffer", t.toc());
t.tic();
#endif

// sync before the blas call
CUDA_CHECK(cudaStreamSynchronize(stream));
gpu_gemm<Dtype>(cuhandle, CblasTrans, CblasNoTrans,
in_nchannel, // M
out_nchannel, // N
Expand Down
1 change: 0 additions & 1 deletion src/pooling_max_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ void MaxPoolingForwardKernelGPU(

// Second, create number of in_feat per out, and starting index
Itype *d_index, *d_in_map_min, *d_reduced_out_map;
// CUDA_CHECK(cudaMalloc((void **)&d_index, 3 * nmap * sizeof(Itype)));
d_index = d_scr + 2 * nmap;
d_in_map_min = d_scr + 3 * nmap;
d_reduced_out_map = d_scr + 4 * nmap;
Expand Down
148 changes: 148 additions & 0 deletions tests/python/network_speed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Copyright (c) Chris Choy (chrischoy@ai.stanford.edu).
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
# of the Software, and to permit persons to whom the Software is furnished to do
# so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural
# Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part
# of the code.
import os
import argparse
import numpy as np
from urllib.request import urlretrieve

try:
import open3d as o3d
except ImportError:
raise ImportError("Please install open3d with `pip install open3d`.")

import torch
import MinkowskiEngine as ME
from MinkowskiCommon import convert_to_int_list
import examples.minkunet as UNets
from tests.python.common import data_loader, load_file, batched_coordinates
from examples.common import Timer

# Check if the weights and file exist and download
if not os.path.isfile("weights.pth"):
print("Downloading weights and a room ply file...")
urlretrieve(
"http://cvgl.stanford.edu/data2/minkowskiengine/weights.pth", "weights.pth"
)
urlretrieve("http://cvgl.stanford.edu/data2/minkowskiengine/1.ply", "1.ply")

parser = argparse.ArgumentParser()
parser.add_argument("--file_name", type=str, default="1.ply")
parser.add_argument("--weights", type=str, default="weights.pth")
parser.add_argument("--use_cpu", action="store_true")
parser.add_argument("--backward", action="store_true")
parser.add_argument("--max_batch", type=int, default=12)


def quantize(coordinates):
D = coordinates.size(1) - 1
coordinate_manager = ME.CoordinateManager(
D=D, coordinate_map_type=ME.CoordinateMapType.CPU
)
coordinate_map_key = ME.CoordinateMapKey(convert_to_int_list(1, D), "")
key, (unique_map, inverse_map) = coordinate_manager.insert_and_map(
coordinates, *coordinate_map_key.get_key()
)
return unique_map, inverse_map


def load_file(file_name, voxel_size):
pcd = o3d.io.read_point_cloud(file_name)
coords = torch.from_numpy(np.array(pcd.points))
feats = torch.from_numpy(np.array(pcd.colors)).float()

quantized_coords = torch.floor(coords / voxel_size).int()
inds, inverse_inds = quantize(quantized_coords)

return quantized_coords[inds], feats[inds], pcd


def forward(coords, colors, model):
# Measure time
timer = Timer()
for i in range(5):
# Feed-forward pass and get the prediction
timer.tic()
sinput = ME.SparseTensor(
features=colors,
coordinates=coords,
device=device,
allocator_type=ME.GPUMemoryAllocatorType.PYTORCH,
)
logits = model(sinput)
timer.toc()
return timer.min_time, len(logits)


def train(coords, colors, model):
# Measure time
timer = Timer()
for i in range(5):
# Feed-forward pass and get the prediction
timer.tic()
sinput = ME.SparseTensor(
colors,
coords,
device=device,
allocator_type=ME.GPUMemoryAllocatorType.PYTORCH,
)
logits = model(sinput)
logits.F.sum().backward()
timer.toc()
return timer.min_time, len(logits)


def test_network(coords, feats, model, batch_sizes, forward_only=True):
for batch_size in batch_sizes:
bcoords = batched_coordinates([coords for i in range(batch_size)])
bfeats = torch.cat([feats for i in range(batch_size)], 0)
if forward_only:
with torch.no_grad():
time, length = forward(bcoords, bfeats, model)
else:
time, length = train(bcoords, bfeats, model)

print(f"{net.__name__}\t{voxel_size}\t{batch_size}\t{length}\t{time}")
torch.cuda.empty_cache()


if __name__ == "__main__":
config = parser.parse_args()
device = torch.device(
"cuda" if (torch.cuda.is_available() and not config.use_cpu) else "cpu"
)
print(f"Using {device}")
print(f"Using backward {config.backward}")
# Define a model and load the weights
batch_sizes = [i for i in range(2, config.max_batch + 1, 2)]
batch_sizes = [1, *batch_sizes]

for net in [UNets.MinkUNet14, UNets.MinkUNet18, UNets.MinkUNet34, UNets.MinkUNet50]:
model = net(3, 20).to(device)
model.eval()
for voxel_size in [0.02]:
print(voxel_size)
coords, feats, _ = load_file(config.file_name, voxel_size)
test_network(coords, feats, model, batch_sizes, not config.backward)
torch.cuda.empty_cache()
del model
3 changes: 2 additions & 1 deletion tests/python/sparse_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def test_device(self):
SparseTensor(feats, coords)
SparseTensor(feats.to(0), coords.to(0))
feats = torch.FloatTensor([[0, 1, 2, 3, 5, 6, 7]]).T.to(0)
SparseTensor(feats, coords, device=feats.device)
st = SparseTensor(feats, coords, device=feats.device)
print(st)

def test_duplicate_coords(self):
print(f"{self.__class__.__name__}: test_duplicate_coords")
Expand Down

0 comments on commit a28ede6

Please sign in to comment.