diff --git a/cirq-google/cirq_google/engine/engine_client.py b/cirq-google/cirq_google/engine/engine_client.py index 9467845456c..8fa140cd371 100644 --- a/cirq-google/cirq_google/engine/engine_client.py +++ b/cirq-google/cirq_google/engine/engine_client.py @@ -442,11 +442,11 @@ async def create_job_async( # Create job. if snapshot_id: selector = quantum.DeviceConfigSelector( - snapshot_id=snapshot_id, config_alias=device_config_name + snapshot_id=snapshot_id or None, config_alias=device_config_name ) else: selector = quantum.DeviceConfigSelector( - run_name=run_name, config_alias=device_config_name + run_name=run_name or None, config_alias=device_config_name ) job_name = _job_name_from_ids(project_id, program_id, job_id) if job_id else '' job = quantum.QuantumJob( @@ -817,11 +817,11 @@ def run_job_over_stream( if snapshot_id: selector = quantum.DeviceConfigSelector( - snapshot_id=snapshot_id, config_alias=device_config_name + snapshot_id=snapshot_id or None, config_alias=device_config_name ) else: selector = quantum.DeviceConfigSelector( - run_name=run_name, config_alias=device_config_name + run_name=run_name or None, config_alias=device_config_name ) job = quantum.QuantumJob( diff --git a/cirq-google/cirq_google/engine/engine_client_test.py b/cirq-google/cirq_google/engine/engine_client_test.py index d7cb1b4f49f..9f186cbb877 100644 --- a/cirq-google/cirq_google/engine/engine_client_test.py +++ b/cirq-google/cirq_google/engine/engine_client_test.py @@ -351,7 +351,7 @@ def test_delete_program(client_constructor): @mock.patch.dict(os.environ, clear='CIRQ_TESTING') @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) -def test_create_job(client_constructor): +def test_create_job_passes(client_constructor): grpc_client = _setup_client_mock(client_constructor) result = quantum.QuantumJob(name='projects/proj/programs/prog/jobs/job0') @@ -373,9 +373,7 @@ def test_create_job(client_constructor): priority=10, processor_selector=quantum.SchedulingConfig.ProcessorSelector( processor='projects/proj/processors/processor0', - device_config_selector=quantum.DeviceConfigSelector( - run_name="", config_alias="" - ), + device_config_selector=quantum.DeviceConfigSelector(), ), ), description='A job', @@ -398,9 +396,7 @@ def test_create_job(client_constructor): priority=10, processor_selector=quantum.SchedulingConfig.ProcessorSelector( processor='projects/proj/processors/processor0', - device_config_selector=quantum.DeviceConfigSelector( - run_name="", config_alias="" - ), + device_config_selector=quantum.DeviceConfigSelector(), ), ), description='A job', @@ -421,9 +417,7 @@ def test_create_job(client_constructor): priority=10, processor_selector=quantum.SchedulingConfig.ProcessorSelector( processor='projects/proj/processors/processor0', - device_config_selector=quantum.DeviceConfigSelector( - run_name="", config_alias="" - ), + device_config_selector=quantum.DeviceConfigSelector(), ), ), labels=labels, @@ -445,9 +439,7 @@ def test_create_job(client_constructor): priority=10, processor_selector=quantum.SchedulingConfig.ProcessorSelector( processor='projects/proj/processors/processor0', - device_config_selector=quantum.DeviceConfigSelector( - run_name="", config_alias="" - ), + device_config_selector=quantum.DeviceConfigSelector(), ), ), ), @@ -466,9 +458,7 @@ def test_create_job(client_constructor): priority=10, processor_selector=quantum.SchedulingConfig.ProcessorSelector( processor='projects/proj/processors/processor0', - device_config_selector=quantum.DeviceConfigSelector( - run_name="", config_alias="" - ), + device_config_selector=quantum.DeviceConfigSelector(), ), ), ), @@ -542,7 +532,7 @@ def test_create_job_with_invalid_processor_and_device_config_arguments_throws( 'run_name, snapshot_id, device_config_name', [('RUN_NAME', '', 'CONFIG_NAME'), ('', '', ''), ('', '', '')], ) -def test_create_job_with_run_name_and_device_config_name( +def test_create_job_with_run_name_and_device_config_name_succeeds( client_constructor, processor_id, run_name, snapshot_id, device_config_name ): grpc_client = _setup_client_mock(client_constructor) @@ -573,7 +563,7 @@ def test_create_job_with_run_name_and_device_config_name( processor_selector=quantum.SchedulingConfig.ProcessorSelector( processor='projects/proj/processors/processor0', device_config_selector=quantum.DeviceConfigSelector( - run_name=run_name, config_alias=device_config_name + run_name=run_name or None, config_alias=device_config_name ), ), ), @@ -619,7 +609,7 @@ def test_create_job_with_snapshot_id_and_device_config_name_succeeds( processor_selector=quantum.SchedulingConfig.ProcessorSelector( processor='projects/proj/processors/processor0', device_config_selector=quantum.DeviceConfigSelector( - snapshot_id=snapshot_id, config_alias=device_config_name + snapshot_id=snapshot_id or None, config_alias=device_config_name ), ), ), @@ -660,7 +650,7 @@ def test_create_job_with_snapshot_id_and_device_config_name_succeeds( priority=10, processor_selector=quantum.SchedulingConfig.ProcessorSelector( processor='projects/proj/processors/processor0', - device_config_selector=quantum.DeviceConfigSelector(run_name=""), + device_config_selector=quantum.DeviceConfigSelector(), ), ), description='A job', @@ -681,8 +671,6 @@ def test_create_job_with_snapshot_id_and_device_config_name_succeeds( 'priority': 10, 'job_description': 'A job', 'job_labels': {'hello': 'world'}, - 'snapshot_id': 'SNAPSHOT_ID', - 'device_config_name': 'CONFIG_NAME', }, [ 'projects/proj', @@ -696,9 +684,7 @@ def test_create_job_with_snapshot_id_and_device_config_name_succeeds( priority=10, processor_selector=quantum.SchedulingConfig.ProcessorSelector( processor='projects/proj/processors/processor0', - device_config_selector=quantum.DeviceConfigSelector( - snapshot_id="SNAPSHOT_ID", config_alias="CONFIG_NAME" - ), + device_config_selector=quantum.DeviceConfigSelector(), ), ), description='A job', @@ -729,7 +715,7 @@ def test_create_job_with_snapshot_id_and_device_config_name_succeeds( priority=10, processor_selector=quantum.SchedulingConfig.ProcessorSelector( processor='projects/proj/processors/processor0', - device_config_selector=quantum.DeviceConfigSelector(run_name=""), + device_config_selector=quantum.DeviceConfigSelector(), ), ), description='A job', @@ -766,7 +752,7 @@ def test_create_job_with_snapshot_id_and_device_config_name_succeeds( priority=10, processor_selector=quantum.SchedulingConfig.ProcessorSelector( processor='projects/proj/processors/processor0', - device_config_selector=quantum.DeviceConfigSelector(run_name=""), + device_config_selector=quantum.DeviceConfigSelector(), ), ), description='A job', @@ -801,7 +787,7 @@ def test_create_job_with_snapshot_id_and_device_config_name_succeeds( priority=10, processor_selector=quantum.SchedulingConfig.ProcessorSelector( processor='projects/proj/processors/processor0', - device_config_selector=quantum.DeviceConfigSelector(run_name=""), + device_config_selector=quantum.DeviceConfigSelector(), ), ), ), @@ -833,7 +819,7 @@ def test_create_job_with_snapshot_id_and_device_config_name_succeeds( scheduling_config=quantum.SchedulingConfig( processor_selector=quantum.SchedulingConfig.ProcessorSelector( processor='projects/proj/processors/processor0', - device_config_selector=quantum.DeviceConfigSelector(run_name=""), + device_config_selector=quantum.DeviceConfigSelector(), ) ), ), @@ -861,6 +847,59 @@ def test_run_job_over_stream( stream_manager.submit.assert_called_with(*expected_submit_args) +@pytest.mark.parametrize( + 'run_job_kwargs, expected_submit_args', + [ + ( + { + 'project_id': 'proj', + 'program_id': 'prog', + 'code': any_pb2.Any(), + 'job_id': 'job0', + 'processor_id': 'processor0', + 'run_context': any_pb2.Any(), + 'snapshot_id': 'SNAPSHOT_ID', + 'device_config_name': 'CONFIG_NAME', + }, + [ + 'projects/proj', + quantum.QuantumProgram(name='projects/proj/programs/prog', code=any_pb2.Any()), + quantum.QuantumJob( + name='projects/proj/programs/prog/jobs/job0', + run_context=any_pb2.Any(), + scheduling_config=quantum.SchedulingConfig( + processor_selector=quantum.SchedulingConfig.ProcessorSelector( + processor='projects/proj/processors/processor0', + device_config_selector=quantum.DeviceConfigSelector( + snapshot_id="SNAPSHOT_ID", config_alias="CONFIG_NAME" + ), + ) + ), + ), + ], + ) + ], +) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) +@mock.patch.object(engine_stream_manager, 'StreamManager', autospec=True) +def test_run_job_over_stream_with_snapshot_id_passes( + manager_constructor, client_constructor, run_job_kwargs, expected_submit_args +): + _setup_client_mock(client_constructor) + stream_manager = _setup_stream_manager_mock(manager_constructor) + + result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') + expected_future = duet.AwaitableFuture() + expected_future.try_set_result(result) + stream_manager.submit.return_value = expected_future + client = EngineClient() + + actual_future = client.run_job_over_stream(**run_job_kwargs) + + assert actual_future == expected_future + stream_manager.submit.assert_called_with(*expected_submit_args) + + def test_run_job_over_stream_with_priority_out_of_bound_raises(): client = EngineClient()