Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sync libasr from LFortran #2148

Merged
merged 6 commits into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,7 @@ RUN(NAME loop_02 LABELS cpython llvm c wasm wasm_x86 wasm_x64)
RUN(NAME loop_03 LABELS cpython llvm c wasm wasm_x64)
RUN(NAME loop_04 LABELS cpython llvm c)
RUN(NAME loop_05 LABELS cpython llvm c)
RUN(NAME loop_06 LABELS cpython llvm c NOFAST)
RUN(NAME if_01 LABELS cpython llvm c wasm wasm_x86 wasm_x64)
RUN(NAME if_02 LABELS cpython llvm c wasm wasm_x86 wasm_x64)
RUN(NAME if_03 FAIL LABELS cpython llvm c NOFAST)
Expand Down
9 changes: 9 additions & 0 deletions tests/loop2.py → integration_tests/loop_06.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
from lpython import i32
from sys import exit

def test_for():
i: i32
j: i32; k: i32;
k = 0
for i in range(0, 10):
if i == 0:
j = 0
continue
if i > 5:
k = k + i
break
if i == 3:
print(j, k)
assert j == 0
assert k == 0
quit()
print(j, k)
exit(2)

test_for()
20 changes: 17 additions & 3 deletions src/libasr/ASR.asdl
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,10 @@ symbol
identifier proc_name, symbol proc, abi abi, bool is_deferred)
| AssociateBlock(symbol_table symtab, identifier name, stmt* body)
| Block(symbol_table symtab, identifier name, stmt* body)
| Requirement(symbol_table symtab, identifier name, identifier* args)
| Template(symbol_table symtab, identifier name, identifier* args)
| Requirement(symbol_table symtab, identifier name, identifier* args,
require_instantiation* requires)
| Template(symbol_table symtab, identifier name, identifier* args,
require_instantiation* requires)

storage_type = Default | Save | Parameter
access = Public | Private
Expand Down Expand Up @@ -237,6 +239,8 @@ expr
ttype type, expr? value, expr? dt)
| IntrinsicFunction(int intrinsic_id, expr* args, int overload_id,
ttype? type, expr? value)
| IntrinsicImpureFunction(int impure_intrinsic_id, expr* args, int overload_id,
ttype? type, expr? value)
| StructTypeConstructor(symbol dt_sym, call_arg* args, ttype type, expr? value)
| EnumTypeConstructor(symbol dt_sym, expr* args, ttype type, expr? value)
| UnionTypeConstructor(symbol dt_sym, expr* args, ttype type, expr? value)
Expand Down Expand Up @@ -289,6 +293,7 @@ expr
| StringCompare(expr left, cmpop op, expr right, ttype type, expr? value)
| StringOrd(expr arg, ttype type, expr? value)
| StringChr(expr arg, ttype type, expr? value)
| StringFormat(expr fmt, expr* args, string_format_kind kind, ttype type, expr? value)

| CPtrCompare(expr left, cmpop op, expr right, ttype type, expr? value)
| SymbolicCompare(expr left, cmpop op, expr right, ttype type, expr? value)
Expand Down Expand Up @@ -461,7 +466,7 @@ cast_kind

dimension = (expr? start, expr? length)

alloc_arg = (expr a, dimension* dims)
alloc_arg = (expr a, dimension* dims, expr? len_expr, ttype? type)

attribute = Attribute(identifier name, attribute_arg *args)

Expand All @@ -484,4 +489,13 @@ type_stmt

enumtype = IntegerConsecutiveFromZero | IntegerUnique | IntegerNotUnique | NonInteger

require_instantiation = Require(identifier name, identifier* args)

string_format_kind
= FormatFortran -- "(f8.3,i4.2)", a, b
| FormatC -- "%f: %d", a, b
| FormatPythonPercent -- "%f: %d" % (a, b)
| FormatPythonFString -- f"{a}: {b}"
| FormatPythonFormat -- "{}: {}".format(a, b)

}
4 changes: 4 additions & 0 deletions src/libasr/asdl_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,6 +1109,8 @@ def visitField(self, field):
self.emit(" ASR::alloc_arg_t alloc_arg_copy;", level)
self.emit(" alloc_arg_copy.loc = x->m_%s[i].loc;"%(field.name), level)
self.emit(" alloc_arg_copy.m_a = self().duplicate_expr(x->m_%s[i].m_a);"%(field.name), level)
self.emit(" alloc_arg_copy.m_len_expr = self().duplicate_expr(x->m_%s[i].m_len_expr);"%(field.name), level)
self.emit(" alloc_arg_copy.m_type = self().duplicate_ttype(x->m_%s[i].m_type);"%(field.name), level)
self.emit(" alloc_arg_copy.n_dims = x->m_%s[i].n_dims;"%(field.name), level)
self.emit(" Vec<ASR::dimension_t> dims_copy;", level)
self.emit(" dims_copy.reserve(al, alloc_arg_copy.n_dims);", level)
Expand Down Expand Up @@ -1678,6 +1680,8 @@ def visitField(self, field, cons):
else:
if field.name == "intrinsic_id":
self.emit('s.append(self().convert_intrinsic_id(x.m_%s));' % field.name, 2)
elif field.name == "impure_intrinsic_id":
self.emit('s.append(self().convert_impure_intrinsic_id(x.m_%s));' % field.name, 2)
else:
self.emit('s.append(std::to_string(x.m_%s));' % field.name, 2)
elif field.type == "float" and not field.seq and not field.opt:
Expand Down
5 changes: 4 additions & 1 deletion src/libasr/asr_scopes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,11 @@ ASR::symbol_t *SymbolTable::find_scoped_symbol(const std::string &name,
}
}

std::string SymbolTable::get_unique_name(const std::string &name) {
std::string SymbolTable::get_unique_name(const std::string &name, bool use_unique_id) {
std::string unique_name = name;
if( use_unique_id ) {
unique_name += "_" + lcompilers_unique_ID;
}
int counter = 1;
while (scope.find(unique_name) != scope.end()) {
unique_name = name + std::to_string(counter);
Expand Down
2 changes: 1 addition & 1 deletion src/libasr/asr_scopes.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ struct SymbolTable {
ASR::symbol_t *find_scoped_symbol(const std::string &name,
size_t n_scope_names, char **m_scope_names);

std::string get_unique_name(const std::string &name);
std::string get_unique_name(const std::string &name, bool use_unique_id=true);

void move_symbols_from_global_scope(Allocator &al,
SymbolTable *module_scope, Vec<char *> &syms,
Expand Down
91 changes: 88 additions & 3 deletions src/libasr/asr_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,91 @@ void extract_module_python(const ASR::TranslationUnit_t &m,
}
}

void update_call_args(Allocator &al, SymbolTable *current_scope, bool implicit_interface) {
/*
Iterate over body of program, check if there are any subroutine calls if yes, iterate over its args
and update the args if they are equal to the old symbol
For example:
function func(f)
double precision c
call sub2(c)
print *, c(d)
end function
This function updates `sub2` to use the new symbol `c` that is now a function, not a variable.
Along with this, it also updates the args of `sub2` to use the new symbol `c` instead of the old one.
*/
class UpdateArgsVisitor : public PassUtils::PassVisitor<UpdateArgsVisitor>
{
public:
SymbolTable* scope = current_scope;
UpdateArgsVisitor(Allocator &al) : PassVisitor(al, nullptr) {}

ASR::symbol_t* fetch_sym(ASR::symbol_t* arg_sym_underlying) {
ASR::symbol_t* sym = nullptr;
if (ASR::is_a<ASR::Variable_t>(*arg_sym_underlying)) {
ASR::Variable_t* arg_variable = ASR::down_cast<ASR::Variable_t>(arg_sym_underlying);
std::string arg_variable_name = std::string(arg_variable->m_name);
sym = arg_variable->m_parent_symtab->get_symbol(arg_variable_name);
} else if (ASR::is_a<ASR::Function_t>(*arg_sym_underlying)) {
ASR::Function_t* arg_function = ASR::down_cast<ASR::Function_t>(arg_sym_underlying);
std::string arg_function_name = std::string(arg_function->m_name);
sym = arg_function->m_symtab->parent->get_symbol(arg_function_name);
}
return sym;
}

void visit_SubroutineCall(const ASR::SubroutineCall_t& x) {
ASR::SubroutineCall_t* subrout_call = (ASR::SubroutineCall_t*)(&x);
for (size_t j = 0; j < subrout_call->n_args; j++) {
ASR::call_arg_t arg = subrout_call->m_args[j];
ASR::expr_t* arg_expr = arg.m_value;
if (ASR::is_a<ASR::Var_t>(*arg_expr)) {
ASR::Var_t* arg_var = ASR::down_cast<ASR::Var_t>(arg_expr);
ASR::symbol_t* arg_sym = arg_var->m_v;
ASR::symbol_t* arg_sym_underlying = ASRUtils::symbol_get_past_external(arg_sym);
ASR::symbol_t* sym = fetch_sym(arg_sym_underlying);
if (sym != arg_sym) {
subrout_call->m_args[j].m_value = ASRUtils::EXPR(ASR::make_Var_t(al, arg_expr->base.loc, sym));
}
}
}
}

void visit_Function(const ASR::Function_t& x) {
ASR::Function_t* func = (ASR::Function_t*)(&x);
for (size_t i = 0; i < func->n_args; i++) {
ASR::expr_t* arg_expr = func->m_args[i];
if (ASR::is_a<ASR::Var_t>(*arg_expr)) {
ASR::Var_t* arg_var = ASR::down_cast<ASR::Var_t>(arg_expr);
ASR::symbol_t* arg_sym = arg_var->m_v;
ASR::symbol_t* arg_sym_underlying = ASRUtils::symbol_get_past_external(arg_sym);
ASR::symbol_t* sym = fetch_sym(arg_sym_underlying);
if (sym != arg_sym) {
func->m_args[i] = ASRUtils::EXPR(ASR::make_Var_t(al, arg_expr->base.loc, sym));
}
}
}
scope = func->m_symtab;
for (auto &it: scope->get_scope()) {
visit_symbol(*it.second);
}
scope = func->m_symtab;
for (size_t i = 0; i < func->n_body; i++) {
visit_stmt(*func->m_body[i]);
}
scope = func->m_symtab;
}
};

if (implicit_interface) {
UpdateArgsVisitor v(al);
SymbolTable *tu_symtab = ASRUtils::get_tu_symtab(current_scope);
ASR::asr_t* asr_ = tu_symtab->asr_owner;
ASR::TranslationUnit_t* tu = ASR::down_cast2<ASR::TranslationUnit_t>(asr_);
v.visit_TranslationUnit(*tu);
}
}

ASR::Module_t* extract_module(const ASR::TranslationUnit_t &m) {
LCOMPILERS_ASSERT(m.m_global_scope->get_scope().size()== 1);
for (auto &a : m.m_global_scope->get_scope()) {
Expand Down Expand Up @@ -368,7 +453,7 @@ ASR::asr_t* getStructInstanceMember_t(Allocator& al, const Location& loc,
}
std::string mangled_name = current_scope->get_unique_name(
std::string(module_name) + "_" +
std::string(der_type_name));
std::string(der_type_name), false);
char* mangled_name_char = s2c(al, mangled_name);
if( current_scope->get_symbol(mangled_name) == nullptr ) {
bool make_new_ext_sym = true;
Expand Down Expand Up @@ -789,7 +874,7 @@ void process_overloaded_assignment_function(ASR::symbol_t* proc, ASR::expr_t* ta
ASRUtils::insert_module_dependency(a_name, al, current_module_dependencies);
ASRUtils::set_absent_optional_arguments_to_null(a_args, subrout, al);
asr = ASRUtils::make_SubroutineCall_t_util(al, loc, a_name, sym,
a_args.p, 2, nullptr);
a_args.p, 2, nullptr, nullptr, false);
}
}
}
Expand Down Expand Up @@ -1129,7 +1214,7 @@ ASR::asr_t* symbol_resolve_external_generic_procedure_without_eval(
}
return ASRUtils::make_SubroutineCall_t_util(al, loc, final_sym,
v, args.p, args.size(),
nullptr);
nullptr, nullptr, false);
} else {
if( func ) {
ASRUtils::set_absent_optional_arguments_to_null(args, func, al);
Expand Down
Loading