Skip to content

Commit

Permalink
Add FMA in intrinsic registry
Browse files Browse the repository at this point in the history
  • Loading branch information
Smit-create committed Aug 17, 2023
1 parent 91207c9 commit fad2ad3
Showing 1 changed file with 78 additions and 0 deletions.
78 changes: 78 additions & 0 deletions src/libasr/pass/intrinsic_function_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ enum class IntrinsicScalarFunctions : int64_t {
Exp,
Exp2,
Expm1,
FMA,
ListIndex,
Partition,
ListReverse,
Expand Down Expand Up @@ -93,6 +94,7 @@ inline std::string get_intrinsic_name(int x) {
INTRINSIC_NAME_CASE(Exp)
INTRINSIC_NAME_CASE(Exp2)
INTRINSIC_NAME_CASE(Expm1)
INTRINSIC_NAME_CASE(FMA)
INTRINSIC_NAME_CASE(ListIndex)
INTRINSIC_NAME_CASE(Partition)
INTRINSIC_NAME_CASE(ListReverse)
Expand Down Expand Up @@ -1281,6 +1283,82 @@ namespace Sign {

} // namespace Sign

namespace FMA {

static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, diag::Diagnostics& diagnostics) {
ASRUtils::require_impl(x.n_args == 3,
"ASR Verify: Call to FMA must have exactly 3 arguments",
x.base.base.loc, diagnostics);
ASR::ttype_t *type1 = ASRUtils::expr_type(x.m_args[0]);
ASR::ttype_t *type2 = ASRUtils::expr_type(x.m_args[1]);
ASR::ttype_t *type3 = ASRUtils::expr_type(x.m_args[2]);
ASRUtils::require_impl((is_real(*type1) && is_real(*type2) && is_real(*type3)),
"ASR Verify: Arguments to FMA must be of real type",
x.base.base.loc, diagnostics);
}

static ASR::expr_t *eval_FMA(Allocator &al, const Location &loc,
ASR::ttype_t* t1, Vec<ASR::expr_t*> &args) {
double a = ASR::down_cast<ASR::RealConstant_t>(args[0])->m_r;
double b = ASR::down_cast<ASR::RealConstant_t>(args[1])->m_r;
double c = ASR::down_cast<ASR::RealConstant_t>(args[2])->m_r;
return make_ConstantWithType(make_RealConstant_t, a + b*c, t1, loc);
}

static inline ASR::asr_t* create_FMA(Allocator& al, const Location& loc,
Vec<ASR::expr_t*>& args,
const std::function<void (const std::string &, const Location &)> err) {
if (args.size() != 3) {
err("Intrinsic FMA function accepts exactly 3 arguments", loc);
}
ASR::ttype_t *type1 = ASRUtils::expr_type(args[0]);
ASR::ttype_t *type2 = ASRUtils::expr_type(args[1]);
ASR::ttype_t *type3 = ASRUtils::expr_type(args[2]);
if (!ASRUtils::is_real(*type1) || !ASRUtils::is_real(*type2) || !ASRUtils::is_real(*type3)) {
err("Argument of the FMA function must be Real",
args[0]->base.loc);
}
ASR::expr_t *m_value = nullptr;
if (all_args_evaluated(args)) {
Vec<ASR::expr_t*> arg_values; arg_values.reserve(al, 3);
arg_values.push_back(al, expr_value(args[0]));
arg_values.push_back(al, expr_value(args[1]));
arg_values.push_back(al, expr_value(args[2]));
m_value = eval_FMA(al, loc, expr_type(args[0]), arg_values);
}
return ASR::make_IntrinsicScalarFunction_t(al, loc,
static_cast<int64_t>(IntrinsicScalarFunctions::FMA),
args.p, args.n, 0, ASRUtils::expr_type(args[0]), m_value);
}

static inline ASR::expr_t* instantiate_FMA(Allocator &al, const Location &loc,
SymbolTable *scope, Vec<ASR::ttype_t*>& arg_types, ASR::ttype_t *return_type,
Vec<ASR::call_arg_t>& new_args, int64_t /*overload_id*/,
ASR::expr_t* compile_time_value) {
if (compile_time_value) {
return compile_time_value;
}
declare_basic_variables("_lcompilers_optimization_fma_" + type_to_str_python(arg_types[0]));
fill_func_arg("a", arg_types[0]);
fill_func_arg("b", arg_types[0]);
fill_func_arg("c", arg_types[0]);
auto result = declare(fn_name, return_type, ReturnVar);
/*
* result = a + b*c
*/

ASR::expr_t *op1 = b.ElementalMul(args[1], args[2], loc);
body.push_back(al, b.Assignment(result,
b.ElementalAdd(args[0], op1, loc)));

ASR::symbol_t *f_sym = make_Function_t(fn_name, fn_symtab, dep, args,
body, result, Source, Implementation, nullptr);
scope->add_symbol(fn_name, f_sym);
return b.Call(f_sym, new_args, return_type, nullptr);
}

} // namespace FMA

#define create_exp_macro(X, stdeval) \
namespace X { \
static inline ASR::expr_t* eval_##X(Allocator &al, const Location &loc, \
Expand Down

0 comments on commit fad2ad3

Please sign in to comment.