Skip to content

Commit

Permalink
Merge pull request lcompilers#1693 from czgdp1807/deps
Browse files Browse the repository at this point in the history
Fix for duplicate dependencies in ASR
  • Loading branch information
czgdp1807 authored Apr 12, 2023
2 parents d377bf9 + f248074 commit ae6c547
Show file tree
Hide file tree
Showing 79 changed files with 280 additions and 189 deletions.
2 changes: 1 addition & 1 deletion src/libasr/asdl_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2178,7 +2178,7 @@ def visitProduct(self, prod, name):
if field.type == "identifier":
self.emit('{', 2)
self.emit('uint64_t n = self().read_int64();', 3)
self.emit("Vec<char*> v;", 3)
self.emit("Vec<char*> v_%s;" % (field.name), 3)
self.emit("v.reserve(al, n);", 3)
self.emit("for (uint64_t i=0; i<n; i++) {", 3)
self.emit("v.push_back(al, self().read_cstring());", 4)
Expand Down
10 changes: 3 additions & 7 deletions src/libasr/asr_scopes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ std::string SymbolTable::get_unique_name(const std::string &name) {

void SymbolTable::move_symbols_from_global_scope(Allocator &al,
SymbolTable *module_scope, Vec<char *> &syms,
Vec<char *> &mod_dependencies) {
SetChar &mod_dependencies) {
// TODO: This isn't scalable. We have write a visitor in asdl_cpp.py
syms.reserve(al, 4);
mod_dependencies.reserve(al, 4);
Expand Down Expand Up @@ -180,9 +180,7 @@ void SymbolTable::move_symbols_from_global_scope(Allocator &al,
if (s != nullptr && ASR::is_a<ASR::ExternalSymbol_t>(*s)) {
char *es_name = ASR::down_cast<
ASR::ExternalSymbol_t>(s)->m_module_name;
if (!present(mod_dependencies, es_name)) {
mod_dependencies.push_back(al, es_name);
}
mod_dependencies.push_back(al, es_name);
}
}
fn->m_symtab->parent = module_scope;
Expand All @@ -197,9 +195,7 @@ void SymbolTable::move_symbols_from_global_scope(Allocator &al,
break;
} case (ASR::symbolType::ExternalSymbol) : {
ASR::ExternalSymbol_t *es = ASR::down_cast<ASR::ExternalSymbol_t>(a.second);
if (!present(mod_dependencies, es->m_module_name)) {
mod_dependencies.push_back(al, es->m_module_name);
}
mod_dependencies.push_back(al, es->m_module_name);
es->m_parent_symtab = module_scope;
ASR::symbol_t *s = ASRUtils::symbol_get_past_external(a.second);
LCOMPILERS_ASSERT(s);
Expand Down
2 changes: 1 addition & 1 deletion src/libasr/asr_scopes.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ struct SymbolTable {

void move_symbols_from_global_scope(Allocator &al,
SymbolTable *module_scope, Vec<char *> &syms,
Vec<char *> &mod_dependencies);
SetChar &mod_dependencies);
};

} // namespace LCompilers
Expand Down
22 changes: 11 additions & 11 deletions src/libasr/asr_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -430,8 +430,8 @@ bool use_overloaded(ASR::expr_t* left, ASR::expr_t* right,
ASR::binopType op, std::string& intrinsic_op_name,
SymbolTable* curr_scope, ASR::asr_t*& asr,
Allocator &al, const Location& loc,
std::set<std::string>& current_function_dependencies,
Vec<char*>& current_module_dependencies,
SetChar& current_function_dependencies,
SetChar& current_module_dependencies,
const std::function<void (const std::string &, const Location &)> err) {
ASR::ttype_t *left_type = ASRUtils::expr_type(left);
ASR::ttype_t *right_type = ASRUtils::expr_type(right);
Expand Down Expand Up @@ -478,7 +478,7 @@ bool use_overloaded(ASR::expr_t* left, ASR::expr_t* right,
} else {
return_type = ASRUtils::expr_type(func->m_return_var);
}
current_function_dependencies.insert(matched_func_name);
current_function_dependencies.push_back(al, s2c(al, matched_func_name));
ASRUtils::insert_module_dependency(a_name, al, current_module_dependencies);
asr = ASR::make_FunctionCall_t(al, loc, a_name, sym,
a_args.p, 2,
Expand Down Expand Up @@ -544,8 +544,8 @@ bool is_op_overloaded(ASR::binopType op, std::string& intrinsic_op_name,

void process_overloaded_assignment_function(ASR::symbol_t* proc, ASR::expr_t* target, ASR::expr_t* value,
ASR::ttype_t* target_type, ASR::ttype_t* value_type, bool& found, Allocator& al, const Location& target_loc,
const Location& value_loc, SymbolTable* curr_scope, std::set<std::string>& current_function_dependencies,
Vec<char*>& current_module_dependencies, ASR::asr_t*& asr, ASR::symbol_t* sym, const Location& loc, ASR::expr_t* expr_dt,
const Location& value_loc, SymbolTable* curr_scope, SetChar& current_function_dependencies,
SetChar& current_module_dependencies, ASR::asr_t*& asr, ASR::symbol_t* sym, const Location& loc, ASR::expr_t* expr_dt,
const std::function<void (const std::string &, const Location &)> err, char* pass_arg=nullptr) {
ASR::Function_t* subrout = ASR::down_cast<ASR::Function_t>(proc);
std::string matched_subrout_name = "";
Expand Down Expand Up @@ -592,7 +592,7 @@ void process_overloaded_assignment_function(ASR::symbol_t* proc, ASR::expr_t* ta
if( a_name == nullptr ) {
err("Unable to resolve matched subroutine for assignment overloading, " + matched_subrout_name, loc);
}
current_function_dependencies.insert(matched_subrout_name);
current_function_dependencies.push_back(al, s2c(al, matched_subrout_name));
ASRUtils::insert_module_dependency(a_name, al, current_module_dependencies);
asr = ASR::make_SubroutineCall_t(al, loc, a_name, sym,
a_args.p, 2, nullptr);
Expand All @@ -603,8 +603,8 @@ void process_overloaded_assignment_function(ASR::symbol_t* proc, ASR::expr_t* ta
bool use_overloaded_assignment(ASR::expr_t* target, ASR::expr_t* value,
SymbolTable* curr_scope, ASR::asr_t*& asr,
Allocator &al, const Location& loc,
std::set<std::string>& current_function_dependencies,
Vec<char*>& current_module_dependencies,
SetChar& current_function_dependencies,
SetChar& current_module_dependencies,
const std::function<void (const std::string &, const Location &)> err) {
ASR::ttype_t *target_type = ASRUtils::expr_type(target);
ASR::ttype_t *value_type = ASRUtils::expr_type(value);
Expand Down Expand Up @@ -659,8 +659,8 @@ bool use_overloaded(ASR::expr_t* left, ASR::expr_t* right,
ASR::cmpopType op, std::string& intrinsic_op_name,
SymbolTable* curr_scope, ASR::asr_t*& asr,
Allocator &al, const Location& loc,
std::set<std::string>& current_function_dependencies,
Vec<char*>& current_module_dependencies,
SetChar& current_function_dependencies,
SetChar& current_module_dependencies,
const std::function<void (const std::string &, const Location &)> err) {
ASR::ttype_t *left_type = ASRUtils::expr_type(left);
ASR::ttype_t *right_type = ASRUtils::expr_type(right);
Expand Down Expand Up @@ -731,7 +731,7 @@ bool use_overloaded(ASR::expr_t* left, ASR::expr_t* right,
} else {
return_type = ASRUtils::expr_type(func->m_return_var);
}
current_function_dependencies.insert(matched_func_name);
current_function_dependencies.push_back(al, s2c(al, matched_func_name));
ASRUtils::insert_module_dependency(a_name, al, current_module_dependencies);
asr = ASR::make_FunctionCall_t(al, loc, a_name, sym,
a_args.p, 2,
Expand Down
29 changes: 14 additions & 15 deletions src/libasr/asr_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1339,8 +1339,8 @@ bool use_overloaded(ASR::expr_t* left, ASR::expr_t* right,
ASR::binopType op, std::string& intrinsic_op_name,
SymbolTable* curr_scope, ASR::asr_t*& asr,
Allocator &al, const Location& loc,
std::set<std::string>& current_function_dependencies,
Vec<char*>& current_module_dependencies,
SetChar& current_function_dependencies,
SetChar& current_module_dependencies,
const std::function<void (const std::string &, const Location &)> err);

bool is_op_overloaded(ASR::binopType op, std::string& intrinsic_op_name,
Expand All @@ -1350,8 +1350,8 @@ bool use_overloaded(ASR::expr_t* left, ASR::expr_t* right,
ASR::cmpopType op, std::string& intrinsic_op_name,
SymbolTable* curr_scope, ASR::asr_t*& asr,
Allocator &al, const Location& loc,
std::set<std::string>& current_function_dependencies,
Vec<char*>& current_module_dependencies,
SetChar& current_function_dependencies,
SetChar& current_module_dependencies,
const std::function<void (const std::string &, const Location &)> err);

bool is_op_overloaded(ASR::cmpopType op, std::string& intrinsic_op_name,
Expand All @@ -1360,8 +1360,8 @@ bool is_op_overloaded(ASR::cmpopType op, std::string& intrinsic_op_name,
bool use_overloaded_assignment(ASR::expr_t* target, ASR::expr_t* value,
SymbolTable* curr_scope, ASR::asr_t*& asr,
Allocator &al, const Location& loc,
std::set<std::string>& current_function_dependencies,
Vec<char*>& /*current_module_dependencies*/,
SetChar& current_function_dependencies,
SetChar& /*current_module_dependencies*/,
const std::function<void (const std::string &, const Location &)> err);

void set_intrinsic(ASR::symbol_t* sym);
Expand Down Expand Up @@ -2351,16 +2351,15 @@ static inline bool is_dimension_empty(ASR::dimension_t* dims, size_t n) {
}

static inline void insert_module_dependency(ASR::symbol_t* a,
Allocator& al, Vec<char*>& module_dependencies) {
Allocator& al, SetChar& module_dependencies) {
if( ASR::is_a<ASR::ExternalSymbol_t>(*a) ) {
ASR::ExternalSymbol_t* a_ext = ASR::down_cast<ASR::ExternalSymbol_t>(a);
ASR::symbol_t* a_sym_module = ASRUtils::get_asr_owner(a_ext->m_external);
if( a_sym_module ) {
while( a_sym_module && !ASR::is_a<ASR::Module_t>(*a_sym_module) ) {
a_sym_module = ASRUtils::get_asr_owner(a_sym_module);
}
if( a_sym_module && !LCompilers::present(module_dependencies,
ASRUtils::symbol_name(a_sym_module)) ) {
if( a_sym_module ) {
module_dependencies.push_back(al, ASRUtils::symbol_name(a_sym_module));
}
}
Expand Down Expand Up @@ -2494,13 +2493,13 @@ class ReplaceArgVisitor: public ASR::BaseExprReplacer<ReplaceArgVisitor> {

Vec<ASR::call_arg_t>& orig_args;

std::set<std::string>& current_function_dependencies;
SetChar& current_function_dependencies;

public:

ReplaceArgVisitor(Allocator& al_, SymbolTable* current_scope_,
ASR::Function_t* orig_func_, Vec<ASR::call_arg_t>& orig_args_,
std::set<std::string>& current_function_dependencies_) :
SetChar& current_function_dependencies_) :
al(al_), current_scope(current_scope_), orig_func(orig_func_),
orig_args(orig_args_), current_function_dependencies(current_function_dependencies_)
{}
Expand Down Expand Up @@ -2572,7 +2571,7 @@ class ReplaceArgVisitor: public ASR::BaseExprReplacer<ReplaceArgVisitor> {
default:
break;
}
current_function_dependencies.insert(std::string(ASRUtils::symbol_name(new_es)));
current_function_dependencies.push_back(al, ASRUtils::symbol_name(new_es));
x->m_name = new_es;
}

Expand Down Expand Up @@ -3236,11 +3235,11 @@ class CollectIdentifiersFromASRExpression: public ASR::BaseWalkVisitor<CollectId
private:

Allocator& al;
Vec<char*>& identifiers;
SetChar& identifiers;

public:

CollectIdentifiersFromASRExpression(Allocator& al_, Vec<char*>& identifiers_) :
CollectIdentifiersFromASRExpression(Allocator& al_, SetChar& identifiers_) :
al(al_), identifiers(identifiers_)
{}

Expand All @@ -3249,7 +3248,7 @@ class CollectIdentifiersFromASRExpression: public ASR::BaseWalkVisitor<CollectId
}
};

static inline void collect_variable_dependencies(Allocator& al, Vec<char*>& deps_vec,
static inline void collect_variable_dependencies(Allocator& al, SetChar& deps_vec,
ASR::ttype_t* type=nullptr, ASR::expr_t* init_expr=nullptr,
ASR::expr_t* value=nullptr) {
ASRUtils::CollectIdentifiersFromASRExpression collector(al, deps_vec);
Expand Down
29 changes: 29 additions & 0 deletions src/libasr/asr_verify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class VerifyVisitor : public BaseWalkVisitor<VerifyVisitor>

// Requires the condition `cond` to be true. Raise an exception otherwise.
#define require(cond, error_msg) require_impl((cond), (error_msg), x.base.base.loc)
#define require_with_loc(cond, error_msg, loc) require_impl((cond), (error_msg), loc)
void require_impl(bool cond, const std::string &error_msg, const Location &loc) {
if (!cond) {
diagnostics.message_label("ASR verify: " + error_msg,
Expand Down Expand Up @@ -221,6 +222,20 @@ class VerifyVisitor : public BaseWalkVisitor<VerifyVisitor>
current_symtab = parent_symtab;
}

void verify_unique_dependencies(char** m_dependencies,
size_t n_dependencies, std::string m_name, const Location& loc) {
// Check if any dependency is duplicated
// in the dependency list of the function
std::set<std::string> dependencies_set;
for( size_t i = 0; i < n_dependencies; i++ ) {
std::string found_dep = m_dependencies[i];
require_with_loc(dependencies_set.find(found_dep) == dependencies_set.end(),
"Symbol " + found_dep + " is duplicated in the dependency "
"list of " + m_name, loc);
dependencies_set.insert(found_dep);
}
}

void visit_Module(const Module_t &x) {
module_dependencies.clear();
module_dependencies.reserve(x.n_dependencies);
Expand All @@ -242,6 +257,10 @@ class VerifyVisitor : public BaseWalkVisitor<VerifyVisitor>
for (auto &a : x.m_symtab->get_scope()) {
this->visit_symbol(*a.second);
}

verify_unique_dependencies(x.m_dependencies, x.n_dependencies,
x.m_name, x.base.base.loc);

for (size_t i=0; i < x.n_dependencies; i++) {
require(x.m_dependencies[i] != nullptr,
"A module dependency must not be a nullptr");
Expand Down Expand Up @@ -351,6 +370,10 @@ class VerifyVisitor : public BaseWalkVisitor<VerifyVisitor>
if (x.m_return_var) {
visit_expr(*x.m_return_var);
}

verify_unique_dependencies(x.m_dependencies, x.n_dependencies,
x.m_name, x.base.base.loc);

// Check if there are unnecessary dependencies
// present in the dependency list of the function
for( size_t i = 0; i < x.n_dependencies; i++ ) {
Expand Down Expand Up @@ -426,6 +449,9 @@ class VerifyVisitor : public BaseWalkVisitor<VerifyVisitor>
std::string(x.m_dependencies[i]) + " is not a dependency of " + std::string(x.m_name)
+ " but it is present in its dependency list.");
}

verify_unique_dependencies(x.m_dependencies, x.n_dependencies,
x.m_name, x.base.base.loc);
current_symtab = parent_symtab;
}

Expand Down Expand Up @@ -538,6 +564,9 @@ class VerifyVisitor : public BaseWalkVisitor<VerifyVisitor>
visit_expr(*x.m_symbolic_value);
visit_ttype(*x.m_type);

verify_unique_dependencies(x.m_dependencies, x.n_dependencies,
x.m_name, x.base.base.loc);

// Verify dependencies
for( size_t i = 0; i < x.n_dependencies; i++ ) {
require(std::find(
Expand Down
92 changes: 92 additions & 0 deletions src/libasr/containers.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,49 @@ struct Vec {
#endif
}

template <class Q = T>
typename std::enable_if<std::is_same<Q, char*>::value, bool>::type present(Q x, size_t& index) {
for( size_t i = 0; i < n; i++ ) {
if( strcmp(p[i], x) == 0 ) {
index = i;
return true;
}
}
return false;
}

template <class Q = T>
typename std::enable_if<!std::is_same<Q, char*>::value, bool>::type present(Q x, size_t& index) {
for( size_t i = 0; i < n; i++ ) {
if( p[i] == x ) {
index = i;
return true;
}
}
return false;
}

void erase(T x) {
size_t delete_index;
if( !present(x, delete_index) ) {
return ;
}

for( int64_t i = delete_index; i < (int64_t) n - 1; i++ ) {
p[i] = p[i + 1];
}
if( n >= 1 ) {
n = n - 1;
}
}

void push_back_unique(Allocator &al, T x) {
size_t index;
if( !Vec<T>::present(x, index) ) {
Vec<T>::push_back(al, x);
}
}

void push_back(Allocator &al, T x) {
// This can pass by accident even if reserve() is not called (if
// reserve_called happens to be equal to vec_called_const when Vec is
Expand Down Expand Up @@ -135,6 +178,55 @@ struct Vec {
static_assert(std::is_standard_layout<Vec<int>>::value);
static_assert(std::is_trivial<Vec<int>>::value);

/*
SetChar emulates the std::set<std::string> API
so that it acts as a drop in replacement.
*/
struct SetChar: Vec<char*> {

bool reserved;

SetChar():
reserved(false) {
clear();
}

void clear() {
n = 0;
p = nullptr;
max = 0;
}

void clear(Allocator& al) {
reserve(al, 0);
}

void reserve(Allocator& al, size_t max) {
Vec<char*>::reserve(al, max);
reserved = true;
}

void from_pointer_n_copy(Allocator &al, char** p, size_t n) {
reserve(al, n);
for (size_t i = 0; i < n; i++) {
push_back(al, p[i]);
}
}

void from_pointer_n(char** p, size_t n) {
Vec<char*>::from_pointer_n(p, n);
reserved = true;
}

void push_back(Allocator &al, char* x) {
if( !reserved ) {
reserve(al, 0);
}

Vec<char*>::push_back_unique(al, x);
}
};

// String implementation (not null-terminated)
struct Str {
size_t n;
Expand Down
Loading

0 comments on commit ae6c547

Please sign in to comment.