Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 93 additions & 10 deletions src/runtime/oauth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import type { OAuthSession } from '../oauth.js';
import { isUnauthorizedError } from '../runtime-oauth-support.js';

export const DEFAULT_OAUTH_CODE_TIMEOUT_MS = 60_000;
const OAUTH_FLOW_ERROR = Symbol('oauth-flow-error');
const POST_AUTH_CONNECT_ERROR = Symbol('post-auth-connect-error');

export class OAuthTimeoutError extends Error {
public readonly timeoutMs: number;
Expand All @@ -19,6 +21,43 @@ export class OAuthTimeoutError extends Error {
}
}

export function markOAuthFlowError(error: unknown): unknown {
return markError(error, OAUTH_FLOW_ERROR);
}

export function isOAuthFlowError(error: unknown): boolean {
return hasErrorMarker(error, OAUTH_FLOW_ERROR);
}

export function markPostAuthConnectError(error: unknown): unknown {
return markError(error, POST_AUTH_CONNECT_ERROR);
}

export function isPostAuthConnectError(error: unknown): boolean {
return hasErrorMarker(error, POST_AUTH_CONNECT_ERROR);
}

function markError(error: unknown, marker: symbol): unknown {
if (!error || (typeof error !== 'object' && typeof error !== 'function')) {
return error;
}
Object.defineProperty(error, marker, {
value: true,
enumerable: false,
configurable: true,
});
return error;
}

function hasErrorMarker(error: unknown, marker: symbol): boolean {
return (
!!error &&
(typeof error === 'object' || typeof error === 'function') &&
marker in error &&
Boolean((error as Record<PropertyKey, unknown>)[marker])
);
}

export async function connectWithAuth(
client: Client,
transport: Transport & {
Expand All @@ -27,21 +66,58 @@ export async function connectWithAuth(
},
session: OAuthSession | undefined,
logger: Logger,
options: { serverName?: string; maxAttempts?: number; oauthTimeoutMs?: number } = {}
): Promise<void> {
const { serverName, maxAttempts = 3, oauthTimeoutMs = DEFAULT_OAUTH_CODE_TIMEOUT_MS } = options;
options: {
serverName?: string;
maxAttempts?: number;
oauthTimeoutMs?: number;
recreateTransport?: (
transport: Transport & {
close(): Promise<void>;
finishAuth?: (authorizationCode: string) => Promise<void>;
}
) => Promise<
Transport & {
close(): Promise<void>;
finishAuth?: (authorizationCode: string) => Promise<void>;
}
>;
} = {}
): Promise<
Transport & {
close(): Promise<void>;
finishAuth?: (authorizationCode: string) => Promise<void>;
}
> {
const { serverName, maxAttempts = 3, oauthTimeoutMs = DEFAULT_OAUTH_CODE_TIMEOUT_MS, recreateTransport } = options;
let activeTransport = transport;
let attempt = 0;
let hasCompletedAuthFlow = false;

const closeReplacementTransport = async (): Promise<void> => {
if (activeTransport === transport) {
return;
}
await activeTransport.close().catch(() => {});
};

while (true) {
try {
await client.connect(transport);
return;
await client.connect(activeTransport);
return activeTransport;
} catch (error) {
if (!isUnauthorizedError(error) || !session) {
const unauthorized = isUnauthorizedError(error);
if (hasCompletedAuthFlow && !unauthorized) {
await closeReplacementTransport();
throw markPostAuthConnectError(error);
}
if (!unauthorized || !session) {
await closeReplacementTransport();
throw error;
}
attempt += 1;
if (attempt > maxAttempts) {
throw error;
await closeReplacementTransport();
throw hasCompletedAuthFlow ? markPostAuthConnectError(error) : error;
}
logger.warn(`OAuth authorization required for '${serverName ?? 'unknown'}'. Waiting for browser approval...`);
try {
Expand All @@ -51,16 +127,23 @@ export async function connectWithAuth(
serverName,
oauthTimeoutMs ?? DEFAULT_OAUTH_CODE_TIMEOUT_MS
);
if (typeof transport.finishAuth === 'function') {
await transport.finishAuth(code);
if (typeof activeTransport.finishAuth === 'function') {
await activeTransport.finishAuth(code);
if (recreateTransport) {
const nextTransport = await recreateTransport(activeTransport);
await activeTransport.close().catch(() => {});
activeTransport = nextTransport;
}
hasCompletedAuthFlow = true;
logger.info('Authorization code accepted. Retrying connection...');
} else {
logger.warn('Transport does not support finishAuth; cannot complete OAuth flow automatically.');
throw error;
}
} catch (authError) {
logger.error('OAuth authorization failed while waiting for callback.', authError);
throw authError;
await closeReplacementTransport();
throw markOAuthFlowError(authError);
}
}
}
Expand Down
63 changes: 51 additions & 12 deletions src/runtime/transport.ts
Original file line number Diff line number Diff line change
@@ -1,21 +1,53 @@
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
import { StreamableHTTPClientTransport, StreamableHTTPError } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
import type { ServerDefinition } from '../config.js';
import { resolveEnvValue, withEnvOverrides } from '../env.js';
import { analyzeConnectionError } from '../error-classifier.js';
import type { Logger } from '../logging.js';
import { createOAuthSession, type OAuthSession } from '../oauth.js';
import { readCachedAccessToken } from '../oauth-persistence.js';
import { materializeHeaders } from '../runtime-header-utils.js';
import { isUnauthorizedError, maybeEnableOAuth } from '../runtime-oauth-support.js';
import { closeTransportAndWait } from '../runtime-process-utils.js';
import { connectWithAuth, OAuthTimeoutError } from './oauth.js';
import { connectWithAuth, isOAuthFlowError, isPostAuthConnectError, OAuthTimeoutError } from './oauth.js';
import { resolveCommandArgument, resolveCommandArguments } from './utils.js';

const STDIO_TRACE_ENABLED = process.env.MCPORTER_STDIO_TRACE === '1';

function extractTransportStatusCode(error: unknown): number | undefined {
if (!error || typeof error !== 'object') {
return undefined;
}
const record = error as Record<string, unknown>;
for (const candidate of [record.code, record.status, record.statusCode]) {
if (typeof candidate === 'number') {
return candidate;
}
if (typeof candidate === 'string') {
const parsed = Number.parseInt(candidate, 10);
if (Number.isFinite(parsed)) {
return parsed;
}
}
}
return undefined;
}

function isLegacySseTransportMismatch(error: unknown): boolean {
if (error instanceof StreamableHTTPError) {
return error.code === 404 || error.code === 405;
}
const directStatusCode = extractTransportStatusCode(error);
if (directStatusCode === 404 || directStatusCode === 405) {
return true;
}
const issue = analyzeConnectionError(error);
return issue.kind === 'http' && (issue.statusCode === 404 || issue.statusCode === 405);
}

function attachStdioTraceLogging(_transport: StdioClientTransport, _label?: string): void {
// STDIO instrumentation is handled via sdk-patches side effects. This helper remains
// so runtime callers can opt-in without sprinkling conditional checks everywhere.
Expand Down Expand Up @@ -125,13 +157,15 @@ export async function createClientContext(
};

const attemptConnect = async () => {
const streamableTransport = new StreamableHTTPClientTransport(command.url, baseOptions);
const createStreamableTransport = () => new StreamableHTTPClientTransport(command.url, baseOptions);
let streamableTransport = createStreamableTransport();
try {
await connectWithAuth(client, streamableTransport, oauthSession, logger, {
streamableTransport = (await connectWithAuth(client, streamableTransport, oauthSession, logger, {
serverName: activeDefinition.name,
maxAttempts: options.maxOAuthAttempts,
oauthTimeoutMs: options.oauthTimeoutMs,
});
recreateTransport: async () => createStreamableTransport(),
})) as StreamableHTTPClientTransport;
return {
client,
transport: streamableTransport,
Expand All @@ -147,6 +181,15 @@ export async function createClientContext(
try {
return await attemptConnect();
} catch (primaryError) {
if (isPostAuthConnectError(primaryError)) {
if (!isLegacySseTransportMismatch(primaryError)) {
await oauthSession?.close().catch(() => {});
throw primaryError;
}
} else if (isOAuthFlowError(primaryError) || primaryError instanceof OAuthTimeoutError) {
await oauthSession?.close().catch(() => {});
throw primaryError;
}
if (isUnauthorizedError(primaryError)) {
await oauthSession?.close().catch(() => {});
oauthSession = undefined;
Expand All @@ -159,23 +202,19 @@ export async function createClientContext(
}
}
}
if (primaryError instanceof OAuthTimeoutError) {
await oauthSession?.close().catch(() => {});
throw primaryError;
}
if (primaryError instanceof Error) {
logger.info(`Falling back to SSE transport for '${activeDefinition.name}': ${primaryError.message}`);
}
const sseTransport = new SSEClientTransport(command.url, {
...baseOptions,
});
try {
await connectWithAuth(client, sseTransport, oauthSession, logger, {
const connectedTransport = (await connectWithAuth(client, sseTransport, oauthSession, logger, {
serverName: activeDefinition.name,
maxAttempts: options.maxOAuthAttempts,
oauthTimeoutMs: options.oauthTimeoutMs,
});
return { client, transport: sseTransport, definition: activeDefinition, oauthSession };
})) as SSEClientTransport;
return { client, transport: connectedTransport, definition: activeDefinition, oauthSession };
} catch (sseError) {
await closeTransportAndWait(logger, sseTransport).catch(() => {});
await oauthSession?.close().catch(() => {});
Expand Down
Loading