Skip to content

Commit

Permalink
Update ASR from LFortran
Browse files Browse the repository at this point in the history
  • Loading branch information
certik committed Jun 4, 2022
1 parent 8178c43 commit 0945ecb
Show file tree
Hide file tree
Showing 26 changed files with 2,266 additions and 760 deletions.
18 changes: 13 additions & 5 deletions src/libasr/ASR.asdl
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ abi -- External ABI


stmt
= Allocate(alloc_arg* args, expr? stat, expr? errmsg)
= Allocate(alloc_arg* args, expr? stat, expr? errmsg, expr? source)
| Assign(int label, identifier variable)
| Assignment(expr target, expr value, stmt? overloaded)
| Associate(expr target, expr value)
Expand Down Expand Up @@ -202,6 +202,7 @@ stmt
| Flush(int label, expr unit, expr? err, expr? iomsg, expr? iostat)
| ListAppend(symbol a, expr ele)
| AssociateBlockCall(symbol m)
| CFPointer(expr cptr, expr fptr, expr? shape)
| BlockCall(symbol m)
| SetInsert(symbol a, expr ele)
| SetRemove(symbol a, expr ele)
Expand All @@ -218,10 +219,10 @@ expr
| ComplexConstructor(expr re, expr im, ttype type, expr? value)
| NamedExpr(expr target, expr value, ttype type)
| Compare(expr left, cmpop op, expr right, ttype type, expr? value, expr? overloaded)
| IfExp(expr test, expr body, expr orelse, ttype type)
| IfExp(expr test, expr body, expr orelse, ttype type, expr? value)
| FunctionCall(symbol name, symbol? original_name,
call_arg* args, ttype type, expr? value, expr? dt)
| DerivedTypeConstructor(symbol dt_sym, expr* args, ttype type)
| DerivedTypeConstructor(symbol dt_sym, expr* args, ttype type, expr? value)
| ImpliedDoLoop(expr* values, expr var, expr start, expr end,
expr? increment, ttype type, expr? value)
| IntegerConstant(int n, ttype type)
Expand Down Expand Up @@ -259,20 +260,26 @@ expr
| Var(symbol v)
| ArrayRef(symbol v, array_index* args, ttype type, expr? value)
| ArraySize(expr v, expr? dim, ttype type, expr? value)
| ArrayBound(expr v, expr? dim ttype type, arraybound bound,
| ArrayBound(expr v, expr? dim, ttype type, arraybound bound,
expr? value)
| ArrayTranspose(expr matrix, ttype type, expr? value)
| ArrayMatMul(expr matrix_a, expr matrix_b, ttype type, expr? value)
| ArrayPack(expr array, expr mask, expr? vector, ttype type, expr? value)
| Transfer(expr source, expr mold, expr? size, ttype type, expr? value)
| DerivedRef(expr v, symbol m, ttype type, expr? value)
| Cast(expr arg, cast_kind kind, ttype type, expr? value)
| ComplexRe(expr arg, ttype type, expr? value)
| ComplexIm(expr arg, ttype type, expr? value)
| DictItem(symbol a, expr key, expr? default, ttype type)
| DictItem(symbol a, expr key, expr? default, ttype type, expr? value)
| CLoc(expr arg, ttype type, expr? value)
| ListItem(symbol a, expr pos, ttype type, expr? value)
| TupleItem(symbol a, expr pos, ttype type, expr? value)
| ListSection(expr a, array_index section, ttype type, expr? value)
| ListPop(symbol a, expr? index, ttype type, expr? value)
| DictPop(symbol a, expr key, ttype type, expr? value)
| SetPop(symbol a, ttype type, expr? value)


-- `len` in Character:
-- >=0 ... the length of the string, known at compile time
-- -1 ... character(*), i.e., inferred at runtime
Expand Down Expand Up @@ -305,6 +312,7 @@ ttype
| Class(symbol class_type, dimension* dims)
| Dict(ttype key_type, ttype value_type)
| Pointer(ttype type)
| CPtr()

boolop = And | Or | Xor | NEqv | Eqv

Expand Down
4 changes: 3 additions & 1 deletion src/libasr/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ set(SRC
codegen/x86_assembler.cpp
codegen/asr_to_x86.cpp
codegen/asr_to_wasm.cpp

codegen/wasm_to_wat.cpp
codegen/wasm_utils.cpp

pass/param_to_const.cpp
pass/do_loops.cpp
pass/for_all.cpp
Expand Down
144 changes: 120 additions & 24 deletions src/libasr/asr_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ void set_intrinsic(ASR::symbol_t* sym) {
switch( sym->type ) {
case ASR::symbolType::Module: {
ASR::Module_t* module_sym = ASR::down_cast<ASR::Module_t>(sym);
module_sym->m_intrinsic = true;
for( auto& itr: module_sym->m_symtab->get_scope() ) {
set_intrinsic(itr.second);
}
Expand Down Expand Up @@ -302,14 +303,15 @@ bool use_overloaded(ASR::expr_t* left, ASR::expr_t* right,
ASR::ttype_t *right_type = LFortran::ASRUtils::expr_type(right);
bool found = false;
if( is_op_overloaded(op, intrinsic_op_name, curr_scope) ) {
ASR::symbol_t* sym = curr_scope->get_symbol(intrinsic_op_name);
ASR::symbol_t* sym = curr_scope->resolve_symbol(intrinsic_op_name);
ASR::symbol_t* orig_sym = ASRUtils::symbol_get_past_external(sym);
ASR::CustomOperator_t* gen_proc = ASR::down_cast<ASR::CustomOperator_t>(orig_sym);
for( size_t i = 0; i < gen_proc->n_procs && !found; i++ ) {
ASR::symbol_t* proc = gen_proc->m_procs[i];
switch(proc->type) {
case ASR::symbolType::Function: {
ASR::Function_t* func = ASR::down_cast<ASR::Function_t>(proc);
std::string matched_func_name = "";
if( func->n_args == 2 ) {
ASR::ttype_t* left_arg_type = ASRUtils::expr_type(func->m_args[0]);
ASR::ttype_t* right_arg_type = ASRUtils::expr_type(func->m_args[1]);
Expand All @@ -323,7 +325,18 @@ bool use_overloaded(ASR::expr_t* left, ASR::expr_t* right,
a_args.push_back(al, left_call_arg);
right_call_arg.loc = right->base.loc, right_call_arg.m_value = right;
a_args.push_back(al, right_call_arg);
asr = ASR::make_FunctionCall_t(al, loc, curr_scope->get_symbol(std::string(func->m_name)), orig_sym,
std::string func_name = to_lower(func->m_name);
if( curr_scope->resolve_symbol(func_name) ) {
matched_func_name = func_name;
} else {
std::string mangled_name = func_name + "@" + intrinsic_op_name;
matched_func_name = mangled_name;
}
ASR::symbol_t* a_name = curr_scope->resolve_symbol(matched_func_name);
if( a_name == nullptr ) {
err("Unable to resolve matched function for operator overloading, " + matched_func_name, loc);
}
asr = ASR::make_FunctionCall_t(al, loc, a_name, sym,
a_args.p, 2,
ASRUtils::expr_type(func->m_return_var),
nullptr, nullptr);
Expand Down Expand Up @@ -416,24 +429,13 @@ bool use_overloaded_assignment(ASR::expr_t* target, ASR::expr_t* value,
} else {
std::string mangled_name = subrout_name + "@~assign";
matched_subrout_name = mangled_name;
ASR::symbol_t* imported_subrout = nullptr;
if( sym->type == ASR::symbolType::ExternalSymbol &&
curr_scope->resolve_symbol(mangled_name) == nullptr) {
ASR::ExternalSymbol_t* ext_sym = ASR::down_cast<ASR::ExternalSymbol_t>(sym);
imported_subrout = ASR::down_cast<ASR::symbol_t>(
ASR::make_ExternalSymbol_t(al,
loc, curr_scope,
s2c(al, mangled_name), proc,
ext_sym->m_module_name, nullptr, 0,
subrout->m_name, ASR::accessType::Private));
curr_scope->add_symbol(mangled_name, imported_subrout);
}
}
if( curr_scope->resolve_symbol(matched_subrout_name) == nullptr ) {
err("Unable to resolve matched subroutine for assignment overloading, " + std::string(matched_subrout_name), loc);
ASR::symbol_t *a_name = curr_scope->resolve_symbol(matched_subrout_name);
if( a_name == nullptr ) {
err("Unable to resolve matched subroutine for assignment overloading, " + matched_subrout_name, loc);
}
asr = ASR::make_SubroutineCall_t(al, loc, curr_scope->resolve_symbol(matched_subrout_name), orig_sym,
a_args.p, 2, nullptr);
asr = ASR::make_SubroutineCall_t(al, loc, a_name, sym,
a_args.p, 2, nullptr);
}
}
}
Expand All @@ -458,6 +460,7 @@ bool use_overloaded(ASR::expr_t* left, ASR::expr_t* right,
switch(proc->type) {
case ASR::symbolType::Function: {
ASR::Function_t* func = ASR::down_cast<ASR::Function_t>(proc);
std::string matched_func_name = "";
if( func->n_args == 2 ) {
ASR::ttype_t* left_arg_type = ASRUtils::expr_type(func->m_args[0]);
ASR::ttype_t* right_arg_type = ASRUtils::expr_type(func->m_args[1]);
Expand All @@ -471,7 +474,18 @@ bool use_overloaded(ASR::expr_t* left, ASR::expr_t* right,
a_args.push_back(al, left_call_arg);
right_call_arg.loc = right->base.loc, right_call_arg.m_value = right;
a_args.push_back(al, right_call_arg);
asr = ASR::make_FunctionCall_t(al, loc, curr_scope->resolve_symbol(std::string(func->m_name)), orig_sym,
std::string func_name = to_lower(func->m_name);
if( curr_scope->resolve_symbol(func_name) ) {
matched_func_name = func_name;
} else {
std::string mangled_name = func_name + "@" + intrinsic_op_name;
matched_func_name = mangled_name;
}
ASR::symbol_t* a_name = curr_scope->resolve_symbol(matched_func_name);
if( a_name == nullptr ) {
err("Unable to resolve matched function for operator overloading, " + matched_func_name, loc);
}
asr = ASR::make_FunctionCall_t(al, loc, a_name, sym,
a_args.p, 2,
ASRUtils::expr_type(func->m_return_var),
nullptr, nullptr);
Expand Down Expand Up @@ -536,13 +550,26 @@ bool is_op_overloaded(ASR::cmpopType op, std::string& intrinsic_op_name,
return result;
}

bool is_parent(ASR::DerivedType_t* a, ASR::DerivedType_t* b) {
ASR::symbol_t* current_parent = b->m_parent;
while( current_parent ) {
if( current_parent == (ASR::symbol_t*) a ) {
return true;
}
LFORTRAN_ASSERT(ASR::is_a<ASR::DerivedType_t>(*current_parent));
current_parent = ASR::down_cast<ASR::DerivedType_t>(current_parent)->m_parent;
}
return false;
}

bool is_derived_type_similar(ASR::DerivedType_t* a, ASR::DerivedType_t* b) {
return a == b || is_parent(a, b) || is_parent(b, a);
}

bool types_equal(const ASR::ttype_t &a, const ASR::ttype_t &b) {
// TODO: If anyone of the input or argument is derived type then
// add support for checking member wise types and do not compare
// directly. From stdlib_string len(pattern) error.
if (b.type == ASR::ttypeType::Derived || b.type == ASR::ttypeType::Class) {
return true;
}
if (a.type == b.type) {
// TODO: check dims
// TODO: check all types
Expand Down Expand Up @@ -597,8 +624,74 @@ bool types_equal(const ASR::ttype_t &a, const ASR::ttype_t &b) {
}
break;
}
case (ASR::ttypeType::Derived) : {
ASR::Derived_t *a2 = ASR::down_cast<ASR::Derived_t>(&a);
ASR::Derived_t *b2 = ASR::down_cast<ASR::Derived_t>(&b);
ASR::DerivedType_t *a2_type = ASR::down_cast<ASR::DerivedType_t>(
ASRUtils::symbol_get_past_external(
a2->m_derived_type));
ASR::DerivedType_t *b2_type = ASR::down_cast<ASR::DerivedType_t>(
ASRUtils::symbol_get_past_external(
b2->m_derived_type));
return a2_type == b2_type;
}
case (ASR::ttypeType::Class) : {
ASR::Class_t *a2 = ASR::down_cast<ASR::Class_t>(&a);
ASR::Class_t *b2 = ASR::down_cast<ASR::Class_t>(&b);
ASR::symbol_t* a2_typesym = ASRUtils::symbol_get_past_external(a2->m_class_type);
ASR::symbol_t* b2_typesym = ASRUtils::symbol_get_past_external(b2->m_class_type);
if( a2_typesym->type != b2_typesym->type ) {
return false;
}
if( a2_typesym->type == ASR::symbolType::ClassType ) {
ASR::ClassType_t *a2_type = ASR::down_cast<ASR::ClassType_t>(a2_typesym);
ASR::ClassType_t *b2_type = ASR::down_cast<ASR::ClassType_t>(b2_typesym);
return a2_type == b2_type;
} else if( a2_typesym->type == ASR::symbolType::DerivedType ) {
ASR::DerivedType_t *a2_type = ASR::down_cast<ASR::DerivedType_t>(a2_typesym);
ASR::DerivedType_t *b2_type = ASR::down_cast<ASR::DerivedType_t>(b2_typesym);
return is_derived_type_similar(a2_type, b2_type);
}
return false;
}
default : return false;
}
} else if( a.type == ASR::ttypeType::Derived &&
b.type == ASR::ttypeType::Class ) {
ASR::Derived_t *a2 = ASR::down_cast<ASR::Derived_t>(&a);
ASR::Class_t *b2 = ASR::down_cast<ASR::Class_t>(&b);
ASR::symbol_t* a2_typesym = ASRUtils::symbol_get_past_external(a2->m_derived_type);
ASR::symbol_t* b2_typesym = ASRUtils::symbol_get_past_external(b2->m_class_type);
if( a2_typesym->type != b2_typesym->type ) {
return false;
}
if( a2_typesym->type == ASR::symbolType::ClassType ) {
ASR::ClassType_t *a2_type = ASR::down_cast<ASR::ClassType_t>(a2_typesym);
ASR::ClassType_t *b2_type = ASR::down_cast<ASR::ClassType_t>(b2_typesym);
return a2_type == b2_type;
} else if( a2_typesym->type == ASR::symbolType::DerivedType ) {
ASR::DerivedType_t *a2_type = ASR::down_cast<ASR::DerivedType_t>(a2_typesym);
ASR::DerivedType_t *b2_type = ASR::down_cast<ASR::DerivedType_t>(b2_typesym);
return is_derived_type_similar(a2_type, b2_type);
}
} else if( a.type == ASR::ttypeType::Class &&
b.type == ASR::ttypeType::Derived ) {
ASR::Class_t *a2 = ASR::down_cast<ASR::Class_t>(&a);
ASR::Derived_t *b2 = ASR::down_cast<ASR::Derived_t>(&b);
ASR::symbol_t* a2_typesym = ASRUtils::symbol_get_past_external(a2->m_class_type);
ASR::symbol_t* b2_typesym = ASRUtils::symbol_get_past_external(b2->m_derived_type);
if( a2_typesym->type != b2_typesym->type ) {
return false;
}
if( a2_typesym->type == ASR::symbolType::ClassType ) {
ASR::ClassType_t *a2_type = ASR::down_cast<ASR::ClassType_t>(a2_typesym);
ASR::ClassType_t *b2_type = ASR::down_cast<ASR::ClassType_t>(b2_typesym);
return a2_type == b2_type;
} else if( a2_typesym->type == ASR::symbolType::DerivedType ) {
ASR::DerivedType_t *a2_type = ASR::down_cast<ASR::DerivedType_t>(a2_typesym);
ASR::DerivedType_t *b2_type = ASR::down_cast<ASR::DerivedType_t>(b2_typesym);
return is_derived_type_similar(a2_type, b2_type);
}
}
return false;
}
Expand Down Expand Up @@ -651,7 +744,8 @@ bool select_func_subrout(const ASR::symbol_t* proc, const Vec<ASR::call_arg_t>&

int select_generic_procedure(const Vec<ASR::call_arg_t>& args,
const ASR::GenericProcedure_t &p, Location loc,
const std::function<void (const std::string &, const Location &)> err) {
const std::function<void (const std::string &, const Location &)> err,
bool raise_error) {
for (size_t i=0; i < p.n_procs; i++) {
if( ASR::is_a<ASR::ClassProcedure_t>(*p.m_procs[i]) ) {
ASR::ClassProcedure_t *clss_fn
Expand All @@ -666,7 +760,9 @@ int select_generic_procedure(const Vec<ASR::call_arg_t>& args,
}
}
}
err("Arguments do not match for any generic procedure", loc);
if( raise_error ) {
err("Arguments do not match for any generic procedure, " + std::string(p.m_name), loc);
}
return -1;
}

Expand Down
Loading

0 comments on commit 0945ecb

Please sign in to comment.