forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Tensor.h
408 lines (397 loc) · 15.7 KB
/
Tensor.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
#pragma once
#include "ATen/Generator.h"
#include "ATen/Scalar.h"
#include "ATen/ScalarType.h"
#include "ATen/TensorAccessor.h"
#include "ATen/TensorImpl.h"
#include "ATen/TensorBase.h"
#include "ATen/Storage.h"
#include "ATen/SparseTensorRef.h"
#include "ATen/Utils.h"
namespace at {
struct Type;
// Tensor is a "generic" object holding a pointer to the underlying TensorImpl object, which
// has an embedded reference count. In this way, Tensor is similar to boost::intrusive_ptr.
//
// For example:
//
// void func(Tensor a) {
// Tensor b = a;
// ...
// }
//
// In this example, when we say Tensor b = a, we are creating a new object that points to the
// same underlying TensorImpl, and bumps its reference count. When b goes out of scope, the
// destructor decrements the reference count by calling release() on the TensorImpl it points to.
// The existing constructors, operator overloads, etc. take care to implement the correct semantics.
//
// Note that Tensor can also be NULL, i.e. it is not associated with any underlying TensorImpl, and
// special care must be taken to handle this.
struct Tensor : public detail::TensorBase {
Tensor() : TensorBase() {}
Tensor(TensorImpl * self, bool retain) : TensorBase(self, retain) {}
Tensor(const TensorBase & rhs) : TensorBase(rhs) {}
Tensor(const Tensor & rhs) = default;
Tensor(Tensor && rhs) noexcept = default;
Tensor & operator=(Tensor && rhs) & {
rhs.swap(*this);
return *this;
}
Tensor & operator=(Tensor const & rhs) & {
//Tensor ctor retains original rhs.pImpl
//then rhs.pImpl is swapped with this->pImpl
//finally Tensor dtor releases rhs.pImpl, which was originally this->pImpl
Tensor(rhs).swap(*this);
return *this;
}
Tensor & operator=(Tensor const & rhs) && {
return assign_(rhs);
}
Tensor & operator=(Scalar v) &&;
Tensor & assign_(Scalar v);
void reset() {
Tensor().swap(*this);
}
void reset(TensorImpl * rhs) {
Tensor(rhs, true).swap(*this);
}
void reset(TensorImpl * rhs, bool retain) {
Tensor(rhs, retain).swap(*this );
}
TensorImpl * get() const {
return pImpl;
}
TensorImpl * detach() {
TensorImpl * ret = pImpl;
pImpl = nullptr;
return ret;
}
bool defined() const {
return pImpl != nullptr;
}
void swap(Tensor & rhs) {
TensorImpl * tmp = pImpl;
pImpl = rhs.pImpl;
rhs.pImpl = tmp;
}
const char * toString() const {
return pImpl->toString();
}
IntList sizes() const {
return pImpl->sizes();
}
IntList strides() const {
return pImpl->strides();
}
int64_t ndimension() const {
return dim();
}
Type & type() const {
return pImpl->type();
}
inline Tensor toType(const Type & t) const;
inline Tensor & copy_(const Tensor & src);
inline Tensor toType(ScalarType t) const;
inline Tensor toBackend(Backend b) const;
template<typename T>
T * data() const;
void * unsafeGetTH(bool retain) const {
return pImpl->unsafeGetTH(retain);
}
//toLongData(), toFloatData() etc.
#define TO_TYPE_DATA(T,name,_) \
T * to##name##Data() const;
AT_FORALL_SCALAR_TYPES(TO_TYPE_DATA)
#undef TO_TYPE_DATA
template<typename T, size_t N>
TensorAccessor<T,N> accessor() {
static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data<T>()");
AT_ASSERT(dim() == N, "expected %d dims but tensor has %d",N,dim());
return TensorAccessor<T,N>(data<T>(),sizes().data(),strides().data());
}
Tensor operator-() const;
Tensor& operator+=(const Tensor & other);
Tensor& operator+=(Scalar other);
Tensor& operator-=(const Tensor & other);
Tensor& operator-=(Scalar other);
Tensor& operator*=(const Tensor & other);
Tensor& operator*=(Scalar other);
Tensor& operator/=(const Tensor & other);
Tensor& operator/=(Scalar other);
Tensor operator[](int64_t idx) const;
//example
//Tensor * add(Tensor & b);
int64_t storage_offset() const;
Tensor & resize_(IntList size);
int64_t numel() const;
Tensor & set_(Storage & storage);
Tensor & set_(Storage & sourceStorage, int64_t storage_offset, IntList size, IntList stride={});
Tensor & set_(const Tensor & source);
Tensor & set_();
Tensor & fill_(Scalar value);
bool is_same_size(const Tensor & other) const;
bool is_contiguous() const;
bool is_set_to(const Tensor & tensor) const;
Tensor & masked_fill_(const Tensor & mask, Scalar value);
Tensor & masked_scatter_(const Tensor & mask, const Tensor & source);
Tensor masked_select(const Tensor & mask) const;
Tensor transpose(int64_t dim0, int64_t dim1) const;
Tensor & transpose_(int64_t dim0, int64_t dim1);
Tensor t() const;
Tensor & t_();
Tensor squeeze(int64_t dim) const;
Tensor squeeze() const;
Tensor & squeeze_(int64_t dim);
Tensor & squeeze_();
Tensor unsqueeze(int64_t dim) const;
Tensor & unsqueeze_(int64_t dim);
Tensor nonzero() const;
Tensor contiguous() const;
Tensor clone() const;
Tensor view(IntList size) const;
Tensor expand(IntList size) const;
Tensor & resize_as_(const Tensor & the_template);
Tensor index_select(int64_t dim, const Tensor & index) const;
Tensor & index_copy_(int64_t dim, const Tensor & index, const Tensor & source);
Tensor & index_add_(int64_t dim, const Tensor & index, const Tensor & source);
Tensor & index_fill_(int64_t dim, const Tensor & index, Scalar value);
Tensor narrow(int64_t dimension, int64_t start, int64_t length) const;
Tensor unfold(int64_t dimension, int64_t size, int64_t step) const;
Tensor & scatter_(int64_t dim, const Tensor & index, const Tensor & src);
Tensor & scatter_(int64_t dim, const Tensor & index, Scalar value);
Tensor & scatter_add_(int64_t dim, const Tensor & index, const Tensor & src);
Tensor gather(int64_t dim, const Tensor & index) const;
void* data_ptr() const;
bool equal(const Tensor & other) const;
Tensor __and__(Scalar other) const;
Tensor __and__(const Tensor & other) const;
Tensor & __iand__(Scalar other);
Tensor & __iand__(const Tensor & other);
Tensor __or__(Scalar other) const;
Tensor __or__(const Tensor & other) const;
Tensor & __ior__(Scalar other);
Tensor & __ior__(const Tensor & other);
Tensor __xor__(Scalar other) const;
Tensor __xor__(const Tensor & other) const;
Tensor & __ixor__(Scalar other);
Tensor & __ixor__(const Tensor & other);
Tensor __lshift__(Scalar other) const;
Tensor __lshift__(const Tensor & other) const;
Tensor & __ilshift__(Scalar other);
Tensor & __ilshift__(const Tensor & other);
Tensor __rshift__(Scalar other) const;
Tensor __rshift__(const Tensor & other) const;
Tensor & __irshift__(Scalar other);
Tensor & __irshift__(const Tensor & other);
Tensor lt(Scalar other) const;
Tensor lt(const Tensor & other) const;
Tensor & lt_(Scalar other);
Tensor & lt_(const Tensor & other);
Tensor gt(Scalar other) const;
Tensor gt(const Tensor & other) const;
Tensor & gt_(Scalar other);
Tensor & gt_(const Tensor & other);
Tensor le(Scalar other) const;
Tensor le(const Tensor & other) const;
Tensor & le_(Scalar other);
Tensor & le_(const Tensor & other);
Tensor ge(Scalar other) const;
Tensor ge(const Tensor & other) const;
Tensor & ge_(Scalar other);
Tensor & ge_(const Tensor & other);
Tensor eq(Scalar other) const;
Tensor eq(const Tensor & other) const;
Tensor & eq_(Scalar other);
Tensor & eq_(const Tensor & other);
Tensor ne(Scalar other) const;
Tensor ne(const Tensor & other) const;
Tensor & ne_(Scalar other);
Tensor & ne_(const Tensor & other);
std::tuple<Tensor,Tensor> min(int64_t dim, bool keepdim=false) const;
Tensor min(const Tensor & other) const;
Scalar min() const;
std::tuple<Tensor,Tensor> max(int64_t dim, bool keepdim=false) const;
Tensor max(const Tensor & other) const;
Scalar max() const;
std::tuple<Tensor,Tensor> kthvalue(int64_t k, int64_t dim=-1, bool keepdim=false) const;
std::tuple<Tensor,Tensor> mode(int64_t dim=-1, bool keepdim=false) const;
std::tuple<Tensor,Tensor> median(int64_t dim, bool keepdim=false) const;
Scalar median() const;
std::tuple<Tensor,Tensor> sort(int64_t dim=-1, bool descending=false) const;
std::tuple<Tensor,Tensor> topk(int64_t k, int64_t dim=-1, bool largest=true, bool sorted=true) const;
bool all() const;
bool any() const;
int64_t get_device() const;
Tensor abs() const;
Tensor & abs_();
Tensor & sigmoid_();
Tensor sigmoid() const;
Tensor & log_();
Tensor log() const;
Tensor & log1p_();
Tensor log1p() const;
Tensor lgamma() const;
Tensor & lgamma_();
Tensor & exp_();
Tensor exp() const;
Tensor & cos_();
Tensor cos() const;
Tensor & acos_();
Tensor acos() const;
Tensor & cosh_();
Tensor cosh() const;
Tensor & sin_();
Tensor sin() const;
Tensor & asin_();
Tensor asin() const;
Tensor & sinh_();
Tensor sinh() const;
Tensor & tan_();
Tensor tan() const;
Tensor & atan_();
Tensor atan() const;
Tensor & tanh_();
Tensor tanh() const;
Tensor & erf_();
Tensor erf() const;
Tensor & erfinv_();
Tensor erfinv() const;
Tensor & sqrt_();
Tensor sqrt() const;
Tensor & rsqrt_();
Tensor rsqrt() const;
Tensor & ceil_();
Tensor ceil() const;
Tensor & floor_();
Tensor floor() const;
Tensor & round_();
Tensor round() const;
Tensor & trunc_();
Tensor trunc() const;
Tensor & frac_();
Tensor frac() const;
Tensor mean(int64_t dim, bool keepdim=false) const;
Scalar mean() const;
Tensor var(int64_t dim, bool unbiased=true, bool keepdim=false) const;
Scalar var(bool unbiased=true) const;
Tensor std(int64_t dim, bool unbiased=true, bool keepdim=false) const;
Scalar std(bool unbiased=true) const;
Tensor norm(Scalar p, int64_t dim, bool keepdim=false) const;
Scalar norm(Scalar p=2) const;
Tensor renorm(Scalar p, int64_t dim, Scalar maxnorm) const;
Tensor & renorm_(Scalar p, int64_t dim, Scalar maxnorm);
Scalar dist(const Tensor & other, Scalar p=2) const;
Tensor reciprocal() const;
Tensor & reciprocal_();
Tensor neg() const;
Tensor & neg_();
Tensor atan2(const Tensor & other) const;
Tensor & atan2_(const Tensor & other);
Tensor pow(Scalar exponent) const;
Tensor pow(const Tensor & exponent) const;
Tensor & pow_(Scalar exponent);
Tensor & pow_(const Tensor & exponent);
Tensor lerp(const Tensor & end, Scalar weight) const;
Tensor & lerp_(const Tensor & end, Scalar weight);
Tensor histc(int64_t bins=100, Scalar min=0, Scalar max=0) const;
Tensor & zero_();
Tensor sum(int64_t dim, bool keepdim=false) const;
Scalar sum() const;
Tensor prod(int64_t dim, bool keepdim=false) const;
Scalar prod() const;
Tensor cumsum(int64_t dim) const;
Tensor cumprod(int64_t dim) const;
Tensor sign() const;
Tensor & sign_();
Scalar trace() const;
Tensor add(Scalar other, Scalar alpha=1) const;
Tensor add(const Tensor & other, Scalar alpha=1) const;
Tensor add(SparseTensor other, Scalar alpha=1) const;
Tensor & add_(Scalar other, Scalar alpha=1);
Tensor & add_(const Tensor & other, Scalar alpha=1);
Tensor & add_(SparseTensor other, Scalar alpha=1);
Tensor sub(Scalar other, Scalar alpha=1) const;
Tensor sub(const Tensor & other, Scalar alpha=1) const;
Tensor & sub_(Scalar other, Scalar alpha=1);
Tensor & sub_(const Tensor & other, Scalar alpha=1);
Tensor mul(Scalar other) const;
Tensor mul(const Tensor & other) const;
Tensor & mul_(Scalar other);
Tensor & mul_(const Tensor & other);
Tensor div(Scalar other) const;
Tensor div(const Tensor & other) const;
Tensor & div_(Scalar other);
Tensor & div_(const Tensor & other);
Tensor fmod(Scalar other) const;
Tensor fmod(const Tensor & other) const;
Tensor & fmod_(Scalar other);
Tensor & fmod_(const Tensor & other);
Tensor remainder(Scalar other) const;
Tensor remainder(const Tensor & other) const;
Tensor & remainder_(Scalar other);
Tensor & remainder_(const Tensor & other);
Tensor clamp(Scalar min, Scalar max) const;
Tensor clamp(Scalar min) const;
Tensor & clamp_(Scalar min, Scalar max);
Tensor & clamp_(Scalar min);
Scalar dot(const Tensor & tensor) const;
Tensor tril(int64_t diagonal=0) const;
Tensor & tril_(int64_t diagonal=0);
Tensor triu(int64_t diagonal=0) const;
Tensor & triu_(int64_t diagonal=0);
Tensor cross(const Tensor & other, int64_t dim=-1) const;
Tensor diag(int64_t diagonal=0) const;
Tensor addmm(const Tensor & mat1, const Tensor & mat2, Scalar beta=1, Scalar alpha=1) const;
Tensor & addmm_(const Tensor & mat1, const Tensor & mat2, Scalar beta=1, Scalar alpha=1);
Tensor addmv(const Tensor & mat, const Tensor & vec, Scalar beta=1, Scalar alpha=1) const;
Tensor & addmv_(const Tensor & mat, const Tensor & vec, Scalar beta=1, Scalar alpha=1);
Tensor addr(const Tensor & vec1, const Tensor & vec2, Scalar beta=1, Scalar alpha=1) const;
Tensor & addr_(const Tensor & vec1, const Tensor & vec2, Scalar beta=1, Scalar alpha=1);
Tensor ger(const Tensor & vec2) const;
Tensor mv(const Tensor & vec) const;
Tensor mm(const Tensor & mat2) const;
Tensor bmm(const Tensor & mat2) const;
Tensor addbmm(const Tensor & batch1, const Tensor & batch2, Scalar beta=1, Scalar alpha=1) const;
Tensor & addbmm_(const Tensor & batch1, const Tensor & batch2, Scalar beta=1, Scalar alpha=1);
Tensor baddbmm(const Tensor & batch1, const Tensor & batch2, Scalar beta=1, Scalar alpha=1) const;
Tensor & baddbmm_(const Tensor & batch1, const Tensor & batch2, Scalar beta=1, Scalar alpha=1);
Tensor addcmul(const Tensor & tensor1, const Tensor & tensor2, Scalar value=1) const;
Tensor & addcmul_(const Tensor & tensor1, const Tensor & tensor2, Scalar value=1);
Tensor addcdiv(const Tensor & tensor1, const Tensor & tensor2, Scalar value=1) const;
Tensor & addcdiv_(const Tensor & tensor1, const Tensor & tensor2, Scalar value=1);
std::tuple<Tensor,Tensor> gesv(const Tensor & A) const;
std::tuple<Tensor,Tensor> gels(const Tensor & A) const;
std::tuple<Tensor,Tensor> trtrs(const Tensor & A, bool upper=true, bool transpose=false, bool unitriangular=false) const;
std::tuple<Tensor,Tensor> symeig(bool eigenvectors=false, bool upper=true) const;
std::tuple<Tensor,Tensor> eig(bool eigenvectors=false) const;
std::tuple<Tensor,Tensor,Tensor> svd(bool some=true) const;
Tensor inverse() const;
Tensor potrf(bool upper=true) const;
Tensor potrs(const Tensor & input2, bool upper=true) const;
Tensor potri(bool upper=true) const;
std::tuple<Tensor,Tensor> pstrf(bool upper=true, Scalar tol=-1) const;
std::tuple<Tensor,Tensor> qr() const;
std::tuple<Tensor,Tensor> geqrf() const;
Tensor orgqr(const Tensor & input2) const;
Tensor ormqr(const Tensor & input2, const Tensor & input3, bool left=true, bool transpose=false) const;
std::tuple<Tensor,Tensor> btrifact(const Tensor & info={}, bool pivot=true) const;
Tensor btrisolve(const Tensor & LU_data, const Tensor & LU_pivots) const;
Tensor & random_(int64_t from, int64_t to, Generator * generator=nullptr);
Tensor & random_(int64_t to, Generator * generator=nullptr);
Tensor & random_(Generator * generator=nullptr);
Tensor multinomial(int64_t num_samples, bool replacement=false, Generator * generator=nullptr) const;
Tensor & uniform_(double from=0, double to=1, Generator * generator=nullptr);
Tensor & normal_(double mean=0, double std=1, Generator * generator=nullptr);
Tensor & cauchy_(double median=0, double sigma=1, Generator * generator=nullptr);
Tensor & log_normal_(double mean=1, double std=2, Generator * generator=nullptr);
Tensor & exponential_(double lambd=1, Generator * generator=nullptr);
Tensor & geometric_(double p, Generator * generator=nullptr);
int64_t size(int64_t dim) const;
int64_t stride(int64_t dim) const;
Tensor select(int64_t dim, int64_t sliceIndex) const;
Tensor _unnarrow(int64_t dimension, int64_t offset, int64_t dimSize) const;
Tensor & assign_(const Tensor & src);
std::vector<Tensor> split(int64_t split_size, int64_t dim=0) const;
std::vector<Tensor> chunk(int64_t chunks, int64_t dim=0) const;
};
} //namespace at