-
Notifications
You must be signed in to change notification settings - Fork 348
/
storage.cuh
200 lines (177 loc) · 6.67 KB
/
storage.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
/* Copyright (c) 2020 NVIDIA CORPORATION.
* Copyright (c) 2018-2020 Chris Choy (chrischoy@ai.stanford.edu)
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
* IN THE SOFTWARE.
*
* Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural
* Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part
* of the code.
*/
#ifndef STORAGE_CUH
#define STORAGE_CUH
#include "utils.hpp"
#include <vector>
namespace minkowski {
template <typename Dtype, typename ByteAllocator> class gpu_storage {
public:
using data_type = Dtype;
using byte_allocator_type = ByteAllocator;
using self_type = gpu_storage<data_type, byte_allocator_type>;
gpu_storage() : m_data(nullptr), m_num_elements(0) {}
gpu_storage(uint64_t const num_elements) { allocate(num_elements); }
gpu_storage(self_type const &other_storage) {
LOG_DEBUG("copy storage constructor");
if (other_storage.size() == 0)
return;
allocate(other_storage.size());
CUDA_CHECK(cudaMemcpy(m_data, other_storage.cdata(),
other_storage.size() * sizeof(data_type),
cudaMemcpyDeviceToDevice));
}
gpu_storage(self_type &&other_storage) {
LOG_DEBUG("move storage constructor from", other_storage.m_data,
"with size", other_storage.m_num_elements);
if (other_storage.size() == 0)
return;
m_num_elements = other_storage.size();
m_data = other_storage.data();
other_storage.m_data = nullptr;
other_storage.m_num_elements = 0;
}
gpu_storage(std::vector<Dtype> const &vec) {
LOG_DEBUG("vector storage constructor");
from_vector(vec);
}
~gpu_storage() { deallocate(); }
data_type *allocate(uint64_t const num_elements) {
if (num_elements == 0)
return nullptr;
m_num_elements = num_elements;
m_data =
(data_type *)m_allocator.allocate(m_num_elements * sizeof(data_type));
LOG_DEBUG("Allocating", num_elements, "gpu storage at", m_data);
return m_data;
}
void deallocate() {
LOG_DEBUG("Deallocating", m_num_elements, "gpu storage at", m_data);
if (m_num_elements > 0) {
m_allocator.deallocate((char *)m_data,
m_num_elements * sizeof(data_type));
}
}
void from_vector(std::vector<Dtype> const &vec) {
resize(vec.size());
if (m_num_elements > 0) {
CUDA_CHECK(cudaMemcpy(m_data, vec.data(),
m_num_elements * sizeof(data_type),
cudaMemcpyHostToDevice));
}
}
data_type *data() {
check_pointer("data");
return m_data;
}
data_type const *cdata() const {
check_pointer("cdata");
return m_data;
}
data_type *begin() {
check_pointer("begin");
return m_data;
}
data_type *end() {
check_pointer("end");
return m_data + m_num_elements;
}
data_type const *cbegin() const {
check_pointer("cbegin");
return m_data;
}
data_type const *cend() const {
check_pointer("cend");
return m_data + m_num_elements;
}
std::vector<data_type> to_vector() { return to_vector(size()); }
std::vector<data_type> to_vector(uint64_t const num_elements) {
std::vector<data_type> cpu_storage(num_elements);
if (num_elements > 0)
CUDA_CHECK(cudaMemcpy(cpu_storage.data(), m_data,
num_elements * sizeof(data_type),
cudaMemcpyDeviceToHost));
return cpu_storage;
}
uint64_t size() const { return m_num_elements; }
void resize(uint64_t const new_num_elements) {
LOG_DEBUG("resizing from", m_num_elements, "to", new_num_elements);
if (new_num_elements == m_num_elements)
return;
data_type *new_data =
(data_type *)m_allocator.allocate(new_num_elements * sizeof(data_type));
if (m_num_elements > 0) {
CUDA_CHECK(cudaMemcpy(new_data, m_data,
new_num_elements * sizeof(data_type),
cudaMemcpyDeviceToDevice));
m_allocator.deallocate((char *)m_data,
m_num_elements * sizeof(data_type));
}
m_data = new_data;
m_num_elements = new_num_elements;
}
void print_by_vector(uint64_t const num_vec, uint64_t const vec_size) {
auto const print_n = std::min(num_vec, size() / vec_size);
auto const cpu_storage = to_vector(vec_size * print_n);
for (int i = 0; i < print_n; ++i) {
std::cout << PtrToString(&cpu_storage[i * vec_size], vec_size) << "\n";
}
}
private:
void check_pointer(std::string const &fn) const {
#ifdef DEBUG
if (m_data == nullptr) {
throw std::runtime_error("storage.cuh: m_data == nullptr on " + fn);
} else if (m_num_elements == 0) {
throw std::runtime_error("storage.cuh: m_num_elements == 0 on" + fn);
}
#endif
}
byte_allocator_type m_allocator;
data_type *m_data = nullptr;
uint64_t m_num_elements = 0;
};
template <typename Dtype, typename ByteAllocator>
void print(const gpu_storage<Dtype, ByteAllocator> &v) {
auto cpu_storage = v.to_vector();
for (size_t i = 0; i < cpu_storage.size(); i++)
std::cout << " " << std::fixed << std::setprecision(3) << cpu_storage[i];
std::cout << "\n";
}
// template void print(const thrust::device_vector<float> &v);
// template void print(const thrust::device_vector<int32_t> &v);
// template <typename Dtype1, typename Dtype2>
// void print(const thrust::device_vector<Dtype1> &v1,
// const thrust::device_vector<Dtype2> &v2) {
// for (size_t i = 0; i < v1.size(); i++)
// std::cout << " (" << v1[i] << "," << std::setw(2) << v2[i] << ")";
// std::cout << "\n";
// }
//
// template void print(const thrust::device_vector<int32_t> &v1,
// const thrust::device_vector<int32_t> &v2);
} // namespace minkowski
#endif