From 1df32d7f553d1caa9c527b2b176a0cf63308a368 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 21 Feb 2025 12:25:15 +0100 Subject: [PATCH] wip: add HW requirements calculator --- packages/hub/index.ts | 19 +++++ packages/hub/src/lib/hardware-requirements.ts | 72 +++++++++++++++++++ 2 files changed, 91 insertions(+) create mode 100644 packages/hub/src/lib/hardware-requirements.ts diff --git a/packages/hub/index.ts b/packages/hub/index.ts index 3bd16e178a..758e4eb4b1 100644 --- a/packages/hub/index.ts +++ b/packages/hub/index.ts @@ -1 +1,20 @@ export * from "./src"; + +// TODO: remove this before merging +// Run with: npx ts-node index.ts +import { getHardwareRequirements } from "./src/lib/hardware-requirements"; +(async () => { + const models = [ + "hexgrad/Kokoro-82M", + "microsoft/OmniParser-v2.0", + "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", + "deepseek-ai/DeepSeek-R1-Distill-Llama-8B", + "NousResearch/DeepHermes-3-Llama-3-8B-Preview", + "unsloth/DeepSeek-R1-Distill-Llama-8B-unsloth-bnb-4bit", + ]; + + for (const name of models) { + const mem = await getHardwareRequirements({ name }); + console.log('mem', JSON.stringify(mem, null, 2)); + } +})(); diff --git a/packages/hub/src/lib/hardware-requirements.ts b/packages/hub/src/lib/hardware-requirements.ts new file mode 100644 index 0000000000..1f87c4c6cc --- /dev/null +++ b/packages/hub/src/lib/hardware-requirements.ts @@ -0,0 +1,72 @@ +import { ListFileEntry, listFiles } from "./list-files"; + +export interface MemoryRequirements { + minimumGigabytes: number; + recommendedGigabytes: number; +}; + +export interface HardwareRequirements { + name: string; + memory: MemoryRequirements; +}; + +export async function getHardwareRequirements(params: { + /** + * The model name in the format of `namespace/repo`. + */ + name: string; + /** + * The context size in tokens, default to 2048. + */ + contextSize?: number; +}) { + const files = await getFiles(params.name); + const hasSafetensors = files.some((file) => file.path.endsWith(".safetensors")); + const hasPytorch = files.some((file) => file.path.endsWith(".pth")); + + // Get the total size of the model weight in bytes (we don't care about quantization scheme) + let totalWeightBytes = 0; + if (hasSafetensors) { + totalWeightBytes = sumFileSize(files.filter((file) => file.path.endsWith(".safetensors"))); + } else if (hasPytorch) { + totalWeightBytes = sumFileSize(files.filter((file) => file.path.endsWith(".pth"))); + } + + // Calculate the memory for context window + // TODO: this also scales in function of weight, to be implemented later + const contextWindow = params.contextSize ?? 2048; + const batchSize = 256; // a bit overhead for batching + const contextMemoryBytes = (contextWindow + batchSize) * 0.5 * 1e6; + + // Calculate the memory overhead + const osOverheadBytes = Math.max(512 * 1e6, 0.2 * totalWeightBytes); + + // Calculate the total memory requirements + const totalMemoryGb = (totalWeightBytes + contextMemoryBytes + osOverheadBytes) / 1e9; + + return { + name: params.name, + memory: { + minimumGigabytes: totalMemoryGb, + recommendedGigabytes: totalMemoryGb * 1.1, + }, + } satisfies HardwareRequirements; +} + +async function getFiles(name: string): Promise { + const files: ListFileEntry[] = []; + const cursor = listFiles({ + repo: { + name, + type: "model", + }, + }); + for await (const entry of cursor) { + files.push(entry); + } + return files; +}; + +function sumFileSize(files: ListFileEntry[]): number { + return files.reduce((total, file) => total + file.size, 0); +}