Skip to content

Commit

Permalink
Added support for casting within visit_SubroutineCall
Browse files Browse the repository at this point in the history
  • Loading branch information
anutosh491 authored and certik committed Sep 24, 2023
1 parent 46df3e5 commit 37cc6eb
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
2 changes: 1 addition & 1 deletion integration_tests/symbolics_09.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def call_addInteger():
e: S = cos(b)
addInteger(c, d, e, 2)
addInteger(c, sin(a), cos(b), 2)
addInteger(c, sin(Symbol("x")), cos(Symbol("y")), 2)
addInteger(pi, sin(Symbol("x")), cos(Symbol("y")), 2)

def main0():
call_addInteger()
Expand Down
14 changes: 12 additions & 2 deletions src/libasr/pass/replace_symbolic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,17 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg));
process_intrinsic_function(al, x.base.base.loc, intrinsic_func, module_scope, target);

ASR::call_arg_t call_arg;
call_arg.loc = x.base.base.loc;
call_arg.m_value = target;
call_args.push_back(al, call_arg);
} else if (ASR::is_a<ASR::Cast_t>(*val)) {
ASR::Cast_t* cast_t = ASR::down_cast<ASR::Cast_t>(val);
if(cast_t->m_kind != ASR::cast_kindType::IntegerToSymbolicExpression) return;
this->visit_Cast(*cast_t);
ASR::symbol_t *var_sym = current_scope->get_symbol(symengine_stack.pop());
ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym));

ASR::call_arg_t call_arg;
call_arg.loc = x.base.base.loc;
call_arg.m_value = target;
Expand Down Expand Up @@ -793,9 +804,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
} else if (ASR::is_a<ASR::Cast_t>(*val)) {
ASR::Cast_t* cast_t = ASR::down_cast<ASR::Cast_t>(val);
if(cast_t->m_kind != ASR::cast_kindType::IntegerToSymbolicExpression) return;
ASR::symbol_t *var_sym = nullptr;
this->visit_Cast(*cast_t);
var_sym = current_scope->get_symbol(symengine_stack.pop());
ASR::symbol_t *var_sym = current_scope->get_symbol(symengine_stack.pop());
ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym));

// Now create the FunctionCall node for basic_str
Expand Down

0 comments on commit 37cc6eb

Please sign in to comment.