Skip to content

Commit 65c5691

Browse files
Improve performance of _build_sweep_const (#7601)
- This function was building the Const message twice. - Instead, write directly to the proto message. - This speeds up run contexts with Const values up by about 40%. --------- Co-authored-by: Pavol Juhas <[email protected]>
1 parent f841c5d commit 65c5691

File tree

2 files changed

+18
-14
lines changed

2 files changed

+18
-14
lines changed

cirq-google/cirq_google/api/v2/sweeps.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,36 +30,38 @@
3030
from cirq.study import sweeps
3131

3232

33-
def _build_sweep_const(value: Any, use_float64: bool = False) -> run_context_pb2.ConstValue:
33+
def _add_sweep_const(
34+
sweep: run_context_pb2.SingleSweep, value: Any, use_float64: bool = False
35+
) -> None:
3436
"""Build the sweep const message from a value."""
3537
if isinstance(value, float):
3638
# comparing to float is ~5x than testing numbers.Real
3739
# if modifying the below, also modify the block below for numbers.Real
3840
if use_float64:
39-
return run_context_pb2.ConstValue(double_value=value)
41+
sweep.const_value.double_value = value
4042
else:
4143
# Note: A loss of precision for floating-point numbers may occur here.
42-
return run_context_pb2.ConstValue(float_value=value)
44+
sweep.const_value.float_value = value
4345
elif isinstance(value, int):
4446
# comparing to int is ~5x than testing numbers.Integral
4547
# if modifying the below, also modify the block below for numbers.Integral
46-
return run_context_pb2.ConstValue(int_value=value)
48+
sweep.const_value.int_value = value
4749
elif value is None:
48-
return run_context_pb2.ConstValue(is_none=True)
50+
sweep.const_value.is_none = True
4951
elif isinstance(value, str):
50-
return run_context_pb2.ConstValue(string_value=value)
52+
sweep.const_value.string_value = value
5153
elif isinstance(value, numbers.Integral):
5254
# more general than isinstance(int) but also slower
53-
return run_context_pb2.ConstValue(int_value=int(value))
55+
sweep.const_value.int_value = int(value)
5456
elif isinstance(value, numbers.Real):
5557
# more general than isinstance(float) but also slower
5658
if use_float64:
57-
return run_context_pb2.ConstValue(double_value=float(value))
59+
sweep.const_value.double_value = float(value) # pragma: no cover
5860
else:
5961
# Note: A loss of precision for floating-point numbers may occur here.
60-
return run_context_pb2.ConstValue(float_value=float(value))
62+
sweep.const_value.float_value = float(value)
6163
elif isinstance(value, tunits.Value):
62-
return run_context_pb2.ConstValue(with_unit_value=value.to_proto())
64+
value.to_proto(sweep.const_value.with_unit_value)
6365
else:
6466
raise ValueError(
6567
f"Unsupported type for serializing const sweep: {value=} and {type(value)=}"
@@ -191,7 +193,7 @@ def sweep_to_proto(
191193
sweep = cast(cirq.Points, sweep_transformer(sweep))
192194
out.single_sweep.parameter_key = sweep.key
193195
if len(sweep.points) == 1:
194-
out.single_sweep.const_value.MergeFrom(_build_sweep_const(sweep.points[0], use_float64))
196+
_add_sweep_const(out.single_sweep, sweep.points[0], use_float64)
195197
else:
196198
if isinstance(sweep.points[0], tunits.Value):
197199
unit = sweep.points[0].unit
@@ -404,7 +406,7 @@ def sweepable_to_proto(
404406
for key, val in sweepable.items():
405407
single_sweep = zip_proto.sweeps.add().single_sweep
406408
single_sweep.parameter_key = key
407-
single_sweep.const_value.MergeFrom(_build_sweep_const(val, use_float64))
409+
_add_sweep_const(single_sweep, val, use_float64)
408410
return out
409411
if isinstance(sweepable, Iterable):
410412
for sweepable_element in sweepable:

cirq-google/cirq_google/api/v2/sweeps_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,9 @@ def test_sweep_to_proto_linspace():
164164

165165
@pytest.mark.parametrize("val", [None, 1, 1.5, 's'])
166166
def test_build_recover_const(val):
167-
val2 = v2.sweeps._recover_sweep_const(v2.sweeps._build_sweep_const(val))
167+
sweep = v2.run_context_pb2.SingleSweep()
168+
v2.sweeps._add_sweep_const(sweep, val)
169+
val2 = v2.sweeps._recover_sweep_const(sweep.const_value)
168170
if isinstance(val, float):
169171
assert math.isclose(val, val2) # avoid the floating precision issue.
170172
else:
@@ -179,7 +181,7 @@ def test_build_covert_const_double():
179181

180182
def test_build_const_unsupported_type():
181183
with pytest.raises(ValueError, match='Unsupported type for serializing const sweep'):
182-
v2.sweeps._build_sweep_const((1, 2))
184+
v2.sweeps._add_sweep_const(v2.run_context_pb2.SingleSweep(), (1, 2))
183185

184186

185187
def test_list_sweep_bad_expression():

0 commit comments

Comments
 (0)