diff --git a/packages/inference/src/lib/getProviderHelper.ts b/packages/inference/src/lib/getProviderHelper.ts index ad0eb89a6d..619bce2133 100644 --- a/packages/inference/src/lib/getProviderHelper.ts +++ b/packages/inference/src/lib/getProviderHelper.ts @@ -138,6 +138,7 @@ export const PROVIDERS: Record): Record { + return { + input: { + ...omit(params.args, ["inputs", "parameters"]), + ...params.args.parameters, + input_image: params.args.inputs, // This will be processed in preparePayloadAsync + }, + version: params.model.includes(":") ? params.model.split(":")[1] : undefined, + }; + } + + async preparePayloadAsync(args: ImageToImageArgs): Promise { + const { inputs, ...restArgs } = args; + + // Convert Blob to base64 data URL + const bytes = new Uint8Array(await inputs.arrayBuffer()); + const base64 = base64FromBytes(bytes); + const imageInput = `data:${inputs.type || "image/jpeg"};base64,${base64}`; + + return { + ...restArgs, + inputs: imageInput, + }; + } + + override async getResponse(response: ReplicateOutput): Promise { + if ( + typeof response === "object" && + !!response && + "output" in response && + Array.isArray(response.output) && + response.output.length > 0 && + typeof response.output[0] === "string" + ) { + const urlResponse = await fetch(response.output[0]); + return await urlResponse.blob(); + } + + if ( + typeof response === "object" && + !!response && + "output" in response && + typeof response.output === "string" && + isUrl(response.output) + ) { + const urlResponse = await fetch(response.output); + return await urlResponse.blob(); + } + + throw new InferenceClientProviderOutputError("Received malformed response from Replicate image-to-image API"); + } +} diff --git a/packages/inference/test/InferenceClient.spec.ts b/packages/inference/test/InferenceClient.spec.ts index 8f9809f23c..29cccce8eb 100644 --- a/packages/inference/test/InferenceClient.spec.ts +++ b/packages/inference/test/InferenceClient.spec.ts @@ -1277,6 +1277,18 @@ describe.skip("InferenceClient", () => { expect(res).toBeInstanceOf(Blob); }); + + it("imageToImage - FLUX Kontext Dev", async () => { + const res = await client.imageToImage({ + model: "black-forest-labs/flux-kontext-dev", + provider: "replicate", + inputs: new Blob([readTestFile("stormtrooper_depth.png")], { type: "image/png" }), + parameters: { + prompt: "Change the stormtrooper armor to golden color while keeping the same pose and helmet design", + }, + }); + expect(res).toBeInstanceOf(Blob); + }); }, TIMEOUT );