diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index f600032..a8d2cb9 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -165,3 +165,26 @@ jobs: env: FMA_APIKEY: ${{ secrets.FMA_APIKEY }} run: pytest tests/test_async_inference.py --tb=short + + test_m1_clients: + needs: + - lint + runs-on: ubuntu-latest + + container: + image: python:3.8 + + steps: + - name: Check out git repo + uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Fix + run: git config --global --add safe.directory '*' + + - name: Install dependencies + run: pip3 install poetry pytest-asyncio && poetry config virtualenvs.create false && poetry install + + - name: Test + run: pytest tests/test_flymyai_m1_client.py --tb=short diff --git a/README.md b/README.md index d37d102..ecc6686 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,5 @@ # FlyMy.AI +

Generated with FlyMy.AI in 🚀 70ms
Generated with FlyMy.AI in 🚀 70ms @@ -6,7 +7,6 @@ Welcome to FlyMy.AI inference platform. Our goal is to provide the fastest and most affordable deployment solutions for neural networks and AI applications. - - **Fast Inference**: Experience the fastest Stable Diffusion inference globally. - **Scalability**: Autoscaling to millions of users per second. - **Ease of Use**: One-click deployment for any publicly available neural networks. @@ -16,12 +16,10 @@ Welcome to FlyMy.AI inference platform. Our goal is to provide the fastest and m For more information, visit our website: [FlyMy.AI](https://flymy.ai) Or connect with us and other users on Discord: [Join Discord](https://discord.com/invite/t6hPBpSebw) - ## Getting Started This is a Python client for [FlyMyAI](https://flymy.ai). It allows you to easily run models and get predictions from your Python code in sync and async mode. - ## Requirements - Python 3.8+ @@ -35,12 +33,15 @@ pip install flymyai ``` ## Authentication + Before using the client, you need to have your API key, username, and project name. In order to get credentials, you have to sign up on flymy.ai and get your personal data on [the profile](https://app.flymy.ai/profile). ## Basic Usage + Here's a simple example of how to use the FlyMyAI client: #### BERT Sentiment analysis + ```python import flymyai @@ -52,11 +53,12 @@ response = flymyai.run( print(response.output_data["logits"][0]) ``` - ## Sync Streams + For llms you should use stream method #### llama 3.1 8b + ```python from flymyai import client, FlyMyAIPredictException @@ -87,6 +89,7 @@ finally: ``` ## Async Streams + For llms you should use stream method #### Stable Code Instruct 3b @@ -126,10 +129,10 @@ async def run_stable_code(): asyncio.run(run_stable_code()) ``` - - ## File Inputs + #### ResNet image classification + You can pass file inputs to models using file paths: ```python @@ -145,9 +148,10 @@ response = flymyai.run( print(response.output_data["495"]) ``` - ## File Response Handling + Files received from the neural network are always encoded in base64 format. To process these files, you need to decode them first. Here's an example of how to handle an image file: + #### StableDiffusion Turbo image generation in ~50ms 🚀 ```python @@ -167,8 +171,8 @@ with open("generated_image.jpg", "wb") as file: file.write(image_data) ``` - ## Asynchronous Requests + FlyMyAI supports asynchronous requests for improved performance. Here's how to use it: ```python @@ -207,6 +211,7 @@ asyncio.run(main()) ``` ## Running Models in the Background + To run a model in the background, simply use the async_run() method: ```python @@ -233,7 +238,6 @@ asyncio.run(main()) # Continue with other operations while the model runs in the background ``` - ## Asynchronous Prediction Tasks For long-running operations, FlyMyAI provides asynchronous prediction tasks. This allows you to submit a task and check its status later, which is useful for handling time-consuming predictions without blocking your application. @@ -280,25 +284,25 @@ from flymyai.core.exceptions import ( async def run_prediction(): # Initialize async client fma_client = async_client(apikey="fly-secret-key") - + # Submit async prediction task prediction_task = await fma_client.predict_async_task( model="flymyai/flux-schnell", payload={"prompt": "Funny Cat with Stupid Dog"} ) - + try: # Await result with default timeout result = await prediction_task.result() print(f"Prediction completed: {result.inference_responses}") - + # Check response status all_successful = all( - resp.infer_details["status"] == 200 + resp.infer_details["status"] == 200 for resp in result.inference_responses ) print(f"All predictions successful: {all_successful}") - + except RetryTimeoutExceededException: print("Prediction is taking longer than expected") except FlyMyAIExceptionGroup as e: @@ -306,4 +310,74 @@ async def run_prediction(): # Run async function asyncio.run(run_prediction()) -``` \ No newline at end of file +``` + +## M1 Agent Usage + +### Using Synchronous Client + +```python +from flymyai import m1_client + +client = m1_client(apikey="fly-secret-key") +result = client.generate("An Iron Man") +print(result.data.text, result.data.file_url) +``` + +FlymyAI M1 client also stores request history for later generation context: + +```python +from flymyai import m1_client + +client = m1_client(apikey="fly-secret-key") + +result = client.generate("An Iron Man") +print(result.data.text, result.data.file_url) + +result = client.generate("Add him Captain America's shield") +print(result.data.text, result.data.file_url) +``` + +#### Passing image + +```python +from pathlib import Path +from flymyai import m1_client + +client = m1_client(apikey="fly-secret-key") +result = client.generate("An Iron Man", image=Path("./image.png")) +print(result.data.text, result.data.file_url) +``` + +### Using Asynchronous Client + +```python +import asyncio +from flymyai import async_m1_client + + +async def main(): + client = async_m1_client(apikey="fly-secret-key") + result = await client.generate("An Iron Man") + print(result.data.text, result.data.file_url) + + +asyncio.run(main()) +``` + +#### Passing image + +```python +import asyncio +from pathlib import Path +from flymyai import async_m1_client + + +async def main(): + client = async_m1_client(apikey="fly-secret-key") + result = await client.generate("An Iron Man", image=Path("./image.png")) + print(result.data.text, result.data.file_url) + + +asyncio.run(main()) +``` diff --git a/flymyai/__init__.py b/flymyai/__init__.py index 3fd3ca3..5542c2f 100644 --- a/flymyai/__init__.py +++ b/flymyai/__init__.py @@ -1,9 +1,8 @@ import httpx -from flymyai.core.client import FlyMyAI, AsyncFlyMyAI +from flymyai.core.client import FlyMyAI, AsyncFlyMyAI, FlyMyAIM1, AsyncFlymyAIM1 from flymyai.core.exceptions import FlyMyAIPredictException, FlyMyAIExceptionGroup - __all__ = [ "run", "httpx", @@ -19,3 +18,6 @@ async_client = AsyncFlyMyAI run = client.run_predict async_run = async_client.arun_predict + +m1_client = FlyMyAIM1 +async_m1_client = AsyncFlymyAIM1 diff --git a/flymyai/core/_response.py b/flymyai/core/_response.py index 3a49535..6f7a209 100644 --- a/flymyai/core/_response.py +++ b/flymyai/core/_response.py @@ -1,5 +1,7 @@ import json +import os import typing +from dataclasses import dataclass import httpx @@ -23,3 +25,47 @@ def json(self, **kwargs) -> typing.Any: return json.loads(self.content.removeprefix(b"event")) else: return super().json(**kwargs) + + +@dataclass +class ChatResponseData: + text: typing.Optional[str] + tool_used: typing.Optional[str] + file_url: typing.Optional[str] + + @classmethod + def from_dict( + cls, data: typing.Optional[dict] + ) -> typing.Optional["ChatResponseData"]: + if not data: + return None + return cls( + text=data.get("text"), + tool_used=data.get("tool_used"), + file_url=( + "".join( + [ + os.getenv("FLYMYAI_M1_DSN", "https://api.chat.flymy.ai/"), + data.get("file_url"), + ] + ) + if data.get("file_url") + else None + ), + ) + + +@dataclass +class FlyMyAIM1Response: + success: bool + error: typing.Optional[str] + data: typing.Optional[ChatResponseData] + + @classmethod + def from_httpx(cls, response): + json_data = response.json() + return cls( + success=json_data.get("success", False), + error=json_data.get("error"), + data=ChatResponseData.from_dict(json_data.get("data")), + ) diff --git a/flymyai/core/client.py b/flymyai/core/client.py index 62eded1..ac8f63d 100644 --- a/flymyai/core/client.py +++ b/flymyai/core/client.py @@ -1,8 +1,16 @@ from flymyai.core.clients.AsyncClient import BaseAsyncClient from flymyai.core.clients.SyncClient import BaseSyncClient +from flymyai.core.clients.m1Client import BaseM1SyncClient +from flymyai.core.clients.m1AsyncClient import BaseM1AsyncClient class FlyMyAI(BaseSyncClient): ... class AsyncFlyMyAI(BaseAsyncClient): ... + + +class FlyMyAIM1(BaseM1SyncClient): ... + + +class AsyncFlymyAIM1(BaseM1AsyncClient): ... diff --git a/flymyai/core/clients/base_m1_client.py b/flymyai/core/clients/base_m1_client.py new file mode 100644 index 0000000..eb9e994 --- /dev/null +++ b/flymyai/core/clients/base_m1_client.py @@ -0,0 +1,93 @@ +import os +from pathlib import Path +from typing import overload, Union, TypeVar, Generic, Optional + +import httpx + +from flymyai.core._response import FlyMyAIM1Response +from flymyai.core.types.m1 import M1GenerationTask +from flymyai.core.models.m1_history import M1History + +DEFAULT_RETRY_COUNT = os.getenv("FLYMYAI_MAX_RETRIES", 2) + +_PossibleClients = TypeVar( + "_PossibleClients", bound=Union[httpx.Client, httpx.AsyncClient] +) + + +_predict_timeout = httpx.Timeout( + connect=int(os.getenv("FMA_CONNECT_TIMEOUT", 999999)), + read=int(os.getenv("FMA_READ_TIMEOUT", 999999)), + write=int(os.getenv("FMA_WRITE_TIMEOUT", 999999)), + pool=int(os.getenv("FMA_POOL_TIMEOUT", 999999)), +) + + +class BaseM1Client(Generic[_PossibleClients]): + client: _PossibleClients + _m1_history: M1History + _image: Optional[str] + + def __init__(self, apikey: str): + self._apikey = apikey + self._client = self._construct_client() + self._m1_history = M1History() + self._image = None + + def reset_history(self): + self._m1_history = M1History() + + @overload + def generate( + self, prompt: str, image: Optional[Union[str, Path]] = None + ) -> FlyMyAIM1Response: ... + + @overload + def generation_task( + self, prompt: str, image: Optional[Union[str, Path]] = None + ) -> M1GenerationTask: ... + + @overload + def generation_task_result( + self, generation_task: M1GenerationTask + ) -> FlyMyAIM1Response: ... + + @overload + def upload_image(self, image: Union[str, Path]): ... + + @overload + async def generate( + self, prompt: str, image: Optional[Union[str, Path]] = None + ) -> FlyMyAIM1Response: ... + + @overload + async def generation_task( + self, prompt: str, image: Optional[Union[str, Path]] = None + ) -> M1GenerationTask: ... + + @overload + async def generation_task_result( + self, generation_task: M1GenerationTask + ) -> FlyMyAIM1Response: ... + + @overload + async def upload_image(self, image: Union[str, Path]): ... + + @property + def _headers(self): + return {"X-API-KEY": self._apikey} + + @property + def _generation_path(self): + return "/chat" + + @property + def _result_path(self): + return "/chat-result/" + + def _populate_result_path(self, generation_task: M1GenerationTask): + return "".join([self._result_path, generation_task.request_id]) + + @property + def _image_upload_path(self): + return "/upload-image" diff --git a/flymyai/core/clients/m1AsyncClient.py b/flymyai/core/clients/m1AsyncClient.py new file mode 100644 index 0000000..78eb176 --- /dev/null +++ b/flymyai/core/clients/m1AsyncClient.py @@ -0,0 +1,122 @@ +import asyncio +import os +from typing import Union, Optional +from pathlib import Path + +import httpx + +from flymyai.core._response import FlyMyAIM1Response +from flymyai.core.types.m1 import M1GenerationTask, M1Record, M1Role +from flymyai.core.clients.base_m1_client import BaseM1Client, _predict_timeout + + +class BaseM1AsyncClient(BaseM1Client[httpx.AsyncClient]): + """Asynchronous client for interacting with FlyMyAI M1 chat generation models. + Handles image uploads, chat history tracking, and result polling. + """ + + def _construct_client(self): + return httpx.AsyncClient( + http2=True, + headers=self._headers, + base_url=os.getenv("FLYMYAI_M1_DSN", "https://api.chat.flymy.ai/"), + timeout=_predict_timeout, + ) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if hasattr(self, "_client"): + await self._client.aclose() + + async def generate( + self, prompt: str, image: Union[str, Path, None] = None + ) -> FlyMyAIM1Response: + """Submit a chat prompt with optional image input and return the final generation result. + + :param prompt: User input string to send to the model. + :param image: Local image file (as `Path`) or remote image URL (as `str`). + :return: FlyMyAIM1Response with generated content and metadata. + """ + await self._process_image(image) + self._m1_history.add(M1Record(role=M1Role.user, content=prompt)) + generation_task = await self.generation_task() + result = await self.generation_task_result(generation_task) + return result + + async def _process_image(self, image: Optional[Union[str, Path]]) -> Optional[str]: + if image is None: + return + + image_url = None + + if isinstance(image, Path): + image_url = await self.upload_image(image) + elif isinstance(image, str): + image_url = image + + self._image = image_url + return image_url + + async def generation_task(self) -> M1GenerationTask: + payload = { + "chat_history": self._m1_history.serialize(), + "image_url": self._image, + } + response = await self._client.post( + self._generation_path, json=payload, headers=self._headers + ) + response.raise_for_status() + response_data = response.json() + return M1GenerationTask(request_id=response_data["request_id"]) + + async def generation_task_result( + self, generation_task: M1GenerationTask + ) -> FlyMyAIM1Response: + while True: + response = await self._client.get( + self._populate_result_path(generation_task) + ) + response.raise_for_status() + response_data = response.json() + + if response_data.get("success"): + self._m1_history.add( + M1Record( + role=M1Role.assistant, + content=response_data.get("data", {}).get("text", ""), + ) + ) + if file_url := response_data.get("data", {}).get("file_url", ""): + if not file_url.endswith(".mp4"): + self._image = ( + os.getenv("FLYMYAI_M1_DSN", "https://api.chat.flymy.ai/") + + file_url + ) + return FlyMyAIM1Response.from_httpx(response) + + if response_data.get("error") == "Still processing": + await asyncio.sleep(1) + continue + + raise RuntimeError( + f"Generation failed with status {response_data.get('status')}: {response_data.get('error')}" + ) + + async def upload_image(self, image: Union[str, Path]) -> str: + """Upload a local image file and receive a hosted URL. + + :param image: Local file path (as `str` or `Path`). + :return: Hosted image URL returned by the server. + """ + image_path = Path(image) if isinstance(image, str) else image + with image_path.open("rb") as f: + files = {"file": (image_path.name, f, "image/png")} + response = await self._client.post(self._image_upload_path, files=files) + response.raise_for_status() + response_data = response.json() + return ( + os.getenv("FLYMYAI_M1_DSN", "https://api.chat.flymy.ai/") + + response_data["url"] + ) diff --git a/flymyai/core/clients/m1Client.py b/flymyai/core/clients/m1Client.py new file mode 100644 index 0000000..582d5f2 --- /dev/null +++ b/flymyai/core/clients/m1Client.py @@ -0,0 +1,116 @@ +import os +import time +from typing import Union, Optional +from pathlib import Path + +import httpx + +from flymyai.core._response import FlyMyAIM1Response +from flymyai.core.types.m1 import M1GenerationTask, M1Record, M1Role +from flymyai.core.clients.base_m1_client import BaseM1Client, _predict_timeout + + +class BaseM1SyncClient(BaseM1Client[httpx.Client]): + """Synchronous client for interacting with FlyMyAI M1 chat generation models. + Handles image uploads, chat history tracking, and result polling. + """ + + def _construct_client(self): + return httpx.Client( + http2=True, + headers=self._headers, + base_url=os.getenv("FLYMYAI_M1_DSN", "https://api.chat.flymy.ai/"), + timeout=_predict_timeout, + ) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._client.close() + + def generate(self, prompt: str, image: Optional[Union[str, Path]] = None): + """Submit a chat prompt with optional image input and return the final generation result. + + :param prompt: User input string to send to the model. + :param image: Local image file (as `Path`) or remote image URL (as `str`). + :return: FlyMyAIM1Response with generated content and metadata. + """ + self._process_image(image) + self._m1_history.add(M1Record(role=M1Role.user, content=prompt)) + generation_task = self.generation_task() + result = self.generation_task_result(generation_task) + return result + + def _process_image(self, image: Optional[Union[str, Path]]) -> Optional[str]: + if image is None: + return + image_url = None + + if isinstance(image, Path): + image_url = self.upload_image(image) + elif isinstance(image, str): + image_url = image + + self._image = image_url + return image_url + + def generation_task(self) -> M1GenerationTask: + payload = { + "chat_history": self._m1_history.serialize(), + "image_url": self._image, + } + response = self._client.post( + self._generation_path, json=payload, headers=self._headers + ) + response.raise_for_status() + response_data = response.json() + return M1GenerationTask(request_id=response_data["request_id"]) + + def generation_task_result( + self, generation_task: M1GenerationTask + ) -> FlyMyAIM1Response: + while True: + response = self._client.get(self._populate_result_path(generation_task)) + response.raise_for_status() + response_data = response.json() + + if response_data.get("success"): + self._m1_history.add( + M1Record( + role=M1Role.assistant, + content=response_data.get("data", {}).get("text", ""), + ) + ) + if file_url := response_data.get("data", {}).get("file_url", ""): + if not file_url.endswith(".mp4"): + self._image = ( + os.getenv("FLYMYAI_M1_DSN", "https://api.chat.flymy.ai/") + + file_url + ) + return FlyMyAIM1Response.from_httpx(response) + + if response_data.get("error") == "Still processing": + time.sleep(1) + continue + + raise RuntimeError( + f"Generation failed with status {response_data.get('status')}: {response_data.get('error')}" + ) + + def upload_image(self, image: Union[str, Path]) -> str: + """Upload a local image file and receive a hosted URL. + + :param image: Local file path (as `str` or `Path`). + :return: Hosted image URL returned by the server. + """ + image_path = Path(image) if isinstance(image, str) else image + with image_path.open("rb") as f: + files = {"file": (image_path.name, f, "image/png")} + response = self._client.post(self._image_upload_path, files=files) + response.raise_for_status() + response_data = response.json() + return ( + os.getenv("FLYMYAI_M1_DSN", "https://api.chat.flymy.ai/") + + response_data["url"] + ) diff --git a/flymyai/core/models/m1_history.py b/flymyai/core/models/m1_history.py new file mode 100644 index 0000000..525f5b1 --- /dev/null +++ b/flymyai/core/models/m1_history.py @@ -0,0 +1,27 @@ +from dataclasses import dataclass +from typing import List, Dict + +from flymyai.core.types.m1 import M1Record + + +@dataclass +class M1History: + _records: List[M1Record] + + def __init__(self): + self._records = [] + + def add(self, M1Record): + self._records.append(M1Record) + + def serialize(self) -> List[Dict]: + return [ + { + "role": record.role.value, + "content": record.content, + } + for record in self._records + ] + + def pop(self) -> M1Record: + return self._records.pop() diff --git a/flymyai/core/stream_iterators/AsyncPredictionStream.py b/flymyai/core/stream_iterators/AsyncPredictionStream.py index 433a15d..35fae69 100644 --- a/flymyai/core/stream_iterators/AsyncPredictionStream.py +++ b/flymyai/core/stream_iterators/AsyncPredictionStream.py @@ -13,7 +13,6 @@ from flymyai.core.stream_iterators.exceptions import StreamCancellationException from flymyai.core.types.event_types import EventType - _AsyncEventCallbackType = TypeVar( "_AsyncEventCallbackType", bound=Union[ diff --git a/flymyai/core/stream_iterators/PredictionStream.py b/flymyai/core/stream_iterators/PredictionStream.py index 78dd1b3..096ce5d 100644 --- a/flymyai/core/stream_iterators/PredictionStream.py +++ b/flymyai/core/stream_iterators/PredictionStream.py @@ -12,7 +12,6 @@ from flymyai.core.stream_iterators.exceptions import StreamCancellationException from flymyai.core.types.event_types import EventType - _SyncEventCallbackType = TypeVar( "_SyncEventCallbackType", bound=Callable[[PredictionEvent], None] ) diff --git a/flymyai/core/types/m1.py b/flymyai/core/types/m1.py new file mode 100644 index 0000000..131bfcd --- /dev/null +++ b/flymyai/core/types/m1.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Optional + + +class M1Role(Enum): + user = "user" + assistant = "assistant" + + +@dataclass +class M1Record: + role: M1Role + content: str + + +@dataclass +class M1GenerationTask: + request_id: str diff --git a/pyproject.toml b/pyproject.toml index 4a42885..6aef77f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ httpx = ">=0.26.0" [tool.poetry.group.dev.dependencies] tomli = ">=2.0.1" pytest-asyncio = ">=0.23.7" +respx = "^0.22.0" [build-system] requires = ["poetry-core"] diff --git a/tests/fixtures/test_flymyai_m1_client.json b/tests/fixtures/test_flymyai_m1_client.json new file mode 100644 index 0000000..6b905ba --- /dev/null +++ b/tests/fixtures/test_flymyai_m1_client.json @@ -0,0 +1,6 @@ +{ + "m1_env_fixture": "https://mock.flymy.ai/", + "client_auth_fixture": { + "apikey": "test-api-key-123" + } +} diff --git a/tests/test_flymyai_m1_client.py b/tests/test_flymyai_m1_client.py new file mode 100644 index 0000000..bef3f66 --- /dev/null +++ b/tests/test_flymyai_m1_client.py @@ -0,0 +1,159 @@ +import os +import pathlib + +import pytest +import respx +import httpx +from flymyai import m1_client, async_m1_client +from flymyai.core.types.m1 import M1Role +from .FixtureFactory import FixtureFactory +from httpx import Response + +factory = FixtureFactory(__file__) + + +@pytest.fixture +def m1_env_fixture(): + os.environ["FLYMYAI_M1_DSN"] = factory("m1_env_fixture") + + +@pytest.fixture +def test_prompt(): + return "Hello, generate something interesting!" + + +@pytest.fixture +def dummy_image_path(): + return pathlib.Path(__file__).parent / "fixtures" / "Untitled.png" + + +@pytest.fixture +def apikey_fixture(): + auth_data = factory("client_auth_fixture") + return auth_data.get("apikey", "dummy-test-apikey") + + +@respx.mock +def test_generate_text_only(m1_env_fixture, test_prompt, apikey_fixture): + base_url = os.getenv("FLYMYAI_M1_DSN") + + respx.post(f"{base_url}chat").mock( + return_value=Response(200, json={"request_id": "abc123"}) + ) + respx.get(f"{base_url}chat-result/abc123").mock( + return_value=Response( + 200, json={"success": True, "data": {"text": "This is a response"}} + ) + ) + + client = m1_client(apikey_fixture) + response = client.generate(prompt=test_prompt) + + assert response.data.text == "This is a response" + assert response.success + assert client._m1_history._records[0].role == M1Role.user + assert client._m1_history._records[1].role == M1Role.assistant + + +@respx.mock +def test_generate_with_image( + m1_env_fixture, test_prompt, dummy_image_path, apikey_fixture +): + base_url = os.getenv("FLYMYAI_M1_DSN") + + respx.post(f"{base_url}upload-image").mock( + return_value=Response(200, json={"url": "/static/images/xyz.png"}) + ) + respx.post(f"{base_url}chat").mock( + return_value=Response(200, json={"request_id": "img123"}) + ) + respx.get(f"{base_url}chat-result/img123").mock( + return_value=Response( + 200, json={"success": True, "data": {"text": "Image-based response"}} + ) + ) + + client = m1_client(apikey_fixture) + response = client.generate(prompt=test_prompt, image=dummy_image_path) + + assert response.data.text == "Image-based response" + + +@respx.mock +def test_image_upload(m1_env_fixture, dummy_image_path, apikey_fixture): + base_url = os.getenv("FLYMYAI_M1_DSN") + + respx.post(f"{base_url}upload-image").mock( + return_value=Response(200, json={"url": "/uploads/fake123.png"}) + ) + + client = m1_client(apikey_fixture) + image_url = client.upload_image(dummy_image_path) + + assert image_url.endswith("/uploads/fake123.png") + + +@pytest.mark.asyncio +@respx.mock +async def test_async_generate_text_only(m1_env_fixture, test_prompt, apikey_fixture): + base_url = os.getenv("FLYMYAI_M1_DSN") + + respx.post(f"{base_url}chat").mock( + return_value=Response(200, json={"request_id": "async123"}) + ) + respx.get(f"{base_url}chat-result/async123").mock( + return_value=Response( + 200, json={"success": True, "data": {"text": "Async response"}} + ) + ) + + client = async_m1_client(apikey=apikey_fixture) + async with client: + response = await client.generate(prompt=test_prompt) + + assert response.data.text == "Async response" + assert response.success + assert client._m1_history._records[0].role == M1Role.user + assert client._m1_history._records[1].role == M1Role.assistant + + +@pytest.mark.asyncio +@respx.mock +async def test_async_generate_with_image( + m1_env_fixture, test_prompt, dummy_image_path, apikey_fixture +): + base_url = os.getenv("FLYMYAI_M1_DSN") + + respx.post(f"{base_url}upload-image").mock( + return_value=Response(200, json={"url": "/static/images/xyz.png"}) + ) + respx.post(f"{base_url}chat").mock( + return_value=Response(200, json={"request_id": "img123"}) + ) + respx.get(f"{base_url}chat-result/img123").mock( + return_value=Response( + 200, json={"success": True, "data": {"text": "Image-based async response"}} + ) + ) + + client = async_m1_client(apikey=apikey_fixture) + async with client: + response = await client.generate(prompt=test_prompt, image=dummy_image_path) + + assert response.data.text == "Image-based async response" + + +@pytest.mark.asyncio +@respx.mock +async def test_async_image_upload(m1_env_fixture, dummy_image_path, apikey_fixture): + base_url = os.getenv("FLYMYAI_M1_DSN") + + respx.post(f"{base_url}upload-image").mock( + return_value=Response(200, json={"url": "/uploads/fake123.png"}) + ) + + client = async_m1_client(apikey=apikey_fixture) + async with client: + image_url = await client.upload_image(dummy_image_path) + + assert image_url.endswith("/uploads/fake123.png")