Skip to content

Commit

Permalink
Merge pull request lcompilers#263 from namannimmo10/np
Browse files Browse the repository at this point in the history
Overload numpy functions
  • Loading branch information
namannimmo10 committed Mar 20, 2022
2 parents cc1cfdd + 23aac95 commit 2dd1e4d
Showing 1 changed file with 92 additions and 7 deletions.
99 changes: 92 additions & 7 deletions integration_tests/test_numpy_02.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This test handles actual LPython implementations of functions from the numpy
# module.
from ltypes import i32, i64, f64, TypeVar, overload
from ltypes import i32, i64, f32, f64, TypeVar, overload
from numpy import empty, int64

e: f64 = 2.718281828459045
Expand Down Expand Up @@ -32,33 +32,88 @@ def arange(n: i32) -> i64[n]:
A[i] = i
return A

#: sqrt() as a generic procedure.
#: supported types for argument:
#: i32, i64, f32, f64, bool
@overload
def sqrt(n: i32) -> f64:
return n**(1/2)

@overload
def sqrt(n: i64) -> f64:
return n**(1/2)

@overload
def sqrt(f: f32) -> f32:
return f**(1/2)

@overload
def sqrt(f: f64) -> f64:
return f**(1/2)

@overload
def sqrt(b: bool) -> f64:
if b:
return 1.0
else:
return 0.0

#: exp() as a generic procedure.
#: supported types for argument:
#: i32, i64, f32, f64, bool
@overload
def exp(n: i32) -> f64:
return e**n

@overload
def exp(n: i64) -> f64:
return e**n

@overload
def exp(f: f32) -> f32:
return e**f

@overload
def exp(f: f64) -> f64:
return e**f

@overload
def exp(b: bool) -> f64:
if b:
return 2.719
else:
return 1.0

#: fabs() as a generic procedure.
#: supported types for argument:
#: i32, i64, f32, f64, bool
@overload
def fabs(n: i32) -> f64:
if n < 0:
return -1.0*n
return 1.0*n

@overload
def fabs(n: i64) -> f64:
if n < 0:
return -1.0*n
return 1.0*n

@overload
def fabs(f: f32) -> f32:
if f < 0.0:
return -f
return f

@overload
def fabs(f: f64) -> f64:
if f < 0.0:
return -f
return f

@overload
def fabs(n: i32) -> f64:
if n < 0:
return -1.0*n
return 1.0*n
def fabs(b: bool) -> f64:
return sqrt(b)

num: i32
num = TypeVar("num")
Expand Down Expand Up @@ -99,14 +154,24 @@ def test_arange():
assert a[3] == 3

def test_sqrt():
eps: f64
eps = 1e-12
a: f64
a2: f64
a = sqrt(2)
a2 = sqrt(5.6)
eps: f64
eps = 1e-12
assert abs(a - 1.4142135623730951) < eps
assert abs(a2 - 2.3664319132398464) < eps
assert abs(sqrt(False) - 0.0) < eps

i: i64
i = 4
a = sqrt(i)
assert abs(a - 2.0) < eps

f: f32
f = 4.0
assert abs(sqrt(f) - 2.0) < eps

def test_exp():
a: f64
Expand All @@ -117,6 +182,16 @@ def test_exp():
eps = 1e-12
assert abs(a - 403.4287934927351) < eps
assert abs(a2 - 270.42640742615254) < eps
assert abs(exp(True) - 2.719) < eps

i: i64
i = 4
a = exp(i)
assert abs(a - 54.598150033144236) < eps

f: f32
f = -4.0
print(exp(f))

def test_fabs():
a: f64
Expand All @@ -127,6 +202,16 @@ def test_fabs():
eps = 1e-12
assert abs(a - 3.7) < eps
assert abs(a2 - 3.0) < eps
assert abs(fabs(True) - 1.0) < eps

i: i64
i = -4
a = fabs(i)
assert abs(a - 4.0) < eps

f: f32
f = -4.0
assert abs(fabs(f) - 4.0) < eps

def test_linspace():
a: f64[4]
Expand Down

0 comments on commit 2dd1e4d

Please sign in to comment.