diff --git a/.changeset/mcp-client-implementation.md b/.changeset/mcp-client-implementation.md new file mode 100644 index 0000000..0da47e1 --- /dev/null +++ b/.changeset/mcp-client-implementation.md @@ -0,0 +1,17 @@ +--- +"mcp-lite": minor +--- + +Add MCP client implementation with session management and bidirectional communication support. + +The client provides: +- McpClient class for connecting to MCP servers +- Connection interface for calling tools, prompts, and resources +- ClientSessionAdapter for optional session persistence +- StreamableHttpClientTransport for HTTP/SSE communication +- Handler registration for server-initiated requests (sampling, elicitation) +- Middleware support for request interception +- ToolAdapter interface for SDK integration +- Full TypeScript support with zero runtime dependencies + +Includes 29 integration tests validating stateless operations, session management, server-initiated requests, and end-to-end workflows. diff --git a/README.client.md b/README.client.md new file mode 100644 index 0000000..958cb7d --- /dev/null +++ b/README.client.md @@ -0,0 +1,1713 @@ +# MCP Client + +A lightweight client for connecting to and interacting with Model Context Protocol (MCP) servers. + +The `mcp-lite` client provides a simple, type-safe API for calling tools, prompts, and resources on MCP servers, with support for server-initiated requests like elicitation and sampling. + +## Quick Start + +```bash +npm install mcp-lite +``` + +Connect to an MCP server and call a tool: + +```typescript +import { McpClient, StreamableHttpClientTransport } from "mcp-lite"; + +// Create client instance +const client = new McpClient({ + name: "my-client", + version: "1.0.0" +}); + +// Connect to server +const transport = new StreamableHttpClientTransport(); +const connect = transport.bind(client); +const connection = await connect("http://localhost:3000/mcp"); + +// Discover available tools +const { tools } = await connection.listTools(); +console.log(`Found ${tools.length} tools:`, tools.map(t => t.name)); + +// Call a tool +const result = await connection.callTool("echo", { + message: "Hello!" +}); + +console.log(result.content[0].text); +``` + +## Features + +- **Simple API** - Connect to MCP servers and call tools, prompts, and resources +- **Type-safe** - Full TypeScript support with inferred types +- **Stateless or Stateful** - Start without sessions, add them when you need server-initiated requests +- **Multi-server** - Connect to multiple servers simultaneously with independent authentication +- **Custom Headers** - Pass authentication tokens, API keys, or custom headers per-connection +- **Server Requests** - Handle elicitation and sampling requests from servers +- **SSE Streaming** - Receive server notifications and progress updates via Server-Sent Events +- **OAuth 2.1** - Built-in OAuth support with PKCE, token refresh, and discovery +- **Error Handling** - Clear error messages with RpcError support + +## Client Setup + +### Basic Client + +Create a client with minimal configuration: + +```typescript +import { McpClient } from "mcp-lite"; + +const client = new McpClient({ + name: "my-app", + version: "1.0.0" +}); +``` + +### Client with Capabilities + +Advertise support for elicitation and sampling: + +```typescript +const client = new McpClient({ + name: "my-app", + version: "1.0.0", + capabilities: { + elicitation: {}, // Support elicitation requests from server + sampling: {} // Support sampling requests from server + } +}); +``` + +### Client with Custom Logger + +Provide your own logger for debugging: + +```typescript +const client = new McpClient({ + name: "my-app", + version: "1.0.0", + logger: { + info: (msg) => console.log(`[INFO] ${msg}`), + error: (msg) => console.error(`[ERROR] ${msg}`), + warn: (msg) => console.warn(`[WARN] ${msg}`) + } +}); +``` + +## Connecting to Servers + +### Stateless Connection + +Connect without sessions for simple request/response: + +```typescript +import { StreamableHttpClientTransport } from "mcp-lite"; + +const transport = new StreamableHttpClientTransport(); +const connect = transport.bind(client); + +const connection = await connect("http://localhost:3000/mcp"); + +// Server information is available on the connection +console.log(connection.serverInfo.name); // "my-server" +console.log(connection.serverInfo.version); // "1.0.0" +console.log(connection.serverCapabilities); // { tools: {...}, prompts: {...} } +``` + +### Session-Based Connection + +Enable sessions for SSE streaming and server-initiated requests: + +```typescript +import { + StreamableHttpClientTransport, + InMemoryClientSessionAdapter +} from "mcp-lite"; + +const transport = new StreamableHttpClientTransport({ + sessionAdapter: new InMemoryClientSessionAdapter() +}); + +const connect = transport.bind(client); +const connection = await connect("http://localhost:3000/mcp"); + +// Session ID is available when using session adapter +console.log(connection.sessionId); // "abc123..." + +// Open SSE stream to receive server notifications +const stream = await connection.openSessionStream(); +``` + +### Connection with Custom Headers + +Pass custom headers for authentication, API keys, or request tracking: + +```typescript +const transport = new StreamableHttpClientTransport(); +const connect = transport.bind(client); + +// Connect with authentication headers +const connection = await connect("http://localhost:3000/mcp", { + headers: { + 'Authorization': 'Bearer my-secret-token', + 'X-API-Key': 'my-api-key', + 'X-Request-ID': 'req-123' + } +}); + +// Headers are included in all requests for this connection +await connection.callTool("listRepos", {}); +``` + +Custom headers are merged with protocol-required headers and included in: +- The initial `initialize` request +- All tool, prompt, and resource calls +- SSE streams (when using session mode) +- Response sends (for server-initiated requests) + +**Note**: Custom headers work alongside OAuth authentication. If both are configured, the OAuth `Authorization` header takes precedence, but other custom headers are still applied + +### Multiple Server Connections + +A single client can connect to multiple servers with different headers: + +```typescript +const client = new McpClient({ + name: "multi-client", + version: "1.0.0" +}); + +const transport = new StreamableHttpClientTransport(); +const connect = transport.bind(client); + +// Connect to multiple servers with different authentication +const githubConn = await connect("http://localhost:3000/github", { + headers: { + 'Authorization': 'Bearer github-token-123' + } +}); + +const slackConn = await connect("http://localhost:3001/slack", { + headers: { + 'Authorization': 'Bearer slack-token-456' + } +}); + +const dbConn = await connect("http://localhost:3002/db", { + headers: { + 'Authorization': 'Bearer db-token-789', + 'X-Database': 'production' + } +}); + +// Each connection uses its own headers independently +const repos = await githubConn.callTool("listRepos", {}); +const message = await slackConn.callTool("postMessage", { + channel: "#dev", + text: "New issue created" +}); +const records = await dbConn.callTool("query", { + sql: "SELECT * FROM users" +}); +``` + +## Calling Tools + +### Basic Tool Call + +Call a tool with arguments: + +```typescript +const result = await connection.callTool("echo", { + message: "Hello World" +}); + +console.log(result.content[0].text); // "Hello World" +``` + +### Tool Call with Structured Output + +Access both human-readable and structured content: + +```typescript +const result = await connection.callTool("getWeather", { + location: "San Francisco" +}); + +// Human-readable content +console.log(result.content[0].text); +// "Weather in San Francisco: 22°C, sunny" + +// Structured content (if provided by server) +if (result.structuredContent) { + console.log(result.structuredContent.temperature); // 22 + console.log(result.structuredContent.conditions); // "sunny" +} +``` + +### Listing Available Tools + +Discover what tools are available: + +```typescript +const { tools } = await connection.listTools(); + +for (const tool of tools) { + console.log(`${tool.name}: ${tool.description}`); + console.log(`Input schema:`, tool.inputSchema); + console.log(`Output schema:`, tool.outputSchema); +} +``` + +### Concurrent Tool Calls + +Execute multiple tools in parallel: + +```typescript +const results = await Promise.all([ + connection.callTool("echo", { message: "First" }), + connection.callTool("echo", { message: "Second" }), + connection.callTool("add", { a: 1, b: 2 }) +]); + +console.log(results[0].content[0].text); // "First" +console.log(results[1].content[0].text); // "Second" +console.log(results[2].content[0].text); // "3" +``` + +## Working with Prompts + +### List Prompts + +Get all available prompts from the server: + +```typescript +const { prompts } = await connection.listPrompts(); + +for (const prompt of prompts) { + console.log(`${prompt.name}: ${prompt.description}`); + console.log(`Arguments:`, prompt.arguments); +} +``` + +### Get a Prompt + +Retrieve a prompt with arguments: + +```typescript +const result = await connection.getPrompt("summarize", { + text: "Long article text...", + length: "short" +}); + +// Prompt returns messages for LLM +for (const message of result.messages) { + console.log(`${message.role}:`, message.content.text); +} +``` + +### Basic Prompt (No Arguments) + +```typescript +const result = await connection.getPrompt("greet"); + +console.log(result.messages[0].content.text); +// "Hello, how are you?" +``` + +## Working with Resources + +### List Resources + +Get all available resources: + +```typescript +const { resources } = await connection.listResources(); + +for (const resource of resources) { + console.log(`${resource.uri}: ${resource.description}`); + console.log(`MIME type:`, resource.mimeType); +} +``` + +### Read a Resource + +Fetch resource contents: + +```typescript +const result = await connection.readResource("file://config.json"); + +for (const content of result.contents) { + console.log(`URI: ${content.uri}`); + console.log(`Type: ${content.type}`); + console.log(`Content: ${content.text}`); + console.log(`MIME: ${content.mimeType}`); +} +``` + +### List Resource Templates + +Discover templated resources: + +```typescript +const { resourceTemplates } = await connection.listResourceTemplates(); + +for (const template of resourceTemplates) { + console.log(`Template: ${template.uriTemplate}`); + console.log(`Description: ${template.description}`); +} + +// Use a template with parameters +const result = await connection.readResource( + "github://repos/owner/repo" +); +``` + +## Server-Initiated Requests + +When using session-based connections, servers can send requests to the client for elicitation (user prompts) and sampling (LLM completions). + +### Handling Elicitation + +Register a handler for elicitation requests: + +```typescript +const client = new McpClient({ + name: "my-client", + version: "1.0.0", + capabilities: { elicitation: {} } +}); + +// Register elicitation handler +client.onElicit(async (params, connection) => { + // params.message: "What is your name?" + // params.requestedSchema: { type: "object", properties: {...} } + + // Prompt user and return response + const userInput = await promptUser(params.message); + + return { + action: "accept", + content: { name: userInput } + }; +}); + +// Connect with session support +const transport = new StreamableHttpClientTransport({ + sessionAdapter: new InMemoryClientSessionAdapter() +}); +const connect = transport.bind(client); +const connection = await connect(serverUrl); + +// Open SSE stream to receive elicitation requests +await connection.openSessionStream(); + +// Call a tool that triggers elicitation +const result = await connection.callTool("getUserInfo", {}); +// Server will send elicitation request, handler will be called +``` + +### Elicitation Response Actions + +The handler can return different actions: + +```typescript +// Accept with data +client.onElicit(async (params) => { + return { + action: "accept", + content: { confirmed: true } + }; +}); + +// Decline (user says no) +client.onElicit(async (params) => { + return { action: "decline" }; +}); + +// Cancel (user aborts) +client.onElicit(async (params) => { + return { action: "cancel" }; +}); +``` + +### Multiple Elicitations + +A single tool call can trigger multiple sequential elicitations: + +```typescript +client.onElicit(async (params) => { + if (params.message.includes("name")) { + return { action: "accept", content: { name: "Alice" } }; + } else if (params.message.includes("age")) { + return { action: "accept", content: { age: 30 } }; + } + return { action: "decline" }; +}); + +// Server tool might ask for name, then age +const result = await connection.callTool("getUserInfo", {}); +// Elicitation handler called twice in sequence +``` + +### Handling Sampling + +Register a handler for sampling requests (LLM completions): + +```typescript +const client = new McpClient({ + name: "my-client", + version: "1.0.0", + capabilities: { sampling: {} } +}); + +// Register sampling handler +client.onSample(async (params, connection) => { + // params.messages: Array of messages for LLM + // params.modelPreferences: { hints, costPriority, speedPriority, ... } + // params.systemPrompt: Optional system prompt + // params.maxTokens: Maximum tokens to generate + + // Call your LLM + const response = await callLLM({ + messages: params.messages, + systemPrompt: params.systemPrompt, + maxTokens: params.maxTokens + }); + + return { + role: "assistant", + content: { + type: "text", + text: response.text + }, + model: "gpt-4", + stopReason: "endTurn" + }; +}); +``` + +### Connection Info in Handlers + +Handlers receive connection info as the second parameter: + +```typescript +client.onSample(async (params, connectionInfo) => { + console.log(`Server: ${connectionInfo?.serverInfo.name}`); + console.log(`Protocol: ${connectionInfo?.protocolVersion}`); + + // Use connection info to route to appropriate LLM + const llmEndpoint = getLLMEndpoint(connectionInfo?.serverInfo.name); + + return { + role: "assistant", + content: { type: "text", text: "..." }, + model: "gpt-4", + stopReason: "endTurn" + }; +}); +``` + +## Session Management + +### Opening Session Streams + +Open an SSE stream to receive server events: + +```typescript +const stream = await connection.openSessionStream(); + +// Read events manually if needed +const reader = stream.getReader(); +const decoder = new TextDecoder(); + +while (true) { + const { done, value } = await reader.read(); + if (done) break; + + const text = decoder.decode(value); + console.log("SSE event:", text); +} +``` + +### Progress Notifications + +Receive progress updates during long-running operations: + +```typescript +async function readStreamWithProgress(stream: ReadableStream) { + const reader = stream.getReader(); + const decoder = new TextDecoder(); + const progressEvents = []; + let buffer = ""; + + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + buffer += decoder.decode(value, { stream: true }); + const lines = buffer.split("\n"); + buffer = lines.pop() || ""; + + for (const line of lines) { + if (line.startsWith("data: ")) { + const data = JSON.parse(line.slice(6)); + if (data.method === "notifications/progress") { + progressEvents.push(data.params); + console.log(`Progress: ${data.params.progress}/${data.params.total}`); + console.log(`Message: ${data.params.message}`); + } + } + } + } + + return progressEvents; +} + +// Open stream and read in background +const stream = await connection.openSessionStream(); +const progressPromise = readStreamWithProgress(stream); + +// Execute long-running tool +const result = await connection.callTool("processLargeFile", { + filename: "data.csv" +}); + +// Get all progress events +const progressEvents = await progressPromise; +console.log(`Received ${progressEvents.length} progress updates`); +``` + +### Event Replay + +Resume from a specific event ID after reconnection: + +```typescript +// First connection +const stream = await connection.openSessionStream(); +// ... connection drops ... + +// Reconnect with last event ID +const resumedStream = await connection.openSessionStream("event-id-123"); +// Server replays events from that point +``` + +### Closing Sessions + +Close the connection and optionally delete the session: + +```typescript +// Close stream only (session persists on server) +connection.closeSessionStream(); + +// Close connection without deleting session +await connection.close(false); + +// Close connection and delete session from server +await connection.close(true); +``` + +## Ping + +Send a ping to verify the connection: + +```typescript +try { + await connection.ping(); + console.log("Server is alive"); +} catch (error) { + console.log("Server is not responding"); +} +``` + +## OAuth Authentication + +MCP clients can connect to OAuth 2.1-protected MCP servers using the built-in OAuth support. The client handles PKCE (Proof Key for Code Exchange), token storage, automatic token refresh, and multiple server authentication. + +### Basic OAuth Setup + +Connect to an OAuth-protected MCP server: + +```typescript +import { + McpClient, + StreamableHttpClientTransport, + InMemoryOAuthAdapter, + StandardOAuthProvider +} from "mcp-lite"; + +// Create OAuth adapter for token storage +const oauthAdapter = new InMemoryOAuthAdapter(); + +// Create OAuth provider for handling OAuth flows +const oauthProvider = new StandardOAuthProvider(); + +// Configure OAuth settings +const oauthConfig = { + clientId: "your-client-id", // Optional - can use DCR instead + redirectUri: "http://localhost:3000/callback", + onAuthorizationRequired: (authorizationUrl) => { + // Redirect user to authorization URL + console.log("Please authorize at:", authorizationUrl); + // In a web app: window.location.href = authorizationUrl; + // In a CLI app: open(authorizationUrl); + } +}; + +// Create client with OAuth support +const client = new McpClient({ + name: "oauth-client", + version: "1.0.0" +}); + +const transport = new StreamableHttpClientTransport({ + oauthAdapter, + oauthProvider, + oauthConfig +}); + +const connect = transport.bind(client); + +try { + // First connection attempt may fail with 401 + const connection = await connect("https://api.example.com/mcp"); +} catch (error) { + // If authentication required, user is redirected to OAuth server + console.log(error.message); // "Authentication required. Authorization flow started..." +} +``` + +### Dynamic Client Registration (DCR) + +If you don't have a pre-registered `clientId`, mcp-lite can automatically register with the authorization server using RFC 7591 Dynamic Client Registration: + +```typescript +const oauthConfig = { + // No clientId needed - will use DCR automatically + redirectUri: "http://localhost:3000/callback", + clientName: "My MCP Client", // Optional, used for DCR + onAuthorizationRequired: (authorizationUrl) => { + console.log("Please authorize at:", authorizationUrl); + } +}; + +const transport = new StreamableHttpClientTransport({ + oauthAdapter, + oauthProvider, + oauthConfig +}); + +// Client automatically registers on first connection +const connection = await connect("https://api.example.com/mcp"); +// Registered client credentials are stored and reused for future connections +``` + +**How DCR works:** + +1. Client attempts connection to OAuth-protected server +2. Discovers authorization server supports DCR (via `registration_endpoint`) +3. Checks if client credentials already exist for this authorization server +4. If not, automatically registers using `registerOAuthClient()` +5. Stores client credentials (clientId, clientSecret) per authorization server +6. Reuses credentials for subsequent connections to servers using same auth server + +**Fallback behavior:** +- If DCR is not supported and no `clientId` is provided, an error is thrown +- You can manually register a client using `registerOAuthClient()`: + +```typescript +import { registerOAuthClient, discoverOAuthEndpoints } from "mcp-lite"; + +// Discover endpoints +const endpoints = await discoverOAuthEndpoints("https://api.example.com/mcp"); + +// Manually register client +if (endpoints.registrationEndpoint) { + const credentials = await registerOAuthClient( + endpoints.registrationEndpoint, + { + clientName: "My MCP Client", + redirectUris: ["http://localhost:3000/callback"], + grantTypes: ["authorization_code", "refresh_token"], + tokenEndpointAuthMethod: "none" + } + ); + + console.log("Client ID:", credentials.clientId); + // Store and use this clientId in your OAuthConfig +} +``` + +### Completing Authorization Flow + +After the user authorizes and is redirected back to your redirect URI: + +```typescript +// Parse authorization code and state from callback URL +// Example: http://localhost:3000/callback?code=abc123&state=xyz789 +const urlParams = new URLSearchParams(window.location.search); +const code = urlParams.get("code"); +const state = urlParams.get("state"); + +// Complete the authorization flow +await transport.completeAuthorizationFlow( + "https://api.example.com/mcp", + code, + state +); + +// Now connect successfully with stored token +const connection = await connect("https://api.example.com/mcp"); +console.log("Connected:", connection.serverInfo.name); +``` + +### Persistent Token Storage + +The `InMemoryOAuthAdapter` stores tokens in memory, which are lost when the process exits. For production use, implement a persistent adapter: + +```typescript +import { OAuthAdapter, OAuthTokens, StoredClientCredentials } from "mcp-lite"; +import fs from "fs/promises"; + +class FileOAuthAdapter implements OAuthAdapter { + constructor( + private tokenFile: string, + private credentialsFile: string + ) {} + + async storeTokens(resource: string, tokens: OAuthTokens): Promise { + const allTokens = await this.loadAllTokens(); + allTokens[resource] = tokens; + await fs.writeFile(this.tokenFile, JSON.stringify(allTokens, null, 2)); + } + + async getTokens(resource: string): Promise { + const allTokens = await this.loadAllTokens(); + return allTokens[resource]; + } + + async deleteTokens(resource: string): Promise { + const allTokens = await this.loadAllTokens(); + delete allTokens[resource]; + await fs.writeFile(this.tokenFile, JSON.stringify(allTokens, null, 2)); + } + + async hasValidToken(resource: string): Promise { + const tokens = await this.getTokens(resource); + if (!tokens) return false; + + const now = Math.floor(Date.now() / 1000); + const BUFFER_SECONDS = 5 * 60; // 5 minute buffer + return tokens.expiresAt > now + BUFFER_SECONDS; + } + + async storeClientCredentials( + authorizationServer: string, + credentials: StoredClientCredentials + ): Promise { + const allCredentials = await this.loadAllCredentials(); + allCredentials[authorizationServer] = credentials; + await fs.writeFile(this.credentialsFile, JSON.stringify(allCredentials, null, 2)); + } + + async getClientCredentials( + authorizationServer: string + ): Promise { + const allCredentials = await this.loadAllCredentials(); + return allCredentials[authorizationServer]; + } + + async deleteClientCredentials(authorizationServer: string): Promise { + const allCredentials = await this.loadAllCredentials(); + delete allCredentials[authorizationServer]; + await fs.writeFile(this.credentialsFile, JSON.stringify(allCredentials, null, 2)); + } + + private async loadAllTokens(): Promise> { + try { + const data = await fs.readFile(this.tokenFile, "utf-8"); + return JSON.parse(data); + } catch { + return {}; + } + } + + private async loadAllCredentials(): Promise> { + try { + const data = await fs.readFile(this.credentialsFile, "utf-8"); + return JSON.parse(data); + } catch { + return {}; + } + } +} + +// Use file-based storage +const oauthAdapter = new FileOAuthAdapter( + "./oauth-tokens.json", + "./oauth-credentials.json" +); +``` + +### Database Token Storage + +For multi-user applications, store tokens in a database: + +```typescript +import { OAuthAdapter, OAuthTokens } from "mcp-lite"; + +class PostgresOAuthAdapter implements OAuthAdapter { + constructor( + private db: DatabaseConnection, + private userId: string + ) {} + + async storeTokens(resource: string, tokens: OAuthTokens): Promise { + await this.db.query( + `INSERT INTO oauth_tokens (user_id, resource, tokens, expires_at) + VALUES ($1, $2, $3, $4) + ON CONFLICT (user_id, resource) + DO UPDATE SET tokens = $3, expires_at = $4`, + [this.userId, resource, JSON.stringify(tokens), tokens.expiresAt] + ); + } + + async getTokens(resource: string): Promise { + const result = await this.db.query( + `SELECT tokens FROM oauth_tokens + WHERE user_id = $1 AND resource = $2`, + [this.userId, resource] + ); + return result.rows[0]?.tokens; + } + + async deleteTokens(resource: string): Promise { + await this.db.query( + `DELETE FROM oauth_tokens + WHERE user_id = $1 AND resource = $2`, + [this.userId, resource] + ); + } + + async hasValidToken(resource: string): Promise { + const result = await this.db.query( + `SELECT expires_at FROM oauth_tokens + WHERE user_id = $1 AND resource = $2`, + [this.userId, resource] + ); + + if (!result.rows[0]) return false; + + const now = Math.floor(Date.now() / 1000); + const BUFFER_SECONDS = 5 * 60; + return result.rows[0].expires_at > now + BUFFER_SECONDS; + } +} + +// Use database storage +const oauthAdapter = new PostgresOAuthAdapter(db, currentUserId); +``` + +### Automatic Token Refresh + +Tokens are automatically refreshed when they expire: + +```typescript +const transport = new StreamableHttpClientTransport({ + oauthAdapter, + oauthProvider, + oauthConfig +}); + +const connect = transport.bind(client); + +// First connection (uses existing token) +const connection1 = await connect("https://api.example.com/mcp"); +await connection1.callTool("echo", { message: "Hello" }); + +// Wait for token to expire... +// (Tokens are checked with 5-minute buffer before expiry) + +// Next connection automatically refreshes the token +const connection2 = await connect("https://api.example.com/mcp"); +await connection2.callTool("echo", { message: "Still works!" }); +``` + +### Multiple OAuth Providers + +Connect to multiple OAuth-protected servers: + +```typescript +// Each server can have its own tokens +const adapter = new InMemoryOAuthAdapter(); +const provider = new StandardOAuthProvider(); + +const config = { + clientId: "my-client-id", + redirectUri: "http://localhost:3000/callback", + onAuthorizationRequired: (url) => console.log("Authorize:", url) +}; + +const transport = new StreamableHttpClientTransport({ + oauthAdapter: adapter, + oauthProvider: provider, + oauthConfig: config +}); + +const connect = transport.bind(client); + +// Connect to multiple servers with different OAuth tokens +const github = await connect("https://github-mcp.example.com"); +const slack = await connect("https://slack-mcp.example.com"); +const gdrive = await connect("https://drive-mcp.example.com"); + +// Each connection uses its own OAuth token +await github.callTool("listRepos", {}); +await slack.callTool("postMessage", { channel: "#dev", text: "Hi" }); +await gdrive.callTool("listFiles", {}); +``` + +### OAuth Discovery + +MCP servers advertise their OAuth endpoints using RFC 8707 (Resource Indicators) and RFC 8414 (Authorization Server Metadata): + +```typescript +import { discoverOAuthEndpoints } from "mcp-lite"; + +const endpoints = await discoverOAuthEndpoints("https://api.example.com/mcp"); + +console.log(endpoints.authorizationServer); // OAuth server URL +console.log(endpoints.authorizationEndpoint); // Where to send users +console.log(endpoints.tokenEndpoint); // Where to exchange codes +console.log(endpoints.registrationEndpoint); // DCR endpoint (if supported) +console.log(endpoints.scopes); // Required scopes +``` + +The discovery process: +1. Extracts origin from MCP server URL (per RFC 8707 Section 3) +2. Fetches `/.well-known/oauth-protected-resource` from **origin** (not sub-path) + - Example: `https://example.com/mcp` → `https://example.com/.well-known/oauth-protected-resource` +3. Extracts authorization server URL from resource metadata +4. Fetches authorization server metadata +5. Verifies PKCE S256 support (required for OAuth 2.1) +6. Returns endpoint information including DCR endpoint if available + +**RFC 8707 Compliance:** +Per RFC 8707, `.well-known` endpoints MUST be at the origin, not at sub-paths. If you have an MCP server at `https://greentea.fiberplane.io/mcp`, discovery will correctly query `https://greentea.fiberplane.io/.well-known/oauth-protected-resource`. + +**Fallback mechanism:** +If origin-based discovery fails, the client attempts to read the `WWW-Authenticate` header's `as_uri` parameter as a hint for the authorization server metadata URL. + +### Error Handling + +Handle OAuth-specific errors: + +```typescript +try { + const connection = await connect("https://api.example.com/mcp"); +} catch (error) { + if (error.message.includes("Authentication required")) { + // User needs to authorize - wait for callback + console.log("Waiting for user authorization..."); + } else { + console.error("Connection failed:", error); + } +} + +// After callback +try { + await transport.completeAuthorizationFlow(serverUrl, code, state); +} catch (error) { + if (error.message.includes("State parameter mismatch")) { + // Possible CSRF attack + console.error("Security error: invalid state parameter"); + } else if (error.message.includes("Token exchange failed")) { + // OAuth server rejected the code + console.error("Authorization failed:", error.message); + } else { + console.error("Unexpected error:", error); + } +} +``` + +### Security Best Practices + +1. **Always use HTTPS** - OAuth flows must use HTTPS in production +2. **PKCE is mandatory** - The client automatically uses PKCE S256 method +3. **State validation** - State parameters are automatically validated to prevent CSRF +4. **Secure token storage** - Use encrypted storage for production token adapters +5. **Token expiry buffer** - Tokens are refreshed 5 minutes before expiry to prevent race conditions +6. **Resource parameter** - RFC 8707 resource parameter is included in all OAuth requests + +### OAuth Configuration Summary + +Required configuration for OAuth: + +```typescript +interface OAuthConfig { + clientId?: string; // OAuth client ID (optional, uses DCR if not provided) + redirectUri: string; // Callback URL for authorization + clientName?: string; // Client name for DCR (defaults to "MCP Client") + onAuthorizationRequired: (url: string) => void; // Redirect handler +} + +interface OAuthAdapter { + // Token storage + storeTokens(resource: string, tokens: OAuthTokens): Promise | void; + getTokens(resource: string): Promise | OAuthTokens | undefined; + deleteTokens(resource: string): Promise | void; + hasValidToken(resource: string): Promise | boolean; + + // Client credentials storage (for DCR) + storeClientCredentials(authorizationServer: string, credentials: StoredClientCredentials): Promise | void; + getClientCredentials(authorizationServer: string): Promise | StoredClientCredentials | undefined; + deleteClientCredentials(authorizationServer: string): Promise | void; +} +``` + +Built-in adapters: +- `InMemoryOAuthAdapter` - In-memory storage for tokens and client credentials (for testing) +- Implement `OAuthAdapter` for custom storage (files, database, etc.) + +Built-in providers: +- `StandardOAuthProvider` - OAuth 2.1 with PKCE S256 +- Implement `OAuthProvider` for custom OAuth flows + +## Error Handling + +### RpcError + +All JSON-RPC errors are thrown as `RpcError` instances: + +```typescript +import { RpcError } from "mcp-lite"; + +try { + await connection.callTool("nonexistent", {}); +} catch (error) { + if (error instanceof RpcError) { + console.log(`Code: ${error.code}`); + console.log(`Message: ${error.message}`); + console.log(`Data:`, error.data); + } +} +``` + +### HTTP Errors + +Network and HTTP errors are thrown as standard `Error` instances: + +```typescript +try { + const connection = await connect("http://invalid:9999"); +} catch (error) { + console.error("Connection failed:", error.message); +} +``` + +### Tool Call Errors + +Handle errors from tool execution: + +```typescript +try { + const result = await connection.callTool("divide", { + a: 10, + b: 0 + }); +} catch (error) { + if (error instanceof RpcError) { + // Server returned an error (e.g., division by zero) + console.error("Tool error:", error.message); + } else { + // Network or other error + console.error("Request failed:", error); + } +} +``` + +### Retry Pattern + +Implement retry logic for transient failures: + +```typescript +async function callWithRetry( + connection: Connection, + toolName: string, + args: unknown, + maxRetries = 3 +) { + let lastError; + + for (let i = 0; i < maxRetries; i++) { + try { + return await connection.callTool(toolName, args); + } catch (error) { + lastError = error; + if (i < maxRetries - 1) { + await new Promise(resolve => setTimeout(resolve, 1000 * (i + 1))); + } + } + } + + throw lastError; +} + +// Usage +const result = await callWithRetry(connection, "flaky-tool", {}); +``` + +## Advanced Patterns + +### Tool Discovery Pattern + +Always discover available tools before calling them: + +```typescript +async function callToolSafely( + connection: Connection, + toolName: string, + args: unknown +) { + // First, check if tool exists + const { tools } = await connection.listTools(); + const tool = tools.find(t => t.name === toolName); + + if (!tool) { + throw new Error(`Tool '${toolName}' not found. Available: ${tools.map(t => t.name).join(", ")}`); + } + + // Validate args against schema if needed + console.log(`Calling ${toolName} with schema:`, tool.inputSchema); + + // Call the tool + return await connection.callTool(toolName, args); +} + +// Usage +const result = await callToolSafely(connection, "calculate", { a: 5, b: 3 }); +``` + +### Tool Adapter Pattern + +Adapt MCP tools to other SDK formats: + +```typescript +class SDKAdapter { + private toolsCache?: any[]; + + constructor(private connection: Connection) {} + + // Convert MCP tool to SDK tool format + async getTools() { + if (this.toolsCache) { + return this.toolsCache; + } + + const { tools } = await this.connection.listTools(); + + this.toolsCache = tools.map(tool => ({ + type: "function", + function: { + name: tool.name, + description: tool.description, + parameters: tool.inputSchema + } + })); + + return this.toolsCache; + } + + // Execute tool and convert result + async execute(toolName: string, args: unknown) { + const result = await this.connection.callTool(toolName, args); + + // Return structured content if available, otherwise text + if (result.structuredContent) { + return result.structuredContent; + } + + return result.content[0]?.text; + } +} + +// Usage +const adapter = new SDKAdapter(connection); +const sdkTools = await adapter.getTools(); +const result = await adapter.execute("calculate", { a: 5, b: 3 }); +``` + +### Connection Pool with Persistence + +Manage multiple connections efficiently with persistence and auto-reconnect: + +```typescript +interface ConnectionMetadata { + name: string; + url: string; + sessionId?: string; + serverInfo?: { name: string; version: string }; + lastPing?: number; +} + +class PersistentConnectionPool { + private connections = new Map(); + private metadata = new Map(); + private transport: StreamableHttpClientTransport; + + constructor( + private client: McpClient, + private persistencePath?: string + ) { + this.transport = new StreamableHttpClientTransport({ + sessionAdapter: new InMemoryClientSessionAdapter() + }); + } + + /** + * Connect to a server, or return existing connection + */ + async connect(name: string, url: string): Promise { + // Return cached connection if healthy + const existing = this.connections.get(name); + if (existing && await this.isHealthy(existing)) { + return existing; + } + + // Remove stale connection + if (existing) { + await this.disconnect(name); + } + + // Create new connection + const connect = this.transport.bind(this.client); + const connection = await connect(url); + + this.connections.set(name, connection); + this.metadata.set(name, { + name, + url, + sessionId: connection.sessionId, + serverInfo: connection.serverInfo, + lastPing: Date.now() + }); + + await this.saveMetadata(); + + return connection; + } + + /** + * Get existing connection without creating new one + */ + get(name: string): Connection | undefined { + return this.connections.get(name); + } + + /** + * Check if connection is healthy + */ + async isHealthy(connection: Connection): Promise { + try { + await connection.ping(); + return true; + } catch { + return false; + } + } + + /** + * Disconnect and remove a specific connection + */ + async disconnect(name: string): Promise { + const connection = this.connections.get(name); + if (connection) { + await connection.close(true); + this.connections.delete(name); + this.metadata.delete(name); + await this.saveMetadata(); + } + } + + /** + * Reconnect to a server using saved metadata + */ + async reconnect(name: string): Promise { + const meta = this.metadata.get(name); + if (!meta) { + throw new Error(`No metadata found for connection: ${name}`); + } + + return await this.connect(name, meta.url); + } + + /** + * Save connection metadata for restoration + */ + private async saveMetadata(): Promise { + if (!this.persistencePath) return; + + const data = Array.from(this.metadata.entries()).map(([name, meta]) => ({ + name, + url: meta.url, + sessionId: meta.sessionId, + serverInfo: meta.serverInfo + })); + + // Save to file, localStorage, or database + // Example: await fs.writeFile(this.persistencePath, JSON.stringify(data)); + } + + /** + * Restore connections from saved metadata + */ + async restore(): Promise { + if (!this.persistencePath) return; + + // Load from file, localStorage, or database + // Example: const data = JSON.parse(await fs.readFile(this.persistencePath)); + const data: ConnectionMetadata[] = []; // Load your data here + + for (const meta of data) { + this.metadata.set(meta.name, meta); + // Optionally reconnect immediately + // await this.reconnect(meta.name); + } + } + + /** + * List all connection metadata + */ + list(): ConnectionMetadata[] { + return Array.from(this.metadata.values()); + } + + /** + * Close all connections + */ + async closeAll(deleteRemoteSessions = false): Promise { + await Promise.all( + Array.from(this.connections.values()).map(conn => + conn.close(deleteRemoteSessions) + ) + ); + this.connections.clear(); + + if (deleteRemoteSessions) { + this.metadata.clear(); + await this.saveMetadata(); + } + } + + /** + * Health check all connections + */ + async healthCheck(): Promise> { + const results = new Map(); + + for (const [name, connection] of this.connections.entries()) { + const healthy = await this.isHealthy(connection); + results.set(name, healthy); + + if (healthy) { + const meta = this.metadata.get(name); + if (meta) { + meta.lastPing = Date.now(); + } + } + } + + await this.saveMetadata(); + return results; + } +} + +// Usage +const pool = new PersistentConnectionPool(client, "./connections.json"); + +// Restore previous connections +await pool.restore(); + +// Connect to servers +const github = await pool.connect("github", "http://localhost:3000"); +const slack = await pool.connect("slack", "http://localhost:3001"); + +// List available tools from each server +const githubTools = await github.listTools(); +console.log("GitHub tools:", githubTools.tools.map(t => t.name)); + +// Use connections +await github.callTool("listRepos", {}); +await slack.callTool("postMessage", { channel: "#dev", text: "test" }); + +// Health check +const health = await pool.healthCheck(); +console.log("Connection health:", Object.fromEntries(health)); + +// Reconnect if needed +if (!health.get("github")) { + await pool.reconnect("github"); +} + +// List all connections +console.log("Active connections:", pool.list()); + +// Clean up (keeps metadata for future restore) +await pool.closeAll(false); +``` + +### Workflow Orchestration + +Coordinate operations across multiple servers: + +```typescript +async function createAndNotifyIssue( + githubConn: Connection, + slackConn: Connection, + dbConn: Connection +) { + // First, verify all required tools are available + const [githubTools, slackTools, dbTools] = await Promise.all([ + githubConn.listTools(), + slackConn.listTools(), + dbConn.listTools() + ]); + + const hasRequired = + githubTools.tools.some(t => t.name === "createIssue") && + slackTools.tools.some(t => t.name === "postMessage") && + dbTools.tools.some(t => t.name === "insert"); + + if (!hasRequired) { + throw new Error("Required tools not available"); + } + + // Create issue + const issue = await githubConn.callTool("createIssue", { + repo: "my-repo", + title: "Bug found", + body: "Critical bug" + }); + + // Log and notify in parallel + await Promise.all([ + // Log to database + dbConn.callTool("insert", { + table: "issues", + data: { + source: "github", + id: issue.structuredContent?.id, + title: "Bug found" + } + }), + + // Notify team + slackConn.callTool("postMessage", { + channel: "#dev", + text: `New issue created: ${issue.structuredContent?.url}` + }) + ]); + + return issue; +} +``` + +## TypeScript Types + +### Connection Types + +```typescript +import type { + Connection, + ToolCallResult, + ListToolsResult, + ListPromptsResult, + PromptGetResult, + ListResourcesResult, + ResourceReadResult +} from "mcp-lite"; + +const connection: Connection = await connect(url); + +const toolResult: ToolCallResult = await connection.callTool("echo", {}); +const tools: ListToolsResult = await connection.listTools(); +const prompts: ListPromptsResult = await connection.listPrompts(); +const prompt: PromptGetResult = await connection.getPrompt("greet"); +const resources: ListResourcesResult = await connection.listResources(); +const resource: ResourceReadResult = await connection.readResource("file://test"); +``` + +### Handler Types + +```typescript +import type { + SampleHandler, + ElicitHandler, + SamplingParams, + SamplingResult, + ElicitationParams, + ElicitationResult, + ClientConnectionInfo +} from "mcp-lite"; + +const sampleHandler: SampleHandler = async ( + params: SamplingParams, + connection?: ClientConnectionInfo +): Promise => { + return { + role: "assistant", + content: { type: "text", text: "..." }, + model: "gpt-4", + stopReason: "endTurn" + }; +}; + +const elicitHandler: ElicitHandler = async ( + params: ElicitationParams, + connection?: ClientConnectionInfo +): Promise => { + return { + action: "accept", + content: { answer: "..." } + }; +}; + +client.onSample(sampleHandler); +client.onElicit(elicitHandler); +``` + +### Client Capabilities Types + +```typescript +import type { ClientCapabilities } from "mcp-lite"; + +const capabilities: ClientCapabilities = { + elicitation: {}, + sampling: {}, + roots: {}, + // Custom capabilities + customFeature: { enabled: true } +}; + +const client = new McpClient({ + name: "my-client", + version: "1.0.0", + capabilities +}); +``` + +## Examples + +The `packages/core/tests/integration/` directory contains comprehensive examples: + +- **`client-stateless.test.ts`** - Basic stateless operations (tools, prompts, resources) +- **`client-server-requests.test.ts`** - Elicitation and sampling handlers +- **`client-e2e-full.test.ts`** - Multi-server workflows, progress notifications, error recovery + +## Best Practices + +### 1. Connection Lifecycle + +```typescript +// Good: Reuse connection for multiple requests +const connection = await connect(serverUrl); +await connection.callTool("tool1", {}); +await connection.callTool("tool2", {}); +await connection.close(true); + +// Bad: Creating new connection for each request +await (await connect(serverUrl)).callTool("tool1", {}); +await (await connect(serverUrl)).callTool("tool2", {}); // Wasteful +``` + +### 2. Error Handling + +```typescript +// Good: Handle specific error types +try { + await connection.callTool("tool", {}); +} catch (error) { + if (error instanceof RpcError && error.code === -32601) { + console.error("Tool not found"); + } else { + console.error("Other error:", error); + } +} + +// Bad: Generic catch +try { + await connection.callTool("tool", {}); +} catch (error) { + console.error("Error:", error); // Too broad +} +``` + +### 3. Session Management + +```typescript +// Good: Open stream once, reuse for multiple operations +const stream = await connection.openSessionStream(); +// ... multiple tool calls that may send notifications ... +await connection.close(true); + +// Bad: Opening/closing stream repeatedly +await connection.openSessionStream(); +await connection.closeSessionStream(); +await connection.openSessionStream(); // Inefficient +``` + +### 4. Concurrent Operations + +```typescript +// Good: Parallel execution when possible +const [tools, prompts, resources] = await Promise.all([ + connection.listTools(), + connection.listPrompts(), + connection.listResources() +]); + +// Bad: Sequential execution when not needed +const tools = await connection.listTools(); +const prompts = await connection.listPrompts(); +const resources = await connection.listResources(); +``` + +### 5. Capability Declaration + +```typescript +// Good: Only declare capabilities you actually implement +const client = new McpClient({ + name: "my-client", + version: "1.0.0", + capabilities: { + elicitation: {} // Only if you register onElicit handler + } +}); + +client.onElicit(async (params) => { + // Handler implementation +}); + +// Bad: Declaring capabilities without handlers +const client = new McpClient({ + capabilities: { + elicitation: {}, + sampling: {} + } +}); +// No handlers registered - will fail when server tries to use them +``` + +## Protocol Support + +The client supports MCP protocol versions: + +- **`2025-06-18`** (default) - Current version with full elicitation and structured output support +- **`2025-03-26`** - Backward compatible version + +The client automatically negotiates the protocol version during initialization and stores it per connection. + +For more details on protocol versioning, see the [MCP Specification](https://modelcontextprotocol.io/specification). + diff --git a/packages/core/src/client/adapters/index.ts b/packages/core/src/client/adapters/index.ts new file mode 100644 index 0000000..beabae8 --- /dev/null +++ b/packages/core/src/client/adapters/index.ts @@ -0,0 +1 @@ +export type { ToolAdapter } from "./tool-adapter.js"; diff --git a/packages/core/src/client/adapters/tool-adapter.ts b/packages/core/src/client/adapters/tool-adapter.ts new file mode 100644 index 0000000..a4864c1 --- /dev/null +++ b/packages/core/src/client/adapters/tool-adapter.ts @@ -0,0 +1,46 @@ +import type { Tool, ToolCallResult } from "../../types.js"; + +/** + * Adapter interface for converting MCP tools to SDK-specific tool formats. + * + * This allows MCP clients to easily integrate with various AI SDKs (Vercel AI SDK, + * Anthropic SDK, etc.) by providing adapters that convert MCP tool formats. + * + * @template TSDKTool - The SDK-specific tool type + * + * @example Vercel AI SDK adapter + * ```typescript + * const vercelAdapter: ToolAdapter = { + * toSDK: (mcpTool) => ({ + * type: "function", + * function: { + * name: mcpTool.name, + * description: mcpTool.description, + * parameters: mcpTool.inputSchema + * } + * }), + * resultToSDK: (mcpResult) => ({ + * toolCallId: "...", + * toolName: "...", + * result: mcpResult.content[0].text + * }) + * }; + * ``` + */ +export interface ToolAdapter { + /** + * Convert an MCP tool definition to SDK-specific format + * + * @param mcpTool - MCP tool metadata + * @returns SDK-specific tool definition + */ + toSDK(mcpTool: Tool): TSDKTool; + + /** + * Convert an MCP tool call result to SDK-specific format + * + * @param mcpResult - MCP tool call result + * @returns SDK-specific result format + */ + resultToSDK(mcpResult: ToolCallResult): unknown; +} diff --git a/packages/core/src/client/client.ts b/packages/core/src/client/client.ts new file mode 100644 index 0000000..059cc75 --- /dev/null +++ b/packages/core/src/client/client.ts @@ -0,0 +1,216 @@ +import { METHODS } from "../constants.js"; +import type { Logger } from "../core.js"; +import { RpcError } from "../errors.js"; +import { + createJsonRpcError, + createJsonRpcResponse, + JSON_RPC_ERROR_CODES, + type JsonRpcReq, + type JsonRpcRes, +} from "../types.js"; +import type { + ClientConnectionInfo, + ElicitationParams, + ElicitHandler, + SampleHandler, + SamplingParams, +} from "./types.js"; + +/** + * Client capabilities to advertise to the server + */ +export interface ClientCapabilities { + elicitation?: Record; + roots?: Record; + sampling?: Record; + [key: string]: unknown; +} + +/** + * Options for creating an MCP client + */ +export interface McpClientOptions { + /** Client name (included in client info during initialize) */ + name: string; + /** Client version (included in client info during initialize) */ + version: string; + /** Optional capabilities to advertise to server */ + capabilities?: ClientCapabilities; + /** Optional logger for client messages */ + logger?: Logger; +} + +/** + * MCP Client implementation. + * + * Provides a framework for building MCP-compliant clients that can connect to + * MCP servers and call tools, prompts, and resources. + * + * @example Basic client setup + * ```typescript + * import { McpClient, StreamableHttpClientTransport } from "mcp-lite"; + * + * // Create client instance + * const client = new McpClient({ + * name: "my-client", + * version: "1.0.0" + * }); + * + * // Create HTTP transport and connect + * const transport = new StreamableHttpClientTransport(); + * const connect = transport.bind(client); + * const connection = await connect("http://localhost:3000"); + * + * // Call a tool + * const result = await connection.callTool("echo", { message: "Hello!" }); + * ``` + */ +export class McpClient { + public readonly clientInfo: { name: string; version: string }; + public readonly capabilities?: ClientCapabilities; + private logger: Logger; + + // Handlers for server-initiated requests + private sampleHandler?: SampleHandler; + private elicitHandler?: ElicitHandler; + + // Connection info set by transport after initialize + private connectionInfo?: ClientConnectionInfo; + + /** + * Create a new MCP client instance. + * + * @param options - Client configuration options + */ + constructor(options: McpClientOptions) { + this.clientInfo = { + name: options.name, + version: options.version, + }; + this.capabilities = options.capabilities; + this.logger = options.logger || console; + } + + /** + * Register handler for server sampling requests. + * + * When the server needs the client to call an LLM, it will send a sampling + * request that will be handled by this function. + * + * @param handler - Sampling handler function + * @returns This client instance for chaining + * + * @example + * ```typescript + * client.onSample(async (params, connection) => { + * const response = await callLLM(params.messages, params.modelPreferences); + * return { + * role: "assistant", + * content: { type: "text", text: response }, + * model: "gpt-4", + * stopReason: "endTurn" + * }; + * }); + * ``` + */ + onSample(handler: SampleHandler): this { + this.sampleHandler = handler; + return this; + } + + /** + * Register handler for server elicitation requests. + * + * When the server needs the client to prompt the user for structured data, + * it will send an elicitation request that will be handled by this function. + * + * @param handler - Elicitation handler function + * @returns This client instance for chaining + * + * @example + * ```typescript + * client.onElicit(async (params, connection) => { + * const userInput = await promptUser(params.message, params.requestedSchema); + * return { + * action: "accept", + * content: userInput + * }; + * }); + * ``` + */ + onElicit(handler: ElicitHandler): this { + this.elicitHandler = handler; + return this; + } + + /** + * Set connection info after successful initialization. + * @internal + */ + _setConnectionInfo(info: ClientConnectionInfo): void { + this.connectionInfo = info; + } + + /** + * Internal dispatcher for server-initiated requests. + * Called by transport when SSE stream receives a request. + * + * @internal + */ + async _dispatch(message: JsonRpcReq): Promise { + const requestId = message.id; + + try { + let result: unknown; + + switch (message.method) { + case METHODS.ELICITATION.CREATE: + if (!this.elicitHandler) { + throw new RpcError( + JSON_RPC_ERROR_CODES.METHOD_NOT_FOUND, + "No elicitation handler registered", + ); + } + result = await this.elicitHandler( + message.params as ElicitationParams, + this.connectionInfo, + ); + break; + + case METHODS.SAMPLING.CREATE: + if (!this.sampleHandler) { + throw new RpcError( + JSON_RPC_ERROR_CODES.METHOD_NOT_FOUND, + "No sampling handler registered", + ); + } + result = await this.sampleHandler( + message.params as SamplingParams, + this.connectionInfo, + ); + break; + + default: + throw new RpcError( + JSON_RPC_ERROR_CODES.METHOD_NOT_FOUND, + `Unknown method: ${message.method}`, + ); + } + + return createJsonRpcResponse(requestId, result); + } catch (error) { + // Default error handling + if (error instanceof RpcError) { + return createJsonRpcError(requestId, error.toJson()); + } + + return createJsonRpcError( + requestId, + new RpcError( + JSON_RPC_ERROR_CODES.INTERNAL_ERROR, + error instanceof Error ? error.message : "Unknown error", + ).toJson(), + ); + } + } +} diff --git a/packages/core/src/client/connection.ts b/packages/core/src/client/connection.ts new file mode 100644 index 0000000..d633a37 --- /dev/null +++ b/packages/core/src/client/connection.ts @@ -0,0 +1,383 @@ +import { + JSON_RPC_VERSION, + MCP_PROTOCOL_HEADER, + MCP_SESSION_ID_HEADER, + METHODS, + SSE_ACCEPT_HEADER, + SUPPORTED_MCP_PROTOCOL_VERSIONS, +} from "../constants.js"; +import type { Logger } from "../core.js"; +import { RpcError } from "../errors.js"; +import { + createJsonRpcError, + type InitializeResult, + isJsonRpcRequest, + JSON_RPC_ERROR_CODES, + type JsonRpcReq, + type JsonRpcRes, + type ListPromptsResult, + type ListResourcesResult, + type ListResourceTemplatesResult, + type ListToolsResult, + type PromptGetResult, + type ResourceReadResult, + type ToolCallResult, +} from "../types.js"; +import type { McpClient } from "./client.js"; + +/** + * Options for creating a Connection + */ +export interface ConnectionOptions { + baseUrl: string; + serverInfo: { name: string; version: string }; + serverCapabilities: InitializeResult["capabilities"]; + sessionId?: string; + responseSender?: (response: JsonRpcRes) => Promise; + logger?: Logger; + headers?: Record; +} + +/** + * Connection to an MCP server. + * + * Provides methods to interact with the server's tools, prompts, and resources. + */ +export class Connection { + private baseUrl: string; + public readonly sessionId?: string; + public readonly serverInfo: { name: string; version: string }; + public readonly serverCapabilities: InitializeResult["capabilities"]; + + // SSE stream management + private sessionStreamAbortController?: AbortController; + private responseSender?: (response: JsonRpcRes) => Promise; + private client?: McpClient; + private logger?: Logger; + private customHeaders?: Record; + + constructor(options: ConnectionOptions) { + this.baseUrl = options.baseUrl; + this.sessionId = options.sessionId; + this.serverInfo = options.serverInfo; + this.serverCapabilities = options.serverCapabilities; + this.responseSender = options.responseSender; + this.logger = options.logger; + this.customHeaders = options.headers; + } + + /** + * Set the client instance for handling server requests + * Called by transport after creating connection + * @internal + */ + _setClient(client: McpClient): void { + this.client = client; + } + + /** + * Call a tool on the server + * + * @param name - Tool name + * @param args - Tool arguments + * @returns Tool call result + */ + async callTool(name: string, args?: unknown): Promise { + const response = await this._request(METHODS.TOOLS.CALL, { + name, + arguments: args, + }); + return response as ToolCallResult; + } + + /** + * List all available tools from the server + * + * @returns List of tools + */ + async listTools(): Promise { + const response = await this._request(METHODS.TOOLS.LIST); + return response as ListToolsResult; + } + + /** + * List all available prompts from the server + * + * @returns List of prompts + */ + async listPrompts(): Promise { + const response = await this._request(METHODS.PROMPTS.LIST); + return response as ListPromptsResult; + } + + /** + * Get a prompt from the server + * + * @param name - Prompt name + * @param args - Prompt arguments + * @returns Prompt result + */ + async getPrompt(name: string, args?: unknown): Promise { + const response = await this._request(METHODS.PROMPTS.GET, { + name, + arguments: args, + }); + return response as PromptGetResult; + } + + /** + * List all available resources from the server + * + * @returns List of resources + */ + async listResources(): Promise { + const response = await this._request(METHODS.RESOURCES.LIST); + return response as ListResourcesResult; + } + + /** + * List all available resource templates from the server + * + * @returns List of resource templates + */ + async listResourceTemplates(): Promise { + const response = await this._request(METHODS.RESOURCES.TEMPLATES_LIST); + return response as ListResourceTemplatesResult; + } + + /** + * Read a resource from the server + * + * @param uri - Resource URI + * @returns Resource contents + */ + async readResource(uri: string): Promise { + const response = await this._request(METHODS.RESOURCES.READ, { uri }); + return response as ResourceReadResult; + } + + /** + * Send a ping to the server + * + * @returns Empty response + */ + async ping(): Promise> { + const response = await this._request(METHODS.PING); + return response as Record; + } + + /** + * Internal method to send requests to the server + * + * @private + */ + private async _request(method: string, params?: unknown): Promise { + const headers: Record = { + "Content-Type": "application/json", + "MCP-Protocol-Version": SUPPORTED_MCP_PROTOCOL_VERSIONS.V2025_06_18, + }; + + if (this.sessionId) { + headers[MCP_SESSION_ID_HEADER] = this.sessionId; + } + + // Merge custom headers (these override defaults if there are conflicts) + if (this.customHeaders) { + Object.assign(headers, this.customHeaders); + } + + const requestBody: JsonRpcReq = { + jsonrpc: JSON_RPC_VERSION, + id: Math.random().toString(36).substring(7), + method, + params, + }; + + const response = await fetch(this.baseUrl, { + method: "POST", + headers, + body: JSON.stringify(requestBody), + }); + + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`); + } + + const result = (await response.json()) as JsonRpcRes; + + if (result.error) { + throw new RpcError( + result.error.code, + result.error.message, + result.error.data, + ); + } + + return result.result; + } + + /** + * Open a GET SSE stream to receive server notifications. + * Only available when using session-based transport. + * + * The stream will automatically be processed in the background to handle + * server-initiated requests (like elicitation and sampling). + * + * @param lastEventId - Optional Last-Event-ID for replay from a specific event + * @throws Error if connection does not have a session ID + */ + async openSessionStream(lastEventId?: string): Promise { + if (!this.sessionId) { + throw new Error("Cannot open session stream without session ID"); + } + + // Close any existing stream + this.closeSessionStream(); + + const headers: Record = { + Accept: SSE_ACCEPT_HEADER, + [MCP_PROTOCOL_HEADER]: SUPPORTED_MCP_PROTOCOL_VERSIONS.V2025_06_18, + [MCP_SESSION_ID_HEADER]: this.sessionId, + }; + + if (lastEventId) { + headers["Last-Event-ID"] = lastEventId; + } + + // Merge custom headers (these override defaults if there are conflicts) + if (this.customHeaders) { + Object.assign(headers, this.customHeaders); + } + + this.sessionStreamAbortController = new AbortController(); + + const response = await fetch(this.baseUrl, { + method: "GET", + headers, + signal: this.sessionStreamAbortController.signal, + }); + + if (!response.ok) { + throw new Error( + `Failed to open session stream: ${response.status} ${response.statusText}`, + ); + } + + if (!response.body) { + throw new Error("No response body for SSE stream"); + } + + // Process the stream in the background to handle server requests + this.processSessionStream(response.body); + } + + /** + * Process incoming SSE events from the stream + * @private + */ + private async processSessionStream( + stream: ReadableStream, + ): Promise { + const reader = stream.getReader(); + const decoder = new TextDecoder(); + let buffer = ""; + + try { + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + buffer += decoder.decode(value, { stream: true }); + const lines = buffer.split("\n"); + buffer = lines.pop() || ""; + + for (const line of lines) { + if (line.startsWith("data: ")) { + try { + const data = JSON.parse(line.slice(6)); + + // Check if this is a JSON-RPC request (has method and id) + if (isJsonRpcRequest(data)) { + await this.handleServerRequest(data); + } + } catch (error) { + this.logger?.error?.("Failed to parse SSE data:", error); + } + } + } + } + } catch (error) { + if ((error as Error).name !== "AbortError") { + this.logger?.error?.("SSE stream error:", error); + } + } finally { + reader.releaseLock(); + } + } + + /** + * Handle a server-initiated JSON-RPC request + * @private + */ + private async handleServerRequest(request: JsonRpcReq): Promise { + if (!this.client) { + this.logger?.error?.("Cannot handle server request: no client instance"); + return; + } + + if (!this.responseSender) { + this.logger?.error?.("Cannot handle server request: no response sender"); + return; + } + + try { + // Dispatch to client handlers + const response = await this.client._dispatch(request); + + // Send response back to server + await this.responseSender(response); + } catch (error) { + this.logger?.error?.("Error handling server request:", error); + + // Send error response + const errorResponse = createJsonRpcError( + request.id, + new RpcError( + JSON_RPC_ERROR_CODES.INTERNAL_ERROR, + error instanceof Error ? error.message : "Unknown error", + ).toJson(), + ); + + await this.responseSender(errorResponse); + } + } + + /** + * Close the session stream + */ + closeSessionStream(): void { + if (this.sessionStreamAbortController) { + this.sessionStreamAbortController.abort(); + this.sessionStreamAbortController = undefined; + } + } + + /** + * Close the connection and optionally delete the session + * + * @param deleteSession - If true, sends a DELETE request to remove the session from the server + */ + async close(deleteSession = false): Promise { + this.closeSessionStream(); + + if (deleteSession && this.sessionId) { + // Send DELETE request to close session + await fetch(this.baseUrl, { + method: "DELETE", + headers: { + [MCP_SESSION_ID_HEADER]: this.sessionId, + }, + }); + } + } +} diff --git a/packages/core/src/client/index.ts b/packages/core/src/client/index.ts new file mode 100644 index 0000000..96a4003 --- /dev/null +++ b/packages/core/src/client/index.ts @@ -0,0 +1,50 @@ +export type { ToolAdapter } from "./adapters/index.js"; +export { + type ClientCapabilities, + McpClient, + type McpClientOptions, +} from "./client.js"; +export { Connection, type ConnectionOptions } from "./connection.js"; +export { + InMemoryOAuthAdapter, + type OAuthAdapter, + type OAuthTokens, + type StoredClientCredentials, +} from "./oauth-adapter.js"; +export { + type ClientCredentials, + type ClientMetadata, + registerOAuthClient, +} from "./oauth-dcr.js"; +export { + discoverOAuthEndpoints, + type OAuthEndpoints, +} from "./oauth-discovery.js"; +export { + type AuthorizationFlowResult, + type ExchangeCodeParams, + type OAuthProvider, + type RefreshTokenParams, + StandardOAuthProvider, + type StartAuthorizationFlowParams, +} from "./oauth-provider.js"; +export { + type ClientSessionAdapter, + type ClientSessionData, + InMemoryClientSessionAdapter, +} from "./session-adapter.js"; +export { + type ConnectOptions, + type OAuthConfig, + StreamableHttpClientTransport, + type StreamableHttpClientTransportOptions, +} from "./transport-http.js"; +export type { + ClientConnectionInfo, + ElicitationParams, + ElicitationResult, + ElicitHandler, + SampleHandler, + SamplingParams, + SamplingResult, +} from "./types.js"; diff --git a/packages/core/src/client/oauth-adapter.ts b/packages/core/src/client/oauth-adapter.ts new file mode 100644 index 0000000..b7f8457 --- /dev/null +++ b/packages/core/src/client/oauth-adapter.ts @@ -0,0 +1,156 @@ +/** + * OAuth tokens stored for a specific resource server + */ +export interface OAuthTokens { + /** Access token for API authorization */ + accessToken: string; + /** Optional refresh token for obtaining new access tokens */ + refreshToken?: string; + /** Unix timestamp in seconds when the access token expires */ + expiresAt: number; + /** Scopes granted for this token */ + scopes: string[]; + /** Token type, always "Bearer" for OAuth 2.1 */ + tokenType: "Bearer"; +} + +/** + * OAuth client credentials from Dynamic Client Registration (RFC 7591) + */ +export interface StoredClientCredentials { + /** OAuth client identifier */ + clientId: string; + /** Client secret (optional for public clients) */ + clientSecret?: string; + /** Registration access token for updating client metadata */ + registrationAccessToken?: string; + /** Registration client URI for updating/deleting client */ + registrationClientUri?: string; +} + +/** + * Adapter interface for OAuth token persistence + * + * Implementations can store tokens in memory, localStorage, secure storage, database, etc. + * Each resource server (MCP server) has its own set of tokens identified by the resource URL. + */ +export interface OAuthAdapter { + /** + * Store OAuth tokens for a specific resource server + * + * @param resource - Resource server URL (MCP server base URL) + * @param tokens - OAuth tokens to store + */ + storeTokens(resource: string, tokens: OAuthTokens): Promise | void; + + /** + * Retrieve OAuth tokens for a specific resource server + * + * @param resource - Resource server URL (MCP server base URL) + * @returns OAuth tokens if found, undefined otherwise + */ + getTokens( + resource: string, + ): Promise | OAuthTokens | undefined; + + /** + * Delete OAuth tokens for a specific resource server + * + * @param resource - Resource server URL (MCP server base URL) + */ + deleteTokens(resource: string): Promise | void; + + /** + * Check if a valid (non-expired) token exists for a resource server + * + * @param resource - Resource server URL (MCP server base URL) + * @returns True if a valid token exists, false otherwise + */ + hasValidToken(resource: string): Promise | boolean; + + /** + * Store OAuth client credentials for a specific authorization server + * + * @param authorizationServer - Authorization server URL + * @param credentials - Client credentials to store + */ + storeClientCredentials( + authorizationServer: string, + credentials: StoredClientCredentials, + ): Promise | void; + + /** + * Retrieve OAuth client credentials for a specific authorization server + * + * @param authorizationServer - Authorization server URL + * @returns Client credentials if found, undefined otherwise + */ + getClientCredentials( + authorizationServer: string, + ): + | Promise + | StoredClientCredentials + | undefined; + + /** + * Delete OAuth client credentials for a specific authorization server + * + * @param authorizationServer - Authorization server URL + */ + deleteClientCredentials(authorizationServer: string): Promise | void; +} + +/** + * In-memory OAuth token adapter + * + * Stores tokens and client credentials in memory. Data is lost when the process exits. + * Suitable for testing and short-lived clients. + * + * For production use, implement a persistent adapter that stores tokens + * in secure storage (e.g., encrypted file, secure database, keychain). + */ +export class InMemoryOAuthAdapter implements OAuthAdapter { + private tokens = new Map(); + private clientCredentials = new Map(); + + storeTokens(resource: string, tokens: OAuthTokens): void { + this.tokens.set(resource, tokens); + } + + getTokens(resource: string): OAuthTokens | undefined { + return this.tokens.get(resource); + } + + deleteTokens(resource: string): void { + this.tokens.delete(resource); + } + + hasValidToken(resource: string): boolean { + const tokens = this.tokens.get(resource); + if (!tokens) { + return false; + } + + // Check if token is expired (with 5 minute buffer) + const now = Math.floor(Date.now() / 1000); + const BUFFER_SECONDS = 5 * 60; + return tokens.expiresAt > now + BUFFER_SECONDS; + } + + storeClientCredentials( + authorizationServer: string, + credentials: StoredClientCredentials, + ): void { + this.clientCredentials.set(authorizationServer, credentials); + } + + getClientCredentials( + authorizationServer: string, + ): StoredClientCredentials | undefined { + return this.clientCredentials.get(authorizationServer); + } + + deleteClientCredentials(authorizationServer: string): void { + this.clientCredentials.delete(authorizationServer); + } +} diff --git a/packages/core/src/client/oauth-dcr.ts b/packages/core/src/client/oauth-dcr.ts new file mode 100644 index 0000000..80fa414 --- /dev/null +++ b/packages/core/src/client/oauth-dcr.ts @@ -0,0 +1,103 @@ +/** + * Dynamic Client Registration for OAuth 2.0 (RFC 7591) + */ + +/** + * Client credentials returned from Dynamic Client Registration + */ +export interface ClientCredentials { + /** OAuth client identifier */ + clientId: string; + /** Client secret (optional for public clients using PKCE) */ + clientSecret?: string; + /** Registration access token for updating client metadata */ + registrationAccessToken?: string; + /** Registration client URI for updating/deleting client */ + registrationClientUri?: string; +} + +/** + * Client metadata for Dynamic Client Registration + */ +export interface ClientMetadata { + /** Human-readable client name */ + clientName: string; + /** Redirect URIs for OAuth callbacks */ + redirectUris: string[]; + /** Grant types supported */ + grantTypes?: string[]; + /** Token endpoint authentication method */ + tokenEndpointAuthMethod?: string; + /** Scopes to request */ + scope?: string; +} + +/** + * Register a new OAuth client dynamically per RFC 7591 + * + * @param registrationEndpoint - Client registration endpoint URL + * @param metadata - Client metadata to register + * @returns Client credentials including client_id + * @throws Error if registration fails + * + * @example + * ```typescript + * const credentials = await registerOAuthClient( + * "https://auth.example.com/register", + * { + * clientName: "MCP Client", + * redirectUris: ["http://localhost:3000/callback"], + * } + * ); + * console.log(credentials.clientId); + * ``` + */ +export async function registerOAuthClient( + registrationEndpoint: string, + metadata: ClientMetadata, +): Promise { + // Prepare registration request per RFC 7591 + const registrationRequest = { + client_name: metadata.clientName, + redirect_uris: metadata.redirectUris, + grant_types: metadata.grantTypes || ["authorization_code", "refresh_token"], + token_endpoint_auth_method: metadata.tokenEndpointAuthMethod || "none", + ...(metadata.scope && { scope: metadata.scope }), + }; + + const response = await fetch(registrationEndpoint, { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "application/json", + }, + body: JSON.stringify(registrationRequest), + }); + + if (!response.ok) { + const errorBody = await response.text(); + throw new Error( + `Dynamic Client Registration failed: HTTP ${response.status} - ${errorBody}`, + ); + } + + const registrationResponse = (await response.json()) as { + client_id: string; + client_secret?: string; + registration_access_token?: string; + registration_client_uri?: string; + }; + + if (!registrationResponse.client_id) { + throw new Error( + "Dynamic Client Registration response missing client_id field", + ); + } + + return { + clientId: registrationResponse.client_id, + clientSecret: registrationResponse.client_secret, + registrationAccessToken: registrationResponse.registration_access_token, + registrationClientUri: registrationResponse.registration_client_uri, + }; +} diff --git a/packages/core/src/client/oauth-discovery.ts b/packages/core/src/client/oauth-discovery.ts new file mode 100644 index 0000000..bfd5479 --- /dev/null +++ b/packages/core/src/client/oauth-discovery.ts @@ -0,0 +1,206 @@ +import { SUPPORTED_MCP_PROTOCOL_VERSIONS } from "../constants.js"; + +/** + * OAuth endpoint information discovered from an MCP server + */ +export interface OAuthEndpoints { + /** Authorization server URL */ + authorizationServer: string; + /** Authorization endpoint URL for starting OAuth flows */ + authorizationEndpoint: string; + /** Token endpoint URL for exchanging codes and refreshing tokens */ + tokenEndpoint: string; + /** Dynamic Client Registration endpoint (RFC 7591) */ + registrationEndpoint?: string; + /** Scopes required for accessing this resource server */ + scopes: string[]; +} + +/** + * Helper function to validate and return OAuth endpoints + */ +function validateAndReturnEndpoints( + authorizationServer: string, + authorizationEndpoint: string, + tokenEndpoint: string, + codeChallenges: string[], + scopes: string[], + registrationEndpoint?: string, +): OAuthEndpoints { + if (!authorizationEndpoint || !tokenEndpoint) { + throw new Error( + "Authorization server metadata missing required endpoints (authorization_endpoint, token_endpoint)", + ); + } + + if (!codeChallenges.includes("S256")) { + throw new Error( + "Authorization server does not support PKCE S256 method (required for OAuth 2.1)", + ); + } + + return { + authorizationServer, + authorizationEndpoint, + tokenEndpoint, + registrationEndpoint, + scopes, + }; +} + +/** + * Discover OAuth endpoints for an MCP server + * + * Follows RFC 8414 (OAuth 2.0 Authorization Server Metadata) and + * RFC 8707 (Resource Indicators) to discover OAuth configuration. + * + * Per RFC 8707 Section 3, .well-known endpoints MUST be at the origin, + * not sub-paths. For example, if the MCP endpoint is at + * https://example.com/mcp, discovery uses https://example.com/.well-known/oauth-protected-resource + * + * Steps: + * 1. Extract origin from baseUrl + * 2. Fetch /.well-known/oauth-protected-resource from origin + * 3. If origin discovery fails, try fetching the MCP endpoint to get WWW-Authenticate header with as_uri + * 4. Retrieve authorization server URL from resource metadata or as_uri + * 5. Fetch authorization server metadata + * 6. Extract and validate OAuth endpoints + * 7. Verify PKCE S256 support (mandatory for OAuth 2.1) + * + * @param baseUrl - MCP server base URL + * @returns OAuth endpoint information + * @throws Error if discovery fails or server doesn't support required features + * + * @example + * ```typescript + * const endpoints = await discoverOAuthEndpoints("https://api.example.com/mcp"); + * console.log(endpoints.authorizationEndpoint); + * console.log(endpoints.tokenEndpoint); + * ``` + */ +export async function discoverOAuthEndpoints( + baseUrl: string, +): Promise { + // Extract origin for RFC 8707 compliant discovery + const url = new URL(baseUrl); + const origin = url.origin; + + // Step 1: Fetch resource server metadata (RFC 8707) + // Per RFC 8707, .well-known endpoints MUST be at the origin, not sub-paths + const resourceMetadataUrl = `${origin}/.well-known/oauth-protected-resource`; + const resourceResponse = await fetch(resourceMetadataUrl); + + // Step 1.5: If origin discovery fails, try the actual endpoint to get WWW-Authenticate + if (!resourceResponse.ok) { + // Make a request to the actual MCP endpoint to get auth headers + const endpointResponse = await fetch(baseUrl, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + jsonrpc: "2.0", + id: "discovery", + method: "initialize", + params: { + protocolVersion: SUPPORTED_MCP_PROTOCOL_VERSIONS.V2025_06_18, + clientInfo: { name: "discovery", version: "1.0.0" }, + capabilities: {}, + }, + }), + }); + + // Check for WWW-Authenticate header with as_uri + const wwwAuth = endpointResponse.headers.get("www-authenticate"); + if (wwwAuth && endpointResponse.status === 401) { + const asUriMatch = wwwAuth.match(/as_uri="([^"]+)"/); + if (asUriMatch?.[1]) { + const authServerMetadataUrl = asUriMatch[1]; + // Skip resource metadata, go directly to authorization server + const serverResponse = await fetch(authServerMetadataUrl); + + if (!serverResponse.ok) { + throw new Error( + `Failed to fetch authorization server metadata from ${authServerMetadataUrl}: HTTP ${serverResponse.status}`, + ); + } + + const serverMetadata = (await serverResponse.json()) as { + authorization_endpoint?: string; + token_endpoint?: string; + code_challenge_methods_supported?: string[]; + registration_endpoint?: string; + issuer?: string; + }; + + // Extract issuer as authorization server + const authorizationServer = serverMetadata.issuer || origin; + + return validateAndReturnEndpoints( + authorizationServer, + serverMetadata.authorization_endpoint ?? "", + serverMetadata.token_endpoint ?? "", + serverMetadata.code_challenge_methods_supported || [], + [], // scopes from resource metadata not available + serverMetadata.registration_endpoint, + ); + } + } + + // If still no success, throw original error + throw new Error( + `Failed to fetch resource metadata from ${resourceMetadataUrl}: HTTP ${resourceResponse.status}`, + ); + } + + const resourceMetadata = (await resourceResponse.json()) as { + authorization_servers?: string[]; + authorization_server?: string; + scopes_supported?: string[]; + }; + + // Extract authorization server URL + const authorizationServer = + resourceMetadata.authorization_servers?.[0] || + resourceMetadata.authorization_server; + + if (!authorizationServer) { + throw new Error( + "Resource metadata missing authorization_server or authorization_servers field", + ); + } + + // Extract required scopes + const scopes: string[] = resourceMetadata.scopes_supported || []; + + // Step 2: Fetch authorization server metadata (RFC 8414) + const serverMetadataUrl = `${authorizationServer}/.well-known/oauth-authorization-server`; + const serverResponse = await fetch(serverMetadataUrl); + + if (!serverResponse.ok) { + throw new Error( + `Failed to fetch authorization server metadata from ${serverMetadataUrl}: HTTP ${serverResponse.status}`, + ); + } + + const serverMetadata = (await serverResponse.json()) as { + authorization_endpoint?: string; + token_endpoint?: string; + code_challenge_methods_supported?: string[]; + registration_endpoint?: string; + }; + + // Extract endpoints + const authorizationEndpoint = serverMetadata.authorization_endpoint ?? ""; + const tokenEndpoint = serverMetadata.token_endpoint ?? ""; + const registrationEndpoint = serverMetadata.registration_endpoint; + const supportedChallengeMethods: string[] = + serverMetadata.code_challenge_methods_supported || []; + + return validateAndReturnEndpoints( + authorizationServer, + authorizationEndpoint, + tokenEndpoint, + supportedChallengeMethods, + scopes, + registrationEndpoint, + ); +} diff --git a/packages/core/src/client/oauth-provider.ts b/packages/core/src/client/oauth-provider.ts new file mode 100644 index 0000000..62d9b89 --- /dev/null +++ b/packages/core/src/client/oauth-provider.ts @@ -0,0 +1,295 @@ +import type { OAuthTokens } from "./oauth-adapter.js"; + +/** + * Parameters for starting an OAuth authorization flow + */ +export interface StartAuthorizationFlowParams { + /** OAuth authorization endpoint URL */ + authorizationEndpoint: string; + /** OAuth client ID */ + clientId: string; + /** Redirect URI for the authorization callback */ + redirectUri: string; + /** Requested OAuth scopes */ + scopes: string[]; + /** Resource server URL (RFC 8707) */ + resource: string; + /** Optional state parameter for CSRF protection */ + state?: string; +} + +/** + * Result of starting an authorization flow + */ +export interface AuthorizationFlowResult { + /** Complete authorization URL to redirect user to */ + authorizationUrl: string; + /** PKCE code verifier to use in token exchange */ + codeVerifier: string; + /** State parameter for CSRF validation */ + state: string; +} + +/** + * Parameters for exchanging authorization code for tokens + */ +export interface ExchangeCodeParams { + /** OAuth token endpoint URL */ + tokenEndpoint: string; + /** Authorization code from callback */ + code: string; + /** PKCE code verifier from authorization flow */ + codeVerifier: string; + /** OAuth client ID */ + clientId: string; + /** Redirect URI used in authorization request */ + redirectUri: string; + /** Resource server URL (RFC 8707) */ + resource: string; +} + +/** + * Parameters for refreshing an access token + */ +export interface RefreshTokenParams { + /** OAuth token endpoint URL */ + tokenEndpoint: string; + /** Refresh token */ + refreshToken: string; + /** OAuth client ID */ + clientId: string; + /** Resource server URL (RFC 8707) */ + resource: string; +} + +/** + * OAuth provider interface for handling OAuth 2.1 flows + * + * Implementations handle PKCE generation, authorization URL construction, + * token exchange, and token refresh. + */ +export interface OAuthProvider { + /** + * Start an OAuth authorization flow with PKCE + * + * Generates a PKCE code verifier and challenge, constructs the authorization URL, + * and returns everything needed to complete the flow. + * + * @param params - Authorization flow parameters + * @returns Authorization URL and flow state (verifier, state) + */ + startAuthorizationFlow( + params: StartAuthorizationFlowParams, + ): Promise; + + /** + * Exchange an authorization code for OAuth tokens + * + * Sends a token request to the OAuth server with the authorization code + * and PKCE code verifier. + * + * @param params - Token exchange parameters + * @returns OAuth tokens + */ + exchangeCodeForTokens(params: ExchangeCodeParams): Promise; + + /** + * Refresh an expired access token + * + * Uses a refresh token to obtain a new access token without user interaction. + * + * @param params - Token refresh parameters + * @returns New OAuth tokens + */ + refreshAccessToken(params: RefreshTokenParams): Promise; +} + +/** + * Standard OAuth 2.1 provider implementation + * + * Implements OAuth 2.1 authorization code flow with PKCE (RFC 7636). + * Uses Web Crypto API for secure random generation and SHA-256 hashing. + * + * PKCE is mandatory for all OAuth 2.1 flows to prevent authorization code + * interception attacks. + */ +export class StandardOAuthProvider implements OAuthProvider { + /** + * Generate a cryptographically random code verifier for PKCE + * + * Generates a 43-character base64url-encoded random string as specified in RFC 7636. + * + * @returns Base64url-encoded code verifier + */ + private generateCodeVerifier(): string { + const randomBytes = new Uint8Array(32); + crypto.getRandomValues(randomBytes); + return this.base64UrlEncode(randomBytes); + } + + /** + * Generate a code challenge from a code verifier using S256 method + * + * Creates a SHA-256 hash of the code verifier and base64url-encodes it. + * + * @param verifier - Code verifier to hash + * @returns Base64url-encoded code challenge + */ + private async generateCodeChallenge(verifier: string): Promise { + const encoder = new TextEncoder(); + const data = encoder.encode(verifier); + const hash = await crypto.subtle.digest("SHA-256", data); + return this.base64UrlEncode(new Uint8Array(hash)); + } + + /** + * Generate a cryptographically random state parameter + * + * @returns Random state string for CSRF protection + */ + private generateState(): string { + const randomBytes = new Uint8Array(16); + crypto.getRandomValues(randomBytes); + return this.base64UrlEncode(randomBytes); + } + + /** + * Base64url encode a byte array (RFC 4648 Section 5) + * + * @param bytes - Bytes to encode + * @returns Base64url-encoded string + */ + private base64UrlEncode(bytes: Uint8Array): string { + // Convert bytes to base64 + const base64 = btoa(String.fromCharCode(...bytes)); + + // Convert base64 to base64url (replace + with -, / with _, remove =) + return base64.replace(/\+/g, "-").replace(/\//g, "_").replace(/=/g, ""); + } + + async startAuthorizationFlow( + params: StartAuthorizationFlowParams, + ): Promise { + const codeVerifier = this.generateCodeVerifier(); + const codeChallenge = await this.generateCodeChallenge(codeVerifier); + const state = params.state || this.generateState(); + + // Build authorization URL with all required parameters + const url = new URL(params.authorizationEndpoint); + url.searchParams.set("response_type", "code"); + url.searchParams.set("client_id", params.clientId); + url.searchParams.set("redirect_uri", params.redirectUri); + url.searchParams.set("code_challenge", codeChallenge); + url.searchParams.set("code_challenge_method", "S256"); + url.searchParams.set("state", state); + url.searchParams.set("scope", params.scopes.join(" ")); + url.searchParams.set("resource", params.resource); + + return { + authorizationUrl: url.toString(), + codeVerifier, + state, + }; + } + + async exchangeCodeForTokens( + params: ExchangeCodeParams, + ): Promise { + const body = new URLSearchParams({ + grant_type: "authorization_code", + code: params.code, + redirect_uri: params.redirectUri, + client_id: params.clientId, + code_verifier: params.codeVerifier, + resource: params.resource, + }); + + const response = await fetch(params.tokenEndpoint, { + method: "POST", + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + body: body.toString(), + }); + + if (!response.ok) { + const errorText = await response.text(); + throw new Error( + `Token exchange failed: HTTP ${response.status} ${response.statusText} - ${errorText}`, + ); + } + + const data = (await response.json()) as { + access_token: string; + refresh_token?: string; + expires_in: number; + scope?: string; + token_type: string; + }; + + return this.parseTokenResponse(data); + } + + async refreshAccessToken(params: RefreshTokenParams): Promise { + const body = new URLSearchParams({ + grant_type: "refresh_token", + refresh_token: params.refreshToken, + client_id: params.clientId, + resource: params.resource, + }); + + const response = await fetch(params.tokenEndpoint, { + method: "POST", + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + body: body.toString(), + }); + + if (!response.ok) { + const errorText = await response.text(); + throw new Error( + `Token refresh failed: HTTP ${response.status} ${response.statusText} - ${errorText}`, + ); + } + + const data = (await response.json()) as { + access_token: string; + refresh_token?: string; + expires_in: number; + scope?: string; + token_type: string; + }; + + return this.parseTokenResponse(data); + } + + /** + * Parse OAuth token response into OAuthTokens structure + * + * @param data - Token response from OAuth server + * @returns Parsed OAuth tokens + */ + private parseTokenResponse(data: { + access_token: string; + refresh_token?: string; + expires_in: number; + scope?: string; + token_type: string; + }): OAuthTokens { + // Calculate expiry timestamp + const now = Math.floor(Date.now() / 1000); + const expiresAt = now + data.expires_in; + + // Parse scopes + const scopes = data.scope ? data.scope.split(" ") : []; + + return { + accessToken: data.access_token, + refreshToken: data.refresh_token, + expiresAt, + scopes, + tokenType: "Bearer", + }; + } +} diff --git a/packages/core/src/client/session-adapter.ts b/packages/core/src/client/session-adapter.ts new file mode 100644 index 0000000..84431b4 --- /dev/null +++ b/packages/core/src/client/session-adapter.ts @@ -0,0 +1,66 @@ +import type { InitializeResult } from "../types.js"; + +/** + * Client-side session data stored for each session + */ +export interface ClientSessionData { + sessionId: string; + protocolVersion: string; + serverInfo: { name: string; version: string }; + serverCapabilities: InitializeResult["capabilities"]; + createdAt: number; +} + +/** + * Adapter interface for client-side session persistence + * + * Implementations can store sessions in memory, localStorage, IndexedDB, etc. + */ +export interface ClientSessionAdapter { + /** + * Create and store a new session + * + * @param sessionId - Unique session identifier + * @param data - Session data to store + */ + create(sessionId: string, data: ClientSessionData): Promise | void; + + /** + * Retrieve session data by ID + * + * @param sessionId - Session identifier + * @returns Session data if found, undefined otherwise + */ + get( + sessionId: string, + ): Promise | ClientSessionData | undefined; + + /** + * Delete a session + * + * @param sessionId - Session identifier + */ + delete(sessionId: string): Promise | void; +} + +/** + * In-memory client session adapter + * + * Stores sessions in memory. Sessions are lost when the process exits. + * Suitable for testing and short-lived clients. + */ +export class InMemoryClientSessionAdapter implements ClientSessionAdapter { + private sessions = new Map(); + + create(sessionId: string, data: ClientSessionData): void { + this.sessions.set(sessionId, data); + } + + get(sessionId: string): ClientSessionData | undefined { + return this.sessions.get(sessionId); + } + + delete(sessionId: string): void { + this.sessions.delete(sessionId); + } +} diff --git a/packages/core/src/client/transport-http.ts b/packages/core/src/client/transport-http.ts new file mode 100644 index 0000000..f8f95db --- /dev/null +++ b/packages/core/src/client/transport-http.ts @@ -0,0 +1,492 @@ +import { + JSON_RPC_VERSION, + MCP_PROTOCOL_HEADER, + MCP_SESSION_ID_HEADER, + SUPPORTED_MCP_PROTOCOL_VERSIONS, +} from "../constants.js"; +import { RpcError } from "../errors.js"; +import type { InitializeResult, JsonRpcRes } from "../types.js"; +import type { McpClient } from "./client.js"; +import { Connection } from "./connection.js"; +import type { OAuthAdapter } from "./oauth-adapter.js"; +import { registerOAuthClient } from "./oauth-dcr.js"; +import { discoverOAuthEndpoints } from "./oauth-discovery.js"; +import type { OAuthProvider } from "./oauth-provider.js"; +import type { ClientSessionAdapter } from "./session-adapter.js"; + +/** + * OAuth configuration for authenticated MCP servers + */ +export interface OAuthConfig { + /** + * OAuth client ID (optional). + * If not provided, Dynamic Client Registration (RFC 7591) will be used + * to obtain a client ID automatically if the server supports it. + */ + clientId?: string; + /** OAuth redirect URI for authorization callback */ + redirectUri: string; + /** + * Client name for Dynamic Client Registration. + * Used when clientId is not provided and DCR is performed. + * Defaults to "MCP Client" if not specified. + */ + clientName?: string; + /** + * Callback invoked when user authorization is required. + * Implementation should redirect user to the authorization URL. + * After user authorizes, call completeAuthorizationFlow() with the code and state. + */ + onAuthorizationRequired: (authorizationUrl: string) => void; +} + +/** + * OAuth flow state stored during authorization + */ +interface PendingAuthState { + codeVerifier: string; + state: string; + tokenEndpoint: string; + clientId: string; +} + +/** + * Options for creating an HTTP client transport + */ +export interface StreamableHttpClientTransportOptions { + /** + * Optional session adapter for persisting session state. + * If provided, the transport will enable session-based mode. + */ + sessionAdapter?: ClientSessionAdapter; + + /** + * Optional OAuth adapter for token storage. + * Required if connecting to OAuth-protected MCP servers. + */ + oauthAdapter?: OAuthAdapter; + + /** + * Optional OAuth provider for handling OAuth flows. + * Required if connecting to OAuth-protected MCP servers. + */ + oauthProvider?: OAuthProvider; + + /** + * Optional OAuth configuration. + * Required if connecting to OAuth-protected MCP servers. + */ + oauthConfig?: OAuthConfig; +} + +/** + * Options for connecting to a server + */ +export interface ConnectOptions { + /** + * Optional custom headers to include in all requests to this server. + * These will be merged with protocol-required headers. + * Can be used for authentication tokens, API keys, etc. + * + * @example + * ```typescript + * { + * headers: { + * 'Authorization': 'Bearer my-token', + * 'X-API-Key': 'my-key' + * } + * } + * ``` + */ + headers?: Record; +} + +/** + * HTTP transport for MCP clients. + * + * Handles initialization and request/response communication with MCP servers + * over HTTP. + * + * @example + * ```typescript + * const client = new McpClient({ name: "my-client", version: "1.0.0" }); + * const transport = new StreamableHttpClientTransport(); + * const connect = transport.bind(client); + * + * // Connect without custom headers + * const connection = await connect("http://localhost:3000"); + * + * // Connect with custom headers (e.g., authentication) + * const connection = await connect("http://localhost:3000", { + * headers: { + * 'Authorization': 'Bearer my-token', + * 'X-API-Key': 'my-key' + * } + * }); + * ``` + */ +export class StreamableHttpClientTransport { + private client?: McpClient; + private sessionAdapter?: ClientSessionAdapter; + private oauthAdapter?: OAuthAdapter; + private oauthProvider?: OAuthProvider; + private oauthConfig?: OAuthConfig; + private pendingAuthFlows = new Map(); + + constructor(options?: StreamableHttpClientTransportOptions) { + this.sessionAdapter = options?.sessionAdapter; + this.oauthAdapter = options?.oauthAdapter; + this.oauthProvider = options?.oauthProvider; + this.oauthConfig = options?.oauthConfig; + + // Validate OAuth configuration consistency + if (this.oauthAdapter || this.oauthProvider || this.oauthConfig) { + if (!this.oauthAdapter || !this.oauthProvider || !this.oauthConfig) { + throw new Error( + "OAuth configuration incomplete: oauthAdapter, oauthProvider, and oauthConfig must all be provided together", + ); + } + + // Validate that redirectUri is always provided + if (!this.oauthConfig.redirectUri) { + throw new Error("OAuth configuration missing redirectUri"); + } + } + } + + /** + * Bind the transport to a client instance. + * + * @param client - The MCP client instance + * @returns A connect function that initializes connections to servers + */ + bind( + client: McpClient, + ): (baseUrl: string, options?: ConnectOptions) => Promise { + this.client = client; + + return async (baseUrl: string, options?: ConnectOptions) => { + if (!this.client) { + throw new Error("Transport not bound to a client"); + } + + const customHeaders = options?.headers; + + // Try to get existing valid token if OAuth is configured + const accessToken = await this.ensureValidToken(baseUrl); + + // Send initialize request + const initRequest = { + jsonrpc: JSON_RPC_VERSION, + id: "init", + method: "initialize", + params: { + protocolVersion: SUPPORTED_MCP_PROTOCOL_VERSIONS.V2025_06_18, + clientInfo: this.client.clientInfo, + capabilities: this.client.capabilities || {}, + }, + }; + + const headers: Record = { + "Content-Type": "application/json", + [MCP_PROTOCOL_HEADER]: SUPPORTED_MCP_PROTOCOL_VERSIONS.V2025_06_18, + }; + + // Add Authorization header if we have a token + if (accessToken) { + headers.Authorization = `Bearer ${accessToken}`; + } + + // Merge custom headers (these override defaults if there are conflicts) + if (customHeaders) { + Object.assign(headers, customHeaders); + } + + const response = await fetch(baseUrl, { + method: "POST", + headers, + body: JSON.stringify(initRequest), + }); + + // Handle 401 Unauthorized - start OAuth flow + if ( + response.status === 401 && + this.oauthAdapter && + this.oauthProvider && + this.oauthConfig + ) { + await this.handleAuthenticationRequired(baseUrl); + throw new Error( + "Authentication required. Authorization flow started. Please complete authorization and retry connection.", + ); + } + + if (!response.ok) { + throw new Error( + `Failed to initialize: HTTP ${response.status} ${response.statusText}`, + ); + } + + const result = (await response.json()) as JsonRpcRes; + + if (result.error) { + throw new RpcError( + result.error.code, + result.error.message, + result.error.data, + ); + } + + const initResult = result.result as InitializeResult; + + // Get session ID from header if session adapter is configured + const sessionId = this.sessionAdapter + ? response.headers.get(MCP_SESSION_ID_HEADER) || undefined + : undefined; + + // Store session data if we have an adapter and session ID + if (sessionId && this.sessionAdapter) { + await this.sessionAdapter.create(sessionId, { + sessionId, + protocolVersion: SUPPORTED_MCP_PROTOCOL_VERSIONS.V2025_06_18, + serverInfo: initResult.serverInfo, + serverCapabilities: initResult.capabilities, + createdAt: Date.now(), + }); + } + + // Set connection info on client after successful initialization + this.client._setConnectionInfo({ + serverInfo: initResult.serverInfo, + protocolVersion: SUPPORTED_MCP_PROTOCOL_VERSIONS.V2025_06_18, + }); + + // Create connection with server info and capabilities + const connection = new Connection({ + baseUrl, + serverInfo: initResult.serverInfo, + serverCapabilities: initResult.capabilities, + sessionId, + responseSender: sessionId + ? this.createResponseSender(baseUrl, sessionId, customHeaders) + : undefined, + headers: customHeaders, + }); + + // Set client instance for handling server requests + connection._setClient(this.client); + + return connection; + }; + } + + /** + * Complete an OAuth authorization flow after user has authorized + * + * Call this method after the user is redirected back from the OAuth server + * with an authorization code and state parameter. + * + * @param baseUrl - MCP server base URL (must match the one used in initialization) + * @param code - Authorization code from OAuth callback + * @param state - State parameter from OAuth callback + * + * @example + * ```typescript + * // After user is redirected to your redirect_uri with ?code=...&state=... + * await transport.completeAuthorizationFlow( + * "https://api.example.com", + * code, + * state + * ); + * ``` + */ + async completeAuthorizationFlow( + baseUrl: string, + code: string, + state: string, + ): Promise { + if (!this.oauthAdapter || !this.oauthProvider || !this.oauthConfig) { + throw new Error("OAuth not configured for this transport"); + } + + // Retrieve and validate pending auth state + const pendingAuth = this.pendingAuthFlows.get(baseUrl); + if (!pendingAuth) { + throw new Error( + "No pending authorization flow found for this server. Authorization may have expired or already been completed.", + ); + } + + // Validate state parameter (CSRF protection) + if (pendingAuth.state !== state) { + this.pendingAuthFlows.delete(baseUrl); + throw new Error("State parameter mismatch. Possible CSRF attack."); + } + + // Exchange authorization code for tokens + const tokens = await this.oauthProvider.exchangeCodeForTokens({ + tokenEndpoint: pendingAuth.tokenEndpoint, + code, + codeVerifier: pendingAuth.codeVerifier, + clientId: pendingAuth.clientId, + redirectUri: this.oauthConfig.redirectUri, + resource: baseUrl, + }); + + // Store tokens + await this.oauthAdapter.storeTokens(baseUrl, tokens); + + // Clean up pending auth state + this.pendingAuthFlows.delete(baseUrl); + } + + /** + * Ensure a valid access token exists for the given resource server. + * Automatically refreshes the token if it's expired. + * + * @param resource - Resource server URL (MCP server base URL) + * @returns Valid access token, or undefined if no token exists or OAuth not configured + * @private + */ + private async ensureValidToken( + resource: string, + ): Promise { + if (!this.oauthAdapter || !this.oauthProvider || !this.oauthConfig) { + return undefined; + } + + // Check if we have a valid token + const hasValid = await this.oauthAdapter.hasValidToken(resource); + if (hasValid) { + const tokens = await this.oauthAdapter.getTokens(resource); + return tokens?.accessToken; + } + + // Try to refresh if we have a refresh token + const tokens = await this.oauthAdapter.getTokens(resource); + if (tokens?.refreshToken) { + // Token exists but expired - try to refresh + const pendingAuth = this.pendingAuthFlows.get(resource); + if (!pendingAuth) { + // No token endpoint available - can't refresh + return undefined; + } + + const newTokens = await this.oauthProvider.refreshAccessToken({ + tokenEndpoint: pendingAuth.tokenEndpoint, + refreshToken: tokens.refreshToken, + clientId: pendingAuth.clientId, + resource, + }); + + await this.oauthAdapter.storeTokens(resource, newTokens); + return newTokens.accessToken; + } + + return undefined; + } + + /** + * Handle authentication required (401) response by starting OAuth flow + * + * @param baseUrl - MCP server base URL + * @private + */ + private async handleAuthenticationRequired(baseUrl: string): Promise { + if (!this.oauthAdapter || !this.oauthProvider || !this.oauthConfig) { + throw new Error("OAuth not configured for this transport"); + } + + // Discover OAuth endpoints + const endpoints = await discoverOAuthEndpoints(baseUrl); + + // Determine client ID to use (from config or DCR) + let clientId = this.oauthConfig.clientId; + + // If no client ID provided, try Dynamic Client Registration + if (!clientId) { + // Check if we already have registered client credentials for this authorization server + const storedCredentials = await this.oauthAdapter.getClientCredentials( + endpoints.authorizationServer, + ); + + if (storedCredentials) { + clientId = storedCredentials.clientId; + } else if (endpoints.registrationEndpoint) { + // Perform Dynamic Client Registration + const credentials = await registerOAuthClient( + endpoints.registrationEndpoint, + { + clientName: this.oauthConfig.clientName || "MCP Client", + redirectUris: [this.oauthConfig.redirectUri], + grantTypes: ["authorization_code", "refresh_token"], + tokenEndpointAuthMethod: "none", + scope: endpoints.scopes.join(" "), + }, + ); + + // Store registered client credentials + await this.oauthAdapter.storeClientCredentials( + endpoints.authorizationServer, + credentials, + ); + + clientId = credentials.clientId; + } else { + throw new Error( + "No client ID provided and Dynamic Client Registration not available. " + + "Either provide a clientId in OAuthConfig or ensure the server supports DCR.", + ); + } + } + + // Start authorization flow + const flowResult = await this.oauthProvider.startAuthorizationFlow({ + authorizationEndpoint: endpoints.authorizationEndpoint, + clientId, + redirectUri: this.oauthConfig.redirectUri, + scopes: endpoints.scopes, + resource: baseUrl, + }); + + // Store pending auth state for later validation + this.pendingAuthFlows.set(baseUrl, { + codeVerifier: flowResult.codeVerifier, + state: flowResult.state, + tokenEndpoint: endpoints.tokenEndpoint, + clientId, + }); + + // Notify application to redirect user + this.oauthConfig.onAuthorizationRequired(flowResult.authorizationUrl); + } + + /** + * Create a function that sends JSON-RPC responses back to the server + * @private + */ + private createResponseSender( + baseUrl: string, + sessionId: string, + customHeaders?: Record, + ): (response: JsonRpcRes) => Promise { + return async (response: JsonRpcRes) => { + const headers: Record = { + "Content-Type": "application/json", + [MCP_PROTOCOL_HEADER]: SUPPORTED_MCP_PROTOCOL_VERSIONS.V2025_06_18, + [MCP_SESSION_ID_HEADER]: sessionId, + }; + + // Merge custom headers (these override defaults if there are conflicts) + if (customHeaders) { + Object.assign(headers, customHeaders); + } + + await fetch(baseUrl, { + method: "POST", + headers, + body: JSON.stringify(response), + }); + }; + } +} diff --git a/packages/core/src/client/types.ts b/packages/core/src/client/types.ts new file mode 100644 index 0000000..4925db5 --- /dev/null +++ b/packages/core/src/client/types.ts @@ -0,0 +1,73 @@ +import type { JsonRpcReq, JsonRpcRes } from "../types.js"; + +/** + * Connection information provided to handlers for context about the connected server + */ +export interface ClientConnectionInfo { + serverInfo: { name: string; version: string }; + protocolVersion: string; +} + +/** + * Sampling request parameters sent from server to client + */ +export interface SamplingParams { + messages: unknown[]; + modelPreferences?: { + hints?: Array<{ name?: string }>; + costPriority?: number; + speedPriority?: number; + intelligencePriority?: number; + }; + systemPrompt?: string; + includeContext?: "none" | "thisServer" | "allServers"; + maxTokens: number; + temperature?: number; + stopSequences?: string[]; + metadata?: Record; +} + +/** + * Sampling result returned from client to server + */ +export interface SamplingResult { + role: "assistant"; + content: { + type: "text"; + text: string; + }; + model: string; + stopReason?: "endTurn" | "stopSequence" | "maxTokens" | string; +} + +/** + * Handler function for sampling requests from server + */ +export type SampleHandler = ( + params: SamplingParams, + connection?: ClientConnectionInfo, +) => Promise | SamplingResult; + +/** + * Elicitation request parameters sent from server to client + */ +export interface ElicitationParams { + message: string; + requestedSchema: unknown; // JSON Schema +} + +/** + * Elicitation result returned from client to server + */ +export interface ElicitationResult { + action: "accept" | "decline" | "cancel"; + content?: Record; // Present on "accept" +} + +/** + * Handler function for elicitation requests from server + */ +export type ElicitHandler = ( + params: ElicitationParams, + connection?: ClientConnectionInfo, +) => Promise | ElicitationResult; diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 3057e98..834e48f 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -1,3 +1,41 @@ +export type { ToolAdapter } from "./client/adapters/index.js"; +// Client exports +export { + type AuthorizationFlowResult, + type ClientCapabilities, + type ClientConnectionInfo, + type ClientCredentials, + type ClientMetadata, + type ClientSessionAdapter, + type ClientSessionData, + Connection, + type ConnectionOptions, + type ConnectOptions, + discoverOAuthEndpoints, + type ElicitationParams, + type ElicitationResult, + type ElicitHandler, + type ExchangeCodeParams, + InMemoryClientSessionAdapter, + InMemoryOAuthAdapter, + McpClient, + type McpClientOptions, + type OAuthAdapter, + type OAuthConfig, + type OAuthEndpoints, + type OAuthProvider, + type OAuthTokens, + type RefreshTokenParams, + registerOAuthClient, + type SampleHandler, + type SamplingParams, + type SamplingResult, + StandardOAuthProvider, + type StartAuthorizationFlowParams, + type StoredClientCredentials, + StreamableHttpClientTransport, + type StreamableHttpClientTransportOptions, +} from "./client/index.js"; export type { ClientRequestAdapter } from "./client-request-adapter.js"; export { InMemoryClientRequestAdapter } from "./client-request-adapter.js"; export { diff --git a/packages/core/tests/integration/client-custom-headers.test.ts b/packages/core/tests/integration/client-custom-headers.test.ts new file mode 100644 index 0000000..6784f58 --- /dev/null +++ b/packages/core/tests/integration/client-custom-headers.test.ts @@ -0,0 +1,210 @@ +/** biome-ignore-all lint/style/noNonNullAssertion: tests */ +import { afterEach, beforeEach, describe, expect, it } from "bun:test"; +import { createTestHarness, type TestServer } from "@internal/test-utils"; +import { + InMemoryClientSessionAdapter, + InMemorySessionAdapter, + McpClient, + McpServer, + StreamableHttpClientTransport, +} from "../../src/index.js"; + +describe("MCP Client - Custom Headers", () => { + let testServer: TestServer; + let mcpServer: McpServer; + let serverUrl: string; + + beforeEach(async () => { + mcpServer = new McpServer({ + name: "test-server", + version: "1.0.0", + }); + + mcpServer.tool("echo", { + description: "Echoes input", + handler: (args: { message: string }) => ({ + content: [{ type: "text", text: args.message }], + }), + }); + + testServer = await createTestHarness(mcpServer, { + sessionAdapter: new InMemorySessionAdapter({ maxEventBufferSize: 1024 }), + }); + serverUrl = testServer.url; + }); + + afterEach(async () => { + await testServer.stop(); + }); + + it("should successfully connect with custom Authorization header", async () => { + const client = new McpClient({ + name: "test-client", + version: "1.0.0", + }); + + const transport = new StreamableHttpClientTransport({ + sessionAdapter: new InMemoryClientSessionAdapter(), + }); + + const connect = transport.bind(client); + const connection = await connect(serverUrl, { + headers: { + Authorization: "Bearer test-token-123", + "X-API-Key": "my-api-key", + }, + }); + + // Should connect successfully even with custom headers + expect(connection.serverInfo.name).toBe("test-server"); + + // Should be able to make requests with headers included + const result = await connection.callTool("echo", { message: "test" }); + expect(result.content[0].text).toBe("test"); + + await connection.close(true); + }); + + it("should successfully connect with multiple custom headers", async () => { + const client = new McpClient({ + name: "test-client", + version: "1.0.0", + }); + + const transport = new StreamableHttpClientTransport({ + sessionAdapter: new InMemoryClientSessionAdapter(), + }); + + const connect = transport.bind(client); + const connection = await connect(serverUrl, { + headers: { + Authorization: "Bearer test-token-456", + "X-Custom-Header-1": "value1", + "X-Custom-Header-2": "value2", + "X-Request-ID": "req-123", + }, + }); + + // Connection should work + expect(connection.serverInfo.name).toBe("test-server"); + + // Should be able to list tools + const tools = await connection.listTools(); + expect(tools.tools.length).toBeGreaterThan(0); + + await connection.close(true); + }); + + it("should work without custom headers", async () => { + const client = new McpClient({ + name: "test-client", + version: "1.0.0", + }); + + const transport = new StreamableHttpClientTransport({ + sessionAdapter: new InMemoryClientSessionAdapter(), + // No headers specified + }); + + const connect = transport.bind(client); + const connection = await connect(serverUrl); + + // Should still work normally + const result = await connection.callTool("echo", { message: "no headers" }); + expect(result.content[0].text).toBe("no headers"); + + await connection.close(true); + }); + + it("should include headers in all request types", async () => { + const client = new McpClient({ + name: "test-client", + version: "1.0.0", + }); + + const transport = new StreamableHttpClientTransport({ + sessionAdapter: new InMemoryClientSessionAdapter(), + }); + + const connect = transport.bind(client); + const connection = await connect(serverUrl, { + headers: { + Authorization: "Bearer comprehensive-test", + }, + }); + + // All these operations should work with headers + await connection.listTools(); + await connection.callTool("echo", { message: "test" }); + + // List other resources + await connection.listPrompts(); + await connection.listResources(); + + await connection.close(true); + }); + + it("should support different headers for multiple servers", async () => { + // Create second server + const server2 = new McpServer({ + name: "test-server-2", + version: "1.0.0", + }); + + server2.tool("echo2", { + description: "Echoes input", + handler: (args: { message: string }) => ({ + content: [{ type: "text", text: args.message }], + }), + }); + + const testServer2 = await createTestHarness(server2, { + sessionAdapter: new InMemorySessionAdapter({ maxEventBufferSize: 1024 }), + }); + const serverUrl2 = testServer2.url; + + try { + const client = new McpClient({ + name: "test-client", + version: "1.0.0", + }); + + const transport = new StreamableHttpClientTransport({ + sessionAdapter: new InMemoryClientSessionAdapter(), + }); + + const connect = transport.bind(client); + + // Connect to first server with one set of headers + const connection1 = await connect(serverUrl, { + headers: { + Authorization: "Bearer server-1-token", + "X-Server": "server-1", + }, + }); + + // Connect to second server with different headers + const connection2 = await connect(serverUrl2, { + headers: { + Authorization: "Bearer server-2-token", + "X-Server": "server-2", + }, + }); + + // Both connections should work independently + expect(connection1.serverInfo.name).toBe("test-server"); + expect(connection2.serverInfo.name).toBe("test-server-2"); + + const result1 = await connection1.callTool("echo", { message: "test1" }); + const result2 = await connection2.callTool("echo2", { message: "test2" }); + + expect(result1.content[0].text).toBe("test1"); + expect(result2.content[0].text).toBe("test2"); + + await connection1.close(true); + await connection2.close(true); + } finally { + await testServer2.stop(); + } + }); +}); diff --git a/packages/core/tests/integration/client-e2e-full.test.ts b/packages/core/tests/integration/client-e2e-full.test.ts new file mode 100644 index 0000000..c5b74d0 --- /dev/null +++ b/packages/core/tests/integration/client-e2e-full.test.ts @@ -0,0 +1,509 @@ +/** biome-ignore-all lint/style/noNonNullAssertion: tests */ +import { afterEach, beforeEach, describe, expect, it } from "bun:test"; +import { + createTestHarness, + openSessionStream, + type TestServer, +} from "../../../test-utils/src/index.js"; +import { + InMemoryClientRequestAdapter, + InMemoryClientSessionAdapter, + InMemorySessionAdapter, + McpClient, + McpServer, + StreamableHttpClientTransport, +} from "../../src/index.js"; + +describe("MCP Client - End-to-End Full Workflows", () => { + describe("Multi-server workflow", () => { + let githubServer: TestServer; + let slackServer: TestServer; + let dbServer: TestServer; + + beforeEach(async () => { + // Create GitHub server + const github = new McpServer({ name: "github-server", version: "1.0.0" }); + github.tool("listRepos", { + description: "List repositories", + handler: () => ({ + content: [{ type: "text", text: "repo1, repo2, repo3" }], + }), + }); + github.tool("createIssue", { + description: "Create issue", + handler: (args: { title: string }) => ({ + content: [{ type: "text", text: `Issue created: ${args.title}` }], + }), + }); + githubServer = await createTestHarness(github); + + // Create Slack server + const slack = new McpServer({ name: "slack-server", version: "1.0.0" }); + slack.tool("postMessage", { + description: "Post message", + handler: (args: { channel: string; text: string }) => ({ + content: [ + { type: "text", text: `Posted to ${args.channel}: ${args.text}` }, + ], + }), + }); + slackServer = await createTestHarness(slack); + + // Create DB server + const db = new McpServer({ name: "db-server", version: "1.0.0" }); + db.tool("query", { + description: "Run query", + handler: (args: { sql: string }) => ({ + content: [{ type: "text", text: `Query result: 5 rows` }], + }), + }); + dbServer = await createTestHarness(db); + }); + + afterEach(async () => { + await githubServer.stop(); + await slackServer.stop(); + await dbServer.stop(); + }); + + it("should connect to multiple servers and use tools from each", async () => { + const client = new McpClient({ + name: "multi-client", + version: "1.0.0", + }); + + const transport = new StreamableHttpClientTransport(); + const connect = transport.bind(client); + + // Connect to all three servers + const githubConn = await connect(githubServer.url); + const slackConn = await connect(slackServer.url); + const dbConn = await connect(dbServer.url); + + // Verify server info + expect(githubConn.serverInfo.name).toBe("github-server"); + expect(slackConn.serverInfo.name).toBe("slack-server"); + expect(dbConn.serverInfo.name).toBe("db-server"); + + // Get tools from each server + const githubTools = await githubConn.listTools(); + const slackTools = await slackConn.listTools(); + const dbTools = await dbConn.listTools(); + + expect(githubTools.tools).toHaveLength(2); + expect(slackTools.tools).toHaveLength(1); + expect(dbTools.tools).toHaveLength(1); + + // Execute a workflow using tools from all servers + const repos = await githubConn.callTool("listRepos", {}); + expect(repos.content[0].text).toContain("repo1"); + + const issue = await githubConn.callTool("createIssue", { + title: "Bug in repo1", + }); + expect(issue.content[0].text).toContain("Issue created"); + + const message = await slackConn.callTool("postMessage", { + channel: "#dev", + text: "New issue created", + }); + expect(message.content[0].text).toContain("Posted to #dev"); + + const dbResult = await dbConn.callTool("query", { + sql: "SELECT * FROM issues", + }); + expect(dbResult.content[0].text).toContain("5 rows"); + }); + + it("should handle concurrent operations across multiple servers", async () => { + const client = new McpClient({ + name: "multi-client", + version: "1.0.0", + }); + + const transport = new StreamableHttpClientTransport(); + const connect = transport.bind(client); + + const githubConn = await connect(githubServer.url); + const slackConn = await connect(slackServer.url); + const dbConn = await connect(dbServer.url); + + // Execute operations in parallel + const results = await Promise.all([ + githubConn.callTool("listRepos", {}), + slackConn.callTool("postMessage", { channel: "#dev", text: "test" }), + dbConn.callTool("query", { sql: "SELECT 1" }), + githubConn.callTool("createIssue", { title: "Test" }), + ]); + + expect(results).toHaveLength(4); + expect(results[0].content[0].text).toContain("repo1"); + expect(results[1].content[0].text).toContain("Posted"); + expect(results[2].content[0].text).toContain("Query result"); + expect(results[3].content[0].text).toContain("Issue created"); + }); + }); + + describe("Session with progress notifications", () => { + let testServer: TestServer; + let mcpServer: McpServer; + + beforeEach(async () => { + mcpServer = new McpServer({ + name: "progress-server", + version: "1.0.0", + }); + + mcpServer.tool("longRunning", { + description: "Long running task", + handler: async (args: { steps: number }, ctx) => { + for (let i = 1; i <= args.steps; i++) { + await ctx.progress?.({ + progress: i, + total: args.steps, + message: `Processing step ${i}`, + }); + await new Promise((resolve) => setTimeout(resolve, 50)); + } + return { content: [{ type: "text", text: "Completed" }] }; + }, + }); + + testServer = await createTestHarness(mcpServer, { + sessionAdapter: new InMemorySessionAdapter({ + maxEventBufferSize: 1024, + }), + }); + }); + + afterEach(async () => { + await testServer.stop(); + }); + + it("should receive progress notifications during tool execution", async () => { + const client = new McpClient({ + name: "progress-client", + version: "1.0.0", + }); + + const transport = new StreamableHttpClientTransport({ + sessionAdapter: new InMemoryClientSessionAdapter(), + }); + const connect = transport.bind(client); + const connection = await connect(testServer.url); + + // Track progress events + const progressEvents: any[] = []; + + // For this test, we want to observe progress events, so use test-utils helper + // (don't use connection.openSessionStream() - server only allows one stream) + const stream = await openSessionStream( + testServer.url, + connection.sessionId!, + ); + const reader = stream.getReader(); + const decoder = new TextDecoder(); + + // Start reading in background + const readPromise = (async () => { + try { + let buffer = ""; + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + buffer += decoder.decode(value, { stream: true }); + const lines = buffer.split("\n"); + buffer = lines.pop() || ""; + + for (const line of lines) { + if (line.startsWith("data: ")) { + const data = JSON.parse(line.slice(6)); + if (data.method === "notifications/progress") { + progressEvents.push(data.params); + } + } + } + + // Stop after getting some events + if (progressEvents.length >= 3) { + break; + } + } + } catch (error) { + // Stream cancelled + } + })(); + + // Execute tool with progress token + const resultPromise = fetch(testServer.url, { + method: "POST", + headers: { + "Content-Type": "application/json", + "MCP-Protocol-Version": "2025-06-18", + "MCP-Session-Id": connection.sessionId!, + }, + body: JSON.stringify({ + jsonrpc: "2.0", + id: "progress-test", + method: "tools/call", + params: { + _meta: { progressToken: "test-123" }, + name: "longRunning", + arguments: { steps: 5 }, + }, + }), + }); + + await readPromise; + reader.cancel(); + + expect(progressEvents.length).toBeGreaterThanOrEqual(3); + expect(progressEvents[0].progressToken).toBe("test-123"); + expect(progressEvents[0].message).toContain("step 1"); + + await resultPromise; + await connection.close(true); + }); + }); + + describe("Elicitation workflow", () => { + let testServer: TestServer; + let mcpServer: McpServer; + + beforeEach(async () => { + mcpServer = new McpServer({ + name: "elicit-server", + version: "1.0.0", + }); + + mcpServer.tool("getUserInfo", { + description: "Get user information", + handler: async (_, ctx) => { + if (!ctx.client.supports("elicitation")) { + return { content: [{ type: "text", text: "No elicitation" }] }; + } + + const nameResult = await ctx.elicit({ + message: "What is your name?", + schema: { + type: "object", + properties: { name: { type: "string" } }, + required: ["name"], + }, + }); + + if (nameResult.action !== "accept") { + return { content: [{ type: "text", text: "User declined" }] }; + } + + const ageResult = await ctx.elicit({ + message: "What is your age?", + schema: { + type: "object", + properties: { age: { type: "number" } }, + required: ["age"], + }, + }); + + if (ageResult.action !== "accept") { + return { content: [{ type: "text", text: "User declined age" }] }; + } + + return { + content: [ + { + type: "text", + text: `Name: ${nameResult.content?.name}, Age: ${ageResult.content?.age}`, + }, + ], + }; + }, + }); + + testServer = await createTestHarness(mcpServer, { + sessionAdapter: new InMemorySessionAdapter({ + maxEventBufferSize: 1024, + }), + clientRequestAdapter: new InMemoryClientRequestAdapter(), + }); + }); + + afterEach(async () => { + await testServer.stop(); + }); + + it("should handle multiple elicitations in sequence", async () => { + const client = new McpClient({ + name: "elicit-client", + version: "1.0.0", + capabilities: { elicitation: {} }, + }); + + let elicitCount = 0; + client.onElicit(async (params) => { + elicitCount++; + + if (params.message.includes("name")) { + return { action: "accept", content: { name: "Alice" } }; + } else if (params.message.includes("age")) { + return { action: "accept", content: { age: 30 } }; + } + + return { action: "decline" }; + }); + + const transport = new StreamableHttpClientTransport({ + sessionAdapter: new InMemoryClientSessionAdapter(), + }); + const connect = transport.bind(client); + const connection = await connect(testServer.url); + + await connection.openSessionStream(); + + const result = await connection.callTool("getUserInfo", {}); + + expect(elicitCount).toBe(2); // Both elicitations called + expect(result.content[0].text).toBe("Name: Alice, Age: 30"); + + await connection.close(true); + }); + }); + + describe("Error recovery", () => { + let testServer: TestServer; + let mcpServer: McpServer; + + beforeEach(async () => { + mcpServer = new McpServer({ + name: "error-server", + version: "1.0.0", + }); + + let callCount = 0; + mcpServer.tool("flaky", { + description: "Sometimes fails", + handler: () => { + callCount++; + if (callCount === 1) { + throw new Error("Temporary failure"); + } + return { content: [{ type: "text", text: "Success" }] }; + }, + }); + + testServer = await createTestHarness(mcpServer); + }); + + afterEach(async () => { + await testServer.stop(); + }); + + it("should handle tool errors and allow retry", async () => { + const client = new McpClient({ + name: "retry-client", + version: "1.0.0", + }); + + const transport = new StreamableHttpClientTransport(); + const connect = transport.bind(client); + const connection = await connect(testServer.url); + + // First call fails + await expect(connection.callTool("flaky", {})).rejects.toThrow(); + + // Second call succeeds + const result = await connection.callTool("flaky", {}); + expect(result.content[0].text).toBe("Success"); + }); + }); + + describe("Tool adapter interface", () => { + let testServer: TestServer; + + beforeEach(async () => { + const server = new McpServer({ + name: "adapter-server", + version: "1.0.0", + }); + + server.tool("calculate", { + description: "Calculate something", + inputSchema: { + type: "object", + properties: { + operation: { type: "string" }, + a: { type: "number" }, + b: { type: "number" }, + }, + required: ["operation", "a", "b"], + }, + handler: (args: { operation: string; a: number; b: number }) => { + let result = 0; + if (args.operation === "add") result = args.a + args.b; + else if (args.operation === "multiply") result = args.a * args.b; + + return { + content: [{ type: "text", text: `Result: ${result}` }], + structuredContent: { result }, + }; + }, + }); + + testServer = await createTestHarness(server); + }); + + afterEach(async () => { + await testServer.stop(); + }); + + it("should demonstrate tool adapter pattern", async () => { + const client = new McpClient({ + name: "adapter-client", + version: "1.0.0", + }); + + const transport = new StreamableHttpClientTransport(); + const connect = transport.bind(client); + const connection = await connect(testServer.url); + + // Get tools + const { tools } = await connection.listTools(); + expect(tools).toHaveLength(1); + + // Example adapter pattern (user would implement this) + class SimpleAdapter { + toSDK(mcpTool: any) { + return { + name: mcpTool.name, + description: mcpTool.description, + parameters: mcpTool.inputSchema, + }; + } + + async execute(connection: any, toolName: string, args: any) { + const result = await connection.callTool(toolName, args); + // Convert MCP result to SDK format + if (result.structuredContent) { + return result.structuredContent; + } + return result.content[0]?.text; + } + } + + const adapter = new SimpleAdapter(); + const sdkTool = adapter.toSDK(tools[0]); + + expect(sdkTool.name).toBe("calculate"); + expect(sdkTool.parameters).toBeDefined(); + + const result = await adapter.execute(connection, "calculate", { + operation: "add", + a: 5, + b: 3, + }); + + expect(result).toEqual({ result: 8 }); + }); + }); +}); diff --git a/packages/core/tests/integration/client-oauth.test.ts b/packages/core/tests/integration/client-oauth.test.ts new file mode 100644 index 0000000..7d14216 --- /dev/null +++ b/packages/core/tests/integration/client-oauth.test.ts @@ -0,0 +1,653 @@ +import { afterEach, beforeEach, describe, expect, it, mock } from "bun:test"; +import { type Server, serve } from "bun"; +import { + discoverOAuthEndpoints, + InMemoryOAuthAdapter, + McpClient, + type OAuthConfig, + StandardOAuthProvider, + StreamableHttpClientTransport, +} from "../../src/index.js"; + +describe("MCP Client - OAuth Integration", () => { + let oauthServer: Server; + let mcpServer: Server; + let oauthServerUrl: string; + let mcpServerUrl: string; + let authorizationCallbackUrl: string | null = null; + + // Mock OAuth token response + const mockTokenResponse = { + access_token: "mock_access_token_12345", + refresh_token: "mock_refresh_token_67890", + expires_in: 3600, + scope: "mcp:access", + token_type: "Bearer", + }; + + beforeEach(async () => { + // Create OAuth authorization server + oauthServer = serve({ + port: 0, // random port + async fetch(request) { + const url = new URL(request.url); + + // OAuth discovery endpoint (RFC 8414) + if (url.pathname === "/.well-known/oauth-authorization-server") { + return new Response( + JSON.stringify({ + issuer: oauthServerUrl, + authorization_endpoint: `${oauthServerUrl}/authorize`, + token_endpoint: `${oauthServerUrl}/token`, + code_challenge_methods_supported: ["S256"], + grant_types_supported: ["authorization_code", "refresh_token"], + }), + { headers: { "Content-Type": "application/json" } }, + ); + } + + // Token endpoint + if (url.pathname === "/token" && request.method === "POST") { + const body = await request.text(); + const params = new URLSearchParams(body); + const grantType = params.get("grant_type"); + + if (grantType === "authorization_code") { + // Verify PKCE parameters + const codeVerifier = params.get("code_verifier"); + const resource = params.get("resource"); + + if (!codeVerifier) { + return new Response( + JSON.stringify({ error: "missing_code_verifier" }), + { + status: 400, + headers: { "Content-Type": "application/json" }, + }, + ); + } + + if (!resource) { + return new Response( + JSON.stringify({ error: "missing_resource" }), + { + status: 400, + headers: { "Content-Type": "application/json" }, + }, + ); + } + + return new Response(JSON.stringify(mockTokenResponse), { + headers: { "Content-Type": "application/json" }, + }); + } + + if (grantType === "refresh_token") { + const refreshToken = params.get("refresh_token"); + const resource = params.get("resource"); + + if (!refreshToken) { + return new Response( + JSON.stringify({ error: "missing_refresh_token" }), + { + status: 400, + headers: { "Content-Type": "application/json" }, + }, + ); + } + + if (!resource) { + return new Response( + JSON.stringify({ error: "missing_resource" }), + { + status: 400, + headers: { "Content-Type": "application/json" }, + }, + ); + } + + // Return new tokens with updated expiry + return new Response( + JSON.stringify({ + ...mockTokenResponse, + access_token: "refreshed_access_token", + }), + { headers: { "Content-Type": "application/json" } }, + ); + } + + return new Response( + JSON.stringify({ error: "unsupported_grant_type" }), + { + status: 400, + headers: { "Content-Type": "application/json" }, + }, + ); + } + + return new Response("Not Found", { status: 404 }); + }, + }); + + oauthServerUrl = `http://localhost:${oauthServer.port}`; + + // Create OAuth-protected MCP server + mcpServer = serve({ + port: 0, + async fetch(request) { + const url = new URL(request.url); + + // Resource server discovery endpoint (RFC 8707) + if (url.pathname === "/.well-known/oauth-protected-resource") { + return new Response( + JSON.stringify({ + resource: mcpServerUrl, + authorization_servers: [oauthServerUrl], + scopes_supported: ["mcp:access"], + }), + { headers: { "Content-Type": "application/json" } }, + ); + } + + // Check authorization header + const authHeader = request.headers.get("Authorization"); + if (!authHeader || !authHeader.startsWith("Bearer ")) { + return new Response("Unauthorized", { + status: 401, + headers: { "WWW-Authenticate": 'Bearer realm="MCP Server"' }, + }); + } + + // MCP initialize endpoint + if (request.method === "POST") { + const body = await request.json(); + if (body.method === "initialize") { + return new Response( + JSON.stringify({ + jsonrpc: "2.0", + id: body.id, + result: { + protocolVersion: "2025-06-18", + serverInfo: { name: "oauth-test-server", version: "1.0.0" }, + capabilities: { tools: {} }, + }, + }), + { headers: { "Content-Type": "application/json" } }, + ); + } + } + + return new Response("Not Found", { status: 404 }); + }, + }); + + mcpServerUrl = `http://localhost:${mcpServer.port}`; + }); + + afterEach(() => { + oauthServer?.stop(); + mcpServer?.stop(); + authorizationCallbackUrl = null; + }); + + it("should discover OAuth endpoints from MCP server", async () => { + const endpoints = await discoverOAuthEndpoints(mcpServerUrl); + + expect(endpoints.authorizationServer).toBe(oauthServerUrl); + expect(endpoints.authorizationEndpoint).toBe(`${oauthServerUrl}/authorize`); + expect(endpoints.tokenEndpoint).toBe(`${oauthServerUrl}/token`); + expect(endpoints.scopes).toEqual(["mcp:access"]); + }); + + it("should start OAuth authorization flow with correct parameters", async () => { + const provider = new StandardOAuthProvider(); + const endpoints = await discoverOAuthEndpoints(mcpServerUrl); + + const result = await provider.startAuthorizationFlow({ + authorizationEndpoint: endpoints.authorizationEndpoint, + clientId: "test-client-id", + redirectUri: "http://localhost:3000/callback", + scopes: endpoints.scopes, + resource: mcpServerUrl, + }); + + expect(result.authorizationUrl).toContain(endpoints.authorizationEndpoint); + expect(result.authorizationUrl).toContain("client_id=test-client-id"); + expect(result.authorizationUrl).toContain("redirect_uri="); + expect(result.authorizationUrl).toContain("code_challenge="); + expect(result.authorizationUrl).toContain("code_challenge_method=S256"); + expect(result.authorizationUrl).toContain( + `resource=${encodeURIComponent(mcpServerUrl)}`, + ); + expect(result.authorizationUrl).toContain("scope=mcp%3Aaccess"); + expect(result.codeVerifier).toHaveLength(43); + expect(result.state).toBeTruthy(); + }); + + it("should exchange authorization code for tokens with PKCE", async () => { + const provider = new StandardOAuthProvider(); + const endpoints = await discoverOAuthEndpoints(mcpServerUrl); + + const flowResult = await provider.startAuthorizationFlow({ + authorizationEndpoint: endpoints.authorizationEndpoint, + clientId: "test-client-id", + redirectUri: "http://localhost:3000/callback", + scopes: endpoints.scopes, + resource: mcpServerUrl, + }); + + const tokens = await provider.exchangeCodeForTokens({ + tokenEndpoint: endpoints.tokenEndpoint, + code: "mock_authorization_code", + codeVerifier: flowResult.codeVerifier, + clientId: "test-client-id", + redirectUri: "http://localhost:3000/callback", + resource: mcpServerUrl, + }); + + expect(tokens.accessToken).toBe("mock_access_token_12345"); + expect(tokens.refreshToken).toBe("mock_refresh_token_67890"); + expect(tokens.tokenType).toBe("Bearer"); + expect(tokens.scopes).toEqual(["mcp:access"]); + expect(tokens.expiresAt).toBeGreaterThan(Date.now() / 1000); + }); + + it("should refresh expired tokens", async () => { + const provider = new StandardOAuthProvider(); + const endpoints = await discoverOAuthEndpoints(mcpServerUrl); + + const newTokens = await provider.refreshAccessToken({ + tokenEndpoint: endpoints.tokenEndpoint, + refreshToken: "mock_refresh_token_67890", + clientId: "test-client-id", + resource: mcpServerUrl, + }); + + expect(newTokens.accessToken).toBe("refreshed_access_token"); + expect(newTokens.tokenType).toBe("Bearer"); + }); + + it("should handle 401 response and start OAuth flow", async () => { + const adapter = new InMemoryOAuthAdapter(); + const provider = new StandardOAuthProvider(); + + const onAuthorizationRequired = mock((url: string) => { + authorizationCallbackUrl = url; + }); + + const oauthConfig: OAuthConfig = { + clientId: "test-client-id", + redirectUri: "http://localhost:3000/callback", + onAuthorizationRequired, + }; + + const client = new McpClient({ + name: "oauth-test-client", + version: "1.0.0", + }); + + const transport = new StreamableHttpClientTransport({ + oauthAdapter: adapter, + oauthProvider: provider, + oauthConfig, + }); + + const connect = transport.bind(client); + + // First connection attempt should fail with 401 and start OAuth flow + await expect(connect(mcpServerUrl)).rejects.toThrow( + "Authentication required", + ); + + // Verify authorization callback was invoked + expect(onAuthorizationRequired).toHaveBeenCalledTimes(1); + expect(authorizationCallbackUrl).toContain(`${oauthServerUrl}/authorize`); + expect(authorizationCallbackUrl).toContain("client_id=test-client-id"); + }); + + it("should complete authorization flow and store tokens", async () => { + const adapter = new InMemoryOAuthAdapter(); + const provider = new StandardOAuthProvider(); + + let capturedAuthUrl: string | null = null; + + const oauthConfig: OAuthConfig = { + clientId: "test-client-id", + redirectUri: "http://localhost:3000/callback", + onAuthorizationRequired: (url: string) => { + capturedAuthUrl = url; + }, + }; + + const client = new McpClient({ + name: "oauth-test-client", + version: "1.0.0", + }); + + const transport = new StreamableHttpClientTransport({ + oauthAdapter: adapter, + oauthProvider: provider, + oauthConfig, + }); + + const connect = transport.bind(client); + + // Start OAuth flow + await expect(connect(mcpServerUrl)).rejects.toThrow(); + + // Extract state from authorization URL + const authUrl = new URL(capturedAuthUrl!); + const state = authUrl.searchParams.get("state")!; + + // Complete authorization flow + await transport.completeAuthorizationFlow( + mcpServerUrl, + "mock_authorization_code", + state, + ); + + // Verify tokens were stored + const tokens = await adapter.getTokens(mcpServerUrl); + expect(tokens).toBeDefined(); + expect(tokens?.accessToken).toBe("mock_access_token_12345"); + expect(tokens?.refreshToken).toBe("mock_refresh_token_67890"); + }); + + it("should reject authorization flow with invalid state", async () => { + const adapter = new InMemoryOAuthAdapter(); + const provider = new StandardOAuthProvider(); + + const oauthConfig: OAuthConfig = { + clientId: "test-client-id", + redirectUri: "http://localhost:3000/callback", + onAuthorizationRequired: () => {}, + }; + + const client = new McpClient({ + name: "oauth-test-client", + version: "1.0.0", + }); + + const transport = new StreamableHttpClientTransport({ + oauthAdapter: adapter, + oauthProvider: provider, + oauthConfig, + }); + + const connect = transport.bind(client); + + // Start OAuth flow + await expect(connect(mcpServerUrl)).rejects.toThrow(); + + // Try to complete with wrong state + await expect( + transport.completeAuthorizationFlow( + mcpServerUrl, + "mock_authorization_code", + "wrong_state_value", + ), + ).rejects.toThrow("State parameter mismatch"); + }); + + it("should successfully connect with valid OAuth tokens", async () => { + const adapter = new InMemoryOAuthAdapter(); + const provider = new StandardOAuthProvider(); + + let capturedAuthUrl: string | null = null; + + const oauthConfig: OAuthConfig = { + clientId: "test-client-id", + redirectUri: "http://localhost:3000/callback", + onAuthorizationRequired: (url: string) => { + capturedAuthUrl = url; + }, + }; + + const client = new McpClient({ + name: "oauth-test-client", + version: "1.0.0", + }); + + const transport = new StreamableHttpClientTransport({ + oauthAdapter: adapter, + oauthProvider: provider, + oauthConfig, + }); + + const connect = transport.bind(client); + + // Start OAuth flow + await expect(connect(mcpServerUrl)).rejects.toThrow(); + + // Complete authorization + const authUrl = new URL(capturedAuthUrl!); + const state = authUrl.searchParams.get("state")!; + await transport.completeAuthorizationFlow( + mcpServerUrl, + "mock_authorization_code", + state, + ); + + // Now connection should succeed with stored token + const connection = await connect(mcpServerUrl); + expect(connection.serverInfo.name).toBe("oauth-test-server"); + }); + + it("should support multiple MCP servers with different tokens", async () => { + // Create second OAuth-protected MCP server + const mcpServer2 = serve({ + port: 0, + async fetch(request) { + const url = new URL(request.url); + + if (url.pathname === "/.well-known/oauth-protected-resource") { + return new Response( + JSON.stringify({ + resource: `http://localhost:${mcpServer2.port}`, + authorization_servers: [oauthServerUrl], + scopes_supported: ["mcp:access"], + }), + { headers: { "Content-Type": "application/json" } }, + ); + } + + const authHeader = request.headers.get("Authorization"); + if (!authHeader?.startsWith("Bearer ")) { + return new Response("Unauthorized", { status: 401 }); + } + + if (request.method === "POST") { + const body = await request.json(); + if (body.method === "initialize") { + return new Response( + JSON.stringify({ + jsonrpc: "2.0", + id: body.id, + result: { + protocolVersion: "2025-06-18", + serverInfo: { name: "oauth-test-server-2", version: "1.0.0" }, + capabilities: { tools: {} }, + }, + }), + { headers: { "Content-Type": "application/json" } }, + ); + } + } + + return new Response("Not Found", { status: 404 }); + }, + }); + + const mcpServer2Url = `http://localhost:${mcpServer2.port}`; + + const adapter = new InMemoryOAuthAdapter(); + const provider = new StandardOAuthProvider(); + + // Store tokens for first server + await adapter.storeTokens(mcpServerUrl, { + accessToken: "token_for_server_1", + tokenType: "Bearer", + expiresAt: Math.floor(Date.now() / 1000) + 3600, + scopes: ["mcp:access"], + }); + + // Store tokens for second server + await adapter.storeTokens(mcpServer2Url, { + accessToken: "token_for_server_2", + tokenType: "Bearer", + expiresAt: Math.floor(Date.now() / 1000) + 3600, + scopes: ["mcp:access"], + }); + + // Verify tokens are stored separately + const tokens1 = await adapter.getTokens(mcpServerUrl); + const tokens2 = await adapter.getTokens(mcpServer2Url); + + expect(tokens1?.accessToken).toBe("token_for_server_1"); + expect(tokens2?.accessToken).toBe("token_for_server_2"); + expect(tokens1?.accessToken).not.toBe(tokens2?.accessToken); + + mcpServer2.stop(); + }); + + it("should validate that token has not expired with buffer", () => { + const adapter = new InMemoryOAuthAdapter(); + const now = Math.floor(Date.now() / 1000); + + // Store token that expires in 10 minutes + adapter.storeTokens(mcpServerUrl, { + accessToken: "valid_token", + tokenType: "Bearer", + expiresAt: now + 600, // 10 minutes from now + scopes: ["mcp:access"], + }); + + expect(adapter.hasValidToken(mcpServerUrl)).toBe(true); + + // Store token that expires in 2 minutes (within 5-minute buffer) + adapter.storeTokens(mcpServerUrl, { + accessToken: "expiring_soon_token", + tokenType: "Bearer", + expiresAt: now + 120, // 2 minutes from now + scopes: ["mcp:access"], + }); + + expect(adapter.hasValidToken(mcpServerUrl)).toBe(false); + }); + + it("should throw error if OAuth config is incomplete", () => { + const adapter = new InMemoryOAuthAdapter(); + + // Missing provider and config + expect(() => { + new StreamableHttpClientTransport({ + oauthAdapter: adapter, + }); + }).toThrow("OAuth configuration incomplete"); + + // Missing adapter and config + expect(() => { + new StreamableHttpClientTransport({ + oauthProvider: new StandardOAuthProvider(), + }); + }).toThrow("OAuth configuration incomplete"); + }); + + it("should use origin for discovery when baseUrl has a path (RFC 8707)", async () => { + // Create a server that handles both /mcp endpoint and origin-based discovery + const serverWithPath = serve({ + port: 0, + async fetch(request) { + const url = new URL(request.url); + + // Discovery MUST be at origin, not at /mcp/.well-known/... + if (url.pathname === "/.well-known/oauth-protected-resource") { + return new Response( + JSON.stringify({ + resource: `http://localhost:${serverWithPath.port}/mcp`, + authorization_servers: [oauthServerUrl], + scopes_supported: ["mcp:access"], + }), + { headers: { "Content-Type": "application/json" } }, + ); + } + + // This should NOT be called for discovery + if (url.pathname === "/mcp/.well-known/oauth-protected-resource") { + return new Response("Wrong path - should use origin", { + status: 404, + }); + } + + return new Response("Not Found", { status: 404 }); + }, + }); + + const serverWithPathUrl = `http://localhost:${serverWithPath.port}/mcp`; + + const endpoints = await discoverOAuthEndpoints(serverWithPathUrl); + + expect(endpoints.authorizationServer).toBe(oauthServerUrl); + expect(endpoints.authorizationEndpoint).toBe(`${oauthServerUrl}/authorize`); + expect(endpoints.tokenEndpoint).toBe(`${oauthServerUrl}/token`); + + serverWithPath.stop(); + }); + + it("should fallback to WWW-Authenticate header when origin discovery fails", async () => { + // Create a server that doesn't have origin-based discovery + // but provides as_uri in WWW-Authenticate header + const serverWithoutDiscovery = serve({ + port: 0, + async fetch(request) { + const url = new URL(request.url); + + // No resource metadata at origin + if (url.pathname === "/.well-known/oauth-protected-resource") { + return new Response("Not Found", { status: 404 }); + } + + // MCP endpoint returns 401 with WWW-Authenticate header + if (url.pathname === "/mcp" && request.method === "POST") { + return new Response("Unauthorized", { + status: 401, + headers: { + "WWW-Authenticate": `Bearer realm="MCP Server", as_uri="${oauthServerUrl}/.well-known/oauth-authorization-server"`, + }, + }); + } + + return new Response("Not Found", { status: 404 }); + }, + }); + + const serverWithoutDiscoveryUrl = `http://localhost:${serverWithoutDiscovery.port}/mcp`; + + const endpoints = await discoverOAuthEndpoints(serverWithoutDiscoveryUrl); + + expect(endpoints.authorizationServer).toBe(oauthServerUrl); + expect(endpoints.authorizationEndpoint).toBe(`${oauthServerUrl}/authorize`); + expect(endpoints.tokenEndpoint).toBe(`${oauthServerUrl}/token`); + expect(endpoints.scopes).toEqual([]); // No scopes from resource metadata + + serverWithoutDiscovery.stop(); + }); + + it("should fail gracefully when neither discovery method works", async () => { + // Create a server that has no discovery mechanism + const serverWithNoDiscovery = serve({ + port: 0, + async fetch(_request) { + return new Response("Not Found", { status: 404 }); + }, + }); + + const serverWithNoDiscoveryUrl = `http://localhost:${serverWithNoDiscovery.port}/mcp`; + + await expect( + discoverOAuthEndpoints(serverWithNoDiscoveryUrl), + ).rejects.toThrow("Failed to fetch resource metadata"); + + serverWithNoDiscovery.stop(); + }); +}); diff --git a/packages/core/tests/integration/client-server-requests.test.ts b/packages/core/tests/integration/client-server-requests.test.ts new file mode 100644 index 0000000..b75c32d --- /dev/null +++ b/packages/core/tests/integration/client-server-requests.test.ts @@ -0,0 +1,328 @@ +/** biome-ignore-all lint/style/noNonNullAssertion: tests */ +import { afterEach, beforeEach, describe, expect, test } from "bun:test"; +import { + collectSseEventsCount, + createTestHarness, + openSessionStream, + type TestServer, +} from "@internal/test-utils"; +import { + InMemoryClientRequestAdapter, + InMemoryClientSessionAdapter, + InMemorySessionAdapter, + McpClient, + McpServer, + StreamableHttpClientTransport, + StreamableHttpTransport, +} from "../../src/index.js"; + +describe("MCP Client - Server-Initiated Requests", () => { + let testServer: TestServer; + let mcpServer: McpServer; + let serverUrl: string; + + beforeEach(async () => { + // Create server with elicitation support + mcpServer = new McpServer({ + name: "test-server", + version: "1.0.0", + }); + + testServer = await createTestHarness(mcpServer, { + sessionAdapter: new InMemorySessionAdapter({ maxEventBufferSize: 1024 }), + clientRequestAdapter: new InMemoryClientRequestAdapter(), + }); + serverUrl = testServer.url; + }); + + afterEach(async () => { + await testServer.stop(); + }); + + test("should handle elicitation request from server", async () => { + // Server tool that requests elicitation + mcpServer.tool("ask-user", { + description: "Asks user for input", + handler: async (_, ctx) => { + if (!ctx.client.supports("elicitation")) { + return { + content: [{ type: "text", text: "No elicitation support" }], + }; + } + + const result = await ctx.elicit({ + message: "What is your name?", + schema: { type: "object", properties: { name: { type: "string" } } }, + }); + + if (result.action === "accept") { + return { + content: [ + { + type: "text", + text: `Hello, ${result.content?.name}!`, + }, + ], + }; + } + + return { + content: [{ type: "text", text: `Action: ${result.action}` }], + }; + }, + }); + + const client = new McpClient({ + name: "test-client", + version: "1.0.0", + capabilities: { + elicitation: {}, + }, + }); + + // Register elicitation handler + client.onElicit(async (params, _ctx) => { + expect(params.message).toBe("What is your name?"); + expect(params.requestedSchema).toBeDefined(); + + return { + action: "accept", + content: { name: "Alice" }, + }; + }); + + const transport = new StreamableHttpClientTransport({ + sessionAdapter: new InMemoryClientSessionAdapter(), + }); + const connect = transport.bind(client); + const connection = await connect(serverUrl); + + // Open SSE stream to enable server request handling + await connection.openSessionStream(); + + // Call tool that will trigger elicitation + const result = await connection.callTool("ask-user", {}); + expect(result.content[0].text).toBe("Hello, Alice!"); + + await connection.close(true); + }); + + test("should handle elicitation decline", async () => { + mcpServer.tool("ask-user", { + handler: async (_, ctx) => { + const result = await ctx.elicit({ + message: "What is your age?", + schema: { type: "object", properties: { age: { type: "number" } } }, + }); + + return { + content: [ + { + type: "text", + text: `Action: ${result.action}`, + }, + ], + }; + }, + }); + + const client = new McpClient({ + name: "test-client", + version: "1.0.0", + capabilities: { elicitation: {} }, + }); + + client.onElicit(async () => ({ + action: "decline", + })); + + const transport = new StreamableHttpClientTransport({ + sessionAdapter: new InMemoryClientSessionAdapter(), + }); + const connect = transport.bind(client); + const connection = await connect(serverUrl); + + await connection.openSessionStream(); + + const result = await connection.callTool("ask-user", {}); + expect(result.content[0].text).toBe("Action: decline"); + + await connection.close(true); + }); + + test("should handle elicitation cancel", async () => { + mcpServer.tool("ask-user", { + handler: async (_, ctx) => { + const result = await ctx.elicit({ + message: "Confirm action?", + schema: { type: "object", properties: {} }, + }); + + return { + content: [ + { + type: "text", + text: `Action: ${result.action}`, + }, + ], + }; + }, + }); + + const client = new McpClient({ + name: "test-client", + version: "1.0.0", + capabilities: { elicitation: {} }, + }); + + client.onElicit(async () => ({ + action: "cancel", + })); + + const transport = new StreamableHttpClientTransport({ + sessionAdapter: new InMemoryClientSessionAdapter(), + }); + const connect = transport.bind(client); + const connection = await connect(serverUrl); + + await connection.openSessionStream(); + + const result = await connection.callTool("ask-user", {}); + expect(result.content[0].text).toBe("Action: cancel"); + + await connection.close(true); + }); + + test("should run middleware for server requests", async () => { + const log: string[] = []; + + mcpServer.tool("ask-user", { + handler: async (_, ctx) => { + const result = await ctx.elicit({ + message: "Test", + schema: { type: "object", properties: {} }, + }); + return { content: [{ type: "text", text: "done" }] }; + }, + }); + + const client = new McpClient({ + name: "test-client", + version: "1.0.0", + capabilities: { elicitation: {} }, + }); + + client.onElicit(async () => { + log.push("handler"); + return { action: "accept", content: {} }; + }); + + const transport = new StreamableHttpClientTransport({ + sessionAdapter: new InMemoryClientSessionAdapter(), + }); + const connect = transport.bind(client); + const connection = await connect(serverUrl); + + await connection.openSessionStream(); + await connection.callTool("ask-user", {}); + + // Wait a bit for async processing + await new Promise((resolve) => setTimeout(resolve, 100)); + + expect(log).toEqual(["handler"]); + + await connection.close(true); + }); + + test("should handle error in elicitation handler", async () => { + mcpServer.tool("ask-user", { + handler: async (_, ctx) => { + try { + await ctx.elicit({ + message: "Test", + schema: { type: "object" }, + }); + return { content: [{ type: "text", text: "should not reach" }] }; + } catch (error) { + return { + content: [ + { + type: "text", + text: `Error: ${error instanceof Error ? error.message : "Unknown"}`, + }, + ], + }; + } + }, + }); + + const client = new McpClient({ + name: "test-client", + version: "1.0.0", + capabilities: { elicitation: {} }, + }); + + client.onElicit(async () => { + throw new Error("Handler failed"); + }); + + const transport = new StreamableHttpClientTransport({ + sessionAdapter: new InMemoryClientSessionAdapter(), + }); + const connect = transport.bind(client); + const connection = await connect(serverUrl); + + await connection.openSessionStream(); + + const result = await connection.callTool("ask-user", {}); + expect(result.content[0].text).toContain("Handler failed"); + + await connection.close(true); + }); + + test("should handle missing handler gracefully", async () => { + mcpServer.tool("ask-user", { + handler: async (_, ctx) => { + try { + await ctx.elicit({ + message: "Test", + schema: { type: "object" }, + }); + return { content: [{ type: "text", text: "should not reach" }] }; + } catch (error) { + return { + content: [ + { + type: "text", + text: `Error: ${error instanceof Error ? error.message : "Unknown"}`, + }, + ], + }; + } + }, + }); + + const client = new McpClient({ + name: "test-client", + version: "1.0.0", + capabilities: { elicitation: {} }, + }); + + // No handler registered! + + const transport = new StreamableHttpClientTransport({ + sessionAdapter: new InMemoryClientSessionAdapter(), + }); + const connect = transport.bind(client); + const connection = await connect(serverUrl); + + await connection.openSessionStream(); + + const result = await connection.callTool("ask-user", {}); + expect(result.content[0].text).toContain( + "No elicitation handler registered", + ); + + await connection.close(true); + }); +}); diff --git a/packages/core/tests/integration/client-session.test.ts b/packages/core/tests/integration/client-session.test.ts new file mode 100644 index 0000000..b3e3ecd --- /dev/null +++ b/packages/core/tests/integration/client-session.test.ts @@ -0,0 +1,381 @@ +import { afterEach, beforeEach, describe, expect, it } from "bun:test"; +import { + collectSseEvents, + collectSseEventsCount, + createTestHarness, + openSessionStream, + type TestServer, +} from "@internal/test-utils"; +import { + InMemoryClientSessionAdapter, + InMemorySessionAdapter, + McpClient, + McpServer, + StreamableHttpClientTransport, +} from "../../src/index.js"; + +describe("MCP Client - Session Management", () => { + let testServer: TestServer; + let mcpServer: McpServer; + let serverUrl: string; + + beforeEach(async () => { + // Create server with session support + mcpServer = new McpServer({ + name: "test-server", + version: "1.0.0", + }); + + mcpServer.tool("echo", { + description: "Echoes input", + handler: (args: { message: string }) => ({ + content: [{ type: "text", text: args.message }], + }), + }); + + // Tool that sends progress notifications + mcpServer.tool("longTask", { + description: "Task with progress", + handler: async (args: { count: number }, ctx) => { + for (let i = 1; i <= args.count; i++) { + await ctx.progress?.({ + progress: i, + total: args.count, + message: `step ${i}`, + }); + } + return { content: [{ type: "text", text: `done ${args.count}` }] }; + }, + }); + + testServer = await createTestHarness(mcpServer, { + sessionAdapter: new InMemorySessionAdapter({ maxEventBufferSize: 1024 }), + }); + serverUrl = testServer.url; + }); + + afterEach(async () => { + await testServer.stop(); + }); + + it("should initialize session with server", async () => { + const client = new McpClient({ + name: "test-client", + version: "1.0.0", + }); + + const clientSessionAdapter = new InMemoryClientSessionAdapter(); + const transport = new StreamableHttpClientTransport({ + sessionAdapter: clientSessionAdapter, + }); + const connect = transport.bind(client); + + const connection = await connect(serverUrl); + + // Should have session ID + expect(connection.sessionId).toBeDefined(); + expect(connection.serverInfo.name).toBe("test-server"); + + // Session should be stored in adapter + const sessionId = connection.sessionId; + expect(sessionId).toBeDefined(); + const sessionData = await clientSessionAdapter.get(sessionId!); + expect(sessionData).toBeDefined(); + expect(sessionData?.serverInfo.name).toBe("test-server"); + expect(sessionData?.protocolVersion).toBe("2025-06-18"); + + await connection.close(true); + }); + + it("should open and receive notifications via session stream", async () => { + const client = new McpClient({ + name: "test-client", + version: "1.0.0", + }); + + const transport = new StreamableHttpClientTransport({ + sessionAdapter: new InMemoryClientSessionAdapter(), + }); + const connect = transport.bind(client); + const connection = await connect(serverUrl); + + // For this test, we want to observe events, so use test-utils helper + // (don't use connection.openSessionStream() - server only allows one stream) + const sessionId = connection.sessionId; + expect(sessionId).toBeDefined(); + const stream = await openSessionStream(serverUrl, sessionId!); + + // Start collecting events with timeout + // Will collect all events that arrive within 1000ms + const eventsPromise = collectSseEvents(stream, 1000); + + // Small delay to ensure stream is ready + await new Promise((resolve) => setTimeout(resolve, 50)); + + // Make a tool call with progress token + await fetch(serverUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + "MCP-Protocol-Version": "2025-06-18", + "MCP-Session-Id": sessionId!, + }, + body: JSON.stringify({ + jsonrpc: "2.0", + id: "call1", + method: "tools/call", + params: { + _meta: { progressToken: "test-token" }, + name: "longTask", + arguments: { count: 3 }, + }, + }), + }); + + // Wait for events to be collected + const events = await eventsPromise; + + // Should have at least 3 progress notifications + // May also have initial ping event + expect(events.length).toBeGreaterThanOrEqual(3); + + // Find progress notifications (filtering out ping if present) + const progressEvents = events.filter( + (e) => e.data.method === "notifications/progress", + ); + expect(progressEvents).toHaveLength(3); + + expect(progressEvents[0].data.params.progressToken).toBe("test-token"); + expect(progressEvents[0].data.params.progress).toBe(1); + expect(progressEvents[1].data.params.progress).toBe(2); + expect(progressEvents[2].data.params.progress).toBe(3); + + await connection.close(true); + }); + + it("should reconnect with Last-Event-ID for replay", async () => { + const client = new McpClient({ + name: "test-client", + version: "1.0.0", + }); + + const transport = new StreamableHttpClientTransport({ + sessionAdapter: new InMemoryClientSessionAdapter(), + }); + const connect = transport.bind(client); + const connection = await connect(serverUrl); + + // Generate some events first + const sessionId = connection.sessionId; + expect(sessionId).toBeDefined(); + + await fetch(serverUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + "MCP-Protocol-Version": "2025-06-18", + "MCP-Session-Id": sessionId!, + }, + body: JSON.stringify({ + jsonrpc: "2.0", + id: "setup", + method: "tools/call", + params: { + _meta: { progressToken: "replay-test" }, + name: "longTask", + arguments: { count: 3 }, + }, + }), + }); + + // Wait a bit for events to be stored + await new Promise((resolve) => setTimeout(resolve, 100)); + + // Reconnect asking for replay from event 1 (using test-utils for observation) + const stream = await openSessionStream( + serverUrl, + sessionId!, + "1#_GET_stream", + ); + const events = await collectSseEvents(stream, 1000); + + // Should receive events 2 and 3 (after event 1) + expect(events.length).toBeGreaterThan(0); + expect(events[0].id).toBe("2#_GET_stream"); + expect(events[1].id).toBe("3#_GET_stream"); + + await connection.close(true); + }); + + it("should close and delete session", async () => { + const client = new McpClient({ + name: "test-client", + version: "1.0.0", + }); + + const clientSessionAdapter = new InMemoryClientSessionAdapter(); + const transport = new StreamableHttpClientTransport({ + sessionAdapter: clientSessionAdapter, + }); + const connect = transport.bind(client); + const connection = await connect(serverUrl); + + const sessionId = connection.sessionId; + expect(sessionId).toBeDefined(); + + // Verify session exists on client side + expect(await clientSessionAdapter.get(sessionId!)).toBeDefined(); + + // Close with delete + await connection.close(true); + + // Try to open a GET stream with deleted session - should fail + // (POST requests don't validate session existence, but GET does) + const response = await fetch(serverUrl, { + method: "GET", + headers: { + Accept: "text/event-stream", + "MCP-Protocol-Version": "2025-06-18", + "MCP-Session-Id": sessionId!, + }, + }); + + expect(response.status).toBe(400); // Session no longer exists + const errorText = await response.text(); + expect(errorText).toContain("Invalid or missing session ID"); + }); + + it("should support multiple concurrent sessions", async () => { + const client = new McpClient({ + name: "test-client", + version: "1.0.0", + }); + + const transport = new StreamableHttpClientTransport({ + sessionAdapter: new InMemoryClientSessionAdapter(), + }); + const connect = transport.bind(client); + + // Create 3 separate sessions + const conn1 = await connect(serverUrl); + const conn2 = await connect(serverUrl); + const conn3 = await connect(serverUrl); + + expect(conn1.sessionId).toBeDefined(); + expect(conn2.sessionId).toBeDefined(); + expect(conn3.sessionId).toBeDefined(); + + // All should be different + expect(conn1.sessionId).not.toBe(conn2.sessionId); + expect(conn2.sessionId).not.toBe(conn3.sessionId); + + // All should work independently + const result1 = await conn1.callTool("echo", { message: "First" }); + const result2 = await conn2.callTool("echo", { message: "Second" }); + const result3 = await conn3.callTool("echo", { message: "Third" }); + + expect(result1.content[0].text).toBe("First"); + expect(result2.content[0].text).toBe("Second"); + expect(result3.content[0].text).toBe("Third"); + + await conn1.close(true); + await conn2.close(true); + await conn3.close(true); + }); + + it("should handle session stream closure gracefully", async () => { + const client = new McpClient({ + name: "test-client", + version: "1.0.0", + }); + + const transport = new StreamableHttpClientTransport({ + sessionAdapter: new InMemoryClientSessionAdapter(), + }); + const connect = transport.bind(client); + const connection = await connect(serverUrl); + + await connection.openSessionStream(); + + // Close the stream + connection.closeSessionStream(); + + // Should be able to open a new one + await connection.openSessionStream(); + + await connection.close(true); + }); + + it("should work in stateless mode (no session adapter)", async () => { + // Create a separate server WITHOUT session adapter for stateless testing + const statelessServer = new McpServer({ + name: "stateless-server", + version: "1.0.0", + }); + + statelessServer.tool("echo", { + description: "Echoes input", + handler: (args: { message: string }) => ({ + content: [{ type: "text", text: args.message }], + }), + }); + + // Create test harness WITHOUT sessionAdapter (stateless mode) + const statelessTestServer = await createTestHarness(statelessServer, {}); + const statelessUrl = statelessTestServer.url; + + try { + const client = new McpClient({ + name: "test-client", + version: "1.0.0", + }); + + // No session adapter = stateless + const transport = new StreamableHttpClientTransport(); + const connect = transport.bind(client); + const connection = await connect(statelessUrl); + + // Should not have session ID + expect(connection.sessionId).toBeUndefined(); + + // Should still work for basic operations + const result = await connection.callTool("echo", { message: "Test" }); + expect(result.content[0].text).toBe("Test"); + + // Should fail to open session stream + await expect(connection.openSessionStream()).rejects.toThrow( + "Cannot open session stream without session ID", + ); + } finally { + await statelessTestServer.stop(); + } + }); + + it("should retrieve stored session data", async () => { + const client = new McpClient({ + name: "test-client", + version: "1.0.0", + }); + + const clientSessionAdapter = new InMemoryClientSessionAdapter(); + const transport = new StreamableHttpClientTransport({ + sessionAdapter: clientSessionAdapter, + }); + const connect = transport.bind(client); + const connection = await connect(serverUrl); + + const sessionId = connection.sessionId; + expect(sessionId).toBeDefined(); + const sessionData = await clientSessionAdapter.get(sessionId!); + + expect(sessionData).toBeDefined(); + expect(sessionData?.serverInfo).toEqual({ + name: "test-server", + version: "1.0.0", + }); + expect(sessionData?.serverCapabilities.tools).toBeDefined(); + expect(sessionData?.createdAt).toBeGreaterThan(0); + + await connection.close(true); + }); +}); diff --git a/packages/core/tests/integration/client-stateless.test.ts b/packages/core/tests/integration/client-stateless.test.ts new file mode 100644 index 0000000..577f945 --- /dev/null +++ b/packages/core/tests/integration/client-stateless.test.ts @@ -0,0 +1,184 @@ +import { afterEach, beforeEach, describe, expect, it } from "bun:test"; +import { createTestHarness, type TestServer } from "@internal/test-utils"; +import { + McpClient, + McpServer, + StreamableHttpClientTransport, +} from "../../src/index.js"; + +describe("MCP Client - Stateless Operations", () => { + let testServer: TestServer; + let mcpServer: McpServer; + let serverUrl: string; + + beforeEach(async () => { + // Create a real MCP server with tools + mcpServer = new McpServer({ + name: "test-server", + version: "1.0.0", + }); + + mcpServer.tool("echo", { + description: "Echoes input", + handler: (args: { message: string }) => ({ + content: [{ type: "text", text: args.message }], + }), + }); + + mcpServer.tool("add", { + description: "Adds two numbers", + handler: (args: { a: number; b: number }) => ({ + content: [{ type: "text", text: String(args.a + args.b) }], + }), + }); + + mcpServer.prompt("greet", { + description: "Greeting prompt", + handler: () => ({ + messages: [{ role: "user", content: { type: "text", text: "Hello!" } }], + }), + }); + + mcpServer.resource( + "file://test.txt", + { + description: "Test file", + }, + async () => ({ + contents: [ + { uri: "file://test.txt", type: "text", text: "Test content" }, + ], + }), + ); + + testServer = await createTestHarness(mcpServer); + serverUrl = testServer.url; + }); + + afterEach(async () => { + await testServer.stop(); + }); + + it("should initialize connection to server", async () => { + const client = new McpClient({ + name: "test-client", + version: "1.0.0", + }); + + const transport = new StreamableHttpClientTransport(); + const connect = transport.bind(client); + + const connection = await connect(serverUrl); + + expect(connection.serverInfo.name).toBe("test-server"); + expect(connection.serverInfo.version).toBe("1.0.0"); + expect(connection.serverCapabilities.tools).toBeDefined(); + }); + + it("should list tools from server", async () => { + const client = new McpClient({ + name: "test-client", + version: "1.0.0", + }); + + const transport = new StreamableHttpClientTransport(); + const connect = transport.bind(client); + const connection = await connect(serverUrl); + + const { tools } = await connection.listTools(); + + expect(tools).toHaveLength(2); + expect(tools[0].name).toBe("echo"); + expect(tools[1].name).toBe("add"); + }); + + it("should call a tool successfully", async () => { + const client = new McpClient({ + name: "test-client", + version: "1.0.0", + }); + + const transport = new StreamableHttpClientTransport(); + const connect = transport.bind(client); + const connection = await connect(serverUrl); + + const result = await connection.callTool("echo", { + message: "Hello World", + }); + + expect(result.content).toHaveLength(1); + expect(result.content[0].type).toBe("text"); + expect(result.content[0].text).toBe("Hello World"); + }); + + it("should handle tool call errors", async () => { + const client = new McpClient({ + name: "test-client", + version: "1.0.0", + }); + + const transport = new StreamableHttpClientTransport(); + const connect = transport.bind(client); + const connection = await connect(serverUrl); + + await expect(connection.callTool("nonexistent", {})).rejects.toThrow(); + }); + + it("should list and get prompts", async () => { + const client = new McpClient({ + name: "test-client", + version: "1.0.0", + }); + + const transport = new StreamableHttpClientTransport(); + const connect = transport.bind(client); + const connection = await connect(serverUrl); + + const { prompts } = await connection.listPrompts(); + expect(prompts).toHaveLength(1); + expect(prompts[0].name).toBe("greet"); + + const result = await connection.getPrompt("greet"); + expect(result.messages).toHaveLength(1); + }); + + it("should list and read resources", async () => { + const client = new McpClient({ + name: "test-client", + version: "1.0.0", + }); + + const transport = new StreamableHttpClientTransport(); + const connect = transport.bind(client); + const connection = await connect(serverUrl); + + const { resources } = await connection.listResources(); + expect(resources).toHaveLength(1); + expect(resources[0].uri).toBe("file://test.txt"); + + const result = await connection.readResource("file://test.txt"); + expect(result.contents).toHaveLength(1); + expect(result.contents[0].text).toBe("Test content"); + }); + + it("should handle multiple concurrent tool calls", async () => { + const client = new McpClient({ + name: "test-client", + version: "1.0.0", + }); + + const transport = new StreamableHttpClientTransport(); + const connect = transport.bind(client); + const connection = await connect(serverUrl); + + const results = await Promise.all([ + connection.callTool("echo", { message: "First" }), + connection.callTool("echo", { message: "Second" }), + connection.callTool("add", { a: 1, b: 2 }), + ]); + + expect(results[0].content[0].text).toBe("First"); + expect(results[1].content[0].text).toBe("Second"); + expect(results[2].content[0].text).toBe("3"); + }); +}); diff --git a/packages/test-utils/src/harness.ts b/packages/test-utils/src/harness.ts index 52e7880..b93dbfa 100644 --- a/packages/test-utils/src/harness.ts +++ b/packages/test-utils/src/harness.ts @@ -2,7 +2,7 @@ * Optional in-process server harness for testing */ -import type { McpServer, SessionAdapter } from "mcp-lite"; +import type { ClientRequestAdapter, McpServer, SessionAdapter } from "mcp-lite"; import { InMemorySessionAdapter, StreamableHttpTransport } from "mcp-lite"; import type { TestServer } from "./types.js"; @@ -11,6 +11,8 @@ export interface TestHarnessOptions { sessionId?: string; /** Session adapter instance */ sessionAdapter?: SessionAdapter; + /** Client request adapter instance */ + clientRequestAdapter?: ClientRequestAdapter; /** Port for server (defaults to 0 for random) */ port?: number; } @@ -22,7 +24,7 @@ export async function createTestHarness( server: McpServer, options: TestHarnessOptions = {}, ): Promise { - const { sessionId, sessionAdapter, port = 0 } = options; + const { sessionId, sessionAdapter, clientRequestAdapter, port = 0 } = options; const transportOptions: ConstructorParameters< typeof StreamableHttpTransport @@ -42,6 +44,11 @@ export async function createTestHarness( transportOptions.sessionAdapter = sessionAdapter; } + // Add client request adapter if provided + if (clientRequestAdapter !== undefined) { + transportOptions.clientRequestAdapter = clientRequestAdapter; + } + // If neither sessionId nor sessionAdapter are provided, create stateless transport const transport = new StreamableHttpTransport(transportOptions); diff --git a/templates/starter-mcp-supabase/supabase/functions/mcp-server/index.ts b/templates/starter-mcp-supabase/supabase/functions/mcp-server/index.ts index e60632a..2325054 100644 --- a/templates/starter-mcp-supabase/supabase/functions/mcp-server/index.ts +++ b/templates/starter-mcp-supabase/supabase/functions/mcp-server/index.ts @@ -5,7 +5,6 @@ // Setup type definitions for built-in Supabase Runtime APIs /// - import { Hono } from "hono"; import { McpServer, StreamableHttpTransport } from "mcp-lite"; import { z } from "zod";