Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/guides/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ For more details, refer to the [Inference Providers pricing documentation](https
| [`~InferenceClient.fill_mask`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| [`~InferenceClient.image_classification`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| [`~InferenceClient.image_segmentation`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| [`~InferenceClient.image_to_image`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | | ❌ | ❌ |
| [`~InferenceClient.image_to_image`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | | ❌ | ❌ |
| [`~InferenceClient.image_to_text`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| [`~InferenceClient.object_detection`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | ❌ |
| [`~InferenceClient.question_answering`] | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
Expand Down
1 change: 1 addition & 0 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1338,6 +1338,7 @@ def image_to_image(
api_key=self.token,
)
response = self._inner_post(request_parameters)
response = provider_helper.get_response(response, request_parameters)
return _bytes_to_image(response)

def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput:
Expand Down
12 changes: 12 additions & 0 deletions src/huggingface_hub/inference/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import io
import json
import logging
import mimetypes
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
Expand Down Expand Up @@ -197,6 +198,17 @@ def _b64_encode(content: ContentT) -> str:
return base64.b64encode(data_as_bytes).decode()


def _as_url(content: ContentT, default_mime_type: str) -> str:
if isinstance(content, str) and (content.startswith("https://") or content.startswith("http://")):
return content

mime_type = (
mimetypes.guess_type(content, strict=False)[0] if isinstance(content, (str, Path)) else None
) or default_mime_type
encoded_data = _b64_encode(content)
return f"data:{mime_type};base64,{encoded_data}"


def _b64_to_image(encoded_image: str) -> "Image":
"""Parse a base64-encoded string into a PIL Image."""
Image = _import_pil_image()
Expand Down
1 change: 1 addition & 0 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1384,6 +1384,7 @@ async def image_to_image(
api_key=self.token,
)
response = await self._inner_post(request_parameters)
response = provider_helper.get_response(response, request_parameters)
return _bytes_to_image(response)

async def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput:
Expand Down
3 changes: 2 additions & 1 deletion src/huggingface_hub/inference/_providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from .novita import NovitaConversationalTask, NovitaTextGenerationTask, NovitaTextToVideoTask
from .nscale import NscaleConversationalTask, NscaleTextToImageTask
from .openai import OpenAIConversationalTask
from .replicate import ReplicateTask, ReplicateTextToImageTask, ReplicateTextToSpeechTask
from .replicate import ReplicateImageToImageTask, ReplicateTask, ReplicateTextToImageTask, ReplicateTextToSpeechTask
from .sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask
from .together import TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask

Expand Down Expand Up @@ -141,6 +141,7 @@
"conversational": OpenAIConversationalTask(),
},
"replicate": {
"image-to-image": ReplicateImageToImageTask(),
"text-to-image": ReplicateTextToImageTask(),
"text-to-speech": ReplicateTextToSpeechTask(),
"text-to-video": ReplicateTask("text-to-video"),
Expand Down
20 changes: 19 additions & 1 deletion src/huggingface_hub/inference/_providers/replicate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Dict, Optional, Union

from huggingface_hub.hf_api import InferenceProviderMapping
from huggingface_hub.inference._common import RequestParameters, _as_dict
from huggingface_hub.inference._common import RequestParameters, _as_dict, _as_url
from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none
from huggingface_hub.utils import get_session

Expand Down Expand Up @@ -70,3 +70,21 @@ def _prepare_payload_as_dict(
payload: Dict = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info) # type: ignore[assignment]
payload["input"]["text"] = payload["input"].pop("prompt") # rename "prompt" to "text" for TTS
return payload


class ReplicateImageToImageTask(ReplicateTask):
def __init__(self):
super().__init__("image-to-image")

def _prepare_payload_as_dict(
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
) -> Optional[Dict]:
image_url = _as_url(inputs, default_mime_type="image/jpeg")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice having a default one 👍


payload: Dict[str, Any] = {"input": {"input_image": image_url, **filter_none(parameters)}}

mapped_model = provider_mapping_info.provider_id
if ":" in mapped_model:
version = mapped_model.split(":", 1)[1]
payload["version"] = version
return payload
31 changes: 30 additions & 1 deletion tests/test_inference_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@
)
from huggingface_hub.errors import HfHubHTTPError, ValidationError
from huggingface_hub.inference._client import _open_as_binary
from huggingface_hub.inference._common import _stream_chat_completion_response, _stream_text_generation_response
from huggingface_hub.inference._common import (
_as_url,
_stream_chat_completion_response,
_stream_text_generation_response,
)
from huggingface_hub.inference._providers import get_provider_helper
from huggingface_hub.inference._providers.hf_inference import _build_chat_completion_url

Expand Down Expand Up @@ -1163,3 +1167,28 @@ def test_chat_completion_url_resolution(
assert request_params.url == expected_request_url
assert request_params.json is not None
assert request_params.json.get("model") == expected_payload_model


@pytest.mark.parametrize(
"content_input, default_mime_type, expected, is_exact_match",
[
("https://my-url.com/cat.gif", "image/jpeg", "https://my-url.com/cat.gif", True),
("assets/image.png", "image/jpeg", "data:image/png;base64,", False),
(Path("assets/image.png"), "image/jpeg", "data:image/png;base64,", False),
("assets/image.foo", "image/jpeg", "data:image/jpeg;base64,", False),
(b"some image bytes", "image/jpeg", "", True),
(io.BytesIO(b"some image bytes"), "image/jpeg", "", True),
],
)
def test_as_url(content_input, default_mime_type, expected, is_exact_match, tmp_path: Path):
if isinstance(content_input, (str, Path)) and not str(content_input).startswith("http"):
file_path = tmp_path / content_input
file_path.parent.mkdir(exist_ok=True, parents=True)
file_path.touch()
content_input = file_path

result = _as_url(content_input, default_mime_type)
if is_exact_match:
assert result == expected
else:
assert result.startswith(expected)
44 changes: 43 additions & 1 deletion tests/test_inference_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@
from huggingface_hub.inference._providers.novita import NovitaConversationalTask, NovitaTextGenerationTask
from huggingface_hub.inference._providers.nscale import NscaleConversationalTask, NscaleTextToImageTask
from huggingface_hub.inference._providers.openai import OpenAIConversationalTask
from huggingface_hub.inference._providers.replicate import ReplicateTask, ReplicateTextToSpeechTask
from huggingface_hub.inference._providers.replicate import (
ReplicateImageToImageTask,
ReplicateTask,
ReplicateTextToSpeechTask,
)
from huggingface_hub.inference._providers.sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask
from huggingface_hub.inference._providers.together import TogetherTextToImageTask

Expand Down Expand Up @@ -1057,6 +1061,44 @@ def test_get_response_single_output(self, mocker):
mock.return_value.get.assert_called_once_with("https://example.com/image.jpg")
assert response == mock.return_value.get.return_value.content

def test_image_to_image_payload(self):
helper = ReplicateImageToImageTask()
dummy_image = b"dummy image data"
encoded_image = base64.b64encode(dummy_image).decode("utf-8")
image_uri = f"data:image/jpeg;base64,{encoded_image}"

# No model version
payload = helper._prepare_payload_as_dict(
dummy_image,
{"num_inference_steps": 20},
InferenceProviderMapping(
provider="replicate",
hf_model_id="google/gemini-pro-vision",
providerId="google/gemini-pro-vision",
task="image-to-image",
status="live",
),
)
assert payload == {
"input": {"input_image": image_uri, "num_inference_steps": 20},
}

payload = helper._prepare_payload_as_dict(
dummy_image,
{"num_inference_steps": 20},
InferenceProviderMapping(
provider="replicate",
hf_model_id="google/gemini-pro-vision",
providerId="google/gemini-pro-vision:123456",
task="image-to-image",
status="live",
),
)
assert payload == {
"input": {"input_image": image_uri, "num_inference_steps": 20},
"version": "123456",
}


class TestSambanovaProvider:
def test_prepare_url_conversational(self):
Expand Down