Skip to content

Add attribute to Dummy: dummy_index #517

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion symengine/lib/symengine.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ cdef extern from "<symengine/basic.h>" namespace "SymEngine":
RCP[const Rational] rcp_static_cast_Rational "SymEngine::rcp_static_cast<const SymEngine::Rational>"(rcp_const_basic &b) nogil
RCP[const Complex] rcp_static_cast_Complex "SymEngine::rcp_static_cast<const SymEngine::Complex>"(rcp_const_basic &b) nogil
RCP[const Number] rcp_static_cast_Number "SymEngine::rcp_static_cast<const SymEngine::Number>"(rcp_const_basic &b) nogil
RCP[const Dummy] rcp_static_cast_Dummy "SymEngine::rcp_static_cast<const SymEngine::Dummy>"(rcp_const_basic &b) nogil
RCP[const Add] rcp_static_cast_Add "SymEngine::rcp_static_cast<const SymEngine::Add>"(rcp_const_basic &b) nogil
RCP[const Mul] rcp_static_cast_Mul "SymEngine::rcp_static_cast<const SymEngine::Mul>"(rcp_const_basic &b) nogil
RCP[const Pow] rcp_static_cast_Pow "SymEngine::rcp_static_cast<const SymEngine::Pow>"(rcp_const_basic &b) nogil
Expand Down Expand Up @@ -180,7 +181,7 @@ cdef extern from "<symengine/symbol.h>" namespace "SymEngine":
Symbol(string name) nogil
string get_name() nogil
cdef cppclass Dummy(Symbol):
pass
size_t get_index()

cdef extern from "<symengine/number.h>" namespace "SymEngine":
cdef cppclass Number(Basic):
Expand Down Expand Up @@ -322,6 +323,7 @@ cdef extern from "<symengine/basic.h>" namespace "SymEngine":
rcp_const_basic make_rcp_Symbol "SymEngine::make_rcp<const SymEngine::Symbol>"(string name) nogil
rcp_const_basic make_rcp_Dummy "SymEngine::make_rcp<const SymEngine::Dummy>"() nogil
rcp_const_basic make_rcp_Dummy "SymEngine::make_rcp<const SymEngine::Dummy>"(string name) nogil
rcp_const_basic make_rcp_Dummy "SymEngine::make_rcp<const SymEngine::Dummy>"(string &name, size_t index) nogil
rcp_const_basic make_rcp_PySymbol "SymEngine::make_rcp<const SymEngine::PySymbol>"(string name, PyObject * pyobj, bool use_pickle) except +
rcp_const_basic make_rcp_Constant "SymEngine::make_rcp<const SymEngine::Constant>"(string name) nogil
rcp_const_basic make_rcp_Infty "SymEngine::make_rcp<const SymEngine::Infty>"(RCP[const Number] i) nogil
Expand Down
31 changes: 21 additions & 10 deletions symengine/lib/symengine_wrapper.in.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -278,10 +278,10 @@ def sympy2symengine(a, raise_error=False):
"""
import sympy
from sympy.core.function import AppliedUndef as sympy_AppliedUndef
if isinstance(a, sympy.Symbol):
if isinstance(a, sympy.Dummy):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since Dummy is a subclass of Symbol, we need to check for it first.

return Dummy(a.name, a.dummy_index)
elif isinstance(a, sympy.Symbol):
return Symbol(a.name)
elif isinstance(a, sympy.Dummy):
return Dummy(a.name)
elif isinstance(a, sympy.Mul):
return mul(*[sympy2symengine(x, raise_error) for x in a.args])
elif isinstance(a, sympy.Add):
Expand Down Expand Up @@ -1301,10 +1301,10 @@ cdef class Symbol(Expr):
return sympy.Symbol(str(self))

def __reduce__(self):
if type(self) == Symbol:
if type(self) in (Symbol, Dummy):
return Basic.__reduce__(self)
else:
raise NotImplementedError("pickling for Symbol subclass not implemented")
raise NotImplementedError("pickling for subclass of Symbol or Dummy not implemented")

def _sage_(self):
import sage.all as sage
Expand Down Expand Up @@ -1337,15 +1337,20 @@ cdef class Symbol(Expr):

cdef class Dummy(Symbol):

def __init__(Basic self, name=None, *args, **kwargs):
if name is None:
self.thisptr = symengine.make_rcp_Dummy()
def __init__(Basic self, name=None, dummy_index=None, *args, **kwargs):
cdef size_t index
if dummy_index is None:
if name is None:
self.thisptr = symengine.make_rcp_Dummy()
else:
self.thisptr = symengine.make_rcp_Dummy(name.encode("utf-8"))
else:
self.thisptr = symengine.make_rcp_Dummy(name.encode("utf-8"))
index = dummy_index
self.thisptr = symengine.make_rcp_Dummy(name.encode("utf-8"), index)

def _sympy_(self):
import sympy
return sympy.Dummy(str(self)[1:])
return sympy.Dummy(name=self.name, dummy_index=self.dummy_index)

@property
def is_Dummy(self):
Expand All @@ -1355,6 +1360,12 @@ cdef class Dummy(Symbol):
def func(self):
return self.__class__

@property
def dummy_index(self):
cdef RCP[const symengine.Dummy] this = \
symengine.rcp_static_cast_Dummy(self.thisptr)
cdef size_t index = deref(this).get_index()
return index

def symarray(prefix, shape, **kwargs):
""" Creates an nd-array of symbols
Expand Down
18 changes: 17 additions & 1 deletion symengine/tests/test_pickling.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from symengine import symbols, sin, sinh, have_numpy, have_llvm, cos, Symbol
from symengine import symbols, sin, sinh, have_numpy, have_llvm, cos, Symbol, Dummy
from symengine.test_utilities import raises
import pickle
import unittest
Expand Down Expand Up @@ -57,3 +57,19 @@ def test_llvm_double():
ll = pickle.loads(ss)
inp = [1, 2, 3]
assert np.allclose(l(inp), ll(inp))


def _check_pickling_roundtrip(arg):
s2 = pickle.dumps(arg)
arg2 = pickle.loads(s2)
assert arg == arg2
s3 = pickle.dumps(arg2)
arg3 = pickle.loads(s3)
assert arg == arg3


def test_pickling_roundtrip():
x, y, z = symbols('x y z')
_check_pickling_roundtrip(x+y)
_check_pickling_roundtrip(Dummy('d'))
_check_pickling_roundtrip(Dummy('d') - z)
3 changes: 3 additions & 0 deletions symengine/tests/test_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ def test_dummy():
x2 = Symbol('x')
xdummy1 = Dummy('x')
xdummy2 = Dummy('x')
assert xdummy1.dummy_index != xdummy2.dummy_index # maybe test using "less than"?
assert xdummy1.name == 'x'
assert xdummy2.name == 'x'

assert x1 == x2
assert x1 != xdummy1
Expand Down
23 changes: 22 additions & 1 deletion symengine/tests/test_sympy_conv.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from symengine import (Symbol, Integer, sympify, SympifyError, log,
function_symbol, I, E, pi, oo, zoo, nan, true, false,
exp, gamma, have_mpfr, have_mpc, DenseMatrix, sin, cos, tan, cot,
exp, gamma, have_mpfr, have_mpc, DenseMatrix, Dummy, sin, cos, tan, cot,
csc, sec, asin, acos, atan, acot, acsc, asec, sinh, cosh, tanh, coth,
asinh, acosh, atanh, acoth, atan2, Add, Mul, Pow, diff, GoldenRatio,
Catalan, EulerGamma, UnevaluatedExpr, RealDouble)
Expand Down Expand Up @@ -833,3 +833,24 @@ def test_conv_large_integers():
if have_sympy:
c = a._sympy_()
d = sympify(c)


def _check_sympy_roundtrip(arg):
arg_sy1 = sympy.sympify(arg)
arg_se2 = sympify(arg_sy1)
assert arg == arg_se2
arg_sy2 = sympy.sympify(arg_se2)
assert arg_sy2 == arg_sy1
arg_se3 = sympify(arg_sy2)
assert arg_se3 == arg


@unittest.skipIf(not have_sympy, "SymPy not installed")
def test_sympy_roundtrip():
x = Symbol("x")
y = Symbol("y")
d = Dummy("d")
_check_sympy_roundtrip(x)
_check_sympy_roundtrip(x+y)
_check_sympy_roundtrip(x**y)
_check_sympy_roundtrip(d)
2 changes: 1 addition & 1 deletion symengine_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
c9510fb4b5c30b84adb993573a51f2a9a38a4cfe
c574fa8d7018a850481afa7a59809d30e774d78d
Loading