Skip to content

Commit

Permalink
Implemented symbolic constant E
Browse files Browse the repository at this point in the history
  • Loading branch information
anutosh491 committed Oct 22, 2023
1 parent f8fbd8d commit a67a76d
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 70 deletions.
57 changes: 34 additions & 23 deletions src/libasr/pass/intrinsic_function_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ enum class IntrinsicScalarFunctions : int64_t {
SymbolicDiv,
SymbolicPow,
SymbolicPi,
SymbolicE,
SymbolicInteger,
SymbolicDiff,
SymbolicExpand,
Expand Down Expand Up @@ -135,6 +136,7 @@ inline std::string get_intrinsic_name(int x) {
INTRINSIC_NAME_CASE(SymbolicDiv)
INTRINSIC_NAME_CASE(SymbolicPow)
INTRINSIC_NAME_CASE(SymbolicPi)
INTRINSIC_NAME_CASE(SymbolicE)
INTRINSIC_NAME_CASE(SymbolicInteger)
INTRINSIC_NAME_CASE(SymbolicDiff)
INTRINSIC_NAME_CASE(SymbolicExpand)
Expand Down Expand Up @@ -3004,30 +3006,34 @@ create_symbolic_binary_macro(SymbolicDiv)
create_symbolic_binary_macro(SymbolicPow)
create_symbolic_binary_macro(SymbolicDiff)

namespace SymbolicPi {

static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, diag::Diagnostics& diagnostics) {
ASRUtils::require_impl(x.n_args == 0, "SymbolicPi does not take arguments",
x.base.base.loc, diagnostics);
}

static inline ASR::expr_t *eval_SymbolicPi(Allocator &/*al*/,
const Location &/*loc*/, ASR::ttype_t *, Vec<ASR::expr_t*>& /*args*/) {
// TODO
return nullptr;
}

static inline ASR::asr_t* create_SymbolicPi(Allocator& al, const Location& loc,
Vec<ASR::expr_t*>& args,
const std::function<void (const std::string &, const Location &)> /*err*/) {
ASR::ttype_t *to_type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, loc));
ASR::expr_t* compile_time_value = eval_SymbolicPi(al, loc, to_type, args);
return ASR::make_IntrinsicScalarFunction_t(al, loc,
static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicPi),
nullptr, 0, 0, to_type, compile_time_value);
}
#define create_symbolic_constants_macro(X) \
namespace X { \
static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, \
diag::Diagnostics& diagnostics) { \
const Location& loc = x.base.base.loc; \
ASRUtils::require_impl(x.n_args == 0, \
#X " does not take arguments", loc, diagnostics); \
} \
\
static inline ASR::expr_t* eval_##X(Allocator &/*al*/, const Location &/*loc*/, \
ASR::ttype_t *, Vec<ASR::expr_t*> &/*args*/) { \
/*TODO*/ \
return nullptr; \
} \
\
static inline ASR::asr_t* create_##X(Allocator& al, const Location& loc, \
Vec<ASR::expr_t*>& args, \
const std::function<void (const std::string &, const Location &)> /*err*/) { \
ASR::ttype_t *to_type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, loc)); \
ASR::expr_t* compile_time_value = eval_##X(al, loc, to_type, args); \
return ASR::make_IntrinsicScalarFunction_t(al, loc, \
static_cast<int64_t>(IntrinsicScalarFunctions::X), \
nullptr, 0, 0, to_type, compile_time_value); \
} \
} // namespace X

} // namespace SymbolicPi
create_symbolic_constants_macro(SymbolicPi)
create_symbolic_constants_macro(SymbolicE)

namespace SymbolicInteger {

Expand Down Expand Up @@ -3286,6 +3292,8 @@ namespace IntrinsicScalarFunctionRegistry {
{nullptr, &SymbolicPow::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicPi),
{nullptr, &SymbolicPi::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicE),
{nullptr, &SymbolicE::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicInteger),
{nullptr, &SymbolicInteger::verify_args}},
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicDiff),
Expand Down Expand Up @@ -3398,6 +3406,8 @@ namespace IntrinsicScalarFunctionRegistry {
"SymbolicPow"},
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicPi),
"pi"},
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicE),
"E"},
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicInteger),
"SymbolicInteger"},
{static_cast<int64_t>(IntrinsicScalarFunctions::SymbolicDiff),
Expand Down Expand Up @@ -3470,6 +3480,7 @@ namespace IntrinsicScalarFunctionRegistry {
{"SymbolicDiv", {&SymbolicDiv::create_SymbolicDiv, &SymbolicDiv::eval_SymbolicDiv}},
{"SymbolicPow", {&SymbolicPow::create_SymbolicPow, &SymbolicPow::eval_SymbolicPow}},
{"pi", {&SymbolicPi::create_SymbolicPi, &SymbolicPi::eval_SymbolicPi}},
{"E", {&SymbolicE::create_SymbolicE, &SymbolicE::eval_SymbolicE}},
{"SymbolicInteger", {&SymbolicInteger::create_SymbolicInteger, &SymbolicInteger::eval_SymbolicInteger}},
{"diff", {&SymbolicDiff::create_SymbolicDiff, &SymbolicDiff::eval_SymbolicDiff}},
{"expand", {&SymbolicExpand::create_SymbolicExpand, &SymbolicExpand::eval_SymbolicExpand}},
Expand Down
102 changes: 57 additions & 45 deletions src/libasr/pass/replace_symbolic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,50 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
pass_result.push_back(al, stmt);
}

void perform_symbolic_constant_operation(Allocator &al, const Location &loc, SymbolTable* module_scope,
const std::string& new_name, ASR::expr_t* value) {
symbolic_dependencies.push_back(new_name);
if (!module_scope->get_symbol(new_name)) {
std::string header = "symengine/cwrapper.h";
SymbolTable* fn_symtab = al.make_new<SymbolTable>(module_scope);

Vec<ASR::expr_t*> args;
args.reserve(al, 1);
ASR::symbol_t* arg = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "x"), arg);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg)));

Vec<ASR::stmt_t*> body;
body.reserve(al, 1);

Vec<char*> dep;
dep.reserve(al, 1);

ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, loc,
fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n,
nullptr, ASR::abiType::BindC, ASR::accessType::Public,
ASR::deftypeType::Interface, s2c(al, new_name), false, false, false,
false, false, nullptr, 0, false, false, false, s2c(al, header));
ASR::symbol_t* new_symbol = ASR::down_cast<ASR::symbol_t>(new_subrout);
module_scope->add_symbol(s2c(al, new_name), new_symbol);
}

ASR::symbol_t* func_sym = module_scope->get_symbol(new_name);
Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 1);
ASR::call_arg_t call_arg;
call_arg.loc = loc;
call_arg.m_value = value;
call_args.push_back(al, call_arg);

ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, func_sym,
func_sym, call_args.p, call_args.n, nullptr));
pass_result.push_back(al, stmt);
}

ASR::expr_t* handle_argument(Allocator &al, const Location &loc, ASR::expr_t* arg) {
if (ASR::is_a<ASR::Var_t>(*arg)) {
return arg;
Expand Down Expand Up @@ -397,55 +441,15 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
perform_symbolic_unary_operation(al, loc, module_scope, new_name, target, value1);
}

void process_constants(Allocator &al, const Location &loc, ASR::IntrinsicScalarFunction_t* /*x*/, SymbolTable* module_scope,
const std::string& new_name, ASR::expr_t* target) {
perform_symbolic_constant_operation(al, loc, module_scope, new_name, target);
}

void process_intrinsic_function(Allocator &al, const Location &loc, ASR::IntrinsicScalarFunction_t* x, SymbolTable* module_scope,
ASR::expr_t* target){
int64_t intrinsic_id = x->m_intrinsic_id;
switch (static_cast<LCompilers::ASRUtils::IntrinsicScalarFunctions>(intrinsic_id)) {
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicPi: {
std::string new_name = "basic_const_pi";
symbolic_dependencies.push_back(new_name);
if (!module_scope->get_symbol(new_name)) {
std::string header = "symengine/cwrapper.h";
SymbolTable* fn_symtab = al.make_new<SymbolTable>(module_scope);

Vec<ASR::expr_t*> args;
args.reserve(al, 1);
ASR::symbol_t* arg = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "x"), arg);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg)));

Vec<ASR::stmt_t*> body;
body.reserve(al, 1);

Vec<char*> dep;
dep.reserve(al, 1);

ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, loc,
fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n,
nullptr, ASR::abiType::BindC, ASR::accessType::Public,
ASR::deftypeType::Interface, s2c(al, new_name), false, false, false,
false, false, nullptr, 0, false, false, false, s2c(al, header));
ASR::symbol_t* new_symbol = ASR::down_cast<ASR::symbol_t>(new_subrout);
module_scope->add_symbol(s2c(al, new_name), new_symbol);
}

// Create the function call statement for basic_const_pi
ASR::symbol_t* basic_const_pi_sym = module_scope->get_symbol(new_name);
Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 1);
ASR::call_arg_t call_arg;
call_arg.loc = loc;
call_arg.m_value = target;
call_args.push_back(al, call_arg);

ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, basic_const_pi_sym,
basic_const_pi_sym, call_args.p, call_args.n, nullptr));
pass_result.push_back(al, stmt);
break;
}
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicSymbol: {
std::string new_name = "symbol_set";
symbolic_dependencies.push_back(new_name);
Expand Down Expand Up @@ -499,6 +503,14 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
pass_result.push_back(al, stmt);
break;
}
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicPi: {
process_constants(al, loc, x, module_scope, "basic_const_pi", target);
break;
}
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicE: {
process_constants(al, loc, x, module_scope, "basic_const_E", target);
break;
}
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicAdd: {
process_binary_operator(al, loc, x, module_scope, "basic_add", target);
break;
Expand Down
4 changes: 2 additions & 2 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3172,7 +3172,7 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
std::string name = x.m_id;
ASR::symbol_t *s = current_scope->resolve_symbol(name);
std::set<std::string> not_cpython_builtin = {
"pi"};
"pi", "E"};
if (s) {
tmp = ASR::make_Var_t(al, x.base.base.loc, s);
} else if (name == "i32" || name == "i64" || name == "f32" ||
Expand Down Expand Up @@ -7116,7 +7116,7 @@ we will have to use something else.
"diff", "expand", "has"
};
std::set<std::string> symbolic_constants = {
"pi"
"pi", "E"
};
if (symbolic_attributes.find(call_name) != symbolic_attributes.end() &&
symbolic_constants.find(mod_name) != symbolic_constants.end()){
Expand Down

0 comments on commit a67a76d

Please sign in to comment.