Skip to content

Commit

Permalink
Merge pull request lcompilers#2282 from Smit-create/i-1671-1
Browse files Browse the repository at this point in the history
Fix FMA pass to use IntrinsicFunction
  • Loading branch information
Smit-create authored Aug 18, 2023
2 parents 91207c9 + 9638342 commit 6e9fb0d
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 13 deletions.
1 change: 1 addition & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ RUN(NAME expr_18 FAIL LABELS cpython llvm c)
RUN(NAME expr_19 LABELS cpython llvm c)
RUN(NAME expr_20 LABELS cpython llvm c)
RUN(NAME expr_21 LABELS cpython llvm c)
RUN(NAME expr_22 LABELS cpython llvm c)

RUN(NAME expr_01u LABELS cpython llvm c NOFAST)
RUN(NAME expr_02u LABELS cpython llvm c NOFAST)
Expand Down
10 changes: 10 additions & 0 deletions integration_tests/expr_22.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from lpython import f64

# test issue 1671
def test_fast_fma() -> f64:
a : f64 = 5.00
a = a + a * 10.00
assert abs(a - 55.00) < 1e-12
return a

print(test_fast_fma())
5 changes: 3 additions & 2 deletions src/libasr/pass/fma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,7 @@ class FMAVisitor : public PassUtils::SkipOptimizationFunctionVisitor<FMAVisitor>
}

fma_var = PassUtils::get_fma(other_expr, first_arg, second_arg,
al, unit, pass_options, current_scope, x.base.base.loc,
[&](const std::string &msg, const Location &) { throw LCompilersException(msg); });
al, unit, x.base.base.loc);
from_fma = false;
}

Expand Down Expand Up @@ -170,6 +169,8 @@ void pass_replace_fma(Allocator &al, ASR::TranslationUnit_t &unit,
const LCompilers::PassOptions& pass_options) {
FMAVisitor v(al, unit, pass_options);
v.visit_TranslationUnit(unit);
PassUtils::UpdateDependenciesVisitor u(al);
u.visit_TranslationUnit(unit);
}


Expand Down
83 changes: 83 additions & 0 deletions src/libasr/pass/intrinsic_function_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ enum class IntrinsicScalarFunctions : int64_t {
Exp,
Exp2,
Expm1,
FMA,
ListIndex,
Partition,
ListReverse,
Expand Down Expand Up @@ -93,6 +94,7 @@ inline std::string get_intrinsic_name(int x) {
INTRINSIC_NAME_CASE(Exp)
INTRINSIC_NAME_CASE(Exp2)
INTRINSIC_NAME_CASE(Expm1)
INTRINSIC_NAME_CASE(FMA)
INTRINSIC_NAME_CASE(ListIndex)
INTRINSIC_NAME_CASE(Partition)
INTRINSIC_NAME_CASE(ListReverse)
Expand Down Expand Up @@ -1281,6 +1283,82 @@ namespace Sign {

} // namespace Sign

namespace FMA {

static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, diag::Diagnostics& diagnostics) {
ASRUtils::require_impl(x.n_args == 3,
"ASR Verify: Call to FMA must have exactly 3 arguments",
x.base.base.loc, diagnostics);
ASR::ttype_t *type1 = ASRUtils::expr_type(x.m_args[0]);
ASR::ttype_t *type2 = ASRUtils::expr_type(x.m_args[1]);
ASR::ttype_t *type3 = ASRUtils::expr_type(x.m_args[2]);
ASRUtils::require_impl((is_real(*type1) && is_real(*type2) && is_real(*type3)),
"ASR Verify: Arguments to FMA must be of real type",
x.base.base.loc, diagnostics);
}

static ASR::expr_t *eval_FMA(Allocator &al, const Location &loc,
ASR::ttype_t* t1, Vec<ASR::expr_t*> &args) {
double a = ASR::down_cast<ASR::RealConstant_t>(args[0])->m_r;
double b = ASR::down_cast<ASR::RealConstant_t>(args[1])->m_r;
double c = ASR::down_cast<ASR::RealConstant_t>(args[2])->m_r;
return make_ConstantWithType(make_RealConstant_t, a + b*c, t1, loc);
}

static inline ASR::asr_t* create_FMA(Allocator& al, const Location& loc,
Vec<ASR::expr_t*>& args,
const std::function<void (const std::string &, const Location &)> err) {
if (args.size() != 3) {
err("Intrinsic FMA function accepts exactly 3 arguments", loc);
}
ASR::ttype_t *type1 = ASRUtils::expr_type(args[0]);
ASR::ttype_t *type2 = ASRUtils::expr_type(args[1]);
ASR::ttype_t *type3 = ASRUtils::expr_type(args[2]);
if (!ASRUtils::is_real(*type1) || !ASRUtils::is_real(*type2) || !ASRUtils::is_real(*type3)) {
err("Argument of the FMA function must be Real",
args[0]->base.loc);
}
ASR::expr_t *m_value = nullptr;
if (all_args_evaluated(args)) {
Vec<ASR::expr_t*> arg_values; arg_values.reserve(al, 3);
arg_values.push_back(al, expr_value(args[0]));
arg_values.push_back(al, expr_value(args[1]));
arg_values.push_back(al, expr_value(args[2]));
m_value = eval_FMA(al, loc, expr_type(args[0]), arg_values);
}
return ASR::make_IntrinsicScalarFunction_t(al, loc,
static_cast<int64_t>(IntrinsicScalarFunctions::FMA),
args.p, args.n, 0, ASRUtils::expr_type(args[0]), m_value);
}

static inline ASR::expr_t* instantiate_FMA(Allocator &al, const Location &loc,
SymbolTable *scope, Vec<ASR::ttype_t*>& arg_types, ASR::ttype_t *return_type,
Vec<ASR::call_arg_t>& new_args, int64_t /*overload_id*/,
ASR::expr_t* compile_time_value) {
if (compile_time_value) {
return compile_time_value;
}
declare_basic_variables("_lcompilers_optimization_fma_" + type_to_str_python(arg_types[0]));
fill_func_arg("a", arg_types[0]);
fill_func_arg("b", arg_types[0]);
fill_func_arg("c", arg_types[0]);
auto result = declare(fn_name, return_type, ReturnVar);
/*
* result = a + b*c
*/

ASR::expr_t *op1 = b.ElementalMul(args[1], args[2], loc);
body.push_back(al, b.Assignment(result,
b.ElementalAdd(args[0], op1, loc)));

ASR::symbol_t *f_sym = make_Function_t(fn_name, fn_symtab, dep, args,
body, result, Source, Implementation, nullptr);
scope->add_symbol(fn_name, f_sym);
return b.Call(f_sym, new_args, return_type, nullptr);
}

} // namespace FMA

#define create_exp_macro(X, stdeval) \
namespace X { \
static inline ASR::expr_t* eval_##X(Allocator &al, const Location &loc, \
Expand Down Expand Up @@ -2314,6 +2392,8 @@ namespace IntrinsicScalarFunctionRegistry {
{nullptr, &UnaryIntrinsicFunction::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::Expm1),
{nullptr, &UnaryIntrinsicFunction::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::FMA),
{&FMA::instantiate_FMA, &FMA::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::Abs),
{&Abs::instantiate_Abs, &Abs::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::Partition),
Expand Down Expand Up @@ -2400,6 +2480,8 @@ namespace IntrinsicScalarFunctionRegistry {
"exp"},
{static_cast<int64_t>(IntrinsicScalarFunctions::Exp2),
"exp2"},
{static_cast<int64_t>(IntrinsicScalarFunctions::FMA),
"fma"},
{static_cast<int64_t>(IntrinsicScalarFunctions::Expm1),
"expm1"},
{static_cast<int64_t>(IntrinsicScalarFunctions::ListIndex),
Expand Down Expand Up @@ -2474,6 +2556,7 @@ namespace IntrinsicScalarFunctionRegistry {
{"exp", {&Exp::create_Exp, &Exp::eval_Exp}},
{"exp2", {&Exp2::create_Exp2, &Exp2::eval_Exp2}},
{"expm1", {&Expm1::create_Expm1, &Expm1::eval_Expm1}},
{"fma", {&FMA::create_FMA, &FMA::eval_FMA}},
{"list.index", {&ListIndex::create_ListIndex, &ListIndex::eval_list_index}},
{"list.reverse", {&ListReverse::create_ListReverse, &ListReverse::eval_list_reverse}},
{"list.pop", {&ListPop::create_ListPop, &ListPop::eval_list_pop}},
Expand Down
22 changes: 14 additions & 8 deletions src/libasr/pass/pass_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -666,11 +666,17 @@ namespace LCompilers {
}

ASR::expr_t* get_fma(ASR::expr_t* arg0, ASR::expr_t* arg1, ASR::expr_t* arg2,
Allocator& al, ASR::TranslationUnit_t& unit, LCompilers::PassOptions& pass_options,
SymbolTable*& current_scope, Location& loc,
const std::function<void (const std::string &, const Location &)> err) {
ASR::symbol_t *v = import_generic_procedure("fma", "lfortran_intrinsic_optimization",
al, unit, pass_options, current_scope, arg0->base.loc);
Allocator& al, ASR::TranslationUnit_t& unit, Location& loc){

ASRUtils::impl_function instantiate_function =
ASRUtils::IntrinsicScalarFunctionRegistry::get_instantiate_function(
static_cast<int64_t>(ASRUtils::IntrinsicScalarFunctions::FMA));
Vec<ASR::ttype_t*> arg_types;
ASR::ttype_t* type = ASRUtils::expr_type(arg0);
arg_types.reserve(al, 3);
arg_types.push_back(al, ASRUtils::expr_type(arg0));
arg_types.push_back(al, ASRUtils::expr_type(arg1));
arg_types.push_back(al, ASRUtils::expr_type(arg2));
Vec<ASR::call_arg_t> args;
args.reserve(al, 3);
ASR::call_arg_t arg0_, arg1_, arg2_;
Expand All @@ -680,9 +686,9 @@ namespace LCompilers {
args.push_back(al, arg1_);
arg2_.loc = arg2->base.loc, arg2_.m_value = arg2;
args.push_back(al, arg2_);
return ASRUtils::EXPR(
ASRUtils::symbol_resolve_external_generic_procedure_without_eval(
loc, v, args, current_scope, al, err));
return instantiate_function(al, loc,
unit.m_global_scope, arg_types, type, args, 0,
nullptr);
}

ASR::symbol_t* insert_fallback_vector_copy(Allocator& al, ASR::TranslationUnit_t& unit,
Expand Down
4 changes: 1 addition & 3 deletions src/libasr/pass/pass_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,7 @@ namespace LCompilers {
ASR::intentType var_intent=ASR::intentType::Local);

ASR::expr_t* get_fma(ASR::expr_t* arg0, ASR::expr_t* arg1, ASR::expr_t* arg2,
Allocator& al, ASR::TranslationUnit_t& unit, LCompilers::PassOptions& pass_options,
SymbolTable*& current_scope,Location& loc,
const std::function<void (const std::string &, const Location &)> err);
Allocator& al, ASR::TranslationUnit_t& unit, Location& loc);

ASR::expr_t* get_sign_from_value(ASR::expr_t* arg0, ASR::expr_t* arg1,
Allocator& al, ASR::TranslationUnit_t& unit,
Expand Down

0 comments on commit 6e9fb0d

Please sign in to comment.