Skip to content

Commit

Permalink
Merge pull request lcompilers#1384 from Thirumalai-Shaktivel/global_s…
Browse files Browse the repository at this point in the history
…tmts

Initial support for global Lists
  • Loading branch information
certik authored Mar 15, 2023
2 parents 9da3fce + f86901d commit 78c8347
Show file tree
Hide file tree
Showing 11 changed files with 83 additions and 13 deletions.
3 changes: 2 additions & 1 deletion integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -432,4 +432,5 @@ RUN(NAME bit_operations_i32 LABELS cpython llvm wasm wasm_x64)
RUN(NAME bit_operations_i64 LABELS cpython llvm wasm)

RUN(NAME test_argv_01 LABELS llvm) # TODO: Test using CPython
RUN(NAME global_syms_01 LABELS cpython)
RUN(NAME global_syms_01 LABELS cpython llvm)
RUN(NAME global_syms_02 LABELS cpython llvm)
23 changes: 23 additions & 0 deletions integration_tests/global_syms_02.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from ltypes import i32

x: list[i32]
x = [0, 1]
x.append(3)

def test_global_symbols():
assert len(x) == 3
x.insert(2, 2)

test_global_symbols()

i: i32
for i in range(len(x)):
assert i == x[i]

tmp: list[i32]
tmp = x

tmp.remove(0)
assert len(tmp) == 3
tmp.clear()
assert len(tmp) == 0
18 changes: 17 additions & 1 deletion src/libasr/asr_scopes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,11 @@ std::string SymbolTable::get_unique_name(const std::string &name) {

void SymbolTable::move_symbols_from_global_scope(Allocator &al,
SymbolTable *module_scope, Vec<char *> &syms,
Vec<char *> &mod_dependencies) {
Vec<char *> &mod_dependencies, Vec<ASR::stmt_t*> &var_init) {
// TODO: This isn't scalable. We have write a visitor in asdl_cpp.py
syms.reserve(al, 4);
mod_dependencies.reserve(al, 4);
var_init.reserve(al, 4);
for (auto &a : scope) {
switch (a.second->type) {
case (ASR::symbolType::Module): {
Expand Down Expand Up @@ -225,6 +226,21 @@ void SymbolTable::move_symbols_from_global_scope(Allocator &al,
} case (ASR::symbolType::Variable) : {
ASR::Variable_t *v = ASR::down_cast<ASR::Variable_t>(a.second);
v->m_parent_symtab = module_scope;
if (v->m_symbolic_value && !ASR::is_a<ASR::Const_t>(*v->m_type)) {
ASR::expr_t* v_expr = ASRUtils::EXPR(ASR::make_Var_t(
al, v->base.base.loc, (ASR::symbol_t *) v));
ASR::asr_t* assign = ASR::make_Assignment_t(al,
v->base.base.loc, v_expr, v->m_symbolic_value, nullptr);
var_init.push_back(al, ASRUtils::STMT(assign));
v->m_symbolic_value = nullptr;
v->m_value = nullptr;
Vec<char*> v_dependencies;
v_dependencies.reserve(al, 1);
ASRUtils::collect_variable_dependencies(al,
v_dependencies, v->m_type);
v->m_dependencies = v_dependencies.p;
v->n_dependencies = v_dependencies.size();
}
module_scope->add_symbol(a.first, (ASR::symbol_t *) v);
syms.push_back(al, s2c(al, a.first));
break;
Expand Down
3 changes: 2 additions & 1 deletion src/libasr/asr_scopes.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ namespace LCompilers {

namespace ASR {
struct asr_t;
struct stmt_t;
struct symbol_t;
}

Expand Down Expand Up @@ -84,7 +85,7 @@ struct SymbolTable {

void move_symbols_from_global_scope(Allocator &al,
SymbolTable *module_scope, Vec<char *> &syms,
Vec<char *> &mod_dependencies);
Vec<char *> &mod_dependencies, Vec<ASR::stmt_t*> &var_init);
};

} // namespace LCompilers
Expand Down
10 changes: 9 additions & 1 deletion src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2296,6 +2296,14 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
}
}
llvm_symtab[h] = ptr;
} else if (x.m_type->type == ASR::ttypeType::List) {
llvm::StructType* list_type = static_cast<llvm::StructType*>(
get_type_from_ttype_t_util(x.m_type));
llvm::Constant *ptr = module->getOrInsertGlobal(x.m_name, list_type);
module->getNamedGlobal(x.m_name)->setInitializer(
llvm::ConstantStruct::get(list_type,
llvm::Constant::getNullValue(list_type)));
llvm_symtab[h] = ptr;
} else if (x.m_type->type == ASR::ttypeType::TypeParameter) {
// Ignore type variables
} else {
Expand Down Expand Up @@ -7015,7 +7023,7 @@ Result<std::unique_ptr<LLVMModule>> asr_to_llvm(ASR::TranslationUnit_t &asr,
pass_manager.apply_passes(al, &asr, pass_options, diagnostics);

// Uncomment for debugging the ASR after the transformation
// std::cout << pickle(asr, true, true, true) << std::endl;
// std::cout << LPython::pickle(asr, true, true, true) << std::endl;

v.nested_func_types = pass_find_nested_vars(asr, context,
v.nested_globals, v.nested_call_out, v.nesting_map);
Expand Down
17 changes: 15 additions & 2 deletions src/libasr/pass/global_symbols.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,35 @@ namespace LCompilers {

void pass_wrap_global_syms_into_module(Allocator &al,
ASR::TranslationUnit_t &unit,
const LCompilers::PassOptions& /*pass_options*/) {
const LCompilers::PassOptions& pass_options) {
Location loc = unit.base.base.loc;
char *module_name = s2c(al, "_global_symbols");
SymbolTable *module_scope = al.make_new<SymbolTable>(unit.m_global_scope);
Vec<char *> moved_symbols;
Vec<char *> mod_dependencies;
Vec<ASR::stmt_t*> var_init;

// Move all the symbols from global into the module scope
unit.m_global_scope->move_symbols_from_global_scope(al, module_scope,
moved_symbols, mod_dependencies);
moved_symbols, mod_dependencies, var_init);

// Erase the symbols that are moved into the module
for (auto &sym: moved_symbols) {
unit.m_global_scope->erase_symbol(sym);
}

if (module_scope->get_symbol(pass_options.run_fun) && var_init.n > 0) {
ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(
module_scope->get_symbol(pass_options.run_fun));
for (size_t i = 0; i < f->n_body; i++) {
var_init.push_back(al, f->m_body[i]);
}
f->m_body = var_init.p;
f->n_body = var_init.n;
// Overwrites the function: `_lpython_main_program`
module_scope->add_symbol(f->m_name, (ASR::symbol_t *) f);
}

Vec<char *> m_dependencies;
m_dependencies.reserve(al, mod_dependencies.size());
for( auto &dep: mod_dependencies) {
Expand Down
4 changes: 4 additions & 0 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3800,6 +3800,10 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
void visit_For(const AST::For_t &/*x*/) {
// We skip this in the SymbolTable visitor, but visit it in the BodyVisitor
}

void visit_Assert(const AST::Assert_t &/*x*/) {
// We skip this in the SymbolTable visitor, but visit it in the BodyVisitor
}
};

Result<ASR::asr_t*> symbol_table_visitor(Allocator &al, LocationManager &lm, const AST::Module_t &ast,
Expand Down
2 changes: 1 addition & 1 deletion tests/reference/asr-expr_07-7742668.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"outfile": null,
"outfile_hash": null,
"stdout": "asr-expr_07-7742668.stdout",
"stdout_hash": "36af7cdd0a8bed977355197e5fb0512e1b23f1b36966eef5e83ef244",
"stdout_hash": "e419602ad57368da314d299a740459bc2fd912e9302ef68207e70105",
"stderr": null,
"stderr_hash": null,
"returncode": 0
Expand Down
2 changes: 1 addition & 1 deletion tests/reference/asr-expr_07-7742668.stdout
Original file line number Diff line number Diff line change
@@ -1 +1 @@
(TranslationUnit (SymbolTable 1 {_global_symbols: (Module (SymbolTable 7 {_lpython_main_program: (Function (SymbolTable 6 {}) _lpython_main_program (FunctionType [] () Source Implementation () .false. .false. .false. .false. .false. [] [] .false.) [f bool_to_str] [] [(SubroutineCall 7 f () [] ()) (SubroutineCall 7 bool_to_str () [] ())] () Public .false. .false.), bool_to_str: (Function (SymbolTable 4 {var: (Variable 4 var [] Local () () Default (Logical 4 []) Source Public Required .false.)}) bool_to_str (FunctionType [] () Source Implementation () .false. .false. .false. .false. .false. [] [] .false.) [] [] [(= (Var 4 var) (LogicalConstant .true. (Logical 4 [])) ()) (Print () [(Cast (LogicalConstant .true. (Logical 4 [])) LogicalToCharacter (Character 1 -2 () []) (StringConstant "True" (Character 1 4 () [])))] () ()) (Assert (StringCompare (Cast (Var 4 var) LogicalToCharacter (Character 1 -2 () []) ()) Eq (StringConstant "True" (Character 1 4 () [])) (Logical 4 []) ()) ()) (= (Var 4 var) (LogicalConstant .false. (Logical 4 [])) ()) (Assert (StringCompare (Cast (Var 4 var) LogicalToCharacter (Character 1 -2 () []) ()) Eq (StringConstant "False" (Character 1 5 () [])) (Logical 4 []) ()) ()) (Assert (StringCompare (Cast (LogicalConstant .true. (Logical 4 [])) LogicalToCharacter (Character 1 -2 () []) (StringConstant "True" (Character 1 4 () []))) Eq (StringConstant "True" (Character 1 4 () [])) (Logical 4 []) (LogicalConstant .true. (Logical 4 []))) ())] () Public .false. .false.), f: (Function (SymbolTable 3 {a: (Variable 3 a [] Local (IntegerConstant 5 (Integer 4 [])) () Default (Integer 4 []) Source Public Required .false.), b: (Variable 3 b [x] Local (IntegerBinOp (Var 3 x) Add (IntegerConstant 1 (Integer 4 [])) (Integer 4 []) ()) () Default (Integer 4 []) Source Public Required .false.), x: (Variable 3 x [] Local (IntegerConstant 3 (Integer 4 [])) () Default (Integer 4 []) Source Public Required .false.)}) f (FunctionType [] () Source Implementation () .false. .false. .false. .false. .false. [] [] .false.) [g] [] [(= (Var 3 a) (IntegerConstant 5 (Integer 4 [])) ()) (= (Var 3 x) (IntegerConstant 3 (Integer 4 [])) ()) (= (Var 3 x) (IntegerConstant 5 (Integer 4 [])) ()) (= (Var 3 b) (IntegerBinOp (Var 3 x) Add (IntegerConstant 1 (Integer 4 [])) (Integer 4 []) ()) ()) (Print () [(Var 3 a) (Var 3 b)] () ()) (Assert (IntegerCompare (Var 3 b) Eq (IntegerConstant 6 (Integer 4 [])) (Logical 4 []) ()) ()) (SubroutineCall 7 g () [((IntegerBinOp (IntegerBinOp (Var 3 a) Mul (Var 3 b) (Integer 4 []) ()) Add (IntegerConstant 3 (Integer 4 [])) (Integer 4 []) ()))] ())] () Public .false. .false.), g: (Function (SymbolTable 2 {x: (Variable 2 x [] In () () Default (Integer 4 []) Source Public Required .false.)}) g (FunctionType [(Integer 4 [])] () Source Implementation () .false. .false. .false. .false. .false. [] [] .false.) [] [(Var 2 x)] [(Print () [(Var 2 x)] () ())] () Public .false. .false.), x: (Variable 7 x [] Local (IntegerConstant 7 (Integer 4 [])) () Default (Integer 4 []) Source Public Required .false.)}) _global_symbols [] .false. .false.), main_program: (Program (SymbolTable 5 {_lpython_main_program: (ExternalSymbol 5 _lpython_main_program 7 _lpython_main_program _global_symbols [] _lpython_main_program Public)}) main_program [_global_symbols] [(SubroutineCall 5 _lpython_main_program () [] ())])}) [])
(TranslationUnit (SymbolTable 1 {_global_symbols: (Module (SymbolTable 7 {_lpython_main_program: (Function (SymbolTable 6 {}) _lpython_main_program (FunctionType [] () Source Implementation () .false. .false. .false. .false. .false. [] [] .false.) [f bool_to_str] [] [(= (Var 7 x) (IntegerConstant 7 (Integer 4 [])) ()) (SubroutineCall 7 f () [] ()) (SubroutineCall 7 bool_to_str () [] ())] () Public .false. .false.), bool_to_str: (Function (SymbolTable 4 {var: (Variable 4 var [] Local () () Default (Logical 4 []) Source Public Required .false.)}) bool_to_str (FunctionType [] () Source Implementation () .false. .false. .false. .false. .false. [] [] .false.) [] [] [(= (Var 4 var) (LogicalConstant .true. (Logical 4 [])) ()) (Print () [(Cast (LogicalConstant .true. (Logical 4 [])) LogicalToCharacter (Character 1 -2 () []) (StringConstant "True" (Character 1 4 () [])))] () ()) (Assert (StringCompare (Cast (Var 4 var) LogicalToCharacter (Character 1 -2 () []) ()) Eq (StringConstant "True" (Character 1 4 () [])) (Logical 4 []) ()) ()) (= (Var 4 var) (LogicalConstant .false. (Logical 4 [])) ()) (Assert (StringCompare (Cast (Var 4 var) LogicalToCharacter (Character 1 -2 () []) ()) Eq (StringConstant "False" (Character 1 5 () [])) (Logical 4 []) ()) ()) (Assert (StringCompare (Cast (LogicalConstant .true. (Logical 4 [])) LogicalToCharacter (Character 1 -2 () []) (StringConstant "True" (Character 1 4 () []))) Eq (StringConstant "True" (Character 1 4 () [])) (Logical 4 []) (LogicalConstant .true. (Logical 4 []))) ())] () Public .false. .false.), f: (Function (SymbolTable 3 {a: (Variable 3 a [] Local (IntegerConstant 5 (Integer 4 [])) () Default (Integer 4 []) Source Public Required .false.), b: (Variable 3 b [x] Local (IntegerBinOp (Var 3 x) Add (IntegerConstant 1 (Integer 4 [])) (Integer 4 []) ()) () Default (Integer 4 []) Source Public Required .false.), x: (Variable 3 x [] Local (IntegerConstant 3 (Integer 4 [])) () Default (Integer 4 []) Source Public Required .false.)}) f (FunctionType [] () Source Implementation () .false. .false. .false. .false. .false. [] [] .false.) [g] [] [(= (Var 3 a) (IntegerConstant 5 (Integer 4 [])) ()) (= (Var 3 x) (IntegerConstant 3 (Integer 4 [])) ()) (= (Var 3 x) (IntegerConstant 5 (Integer 4 [])) ()) (= (Var 3 b) (IntegerBinOp (Var 3 x) Add (IntegerConstant 1 (Integer 4 [])) (Integer 4 []) ()) ()) (Print () [(Var 3 a) (Var 3 b)] () ()) (Assert (IntegerCompare (Var 3 b) Eq (IntegerConstant 6 (Integer 4 [])) (Logical 4 []) ()) ()) (SubroutineCall 7 g () [((IntegerBinOp (IntegerBinOp (Var 3 a) Mul (Var 3 b) (Integer 4 []) ()) Add (IntegerConstant 3 (Integer 4 [])) (Integer 4 []) ()))] ())] () Public .false. .false.), g: (Function (SymbolTable 2 {x: (Variable 2 x [] In () () Default (Integer 4 []) Source Public Required .false.)}) g (FunctionType [(Integer 4 [])] () Source Implementation () .false. .false. .false. .false. .false. [] [] .false.) [] [(Var 2 x)] [(Print () [(Var 2 x)] () ())] () Public .false. .false.), x: (Variable 7 x [] Local () () Default (Integer 4 []) Source Public Required .false.)}) _global_symbols [] .false. .false.), main_program: (Program (SymbolTable 5 {_lpython_main_program: (ExternalSymbol 5 _lpython_main_program 7 _lpython_main_program _global_symbols [] _lpython_main_program Public)}) main_program [_global_symbols] [(SubroutineCall 5 _lpython_main_program () [] ())])}) [])
2 changes: 1 addition & 1 deletion tests/reference/llvm-print_04-443a8d8.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"outfile": null,
"outfile_hash": null,
"stdout": "llvm-print_04-443a8d8.stdout",
"stdout_hash": "82ed5155deef7a5b597b62a427ca9c33a1dd6650f757223c5c604889",
"stdout_hash": "740498e6d0b9c0a6ffc7f6711e1f8c5bae27b834b726d108fb0e0c18",
"stderr": null,
"stderr_hash": null,
"returncode": 0
Expand Down
12 changes: 8 additions & 4 deletions tests/reference/llvm-print_04-443a8d8.stdout
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
; ModuleID = 'LFortran'
source_filename = "LFortran"

@u = global i64 -922337203685477580
@x = global i32 -2147483648
@y = global i16 -32768
@z = global i8 -128
@u = global i64 0
@x = global i32 0
@y = global i16 0
@z = global i8 0
@0 = private unnamed_addr constant [2 x i8] c" \00", align 1
@1 = private unnamed_addr constant [2 x i8] c"\0A\00", align 1
@2 = private unnamed_addr constant [7 x i8] c"%lld%s\00", align 1
Expand All @@ -20,6 +20,10 @@ source_filename = "LFortran"

define void @__module__global_symbols__lpython_main_program() {
.entry:
store i64 -922337203685477580, i64* @u, align 4
store i32 -2147483648, i32* @x, align 4
store i16 -32768, i16* @y, align 2
store i8 -128, i8* @z, align 1
%0 = load i64, i64* @u, align 4
call void (i8*, ...) @_lfortran_printf(i8* getelementptr inbounds ([7 x i8], [7 x i8]* @2, i32 0, i32 0), i64 %0, i8* getelementptr inbounds ([2 x i8], [2 x i8]* @1, i32 0, i32 0))
%1 = load i32, i32* @x, align 4
Expand Down

0 comments on commit 78c8347

Please sign in to comment.