Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/cozy-hands-give.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"workers-ai-provider": patch
---

Fix image generation output types handling
81 changes: 51 additions & 30 deletions packages/workers-ai-provider/src/workersai-image-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,14 @@ export class WorkersAIImageModel implements ImageModelV2 {
}

const generateImage = async () => {
const outputStream: ReadableStream<Uint8Array> = await this.config.binding.run(
this.modelId,
{
height,
prompt,
seed,
width,
},
);

// Convert the output stream to a Uint8Array.
return streamToUint8Array(outputStream);
const output = await this.config.binding.run(this.modelId, {
height,
prompt,
seed,
width,
});

return toUint8Array(output as ReadableStream<Uint8Array> | Uint8Array | ArrayBuffer | { image: string });
};

const images: Uint8Array[] = await Promise.all(
Expand Down Expand Up @@ -90,25 +86,50 @@ function parseInteger(value?: string) {
return Number.isInteger(number) ? number : undefined;
}

async function streamToUint8Array(stream: ReadableStream<Uint8Array>): Promise<Uint8Array> {
const reader = stream.getReader();
const chunks: Uint8Array[] = [];
let totalLength = 0;

// Read the stream until it is finished.
while (true) {
const { done, value } = await reader.read();
if (done) break;
chunks.push(value);
totalLength += value.length;
async function toUint8Array(
output: ReadableStream<Uint8Array> | Uint8Array | ArrayBuffer | { image: string },
): Promise<Uint8Array> {
// Already a Uint8Array
if (output instanceof Uint8Array) {
return output;
}

// ArrayBuffer - wrap it
if (output instanceof ArrayBuffer) {
return new Uint8Array(output);
}

// Allocate a new Uint8Array to hold all the data.
const result = new Uint8Array(totalLength);
let offset = 0;
for (const chunk of chunks) {
result.set(chunk, offset);
offset += chunk.length;
// REST API response with base64 image
if (output && typeof output === "object" && "image" in output && typeof output.image === "string") {
const binaryString = atob(output.image);
const bytes = new Uint8Array(binaryString.length);
for (let i = 0; i < binaryString.length; i++) {
bytes[i] = binaryString.charCodeAt(i);
}
return bytes;
}
return result;

// ReadableStream - read all chunks
if (output && typeof (output as ReadableStream).getReader === "function") {
const reader = (output as ReadableStream<Uint8Array>).getReader();
const chunks: Uint8Array[] = [];
let totalLength = 0;

while (true) {
const { done, value } = await reader.read();
if (done) break;
chunks.push(value);
totalLength += value.length;
}

const result = new Uint8Array(totalLength);
let offset = 0;
for (const chunk of chunks) {
result.set(chunk, offset);
offset += chunk.length;
}
return result;
}

throw new Error(`Unexpected output type from image model: ${typeof output}`);
}
131 changes: 131 additions & 0 deletions packages/workers-ai-provider/test/image-generation.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import { experimental_generateImage as generateImage } from "ai";
import { HttpResponse, http } from "msw";
import { setupServer } from "msw/node";
import { afterAll, afterEach, beforeAll, describe, expect, it } from "vitest";
import { createWorkersAI } from "../src/index";

const TEST_ACCOUNT_ID = "test-account-id";
const TEST_API_KEY = "test-api-key";
const TEST_IMAGE_MODEL = "@cf/black-forest-labs/flux-1-schnell";

// Base64 encoded 1x1 red PNG for testing
const TEST_IMAGE_BASE64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==";

const imageGenerationHandler = http.post(
`https://api.cloudflare.com/client/v4/accounts/${TEST_ACCOUNT_ID}/ai/run/${TEST_IMAGE_MODEL}`,
async () => {
return HttpResponse.json({ result: { image: TEST_IMAGE_BASE64 } });
},
);

const server = setupServer(imageGenerationHandler);

describe("REST API - Image Generation Tests", () => {
beforeAll(() => server.listen());
afterEach(() => server.resetHandlers());
afterAll(() => server.close());

it("should generate an image", async () => {
const workersai = createWorkersAI({
accountId: TEST_ACCOUNT_ID,
apiKey: TEST_API_KEY,
});

const result = await generateImage({
model: workersai.image(TEST_IMAGE_MODEL),
prompt: "A futuristic city",
size: "512x512",
});

expect(result.images).toHaveLength(1);
expect(result.images[0].uint8Array).toBeInstanceOf(Uint8Array);
expect(result.images[0].uint8Array.length).toBeGreaterThan(0);
});
});

describe("Binding - Image Generation Tests", () => {
it("should handle Uint8Array output directly", async () => {
const expectedData = new Uint8Array([1, 2, 3, 4, 5]);

const workersai = createWorkersAI({
binding: {
run: async () => expectedData,
},
});

const result = await generateImage({
model: workersai.image(TEST_IMAGE_MODEL),
prompt: "test image",
size: "512x512",
});

expect(result.images).toHaveLength(1);
expect(result.images[0].uint8Array).toEqual(expectedData);
});

it("should handle ArrayBuffer output", async () => {
const data = new Uint8Array([10, 20, 30, 40]);

const workersai = createWorkersAI({
binding: {
run: async () => data.buffer,
},
});

const result = await generateImage({
model: workersai.image(TEST_IMAGE_MODEL),
prompt: "test image",
size: "256x256",
});

expect(result.images).toHaveLength(1);
expect(result.images[0].uint8Array).toEqual(data);
});

it("should handle ReadableStream output", async () => {
const chunk1 = new Uint8Array([1, 2, 3]);
const chunk2 = new Uint8Array([4, 5, 6]);
const expectedResult = new Uint8Array([1, 2, 3, 4, 5, 6]);

const workersai = createWorkersAI({
binding: {
run: async () => {
return new ReadableStream<Uint8Array>({
start(controller) {
controller.enqueue(chunk1);
controller.enqueue(chunk2);
controller.close();
},
});
},
},
});

const result = await generateImage({
model: workersai.image(TEST_IMAGE_MODEL),
prompt: "test image",
size: "512x512",
});

expect(result.images).toHaveLength(1);
expect(result.images[0].uint8Array).toEqual(expectedResult);
});

it("should handle base64 image response (REST API format)", async () => {
const workersai = createWorkersAI({
binding: {
run: async () => ({ image: TEST_IMAGE_BASE64 }),
},
});

const result = await generateImage({
model: workersai.image(TEST_IMAGE_MODEL),
prompt: "test image",
size: "512x512",
});

expect(result.images).toHaveLength(1);
expect(result.images[0].uint8Array).toBeInstanceOf(Uint8Array);
expect(result.images[0].uint8Array.length).toBeGreaterThan(0);
});
});
Loading