Skip to content

Commit

Permalink
Following changes have been made,
Browse files Browse the repository at this point in the history
1. Allow IntegerBinOp in dimension specification i.e., i16[n*k]
2. Check for type mismatches in c_p_pointer and its target
  • Loading branch information
czgdp1807 committed Feb 27, 2023
1 parent c803bef commit fcae464
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 10 deletions.
135 changes: 129 additions & 6 deletions src/libasr/asr_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1908,7 +1908,100 @@ inline bool is_derived_type_similar(ASR::StructType_t* a, ASR::StructType_t* b)
std::string(b->m_name) == "~abstract_type");
}

inline bool types_equal(ASR::ttype_t *a, ASR::ttype_t *b) {
// TODO: Scaled up implementation for all exprTypes
// One way is to do it in asdl_cpp.py
inline bool expr_equal(ASR::expr_t* x, ASR::expr_t* y) {
if( x->type != y->type ) {
return false;
}

switch( x->type ) {
case ASR::exprType::IntegerBinOp: {
ASR::IntegerBinOp_t* intbinop_x = ASR::down_cast<ASR::IntegerBinOp_t>(x);
ASR::IntegerBinOp_t* intbinop_y = ASR::down_cast<ASR::IntegerBinOp_t>(y);
if( intbinop_x->m_op != intbinop_y->m_op ) {
return false;
}
bool left_left = expr_equal(intbinop_x->m_left, intbinop_y->m_left);
bool left_right = expr_equal(intbinop_x->m_left, intbinop_y->m_right);
bool right_left = expr_equal(intbinop_x->m_right, intbinop_y->m_left);
bool right_right = expr_equal(intbinop_x->m_right, intbinop_y->m_right);
switch( intbinop_x->m_op ) {
case ASR::binopType::Add:
case ASR::binopType::Mul:
case ASR::binopType::BitAnd:
case ASR::binopType::BitOr:
case ASR::binopType::BitXor: {
return (left_left && right_right) || (left_right && right_left);
}
case ASR::binopType::Sub:
case ASR::binopType::Div:
case ASR::binopType::Pow:
case ASR::binopType::BitLShift:
case ASR::binopType::BitRShift: {
return (left_left && right_right);
}
}
break;
}
case ASR::exprType::Var: {
ASR::Var_t* var_x = ASR::down_cast<ASR::Var_t>(x);
ASR::Var_t* var_y = ASR::down_cast<ASR::Var_t>(y);
return var_x->m_v == var_y->m_v;
}
default: {
// Let it pass for now.
return true;
}
}

// Let it pass for now.
return true;
}

inline bool dimension_expr_equal(ASR::expr_t* dim_a, ASR::expr_t* dim_b) {
if( !(dim_a && dim_b) ) {
return true;
}
ASR::expr_t* dim_a_fallback = nullptr;
ASR::expr_t* dim_b_fallback = nullptr;
if( ASR::is_a<ASR::Var_t>(*dim_a) &&
ASR::is_a<ASR::Variable_t>(
*ASR::down_cast<ASR::Var_t>(dim_a)->m_v) ) {
dim_a_fallback = ASRUtils::EXPR2VAR(dim_a)->m_symbolic_value;
}
if( ASR::is_a<ASR::Var_t>(*dim_b) &&
ASR::is_a<ASR::Variable_t>(
*ASR::down_cast<ASR::Var_t>(dim_b)->m_v) ) {
dim_b_fallback = ASRUtils::EXPR2VAR(dim_b)->m_symbolic_value;
}
if( !ASRUtils::expr_equal(dim_a, dim_b) &&
!(dim_a_fallback && ASRUtils::expr_equal(dim_a_fallback, dim_b)) &&
!(dim_b_fallback && ASRUtils::expr_equal(dim_a, dim_b_fallback)) ) {
return false;
}
return true;
}

inline bool dimensions_equal(ASR::dimension_t* dims_a, size_t n_dims_a,
ASR::dimension_t* dims_b, size_t n_dims_b) {
if( n_dims_a != n_dims_b ) {
return false;
}

for( size_t i = 0; i < n_dims_a; i++ ) {
ASR::dimension_t dim_a = dims_a[i];
ASR::dimension_t dim_b = dims_b[i];
if( !dimension_expr_equal(dim_a.m_length, dim_b.m_length) ||
!dimension_expr_equal(dim_a.m_start, dim_b.m_start) ) {
return false;
}
}
return true;
}

inline bool types_equal(ASR::ttype_t *a, ASR::ttype_t *b,
bool check_for_dimensions=false) {
// TODO: If anyone of the input or argument is derived type then
// add support for checking member wise types and do not compare
// directly. From stdlib_string len(pattern) error
Expand All @@ -1922,7 +2015,13 @@ inline bool types_equal(ASR::ttype_t *a, ASR::ttype_t *b) {
ASR::Integer_t *a2 = ASR::down_cast<ASR::Integer_t>(a);
ASR::Integer_t *b2 = ASR::down_cast<ASR::Integer_t>(b);
if (a2->m_kind == b2->m_kind) {
return true;
if( check_for_dimensions ) {
return ASRUtils::dimensions_equal(
a2->m_dims, a2->n_dims,
b2->m_dims, b2->n_dims);
} else {
return true;
}
} else {
return false;
}
Expand All @@ -1935,7 +2034,13 @@ inline bool types_equal(ASR::ttype_t *a, ASR::ttype_t *b) {
ASR::Real_t *a2 = ASR::down_cast<ASR::Real_t>(a);
ASR::Real_t *b2 = ASR::down_cast<ASR::Real_t>(b);
if (a2->m_kind == b2->m_kind) {
return true;
if( check_for_dimensions ) {
return ASRUtils::dimensions_equal(
a2->m_dims, a2->n_dims,
b2->m_dims, b2->n_dims);
} else {
return true;
}
} else {
return false;
}
Expand All @@ -1945,7 +2050,13 @@ inline bool types_equal(ASR::ttype_t *a, ASR::ttype_t *b) {
ASR::Complex_t *a2 = ASR::down_cast<ASR::Complex_t>(a);
ASR::Complex_t *b2 = ASR::down_cast<ASR::Complex_t>(b);
if (a2->m_kind == b2->m_kind) {
return true;
if( check_for_dimensions ) {
return ASRUtils::dimensions_equal(
a2->m_dims, a2->n_dims,
b2->m_dims, b2->n_dims);
} else {
return true;
}
} else {
return false;
}
Expand All @@ -1955,7 +2066,13 @@ inline bool types_equal(ASR::ttype_t *a, ASR::ttype_t *b) {
ASR::Logical_t *a2 = ASR::down_cast<ASR::Logical_t>(a);
ASR::Logical_t *b2 = ASR::down_cast<ASR::Logical_t>(b);
if (a2->m_kind == b2->m_kind) {
return true;
if( check_for_dimensions ) {
return ASRUtils::dimensions_equal(
a2->m_dims, a2->n_dims,
b2->m_dims, b2->n_dims);
} else {
return true;
}
} else {
return false;
}
Expand All @@ -1965,7 +2082,13 @@ inline bool types_equal(ASR::ttype_t *a, ASR::ttype_t *b) {
ASR::Character_t *a2 = ASR::down_cast<ASR::Character_t>(a);
ASR::Character_t *b2 = ASR::down_cast<ASR::Character_t>(b);
if (a2->m_kind == b2->m_kind) {
return true;
if( check_for_dimensions ) {
return ASRUtils::dimensions_equal(
a2->m_dims, a2->n_dims,
b2->m_dims, b2->n_dims);
} else {
return true;
}
} else {
return false;
}
Expand Down
37 changes: 33 additions & 4 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1384,7 +1384,8 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
ASR::dimension_t dim;
dim.loc = loc;
if (ASR::is_a<ASR::IntegerConstant_t>(*value) ||
ASR::is_a<ASR::Var_t>(*value)) {
ASR::is_a<ASR::Var_t>(*value) ||
ASR::is_a<ASR::IntegerBinOp_t>(*value)) {
ASR::ttype_t *itype = ASRUtils::expr_type(value);
ASR::expr_t* one = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 1, itype));
ASR::expr_t* zero = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 0, itype));
Expand Down Expand Up @@ -2235,11 +2236,24 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
}

ASR::asr_t* create_CPtrToPointerFromArgs(AST::expr_t* ast_cptr, AST::expr_t* ast_pptr,
const Location& loc) {
AST::expr_t* ast_type_expr, const Location& loc) {
this->visit_expr(*ast_cptr);
ASR::expr_t* cptr = ASRUtils::EXPR(tmp);
this->visit_expr(*ast_pptr);
ASR::expr_t* pptr = ASRUtils::EXPR(tmp);
ASR::ttype_t* asr_alloc_type = ast_expr_to_asr_type(ast_type_expr->base.loc, *ast_type_expr);
ASR::ttype_t* target_type = ASRUtils::type_get_past_pointer(ASRUtils::expr_type(pptr));
if( !ASRUtils::types_equal(target_type, asr_alloc_type, true) ) {
diag.add(diag::Diagnostic(
"Type mismatch in c_p_pointer and target variable, the types must match",
diag::Level::Error, diag::Stage::Semantic, {
diag::Label("type mismatch between target variable type "
"and c_p_pointer allocation type)",
{target_type->base.loc, asr_alloc_type->base.loc})
})
);
throw SemanticAbort();
}
return ASR::make_CPtrToPointer_t(al, loc, cptr,
pptr, nullptr);
}
Expand Down Expand Up @@ -2269,7 +2283,8 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
AST::Call_t* c_p_pointer_call = AST::down_cast<AST::Call_t>(x.m_value);
AST::expr_t* cptr = c_p_pointer_call->m_args[0];
AST::expr_t* pptr = assign_ast_target;
tmp = create_CPtrToPointerFromArgs(cptr, pptr, x.base.base.loc);
tmp = create_CPtrToPointerFromArgs(cptr, pptr, c_p_pointer_call->m_args[1],
x.base.base.loc);
// if( current_body ) {
// current_body->push_back(al, ASRUtils::STMT(tmp));
// }
Expand Down Expand Up @@ -4034,7 +4049,8 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
AST::Call_t* c_p_pointer_call = AST::down_cast<AST::Call_t>(x.m_value);
AST::expr_t* cptr = c_p_pointer_call->m_args[0];
AST::expr_t* pptr = x.m_targets[0];
tmp = create_CPtrToPointerFromArgs(cptr, pptr, x.base.base.loc);
tmp = create_CPtrToPointerFromArgs(cptr, pptr, c_p_pointer_call->m_args[1],
x.base.base.loc);
is_c_p_pointer_call = is_c_p_pointer_call;
return ;
}
Expand Down Expand Up @@ -5508,6 +5524,19 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
ASR::expr_t* cptr = ASRUtils::EXPR(tmp);
visit_expr(*x.m_args[1]);
ASR::expr_t* pptr = ASRUtils::EXPR(tmp);
ASR::ttype_t* asr_alloc_type = ast_expr_to_asr_type(x.m_args[1]->base.loc, *x.m_args[1]);
ASR::ttype_t* target_type = ASRUtils::type_get_past_pointer(ASRUtils::expr_type(pptr));
if( !ASRUtils::types_equal(target_type, asr_alloc_type, true) ) {
diag.add(diag::Diagnostic(
"Type mismatch in c_p_pointer and target variable, the types must match",
diag::Level::Error, diag::Stage::Semantic, {
diag::Label("type mismatch ('" + ASRUtils::type_to_str_python(target_type) +
"' and '" + ASRUtils::type_to_str_python(asr_alloc_type) + "')",
{target_type->base.loc, asr_alloc_type->base.loc})
})
);
throw SemanticAbort();
}
return ASR::make_CPtrToPointer_t(al, x.base.base.loc, cptr,
pptr, nullptr);
}
Expand Down
24 changes: 24 additions & 0 deletions tests/errors/bindc_03.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@

from ltypes import c_p_pointer, CPtr, i32, Pointer, i16

def fill_A(k: i32, n: i32, b: CPtr) -> None:
nk: i32 = n * k
A: Pointer[i16[nk]] = c_p_pointer(b, i16[k * n])
i: i32; j: i32
for j in range(k):
for i in range(n):
A[j*n+i] = i16((i+j))

def fill_B(k: i32, n: i32, b: CPtr) -> None:
B: Pointer[i16[n * k]] = c_p_pointer(b, i16[k * n])
i: i32; j: i32
for j in range(k):
for i in range(n):
B[j*n+i] = i16((i+j))

def fill_C(k: i32, n: i32, b: CPtr) -> None:
C: Pointer[i16[n]] = c_p_pointer(b, i16[n * k])
i: i32; j: i32
for j in range(k):
for i in range(n):
C[j*n+i] = i16((i+j))

0 comments on commit fcae464

Please sign in to comment.