From 6ddcd1881dcafefbff6a8981b1af6c9642f9ed7f Mon Sep 17 00:00:00 2001
From: Celina Hanouti <hanouticelina@gmail.com>
Date: Fri, 25 Apr 2025 17:57:56 +0200
Subject: [PATCH 1/6] auto select provider

---
 .../src/lib/getInferenceProviderMapping.ts    | 76 ++++++++++++++-----
 .../inference/src/lib/makeRequestOptions.ts   |  7 +-
 .../inference/src/providers/providerHelper.ts |  2 +-
 .../src/tasks/audio/audioClassification.ts    |  4 +-
 .../inference/src/tasks/audio/audioToAudio.ts |  5 +-
 .../tasks/audio/automaticSpeechRecognition.ts |  4 +-
 .../inference/src/tasks/audio/textToSpeech.ts |  3 +-
 .../inference/src/tasks/custom/request.ts     |  4 +-
 .../src/tasks/custom/streamingRequest.ts      |  4 +-
 .../src/tasks/cv/imageClassification.ts       |  4 +-
 .../src/tasks/cv/imageSegmentation.ts         |  4 +-
 .../inference/src/tasks/cv/imageToImage.ts    |  4 +-
 .../inference/src/tasks/cv/imageToText.ts     |  4 +-
 .../inference/src/tasks/cv/objectDetection.ts |  4 +-
 .../inference/src/tasks/cv/textToImage.ts     |  3 +-
 .../inference/src/tasks/cv/textToVideo.ts     |  3 +-
 .../tasks/cv/zeroShotImageClassification.ts   |  4 +-
 .../multimodal/documentQuestionAnswering.ts   |  4 +-
 .../multimodal/visualQuestionAnswering.ts     |  4 +-
 .../inference/src/tasks/nlp/chatCompletion.ts |  4 +-
 .../src/tasks/nlp/chatCompletionStream.ts     |  4 +-
 .../src/tasks/nlp/featureExtraction.ts        |  4 +-
 packages/inference/src/tasks/nlp/fillMask.ts  |  4 +-
 .../src/tasks/nlp/questionAnswering.ts        |  5 +-
 .../src/tasks/nlp/sentenceSimilarity.ts       |  4 +-
 .../inference/src/tasks/nlp/summarization.ts  |  4 +-
 .../src/tasks/nlp/tableQuestionAnswering.ts   |  4 +-
 .../src/tasks/nlp/textClassification.ts       |  4 +-
 .../inference/src/tasks/nlp/textGeneration.ts |  4 +-
 .../src/tasks/nlp/textGenerationStream.ts     |  4 +-
 .../src/tasks/nlp/tokenClassification.ts      |  4 +-
 .../inference/src/tasks/nlp/translation.ts    |  4 +-
 .../src/tasks/nlp/zeroShotClassification.ts   |  4 +-
 .../tasks/tabular/tabularClassification.ts    |  4 +-
 .../src/tasks/tabular/tabularRegression.ts    |  4 +-
 packages/inference/src/types.ts               |  6 +-
 36 files changed, 162 insertions(+), 56 deletions(-)

diff --git a/packages/inference/src/lib/getInferenceProviderMapping.ts b/packages/inference/src/lib/getInferenceProviderMapping.ts
index d897cb72a0..08a8ee5a8b 100644
--- a/packages/inference/src/lib/getInferenceProviderMapping.ts
+++ b/packages/inference/src/lib/getInferenceProviderMapping.ts
@@ -1,8 +1,8 @@
 import type { WidgetType } from "@huggingface/tasks";
-import type { InferenceProvider, ModelId } from "../types";
 import { HF_HUB_URL } from "../config";
 import { HARDCODED_MODEL_INFERENCE_MAPPING } from "../providers/consts";
 import { EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS } from "../providers/hf-inference";
+import type { InferenceProvider, InferenceProviderPolicy, ModelId } from "../types";
 import { typedInclude } from "../utils/typedInclude";
 
 export const inferenceProviderMappingCache = new Map<ModelId, InferenceProviderMapping>();
@@ -20,44 +20,62 @@ export interface InferenceProviderModelMapping {
 	task: WidgetType;
 }
 
-export async function getInferenceProviderMapping(
-	params: {
-		accessToken?: string;
-		modelId: ModelId;
-		provider: InferenceProvider;
-		task: WidgetType;
-	},
-	options: {
+export async function fetchInferenceProviderMappingForModel(
+	modelId: ModelId,
+	accessToken?: string,
+	options?: {
 		fetch?: (input: RequestInfo, init?: RequestInit) => Promise<Response>;
 	}
-): Promise<InferenceProviderModelMapping | null> {
-	if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
-		return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
-	}
+): Promise<InferenceProviderMapping> {
 	let inferenceProviderMapping: InferenceProviderMapping | null;
-	if (inferenceProviderMappingCache.has(params.modelId)) {
+	if (inferenceProviderMappingCache.has(modelId)) {
 		// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
-		inferenceProviderMapping = inferenceProviderMappingCache.get(params.modelId)!;
+		inferenceProviderMapping = inferenceProviderMappingCache.get(modelId)!;
 	} else {
 		const resp = await (options?.fetch ?? fetch)(
-			`${HF_HUB_URL}/api/models/${params.modelId}?expand[]=inferenceProviderMapping`,
+			`${HF_HUB_URL}/api/models/${modelId}?expand[]=inferenceProviderMapping`,
 			{
-				headers: params.accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${params.accessToken}` } : {},
+				headers: accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${accessToken}` } : {},
 			}
 		);
 		if (resp.status === 404) {
-			throw new Error(`Model ${params.modelId} does not exist`);
+			throw new Error(`Model ${modelId} does not exist`);
 		}
 		inferenceProviderMapping = await resp
 			.json()
 			.then((json) => json.inferenceProviderMapping)
 			.catch(() => null);
+
+		if (inferenceProviderMapping) {
+			inferenceProviderMappingCache.set(modelId, inferenceProviderMapping);
+		}
 	}
 
 	if (!inferenceProviderMapping) {
-		throw new Error(`We have not been able to find inference provider information for model ${params.modelId}.`);
+		throw new Error(`We have not been able to find inference provider information for model ${modelId}.`);
 	}
+	return inferenceProviderMapping;
+}
 
+export async function getInferenceProviderMapping(
+	params: {
+		accessToken?: string;
+		modelId: ModelId;
+		provider: InferenceProvider;
+		task: WidgetType;
+	},
+	options: {
+		fetch?: (input: RequestInfo, init?: RequestInit) => Promise<Response>;
+	}
+): Promise<InferenceProviderModelMapping | null> {
+	if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
+		return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
+	}
+	const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(
+		params.modelId,
+		params.accessToken,
+		options
+	);
 	const providerMapping = inferenceProviderMapping[params.provider];
 	if (providerMapping) {
 		const equivalentTasks =
@@ -94,3 +112,23 @@ export async function getInferenceProviderMapping(
 	}
 	return null;
 }
+
+export async function resolveProvider(
+	provider?: InferenceProviderPolicy,
+	modelId?: string
+): Promise<InferenceProvider> {
+	if (!provider && !modelId) {
+		provider = "hf-inference";
+	}
+	if (!provider) {
+		provider = "auto";
+	}
+	if (provider === "auto") {
+		if (!modelId) {
+			throw new Error("Specifying a model is required when provider is 'auto'");
+		}
+		const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(modelId);
+		provider = Object.keys(inferenceProviderMapping)[0] as InferenceProvider;
+	}
+	return provider;
+}
diff --git a/packages/inference/src/lib/makeRequestOptions.ts b/packages/inference/src/lib/makeRequestOptions.ts
index 06db7d3693..08f1059cc7 100644
--- a/packages/inference/src/lib/makeRequestOptions.ts
+++ b/packages/inference/src/lib/makeRequestOptions.ts
@@ -27,8 +27,8 @@ export async function makeRequestOptions(
 		task?: InferenceTask;
 	}
 ): Promise<{ url: string; info: RequestInit }> {
-	const { provider: maybeProvider, model: maybeModel } = args;
-	const provider = maybeProvider ?? "hf-inference";
+	const { model: maybeModel } = args;
+	const provider = providerHelper.provider;
 	const { task } = options ?? {};
 
 	// Validate inputs
@@ -113,8 +113,9 @@ export function makeRequestOptionsFromResolvedModel(
 ): { url: string; info: RequestInit } {
 	const { accessToken, endpointUrl, provider: maybeProvider, model, ...remainingArgs } = args;
 	void model;
+	void maybeProvider;
 
-	const provider = maybeProvider ?? "hf-inference";
+	const provider = providerHelper.provider;
 
 	const { includeCredentials, task, signal, billTo } = options ?? {};
 	const authMethod = (() => {
diff --git a/packages/inference/src/providers/providerHelper.ts b/packages/inference/src/providers/providerHelper.ts
index a0da0f7c6f..994f3ab0ea 100644
--- a/packages/inference/src/providers/providerHelper.ts
+++ b/packages/inference/src/providers/providerHelper.ts
@@ -56,7 +56,7 @@ import { toArray } from "../utils/toArray";
  */
 export abstract class TaskProviderHelper {
 	constructor(
-		private provider: InferenceProvider,
+		readonly provider: InferenceProvider,
 		private baseUrl: string,
 		readonly clientSideRoutingOnly: boolean = false
 	) {}
diff --git a/packages/inference/src/tasks/audio/audioClassification.ts b/packages/inference/src/tasks/audio/audioClassification.ts
index f1ff1c20c7..bc495b0227 100644
--- a/packages/inference/src/tasks/audio/audioClassification.ts
+++ b/packages/inference/src/tasks/audio/audioClassification.ts
@@ -1,4 +1,5 @@
 import type { AudioClassificationInput, AudioClassificationOutput } from "@huggingface/tasks";
+import { resolveProvider } from "../../lib/getInferenceProviderMapping";
 import { getProviderHelper } from "../../lib/getProviderHelper";
 import type { BaseArgs, Options } from "../../types";
 import { innerRequest } from "../../utils/request";
@@ -15,7 +16,8 @@ export async function audioClassification(
 	args: AudioClassificationArgs,
 	options?: Options
 ): Promise<AudioClassificationOutput> {
-	const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-classification");
+	const provider = await resolveProvider(args.provider, args.model);
+	const providerHelper = getProviderHelper(provider, "audio-classification");
 	const payload = preparePayload(args);
 	const { data: res } = await innerRequest<AudioClassificationOutput>(payload, providerHelper, {
 		...options,
diff --git a/packages/inference/src/tasks/audio/audioToAudio.ts b/packages/inference/src/tasks/audio/audioToAudio.ts
index fa055d3234..b4f37c1563 100644
--- a/packages/inference/src/tasks/audio/audioToAudio.ts
+++ b/packages/inference/src/tasks/audio/audioToAudio.ts
@@ -1,3 +1,4 @@
+import { resolveProvider } from "../../lib/getInferenceProviderMapping";
 import { getProviderHelper } from "../../lib/getProviderHelper";
 import type { BaseArgs, Options } from "../../types";
 import { innerRequest } from "../../utils/request";
@@ -36,7 +37,9 @@ export interface AudioToAudioOutput {
  * Example model: speechbrain/sepformer-wham does audio source separation.
  */
 export async function audioToAudio(args: AudioToAudioArgs, options?: Options): Promise<AudioToAudioOutput[]> {
-	const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-to-audio");
+	const model = "inputs" in args ? args.model : undefined;
+	const provider = await resolveProvider(args.provider, model);
+	const providerHelper = getProviderHelper(provider, "audio-to-audio");
 	const payload = preparePayload(args);
 	const { data: res } = await innerRequest<AudioToAudioOutput>(payload, providerHelper, {
 		...options,
diff --git a/packages/inference/src/tasks/audio/automaticSpeechRecognition.ts b/packages/inference/src/tasks/audio/automaticSpeechRecognition.ts
index c71af07427..347c01b833 100644
--- a/packages/inference/src/tasks/audio/automaticSpeechRecognition.ts
+++ b/packages/inference/src/tasks/audio/automaticSpeechRecognition.ts
@@ -1,4 +1,5 @@
 import type { AutomaticSpeechRecognitionInput, AutomaticSpeechRecognitionOutput } from "@huggingface/tasks";
+import { resolveProvider } from "../../lib/getInferenceProviderMapping";
 import { getProviderHelper } from "../../lib/getProviderHelper";
 import { InferenceOutputError } from "../../lib/InferenceOutputError";
 import { FAL_AI_SUPPORTED_BLOB_TYPES } from "../../providers/fal-ai";
@@ -18,7 +19,8 @@ export async function automaticSpeechRecognition(
 	args: AutomaticSpeechRecognitionArgs,
 	options?: Options
 ): Promise<AutomaticSpeechRecognitionOutput> {
-	const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "automatic-speech-recognition");
+	const provider = await resolveProvider(args.provider, args.model);
+	const providerHelper = getProviderHelper(provider, "automatic-speech-recognition");
 	const payload = await buildPayload(args);
 	const { data: res } = await innerRequest<AutomaticSpeechRecognitionOutput>(payload, providerHelper, {
 		...options,
diff --git a/packages/inference/src/tasks/audio/textToSpeech.ts b/packages/inference/src/tasks/audio/textToSpeech.ts
index 11f01f436f..24694062fc 100644
--- a/packages/inference/src/tasks/audio/textToSpeech.ts
+++ b/packages/inference/src/tasks/audio/textToSpeech.ts
@@ -1,4 +1,5 @@
 import type { TextToSpeechInput } from "@huggingface/tasks";
+import { resolveProvider } from "../../lib/getInferenceProviderMapping";
 import { getProviderHelper } from "../../lib/getProviderHelper";
 import type { BaseArgs, Options } from "../../types";
 import { innerRequest } from "../../utils/request";
@@ -12,7 +13,7 @@ interface OutputUrlTextToSpeechGeneration {
  * Recommended model: espnet/kan-bayashi_ljspeech_vits
  */
 export async function textToSpeech(args: TextToSpeechArgs, options?: Options): Promise<Blob> {
-	const provider = args.provider ?? "hf-inference";
+	const provider = await resolveProvider(args.provider, args.model);
 	const providerHelper = getProviderHelper(provider, "text-to-speech");
 	const { data: res } = await innerRequest<Blob | OutputUrlTextToSpeechGeneration>(args, providerHelper, {
 		...options,
diff --git a/packages/inference/src/tasks/custom/request.ts b/packages/inference/src/tasks/custom/request.ts
index 62f45f28b3..df103ad5e0 100644
--- a/packages/inference/src/tasks/custom/request.ts
+++ b/packages/inference/src/tasks/custom/request.ts
@@ -1,3 +1,4 @@
+import { resolveProvider } from "../../lib/getInferenceProviderMapping";
 import { getProviderHelper } from "../../lib/getProviderHelper";
 import type { InferenceTask, Options, RequestArgs } from "../../types";
 import { innerRequest } from "../../utils/request";
@@ -16,7 +17,8 @@ export async function request<T>(
 	console.warn(
 		"The request method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
 	);
-	const providerHelper = getProviderHelper(args.provider ?? "hf-inference", options?.task);
+	const provider = await resolveProvider(args.provider, args.model);
+	const providerHelper = getProviderHelper(provider, options?.task);
 	const result = await innerRequest<T>(args, providerHelper, options);
 	return result.data;
 }
diff --git a/packages/inference/src/tasks/custom/streamingRequest.ts b/packages/inference/src/tasks/custom/streamingRequest.ts
index 45ae99f323..8e45ad5535 100644
--- a/packages/inference/src/tasks/custom/streamingRequest.ts
+++ b/packages/inference/src/tasks/custom/streamingRequest.ts
@@ -1,3 +1,4 @@
+import { resolveProvider } from "../../lib/getInferenceProviderMapping";
 import { getProviderHelper } from "../../lib/getProviderHelper";
 import type { InferenceTask, Options, RequestArgs } from "../../types";
 import { innerStreamingRequest } from "../../utils/request";
@@ -16,6 +17,7 @@ export async function* streamingRequest<T>(
 	console.warn(
 		"The streamingRequest method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
 	);
-	const providerHelper = getProviderHelper(args.provider ?? "hf-inference", options?.task);
+	const provider = await resolveProvider(args.provider, args.model);
+	const providerHelper = getProviderHelper(provider, options?.task);
 	yield* innerStreamingRequest(args, providerHelper, options);
 }
diff --git a/packages/inference/src/tasks/cv/imageClassification.ts b/packages/inference/src/tasks/cv/imageClassification.ts
index e683ecb3e2..a324faa1db 100644
--- a/packages/inference/src/tasks/cv/imageClassification.ts
+++ b/packages/inference/src/tasks/cv/imageClassification.ts
@@ -1,4 +1,5 @@
 import type { ImageClassificationInput, ImageClassificationOutput } from "@huggingface/tasks";
+import { resolveProvider } from "../../lib/getInferenceProviderMapping";
 import { getProviderHelper } from "../../lib/getProviderHelper";
 import type { BaseArgs, Options } from "../../types";
 import { innerRequest } from "../../utils/request";
@@ -14,7 +15,8 @@ export async function imageClassification(
 	args: ImageClassificationArgs,
 	options?: Options
 ): Promise<ImageClassificationOutput> {
-	const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-classification");
+	const provider = await resolveProvider(args.provider, args.model);
+	const providerHelper = getProviderHelper(provider, "image-classification");
 	const payload = preparePayload(args);
 	const { data: res } = await innerRequest<ImageClassificationOutput>(payload, providerHelper, {
 		...options,
diff --git a/packages/inference/src/tasks/cv/imageSegmentation.ts b/packages/inference/src/tasks/cv/imageSegmentation.ts
index 9de0e9a2ef..719fa684ec 100644
--- a/packages/inference/src/tasks/cv/imageSegmentation.ts
+++ b/packages/inference/src/tasks/cv/imageSegmentation.ts
@@ -1,4 +1,5 @@
 import type { ImageSegmentationInput, ImageSegmentationOutput } from "@huggingface/tasks";
+import { resolveProvider } from "../../lib/getInferenceProviderMapping";
 import { getProviderHelper } from "../../lib/getProviderHelper";
 import type { BaseArgs, Options } from "../../types";
 import { innerRequest } from "../../utils/request";
@@ -14,7 +15,8 @@ export async function imageSegmentation(
 	args: ImageSegmentationArgs,
 	options?: Options
 ): Promise<ImageSegmentationOutput> {
-	const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-segmentation");
+	const provider = await resolveProvider(args.provider, args.model);
+	const providerHelper = getProviderHelper(provider, "image-segmentation");
 	const payload = preparePayload(args);
 	const { data: res } = await innerRequest<ImageSegmentationOutput>(payload, providerHelper, {
 		...options,
diff --git a/packages/inference/src/tasks/cv/imageToImage.ts b/packages/inference/src/tasks/cv/imageToImage.ts
index 49d8ca2be5..a0bab24bab 100644
--- a/packages/inference/src/tasks/cv/imageToImage.ts
+++ b/packages/inference/src/tasks/cv/imageToImage.ts
@@ -1,4 +1,5 @@
 import type { ImageToImageInput } from "@huggingface/tasks";
+import { resolveProvider } from "../../lib/getInferenceProviderMapping";
 import { getProviderHelper } from "../../lib/getProviderHelper";
 import type { BaseArgs, Options, RequestArgs } from "../../types";
 import { base64FromBytes } from "../../utils/base64FromBytes";
@@ -11,7 +12,8 @@ export type ImageToImageArgs = BaseArgs & ImageToImageInput;
  * Recommended model: lllyasviel/sd-controlnet-depth
  */
 export async function imageToImage(args: ImageToImageArgs, options?: Options): Promise<Blob> {
-	const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-image");
+	const provider = await resolveProvider(args.provider, args.model);
+	const providerHelper = getProviderHelper(provider, "image-to-image");
 	let reqArgs: RequestArgs;
 	if (!args.parameters) {
 		reqArgs = {
diff --git a/packages/inference/src/tasks/cv/imageToText.ts b/packages/inference/src/tasks/cv/imageToText.ts
index cdee706fa4..8d2d791893 100644
--- a/packages/inference/src/tasks/cv/imageToText.ts
+++ b/packages/inference/src/tasks/cv/imageToText.ts
@@ -1,4 +1,5 @@
 import type { ImageToTextInput, ImageToTextOutput } from "@huggingface/tasks";
+import { resolveProvider } from "../../lib/getInferenceProviderMapping";
 import { getProviderHelper } from "../../lib/getProviderHelper";
 import type { BaseArgs, Options } from "../../types";
 import { innerRequest } from "../../utils/request";
@@ -10,7 +11,8 @@ export type ImageToTextArgs = BaseArgs & (ImageToTextInput | LegacyImageInput);
  * This task reads some image input and outputs the text caption.
  */
 export async function imageToText(args: ImageToTextArgs, options?: Options): Promise<ImageToTextOutput> {
-	const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-text");
+	const provider = await resolveProvider(args.provider, args.model);
+	const providerHelper = getProviderHelper(provider, "image-to-text");
 	const payload = preparePayload(args);
 	const { data: res } = await innerRequest<[ImageToTextOutput]>(payload, providerHelper, {
 		...options,
diff --git a/packages/inference/src/tasks/cv/objectDetection.ts b/packages/inference/src/tasks/cv/objectDetection.ts
index d103feeb96..d216be78eb 100644
--- a/packages/inference/src/tasks/cv/objectDetection.ts
+++ b/packages/inference/src/tasks/cv/objectDetection.ts
@@ -1,4 +1,5 @@
 import type { ObjectDetectionInput, ObjectDetectionOutput } from "@huggingface/tasks";
+import { resolveProvider } from "../../lib/getInferenceProviderMapping";
 import { getProviderHelper } from "../../lib/getProviderHelper";
 import type { BaseArgs, Options } from "../../types";
 import { innerRequest } from "../../utils/request";
@@ -11,7 +12,8 @@ export type ObjectDetectionArgs = BaseArgs & (ObjectDetectionInput | LegacyImage
  * Recommended model: facebook/detr-resnet-50
  */
 export async function objectDetection(args: ObjectDetectionArgs, options?: Options): Promise<ObjectDetectionOutput> {
-	const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "object-detection");
+	const provider = await resolveProvider(args.provider, args.model);
+	const providerHelper = getProviderHelper(provider, "object-detection");
 	const payload = preparePayload(args);
 	const { data: res } = await innerRequest<ObjectDetectionOutput>(payload, providerHelper, {
 		...options,
diff --git a/packages/inference/src/tasks/cv/textToImage.ts b/packages/inference/src/tasks/cv/textToImage.ts
index 490a8e10b6..76544065df 100644
--- a/packages/inference/src/tasks/cv/textToImage.ts
+++ b/packages/inference/src/tasks/cv/textToImage.ts
@@ -1,4 +1,5 @@
 import type { TextToImageInput } from "@huggingface/tasks";
+import { resolveProvider } from "../../lib/getInferenceProviderMapping";
 import { getProviderHelper } from "../../lib/getProviderHelper";
 import { makeRequestOptions } from "../../lib/makeRequestOptions";
 import type { BaseArgs, Options } from "../../types";
@@ -23,7 +24,7 @@ export async function textToImage(
 	options?: TextToImageOptions & { outputType?: undefined | "blob" }
 ): Promise<Blob>;
 export async function textToImage(args: TextToImageArgs, options?: TextToImageOptions): Promise<Blob | string> {
-	const provider = args.provider ?? "hf-inference";
+	const provider = await resolveProvider(args.provider, args.model);
 	const providerHelper = getProviderHelper(provider, "text-to-image");
 	const { data: res } = await innerRequest<Record<string, unknown>>(args, providerHelper, {
 		...options,
diff --git a/packages/inference/src/tasks/cv/textToVideo.ts b/packages/inference/src/tasks/cv/textToVideo.ts
index 9143e147a8..798799d61b 100644
--- a/packages/inference/src/tasks/cv/textToVideo.ts
+++ b/packages/inference/src/tasks/cv/textToVideo.ts
@@ -1,4 +1,5 @@
 import type { TextToVideoInput } from "@huggingface/tasks";
+import { resolveProvider } from "../../lib/getInferenceProviderMapping";
 import { getProviderHelper } from "../../lib/getProviderHelper";
 import { makeRequestOptions } from "../../lib/makeRequestOptions";
 import type { FalAiQueueOutput } from "../../providers/fal-ai";
@@ -12,7 +13,7 @@ export type TextToVideoArgs = BaseArgs & TextToVideoInput;
 export type TextToVideoOutput = Blob;
 
 export async function textToVideo(args: TextToVideoArgs, options?: Options): Promise<TextToVideoOutput> {
-	const provider = args.provider ?? "hf-inference";
+	const provider = await resolveProvider(args.provider, args.model);
 	const providerHelper = getProviderHelper(provider, "text-to-video");
 	const { data: response } = await innerRequest<FalAiQueueOutput | ReplicateOutput | NovitaOutput>(
 		args,
diff --git a/packages/inference/src/tasks/cv/zeroShotImageClassification.ts b/packages/inference/src/tasks/cv/zeroShotImageClassification.ts
index 8f09e74aac..e75b439853 100644
--- a/packages/inference/src/tasks/cv/zeroShotImageClassification.ts
+++ b/packages/inference/src/tasks/cv/zeroShotImageClassification.ts
@@ -1,4 +1,5 @@
 import type { ZeroShotImageClassificationInput, ZeroShotImageClassificationOutput } from "@huggingface/tasks";
+import { resolveProvider } from "../../lib/getInferenceProviderMapping";
 import { getProviderHelper } from "../../lib/getProviderHelper";
 import type { BaseArgs, Options, RequestArgs } from "../../types";
 import { base64FromBytes } from "../../utils/base64FromBytes";
@@ -44,7 +45,8 @@ export async function zeroShotImageClassification(
 	args: ZeroShotImageClassificationArgs,
 	options?: Options
 ): Promise<ZeroShotImageClassificationOutput> {
-	const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-image-classification");
+	const provider = await resolveProvider(args.provider, args.model);
+	const providerHelper = getProviderHelper(provider, "zero-shot-image-classification");
 	const payload = await preparePayload(args);
 	const { data: res } = await innerRequest<ZeroShotImageClassificationOutput>(payload, providerHelper, {
 		...options,
diff --git a/packages/inference/src/tasks/multimodal/documentQuestionAnswering.ts b/packages/inference/src/tasks/multimodal/documentQuestionAnswering.ts
index 033708dce0..c17234ee2d 100644
--- a/packages/inference/src/tasks/multimodal/documentQuestionAnswering.ts
+++ b/packages/inference/src/tasks/multimodal/documentQuestionAnswering.ts
@@ -3,6 +3,7 @@ import type {
 	DocumentQuestionAnsweringInputData,
 	DocumentQuestionAnsweringOutput,
 } from "@huggingface/tasks";
+import { resolveProvider } from "../../lib/getInferenceProviderMapping";
 import { getProviderHelper } from "../../lib/getProviderHelper";
 import type { BaseArgs, Options, RequestArgs } from "../../types";
 import { base64FromBytes } from "../../utils/base64FromBytes";
@@ -19,7 +20,8 @@ export async function documentQuestionAnswering(
 	args: DocumentQuestionAnsweringArgs,
 	options?: Options
 ): Promise<DocumentQuestionAnsweringOutput[number]> {
-	const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "document-question-answering");
+	const provider = await resolveProvider(args.provider, args.model);
+	const providerHelper = getProviderHelper(provider, "document-question-answering");
 	const reqArgs: RequestArgs = {
 		...args,
 		inputs: {
diff --git a/packages/inference/src/tasks/multimodal/visualQuestionAnswering.ts b/packages/inference/src/tasks/multimodal/visualQuestionAnswering.ts
index 98a6616206..62e2d9e9da 100644
--- a/packages/inference/src/tasks/multimodal/visualQuestionAnswering.ts
+++ b/packages/inference/src/tasks/multimodal/visualQuestionAnswering.ts
@@ -3,6 +3,7 @@ import type {
 	VisualQuestionAnsweringInputData,
 	VisualQuestionAnsweringOutput,
 } from "@huggingface/tasks";
+import { resolveProvider } from "../../lib/getInferenceProviderMapping";
 import { getProviderHelper } from "../../lib/getProviderHelper";
 import type { BaseArgs, Options, RequestArgs } from "../../types";
 import { base64FromBytes } from "../../utils/base64FromBytes";
@@ -19,7 +20,8 @@ export async function visualQuestionAnswering(
 	args: VisualQuestionAnsweringArgs,
 	options?: Options
 ): Promise<VisualQuestionAnsweringOutput[number]> {
-	const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "visual-question-answering");
+	const provider = await resolveProvider(args.provider, args.model);
+	const providerHelper = getProviderHelper(provider, "visual-question-answering");
 	const reqArgs: RequestArgs = {
 		...args,
 		inputs: {
diff --git a/packages/inference/src/tasks/nlp/chatCompletion.ts b/packages/inference/src/tasks/nlp/chatCompletion.ts
index 4ad9be5f1c..6ad4da6aea 100644
--- a/packages/inference/src/tasks/nlp/chatCompletion.ts
+++ b/packages/inference/src/tasks/nlp/chatCompletion.ts
@@ -1,4 +1,5 @@
 import type { ChatCompletionInput, ChatCompletionOutput } from "@huggingface/tasks";
+import { resolveProvider } from "../../lib/getInferenceProviderMapping";
 import { getProviderHelper } from "../../lib/getProviderHelper";
 import type { BaseArgs, Options } from "../../types";
 import { innerRequest } from "../../utils/request";
@@ -10,7 +11,8 @@ export async function chatCompletion(
 	args: BaseArgs & ChatCompletionInput,
 	options?: Options
 ): Promise<ChatCompletionOutput> {
-	const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "conversational");
+	const provider = await resolveProvider(args.provider, args.model);
+	const providerHelper = getProviderHelper(provider, "conversational");
 	const { data: response } = await innerRequest<ChatCompletionOutput>(args, providerHelper, {
 		...options,
 		task: "conversational",
diff --git a/packages/inference/src/tasks/nlp/chatCompletionStream.ts b/packages/inference/src/tasks/nlp/chatCompletionStream.ts
index cf88044112..260b9374ff 100644
--- a/packages/inference/src/tasks/nlp/chatCompletionStream.ts
+++ b/packages/inference/src/tasks/nlp/chatCompletionStream.ts
@@ -1,4 +1,5 @@
 import type { ChatCompletionInput, ChatCompletionStreamOutput } from "@huggingface/tasks";
+import { resolveProvider } from "../../lib/getInferenceProviderMapping";
 import { getProviderHelper } from "../../lib/getProviderHelper";
 import type { BaseArgs, Options } from "../../types";
 import { innerStreamingRequest } from "../../utils/request";
@@ -10,7 +11,8 @@ export async function* chatCompletionStream(
 	args: BaseArgs & ChatCompletionInput,
 	options?: Options
 ): AsyncGenerator<ChatCompletionStreamOutput> {
-	const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "conversational");
+	const provider = await resolveProvider(args.provider, args.model);
+	const providerHelper = getProviderHelper(provider, "conversational");
 	yield* innerStreamingRequest<ChatCompletionStreamOutput>(args, providerHelper, {
 		...options,
 		task: "conversational",
diff --git a/packages/inference/src/tasks/nlp/featureExtraction.ts b/packages/inference/src/tasks/nlp/featureExtraction.ts
index 5d0fe93ded..30fc3a6766 100644
--- a/packages/inference/src/tasks/nlp/featureExtraction.ts
+++ b/packages/inference/src/tasks/nlp/featureExtraction.ts
@@ -1,4 +1,5 @@
 import type { FeatureExtractionInput } from "@huggingface/tasks";
+import { resolveProvider } from "../../lib/getInferenceProviderMapping";
 import { getProviderHelper } from "../../lib/getProviderHelper";
 import type { BaseArgs, Options } from "../../types";
 import { innerRequest } from "../../utils/request";
@@ -22,7 +23,8 @@ export async function featureExtraction(
 	args: FeatureExtractionArgs,
 	options?: Options
 ): Promise<FeatureExtractionOutput> {
-	const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "feature-extraction");
+	const provider = await resolveProvider(args.provider, args.model);
+	const providerHelper = getProviderHelper(provider, "feature-extraction");
 	const { data: res } = await innerRequest<FeatureExtractionOutput>(args, providerHelper, {
 		...options,
 		task: "feature-extraction",
diff --git a/packages/inference/src/tasks/nlp/fillMask.ts b/packages/inference/src/tasks/nlp/fillMask.ts
index 663db87d99..8a992f938d 100644
--- a/packages/inference/src/tasks/nlp/fillMask.ts
+++ b/packages/inference/src/tasks/nlp/fillMask.ts
@@ -1,4 +1,5 @@
 import type { FillMaskInput, FillMaskOutput } from "@huggingface/tasks";
+import { resolveProvider } from "../../lib/getInferenceProviderMapping";
 import { getProviderHelper } from "../../lib/getProviderHelper";
 import type { BaseArgs, Options } from "../../types";
 import { innerRequest } from "../../utils/request";
@@ -9,7 +10,8 @@ export type FillMaskArgs = BaseArgs & FillMaskInput;
  * Tries to fill in a hole with a missing word (token to be precise). That’s the base task for BERT models.
  */
 export async function fillMask(args: FillMaskArgs, options?: Options): Promise<FillMaskOutput> {
-	const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "fill-mask");
+	const provider = await resolveProvider(args.provider, args.model);
+	const providerHelper = getProviderHelper(provider, "fill-mask");
 	const { data: res } = await innerRequest<FillMaskOutput>(args, providerHelper, {
 		...options,
 		task: "fill-mask",
diff --git a/packages/inference/src/tasks/nlp/questionAnswering.ts b/packages/inference/src/tasks/nlp/questionAnswering.ts
index 6559c80c1a..a6b3cc0570 100644
--- a/packages/inference/src/tasks/nlp/questionAnswering.ts
+++ b/packages/inference/src/tasks/nlp/questionAnswering.ts
@@ -1,4 +1,6 @@
 import type { QuestionAnsweringInput, QuestionAnsweringOutput } from "@huggingface/tasks";
+
+import { resolveProvider } from "../../lib/getInferenceProviderMapping";
 import { getProviderHelper } from "../../lib/getProviderHelper";
 import type { BaseArgs, Options } from "../../types";
 import { innerRequest } from "../../utils/request";
@@ -12,7 +14,8 @@ export async function questionAnswering(
 	args: QuestionAnsweringArgs,
 	options?: Options
 ): Promise<QuestionAnsweringOutput[number]> {
-	const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "question-answering");
+	const provider = await resolveProvider(args.provider, args.model);
+	const providerHelper = getProviderHelper(provider, "question-answering");
 	const { data: res } = await innerRequest<QuestionAnsweringOutput | QuestionAnsweringOutput[number]>(
 		args,
 		providerHelper,
diff --git a/packages/inference/src/tasks/nlp/sentenceSimilarity.ts b/packages/inference/src/tasks/nlp/sentenceSimilarity.ts
index faa751f73e..f76b8c795c 100644
--- a/packages/inference/src/tasks/nlp/sentenceSimilarity.ts
+++ b/packages/inference/src/tasks/nlp/sentenceSimilarity.ts
@@ -1,4 +1,5 @@
 import type { SentenceSimilarityInput, SentenceSimilarityOutput } from "@huggingface/tasks";
+import { resolveProvider } from "../../lib/getInferenceProviderMapping";
 import { getProviderHelper } from "../../lib/getProviderHelper";
 import type { BaseArgs, Options } from "../../types";
 import { innerRequest } from "../../utils/request";
@@ -12,7 +13,8 @@ export async function sentenceSimilarity(
 	args: SentenceSimilarityArgs,
 	options?: Options
 ): Promise<SentenceSimilarityOutput> {
-	const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "sentence-similarity");
+	const provider = await resolveProvider(args.provider, args.model);
+	const providerHelper = getProviderHelper(provider, "sentence-similarity");
 	const { data: res } = await innerRequest<SentenceSimilarityOutput>(args, providerHelper, {
 		...options,
 		task: "sentence-similarity",
diff --git a/packages/inference/src/tasks/nlp/summarization.ts b/packages/inference/src/tasks/nlp/summarization.ts
index 4b4205bf4b..29cda2cd5b 100644
--- a/packages/inference/src/tasks/nlp/summarization.ts
+++ b/packages/inference/src/tasks/nlp/summarization.ts
@@ -1,4 +1,5 @@
 import type { SummarizationInput, SummarizationOutput } from "@huggingface/tasks";
+import { resolveProvider } from "../../lib/getInferenceProviderMapping";
 import { getProviderHelper } from "../../lib/getProviderHelper";
 import type { BaseArgs, Options } from "../../types";
 import { innerRequest } from "../../utils/request";
@@ -9,7 +10,8 @@ export type SummarizationArgs = BaseArgs & SummarizationInput;
  * This task is well known to summarize longer text into shorter text. Be careful, some models have a maximum length of input. That means that the summary cannot handle full books for instance. Be careful when choosing your model.
  */
 export async function summarization(args: SummarizationArgs, options?: Options): Promise<SummarizationOutput> {
-	const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "summarization");
+	const provider = await resolveProvider(args.provider, args.model);
+	const providerHelper = getProviderHelper(provider, "summarization");
 	const { data: res } = await innerRequest<SummarizationOutput[]>(args, providerHelper, {
 		...options,
 		task: "summarization",
diff --git a/packages/inference/src/tasks/nlp/tableQuestionAnswering.ts b/packages/inference/src/tasks/nlp/tableQuestionAnswering.ts
index 3939115862..03127a2fd8 100644
--- a/packages/inference/src/tasks/nlp/tableQuestionAnswering.ts
+++ b/packages/inference/src/tasks/nlp/tableQuestionAnswering.ts
@@ -1,4 +1,5 @@
 import type { TableQuestionAnsweringInput, TableQuestionAnsweringOutput } from "@huggingface/tasks";
+import { resolveProvider } from "../../lib/getInferenceProviderMapping";
 import { getProviderHelper } from "../../lib/getProviderHelper";
 import type { BaseArgs, Options } from "../../types";
 import { innerRequest } from "../../utils/request";
@@ -12,7 +13,8 @@ export async function tableQuestionAnswering(
 	args: TableQuestionAnsweringArgs,
 	options?: Options
 ): Promise<TableQuestionAnsweringOutput[number]> {
-	const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "table-question-answering");
+	const provider = await resolveProvider(args.provider, args.model);
+	const providerHelper = getProviderHelper(provider, "table-question-answering");
 	const { data: res } = await innerRequest<TableQuestionAnsweringOutput | TableQuestionAnsweringOutput[number]>(
 		args,
 		providerHelper,
diff --git a/packages/inference/src/tasks/nlp/textClassification.ts b/packages/inference/src/tasks/nlp/textClassification.ts
index 7631d82286..8440d81bff 100644
--- a/packages/inference/src/tasks/nlp/textClassification.ts
+++ b/packages/inference/src/tasks/nlp/textClassification.ts
@@ -1,4 +1,5 @@
 import type { TextClassificationInput, TextClassificationOutput } from "@huggingface/tasks";
+import { resolveProvider } from "../../lib/getInferenceProviderMapping";
 import { getProviderHelper } from "../../lib/getProviderHelper";
 import type { BaseArgs, Options } from "../../types";
 import { innerRequest } from "../../utils/request";
@@ -12,7 +13,8 @@ export async function textClassification(
 	args: TextClassificationArgs,
 	options?: Options
 ): Promise<TextClassificationOutput> {
-	const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-classification");
+	const provider = await resolveProvider(args.provider, args.model);
+	const providerHelper = getProviderHelper(provider, "text-classification");
 	const { data: res } = await innerRequest<TextClassificationOutput>(args, providerHelper, {
 		...options,
 		task: "text-classification",
diff --git a/packages/inference/src/tasks/nlp/textGeneration.ts b/packages/inference/src/tasks/nlp/textGeneration.ts
index 5d84e543cb..a1e441e6bc 100644
--- a/packages/inference/src/tasks/nlp/textGeneration.ts
+++ b/packages/inference/src/tasks/nlp/textGeneration.ts
@@ -1,4 +1,5 @@
 import type { TextGenerationInput, TextGenerationOutput } from "@huggingface/tasks";
+import { resolveProvider } from "../../lib/getInferenceProviderMapping";
 import { getProviderHelper } from "../../lib/getProviderHelper";
 import type { HyperbolicTextCompletionOutput } from "../../providers/hyperbolic";
 import type { BaseArgs, Options } from "../../types";
@@ -13,7 +14,8 @@ export async function textGeneration(
 	args: BaseArgs & TextGenerationInput,
 	options?: Options
 ): Promise<TextGenerationOutput> {
-	const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-generation");
+	const provider = await resolveProvider(args.provider, args.model);
+	const providerHelper = getProviderHelper(provider, "text-generation");
 	const { data: response } = await innerRequest<
 		HyperbolicTextCompletionOutput | TextGenerationOutput | TextGenerationOutput[]
 	>(args, providerHelper, {
diff --git a/packages/inference/src/tasks/nlp/textGenerationStream.ts b/packages/inference/src/tasks/nlp/textGenerationStream.ts
index 59706eaa5c..935e533a2a 100644
--- a/packages/inference/src/tasks/nlp/textGenerationStream.ts
+++ b/packages/inference/src/tasks/nlp/textGenerationStream.ts
@@ -1,4 +1,5 @@
 import type { TextGenerationInput } from "@huggingface/tasks";
+import { resolveProvider } from "../../lib/getInferenceProviderMapping";
 import { getProviderHelper } from "../../lib/getProviderHelper";
 import type { BaseArgs, Options } from "../../types";
 import { innerStreamingRequest } from "../../utils/request";
@@ -90,7 +91,8 @@ export async function* textGenerationStream(
 	args: BaseArgs & TextGenerationInput,
 	options?: Options
 ): AsyncGenerator<TextGenerationStreamOutput> {
-	const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-generation");
+	const provider = await resolveProvider(args.provider, args.model);
+	const providerHelper = getProviderHelper(provider, "text-generation");
 	yield* innerStreamingRequest<TextGenerationStreamOutput>(args, providerHelper, {
 		...options,
 		task: "text-generation",
diff --git a/packages/inference/src/tasks/nlp/tokenClassification.ts b/packages/inference/src/tasks/nlp/tokenClassification.ts
index 0c52b9e6a6..a822c7efcf 100644
--- a/packages/inference/src/tasks/nlp/tokenClassification.ts
+++ b/packages/inference/src/tasks/nlp/tokenClassification.ts
@@ -1,4 +1,5 @@
 import type { TokenClassificationInput, TokenClassificationOutput } from "@huggingface/tasks";
+import { resolveProvider } from "../../lib/getInferenceProviderMapping";
 import { getProviderHelper } from "../../lib/getProviderHelper";
 import type { BaseArgs, Options } from "../../types";
 import { innerRequest } from "../../utils/request";
@@ -12,7 +13,8 @@ export async function tokenClassification(
 	args: TokenClassificationArgs,
 	options?: Options
 ): Promise<TokenClassificationOutput> {
-	const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "token-classification");
+	const provider = await resolveProvider(args.provider, args.model);
+	const providerHelper = getProviderHelper(provider, "token-classification");
 	const { data: res } = await innerRequest<TokenClassificationOutput[number] | TokenClassificationOutput>(
 		args,
 		providerHelper,
diff --git a/packages/inference/src/tasks/nlp/translation.ts b/packages/inference/src/tasks/nlp/translation.ts
index 1f463e3e67..1347527d87 100644
--- a/packages/inference/src/tasks/nlp/translation.ts
+++ b/packages/inference/src/tasks/nlp/translation.ts
@@ -1,4 +1,5 @@
 import type { TranslationInput, TranslationOutput } from "@huggingface/tasks";
+import { resolveProvider } from "../../lib/getInferenceProviderMapping";
 import { getProviderHelper } from "../../lib/getProviderHelper";
 import type { BaseArgs, Options } from "../../types";
 import { innerRequest } from "../../utils/request";
@@ -8,7 +9,8 @@ export type TranslationArgs = BaseArgs & TranslationInput;
  * This task is well known to translate text from one language to another. Recommended model: Helsinki-NLP/opus-mt-ru-en.
  */
 export async function translation(args: TranslationArgs, options?: Options): Promise<TranslationOutput> {
-	const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "translation");
+	const provider = await resolveProvider(args.provider, args.model);
+	const providerHelper = getProviderHelper(provider, "translation");
 	const { data: res } = await innerRequest<TranslationOutput>(args, providerHelper, {
 		...options,
 		task: "translation",
diff --git a/packages/inference/src/tasks/nlp/zeroShotClassification.ts b/packages/inference/src/tasks/nlp/zeroShotClassification.ts
index 30d6d0c156..83642749b1 100644
--- a/packages/inference/src/tasks/nlp/zeroShotClassification.ts
+++ b/packages/inference/src/tasks/nlp/zeroShotClassification.ts
@@ -1,4 +1,5 @@
 import type { ZeroShotClassificationInput, ZeroShotClassificationOutput } from "@huggingface/tasks";
+import { resolveProvider } from "../../lib/getInferenceProviderMapping";
 import { getProviderHelper } from "../../lib/getProviderHelper";
 import type { BaseArgs, Options } from "../../types";
 import { innerRequest } from "../../utils/request";
@@ -12,7 +13,8 @@ export async function zeroShotClassification(
 	args: ZeroShotClassificationArgs,
 	options?: Options
 ): Promise<ZeroShotClassificationOutput> {
-	const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-classification");
+	const provider = await resolveProvider(args.provider, args.model);
+	const providerHelper = getProviderHelper(provider, "zero-shot-classification");
 	const { data: res } = await innerRequest<ZeroShotClassificationOutput[number] | ZeroShotClassificationOutput>(
 		args,
 		providerHelper,
diff --git a/packages/inference/src/tasks/tabular/tabularClassification.ts b/packages/inference/src/tasks/tabular/tabularClassification.ts
index 9174c17718..5d615fdbed 100644
--- a/packages/inference/src/tasks/tabular/tabularClassification.ts
+++ b/packages/inference/src/tasks/tabular/tabularClassification.ts
@@ -1,3 +1,4 @@
+import { resolveProvider } from "../../lib/getInferenceProviderMapping";
 import { getProviderHelper } from "../../lib/getProviderHelper";
 import type { BaseArgs, Options } from "../../types";
 import { innerRequest } from "../../utils/request";
@@ -25,7 +26,8 @@ export async function tabularClassification(
 	args: TabularClassificationArgs,
 	options?: Options
 ): Promise<TabularClassificationOutput> {
-	const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-classification");
+	const provider = await resolveProvider(args.provider, args.model);
+	const providerHelper = getProviderHelper(provider, "tabular-classification");
 	const { data: res } = await innerRequest<TabularClassificationOutput>(args, providerHelper, {
 		...options,
 		task: "tabular-classification",
diff --git a/packages/inference/src/tasks/tabular/tabularRegression.ts b/packages/inference/src/tasks/tabular/tabularRegression.ts
index 2c2408ffde..284db1c8a6 100644
--- a/packages/inference/src/tasks/tabular/tabularRegression.ts
+++ b/packages/inference/src/tasks/tabular/tabularRegression.ts
@@ -1,3 +1,4 @@
+import { resolveProvider } from "../../lib/getInferenceProviderMapping";
 import { getProviderHelper } from "../../lib/getProviderHelper";
 import type { BaseArgs, Options } from "../../types";
 import { innerRequest } from "../../utils/request";
@@ -25,7 +26,8 @@ export async function tabularRegression(
 	args: TabularRegressionArgs,
 	options?: Options
 ): Promise<TabularRegressionOutput> {
-	const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-regression");
+	const provider = await resolveProvider(args.provider, args.model);
+	const providerHelper = getProviderHelper(provider, "tabular-regression");
 	const { data: res } = await innerRequest<TabularRegressionOutput>(args, providerHelper, {
 		...options,
 		task: "tabular-regression",
diff --git a/packages/inference/src/types.ts b/packages/inference/src/types.ts
index e5870f6ef3..51ad16156f 100644
--- a/packages/inference/src/types.ts
+++ b/packages/inference/src/types.ts
@@ -56,8 +56,12 @@ export const INFERENCE_PROVIDERS = [
 	"together",
 ] as const;
 
+export const PROVIDERS_OR_POLICIES = [...INFERENCE_PROVIDERS, "auto"] as const;
+
 export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number];
 
+export type InferenceProviderPolicy = (typeof PROVIDERS_OR_POLICIES)[number];
+
 export interface BaseArgs {
 	/**
 	 * The access token to use. Without it, you'll get rate-limited quickly.
@@ -90,7 +94,7 @@ export interface BaseArgs {
 	 *
 	 * Defaults to the first provider in your user settings that is compatible with this model.
 	 */
-	provider?: InferenceProvider;
+	provider?: InferenceProviderPolicy;
 }
 
 export type RequestArgs = BaseArgs &

From bab28f946f2964ae0358c2282f88e63be6f47f3d Mon Sep 17 00:00:00 2001
From: Celina Hanouti <hanouticelina@gmail.com>
Date: Fri, 25 Apr 2025 18:03:22 +0200
Subject: [PATCH 2/6] nit

---
 packages/inference/src/types.ts | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/packages/inference/src/types.ts b/packages/inference/src/types.ts
index 51ad16156f..f362a5647e 100644
--- a/packages/inference/src/types.ts
+++ b/packages/inference/src/types.ts
@@ -92,7 +92,7 @@ export interface BaseArgs {
 	/**
 	 * Set an Inference provider to run this model on.
 	 *
-	 * Defaults to the first provider in your user settings that is compatible with this model.
+	 * Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
 	 */
 	provider?: InferenceProviderPolicy;
 }

From c7383d335f27dcf36179d1e02f05a13e5cde237d Mon Sep 17 00:00:00 2001
From: Celina Hanouti <hanouticelina@gmail.com>
Date: Fri, 25 Apr 2025 18:07:10 +0200
Subject: [PATCH 3/6] logging

---
 packages/inference/src/lib/getInferenceProviderMapping.ts | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/packages/inference/src/lib/getInferenceProviderMapping.ts b/packages/inference/src/lib/getInferenceProviderMapping.ts
index 08a8ee5a8b..25418a50a5 100644
--- a/packages/inference/src/lib/getInferenceProviderMapping.ts
+++ b/packages/inference/src/lib/getInferenceProviderMapping.ts
@@ -121,6 +121,9 @@ export async function resolveProvider(
 		provider = "hf-inference";
 	}
 	if (!provider) {
+		console.log(
+			"Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers."
+		);
 		provider = "auto";
 	}
 	if (provider === "auto") {
@@ -129,6 +132,7 @@ export async function resolveProvider(
 		}
 		const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(modelId);
 		provider = Object.keys(inferenceProviderMapping)[0] as InferenceProvider;
+		console.log("Auto-selected provider:", provider);
 	}
 	return provider;
 }

From 4a104793d28f90b183345c7d98c878219130bb90 Mon Sep 17 00:00:00 2001
From: Celina Hanouti <hanouticelina@gmail.com>
Date: Fri, 25 Apr 2025 18:21:43 +0200
Subject: [PATCH 4/6] fix

---
 packages/inference/src/lib/getInferenceProviderMapping.ts | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/packages/inference/src/lib/getInferenceProviderMapping.ts b/packages/inference/src/lib/getInferenceProviderMapping.ts
index 25418a50a5..46153a102f 100644
--- a/packages/inference/src/lib/getInferenceProviderMapping.ts
+++ b/packages/inference/src/lib/getInferenceProviderMapping.ts
@@ -117,9 +117,6 @@ export async function resolveProvider(
 	provider?: InferenceProviderPolicy,
 	modelId?: string
 ): Promise<InferenceProvider> {
-	if (!provider && !modelId) {
-		provider = "hf-inference";
-	}
 	if (!provider) {
 		console.log(
 			"Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers."

From 38530d057e20a863ce23e166149e96a0ad95c78f Mon Sep 17 00:00:00 2001
From: Celina Hanouti <hanouticelina@gmail.com>
Date: Tue, 29 Apr 2025 15:39:26 +0200
Subject: [PATCH 5/6] rename InferenceProviderPolicy ->
 InferenceProviderOrPolicy

---
 packages/inference/src/lib/getInferenceProviderMapping.ts | 4 ++--
 packages/inference/src/types.ts                           | 4 ++--
 2 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/packages/inference/src/lib/getInferenceProviderMapping.ts b/packages/inference/src/lib/getInferenceProviderMapping.ts
index 3220188efe..96904d5d9c 100644
--- a/packages/inference/src/lib/getInferenceProviderMapping.ts
+++ b/packages/inference/src/lib/getInferenceProviderMapping.ts
@@ -2,7 +2,7 @@ import type { WidgetType } from "@huggingface/tasks";
 import { HF_HUB_URL } from "../config";
 import { HARDCODED_MODEL_INFERENCE_MAPPING } from "../providers/consts";
 import { EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS } from "../providers/hf-inference";
-import type { InferenceProvider, InferenceProviderPolicy, ModelId } from "../types";
+import type { InferenceProvider, InferenceProviderOrPolicy, ModelId } from "../types";
 import { typedInclude } from "../utils/typedInclude";
 
 export const inferenceProviderMappingCache = new Map<ModelId, InferenceProviderMapping>();
@@ -98,7 +98,7 @@ export async function getInferenceProviderMapping(
 }
 
 export async function resolveProvider(
-	provider?: InferenceProviderPolicy,
+	provider?: InferenceProviderOrPolicy,
 	modelId?: string
 ): Promise<InferenceProvider> {
 	if (!provider) {
diff --git a/packages/inference/src/types.ts b/packages/inference/src/types.ts
index f362a5647e..a81b1ddde5 100644
--- a/packages/inference/src/types.ts
+++ b/packages/inference/src/types.ts
@@ -60,7 +60,7 @@ export const PROVIDERS_OR_POLICIES = [...INFERENCE_PROVIDERS, "auto"] as const;
 
 export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number];
 
-export type InferenceProviderPolicy = (typeof PROVIDERS_OR_POLICIES)[number];
+export type InferenceProviderOrPolicy = (typeof PROVIDERS_OR_POLICIES)[number];
 
 export interface BaseArgs {
 	/**
@@ -94,7 +94,7 @@ export interface BaseArgs {
 	 *
 	 * Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
 	 */
-	provider?: InferenceProviderPolicy;
+	provider?: InferenceProviderOrPolicy;
 }
 
 export type RequestArgs = BaseArgs &

From a4bce4f37c9894df6a0ac224282f628ec2eea566 Mon Sep 17 00:00:00 2001
From: Celina Hanouti <hanouticelina@gmail.com>
Date: Wed, 30 Apr 2025 11:12:46 +0200
Subject: [PATCH 6/6] remove unnecessary log

---
 packages/inference/src/lib/getInferenceProviderMapping.ts | 1 -
 1 file changed, 1 deletion(-)

diff --git a/packages/inference/src/lib/getInferenceProviderMapping.ts b/packages/inference/src/lib/getInferenceProviderMapping.ts
index 96904d5d9c..d751beae6d 100644
--- a/packages/inference/src/lib/getInferenceProviderMapping.ts
+++ b/packages/inference/src/lib/getInferenceProviderMapping.ts
@@ -113,7 +113,6 @@ export async function resolveProvider(
 		}
 		const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(modelId);
 		provider = Object.keys(inferenceProviderMapping)[0] as InferenceProvider;
-		console.log("Auto-selected provider:", provider);
 	}
 	return provider;
 }