Skip to content

Commit

Permalink
[TS] Add complex support for more ops (pytorch#54541)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#54541

Test Plan: Imported from OSS

Reviewed By: SplitInfinity

Differential Revision: D27599114

Pulled By: anjali411

fbshipit-source-id: 182d4480fd788599c408bfaf0d23baf3d9a4e967
  • Loading branch information
anjali411 authored and facebook-github-bot committed May 13, 2021
1 parent 7a95ccc commit 82d7149
Show file tree
Hide file tree
Showing 5 changed files with 266 additions and 97 deletions.
74 changes: 71 additions & 3 deletions test/jit/test_complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import sys
from torch.testing._internal.jit_utils import JitTestCase, execWrapper
from torch.testing._internal.common_utils import IS_MACOS
from typing import List, Dict
from itertools import product
from textwrap import dedent
Expand Down Expand Up @@ -79,7 +80,13 @@ def checkCmath(func_name, funcs_template=funcs_template):
f_script = cu.func
f = scope['func']

for a in complex_vals:
if func_name in ['isinf', 'isnan', 'isfinite']:
new_vals = vals + ([float('inf'), float('nan'), -1 * float('inf')])
final_vals = tuple(complex(x, y) for x, y in product(new_vals, new_vals))
else:
final_vals = complex_vals

for a in final_vals:
res_python = None
res_script = None
try:
Expand All @@ -99,7 +106,7 @@ def checkCmath(func_name, funcs_template=funcs_template):
self.assertEqual(res_python, res_script, msg=msg)

unary_ops = ['log', 'log10', 'sqrt', 'exp', 'sin', 'cos', 'asin', 'acos', 'atan', 'sinh', 'cosh',
'tanh', 'asinh', 'acosh', 'atanh', 'phase']
'tanh', 'asinh', 'acosh', 'atanh', 'phase', 'isinf', 'isnan', 'isfinite']

# --- Unary ops ---
for op in unary_ops:
Expand All @@ -111,6 +118,32 @@ def fn(x: complex):
for val in complex_vals:
self.checkScript(fn, (val, ))

def pow_complex_float(x: complex, y: float):
return pow(x, y)

def pow_float_complex(x: float, y: complex):
return pow(x, y)


self.checkScript(pow_float_complex, (2, 3j))
self.checkScript(pow_complex_float, (3j, 2))

def pow_complex_complex(x: complex, y: complex):
return pow(x, y)

for x, y in zip(complex_vals, complex_vals):
# Reference: https://github.com/pytorch/pytorch/issues/54622
if (x == 0):
continue
self.checkScript(pow_complex_complex, (x, y))

if not IS_MACOS:
# --- Binary op ---
def rect_fn(x: float, y: float):
return cmath.rect(x, y)
for x, y in product(vals, vals):
self.checkScript(rect_fn, (x, y, ))

func_constants_template = dedent('''
def func():
return cmath.{func_or_const}
Expand All @@ -120,7 +153,6 @@ def func():
for x in (float_consts + complex_consts):
checkCmath(x, funcs_template=func_constants_template)


def test_infj_nanj_pickle(self):
class ComplexModule(torch.jit.ScriptModule):
def __init__(self):
Expand Down Expand Up @@ -247,3 +279,39 @@ def fn_tensor_tensor(real, img):

for x, y in product(tensors, tensors):
self.checkScript(fn_tensor_tensor, (x, y, ))

def test_comparison_ops(self):
def fn1(a: complex, b: complex):
return a == b

def fn2(a: complex, b: complex):
return a != b

def fn3(a: complex, b: float):
return a == b

def fn4(a: complex, b: float):
return a != b

x, y = 2 - 3j, 4j
self.checkScript(fn1, (x, x))
self.checkScript(fn1, (x, y))
self.checkScript(fn2, (x, x))
self.checkScript(fn2, (x, y))

x1, y1 = 1 + 0j, 1.0
self.checkScript(fn3, (x1, y1))
self.checkScript(fn4, (x1, y1))

def test_div(self):
def fn1(a: complex, b: complex):
return a / b

x, y = 2 - 3j, 4j
self.checkScript(fn1, (x, y))

def test_complex_list_sum(self):
def fn(x: List[complex]):
return sum(x)

self.checkScript(fn, (torch.randn(4, dtype=torch.cdouble).tolist(), ))
Loading

0 comments on commit 82d7149

Please sign in to comment.