Skip to content

Commit

Permalink
Implement global struct to hold needed values for nested functions
Browse files Browse the repository at this point in the history
  • Loading branch information
dpoerio committed Apr 16, 2021
1 parent dcbbacb commit 8b82c33
Show file tree
Hide file tree
Showing 16 changed files with 543 additions and 7 deletions.
2 changes: 2 additions & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -252,5 +252,7 @@ RUN(NAME string3 LABELS gfortran llvm)

RUN(NAME nested_01 LABELS gfortran llvm)
RUN(NAME nested_02 LABELS gfortran llvm)
RUN(NAME nested_03 LABELS gfortran llvm)
RUN(NAME nested_04 LABELS gfortran llvm)

RUN(NAME intent_01 LABELS gfortran llvm)
25 changes: 25 additions & 0 deletions integration_tests/nested_03.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
module nested_03_a
implicit none

contains

subroutine b()
real :: x = 6
print *, "b()"
call c()
contains
subroutine c()
print *, 5
print *, x
end subroutine c
end subroutine b

end module

program nested_03
use nested_03_a, only: b
implicit none

call b()

end
31 changes: 31 additions & 0 deletions integration_tests/nested_04.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
module nested_04_a
implicit none

contains

integer function b(x)
integer, intent(in) :: x
integer y
real :: yy = 6.6
y = x
print *, "b()"
b = c(6)
contains
integer function c(z)
integer, intent(in) :: z
print *, z
print *, y
print *, yy
c = z
end function c
end function b

end module

program nested_04
use nested_04_a, only: b
implicit none
integer test
test = b(5)

end
1 change: 1 addition & 0 deletions src/lfortran/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ if (WITH_LLVM)
set(SRC ${SRC}
codegen/evaluator.cpp
codegen/asr_to_llvm.cpp
pass/nested_vars.cpp
)
# We use deprecated API in LLVM, so we disable the warning until we upgrade
if (NOT MSVC)
Expand Down
110 changes: 103 additions & 7 deletions src/lfortran/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include <lfortran/pass/select_case.h>
#include <lfortran/pass/global_stmts.h>
#include <lfortran/pass/param_to_const.h>
#include <lfortran/pass/nested_vars.h>
#include <lfortran/exception.h>
#include <lfortran/asr_utils.h>
#include <lfortran/pickle.h>
Expand Down Expand Up @@ -135,6 +136,18 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
std::map<uint64_t, llvm::Value*> llvm_symtab; // llvm_symtab_value
std::map<uint64_t, llvm::Function*> llvm_symtab_fn;

// Data members for handling nested functions
std::vector<uint64_t> needed_globals; /* For saving the hash of variables
from a parent scope needed in a nested function */
std::map<uint64_t, std::vector<llvm::Type*>> runtime_descriptor; /* For
saving the hash of a parent function needing to give access to
variables in a nested function, as well as the variable types */
llvm::StructType* needed_global_struct; /*The struct type that will hold
variables needed in a nested function; will contain types as given in
the runtime descriptor member */
std::string desc_name; // For setting the name of the global struct


ASRToLLVMVisitor(llvm::LLVMContext &context) : context(context),
prototype_only(false), dim_des(llvm::StructType::create(
context,
Expand Down Expand Up @@ -483,6 +496,13 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
return builder->CreateLoad(pres);
}

llvm::Value *nested_struct_rd(std::vector<llvm::Value*> vals,
llvm::StructType* rd) {
llvm::AllocaInst *pres = builder->CreateAlloca(rd, nullptr);
llvm::Value *pim = builder->CreateGEP(pres, vals);
return builder->CreateLoad(pim);
}

/**
* @brief This function generates the
* @detail This is converted to
Expand Down Expand Up @@ -758,6 +778,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>

template<typename T>
void declare_vars(const T &x) {
std::vector<llvm::Value*> needed_glob_vals;
llvm::Value *target_var;
for (auto &item : x.m_symtab->scope) {
if (is_a<ASR::Variable_t>(*item.second)) {
ASR::Variable_t *v = down_cast<ASR::Variable_t>(item.second);
Expand Down Expand Up @@ -835,10 +857,22 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
llvm_symtab[h] = ptr;
fill_array_details(ptr, m_dims, n_dims);
if( v->m_value != nullptr ) {
llvm::Value *target_var = ptr;
target_var = ptr;
this->visit_expr_wrapper(v->m_value, true);
llvm::Value *init_value = tmp;
needed_glob_vals.push_back(tmp);
builder->CreateStore(init_value, target_var);

auto finder = std::find(needed_globals.begin(),
needed_globals.end(), h);
if (finder != needed_globals.end()) {
llvm::Value* ptr = module->getOrInsertGlobal(desc_name,
needed_global_struct);
int idx = std::distance(needed_globals.begin(),
finder);
builder->CreateStore(builder->CreateLoad(target_var),
create_gep(ptr, idx));
}
}
}
}
Expand Down Expand Up @@ -910,10 +944,25 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
x.m_abi != ASR::abiType::Interactive) {
return;
}
// Check if the procedure has a nested function that needs access to
// some variables in its local scope
uint32_t h = get_hash((ASR::asr_t*)&x);
std::vector<llvm::Type*> nested_type;
if (runtime_descriptor[h].size() > 0) {
nested_type = runtime_descriptor[h];
needed_global_struct = llvm::StructType::create(
context, nested_type, x.m_name);
desc_name = x.m_name;
std::string desc_string = "_rtd";
desc_name += desc_string;
module->getOrInsertGlobal(desc_name, needed_global_struct);
llvm::ConstantAggregateZero* initializer =
llvm::ConstantAggregateZero::get(needed_global_struct);
module->getNamedGlobal(desc_name)->setInitializer(initializer);
}
visit_procedures(x);
bool interactive = (x.m_abi == ASR::abiType::Interactive);

uint32_t h = get_hash((ASR::asr_t*)&x);
llvm::Function *F = nullptr;
if (llvm_symtab_fn.find(h) != llvm_symtab_fn.end()) {
/*
Expand Down Expand Up @@ -992,8 +1041,23 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
return;
}
bool interactive = (x.m_abi == ASR::abiType::Interactive);
visit_procedures(x);
// Check if the procedure has a nested function that needs access to
// some variables in its local scope
uint32_t h = get_hash((ASR::asr_t*)&x);
std::vector<llvm::Type*> nested_type;
if (runtime_descriptor[h].size() > 0) {
nested_type = runtime_descriptor[h];
needed_global_struct = llvm::StructType::create(
context, nested_type, x.m_name);
desc_name = x.m_name;
std::string desc_string = "_rtd";
desc_name += desc_string;
module->getOrInsertGlobal(desc_name, needed_global_struct);
llvm::ConstantAggregateZero* initializer =
llvm::ConstantAggregateZero::get(needed_global_struct);
module->getNamedGlobal(desc_name)->setInitializer(initializer);
}
visit_procedures(x);
llvm::Function *F = nullptr;
if (llvm_symtab_fn.find(h) != llvm_symtab_fn.end()) {
/*
Expand Down Expand Up @@ -1054,12 +1118,13 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>

void visit_Assignment(const ASR::Assignment_t &x) {
llvm::Value *target, *value;
uint32_t h;
if( x.m_target->type == ASR::exprType::ArrayRef ) {
this->visit_expr(*x.m_target);
target = tmp;
} else {
ASR::Variable_t *asr_target = EXPR2VAR(x.m_target);
uint32_t h = get_hash((ASR::asr_t*)asr_target);
h = get_hash((ASR::asr_t*)asr_target);
switch( asr_target->m_type->type ) {
case ASR::ttypeType::IntegerPointer:
case ASR::ttypeType::RealPointer:
Expand All @@ -1079,6 +1144,17 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
this->visit_expr_wrapper(x.m_value, true);
value = tmp;
builder->CreateStore(value, target);

auto finder = std::find(needed_globals.begin(),
needed_globals.end(), h);
if (finder != needed_globals.end()) {
llvm::Value* ptr = module->getOrInsertGlobal(desc_name,
needed_global_struct);
int idx = std::distance(needed_globals.begin(),
finder);
builder->CreateStore(builder->CreateLoad(target),
create_gep(ptr, idx));
}
}

inline void visit_expr_wrapper(const ASR::expr_t* x, bool load_array_ref=false) {
Expand Down Expand Up @@ -1542,9 +1618,26 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>

inline void fetch_val(ASR::Variable_t* x) {
uint32_t x_h = get_hash((ASR::asr_t*)x);
LFORTRAN_ASSERT(llvm_symtab.find(x_h) != llvm_symtab.end());
llvm::Value* x_v = llvm_symtab[x_h];
tmp = builder->CreateLoad(x_v);
llvm::Value* x_v;
// Check if x is a needed global here, if so, it should exist as an
// element in the runtime descriptor, get element pointer and create
// load
if (llvm_symtab.find(x_h) == llvm_symtab.end()) {
LFORTRAN_ASSERT(std::find(needed_globals.begin(),
needed_globals.end(), x_h) != needed_globals.end());
auto finder = std::find(needed_globals.begin(),
needed_globals.end(), x_h);
llvm::Constant *ptr = module->getOrInsertGlobal(desc_name,
needed_global_struct);
int idx = std::distance(needed_globals.begin(), finder);
std::vector<llvm::Value*> idx_vec = {
llvm::ConstantInt::get(context, llvm::APInt(32, 0)),
llvm::ConstantInt::get(context, llvm::APInt(32, idx))};
x_v = builder->CreateGEP(ptr, idx_vec);
} else {
x_v = llvm_symtab[x_h];
}
tmp = builder->CreateLoad(x_v);
}

inline void fetch_var(ASR::Variable_t* x) {
Expand Down Expand Up @@ -2064,8 +2157,11 @@ std::unique_ptr<LLVMModule> asr_to_llvm(ASR::TranslationUnit_t &asr,
pass_replace_param_to_const(al, asr);
// Uncomment for debugging the ASR after the transformation
// std::cout << pickle(asr) << std::endl;

pass_replace_do_loops(al, asr);
pass_replace_select_case(al, asr);
v.runtime_descriptor = pass_find_nested_vars(asr, context,
v.needed_globals);
v.visit_asr((ASR::asr_t&)asr);
std::string msg;
llvm::raw_string_ostream err(msg);
Expand Down
Loading

0 comments on commit 8b82c33

Please sign in to comment.