Skip to content

[Inference Providers] provider="auto" #1390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 58 additions & 19 deletions packages/inference/src/lib/getInferenceProviderMapping.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import type { WidgetType } from "@huggingface/tasks";
import type { InferenceProvider, ModelId } from "../types";
import { HF_HUB_URL } from "../config";
import { HARDCODED_MODEL_INFERENCE_MAPPING } from "../providers/consts";
import { EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS } from "../providers/hf-inference";
import type { InferenceProvider, InferenceProviderPolicy, ModelId } from "../types";
import { typedInclude } from "../utils/typedInclude";

export const inferenceProviderMappingCache = new Map<ModelId, InferenceProviderMapping>();
Expand All @@ -20,44 +20,62 @@ export interface InferenceProviderModelMapping {
task: WidgetType;
}

export async function getInferenceProviderMapping(
params: {
accessToken?: string;
modelId: ModelId;
provider: InferenceProvider;
task: WidgetType;
},
options: {
export async function fetchInferenceProviderMappingForModel(
modelId: ModelId,
accessToken?: string,
options?: {
fetch?: (input: RequestInfo, init?: RequestInit) => Promise<Response>;
}
): Promise<InferenceProviderModelMapping | null> {
if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
}
): Promise<InferenceProviderMapping> {
let inferenceProviderMapping: InferenceProviderMapping | null;
if (inferenceProviderMappingCache.has(params.modelId)) {
if (inferenceProviderMappingCache.has(modelId)) {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
inferenceProviderMapping = inferenceProviderMappingCache.get(params.modelId)!;
inferenceProviderMapping = inferenceProviderMappingCache.get(modelId)!;
} else {
const resp = await (options?.fetch ?? fetch)(
`${HF_HUB_URL}/api/models/${params.modelId}?expand[]=inferenceProviderMapping`,
`${HF_HUB_URL}/api/models/${modelId}?expand[]=inferenceProviderMapping`,
{
headers: params.accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${params.accessToken}` } : {},
headers: accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${accessToken}` } : {},
}
);
if (resp.status === 404) {
throw new Error(`Model ${params.modelId} does not exist`);
throw new Error(`Model ${modelId} does not exist`);
}
inferenceProviderMapping = await resp
.json()
.then((json) => json.inferenceProviderMapping)
.catch(() => null);

if (inferenceProviderMapping) {
inferenceProviderMappingCache.set(modelId, inferenceProviderMapping);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if provider="auto", we call fetchInferenceProviderMappingForModel 2 times: once to resolve the provider and a second time in makeRequestOptions. this cache avoids the extra HTTP call.

}
}

if (!inferenceProviderMapping) {
throw new Error(`We have not been able to find inference provider information for model ${params.modelId}.`);
throw new Error(`We have not been able to find inference provider information for model ${modelId}.`);
}
return inferenceProviderMapping;
}

export async function getInferenceProviderMapping(
params: {
accessToken?: string;
modelId: ModelId;
provider: InferenceProvider;
task: WidgetType;
},
options: {
fetch?: (input: RequestInfo, init?: RequestInit) => Promise<Response>;
}
): Promise<InferenceProviderModelMapping | null> {
if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
}
const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(
params.modelId,
params.accessToken,
options
);
const providerMapping = inferenceProviderMapping[params.provider];
if (providerMapping) {
const equivalentTasks =
Expand All @@ -78,3 +96,24 @@ export async function getInferenceProviderMapping(
}
return null;
}

export async function resolveProvider(
Copy link
Contributor Author

@hanouticelina hanouticelina Apr 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this could be done in getProviderHelper as we did for the python client, but that would make the function async and we would have to update the snippets generation as well

provider?: InferenceProviderPolicy,
modelId?: string
): Promise<InferenceProvider> {
if (!provider) {
console.log(
"Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers."
);
provider = "auto";
}
if (provider === "auto") {
if (!modelId) {
throw new Error("Specifying a model is required when provider is 'auto'");
}
const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(modelId);
provider = Object.keys(inferenceProviderMapping)[0] as InferenceProvider;
console.log("Auto-selected provider:", provider);
}
return provider;
}
7 changes: 4 additions & 3 deletions packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ export async function makeRequestOptions(
task?: InferenceTask;
}
): Promise<{ url: string; info: RequestInit }> {
const { provider: maybeProvider, model: maybeModel } = args;
const provider = maybeProvider ?? "hf-inference";
const { model: maybeModel } = args;
const provider = providerHelper.provider;
const { task } = options ?? {};

// Validate inputs
Expand Down Expand Up @@ -113,8 +113,9 @@ export function makeRequestOptionsFromResolvedModel(
): { url: string; info: RequestInit } {
const { accessToken, endpointUrl, provider: maybeProvider, model, ...remainingArgs } = args;
void model;
void maybeProvider;

const provider = maybeProvider ?? "hf-inference";
const provider = providerHelper.provider;

const { includeCredentials, task, signal, billTo } = options ?? {};
const authMethod = (() => {
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/src/providers/providerHelper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ import { toArray } from "../utils/toArray";
*/
export abstract class TaskProviderHelper {
constructor(
private provider: InferenceProvider,
readonly provider: InferenceProvider,
private baseUrl: string,
readonly clientSideRoutingOnly: boolean = false
) {}
Expand Down
4 changes: 3 additions & 1 deletion packages/inference/src/tasks/audio/audioClassification.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { AudioClassificationInput, AudioClassificationOutput } from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { BaseArgs, Options } from "../../types";
import { innerRequest } from "../../utils/request";
Expand All @@ -15,7 +16,8 @@ export async function audioClassification(
args: AudioClassificationArgs,
options?: Options
): Promise<AudioClassificationOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-classification");
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "audio-classification");
const payload = preparePayload(args);
const { data: res } = await innerRequest<AudioClassificationOutput>(payload, providerHelper, {
...options,
Expand Down
5 changes: 4 additions & 1 deletion packages/inference/src/tasks/audio/audioToAudio.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { BaseArgs, Options } from "../../types";
import { innerRequest } from "../../utils/request";
Expand Down Expand Up @@ -36,7 +37,9 @@ export interface AudioToAudioOutput {
* Example model: speechbrain/sepformer-wham does audio source separation.
*/
export async function audioToAudio(args: AudioToAudioArgs, options?: Options): Promise<AudioToAudioOutput[]> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-to-audio");
const model = "inputs" in args ? args.model : undefined;
const provider = await resolveProvider(args.provider, model);
const providerHelper = getProviderHelper(provider, "audio-to-audio");
const payload = preparePayload(args);
const { data: res } = await innerRequest<AudioToAudioOutput>(payload, providerHelper, {
...options,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { AutomaticSpeechRecognitionInput, AutomaticSpeechRecognitionOutput } from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import { FAL_AI_SUPPORTED_BLOB_TYPES } from "../../providers/fal-ai";
Expand All @@ -18,7 +19,8 @@ export async function automaticSpeechRecognition(
args: AutomaticSpeechRecognitionArgs,
options?: Options
): Promise<AutomaticSpeechRecognitionOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "automatic-speech-recognition");
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "automatic-speech-recognition");
const payload = await buildPayload(args);
const { data: res } = await innerRequest<AutomaticSpeechRecognitionOutput>(payload, providerHelper, {
...options,
Expand Down
3 changes: 2 additions & 1 deletion packages/inference/src/tasks/audio/textToSpeech.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { TextToSpeechInput } from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { BaseArgs, Options } from "../../types";
import { innerRequest } from "../../utils/request";
Expand All @@ -12,7 +13,7 @@ interface OutputUrlTextToSpeechGeneration {
* Recommended model: espnet/kan-bayashi_ljspeech_vits
*/
export async function textToSpeech(args: TextToSpeechArgs, options?: Options): Promise<Blob> {
const provider = args.provider ?? "hf-inference";
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "text-to-speech");
const { data: res } = await innerRequest<Blob | OutputUrlTextToSpeechGeneration>(args, providerHelper, {
...options,
Expand Down
4 changes: 3 additions & 1 deletion packages/inference/src/tasks/custom/request.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { InferenceTask, Options, RequestArgs } from "../../types";
import { innerRequest } from "../../utils/request";
Expand All @@ -16,7 +17,8 @@ export async function request<T>(
console.warn(
"The request method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
);
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", options?.task);
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, options?.task);
const result = await innerRequest<T>(args, providerHelper, options);
return result.data;
}
4 changes: 3 additions & 1 deletion packages/inference/src/tasks/custom/streamingRequest.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { InferenceTask, Options, RequestArgs } from "../../types";
import { innerStreamingRequest } from "../../utils/request";
Expand All @@ -16,6 +17,7 @@ export async function* streamingRequest<T>(
console.warn(
"The streamingRequest method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
);
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", options?.task);
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, options?.task);
yield* innerStreamingRequest(args, providerHelper, options);
}
4 changes: 3 additions & 1 deletion packages/inference/src/tasks/cv/imageClassification.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { ImageClassificationInput, ImageClassificationOutput } from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { BaseArgs, Options } from "../../types";
import { innerRequest } from "../../utils/request";
Expand All @@ -14,7 +15,8 @@ export async function imageClassification(
args: ImageClassificationArgs,
options?: Options
): Promise<ImageClassificationOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-classification");
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "image-classification");
const payload = preparePayload(args);
const { data: res } = await innerRequest<ImageClassificationOutput>(payload, providerHelper, {
...options,
Expand Down
4 changes: 3 additions & 1 deletion packages/inference/src/tasks/cv/imageSegmentation.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { ImageSegmentationInput, ImageSegmentationOutput } from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { BaseArgs, Options } from "../../types";
import { innerRequest } from "../../utils/request";
Expand All @@ -14,7 +15,8 @@ export async function imageSegmentation(
args: ImageSegmentationArgs,
options?: Options
): Promise<ImageSegmentationOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-segmentation");
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "image-segmentation");
const payload = preparePayload(args);
const { data: res } = await innerRequest<ImageSegmentationOutput>(payload, providerHelper, {
...options,
Expand Down
4 changes: 3 additions & 1 deletion packages/inference/src/tasks/cv/imageToImage.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { ImageToImageInput } from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { BaseArgs, Options, RequestArgs } from "../../types";
import { base64FromBytes } from "../../utils/base64FromBytes";
Expand All @@ -11,7 +12,8 @@ export type ImageToImageArgs = BaseArgs & ImageToImageInput;
* Recommended model: lllyasviel/sd-controlnet-depth
*/
export async function imageToImage(args: ImageToImageArgs, options?: Options): Promise<Blob> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-image");
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "image-to-image");
let reqArgs: RequestArgs;
if (!args.parameters) {
reqArgs = {
Expand Down
4 changes: 3 additions & 1 deletion packages/inference/src/tasks/cv/imageToText.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { ImageToTextInput, ImageToTextOutput } from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { BaseArgs, Options } from "../../types";
import { innerRequest } from "../../utils/request";
Expand All @@ -10,7 +11,8 @@ export type ImageToTextArgs = BaseArgs & (ImageToTextInput | LegacyImageInput);
* This task reads some image input and outputs the text caption.
*/
export async function imageToText(args: ImageToTextArgs, options?: Options): Promise<ImageToTextOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-text");
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "image-to-text");
const payload = preparePayload(args);
const { data: res } = await innerRequest<[ImageToTextOutput]>(payload, providerHelper, {
...options,
Expand Down
4 changes: 3 additions & 1 deletion packages/inference/src/tasks/cv/objectDetection.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { ObjectDetectionInput, ObjectDetectionOutput } from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { BaseArgs, Options } from "../../types";
import { innerRequest } from "../../utils/request";
Expand All @@ -11,7 +12,8 @@ export type ObjectDetectionArgs = BaseArgs & (ObjectDetectionInput | LegacyImage
* Recommended model: facebook/detr-resnet-50
*/
export async function objectDetection(args: ObjectDetectionArgs, options?: Options): Promise<ObjectDetectionOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "object-detection");
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "object-detection");
const payload = preparePayload(args);
const { data: res } = await innerRequest<ObjectDetectionOutput>(payload, providerHelper, {
...options,
Expand Down
3 changes: 2 additions & 1 deletion packages/inference/src/tasks/cv/textToImage.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { TextToImageInput } from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import { makeRequestOptions } from "../../lib/makeRequestOptions";
import type { BaseArgs, Options } from "../../types";
Expand All @@ -23,7 +24,7 @@ export async function textToImage(
options?: TextToImageOptions & { outputType?: undefined | "blob" }
): Promise<Blob>;
export async function textToImage(args: TextToImageArgs, options?: TextToImageOptions): Promise<Blob | string> {
const provider = args.provider ?? "hf-inference";
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "text-to-image");
const { data: res } = await innerRequest<Record<string, unknown>>(args, providerHelper, {
...options,
Expand Down
3 changes: 2 additions & 1 deletion packages/inference/src/tasks/cv/textToVideo.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { TextToVideoInput } from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import { makeRequestOptions } from "../../lib/makeRequestOptions";
import type { FalAiQueueOutput } from "../../providers/fal-ai";
Expand All @@ -12,7 +13,7 @@ export type TextToVideoArgs = BaseArgs & TextToVideoInput;
export type TextToVideoOutput = Blob;

export async function textToVideo(args: TextToVideoArgs, options?: Options): Promise<TextToVideoOutput> {
const provider = args.provider ?? "hf-inference";
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "text-to-video");
const { data: response } = await innerRequest<FalAiQueueOutput | ReplicateOutput | NovitaOutput>(
args,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { ZeroShotImageClassificationInput, ZeroShotImageClassificationOutput } from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { BaseArgs, Options, RequestArgs } from "../../types";
import { base64FromBytes } from "../../utils/base64FromBytes";
Expand Down Expand Up @@ -44,7 +45,8 @@ export async function zeroShotImageClassification(
args: ZeroShotImageClassificationArgs,
options?: Options
): Promise<ZeroShotImageClassificationOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-image-classification");
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "zero-shot-image-classification");
const payload = await preparePayload(args);
const { data: res } = await innerRequest<ZeroShotImageClassificationOutput>(payload, providerHelper, {
...options,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import type {
DocumentQuestionAnsweringInputData,
DocumentQuestionAnsweringOutput,
} from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { BaseArgs, Options, RequestArgs } from "../../types";
import { base64FromBytes } from "../../utils/base64FromBytes";
Expand All @@ -19,7 +20,8 @@ export async function documentQuestionAnswering(
args: DocumentQuestionAnsweringArgs,
options?: Options
): Promise<DocumentQuestionAnsweringOutput[number]> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "document-question-answering");
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "document-question-answering");
const reqArgs: RequestArgs = {
...args,
inputs: {
Expand Down
Loading