Skip to content

Commit

Permalink
[nvFuser] Add real and imag to nvfuser and its python frontend (pytor…
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm authored and pytorchmergebot committed Jul 7, 2022
1 parent 6f1d99b commit dad071d
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 0 deletions.
4 changes: 4 additions & 0 deletions torch/_prims/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@
"exp",
"expm1",
"floor",
"imag",
"isfinite",
"lgamma",
"log",
Expand All @@ -205,6 +206,7 @@
"log10",
"reciprocal",
"neg",
"real",
"round",
"rsqrt",
"sin",
Expand Down Expand Up @@ -733,6 +735,7 @@ def _fill_aten(a: Tensor, value: NumberType) -> Tensor:
),
return_type=RETURN_TYPE.VIEW,
impl_aten=torch.imag,
impl_nvfuser=_imag_nvfuser, # type: ignore[name-defined]
doc="",
)

Expand Down Expand Up @@ -792,6 +795,7 @@ def _fill_aten(a: Tensor, value: NumberType) -> Tensor:
),
return_type=RETURN_TYPE.VIEW,
impl_aten=torch.real,
impl_nvfuser=_real_nvfuser, # type: ignore[name-defined]
doc="",
)

Expand Down
40 changes: 40 additions & 0 deletions torch/csrc/jit/codegen/cuda/arith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,46 @@ TensorView* abs(TensorView* tv) {
return abs(tv->as<Val>())->as<TensorView>();
}

// The output of real(complex_tensor) are real numbers
Val* real(Val* v) {
if (v->getDataType() == DataType::ComplexDouble) {
Val* out = newValLike(v, DataType::Double);
IrBuilder::create<UnaryOp>(UnaryOpType::Real, out, v);
return out;
}
if (v->getDataType() == DataType::ComplexFloat) {
Val* out = newValLike(v, DataType::Float);
IrBuilder::create<UnaryOp>(UnaryOpType::Real, out, v);
return out;
}
// We use UnaryOpType::Set instead of UnaryOpType::Real to support non-complex
// tensors
return unaryOp(UnaryOpType::Set, v);
}

TensorView* real(TensorView* tv) {
return real(tv->as<Val>())->as<TensorView>();
}

// The output of imag(complex_tensor) are real numbers
Val* imag(Val* v) {
if (v->getDataType() == DataType::ComplexDouble) {
Val* out = newValLike(v, DataType::Double);
IrBuilder::create<UnaryOp>(UnaryOpType::Imag, out, v);
return out;
}
if (v->getDataType() == DataType::ComplexFloat) {
Val* out = newValLike(v, DataType::Float);
IrBuilder::create<UnaryOp>(UnaryOpType::Imag, out, v);
return out;
}
TORCH_CHECK(false, "imag not supported for non-complex tensors");
}

TensorView* imag(TensorView* tv) {
return imag(tv->as<Val>())->as<TensorView>();
}

// UNARY FLOAT CAST OPERATIONS

#define NVFUSER_DEFINE_UNARY_FLOAT_OP(op_name, op_type) \
Expand Down
6 changes: 6 additions & 0 deletions torch/csrc/jit/codegen/cuda/arith.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@ TORCH_CUDA_CU_API TensorView* neg(TensorView*);
// randlike
TORCH_CUDA_CU_API Val* randlike(Val*);
TORCH_CUDA_CU_API TensorView* randlike(TensorView*);
// real
TORCH_CUDA_CU_API Val* real(Val*);
TORCH_CUDA_CU_API TensorView* real(TensorView*);
// reciprocal
TORCH_CUDA_CU_API Val* reciprocal(Val*);
TORCH_CUDA_CU_API TensorView* reciprocal(TensorView*);
Expand Down Expand Up @@ -229,6 +232,9 @@ TORCH_CUDA_CU_API TensorView* trunc(TensorView*);
// bitwise_not
TORCH_CUDA_CU_API Val* bitwise_not(Val*);
TORCH_CUDA_CU_API TensorView* bitwise_not(TensorView*);
// imag
TORCH_CUDA_CU_API Val* imag(Val*);
TORCH_CUDA_CU_API TensorView* imag(TensorView*);
// isfinite
TORCH_CUDA_CU_API Val* isfinite(Val*);
TORCH_CUDA_CU_API TensorView* isfinite(TensorView*);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,8 @@ void initNvFuserPythonBindings(PyObject* module) {
NVFUSER_PYTHON_BINDING_UNARY_OP("isneginf", isneginf)
NVFUSER_PYTHON_BINDING_UNARY_OP("isposinf", isposinf)
NVFUSER_PYTHON_BINDING_UNARY_OP("isreal", isreal)
NVFUSER_PYTHON_BINDING_UNARY_OP("real", real)
NVFUSER_PYTHON_BINDING_UNARY_OP("imag", imag)
#undef NVFUSER_PYTHON_BINDING_UNARY_OP

#define NVFUSER_PYTHON_BINDING_BINARY_OP(op_str, op_name) \
Expand Down
11 changes: 11 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4051,6 +4051,14 @@ TEST_F(NVFuserTest, FusionUnaryOps_CUDA) {
OpTuple{at::isposinf, UnaryOpType::IsPosInf, "isposinf"},
};

// The following ops only supports complex
std::vector<OpTuple> ops_complex_only{
// real is supported via UnaryOpType::Set for non-complex types, and
// UnaryOpType::Real requires input to be complex
OpTuple{at::real, UnaryOpType::Real, "real"},
OpTuple{at::imag, UnaryOpType::Imag, "imag"},
};

// Complex support for the following op is not working in nvFuser yet
std::vector<OpTuple> ops_skip_complex{
// TODO: abs is actually supported in nvFuser, but it has bug!!!
Expand Down Expand Up @@ -4082,6 +4090,9 @@ TEST_F(NVFuserTest, FusionUnaryOps_CUDA) {
ops_without_complex.end());
ops_to_test.insert(
ops_to_test.end(), ops_skip_complex.begin(), ops_skip_complex.end());
} else {
ops_to_test.insert(
ops_to_test.end(), ops_complex_only.begin(), ops_complex_only.end());
}
std::for_each(ops.begin(), ops.end(), [&](OpTuple& op) {
test_op(
Expand Down
6 changes: 6 additions & 0 deletions torch/csrc/jit/codegen/cuda/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,8 @@ bool needFloatSuffix(UnaryOpType t) {
case UnaryOpType::IsNegInf:
case UnaryOpType::IsPosInf:
case UnaryOpType::IsReal:
case UnaryOpType::Real:
case UnaryOpType::Imag:
return false;
default:
return true;
Expand Down Expand Up @@ -466,6 +468,10 @@ static const char* unary_op_type2string(UnaryOpType t) {
return "isposinf";
case UnaryOpType::IsReal:
return "isreal";
case UnaryOpType::Real:
return "std::real";
case UnaryOpType::Imag:
return "std::imag";
default:
TORCH_INTERNAL_ASSERT(false, "No string found for unary op type.");
}
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/codegen/cuda/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ enum class UnaryOpType {
Floor,
Frac,
Gelu,
Imag,
Silu,
Lgamma,
Log,
Expand All @@ -161,6 +162,7 @@ enum class UnaryOpType {
BitCast,
Neg,
RandLike,
Real,
Reciprocal,
Relu,
Rsqrt,
Expand Down

0 comments on commit dad071d

Please sign in to comment.