Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 36 additions & 16 deletions deepset_cloud_sdk/_service/pipeline_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from io import StringIO
from typing import Any, List, Optional, Protocol, runtime_checkable

import httpx
import structlog
from httpx import Response
from pydantic import BaseModel
Expand Down Expand Up @@ -376,28 +377,47 @@ async def _create_index(self, name: str, pipeline_yaml: str) -> Response:
async def _overwrite_pipeline(self, name: str, pipeline_yaml: str) -> Response:
"""Overwrite a pipeline in deepset AI Platform.

:param name: Name of the pipeline.
:param pipeline_yaml: Generated pipeline YAML string.
Behavior:
- First try to fetch the latest version.
- If the pipeline doesn't exist (404), create it instead.
- If the latest version is a draft (is_draft == True), PATCH that version.
- Otherwise, create a new version via POST /pipelines/{name}/versions.
"""
# First get the (last) version id if available
version_response = await self._api.get(
workspace_name=self._workspace_name, endpoint=f"pipelines/{name}/versions"
)

# If pipeline doesn't exist (404), create it instead
if version_response.status_code == HTTPStatus.NOT_FOUND:
logger.debug(f"Pipeline {name} not found, creating new pipeline.")
response = await self._create_pipeline(name=name, pipeline_yaml=pipeline_yaml)
else:
version_body = version_response.json()
version_id = version_body["data"][0]["version_id"]
response = await self._api.patch(
# Fetch versions
try:
version_response = await self._api.get(
workspace_name=self._workspace_name,
endpoint=f"pipelines/{name}/versions",
)
version_response.raise_for_status()
except httpx.HTTPStatusError as e:
if e.response.status_code != HTTPStatus.NOT_FOUND:
raise
# the pipeline does not exist, let's create it.
logger.debug(f"Pipeline '{name}' not found, creating new pipeline.")
return await self._create_pipeline(name=name, pipeline_yaml=pipeline_yaml)

version_body = version_response.json()
latest_version = version_body["data"][0]
version_id = latest_version["version_id"]
is_draft = latest_version.get("is_draft", False)
Comment on lines +400 to +403
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are assuming the response is a valid response without checking. If we get another status code like, say, 401 or 403 these lines may raise an error.

Instead let's do the following:

import httpx

try:
    version_response = await sef._api.get(...)
    version_response.raise_for_status()
except httpx.HTTPStatusError as e:
    if e.response.status_code != HTTPStatus.NOT_FOUND:
        raise
     # the pipeline does not exist, let's create it.
     return await self._api.create_pipeline(...)

# pass the response
version_body = ...

if is_draft:
    # do draft things
    return

# do non-draft things

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @abrahamy Updated this accordingly.


if is_draft:
# Patch existing draft version
logger.debug(f"Patching existing draft version '{version_id}' of pipeline '{name}'.")
return await self._api.patch(
workspace_name=self._workspace_name,
endpoint=f"pipelines/{name}/versions/{version_id}",
json={"config_yaml": pipeline_yaml},
)

return response
# Create a new version
logger.debug(f"Latest version '{version_id}' of pipeline '{name}' is not a draft, creating new version.")
return await self._api.post(
workspace_name=self._workspace_name,
endpoint=f"pipelines/{name}/versions",
json={"config_yaml": pipeline_yaml},
)

async def _create_pipeline(self, name: str, pipeline_yaml: str) -> Response:
"""Create a pipeline in deepset AI Platform.
Expand Down
76 changes: 63 additions & 13 deletions tests/unit/service/test_pipeline_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any
from unittest.mock import AsyncMock, Mock

import httpx
import pytest
from haystack import AsyncPipeline, Pipeline
from haystack.components.converters import CSVToDocument, TextFileToDocument
Expand All @@ -14,10 +15,7 @@
from httpx import Response
from structlog.testing import capture_logs

from deepset_cloud_sdk._service.pipeline_service import (
DeepsetValidationError,
PipelineService,
)
from deepset_cloud_sdk._service.pipeline_service import DeepsetValidationError, PipelineService
from deepset_cloud_sdk.models import (
IndexConfig,
IndexInputs,
Expand Down Expand Up @@ -587,7 +585,7 @@ async def test_import_index_with_overwrite_fallback_to_create(
async def test_import_pipeline_with_overwrite_true(
self, pipeline_service: PipelineService, index_pipeline: Pipeline, mock_api: AsyncMock
) -> None:
"""Test importing a pipeline with overwrite=True uses PUT endpoint."""
"""Test importing a pipeline with overwrite=True patches latest draft version."""
config = PipelineConfig(
name="test_pipeline_overwrite",
inputs=PipelineInputs(query=["retriever.query"]),
Expand All @@ -600,24 +598,25 @@ async def test_import_pipeline_with_overwrite_true(
validation_response = Mock(spec=Response)
validation_response.status_code = HTTPStatus.NO_CONTENT.value

# Mock successful versions response
# Mock successful versions response, latest version is a draft
versions_response = Mock(status_code=HTTPStatus.OK.value)
versions_response.json.return_value = {
"data": [{"version_id": "42abcd"}],
"data": [{"version_id": "42abcd", "is_draft": True}],
}

# Mock successful overwrite response
# Mock successful overwrite (PATCH) response
overwrite_response = Mock(spec=Response)
overwrite_response.status_code = HTTPStatus.OK.value

mock_api.post.return_value = validation_response
mock_api.get.return_value = versions_response
mock_api.put.return_value = overwrite_response
mock_api.patch.return_value = overwrite_response

await pipeline_service.import_async(index_pipeline, config)

# Should call validation endpoint first, then overwrite endpoint
# validation + GET versions + PATCH draft version
assert mock_api.post.call_count == 1
assert mock_api.get.call_count == 1
assert mock_api.patch.call_count == 1

# Check validation call
Expand All @@ -627,11 +626,60 @@ async def test_import_pipeline_with_overwrite_true(
# When overwrite=True, name should be excluded from validation payload
assert "name" not in validation_call.kwargs["json"]

# Check overwrite call
# Check PATCH call
overwrite_call = mock_api.patch.call_args_list[0]
assert overwrite_call.kwargs["endpoint"] == "pipelines/test_pipeline_overwrite/versions/42abcd"
assert "config_yaml" in overwrite_call.kwargs["json"]

@pytest.mark.asyncio
async def test_import_pipeline_with_overwrite_true_creates_new_version_when_not_draft(
self, pipeline_service: PipelineService, index_pipeline: Pipeline, mock_api: AsyncMock
) -> None:
"""Test importing a pipeline with overwrite=True creates a new version when latest version is not draft."""
config = PipelineConfig(
name="test_pipeline_overwrite",
inputs=PipelineInputs(query=["retriever.query"]),
outputs=PipelineOutputs(documents="meta_ranker.documents"),
strict_validation=False,
overwrite=True,
)

# Mock successful validation response
validation_response = Mock(spec=Response)
validation_response.status_code = HTTPStatus.NO_CONTENT.value

# Mock versions response, latest version is NOT a draft
versions_response = Mock(status_code=HTTPStatus.OK.value)
versions_response.json.return_value = {
"data": [{"version_id": "42abcd", "is_draft": False}],
}

# Mock successful "create new version" response
new_version_response = Mock(spec=Response)
new_version_response.status_code = HTTPStatus.CREATED.value

# First POST is validation, second POST is "create new version"
mock_api.post.side_effect = [validation_response, new_version_response]
mock_api.get.return_value = versions_response

await pipeline_service.import_async(index_pipeline, config)

# validation + GET versions + POST versions (new version)
assert mock_api.post.call_count == 2
assert mock_api.get.call_count == 1
assert mock_api.patch.call_count == 0

# Check validation call
validation_call = mock_api.post.call_args_list[0]
assert validation_call.kwargs["endpoint"] == "pipeline_validations"
assert "query_yaml" in validation_call.kwargs["json"]
assert "name" not in validation_call.kwargs["json"]

# Check create-version POST call
create_version_call = mock_api.post.call_args_list[1]
assert create_version_call.kwargs["endpoint"] == "pipelines/test_pipeline_overwrite/versions"
assert "config_yaml" in create_version_call.kwargs["json"]

@pytest.mark.asyncio
async def test_import_pipeline_with_overwrite_fallback_to_create(
self, pipeline_service: PipelineService, index_pipeline: Pipeline, mock_api: AsyncMock
Expand All @@ -653,6 +701,9 @@ async def test_import_pipeline_with_overwrite_fallback_to_create(
# Mock 404 response for GET (resource not found)
not_found_response = Mock(spec=Response)
not_found_response.status_code = HTTPStatus.NOT_FOUND.value
not_found_response.raise_for_status.side_effect = httpx.HTTPStatusError(
"Not Found", request=Mock(), response=not_found_response
)

# Mock successful creation response
create_response = Mock(spec=Response)
Expand All @@ -663,15 +714,14 @@ async def test_import_pipeline_with_overwrite_fallback_to_create(

await pipeline_service.import_async(index_pipeline, config)

# Should call validation endpoint, then GET (which returns 404), then POST to create
# validation + GET (404) + POST create
assert mock_api.post.call_count == 2
assert mock_api.get.call_count == 1

# Check validation call
validation_call = mock_api.post.call_args_list[0]
assert validation_call.kwargs["endpoint"] == "pipeline_validations"
assert "query_yaml" in validation_call.kwargs["json"]
# When overwrite=True, name should be excluded from validation payload
assert "name" not in validation_call.kwargs["json"]

# Check GET versions attempt
Expand Down
Loading