diff --git a/cirq-google/cirq_google/engine/abstract_processor.py b/cirq-google/cirq_google/engine/abstract_processor.py index 64cfe92f986..261f11c5993 100644 --- a/cirq-google/cirq_google/engine/abstract_processor.py +++ b/cirq-google/cirq_google/engine/abstract_processor.py @@ -180,17 +180,8 @@ async def run_sweep_async( run_sweep = duet.sync(run_sweep_async) @abc.abstractmethod - def get_sampler(self, run_name: str = "", device_config_name: str = "") -> cg.ProcessorSampler: - """Returns a sampler backed by the processor. - - Args: - run_name: A unique identifier representing an automation run for the - processor. An Automation Run contains a collection of device - configurations for the processor. - device_config_name: An identifier used to select the processor configuration - utilized to run the job. A configuration identifies the set of - available qubits, couplers, and supported gates in the processor. - """ + def get_sampler(self) -> cg.ProcessorSampler: + """Returns a sampler backed by the processor.""" @abc.abstractmethod def engine(self) -> abstract_engine.AbstractEngine | None: diff --git a/cirq-google/cirq_google/engine/engine.py b/cirq-google/cirq_google/engine/engine.py index 9fe3f96b0d9..c6fb85a20fb 100644 --- a/cirq-google/cirq_google/engine/engine.py +++ b/cirq-google/cirq_google/engine/engine.py @@ -590,15 +590,14 @@ def get_processor(self, processor_id: str) -> engine_processor.EngineProcessor: """ return engine_processor.EngineProcessor(self.project_id, processor_id, self.context) - def get_sampler( + def get_sampler_from_run_name( self, - processor_id: str | list[str], - run_name: str = "", - device_config_name: str = "", - snapshot_id: str = "", + processor_id: str, + run_name: str, + device_config_name: str | None = None, max_concurrent_jobs: int = 100, ) -> cirq_google.ProcessorSampler: - """Returns a sampler backed by the engine. + """Returns a sampler backed by the engine and given `run_name`. Args: processor_id: String identifier of which processor should be used to sample. @@ -608,8 +607,67 @@ def get_sampler( device_config_name: An identifier used to select the processor configuration utilized to run the job. A configuration identifies the set of available qubits, couplers, and supported gates in the processor. - snapshot_id: A unique identifier for an immutable snapshot reference. A - snapshot contains a collection of device configurations for the processor. + max_concurrent_jobs: The maximum number of jobs to be sent + simultaneously to the Engine. This client-side throttle can be + used to proactively reduce load to the backends and avoid quota + violations when pipelining circuit executions. + + Returns: + A `cirq.Sampler` instance (specifically a `engine_sampler.ProcessorSampler` + that will send circuits to the Quantum Computing Service + when sampled. + """ + return self.get_processor(processor_id).get_sampler_from_run_name( + run_name=run_name, + device_config_name=device_config_name, + max_concurrent_jobs=max_concurrent_jobs, + ) + + def get_sampler_from_snapshot_id( + self, + processor_id: str, + snapshot_id: str, + device_config_name: str | None, + max_concurrent_jobs: int = 100, + ) -> cirq_google.ProcessorSampler: + """Returns a sampler backed by the engine. + Args: + processor_id: String identifier of which processor should be used to sample. + device_config_name: An identifier used to select the processor configuration + utilized to run the job. A configuration identifies the set of + available qubits, couplers, and supported gates in the processor. + snapshot_id: A unique identifier for an immutable snapshot reference. + A snapshot contains a collection of device configurations for the + processor. + max_concurrent_jobs: The maximum number of jobs to be sent + simultaneously to the Engine. This client-side throttle can be + used to proactively reduce load to the backends and avoid quota + violations when pipelining circuit executions. + + Returns: + A `cirq.Sampler` instance (specifically a `engine_sampler.ProcessorSampler` + that will send circuits to the Quantum Computing Service + when sampled. + """ + return self.get_processor(processor_id).get_sampler_from_snapshot_id( + snapshot_id=snapshot_id, + device_config_name=device_config_name, + max_concurrent_jobs=max_concurrent_jobs, + ) + + def get_sampler( + self, + processor_id: str | list[str], + device_config_name: str | None = None, + max_concurrent_jobs: int = 100, + ) -> cirq_google.ProcessorSampler: + """Returns a sampler backed by the engine. + + Args: + processor_id: String identifier of which processor should be used to sample. + device_config_name: An identifier used to select the processor configuration + utilized to run the job. A configuration identifies the set of + available qubits, couplers, and supported gates in the processor. max_concurrent_jobs: The maximum number of jobs to be sent concurrently to the Engine. This client-side throttle can be used to proactively reduce load to the backends and avoid quota @@ -629,11 +687,9 @@ def get_sampler( 'to get_sampler() no longer supported. Use Engine.run() instead if ' 'you need to specify a list.' ) + return self.get_processor(processor_id).get_sampler( - run_name=run_name, - device_config_name=device_config_name, - snapshot_id=snapshot_id, - max_concurrent_jobs=max_concurrent_jobs, + device_config_name=device_config_name, max_concurrent_jobs=max_concurrent_jobs ) async def get_processor_config_from_snapshot_async( diff --git a/cirq-google/cirq_google/engine/engine_processor.py b/cirq-google/cirq_google/engine/engine_processor.py index ad9cbe463a2..56798daaaaf 100644 --- a/cirq-google/cirq_google/engine/engine_processor.py +++ b/cirq-google/cirq_google/engine/engine_processor.py @@ -99,18 +99,48 @@ def engine(self) -> engine_base.Engine: return engine_base.Engine(self.project_id, context=self.context) - def get_sampler( - self, - run_name: str = "", - device_config_name: str = "", - snapshot_id: str = "", - max_concurrent_jobs: int = 100, + def get_sampler_from_run_name( + self, run_name: str, device_config_name: str | None = None, max_concurrent_jobs: int = 100 ) -> cg.engine.ProcessorSampler: - """Returns a sampler backed by the engine. + """Returns a sampler backed by the engine and given `run_name`. + Args: run_name: A unique identifier representing an automation run for the processor. An Automation Run contains a collection of device configurations for the processor. + device_config_name: An identifier used to select the processor configuration + utilized to run the job. A configuration identifies the set of + available qubits, couplers, and supported gates in the processor. + max_concurrent_jobs: The maximum number of jobs to be sent + simultaneously to the Engine. This client-side throttle can be + used to proactively reduce load to the backends and avoid quota + violations when pipelining circuit executions. + + Returns: + A `cirq.Sampler` instance (specifically a `engine_sampler.ProcessorSampler` + that will send circuits to the Quantum Computing Service + when sampled. + """ + processor = self._inner_processor() + return processor_sampler.ProcessorSampler( + processor=self, + run_name=run_name, + device_config_name=( + device_config_name + if device_config_name + else processor.default_device_config_key.config_alias + ), + max_concurrent_jobs=max_concurrent_jobs, + ) + + def get_sampler_from_snapshot_id( + self, + snapshot_id: str, + device_config_name: str | None = None, + max_concurrent_jobs: int = 100, + ) -> cg.engine.ProcessorSampler: + """Returns a sampler backed by the engine. + Args: device_config_name: An identifier used to select the processor configuration utilized to run the job. A configuration identifies the set of available qubits, couplers, and supported gates in the processor. @@ -126,29 +156,50 @@ def get_sampler( A `cirq.Sampler` instance (specifically a `engine_sampler.ProcessorSampler` that will send circuits to the Quantum Computing Service when sampled. + """ + processor = self._inner_processor() + return processor_sampler.ProcessorSampler( + processor=self, + snapshot_id=snapshot_id, + device_config_name=( + device_config_name + if device_config_name + else processor.default_device_config_key.config_alias + ), + max_concurrent_jobs=max_concurrent_jobs, + ) - Raises: - ValueError: If only one of `run_name` and `device_config_name` are specified. - ValueError: If both `run_name` and `snapshot_id` are specified. + def get_sampler( + self, device_config_name: str | None = None, max_concurrent_jobs: int = 100 + ) -> cg.engine.ProcessorSampler: + """Returns the default sampler backed by the engine. + + Args: + device_config_name: An identifier used to select the processor configuration + utilized to run the job. A configuration identifies the set of + available qubits, couplers, and supported gates in the processor. + max_concurrent_jobs: The maximum number of jobs to be sent + simultaneously to the Engine. This client-side throttle can be + used to proactively reduce load to the backends and avoid quota + violations when pipelining circuit executions. + + Returns: + A `cirq.Sampler` instance (specifically a `engine_sampler.ProcessorSampler` + that will send circuits to the Quantum Computing Service + when sampled. """ processor = self._inner_processor() - if run_name and snapshot_id: - raise ValueError('Cannot specify both `run_name` and `snapshot_id`') - if (bool(run_name) or bool(snapshot_id)) ^ bool(device_config_name): - raise ValueError( - 'Cannot specify only one of top level identifier and `device_config_name`' - ) - # If not provided, initialize the sampler with the Processor's default values. - if not run_name and not device_config_name and not snapshot_id: - run_name = processor.default_device_config_key.run - device_config_name = processor.default_device_config_key.config_alias - snapshot_id = processor.default_device_config_key.snapshot_id + return processor_sampler.ProcessorSampler( processor=self, - run_name=run_name, - snapshot_id=snapshot_id, - device_config_name=device_config_name, + run_name=processor.default_device_config_key.run, + snapshot_id=processor.default_device_config_key.snapshot_id, + device_config_name=( + device_config_name + if device_config_name + else processor.default_device_config_key.config_alias + ), max_concurrent_jobs=max_concurrent_jobs, ) diff --git a/cirq-google/cirq_google/engine/engine_processor_test.py b/cirq-google/cirq_google/engine/engine_processor_test.py index 1e4c74d4886..1a98d9c4244 100644 --- a/cirq-google/cirq_google/engine/engine_processor_test.py +++ b/cirq-google/cirq_google/engine/engine_processor_test.py @@ -324,7 +324,7 @@ def test_get_missing_device(): _ = processor.get_device() -def test_get_sampler_initializes_default_device_configuration() -> None: +def test_get_sampler_from_run_name() -> None: processor = cg.EngineProcessor( 'a', 'p', @@ -335,60 +335,96 @@ def test_get_sampler_initializes_default_device_configuration() -> None: ) ), ) - sampler = processor.get_sampler() + run_name = 'test_run_name' + device_config_name = 'test_device_name' - assert sampler.run_name == "run" - assert sampler.device_config_name == "config_alias" + sampler = processor.get_sampler_from_run_name( + run_name=run_name, device_config_name=device_config_name + ) + assert sampler.run_name == run_name + assert sampler.device_config_name == device_config_name -def test_get_sampler_uses_custom_default_device_configuration_key() -> None: + +def test_get_sampler_from_run_name_with_default_values() -> None: + default_config_alias = 'test_alias' processor = cg.EngineProcessor( 'a', 'p', EngineContext(), _processor=quantum.QuantumProcessor( default_device_config_key=quantum.DeviceConfigKey( - run="default_run", config_alias="default_config_alias" + run="run", config_alias=default_config_alias ) ), ) - sampler = processor.get_sampler(run_name="run1", device_config_name="config_alias1") + run_name = 'test_run' + + sampler = processor.get_sampler_from_run_name(run_name=run_name) - assert sampler.run_name == "run1" - assert sampler.device_config_name == "config_alias1" + assert sampler.run_name == run_name + assert sampler.device_config_name == default_config_alias -@pytest.mark.parametrize( - 'run, snapshot_id, config_alias, error_message', - [ - ('run', '', '', 'Cannot specify only one of top level identifier and `device_config_name`'), - ( - '', - '', - 'config', - 'Cannot specify only one of top level identifier and `device_config_name`', +def test_get_sampler_from_snapshot_id() -> None: + processor = cg.EngineProcessor( + 'a', + 'p', + EngineContext(), + _processor=quantum.QuantumProcessor( + default_device_config_key=quantum.DeviceConfigKey( + run="run", config_alias="config_alias" + ) ), - ('run', 'snapshot_id', 'config', 'Cannot specify both `run_name` and `snapshot_id`'), - ], -) -def test_get_sampler_with_incomplete_device_configuration_errors( - run, snapshot_id, config_alias, error_message -) -> None: + ) + snapshot_id = 'test_snapshot' + device_config_name = 'test_device_name' + + sampler = processor.get_sampler_from_snapshot_id( + snapshot_id=snapshot_id, device_config_name=device_config_name + ) + + assert sampler.snapshot_id == snapshot_id + assert sampler.device_config_name == device_config_name + + +def test_get_sampler_from_snapshot_id_with_default_device() -> None: + default_config_alias = 'test_alias' processor = cg.EngineProcessor( 'a', 'p', EngineContext(), _processor=quantum.QuantumProcessor( default_device_config_key=quantum.DeviceConfigKey( - run="default_run", config_alias="default_config_alias" + run="run", config_alias=default_config_alias ) ), ) + snapshot_id = 'test_snapshot' - with pytest.raises(ValueError, match=error_message): - processor.get_sampler( - run_name=run, device_config_name=config_alias, snapshot_id=snapshot_id - ) + sampler = processor.get_sampler_from_snapshot_id( + snapshot_id=snapshot_id, device_config_name=default_config_alias + ) + + assert sampler.snapshot_id == snapshot_id + assert sampler.device_config_name == default_config_alias + + +def test_get_sampler_initializes_default_device_configuration() -> None: + processor = cg.EngineProcessor( + 'a', + 'p', + EngineContext(), + _processor=quantum.QuantumProcessor( + default_device_config_key=quantum.DeviceConfigKey( + run="run", config_alias="config_alias" + ) + ), + ) + sampler = processor.get_sampler() + + assert sampler.run_name == "run" + assert sampler.device_config_name == "config_alias" @mock.patch('cirq_google.engine.engine_client.EngineClient.get_processor_async') diff --git a/cirq-google/cirq_google/engine/engine_test.py b/cirq-google/cirq_google/engine/engine_test.py index 99e75208981..b7db8ec5011 100644 --- a/cirq-google/cirq_google/engine/engine_test.py +++ b/cirq-google/cirq_google/engine/engine_test.py @@ -296,7 +296,9 @@ def test_engine_get_sampler_with_snapshot_id_passes_to_unary_rpc(client): project_id='proj', context=EngineContext(service_args={'client_info': 1}, enable_streaming=False), ) - sampler = engine.get_sampler('mysim', device_config_name="config", snapshot_id="123") + sampler = engine.get_sampler_from_snapshot_id( + 'mysim', device_config_name="config", snapshot_id="123" + ) _ = sampler.run_sweep(_CIRCUIT, params=[cirq.ParamResolver({'a': 1})]) kwargs = client().create_job_async.call_args_list[0].kwargs @@ -813,6 +815,44 @@ def test_get_sampler_initializes_max_concurrent_jobs(): assert sampler.max_concurrent_jobs == max_concurrent_jobs +def test_get_sampler_from_run_name(): + processor_id = 'test_processor_id' + run_name = 'test_run_name' + device_config_name = 'test_config_alias' + project_id = 'test_proj' + engine = cg.Engine(project_id=project_id) + processor = engine.get_processor(processor_id=processor_id) + + processor_sampler = processor.get_sampler_from_run_name( + run_name=run_name, device_config_name=device_config_name + ) + engine_sampler = engine.get_sampler_from_run_name( + processor_id=processor_id, run_name=run_name, device_config_name=device_config_name + ) + + assert processor_sampler.run_name == engine_sampler.run_name + assert processor_sampler.device_config_name == engine_sampler.device_config_name + + +def test_get_sampler_from_snapshot(): + processor_id = 'test_processor_id' + snapshot_id = 'test_snapshot_id' + device_config_name = 'test_config_alias' + project_id = 'test_proj' + engine = cg.Engine(project_id=project_id) + processor = engine.get_processor(processor_id=processor_id) + + processor_sampler = processor.get_sampler_from_snapshot_id( + snapshot_id=snapshot_id, device_config_name=device_config_name + ) + engine_sampler = engine.get_sampler_from_snapshot_id( + processor_id=processor_id, snapshot_id=snapshot_id, device_config_name=device_config_name + ) + + assert processor_sampler.snapshot_id == engine_sampler.snapshot_id + assert processor_sampler.device_config_name == engine_sampler.device_config_name + + @mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_sampler_with_unary_rpcs(client): setup_run_circuit_with_result_(client, _RESULTS)