Skip to content

Commit a872c0b

Browse files
authored
Merge pull request lcompilers#2337 from anutosh491/Fixing_symbolic_attributes
Adding support for executing attribute/query calls without assigning to a prior variable
2 parents 5b51c3c + a4e8f9c commit a872c0b

2 files changed

Lines changed: 95 additions & 1 deletion

File tree

integration_tests/symbolics_05.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from sympy import Symbol, expand, diff
1+
from sympy import Symbol, expand, diff, sin, cos, exp, pi
22
from lpython import S
33

44
def test_operations():
@@ -21,4 +21,16 @@ def test_operations():
2121
print(a.diff(x))
2222
print(diff(b, x))
2323

24+
# test diff 2
25+
c:S = sin(x)
26+
d:S = cos(x)
27+
assert(sin(Symbol("x")).diff(x) == d)
28+
assert(sin(x).diff(Symbol("x")) == d)
29+
assert(sin(x).diff(x) == d)
30+
assert(sin(x).diff(x).diff(x) == S(-1)*c)
31+
assert(sin(x).expand().diff(x).diff(x) == S(-1)*c)
32+
assert((sin(x) + cos(x)).diff(x) == S(-1)*c + d)
33+
assert((sin(x) + cos(x) + exp(x) + pi).diff(x).expand().diff(x) == exp(x) + S(-1)*c + S(-1)*d)
34+
35+
2436
test_operations()

src/lpython/semantics/python_ast_to_asr.cpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7166,6 +7166,27 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
71667166
st = current_scope->get_symbol(call_name_store);
71677167
} else {
71687168
st = current_scope->resolve_symbol(mod_name);
7169+
std::set<std::string> symbolic_attributes = {
7170+
"diff", "expand"
7171+
};
7172+
std::set<std::string> symbolic_constants = {
7173+
"pi"
7174+
};
7175+
if (symbolic_attributes.find(call_name) != symbolic_attributes.end() &&
7176+
symbolic_constants.find(mod_name) != symbolic_constants.end()){
7177+
ASRUtils::create_intrinsic_function create_func;
7178+
create_func = ASRUtils::IntrinsicScalarFunctionRegistry::get_create_function(mod_name);
7179+
Vec<ASR::expr_t*> eles; eles.reserve(al, args.size());
7180+
Vec<ASR::expr_t*> args_; args_.reserve(al, 1);
7181+
for (size_t i=0; i<args.size(); i++) {
7182+
eles.push_back(al, args[i].m_value);
7183+
}
7184+
tmp = create_func(al, at->base.base.loc, args_,
7185+
[&](const std::string &msg, const Location &loc) {
7186+
throw SemanticError(msg, loc); });
7187+
handle_symbolic_attribute(ASRUtils::EXPR(tmp), call_name, loc, eles);
7188+
return;
7189+
}
71697190
if (!st) {
71707191
throw SemanticError("NameError: '" + mod_name + "' is not defined", n->base.base.loc);
71717192
}
@@ -7220,6 +7241,32 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
72207241
ASR::expr_t* expr = ASR::down_cast<ASR::expr_t>(tmp);
72217242
handle_builtin_attribute(expr, at->m_attr, loc, eles);
72227243
return;
7244+
} else if (AST::is_a<AST::BinOp_t>(*at->m_value)) {
7245+
AST::BinOp_t* bop = AST::down_cast<AST::BinOp_t>(at->m_value);
7246+
std::set<std::string> symbolic_attributes = {
7247+
"diff", "expand"
7248+
};
7249+
if (symbolic_attributes.find(at->m_attr) != symbolic_attributes.end()){
7250+
switch (bop->m_op) {
7251+
case (AST::operatorType::Add) :
7252+
case (AST::operatorType::Sub) :
7253+
case (AST::operatorType::Mult) :
7254+
case (AST::operatorType::Div) :
7255+
case (AST::operatorType::Pow) : {
7256+
visit_BinOp(*bop);
7257+
Vec<ASR::expr_t*> eles;
7258+
eles.reserve(al, args.size());
7259+
for (size_t i=0; i<args.size(); i++) {
7260+
eles.push_back(al, args[i].m_value);
7261+
}
7262+
handle_symbolic_attribute(ASRUtils::EXPR(tmp), at->m_attr, loc, eles);
7263+
return;
7264+
}
7265+
default : {
7266+
throw SemanticError("Binary operator type not supported", loc);
7267+
}
7268+
}
7269+
}
72237270
} else if (AST::is_a<AST::ConstantInt_t>(*at->m_value)) {
72247271
if (std::string(at->m_attr) == std::string("bit_length")) {
72257272
//bit_length() attribute:
@@ -7241,6 +7288,41 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
72417288
std::string res = n->m_value;
72427289
handle_constant_string_attributes(res, args, at->m_attr, loc);
72437290
return;
7291+
} else if (AST::is_a<AST::Call_t>(*at->m_value)) {
7292+
AST::Call_t* call = AST::down_cast<AST::Call_t>(at->m_value);
7293+
std::set<std::string> symbolic_attributes = {
7294+
"diff", "expand"
7295+
};
7296+
if (symbolic_attributes.find(at->m_attr) != symbolic_attributes.end()){
7297+
std::set<std::string> symbolic_functions = {
7298+
"sin", "cos", "log", "exp", "Abs", "Symbol"
7299+
};
7300+
if (AST::is_a<AST::Attribute_t>(*call->m_func)) {
7301+
visit_Call(*call);
7302+
Vec<ASR::expr_t*> eles;
7303+
eles.reserve(al, args.size());
7304+
for (size_t i=0; i<args.size(); i++) {
7305+
eles.push_back(al, args[i].m_value);
7306+
}
7307+
handle_symbolic_attribute(ASRUtils::EXPR(tmp), at->m_attr, loc, eles);
7308+
return;
7309+
} else if (AST::is_a<AST::Name_t>(*call->m_func)) {
7310+
AST::Name_t *n = AST::down_cast<AST::Name_t>(call->m_func);
7311+
std::string call_name = n->m_id;
7312+
if (symbolic_functions.find(call_name) != symbolic_functions.end()) {
7313+
visit_Call(*call);
7314+
Vec<ASR::expr_t*> eles;
7315+
eles.reserve(al, args.size());
7316+
for (size_t i=0; i<args.size(); i++) {
7317+
eles.push_back(al, args[i].m_value);
7318+
}
7319+
handle_symbolic_attribute(ASRUtils::EXPR(tmp), at->m_attr, loc, eles);
7320+
return;
7321+
} else {
7322+
throw SemanticError(std::string(call_name) + " not supported in Call", loc);
7323+
}
7324+
}
7325+
}
72447326
} else {
72457327
throw SemanticError("Only Name type and constant integers supported in Call", loc);
72467328
}

0 commit comments

Comments
 (0)