From 3b0d48de79d914abf9a992eeccdcca3e4f68b962 Mon Sep 17 00:00:00 2001 From: TheHackerPuppy <66805474+TheHackerPuppy@users.noreply.github.com> Date: Thu, 27 Nov 2025 13:53:35 +0100 Subject: [PATCH 1/2] fix(workers-ai-provider): image generation Refactored response handling to support non-stream outputs and added missing unit test. --- .../src/workersai-image-model.ts | 81 +++++++---- .../test/image-generation.test.ts | 131 ++++++++++++++++++ 2 files changed, 182 insertions(+), 30 deletions(-) create mode 100644 packages/workers-ai-provider/test/image-generation.test.ts diff --git a/packages/workers-ai-provider/src/workersai-image-model.ts b/packages/workers-ai-provider/src/workersai-image-model.ts index 0ca8a7aa..dcfbe58e 100644 --- a/packages/workers-ai-provider/src/workersai-image-model.ts +++ b/packages/workers-ai-provider/src/workersai-image-model.ts @@ -44,18 +44,14 @@ export class WorkersAIImageModel implements ImageModelV2 { } const generateImage = async () => { - const outputStream: ReadableStream = await this.config.binding.run( - this.modelId, - { - height, - prompt, - seed, - width, - }, - ); - - // Convert the output stream to a Uint8Array. - return streamToUint8Array(outputStream); + const output = await this.config.binding.run(this.modelId, { + height, + prompt, + seed, + width, + }); + + return toUint8Array(output as ReadableStream | Uint8Array | ArrayBuffer | { image: string }); }; const images: Uint8Array[] = await Promise.all( @@ -91,25 +87,50 @@ function parseInteger(value?: string) { return Number.isInteger(number) ? number : undefined; } -async function streamToUint8Array(stream: ReadableStream): Promise { - const reader = stream.getReader(); - const chunks: Uint8Array[] = []; - let totalLength = 0; - - // Read the stream until it is finished. - while (true) { - const { done, value } = await reader.read(); - if (done) break; - chunks.push(value); - totalLength += value.length; +async function toUint8Array( + output: ReadableStream | Uint8Array | ArrayBuffer | { image: string }, +): Promise { + // Already a Uint8Array + if (output instanceof Uint8Array) { + return output; + } + + // ArrayBuffer - wrap it + if (output instanceof ArrayBuffer) { + return new Uint8Array(output); } - // Allocate a new Uint8Array to hold all the data. - const result = new Uint8Array(totalLength); - let offset = 0; - for (const chunk of chunks) { - result.set(chunk, offset); - offset += chunk.length; + // REST API response with base64 image + if (output && typeof output === "object" && "image" in output && typeof output.image === "string") { + const binaryString = atob(output.image); + const bytes = new Uint8Array(binaryString.length); + for (let i = 0; i < binaryString.length; i++) { + bytes[i] = binaryString.charCodeAt(i); + } + return bytes; } - return result; + + // ReadableStream - read all chunks + if (output && typeof (output as ReadableStream).getReader === "function") { + const reader = (output as ReadableStream).getReader(); + const chunks: Uint8Array[] = []; + let totalLength = 0; + + while (true) { + const { done, value } = await reader.read(); + if (done) break; + chunks.push(value); + totalLength += value.length; + } + + const result = new Uint8Array(totalLength); + let offset = 0; + for (const chunk of chunks) { + result.set(chunk, offset); + offset += chunk.length; + } + return result; + } + + throw new Error(`Unexpected output type from image model: ${typeof output}`); } diff --git a/packages/workers-ai-provider/test/image-generation.test.ts b/packages/workers-ai-provider/test/image-generation.test.ts new file mode 100644 index 00000000..9e0bda0e --- /dev/null +++ b/packages/workers-ai-provider/test/image-generation.test.ts @@ -0,0 +1,131 @@ +import { experimental_generateImage as generateImage } from "ai"; +import { HttpResponse, http } from "msw"; +import { setupServer } from "msw/node"; +import { afterAll, afterEach, beforeAll, describe, expect, it } from "vitest"; +import { createWorkersAI } from "../src/index"; + +const TEST_ACCOUNT_ID = "test-account-id"; +const TEST_API_KEY = "test-api-key"; +const TEST_IMAGE_MODEL = "@cf/black-forest-labs/flux-1-schnell"; + +// Base64 encoded 1x1 red PNG for testing +const TEST_IMAGE_BASE64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg=="; + +const imageGenerationHandler = http.post( + `https://api.cloudflare.com/client/v4/accounts/${TEST_ACCOUNT_ID}/ai/run/${TEST_IMAGE_MODEL}`, + async () => { + return HttpResponse.json({ result: { image: TEST_IMAGE_BASE64 } }); + }, +); + +const server = setupServer(imageGenerationHandler); + +describe("REST API - Image Generation Tests", () => { + beforeAll(() => server.listen()); + afterEach(() => server.resetHandlers()); + afterAll(() => server.close()); + + it("should generate an image", async () => { + const workersai = createWorkersAI({ + accountId: TEST_ACCOUNT_ID, + apiKey: TEST_API_KEY, + }); + + const result = await generateImage({ + model: workersai.image(TEST_IMAGE_MODEL), + prompt: "A futuristic city", + size: "512x512", + }); + + expect(result.images).toHaveLength(1); + expect(result.images[0].uint8Array).toBeInstanceOf(Uint8Array); + expect(result.images[0].uint8Array.length).toBeGreaterThan(0); + }); +}); + +describe("Binding - Image Generation Tests", () => { + it("should handle Uint8Array output directly", async () => { + const expectedData = new Uint8Array([1, 2, 3, 4, 5]); + + const workersai = createWorkersAI({ + binding: { + run: async () => expectedData, + }, + }); + + const result = await generateImage({ + model: workersai.image(TEST_IMAGE_MODEL), + prompt: "test image", + size: "512x512", + }); + + expect(result.images).toHaveLength(1); + expect(result.images[0].uint8Array).toEqual(expectedData); + }); + + it("should handle ArrayBuffer output", async () => { + const data = new Uint8Array([10, 20, 30, 40]); + + const workersai = createWorkersAI({ + binding: { + run: async () => data.buffer, + }, + }); + + const result = await generateImage({ + model: workersai.image(TEST_IMAGE_MODEL), + prompt: "test image", + size: "256x256", + }); + + expect(result.images).toHaveLength(1); + expect(result.images[0].uint8Array).toEqual(data); + }); + + it("should handle ReadableStream output", async () => { + const chunk1 = new Uint8Array([1, 2, 3]); + const chunk2 = new Uint8Array([4, 5, 6]); + const expectedResult = new Uint8Array([1, 2, 3, 4, 5, 6]); + + const workersai = createWorkersAI({ + binding: { + run: async () => { + return new ReadableStream({ + start(controller) { + controller.enqueue(chunk1); + controller.enqueue(chunk2); + controller.close(); + }, + }); + }, + }, + }); + + const result = await generateImage({ + model: workersai.image(TEST_IMAGE_MODEL), + prompt: "test image", + size: "512x512", + }); + + expect(result.images).toHaveLength(1); + expect(result.images[0].uint8Array).toEqual(expectedResult); + }); + + it("should handle base64 image response (REST API format)", async () => { + const workersai = createWorkersAI({ + binding: { + run: async () => ({ image: TEST_IMAGE_BASE64 }), + }, + }); + + const result = await generateImage({ + model: workersai.image(TEST_IMAGE_MODEL), + prompt: "test image", + size: "512x512", + }); + + expect(result.images).toHaveLength(1); + expect(result.images[0].uint8Array).toBeInstanceOf(Uint8Array); + expect(result.images[0].uint8Array.length).toBeGreaterThan(0); + }); +}); From 97b79e2815db61a014de2e3d319e9aa5259dfdb1 Mon Sep 17 00:00:00 2001 From: TheHackerPuppy <66805474+TheHackerPuppy@users.noreply.github.com> Date: Thu, 27 Nov 2025 14:23:55 +0100 Subject: [PATCH 2/2] chore: add changeset --- .changeset/cozy-hands-give.md | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 .changeset/cozy-hands-give.md diff --git a/.changeset/cozy-hands-give.md b/.changeset/cozy-hands-give.md new file mode 100644 index 00000000..f60eaa06 --- /dev/null +++ b/.changeset/cozy-hands-give.md @@ -0,0 +1,5 @@ +--- +"workers-ai-provider": patch +--- + +Fix image generation output types handling