From 6ddcd1881dcafefbff6a8981b1af6c9642f9ed7f Mon Sep 17 00:00:00 2001 From: Celina Hanouti <hanouticelina@gmail.com> Date: Fri, 25 Apr 2025 17:57:56 +0200 Subject: [PATCH 1/6] auto select provider --- .../src/lib/getInferenceProviderMapping.ts | 76 ++++++++++++++----- .../inference/src/lib/makeRequestOptions.ts | 7 +- .../inference/src/providers/providerHelper.ts | 2 +- .../src/tasks/audio/audioClassification.ts | 4 +- .../inference/src/tasks/audio/audioToAudio.ts | 5 +- .../tasks/audio/automaticSpeechRecognition.ts | 4 +- .../inference/src/tasks/audio/textToSpeech.ts | 3 +- .../inference/src/tasks/custom/request.ts | 4 +- .../src/tasks/custom/streamingRequest.ts | 4 +- .../src/tasks/cv/imageClassification.ts | 4 +- .../src/tasks/cv/imageSegmentation.ts | 4 +- .../inference/src/tasks/cv/imageToImage.ts | 4 +- .../inference/src/tasks/cv/imageToText.ts | 4 +- .../inference/src/tasks/cv/objectDetection.ts | 4 +- .../inference/src/tasks/cv/textToImage.ts | 3 +- .../inference/src/tasks/cv/textToVideo.ts | 3 +- .../tasks/cv/zeroShotImageClassification.ts | 4 +- .../multimodal/documentQuestionAnswering.ts | 4 +- .../multimodal/visualQuestionAnswering.ts | 4 +- .../inference/src/tasks/nlp/chatCompletion.ts | 4 +- .../src/tasks/nlp/chatCompletionStream.ts | 4 +- .../src/tasks/nlp/featureExtraction.ts | 4 +- packages/inference/src/tasks/nlp/fillMask.ts | 4 +- .../src/tasks/nlp/questionAnswering.ts | 5 +- .../src/tasks/nlp/sentenceSimilarity.ts | 4 +- .../inference/src/tasks/nlp/summarization.ts | 4 +- .../src/tasks/nlp/tableQuestionAnswering.ts | 4 +- .../src/tasks/nlp/textClassification.ts | 4 +- .../inference/src/tasks/nlp/textGeneration.ts | 4 +- .../src/tasks/nlp/textGenerationStream.ts | 4 +- .../src/tasks/nlp/tokenClassification.ts | 4 +- .../inference/src/tasks/nlp/translation.ts | 4 +- .../src/tasks/nlp/zeroShotClassification.ts | 4 +- .../tasks/tabular/tabularClassification.ts | 4 +- .../src/tasks/tabular/tabularRegression.ts | 4 +- packages/inference/src/types.ts | 6 +- 36 files changed, 162 insertions(+), 56 deletions(-) diff --git a/packages/inference/src/lib/getInferenceProviderMapping.ts b/packages/inference/src/lib/getInferenceProviderMapping.ts index d897cb72a0..08a8ee5a8b 100644 --- a/packages/inference/src/lib/getInferenceProviderMapping.ts +++ b/packages/inference/src/lib/getInferenceProviderMapping.ts @@ -1,8 +1,8 @@ import type { WidgetType } from "@huggingface/tasks"; -import type { InferenceProvider, ModelId } from "../types"; import { HF_HUB_URL } from "../config"; import { HARDCODED_MODEL_INFERENCE_MAPPING } from "../providers/consts"; import { EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS } from "../providers/hf-inference"; +import type { InferenceProvider, InferenceProviderPolicy, ModelId } from "../types"; import { typedInclude } from "../utils/typedInclude"; export const inferenceProviderMappingCache = new Map<ModelId, InferenceProviderMapping>(); @@ -20,44 +20,62 @@ export interface InferenceProviderModelMapping { task: WidgetType; } -export async function getInferenceProviderMapping( - params: { - accessToken?: string; - modelId: ModelId; - provider: InferenceProvider; - task: WidgetType; - }, - options: { +export async function fetchInferenceProviderMappingForModel( + modelId: ModelId, + accessToken?: string, + options?: { fetch?: (input: RequestInfo, init?: RequestInit) => Promise<Response>; } -): Promise<InferenceProviderModelMapping | null> { - if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) { - return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]; - } +): Promise<InferenceProviderMapping> { let inferenceProviderMapping: InferenceProviderMapping | null; - if (inferenceProviderMappingCache.has(params.modelId)) { + if (inferenceProviderMappingCache.has(modelId)) { // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - inferenceProviderMapping = inferenceProviderMappingCache.get(params.modelId)!; + inferenceProviderMapping = inferenceProviderMappingCache.get(modelId)!; } else { const resp = await (options?.fetch ?? fetch)( - `${HF_HUB_URL}/api/models/${params.modelId}?expand[]=inferenceProviderMapping`, + `${HF_HUB_URL}/api/models/${modelId}?expand[]=inferenceProviderMapping`, { - headers: params.accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${params.accessToken}` } : {}, + headers: accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${accessToken}` } : {}, } ); if (resp.status === 404) { - throw new Error(`Model ${params.modelId} does not exist`); + throw new Error(`Model ${modelId} does not exist`); } inferenceProviderMapping = await resp .json() .then((json) => json.inferenceProviderMapping) .catch(() => null); + + if (inferenceProviderMapping) { + inferenceProviderMappingCache.set(modelId, inferenceProviderMapping); + } } if (!inferenceProviderMapping) { - throw new Error(`We have not been able to find inference provider information for model ${params.modelId}.`); + throw new Error(`We have not been able to find inference provider information for model ${modelId}.`); } + return inferenceProviderMapping; +} +export async function getInferenceProviderMapping( + params: { + accessToken?: string; + modelId: ModelId; + provider: InferenceProvider; + task: WidgetType; + }, + options: { + fetch?: (input: RequestInfo, init?: RequestInit) => Promise<Response>; + } +): Promise<InferenceProviderModelMapping | null> { + if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) { + return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]; + } + const inferenceProviderMapping = await fetchInferenceProviderMappingForModel( + params.modelId, + params.accessToken, + options + ); const providerMapping = inferenceProviderMapping[params.provider]; if (providerMapping) { const equivalentTasks = @@ -94,3 +112,23 @@ export async function getInferenceProviderMapping( } return null; } + +export async function resolveProvider( + provider?: InferenceProviderPolicy, + modelId?: string +): Promise<InferenceProvider> { + if (!provider && !modelId) { + provider = "hf-inference"; + } + if (!provider) { + provider = "auto"; + } + if (provider === "auto") { + if (!modelId) { + throw new Error("Specifying a model is required when provider is 'auto'"); + } + const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(modelId); + provider = Object.keys(inferenceProviderMapping)[0] as InferenceProvider; + } + return provider; +} diff --git a/packages/inference/src/lib/makeRequestOptions.ts b/packages/inference/src/lib/makeRequestOptions.ts index 06db7d3693..08f1059cc7 100644 --- a/packages/inference/src/lib/makeRequestOptions.ts +++ b/packages/inference/src/lib/makeRequestOptions.ts @@ -27,8 +27,8 @@ export async function makeRequestOptions( task?: InferenceTask; } ): Promise<{ url: string; info: RequestInit }> { - const { provider: maybeProvider, model: maybeModel } = args; - const provider = maybeProvider ?? "hf-inference"; + const { model: maybeModel } = args; + const provider = providerHelper.provider; const { task } = options ?? {}; // Validate inputs @@ -113,8 +113,9 @@ export function makeRequestOptionsFromResolvedModel( ): { url: string; info: RequestInit } { const { accessToken, endpointUrl, provider: maybeProvider, model, ...remainingArgs } = args; void model; + void maybeProvider; - const provider = maybeProvider ?? "hf-inference"; + const provider = providerHelper.provider; const { includeCredentials, task, signal, billTo } = options ?? {}; const authMethod = (() => { diff --git a/packages/inference/src/providers/providerHelper.ts b/packages/inference/src/providers/providerHelper.ts index a0da0f7c6f..994f3ab0ea 100644 --- a/packages/inference/src/providers/providerHelper.ts +++ b/packages/inference/src/providers/providerHelper.ts @@ -56,7 +56,7 @@ import { toArray } from "../utils/toArray"; */ export abstract class TaskProviderHelper { constructor( - private provider: InferenceProvider, + readonly provider: InferenceProvider, private baseUrl: string, readonly clientSideRoutingOnly: boolean = false ) {} diff --git a/packages/inference/src/tasks/audio/audioClassification.ts b/packages/inference/src/tasks/audio/audioClassification.ts index f1ff1c20c7..bc495b0227 100644 --- a/packages/inference/src/tasks/audio/audioClassification.ts +++ b/packages/inference/src/tasks/audio/audioClassification.ts @@ -1,4 +1,5 @@ import type { AudioClassificationInput, AudioClassificationOutput } from "@huggingface/tasks"; +import { resolveProvider } from "../../lib/getInferenceProviderMapping"; import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { innerRequest } from "../../utils/request"; @@ -15,7 +16,8 @@ export async function audioClassification( args: AudioClassificationArgs, options?: Options ): Promise<AudioClassificationOutput> { - const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-classification"); + const provider = await resolveProvider(args.provider, args.model); + const providerHelper = getProviderHelper(provider, "audio-classification"); const payload = preparePayload(args); const { data: res } = await innerRequest<AudioClassificationOutput>(payload, providerHelper, { ...options, diff --git a/packages/inference/src/tasks/audio/audioToAudio.ts b/packages/inference/src/tasks/audio/audioToAudio.ts index fa055d3234..b4f37c1563 100644 --- a/packages/inference/src/tasks/audio/audioToAudio.ts +++ b/packages/inference/src/tasks/audio/audioToAudio.ts @@ -1,3 +1,4 @@ +import { resolveProvider } from "../../lib/getInferenceProviderMapping"; import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { innerRequest } from "../../utils/request"; @@ -36,7 +37,9 @@ export interface AudioToAudioOutput { * Example model: speechbrain/sepformer-wham does audio source separation. */ export async function audioToAudio(args: AudioToAudioArgs, options?: Options): Promise<AudioToAudioOutput[]> { - const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-to-audio"); + const model = "inputs" in args ? args.model : undefined; + const provider = await resolveProvider(args.provider, model); + const providerHelper = getProviderHelper(provider, "audio-to-audio"); const payload = preparePayload(args); const { data: res } = await innerRequest<AudioToAudioOutput>(payload, providerHelper, { ...options, diff --git a/packages/inference/src/tasks/audio/automaticSpeechRecognition.ts b/packages/inference/src/tasks/audio/automaticSpeechRecognition.ts index c71af07427..347c01b833 100644 --- a/packages/inference/src/tasks/audio/automaticSpeechRecognition.ts +++ b/packages/inference/src/tasks/audio/automaticSpeechRecognition.ts @@ -1,4 +1,5 @@ import type { AutomaticSpeechRecognitionInput, AutomaticSpeechRecognitionOutput } from "@huggingface/tasks"; +import { resolveProvider } from "../../lib/getInferenceProviderMapping"; import { getProviderHelper } from "../../lib/getProviderHelper"; import { InferenceOutputError } from "../../lib/InferenceOutputError"; import { FAL_AI_SUPPORTED_BLOB_TYPES } from "../../providers/fal-ai"; @@ -18,7 +19,8 @@ export async function automaticSpeechRecognition( args: AutomaticSpeechRecognitionArgs, options?: Options ): Promise<AutomaticSpeechRecognitionOutput> { - const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "automatic-speech-recognition"); + const provider = await resolveProvider(args.provider, args.model); + const providerHelper = getProviderHelper(provider, "automatic-speech-recognition"); const payload = await buildPayload(args); const { data: res } = await innerRequest<AutomaticSpeechRecognitionOutput>(payload, providerHelper, { ...options, diff --git a/packages/inference/src/tasks/audio/textToSpeech.ts b/packages/inference/src/tasks/audio/textToSpeech.ts index 11f01f436f..24694062fc 100644 --- a/packages/inference/src/tasks/audio/textToSpeech.ts +++ b/packages/inference/src/tasks/audio/textToSpeech.ts @@ -1,4 +1,5 @@ import type { TextToSpeechInput } from "@huggingface/tasks"; +import { resolveProvider } from "../../lib/getInferenceProviderMapping"; import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { innerRequest } from "../../utils/request"; @@ -12,7 +13,7 @@ interface OutputUrlTextToSpeechGeneration { * Recommended model: espnet/kan-bayashi_ljspeech_vits */ export async function textToSpeech(args: TextToSpeechArgs, options?: Options): Promise<Blob> { - const provider = args.provider ?? "hf-inference"; + const provider = await resolveProvider(args.provider, args.model); const providerHelper = getProviderHelper(provider, "text-to-speech"); const { data: res } = await innerRequest<Blob | OutputUrlTextToSpeechGeneration>(args, providerHelper, { ...options, diff --git a/packages/inference/src/tasks/custom/request.ts b/packages/inference/src/tasks/custom/request.ts index 62f45f28b3..df103ad5e0 100644 --- a/packages/inference/src/tasks/custom/request.ts +++ b/packages/inference/src/tasks/custom/request.ts @@ -1,3 +1,4 @@ +import { resolveProvider } from "../../lib/getInferenceProviderMapping"; import { getProviderHelper } from "../../lib/getProviderHelper"; import type { InferenceTask, Options, RequestArgs } from "../../types"; import { innerRequest } from "../../utils/request"; @@ -16,7 +17,8 @@ export async function request<T>( console.warn( "The request method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead." ); - const providerHelper = getProviderHelper(args.provider ?? "hf-inference", options?.task); + const provider = await resolveProvider(args.provider, args.model); + const providerHelper = getProviderHelper(provider, options?.task); const result = await innerRequest<T>(args, providerHelper, options); return result.data; } diff --git a/packages/inference/src/tasks/custom/streamingRequest.ts b/packages/inference/src/tasks/custom/streamingRequest.ts index 45ae99f323..8e45ad5535 100644 --- a/packages/inference/src/tasks/custom/streamingRequest.ts +++ b/packages/inference/src/tasks/custom/streamingRequest.ts @@ -1,3 +1,4 @@ +import { resolveProvider } from "../../lib/getInferenceProviderMapping"; import { getProviderHelper } from "../../lib/getProviderHelper"; import type { InferenceTask, Options, RequestArgs } from "../../types"; import { innerStreamingRequest } from "../../utils/request"; @@ -16,6 +17,7 @@ export async function* streamingRequest<T>( console.warn( "The streamingRequest method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead." ); - const providerHelper = getProviderHelper(args.provider ?? "hf-inference", options?.task); + const provider = await resolveProvider(args.provider, args.model); + const providerHelper = getProviderHelper(provider, options?.task); yield* innerStreamingRequest(args, providerHelper, options); } diff --git a/packages/inference/src/tasks/cv/imageClassification.ts b/packages/inference/src/tasks/cv/imageClassification.ts index e683ecb3e2..a324faa1db 100644 --- a/packages/inference/src/tasks/cv/imageClassification.ts +++ b/packages/inference/src/tasks/cv/imageClassification.ts @@ -1,4 +1,5 @@ import type { ImageClassificationInput, ImageClassificationOutput } from "@huggingface/tasks"; +import { resolveProvider } from "../../lib/getInferenceProviderMapping"; import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { innerRequest } from "../../utils/request"; @@ -14,7 +15,8 @@ export async function imageClassification( args: ImageClassificationArgs, options?: Options ): Promise<ImageClassificationOutput> { - const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-classification"); + const provider = await resolveProvider(args.provider, args.model); + const providerHelper = getProviderHelper(provider, "image-classification"); const payload = preparePayload(args); const { data: res } = await innerRequest<ImageClassificationOutput>(payload, providerHelper, { ...options, diff --git a/packages/inference/src/tasks/cv/imageSegmentation.ts b/packages/inference/src/tasks/cv/imageSegmentation.ts index 9de0e9a2ef..719fa684ec 100644 --- a/packages/inference/src/tasks/cv/imageSegmentation.ts +++ b/packages/inference/src/tasks/cv/imageSegmentation.ts @@ -1,4 +1,5 @@ import type { ImageSegmentationInput, ImageSegmentationOutput } from "@huggingface/tasks"; +import { resolveProvider } from "../../lib/getInferenceProviderMapping"; import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { innerRequest } from "../../utils/request"; @@ -14,7 +15,8 @@ export async function imageSegmentation( args: ImageSegmentationArgs, options?: Options ): Promise<ImageSegmentationOutput> { - const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-segmentation"); + const provider = await resolveProvider(args.provider, args.model); + const providerHelper = getProviderHelper(provider, "image-segmentation"); const payload = preparePayload(args); const { data: res } = await innerRequest<ImageSegmentationOutput>(payload, providerHelper, { ...options, diff --git a/packages/inference/src/tasks/cv/imageToImage.ts b/packages/inference/src/tasks/cv/imageToImage.ts index 49d8ca2be5..a0bab24bab 100644 --- a/packages/inference/src/tasks/cv/imageToImage.ts +++ b/packages/inference/src/tasks/cv/imageToImage.ts @@ -1,4 +1,5 @@ import type { ImageToImageInput } from "@huggingface/tasks"; +import { resolveProvider } from "../../lib/getInferenceProviderMapping"; import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options, RequestArgs } from "../../types"; import { base64FromBytes } from "../../utils/base64FromBytes"; @@ -11,7 +12,8 @@ export type ImageToImageArgs = BaseArgs & ImageToImageInput; * Recommended model: lllyasviel/sd-controlnet-depth */ export async function imageToImage(args: ImageToImageArgs, options?: Options): Promise<Blob> { - const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-image"); + const provider = await resolveProvider(args.provider, args.model); + const providerHelper = getProviderHelper(provider, "image-to-image"); let reqArgs: RequestArgs; if (!args.parameters) { reqArgs = { diff --git a/packages/inference/src/tasks/cv/imageToText.ts b/packages/inference/src/tasks/cv/imageToText.ts index cdee706fa4..8d2d791893 100644 --- a/packages/inference/src/tasks/cv/imageToText.ts +++ b/packages/inference/src/tasks/cv/imageToText.ts @@ -1,4 +1,5 @@ import type { ImageToTextInput, ImageToTextOutput } from "@huggingface/tasks"; +import { resolveProvider } from "../../lib/getInferenceProviderMapping"; import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { innerRequest } from "../../utils/request"; @@ -10,7 +11,8 @@ export type ImageToTextArgs = BaseArgs & (ImageToTextInput | LegacyImageInput); * This task reads some image input and outputs the text caption. */ export async function imageToText(args: ImageToTextArgs, options?: Options): Promise<ImageToTextOutput> { - const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-text"); + const provider = await resolveProvider(args.provider, args.model); + const providerHelper = getProviderHelper(provider, "image-to-text"); const payload = preparePayload(args); const { data: res } = await innerRequest<[ImageToTextOutput]>(payload, providerHelper, { ...options, diff --git a/packages/inference/src/tasks/cv/objectDetection.ts b/packages/inference/src/tasks/cv/objectDetection.ts index d103feeb96..d216be78eb 100644 --- a/packages/inference/src/tasks/cv/objectDetection.ts +++ b/packages/inference/src/tasks/cv/objectDetection.ts @@ -1,4 +1,5 @@ import type { ObjectDetectionInput, ObjectDetectionOutput } from "@huggingface/tasks"; +import { resolveProvider } from "../../lib/getInferenceProviderMapping"; import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { innerRequest } from "../../utils/request"; @@ -11,7 +12,8 @@ export type ObjectDetectionArgs = BaseArgs & (ObjectDetectionInput | LegacyImage * Recommended model: facebook/detr-resnet-50 */ export async function objectDetection(args: ObjectDetectionArgs, options?: Options): Promise<ObjectDetectionOutput> { - const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "object-detection"); + const provider = await resolveProvider(args.provider, args.model); + const providerHelper = getProviderHelper(provider, "object-detection"); const payload = preparePayload(args); const { data: res } = await innerRequest<ObjectDetectionOutput>(payload, providerHelper, { ...options, diff --git a/packages/inference/src/tasks/cv/textToImage.ts b/packages/inference/src/tasks/cv/textToImage.ts index 490a8e10b6..76544065df 100644 --- a/packages/inference/src/tasks/cv/textToImage.ts +++ b/packages/inference/src/tasks/cv/textToImage.ts @@ -1,4 +1,5 @@ import type { TextToImageInput } from "@huggingface/tasks"; +import { resolveProvider } from "../../lib/getInferenceProviderMapping"; import { getProviderHelper } from "../../lib/getProviderHelper"; import { makeRequestOptions } from "../../lib/makeRequestOptions"; import type { BaseArgs, Options } from "../../types"; @@ -23,7 +24,7 @@ export async function textToImage( options?: TextToImageOptions & { outputType?: undefined | "blob" } ): Promise<Blob>; export async function textToImage(args: TextToImageArgs, options?: TextToImageOptions): Promise<Blob | string> { - const provider = args.provider ?? "hf-inference"; + const provider = await resolveProvider(args.provider, args.model); const providerHelper = getProviderHelper(provider, "text-to-image"); const { data: res } = await innerRequest<Record<string, unknown>>(args, providerHelper, { ...options, diff --git a/packages/inference/src/tasks/cv/textToVideo.ts b/packages/inference/src/tasks/cv/textToVideo.ts index 9143e147a8..798799d61b 100644 --- a/packages/inference/src/tasks/cv/textToVideo.ts +++ b/packages/inference/src/tasks/cv/textToVideo.ts @@ -1,4 +1,5 @@ import type { TextToVideoInput } from "@huggingface/tasks"; +import { resolveProvider } from "../../lib/getInferenceProviderMapping"; import { getProviderHelper } from "../../lib/getProviderHelper"; import { makeRequestOptions } from "../../lib/makeRequestOptions"; import type { FalAiQueueOutput } from "../../providers/fal-ai"; @@ -12,7 +13,7 @@ export type TextToVideoArgs = BaseArgs & TextToVideoInput; export type TextToVideoOutput = Blob; export async function textToVideo(args: TextToVideoArgs, options?: Options): Promise<TextToVideoOutput> { - const provider = args.provider ?? "hf-inference"; + const provider = await resolveProvider(args.provider, args.model); const providerHelper = getProviderHelper(provider, "text-to-video"); const { data: response } = await innerRequest<FalAiQueueOutput | ReplicateOutput | NovitaOutput>( args, diff --git a/packages/inference/src/tasks/cv/zeroShotImageClassification.ts b/packages/inference/src/tasks/cv/zeroShotImageClassification.ts index 8f09e74aac..e75b439853 100644 --- a/packages/inference/src/tasks/cv/zeroShotImageClassification.ts +++ b/packages/inference/src/tasks/cv/zeroShotImageClassification.ts @@ -1,4 +1,5 @@ import type { ZeroShotImageClassificationInput, ZeroShotImageClassificationOutput } from "@huggingface/tasks"; +import { resolveProvider } from "../../lib/getInferenceProviderMapping"; import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options, RequestArgs } from "../../types"; import { base64FromBytes } from "../../utils/base64FromBytes"; @@ -44,7 +45,8 @@ export async function zeroShotImageClassification( args: ZeroShotImageClassificationArgs, options?: Options ): Promise<ZeroShotImageClassificationOutput> { - const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-image-classification"); + const provider = await resolveProvider(args.provider, args.model); + const providerHelper = getProviderHelper(provider, "zero-shot-image-classification"); const payload = await preparePayload(args); const { data: res } = await innerRequest<ZeroShotImageClassificationOutput>(payload, providerHelper, { ...options, diff --git a/packages/inference/src/tasks/multimodal/documentQuestionAnswering.ts b/packages/inference/src/tasks/multimodal/documentQuestionAnswering.ts index 033708dce0..c17234ee2d 100644 --- a/packages/inference/src/tasks/multimodal/documentQuestionAnswering.ts +++ b/packages/inference/src/tasks/multimodal/documentQuestionAnswering.ts @@ -3,6 +3,7 @@ import type { DocumentQuestionAnsweringInputData, DocumentQuestionAnsweringOutput, } from "@huggingface/tasks"; +import { resolveProvider } from "../../lib/getInferenceProviderMapping"; import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options, RequestArgs } from "../../types"; import { base64FromBytes } from "../../utils/base64FromBytes"; @@ -19,7 +20,8 @@ export async function documentQuestionAnswering( args: DocumentQuestionAnsweringArgs, options?: Options ): Promise<DocumentQuestionAnsweringOutput[number]> { - const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "document-question-answering"); + const provider = await resolveProvider(args.provider, args.model); + const providerHelper = getProviderHelper(provider, "document-question-answering"); const reqArgs: RequestArgs = { ...args, inputs: { diff --git a/packages/inference/src/tasks/multimodal/visualQuestionAnswering.ts b/packages/inference/src/tasks/multimodal/visualQuestionAnswering.ts index 98a6616206..62e2d9e9da 100644 --- a/packages/inference/src/tasks/multimodal/visualQuestionAnswering.ts +++ b/packages/inference/src/tasks/multimodal/visualQuestionAnswering.ts @@ -3,6 +3,7 @@ import type { VisualQuestionAnsweringInputData, VisualQuestionAnsweringOutput, } from "@huggingface/tasks"; +import { resolveProvider } from "../../lib/getInferenceProviderMapping"; import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options, RequestArgs } from "../../types"; import { base64FromBytes } from "../../utils/base64FromBytes"; @@ -19,7 +20,8 @@ export async function visualQuestionAnswering( args: VisualQuestionAnsweringArgs, options?: Options ): Promise<VisualQuestionAnsweringOutput[number]> { - const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "visual-question-answering"); + const provider = await resolveProvider(args.provider, args.model); + const providerHelper = getProviderHelper(provider, "visual-question-answering"); const reqArgs: RequestArgs = { ...args, inputs: { diff --git a/packages/inference/src/tasks/nlp/chatCompletion.ts b/packages/inference/src/tasks/nlp/chatCompletion.ts index 4ad9be5f1c..6ad4da6aea 100644 --- a/packages/inference/src/tasks/nlp/chatCompletion.ts +++ b/packages/inference/src/tasks/nlp/chatCompletion.ts @@ -1,4 +1,5 @@ import type { ChatCompletionInput, ChatCompletionOutput } from "@huggingface/tasks"; +import { resolveProvider } from "../../lib/getInferenceProviderMapping"; import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { innerRequest } from "../../utils/request"; @@ -10,7 +11,8 @@ export async function chatCompletion( args: BaseArgs & ChatCompletionInput, options?: Options ): Promise<ChatCompletionOutput> { - const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "conversational"); + const provider = await resolveProvider(args.provider, args.model); + const providerHelper = getProviderHelper(provider, "conversational"); const { data: response } = await innerRequest<ChatCompletionOutput>(args, providerHelper, { ...options, task: "conversational", diff --git a/packages/inference/src/tasks/nlp/chatCompletionStream.ts b/packages/inference/src/tasks/nlp/chatCompletionStream.ts index cf88044112..260b9374ff 100644 --- a/packages/inference/src/tasks/nlp/chatCompletionStream.ts +++ b/packages/inference/src/tasks/nlp/chatCompletionStream.ts @@ -1,4 +1,5 @@ import type { ChatCompletionInput, ChatCompletionStreamOutput } from "@huggingface/tasks"; +import { resolveProvider } from "../../lib/getInferenceProviderMapping"; import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { innerStreamingRequest } from "../../utils/request"; @@ -10,7 +11,8 @@ export async function* chatCompletionStream( args: BaseArgs & ChatCompletionInput, options?: Options ): AsyncGenerator<ChatCompletionStreamOutput> { - const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "conversational"); + const provider = await resolveProvider(args.provider, args.model); + const providerHelper = getProviderHelper(provider, "conversational"); yield* innerStreamingRequest<ChatCompletionStreamOutput>(args, providerHelper, { ...options, task: "conversational", diff --git a/packages/inference/src/tasks/nlp/featureExtraction.ts b/packages/inference/src/tasks/nlp/featureExtraction.ts index 5d0fe93ded..30fc3a6766 100644 --- a/packages/inference/src/tasks/nlp/featureExtraction.ts +++ b/packages/inference/src/tasks/nlp/featureExtraction.ts @@ -1,4 +1,5 @@ import type { FeatureExtractionInput } from "@huggingface/tasks"; +import { resolveProvider } from "../../lib/getInferenceProviderMapping"; import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { innerRequest } from "../../utils/request"; @@ -22,7 +23,8 @@ export async function featureExtraction( args: FeatureExtractionArgs, options?: Options ): Promise<FeatureExtractionOutput> { - const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "feature-extraction"); + const provider = await resolveProvider(args.provider, args.model); + const providerHelper = getProviderHelper(provider, "feature-extraction"); const { data: res } = await innerRequest<FeatureExtractionOutput>(args, providerHelper, { ...options, task: "feature-extraction", diff --git a/packages/inference/src/tasks/nlp/fillMask.ts b/packages/inference/src/tasks/nlp/fillMask.ts index 663db87d99..8a992f938d 100644 --- a/packages/inference/src/tasks/nlp/fillMask.ts +++ b/packages/inference/src/tasks/nlp/fillMask.ts @@ -1,4 +1,5 @@ import type { FillMaskInput, FillMaskOutput } from "@huggingface/tasks"; +import { resolveProvider } from "../../lib/getInferenceProviderMapping"; import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { innerRequest } from "../../utils/request"; @@ -9,7 +10,8 @@ export type FillMaskArgs = BaseArgs & FillMaskInput; * Tries to fill in a hole with a missing word (token to be precise). That’s the base task for BERT models. */ export async function fillMask(args: FillMaskArgs, options?: Options): Promise<FillMaskOutput> { - const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "fill-mask"); + const provider = await resolveProvider(args.provider, args.model); + const providerHelper = getProviderHelper(provider, "fill-mask"); const { data: res } = await innerRequest<FillMaskOutput>(args, providerHelper, { ...options, task: "fill-mask", diff --git a/packages/inference/src/tasks/nlp/questionAnswering.ts b/packages/inference/src/tasks/nlp/questionAnswering.ts index 6559c80c1a..a6b3cc0570 100644 --- a/packages/inference/src/tasks/nlp/questionAnswering.ts +++ b/packages/inference/src/tasks/nlp/questionAnswering.ts @@ -1,4 +1,6 @@ import type { QuestionAnsweringInput, QuestionAnsweringOutput } from "@huggingface/tasks"; + +import { resolveProvider } from "../../lib/getInferenceProviderMapping"; import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { innerRequest } from "../../utils/request"; @@ -12,7 +14,8 @@ export async function questionAnswering( args: QuestionAnsweringArgs, options?: Options ): Promise<QuestionAnsweringOutput[number]> { - const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "question-answering"); + const provider = await resolveProvider(args.provider, args.model); + const providerHelper = getProviderHelper(provider, "question-answering"); const { data: res } = await innerRequest<QuestionAnsweringOutput | QuestionAnsweringOutput[number]>( args, providerHelper, diff --git a/packages/inference/src/tasks/nlp/sentenceSimilarity.ts b/packages/inference/src/tasks/nlp/sentenceSimilarity.ts index faa751f73e..f76b8c795c 100644 --- a/packages/inference/src/tasks/nlp/sentenceSimilarity.ts +++ b/packages/inference/src/tasks/nlp/sentenceSimilarity.ts @@ -1,4 +1,5 @@ import type { SentenceSimilarityInput, SentenceSimilarityOutput } from "@huggingface/tasks"; +import { resolveProvider } from "../../lib/getInferenceProviderMapping"; import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { innerRequest } from "../../utils/request"; @@ -12,7 +13,8 @@ export async function sentenceSimilarity( args: SentenceSimilarityArgs, options?: Options ): Promise<SentenceSimilarityOutput> { - const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "sentence-similarity"); + const provider = await resolveProvider(args.provider, args.model); + const providerHelper = getProviderHelper(provider, "sentence-similarity"); const { data: res } = await innerRequest<SentenceSimilarityOutput>(args, providerHelper, { ...options, task: "sentence-similarity", diff --git a/packages/inference/src/tasks/nlp/summarization.ts b/packages/inference/src/tasks/nlp/summarization.ts index 4b4205bf4b..29cda2cd5b 100644 --- a/packages/inference/src/tasks/nlp/summarization.ts +++ b/packages/inference/src/tasks/nlp/summarization.ts @@ -1,4 +1,5 @@ import type { SummarizationInput, SummarizationOutput } from "@huggingface/tasks"; +import { resolveProvider } from "../../lib/getInferenceProviderMapping"; import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { innerRequest } from "../../utils/request"; @@ -9,7 +10,8 @@ export type SummarizationArgs = BaseArgs & SummarizationInput; * This task is well known to summarize longer text into shorter text. Be careful, some models have a maximum length of input. That means that the summary cannot handle full books for instance. Be careful when choosing your model. */ export async function summarization(args: SummarizationArgs, options?: Options): Promise<SummarizationOutput> { - const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "summarization"); + const provider = await resolveProvider(args.provider, args.model); + const providerHelper = getProviderHelper(provider, "summarization"); const { data: res } = await innerRequest<SummarizationOutput[]>(args, providerHelper, { ...options, task: "summarization", diff --git a/packages/inference/src/tasks/nlp/tableQuestionAnswering.ts b/packages/inference/src/tasks/nlp/tableQuestionAnswering.ts index 3939115862..03127a2fd8 100644 --- a/packages/inference/src/tasks/nlp/tableQuestionAnswering.ts +++ b/packages/inference/src/tasks/nlp/tableQuestionAnswering.ts @@ -1,4 +1,5 @@ import type { TableQuestionAnsweringInput, TableQuestionAnsweringOutput } from "@huggingface/tasks"; +import { resolveProvider } from "../../lib/getInferenceProviderMapping"; import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { innerRequest } from "../../utils/request"; @@ -12,7 +13,8 @@ export async function tableQuestionAnswering( args: TableQuestionAnsweringArgs, options?: Options ): Promise<TableQuestionAnsweringOutput[number]> { - const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "table-question-answering"); + const provider = await resolveProvider(args.provider, args.model); + const providerHelper = getProviderHelper(provider, "table-question-answering"); const { data: res } = await innerRequest<TableQuestionAnsweringOutput | TableQuestionAnsweringOutput[number]>( args, providerHelper, diff --git a/packages/inference/src/tasks/nlp/textClassification.ts b/packages/inference/src/tasks/nlp/textClassification.ts index 7631d82286..8440d81bff 100644 --- a/packages/inference/src/tasks/nlp/textClassification.ts +++ b/packages/inference/src/tasks/nlp/textClassification.ts @@ -1,4 +1,5 @@ import type { TextClassificationInput, TextClassificationOutput } from "@huggingface/tasks"; +import { resolveProvider } from "../../lib/getInferenceProviderMapping"; import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { innerRequest } from "../../utils/request"; @@ -12,7 +13,8 @@ export async function textClassification( args: TextClassificationArgs, options?: Options ): Promise<TextClassificationOutput> { - const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-classification"); + const provider = await resolveProvider(args.provider, args.model); + const providerHelper = getProviderHelper(provider, "text-classification"); const { data: res } = await innerRequest<TextClassificationOutput>(args, providerHelper, { ...options, task: "text-classification", diff --git a/packages/inference/src/tasks/nlp/textGeneration.ts b/packages/inference/src/tasks/nlp/textGeneration.ts index 5d84e543cb..a1e441e6bc 100644 --- a/packages/inference/src/tasks/nlp/textGeneration.ts +++ b/packages/inference/src/tasks/nlp/textGeneration.ts @@ -1,4 +1,5 @@ import type { TextGenerationInput, TextGenerationOutput } from "@huggingface/tasks"; +import { resolveProvider } from "../../lib/getInferenceProviderMapping"; import { getProviderHelper } from "../../lib/getProviderHelper"; import type { HyperbolicTextCompletionOutput } from "../../providers/hyperbolic"; import type { BaseArgs, Options } from "../../types"; @@ -13,7 +14,8 @@ export async function textGeneration( args: BaseArgs & TextGenerationInput, options?: Options ): Promise<TextGenerationOutput> { - const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-generation"); + const provider = await resolveProvider(args.provider, args.model); + const providerHelper = getProviderHelper(provider, "text-generation"); const { data: response } = await innerRequest< HyperbolicTextCompletionOutput | TextGenerationOutput | TextGenerationOutput[] >(args, providerHelper, { diff --git a/packages/inference/src/tasks/nlp/textGenerationStream.ts b/packages/inference/src/tasks/nlp/textGenerationStream.ts index 59706eaa5c..935e533a2a 100644 --- a/packages/inference/src/tasks/nlp/textGenerationStream.ts +++ b/packages/inference/src/tasks/nlp/textGenerationStream.ts @@ -1,4 +1,5 @@ import type { TextGenerationInput } from "@huggingface/tasks"; +import { resolveProvider } from "../../lib/getInferenceProviderMapping"; import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { innerStreamingRequest } from "../../utils/request"; @@ -90,7 +91,8 @@ export async function* textGenerationStream( args: BaseArgs & TextGenerationInput, options?: Options ): AsyncGenerator<TextGenerationStreamOutput> { - const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-generation"); + const provider = await resolveProvider(args.provider, args.model); + const providerHelper = getProviderHelper(provider, "text-generation"); yield* innerStreamingRequest<TextGenerationStreamOutput>(args, providerHelper, { ...options, task: "text-generation", diff --git a/packages/inference/src/tasks/nlp/tokenClassification.ts b/packages/inference/src/tasks/nlp/tokenClassification.ts index 0c52b9e6a6..a822c7efcf 100644 --- a/packages/inference/src/tasks/nlp/tokenClassification.ts +++ b/packages/inference/src/tasks/nlp/tokenClassification.ts @@ -1,4 +1,5 @@ import type { TokenClassificationInput, TokenClassificationOutput } from "@huggingface/tasks"; +import { resolveProvider } from "../../lib/getInferenceProviderMapping"; import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { innerRequest } from "../../utils/request"; @@ -12,7 +13,8 @@ export async function tokenClassification( args: TokenClassificationArgs, options?: Options ): Promise<TokenClassificationOutput> { - const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "token-classification"); + const provider = await resolveProvider(args.provider, args.model); + const providerHelper = getProviderHelper(provider, "token-classification"); const { data: res } = await innerRequest<TokenClassificationOutput[number] | TokenClassificationOutput>( args, providerHelper, diff --git a/packages/inference/src/tasks/nlp/translation.ts b/packages/inference/src/tasks/nlp/translation.ts index 1f463e3e67..1347527d87 100644 --- a/packages/inference/src/tasks/nlp/translation.ts +++ b/packages/inference/src/tasks/nlp/translation.ts @@ -1,4 +1,5 @@ import type { TranslationInput, TranslationOutput } from "@huggingface/tasks"; +import { resolveProvider } from "../../lib/getInferenceProviderMapping"; import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { innerRequest } from "../../utils/request"; @@ -8,7 +9,8 @@ export type TranslationArgs = BaseArgs & TranslationInput; * This task is well known to translate text from one language to another. Recommended model: Helsinki-NLP/opus-mt-ru-en. */ export async function translation(args: TranslationArgs, options?: Options): Promise<TranslationOutput> { - const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "translation"); + const provider = await resolveProvider(args.provider, args.model); + const providerHelper = getProviderHelper(provider, "translation"); const { data: res } = await innerRequest<TranslationOutput>(args, providerHelper, { ...options, task: "translation", diff --git a/packages/inference/src/tasks/nlp/zeroShotClassification.ts b/packages/inference/src/tasks/nlp/zeroShotClassification.ts index 30d6d0c156..83642749b1 100644 --- a/packages/inference/src/tasks/nlp/zeroShotClassification.ts +++ b/packages/inference/src/tasks/nlp/zeroShotClassification.ts @@ -1,4 +1,5 @@ import type { ZeroShotClassificationInput, ZeroShotClassificationOutput } from "@huggingface/tasks"; +import { resolveProvider } from "../../lib/getInferenceProviderMapping"; import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { innerRequest } from "../../utils/request"; @@ -12,7 +13,8 @@ export async function zeroShotClassification( args: ZeroShotClassificationArgs, options?: Options ): Promise<ZeroShotClassificationOutput> { - const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-classification"); + const provider = await resolveProvider(args.provider, args.model); + const providerHelper = getProviderHelper(provider, "zero-shot-classification"); const { data: res } = await innerRequest<ZeroShotClassificationOutput[number] | ZeroShotClassificationOutput>( args, providerHelper, diff --git a/packages/inference/src/tasks/tabular/tabularClassification.ts b/packages/inference/src/tasks/tabular/tabularClassification.ts index 9174c17718..5d615fdbed 100644 --- a/packages/inference/src/tasks/tabular/tabularClassification.ts +++ b/packages/inference/src/tasks/tabular/tabularClassification.ts @@ -1,3 +1,4 @@ +import { resolveProvider } from "../../lib/getInferenceProviderMapping"; import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { innerRequest } from "../../utils/request"; @@ -25,7 +26,8 @@ export async function tabularClassification( args: TabularClassificationArgs, options?: Options ): Promise<TabularClassificationOutput> { - const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-classification"); + const provider = await resolveProvider(args.provider, args.model); + const providerHelper = getProviderHelper(provider, "tabular-classification"); const { data: res } = await innerRequest<TabularClassificationOutput>(args, providerHelper, { ...options, task: "tabular-classification", diff --git a/packages/inference/src/tasks/tabular/tabularRegression.ts b/packages/inference/src/tasks/tabular/tabularRegression.ts index 2c2408ffde..284db1c8a6 100644 --- a/packages/inference/src/tasks/tabular/tabularRegression.ts +++ b/packages/inference/src/tasks/tabular/tabularRegression.ts @@ -1,3 +1,4 @@ +import { resolveProvider } from "../../lib/getInferenceProviderMapping"; import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { innerRequest } from "../../utils/request"; @@ -25,7 +26,8 @@ export async function tabularRegression( args: TabularRegressionArgs, options?: Options ): Promise<TabularRegressionOutput> { - const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-regression"); + const provider = await resolveProvider(args.provider, args.model); + const providerHelper = getProviderHelper(provider, "tabular-regression"); const { data: res } = await innerRequest<TabularRegressionOutput>(args, providerHelper, { ...options, task: "tabular-regression", diff --git a/packages/inference/src/types.ts b/packages/inference/src/types.ts index e5870f6ef3..51ad16156f 100644 --- a/packages/inference/src/types.ts +++ b/packages/inference/src/types.ts @@ -56,8 +56,12 @@ export const INFERENCE_PROVIDERS = [ "together", ] as const; +export const PROVIDERS_OR_POLICIES = [...INFERENCE_PROVIDERS, "auto"] as const; + export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number]; +export type InferenceProviderPolicy = (typeof PROVIDERS_OR_POLICIES)[number]; + export interface BaseArgs { /** * The access token to use. Without it, you'll get rate-limited quickly. @@ -90,7 +94,7 @@ export interface BaseArgs { * * Defaults to the first provider in your user settings that is compatible with this model. */ - provider?: InferenceProvider; + provider?: InferenceProviderPolicy; } export type RequestArgs = BaseArgs & From bab28f946f2964ae0358c2282f88e63be6f47f3d Mon Sep 17 00:00:00 2001 From: Celina Hanouti <hanouticelina@gmail.com> Date: Fri, 25 Apr 2025 18:03:22 +0200 Subject: [PATCH 2/6] nit --- packages/inference/src/types.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/inference/src/types.ts b/packages/inference/src/types.ts index 51ad16156f..f362a5647e 100644 --- a/packages/inference/src/types.ts +++ b/packages/inference/src/types.ts @@ -92,7 +92,7 @@ export interface BaseArgs { /** * Set an Inference provider to run this model on. * - * Defaults to the first provider in your user settings that is compatible with this model. + * Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers. */ provider?: InferenceProviderPolicy; } From c7383d335f27dcf36179d1e02f05a13e5cde237d Mon Sep 17 00:00:00 2001 From: Celina Hanouti <hanouticelina@gmail.com> Date: Fri, 25 Apr 2025 18:07:10 +0200 Subject: [PATCH 3/6] logging --- packages/inference/src/lib/getInferenceProviderMapping.ts | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/packages/inference/src/lib/getInferenceProviderMapping.ts b/packages/inference/src/lib/getInferenceProviderMapping.ts index 08a8ee5a8b..25418a50a5 100644 --- a/packages/inference/src/lib/getInferenceProviderMapping.ts +++ b/packages/inference/src/lib/getInferenceProviderMapping.ts @@ -121,6 +121,9 @@ export async function resolveProvider( provider = "hf-inference"; } if (!provider) { + console.log( + "Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers." + ); provider = "auto"; } if (provider === "auto") { @@ -129,6 +132,7 @@ export async function resolveProvider( } const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(modelId); provider = Object.keys(inferenceProviderMapping)[0] as InferenceProvider; + console.log("Auto-selected provider:", provider); } return provider; } From 4a104793d28f90b183345c7d98c878219130bb90 Mon Sep 17 00:00:00 2001 From: Celina Hanouti <hanouticelina@gmail.com> Date: Fri, 25 Apr 2025 18:21:43 +0200 Subject: [PATCH 4/6] fix --- packages/inference/src/lib/getInferenceProviderMapping.ts | 3 --- 1 file changed, 3 deletions(-) diff --git a/packages/inference/src/lib/getInferenceProviderMapping.ts b/packages/inference/src/lib/getInferenceProviderMapping.ts index 25418a50a5..46153a102f 100644 --- a/packages/inference/src/lib/getInferenceProviderMapping.ts +++ b/packages/inference/src/lib/getInferenceProviderMapping.ts @@ -117,9 +117,6 @@ export async function resolveProvider( provider?: InferenceProviderPolicy, modelId?: string ): Promise<InferenceProvider> { - if (!provider && !modelId) { - provider = "hf-inference"; - } if (!provider) { console.log( "Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers." From 38530d057e20a863ce23e166149e96a0ad95c78f Mon Sep 17 00:00:00 2001 From: Celina Hanouti <hanouticelina@gmail.com> Date: Tue, 29 Apr 2025 15:39:26 +0200 Subject: [PATCH 5/6] rename InferenceProviderPolicy -> InferenceProviderOrPolicy --- packages/inference/src/lib/getInferenceProviderMapping.ts | 4 ++-- packages/inference/src/types.ts | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/packages/inference/src/lib/getInferenceProviderMapping.ts b/packages/inference/src/lib/getInferenceProviderMapping.ts index 3220188efe..96904d5d9c 100644 --- a/packages/inference/src/lib/getInferenceProviderMapping.ts +++ b/packages/inference/src/lib/getInferenceProviderMapping.ts @@ -2,7 +2,7 @@ import type { WidgetType } from "@huggingface/tasks"; import { HF_HUB_URL } from "../config"; import { HARDCODED_MODEL_INFERENCE_MAPPING } from "../providers/consts"; import { EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS } from "../providers/hf-inference"; -import type { InferenceProvider, InferenceProviderPolicy, ModelId } from "../types"; +import type { InferenceProvider, InferenceProviderOrPolicy, ModelId } from "../types"; import { typedInclude } from "../utils/typedInclude"; export const inferenceProviderMappingCache = new Map<ModelId, InferenceProviderMapping>(); @@ -98,7 +98,7 @@ export async function getInferenceProviderMapping( } export async function resolveProvider( - provider?: InferenceProviderPolicy, + provider?: InferenceProviderOrPolicy, modelId?: string ): Promise<InferenceProvider> { if (!provider) { diff --git a/packages/inference/src/types.ts b/packages/inference/src/types.ts index f362a5647e..a81b1ddde5 100644 --- a/packages/inference/src/types.ts +++ b/packages/inference/src/types.ts @@ -60,7 +60,7 @@ export const PROVIDERS_OR_POLICIES = [...INFERENCE_PROVIDERS, "auto"] as const; export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number]; -export type InferenceProviderPolicy = (typeof PROVIDERS_OR_POLICIES)[number]; +export type InferenceProviderOrPolicy = (typeof PROVIDERS_OR_POLICIES)[number]; export interface BaseArgs { /** @@ -94,7 +94,7 @@ export interface BaseArgs { * * Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers. */ - provider?: InferenceProviderPolicy; + provider?: InferenceProviderOrPolicy; } export type RequestArgs = BaseArgs & From a4bce4f37c9894df6a0ac224282f628ec2eea566 Mon Sep 17 00:00:00 2001 From: Celina Hanouti <hanouticelina@gmail.com> Date: Wed, 30 Apr 2025 11:12:46 +0200 Subject: [PATCH 6/6] remove unnecessary log --- packages/inference/src/lib/getInferenceProviderMapping.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/packages/inference/src/lib/getInferenceProviderMapping.ts b/packages/inference/src/lib/getInferenceProviderMapping.ts index 96904d5d9c..d751beae6d 100644 --- a/packages/inference/src/lib/getInferenceProviderMapping.ts +++ b/packages/inference/src/lib/getInferenceProviderMapping.ts @@ -113,7 +113,6 @@ export async function resolveProvider( } const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(modelId); provider = Object.keys(inferenceProviderMapping)[0] as InferenceProvider; - console.log("Auto-selected provider:", provider); } return provider; }