diff --git a/src/runtime/oauth.ts b/src/runtime/oauth.ts index 8128cc2..ca1e465 100644 --- a/src/runtime/oauth.ts +++ b/src/runtime/oauth.ts @@ -19,6 +19,13 @@ export class OAuthTimeoutError extends Error { } } +export class OAuthCompletedError extends Error { + constructor(serverName: string) { + super(`OAuth completed for '${serverName}'; reconnect with a fresh transport.`); + this.name = 'OAuthCompletedError'; + } +} + export async function connectWithAuth( client: Client, transport: Transport & { @@ -53,12 +60,16 @@ export async function connectWithAuth( ); if (typeof transport.finishAuth === 'function') { await transport.finishAuth(code); - logger.info('Authorization code accepted. Retrying connection...'); + logger.info('Authorization code accepted. Reconnecting with fresh transport...'); + throw new OAuthCompletedError(serverName ?? 'unknown'); } else { logger.warn('Transport does not support finishAuth; cannot complete OAuth flow automatically.'); throw error; } } catch (authError) { + if (authError instanceof OAuthCompletedError) { + throw authError; + } logger.error('OAuth authorization failed while waiting for callback.', authError); throw authError; } diff --git a/src/runtime/transport.ts b/src/runtime/transport.ts index 983c3be..7b3c6bb 100644 --- a/src/runtime/transport.ts +++ b/src/runtime/transport.ts @@ -11,7 +11,7 @@ import { readCachedAccessToken } from '../oauth-persistence.js'; import { materializeHeaders } from '../runtime-header-utils.js'; import { isUnauthorizedError, maybeEnableOAuth } from '../runtime-oauth-support.js'; import { closeTransportAndWait } from '../runtime-process-utils.js'; -import { connectWithAuth, OAuthTimeoutError } from './oauth.js'; +import { connectWithAuth, OAuthCompletedError, OAuthTimeoutError } from './oauth.js'; import { resolveCommandArgument, resolveCommandArguments } from './utils.js'; const STDIO_TRACE_ENABLED = process.env.MCPORTER_STDIO_TRACE === '1'; @@ -147,6 +147,13 @@ export async function createClientContext( try { return await attemptConnect(); } catch (primaryError) { + if (primaryError instanceof OAuthCompletedError) { + // OAuth succeeded but the transport is already started; retry with a fresh transport. + // Close the current session's callback server to free the port before the next iteration. + await oauthSession?.close().catch(() => {}); + logger.info(`OAuth complete for '${activeDefinition.name}'. Reconnecting...`); + continue; + } if (isUnauthorizedError(primaryError)) { await oauthSession?.close().catch(() => {}); oauthSession = undefined; diff --git a/tests/runtime-oauth-connect.test.ts b/tests/runtime-oauth-connect.test.ts index cba76a8..7da4901 100644 --- a/tests/runtime-oauth-connect.test.ts +++ b/tests/runtime-oauth-connect.test.ts @@ -5,7 +5,7 @@ import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js'; import { describe, expect, it, vi } from 'vitest'; import type { Logger } from '../src/logging.js'; import type { OAuthSession } from '../src/oauth.js'; -import { connectWithAuth } from '../src/runtime/oauth.js'; +import { connectWithAuth, OAuthCompletedError } from '../src/runtime/oauth.js'; // Minimal mock transport that records finishAuth calls. class MockTransport implements Transport { @@ -23,12 +23,9 @@ class MockTransport implements Transport { } describe('connectWithAuth', () => { - it('waits for authorization code and retries connection', async () => { - // First connect throws Unauthorized, second succeeds after finishAuth. - const connect = vi - .fn() - .mockRejectedValueOnce(new UnauthorizedError('auth needed')) - .mockResolvedValueOnce(undefined); + it('throws OAuthCompletedError after successful finishAuth so caller can reconnect', async () => { + // connect throws Unauthorized, finishAuth succeeds, then OAuthCompletedError is thrown. + const connect = vi.fn().mockRejectedValueOnce(new UnauthorizedError('auth needed')); const client = { connect } as unknown as Client; let resolveCode: (code: string) => void = () => {}; @@ -63,10 +60,10 @@ describe('connectWithAuth', () => { // Simulate browser callback arrival. resolveCode('oauth-code-123'); - await promise; + await expect(promise).rejects.toThrow(OAuthCompletedError); expect(waitForAuthorizationCode).toHaveBeenCalledTimes(1); expect(transport.calls).toEqual(['oauth-code-123']); - expect(connect).toHaveBeenCalledTimes(2); + expect(connect).toHaveBeenCalledTimes(1); }); });