Skip to content

Commit

Permalink
Merge pull request lcompilers#2337 from anutosh491/Fixing_symbolic_at…
Browse files Browse the repository at this point in the history
…tributes

Adding support for executing attribute/query calls without assigning to a prior variable
  • Loading branch information
certik committed Sep 27, 2023
2 parents 5b51c3c + a4e8f9c commit a872c0b
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 1 deletion.
14 changes: 13 additions & 1 deletion integration_tests/symbolics_05.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from sympy import Symbol, expand, diff
from sympy import Symbol, expand, diff, sin, cos, exp, pi
from lpython import S

def test_operations():
Expand All @@ -21,4 +21,16 @@ def test_operations():
print(a.diff(x))
print(diff(b, x))

# test diff 2
c:S = sin(x)
d:S = cos(x)
assert(sin(Symbol("x")).diff(x) == d)
assert(sin(x).diff(Symbol("x")) == d)
assert(sin(x).diff(x) == d)
assert(sin(x).diff(x).diff(x) == S(-1)*c)
assert(sin(x).expand().diff(x).diff(x) == S(-1)*c)
assert((sin(x) + cos(x)).diff(x) == S(-1)*c + d)
assert((sin(x) + cos(x) + exp(x) + pi).diff(x).expand().diff(x) == exp(x) + S(-1)*c + S(-1)*d)


test_operations()
82 changes: 82 additions & 0 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7166,6 +7166,27 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
st = current_scope->get_symbol(call_name_store);
} else {
st = current_scope->resolve_symbol(mod_name);
std::set<std::string> symbolic_attributes = {
"diff", "expand"
};
std::set<std::string> symbolic_constants = {
"pi"
};
if (symbolic_attributes.find(call_name) != symbolic_attributes.end() &&
symbolic_constants.find(mod_name) != symbolic_constants.end()){
ASRUtils::create_intrinsic_function create_func;
create_func = ASRUtils::IntrinsicScalarFunctionRegistry::get_create_function(mod_name);
Vec<ASR::expr_t*> eles; eles.reserve(al, args.size());
Vec<ASR::expr_t*> args_; args_.reserve(al, 1);
for (size_t i=0; i<args.size(); i++) {
eles.push_back(al, args[i].m_value);
}
tmp = create_func(al, at->base.base.loc, args_,
[&](const std::string &msg, const Location &loc) {
throw SemanticError(msg, loc); });
handle_symbolic_attribute(ASRUtils::EXPR(tmp), call_name, loc, eles);
return;
}
if (!st) {
throw SemanticError("NameError: '" + mod_name + "' is not defined", n->base.base.loc);
}
Expand Down Expand Up @@ -7220,6 +7241,32 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
ASR::expr_t* expr = ASR::down_cast<ASR::expr_t>(tmp);
handle_builtin_attribute(expr, at->m_attr, loc, eles);
return;
} else if (AST::is_a<AST::BinOp_t>(*at->m_value)) {
AST::BinOp_t* bop = AST::down_cast<AST::BinOp_t>(at->m_value);
std::set<std::string> symbolic_attributes = {
"diff", "expand"
};
if (symbolic_attributes.find(at->m_attr) != symbolic_attributes.end()){
switch (bop->m_op) {
case (AST::operatorType::Add) :
case (AST::operatorType::Sub) :
case (AST::operatorType::Mult) :
case (AST::operatorType::Div) :
case (AST::operatorType::Pow) : {
visit_BinOp(*bop);
Vec<ASR::expr_t*> eles;
eles.reserve(al, args.size());
for (size_t i=0; i<args.size(); i++) {
eles.push_back(al, args[i].m_value);
}
handle_symbolic_attribute(ASRUtils::EXPR(tmp), at->m_attr, loc, eles);
return;
}
default : {
throw SemanticError("Binary operator type not supported", loc);
}
}
}
} else if (AST::is_a<AST::ConstantInt_t>(*at->m_value)) {
if (std::string(at->m_attr) == std::string("bit_length")) {
//bit_length() attribute:
Expand All @@ -7241,6 +7288,41 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
std::string res = n->m_value;
handle_constant_string_attributes(res, args, at->m_attr, loc);
return;
} else if (AST::is_a<AST::Call_t>(*at->m_value)) {
AST::Call_t* call = AST::down_cast<AST::Call_t>(at->m_value);
std::set<std::string> symbolic_attributes = {
"diff", "expand"
};
if (symbolic_attributes.find(at->m_attr) != symbolic_attributes.end()){
std::set<std::string> symbolic_functions = {
"sin", "cos", "log", "exp", "Abs", "Symbol"
};
if (AST::is_a<AST::Attribute_t>(*call->m_func)) {
visit_Call(*call);
Vec<ASR::expr_t*> eles;
eles.reserve(al, args.size());
for (size_t i=0; i<args.size(); i++) {
eles.push_back(al, args[i].m_value);
}
handle_symbolic_attribute(ASRUtils::EXPR(tmp), at->m_attr, loc, eles);
return;
} else if (AST::is_a<AST::Name_t>(*call->m_func)) {
AST::Name_t *n = AST::down_cast<AST::Name_t>(call->m_func);
std::string call_name = n->m_id;
if (symbolic_functions.find(call_name) != symbolic_functions.end()) {
visit_Call(*call);
Vec<ASR::expr_t*> eles;
eles.reserve(al, args.size());
for (size_t i=0; i<args.size(); i++) {
eles.push_back(al, args[i].m_value);
}
handle_symbolic_attribute(ASRUtils::EXPR(tmp), at->m_attr, loc, eles);
return;
} else {
throw SemanticError(std::string(call_name) + " not supported in Call", loc);
}
}
}
} else {
throw SemanticError("Only Name type and constant integers supported in Call", loc);
}
Expand Down

0 comments on commit a872c0b

Please sign in to comment.