Skip to content

Commit

Permalink
Implement generics in ASR
Browse files Browse the repository at this point in the history
Including instantiation.
  • Loading branch information
ansharlubis authored and certik committed Aug 5, 2022
1 parent 45fc4a2 commit 8963921
Show file tree
Hide file tree
Showing 11 changed files with 393 additions and 11 deletions.
8 changes: 5 additions & 3 deletions src/libasr/ASR.asdl
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ symbol
| Subroutine(symbol_table symtab, identifier name, expr* args, stmt* body,
abi abi, access access, deftype deftype, string? bindc_name, bool pure,
bool module)
| Function(symbol_table symtab, identifier name, expr* args, stmt* body,
expr return_var, abi abi, access access, deftype deftype, bool elemental,
string? bindc_name)
| Function(symbol_table symtab, identifier name, expr* args, ttype* type_params,
stmt* body, expr return_var, abi abi, access access, deftype deftype,
bool elemental, string? bindc_name)
| GenericProcedure(symbol_table parent_symtab, identifier name,
symbol* procs, access access)
| CustomOperator(symbol_table parent_symtab, identifier name,
Expand Down Expand Up @@ -239,6 +239,7 @@ expr
| LogicalNot(expr arg, ttype type, expr? value)
| LogicalCompare(expr left, cmpop op, expr right, ttype type, expr? value)
| LogicalBinOp(expr left, logicalbinop op, expr right, ttype type, expr? value)
| TemplateBinOp(expr left, binop op, expr right, ttype type, expr? value)

| ListConstant(expr* args, ttype type)
| ListLen(expr arg, ttype type, expr? value)
Expand Down Expand Up @@ -328,6 +329,7 @@ ttype
| Dict(ttype key_type, ttype value_type)
| Pointer(ttype type)
| CPtr()
| TypeParameter(identifier param, dimension* dims)

binop = Add | Sub | Mul | Div | Pow | BitAnd | BitOr | BitXor | BitLShift | BitRShift

Expand Down
1 change: 1 addition & 0 deletions src/libasr/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ set(SRC
pass/inline_function_calls.cpp
pass/loop_unroll.cpp
pass/dead_code_removal.cpp
pass/instantiate_template.cpp
pass/update_array_dim_intrinsic_calls.cpp

asr_verify.cpp
Expand Down
6 changes: 2 additions & 4 deletions src/libasr/asdl_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,10 +801,8 @@ def visitModule(self, mod):
self.emit("// Expression and statement Duplicator class")
self.emit("")
self.emit("class ExprStmtDuplicator {")
self.emit("private:")
self.emit(" Allocator& al;")
self.emit("")
self.emit("public:")
self.emit(" Allocator &al;")
self.emit(" bool success;")
self.emit(" bool allow_procedure_calls;")
self.emit("")
Expand Down Expand Up @@ -868,7 +866,7 @@ def visitConstructor(self, cons, _):

def make_visitor(self, name, fields):
self.emit("")
self.emit("ASR::asr_t* duplicate_%s(%s_t* x) {" % (name, name), 1)
self.emit("virtual ASR::asr_t* duplicate_%s(%s_t* x) {" % (name, name), 1)
self.used = False
arguments = []
for field in fields:
Expand Down
34 changes: 34 additions & 0 deletions src/libasr/asr_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,10 @@ static inline std::string type_to_str_python(const ASR::ttype_t *t,
ASR::Pointer_t* p = ASR::down_cast<ASR::Pointer_t>(t);
return "Pointer[" + type_to_str_python(p->m_type) + "]";
}
case ASR::ttypeType::TypeParameter: {
ASR::TypeParameter_t *p = ASR::down_cast<ASR::TypeParameter_t>(t);
return p->m_param;
}
default : throw LCompilersException("Not implemented " + std::to_string(t->type));
}
}
Expand Down Expand Up @@ -994,6 +998,21 @@ static inline bool is_logical(ASR::ttype_t &x) {
return ASR::is_a<ASR::Logical_t>(*type_get_past_pointer(&x));
}

static inline bool is_generic(ASR::ttype_t &x) {
return ASR::is_a<ASR::TypeParameter_t>(*type_get_past_pointer(&x));
}

static inline std::string get_parameter_name(const ASR::ttype_t* t) {
switch (t->type) {
case ASR::ttypeType::TypeParameter: {
ASR::TypeParameter_t* tp = ASR::down_cast<ASR::TypeParameter_t>(t);
return tp->m_param;
}
default: throw LCompilersException("Cannot obtain type parameter from this type");
}
}


static inline int get_body_size(ASR::symbol_t* s) {
int n_body = 0;
switch (s->type) {
Expand Down Expand Up @@ -1079,6 +1098,12 @@ inline int extract_dimensions_from_ttype(ASR::ttype_t *x,
m_dims = nullptr;
break;
}
case ASR::ttypeType::TypeParameter: {
ASR::TypeParameter_t* tp = ASR::down_cast<ASR::TypeParameter_t>(x);
n_dims = tp->n_dims;
m_dims = tp->m_dims;
break;
}
default:
throw LCompilersException("Not implemented.");
}
Expand Down Expand Up @@ -1142,6 +1167,13 @@ static inline ASR::ttype_t* duplicate_type(Allocator& al, const ASR::ttype_t* t,
return ASRUtils::TYPE(ASR::make_Pointer_t(al, ptr->base.base.loc,
dup_type));
}
case ASR::ttypeType::TypeParameter: {
ASR::TypeParameter_t* tp = ASR::down_cast<ASR::TypeParameter_t>(t);
ASR::dimension_t* dimsp = dims ? dims->p : tp->m_dims;
size_t dimsn = dims ? dims->n : tp->n_dims;
return ASRUtils::TYPE(ASR::make_TypeParameter_t(al, t->base.loc,
tp->m_param, dimsp, dimsn));
}
default : throw LCompilersException("Not implemented " + std::to_string(t->type));
}
}
Expand Down Expand Up @@ -1306,6 +1338,8 @@ inline bool check_equal_type(ASR::ttype_t* x, ASR::ttype_t* y) {
}
}
return result;
} else if (ASR::is_a<ASR::TypeParameter_t>(*x) && ASR::is_a<ASR::TypeParameter_t>(*y)) {
return true;
}

int64_t x_kind = ASRUtils::extract_kind_from_ttype_t(x);
Expand Down
3 changes: 2 additions & 1 deletion src/libasr/asr_verify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,8 @@ class VerifyVisitor : public BaseWalkVisitor<VerifyVisitor>
require(symtab_sym == current_sym,
"Variable's parent symbol table does not point to it");
require(id_symtab_map.find(symtab->counter) != id_symtab_map.end(),
"Variable::m_parent_symtab must be present in the ASR");
"Variable::m_parent_symtab must be present in the ASR ("
+ std::string(x.m_name) + ")");

if (x.m_symbolic_value)
visit_expr(*x.m_symbolic_value);
Expand Down
2 changes: 2 additions & 0 deletions src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1564,6 +1564,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
}
}
llvm_symtab[h] = ptr;
} else if (x.m_type->type == ASR::ttypeType::TypeParameter) {
// Ignore type variables
} else {
throw CodeGenError("Variable type not supported", x.base.base.loc);
}
Expand Down
4 changes: 2 additions & 2 deletions src/libasr/codegen/asr_to_wasm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ class ASRToWASMVisitor : public ASR::BaseVisitor<ASRToWASMVisitor> {
}

auto func = ASR::make_Function_t(m_al, x.base.base.loc, x.m_global_scope, s2c(m_al, import_func.name),
params.data(), params.size(), nullptr, 0, nullptr, ASR::abiType::Source, ASR::accessType::Public,
params.data(), params.size(), nullptr, 0, nullptr, 0, nullptr, ASR::abiType::Source, ASR::accessType::Public,
ASR::deftypeType::Implementation, false, nullptr);
m_import_func_asr_map[import_func.name] = func;

Expand Down Expand Up @@ -275,7 +275,7 @@ class ASRToWASMVisitor : public ASR::BaseVisitor<ASRToWASMVisitor> {

// Generate main program code
auto main_func = ASR::make_Function_t(m_al, x.base.base.loc, x.m_symtab, s2c(m_al, "_lcompilers_main"),
nullptr, 0, x.m_body, x.n_body, nullptr, ASR::abiType::Source, ASR::accessType::Public,
nullptr, 0, nullptr, 0, x.m_body, x.n_body, nullptr, ASR::abiType::Source, ASR::accessType::Public,
ASR::deftypeType::Implementation, false, nullptr);
this->visit_Function(*((ASR::Function_t *)main_func));
}
Expand Down
2 changes: 2 additions & 0 deletions src/libasr/pass/global_stmts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ void pass_wrap_global_stmts_into_function(Allocator &al,
/* a_name */ fn_name,
/* a_args */ nullptr,
/* n_args */ 0,
/* a_type_params */ nullptr,
/* n_type_params*/ 0,
/* a_body */ body.p,
/* n_body */ body.size(),
/* a_return_var */ return_var_ref,
Expand Down
Loading

0 comments on commit 8963921

Please sign in to comment.