Skip to content

Commit

Permalink
Merge pull request #2260 from anutosh491/GSoC_PR_12
Browse files Browse the repository at this point in the history
Removed Redundant symbolic support through the C backend
  • Loading branch information
certik committed Aug 9, 2023
2 parents 320841b + a554bbd commit 818e2cd
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 242 deletions.
25 changes: 1 addition & 24 deletions src/libasr/codegen/asr_to_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1045,11 +1045,7 @@ R"( // Initialise Numpy
bracket_open++;
visit_expr(*x.m_test);
std::string test_condition = src;
if (ASR::is_a<ASR::SymbolicCompare_t>(*x.m_test)){
out = symengine_src;
symengine_src = "";
out += indent;
}

if (x.m_msg) {
this->visit_expr(*x.m_msg);
std::string tmp_gen = "";
Expand All @@ -1065,19 +1061,10 @@ R"( // Initialise Numpy
if( ASRUtils::is_array(value_type) ) {
src += "->data";
}
if(ASR::is_a<ASR::SymbolicExpression_t>(*value_type)) {
src += symengine_src;
symengine_src = "";
}
if (ASR::is_a<ASR::Complex_t>(*value_type)) {
tmp_gen += "creal(" + src + ")";
tmp_gen += ", ";
tmp_gen += "cimag(" + src + ")";
} else if(ASR::is_a<ASR::SymbolicExpression_t>(*value_type)){
tmp_gen += "basic_str(" + src + ")";
if(ASR::is_a<ASR::Var_t>(*x.m_msg)) {
symengine_queue.pop();
}
} else {
tmp_gen += src;
}
Expand Down Expand Up @@ -1152,10 +1139,6 @@ R"( // Initialise Numpy
if( ASRUtils::is_array(value_type) ) {
src += "->data";
}
if(ASR::is_a<ASR::SymbolicExpression_t>(*value_type)) {
out += symengine_src;
symengine_src = "";
}
if( ASR::is_a<ASR::List_t>(*value_type) ||
ASR::is_a<ASR::Tuple_t>(*value_type)) {
tmp_gen += "\"";
Expand All @@ -1178,12 +1161,6 @@ R"( // Initialise Numpy
v.pop_back();
v.push_back("creal(" + src + ")");
v.push_back("cimag(" + src + ")");
} else if(ASR::is_a<ASR::SymbolicExpression_t>(*value_type)){
v.pop_back();
v.push_back("basic_str(" + src + ")");
if(ASR::is_a<ASR::Var_t>(*x.m_values[i])) {
symengine_queue.pop();
}
}
if (i+1!=x.n_values) {
tmp_gen += "\%s";
Expand Down
218 changes: 0 additions & 218 deletions src/libasr/codegen/asr_to_c_cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#include <iostream>
#include <memory>
#include <set>
#include <unordered_set>

#include <libasr/asr.h>
#include <libasr/containers.h>
Expand Down Expand Up @@ -84,36 +83,6 @@ struct CPPDeclarationOptions: public DeclarationOptions {
}
};

class SymEngineQueue {
public:
std::vector<std::string> queue;
int queue_front = -1;
std::string& symengine_src;
std::unordered_set<std::string> variables_to_free;

SymEngineQueue(std::string& symengine_src) : symengine_src(symengine_src) {}

std::string push() {
std::string indent(4, ' ');
std::string var;
if(queue_front == -1 || queue_front >= static_cast<int>(queue.size())) {
var = "queue" + std::to_string(queue.size());
queue.push_back(var);
if(queue_front == -1) queue_front++;
symengine_src = indent + "basic " + var + ";\n";
symengine_src += indent + "basic_new_stack(" + var + ");\n";
}
variables_to_free.insert(queue[queue_front]);
return queue[queue_front++];
}

void pop() {
LCOMPILERS_ASSERT(queue_front != -1 && queue_front < static_cast<int>(queue.size()));
variables_to_free.insert(queue[queue_front]);
queue_front++;
}
};

template <class Struct>
class BaseCCPPVisitor : public ASR::BaseVisitor<Struct>
{
Expand Down Expand Up @@ -147,8 +116,6 @@ class BaseCCPPVisitor : public ASR::BaseVisitor<Struct>
bool is_c;
std::set<std::string> headers, user_headers, user_defines;
std::vector<std::string> tmp_buffer_src;
std::string symengine_src;
SymEngineQueue symengine_queue{symengine_src};

SymbolTable* global_scope;
int64_t lower_bound;
Expand Down Expand Up @@ -512,8 +479,6 @@ R"(#include <stdio.h>
}
}
}
} else if (ASR::is_a<ASR::SymbolicExpression_t>(*return_var->m_type)) {
sub = "basic ";
} else if (ASR::is_a<ASR::CPtr_t>(*return_var->m_type)) {
sub = "void* ";
} else if (ASR::is_a<ASR::List_t>(*return_var->m_type)) {
Expand Down Expand Up @@ -841,10 +806,6 @@ R"(#include <stdio.h>
if (v->m_intent == ASRUtils::intent_local ||
v->m_intent == ASRUtils::intent_return_var) {
std::string d = indent + self().convert_variable_decl(*v) + ";\n";
if (ASR::is_a<ASR::SymbolicExpression_t>(*v->m_type)) {
std::string v_m_name = v->m_name;
d += indent + "basic_new_stack(" + v_m_name + ");\n";
}
decl += check_tmp_buffer() + d;
}
if (ASR::is_a<ASR::TypeParameter_t>(*v->m_type)) {
Expand Down Expand Up @@ -879,10 +840,6 @@ R"(#include <stdio.h>
+ ";\n";
}

for (const auto& var : symengine_queue.variables_to_free) {
current_body += indent + "basic_free_stack(" + var + ");\n";
}
symengine_queue.variables_to_free.clear();
if (decl.size() > 0 || current_body.size() > 0) {
sub += "{\n" + decl + current_body + "}\n";
} else {
Expand Down Expand Up @@ -1337,17 +1294,6 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
target = "&" + target;
}
}
if( ASR::is_a<ASR::SymbolicExpression_t>(*value_type) ) {
if(ASR::is_a<ASR::Var_t>(*x.m_value)){
src = indent + "basic_assign(" + target + ", " + value + ");\n";
symengine_queue.pop();
symengine_queue.pop();
return;
}
src = symengine_src;
symengine_src = "";
return;
}
if( !from_std_vector_helper.empty() ) {
src = from_std_vector_helper;
} else {
Expand Down Expand Up @@ -1825,15 +1771,6 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
src = std::string(ASR::down_cast<ASR::Variable_t>(s)->m_name);
}
last_expr_precedence = 2;
ASR::ttype_t* var_type = sv->m_type;
if( ASR::is_a<ASR::SymbolicExpression_t>(*var_type)) {
std::string var_name = std::string(ASR::down_cast<ASR::Variable_t>(s)->m_name);
symengine_queue.queue.push_back(var_name);
if (symengine_queue.queue_front == -1) {
symengine_queue.queue_front = 0;
}
symengine_src = "";
}
}

void visit_StructInstanceMember(const ASR::StructInstanceMember_t& x) {
Expand Down Expand Up @@ -2048,11 +1985,6 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
last_expr_precedence = 2;
break;
}
case (ASR::cast_kindType::IntegerToSymbolicExpression): {
self().visit_expr(*x.m_value);
last_expr_precedence = 2;
break;
}
default : throw CodeGenError("Cast kind " + std::to_string(x.m_kind) + " not implemented",
x.base.base.loc);
}
Expand Down Expand Up @@ -2100,40 +2032,6 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
handle_Compare(x);
}

void visit_SymbolicCompare(const ASR::SymbolicCompare_t &x) {
CHECK_FAST_C_CPP(compiler_options, x)
self().visit_expr(*x.m_left);
std::string left_src = symengine_src;
if(ASR::is_a<ASR::Var_t>(*x.m_left)){
symengine_queue.pop();
}
std::string left = std::move(src);

self().visit_expr(*x.m_right);
std::string right_src = symengine_src;
if(ASR::is_a<ASR::Var_t>(*x.m_right)){
symengine_queue.pop();
}
std::string right = std::move(src);
std::string op_str = ASRUtils::cmpop_to_str(x.m_op);
switch (x.m_op) {
case (ASR::cmpopType::Eq) : {
src = "basic_eq(" + left + ", " + right + ") " + op_str + " 1";
break;
}
case (ASR::cmpopType::NotEq) : {
src = "basic_neq(" + left + ", " + right + ") " + op_str + " 0";
break;
}
default : {
throw LCompilersException("Symbolic comparison operator: '"
+ op_str
+ "' is not implemented");
}
}
symengine_src = left_src + right_src;
}

template<typename T>
void handle_Compare(const T &x) {
CHECK_FAST_C_CPP(compiler_options, x)
Expand Down Expand Up @@ -2846,48 +2744,6 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
out += func_name; break; \
}

std::string performBinarySymbolicOperation(const std::string& functionName, const ASR::IntrinsicFunction_t& x) {
headers.insert("symengine/cwrapper.h");
std::string indent(4, ' ');
LCOMPILERS_ASSERT(x.n_args == 2);
std::string target = symengine_queue.push();
std::string target_src = symengine_src;
this->visit_expr(*x.m_args[0]);
std::string arg1 = src;
std::string arg1_src = symengine_src;
// Check if x.m_args[0] is a Var
if (ASR::is_a<ASR::Var_t>(*x.m_args[0])) {
symengine_queue.pop();
}
this->visit_expr(*x.m_args[1]);
std::string arg2 = src;
std::string arg2_src = symengine_src;
// Check if x.m_args[0] is a Var
if (ASR::is_a<ASR::Var_t>(*x.m_args[1])) {
symengine_queue.pop();
}
symengine_src = target_src + arg1_src + arg2_src;
symengine_src += indent + functionName + "(" + target + ", " + arg1 + ", " + arg2 + ");\n";
return target;
}

std::string performUnarySymbolicOperation(const std::string& functionName, const ASR::IntrinsicFunction_t& x) {
headers.insert("symengine/cwrapper.h");
std::string indent(4, ' ');
LCOMPILERS_ASSERT(x.n_args == 1);
std::string target = symengine_queue.push();
std::string target_src = symengine_src;
this->visit_expr(*x.m_args[0]);
std::string arg1 = src;
std::string arg1_src = symengine_src;
if (ASR::is_a<ASR::Var_t>(*x.m_args[0])) {
symengine_queue.pop();
}
symengine_src = target_src + arg1_src;
symengine_src += indent + functionName + "(" + target + ", " + arg1 + ");\n";
return target;
}

void visit_IntrinsicFunction(const ASR::IntrinsicFunction_t &x) {
std::string out;
std::string indent(4, ' ');
Expand All @@ -2905,80 +2761,6 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
SET_INTRINSIC_NAME(Exp, "exp");
SET_INTRINSIC_NAME(Exp2, "exp2");
SET_INTRINSIC_NAME(Expm1, "expm1");
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicAdd)): {
src = performBinarySymbolicOperation("basic_add", x);
return;
}
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicSub)): {
src = performBinarySymbolicOperation("basic_sub", x);
return;
}
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicMul)): {
src = performBinarySymbolicOperation("basic_mul", x);
return;
}
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicDiv)): {
src = performBinarySymbolicOperation("basic_div", x);
return;
}
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicPow)): {
src = performBinarySymbolicOperation("basic_pow", x);
return;
}
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicDiff)): {
src = performBinarySymbolicOperation("basic_diff", x);
return;
}
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicSin)): {
src = performUnarySymbolicOperation("basic_sin", x);
return;
}
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicCos)): {
src = performUnarySymbolicOperation("basic_cos", x);
return;
}
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicLog)): {
src = performUnarySymbolicOperation("basic_log", x);
return;
}
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicExp)): {
src = performUnarySymbolicOperation("basic_exp", x);
return;
}
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicAbs)): {
src = performUnarySymbolicOperation("basic_abs", x);
return;
}
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicExpand)): {
src = performUnarySymbolicOperation("basic_expand", x);
return;
}
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicPi)): {
headers.insert("symengine/cwrapper.h");
LCOMPILERS_ASSERT(x.n_args == 0);
std::string target = symengine_queue.push();
symengine_src += indent + "basic_const_pi(" + target + ");\n";
src = target;
return;
}
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicSymbol)): {
headers.insert("symengine/cwrapper.h");
LCOMPILERS_ASSERT(x.n_args == 1);
this->visit_expr(*x.m_args[0]);
std::string target = symengine_queue.push();
symengine_src += indent + "symbol_set(" + target + ", " + src + ");\n";
src = target;
return;
}
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicInteger)): {
headers.insert("symengine/cwrapper.h");
LCOMPILERS_ASSERT(x.n_args == 1);
this->visit_expr(*x.m_args[0]);
std::string target = symengine_queue.push();
symengine_src += indent + "integer_set_si(" + target + ", " + src + ");\n";
src = target;
return;
}
default : {
throw LCompilersException("IntrinsicFunction: `"
+ ASRUtils::get_intrinsic_name(x.m_intrinsic_id)
Expand Down

0 comments on commit 818e2cd

Please sign in to comment.