diff --git a/deepset_cloud_sdk/_service/pipeline_service.py b/deepset_cloud_sdk/_service/pipeline_service.py index b5dac1fc..42d925d0 100644 --- a/deepset_cloud_sdk/_service/pipeline_service.py +++ b/deepset_cloud_sdk/_service/pipeline_service.py @@ -7,6 +7,7 @@ from io import StringIO from typing import Any, List, Optional, Protocol, runtime_checkable +import httpx import structlog from httpx import Response from pydantic import BaseModel @@ -376,28 +377,47 @@ async def _create_index(self, name: str, pipeline_yaml: str) -> Response: async def _overwrite_pipeline(self, name: str, pipeline_yaml: str) -> Response: """Overwrite a pipeline in deepset AI Platform. - :param name: Name of the pipeline. - :param pipeline_yaml: Generated pipeline YAML string. + Behavior: + - First try to fetch the latest version. + - If the pipeline doesn't exist (404), create it instead. + - If the latest version is a draft (is_draft == True), PATCH that version. + - Otherwise, create a new version via POST /pipelines/{name}/versions. """ - # First get the (last) version id if available - version_response = await self._api.get( - workspace_name=self._workspace_name, endpoint=f"pipelines/{name}/versions" - ) - - # If pipeline doesn't exist (404), create it instead - if version_response.status_code == HTTPStatus.NOT_FOUND: - logger.debug(f"Pipeline {name} not found, creating new pipeline.") - response = await self._create_pipeline(name=name, pipeline_yaml=pipeline_yaml) - else: - version_body = version_response.json() - version_id = version_body["data"][0]["version_id"] - response = await self._api.patch( + # Fetch versions + try: + version_response = await self._api.get( + workspace_name=self._workspace_name, + endpoint=f"pipelines/{name}/versions", + ) + version_response.raise_for_status() + except httpx.HTTPStatusError as e: + if e.response.status_code != HTTPStatus.NOT_FOUND: + raise + # the pipeline does not exist, let's create it. + logger.debug(f"Pipeline '{name}' not found, creating new pipeline.") + return await self._create_pipeline(name=name, pipeline_yaml=pipeline_yaml) + + version_body = version_response.json() + latest_version = version_body["data"][0] + version_id = latest_version["version_id"] + is_draft = latest_version.get("is_draft", False) + + if is_draft: + # Patch existing draft version + logger.debug(f"Patching existing draft version '{version_id}' of pipeline '{name}'.") + return await self._api.patch( workspace_name=self._workspace_name, endpoint=f"pipelines/{name}/versions/{version_id}", json={"config_yaml": pipeline_yaml}, ) - return response + # Create a new version + logger.debug(f"Latest version '{version_id}' of pipeline '{name}' is not a draft, creating new version.") + return await self._api.post( + workspace_name=self._workspace_name, + endpoint=f"pipelines/{name}/versions", + json={"config_yaml": pipeline_yaml}, + ) async def _create_pipeline(self, name: str, pipeline_yaml: str) -> Response: """Create a pipeline in deepset AI Platform. diff --git a/tests/unit/service/test_pipeline_service.py b/tests/unit/service/test_pipeline_service.py index fea821ca..5df9a1df 100644 --- a/tests/unit/service/test_pipeline_service.py +++ b/tests/unit/service/test_pipeline_service.py @@ -6,6 +6,7 @@ from typing import Any from unittest.mock import AsyncMock, Mock +import httpx import pytest from haystack import AsyncPipeline, Pipeline from haystack.components.converters import CSVToDocument, TextFileToDocument @@ -14,10 +15,7 @@ from httpx import Response from structlog.testing import capture_logs -from deepset_cloud_sdk._service.pipeline_service import ( - DeepsetValidationError, - PipelineService, -) +from deepset_cloud_sdk._service.pipeline_service import DeepsetValidationError, PipelineService from deepset_cloud_sdk.models import ( IndexConfig, IndexInputs, @@ -587,7 +585,7 @@ async def test_import_index_with_overwrite_fallback_to_create( async def test_import_pipeline_with_overwrite_true( self, pipeline_service: PipelineService, index_pipeline: Pipeline, mock_api: AsyncMock ) -> None: - """Test importing a pipeline with overwrite=True uses PUT endpoint.""" + """Test importing a pipeline with overwrite=True patches latest draft version.""" config = PipelineConfig( name="test_pipeline_overwrite", inputs=PipelineInputs(query=["retriever.query"]), @@ -600,24 +598,25 @@ async def test_import_pipeline_with_overwrite_true( validation_response = Mock(spec=Response) validation_response.status_code = HTTPStatus.NO_CONTENT.value - # Mock successful versions response + # Mock successful versions response, latest version is a draft versions_response = Mock(status_code=HTTPStatus.OK.value) versions_response.json.return_value = { - "data": [{"version_id": "42abcd"}], + "data": [{"version_id": "42abcd", "is_draft": True}], } - # Mock successful overwrite response + # Mock successful overwrite (PATCH) response overwrite_response = Mock(spec=Response) overwrite_response.status_code = HTTPStatus.OK.value mock_api.post.return_value = validation_response mock_api.get.return_value = versions_response - mock_api.put.return_value = overwrite_response + mock_api.patch.return_value = overwrite_response await pipeline_service.import_async(index_pipeline, config) - # Should call validation endpoint first, then overwrite endpoint + # validation + GET versions + PATCH draft version assert mock_api.post.call_count == 1 + assert mock_api.get.call_count == 1 assert mock_api.patch.call_count == 1 # Check validation call @@ -627,11 +626,60 @@ async def test_import_pipeline_with_overwrite_true( # When overwrite=True, name should be excluded from validation payload assert "name" not in validation_call.kwargs["json"] - # Check overwrite call + # Check PATCH call overwrite_call = mock_api.patch.call_args_list[0] assert overwrite_call.kwargs["endpoint"] == "pipelines/test_pipeline_overwrite/versions/42abcd" assert "config_yaml" in overwrite_call.kwargs["json"] + @pytest.mark.asyncio + async def test_import_pipeline_with_overwrite_true_creates_new_version_when_not_draft( + self, pipeline_service: PipelineService, index_pipeline: Pipeline, mock_api: AsyncMock + ) -> None: + """Test importing a pipeline with overwrite=True creates a new version when latest version is not draft.""" + config = PipelineConfig( + name="test_pipeline_overwrite", + inputs=PipelineInputs(query=["retriever.query"]), + outputs=PipelineOutputs(documents="meta_ranker.documents"), + strict_validation=False, + overwrite=True, + ) + + # Mock successful validation response + validation_response = Mock(spec=Response) + validation_response.status_code = HTTPStatus.NO_CONTENT.value + + # Mock versions response, latest version is NOT a draft + versions_response = Mock(status_code=HTTPStatus.OK.value) + versions_response.json.return_value = { + "data": [{"version_id": "42abcd", "is_draft": False}], + } + + # Mock successful "create new version" response + new_version_response = Mock(spec=Response) + new_version_response.status_code = HTTPStatus.CREATED.value + + # First POST is validation, second POST is "create new version" + mock_api.post.side_effect = [validation_response, new_version_response] + mock_api.get.return_value = versions_response + + await pipeline_service.import_async(index_pipeline, config) + + # validation + GET versions + POST versions (new version) + assert mock_api.post.call_count == 2 + assert mock_api.get.call_count == 1 + assert mock_api.patch.call_count == 0 + + # Check validation call + validation_call = mock_api.post.call_args_list[0] + assert validation_call.kwargs["endpoint"] == "pipeline_validations" + assert "query_yaml" in validation_call.kwargs["json"] + assert "name" not in validation_call.kwargs["json"] + + # Check create-version POST call + create_version_call = mock_api.post.call_args_list[1] + assert create_version_call.kwargs["endpoint"] == "pipelines/test_pipeline_overwrite/versions" + assert "config_yaml" in create_version_call.kwargs["json"] + @pytest.mark.asyncio async def test_import_pipeline_with_overwrite_fallback_to_create( self, pipeline_service: PipelineService, index_pipeline: Pipeline, mock_api: AsyncMock @@ -653,6 +701,9 @@ async def test_import_pipeline_with_overwrite_fallback_to_create( # Mock 404 response for GET (resource not found) not_found_response = Mock(spec=Response) not_found_response.status_code = HTTPStatus.NOT_FOUND.value + not_found_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "Not Found", request=Mock(), response=not_found_response + ) # Mock successful creation response create_response = Mock(spec=Response) @@ -663,7 +714,7 @@ async def test_import_pipeline_with_overwrite_fallback_to_create( await pipeline_service.import_async(index_pipeline, config) - # Should call validation endpoint, then GET (which returns 404), then POST to create + # validation + GET (404) + POST create assert mock_api.post.call_count == 2 assert mock_api.get.call_count == 1 @@ -671,7 +722,6 @@ async def test_import_pipeline_with_overwrite_fallback_to_create( validation_call = mock_api.post.call_args_list[0] assert validation_call.kwargs["endpoint"] == "pipeline_validations" assert "query_yaml" in validation_call.kwargs["json"] - # When overwrite=True, name should be excluded from validation payload assert "name" not in validation_call.kwargs["json"] # Check GET versions attempt