Skip to content

Commit

Permalink
Merge pull request lcompilers#1726 from Smit-create/i-1608
Browse files Browse the repository at this point in the history
Initial support for callbacks
  • Loading branch information
Smit-create committed Apr 24, 2023
2 parents 3d3a53c + e535783 commit 91bc4aa
Show file tree
Hide file tree
Showing 8 changed files with 739 additions and 14 deletions.
2 changes: 2 additions & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,8 @@ RUN(NAME global_syms_04 LABELS cpython llvm c wasm wasm_x64)
RUN(NAME global_syms_05 LABELS cpython llvm c)
RUN(NAME global_syms_06 LABELS cpython llvm c)

RUN(NAME callback_01 LABELS cpython llvm)

# Intrinsic Functions
RUN(NAME intrinsics_01 LABELS cpython llvm) # any

Expand Down
26 changes: 26 additions & 0 deletions integration_tests/callback_01.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from lpython import i32, Callable


def f(x: i32) -> i32:
return x + 1

def f2(x: i32) -> i32:
return x + 10

def f3(x: i32) -> i32:
return f(x) + f2(x)


def g(func: Callable[[i32], i32], arg: i32) -> i32:
ret: i32
ret = func(arg)
return ret


def check():
assert g(f, 10) == 11
assert g(f2, 20) == 30
assert g(f3, 5) == 21


check()
26 changes: 26 additions & 0 deletions src/libasr/asr_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1551,6 +1551,11 @@ inline int extract_dimensions_from_ttype(ASR::ttype_t *x,
m_dims = nullptr;
break;
}
case ASR::ttypeType::FunctionType: {
n_dims = 0;
m_dims = nullptr;
break;
}
case ASR::ttypeType::Dict: {
n_dims = 0;
m_dims = nullptr;
Expand Down Expand Up @@ -2340,6 +2345,27 @@ inline bool check_equal_type(ASR::ttype_t* x, ASR::ttype_t* y) {
std::string left_param = left_tp->m_param;
std::string right_param = right_tp->m_param;
return left_param.compare(right_param) == 0;
} else if (ASR::is_a<ASR::FunctionType_t>(*x) && ASR::is_a<ASR::FunctionType_t>(*y)) {
ASR::FunctionType_t* left_ft = ASR::down_cast<ASR::FunctionType_t>(x);
ASR::FunctionType_t* right_ft = ASR::down_cast<ASR::FunctionType_t>(y);
if (left_ft->n_arg_types != right_ft->n_arg_types) {
return false;
}
bool result;
for (size_t i=0; i<left_ft->n_arg_types; i++) {
result = check_equal_type(left_ft->m_arg_types[i],
right_ft->m_arg_types[i]);
if (!result) return false;
}
if (left_ft->m_return_var_type == nullptr &&
right_ft->m_return_var_type == nullptr) {
return true;
} else if (left_ft->m_return_var_type != nullptr &&
right_ft->m_return_var_type != nullptr) {
return check_equal_type(left_ft->m_return_var_type,
right_ft->m_return_var_type);
}
return false;
}

return types_equal(x, y);
Expand Down
120 changes: 107 additions & 13 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1521,6 +1521,31 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_Tuple_t(al, loc,
types.p, types.size()));
return type;
} else if (var_annotation == "Callable") {
LCOMPILERS_ASSERT(AST::is_a<AST::Tuple_t>(*s->m_slice));
AST::Tuple_t *t = AST::down_cast<AST::Tuple_t>(s->m_slice);
LCOMPILERS_ASSERT(t->n_elts <= 2 && t->n_elts >= 1);
Vec<ASR::ttype_t*> arg_types;
LCOMPILERS_ASSERT(AST::is_a<AST::List_t>(*t->m_elts[0]));

AST::List_t *arg_list = AST::down_cast<AST::List_t>(t->m_elts[0]);
if (arg_list->n_elts > 0) {
arg_types.reserve(al, arg_list->n_elts);
for (size_t i=0; i<arg_list->n_elts; i++) {
arg_types.push_back(al, ast_expr_to_asr_type(loc, *arg_list->m_elts[i]));
}
} else {
arg_types.reserve(al, 1);
}
ASR::ttype_t* ret_type = nullptr;
if (t->n_elts == 2) {
ret_type = ast_expr_to_asr_type(loc, *t->m_elts[1]);
}
ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_FunctionType_t(al, loc, arg_types.p,
arg_types.size(), ret_type, ASR::abiType::Source,
ASR::deftypeType::Interface, nullptr, false, false,
false, false, false, nullptr, 0, nullptr, 0, false));
return type;
} else if (var_annotation == "set") {
if (AST::is_a<AST::Name_t>(*s->m_slice)) {
ASR::ttype_t *type = ast_expr_to_asr_type(loc, *s->m_slice);
Expand Down Expand Up @@ -3380,6 +3405,69 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
tmp = tmp0;
}

ASR::symbol_t* create_implicit_interface_function(Location &loc, ASR::FunctionType_t *func, std::string func_name) {
SymbolTable *parent_scope = current_scope;
current_scope = al.make_new<SymbolTable>(parent_scope);

Vec<ASR::expr_t*> args;
args.reserve(al, func->n_arg_types);
std::string sym_name = to_lower(func_name);
for (size_t i=0; i<func->n_arg_types; i++) {
std::string arg_name = sym_name + "_arg_" + std::to_string(i);
arg_name = to_lower(arg_name);
ASR::symbol_t *v;
SetChar variable_dependencies_vec;
variable_dependencies_vec.reserve(al, 1);
ASRUtils::collect_variable_dependencies(al, variable_dependencies_vec,
func->m_arg_types[i]);
v = ASR::down_cast<ASR::symbol_t>(
ASR::make_Variable_t(al, loc,
current_scope, s2c(al, arg_name), variable_dependencies_vec.p,
variable_dependencies_vec.size(), ASRUtils::intent_unspecified,
nullptr, nullptr, ASR::storage_typeType::Default, func->m_arg_types[i],
ASR::abiType::Source, ASR::Public, ASR::presenceType::Required,
false));
current_scope->add_symbol(arg_name, v);
LCOMPILERS_ASSERT(v != nullptr)
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc,
v)));
}

ASR::expr_t *to_return = nullptr;
if (func->m_return_var_type) {
std::string return_var_name = sym_name + "_return_var_name";
SetChar variable_dependencies_vec;
variable_dependencies_vec.reserve(al, 1);
ASRUtils::collect_variable_dependencies(al, variable_dependencies_vec,
func->m_return_var_type);
ASR::asr_t *return_var = ASR::make_Variable_t(al, loc,
current_scope, s2c(al, return_var_name), variable_dependencies_vec.p,
variable_dependencies_vec.size(), ASRUtils::intent_return_var,
nullptr, nullptr, ASR::storage_typeType::Default, func->m_return_var_type,
ASR::abiType::Source, ASR::Public, ASR::presenceType::Required,
false);
current_scope->add_symbol(return_var_name, ASR::down_cast<ASR::symbol_t>(return_var));
to_return = ASRUtils::EXPR(ASR::make_Var_t(al, loc,
ASR::down_cast<ASR::symbol_t>(return_var)));
}

tmp = ASRUtils::make_Function_t_util(
al, loc,
/* a_symtab */ current_scope,
/* a_name */ s2c(al, sym_name),
nullptr, 0,
/* a_args */ args.p,
/* n_args */ args.size(),
/* a_body */ nullptr,
/* n_body */ 0,
/* a_return_var */ to_return,
ASR::abiType::BindC, ASR::accessType::Public, ASR::deftypeType::Interface,
nullptr, false, false, false, false, false, /* a_type_parameters */ nullptr,
/* n_type_parameters */ 0, nullptr, 0, false, false, false);
current_scope = parent_scope;
return ASR::down_cast<ASR::symbol_t>(tmp);
}

void visit_FunctionDef(const AST::FunctionDef_t &x) {
dependencies.clear(al);
SymbolTable *parent_scope = current_scope;
Expand Down Expand Up @@ -3497,20 +3585,26 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
if (current_procedure_abi_type == ASR::abiType::BindC) {
value_attr = true;
}
SetChar variable_dependencies_vec;
variable_dependencies_vec.reserve(al, 1);
ASRUtils::collect_variable_dependencies(al, variable_dependencies_vec, arg_type, init_expr, value);
ASR::asr_t *v = ASR::make_Variable_t(al, loc, current_scope,
s2c(al, arg_s), variable_dependencies_vec.p,
variable_dependencies_vec.size(),
s_intent, init_expr, value, storage_type, arg_type,
current_procedure_abi_type, s_access, s_presence,
value_attr);
current_scope->add_symbol(arg_s, ASR::down_cast<ASR::symbol_t>(v));

ASR::symbol_t *var = current_scope->get_symbol(arg_s);
ASR::symbol_t *v;
if (ASR::is_a<ASR::FunctionType_t>(*arg_type)) {
ASR::FunctionType_t *func = ASR::down_cast<ASR::FunctionType_t>(arg_type);
v = create_implicit_interface_function(loc, func, arg_s);
} else {
SetChar variable_dependencies_vec;
variable_dependencies_vec.reserve(al, 1);
ASRUtils::collect_variable_dependencies(al, variable_dependencies_vec, arg_type, init_expr, value);
ASR::asr_t *_tmp = ASR::make_Variable_t(al, loc, current_scope,
s2c(al, arg_s), variable_dependencies_vec.p,
variable_dependencies_vec.size(),
s_intent, init_expr, value, storage_type, arg_type,
current_procedure_abi_type, s_access, s_presence,
value_attr);
v = ASR::down_cast<ASR::symbol_t>(_tmp);

}
current_scope->add_symbol(arg_s, v);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc,
var)));
v)));
}
ASR::accessType s_access = ASR::accessType::Public;
ASR::deftypeType deftype = ASR::deftypeType::Implementation;
Expand Down
3 changes: 2 additions & 1 deletion src/runtime/lpython/lpython.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
__slots__ = ["i8", "i16", "i32", "i64", "f32", "f64", "c32", "c64", "CPtr",
"overload", "ccall", "TypeVar", "pointer", "c_p_pointer", "Pointer",
"p_c_pointer", "vectorize", "inline", "Union", "static", "with_goto",
"packed", "Const", "sizeof", "ccallable", "ccallback"]
"packed", "Const", "sizeof", "ccallable", "ccallback", "Callable"]

# data-types

Expand Down Expand Up @@ -55,6 +55,7 @@ def __init__(self, type, dims):
c64 = Type("c64")
CPtr = Type("c_ptr")
Const = ConstType("Const")
Callable = Type("Callable")
Union = ctypes.Union
Pointer = PointerType("Pointer")

Expand Down
13 changes: 13 additions & 0 deletions tests/reference/asr-callback_01-64f7a94.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"basename": "asr-callback_01-64f7a94",
"cmd": "lpython --show-asr --indent --no-color {infile} -o {outfile}",
"infile": "tests/../integration_tests/callback_01.py",
"infile_hash": "c3ab71a93f40edda000ae863149c38c388bb43a8329ebae9320a7ab4",
"outfile": null,
"outfile_hash": null,
"stdout": "asr-callback_01-64f7a94.stdout",
"stdout_hash": "0b2b8730f07fc9aad59a2c4f1dc9060bcd022d05fb5bf2f34e5f8b4b",
"stderr": null,
"stderr_hash": null,
"returncode": 0
}
Loading

0 comments on commit 91bc4aa

Please sign in to comment.