Skip to content

Commit

Permalink
[ASR Pass] Symbolic: Simplify process_attributes to use macros
Browse files Browse the repository at this point in the history
  • Loading branch information
Thirumalai-Shaktivel committed Nov 25, 2023
1 parent f18ae18 commit 760380a
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 50 deletions.
8 changes: 4 additions & 4 deletions src/libasr/pass/intrinsic_function_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,11 +251,11 @@ class ASRBuilder {
false, nullptr, 0, false, false, false));

// Types -------------------------------------------------------------------
#define int32 TYPE(ASR::make_Integer_t(al, loc, 4))
#define int32 ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4))
#define int64 TYPE(ASR::make_Integer_t(al, loc, 8))
#define real32 TYPE(ASR::make_Real_t(al, loc, 4))
#define real64 TYPE(ASR::make_Real_t(al, loc, 8))
#define logical TYPE(ASR::make_Logical_t(al, loc, 4))
#define logical ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4))
#define character(x) TYPE(ASR::make_Character_t(al, loc, 1, x, nullptr))
#define List(x) TYPE(ASR::make_List_t(al, loc, x))

Expand Down Expand Up @@ -285,7 +285,7 @@ class ASRBuilder {

// Expressions -------------------------------------------------------------
#define i(x, t) EXPR(ASR::make_IntegerConstant_t(al, loc, x, t))
#define i32(x) EXPR(ASR::make_IntegerConstant_t(al, loc, x, int32))
#define i32(x) ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, x, int32))
#define i32_n(x) EXPR(ASR::make_IntegerUnaryMinus_t(al, loc, i32(abs(x)), \
int32, i32(x)))
#define i32_neg(x, t) EXPR(ASR::make_IntegerUnaryMinus_t(al, loc, x, t, nullptr))
Expand Down Expand Up @@ -414,7 +414,7 @@ class ASRBuilder {
}

// Compare -----------------------------------------------------------------
#define iEq(x, y) EXPR(ASR::make_IntegerCompare_t(al, loc, x, \
#define iEq(x, y) ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, x, \
ASR::cmpopType::Eq, y, logical, nullptr))
#define iNotEq(x, y) EXPR(ASR::make_IntegerCompare_t(al, loc, x, \
ASR::cmpopType::NotEq, y, logical, nullptr))
Expand Down
60 changes: 14 additions & 46 deletions src/libasr/pass/replace_symbolic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
target, x->m_args[0])); \
break; }

#define BASIC_ATTR(SYM, N) \
case LCompilers::ASRUtils::IntrinsicScalarFunctions::Symbolic##SYM: { \
ASR::expr_t* function_call = basic_get_type(loc, \
intrinsic_func->m_args[0]); \
return iEq(function_call, i32(N)); }

ASR::stmt_t *basic_new_stack(const Location &loc, ASR::expr_t *x) {
std::string fn_name = "basic_new_stack";
symbolic_dependencies.push_back(fn_name);
Expand Down Expand Up @@ -454,7 +460,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
Vec<ASR::call_arg_t> call_args; call_args.reserve(al, 1);
ASR::call_arg_t call_arg;
call_arg.loc = loc;
call_arg.m_value = value;
call_arg.m_value = handle_argument(al, loc, value);
call_args.push_back(al, call_arg);
return ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc,
basic_get_type_sym, basic_get_type_sym, call_args.p, call_args.n,
Expand Down Expand Up @@ -1011,51 +1017,13 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr, nullptr));
break;
}
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicAddQ: {
ASR::expr_t* value1 = handle_argument(al, loc, intrinsic_func->m_args[0]);
ASR::expr_t* function_call = basic_get_type(loc, value1);
// Using 16 as the right value of the IntegerCompare node as it represents SYMENGINE_ADD through SYMENGINE_ENUM
return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq,
ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 16, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))),
ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr));
break;
}
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicMulQ: {
ASR::expr_t* value1 = handle_argument(al, loc, intrinsic_func->m_args[0]);
ASR::expr_t* function_call = basic_get_type(loc, value1);
// Using 15 as the right value of the IntegerCompare node as it represents SYMENGINE_MUL through SYMENGINE_ENUM
return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq,
ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 15, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))),
ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr));
break;
}
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicPowQ: {
ASR::expr_t* value1 = handle_argument(al, loc, intrinsic_func->m_args[0]);
ASR::expr_t* function_call = basic_get_type(loc, value1);
// Using 17 as the right value of the IntegerCompare node as it represents SYMENGINE_POW through SYMENGINE_ENUM
return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq,
ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 17, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))),
ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr));
break;
}
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicLogQ: {
ASR::expr_t* value1 = handle_argument(al, loc, intrinsic_func->m_args[0]);
ASR::expr_t* function_call = basic_get_type(loc, value1);
// Using 29 as the right value of the IntegerCompare node as it represents SYMENGINE_LOG through SYMENGINE_ENUM
return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq,
ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 29, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))),
ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr));
break;
}
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicSinQ: {
ASR::expr_t* value1 = handle_argument(al, loc, intrinsic_func->m_args[0]);
ASR::expr_t* function_call = basic_get_type(loc, value1);
// Using 35 as the right value of the IntegerCompare node as it represents SYMENGINE_SIN through SYMENGINE_ENUM
return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq,
ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 35, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))),
ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr));
break;
}
// (sym_name, n) where n = 16, 15, ... as the right value of the
// IntegerCompare node as it represents SYMENGINE_ADD through SYMENGINE_ENUM
BASIC_ATTR(AddQ, 16)
BASIC_ATTR(MulQ, 15)
BASIC_ATTR(PowQ, 17)
BASIC_ATTR(LogQ, 29)
BASIC_ATTR(SinQ, 35)
default: {
throw LCompilersException("IntrinsicFunction: `"
+ ASRUtils::get_intrinsic_name(intrinsic_id)
Expand Down

0 comments on commit 760380a

Please sign in to comment.