Skip to content

Commit

Permalink
quantization np, torch
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischoy committed Dec 15, 2020
1 parent 652ddc2 commit 397e5d2
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 53 deletions.
4 changes: 1 addition & 3 deletions MinkowskiEngine/MinkowskiSparseTensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,11 +295,9 @@ def __init__(
coordinate_map_key = CoordinateMapKey(
convert_to_int_list(tensor_stride, self.D), ""
)
self._manager = coordinate_manager
else:
# not (coordinate_map_key is None or coordinate_manager is None)
self.D = coordinate_manager.D
self._manager = coordinate_manager

##########################
# Setup CoordsManager
Expand Down Expand Up @@ -329,7 +327,7 @@ def __init__(
allocator_type=allocator_type,
kernel_map_mode=kernel_map_mode,
)
self._manager = coordinate_manager
self._manager = coordinate_manager

##########################
# Initialize coords
Expand Down
2 changes: 1 addition & 1 deletion MinkowskiEngine/utils/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torch
import numpy as np
from collections import Sequence
import MinkowskiEngineBackend as MEB
import MinkowskiEngineBackend._C as MEB


def fnv_hash_vec(arr):
Expand Down
16 changes: 11 additions & 5 deletions pybind/extern.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

#include <torch/extension.h>

#include <pybind11/numpy.h>
#include <pybind11/operators.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
Expand Down Expand Up @@ -172,17 +173,17 @@ std::pair<at::Tensor, at::Tensor> ConvolutionTransposeBackwardGPU(
/*************************************
* Quantization
*************************************/
/*
template <typename MapType>
std::vector<py::array>
quantize_np(py::array_t<int, py::array::c_style | py::array::forcecast> coords);
std::vector<py::array> quantize_np(
py::array_t<int32_t, py::array::c_style | py::array::forcecast> coords);

std::vector<at::Tensor> quantize_th(at::Tensor &coords);

/*
vector<py::array> quantize_label_np(
py::array_t<int, py::array::c_style | py::array::forcecast> coords,
py::array_t<int, py::array::c_style | py::array::forcecast> labels,
int invalid_label);
template <typename MapType> vector<at::Tensor> quantize_th(at::Tensor coords);
vector<at::Tensor> quantize_label_th(at::Tensor coords, at::Tensor labels,
int invalid_label);
Expand Down Expand Up @@ -367,6 +368,11 @@ void instantiate_gpu_func(py::module &m, const std::string &dtypestr) {
py::call_guard<py::gil_scoped_release>());
}

void non_templated_cpu_func(py::module &m) {
m.def("quantize_np", &minkowski::quantize_np);
m.def("quantize_th", &minkowski::quantize_th);
}

void non_templated_gpu_func(py::module &m) {
m.def("coo_spmm_int32", &minkowski::coo_spmm<int32_t>,
py::call_guard<py::gil_scoped_release>());
Expand Down
1 change: 1 addition & 0 deletions pybind/minkowski.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#endif

// Functions
non_templated_cpu_func(m);
instantiate_cpu_func<int32_t>(m, "");

#ifndef CPU_ONLY
Expand Down
1 change: 1 addition & 0 deletions pybind/minkowski.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#endif

// Functions
non_templated_cpu_func(m);
instantiate_cpu_func<int32_t>(m, "");

#ifndef CPU_ONLY
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def _argparse(pattern, argv, is_flag=True):
"coordinate_map_manager.cpp",
"convolution_cpu.cpp",
"convolution_transpose_cpu.cpp",
"quantization.cpp",
],
["pybind/minkowski.cpp"],
["-DCPU_ONLY"],
Expand All @@ -195,6 +196,7 @@ def _argparse(pattern, argv, is_flag=True):
"convolution_transpose_gpu.cu",
"spmm.cu",
"gpu.cu",
"quantization.cpp",
],
["pybind/minkowski.cu"],
[],
Expand Down
107 changes: 64 additions & 43 deletions src/quantization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,17 @@
* of the code.
*/

#include <algorithm>

#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <torch/extension.h>

#include "coordsmap.hpp"
#include "pooling_avg.hpp"
#ifndef CPU_ONLY
#include "pooling_avg.cuh"
#include <ATen/cuda/CUDAContext.h>
#endif
#include "coordinate_map_cpu.hpp"

// #ifndef CPU_ONLY
// #include <ATen/cuda/CUDAContext.h>
// #endif
#include "utils.hpp"

namespace py = pybind11;
Expand All @@ -48,36 +49,48 @@ struct IndexLabel {
IndexLabel(int index_, int label_) : index(index_), label(label_) {}
};
using CoordsLabelMap =
robin_hood::unordered_flat_map<vector<int>, IndexLabel, byte_hash_vec<int>>;
using cpu_map_type =
robin_hood::unordered_flat_map<std::vector<int>, int,
byte_hash_vec<int>>;
*/

template <typename MapType>
vector<py::array> quantize_np(
py::array_t<int, py::array::c_style | py::array::forcecast> coords) {
std::vector<py::array> quantize_np(
py::array_t<int32_t, py::array::c_style | py::array::forcecast> coords) {
using coordinate_type = int32_t;
LOG_DEBUG("quantize_np");
py::buffer_info coords_info = coords.request();
LOG_DEBUG("buffer info requenst");
auto &shape = coords_info.shape;

ASSERT(shape.size() == 2,
"Dimension must be 2. The dimension of the input: ", shape.size());

int *p_coords = (int *)coords_info.ptr;
coordinate_type *p_coords = (coordinate_type *)coords_info.ptr;
LOG_DEBUG("ptr requenst");
int nrows = shape[0], ncols = shape[1];

// Create coords map
CoordsMap<MapType> map;
auto results = map.initialize_batch(p_coords, nrows, ncols, true, true);
LOG_DEBUG("coordinate map generation");
std::vector<default_types::size_type> tensor_stride(ncols - 1);
std::for_each(tensor_stride.begin(), tensor_stride.end(),
[](auto &i) { i = 1; });

CoordinateMapCPU<coordinate_type> map(nrows, ncols, tensor_stride);
LOG_DEBUG("Map nrows:", nrows, "ncols:", ncols);
auto results = map.insert_and_map<true>(p_coords, p_coords + nrows * ncols);
LOG_DEBUG("insertion finished");
auto &mapping = std::get<0>(results);
auto &inverse_mapping = std::get<1>(results);

// Copy the concurrent vector to std vector
py::array_t<int> py_mapping = py::array_t<int>(mapping.size());
py::array_t<int> py_inverse_mapping =
py::array_t<int>(inverse_mapping.size());
py::array_t<int32_t> py_mapping = py::array_t<int32_t>(mapping.size());
py::array_t<int32_t> py_inverse_mapping =
py::array_t<int32_t>(inverse_mapping.size());

py::buffer_info py_mapping_info = py_mapping.request();
py::buffer_info py_inverse_mapping_info = py_inverse_mapping.request();
int *p_py_mapping = (int *)py_mapping_info.ptr;
int *p_py_inverse_mapping = (int *)py_inverse_mapping_info.ptr;
int32_t *p_py_mapping = (int32_t *)py_mapping_info.ptr;
int32_t *p_py_inverse_mapping = (int32_t *)py_inverse_mapping_info.ptr;

std::copy_n(mapping.data(), mapping.size(), p_py_mapping);
std::copy_n(inverse_mapping.data(), inverse_mapping.size(),
Expand All @@ -87,27 +100,33 @@ vector<py::array> quantize_np(
return {py_mapping, py_inverse_mapping};
}

template <typename MapType> vector<at::Tensor> quantize_th(at::Tensor coords) {
std::vector<at::Tensor> quantize_th(at::Tensor &coords) {
using coordinate_type = int32_t;
ASSERT(coords.dtype() == torch::kInt32,
"Coordinates must be an int type tensor.");
ASSERT(coords.dim() == 2,
"Coordinates must be represnted as a matrix. Dimensions: ",
coords.dim(), "!= 2.");
coordinate_type *p_coords = coords.template data_ptr<coordinate_type>();
size_t nrows = coords.size(0), ncols = coords.size(1);
std::vector<default_types::size_type> tensor_stride(ncols - 1);
std::for_each(tensor_stride.begin(), tensor_stride.end(),
[](auto &i) { i = 1; });

CoordinateMapCPU<coordinate_type> map(nrows, ncols, tensor_stride);

CoordsMap<MapType> map;
auto results = map.initialize_batch(
coords.template data<int>(), coords.size(0), coords.size(1), true, true);
auto results = map.insert_and_map<true>(p_coords, p_coords + nrows * ncols);
auto mapping = std::get<0>(results);
auto inverse_mapping = std::get<1>(results);

// Long tensor for for easier indexing
auto th_mapping = torch::empty({(long)mapping.size()},
auto th_mapping = torch::empty({(int64_t)mapping.size()},
torch::TensorOptions().dtype(torch::kInt64));
auto th_inverse_mapping =
torch::empty({(long)inverse_mapping.size()},
torch::empty({(int64_t)inverse_mapping.size()},
torch::TensorOptions().dtype(torch::kInt64));
auto a_th_mapping = th_mapping.accessor<long int, 1>();
auto a_th_inverse_mapping = th_inverse_mapping.accessor<long int, 1>();
auto a_th_mapping = th_mapping.accessor<int64_t, 1>();
auto a_th_inverse_mapping = th_inverse_mapping.accessor<int64_t, 1>();

// Copy the output
for (size_t i = 0; i < mapping.size(); ++i)
Expand All @@ -119,7 +138,8 @@ template <typename MapType> vector<at::Tensor> quantize_th(at::Tensor coords) {
return {th_mapping, th_inverse_mapping};
}

vector<py::array> quantize_label_np(
/*
std::vector<py::array> quantize_label_np(
py::array_t<int, py::array::c_style | py::array::forcecast> coords,
py::array_t<int, py::array::c_style | py::array::forcecast> labels,
int invalid_label) {
Expand All @@ -138,10 +158,10 @@ vector<py::array> quantize_label_np(
int nrows = shape[0], ncols = shape[1];
// Create coords map
CoordsLabelMap map;
cpu_map_type map;
map.reserve(nrows);
for (int i = 0; i < nrows; i++) {
vector<int> coord(ncols);
std::vector<int> coord(ncols);
std::copy_n(p_coords + i * ncols, ncols, coord.data());
auto map_iter = map.find(coord);
if (map_iter == map.end()) {
Expand Down Expand Up @@ -170,7 +190,7 @@ vector<py::array> quantize_label_np(
return {py_mapping, py_colabels};
}
vector<at::Tensor> quantize_label_th(at::Tensor coords, at::Tensor labels,
std::vector<at::Tensor> quantize_label_th(at::Tensor coords, at::Tensor labels,
int invalid_label) {
ASSERT(coords.dtype() == torch::kInt32,
"Coordinates must be an int type tensor.");
Expand All @@ -186,10 +206,10 @@ vector<at::Tensor> quantize_label_th(at::Tensor coords, at::Tensor labels,
int nrows = coords.size(0), ncols = coords.size(1);
// Create coords map
CoordsLabelMap map;
cpu_map_type map;
map.reserve(nrows);
for (int i = 0; i < nrows; i++) {
vector<int> coord(ncols);
std::vector<int> coord(ncols);
std::copy_n(p_coords + i * ncols, ncols, coord.data());
auto map_iter = map.find(coord);
if (map_iter == map.end()) {
Expand Down Expand Up @@ -219,19 +239,18 @@ vector<at::Tensor> quantize_label_th(at::Tensor coords, at::Tensor labels,
return {th_mapping, th_colabels};
}
template vector<py::array> quantize_np<CoordsToIndexMap>(
template std::vector<py::array> quantize_np<CoordsToIndexMap>(
py::array_t<int, py::array::c_style | py::array::forcecast> coords);
template vector<at::Tensor> quantize_th<CoordsToIndexMap>(at::Tensor coords);
template std::vector<at::Tensor> quantize_th<CoordsToIndexMap>(at::Tensor
coords);
template <typename Dtype> InOutMaps<Dtype> CopyToInOutMap(at::Tensor th_map) {
InOutMaps<Dtype> vec_map(1);
vec_map[0].resize(th_map.size(0));
std::copy_n(th_map.data<Dtype>(), th_map.size(0), vec_map[0].begin());
return vec_map;
}
*/
#ifndef CPU_ONLY
template <typename Dtype>
pInOutMaps<Dtype> CopyToInOutMapGPU(at::Tensor th_map) {
Expand All @@ -246,7 +265,7 @@ pInOutMaps<Dtype> CopyToInOutMapGPU(at::Tensor th_map) {
return vec_map;
}
#endif

*/
/**
* A collection of feature averaging methods
* mode == 0: non-weighted average
Expand All @@ -255,11 +274,12 @@ pInOutMaps<Dtype> CopyToInOutMapGPU(at::Tensor th_map) {
*
* in_feat[in_map[i], j] --> out_feat[out_map[i], j]
*/
at::Tensor quantization_average_features(
at::Tensor th_in_feat /* feature matrix */,
at::Tensor th_in_map /* inverse_map from the quantization functions */,
at::Tensor th_out_map /* range(N) */, int out_nrows,
int mode /* average types */) {
// at::Tensor quantization_average_features(
// at::Tensor th_in_feat /* feature matrix */,
// at::Tensor th_in_map /* inverse_map from the quantization functions */,
// at::Tensor th_out_map /* range(N) */, int out_nrows,
// int mode /* average types */) {
/*
ASSERT(th_in_feat.dim() == 2, " The feature tensor should be a matrix.");
ASSERT(th_in_feat.size(0) == th_in_map.size(0),
"The size of the input feature and the input map must match.");
Expand Down Expand Up @@ -373,5 +393,6 @@ at::Tensor quantization_average_features(
return th_out_feat;
}
*/

} // end namespace minkowski
3 changes: 2 additions & 1 deletion tests/python/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import numpy as np

from MinkowskiEngine.utils import sparse_quantize
import MinkowskiEngineBackend as MEB
import MinkowskiEngineBackend._C as MEB


class TestQuantization(unittest.TestCase):
Expand Down Expand Up @@ -54,6 +54,7 @@ def test_mapping(self):
coords = (np.random.rand(N, 3) * 100).astype(np.int32)
mapping, inverse_mapping = MEB.quantize_np(coords)
print('N unique:', len(mapping), 'N:', N)
self.assertTrue((coords == coords[mapping][inverse_mapping]).all())
self.assertTrue((coords == coords[mapping[inverse_mapping]]).all())

coords = torch.from_numpy(coords)
Expand Down

0 comments on commit 397e5d2

Please sign in to comment.