Skip to content

Commit

Permalink
get coords map close #70
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischoy committed Dec 31, 2019
1 parent 235b815 commit 3c38fae
Show file tree
Hide file tree
Showing 7 changed files with 185 additions and 79 deletions.
96 changes: 56 additions & 40 deletions MinkowskiEngine/MinkowskiCoords.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
# of the code.
import os
import numpy as np
from collections import Sequence

import torch
from Common import convert_to_int_list, convert_to_int_tensor, prep_args
import MinkowskiEngineBackend as MEB
Expand Down Expand Up @@ -131,16 +133,23 @@ def transposed_stride(self,
force_creation))
return strided_key

def get_coords_key(self, tensor_strides):
tensor_strides = convert_to_int_list(tensor_strides, self.D)
key = self.CPPCoordsManager.getCoordsKey(tensor_strides)
coords_key = CoordsKey(self.D)
coords_key.setKey(key)
coords_key.setTensorStride(tensor_strides)
return coords_key

def get_coords(self, coords_key):
assert isinstance(coords_key, CoordsKey)
def get_coords_key(self, key_or_tensor_strides):
assert isinstance(key_or_tensor_strides, CoordsKey) or \
isinstance(key_or_tensor_strides, (Sequence, np.ndarray, torch.IntTensor, int)), \
f"The input must be either a CoordsKey or tensor_stride of type (int, list, tuple, array, Tensor). Invalid: {key_or_tensor_strides}"
if isinstance(key_or_tensor_strides, CoordsKey):
# Do nothing and return the input
return key_or_tensor_strides
else:
tensor_strides = convert_to_int_list(key_or_tensor_strides, self.D)
key = self.CPPCoordsManager.getCoordsKey(tensor_strides)
coords_key = CoordsKey(self.D)
coords_key.setKey(key)
coords_key.setTensorStride(tensor_strides)
return coords_key

def get_coords(self, coords_key_or_tensor_strides):
coords_key = self.get_coords_key(coords_key_or_tensor_strides)
coords = torch.IntTensor()
self.CPPCoordsManager.getCoords(coords, coords_key.CPPCoordsKey)
return coords
Expand Down Expand Up @@ -172,16 +181,25 @@ def get_row_indices_per_batch(self, coords_key, out_coords_key=None):
coords_key.CPPCoordsKey, out_coords_key.CPPCoordsKey)

def get_kernel_map(self,
in_tensor_strides,
out_tensor_strides,
in_key_or_tensor_strides,
out_key_or_tensor_strides,
stride=1,
kernel_size=3,
dilation=1,
region_type=0,
is_transpose=False,
is_pool=False):
in_coords_key = self.get_coords_key(in_tensor_strides)
out_coords_key = self.get_coords_key(out_tensor_strides)
r"""Get kernel in-out maps for the specified coords keys or tensor strides.
"""

if isinstance(in_key_or_tensor_strides, CoordsKey):
in_tensor_strides = in_key_or_tensor_strides.getTensorStride()
else:
in_tensor_strides = in_key_or_tensor_strides

in_coords_key = self.get_coords_key(in_key_or_tensor_strides)
out_coords_key = self.get_coords_key(out_key_or_tensor_strides)

tensor_strides = convert_to_int_tensor(in_tensor_strides, self.D)
strides = convert_to_int_tensor(stride, self.D)
Expand All @@ -191,9 +209,7 @@ def get_kernel_map(self,
tensor_strides, strides, kernel_sizes, dilations, region_type = prep_args(
tensor_strides, strides, kernel_sizes, dilations, region_type, D)

kernel_map = torch.IntTensor()
self.CPPCoordsManager.getKernelMap(
kernel_map,
kernel_map = self.CPPCoordsManager.getKernelMap(
convert_to_int_list(tensor_strides, D), #
convert_to_int_list(strides, D), #
convert_to_int_list(kernel_sizes, D), #
Expand All @@ -206,27 +222,29 @@ def get_kernel_map(self,

return kernel_map

def get_kernel_map_by_key(self,
in_coords_key,
out_coords_key,
tensor_strides=1,
stride=1,
kernel_size=3,
dilation=1,
region_type=0,
is_transpose=False):
tensor_strides = convert_to_int_list(tensor_strides, self.D)
strides = convert_to_int_list(stride, self.D)
kernel_sizes = convert_to_int_list(kernel_size, self.D)
dilations = convert_to_int_list(dilation, self.D)

kernel_map = torch.IntTensor()
self.CPPCoordsManager.getKernelMap(kernel_map, tensor_strides, strides,
kernel_sizes, dilations, region_type,
in_coords_key.CPPCoordsKey,
out_coords_key.CPPCoordsKey,
is_transpose)
return kernel_map
def get_coords_map(self, in_key_or_tensor_strides,
out_key_or_tensor_strides):
r"""Extract input coords indices that maps to output coords indices.
.. code-block:: python
sp_tensor = ME.SparseTensor(features, coords=coordinates)
out_sp_tensor = stride_2_conv(sp_tensor)
cm = sp_tensor.coords_man
# cm = out_sp_tensor.coords_man # doesn't matter which tensor you pick
ins, outs = cm.get_coords_map(1, # in stride
2) # out stride
for i, o in zip(ins, outs):
print(f"{i} -> {o}")
r"""

in_coords_key = self.get_coords_key(in_key_or_tensor_strides)
out_coords_key = self.get_coords_key(out_key_or_tensor_strides)

return self.CPPCoordsManager.getCoordsMap(in_coords_key.CPPCoordsKey,
out_coords_key.CPPCoordsKey)

def get_coords_size_by_coords_key(self, coords_key):
assert isinstance(coords_key, CoordsKey)
Expand All @@ -243,8 +261,6 @@ def permute_label(self,
max_label,
target_tensor_stride,
label_tensor_stride=1):
"""
"""
if target_tensor_stride == label_tensor_stride:
return label

Expand Down
6 changes: 6 additions & 0 deletions MinkowskiEngine/SparseTensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,12 @@ def D(self):
"""
return self.coords_key.D

def float(self):
self._F = self._F.float()

def double(self):
self._F = self._F.double()

def stride(self, s):
ss = convert_to_int_list(s)
tensor_strides = self.coords_key.getTensorStride()
Expand Down
12 changes: 8 additions & 4 deletions MinkowskiEngine/utils/collation.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,13 @@ def batched_coordinates(coords):
if isinstance(cs, np.ndarray):
cs = torch.from_numpy(np.floor(cs).astype(np.int32))
else:
cs = cs.floor().int()
if isinstance(cs, torch.IntTensor) or isinstance(cs, torch.LongTensor):
cs = cs
else:
cs = cs.floor()

cn = len(cs)
bcoords[s:s + cn, :D] = cs
bcoords[s:s + cn, :D] = cs.int()
bcoords[s:s + cn, D] = b
s += cn
return bcoords
Expand Down Expand Up @@ -112,14 +116,14 @@ def sparse_collate(coords, feats, labels=None):
if isinstance(coord, np.ndarray):
coord = torch.from_numpy(coord)
else:
assert isinstance( coord, torch.Tensor), \
assert isinstance(coord, torch.Tensor), \
"Coords must be of type numpy.ndarray or torch.Tensor"
coord = coord.int()

if isinstance(feat, np.ndarray):
feat = torch.from_numpy(feat)
else:
assert isinstance( feat, torch.Tensor), \
assert isinstance(feat, torch.Tensor), \
"Features must be of type numpy.ndarray or torch.Tensor"

# Labels
Expand Down
1 change: 1 addition & 0 deletions pybind/minkowski.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ void instantiate_coordsman(py::module &m) {
CoordsManager::existsCoordsKey)
.def("getCoordsKey", &CoordsManager::getCoordsKey)
.def("getKernelMap", &CoordsManager::getKernelMap)
.def("getCoordsMap", &CoordsManager::getCoordsMap)
.def("getCoordsSize", (int (CoordsManager::*)(py::object) const) &
CoordsManager::getCoordsSize)
.def("getCoords", &CoordsManager::getCoords)
Expand Down
95 changes: 73 additions & 22 deletions src/coords_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,10 @@ namespace py = pybind11;
* Given tensor_stride_src and tensor_stride_dst, find the respective coord_maps
* and return the indices of the coord_map_ind in coord_map_dst
*/
void CoordsManager::getKernelMap(at::Tensor kernel_map,
vector<int> tensor_strides,
vector<int> strides, vector<int> kernel_sizes,
vector<int> dilations, int region_type,
py::object py_in_coords_key,
py::object py_out_coords_key,
bool is_transpose, bool is_pool) const {
vector<at::Tensor> CoordsManager::getKernelMap(
vector<int> tensor_strides, vector<int> strides, vector<int> kernel_sizes,
vector<int> dilations, int region_type, py::object py_in_coords_key,
py::object py_out_coords_key, bool is_transpose, bool is_pool) const {
const InOutMapKey map_key = getMapHashKey(
tensor_strides, strides, kernel_sizes, dilations, region_type,
py_in_coords_key, py_out_coords_key, is_transpose, is_pool);
Expand All @@ -55,18 +52,73 @@ void CoordsManager::getKernelMap(at::Tensor kernel_map,
for (int k = 0; k < kernel_volume; k++)
all_volume += in_map[k].size();

kernel_map.resize_({all_volume, 3});
int *p_kernel_map = kernel_map.data<int>();

vector<at::Tensor> vec_tensors;
for (int k = 0; k < kernel_volume; k++) {
int curr_volume = in_map[k].size();
auto curr_volume = in_map[k].size();

at::Tensor kernel_map = torch::empty(
{(long)curr_volume, 2}, torch::TensorOptions().dtype(torch::kInt64));
long *p_kernel_map = kernel_map.data<long>();

for (int i = 0; i < curr_volume; i++) {
p_kernel_map[0] = k;
p_kernel_map[1] = in_map[k][i];
p_kernel_map[2] = out_map[k][i];
p_kernel_map += 3;
p_kernel_map[0] = in_map[k][i];
p_kernel_map[1] = out_map[k][i];
p_kernel_map += 2;
}

vec_tensors.push_back(::move(kernel_map));
}

return vec_tensors;
}

vector<at::Tensor>
CoordsManager::getCoordsMap(py::object py_in_coords_key,
py::object py_out_coords_key) const {
CoordsKey *p_in_coords_key = py_in_coords_key.cast<CoordsKey *>();
CoordsKey *p_out_coords_key = py_out_coords_key.cast<CoordsKey *>();
const uint64_t in_coords_key = p_in_coords_key->getKey();
const uint64_t out_coords_key = p_out_coords_key->getKey();

const auto in_map_iter = coords_maps.find(in_coords_key);
const auto out_map_iter = coords_maps.find(out_coords_key);

ASSERT(in_map_iter != coords_maps.end(), "Input coords not found at",
to_string(in_coords_key));
ASSERT(out_map_iter != coords_maps.end(), "Output coords not found at",
to_string(out_coords_key));

const auto &out_tensor_strides = p_out_coords_key->getTensorStride();
auto in_out =
in_map_iter->second.stride_map(out_map_iter->second, out_tensor_strides);

const auto &ins = in_out.first;
const auto &outs = in_out.second;
// All size
auto N =
::accumulate(ins.begin(), ins.end(), 0, [](size_t curr_sum, const vector<int>& map) {
return curr_sum + map.size();
});

at::Tensor in_out_1 =
torch::empty({N}, torch::TensorOptions().dtype(torch::kInt64));
at::Tensor in_out_2 =
torch::empty({N}, torch::TensorOptions().dtype(torch::kInt64));

auto a_in_out_1 = in_out_1.accessor<long int, 1>();
auto a_in_out_2 = in_out_2.accessor<long int, 1>();

size_t curr_it = 0;
for (const auto &in : ins)
for (const auto i : in)
a_in_out_1[curr_it++] = i;

curr_it = 0;
for (const auto &out : outs)
for (const auto o : out)
a_in_out_2[curr_it++] = o;

return {in_out_1, in_out_2};
}

uint64_t CoordsManager::getCoordsKey(const vector<int> &tensor_strides) const {
Expand Down Expand Up @@ -205,10 +257,12 @@ uint64_t CoordsManager::initializeCoords(at::Tensor coords, at::Tensor mapping,
batch_indices = map_batch_pair.second;

if (!allow_duplicate_coords && !force_remap) {
ASSERT(nrows == coords_map.size(), "A duplicate coordinate found. ",
ASSERT(nrows == coords_map.size(), "Duplicate coordinates found. ",
"Number of input coords:", nrows,
" != Number of unique coords:", coords_map.size(),
"If the duplication was intentional, set force_remap to true."
"For more information, please refer to the SparseTensor creation "
"documentation available at:"
"documentation available at: "
"https://stanfordvl.github.io/MinkowskiEngine/sparse_tensor.html");
}

Expand Down Expand Up @@ -248,7 +302,6 @@ uint64_t CoordsManager::createStridedCoords(uint64_t coords_key,
const vector<int> &tensor_strides,
const vector<int> &strides,
bool force_creation) {

// Basic assertions
ASSERT(existsCoordsKey(coords_key),
"The coord map doesn't exist for the given coords_key: ",
Expand Down Expand Up @@ -285,7 +338,6 @@ uint64_t CoordsManager::createTransposedStridedRegionCoords(
uint64_t coords_key, const vector<int> &tensor_strides,
const vector<int> &strides, vector<int> kernel_sizes, vector<int> dilations,
int region_type, at::Tensor offsets, bool force_creation) {

const vector<int> out_tensor_strides =
computeOutTensorStride(tensor_strides, strides, true /* is_transpose */);

Expand Down Expand Up @@ -401,7 +453,6 @@ const InOutMapKey CoordsManager::getMapHashKey(
vector<int> tensor_strides, vector<int> strides, vector<int> kernel_sizes,
vector<int> dilations, int region_type, py::object py_in_coords_key,
py::object py_out_coords_key, bool is_transpose, bool is_pool) const {

int D = tensor_strides.size();
ASSERT(D == tensor_strides.size() and D == strides.size() and
D == kernel_sizes.size() and D == dilations.size(),
Expand Down Expand Up @@ -493,7 +544,6 @@ const InOutMapsRefPair<int> CoordsManager::getInOutMaps(
int region_type, const at::Tensor &offsets, py::object py_in_coords_key,
py::object py_out_coords_key, bool is_transpose, bool is_pool,
bool force_creation) {

int D = tensor_strides.size();
ASSERT(D == tensor_strides.size() and D == strides.size() and
D == kernel_sizes.size() and D == dilations.size(),
Expand Down Expand Up @@ -659,7 +709,8 @@ CoordsManager::getPruningInOutMaps(at::Tensor use_feat,
else
out_coords_key = p_out_coords_key->getKey();

// Use the map key for origin hash map (stride, dilation, kernel are all NULL)
// Use the map key for origin hash map (stride, dilation, kernel are all
// NULL)
const InOutMapKey map_key =
getOriginMapHashKey(py_in_coords_key, py_out_coords_key);

Expand Down
15 changes: 9 additions & 6 deletions src/coords_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@

template <typename VType> int getInOutMapsSize(const VType &map) {
int n = 0;
for (auto cmap = begin(map); cmap != end(map); cmap++)
for (auto cmap = ::begin(map); cmap != ::end(map); ++cmap)
n += cmap->size();
return n;
}
Expand Down Expand Up @@ -119,11 +119,14 @@ class CoordsManager {
long int getBatchSize() const { return batch_indices.size(); }
set<int> getBatchIndices() const { return batch_indices; }
void getCoords(at::Tensor coords, py::object py_coords_key) const;
void getKernelMap(at::Tensor kernel_map, vector<int> tensor_strides,
vector<int> strides, vector<int> kernel_sizes,
vector<int> dilations, int region_type,
py::object py_in_coords_key, py::object py_out_coords_key,
bool is_transpose, bool is_pool) const;
vector<at::Tensor> getKernelMap(vector<int> tensor_strides,
vector<int> strides, vector<int> kernel_sizes,
vector<int> dilations, int region_type,
py::object py_in_coords_key,
py::object py_out_coords_key,
bool is_transpose, bool is_pool) const;
vector<at::Tensor> getCoordsMap(py::object py_in_coords_key,
py::object py_out_coords_key) const;

// Set the py_coords_key to the origin coords map key
void setOriginCoordsKey(py::object py_coords_key);
Expand Down
Loading

0 comments on commit 3c38fae

Please sign in to comment.