Skip to content

Commit

Permalink
Enabled torch.mm and torch.mv for bfloat16
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#24224

Test Plan: Imported from OSS

Differential Revision: D16779996

Pulled By: izdeby

fbshipit-source-id: c859d8945a564edfa3f8a1430f140ae30d484d19
  • Loading branch information
izdeby authored and facebook-github-bot committed Aug 16, 2019
1 parent cf57f73 commit eee3e92
Show file tree
Hide file tree
Showing 9 changed files with 597 additions and 375 deletions.
9 changes: 9 additions & 0 deletions aten/src/ATen/Declarations.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -2102,6 +2102,7 @@
[[
name: _th_addmm
cname: addmm
cpu_bfloat16: True
variants:
- function
return: argument 0
Expand All @@ -2123,6 +2124,7 @@
]]
[[
name: _th_addmm_
cpu_bfloat16: True
variants: [function]
return: self
options:
Expand All @@ -2142,6 +2144,7 @@
[[
name: _th_addmv
cname: addmv
cpu_bfloat16: True
variants: function
return: argument 0
arguments:
Expand All @@ -2161,6 +2164,7 @@
[[
name: _th_addmv_
cname: addmv
cpu_bfloat16: True
variants: function
return: self
arguments:
Expand All @@ -2178,6 +2182,7 @@
[[
name: _th_addr
cname: addr
cpu_bfloat16: True
variants: function
return: argument 0
scalar_check: false
Expand All @@ -2197,6 +2202,7 @@
]]
[[
name: _th_addr_
cpu_bfloat16: True
cname: addr
return: self
variants: function
Expand All @@ -2215,6 +2221,7 @@
[[
name: _th_ger
cname: addr
cpu_bfloat16: True
variants: function
return: argument 0
scalar_check: false
Expand All @@ -2231,6 +2238,7 @@
]]
[[
name: _th_mv
cpu_bfloat16: True
cname: addmv
variants: function
return: argument 0
Expand All @@ -2248,6 +2256,7 @@
]]
[[
name: _th_mm
cpu_bfloat16: True
variants: function
return: argument 0
options:
Expand Down
3 changes: 3 additions & 0 deletions aten/src/TH/THBlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@

#include <TH/generic/THBlas.cpp>
#include <TH/THGenerateAllTypes.h>

#include <TH/generic/THBlas.cpp>
#include <TH/THGenerateBFloat16Type.h>
3 changes: 3 additions & 0 deletions aten/src/TH/THBlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,7 @@
#include <TH/generic/THBlas.h>
#include <TH/THGenerateAllTypes.h>

#include <TH/generic/THBlas.h>
#include <TH/THGenerateBFloat16Type.h>

#endif
3 changes: 3 additions & 0 deletions aten/src/TH/THTensorMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,8 @@
#include <TH/generic/THTensorMath.cpp>
#include <TH/THGenerateAllTypes.h>

#include <TH/generic/THTensorMath.cpp>
#include <TH/THGenerateBFloat16Type.h>

#include <TH/generic/THTensorMath.cpp>
#include <TH/THGenerateBoolType.h>
54 changes: 28 additions & 26 deletions aten/src/TH/generic/THTensorEvenMoreMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,34 @@ void THTensor_(maskedCopyBool)(THTensor *tensor, THBoolTensor *mask, THTensor* s
c10::raw::intrusive_ptr::decref(srct);
}

#if !defined(TH_REAL_IS_BOOL)
void THTensor_(mul)(THTensor *r_, THTensor *t, scalar_t value)
{
THTensor_(resizeAs)(r_, t);
int64_t r_Size = THTensor_(nElement)(r_);
int r_Contig = THTensor_(isContiguous)(r_);
int tContig = THTensor_(isContiguous)(t);
if (r_Contig && tContig) {
TH_TENSOR_APPLY2_CONTIG(scalar_t, r_, scalar_t, t, THVector_(muls)(r__data, t_data, value, r__len););
} else {
TH_TENSOR_APPLY2_PARALLEL(r_Size, r_Contig, tContig, scalar_t, r_, scalar_t, t, *r__data = *t_data * value;, ORDIN_TH_OMP_OVERHEAD_THRESHOLD)
}
}

void THTensor_(div)(THTensor *r_, THTensor *t, scalar_t value)
{
THTensor_(resizeAs)(r_, t);
int64_t r_Size = THTensor_(nElement)(r_);
int r_Contig = THTensor_(isContiguous)(r_);
int tContig = THTensor_(isContiguous)(t);
if (r_Contig && tContig) {
TH_TENSOR_APPLY2_CONTIG(scalar_t, r_, scalar_t, t, THVector_(divs)(r__data, t_data, value, r__len););
} else {
TH_TENSOR_APPLY2_PARALLEL(r_Size, r_Contig, tContig, scalar_t, r_, scalar_t, t, *r__data = *t_data / value;, ORDIN_TH_OMP_OVERHEAD_THRESHOLD)
}
}
#endif

#if !defined(TH_REAL_IS_BFLOAT16) /* non bfloat16 part*/

accreal THTensor_(sumall)(THTensor *tensor)
Expand Down Expand Up @@ -765,32 +793,6 @@ void THTensor_(sub_scaled)(THTensor *r_, THTensor *t, scalar_t value, scalar_t a
THTensor_(add)(r_, t, -value * alpha);
}

void THTensor_(mul)(THTensor *r_, THTensor *t, scalar_t value)
{
THTensor_(resizeAs)(r_, t);
int64_t r_Size = THTensor_(nElement)(r_);
int r_Contig = THTensor_(isContiguous)(r_);
int tContig = THTensor_(isContiguous)(t);
if (r_Contig && tContig) {
TH_TENSOR_APPLY2_CONTIG(scalar_t, r_, scalar_t, t, THVector_(muls)(r__data, t_data, value, r__len););
} else {
TH_TENSOR_APPLY2_PARALLEL(r_Size, r_Contig, tContig, scalar_t, r_, scalar_t, t, *r__data = *t_data * value;, ORDIN_TH_OMP_OVERHEAD_THRESHOLD)
}
}

void THTensor_(div)(THTensor *r_, THTensor *t, scalar_t value)
{
THTensor_(resizeAs)(r_, t);
int64_t r_Size = THTensor_(nElement)(r_);
int r_Contig = THTensor_(isContiguous)(r_);
int tContig = THTensor_(isContiguous)(t);
if (r_Contig && tContig) {
TH_TENSOR_APPLY2_CONTIG(scalar_t, r_, scalar_t, t, THVector_(divs)(r__data, t_data, value, r__len););
} else {
TH_TENSOR_APPLY2_PARALLEL(r_Size, r_Contig, tContig, scalar_t, r_, scalar_t, t, *r__data = *t_data / value;, ORDIN_TH_OMP_OVERHEAD_THRESHOLD)
}
}

void THTensor_(lshift)(THTensor *r_, THTensor *t, scalar_t value)
{
#if defined(TH_REAL_IS_FLOAT)
Expand Down
Loading

0 comments on commit eee3e92

Please sign in to comment.