Skip to content

Commit

Permalink
Merge pull request #2738 from hlohaus/16Feb
Browse files Browse the repository at this point in the history
Use gradio api in flux dev
  • Loading branch information
hlohaus authored Feb 21, 2025
2 parents 4cfe35f + 69d0b09 commit ba2e6eb
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 43 deletions.
107 changes: 68 additions & 39 deletions g4f/Provider/hf_space/BlackForestLabsFlux1Dev.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

import json
from aiohttp import ClientSession
import uuid

from ...typing import AsyncResult, Messages
from ...providers.response import ImageResponse, ImagePreview, JsonConversation
from ...providers.response import ImageResponse, ImagePreview, JsonConversation, Reasoning
from ...requests import StreamSession
from ...errors import ResponseError
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_image_prompt
Expand All @@ -14,7 +15,7 @@
class BlackForestLabsFlux1Dev(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://black-forest-labs-flux-1-dev.hf.space"
space = "black-forest-labs/FLUX.1-dev"
api_endpoint = "/gradio_api/call/infer"
referer = f"{url}/?__theme=light"

working = True

Expand All @@ -24,6 +25,29 @@ class BlackForestLabsFlux1Dev(AsyncGeneratorProvider, ProviderModelMixin):
image_models = [default_image_model, *model_aliases.keys()]
models = image_models

@classmethod
def run(cls, method: str, session: StreamSession, conversation: JsonConversation, data: list = None):
headers = {
"accept": "application/json",
"content-type": "application/json",
"x-zerogpu-token": conversation.zerogpu_token,
"x-zerogpu-uuid": conversation.zerogpu_uuid,
"referer": cls.referer,
}
if method == "post":
return session.post(f"{cls.url}/gradio_api/queue/join?__theme=light", **{
"headers": {k: v for k, v in headers.items() if v is not None},
"json": {"data": data,"event_data":None,"fn_index":2,"trigger_id":4,"session_hash":conversation.session_hash}

})
return session.get(f"{cls.url}/gradio_api/queue/data?session_hash={conversation.session_hash}", **{
"headers": {
"accept": "text/event-stream",
"content-type": "application/json",
"referer": cls.referer,
}
})

@classmethod
async def create_async_generator(
cls,
Expand All @@ -43,44 +67,49 @@ async def create_async_generator(
**kwargs
) -> AsyncResult:
model = cls.get_model(model)
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}
async with ClientSession(headers=headers) as session:
async with StreamSession(impersonate="chrome", proxy=proxy) as session:
prompt = format_image_prompt(messages, prompt)
data = {
"data": [prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps]
}
if zerogpu_token is None:
zerogpu_uuid, zerogpu_token = await get_zerogpu_token(cls.space, session, JsonConversation(), cookies)
headers = {
"x-zerogpu-token": zerogpu_token,
"x-zerogpu-uuid": zerogpu_uuid,
}
headers = {k: v for k, v in headers.items() if v is not None}
async with session.post(f"{cls.url}{cls.api_endpoint}", json=data, proxy=proxy, headers=headers) as response:
data = [prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps]
conversation = JsonConversation(zerogpu_token=zerogpu_token, zerogpu_uuid=zerogpu_uuid, session_hash=uuid.uuid4().hex)
if conversation.zerogpu_token is None:
conversation.zerogpu_uuid, conversation.zerogpu_token = await get_zerogpu_token(cls.space, session, conversation, cookies)
async with cls.run(f"post", session, conversation, data) as response:
await raise_for_status(response)
event_id = (await response.json()).get("event_id")
async with session.get(f"{cls.url}{cls.api_endpoint}/{event_id}") as event_response:
assert (await response.json()).get("event_id")
async with cls.run("get", session, conversation) as event_response:
await raise_for_status(event_response)
event = None
async for chunk in event_response.content:
if chunk.startswith(b"event: "):
event = chunk[7:].decode(errors="replace").strip()
async for chunk in event_response.iter_lines():
if chunk.startswith(b"data: "):
if event == "error":
raise ResponseError(f"GPU token limit exceeded: {chunk.decode(errors='replace')}")
if event in ("complete", "generating"):
try:
data = json.loads(chunk[6:])
if data is None:
continue
url = data[0]["url"]
except (json.JSONDecodeError, KeyError, TypeError) as e:
raise RuntimeError(f"Failed to parse image URL: {chunk.decode(errors='replace')}", e)
if event == "generating":
yield ImagePreview(url, prompt)
else:
yield ImageResponse(url, prompt)
try:
json_data = json.loads(chunk[6:])
if json_data is None:
continue
if json_data.get('msg') == 'log':
yield Reasoning(status=json_data["log"])

if json_data.get('msg') == 'progress':
if 'progress_data' in json_data:
if json_data['progress_data']:
progress = json_data['progress_data'][0]
yield Reasoning(status=f"{progress['desc']} {progress['index']}/{progress['length']}")
else:
yield Reasoning(status=f"Generating")

elif json_data.get('msg') == 'process_generating':
for item in json_data['output']['data'][0]:
if isinstance(item, dict) and "url" in item:
yield ImagePreview(item["url"], prompt)
elif isinstance(item, list) and len(item) > 2 and "url" in item[1]:
yield ImagePreview(item[2], prompt)

elif json_data.get('msg') == 'process_completed':
if 'output' in json_data and 'error' in json_data['output']:
json_data['output']['error'] = json_data['output']['error'].split(" <a ")[0]
raise ResponseError(json_data['output']['error'])
if 'output' in json_data and 'data' in json_data['output']:
yield Reasoning(status="Finished")
if len(json_data['output']['data']) > 0:
yield ImageResponse(json_data['output']['data'][0]["url"], prompt)
break
except (json.JSONDecodeError, KeyError, TypeError) as e:
raise RuntimeError(f"Failed to parse message: {chunk.decode(errors='replace')}", e)
1 change: 1 addition & 0 deletions g4f/Provider/hf_space/G4F.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
class FluxDev(BlackForestLabsFlux1Dev):
url = "https://roxky-flux-1-dev.hf.space"
space = "roxky/FLUX.1-dev"
referer = f"{url}/?__theme=light"

class G4F(Janus_Pro_7B):
label = "G4F framework"
Expand Down
4 changes: 0 additions & 4 deletions g4f/gui/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from ... import ChatCompletion, get_model_and_provider
from ... import debug

logger = logging.getLogger(__name__)
conversations: dict[dict[str, BaseConversation]] = {}

class Api:
Expand Down Expand Up @@ -156,7 +155,6 @@ def decorated_log(text: str, file = None):
has_images="images" in kwargs,
)
except Exception as e:
logger.exception(e)
debug.error(e)
yield self._format_json('error', type(e).__name__, message=get_error_message(e))
return
Expand Down Expand Up @@ -186,7 +184,6 @@ def decorated_log(text: str, file = None):
else:
yield self._format_json("conversation_id", conversation_id)
elif isinstance(chunk, Exception):
logger.exception(chunk)
debug.error(chunk)
yield self._format_json('message', get_error_message(chunk), error=type(chunk).__name__)
elif isinstance(chunk, PreviewResponse):
Expand Down Expand Up @@ -222,7 +219,6 @@ def decorated_log(text: str, file = None):
yield self._format_json("content", str(chunk))
yield from self._yield_logs()
except Exception as e:
logger.exception(e)
debug.error(e)
yield from self._yield_logs()
yield self._format_json('error', type(e).__name__, message=get_error_message(e))
Expand Down

0 comments on commit ba2e6eb

Please sign in to comment.