Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ Each model entry also offers `cheapest` and `fastest` mode for each model. `fast
* VS Code 1.104.0 or higher.
* Hugging Face access token with `inference.serverless` permissions.

## Organization Billing
If you want to bill inference requests to a Hugging Face organization, set the `huggingface.billTo` setting to your org name. The extension will forward it as `X-HF-Bill-To` on inference requests.

## 🛠️ Development
```bash
git clone https://github.com/huggingface/huggingface-vscode-chat
Expand Down
12 changes: 11 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,17 @@
"command": "huggingface.manage",
"title": "Manage Hugging Face Provider"
}
]
],
"configuration": {
"title": "Hugging Face Provider",
"properties": {
"huggingface.billTo": {
"type": "string",
"default": "",
"description": "Optional organization name to bill Hugging Face inference requests to."
}
}
}
},
"main": "./out/extension.js",
"scripts": {
Expand Down
26 changes: 22 additions & 4 deletions src/extension.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ export function activate(context: vscode.ExtensionContext) {
// Register the Hugging Face provider under the vendor id used in package.json
vscode.lm.registerLanguageModelChatProvider("huggingface", provider);

// Management command to configure API key
// Management command to configure API key and optional organization billing
context.subscriptions.push(
vscode.commands.registerCommand("huggingface.manage", async () => {
const existing = await context.secrets.get("huggingface.apiKey");
Expand All @@ -30,10 +30,28 @@ export function activate(context: vscode.ExtensionContext) {
if (!apiKey.trim()) {
await context.secrets.delete("huggingface.apiKey");
vscode.window.showInformationMessage("Hugging Face API key cleared.");
return;
} else {
await context.secrets.store("huggingface.apiKey", apiKey.trim());
vscode.window.showInformationMessage("Hugging Face API key saved.");
}

const config = vscode.workspace.getConfiguration("huggingface");
const existingBillTo = config.get<string>("billTo") ?? "";
const billTo = await vscode.window.showInputBox({
title: "Hugging Face Organization Billing",
prompt: "Optional org name to bill inference requests to",
ignoreFocusOut: true,
value: existingBillTo,
});
if (billTo === undefined) {
return; // user canceled
}
await config.update("billTo", billTo.trim(), vscode.ConfigurationTarget.Global);
if (!billTo.trim()) {
vscode.window.showInformationMessage("Hugging Face organization billing cleared.");
} else {
vscode.window.showInformationMessage("Hugging Face organization billing saved.");
}
await context.secrets.store("huggingface.apiKey", apiKey.trim());
vscode.window.showInformationMessage("Hugging Face API key saved.");
})
);
}
Expand Down
137 changes: 94 additions & 43 deletions src/provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,55 @@ const BASE_URL = "https://router.huggingface.co/v1";
const DEFAULT_MAX_OUTPUT_TOKENS = 16000;
const DEFAULT_CONTEXT_LENGTH = 128000;

type HFStreamReader = {
read(): Promise<{ done: boolean; value?: Uint8Array }>;
releaseLock(): void;
};

type HFStream = {
getReader(): HFStreamReader;
};

type HFResponse = {
ok: boolean;
status: number;
statusText: string;
text(): Promise<string>;
json(): Promise<unknown>;
body: HFStream | null;
};

const runtimeGlobals = globalThis as unknown as {
fetch?: (input: string, init?: { method?: string; headers?: Record<string, string>; body?: string }) => Promise<unknown>;
TextDecoder?: new () => { decode(input?: Uint8Array, options?: { stream?: boolean }): string };
console?: { error?: (...args: unknown[]) => void };
};

const hfFetch = async (
input: string,
init?: { method?: string; headers?: Record<string, string>; body?: string },
): Promise<HFResponse> => {
if (!runtimeGlobals.fetch) {
throw new Error("Fetch API is not available in this runtime.");
}
return (await runtimeGlobals.fetch(input, init)) as HFResponse;
};

const createTextDecoder = (): { decode(input?: Uint8Array, options?: { stream?: boolean }): string } => {
if (!runtimeGlobals.TextDecoder) {
throw new Error("TextDecoder is not available in this runtime.");
}
return new runtimeGlobals.TextDecoder();
};

const logError = (...args: unknown[]): void => {
runtimeGlobals.console?.error?.(...args);
};

/**
* VS Code Chat provider backed by Hugging Face Inference Providers.
*/
export class HuggingFaceChatModelProvider implements LanguageModelChatProvider {
private _chatEndpoints: { model: string; modelMaxPromptTokens: number }[] = [];
/** Buffer for assembling streamed tool calls by index. */
private _toolCallBuffers: Map<number, { id?: string; name?: string; args: string }> = new Map<
number,
Expand Down Expand Up @@ -80,6 +124,13 @@ export class HuggingFaceChatModelProvider implements LanguageModelChatProvider {
}
}

/** Optional organization to bill requests to, as configured by the user. */
private getBillTo(): string | undefined {
const billTo = vscode.workspace.getConfiguration("huggingface").get<string>("billTo");
const trimmed = billTo?.trim();
return trimmed ? trimmed : undefined;
}

/**
* Get the list of available language models contributed by this provider
* @param options Options which specify the calling context of this function
Expand Down Expand Up @@ -181,11 +232,6 @@ export class HuggingFaceChatModelProvider implements LanguageModelChatProvider {
return entries;
});

this._chatEndpoints = infos.map((info) => ({
model: info.id,
modelMaxPromptTokens: info.maxInputTokens + info.maxOutputTokens,
}));

return infos;
}

Expand All @@ -204,7 +250,7 @@ export class HuggingFaceChatModelProvider implements LanguageModelChatProvider {
apiKey: string
): Promise<{ models: HFModelItem[] }> {
const modelsList = (async () => {
const resp = await fetch(`${BASE_URL}/models`, {
const resp = await hfFetch(`${BASE_URL}/models`, {
method: "GET",
headers: { Authorization: `Bearer ${apiKey}`, "User-Agent": this.userAgent },
});
Expand All @@ -213,12 +259,12 @@ export class HuggingFaceChatModelProvider implements LanguageModelChatProvider {
try {
text = await resp.text();
} catch (error) {
console.error("[Hugging Face Model Provider] Failed to read response text", error);
logError("[Hugging Face Model Provider] Failed to read response text", error);
}
const err = new Error(
`Failed to fetch Hugging Face models: ${resp.status} ${resp.statusText}${text ? `\n${text}` : ""}`
);
console.error("[Hugging Face Model Provider] Failed to fetch Hugging Face models", err);
logError("[Hugging Face Model Provider] Failed to fetch Hugging Face models", err);
throw err;
}
const parsed = (await resp.json()) as HFModelsResponse;
Expand All @@ -229,7 +275,7 @@ export class HuggingFaceChatModelProvider implements LanguageModelChatProvider {
const models = await modelsList;
return { models };
} catch (err) {
console.error("[Hugging Face Model Provider] Failed to fetch Hugging Face models", err);
logError("[Hugging Face Model Provider] Failed to fetch Hugging Face models", err);
throw err;
}
}
Expand Down Expand Up @@ -268,7 +314,7 @@ export class HuggingFaceChatModelProvider implements LanguageModelChatProvider {
try {
progress.report(part);
} catch (e) {
console.error("[Hugging Face Model Provider] Progress.report failed", {
logError("[Hugging Face Model Provider] Progress.report failed", {
modelId: model.id,
error: e instanceof Error ? { name: e.name, message: e.message } : String(e),
});
Expand All @@ -280,6 +326,7 @@ export class HuggingFaceChatModelProvider implements LanguageModelChatProvider {
if (!apiKey) {
throw new Error("Hugging Face API key not found");
}
const billTo = this.getBillTo();

const openaiMessages = convertMessages(messages);

Expand All @@ -291,13 +338,13 @@ export class HuggingFaceChatModelProvider implements LanguageModelChatProvider {
throw new Error("Cannot have more than 128 tools per request.");
}

const inputTokenCount = this.estimateMessagesTokens(messages);
const toolTokenCount = this.estimateToolTokens(toolConfig.tools);
const tokenLimit = Math.max(1, model.maxInputTokens);
if (inputTokenCount + toolTokenCount > tokenLimit) {
console.error("[Hugging Face Model Provider] Message exceeds token limit", { total: inputTokenCount + toolTokenCount, tokenLimit });
throw new Error("Message exceeds token limit.");
}
const inputTokenCount = this.estimateMessagesTokens(messages);
const toolTokenCount = this.estimateToolTokens(toolConfig.tools);
const tokenLimit = Math.max(1, model.maxInputTokens);
if (inputTokenCount + toolTokenCount > tokenLimit) {
logError("[Hugging Face Model Provider] Message exceeds token limit", { total: inputTokenCount + toolTokenCount, tokenLimit });
throw new Error("Message exceeds token limit.");
}

requestBody = {
model: model.id,
Expand Down Expand Up @@ -327,19 +374,23 @@ export class HuggingFaceChatModelProvider implements LanguageModelChatProvider {
if (toolConfig.tool_choice) {
(requestBody as Record<string, unknown>).tool_choice = toolConfig.tool_choice;
}
const response = await fetch(`${BASE_URL}/chat/completions`, {
method: "POST",
headers: {
Authorization: `Bearer ${apiKey}`,
"Content-Type": "application/json",
"User-Agent": this.userAgent,
},
body: JSON.stringify(requestBody),
});
const headers: Record<string, string> = {
Authorization: `Bearer ${apiKey}`,
"Content-Type": "application/json",
"User-Agent": this.userAgent,
};
if (billTo) {
headers["X-HF-Bill-To"] = billTo;
}
const response = await hfFetch(`${BASE_URL}/chat/completions`, {
method: "POST",
headers,
body: JSON.stringify(requestBody),
});

if (!response.ok) {
const errorText = await response.text();
console.error("[Hugging Face Model Provider] HF API error response", errorText);
logError("[Hugging Face Model Provider] HF API error response", errorText);
throw new Error(
`Hugging Face API error: ${response.status} ${response.statusText}${errorText ? `\n${errorText}` : ""}`
);
Expand All @@ -350,7 +401,7 @@ export class HuggingFaceChatModelProvider implements LanguageModelChatProvider {
}
await this.processStreamingResponse(response.body, trackingProgress, token);
} catch (err) {
console.error("[Hugging Face Model Provider] Chat request failed", {
logError("[Hugging Face Model Provider] Chat request failed", {
modelId: model.id,
messageCount: messages.length,
error: err instanceof Error ? { name: err.name, message: err.message } : String(err),
Expand All @@ -367,7 +418,7 @@ export class HuggingFaceChatModelProvider implements LanguageModelChatProvider {
* @returns A promise that resolves to the number of tokens
*/
async provideTokenCount(
model: LanguageModelChatInformation,
_model: LanguageModelChatInformation,
text: string | LanguageModelChatMessage,
_token: CancellationToken
): Promise<number> {
Expand Down Expand Up @@ -411,14 +462,14 @@ export class HuggingFaceChatModelProvider implements LanguageModelChatProvider {
* @param progress Progress reporter for streamed parts.
* @param token Cancellation token.
*/
private async processStreamingResponse(
responseBody: ReadableStream<Uint8Array>,
progress: vscode.Progress<vscode.LanguageModelResponsePart>,
token: vscode.CancellationToken,
): Promise<void> {
const reader = responseBody.getReader();
const decoder = new TextDecoder();
let buffer = "";
private async processStreamingResponse(
responseBody: HFStream,
progress: vscode.Progress<vscode.LanguageModelResponsePart>,
token: vscode.CancellationToken,
): Promise<void> {
const reader = responseBody.getReader();
const decoder = createTextDecoder();
let buffer = "";

try {
while (!token.isCancellationRequested) {
Expand Down Expand Up @@ -779,11 +830,11 @@ export class HuggingFaceChatModelProvider implements LanguageModelChatProvider {
}
for (const [idx, buf] of Array.from(this._toolCallBuffers.entries())) {
const parsed = tryParseJSONObject(buf.args);
if (!parsed.ok) {
if (throwOnInvalid) {
console.error("[Hugging Face Model Provider] Invalid JSON for tool call", { idx, snippet: (buf.args || "").slice(0, 200) });
throw new Error("Invalid JSON for tool call");
}
if (!parsed.ok) {
if (throwOnInvalid) {
logError("[Hugging Face Model Provider] Invalid JSON for tool call", { idx, snippet: (buf.args || "").slice(0, 200) });
throw new Error("Invalid JSON for tool call");
}
// When not throwing (e.g. on [DONE]), drop silently to reduce noise
continue;
}
Expand Down
Loading