Skip to content

Commit

Permalink
[Bugfix] Re-allow non-integral line/grid qubits (#7110)
Browse files Browse the repository at this point in the history
* Allow non-integral line/grid qubits

* tests

* fix numpy overflow

* fix tests

* fix tests on windows

* Address PR comments
  • Loading branch information
daxfohl authored Mar 3, 2025
1 parent 2ad4136 commit 6ff3d66
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 23 deletions.
8 changes: 2 additions & 6 deletions cirq-core/cirq/devices/grid_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,6 @@ def __new__(cls, row: int, col: int, *, dimension: int) -> 'cirq.GridQid':
dimension: The dimension of the qid's Hilbert space, i.e.
the number of quantum levels.
"""
row = int(row)
col = int(col)
dimension = int(dimension)
key = (row, col, dimension)
inst = cls._cache.get(key)
Expand All @@ -224,7 +222,7 @@ def __new__(cls, row: int, col: int, *, dimension: int) -> 'cirq.GridQid':
inst._row = row
inst._col = col
inst._dimension = dimension
inst._hash = ((dimension - 2) * 1_000_003 + col) * 1_000_003 + row
inst._hash = ((dimension - 2) * 1_000_003 + hash(col)) * 1_000_003 + hash(row)
cls._cache[key] = inst
return inst

Expand Down Expand Up @@ -380,15 +378,13 @@ def __new__(cls, row: int, col: int) -> 'cirq.GridQubit':
row: the row coordinate
col: the column coordinate
"""
row = int(row)
col = int(col)
key = (row, col)
inst = cls._cache.get(key)
if inst is None:
inst = super().__new__(cls)
inst._row = row
inst._col = col
inst._hash = col * 1_000_003 + row
inst._hash = hash(col) * 1_000_003 + hash(row)
cls._cache[key] = inst
return inst

Expand Down
24 changes: 16 additions & 8 deletions cirq-core/cirq/devices/grid_qubit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,22 +393,30 @@ def test_complex():
assert isinstance(complex(cirq.GridQubit(row=1, col=2)), complex)


def test_numpy_index():
np5, np6, np3 = [np.int64(i) for i in [5, 6, 3]]
@pytest.mark.parametrize('dtype', (np.int8, np.int64, float, np.float64))
def test_numpy_index(dtype):
np5, np6, np3 = [dtype(i) for i in [5, 6, 3]]
q = cirq.GridQubit(np5, np6)
hash(q) # doesn't throw
assert hash(q) == hash(cirq.GridQubit(5, 6))
assert q.row == 5
assert q.col == 6
assert q.dimension == 2
assert isinstance(q.row, int)
assert isinstance(q.col, int)
assert isinstance(q.dimension, int)

q = cirq.GridQid(np5, np6, dimension=np3)
hash(q) # doesn't throw
assert hash(q) == hash(cirq.GridQid(5, 6, dimension=3))
assert q.row == 5
assert q.col == 6
assert q.dimension == 3
assert isinstance(q.row, int)
assert isinstance(q.col, int)
assert isinstance(q.dimension, int)


@pytest.mark.parametrize('dtype', (float, np.float64))
def test_non_integer_index(dtype):
# Not supported type-wise, but is used in practice, so behavior needs to be preserved.
q = cirq.GridQubit(dtype(5.5), dtype(6.5))
assert hash(q) == hash(cirq.GridQubit(5.5, 6.5))
assert q.row == 5.5
assert q.col == 6.5
assert isinstance(q.row, dtype)
assert isinstance(q.col, dtype)
6 changes: 2 additions & 4 deletions cirq-core/cirq/devices/line_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ def __new__(cls, x: int, dimension: int) -> 'cirq.LineQid':
dimension: The dimension of the qid's Hilbert space, i.e.
the number of quantum levels.
"""
x = int(x)
dimension = int(dimension)
key = (x, dimension)
inst = cls._cache.get(key)
Expand All @@ -200,7 +199,7 @@ def __new__(cls, x: int, dimension: int) -> 'cirq.LineQid':
inst = super().__new__(cls)
inst._x = x
inst._dimension = dimension
inst._hash = (dimension - 2) * 1_000_003 + x
inst._hash = (dimension - 2) * 1_000_003 + hash(x)
cls._cache[key] = inst
return inst

Expand Down Expand Up @@ -302,12 +301,11 @@ def __new__(cls, x: int) -> 'cirq.LineQubit':
Args:
x: The x coordinate.
"""
x = int(x)
inst = cls._cache.get(x)
if inst is None:
inst = super().__new__(cls)
inst._x = x
inst._hash = x
inst._hash = hash(x)
cls._cache[x] = inst
return inst

Expand Down
18 changes: 13 additions & 5 deletions cirq-core/cirq/devices/line_qubit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,18 +287,26 @@ def test_numeric():
assert isinstance(complex(cirq.LineQubit(x=5)), complex)


def test_numpy_index():
np5 = np.int64(5)
@pytest.mark.parametrize('dtype', (np.int8, np.int64, float, np.float64))
def test_numpy_index(dtype):
np5 = dtype(5)
q = cirq.LineQubit(np5)
assert hash(q) == 5
assert q.x == 5
assert q.dimension == 2
assert isinstance(q.x, int)
assert isinstance(q.dimension, int)

q = cirq.LineQid(np5, np.int64(3))
q = cirq.LineQid(np5, dtype(3))
hash(q) # doesn't throw
assert q.x == 5
assert q.dimension == 3
assert isinstance(q.x, int)
assert isinstance(q.dimension, int)


@pytest.mark.parametrize('dtype', (float, np.float64))
def test_non_integer_index(dtype):
# Not supported type-wise, but is used in practice, so behavior needs to be preserved.
q = cirq.LineQubit(dtype(5.5))
assert q.x == 5.5
assert q.x == dtype(5.5)
assert isinstance(q.x, dtype)

0 comments on commit 6ff3d66

Please sign in to comment.