Skip to content

Commit

Permalink
Use SetChar inplace of std::set<std::string> and Vec<char*>
Browse files Browse the repository at this point in the history
  • Loading branch information
czgdp1807 committed Apr 12, 2023
1 parent 43bfbae commit 5699f80
Show file tree
Hide file tree
Showing 12 changed files with 122 additions and 123 deletions.
2 changes: 1 addition & 1 deletion src/libasr/asdl_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2178,7 +2178,7 @@ def visitProduct(self, prod, name):
if field.type == "identifier":
self.emit('{', 2)
self.emit('uint64_t n = self().read_int64();', 3)
self.emit("Vec<char*> v;", 3)
self.emit("Vec<char*> v_%s;" % (field.name), 3)
self.emit("v.reserve(al, n);", 3)
self.emit("for (uint64_t i=0; i<n; i++) {", 3)
self.emit("v.push_back(al, self().read_cstring());", 4)
Expand Down
10 changes: 3 additions & 7 deletions src/libasr/asr_scopes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ std::string SymbolTable::get_unique_name(const std::string &name) {

void SymbolTable::move_symbols_from_global_scope(Allocator &al,
SymbolTable *module_scope, Vec<char *> &syms,
Vec<char *> &mod_dependencies) {
SetChar &mod_dependencies) {
// TODO: This isn't scalable. We have write a visitor in asdl_cpp.py
syms.reserve(al, 4);
mod_dependencies.reserve(al, 4);
Expand Down Expand Up @@ -180,9 +180,7 @@ void SymbolTable::move_symbols_from_global_scope(Allocator &al,
if (s != nullptr && ASR::is_a<ASR::ExternalSymbol_t>(*s)) {
char *es_name = ASR::down_cast<
ASR::ExternalSymbol_t>(s)->m_module_name;
if (!present(mod_dependencies, es_name)) {
mod_dependencies.push_back(al, es_name);
}
mod_dependencies.push_back(al, es_name);
}
}
fn->m_symtab->parent = module_scope;
Expand All @@ -197,9 +195,7 @@ void SymbolTable::move_symbols_from_global_scope(Allocator &al,
break;
} case (ASR::symbolType::ExternalSymbol) : {
ASR::ExternalSymbol_t *es = ASR::down_cast<ASR::ExternalSymbol_t>(a.second);
if (!present(mod_dependencies, es->m_module_name)) {
mod_dependencies.push_back(al, es->m_module_name);
}
mod_dependencies.push_back(al, es->m_module_name);
es->m_parent_symtab = module_scope;
ASR::symbol_t *s = ASRUtils::symbol_get_past_external(a.second);
LCOMPILERS_ASSERT(s);
Expand Down
2 changes: 1 addition & 1 deletion src/libasr/asr_scopes.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ struct SymbolTable {

void move_symbols_from_global_scope(Allocator &al,
SymbolTable *module_scope, Vec<char *> &syms,
Vec<char *> &mod_dependencies);
SetChar &mod_dependencies);
};

} // namespace LCompilers
Expand Down
22 changes: 11 additions & 11 deletions src/libasr/asr_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -430,8 +430,8 @@ bool use_overloaded(ASR::expr_t* left, ASR::expr_t* right,
ASR::binopType op, std::string& intrinsic_op_name,
SymbolTable* curr_scope, ASR::asr_t*& asr,
Allocator &al, const Location& loc,
std::set<std::string>& current_function_dependencies,
Vec<char*>& current_module_dependencies,
SetChar& current_function_dependencies,
SetChar& current_module_dependencies,
const std::function<void (const std::string &, const Location &)> err) {
ASR::ttype_t *left_type = ASRUtils::expr_type(left);
ASR::ttype_t *right_type = ASRUtils::expr_type(right);
Expand Down Expand Up @@ -478,7 +478,7 @@ bool use_overloaded(ASR::expr_t* left, ASR::expr_t* right,
} else {
return_type = ASRUtils::expr_type(func->m_return_var);
}
current_function_dependencies.insert(matched_func_name);
current_function_dependencies.push_back(al, s2c(al, matched_func_name));
ASRUtils::insert_module_dependency(a_name, al, current_module_dependencies);
asr = ASR::make_FunctionCall_t(al, loc, a_name, sym,
a_args.p, 2,
Expand Down Expand Up @@ -544,8 +544,8 @@ bool is_op_overloaded(ASR::binopType op, std::string& intrinsic_op_name,

void process_overloaded_assignment_function(ASR::symbol_t* proc, ASR::expr_t* target, ASR::expr_t* value,
ASR::ttype_t* target_type, ASR::ttype_t* value_type, bool& found, Allocator& al, const Location& target_loc,
const Location& value_loc, SymbolTable* curr_scope, std::set<std::string>& current_function_dependencies,
Vec<char*>& current_module_dependencies, ASR::asr_t*& asr, ASR::symbol_t* sym, const Location& loc, ASR::expr_t* expr_dt,
const Location& value_loc, SymbolTable* curr_scope, SetChar& current_function_dependencies,
SetChar& current_module_dependencies, ASR::asr_t*& asr, ASR::symbol_t* sym, const Location& loc, ASR::expr_t* expr_dt,
const std::function<void (const std::string &, const Location &)> err, char* pass_arg=nullptr) {
ASR::Function_t* subrout = ASR::down_cast<ASR::Function_t>(proc);
std::string matched_subrout_name = "";
Expand Down Expand Up @@ -592,7 +592,7 @@ void process_overloaded_assignment_function(ASR::symbol_t* proc, ASR::expr_t* ta
if( a_name == nullptr ) {
err("Unable to resolve matched subroutine for assignment overloading, " + matched_subrout_name, loc);
}
current_function_dependencies.insert(matched_subrout_name);
current_function_dependencies.push_back(al, s2c(al, matched_subrout_name));
ASRUtils::insert_module_dependency(a_name, al, current_module_dependencies);
asr = ASR::make_SubroutineCall_t(al, loc, a_name, sym,
a_args.p, 2, nullptr);
Expand All @@ -603,8 +603,8 @@ void process_overloaded_assignment_function(ASR::symbol_t* proc, ASR::expr_t* ta
bool use_overloaded_assignment(ASR::expr_t* target, ASR::expr_t* value,
SymbolTable* curr_scope, ASR::asr_t*& asr,
Allocator &al, const Location& loc,
std::set<std::string>& current_function_dependencies,
Vec<char*>& current_module_dependencies,
SetChar& current_function_dependencies,
SetChar& current_module_dependencies,
const std::function<void (const std::string &, const Location &)> err) {
ASR::ttype_t *target_type = ASRUtils::expr_type(target);
ASR::ttype_t *value_type = ASRUtils::expr_type(value);
Expand Down Expand Up @@ -659,8 +659,8 @@ bool use_overloaded(ASR::expr_t* left, ASR::expr_t* right,
ASR::cmpopType op, std::string& intrinsic_op_name,
SymbolTable* curr_scope, ASR::asr_t*& asr,
Allocator &al, const Location& loc,
std::set<std::string>& current_function_dependencies,
Vec<char*>& current_module_dependencies,
SetChar& current_function_dependencies,
SetChar& current_module_dependencies,
const std::function<void (const std::string &, const Location &)> err) {
ASR::ttype_t *left_type = ASRUtils::expr_type(left);
ASR::ttype_t *right_type = ASRUtils::expr_type(right);
Expand Down Expand Up @@ -731,7 +731,7 @@ bool use_overloaded(ASR::expr_t* left, ASR::expr_t* right,
} else {
return_type = ASRUtils::expr_type(func->m_return_var);
}
current_function_dependencies.insert(matched_func_name);
current_function_dependencies.push_back(al, s2c(al, matched_func_name));
ASRUtils::insert_module_dependency(a_name, al, current_module_dependencies);
asr = ASR::make_FunctionCall_t(al, loc, a_name, sym,
a_args.p, 2,
Expand Down
29 changes: 14 additions & 15 deletions src/libasr/asr_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1339,8 +1339,8 @@ bool use_overloaded(ASR::expr_t* left, ASR::expr_t* right,
ASR::binopType op, std::string& intrinsic_op_name,
SymbolTable* curr_scope, ASR::asr_t*& asr,
Allocator &al, const Location& loc,
std::set<std::string>& current_function_dependencies,
Vec<char*>& current_module_dependencies,
SetChar& current_function_dependencies,
SetChar& current_module_dependencies,
const std::function<void (const std::string &, const Location &)> err);

bool is_op_overloaded(ASR::binopType op, std::string& intrinsic_op_name,
Expand All @@ -1350,8 +1350,8 @@ bool use_overloaded(ASR::expr_t* left, ASR::expr_t* right,
ASR::cmpopType op, std::string& intrinsic_op_name,
SymbolTable* curr_scope, ASR::asr_t*& asr,
Allocator &al, const Location& loc,
std::set<std::string>& current_function_dependencies,
Vec<char*>& current_module_dependencies,
SetChar& current_function_dependencies,
SetChar& current_module_dependencies,
const std::function<void (const std::string &, const Location &)> err);

bool is_op_overloaded(ASR::cmpopType op, std::string& intrinsic_op_name,
Expand All @@ -1360,8 +1360,8 @@ bool is_op_overloaded(ASR::cmpopType op, std::string& intrinsic_op_name,
bool use_overloaded_assignment(ASR::expr_t* target, ASR::expr_t* value,
SymbolTable* curr_scope, ASR::asr_t*& asr,
Allocator &al, const Location& loc,
std::set<std::string>& current_function_dependencies,
Vec<char*>& /*current_module_dependencies*/,
SetChar& current_function_dependencies,
SetChar& /*current_module_dependencies*/,
const std::function<void (const std::string &, const Location &)> err);

void set_intrinsic(ASR::symbol_t* sym);
Expand Down Expand Up @@ -2351,16 +2351,15 @@ static inline bool is_dimension_empty(ASR::dimension_t* dims, size_t n) {
}

static inline void insert_module_dependency(ASR::symbol_t* a,
Allocator& al, Vec<char*>& module_dependencies) {
Allocator& al, SetChar& module_dependencies) {
if( ASR::is_a<ASR::ExternalSymbol_t>(*a) ) {
ASR::ExternalSymbol_t* a_ext = ASR::down_cast<ASR::ExternalSymbol_t>(a);
ASR::symbol_t* a_sym_module = ASRUtils::get_asr_owner(a_ext->m_external);
if( a_sym_module ) {
while( a_sym_module && !ASR::is_a<ASR::Module_t>(*a_sym_module) ) {
a_sym_module = ASRUtils::get_asr_owner(a_sym_module);
}
if( a_sym_module && !LCompilers::present(module_dependencies,
ASRUtils::symbol_name(a_sym_module)) ) {
if( a_sym_module ) {
module_dependencies.push_back(al, ASRUtils::symbol_name(a_sym_module));
}
}
Expand Down Expand Up @@ -2494,13 +2493,13 @@ class ReplaceArgVisitor: public ASR::BaseExprReplacer<ReplaceArgVisitor> {

Vec<ASR::call_arg_t>& orig_args;

std::set<std::string>& current_function_dependencies;
SetChar& current_function_dependencies;

public:

ReplaceArgVisitor(Allocator& al_, SymbolTable* current_scope_,
ASR::Function_t* orig_func_, Vec<ASR::call_arg_t>& orig_args_,
std::set<std::string>& current_function_dependencies_) :
SetChar& current_function_dependencies_) :
al(al_), current_scope(current_scope_), orig_func(orig_func_),
orig_args(orig_args_), current_function_dependencies(current_function_dependencies_)
{}
Expand Down Expand Up @@ -2572,7 +2571,7 @@ class ReplaceArgVisitor: public ASR::BaseExprReplacer<ReplaceArgVisitor> {
default:
break;
}
current_function_dependencies.insert(std::string(ASRUtils::symbol_name(new_es)));
current_function_dependencies.push_back(al, ASRUtils::symbol_name(new_es));
x->m_name = new_es;
}

Expand Down Expand Up @@ -3236,11 +3235,11 @@ class CollectIdentifiersFromASRExpression: public ASR::BaseWalkVisitor<CollectId
private:

Allocator& al;
Vec<char*>& identifiers;
SetChar& identifiers;

public:

CollectIdentifiersFromASRExpression(Allocator& al_, Vec<char*>& identifiers_) :
CollectIdentifiersFromASRExpression(Allocator& al_, SetChar& identifiers_) :
al(al_), identifiers(identifiers_)
{}

Expand All @@ -3249,7 +3248,7 @@ class CollectIdentifiersFromASRExpression: public ASR::BaseWalkVisitor<CollectId
}
};

static inline void collect_variable_dependencies(Allocator& al, Vec<char*>& deps_vec,
static inline void collect_variable_dependencies(Allocator& al, SetChar& deps_vec,
ASR::ttype_t* type=nullptr, ASR::expr_t* init_expr=nullptr,
ASR::expr_t* value=nullptr) {
ASRUtils::CollectIdentifiersFromASRExpression collector(al, deps_vec);
Expand Down
29 changes: 29 additions & 0 deletions src/libasr/asr_verify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class VerifyVisitor : public BaseWalkVisitor<VerifyVisitor>

// Requires the condition `cond` to be true. Raise an exception otherwise.
#define require(cond, error_msg) require_impl((cond), (error_msg), x.base.base.loc)
#define require_with_loc(cond, error_msg, loc) require_impl((cond), (error_msg), loc)
void require_impl(bool cond, const std::string &error_msg, const Location &loc) {
if (!cond) {
diagnostics.message_label("ASR verify: " + error_msg,
Expand Down Expand Up @@ -221,6 +222,20 @@ class VerifyVisitor : public BaseWalkVisitor<VerifyVisitor>
current_symtab = parent_symtab;
}

void verify_unique_dependencies(char** m_dependencies,
size_t n_dependencies, std::string m_name, const Location& loc) {
// Check if any dependency is duplicated
// in the dependency list of the function
std::set<std::string> dependencies_set;
for( size_t i = 0; i < n_dependencies; i++ ) {
std::string found_dep = m_dependencies[i];
require_with_loc(dependencies_set.find(found_dep) == dependencies_set.end(),
"Symbol " + found_dep + " is duplicated in the dependency "
"list of " + m_name, loc);
dependencies_set.insert(found_dep);
}
}

void visit_Module(const Module_t &x) {
module_dependencies.clear();
module_dependencies.reserve(x.n_dependencies);
Expand All @@ -242,6 +257,10 @@ class VerifyVisitor : public BaseWalkVisitor<VerifyVisitor>
for (auto &a : x.m_symtab->get_scope()) {
this->visit_symbol(*a.second);
}

verify_unique_dependencies(x.m_dependencies, x.n_dependencies,
x.m_name, x.base.base.loc);

for (size_t i=0; i < x.n_dependencies; i++) {
require(x.m_dependencies[i] != nullptr,
"A module dependency must not be a nullptr");
Expand Down Expand Up @@ -351,6 +370,10 @@ class VerifyVisitor : public BaseWalkVisitor<VerifyVisitor>
if (x.m_return_var) {
visit_expr(*x.m_return_var);
}

verify_unique_dependencies(x.m_dependencies, x.n_dependencies,
x.m_name, x.base.base.loc);

// Check if there are unnecessary dependencies
// present in the dependency list of the function
for( size_t i = 0; i < x.n_dependencies; i++ ) {
Expand Down Expand Up @@ -426,6 +449,9 @@ class VerifyVisitor : public BaseWalkVisitor<VerifyVisitor>
std::string(x.m_dependencies[i]) + " is not a dependency of " + std::string(x.m_name)
+ " but it is present in its dependency list.");
}

verify_unique_dependencies(x.m_dependencies, x.n_dependencies,
x.m_name, x.base.base.loc);
current_symtab = parent_symtab;
}

Expand Down Expand Up @@ -538,6 +564,9 @@ class VerifyVisitor : public BaseWalkVisitor<VerifyVisitor>
visit_expr(*x.m_symbolic_value);
visit_ttype(*x.m_type);

verify_unique_dependencies(x.m_dependencies, x.n_dependencies,
x.m_name, x.base.base.loc);

// Verify dependencies
for( size_t i = 0; i < x.n_dependencies; i++ ) {
require(std::find(
Expand Down
2 changes: 1 addition & 1 deletion src/libasr/pass/global_stmts_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ void pass_wrap_global_stmts_into_program(Allocator &al,
std::string prog_name = "main_program";
Vec<ASR::stmt_t*> prog_body;
prog_body.reserve(al, 1);
Vec<char *> prog_dep;
SetChar prog_dep;
prog_dep.reserve(al, 1);
if (unit.n_items > 0) {
pass_wrap_global_stmts_into_function(al, unit, pass_options);
Expand Down
2 changes: 1 addition & 1 deletion src/libasr/pass/global_symbols.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ void pass_wrap_global_syms_into_module(Allocator &al,
char *module_name = s2c(al, "_global_symbols");
SymbolTable *module_scope = al.make_new<SymbolTable>(unit.m_global_scope);
Vec<char *> moved_symbols;
Vec<char *> mod_dependencies;
SetChar mod_dependencies;

// Move all the symbols from global into the module scope
unit.m_global_scope->move_symbols_from_global_scope(al, module_scope,
Expand Down
Loading

0 comments on commit 5699f80

Please sign in to comment.