Skip to content

Commit

Permalink
Update FMA pass to use intrinsic function
Browse files Browse the repository at this point in the history
  • Loading branch information
Smit-create committed Aug 17, 2023
1 parent fad2ad3 commit e510c2b
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 13 deletions.
5 changes: 3 additions & 2 deletions src/libasr/pass/fma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,7 @@ class FMAVisitor : public PassUtils::SkipOptimizationFunctionVisitor<FMAVisitor>
}

fma_var = PassUtils::get_fma(other_expr, first_arg, second_arg,
al, unit, pass_options, current_scope, x.base.base.loc,
[&](const std::string &msg, const Location &) { throw LCompilersException(msg); });
al, unit, x.base.base.loc);
from_fma = false;
}

Expand Down Expand Up @@ -170,6 +169,8 @@ void pass_replace_fma(Allocator &al, ASR::TranslationUnit_t &unit,
const LCompilers::PassOptions& pass_options) {
FMAVisitor v(al, unit, pass_options);
v.visit_TranslationUnit(unit);
PassUtils::UpdateDependenciesVisitor u(al);
u.visit_TranslationUnit(unit);
}


Expand Down
5 changes: 5 additions & 0 deletions src/libasr/pass/intrinsic_function_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -2392,6 +2392,8 @@ namespace IntrinsicScalarFunctionRegistry {
{nullptr, &UnaryIntrinsicFunction::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::Expm1),
{nullptr, &UnaryIntrinsicFunction::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::FMA),
{&FMA::instantiate_FMA, &FMA::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::Abs),
{&Abs::instantiate_Abs, &Abs::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::Partition),
Expand Down Expand Up @@ -2478,6 +2480,8 @@ namespace IntrinsicScalarFunctionRegistry {
"exp"},
{static_cast<int64_t>(IntrinsicScalarFunctions::Exp2),
"exp2"},
{static_cast<int64_t>(IntrinsicScalarFunctions::FMA),
"fma"},
{static_cast<int64_t>(IntrinsicScalarFunctions::Expm1),
"expm1"},
{static_cast<int64_t>(IntrinsicScalarFunctions::ListIndex),
Expand Down Expand Up @@ -2552,6 +2556,7 @@ namespace IntrinsicScalarFunctionRegistry {
{"exp", {&Exp::create_Exp, &Exp::eval_Exp}},
{"exp2", {&Exp2::create_Exp2, &Exp2::eval_Exp2}},
{"expm1", {&Expm1::create_Expm1, &Expm1::eval_Expm1}},
{"fma", {&FMA::create_FMA, &FMA::eval_FMA}},
{"list.index", {&ListIndex::create_ListIndex, &ListIndex::eval_list_index}},
{"list.reverse", {&ListReverse::create_ListReverse, &ListReverse::eval_list_reverse}},
{"list.pop", {&ListPop::create_ListPop, &ListPop::eval_list_pop}},
Expand Down
22 changes: 14 additions & 8 deletions src/libasr/pass/pass_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -666,11 +666,17 @@ namespace LCompilers {
}

ASR::expr_t* get_fma(ASR::expr_t* arg0, ASR::expr_t* arg1, ASR::expr_t* arg2,
Allocator& al, ASR::TranslationUnit_t& unit, LCompilers::PassOptions& pass_options,
SymbolTable*& current_scope, Location& loc,
const std::function<void (const std::string &, const Location &)> err) {
ASR::symbol_t *v = import_generic_procedure("fma", "lfortran_intrinsic_optimization",
al, unit, pass_options, current_scope, arg0->base.loc);
Allocator& al, ASR::TranslationUnit_t& unit, Location& loc){

ASRUtils::impl_function instantiate_function =
ASRUtils::IntrinsicScalarFunctionRegistry::get_instantiate_function(
static_cast<int64_t>(ASRUtils::IntrinsicScalarFunctions::FMA));
Vec<ASR::ttype_t*> arg_types;
ASR::ttype_t* type = ASRUtils::expr_type(arg0);
arg_types.reserve(al, 3);
arg_types.push_back(al, ASRUtils::expr_type(arg0));
arg_types.push_back(al, ASRUtils::expr_type(arg1));
arg_types.push_back(al, ASRUtils::expr_type(arg2));
Vec<ASR::call_arg_t> args;
args.reserve(al, 3);
ASR::call_arg_t arg0_, arg1_, arg2_;
Expand All @@ -680,9 +686,9 @@ namespace LCompilers {
args.push_back(al, arg1_);
arg2_.loc = arg2->base.loc, arg2_.m_value = arg2;
args.push_back(al, arg2_);
return ASRUtils::EXPR(
ASRUtils::symbol_resolve_external_generic_procedure_without_eval(
loc, v, args, current_scope, al, err));
return instantiate_function(al, loc,
unit.m_global_scope, arg_types, type, args, 0,
nullptr);
}

ASR::symbol_t* insert_fallback_vector_copy(Allocator& al, ASR::TranslationUnit_t& unit,
Expand Down
4 changes: 1 addition & 3 deletions src/libasr/pass/pass_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,7 @@ namespace LCompilers {
ASR::intentType var_intent=ASR::intentType::Local);

ASR::expr_t* get_fma(ASR::expr_t* arg0, ASR::expr_t* arg1, ASR::expr_t* arg2,
Allocator& al, ASR::TranslationUnit_t& unit, LCompilers::PassOptions& pass_options,
SymbolTable*& current_scope,Location& loc,
const std::function<void (const std::string &, const Location &)> err);
Allocator& al, ASR::TranslationUnit_t& unit, Location& loc);

ASR::expr_t* get_sign_from_value(ASR::expr_t* arg0, ASR::expr_t* arg1,
Allocator& al, ASR::TranslationUnit_t& unit,
Expand Down

0 comments on commit e510c2b

Please sign in to comment.