From 82d714935ea7c54f9e5966f47adf4b8460eb2c3d Mon Sep 17 00:00:00 2001 From: anjali411 Date: Wed, 12 May 2021 23:30:11 -0700 Subject: [PATCH] [TS] Add complex support for more ops (#54541) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54541 Test Plan: Imported from OSS Reviewed By: SplitInfinity Differential Revision: D27599114 Pulled By: anjali411 fbshipit-source-id: 182d4480fd788599c408bfaf0d23baf3d9a4e967 --- test/jit/test_complex.py | 74 +++++- torch/csrc/jit/runtime/register_ops_utils.h | 216 +++++++++++------- torch/csrc/jit/runtime/register_prim_ops.cpp | 34 ++- .../jit/runtime/register_prim_ops_fulljit.cpp | 35 ++- torch/jit/_builtins.py | 4 + 5 files changed, 266 insertions(+), 97 deletions(-) diff --git a/test/jit/test_complex.py b/test/jit/test_complex.py index d2b86eb89c89f..0eed4a7eb0bb3 100644 --- a/test/jit/test_complex.py +++ b/test/jit/test_complex.py @@ -2,6 +2,7 @@ import os import sys from torch.testing._internal.jit_utils import JitTestCase, execWrapper +from torch.testing._internal.common_utils import IS_MACOS from typing import List, Dict from itertools import product from textwrap import dedent @@ -79,7 +80,13 @@ def checkCmath(func_name, funcs_template=funcs_template): f_script = cu.func f = scope['func'] - for a in complex_vals: + if func_name in ['isinf', 'isnan', 'isfinite']: + new_vals = vals + ([float('inf'), float('nan'), -1 * float('inf')]) + final_vals = tuple(complex(x, y) for x, y in product(new_vals, new_vals)) + else: + final_vals = complex_vals + + for a in final_vals: res_python = None res_script = None try: @@ -99,7 +106,7 @@ def checkCmath(func_name, funcs_template=funcs_template): self.assertEqual(res_python, res_script, msg=msg) unary_ops = ['log', 'log10', 'sqrt', 'exp', 'sin', 'cos', 'asin', 'acos', 'atan', 'sinh', 'cosh', - 'tanh', 'asinh', 'acosh', 'atanh', 'phase'] + 'tanh', 'asinh', 'acosh', 'atanh', 'phase', 'isinf', 'isnan', 'isfinite'] # --- Unary ops --- for op in unary_ops: @@ -111,6 +118,32 @@ def fn(x: complex): for val in complex_vals: self.checkScript(fn, (val, )) + def pow_complex_float(x: complex, y: float): + return pow(x, y) + + def pow_float_complex(x: float, y: complex): + return pow(x, y) + + + self.checkScript(pow_float_complex, (2, 3j)) + self.checkScript(pow_complex_float, (3j, 2)) + + def pow_complex_complex(x: complex, y: complex): + return pow(x, y) + + for x, y in zip(complex_vals, complex_vals): + # Reference: https://github.com/pytorch/pytorch/issues/54622 + if (x == 0): + continue + self.checkScript(pow_complex_complex, (x, y)) + + if not IS_MACOS: + # --- Binary op --- + def rect_fn(x: float, y: float): + return cmath.rect(x, y) + for x, y in product(vals, vals): + self.checkScript(rect_fn, (x, y, )) + func_constants_template = dedent(''' def func(): return cmath.{func_or_const} @@ -120,7 +153,6 @@ def func(): for x in (float_consts + complex_consts): checkCmath(x, funcs_template=func_constants_template) - def test_infj_nanj_pickle(self): class ComplexModule(torch.jit.ScriptModule): def __init__(self): @@ -247,3 +279,39 @@ def fn_tensor_tensor(real, img): for x, y in product(tensors, tensors): self.checkScript(fn_tensor_tensor, (x, y, )) + + def test_comparison_ops(self): + def fn1(a: complex, b: complex): + return a == b + + def fn2(a: complex, b: complex): + return a != b + + def fn3(a: complex, b: float): + return a == b + + def fn4(a: complex, b: float): + return a != b + + x, y = 2 - 3j, 4j + self.checkScript(fn1, (x, x)) + self.checkScript(fn1, (x, y)) + self.checkScript(fn2, (x, x)) + self.checkScript(fn2, (x, y)) + + x1, y1 = 1 + 0j, 1.0 + self.checkScript(fn3, (x1, y1)) + self.checkScript(fn4, (x1, y1)) + + def test_div(self): + def fn1(a: complex, b: complex): + return a / b + + x, y = 2 - 3j, 4j + self.checkScript(fn1, (x, y)) + + def test_complex_list_sum(self): + def fn(x: List[complex]): + return sum(x) + + self.checkScript(fn, (torch.randn(4, dtype=torch.cdouble).tolist(), )) diff --git a/torch/csrc/jit/runtime/register_ops_utils.h b/torch/csrc/jit/runtime/register_ops_utils.h index 5844f8160ede8..239989bc0bd96 100644 --- a/torch/csrc/jit/runtime/register_ops_utils.h +++ b/torch/csrc/jit/runtime/register_ops_utils.h @@ -435,23 +435,36 @@ void listCopyAndSort(Stack* stack); void listSetItem(Stack* stack); -#define DEFINE_GENERIC_BINARY_OP(aten_op, op, result) \ - OperatorGenerator( \ - TORCH_SELECTIVE_SCHEMA(#aten_op ".int_int(int a, int b) -> " #result), \ - [](Stack* stack) { \ - int64_t a, b; \ - pop(stack, a, b); \ - push(stack, op); \ - }, \ - aliasAnalysisFromSchema()), \ - OperatorGenerator( \ - TORCH_SELECTIVE_SCHEMA( \ - #aten_op ".float_float(float a, float b) -> " #result), \ - [](Stack* stack) { \ - double a, b; \ - pop(stack, a, b); \ - push(stack, op); \ - }, \ +#define DEFINE_GENERIC_BINARY_OP( \ + aten_op, op, int_float_result, complex_result) \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA(#aten_op \ + ".int_int(int a, int b) -> " #int_float_result), \ + [](Stack* stack) { \ + int64_t a, b; \ + pop(stack, a, b); \ + push(stack, op); \ + }, \ + aliasAnalysisFromSchema()), \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA( \ + #aten_op \ + ".float_float(float a, float b) -> " #int_float_result), \ + [](Stack* stack) { \ + double a, b; \ + pop(stack, a, b); \ + push(stack, op); \ + }, \ + aliasAnalysisFromSchema()), \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA( \ + #aten_op \ + ".complex_complex(complex a, complex b) -> " #complex_result), \ + [](Stack* stack) { \ + c10::complex a, b; \ + pop(stack, a, b); \ + push(stack, op); \ + }, \ aliasAnalysisFromSchema()) // define implementations for primitive number ops @@ -520,66 +533,46 @@ void listSetItem(Stack* stack); // it's necessary to register this overload following // int/float variations to avoid trapping Scalar args // in unintended implicit conversions -#define DEFINE_SCALAR_BINARY_OP(aten_op, int_op, float_op, result) \ - OperatorGenerator( \ - TORCH_SELECTIVE_SCHEMA(#aten_op "(Scalar a, Scalar b) -> " #result), \ - [](Stack* stack) { \ - IValue x, y; \ - pop(stack, x, y); \ - if (x.isDouble()) { \ - if (y.isDouble()) { \ - double a = x.toDouble(); \ - double b = y.toDouble(); \ - push(stack, float_op); \ - } else { \ - double a = x.toDouble(); \ - int64_t b = y.toInt(); \ - push(stack, float_op); \ - } \ - } else { \ - if (y.isDouble()) { \ - int64_t a = x.toInt(); \ - double b = y.toDouble(); \ - push(stack, float_op); \ - } else { \ - int64_t a = x.toInt(); \ - int64_t b = y.toInt(); \ - push(stack, int_op); \ - } \ - } \ - }, \ +#define DEFINE_SCALAR_BINARY_OP_AVOID_COLLISION_GENERIC( \ + aten_op, int_op, float_op, result, string_val) \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA(#aten_op string_val \ + "(Scalar a, Scalar b) -> " #result), \ + [](Stack* stack) { \ + IValue x, y; \ + pop(stack, x, y); \ + if (x.isDouble()) { \ + if (y.isDouble()) { \ + double a = x.toDouble(); \ + double b = y.toDouble(); \ + push(stack, float_op); \ + } else { \ + double a = x.toDouble(); \ + int64_t b = y.toInt(); \ + push(stack, float_op); \ + } \ + } else { \ + if (y.isDouble()) { \ + int64_t a = x.toInt(); \ + double b = y.toDouble(); \ + push(stack, float_op); \ + } else { \ + int64_t a = x.toInt(); \ + int64_t b = y.toInt(); \ + push(stack, int_op); \ + } \ + } \ + }, \ aliasAnalysisFromSchema()) -#define DEFINE_SCALAR_SCALAR_BINARY_OP(aten_op, int_op, float_op, result) \ - OperatorGenerator( \ - TORCH_SELECTIVE_SCHEMA( \ - #aten_op ".Scalar_Scalar(Scalar a, Scalar b) -> " #result), \ - [](Stack* stack) { \ - IValue x, y; \ - pop(stack, x, y); \ - if (x.isDouble()) { \ - if (y.isDouble()) { \ - double a = x.toDouble(); \ - double b = y.toDouble(); \ - push(stack, float_op); \ - } else { \ - double a = x.toDouble(); \ - int64_t b = y.toInt(); \ - push(stack, float_op); \ - } \ - } else { \ - if (y.isDouble()) { \ - int64_t a = x.toInt(); \ - double b = y.toDouble(); \ - push(stack, float_op); \ - } else { \ - int64_t a = x.toInt(); \ - int64_t b = y.toInt(); \ - push(stack, int_op); \ - } \ - } \ - }, \ - aliasAnalysisFromSchema()) +#define DEFINE_SCALAR_BINARY_OP(aten_op, int_op, float_op, result) \ + DEFINE_SCALAR_BINARY_OP_AVOID_COLLISION_GENERIC( \ + aten_op, int_op, float_op, result, "") + +#define DEFINE_SCALAR_BINARY_OP_AVOID_COLLISION( \ + aten_op, int_op, float_op, result) \ + DEFINE_SCALAR_BINARY_OP_AVOID_COLLISION_GENERIC( \ + aten_op, int_op, float_op, result, ".Scalar_Scalar") #define DEFINE_BINARY_OP(aten_op, op) \ DEFINE_GENERIC_OP(aten_op, op, op, int, float), \ @@ -783,7 +776,55 @@ void listSetItem(Stack* stack); }, \ aliasAnalysisFromSchema()) -#define DEFINE_SCALAR_BINARY_OP_WITH_COMPLEX( \ +#define DEFINE_SCALAR_BINARY_OP_WITH_COMPLEX_AVOID_COLLISION_GENERIC( \ + aten_op, int_op, float_op, complex_op, result, string_val) \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA(#aten_op string_val \ + "(Scalar a, Scalar b) -> " #result), \ + [](Stack* stack) { \ + IValue x, y; \ + pop(stack, x, y); \ + if (x.isComplexDouble()) { \ + c10::complex a = x.toComplexDouble(); \ + if (y.isComplexDouble()) { \ + c10::complex b = y.toComplexDouble(); \ + push(stack, complex_op); \ + } else if (y.isDouble()) { \ + double b = y.toDouble(); \ + push(stack, complex_op); \ + } else { \ + int64_t b = y.toInt(); \ + push(stack, complex_op); \ + } \ + } else if (x.isDouble()) { \ + double a = x.toDouble(); \ + if (y.isComplexDouble()) { \ + c10::complex b = y.toComplexDouble(); \ + push(stack, complex_op); \ + } else if (y.isDouble()) { \ + double b = y.toDouble(); \ + push(stack, float_op); \ + } else { \ + int64_t b = y.toInt(); \ + push(stack, float_op); \ + } \ + } else { \ + int64_t a = x.toInt(); \ + if (y.isComplexDouble()) { \ + c10::complex b = y.toComplexDouble(); \ + push(stack, complex_op); \ + } else if (y.isDouble()) { \ + double b = y.toDouble(); \ + push(stack, float_op); \ + } else { \ + int64_t b = y.toInt(); \ + push(stack, int_op); \ + } \ + } \ + }, \ + aliasAnalysisFromSchema()) + +#define DEFINE_SCALAR_BINARY_OP_WITH_COMPLEX_WITHOUT_INT_COMPLEX_PAIR( \ aten_op, int_op, float_op, complex_op, result) \ OperatorGenerator( \ TORCH_SELECTIVE_SCHEMA(#aten_op "(Scalar a, Scalar b) -> " #result), \ @@ -798,9 +839,6 @@ void listSetItem(Stack* stack); } else if (y.isDouble()) { \ double b = y.toDouble(); \ push(stack, complex_op); \ - } else { \ - int64_t b = y.toInt(); \ - push(stack, complex_op); \ } \ } else if (x.isDouble()) { \ double a = x.toDouble(); \ @@ -816,13 +854,10 @@ void listSetItem(Stack* stack); } \ } else { \ int64_t a = x.toInt(); \ - if (y.isComplexDouble()) { \ - c10::complex b = y.toComplexDouble(); \ - push(stack, complex_op); \ - } else if (y.isDouble()) { \ + if (y.isDouble()) { \ double b = y.toDouble(); \ push(stack, float_op); \ - } else { \ + } else if (y.isInt()) { \ int64_t b = y.toInt(); \ push(stack, int_op); \ } \ @@ -830,6 +865,11 @@ void listSetItem(Stack* stack); }, \ aliasAnalysisFromSchema()) +#define DEFINE_SCALAR_BINARY_OP_WITH_COMPLEX( \ + aten_op, int_op, float_op, complex_op, result) \ + DEFINE_SCALAR_BINARY_OP_WITH_COMPLEX_AVOID_COLLISION_GENERIC( \ + aten_op, int_op, float_op, complex_op, result, "") + #define DEFINE_BINARY_OP_WITH_COMPLEX(aten_op, op) \ DEFINE_GENERIC_OP_WITH_COMPLEX(aten_op, op, op, op, int, float, complex), \ DEFINE_INT_COMPLEX_OP(aten_op, op, complex), \ @@ -837,5 +877,13 @@ void listSetItem(Stack* stack); DEFINE_INT_FLOAT_OP(aten_op, op, float), \ DEFINE_SCALAR_BINARY_OP_WITH_COMPLEX(aten_op, op, op, op, Scalar) +#define DEFINE_COMPARISON_OP_WITH_COMPLEX(aten_op, op) \ + DEFINE_GENERIC_OP_WITH_COMPLEX(aten_op, op, op, op, bool, bool, bool), \ + DEFINE_INT_FLOAT_OP(aten_op, op, bool), \ + DEFINE_FLOAT_COMPLEX_OP(aten_op, op, bool), \ + DEFINE_SCALAR_BINARY_OP_WITH_COMPLEX_WITHOUT_INT_COMPLEX_PAIR( \ + aten_op, op, op, op, bool), \ + DEFINE_STR_CMP_OP(aten_op, op) + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp index d1bb9563d86e1..b270d1f946abb 100644 --- a/torch/csrc/jit/runtime/register_prim_ops.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops.cpp @@ -792,8 +792,23 @@ RegisterOperators reg( aliasAnalysisFromSchema()), DEFINE_UNARY_OP_WITH_COMPLEX(aten::log, std::log(a), float, float), DEFINE_STRING_OP(aten::add, a + b, str), - DEFINE_COMPARISON_OP(aten::eq, a == b), - DEFINE_COMPARISON_OP(aten::ne, a != b), + DEFINE_COMPARISON_OP_WITH_COMPLEX(aten::eq, a == b), + DEFINE_COMPARISON_OP_WITH_COMPLEX(aten::ne, a != b), + DEFINE_GENERIC_OP( + aten::polar, + c10::polar(static_cast(a), static_cast(b)), + c10::polar(static_cast(a), static_cast(b)), + complex, + complex), + DEFINE_INT_FLOAT_OP( + aten::polar, + c10::polar(static_cast(a), static_cast(b)), + complex), + DEFINE_SCALAR_BINARY_OP_AVOID_COLLISION( + aten::polar, + c10::polar(static_cast(a), static_cast(b)), + c10::polar(static_cast(a), static_cast(b)), + Scalar), DEFINE_COMPARISON_OP(aten::lt, a < b), DEFINE_COMPARISON_OP(aten::gt, a > b), DEFINE_COMPARISON_OP(aten::le, a <= b), @@ -826,12 +841,14 @@ RegisterOperators reg( fmod((b + fmod(a, b)), b), Scalar), // NB: This is the python truediv operation - DEFINE_GENERIC_OP( + DEFINE_GENERIC_OP_WITH_COMPLEX( aten::div, static_cast(a) / static_cast(b), a / b, + a / b, float, - float), + float, + complex), DEFINE_SCALAR_BINARY_OP( aten::div, static_cast(a) / static_cast(b), @@ -851,14 +868,17 @@ RegisterOperators reg( Scalar), // int ** int produces a float, because negative exponents produce float // results - DEFINE_GENERIC_OP( + DEFINE_GENERIC_OP_WITH_COMPLEX( aten::pow, static_cast(pow(a, b)), static_cast(pow(a, b)), + static_cast>(pow(a, b)), float, - float), + float, + complex), DEFINE_INT_FLOAT_OP(aten::pow, pow(a, b), float), - DEFINE_SCALAR_SCALAR_BINARY_OP( + DEFINE_FLOAT_COMPLEX_OP(aten::pow, pow(a, b), complex), + DEFINE_SCALAR_BINARY_OP_AVOID_COLLISION( aten::pow, static_cast(pow(a, b)), static_cast(pow(a, b)), diff --git a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp index 5e1368daa9608..c8120cf3d2c42 100644 --- a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp @@ -947,9 +947,15 @@ RegisterOperators reg2({ DEFINE_INT_OP(aten::__lshift__, a << b), DEFINE_INT_OP(aten::__rshift__, a >> b), - DEFINE_GENERIC_BINARY_OP(aten::log, std::log(a) / std::log(b), float), + DEFINE_GENERIC_BINARY_OP( + aten::log, + std::log(a) / std::log(b), + float, + complex), DEFINE_INT_FLOAT_OP(aten::log, std::log(a) / std::log(b), float), - DEFINE_SCALAR_SCALAR_BINARY_OP( + DEFINE_INT_COMPLEX_OP(aten::log, std::log(a) / std::log(b), complex), + DEFINE_FLOAT_COMPLEX_OP(aten::log, std::log(a) / std::log(b), complex), + DEFINE_SCALAR_BINARY_OP_AVOID_COLLISION( aten::log, std::log(a) / std::log(b), std::log(a) / std::log(b), @@ -967,7 +973,7 @@ RegisterOperators reg2({ float, float), DEFINE_INT_FLOAT_OP(aten::atan2, std::atan2(a, b), float), - DEFINE_SCALAR_SCALAR_BINARY_OP( + DEFINE_SCALAR_BINARY_OP_AVOID_COLLISION( aten::atan2, std::atan2(a, b), std::atan2(a, b), @@ -995,6 +1001,18 @@ RegisterOperators reg2({ DEFINE_UNARY_FLOAT_OP(aten::isnan, std::isnan(a), bool), DEFINE_UNARY_FLOAT_OP(aten::isfinite, std::isfinite(a), bool), DEFINE_UNARY_FLOAT_OP(aten::isinf, std::isinf(a), bool), + DEFINE_UNARY_COMPLEX_OP( + aten::isnan, + std::isnan(a.real()) || std::isnan(a.imag()), + bool), + DEFINE_UNARY_COMPLEX_OP( + aten::isfinite, + std::isfinite(a.real()) && std::isfinite(a.imag()), + bool), + DEFINE_UNARY_COMPLEX_OP( + aten::isinf, + std::isinf(a.real()) || std::isinf(a.imag()), + bool), DEFINE_UNARY_OP(aten::gamma, std::tgamma(a), float, float), DEFINE_UNARY_OP(aten::erf, std::erf(a), float, float), DEFINE_UNARY_OP(aten::erfc, std::erfc(a), float, float), @@ -1080,6 +1098,17 @@ RegisterOperators reg2({ push(stack, sum); }, aliasAnalysisFromSchema()), + Operator( + "aten::sum.complex(complex[] self) -> complex", + [](Stack* stack) { + c10::List> l = pop(stack).toComplexDoubleList(); + c10::complex sum = 0.0; + for (int i = 0; i < l.size(); i++) { + sum = sum + l.extract(i); + } + push(stack, sum); + }, + aliasAnalysisFromSchema()), Operator( "aten::sum.bool(bool[] self) -> int", [](Stack* stack) { diff --git a/torch/jit/_builtins.py b/torch/jit/_builtins.py index 10a0b159dbee3..659a4f9dac9df 100644 --- a/torch/jit/_builtins.py +++ b/torch/jit/_builtins.py @@ -63,7 +63,11 @@ (math.isinf, "aten::isinf"), (math.degrees, "aten::degrees"), (math.radians, "aten::radians"), + (cmath.isnan, "aten::isnan"), + (cmath.isfinite, "aten::isfinite"), + (cmath.isinf, "aten::isinf"), (cmath.phase, "aten::angle"), + (cmath.rect, "aten::polar"), (cmath.log, "aten::log"), (cmath.log10, "aten::log10"), (cmath.sqrt, "aten::sqrt"),