Skip to content

Commit

Permalink
Merge pull request lcompilers#2012 from Shaikh-Ubaid/pythoncall_array…
Browse files Browse the repository at this point in the history
…s_as_return_type

PythonCall: Support array of simple types as return type
  • Loading branch information
Shaikh-Ubaid committed Jun 23, 2023
2 parents 62e5ead + 832245c commit c459bb5
Show file tree
Hide file tree
Showing 8 changed files with 272 additions and 32 deletions.
1 change: 1 addition & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,7 @@ RUN(NAME bindc_06 LABELS llvm c
EXTRAFILES bindc_06b.c)
RUN(NAME bindpy_01 LABELS cpython c_py ENABLE_CPYTHON NOFAST EXTRAFILES bindpy_01_module.py)
RUN(NAME bindpy_02 LABELS cpython c_py LINK_NUMPY EXTRAFILES bindpy_02_module.py)
RUN(NAME bindpy_03 LABELS cpython c_py LINK_NUMPY NOFAST EXTRAFILES bindpy_03_module.py)
RUN(NAME test_generics_01 LABELS cpython llvm c NOFAST)
RUN(NAME test_cmath LABELS cpython llvm c NOFAST)
RUN(NAME test_complex_01 LABELS cpython llvm c wasm wasm_x64)
Expand Down
136 changes: 136 additions & 0 deletions integration_tests/bindpy_03.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from lpython import i32, i64, f64, pythoncall, Const, TypeVar
from numpy import empty, int32, int64, float64

n = TypeVar("n")
m = TypeVar("m")
p = TypeVar("p")
q = TypeVar("q")
r = TypeVar("r")

@pythoncall(module = "bindpy_03_module")
def get_cpython_version() -> str:
pass

@pythoncall(module = "bindpy_03_module")
def get_int_array_sum(n: i32, a: i32[:], b: i32[:]) -> i32[n]:
pass

@pythoncall(module = "bindpy_03_module")
def get_int_array_product(n: i32, a: i32[:], b: i32[:]) -> i32[n]:
pass

@pythoncall(module = "bindpy_03_module")
def get_float_array_sum(n: i32, m: i32, a: f64[:], b: f64[:]) -> f64[n, m]:
pass

@pythoncall(module = "bindpy_03_module")
def get_float_array_product(n: i32, m: i32, a: f64[:], b: f64[:]) -> f64[n, m]:
pass

@pythoncall(module = "bindpy_03_module")
def get_array_dot_product(m: i32, a: i64[:], b: f64[:]) -> f64[m]:
pass

@pythoncall(module = "bindpy_03_module")
def get_multidim_array_i64(p: i32, q: i32, r: i32) -> i64[p, q, r]:
pass

# Integers:
def test_array_ints():
n: Const[i32] = 5
a: i32[n] = empty([n], dtype=int32)
b: i32[n] = empty([n], dtype=int32)

i: i32
for i in range(n):
a[i] = i + 10
for i in range(n):
b[i] = i + 20

c: i32[n] = get_int_array_sum(n, a, b)
print(c)
for i in range(n):
assert c[i] == (i + i + 30)


c = get_int_array_product(n, a, b)
print(c)
for i in range(n):
assert c[i] == ((i + 10) * (i + 20))

# Floats
def test_array_floats():
n: Const[i32] = 3
m: Const[i32] = 5
a: f64[n, m] = empty([n, m], dtype=float64)
b: f64[n, m] = empty([n, m], dtype=float64)

i: i32
j: i32

for i in range(n):
for j in range(m):
a[i, j] = f64((i + 10) * (j + 10))

for i in range(n):
for j in range(m):
b[i, j] = f64((i + 20) * (j + 20))

c: f64[n, m] = get_float_array_sum(n, m, a, b)
print(c)
for i in range(n):
for j in range(m):
assert abs(c[i, j] - (f64((i + 10) * (j + 10)) + f64((i + 20) * (j + 20)))) <= 1e-4

c = get_float_array_product(n, m, a, b)
print(c)
for i in range(n):
for j in range(m):
assert abs(c[i, j] - (f64((i + 10) * (j + 10)) * f64((i + 20) * (j + 20)))) <= 1e-4

def test_array_broadcast():
n: Const[i32] = 3
m: Const[i32] = 5
a: i64[n] = empty([n], dtype=int64)
b: f64[n, m] = empty([n, m], dtype=float64)

i: i32
j: i32
for i in range(n):
a[i] = i64(i + 10)

for i in range(n):
for j in range(m):
b[i, j] = f64((i + 1) * (j + 1))

c: f64[m] = get_array_dot_product(m, a, b)
print(c)
assert abs(c[0] - (68.0)) <= 1e-4
assert abs(c[1] - (136.0)) <= 1e-4
assert abs(c[2] - (204.0)) <= 1e-4
assert abs(c[3] - (272.0)) <= 1e-4
assert abs(c[4] - (340.0)) <= 1e-4

def test_multidim_array_return_i64():
p: Const[i32] = 3
q: Const[i32] = 4
r: Const[i32] = 5
a: i64[p, q, r] = empty([p, q, r], dtype=int64)
a = get_multidim_array_i64(p, q, r)
print(a)

i: i32; j: i32; k: i32
for i in range(p):
for j in range(q):
for k in range(r):
assert a[i, j, k] == i64(i * 2 + j * 3 + k * 4)

def main0():
print("CPython version: ", get_cpython_version())

test_array_ints()
test_array_floats()
test_array_broadcast()
test_multidim_array_return_i64()

main0()
31 changes: 31 additions & 0 deletions integration_tests/bindpy_03_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import numpy as np

def get_cpython_version():
import platform
return platform.python_version()

def get_int_array_sum(n, a, b):
return np.add(a, b)

def get_int_array_product(n, a, b):
return np.multiply(a, b)

def get_float_array_sum(n, m, a, b):
return np.add(a, b)

def get_float_array_product(n, m, a, b):
return np.multiply(a, b)

def get_array_dot_product(m, a, b):
print(a, b)
c = a @ b
print(c)
return c

def get_multidim_array_i64(p, q, r):
a = np.empty([p, q, r], dtype = np.int64)
for i in range(p):
for j in range(q):
for k in range(r):
a[i, j, k] = i * 2 + j * 3 + k * 4
return a
3 changes: 3 additions & 0 deletions src/libasr/asr_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,9 @@ static inline ASR::symbol_t *get_asr_owner(const ASR::expr_t *expr) {
case ASR::exprType::GetPointer: {
return ASRUtils::get_asr_owner(ASR::down_cast<ASR::GetPointer_t>(expr)->m_arg);
}
case ASR::exprType::FunctionCall: {
return ASRUtils::get_asr_owner(ASR::down_cast<ASR::FunctionCall_t>(expr)->m_name);
}
default: {
throw LCompilersException("Cannot find the ASR owner of underlying symbol of expression "
+ std::to_string(expr->type));
Expand Down
3 changes: 0 additions & 3 deletions src/libasr/codegen/asr_to_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,13 @@ class ASRToCVisitor : public BaseCCPPVisitor<ASRToCVisitor>
{
public:

std::string array_types_decls;

std::unique_ptr<CUtils::CUtilFunctions> c_utils_functions;

int counter;

ASRToCVisitor(diag::Diagnostics &diag, CompilerOptions &co,
int64_t default_lower_bound)
: BaseCCPPVisitor(diag, co.platform, co, false, false, true, default_lower_bound),
array_types_decls(std::string("")),
c_utils_functions{std::make_unique<CUtils::CUtilFunctions>()},
counter{0} {
}
Expand Down
68 changes: 43 additions & 25 deletions src/libasr/codegen/asr_to_c_cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class BaseCCPPVisitor : public ASR::BaseVisitor<Struct>
std::map<uint64_t, std::string> const_var_names;
std::map<int32_t, std::string> gotoid2name;
std::map<std::string, std::string> emit_headers;
std::string array_types_decls;

// Output configuration:
// Use std::string or char*
Expand Down Expand Up @@ -146,7 +147,7 @@ class BaseCCPPVisitor : public ASR::BaseVisitor<Struct>
BaseCCPPVisitor(diag::Diagnostics &diag, Platform &platform,
CompilerOptions &_compiler_options, bool gen_stdstring, bool gen_stdcomplex, bool is_c,
int64_t default_lower_bound) : diag{diag},
platform{platform}, compiler_options{_compiler_options},
platform{platform}, compiler_options{_compiler_options}, array_types_decls{std::string("")},
gen_stdstring{gen_stdstring}, gen_stdcomplex{gen_stdcomplex},
is_c{is_c}, global_scope{nullptr}, lower_bound{default_lower_bound},
template_number{0}, c_ds_api{std::make_unique<CCPPDSUtils>(is_c, platform)},
Expand Down Expand Up @@ -381,28 +382,32 @@ R"(#include <stdio.h>
}
if (x.m_return_var) {
ASR::Variable_t *return_var = ASRUtils::EXPR2VAR(x.m_return_var);
bool is_array = ASRUtils::is_array(return_var->m_type);
if (ASRUtils::is_integer(*return_var->m_type)) {
int kind = ASR::down_cast<ASR::Integer_t>(return_var->m_type)->m_kind;
switch (kind) {
case (1) : sub = "int8_t "; break;
case (2) : sub = "int16_t "; break;
case (4) : sub = "int32_t "; break;
case (8) : sub = "int64_t "; break;
int kind = ASRUtils::extract_kind_from_ttype_t(return_var->m_type);
if (is_array) {
sub = "struct i" + std::to_string(kind * 8) + "* ";
} else {
sub = "int" + std::to_string(kind * 8) + "_t ";
}
} else if (ASRUtils::is_unsigned_integer(*return_var->m_type)) {
int kind = ASR::down_cast<ASR::UnsignedInteger_t>(return_var->m_type)->m_kind;
switch (kind) {
case (1) : sub = "uint8_t "; break;
case (2) : sub = "uint16_t "; break;
case (4) : sub = "uint32_t "; break;
case (8) : sub = "uint64_t "; break;
int kind = ASRUtils::extract_kind_from_ttype_t(return_var->m_type);
if (is_array) {
sub = "struct u" + std::to_string(kind * 8) + "* ";
} else {
sub = "uint" + std::to_string(kind * 8) + "_t ";
}
} else if (ASRUtils::is_real(*return_var->m_type)) {
bool is_float = ASR::down_cast<ASR::Real_t>(return_var->m_type)->m_kind == 4;
if (is_float) {
sub = "float ";
int kind = ASRUtils::extract_kind_from_ttype_t(return_var->m_type);
bool is_float = (kind == 4);
if (is_array) {
sub = "struct r" + std::to_string(kind * 8) + "* ";
} else {
sub = "double ";
if (is_float) {
sub = "float ";
} else {
sub = "double ";
}
}
} else if (ASRUtils::is_logical(*return_var->m_type)) {
sub = "bool ";
Expand Down Expand Up @@ -534,17 +539,30 @@ R"(#include <stdio.h>
if (!x.m_return_var) return "";
ASR::Variable_t* r_v = ASRUtils::EXPR2VAR(x.m_return_var);
std::string indent = "\n ";
std::string py_val_cnvrt = CUtils::get_py_obj_return_type_conv_func_from_ttype_t(r_v->m_type) + "(pValue)";
std::string ret_var_decl = indent + CUtils::get_c_type_from_ttype_t(r_v->m_type) + " " + std::string(r_v->m_name) + ";";
std::string ret_assign = indent + std::string(r_v->m_name) + " = " + py_val_cnvrt + ";";
std::string ret_stmt = indent + "return " + std::string(r_v->m_name) + ";";
std::string clear_pValue = indent + "Py_DECREF(pValue);";
std::string copy_result = "";
std::string py_val_cnvrt, ret_var_decl, copy_result;
if (ASRUtils::is_aggregate_type(r_v->m_type)) {
if (ASRUtils::is_character(*r_v->m_type)) {
copy_result = indent + std::string(r_v->m_name) + " = _lfortran_str_copy(" + std::string(r_v->m_name) + ", 1, 0);";
if (ASRUtils::is_array(r_v->m_type)) {
ASR::ttype_t* array_type_asr = ASRUtils::type_get_past_array(r_v->m_type);
std::string array_type_name = CUtils::get_c_type_from_ttype_t(array_type_asr);
std::string array_encoded_type_name = ASRUtils::get_type_code(array_type_asr, true, false);
std::string return_type = c_ds_api->get_array_type(array_type_name, array_encoded_type_name, array_types_decls, true);
py_val_cnvrt = bind_py_utils_functions->get_conv_py_arr_to_c(return_type, array_type_name,
array_encoded_type_name) + "(pValue)";
ret_var_decl = indent + return_type + " _lpython_return_variable;";
} else {
if (ASRUtils::is_character(*r_v->m_type)) {
py_val_cnvrt = CUtils::get_py_obj_return_type_conv_func_from_ttype_t(r_v->m_type) + "(pValue)";
ret_var_decl = indent + CUtils::get_c_type_from_ttype_t(r_v->m_type) + " _lpython_return_variable;";
copy_result = indent + "_lpython_return_variable = _lfortran_str_copy(" + std::string(r_v->m_name) + ", 1, 0);";
}
}
} else {
py_val_cnvrt = CUtils::get_py_obj_return_type_conv_func_from_ttype_t(r_v->m_type) + "(pValue)";
ret_var_decl = indent + CUtils::get_c_type_from_ttype_t(r_v->m_type) + " _lpython_return_variable;";
}
std::string ret_assign = indent + std::string(r_v->m_name) + " = " + py_val_cnvrt + ";";
std::string ret_stmt = indent + "return _lpython_return_variable;";
std::string clear_pValue = indent + "Py_DECREF(pValue);";
return ret_var_decl + ret_assign + copy_result + clear_pValue + ret_stmt + "\n";
}

Expand Down
Loading

0 comments on commit c459bb5

Please sign in to comment.