Skip to content

Commit eaa1b9c

Browse files
[Inference Providers] provider="auto" (#1390)
Same as huggingface/huggingface_hub#3011. This PR adds support for auto selection of the provider. Previously the default value was `hf-inference` (HF Inference API provider), now we default to "auto", meaning we will select the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers. you can test with: ```ts import { chatCompletion } from "../src"; const res = await chatCompletion({ // provider="auto", model: "deepseek-ai/DeepSeek-V3-0324", messages: [ { role: "user", content: "What is the capital of France?", }, ], accessToken: process.env.HF_TOKEN, }); console.log(res.choices[0].message.content); ``` ``` Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers. Auto-selected provider: sambanova The capital of France is **Paris**. It is known for its iconic landmarks such as the Eiffel Tower...blabla ``` the selected provider should be be the first in `inferenceProviderMapping` mapping here: https://huggingface.co/api/models/deepseek-ai/DeepSeek-V3-0324?expand=inferenceProviderMapping
1 parent 9d2c7b8 commit eaa1b9c

36 files changed

+163
-57
lines changed

packages/inference/src/lib/getInferenceProviderMapping.ts

Lines changed: 57 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import type { WidgetType } from "@huggingface/tasks";
2-
import type { InferenceProvider, ModelId } from "../types";
32
import { HF_HUB_URL } from "../config";
43
import { HARDCODED_MODEL_INFERENCE_MAPPING } from "../providers/consts";
54
import { EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS } from "../providers/hf-inference";
5+
import type { InferenceProvider, InferenceProviderOrPolicy, ModelId } from "../types";
66
import { typedInclude } from "../utils/typedInclude";
77

88
export const inferenceProviderMappingCache = new Map<ModelId, InferenceProviderMapping>();
@@ -20,44 +20,62 @@ export interface InferenceProviderModelMapping {
2020
task: WidgetType;
2121
}
2222

23-
export async function getInferenceProviderMapping(
24-
params: {
25-
accessToken?: string;
26-
modelId: ModelId;
27-
provider: InferenceProvider;
28-
task: WidgetType;
29-
},
30-
options: {
23+
export async function fetchInferenceProviderMappingForModel(
24+
modelId: ModelId,
25+
accessToken?: string,
26+
options?: {
3127
fetch?: (input: RequestInfo, init?: RequestInit) => Promise<Response>;
3228
}
33-
): Promise<InferenceProviderModelMapping | null> {
34-
if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
35-
return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
36-
}
29+
): Promise<InferenceProviderMapping> {
3730
let inferenceProviderMapping: InferenceProviderMapping | null;
38-
if (inferenceProviderMappingCache.has(params.modelId)) {
31+
if (inferenceProviderMappingCache.has(modelId)) {
3932
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
40-
inferenceProviderMapping = inferenceProviderMappingCache.get(params.modelId)!;
33+
inferenceProviderMapping = inferenceProviderMappingCache.get(modelId)!;
4134
} else {
4235
const resp = await (options?.fetch ?? fetch)(
43-
`${HF_HUB_URL}/api/models/${params.modelId}?expand[]=inferenceProviderMapping`,
36+
`${HF_HUB_URL}/api/models/${modelId}?expand[]=inferenceProviderMapping`,
4437
{
45-
headers: params.accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${params.accessToken}` } : {},
38+
headers: accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${accessToken}` } : {},
4639
}
4740
);
4841
if (resp.status === 404) {
49-
throw new Error(`Model ${params.modelId} does not exist`);
42+
throw new Error(`Model ${modelId} does not exist`);
5043
}
5144
inferenceProviderMapping = await resp
5245
.json()
5346
.then((json) => json.inferenceProviderMapping)
5447
.catch(() => null);
48+
49+
if (inferenceProviderMapping) {
50+
inferenceProviderMappingCache.set(modelId, inferenceProviderMapping);
51+
}
5552
}
5653

5754
if (!inferenceProviderMapping) {
58-
throw new Error(`We have not been able to find inference provider information for model ${params.modelId}.`);
55+
throw new Error(`We have not been able to find inference provider information for model ${modelId}.`);
5956
}
57+
return inferenceProviderMapping;
58+
}
6059

60+
export async function getInferenceProviderMapping(
61+
params: {
62+
accessToken?: string;
63+
modelId: ModelId;
64+
provider: InferenceProvider;
65+
task: WidgetType;
66+
},
67+
options: {
68+
fetch?: (input: RequestInfo, init?: RequestInit) => Promise<Response>;
69+
}
70+
): Promise<InferenceProviderModelMapping | null> {
71+
if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
72+
return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
73+
}
74+
const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(
75+
params.modelId,
76+
params.accessToken,
77+
options
78+
);
6179
const providerMapping = inferenceProviderMapping[params.provider];
6280
if (providerMapping) {
6381
const equivalentTasks =
@@ -78,3 +96,23 @@ export async function getInferenceProviderMapping(
7896
}
7997
return null;
8098
}
99+
100+
export async function resolveProvider(
101+
provider?: InferenceProviderOrPolicy,
102+
modelId?: string
103+
): Promise<InferenceProvider> {
104+
if (!provider) {
105+
console.log(
106+
"Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers."
107+
);
108+
provider = "auto";
109+
}
110+
if (provider === "auto") {
111+
if (!modelId) {
112+
throw new Error("Specifying a model is required when provider is 'auto'");
113+
}
114+
const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(modelId);
115+
provider = Object.keys(inferenceProviderMapping)[0] as InferenceProvider;
116+
}
117+
return provider;
118+
}

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ export async function makeRequestOptions(
2727
task?: InferenceTask;
2828
}
2929
): Promise<{ url: string; info: RequestInit }> {
30-
const { provider: maybeProvider, model: maybeModel } = args;
31-
const provider = maybeProvider ?? "hf-inference";
30+
const { model: maybeModel } = args;
31+
const provider = providerHelper.provider;
3232
const { task } = options ?? {};
3333

3434
// Validate inputs
@@ -113,8 +113,9 @@ export function makeRequestOptionsFromResolvedModel(
113113
): { url: string; info: RequestInit } {
114114
const { accessToken, endpointUrl, provider: maybeProvider, model, ...remainingArgs } = args;
115115
void model;
116+
void maybeProvider;
116117

117-
const provider = maybeProvider ?? "hf-inference";
118+
const provider = providerHelper.provider;
118119

119120
const { includeCredentials, task, signal, billTo } = options ?? {};
120121
const authMethod = (() => {

packages/inference/src/providers/providerHelper.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ import { toArray } from "../utils/toArray";
5656
*/
5757
export abstract class TaskProviderHelper {
5858
constructor(
59-
private provider: InferenceProvider,
59+
readonly provider: InferenceProvider,
6060
private baseUrl: string,
6161
readonly clientSideRoutingOnly: boolean = false
6262
) {}

packages/inference/src/tasks/audio/audioClassification.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import type { AudioClassificationInput, AudioClassificationOutput } from "@huggingface/tasks";
2+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
23
import { getProviderHelper } from "../../lib/getProviderHelper";
34
import type { BaseArgs, Options } from "../../types";
45
import { innerRequest } from "../../utils/request";
@@ -15,7 +16,8 @@ export async function audioClassification(
1516
args: AudioClassificationArgs,
1617
options?: Options
1718
): Promise<AudioClassificationOutput> {
18-
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-classification");
19+
const provider = await resolveProvider(args.provider, args.model);
20+
const providerHelper = getProviderHelper(provider, "audio-classification");
1921
const payload = preparePayload(args);
2022
const { data: res } = await innerRequest<AudioClassificationOutput>(payload, providerHelper, {
2123
...options,

packages/inference/src/tasks/audio/audioToAudio.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
12
import { getProviderHelper } from "../../lib/getProviderHelper";
23
import type { BaseArgs, Options } from "../../types";
34
import { innerRequest } from "../../utils/request";
@@ -36,7 +37,9 @@ export interface AudioToAudioOutput {
3637
* Example model: speechbrain/sepformer-wham does audio source separation.
3738
*/
3839
export async function audioToAudio(args: AudioToAudioArgs, options?: Options): Promise<AudioToAudioOutput[]> {
39-
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-to-audio");
40+
const model = "inputs" in args ? args.model : undefined;
41+
const provider = await resolveProvider(args.provider, model);
42+
const providerHelper = getProviderHelper(provider, "audio-to-audio");
4043
const payload = preparePayload(args);
4144
const { data: res } = await innerRequest<AudioToAudioOutput>(payload, providerHelper, {
4245
...options,

packages/inference/src/tasks/audio/automaticSpeechRecognition.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import type { AutomaticSpeechRecognitionInput, AutomaticSpeechRecognitionOutput } from "@huggingface/tasks";
2+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
23
import { getProviderHelper } from "../../lib/getProviderHelper";
34
import { InferenceOutputError } from "../../lib/InferenceOutputError";
45
import { FAL_AI_SUPPORTED_BLOB_TYPES } from "../../providers/fal-ai";
@@ -18,7 +19,8 @@ export async function automaticSpeechRecognition(
1819
args: AutomaticSpeechRecognitionArgs,
1920
options?: Options
2021
): Promise<AutomaticSpeechRecognitionOutput> {
21-
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "automatic-speech-recognition");
22+
const provider = await resolveProvider(args.provider, args.model);
23+
const providerHelper = getProviderHelper(provider, "automatic-speech-recognition");
2224
const payload = await buildPayload(args);
2325
const { data: res } = await innerRequest<AutomaticSpeechRecognitionOutput>(payload, providerHelper, {
2426
...options,

packages/inference/src/tasks/audio/textToSpeech.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import type { TextToSpeechInput } from "@huggingface/tasks";
2+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
23
import { getProviderHelper } from "../../lib/getProviderHelper";
34
import type { BaseArgs, Options } from "../../types";
45
import { innerRequest } from "../../utils/request";
@@ -12,7 +13,7 @@ interface OutputUrlTextToSpeechGeneration {
1213
* Recommended model: espnet/kan-bayashi_ljspeech_vits
1314
*/
1415
export async function textToSpeech(args: TextToSpeechArgs, options?: Options): Promise<Blob> {
15-
const provider = args.provider ?? "hf-inference";
16+
const provider = await resolveProvider(args.provider, args.model);
1617
const providerHelper = getProviderHelper(provider, "text-to-speech");
1718
const { data: res } = await innerRequest<Blob | OutputUrlTextToSpeechGeneration>(args, providerHelper, {
1819
...options,

packages/inference/src/tasks/custom/request.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
12
import { getProviderHelper } from "../../lib/getProviderHelper";
23
import type { InferenceTask, Options, RequestArgs } from "../../types";
34
import { innerRequest } from "../../utils/request";
@@ -16,7 +17,8 @@ export async function request<T>(
1617
console.warn(
1718
"The request method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
1819
);
19-
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", options?.task);
20+
const provider = await resolveProvider(args.provider, args.model);
21+
const providerHelper = getProviderHelper(provider, options?.task);
2022
const result = await innerRequest<T>(args, providerHelper, options);
2123
return result.data;
2224
}

packages/inference/src/tasks/custom/streamingRequest.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
12
import { getProviderHelper } from "../../lib/getProviderHelper";
23
import type { InferenceTask, Options, RequestArgs } from "../../types";
34
import { innerStreamingRequest } from "../../utils/request";
@@ -16,6 +17,7 @@ export async function* streamingRequest<T>(
1617
console.warn(
1718
"The streamingRequest method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
1819
);
19-
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", options?.task);
20+
const provider = await resolveProvider(args.provider, args.model);
21+
const providerHelper = getProviderHelper(provider, options?.task);
2022
yield* innerStreamingRequest(args, providerHelper, options);
2123
}

packages/inference/src/tasks/cv/imageClassification.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import type { ImageClassificationInput, ImageClassificationOutput } from "@huggingface/tasks";
2+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
23
import { getProviderHelper } from "../../lib/getProviderHelper";
34
import type { BaseArgs, Options } from "../../types";
45
import { innerRequest } from "../../utils/request";
@@ -14,7 +15,8 @@ export async function imageClassification(
1415
args: ImageClassificationArgs,
1516
options?: Options
1617
): Promise<ImageClassificationOutput> {
17-
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-classification");
18+
const provider = await resolveProvider(args.provider, args.model);
19+
const providerHelper = getProviderHelper(provider, "image-classification");
1820
const payload = preparePayload(args);
1921
const { data: res } = await innerRequest<ImageClassificationOutput>(payload, providerHelper, {
2022
...options,

packages/inference/src/tasks/cv/imageSegmentation.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import type { ImageSegmentationInput, ImageSegmentationOutput } from "@huggingface/tasks";
2+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
23
import { getProviderHelper } from "../../lib/getProviderHelper";
34
import type { BaseArgs, Options } from "../../types";
45
import { innerRequest } from "../../utils/request";
@@ -14,7 +15,8 @@ export async function imageSegmentation(
1415
args: ImageSegmentationArgs,
1516
options?: Options
1617
): Promise<ImageSegmentationOutput> {
17-
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-segmentation");
18+
const provider = await resolveProvider(args.provider, args.model);
19+
const providerHelper = getProviderHelper(provider, "image-segmentation");
1820
const payload = preparePayload(args);
1921
const { data: res } = await innerRequest<ImageSegmentationOutput>(payload, providerHelper, {
2022
...options,

packages/inference/src/tasks/cv/imageToImage.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import type { ImageToImageInput } from "@huggingface/tasks";
2+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
23
import { getProviderHelper } from "../../lib/getProviderHelper";
34
import type { BaseArgs, Options, RequestArgs } from "../../types";
45
import { base64FromBytes } from "../../utils/base64FromBytes";
@@ -11,7 +12,8 @@ export type ImageToImageArgs = BaseArgs & ImageToImageInput;
1112
* Recommended model: lllyasviel/sd-controlnet-depth
1213
*/
1314
export async function imageToImage(args: ImageToImageArgs, options?: Options): Promise<Blob> {
14-
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-image");
15+
const provider = await resolveProvider(args.provider, args.model);
16+
const providerHelper = getProviderHelper(provider, "image-to-image");
1517
let reqArgs: RequestArgs;
1618
if (!args.parameters) {
1719
reqArgs = {

packages/inference/src/tasks/cv/imageToText.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import type { ImageToTextInput, ImageToTextOutput } from "@huggingface/tasks";
2+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
23
import { getProviderHelper } from "../../lib/getProviderHelper";
34
import type { BaseArgs, Options } from "../../types";
45
import { innerRequest } from "../../utils/request";
@@ -10,7 +11,8 @@ export type ImageToTextArgs = BaseArgs & (ImageToTextInput | LegacyImageInput);
1011
* This task reads some image input and outputs the text caption.
1112
*/
1213
export async function imageToText(args: ImageToTextArgs, options?: Options): Promise<ImageToTextOutput> {
13-
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-text");
14+
const provider = await resolveProvider(args.provider, args.model);
15+
const providerHelper = getProviderHelper(provider, "image-to-text");
1416
const payload = preparePayload(args);
1517
const { data: res } = await innerRequest<[ImageToTextOutput]>(payload, providerHelper, {
1618
...options,

packages/inference/src/tasks/cv/objectDetection.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import type { ObjectDetectionInput, ObjectDetectionOutput } from "@huggingface/tasks";
2+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
23
import { getProviderHelper } from "../../lib/getProviderHelper";
34
import type { BaseArgs, Options } from "../../types";
45
import { innerRequest } from "../../utils/request";
@@ -11,7 +12,8 @@ export type ObjectDetectionArgs = BaseArgs & (ObjectDetectionInput | LegacyImage
1112
* Recommended model: facebook/detr-resnet-50
1213
*/
1314
export async function objectDetection(args: ObjectDetectionArgs, options?: Options): Promise<ObjectDetectionOutput> {
14-
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "object-detection");
15+
const provider = await resolveProvider(args.provider, args.model);
16+
const providerHelper = getProviderHelper(provider, "object-detection");
1517
const payload = preparePayload(args);
1618
const { data: res } = await innerRequest<ObjectDetectionOutput>(payload, providerHelper, {
1719
...options,

packages/inference/src/tasks/cv/textToImage.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import type { TextToImageInput } from "@huggingface/tasks";
2+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
23
import { getProviderHelper } from "../../lib/getProviderHelper";
34
import { makeRequestOptions } from "../../lib/makeRequestOptions";
45
import type { BaseArgs, Options } from "../../types";
@@ -23,7 +24,7 @@ export async function textToImage(
2324
options?: TextToImageOptions & { outputType?: undefined | "blob" }
2425
): Promise<Blob>;
2526
export async function textToImage(args: TextToImageArgs, options?: TextToImageOptions): Promise<Blob | string> {
26-
const provider = args.provider ?? "hf-inference";
27+
const provider = await resolveProvider(args.provider, args.model);
2728
const providerHelper = getProviderHelper(provider, "text-to-image");
2829
const { data: res } = await innerRequest<Record<string, unknown>>(args, providerHelper, {
2930
...options,

packages/inference/src/tasks/cv/textToVideo.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import type { TextToVideoInput } from "@huggingface/tasks";
2+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
23
import { getProviderHelper } from "../../lib/getProviderHelper";
34
import { makeRequestOptions } from "../../lib/makeRequestOptions";
45
import type { FalAiQueueOutput } from "../../providers/fal-ai";
@@ -12,7 +13,7 @@ export type TextToVideoArgs = BaseArgs & TextToVideoInput;
1213
export type TextToVideoOutput = Blob;
1314

1415
export async function textToVideo(args: TextToVideoArgs, options?: Options): Promise<TextToVideoOutput> {
15-
const provider = args.provider ?? "hf-inference";
16+
const provider = await resolveProvider(args.provider, args.model);
1617
const providerHelper = getProviderHelper(provider, "text-to-video");
1718
const { data: response } = await innerRequest<FalAiQueueOutput | ReplicateOutput | NovitaOutput>(
1819
args,

packages/inference/src/tasks/cv/zeroShotImageClassification.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import type { ZeroShotImageClassificationInput, ZeroShotImageClassificationOutput } from "@huggingface/tasks";
2+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
23
import { getProviderHelper } from "../../lib/getProviderHelper";
34
import type { BaseArgs, Options, RequestArgs } from "../../types";
45
import { base64FromBytes } from "../../utils/base64FromBytes";
@@ -44,7 +45,8 @@ export async function zeroShotImageClassification(
4445
args: ZeroShotImageClassificationArgs,
4546
options?: Options
4647
): Promise<ZeroShotImageClassificationOutput> {
47-
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-image-classification");
48+
const provider = await resolveProvider(args.provider, args.model);
49+
const providerHelper = getProviderHelper(provider, "zero-shot-image-classification");
4850
const payload = await preparePayload(args);
4951
const { data: res } = await innerRequest<ZeroShotImageClassificationOutput>(payload, providerHelper, {
5052
...options,

packages/inference/src/tasks/multimodal/documentQuestionAnswering.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import type {
33
DocumentQuestionAnsweringInputData,
44
DocumentQuestionAnsweringOutput,
55
} from "@huggingface/tasks";
6+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
67
import { getProviderHelper } from "../../lib/getProviderHelper";
78
import type { BaseArgs, Options, RequestArgs } from "../../types";
89
import { base64FromBytes } from "../../utils/base64FromBytes";
@@ -19,7 +20,8 @@ export async function documentQuestionAnswering(
1920
args: DocumentQuestionAnsweringArgs,
2021
options?: Options
2122
): Promise<DocumentQuestionAnsweringOutput[number]> {
22-
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "document-question-answering");
23+
const provider = await resolveProvider(args.provider, args.model);
24+
const providerHelper = getProviderHelper(provider, "document-question-answering");
2325
const reqArgs: RequestArgs = {
2426
...args,
2527
inputs: {

packages/inference/src/tasks/multimodal/visualQuestionAnswering.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import type {
33
VisualQuestionAnsweringInputData,
44
VisualQuestionAnsweringOutput,
55
} from "@huggingface/tasks";
6+
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
67
import { getProviderHelper } from "../../lib/getProviderHelper";
78
import type { BaseArgs, Options, RequestArgs } from "../../types";
89
import { base64FromBytes } from "../../utils/base64FromBytes";
@@ -19,7 +20,8 @@ export async function visualQuestionAnswering(
1920
args: VisualQuestionAnsweringArgs,
2021
options?: Options
2122
): Promise<VisualQuestionAnsweringOutput[number]> {
22-
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "visual-question-answering");
23+
const provider = await resolveProvider(args.provider, args.model);
24+
const providerHelper = getProviderHelper(provider, "visual-question-answering");
2325
const reqArgs: RequestArgs = {
2426
...args,
2527
inputs: {

0 commit comments

Comments
 (0)