Skip to content

Commit

Permalink
clean u tests
Browse files Browse the repository at this point in the history
  • Loading branch information
senecameeks committed Oct 4, 2024
1 parent 1aa8182 commit 6ad91b2
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 33 deletions.
8 changes: 4 additions & 4 deletions cirq-google/cirq_google/engine/engine_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,11 +442,11 @@ async def create_job_async(
# Create job.
if snapshot_id:
selector = quantum.DeviceConfigSelector(
snapshot_id=snapshot_id, config_alias=device_config_name
snapshot_id=snapshot_id or None, config_alias=device_config_name
)
else:
selector = quantum.DeviceConfigSelector(
run_name=run_name, config_alias=device_config_name
run_name=run_name or None, config_alias=device_config_name
)
job_name = _job_name_from_ids(project_id, program_id, job_id) if job_id else ''
job = quantum.QuantumJob(
Expand Down Expand Up @@ -817,11 +817,11 @@ def run_job_over_stream(

if snapshot_id:
selector = quantum.DeviceConfigSelector(
snapshot_id=snapshot_id, config_alias=device_config_name
snapshot_id=snapshot_id or None, config_alias=device_config_name
)
else:
selector = quantum.DeviceConfigSelector(
run_name=run_name, config_alias=device_config_name
run_name=run_name or None, config_alias=device_config_name
)

job = quantum.QuantumJob(
Expand Down
97 changes: 68 additions & 29 deletions cirq-google/cirq_google/engine/engine_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def test_delete_program(client_constructor):

@mock.patch.dict(os.environ, clear='CIRQ_TESTING')
@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True)
def test_create_job(client_constructor):
def test_create_job_passes(client_constructor):
grpc_client = _setup_client_mock(client_constructor)

result = quantum.QuantumJob(name='projects/proj/programs/prog/jobs/job0')
Expand All @@ -373,9 +373,7 @@ def test_create_job(client_constructor):
priority=10,
processor_selector=quantum.SchedulingConfig.ProcessorSelector(
processor='projects/proj/processors/processor0',
device_config_selector=quantum.DeviceConfigSelector(
run_name="", config_alias=""
),
device_config_selector=quantum.DeviceConfigSelector(),
),
),
description='A job',
Expand All @@ -398,9 +396,7 @@ def test_create_job(client_constructor):
priority=10,
processor_selector=quantum.SchedulingConfig.ProcessorSelector(
processor='projects/proj/processors/processor0',
device_config_selector=quantum.DeviceConfigSelector(
run_name="", config_alias=""
),
device_config_selector=quantum.DeviceConfigSelector(),
),
),
description='A job',
Expand All @@ -421,9 +417,7 @@ def test_create_job(client_constructor):
priority=10,
processor_selector=quantum.SchedulingConfig.ProcessorSelector(
processor='projects/proj/processors/processor0',
device_config_selector=quantum.DeviceConfigSelector(
run_name="", config_alias=""
),
device_config_selector=quantum.DeviceConfigSelector(),
),
),
labels=labels,
Expand All @@ -445,9 +439,7 @@ def test_create_job(client_constructor):
priority=10,
processor_selector=quantum.SchedulingConfig.ProcessorSelector(
processor='projects/proj/processors/processor0',
device_config_selector=quantum.DeviceConfigSelector(
run_name="", config_alias=""
),
device_config_selector=quantum.DeviceConfigSelector(),
),
),
),
Expand All @@ -466,9 +458,7 @@ def test_create_job(client_constructor):
priority=10,
processor_selector=quantum.SchedulingConfig.ProcessorSelector(
processor='projects/proj/processors/processor0',
device_config_selector=quantum.DeviceConfigSelector(
run_name="", config_alias=""
),
device_config_selector=quantum.DeviceConfigSelector(),
),
),
),
Expand Down Expand Up @@ -542,7 +532,7 @@ def test_create_job_with_invalid_processor_and_device_config_arguments_throws(
'run_name, snapshot_id, device_config_name',
[('RUN_NAME', '', 'CONFIG_NAME'), ('', '', ''), ('', '', '')],
)
def test_create_job_with_run_name_and_device_config_name(
def test_create_job_with_run_name_and_device_config_name_succeeds(
client_constructor, processor_id, run_name, snapshot_id, device_config_name
):
grpc_client = _setup_client_mock(client_constructor)
Expand Down Expand Up @@ -573,7 +563,7 @@ def test_create_job_with_run_name_and_device_config_name(
processor_selector=quantum.SchedulingConfig.ProcessorSelector(
processor='projects/proj/processors/processor0',
device_config_selector=quantum.DeviceConfigSelector(
run_name=run_name, config_alias=device_config_name
run_name=run_name or None, config_alias=device_config_name
),
),
),
Expand Down Expand Up @@ -619,7 +609,7 @@ def test_create_job_with_snapshot_id_and_device_config_name_succeeds(
processor_selector=quantum.SchedulingConfig.ProcessorSelector(
processor='projects/proj/processors/processor0',
device_config_selector=quantum.DeviceConfigSelector(
snapshot_id=snapshot_id, config_alias=device_config_name
snapshot_id=snapshot_id or None, config_alias=device_config_name
),
),
),
Expand Down Expand Up @@ -660,7 +650,7 @@ def test_create_job_with_snapshot_id_and_device_config_name_succeeds(
priority=10,
processor_selector=quantum.SchedulingConfig.ProcessorSelector(
processor='projects/proj/processors/processor0',
device_config_selector=quantum.DeviceConfigSelector(run_name=""),
device_config_selector=quantum.DeviceConfigSelector(),
),
),
description='A job',
Expand All @@ -681,8 +671,6 @@ def test_create_job_with_snapshot_id_and_device_config_name_succeeds(
'priority': 10,
'job_description': 'A job',
'job_labels': {'hello': 'world'},
'snapshot_id': 'SNAPSHOT_ID',
'device_config_name': 'CONFIG_NAME',
},
[
'projects/proj',
Expand All @@ -696,9 +684,7 @@ def test_create_job_with_snapshot_id_and_device_config_name_succeeds(
priority=10,
processor_selector=quantum.SchedulingConfig.ProcessorSelector(
processor='projects/proj/processors/processor0',
device_config_selector=quantum.DeviceConfigSelector(
snapshot_id="SNAPSHOT_ID", config_alias="CONFIG_NAME"
),
device_config_selector=quantum.DeviceConfigSelector(),
),
),
description='A job',
Expand Down Expand Up @@ -729,7 +715,7 @@ def test_create_job_with_snapshot_id_and_device_config_name_succeeds(
priority=10,
processor_selector=quantum.SchedulingConfig.ProcessorSelector(
processor='projects/proj/processors/processor0',
device_config_selector=quantum.DeviceConfigSelector(run_name=""),
device_config_selector=quantum.DeviceConfigSelector(),
),
),
description='A job',
Expand Down Expand Up @@ -766,7 +752,7 @@ def test_create_job_with_snapshot_id_and_device_config_name_succeeds(
priority=10,
processor_selector=quantum.SchedulingConfig.ProcessorSelector(
processor='projects/proj/processors/processor0',
device_config_selector=quantum.DeviceConfigSelector(run_name=""),
device_config_selector=quantum.DeviceConfigSelector(),
),
),
description='A job',
Expand Down Expand Up @@ -801,7 +787,7 @@ def test_create_job_with_snapshot_id_and_device_config_name_succeeds(
priority=10,
processor_selector=quantum.SchedulingConfig.ProcessorSelector(
processor='projects/proj/processors/processor0',
device_config_selector=quantum.DeviceConfigSelector(run_name=""),
device_config_selector=quantum.DeviceConfigSelector(),
),
),
),
Expand Down Expand Up @@ -833,7 +819,7 @@ def test_create_job_with_snapshot_id_and_device_config_name_succeeds(
scheduling_config=quantum.SchedulingConfig(
processor_selector=quantum.SchedulingConfig.ProcessorSelector(
processor='projects/proj/processors/processor0',
device_config_selector=quantum.DeviceConfigSelector(run_name=""),
device_config_selector=quantum.DeviceConfigSelector(),
)
),
),
Expand Down Expand Up @@ -861,6 +847,59 @@ def test_run_job_over_stream(
stream_manager.submit.assert_called_with(*expected_submit_args)


@pytest.mark.parametrize(
'run_job_kwargs, expected_submit_args',
[
(
{
'project_id': 'proj',
'program_id': 'prog',
'code': any_pb2.Any(),
'job_id': 'job0',
'processor_id': 'processor0',
'run_context': any_pb2.Any(),
'snapshot_id': 'SNAPSHOT_ID',
'device_config_name': 'CONFIG_NAME',
},
[
'projects/proj',
quantum.QuantumProgram(name='projects/proj/programs/prog', code=any_pb2.Any()),
quantum.QuantumJob(
name='projects/proj/programs/prog/jobs/job0',
run_context=any_pb2.Any(),
scheduling_config=quantum.SchedulingConfig(
processor_selector=quantum.SchedulingConfig.ProcessorSelector(
processor='projects/proj/processors/processor0',
device_config_selector=quantum.DeviceConfigSelector(
snapshot_id="SNAPSHOT_ID", config_alias="CONFIG_NAME"
),
)
),
),
],
)
],
)
@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True)
@mock.patch.object(engine_stream_manager, 'StreamManager', autospec=True)
def test_run_job_over_stream_with_snapshot_id_passes(
manager_constructor, client_constructor, run_job_kwargs, expected_submit_args
):
_setup_client_mock(client_constructor)
stream_manager = _setup_stream_manager_mock(manager_constructor)

result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0')
expected_future = duet.AwaitableFuture()
expected_future.try_set_result(result)
stream_manager.submit.return_value = expected_future
client = EngineClient()

actual_future = client.run_job_over_stream(**run_job_kwargs)

assert actual_future == expected_future
stream_manager.submit.assert_called_with(*expected_submit_args)


def test_run_job_over_stream_with_priority_out_of_bound_raises():
client = EngineClient()

Expand Down

0 comments on commit 6ad91b2

Please sign in to comment.