Skip to content

[Inference Providers] provider="auto" #1390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Apr 30, 2025
76 changes: 57 additions & 19 deletions packages/inference/src/lib/getInferenceProviderMapping.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import type { WidgetType } from "@huggingface/tasks";
import type { InferenceProvider, ModelId } from "../types";
import { HF_HUB_URL } from "../config";
import { HARDCODED_MODEL_INFERENCE_MAPPING } from "../providers/consts";
import { EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS } from "../providers/hf-inference";
import type { InferenceProvider, InferenceProviderOrPolicy, ModelId } from "../types";
import { typedInclude } from "../utils/typedInclude";

export const inferenceProviderMappingCache = new Map<ModelId, InferenceProviderMapping>();
@@ -20,44 +20,62 @@ export interface InferenceProviderModelMapping {
task: WidgetType;
}

export async function getInferenceProviderMapping(
params: {
accessToken?: string;
modelId: ModelId;
provider: InferenceProvider;
task: WidgetType;
},
options: {
export async function fetchInferenceProviderMappingForModel(
modelId: ModelId,
accessToken?: string,
options?: {
fetch?: (input: RequestInfo, init?: RequestInit) => Promise<Response>;
}
): Promise<InferenceProviderModelMapping | null> {
if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
}
): Promise<InferenceProviderMapping> {
let inferenceProviderMapping: InferenceProviderMapping | null;
if (inferenceProviderMappingCache.has(params.modelId)) {
if (inferenceProviderMappingCache.has(modelId)) {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
inferenceProviderMapping = inferenceProviderMappingCache.get(params.modelId)!;
inferenceProviderMapping = inferenceProviderMappingCache.get(modelId)!;
} else {
const resp = await (options?.fetch ?? fetch)(
`${HF_HUB_URL}/api/models/${params.modelId}?expand[]=inferenceProviderMapping`,
`${HF_HUB_URL}/api/models/${modelId}?expand[]=inferenceProviderMapping`,
{
headers: params.accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${params.accessToken}` } : {},
headers: accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${accessToken}` } : {},
}
);
if (resp.status === 404) {
throw new Error(`Model ${params.modelId} does not exist`);
throw new Error(`Model ${modelId} does not exist`);
}
inferenceProviderMapping = await resp
.json()
.then((json) => json.inferenceProviderMapping)
.catch(() => null);

if (inferenceProviderMapping) {
inferenceProviderMappingCache.set(modelId, inferenceProviderMapping);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if provider="auto", we call fetchInferenceProviderMappingForModel 2 times: once to resolve the provider and a second time in makeRequestOptions. this cache avoids the extra HTTP call.

}
}

if (!inferenceProviderMapping) {
throw new Error(`We have not been able to find inference provider information for model ${params.modelId}.`);
throw new Error(`We have not been able to find inference provider information for model ${modelId}.`);
}
return inferenceProviderMapping;
}

export async function getInferenceProviderMapping(
params: {
accessToken?: string;
modelId: ModelId;
provider: InferenceProvider;
task: WidgetType;
},
options: {
fetch?: (input: RequestInfo, init?: RequestInit) => Promise<Response>;
}
): Promise<InferenceProviderModelMapping | null> {
if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
}
const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(
params.modelId,
params.accessToken,
options
);
const providerMapping = inferenceProviderMapping[params.provider];
if (providerMapping) {
const equivalentTasks =
@@ -78,3 +96,23 @@ export async function getInferenceProviderMapping(
}
return null;
}

export async function resolveProvider(
Copy link
Contributor Author

@hanouticelina hanouticelina Apr 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this could be done in getProviderHelper as we did for the python client, but that would make the function async and we would have to update the snippets generation as well

provider?: InferenceProviderOrPolicy,
modelId?: string
): Promise<InferenceProvider> {
if (!provider) {
console.log(
"Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers."
);
provider = "auto";
}
if (provider === "auto") {
if (!modelId) {
throw new Error("Specifying a model is required when provider is 'auto'");
}
const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(modelId);
provider = Object.keys(inferenceProviderMapping)[0] as InferenceProvider;
}
return provider;
}
7 changes: 4 additions & 3 deletions packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
@@ -27,8 +27,8 @@ export async function makeRequestOptions(
task?: InferenceTask;
}
): Promise<{ url: string; info: RequestInit }> {
const { provider: maybeProvider, model: maybeModel } = args;
const provider = maybeProvider ?? "hf-inference";
const { model: maybeModel } = args;
const provider = providerHelper.provider;
const { task } = options ?? {};

// Validate inputs
@@ -113,8 +113,9 @@ export function makeRequestOptionsFromResolvedModel(
): { url: string; info: RequestInit } {
const { accessToken, endpointUrl, provider: maybeProvider, model, ...remainingArgs } = args;
void model;
void maybeProvider;

const provider = maybeProvider ?? "hf-inference";
const provider = providerHelper.provider;

const { includeCredentials, task, signal, billTo } = options ?? {};
const authMethod = (() => {
2 changes: 1 addition & 1 deletion packages/inference/src/providers/providerHelper.ts
Original file line number Diff line number Diff line change
@@ -56,7 +56,7 @@ import { toArray } from "../utils/toArray";
*/
export abstract class TaskProviderHelper {
constructor(
private provider: InferenceProvider,
readonly provider: InferenceProvider,
private baseUrl: string,
readonly clientSideRoutingOnly: boolean = false
) {}
4 changes: 3 additions & 1 deletion packages/inference/src/tasks/audio/audioClassification.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { AudioClassificationInput, AudioClassificationOutput } from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { BaseArgs, Options } from "../../types";
import { innerRequest } from "../../utils/request";
@@ -15,7 +16,8 @@ export async function audioClassification(
args: AudioClassificationArgs,
options?: Options
): Promise<AudioClassificationOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-classification");
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "audio-classification");
const payload = preparePayload(args);
const { data: res } = await innerRequest<AudioClassificationOutput>(payload, providerHelper, {
...options,
5 changes: 4 additions & 1 deletion packages/inference/src/tasks/audio/audioToAudio.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { BaseArgs, Options } from "../../types";
import { innerRequest } from "../../utils/request";
@@ -36,7 +37,9 @@ export interface AudioToAudioOutput {
* Example model: speechbrain/sepformer-wham does audio source separation.
*/
export async function audioToAudio(args: AudioToAudioArgs, options?: Options): Promise<AudioToAudioOutput[]> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-to-audio");
const model = "inputs" in args ? args.model : undefined;
const provider = await resolveProvider(args.provider, model);
const providerHelper = getProviderHelper(provider, "audio-to-audio");
const payload = preparePayload(args);
const { data: res } = await innerRequest<AudioToAudioOutput>(payload, providerHelper, {
...options,
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { AutomaticSpeechRecognitionInput, AutomaticSpeechRecognitionOutput } from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import { FAL_AI_SUPPORTED_BLOB_TYPES } from "../../providers/fal-ai";
@@ -18,7 +19,8 @@ export async function automaticSpeechRecognition(
args: AutomaticSpeechRecognitionArgs,
options?: Options
): Promise<AutomaticSpeechRecognitionOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "automatic-speech-recognition");
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "automatic-speech-recognition");
const payload = await buildPayload(args);
const { data: res } = await innerRequest<AutomaticSpeechRecognitionOutput>(payload, providerHelper, {
...options,
3 changes: 2 additions & 1 deletion packages/inference/src/tasks/audio/textToSpeech.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { TextToSpeechInput } from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { BaseArgs, Options } from "../../types";
import { innerRequest } from "../../utils/request";
@@ -12,7 +13,7 @@ interface OutputUrlTextToSpeechGeneration {
* Recommended model: espnet/kan-bayashi_ljspeech_vits
*/
export async function textToSpeech(args: TextToSpeechArgs, options?: Options): Promise<Blob> {
const provider = args.provider ?? "hf-inference";
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "text-to-speech");
const { data: res } = await innerRequest<Blob | OutputUrlTextToSpeechGeneration>(args, providerHelper, {
...options,
4 changes: 3 additions & 1 deletion packages/inference/src/tasks/custom/request.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { InferenceTask, Options, RequestArgs } from "../../types";
import { innerRequest } from "../../utils/request";
@@ -16,7 +17,8 @@ export async function request<T>(
console.warn(
"The request method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
);
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", options?.task);
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, options?.task);
const result = await innerRequest<T>(args, providerHelper, options);
return result.data;
}
4 changes: 3 additions & 1 deletion packages/inference/src/tasks/custom/streamingRequest.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { InferenceTask, Options, RequestArgs } from "../../types";
import { innerStreamingRequest } from "../../utils/request";
@@ -16,6 +17,7 @@ export async function* streamingRequest<T>(
console.warn(
"The streamingRequest method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
);
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", options?.task);
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, options?.task);
yield* innerStreamingRequest(args, providerHelper, options);
}
4 changes: 3 additions & 1 deletion packages/inference/src/tasks/cv/imageClassification.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { ImageClassificationInput, ImageClassificationOutput } from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { BaseArgs, Options } from "../../types";
import { innerRequest } from "../../utils/request";
@@ -14,7 +15,8 @@ export async function imageClassification(
args: ImageClassificationArgs,
options?: Options
): Promise<ImageClassificationOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-classification");
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "image-classification");
const payload = preparePayload(args);
const { data: res } = await innerRequest<ImageClassificationOutput>(payload, providerHelper, {
...options,
4 changes: 3 additions & 1 deletion packages/inference/src/tasks/cv/imageSegmentation.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { ImageSegmentationInput, ImageSegmentationOutput } from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { BaseArgs, Options } from "../../types";
import { innerRequest } from "../../utils/request";
@@ -14,7 +15,8 @@ export async function imageSegmentation(
args: ImageSegmentationArgs,
options?: Options
): Promise<ImageSegmentationOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-segmentation");
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "image-segmentation");
const payload = preparePayload(args);
const { data: res } = await innerRequest<ImageSegmentationOutput>(payload, providerHelper, {
...options,
4 changes: 3 additions & 1 deletion packages/inference/src/tasks/cv/imageToImage.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { ImageToImageInput } from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { BaseArgs, Options, RequestArgs } from "../../types";
import { base64FromBytes } from "../../utils/base64FromBytes";
@@ -11,7 +12,8 @@ export type ImageToImageArgs = BaseArgs & ImageToImageInput;
* Recommended model: lllyasviel/sd-controlnet-depth
*/
export async function imageToImage(args: ImageToImageArgs, options?: Options): Promise<Blob> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-image");
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "image-to-image");
let reqArgs: RequestArgs;
if (!args.parameters) {
reqArgs = {
4 changes: 3 additions & 1 deletion packages/inference/src/tasks/cv/imageToText.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { ImageToTextInput, ImageToTextOutput } from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { BaseArgs, Options } from "../../types";
import { innerRequest } from "../../utils/request";
@@ -10,7 +11,8 @@ export type ImageToTextArgs = BaseArgs & (ImageToTextInput | LegacyImageInput);
* This task reads some image input and outputs the text caption.
*/
export async function imageToText(args: ImageToTextArgs, options?: Options): Promise<ImageToTextOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-text");
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "image-to-text");
const payload = preparePayload(args);
const { data: res } = await innerRequest<[ImageToTextOutput]>(payload, providerHelper, {
...options,
4 changes: 3 additions & 1 deletion packages/inference/src/tasks/cv/objectDetection.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { ObjectDetectionInput, ObjectDetectionOutput } from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { BaseArgs, Options } from "../../types";
import { innerRequest } from "../../utils/request";
@@ -11,7 +12,8 @@ export type ObjectDetectionArgs = BaseArgs & (ObjectDetectionInput | LegacyImage
* Recommended model: facebook/detr-resnet-50
*/
export async function objectDetection(args: ObjectDetectionArgs, options?: Options): Promise<ObjectDetectionOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "object-detection");
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "object-detection");
const payload = preparePayload(args);
const { data: res } = await innerRequest<ObjectDetectionOutput>(payload, providerHelper, {
...options,
3 changes: 2 additions & 1 deletion packages/inference/src/tasks/cv/textToImage.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { TextToImageInput } from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import { makeRequestOptions } from "../../lib/makeRequestOptions";
import type { BaseArgs, Options } from "../../types";
@@ -23,7 +24,7 @@ export async function textToImage(
options?: TextToImageOptions & { outputType?: undefined | "blob" }
): Promise<Blob>;
export async function textToImage(args: TextToImageArgs, options?: TextToImageOptions): Promise<Blob | string> {
const provider = args.provider ?? "hf-inference";
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "text-to-image");
const { data: res } = await innerRequest<Record<string, unknown>>(args, providerHelper, {
...options,
3 changes: 2 additions & 1 deletion packages/inference/src/tasks/cv/textToVideo.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { TextToVideoInput } from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import { makeRequestOptions } from "../../lib/makeRequestOptions";
import type { FalAiQueueOutput } from "../../providers/fal-ai";
@@ -12,7 +13,7 @@ export type TextToVideoArgs = BaseArgs & TextToVideoInput;
export type TextToVideoOutput = Blob;

export async function textToVideo(args: TextToVideoArgs, options?: Options): Promise<TextToVideoOutput> {
const provider = args.provider ?? "hf-inference";
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "text-to-video");
const { data: response } = await innerRequest<FalAiQueueOutput | ReplicateOutput | NovitaOutput>(
args,
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { ZeroShotImageClassificationInput, ZeroShotImageClassificationOutput } from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { BaseArgs, Options, RequestArgs } from "../../types";
import { base64FromBytes } from "../../utils/base64FromBytes";
@@ -44,7 +45,8 @@ export async function zeroShotImageClassification(
args: ZeroShotImageClassificationArgs,
options?: Options
): Promise<ZeroShotImageClassificationOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-image-classification");
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "zero-shot-image-classification");
const payload = await preparePayload(args);
const { data: res } = await innerRequest<ZeroShotImageClassificationOutput>(payload, providerHelper, {
...options,
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@ import type {
DocumentQuestionAnsweringInputData,
DocumentQuestionAnsweringOutput,
} from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { BaseArgs, Options, RequestArgs } from "../../types";
import { base64FromBytes } from "../../utils/base64FromBytes";
@@ -19,7 +20,8 @@ export async function documentQuestionAnswering(
args: DocumentQuestionAnsweringArgs,
options?: Options
): Promise<DocumentQuestionAnsweringOutput[number]> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "document-question-answering");
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "document-question-answering");
const reqArgs: RequestArgs = {
...args,
inputs: {
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@ import type {
VisualQuestionAnsweringInputData,
VisualQuestionAnsweringOutput,
} from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { BaseArgs, Options, RequestArgs } from "../../types";
import { base64FromBytes } from "../../utils/base64FromBytes";
@@ -19,7 +20,8 @@ export async function visualQuestionAnswering(
args: VisualQuestionAnsweringArgs,
options?: Options
): Promise<VisualQuestionAnsweringOutput[number]> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "visual-question-answering");
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "visual-question-answering");
const reqArgs: RequestArgs = {
...args,
inputs: {
4 changes: 3 additions & 1 deletion packages/inference/src/tasks/nlp/chatCompletion.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { ChatCompletionInput, ChatCompletionOutput } from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { BaseArgs, Options } from "../../types";
import { innerRequest } from "../../utils/request";
@@ -10,7 +11,8 @@ export async function chatCompletion(
args: BaseArgs & ChatCompletionInput,
options?: Options
): Promise<ChatCompletionOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "conversational");
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "conversational");
const { data: response } = await innerRequest<ChatCompletionOutput>(args, providerHelper, {
...options,
task: "conversational",
4 changes: 3 additions & 1 deletion packages/inference/src/tasks/nlp/chatCompletionStream.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { ChatCompletionInput, ChatCompletionStreamOutput } from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { BaseArgs, Options } from "../../types";
import { innerStreamingRequest } from "../../utils/request";
@@ -10,7 +11,8 @@ export async function* chatCompletionStream(
args: BaseArgs & ChatCompletionInput,
options?: Options
): AsyncGenerator<ChatCompletionStreamOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "conversational");
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "conversational");
yield* innerStreamingRequest<ChatCompletionStreamOutput>(args, providerHelper, {
...options,
task: "conversational",
4 changes: 3 additions & 1 deletion packages/inference/src/tasks/nlp/featureExtraction.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { FeatureExtractionInput } from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { BaseArgs, Options } from "../../types";
import { innerRequest } from "../../utils/request";
@@ -22,7 +23,8 @@ export async function featureExtraction(
args: FeatureExtractionArgs,
options?: Options
): Promise<FeatureExtractionOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "feature-extraction");
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "feature-extraction");
const { data: res } = await innerRequest<FeatureExtractionOutput>(args, providerHelper, {
...options,
task: "feature-extraction",
4 changes: 3 additions & 1 deletion packages/inference/src/tasks/nlp/fillMask.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { FillMaskInput, FillMaskOutput } from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { BaseArgs, Options } from "../../types";
import { innerRequest } from "../../utils/request";
@@ -9,7 +10,8 @@ export type FillMaskArgs = BaseArgs & FillMaskInput;
* Tries to fill in a hole with a missing word (token to be precise). That’s the base task for BERT models.
*/
export async function fillMask(args: FillMaskArgs, options?: Options): Promise<FillMaskOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "fill-mask");
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "fill-mask");
const { data: res } = await innerRequest<FillMaskOutput>(args, providerHelper, {
...options,
task: "fill-mask",
5 changes: 4 additions & 1 deletion packages/inference/src/tasks/nlp/questionAnswering.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import type { QuestionAnsweringInput, QuestionAnsweringOutput } from "@huggingface/tasks";

import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { BaseArgs, Options } from "../../types";
import { innerRequest } from "../../utils/request";
@@ -12,7 +14,8 @@ export async function questionAnswering(
args: QuestionAnsweringArgs,
options?: Options
): Promise<QuestionAnsweringOutput[number]> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "question-answering");
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "question-answering");
const { data: res } = await innerRequest<QuestionAnsweringOutput | QuestionAnsweringOutput[number]>(
args,
providerHelper,
4 changes: 3 additions & 1 deletion packages/inference/src/tasks/nlp/sentenceSimilarity.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { SentenceSimilarityInput, SentenceSimilarityOutput } from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { BaseArgs, Options } from "../../types";
import { innerRequest } from "../../utils/request";
@@ -12,7 +13,8 @@ export async function sentenceSimilarity(
args: SentenceSimilarityArgs,
options?: Options
): Promise<SentenceSimilarityOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "sentence-similarity");
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "sentence-similarity");
const { data: res } = await innerRequest<SentenceSimilarityOutput>(args, providerHelper, {
...options,
task: "sentence-similarity",
4 changes: 3 additions & 1 deletion packages/inference/src/tasks/nlp/summarization.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { SummarizationInput, SummarizationOutput } from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { BaseArgs, Options } from "../../types";
import { innerRequest } from "../../utils/request";
@@ -9,7 +10,8 @@ export type SummarizationArgs = BaseArgs & SummarizationInput;
* This task is well known to summarize longer text into shorter text. Be careful, some models have a maximum length of input. That means that the summary cannot handle full books for instance. Be careful when choosing your model.
*/
export async function summarization(args: SummarizationArgs, options?: Options): Promise<SummarizationOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "summarization");
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "summarization");
const { data: res } = await innerRequest<SummarizationOutput[]>(args, providerHelper, {
...options,
task: "summarization",
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { TableQuestionAnsweringInput, TableQuestionAnsweringOutput } from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { BaseArgs, Options } from "../../types";
import { innerRequest } from "../../utils/request";
@@ -12,7 +13,8 @@ export async function tableQuestionAnswering(
args: TableQuestionAnsweringArgs,
options?: Options
): Promise<TableQuestionAnsweringOutput[number]> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "table-question-answering");
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "table-question-answering");
const { data: res } = await innerRequest<TableQuestionAnsweringOutput | TableQuestionAnsweringOutput[number]>(
args,
providerHelper,
4 changes: 3 additions & 1 deletion packages/inference/src/tasks/nlp/textClassification.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { TextClassificationInput, TextClassificationOutput } from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { BaseArgs, Options } from "../../types";
import { innerRequest } from "../../utils/request";
@@ -12,7 +13,8 @@ export async function textClassification(
args: TextClassificationArgs,
options?: Options
): Promise<TextClassificationOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-classification");
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "text-classification");
const { data: res } = await innerRequest<TextClassificationOutput>(args, providerHelper, {
...options,
task: "text-classification",
4 changes: 3 additions & 1 deletion packages/inference/src/tasks/nlp/textGeneration.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { TextGenerationInput, TextGenerationOutput } from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { HyperbolicTextCompletionOutput } from "../../providers/hyperbolic";
import type { BaseArgs, Options } from "../../types";
@@ -13,7 +14,8 @@ export async function textGeneration(
args: BaseArgs & TextGenerationInput,
options?: Options
): Promise<TextGenerationOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-generation");
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "text-generation");
const { data: response } = await innerRequest<
HyperbolicTextCompletionOutput | TextGenerationOutput | TextGenerationOutput[]
>(args, providerHelper, {
4 changes: 3 additions & 1 deletion packages/inference/src/tasks/nlp/textGenerationStream.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { TextGenerationInput } from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { BaseArgs, Options } from "../../types";
import { innerStreamingRequest } from "../../utils/request";
@@ -90,7 +91,8 @@ export async function* textGenerationStream(
args: BaseArgs & TextGenerationInput,
options?: Options
): AsyncGenerator<TextGenerationStreamOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-generation");
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "text-generation");
yield* innerStreamingRequest<TextGenerationStreamOutput>(args, providerHelper, {
...options,
task: "text-generation",
4 changes: 3 additions & 1 deletion packages/inference/src/tasks/nlp/tokenClassification.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { TokenClassificationInput, TokenClassificationOutput } from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { BaseArgs, Options } from "../../types";
import { innerRequest } from "../../utils/request";
@@ -12,7 +13,8 @@ export async function tokenClassification(
args: TokenClassificationArgs,
options?: Options
): Promise<TokenClassificationOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "token-classification");
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "token-classification");
const { data: res } = await innerRequest<TokenClassificationOutput[number] | TokenClassificationOutput>(
args,
providerHelper,
4 changes: 3 additions & 1 deletion packages/inference/src/tasks/nlp/translation.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { TranslationInput, TranslationOutput } from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { BaseArgs, Options } from "../../types";
import { innerRequest } from "../../utils/request";
@@ -8,7 +9,8 @@ export type TranslationArgs = BaseArgs & TranslationInput;
* This task is well known to translate text from one language to another. Recommended model: Helsinki-NLP/opus-mt-ru-en.
*/
export async function translation(args: TranslationArgs, options?: Options): Promise<TranslationOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "translation");
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "translation");
const { data: res } = await innerRequest<TranslationOutput>(args, providerHelper, {
...options,
task: "translation",
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { ZeroShotClassificationInput, ZeroShotClassificationOutput } from "@huggingface/tasks";
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { BaseArgs, Options } from "../../types";
import { innerRequest } from "../../utils/request";
@@ -12,7 +13,8 @@ export async function zeroShotClassification(
args: ZeroShotClassificationArgs,
options?: Options
): Promise<ZeroShotClassificationOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-classification");
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "zero-shot-classification");
const { data: res } = await innerRequest<ZeroShotClassificationOutput[number] | ZeroShotClassificationOutput>(
args,
providerHelper,
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { BaseArgs, Options } from "../../types";
import { innerRequest } from "../../utils/request";
@@ -25,7 +26,8 @@ export async function tabularClassification(
args: TabularClassificationArgs,
options?: Options
): Promise<TabularClassificationOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-classification");
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "tabular-classification");
const { data: res } = await innerRequest<TabularClassificationOutput>(args, providerHelper, {
...options,
task: "tabular-classification",
4 changes: 3 additions & 1 deletion packages/inference/src/tasks/tabular/tabularRegression.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { BaseArgs, Options } from "../../types";
import { innerRequest } from "../../utils/request";
@@ -25,7 +26,8 @@ export async function tabularRegression(
args: TabularRegressionArgs,
options?: Options
): Promise<TabularRegressionOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-regression");
const provider = await resolveProvider(args.provider, args.model);
const providerHelper = getProviderHelper(provider, "tabular-regression");
const { data: res } = await innerRequest<TabularRegressionOutput>(args, providerHelper, {
...options,
task: "tabular-regression",
8 changes: 6 additions & 2 deletions packages/inference/src/types.ts
Original file line number Diff line number Diff line change
@@ -57,8 +57,12 @@ export const INFERENCE_PROVIDERS = [
"together",
] as const;

export const PROVIDERS_OR_POLICIES = [...INFERENCE_PROVIDERS, "auto"] as const;

export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number];

export type InferenceProviderOrPolicy = (typeof PROVIDERS_OR_POLICIES)[number];

export interface BaseArgs {
/**
* The access token to use. Without it, you'll get rate-limited quickly.
@@ -89,9 +93,9 @@ export interface BaseArgs {
/**
* Set an Inference provider to run this model on.
*
* Defaults to the first provider in your user settings that is compatible with this model.
* Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
*/
provider?: InferenceProvider;
provider?: InferenceProviderOrPolicy;
}

export type RequestArgs = BaseArgs &