Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 39 additions & 9 deletions src/openfermion/ops/operators/symbolic_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class SymbolicOperator(metaclass=abc.ABCMeta):

@staticmethod
def _issmall(val, tol=EQ_TOLERANCE):
'''Checks whether a value is near-zero
'''Checks whether a value is near zero.

Parses the allowed coefficients above for near-zero tests.

Expand Down Expand Up @@ -618,34 +618,64 @@ def __next__(self):
term, coefficient = next(self._iter)
return self.__class__(term=term, coefficient=coefficient)

def isclose(self, other, tol=EQ_TOLERANCE):
def isclose(self, other, tol=None, rtol=EQ_TOLERANCE, atol=EQ_TOLERANCE):
"""Check if other (SymbolicOperator) is close to self.

Comparison is done for each term individually. Return True
if the difference between each term in self and other is
less than EQ_TOLERANCE
less than the specified tolerance.

Args:
other(SymbolicOperator): SymbolicOperator to compare against.
tol(float): This parameter is deprecated since version 1.8.0.
Use `rtol` and/or `atol` instead. If `tol` is provided, it
is used as the value of `atol`.
rtol(float): Relative tolerance used in comparing each term in
self and other.
atol(float): Absolute tolerance used in comparing each term in
self and other.
"""
if not isinstance(self, type(other)):
return NotImplemented

if tol is not None:
if rtol != EQ_TOLERANCE or atol != EQ_TOLERANCE:
raise ValueError(
'Parameters rtol and atol are mutually exclusive with the'
' deprecated parameter tol; use either tol or the other two,'
' not in combination.'
)
warnings.warn(
'Parameter tol is deprecated. Use rtol and/or atol instead.',
DeprecationWarning,
stacklevel=2, # Identify the location of the warning.
)
atol = tol

# terms which are in both:
for term in set(self.terms).intersection(set(other.terms)):
a = self.terms[term]
b = other.terms[term]
if not (isinstance(a, sympy.Expr) or isinstance(b, sympy.Expr)):
tol *= max(1, abs(a), abs(b))
if self._issmall(a - b, tol) is False:
if isinstance(a, sympy.Expr) or isinstance(b, sympy.Expr):
if not self._issmall(a - b, atol):
return False
elif not abs(a - b) <= atol + rtol * max(abs(a), abs(b)):
return False
# terms only in one (compare to 0.0 so only abs_tol)
# terms only in one (compare to 0.0 so only atol)
for term in set(self.terms).symmetric_difference(set(other.terms)):
if term in self.terms:
if self._issmall(self.terms[term], tol) is False:
coeff = self.terms[term]
if isinstance(coeff, sympy.Expr):
if not self._issmall(coeff, atol):
return False
elif not abs(coeff) <= atol:
return False
else:
if self._issmall(other.terms[term], tol) is False:
coeff = other.terms[term]
if isinstance(coeff, sympy.Expr):
if not self._issmall(coeff, atol):
return False
elif not abs(coeff) <= atol:
return False
return True

Expand Down
79 changes: 74 additions & 5 deletions src/openfermion/ops/operators/symbolic_operator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,16 @@
"""Tests symbolic_operator.py."""

import copy
import unittest
import warnings

import numpy
import sympy
import unittest
import warnings

from openfermion.config import EQ_TOLERANCE
from openfermion.testing.testing_utils import EqualsTester

from openfermion.ops.operators.fermion_operator import FermionOperator
from openfermion.ops.operators.symbolic_operator import SymbolicOperator
from openfermion.testing.testing_utils import EqualsTester


class DummyOperator1(SymbolicOperator):
Expand Down Expand Up @@ -868,7 +868,76 @@ def test_pow_high_term(self):
term = DummyOperator1(ops, coeff)
high = term**10
expected = DummyOperator1(ops * 10, coeff**10)
self.assertTrue(expected == high)
self.assertTrue(high.isclose(expected, rtol=1e-12, atol=1e-12))

def test_isclose_parameter_deprecation(self):
op1 = DummyOperator1('0^ 1', 1.0)
op2 = DummyOperator1('0^ 1', 1.001)

with self.assertWarns(DeprecationWarning):
op1.isclose(op2, tol=0.01)

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=DeprecationWarning)
self.assertTrue(op1.isclose(op2, tol=0.001))
self.assertFalse(op1.isclose(op2, tol=0.0001))

def test_isclose_parameter_combos(self):
op1 = DummyOperator1('0^ 1', 1.0)
op2 = DummyOperator1('0^ 1', 1.001)

with self.assertRaises(ValueError):
op1.isclose(op2, tol=0.01, rtol=1e-5)

with self.assertRaises(ValueError):
op1.isclose(op2, tol=0.01, atol=1e-5)

def test_isclose_atol_rtol(self):
op1 = DummyOperator1('0^ 1', 1.0)
op2 = DummyOperator1('0^ 1', 1.001)

op_a = DummyOperator1('0^ 1', 1.0)
op_b = DummyOperator1('0^ 1', 1.001)
self.assertTrue(op_a.isclose(op_b, atol=0.001))
self.assertFalse(op_a.isclose(op_b, atol=0.0001))

op_c = DummyOperator1('0^ 1', 1000)
op_d = DummyOperator1('0^ 1', 1001)
self.assertTrue(op_c.isclose(op_d, rtol=0.001))
self.assertFalse(op_c.isclose(op_d, rtol=0.0001))

op_e = DummyOperator1('0^ 1', 1.0)
op_f = DummyOperator1('0^ 1', 1.001)
self.assertTrue(op_e.isclose(op_f, rtol=1e-4, atol=1e-3))
self.assertFalse(op_e.isclose(op_f, rtol=1e-4, atol=1e-5))

def test_isclose(self):
op1 = DummyOperator1()
op2 = DummyOperator1()
op1 += DummyOperator1('0^ 1', 1000000)
op1 += DummyOperator1('2^ 3', 1)
op2 += DummyOperator1('0^ 1', 1000000)
op2 += DummyOperator1('2^ 3', 1.001)
self.assertFalse(op1.isclose(op2, atol=1e-4))
self.assertTrue(op1.isclose(op2, atol=1e-2))

# Case from https://github.com/quantumlib/OpenFermion/issues/764
x = FermionOperator("0^ 0")
y = FermionOperator("0^ 0")

# construct two identical operators up to some number of terms
num_terms_before_ineq = 30
for i in range(num_terms_before_ineq):
x += FermionOperator(f" (10+0j) [0^ {i}]")
y += FermionOperator(f" (10+0j) [0^ {i}]")

xfinal = FermionOperator(f" (1+0j) [0^ {num_terms_before_ineq + 1}]")
yfinal = FermionOperator(f" (2+0j) [0^ {num_terms_before_ineq + 1}]")
assert xfinal != yfinal

x += xfinal
y += yfinal
assert x != y

def test_pow_neg_error(self):
with self.assertRaises(ValueError):
Expand Down
Loading