Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Instead of replacing the intrinsic function, use the compile-time value #2287

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 9 additions & 41 deletions src/libasr/pass/intrinsic_array_function_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -809,11 +809,7 @@ namespace Shape {

static inline ASR::expr_t* instantiate_Shape(Allocator &al,
const Location &loc, SymbolTable *scope, Vec<ASR::ttype_t*>& arg_types,
ASR::ttype_t *return_type, Vec<ASR::call_arg_t>& new_args,
int64_t, ASR::expr_t* compile_time_value) {
if (compile_time_value) {
return compile_time_value;
}
ASR::ttype_t *return_type, Vec<ASR::call_arg_t>& new_args, int64_t) {
declare_basic_variables("_lcompilers_shape");
fill_func_arg("source", arg_types[0]);
auto result = declare(fn_name, return_type, ReturnVar);
Expand Down Expand Up @@ -1020,11 +1016,7 @@ namespace Any {

static inline ASR::expr_t* instantiate_Any(Allocator &al, const Location &loc,
SymbolTable *scope, Vec<ASR::ttype_t*>& arg_types, ASR::ttype_t *logical_return_type,
Vec<ASR::call_arg_t>& new_args, int64_t overload_id,
ASR::expr_t* compile_time_value) {
if (compile_time_value) {
return compile_time_value;
}
Vec<ASR::call_arg_t>& new_args, int64_t overload_id) {
ASRBuilder builder(al, loc);
ASRBuilder& b = builder;
ASR::ttype_t* arg_type = arg_types[0];
Expand Down Expand Up @@ -1142,10 +1134,7 @@ namespace Sum {
static inline ASR::expr_t* instantiate_Sum(Allocator &al,
const Location &loc, SymbolTable *scope, Vec<ASR::ttype_t*>& arg_types,
ASR::ttype_t *return_type, Vec<ASR::call_arg_t>& new_args,
int64_t overload_id, ASR::expr_t* compile_time_value) {
if (compile_time_value) {
return compile_time_value;
}
int64_t overload_id) {
return ArrIntrinsic::instantiate_ArrIntrinsic(al, loc, scope, arg_types,
return_type, new_args, overload_id, IntrinsicArrayFunctions::Sum,
&get_constant_zero_with_given_type, &ASRBuilder::ElementalAdd);
Expand Down Expand Up @@ -1176,10 +1165,7 @@ namespace Product {
static inline ASR::expr_t* instantiate_Product(Allocator &al,
const Location &loc, SymbolTable *scope, Vec<ASR::ttype_t*>& arg_types,
ASR::ttype_t *return_type, Vec<ASR::call_arg_t>& new_args,
int64_t overload_id, ASR::expr_t* compile_time_value) {
if (compile_time_value) {
return compile_time_value;
}
int64_t overload_id) {
return ArrIntrinsic::instantiate_ArrIntrinsic(al, loc, scope, arg_types,
return_type, new_args, overload_id, IntrinsicArrayFunctions::Product,
&get_constant_one_with_given_type, &ASRBuilder::ElementalMul);
Expand Down Expand Up @@ -1210,10 +1196,7 @@ namespace MaxVal {
static inline ASR::expr_t* instantiate_MaxVal(Allocator &al,
const Location &loc, SymbolTable *scope, Vec<ASR::ttype_t*>& arg_types,
ASR::ttype_t *return_type, Vec<ASR::call_arg_t>& new_args,
int64_t overload_id, ASR::expr_t* compile_time_value) {
if (compile_time_value) {
return compile_time_value;
}
int64_t overload_id) {
return ArrIntrinsic::instantiate_ArrIntrinsic(al, loc, scope, arg_types,
return_type, new_args, overload_id, IntrinsicArrayFunctions::MaxVal,
&get_minimum_value_with_given_type, &ASRBuilder::ElementalMax);
Expand All @@ -1238,11 +1221,7 @@ namespace MaxLoc {
static inline ASR::expr_t *instantiate_MaxLoc(Allocator &al,
const Location &loc, SymbolTable *scope,
Vec<ASR::ttype_t*>& arg_types, ASR::ttype_t *return_type,
Vec<ASR::call_arg_t>& m_args, int64_t overload_id,
ASR::expr_t* compile_time_value) {
if (compile_time_value) {
return compile_time_value;
}
Vec<ASR::call_arg_t>& m_args, int64_t overload_id) {
return ArrIntrinsic::instantiate_MaxMinLoc(al, loc, scope,
static_cast<int>(IntrinsicArrayFunctions::MaxLoc), arg_types, return_type,
m_args, overload_id);
Expand Down Expand Up @@ -1366,11 +1345,7 @@ namespace Merge {
static inline ASR::expr_t* instantiate_Merge(Allocator &al,
const Location &loc, SymbolTable *scope,
Vec<ASR::ttype_t*>& arg_types, ASR::ttype_t *return_type,
Vec<ASR::call_arg_t>& new_args, int64_t /*overload_id*/,
ASR::expr_t* compile_time_value) {
if (compile_time_value) {
return compile_time_value;
}
Vec<ASR::call_arg_t>& new_args, int64_t /*overload_id*/) {
LCOMPILERS_ASSERT(arg_types.size() == 3);

// Array inputs should be elementalised in array_op pass already
Expand Down Expand Up @@ -1442,10 +1417,7 @@ namespace MinVal {
static inline ASR::expr_t* instantiate_MinVal(Allocator &al,
const Location &loc, SymbolTable *scope, Vec<ASR::ttype_t*>& arg_types,
ASR::ttype_t *return_type, Vec<ASR::call_arg_t>& new_args,
int64_t overload_id, ASR::expr_t* compile_time_value) {
if (compile_time_value) {
return compile_time_value;
}
int64_t overload_id) {
return ArrIntrinsic::instantiate_ArrIntrinsic(al, loc, scope, arg_types,
return_type, new_args, overload_id, IntrinsicArrayFunctions::MinVal,
&get_maximum_value_with_given_type, &ASRBuilder::ElementalMin);
Expand All @@ -1470,11 +1442,7 @@ namespace MinLoc {
static inline ASR::expr_t *instantiate_MinLoc(Allocator &al,
const Location &loc, SymbolTable *scope,
Vec<ASR::ttype_t*>& arg_types, ASR::ttype_t *return_type,
Vec<ASR::call_arg_t>& m_args, int64_t overload_id,
ASR::expr_t* compile_time_value) {
if (compile_time_value) {
return compile_time_value;
}
Vec<ASR::call_arg_t>& m_args, int64_t overload_id) {
return ArrIntrinsic::instantiate_MaxMinLoc(al, loc, scope,
static_cast<int>(IntrinsicArrayFunctions::MinLoc), arg_types, return_type,
m_args, overload_id);
Expand Down
59 changes: 31 additions & 28 deletions src/libasr/pass/intrinsic_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,21 @@ class ReplaceIntrinsicFunctions: public ASR::BaseExprReplacer<ReplaceIntrinsicFu


void replace_IntrinsicScalarFunction(ASR::IntrinsicScalarFunction_t* x) {
Vec<ASR::call_arg_t> new_args;
if (x->m_value) {
*current_expr = x->m_value;
return;
}
Vec<ASR::call_arg_t> new_args; new_args.reserve(al, x->n_args);
// Replace any IntrinsicScalarFunctions in the argument first:
{
new_args.reserve(al, x->n_args);
for( size_t i = 0; i < x->n_args; i++ ) {
ASR::expr_t** current_expr_copy_ = current_expr;
current_expr = &(x->m_args[i]);
replace_expr(x->m_args[i]);
ASR::call_arg_t arg0;
arg0.loc = (*current_expr)->base.loc;
arg0.m_value = *current_expr; // Use the converted arg
new_args.push_back(al, arg0);
current_expr = current_expr_copy_;
}
for( size_t i = 0; i < x->n_args; i++ ) {
ASR::expr_t** current_expr_copy_ = current_expr;
current_expr = &(x->m_args[i]);
replace_expr(x->m_args[i]);
ASR::call_arg_t arg0;
arg0.loc = (*current_expr)->base.loc;
arg0.m_value = *current_expr; // Use the converted arg
new_args.push_back(al, arg0);
current_expr = current_expr_copy_;
}
// TODO: currently we always instantiate a new function.
// Rather we should reuse the old instantiation if it has
Expand All @@ -73,7 +74,7 @@ class ReplaceIntrinsicFunctions: public ASR::BaseExprReplacer<ReplaceIntrinsicFu
arg_types.push_back(al, ASRUtils::expr_type(x->m_args[i]));
}
ASR::expr_t* current_expr_ = instantiate_function(al, x->base.base.loc,
global_scope, arg_types, x->m_type, new_args, x->m_overload_id, x->m_value);
global_scope, arg_types, x->m_type, new_args, x->m_overload_id);
if( ASR::is_a<ASR::ArrayPhysicalCast_t>(*(*current_expr)) ) {
ASR::ArrayPhysicalCast_t* array_physical_cast_t = ASR::down_cast<ASR::ArrayPhysicalCast_t>(*current_expr);
array_physical_cast_t->m_arg = current_expr_;
Expand All @@ -83,21 +84,23 @@ class ReplaceIntrinsicFunctions: public ASR::BaseExprReplacer<ReplaceIntrinsicFu
}

void replace_IntrinsicArrayFunction(ASR::IntrinsicArrayFunction_t* x) {
Vec<ASR::call_arg_t> new_args;
if (x->m_value) {
*current_expr = x->m_value;
return;
}
Vec<ASR::call_arg_t> new_args; new_args.reserve(al, x->n_args);
// Replace any IntrinsicArrayFunctions in the argument first:
{
new_args.reserve(al, x->n_args);
for( size_t i = 0; i < x->n_args; i++ ) {
ASR::expr_t** current_expr_copy_ = current_expr;
current_expr = &(x->m_args[i]);
replace_expr(x->m_args[i]);
ASR::call_arg_t arg0;
arg0.loc = (*current_expr)->base.loc;
arg0.m_value = *current_expr; // Use the converted arg
new_args.push_back(al, arg0);
current_expr = current_expr_copy_;
}
for( size_t i = 0; i < x->n_args; i++ ) {
ASR::expr_t** current_expr_copy_ = current_expr;
current_expr = &(x->m_args[i]);
replace_expr(x->m_args[i]);
ASR::call_arg_t arg0;
arg0.loc = (*current_expr)->base.loc;
arg0.m_value = *current_expr; // Use the converted arg
new_args.push_back(al, arg0);
current_expr = current_expr_copy_;
}

// TODO: currently we always instantiate a new function.
// Rather we should reuse the old instantiation if it has
// exactly the same arguments. For that we could use the
Expand All @@ -115,7 +118,7 @@ class ReplaceIntrinsicFunctions: public ASR::BaseExprReplacer<ReplaceIntrinsicFu
arg_types.push_back(al, ASRUtils::expr_type(x->m_args[i]));
}
ASR::expr_t* current_expr_ = instantiate_function(al, x->base.base.loc,
global_scope, arg_types, x->m_type, new_args, x->m_overload_id, x->m_value);
global_scope, arg_types, x->m_type, new_args, x->m_overload_id);
ASR::expr_t* func_call = current_expr_;
if( ASR::is_a<ASR::ArrayPhysicalCast_t>(*(*current_expr)) ) {
ASR::ArrayPhysicalCast_t* array_physical_cast_t = ASR::down_cast<ASR::ArrayPhysicalCast_t>(*current_expr);
Expand Down
66 changes: 20 additions & 46 deletions src/libasr/pass/intrinsic_function_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ inline std::string get_intrinsic_name(int x) {
typedef ASR::expr_t* (*impl_function)(
Allocator&, const Location &,
SymbolTable*, Vec<ASR::ttype_t*>&, ASR::ttype_t *,
Vec<ASR::call_arg_t>&, int64_t, ASR::expr_t*);
Vec<ASR::call_arg_t>&, int64_t);

typedef ASR::expr_t* (*eval_intrinsic_function)(
Allocator&, const Location &, ASR::ttype_t *,
Expand Down Expand Up @@ -709,11 +709,7 @@ namespace UnaryIntrinsicFunction {
static inline ASR::expr_t* instantiate_functions(Allocator &al,
const Location &loc, SymbolTable *scope, std::string new_name,
ASR::ttype_t *arg_type, ASR::ttype_t *return_type,
Vec<ASR::call_arg_t>& new_args, int64_t /*overload_id*/,
ASR::expr_t *value) {
if (value) {
return value;
}
Vec<ASR::call_arg_t>& new_args, int64_t /*overload_id*/) {
std::string c_func_name;
switch (arg_type->type) {
case ASR::ttypeType::Complex : {
Expand All @@ -738,7 +734,7 @@ static inline ASR::expr_t* instantiate_functions(Allocator &al,
if (scope->get_symbol(new_name)) {
ASR::symbol_t *s = scope->get_symbol(new_name);
ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(s);
return b.Call(s, new_args, expr_type(f->m_return_var), value);
return b.Call(s, new_args, expr_type(f->m_return_var));
}
fill_func_arg("x", arg_type);
auto result = declare(new_name, return_type, ReturnVar);
Expand Down Expand Up @@ -768,7 +764,7 @@ static inline ASR::expr_t* instantiate_functions(Allocator &al,
ASR::symbol_t *new_symbol = make_Function_t(fn_name, fn_symtab, dep, args,
body, result, Source, Implementation, nullptr);
scope->add_symbol(fn_name, new_symbol);
return b.Call(new_symbol, new_args, return_type, value);
return b.Call(new_symbol, new_args, return_type);
}

static inline ASR::asr_t* create_UnaryFunction(Allocator& al, const Location& loc,
Expand Down Expand Up @@ -904,7 +900,9 @@ static inline ASR::asr_t* create_LogGamma(Allocator& al, const Location& loc,
const std::function<void (const std::string &, const Location &)> err) {
ASR::ttype_t *type = ASRUtils::expr_type(args[0]);

if (!ASRUtils::is_real(*type)) {
if (args.n != 1) {
err("Intrinsic `log_gamma` accepts exactly one argument", loc);
} else if (!ASRUtils::is_real(*type)) {
err("`x` argument of `log_gamma` must be real",
args[0]->base.loc);
}
Expand All @@ -917,12 +915,11 @@ static inline ASR::asr_t* create_LogGamma(Allocator& al, const Location& loc,
static inline ASR::expr_t* instantiate_LogGamma (Allocator &al,
const Location &loc, SymbolTable *scope, Vec<ASR::ttype_t*>& arg_types,
ASR::ttype_t *return_type, Vec<ASR::call_arg_t>& new_args,
int64_t overload_id,ASR::expr_t* compile_time_value) {
if (compile_time_value) return compile_time_value;
int64_t overload_id) {
LCOMPILERS_ASSERT(arg_types.size() == 1);
ASR::ttype_t* arg_type = arg_types[0];
return UnaryIntrinsicFunction::instantiate_functions(al, loc, scope,
"log_gamma", arg_type, return_type, new_args, overload_id, nullptr);
"log_gamma", arg_type, return_type, new_args, overload_id);
}

} // namespace LogGamma
Expand Down Expand Up @@ -956,7 +953,9 @@ namespace X {
const std::function<void (const std::string &, const Location &)> err) \
{ \
ASR::ttype_t *type = ASRUtils::expr_type(args[0]); \
if (!ASRUtils::is_real(*type) && !ASRUtils::is_complex(*type)) { \
if (args.n != 1) { \
err("Intrinsic `"#X"` accepts exactly one argument", loc); \
} else if (!ASRUtils::is_real(*type) && !ASRUtils::is_complex(*type)) { \
err("`x` argument of `"#X"` must be real or complex", \
args[0]->base.loc); \
} \
Expand All @@ -967,14 +966,10 @@ namespace X {
static inline ASR::expr_t* instantiate_##X (Allocator &al, \
const Location &loc, SymbolTable *scope, \
Vec<ASR::ttype_t*>& arg_types, ASR::ttype_t *return_type, \
Vec<ASR::call_arg_t>& new_args,int64_t overload_id, \
ASR::expr_t* compile_time_value) { \
if (compile_time_value) return compile_time_value; \
LCOMPILERS_ASSERT(arg_types.size() == 1); \
Vec<ASR::call_arg_t>& new_args,int64_t overload_id) { \
ASR::ttype_t* arg_type = arg_types[0]; \
return UnaryIntrinsicFunction::instantiate_functions(al, loc, scope, \
#lcompilers_name, arg_type, return_type, new_args, overload_id, \
nullptr); \
#lcompilers_name, arg_type, return_type, new_args, overload_id); \
} \
} // namespace X

Expand Down Expand Up @@ -1068,10 +1063,7 @@ namespace Abs {

static inline ASR::expr_t* instantiate_Abs(Allocator &al, const Location &loc,
SymbolTable *scope, Vec<ASR::ttype_t*>& arg_types, ASR::ttype_t *return_type,
Vec<ASR::call_arg_t>& new_args, int64_t /*overload_id*/, ASR::expr_t* compile_time_value) {
if (compile_time_value) {
return compile_time_value;
}
Vec<ASR::call_arg_t>& new_args, int64_t /*overload_id*/) {
std::string func_name = "_lcompilers_abs_" + type_to_str_python(arg_types[0]);
declare_basic_variables(func_name);
if (scope->get_symbol(func_name)) {
Expand Down Expand Up @@ -1238,11 +1230,7 @@ namespace Sign {

static inline ASR::expr_t* instantiate_Sign(Allocator &al, const Location &loc,
SymbolTable *scope, Vec<ASR::ttype_t*>& arg_types, ASR::ttype_t *return_type,
Vec<ASR::call_arg_t>& new_args, int64_t /*overload_id*/,
ASR::expr_t* compile_time_value) {
if (compile_time_value) {
return compile_time_value;
}
Vec<ASR::call_arg_t>& new_args, int64_t /*overload_id*/) {
declare_basic_variables("_lcompilers_sign_" + type_to_str_python(arg_types[0]));
fill_func_arg("x", arg_types[0]);
fill_func_arg("y", arg_types[0]);
Expand Down Expand Up @@ -1333,11 +1321,7 @@ namespace FMA {

static inline ASR::expr_t* instantiate_FMA(Allocator &al, const Location &loc,
SymbolTable *scope, Vec<ASR::ttype_t*>& arg_types, ASR::ttype_t *return_type,
Vec<ASR::call_arg_t>& new_args, int64_t /*overload_id*/,
ASR::expr_t* compile_time_value) {
if (compile_time_value) {
return compile_time_value;
}
Vec<ASR::call_arg_t>& new_args, int64_t /*overload_id*/) {
declare_basic_variables("_lcompilers_optimization_fma_" + type_to_str_python(arg_types[0]));
fill_func_arg("a", arg_types[0]);
fill_func_arg("b", arg_types[0]);
Expand Down Expand Up @@ -1880,10 +1864,7 @@ namespace Max {

static inline ASR::expr_t* instantiate_Max(Allocator &al, const Location &loc,
SymbolTable *scope, Vec<ASR::ttype_t*>& arg_types, ASR::ttype_t *return_type,
Vec<ASR::call_arg_t>& new_args, int64_t /*overload_id*/, ASR::expr_t* compile_time_value) {
if (compile_time_value) {
return compile_time_value;
}
Vec<ASR::call_arg_t>& new_args, int64_t /*overload_id*/) {
std::string func_name = "_lcompilers_max0_" + type_to_str_python(arg_types[0]);
std::string fn_name = scope->get_unique_name(func_name);
SymbolTable *fn_symtab = al.make_new<SymbolTable>(scope);
Expand Down Expand Up @@ -1995,10 +1976,7 @@ namespace Min {

static inline ASR::expr_t* instantiate_Min(Allocator &al, const Location &loc,
SymbolTable *scope, Vec<ASR::ttype_t*>& arg_types, ASR::ttype_t *return_type,
Vec<ASR::call_arg_t>& new_args, int64_t /*overload_id*/, ASR::expr_t* compile_time_value) {
if (compile_time_value) {
return compile_time_value;
}
Vec<ASR::call_arg_t>& new_args, int64_t /*overload_id*/) {
std::string func_name = "_lcompilers_min0_" + type_to_str_python(arg_types[0]);
std::string fn_name = scope->get_unique_name(func_name);
SymbolTable *fn_symtab = al.make_new<SymbolTable>(scope);
Expand Down Expand Up @@ -2130,11 +2108,7 @@ namespace Partition {
static inline ASR::expr_t *instantiate_Partition(Allocator &al,
const Location &loc, SymbolTable *scope,
Vec<ASR::ttype_t*>& /*arg_types*/, ASR::ttype_t *return_type,
Vec<ASR::call_arg_t>& new_args, int64_t /*overload_id*/,
ASR::expr_t* compile_time_value) {
if (compile_time_value) {
return compile_time_value;
}
Vec<ASR::call_arg_t>& new_args, int64_t /*overload_id*/) {
// TODO: show runtime error for empty separator or pattern
declare_basic_variables("_lpython_str_partition");
fill_func_arg("target_string", character(-2));
Expand Down
Loading
Loading