Skip to content

Commit

Permalink
Merge pull request lcompilers#1796 from Thirumalai-Shaktivel/lpython_…
Browse files Browse the repository at this point in the history
…decorator

Fixes for lpython decorator
  • Loading branch information
certik committed May 12, 2023
2 parents 20d9536 + 0154fc3 commit bb8bfe7
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 19 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ inst/bin/*
*_lines.dat.txt
*__tmp__generated__.c
visualize*.html
lpython_decorator*/
a.c
a.h
a.py
Expand Down
6 changes: 3 additions & 3 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ RUN(NAME callback_01 LABELS cpython llvm)
# Intrinsic Functions
RUN(NAME intrinsics_01 LABELS cpython llvm) # any

COMPILE(NAME import_order_01 LABELS cpython llvm c) # any
# lpython decorator
RUN(NAME lpython_decorator_01 LABELS cpython)

# Jit
RUN(NAME test_lpython_decorator LABELS cpython)
COMPILE(NAME import_order_01 LABELS cpython llvm c) # any
File renamed without changes.
4 changes: 2 additions & 2 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3603,8 +3603,8 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
is_inline = true;
} else if (name == "static") {
is_static = true;
} else if (name == "jit") {
throw SemanticError("`@lpython.jit` decorator must be "
} else if (name == "lpython") {
throw SemanticError("`@lpython` decorator must be "
"run from CPython, not compiled using LPython",
dec->base.loc);
} else {
Expand Down
36 changes: 22 additions & 14 deletions src/runtime/lpython/lpython.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import platform
from dataclasses import dataclass as py_dataclass, is_dataclass as py_is_dataclass
from goto import with_goto
from numpy import get_include
from distutils.sysconfig import get_python_inc

# TODO: this does not seem to restrict other imports
__slots__ = ["i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "f32", "f64", "c32", "c64", "CPtr",
Expand Down Expand Up @@ -572,11 +570,13 @@ def get_data_type(t):
source_code = getsource(function)
source_code = source_code[source_code.find('\n'):]

# TODO: Create a filename based on the function name
# filename = function.__name__ + ".py"
dir_name = "./lpython_decorator_" + self.fn_name
if not os.path.exists(dir_name):
os.mkdir(dir_name)
filename = dir_name + "/" + self.fn_name

# Open the file for writing
with open("a.py", "w") as file:
with open(filename + ".py", "w") as file:
# Write the Python source code to the file
file.write("@ccallable")
file.write(source_code)
Expand Down Expand Up @@ -682,7 +682,7 @@ def get_data_type(t):
#include <numpy/ndarrayobject.h>
// LPython generated C code
#include "a.h"
#include "{self.fn_name}.h"
// Define the Python module and method mappings
static PyObject* define_module(PyObject* self, PyObject* args) {{
Expand All @@ -700,13 +700,13 @@ def get_data_type(t):
// Define the module initialization function
static struct PyModuleDef module_def = {{
PyModuleDef_HEAD_INIT,
"lpython_jit_module",
"lpython_module_{self.fn_name}",
"Shared library to use LPython generated functions",
-1,
module_methods
}};
PyMODINIT_FUNC PyInit_lpython_jit_module(void) {{
PyMODINIT_FUNC PyInit_lpython_module_{self.fn_name}(void) {{
PyObject* module;
// Create the module object
Expand All @@ -720,33 +720,41 @@ def get_data_type(t):
"""
# ----------------------------------------------------------------------
# Write the C source code to the file
with open("a.c", "w") as file:
with open(filename + ".c", "w") as file:
file.write(template)

# ----------------------------------------------------------------------
# Generate the Shared library
# TODO: Use LLVM instead of C backend
r = os.system("lpython --show-c --disable-main a.py > a.h")
r = os.system("lpython --show-c --disable-main "
+ filename + ".py > " + filename + ".h")
assert r == 0, "Failed to create C file"

gcc_flags = ""
if platform.system() == "Linux":
gcc_flags = " -shared -fPIC "
elif platform.system() == "Darwin":
gcc_flags = " -bundle -flat_namespace -undefined suppress "
else:
raise NotImplementedError("Platform not implemented")

from numpy import get_include
from distutils.sysconfig import get_python_inc, get_python_lib
python_path = "-I" + get_python_inc() + " "
numpy_path = "-I" + get_include()
numpy_path = "-I" + get_include() + " "
rt_path_01 = "-I" + get_rtlib_dir() + "/../libasr/runtime "
rt_path_02 = "-L" + get_rtlib_dir() + " -Wl,-rpath " \
+ get_rtlib_dir() + " -llpython_runtime "
python_lib = "-L" "$CONDA_PREFIX/lib/ -lpython3.10 -lm"
python_lib = "-L" + get_python_lib() + "/../.. -lpython3.10 -lm"

r = os.system("gcc -g" + gcc_flags + python_path + numpy_path +
" a.c -o lpython_jit_module.so " + rt_path_01 + rt_path_02 + python_lib)
filename + ".c -o lpython_module_" + self.fn_name + ".so " +
rt_path_01 + rt_path_02 + python_lib)
assert r == 0, "Failed to create the shared library"

def __call__(self, *args, **kwargs):
import sys; sys.path.append('.')
# import the symbol from the shared library
function = getattr(__import__("lpython_jit_module"), self.fn_name)
function = getattr(__import__("lpython_module_" + self.fn_name),
self.fn_name)
return function(*args, **kwargs)

0 comments on commit bb8bfe7

Please sign in to comment.