Skip to content

Commit 4f8eb80

Browse files
committed
Take draft status into account when importing
1 parent e5a89c0 commit 4f8eb80

File tree

3 files changed

+135
-64
lines changed

3 files changed

+135
-64
lines changed

deepset_cloud_sdk/_service/pipeline_service.py

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -374,32 +374,59 @@ async def _create_index(self, name: str, pipeline_yaml: str) -> Response:
374374
)
375375

376376
async def _overwrite_pipeline(self, name: str, pipeline_yaml: str) -> Response:
377-
"""Overwrite a pipeline in deepset AI Platform by creating a new version.
378-
If creating a new version fails (e.g. pipeline doesn't exist), create the
379-
pipeline instead.
377+
"""Overwrite a pipeline in deepset AI Platform.
378+
379+
Behavior:
380+
- First try to fetch the latest version.
381+
- If the pipeline doesn't exist (404), create it instead.
382+
- If the latest version is a draft (is_draft == True), PATCH that version.
383+
- Otherwise, create a new version via POST /pipelines/{name}/versions.
380384
381385
:param name: Name of the pipeline.
382386
:param pipeline_yaml: Generated pipeline YAML string.
383387
"""
384-
# First try to create a new version of the existing pipeline
385-
version_response = await self._api.post(
388+
# First get the (last) version id if available
389+
version_response = await self._api.get(
386390
workspace_name=self._workspace_name,
387391
endpoint=f"pipelines/{name}/versions",
388-
json={"config_yaml": pipeline_yaml},
389392
)
390393

391-
if version_response.status_code == HTTPStatus.CREATED:
392-
logger.debug("Created new version for pipeline %s.", name)
393-
return version_response
394-
# If creating a version fails, assume the pipeline doesn't exist and create it
395-
logger.debug(
396-
"Failed to create new version for pipeline %s (status %s). "
397-
"Assuming pipeline does not exist and creating it instead.",
398-
name,
399-
version_response.status_code,
400-
)
394+
# If pipeline doesn't exist (404), create it instead
395+
if version_response.status_code == HTTPStatus.NOT_FOUND:
396+
logger.debug("Pipeline %s not found, creating new pipeline.", name)
397+
response = await self._create_pipeline(name=name, pipeline_yaml=pipeline_yaml)
398+
return response
399+
400+
version_body = version_response.json()
401+
latest_version = version_body["data"][0]
402+
version_id = latest_version["version_id"]
403+
is_draft = latest_version.get("is_draft", False)
404+
405+
if is_draft:
406+
# If the latest version is a draft, patch that version
407+
logger.debug(
408+
"Latest version %s of pipeline %s is a draft, patching existing version.",
409+
version_id,
410+
name,
411+
)
412+
response = await self._api.patch(
413+
workspace_name=self._workspace_name,
414+
endpoint=f"pipelines/{name}/versions/{version_id}",
415+
json={"config_yaml": pipeline_yaml},
416+
)
417+
else:
418+
# Otherwise, create a new version
419+
logger.debug(
420+
"Latest version %s of pipeline %s is not a draft, creating new version.",
421+
version_id,
422+
name,
423+
)
424+
response = await self._api.post(
425+
workspace_name=self._workspace_name,
426+
endpoint=f"pipelines/{name}/versions",
427+
json={"config_yaml": pipeline_yaml, "is_draft": True},
428+
)
401429

402-
response = await self._create_pipeline(name=name, pipeline_yaml=pipeline_yaml)
403430
return response
404431

405432
async def _create_pipeline(self, name: str, pipeline_yaml: str) -> Response:

tests/integration/workflows/test_integration_pipeline_client.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -488,10 +488,10 @@ async def test_import_pipeline_with_overwrite_fallback_to_create_async(
488488
return_value=Response(status_code=HTTPStatus.NO_CONTENT)
489489
)
490490

491-
# Mock failed version creation (non-201 for POST /pipelines/{name}/versions)
492-
version_create_route = respx.post(
491+
# Mock 404 response for GET (resource not found)
492+
version_check_route = respx.get(
493493
"https://test-api-url.com/workspaces/test-workspace/pipelines/test-pipeline-fallback/versions"
494-
).mock(return_value=Response(status_code=HTTPStatus.BAD_REQUEST))
494+
).mock(return_value=Response(status_code=HTTPStatus.NOT_FOUND))
495495

496496
# Mock successful creation
497497
create_route = respx.post("https://test-api-url.com/workspaces/test-workspace/pipelines").mock(
@@ -508,24 +508,20 @@ async def test_import_pipeline_with_overwrite_fallback_to_create_async(
508508

509509
await test_async_client.import_into_deepset(sample_pipeline, pipeline_config)
510510

511-
# Verify all three endpoints were called
511+
# Verify all three endpoints were called in sequence
512512
assert validation_route.called
513-
assert version_create_route.called
513+
assert version_check_route.called
514514
assert create_route.called
515515

516516
# Check validation request
517517
validation_request = validation_route.calls[0].request
518518
assert validation_request.headers["Authorization"] == "Bearer test-api-key"
519519
validation_body = json.loads(validation_request.content)
520520
assert "query_yaml" in validation_body
521-
# When overwrite=True, name should be excluded from validation payload (if your code does that)
522-
# assert "name" not in validation_body
523-
524-
# Check attempted version creation request
525-
version_create_request = version_create_route.calls[0].request
526-
assert version_create_request.headers["Authorization"] == "Bearer test-api-key"
527-
version_body = json.loads(version_create_request.content)
528-
assert "config_yaml" in version_body
521+
522+
# Check GET attempt
523+
version_check_request = version_check_route.calls[0].request
524+
assert version_check_request.headers["Authorization"] == "Bearer test-api-key"
529525

530526
# Check fallback creation
531527
create_request = create_route.calls[0].request

tests/unit/service/test_pipeline_service.py

Lines changed: 82 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,7 @@ async def test_import_index_with_overwrite_fallback_to_create(
587587
async def test_import_pipeline_with_overwrite_true(
588588
self, pipeline_service: PipelineService, index_pipeline: Pipeline, mock_api: AsyncMock
589589
) -> None:
590-
"""Test importing a pipeline with overwrite=True creates a new version via POST endpoint."""
590+
"""Test importing a pipeline with overwrite=True patches latest draft version."""
591591
config = PipelineConfig(
592592
name="test_pipeline_overwrite",
593593
inputs=PipelineInputs(query=["retriever.query"]),
@@ -600,39 +600,93 @@ async def test_import_pipeline_with_overwrite_true(
600600
validation_response = Mock(spec=Response)
601601
validation_response.status_code = HTTPStatus.NO_CONTENT.value
602602

603-
# Mock successful "create new version" response
603+
# Mock successful versions response, latest version is a draft
604+
versions_response = Mock(status_code=HTTPStatus.OK.value)
605+
versions_response.json.return_value = {
606+
"data": [{"version_id": "42abcd", "is_draft": True}],
607+
}
608+
609+
# Mock successful overwrite (PATCH) response
604610
overwrite_response = Mock(spec=Response)
605-
overwrite_response.status_code = HTTPStatus.CREATED.value
611+
overwrite_response.status_code = HTTPStatus.OK.value
612+
613+
mock_api.post.return_value = validation_response
614+
mock_api.get.return_value = versions_response
615+
mock_api.patch.return_value = overwrite_response
616+
617+
await pipeline_service.import_async(index_pipeline, config)
618+
619+
# validation + GET versions + PATCH draft version
620+
assert mock_api.post.call_count == 1
621+
assert mock_api.get.call_count == 1
622+
assert mock_api.patch.call_count == 1
623+
624+
# Check validation call
625+
validation_call = mock_api.post.call_args_list[0]
626+
assert validation_call.kwargs["endpoint"] == "pipeline_validations"
627+
assert "query_yaml" in validation_call.kwargs["json"]
628+
# When overwrite=True, name should be excluded from validation payload
629+
assert "name" not in validation_call.kwargs["json"]
630+
631+
# Check PATCH call
632+
overwrite_call = mock_api.patch.call_args_list[0]
633+
assert overwrite_call.kwargs["endpoint"] == "pipelines/test_pipeline_overwrite/versions/42abcd"
634+
assert "config_yaml" in overwrite_call.kwargs["json"]
635+
636+
@pytest.mark.asyncio
637+
async def test_import_pipeline_with_overwrite_true_creates_new_version_when_not_draft(
638+
self, pipeline_service: PipelineService, index_pipeline: Pipeline, mock_api: AsyncMock
639+
) -> None:
640+
"""Test importing a pipeline with overwrite=True creates a new version when latest version is not draft."""
641+
config = PipelineConfig(
642+
name="test_pipeline_overwrite",
643+
inputs=PipelineInputs(query=["retriever.query"]),
644+
outputs=PipelineOutputs(documents="meta_ranker.documents"),
645+
strict_validation=False,
646+
overwrite=True,
647+
)
648+
649+
# Mock successful validation response
650+
validation_response = Mock(spec=Response)
651+
validation_response.status_code = HTTPStatus.NO_CONTENT.value
652+
653+
# Mock versions response, latest version is NOT a draft
654+
versions_response = Mock(status_code=HTTPStatus.OK.value)
655+
versions_response.json.return_value = {
656+
"data": [{"version_id": "42abcd", "is_draft": False}],
657+
}
658+
659+
# Mock successful "create new version" response
660+
new_version_response = Mock(spec=Response)
661+
new_version_response.status_code = HTTPStatus.CREATED.value
606662

607663
# First POST is validation, second POST is "create new version"
608-
mock_api.post.side_effect = [validation_response, overwrite_response]
664+
mock_api.post.side_effect = [validation_response, new_version_response]
665+
mock_api.get.return_value = versions_response
609666

610667
await pipeline_service.import_async(index_pipeline, config)
611668

612-
# Should call validation endpoint first, then create-version endpoint
669+
# validation + GET versions + POST versions (new version)
613670
assert mock_api.post.call_count == 2
614-
# No GET/PATCH/PUT calls in the overwrite path anymore
615-
assert mock_api.get.call_count == 0
671+
assert mock_api.get.call_count == 1
616672
assert mock_api.patch.call_count == 0
617-
assert mock_api.put.call_count == 0
618673

619674
# Check validation call
620675
validation_call = mock_api.post.call_args_list[0]
621676
assert validation_call.kwargs["endpoint"] == "pipeline_validations"
622677
assert "query_yaml" in validation_call.kwargs["json"]
623-
# When overwrite=True, name should be excluded from validation payload
624678
assert "name" not in validation_call.kwargs["json"]
625679

626-
# Check create-version call
627-
overwrite_call = mock_api.post.call_args_list[1]
628-
assert overwrite_call.kwargs["endpoint"] == "pipelines/test_pipeline_overwrite/versions"
629-
assert "config_yaml" in overwrite_call.kwargs["json"]
680+
# Check create-version POST call
681+
create_version_call = mock_api.post.call_args_list[1]
682+
assert create_version_call.kwargs["endpoint"] == "pipelines/test_pipeline_overwrite/versions"
683+
assert "config_yaml" in create_version_call.kwargs["json"]
630684

631685
@pytest.mark.asyncio
632686
async def test_import_pipeline_with_overwrite_fallback_to_create(
633687
self, pipeline_service: PipelineService, index_pipeline: Pipeline, mock_api: AsyncMock
634688
) -> None:
635-
"""Test importing a pipeline with overwrite=True that falls back to create when version creation fails."""
689+
"""Test importing a pipeline with overwrite=True that falls back to create when resource doesn't exist."""
636690

637691
config = PipelineConfig(
638692
name="test_pipeline_fallback",
@@ -646,41 +700,35 @@ async def test_import_pipeline_with_overwrite_fallback_to_create(
646700
validation_response = Mock(spec=Response)
647701
validation_response.status_code = HTTPStatus.NO_CONTENT.value
648702

649-
# Mock non-201 response for POST /pipelines/{name}/versions (version creation fails)
650-
version_fail_response = Mock(spec=Response)
651-
version_fail_response.status_code = HTTPStatus.BAD_REQUEST.value
703+
# Mock 404 response for GET (resource not found)
704+
not_found_response = Mock(spec=Response)
705+
not_found_response.status_code = HTTPStatus.NOT_FOUND.value
652706

653-
# Mock successful creation response for POST /pipelines
707+
# Mock successful creation response
654708
create_response = Mock(spec=Response)
655709
create_response.status_code = HTTPStatus.CREATED.value
656710

657-
# POST calls: validation, create-version (fails), create-pipeline (fallback)
658-
mock_api.post.side_effect = [validation_response, version_fail_response, create_response]
711+
mock_api.post.side_effect = [validation_response, create_response]
712+
mock_api.get.return_value = not_found_response
659713

660714
await pipeline_service.import_async(index_pipeline, config)
661715

662-
# Should call validation endpoint, then POST to create new version (fails),
663-
# then POST to create the pipeline
664-
assert mock_api.post.call_count == 3
665-
# No GET anymore; overwrite logic doesn't fetch versions
666-
assert mock_api.get.call_count == 0
667-
assert mock_api.patch.call_count == 0
668-
assert mock_api.put.call_count == 0
716+
# validation + GET (404) + POST create
717+
assert mock_api.post.call_count == 2
718+
assert mock_api.get.call_count == 1
669719

670720
# Check validation call
671721
validation_call = mock_api.post.call_args_list[0]
672722
assert validation_call.kwargs["endpoint"] == "pipeline_validations"
673723
assert "query_yaml" in validation_call.kwargs["json"]
674-
# When overwrite=True, name should be excluded from validation payload
675724
assert "name" not in validation_call.kwargs["json"]
676725

677-
# Check attempted version creation call
678-
version_call = mock_api.post.call_args_list[1]
679-
assert version_call.kwargs["endpoint"] == "pipelines/test_pipeline_fallback/versions"
680-
assert "config_yaml" in version_call.kwargs["json"]
726+
# Check GET versions attempt
727+
get_call = mock_api.get.call_args_list[0]
728+
assert get_call.kwargs["endpoint"] == "pipelines/test_pipeline_fallback/versions"
681729

682-
# Check fallback create-pipeline call
683-
create_call = mock_api.post.call_args_list[2]
730+
# Check fallback POST call
731+
create_call = mock_api.post.call_args_list[1]
684732
assert create_call.kwargs["endpoint"] == "pipelines"
685733
assert create_call.kwargs["json"]["name"] == "test_pipeline_fallback"
686734
assert "query_yaml" in create_call.kwargs["json"]

0 commit comments

Comments
 (0)