Skip to content

Commit

Permalink
Redesigned operator overloading mechanism with CustomOperator
Browse files Browse the repository at this point in the history
  • Loading branch information
czgdp1807 committed Sep 22, 2021
1 parent 87b703c commit 5a8ae34
Show file tree
Hide file tree
Showing 12 changed files with 160 additions and 118 deletions.
4 changes: 3 additions & 1 deletion grammar/ASR.asdl
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ symbol
string? bindc_name)
| GenericProcedure(symbol_table parent_symtab, identifier name,
symbol* procs, access access)
| CustomOperator(symbol_table parent_symtab, identifier name,
symbol* procs, access access)
| ExternalSymbol(symbol_table parent_symtab, identifier name,
symbol external, identifier module_name, identifier* scope_names,
identifier original_name, access access)
Expand Down Expand Up @@ -179,7 +181,7 @@ stmt

expr
= BoolOp(expr left, boolop op, expr right, ttype type, expr? value)
| BinOp(expr left, binop op, expr right, ttype type, expr? value)
| BinOp(expr left, binop op, expr right, ttype type, expr? value, expr? overloaded)
| StrOp(expr left, strop op, expr right, ttype type, expr? value)
| UnaryOp(unaryop op, expr operand, ttype type, expr? value)
| Compare(expr left, cmpop op, expr right, ttype type, expr? value)
Expand Down
85 changes: 85 additions & 0 deletions src/lfortran/asr_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,91 @@ ASR::asr_t* getDerivedRef_t(Allocator& al, const Location& loc,
}
return ASR::make_DerivedRef_t(al, loc, LFortran::ASRUtils::EXPR(v_var), member, member_type, nullptr);
}

bool use_overloaded(ASR::expr_t* left, ASR::expr_t* right,
ASR::binopType op, std::string& intrinsic_op_name,
SymbolTable* curr_scope, ASR::asr_t*& asr,
Allocator &al, const Location& loc) {
ASR::ttype_t *left_type = LFortran::ASRUtils::expr_type(left);
ASR::ttype_t *right_type = LFortran::ASRUtils::expr_type(right);
bool found = false;
if( is_op_overloaded(op, intrinsic_op_name, curr_scope) ) {
ASR::symbol_t* sym = curr_scope->scope[intrinsic_op_name];
const ASR::symbol_t* orig_sym = ASRUtils::symbol_get_past_external(sym);
ASR::CustomOperator_t* gen_proc = ASR::down_cast<ASR::CustomOperator_t>(orig_sym);
for( size_t i = 0; i < gen_proc->n_procs && !found; i++ ) {
ASR::symbol_t* proc = gen_proc->m_procs[i];
switch(proc->type) {
case ASR::symbolType::Function: {
ASR::Function_t* func = ASR::down_cast<ASR::Function_t>(proc);
if( func->n_args == 2 ) {
ASR::ttype_t* left_arg_type = ASRUtils::expr_type(func->m_args[0]);
ASR::ttype_t* right_arg_type = ASRUtils::expr_type(func->m_args[1]);
if( left_arg_type->type == left_type->type &&
right_arg_type->type == right_type->type ) {
found = true;
Vec<ASR::expr_t*> a_args;
a_args.reserve(al, 2);
a_args.push_back(al, left);
a_args.push_back(al, right);
asr = ASR::make_FunctionCall_t(al, loc, curr_scope->scope[std::string(func->m_name)], nullptr,
a_args.p, 2, nullptr, 0,
ASRUtils::expr_type(func->m_return_var),
nullptr, nullptr);
}
}
break;
}
default: {
throw SemanticError("While overloading binary operators only functions can be used",
proc->base.loc);
}
}
}
}
return found;
}

bool is_op_overloaded(ASR::binopType op, std::string& intrinsic_op_name,
SymbolTable* curr_scope) {
bool result = true;
switch(op) {
case ASR::binopType::Add: {
if(intrinsic_op_name != "~add") {
result = false;
}
break;
}
case ASR::binopType::Sub: {
if(intrinsic_op_name != "~sub") {
result = false;
}
break;
}
case ASR::binopType::Mul: {
if(intrinsic_op_name != "~mul") {
result = false;
}
break;
}
case ASR::binopType::Div: {
if(intrinsic_op_name != "~div") {
result = false;
}
break;
}
case ASR::binopType::Pow: {
if(intrinsic_op_name != "~pow") {
result = false;
}
break;
}
}
if( result && curr_scope->scope.find(intrinsic_op_name) == curr_scope->scope.end() ) {
result = false;
}
return result;
}
} // namespace ASRUtils


Expand Down
14 changes: 14 additions & 0 deletions src/lfortran/asr_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ static inline char *symbol_name(const ASR::symbol_t *f)
case ASR::symbolType::ClassProcedure: {
return ASR::down_cast<ASR::ClassProcedure_t>(f)->m_name;
}
case ASR::symbolType::CustomOperator: {
return ASR::down_cast<ASR::CustomOperator_t>(f)->m_name;
}
default : throw LFortranException("Not implemented");
}
}
Expand Down Expand Up @@ -179,6 +182,9 @@ static inline SymbolTable *symbol_parent_symtab(const ASR::symbol_t *f)
case ASR::symbolType::ClassProcedure: {
return ASR::down_cast<ASR::ClassProcedure_t>(f)->m_parent_symtab;
}
case ASR::symbolType::CustomOperator: {
return ASR::down_cast<ASR::CustomOperator_t>(f)->m_parent_symtab;
}
default : throw LFortranException("Not implemented");
}
}
Expand Down Expand Up @@ -293,6 +299,14 @@ ASR::asr_t* getDerivedRef_t(Allocator& al, const Location& loc,
ASR::asr_t* v_var, ASR::symbol_t* member,
SymbolTable* current_scope);

bool use_overloaded(ASR::expr_t* left, ASR::expr_t* right,
ASR::binopType op, std::string& intrinsic_op_name,
SymbolTable* curr_scope, ASR::asr_t*& asr,
Allocator &al, const Location& loc);

bool is_op_overloaded(ASR::binopType op, std::string& intrinsic_op_name,
SymbolTable* curr_scope);

void set_intrinsic(ASR::symbol_t* sym);

static inline int extract_kind_from_ttype_t(const ASR::ttype_t* curr_type) {
Expand Down
4 changes: 4 additions & 0 deletions src/lfortran/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2612,6 +2612,10 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
}

void visit_BinOp(const ASR::BinOp_t &x) {
if( x.m_overloaded ) {
this->visit_expr(*x.m_overloaded);
return ;
}
if (x.m_value) {
this->visit_expr_wrapper(x.m_value, true);
return;
Expand Down
8 changes: 4 additions & 4 deletions src/lfortran/pass/arr_slice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,16 +146,16 @@ class ArrSliceVisitor : public ASR::BaseWalkVisitor<ArrSliceVisitor>

ASR::expr_t* gap = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, x.base.base.loc,
end, ASR::binopType::Sub, start,
int32_type, nullptr));
int32_type, nullptr, nullptr));
// ASR::expr_t* slice_size = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, x.base.base.loc,
// gap, ASR::binopType::Add, const_1,
// int64_type, nullptr));
ASR::expr_t* slice_size = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, x.base.base.loc,
gap, ASR::binopType::Div, step,
int32_type, nullptr));
int32_type, nullptr, nullptr));
ASR::expr_t* actual_size = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, x.base.base.loc,
slice_size, ASR::binopType::Add, const_1,
int32_type, nullptr));
int32_type, nullptr, nullptr));
ASR::dimension_t curr_dim;
curr_dim.loc = x.base.base.loc;
curr_dim.m_start = const_1;
Expand Down Expand Up @@ -281,7 +281,7 @@ class ArrSliceVisitor : public ASR::BaseWalkVisitor<ArrSliceVisitor>
doloop_body.push_back(al, set_to_one);
doloop_body.push_back(al, doloop);
}
ASR::expr_t* inc_expr = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, x.base.base.loc, idx_vars_target[i], ASR::binopType::Add, const_1, int32_type, nullptr));
ASR::expr_t* inc_expr = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, x.base.base.loc, idx_vars_target[i], ASR::binopType::Add, const_1, int32_type, nullptr, nullptr));
ASR::stmt_t* assign_stmt = LFortran::ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, idx_vars_target[i], inc_expr));
doloop_body.push_back(al, assign_stmt);
doloop = LFortran::ASRUtils::STMT(ASR::make_DoLoop_t(al, x.base.base.loc, head, doloop_body.p, doloop_body.size()));
Expand Down
12 changes: 8 additions & 4 deletions src/lfortran/pass/array_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,8 @@ class ArrayOpVisitor : public ASR::BaseWalkVisitor<ArrayOpVisitor>
case ASR::exprType::BinOp:
op_el_wise = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(
al, x.base.base.loc,
ref_1, (ASR::binopType)x.m_op, ref_2, x.m_type, nullptr));
ref_1, (ASR::binopType)x.m_op, ref_2,
x.m_type, nullptr, nullptr));
break;
case ASR::exprType::Compare:
op_el_wise = LFortran::ASRUtils::EXPR(ASR::make_Compare_t(
Expand All @@ -541,7 +542,9 @@ class ArrayOpVisitor : public ASR::BaseWalkVisitor<ArrayOpVisitor>
doloop_body.push_back(al, set_to_one);
doloop_body.push_back(al, doloop);
}
ASR::expr_t* inc_expr = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, x.base.base.loc, idx_vars_value[i], ASR::binopType::Add, const_1, int32_type, nullptr));
ASR::expr_t* inc_expr = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, x.base.base.loc, idx_vars_value[i],
ASR::binopType::Add, const_1, int32_type,
nullptr, nullptr));
ASR::stmt_t* assign_stmt = LFortran::ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, idx_vars_value[i], inc_expr));
doloop_body.push_back(al, assign_stmt);
doloop = LFortran::ASRUtils::STMT(ASR::make_DoLoop_t(al, x.base.base.loc, head, doloop_body.p, doloop_body.size()));
Expand Down Expand Up @@ -599,7 +602,8 @@ class ArrayOpVisitor : public ASR::BaseWalkVisitor<ArrayOpVisitor>
case ASR::exprType::BinOp:
op_el_wise = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(
al, x.base.base.loc,
ref, (ASR::binopType)x.m_op, other_expr, x.m_type, nullptr));
ref, (ASR::binopType)x.m_op, other_expr,
x.m_type, nullptr, nullptr));
break;
case ASR::exprType::Compare:
op_el_wise = LFortran::ASRUtils::EXPR(ASR::make_Compare_t(
Expand All @@ -622,7 +626,7 @@ class ArrayOpVisitor : public ASR::BaseWalkVisitor<ArrayOpVisitor>
doloop_body.push_back(al, set_to_one);
doloop_body.push_back(al, doloop);
}
ASR::expr_t* inc_expr = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, x.base.base.loc, idx_vars_value[i], ASR::binopType::Add, const_1, int32_type, nullptr));
ASR::expr_t* inc_expr = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, x.base.base.loc, idx_vars_value[i], ASR::binopType::Add, const_1, int32_type, nullptr, nullptr));
ASR::stmt_t* assign_stmt = LFortran::ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, idx_vars_value[i], inc_expr));
doloop_body.push_back(al, assign_stmt);
doloop = LFortran::ASRUtils::STMT(ASR::make_DoLoop_t(al, x.base.base.loc, head, doloop_body.p, doloop_body.size()));
Expand Down
3 changes: 0 additions & 3 deletions src/lfortran/pass/class_constructor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,6 @@ class ClassConstructorVisitor : public ASR::BaseWalkVisitor<ClassConstructorVisi
LFortran::ASRUtils::symbol_get_past_external(dt_der->m_derived_type)->base));
for( size_t i = 0; i < dt_dertype->n_members; i++ ) {
ASR::symbol_t* member = dt_dertype->m_symtab->resolve_symbol(std::string(dt_dertype->m_members[i], strlen(dt_dertype->m_members[i])));
ASR::Variable_t* member_variable = down_cast<ASR::Variable_t>
(LFortran::ASRUtils::symbol_get_past_external(member));
ASR::ttype_t* member_type = member_variable->m_type;
ASR::expr_t* derived_ref = LFortran::ASRUtils::EXPR(ASRUtils::getDerivedRef_t(al, x.base.base.loc, (ASR::asr_t*)result_var, member, current_scope));
ASR::stmt_t* assign = LFortran::ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, derived_ref, x.m_args[i]));
class_constructor_result.push_back(al, assign);
Expand Down
6 changes: 3 additions & 3 deletions src/lfortran/pass/do_loops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,16 @@ Vec<ASR::stmt_t*> replace_doloop(Allocator &al, const ASR::DoLoop_t &loop) {
ASR::expr_t *target = loop.m_head.m_v;
ASR::ttype_t *type = LFortran::ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4, nullptr, 0));
ASR::stmt_t *stmt1 = LFortran::ASRUtils::STMT(ASR::make_Assignment_t(al, loc, target,
LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, loc, a, ASR::binopType::Sub, c, type, nullptr))
LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, loc, a, ASR::binopType::Sub, c, type, nullptr, nullptr))
));

ASR::expr_t *cond = LFortran::ASRUtils::EXPR(ASR::make_Compare_t(al, loc,
LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, loc, target, ASR::binopType::Add, c, type, nullptr)),
LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, loc, target, ASR::binopType::Add, c, type, nullptr, nullptr)),
cmp_op, b, type, nullptr));
Vec<ASR::stmt_t*> body;
body.reserve(al, loop.n_body+1);
body.push_back(al, LFortran::ASRUtils::STMT(ASR::make_Assignment_t(al, loc, target,
LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, loc, target, ASR::binopType::Add, c, type, nullptr))
LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, loc, target, ASR::binopType::Add, c, type, nullptr, nullptr))
)));
for (size_t i=0; i<loop.n_body; i++) {
body.push_back(al, loop.m_body[i]);
Expand Down
12 changes: 6 additions & 6 deletions src/lfortran/pass/implied_do_loops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,9 @@ class ImpliedDoLoopVisitor : public ASR::BaseWalkVisitor<ImpliedDoLoopVisitor>
const_n = offset = num_grps = grp_start = nullptr;
if( arr_idx == nullptr ) {
const_n = LFortran::ASRUtils::EXPR(ASR::make_ConstantInteger_t(al, arr_var->base.base.loc, idoloop->n_values, _type));
offset = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, arr_var->base.base.loc, idoloop->m_var, ASR::binopType::Sub, idoloop->m_start, _type, nullptr));
num_grps = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, arr_var->base.base.loc, offset, ASR::binopType::Mul, const_n, _type, nullptr));
grp_start = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, arr_var->base.base.loc, num_grps, ASR::binopType::Add, const_1, _type, nullptr));
offset = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, arr_var->base.base.loc, idoloop->m_var, ASR::binopType::Sub, idoloop->m_start, _type, nullptr, nullptr));
num_grps = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, arr_var->base.base.loc, offset, ASR::binopType::Mul, const_n, _type, nullptr, nullptr));
grp_start = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, arr_var->base.base.loc, num_grps, ASR::binopType::Add, const_1, _type, nullptr, nullptr));
}
for( size_t i = 0; i < idoloop->n_values; i++ ) {
Vec<ASR::array_index_t> args;
Expand All @@ -164,7 +164,7 @@ class ImpliedDoLoopVisitor : public ASR::BaseWalkVisitor<ImpliedDoLoopVisitor>
ASR::expr_t* const_i = LFortran::ASRUtils::EXPR(ASR::make_ConstantInteger_t(al, arr_var->base.base.loc, i, _type));
ASR::expr_t* idx = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, arr_var->base.base.loc,
grp_start, ASR::binopType::Add, const_i,
_type, nullptr));
_type, nullptr, nullptr));
ai.m_right = idx;
} else {
ai.m_right = arr_idx;
Expand All @@ -181,7 +181,7 @@ class ImpliedDoLoopVisitor : public ASR::BaseWalkVisitor<ImpliedDoLoopVisitor>
ASR::stmt_t* doloop_stmt = LFortran::ASRUtils::STMT(ASR::make_Assignment_t(al, arr_var->base.base.loc, array_ref, idoloop->m_values[i]));
doloop_body.push_back(al, doloop_stmt);
if( arr_idx != nullptr ) {
ASR::expr_t* increment = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, arr_var->base.base.loc, arr_idx, ASR::binopType::Add, const_1, LFortran::ASRUtils::expr_type(arr_idx), nullptr));
ASR::expr_t* increment = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, arr_var->base.base.loc, arr_idx, ASR::binopType::Add, const_1, LFortran::ASRUtils::expr_type(arr_idx), nullptr, nullptr));
ASR::stmt_t* assign_stmt = LFortran::ASRUtils::STMT(ASR::make_Assignment_t(al, arr_var->base.base.loc, arr_idx, increment));
doloop_body.push_back(al, assign_stmt);
}
Expand Down Expand Up @@ -237,7 +237,7 @@ class ImpliedDoLoopVisitor : public ASR::BaseWalkVisitor<ImpliedDoLoopVisitor>
LFortran::ASRUtils::expr_type(LFortran::ASRUtils::EXPR((ASR::asr_t*)arr_var)), nullptr));
ASR::stmt_t* assign_stmt = LFortran::ASRUtils::STMT(ASR::make_Assignment_t(al, arr_var->base.base.loc, array_ref, arr_init->m_args[k]));
implied_do_loop_result.push_back(al, assign_stmt);
ASR::expr_t* increment = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, arr_var->base.base.loc, idx_var, ASR::binopType::Add, const_1, LFortran::ASRUtils::expr_type(idx_var), nullptr));
ASR::expr_t* increment = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, arr_var->base.base.loc, idx_var, ASR::binopType::Add, const_1, LFortran::ASRUtils::expr_type(idx_var), nullptr, nullptr));
assign_stmt = LFortran::ASRUtils::STMT(ASR::make_Assignment_t(al, arr_var->base.base.loc, idx_var, increment));
implied_do_loop_result.push_back(al, assign_stmt);
}
Expand Down
Loading

0 comments on commit 5a8ae34

Please sign in to comment.