Skip to content

Commit

Permalink
Merge pull request lcompilers#2330 from khushi-411/lp_fix
Browse files Browse the repository at this point in the history
[numpy] add fix
  • Loading branch information
certik authored Sep 20, 2023
2 parents 5569d74 + abd6ef1 commit 0d9fff9
Show file tree
Hide file tree
Showing 32 changed files with 1,343 additions and 1,191 deletions.
64 changes: 62 additions & 2 deletions integration_tests/elemental_13.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from lpython import f32, f64
from numpy import trunc, empty, sqrt, reshape, int32, float32, float64
from numpy import trunc, fix, empty, sqrt, reshape, int32, float32, float64


def elemental_trunc64():
Expand Down Expand Up @@ -60,5 +60,65 @@ def elemental_trunc32():
assert abs(trunc(arraynd[i, j, k, l]) - observed[i, j, k, l]) <= eps


def elemental_fix64():
i: i32
j: i32
k: i32
l: i32
eps: f32
eps = f32(1e-6)

arraynd: f64[32, 16, 8, 4] = empty((32, 16, 8, 4), dtype=float64)

newshape: i32[1] = empty(1, dtype = int32)
newshape[0] = 16384

for i in range(32):
for j in range(16):
for k in range(8):
for l in range(4):
arraynd[i, j, k, l] = f64((-1)**l) * sqrt(float(i + j + j + l))

observed: f64[32, 16, 8, 4] = empty((32, 16, 8, 4), dtype=float64)
observed = fix(arraynd)

observed1d: f64[16384] = empty(16384, dtype=float64)
observed1d = reshape(observed, newshape)

array: f64[16384] = empty(16384, dtype=float64)
array = reshape(arraynd, newshape)

for i in range(16384):
assert f32(abs(fix(array[i]) - observed1d[i])) <= eps


def elemental_fix32():
i: i32
j: i32
k: i32
l: i32
eps: f32
eps = f32(1e-6)

arraynd: f32[32, 16, 8, 4] = empty((32, 16, 8, 4), dtype=float32)

for i in range(32):
for j in range(16):
for k in range(8):
for l in range(4):
arraynd[i, j, k, l] = f32(f64((-1)**l) * sqrt(float(i + j + j + l)))

observed: f32[32, 16, 8, 4] = empty((32, 16, 8, 4), dtype=float32)
observed = fix(arraynd)

for i in range(32):
for j in range(16):
for k in range(8):
for l in range(4):
assert abs(fix(arraynd[i, j, k, l]) - observed[i, j, k, l]) <= eps


elemental_trunc64()
elemental_trunc32()
elemental_trunc32()
elemental_fix64()
elemental_fix32()
1 change: 1 addition & 0 deletions src/libasr/codegen/asr_to_c_cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -2799,6 +2799,7 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
SET_INTRINSIC_NAME(Exp2, "exp2");
SET_INTRINSIC_NAME(Expm1, "expm1");
SET_INTRINSIC_NAME(Trunc, "trunc");
SET_INTRINSIC_NAME(Fix, "fix");
default : {
throw LCompilersException("IntrinsicScalarFunction: `"
+ ASRUtils::get_intrinsic_name(x.m_intrinsic_id)
Expand Down
1 change: 1 addition & 0 deletions src/libasr/codegen/asr_to_julia.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1900,6 +1900,7 @@ class ASRToJuliaVisitor : public ASR::BaseVisitor<ASRToJuliaVisitor>
SET_INTRINSIC_NAME(Exp2, "exp2");
SET_INTRINSIC_NAME(Expm1, "expm1");
SET_INTRINSIC_NAME(Trunc, "trunc");
SET_INTRINSIC_NAME(Fix, "fix");
default : {
throw LCompilersException("IntrinsicFunction: `"
+ ASRUtils::get_intrinsic_name(x.m_intrinsic_id)
Expand Down
48 changes: 48 additions & 0 deletions src/libasr/pass/intrinsic_function_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ enum class IntrinsicScalarFunctions : int64_t {
Gamma,
LogGamma,
Trunc,
Fix,
Abs,
Exp,
Exp2,
Expand Down Expand Up @@ -98,6 +99,7 @@ inline std::string get_intrinsic_name(int x) {
INTRINSIC_NAME_CASE(Gamma)
INTRINSIC_NAME_CASE(LogGamma)
INTRINSIC_NAME_CASE(Trunc)
INTRINSIC_NAME_CASE(Fix)
INTRINSIC_NAME_CASE(Abs)
INTRINSIC_NAME_CASE(Exp)
INTRINSIC_NAME_CASE(Exp2)
Expand Down Expand Up @@ -1182,6 +1184,46 @@ namespace X {

create_trunc_macro(Trunc, trunc)

namespace Fix {
static inline ASR::expr_t *eval_Fix(Allocator &al, const Location &loc,
ASR::ttype_t *t, Vec<ASR::expr_t*>& args) {
LCOMPILERS_ASSERT(args.size() == 1);
double rv = ASR::down_cast<ASR::RealConstant_t>(args[0])->m_r;
double val;
if (rv > 0.0) {
val = floor(rv);
} else {
val = ceil(rv);
}
return make_ConstantWithType(make_RealConstant_t, val, t, loc);
}

static inline ASR::asr_t* create_Fix(Allocator& al, const Location& loc,
Vec<ASR::expr_t*>& args,
const std::function<void (const std::string &, const Location &)> err) {
ASR::ttype_t *type = ASRUtils::expr_type(args[0]);
if (args.n != 1) {
err("Intrinsic `fix` accepts exactly one argument", loc);
} else if (!ASRUtils::is_real(*type)) {
err("`fix` argument of `fix` must be real",
args[0]->base.loc);
}
return UnaryIntrinsicFunction::create_UnaryFunction(al, loc, args,
eval_Fix, static_cast<int64_t>(IntrinsicScalarFunctions::Fix),
0, type);
}

static inline ASR::expr_t* instantiate_Fix (Allocator &al,
const Location &loc, SymbolTable *scope, Vec<ASR::ttype_t*>& arg_types,
ASR::ttype_t *return_type, Vec<ASR::call_arg_t>& new_args,
int64_t overload_id) {
ASR::ttype_t* arg_type = arg_types[0];
return UnaryIntrinsicFunction::instantiate_functions(al, loc, scope,
"fix", arg_type, return_type, new_args, overload_id);
}

} // namespace Fix

// `X` is the name of the function in the IntrinsicScalarFunctions enum and
// we use the same name for `create_X` and other places
// `stdeval` is the name of the function in the `std` namespace for compile
Expand Down Expand Up @@ -2921,6 +2963,8 @@ namespace IntrinsicScalarFunctionRegistry {
{&LogGamma::instantiate_LogGamma, &UnaryIntrinsicFunction::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::Trunc),
{&Trunc::instantiate_Trunc, &UnaryIntrinsicFunction::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::Fix),
{&Fix::instantiate_Fix, &UnaryIntrinsicFunction::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::Sin),
{&Sin::instantiate_Sin, &UnaryIntrinsicFunction::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::Cos),
Expand Down Expand Up @@ -3021,6 +3065,8 @@ namespace IntrinsicScalarFunctionRegistry {

{static_cast<int64_t>(IntrinsicScalarFunctions::Trunc),
"trunc"},
{static_cast<int64_t>(IntrinsicScalarFunctions::Fix),
"fix"},
{static_cast<int64_t>(IntrinsicScalarFunctions::Sin),
"sin"},
{static_cast<int64_t>(IntrinsicScalarFunctions::Cos),
Expand Down Expand Up @@ -3119,6 +3165,7 @@ namespace IntrinsicScalarFunctionRegistry {
eval_intrinsic_function>>& intrinsic_function_by_name_db = {
{"log_gamma", {&LogGamma::create_LogGamma, &LogGamma::eval_log_gamma}},
{"trunc", {&Trunc::create_Trunc, &Trunc::eval_Trunc}},
{"fix", {&Fix::create_Fix, &Fix::eval_Fix}},
{"sin", {&Sin::create_Sin, &Sin::eval_Sin}},
{"cos", {&Cos::create_Cos, &Cos::eval_Cos}},
{"tan", {&Tan::create_Tan, &Tan::eval_Tan}},
Expand Down Expand Up @@ -3180,6 +3227,7 @@ namespace IntrinsicScalarFunctionRegistry {
id_ == IntrinsicScalarFunctions::Gamma ||
id_ == IntrinsicScalarFunctions::LogGamma ||
id_ == IntrinsicScalarFunctions::Trunc ||
id_ == IntrinsicScalarFunctions::Fix ||
id_ == IntrinsicScalarFunctions::Sin ||
id_ == IntrinsicScalarFunctions::Exp ||
id_ == IntrinsicScalarFunctions::Exp2 ||
Expand Down
20 changes: 20 additions & 0 deletions src/libasr/runtime/lfortran_intrinsics.c
Original file line number Diff line number Diff line change
Expand Up @@ -1146,6 +1146,26 @@ LFORTRAN_API double _lfortran_dtrunc(double x)
return trunc(x);
}

// fix -----------------------------------------------------------------------

LFORTRAN_API float _lfortran_sfix(float x)
{
if (x > 0.0) {
return floorf(x);
} else {
return ceilf(x);
}
}

LFORTRAN_API double _lfortran_dfix(double x)
{
if (x > 0.0) {
return floor(x);
} else {
return ceil(x);
}
}

// phase --------------------------------------------------------------------

LFORTRAN_API float _lfortran_cphase(float_complex_t x)
Expand Down
2 changes: 2 additions & 0 deletions src/libasr/runtime/lfortran_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ LFORTRAN_API float_complex_t _lfortran_catanh(float_complex_t x);
LFORTRAN_API double_complex_t _lfortran_zatanh(double_complex_t x);
LFORTRAN_API float _lfortran_strunc(float x);
LFORTRAN_API double _lfortran_dtrunc(double x);
LFORTRAN_API float _lfortran_sfix(float x);
LFORTRAN_API double _lfortran_dfix(double x);
LFORTRAN_API float _lfortran_cphase(float_complex_t x);
LFORTRAN_API double _lfortran_zphase(double_complex_t x);
LFORTRAN_API bool _lpython_str_compare_eq(char** s1, char** s2);
Expand Down
2 changes: 1 addition & 1 deletion src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7313,7 +7313,7 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
if (!s) {
std::string intrinsic_name = call_name;
std::set<std::string> not_cpython_builtin = {
"sin", "cos", "gamma", "tan", "asin", "acos", "atan", "sinh", "cosh", "tanh", "exp", "exp2", "expm1", "Symbol", "diff", "expand", "trunc",
"sin", "cos", "gamma", "tan", "asin", "acos", "atan", "sinh", "cosh", "tanh", "exp", "exp2", "expm1", "Symbol", "diff", "expand", "trunc", "fix",
"sum" // For sum called over lists
};
std::set<std::string> symbolic_functions = {
Expand Down
20 changes: 20 additions & 0 deletions src/runtime/lpython_intrinsic_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,3 +430,23 @@ def _lfortran_strunc(x: f32) -> f32:
@vectorize
def trunc(x: f32) -> f32:
return _lfortran_strunc(x)

########## fix ##########

@ccall
def _lfortran_dfix(x: f64) -> f64:
pass

@overload
@vectorize
def fix(x: f64) -> f64:
return _lfortran_dfix(x)

@ccall
def _lfortran_sfix(x: f32) -> f32:
pass

@overload
@vectorize
def fix(x: f32) -> f32:
return _lfortran_sfix(x)
2 changes: 1 addition & 1 deletion tests/reference/asr-array_01_decl-39cf894.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"outfile": null,
"outfile_hash": null,
"stdout": "asr-array_01_decl-39cf894.stdout",
"stdout_hash": "2aa47467473392c970bb1ddde961e3007d4c157bb0ea507b5e0db4a4",
"stdout_hash": "b0dc16e057dc08b7ec8adac23b2d98fa29d536fca17934c2689425d8",
"stderr": null,
"stderr_hash": null,
"returncode": 0
Expand Down
Loading

0 comments on commit 0d9fff9

Please sign in to comment.