Skip to content

Commit 7e86848

Browse files
arabot777github-actions[bot]hanouticelina
authored
[inference provider] Add wavespeed.ai as an inference provider (#3474)
* init wavespeed ai * python wavespeed ai * mapping * Apply style fixes * Address review feedback for Wavespeed AI provider * Remove redundant _prepare_headers override in WavespeedAITask * Apply style fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: célina <[email protected]>
1 parent e5532d9 commit 7e86848

File tree

7 files changed

+359
-35
lines changed

7 files changed

+359
-35
lines changed

docs/source/en/guides/inference.md

Lines changed: 31 additions & 31 deletions
Large diffs are not rendered by default.

src/huggingface_hub/inference/_client.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ class InferenceClient:
135135
Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2
136136
arguments are mutually exclusive. If a URL is passed as `model` or `base_url` for chat completion, the `(/v1)/chat/completions` suffix path will be appended to the URL.
137137
provider (`str`, *optional*):
138-
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"clarifai"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `publicai`, `"replicate"`, `"sambanova"`, `"scaleway"`, `"together"` or `"zai-org"`.
138+
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"clarifai"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"publicai"`, `"replicate"`, `"sambanova"`, `"scaleway"`, `"together"`, `"wavespeed"` or `"zai-org"`.
139139
Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
140140
If model is a URL or `base_url` is passed, then `provider` is not used.
141141
token (`str`, *optional*):
@@ -1321,6 +1321,7 @@ def image_to_image(
13211321
>>> image = client.image_to_image("cat.jpg", prompt="turn the cat into a tiger")
13221322
>>> image.save("tiger.jpg")
13231323
```
1324+
13241325
"""
13251326
model_id = model or self.model
13261327
provider_helper = get_provider_helper(self.provider, task="image-to-image", model=model_id)
@@ -2540,6 +2541,7 @@ def text_to_image(
25402541
... )
25412542
>>> image.save("astronaut.png")
25422543
```
2544+
25432545
"""
25442546
model_id = model or self.model
25452547
provider_helper = get_provider_helper(self.provider, task="text-to-image", model=model_id)
@@ -2560,7 +2562,7 @@ def text_to_image(
25602562
api_key=self.token,
25612563
)
25622564
response = self._inner_post(request_parameters)
2563-
response = provider_helper.get_response(response)
2565+
response = provider_helper.get_response(response, request_parameters)
25642566
return _bytes_to_image(response)
25652567

25662568
def text_to_video(
@@ -2638,6 +2640,7 @@ def text_to_video(
26382640
>>> with open("cat.mp4", "wb") as file:
26392641
... file.write(video)
26402642
```
2643+
26412644
"""
26422645
model_id = model or self.model
26432646
provider_helper = get_provider_helper(self.provider, task="text-to-video", model=model_id)

src/huggingface_hub/inference/_generated/_async_client.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ class AsyncInferenceClient:
126126
Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2
127127
arguments are mutually exclusive. If a URL is passed as `model` or `base_url` for chat completion, the `(/v1)/chat/completions` suffix path will be appended to the URL.
128128
provider (`str`, *optional*):
129-
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"clarifai"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `publicai`, `"replicate"`, `"sambanova"`, `"scaleway"`, `"together"` or `"zai-org"`.
129+
Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"clarifai"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"publicai"`, `"replicate"`, `"sambanova"`, `"scaleway"`, `"together"`, `"wavespeed"` or `"zai-org"`.
130130
Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
131131
If model is a URL or `base_url` is passed, then `provider` is not used.
132132
token (`str`, *optional*):
@@ -1353,6 +1353,7 @@ async def image_to_image(
13531353
>>> image = await client.image_to_image("cat.jpg", prompt="turn the cat into a tiger")
13541354
>>> image.save("tiger.jpg")
13551355
```
1356+
13561357
"""
13571358
model_id = model or self.model
13581359
provider_helper = get_provider_helper(self.provider, task="image-to-image", model=model_id)
@@ -2584,6 +2585,7 @@ async def text_to_image(
25842585
... )
25852586
>>> image.save("astronaut.png")
25862587
```
2588+
25872589
"""
25882590
model_id = model or self.model
25892591
provider_helper = get_provider_helper(self.provider, task="text-to-image", model=model_id)
@@ -2604,7 +2606,7 @@ async def text_to_image(
26042606
api_key=self.token,
26052607
)
26062608
response = await self._inner_post(request_parameters)
2607-
response = provider_helper.get_response(response)
2609+
response = provider_helper.get_response(response, request_parameters)
26082610
return _bytes_to_image(response)
26092611

26102612
async def text_to_video(
@@ -2682,6 +2684,7 @@ async def text_to_video(
26822684
>>> with open("cat.mp4", "wb") as file:
26832685
... file.write(video)
26842686
```
2687+
26852688
"""
26862689
model_id = model or self.model
26872690
provider_helper = get_provider_helper(self.provider, task="text-to-video", model=model_id)

src/huggingface_hub/inference/_providers/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,12 @@
4242
from .sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask
4343
from .scaleway import ScalewayConversationalTask, ScalewayFeatureExtractionTask
4444
from .together import TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask
45+
from .wavespeed import (
46+
WavespeedAIImageToImageTask,
47+
WavespeedAIImageToVideoTask,
48+
WavespeedAITextToImageTask,
49+
WavespeedAITextToVideoTask,
50+
)
4551
from .zai_org import ZaiConversationalTask
4652

4753

@@ -68,6 +74,7 @@
6874
"sambanova",
6975
"scaleway",
7076
"together",
77+
"wavespeed",
7178
"zai-org",
7279
]
7380

@@ -179,6 +186,12 @@
179186
"conversational": TogetherConversationalTask(),
180187
"text-generation": TogetherTextGenerationTask(),
181188
},
189+
"wavespeed": {
190+
"text-to-image": WavespeedAITextToImageTask(),
191+
"text-to-video": WavespeedAITextToVideoTask(),
192+
"image-to-image": WavespeedAIImageToImageTask(),
193+
"image-to-video": WavespeedAIImageToVideoTask(),
194+
},
182195
"zai-org": {
183196
"conversational": ZaiConversationalTask(),
184197
},

src/huggingface_hub/inference/_providers/_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
"sambanova": {},
3737
"scaleway": {},
3838
"together": {},
39+
"wavespeed": {},
3940
"zai-org": {},
4041
}
4142

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import base64
2+
import time
3+
from abc import ABC
4+
from typing import Any, Optional, Union
5+
from urllib.parse import urlparse
6+
7+
from huggingface_hub.hf_api import InferenceProviderMapping
8+
from huggingface_hub.inference._common import RequestParameters, _as_dict
9+
from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none
10+
from huggingface_hub.utils import get_session, hf_raise_for_status
11+
from huggingface_hub.utils.logging import get_logger
12+
13+
14+
logger = get_logger(__name__)
15+
16+
# Polling interval (in seconds)
17+
_POLLING_INTERVAL = 0.5
18+
19+
20+
class WavespeedAITask(TaskProviderHelper, ABC):
21+
def __init__(self, task: str):
22+
super().__init__(provider="wavespeed", base_url="https://api.wavespeed.ai", task=task)
23+
24+
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
25+
return f"/api/v3/{mapped_model}"
26+
27+
def get_response(
28+
self,
29+
response: Union[bytes, dict],
30+
request_params: Optional[RequestParameters] = None,
31+
) -> Any:
32+
response_dict = _as_dict(response)
33+
data = response_dict.get("data", {})
34+
result_path = data.get("urls", {}).get("get")
35+
36+
if not result_path:
37+
raise ValueError("No result URL found in the response")
38+
if request_params is None:
39+
raise ValueError("A `RequestParameters` object should be provided to get responses with WaveSpeed AI.")
40+
41+
# Parse the request URL to determine base URL
42+
parsed_url = urlparse(request_params.url)
43+
# Add /wavespeed to base URL if going through HF router
44+
if parsed_url.netloc == "router.huggingface.co":
45+
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}/wavespeed"
46+
else:
47+
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
48+
49+
# Extract path from result_path URL
50+
if isinstance(result_path, str):
51+
result_url_path = urlparse(result_path).path
52+
else:
53+
result_url_path = result_path
54+
55+
result_url = f"{base_url}{result_url_path}"
56+
57+
logger.info("Processing request, polling for results...")
58+
59+
# Poll until task is completed
60+
while True:
61+
time.sleep(_POLLING_INTERVAL)
62+
result_response = get_session().get(result_url, headers=request_params.headers)
63+
hf_raise_for_status(result_response)
64+
65+
result = result_response.json()
66+
task_result = result.get("data", {})
67+
status = task_result.get("status")
68+
69+
if status == "completed":
70+
# Get content from the first output URL
71+
if not task_result.get("outputs") or len(task_result["outputs"]) == 0:
72+
raise ValueError("No output URL in completed response")
73+
74+
output_url = task_result["outputs"][0]
75+
return get_session().get(output_url).content
76+
elif status == "failed":
77+
error_msg = task_result.get("error", "Task failed with no specific error message")
78+
raise ValueError(f"WaveSpeed AI task failed: {error_msg}")
79+
elif status in ["processing", "created"]:
80+
continue
81+
else:
82+
raise ValueError(f"Unknown status: {status}")
83+
84+
85+
class WavespeedAITextToImageTask(WavespeedAITask):
86+
def __init__(self):
87+
super().__init__("text-to-image")
88+
89+
def _prepare_payload_as_dict(
90+
self,
91+
inputs: Any,
92+
parameters: dict,
93+
provider_mapping_info: InferenceProviderMapping,
94+
) -> Optional[dict]:
95+
return {"prompt": inputs, **filter_none(parameters)}
96+
97+
98+
class WavespeedAITextToVideoTask(WavespeedAITextToImageTask):
99+
def __init__(self):
100+
WavespeedAITask.__init__(self, "text-to-video")
101+
102+
103+
class WavespeedAIImageToImageTask(WavespeedAITask):
104+
def __init__(self):
105+
super().__init__("image-to-image")
106+
107+
def _prepare_payload_as_dict(
108+
self,
109+
inputs: Any,
110+
parameters: dict,
111+
provider_mapping_info: InferenceProviderMapping,
112+
) -> Optional[dict]:
113+
# Convert inputs to image (URL or base64)
114+
if isinstance(inputs, str) and inputs.startswith(("http://", "https://")):
115+
image = inputs
116+
elif isinstance(inputs, str):
117+
# If input is a file path, read it first
118+
with open(inputs, "rb") as f:
119+
file_content = f.read()
120+
image_b64 = base64.b64encode(file_content).decode("utf-8")
121+
image = f"data:image/jpeg;base64,{image_b64}"
122+
else:
123+
# If input is binary data
124+
image_b64 = base64.b64encode(inputs).decode("utf-8")
125+
image = f"data:image/jpeg;base64,{image_b64}"
126+
127+
# Extract prompt from parameters if present
128+
prompt = parameters.pop("prompt", None)
129+
payload = {"image": image, **filter_none(parameters)}
130+
if prompt is not None:
131+
payload["prompt"] = prompt
132+
133+
return payload
134+
135+
136+
class WavespeedAIImageToVideoTask(WavespeedAIImageToImageTask):
137+
def __init__(self):
138+
WavespeedAITask.__init__(self, "image-to-video")

0 commit comments

Comments
 (0)