Skip to content

Commit

Permalink
Merge pull request lcompilers#2313 from Smit-create/lf_2248
Browse files Browse the repository at this point in the history
Update FMA/flip_sign pass
  • Loading branch information
Smit-create committed Sep 6, 2023
2 parents f6a4606 + 588ef41 commit d0e857a
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 11 deletions.
31 changes: 30 additions & 1 deletion src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1884,6 +1884,30 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
}
break ;
}
case ASRUtils::IntrinsicScalarFunctions::FlipSign: {
Vec<ASR::call_arg_t> args;
args.reserve(al, 2);
ASR::call_arg_t arg0_, arg1_;
arg0_.loc = x.m_args[0]->base.loc, arg0_.m_value = x.m_args[0];
args.push_back(al, arg0_);
arg1_.loc = x.m_args[1]->base.loc, arg1_.m_value = x.m_args[1];
args.push_back(al, arg1_);
generate_flip_sign(args.p);
break;
}
case ASRUtils::IntrinsicScalarFunctions::FMA: {
Vec<ASR::call_arg_t> args;
args.reserve(al, 3);
ASR::call_arg_t arg0_, arg1_, arg2_;
arg0_.loc = x.m_args[0]->base.loc, arg0_.m_value = x.m_args[0];
args.push_back(al, arg0_);
arg1_.loc = x.m_args[1]->base.loc, arg1_.m_value = x.m_args[1];
args.push_back(al, arg1_);
arg2_.loc = x.m_args[2]->base.loc, arg2_.m_value = x.m_args[2];
args.push_back(al, arg2_);
generate_fma(args.p);
break;
}
default: {
throw CodeGenError( ASRUtils::IntrinsicScalarFunctionRegistry::
get_intrinsic_function_name(x.m_intrinsic_id) +
Expand Down Expand Up @@ -7372,7 +7396,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
llvm::Value* int_var = builder->CreateBitCast(CreateLoad(variable), shifted_signal->getType());
tmp = builder->CreateXor(shifted_signal, int_var);
llvm::Type* variable_type = llvm_utils->get_type_from_ttype_t_util(asr_variable->m_type, module.get());
builder->CreateStore(builder->CreateBitCast(tmp, variable_type->getPointerTo()), variable);
tmp = builder->CreateBitCast(tmp, variable_type);
}

void generate_fma(ASR::call_arg_t* m_args) {
Expand Down Expand Up @@ -8300,7 +8324,12 @@ Result<std::unique_ptr<LLVMModule>> asr_to_llvm(ASR::TranslationUnit_t &asr,
pass_options.run_fun = run_fn;
pass_options.always_run = false;
pass_options.verbose = co.verbose;
std::vector<int64_t> skip_optimization_func_instantiation;
skip_optimization_func_instantiation.push_back(static_cast<int64_t>(ASRUtils::IntrinsicScalarFunctions::FlipSign));
skip_optimization_func_instantiation.push_back(static_cast<int64_t>(ASRUtils::IntrinsicScalarFunctions::FMA));
pass_options.skip_optimization_func_instantiation = skip_optimization_func_instantiation;
pass_manager.rtlib = co.rtlib;

pass_manager.apply_passes(al, &asr, pass_options, diagnostics);

// Uncomment for debugging the ASR after the transformation
Expand Down
2 changes: 1 addition & 1 deletion src/libasr/pass/flip_sign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class FlipSignVisitor : public PassUtils::SkipOptimizationFunctionVisitor<FlipSi
LCOMPILERS_ASSERT(flip_sign_signal_variable);
LCOMPILERS_ASSERT(flip_sign_variable);
ASR::expr_t* flip_sign_result = PassUtils::get_flipsign(flip_sign_signal_variable,
flip_sign_variable, al, unit, x.base.base.loc);
flip_sign_variable, al, unit, x.base.base.loc, pass_options);
pass_result.push_back(al, ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc,
flip_sign_variable, flip_sign_result, nullptr)));
}
Expand Down
2 changes: 1 addition & 1 deletion src/libasr/pass/fma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class FMAVisitor : public PassUtils::SkipOptimizationFunctionVisitor<FMAVisitor>
}

fma_var = PassUtils::get_fma(other_expr, first_arg, second_arg,
al, unit, x.base.base.loc);
al, unit, x.base.base.loc, pass_options);
from_fma = false;
}

Expand Down
2 changes: 1 addition & 1 deletion src/libasr/pass/intrinsic_function_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -2451,7 +2451,7 @@ namespace IntrinsicScalarFunctionRegistry {
{static_cast<int64_t>(IntrinsicScalarFunctions::FMA),
{&FMA::instantiate_FMA, &FMA::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::FlipSign),
{&FlipSign::instantiate_FlipSign, &FMA::verify_args}},
{&FlipSign::instantiate_FlipSign, &FlipSign::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::Abs),
{&Abs::instantiate_Abs, &Abs::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::Partition),
Expand Down
40 changes: 35 additions & 5 deletions src/libasr/pass/pass_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -587,14 +587,34 @@ namespace LCompilers {
int32_type, bound_type, nullptr));
}

bool skip_instantiation(PassOptions pass_options, int64_t id) {
if (!pass_options.skip_optimization_func_instantiation.empty()) {
for (size_t i=0; i<pass_options.skip_optimization_func_instantiation.size(); i++) {
if (pass_options.skip_optimization_func_instantiation[i] == id) {
return true;
}
}
}
return false;
}

ASR::expr_t* get_flipsign(ASR::expr_t* arg0, ASR::expr_t* arg1,
Allocator& al, ASR::TranslationUnit_t& unit, const Location& loc){
Allocator& al, ASR::TranslationUnit_t& unit, const Location& loc,
PassOptions pass_options){
ASR::ttype_t* type = ASRUtils::expr_type(arg1);
int64_t fp_s = static_cast<int64_t>(ASRUtils::IntrinsicScalarFunctions::FlipSign);
if (skip_instantiation(pass_options, fp_s)) {
Vec<ASR::expr_t*> args;
args.reserve(al, 2);
args.push_back(al, arg0);
args.push_back(al, arg1);
return ASRUtils::EXPR(ASRUtils::make_IntrinsicScalarFunction_t_util(al, loc, fp_s,
args.p, args.n, 0, type, nullptr));
}
ASRUtils::impl_function instantiate_function =
ASRUtils::IntrinsicScalarFunctionRegistry::get_instantiate_function(
static_cast<int64_t>(ASRUtils::IntrinsicScalarFunctions::FlipSign));
Vec<ASR::ttype_t*> arg_types;
ASR::ttype_t* type = ASRUtils::expr_type(arg1);
arg_types.reserve(al, 2);
arg_types.push_back(al, ASRUtils::expr_type(arg0));
arg_types.push_back(al, ASRUtils::expr_type(arg1));
Expand Down Expand Up @@ -667,13 +687,23 @@ namespace LCompilers {
}

ASR::expr_t* get_fma(ASR::expr_t* arg0, ASR::expr_t* arg1, ASR::expr_t* arg2,
Allocator& al, ASR::TranslationUnit_t& unit, Location& loc){

Allocator& al, ASR::TranslationUnit_t& unit, Location& loc,
PassOptions pass_options){
int64_t fma_id = static_cast<int64_t>(ASRUtils::IntrinsicScalarFunctions::FMA);
ASR::ttype_t* type = ASRUtils::expr_type(arg0);
if (skip_instantiation(pass_options, fma_id)) {
Vec<ASR::expr_t*> args;
args.reserve(al, 3);
args.push_back(al, arg0);
args.push_back(al, arg1);
args.push_back(al, arg2);
return ASRUtils::EXPR(ASRUtils::make_IntrinsicScalarFunction_t_util(al, loc, fma_id,
args.p, args.n, 0, type, nullptr));
}
ASRUtils::impl_function instantiate_function =
ASRUtils::IntrinsicScalarFunctionRegistry::get_instantiate_function(
static_cast<int64_t>(ASRUtils::IntrinsicScalarFunctions::FMA));
Vec<ASR::ttype_t*> arg_types;
ASR::ttype_t* type = ASRUtils::expr_type(arg0);
arg_types.reserve(al, 3);
arg_types.push_back(al, ASRUtils::expr_type(arg0));
arg_types.push_back(al, ASRUtils::expr_type(arg1));
Expand Down
6 changes: 4 additions & 2 deletions src/libasr/pass/pass_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ namespace LCompilers {
Allocator& al);

ASR::expr_t* get_flipsign(ASR::expr_t* arg0, ASR::expr_t* arg1,
Allocator& al, ASR::TranslationUnit_t& unit, const Location& loc);
Allocator& al, ASR::TranslationUnit_t& unit, const Location& loc,
PassOptions pass_options);

ASR::expr_t* to_int32(ASR::expr_t* x, ASR::ttype_t* int32type, Allocator& al);

Expand All @@ -86,7 +87,8 @@ namespace LCompilers {
ASR::intentType var_intent=ASR::intentType::Local);

ASR::expr_t* get_fma(ASR::expr_t* arg0, ASR::expr_t* arg1, ASR::expr_t* arg2,
Allocator& al, ASR::TranslationUnit_t& unit, Location& loc);
Allocator& al, ASR::TranslationUnit_t& unit, Location& loc,
PassOptions pass_options);

ASR::expr_t* get_sign_from_value(ASR::expr_t* arg0, ASR::expr_t* arg1,
Allocator& al, ASR::TranslationUnit_t& unit,
Expand Down
1 change: 1 addition & 0 deletions src/libasr/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ namespace LCompilers {
bool verbose = false; // For developer debugging
bool pass_cumulative = false; // Apply passes cumulatively
bool disable_main = false;
std::vector<int64_t> skip_optimization_func_instantiation;
bool module_name_mangling = false;
bool global_symbols_mangling = false;
bool intrinsic_symbols_mangling = false;
Expand Down

0 comments on commit d0e857a

Please sign in to comment.