Skip to content

Commit

Permalink
Merge branch 'llvm_classes' into 'master'
Browse files Browse the repository at this point in the history
Initial implementation of classes in LLVM backend

See merge request lfortran/lfortran!1079
  • Loading branch information
dpoerio committed Jul 19, 2021
2 parents 8561793 + 76aef43 commit 96c5ea8
Show file tree
Hide file tree
Showing 105 changed files with 319 additions and 113 deletions.
4 changes: 2 additions & 2 deletions grammar/ASR.asdl
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ stmt
| Return()
| Select(expr test, case_stmt* body, stmt* default)
| Stop(expr? code)
| SubroutineCall(symbol name, symbol? original_name, expr* args)
| SubroutineCall(symbol name, symbol? original_name, expr* args, expr? dt)
| Where(expr test, stmt* body, stmt* orelse)
| WhileLoop(expr test, stmt* body)

Expand All @@ -174,7 +174,7 @@ expr
| UnaryOp(unaryop op, expr operand, ttype type, expr? value)
| Compare(expr left, cmpop op, expr right, ttype type, expr? value)
| FunctionCall(symbol name, symbol? original_name, expr* args,
keyword* keywords, ttype type, expr? value)
keyword* keywords, ttype type, expr? value, expr? dt)
| ConstantArray(expr* args, ttype type)
| ImpliedDoLoop(expr* values, expr var, expr start, expr end,
expr? increment, ttype type, expr? value)
Expand Down
2 changes: 1 addition & 1 deletion integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -306,4 +306,4 @@ RUN(NAME return_02 LABELS gfortran llvm)
RUN(NAME return_03 LABELS gfortran llvm)

RUN(NAME class_01 LABELS gfortran)
RUN(NAME class_02 LABELS gfortran)
RUN(NAME class_02 LABELS gfortran llvm)
1 change: 1 addition & 0 deletions integration_tests/class_01.f90
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ module class_Circle
end type Circle
contains
function circle_area(this) result(area)
! F2003 standard 4.5.3.3 passed object dummy argument
class(Circle), intent(in) :: this
real :: area
area = pi * this%radius**2
Expand Down
6 changes: 3 additions & 3 deletions integration_tests/class_02.f90
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module class_Circle
module class_circle2
implicit none
private
real :: pi = 3.1415926535897931d0 ! Class-wide private constant
Expand All @@ -21,11 +21,11 @@ subroutine circle_print(this)
area = this%circle_area() ! Call the type-bound function
print *, 'Circle: r = ', this%radius, ' area = ', area
end subroutine circle_print
end module class_Circle
end module class_Circle2


program circle_test
use class_Circle, only: Circle
use class_circle2, only: Circle
implicit none

type(Circle) :: c ! Declare a variable of type Circle.
Expand Down
131 changes: 126 additions & 5 deletions src/lfortran/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,69 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
}


llvm::Type* getClassType(ASR::ttype_t* _type, bool is_pointer=false) {
ASR::Class_t* der = (ASR::Class_t*)(&(_type->base));
ASR::symbol_t* der_sym;
if( der->m_class_type->type == ASR::symbolType::ExternalSymbol ) {
ASR::ExternalSymbol_t* der_extr = (ASR::ExternalSymbol_t*)(&(der->m_class_type->base));
der_sym = der_extr->m_external;
} else {
der_sym = der->m_class_type;
}
ASR::ClassType_t* der_type = (ASR::ClassType_t*)(&(der_sym->base));
std::string der_type_name = std::string(der_type->m_name);
llvm::StructType* der_type_llvm;
if( name2dertype.find(der_type_name) != name2dertype.end() ) {
der_type_llvm = name2dertype[der_type_name];
} else {
std::map<std::string, ASR::symbol_t*> scope = der_type->m_symtab->scope;
std::vector<llvm::Type*> member_types;
int member_idx = 0;
for( auto itr = scope.begin(); itr != scope.end(); itr++ ) {
if (!ASR::is_a<ASR::ClassProcedure_t>(*itr->second)) {
ASR::Variable_t* member = (ASR::Variable_t*)(&(itr->second->base));
llvm::Type* mem_type = nullptr;
switch( member->m_type->type ) {
case ASR::ttypeType::Integer: {
int a_kind = down_cast<ASR::Integer_t>(member->m_type)->m_kind;
mem_type = getIntType(a_kind);
break;
}
case ASR::ttypeType::Real: {
int a_kind = down_cast<ASR::Real_t>(member->m_type)->m_kind;
mem_type = getFPType(a_kind);
break;
}
case ASR::ttypeType::Class: {
mem_type = getClassType(member->m_type);
break;
}
case ASR::ttypeType::Complex: {
int a_kind = down_cast<ASR::Complex_t>(member->m_type)->m_kind;
mem_type = getComplexType(a_kind);
break;
}
default:
throw SemanticError("Cannot identify the type of member, '" +
std::string(member->m_name) +
"' in derived type, '" + der_type_name + "'.",
member->base.base.loc);
}
member_types.push_back(mem_type);
name2memidx[der_type_name][std::string(member->m_name)] = member_idx;
member_idx++;
}
}
der_type_llvm = llvm::StructType::create(context, member_types, der_type_name);
name2dertype[der_type_name] = der_type_llvm;
}
if( is_pointer ) {
return der_type_llvm->getPointerTo();
}
return (llvm::Type*) der_type_llvm;
}


/*
* Dispatches the required function from runtime library to
* perform the specified binary operation.
Expand Down Expand Up @@ -1230,6 +1293,24 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
}
break;
}
case (ASR::ttypeType::Class) : {
ASR::Class_t* v_type = down_cast<ASR::Class_t>(arg->m_type);
m_type_ = arg->m_type;
m_dims = v_type->m_dims;
n_dims = v_type->n_dims;
if( n_dims > 0 ) {
is_array_type = true;
llvm::Type* el_type = get_el_type(m_type_, a_kind);
if( v->m_storage == ASR::storage_typeType::Allocatable ) {
type = arr_descr->get_malloc_array_type(m_type_, a_kind, n_dims, el_type, true);
} else {
type = arr_descr->get_array_type(m_type_, a_kind, n_dims, m_dims, el_type, true);
}
} else {
type = getClassType(arg->m_type, true);
}
break;
}
default :
LFORTRAN_ASSERT(false);
}
Expand Down Expand Up @@ -2416,6 +2497,16 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
}
break;
}
case ASR::ttypeType::Class: {
ASR::Class_t* der = (ASR::Class_t*)(&(x->m_type->base));
ASR::ClassType_t* der_type = (ASR::ClassType_t*)(&(der->m_class_type->base));
der_type_name = std::string(der_type->m_name);
uint32_t h = get_hash((ASR::asr_t*)x);
if( llvm_symtab.find(h) != llvm_symtab.end() ) {
tmp = llvm_symtab[h];
}
break;
}
default: {
fetch_val(x);
break;
Expand Down Expand Up @@ -2928,8 +3019,21 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
}

void visit_SubroutineCall(const ASR::SubroutineCall_t &x) {
ASR::Subroutine_t *s = ASR::down_cast<ASR::Subroutine_t>(
symbol_get_past_external(x.m_name));
ASR::Subroutine_t *s;
std::vector<llvm::Value*> args;
const ASR::symbol_t *proc_sym = symbol_get_past_external(x.m_name);
if (x.m_dt){
ASR::Variable_t *caller = EXPR2VAR(x.m_dt);
std::uint32_t h = get_hash((ASR::asr_t*)caller);
args.push_back(llvm_symtab[h]);
}
if (ASR::is_a<ASR::Subroutine_t>(*proc_sym)) {
s = ASR::down_cast<ASR::Subroutine_t>(proc_sym);
} else {
ASR::ClassProcedure_t *clss_proc = ASR::down_cast<
ASR::ClassProcedure_t>(proc_sym);
s = ASR::down_cast<ASR::Subroutine_t>(clss_proc->m_proc);
}
if (parent_function){
push_nested_stack(parent_function);
} else if (parent_subroutine){
Expand Down Expand Up @@ -2958,15 +3062,31 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
} else {
llvm::Function *fn = llvm_symtab_fn[h];
std::string m_name = std::string(((ASR::Subroutine_t*)(&(x.m_name->base)))->m_name);
std::vector<llvm::Value *> args = convert_call_args(x, m_name);
std::vector<llvm::Value *> args2 = convert_call_args(x, m_name);
args.insert(args.end(), args2.begin(), args2.end());
builder->CreateCall(fn, args);
}
calling_function_hash = h;
pop_nested_stack(s);
}

void visit_FunctionCall(const ASR::FunctionCall_t &x) {
ASR::Function_t *s = ASR::down_cast<ASR::Function_t>(symbol_get_past_external(x.m_name));
ASR::Function_t *s;
std::vector<llvm::Value*> args;
const ASR::symbol_t *proc_sym = symbol_get_past_external(x.m_name);
if (x.m_dt){
ASR::Variable_t *caller = EXPR2VAR(x.m_dt);
std::uint32_t h = get_hash((ASR::asr_t*)caller);
args.push_back(llvm_symtab[h]);
}
if (ASR::is_a<ASR::Function_t>(*proc_sym)) {
s = ASR::down_cast<ASR::Function_t>(proc_sym);
} else {
ASR::ClassProcedure_t *clss_proc = ASR::down_cast<
ASR::ClassProcedure_t>(proc_sym);
s = ASR::down_cast<ASR::Function_t>(clss_proc->m_proc);
}
s = ASR::down_cast<ASR::Function_t>(symbol_get_past_external(x.m_name));
if (parent_function){
push_nested_stack(parent_function);
} else if (parent_subroutine){
Expand Down Expand Up @@ -3021,7 +3141,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
} else {
llvm::Function *fn = llvm_symtab_fn[h];
std::string m_name = std::string(((ASR::Function_t*)(&(x.m_name->base)))->m_name);
std::vector<llvm::Value *> args = convert_call_args(x, m_name);
std::vector<llvm::Value *> args2 = convert_call_args(x, m_name);
args.insert(args.end(), args2.begin(), args2.end());
tmp = builder->CreateCall(fn, args);
}
calling_function_hash = h;
Expand Down
2 changes: 1 addition & 1 deletion src/lfortran/pass/array_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ class ArrayOpVisitor : public ASR::BaseWalkVisitor<ArrayOpVisitor>
tmp_val = result_var;
ASR::stmt_t* subrout_call = LFortran::ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc,
current_scope->scope[x_name], nullptr,
s_args.p, s_args.size()));
s_args.p, s_args.size(), nullptr));
array_op_result.push_back(al, subrout_call);
}
result_var = nullptr;
Expand Down
2 changes: 1 addition & 1 deletion src/lfortran/pass/pass_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ namespace LFortran {
LFortran::ASRUtils::symbol_get_past_external(v))->m_return_var)->m_type;
current_scope = current_scope_copy;
return LFortran::ASRUtils::EXPR(ASR::make_FunctionCall_t(al, arr_expr->base.loc, v, nullptr,
args.p, args.size(), nullptr, 0, type, nullptr));
args.p, args.size(), nullptr, 0, type, nullptr, nullptr));
}

ASR::expr_t* to_int32(ASR::expr_t* x, ASR::ttype_t* int64type, Allocator& al) {
Expand Down
27 changes: 22 additions & 5 deletions src/lfortran/semantics/ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1175,7 +1175,7 @@ class SymbolTableVisitor : public AST::BaseVisitor<SymbolTableVisitor>
LFortran::ASRUtils::symbol_get_past_external(v))
->m_return_var)->m_type;
asr = ASR::make_FunctionCall_t(al, x.base.base.loc, v, nullptr,
args.p, args.size(), nullptr, 0, type, nullptr);
args.p, args.size(), nullptr, 0, type, nullptr, nullptr);
}

void visit_DerivedType(const AST::DerivedType_t &x) {
Expand Down Expand Up @@ -2180,9 +2180,13 @@ class BodyVisitor : public AST::BaseVisitor<BodyVisitor>
void visit_SubroutineCall(const AST::SubroutineCall_t &x) {
std::string sub_name = x.m_name;
ASR::symbol_t *original_sym;
ASR::expr_t *v_expr = nullptr;
// If this is a type bound procedure (in a class) it won't be in the
// main symbol table. Need to check n_member.
if (x.n_member == 1) {
ASR::symbol_t *v = current_scope->resolve_symbol(x.m_member[0].m_name);
ASR::asr_t *v_var = ASR::make_Var_t(al, x.base.base.loc, v);
v_expr = LFortran::ASRUtils::EXPR(v_var);
original_sym = resolve_deriv_type_proc(x.base.base.loc, x.m_name,
x.m_member[0].m_name, current_scope);
} else {
Expand Down Expand Up @@ -2262,7 +2266,7 @@ class BodyVisitor : public AST::BaseVisitor<BodyVisitor>
}
}
tmp = ASR::make_SubroutineCall_t(al, x.base.base.loc,
final_sym, original_sym, args.p, args.size());
final_sym, original_sym, args.p, args.size(), v_expr);
}

int select_generic_procedure(const Vec<ASR::expr_t*> &args,
Expand Down Expand Up @@ -2510,10 +2514,21 @@ class BodyVisitor : public AST::BaseVisitor<BodyVisitor>
std::vector<std::string> all_intrinsics = {
"sin", "cos", "tan", "sinh", "cosh", "tanh",
"asin", "acos", "atan", "asinh", "acosh", "atanh"};

SymbolTable *scope = current_scope;
std::string var_name = x.m_func;
ASR::symbol_t *v = scope->resolve_symbol(var_name);
ASR::expr_t *v_expr = nullptr;
// If this is a type bound procedure (in a class) it won't be in the
// main symbol table. Need to check n_member.
if (x.n_member == 1) {
ASR::symbol_t *v = current_scope->resolve_symbol(x.m_member[0].m_name);
ASR::asr_t *v_var = ASR::make_Var_t(al, x.base.base.loc, v);
v_expr = LFortran::ASRUtils::EXPR(v_var);
v = resolve_deriv_type_proc(x.base.base.loc, x.m_func,
x.m_member[0].m_name, scope);
} else {
v = current_scope->resolve_symbol(var_name);
}
if (!v) {
std::string remote_sym = to_lower(var_name);
if (intrinsic_procedures.find(remote_sym)
Expand Down Expand Up @@ -2661,7 +2676,8 @@ class BodyVisitor : public AST::BaseVisitor<BodyVisitor>
ASR::ttype_t *type;
type = LFortran::ASRUtils::EXPR2VAR(ASR::down_cast<ASR::Function_t>(v)->m_return_var)->m_type;
tmp = ASR::make_FunctionCall_t(al, x.base.base.loc,
v, nullptr, args.p, args.size(), nullptr, 0, type, nullptr);
v, nullptr, args.p, args.size(), nullptr, 0, type, nullptr,
v_expr);
break;
}
case (ASR::symbolType::ExternalSymbol) : {
Expand All @@ -2672,7 +2688,8 @@ class BodyVisitor : public AST::BaseVisitor<BodyVisitor>
ASR::ttype_t *type;
type = LFortran::ASRUtils::EXPR2VAR(ASR::down_cast<ASR::Function_t>(f2)->m_return_var)->m_type;
tmp = ASR::make_FunctionCall_t(al, x.base.base.loc,
v, nullptr, args.p, args.size(), nullptr, 0, type, nullptr);
v, nullptr, args.p, args.size(), nullptr, 0, type,
nullptr, nullptr);
} else if (ASR::is_a<ASR::Variable_t>(*f2)) {
Vec<ASR::array_index_t> args;
args.reserve(al, x.n_args);
Expand Down
2 changes: 1 addition & 1 deletion tests/reference/asr-allocate_01-f3446f6.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"outfile": null,
"outfile_hash": null,
"stdout": "asr-allocate_01-f3446f6.stdout",
"stdout_hash": "a5ca5ed5467083745ee498003b2872b30bb5b656540346376d527f50",
"stdout_hash": "b7179ca1ff9a3620bab6be4d80e196aa3666d88052716b3e37b57b3b",
"stderr": null,
"stderr_hash": null,
"returncode": 0
Expand Down
Loading

0 comments on commit 96c5ea8

Please sign in to comment.