Skip to content

Commit

Permalink
Added IntrinsicFunction in ASR (lcompilers#1616)
Browse files Browse the repository at this point in the history
Co-authored-by: Thirumalai-Shaktivel <thirumalaishaktivel@gmail.com>
  • Loading branch information
anutosh491 and Thirumalai-Shaktivel authored Apr 5, 2023
1 parent dfb552e commit b35e744
Show file tree
Hide file tree
Showing 51 changed files with 971 additions and 127 deletions.
107 changes: 107 additions & 0 deletions doc/src/asr/asr_nodes/expression_nodes/IntrinsicFunction.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# IntrinsicFunction

An intrinsic function. An **expr** node.

## Declaration

### Syntax

```
IntrinsicFunction(expr* args, int intrinsic_id, int overload_id,
ttype type, expr? value)
```

### Arguments

* `args` represents all arguments passed to the function
* `intrinsic_id` is the unique ID of the generic intrinsic function
* `overload_id` is the ID of the signature within the given generic function
* `type` represents the type of the output
* `value` is an optional compile time value

### Return values

The return value is the expression that the `IntrinsicFunction` represents.

## Description

**IntrinsicFunction** represents an intrinsic function (such as `Abs`,
`Modulo`, `Sin`, `Cos`, `LegendreP`, `FlipSign`, ...) that either the backend
or the middle-end (optimizer) needs to have some special logic for. Typically a
math function, but does not have to be.

IntrinsicFunction is both side-effect-free (no writes to global variables) and
deterministic (no reads from global variables). They are also elemental: can be
vectorized over any argument(s). They can be used inside parallel code and
cached.

The `intrinsic_id` determines the generic function uniquely (`Sin` and `Abs`
have different number, but `IntegerAbs` and `RealAbs` share the number) and
`overload_id` uniquely determines the signature starting from 0 for each
generic function (e.g., `IntegerAbs`, `RealAbs` and `ComplexAbs` can have
`overload_id` equal to 0, 1 and 2, and `RealSin`, `ComplexSin` can be 0, 1).

Backend use cases: Some architectures have special hardware instructions for
operations like Sqrt or Sin and if they are faster than a software
implementation, the backend will use it. This includes the `FlipSign` function
which is our own "special function" that the optimizer emits for certain
conditional floating point operations, and the backend emits an efficient bit
manipulation implementation for architectures that support it.

Middle-end use cases: the middle-end can use the high level semantics to
simplify, such as `sin(e)**2 + cos(e)**2 -> 1`, or it could approximate
expressions like `if (abs(sin(x) - 0.5) < 0.3)` with a lower accuracy version
of `sin`.

We provide ASR -> ASR lowering transformations that substitute the given
intrinsic function with an ASR implementation using more primitive ASR nodes,
typically implemented in the surface language (say a `sin` implementation using
argument reduction and a polynomial fit, or a `sqrt` implementation using a
general power formula `x**(0.5)`, or `LegendreP(2,x)` implementation using a
formula `(3*x**2-1)/2`).

This design also makes it possible to allow selecting using command line
options how certain intrinsic functions should be implemented, for example if
trigonometric functions should be implemented using our own fast
implementation, `libm` accurate implementation, we could also call into other
libraries. These choices should happen at the ASR level, and then the result
further optimized (such as inlined) as needed.

## Types

The argument types in `args` have the types of the corresponding signature as
determined by `intrinsic_id`. For example `IntegerAbs` accepts an integer, but
`RealAbs` accepts a real.

## Examples

The following example code creates `IntrinsicFunction` ASR node:

```fortran
sin(0.5)
```

ASR:

```
(TranslationUnit
(SymbolTable
1
{
})
[(IntrinsicFunction
[(RealConstant
0.500000
(Real 4 [])
)]
0
0
(Real 4 [])
(RealConstant 0.479426 (Real 4 []))
)]
)
```

## See Also

[FunctionCall]()
4 changes: 2 additions & 2 deletions integration_tests/test_builtin_abs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def test_abs():

b: bool
b = True
assert abs(b) == 1
assert abs(i32(b)) == 1
b = False
assert abs(b) == 0
assert abs(i32(b)) == 0

test_abs()
2 changes: 2 additions & 0 deletions src/libasr/ASR.asdl
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ expr
| NamedExpr(expr target, expr value, ttype type)
| FunctionCall(symbol name, symbol? original_name, call_arg* args,
ttype type, expr? value, expr? dt)
| IntrinsicFunction(int intrinsic_id, expr* args, int overload_id,
ttype type, expr? value)
| StructTypeConstructor(symbol dt_sym, call_arg* args, ttype type, expr? value)
| EnumTypeConstructor(symbol dt_sym, expr* args, ttype type, expr? value)
| UnionTypeConstructor(symbol dt_sym, expr* args, ttype type, expr? value)
Expand Down
1 change: 1 addition & 0 deletions src/libasr/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ set(SRC
pass/unused_functions.cpp
pass/flip_sign.cpp
pass/div_to_mul.cpp
pass/intrinsic_function.cpp
pass/fma.cpp
pass/loop_vectorise.cpp
pass/sign_from_value.cpp
Expand Down
5 changes: 4 additions & 1 deletion src/libasr/asdl_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1520,7 +1520,10 @@ def visitField(self, field, cons):
self.emit( 's.append("()");', 3)
self.emit("}", 2)
else:
self.emit('s.append(std::to_string(x.m_%s));' % field.name, 2)
if field.name == "intrinsic_id":
self.emit('s.append(self().convert_intrinsic_id(x.m_%s));' % field.name, 2)
else:
self.emit('s.append(std::to_string(x.m_%s));' % field.name, 2)
elif field.type == "float" and not field.seq and not field.opt:
self.emit('s.append(std::to_string(x.m_%s));' % field.name, 2)
elif field.type == "bool" and not field.seq and not field.opt:
Expand Down
18 changes: 17 additions & 1 deletion src/libasr/asr_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include <libasr/string_utils.h>
#include <libasr/utils.h>

#include <complex>

namespace LCompilers {

namespace ASRUtils {
Expand Down Expand Up @@ -801,7 +803,21 @@ static inline bool all_args_evaluated(const Vec<ASR::array_index_t> &args) {
return true;
}

template <typename T>
static inline bool extract_value(ASR::expr_t* value_expr,
std::complex<double>& value) {
if( !ASR::is_a<ASR::ComplexConstant_t>(*value_expr) ) {
return false;
}

ASR::ComplexConstant_t* value_const = ASR::down_cast<ASR::ComplexConstant_t>(value_expr);
value = std::complex(value_const->m_re, value_const->m_im);
return true;
}

template <typename T,
typename = typename std::enable_if<
std::is_same<T, std::complex<double>>::value == false &&
std::is_same<T, std::complex<float>>::value == false>::type>
static inline bool extract_value(ASR::expr_t* value_expr, T& value) {
if( !is_value_constant(value_expr) ) {
return false;
Expand Down
25 changes: 13 additions & 12 deletions src/libasr/codegen/asr_to_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <libasr/pass/class_constructor.h>
#include <libasr/pass/array_op.h>
#include <libasr/pass/subroutine_from_function.h>
#include <libasr/pass/intrinsic_function.h>

#include <map>
#include <utility>
Expand Down Expand Up @@ -680,6 +681,18 @@ R"(
}
}

// Process global functions
size_t i;
for (i = 0; i < global_func_order.size(); i++) {
ASR::symbol_t* sym = x.m_global_scope->get_symbol(global_func_order[i]);
// Ignore external symbols because they are already defined by the loop above.
if( !sym || ASR::is_a<ASR::ExternalSymbol_t>(*sym) ) {
continue ;
}
visit_symbol(*sym);
unit_src += src;
}

// Process modules in the right order
std::vector<std::string> build_order
= ASRUtils::determine_module_dependencies(x);
Expand All @@ -693,18 +706,6 @@ R"(
}
}

// Process global functions
size_t i;
for (i = 0; i < global_func_order.size(); i++) {
ASR::symbol_t* sym = x.m_global_scope->get_symbol(global_func_order[i]);
// Ignore external symbols because they are already defined by the loop above.
if( !sym || ASR::is_a<ASR::ExternalSymbol_t>(*sym) ) {
continue ;
}
visit_symbol(*sym);
unit_src += src;
}

// Then the main program:
for (auto &item : x.m_global_scope->get_scope()) {
if (ASR::is_a<ASR::Program_t>(*item.second)) {
Expand Down
2 changes: 2 additions & 0 deletions src/libasr/codegen/asr_to_cpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <libasr/asr_utils.h>
#include <libasr/string_utils.h>
#include <libasr/pass/unused_functions.h>
#include <libasr/pass/intrinsic_function.h>


namespace LCompilers {
Expand Down Expand Up @@ -745,6 +746,7 @@ Result<std::string> asr_to_cpp(Allocator &al, ASR::TranslationUnit_t &asr,
LCompilers::PassOptions pass_options;
pass_options.always_run = true;
pass_unused_functions(al, asr, pass_options);
pass_replace_intrinsic_function(al, asr, pass_options);
ASRToCPPVisitor v(diagnostics, co, default_lower_bound);
try {
v.visit_asr((ASR::asr_t &)asr);
Expand Down
4 changes: 4 additions & 0 deletions src/libasr/codegen/asr_to_wasm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
#include <libasr/pass/unused_functions.h>
#include <libasr/pass/pass_array_by_data.h>
#include <libasr/pass/print_arr.h>
#include <libasr/pass/intrinsic_function.h>
#include <libasr/exception.h>
#include <libasr/asr_utils.h>

// #define SHOW_ASR

Expand Down Expand Up @@ -3136,6 +3139,7 @@ Result<Vec<uint8_t>> asr_to_wasm_bytes_stream(ASR::TranslationUnit_t &asr,
pass_array_by_data(al, asr, pass_options);
pass_replace_print_arr(al, asr, pass_options);
pass_replace_do_loops(al, asr, pass_options);
pass_replace_intrinsic_function(al, asr, pass_options);
pass_options.always_run = true;
pass_unused_functions(al, asr, pass_options);

Expand Down
83 changes: 83 additions & 0 deletions src/libasr/pass/array_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,89 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer<ReplaceArrayOp> {
replace_ArrayOpCommon<ASR::LogicalCompare_t>(x, "_logical_comp_op_res");
}

void replace_IntrinsicFunction(ASR::IntrinsicFunction_t* x) {
LCOMPILERS_ASSERT(current_scope != nullptr);
const Location& loc = x->base.base.loc;
std::vector<bool> array_mask(x->n_args, false);
bool at_least_one_array = false;
for( size_t iarg = 0; iarg < x->n_args; iarg++ ) {
array_mask[iarg] = ASRUtils::is_array(
ASRUtils::expr_type(x->m_args[iarg]));
at_least_one_array = at_least_one_array || array_mask[iarg];
}
if (!at_least_one_array) {
return ;
}
std::string res_prefix = "_elemental_func_call_res";
ASR::expr_t* result_var_copy = result_var;
bool is_all_rank_0 = true;
std::vector<ASR::expr_t*> operands;
ASR::expr_t* operand = nullptr;
int common_rank = 0;
bool are_all_rank_same = true;
for( size_t iarg = 0; iarg < x->n_args; iarg++ ) {
result_var = nullptr;
ASR::expr_t** current_expr_copy_9 = current_expr;
current_expr = &(x->m_args[iarg]);
self().replace_expr(x->m_args[iarg]);
operand = *current_expr;
current_expr = current_expr_copy_9;
operands.push_back(operand);
int rank_operand = PassUtils::get_rank(operand);
if( common_rank == 0 ) {
common_rank = rank_operand;
}
if( common_rank != rank_operand &&
rank_operand > 0 ) {
are_all_rank_same = false;
}
array_mask[iarg] = (rank_operand > 0);
is_all_rank_0 = is_all_rank_0 && (rank_operand <= 0);
}
if( is_all_rank_0 ) {
return ;
}
if( !are_all_rank_same ) {
throw LCompilersException("Broadcasting support not yet available "
"for different shape arrays.");
}
result_var = result_var_copy;
if( result_var == nullptr ) {
result_var = PassUtils::create_var(result_counter, res_prefix,
loc, operand, al, current_scope);
result_counter += 1;
}
*current_expr = result_var;

Vec<ASR::expr_t*> idx_vars, loop_vars;
std::vector<int> loop_var_indices;
Vec<ASR::stmt_t*> doloop_body;
create_do_loop(loc, common_rank,
idx_vars, loop_vars, loop_var_indices, doloop_body,
[=, &operands, &idx_vars, &doloop_body] () {
Vec<ASR::expr_t*> ref_args;
ref_args.reserve(al, x->n_args);
for( size_t iarg = 0; iarg < x->n_args; iarg++ ) {
ASR::expr_t* ref = operands[iarg];
if( array_mask[iarg] ) {
ref = PassUtils::create_array_ref(operands[iarg], idx_vars, al);
}
ref_args.push_back(al, ref);
}
Vec<ASR::dimension_t> empty_dim;
empty_dim.reserve(al, 1);
ASR::ttype_t* dim_less_type = ASRUtils::duplicate_type(al, x->m_type, &empty_dim);
ASR::expr_t* op_el_wise = ASRUtils::EXPR(ASR::make_IntrinsicFunction_t(al, loc,
x->m_intrinsic_id, ref_args.p, ref_args.size(), x->m_overload_id,
dim_less_type, nullptr));
ASR::expr_t* res = PassUtils::create_array_ref(result_var, idx_vars, al);
ASR::stmt_t* assign = ASRUtils::STMT(ASR::make_Assignment_t(al, loc, res, op_el_wise, nullptr));
doloop_body.push_back(al, assign);
});
use_custom_loop_params = false;
result_var = nullptr;
}

void replace_FunctionCall(ASR::FunctionCall_t* x) {
std::string x_name;
if( x->m_name->type == ASR::symbolType::ExternalSymbol ) {
Expand Down
2 changes: 1 addition & 1 deletion src/libasr/pass/global_symbols.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace LCompilers {

void pass_wrap_global_syms_into_module(Allocator &al,
ASR::TranslationUnit_t &unit,
const LCompilers::PassOptions& pass_options) {
const LCompilers::PassOptions &/*pass_options*/) {
Location loc = unit.base.base.loc;
char *module_name = s2c(al, "_global_symbols");
SymbolTable *module_scope = al.make_new<SymbolTable>(unit.m_global_scope);
Expand Down
Loading

0 comments on commit b35e744

Please sign in to comment.