diff --git a/src/tools/vision-find.ts b/src/tools/vision-find.ts index 66db5420..f1d9c80c 100644 --- a/src/tools/vision-find.ts +++ b/src/tools/vision-find.ts @@ -13,8 +13,10 @@ import { MCPToolDefinition, MCPResult, ToolHandler, ToolContext, hasBudget } fro import { getSessionManager } from '../session-manager'; import { formatElementMapAsText } from '../vision/screenshot-analyzer'; import { formatPerceptionSnapshotAsText } from '../vision/perception-provider'; +import type { AnnotatedScreenshotResult, PerceptionSnapshot } from '../vision/types'; import { DomAnnotatorPerceptionProvider } from '../vision/providers/dom-annotator-provider'; -import { trackVisionUsage } from '../vision/config'; +import { OmniParserHttpProvider } from '../vision/providers/omniparser-http-provider'; +import { getOmniParserProviderConfig, trackVisionUsage } from '../vision/config'; const definition: MCPToolDefinition = { name: 'vision_find', @@ -130,21 +132,54 @@ const handler: ToolHandler = async ( const modeArg = args.mode; const mode: 'viewport' | 'tiled' = modeArg === 'tiled' ? 'tiled' : 'viewport'; - const provider = new DomAnnotatorPerceptionProvider(page); - const { result, snapshot } = await provider.captureAnnotated(tabId, page.url(), { - showGrid, - showBoundingBoxes, - interactiveOnly, - occlusionFilter, - iframes, - mode, - }); - - trackVisionUsage(result.annotationTimeMs); - const textMap = formatElementMapAsText(result.elementMap); - console.error(`[vision_find] Analyzed tab ${tabId}: ${result.elementCount} elements in ${result.annotationTimeMs}ms`); - - const tiles = mode === 'tiled' ? (result.tiling?.tiles ?? []) : []; + const domProvider = new DomAnnotatorPerceptionProvider(page); + const providerConfig = getOmniParserProviderConfig(); + const wantsSnapshot = format === 'snapshot' || format === 'both'; + const needsDomResult = format === 'legacy' || format === 'both' || includeImage; + let result: AnnotatedScreenshotResult | undefined; + let snapshot: PerceptionSnapshot | undefined; + const fallbackWarnings: string[] = []; + + if (wantsSnapshot && providerConfig.provider === 'omniparser-http') { + if (!providerConfig.endpointUrl) { + fallbackWarnings.push('OmniParser HTTP provider requested but OPENCHROME_OMNIPARSER_URL is not set; using dom-annotator fallback.'); + } else { + try { + const provider = new OmniParserHttpProvider(page, { + endpointUrl: providerConfig.endpointUrl, + timeoutMs: providerConfig.timeoutMs, + maxElements: providerConfig.maxElements, + context, + }); + snapshot = await provider.capture(tabId, page.url()); + trackVisionUsage(snapshot.latencyMs); + } catch (error) { + fallbackWarnings.push(`OmniParser HTTP provider failed: ${error instanceof Error ? error.message : String(error)}; using dom-annotator fallback.`); + } + } + } + + if (needsDomResult || !snapshot) { + const captured = await domProvider.captureAnnotated(tabId, page.url(), { + showGrid, + showBoundingBoxes, + interactiveOnly, + occlusionFilter, + iframes, + mode, + }); + result = captured.result; + if (!snapshot) snapshot = captured.snapshot; + if (fallbackWarnings.length > 0) { + snapshot.warnings.unshift(...fallbackWarnings); + } + trackVisionUsage(result.annotationTimeMs); + console.error(`[vision_find] Analyzed tab ${tabId}: ${result.elementCount} DOM elements in ${result.annotationTimeMs}ms`); + } else { + console.error(`[vision_find] Analyzed tab ${tabId}: ${snapshot.elements.length} ${snapshot.provider} elements in ${snapshot.latencyMs}ms`); + } + + const tiles = mode === 'tiled' ? (result?.tiling?.tiles ?? []) : []; const tileNote = mode === 'tiled' && tiles.length > 0 ? ` @@ -154,10 +189,21 @@ Tiled mode: ${tiles.length} tile screenshot(s) attached below in document-Y orde const imageBlocks = tiles.length > 0 ? tiles.map((t) => ({ type: 'image' as const, data: t.imageBase64, mimeType: t.mimeType })) - : [{ type: 'image' as const, data: result.screenshot, mimeType: result.mimeType }]; + : result + ? [{ type: 'image' as const, data: result.screenshot, mimeType: result.mimeType }] + : []; const content: MCPResult['content'] = []; if (format === 'legacy' || format === 'both') { + if (!result) { + const captured = await domProvider.captureAnnotated(tabId, page.url(), { + showGrid, + showBoundingBoxes, + interactiveOnly, + }); + result = captured.result; + } + const textMap = formatElementMapAsText(result.elementMap); content.push({ type: 'text', text: `Vision analysis: ${result.elementCount} elements found (${result.viewport.width}x${result.viewport.height} viewport, ${result.annotationTimeMs}ms)${tileNote} @@ -165,7 +211,7 @@ Tiled mode: ${tiles.length} tile screenshot(s) attached below in document-Y orde ${textMap}`, }); } - if (format === 'snapshot' || format === 'both') { + if (wantsSnapshot) { content.push({ type: 'text', text: formatPerceptionSnapshotAsText(snapshot), diff --git a/src/vision/config.ts b/src/vision/config.ts index dff93581..39ce2c61 100644 --- a/src/vision/config.ts +++ b/src/vision/config.ts @@ -37,3 +37,32 @@ export function resetVisionStats(): void { visionCallCount = 0; totalVisionTimeMs = 0; } + + +// ─── Perception Provider Configuration ─── + +export type VisionProviderName = 'dom' | 'omniparser-http'; + +export interface OmniParserProviderConfig { + provider: VisionProviderName; + endpointUrl?: string; + timeoutMs: number; + maxElements: number; +} + +function parsePositiveInt(value: string | undefined, fallback: number): number { + if (!value) return fallback; + const parsed = Number.parseInt(value, 10); + return Number.isFinite(parsed) && parsed > 0 ? parsed : fallback; +} + +export function getOmniParserProviderConfig(): OmniParserProviderConfig { + const rawProvider = (process.env.OPENCHROME_VISION_PROVIDER || 'dom').toLowerCase(); + const provider: VisionProviderName = rawProvider === 'omniparser-http' ? 'omniparser-http' : 'dom'; + return { + provider, + endpointUrl: process.env.OPENCHROME_OMNIPARSER_URL, + timeoutMs: parsePositiveInt(process.env.OPENCHROME_OMNIPARSER_TIMEOUT_MS, 3000), + maxElements: parsePositiveInt(process.env.OPENCHROME_OMNIPARSER_MAX_ELEMENTS, 200), + }; +} diff --git a/src/vision/providers/omniparser-http-provider.ts b/src/vision/providers/omniparser-http-provider.ts new file mode 100644 index 00000000..f8a91d6b --- /dev/null +++ b/src/vision/providers/omniparser-http-provider.ts @@ -0,0 +1,249 @@ +import type { Page } from 'puppeteer-core'; +import { DEFAULT_SCREENSHOT_QUALITY, DEFAULT_SCREENSHOT_TIMEOUT_MS } from '../../config/defaults'; +import { getRemainingBudget, type ToolContext } from '../../types/mcp'; +import { bufferToBase64WithPayloadGuard, resolveViewportDimensions, validateCaptureArea } from '../../utils/screenshot-guards'; +import type { PerceptionElement, PerceptionSnapshot } from '../types'; +import { sanitizePerceptionLabel, type PerceptionProviderOptions } from '../perception-provider'; + +export interface OmniParserHttpProviderOptions extends PerceptionProviderOptions { + endpointUrl: string; + timeoutMs?: number; + context?: ToolContext; +} + +type ParsedContent = Record; + +interface OmniParserResponse { + parsed_content_list?: unknown; + som_image_base64?: unknown; + latency?: unknown; +} + +const DEFAULT_TIMEOUT_MS = 3000; +const DEFAULT_MAX_ELEMENTS = 200; +const DEFAULT_MAX_LABEL_LENGTH = 160; + +function isRecord(value: unknown): value is Record { + return typeof value === 'object' && value !== null && !Array.isArray(value); +} + +function finiteNumber(value: unknown): number | undefined { + if (typeof value !== 'number' || !Number.isFinite(value)) return undefined; + return value; +} + +function clamp(n: number, min: number, max: number): number { + if (!Number.isFinite(n)) return min; + return Math.max(min, Math.min(max, n)); +} + +function normalizeTimeoutMs(timeoutMs: number | undefined, context?: ToolContext): number { + const configured = Math.max(1, timeoutMs ?? DEFAULT_TIMEOUT_MS); + if (!context) return configured; + return Math.max(1, Math.min(configured, getRemainingBudget(context))); +} + +function resolveLabel(item: ParsedContent): string { + for (const key of ['content', 'label', 'text', 'caption', 'description']) { + const value = item[key]; + if (typeof value === 'string' && value.trim()) return value; + } + return typeof item.type === 'string' ? item.type : ''; +} + +function resolveInteractive(item: ParsedContent): boolean | 'unknown' { + for (const key of ['interactive', 'interactivity', 'is_interactive']) { + if (typeof item[key] === 'boolean') return item[key]; + } + return 'unknown'; +} + +function resolveType(item: ParsedContent, interactive: boolean | 'unknown'): PerceptionElement['type'] { + const raw = String(item.type || item.element_type || '').toLowerCase(); + if (raw.includes('text')) return 'text'; + if (raw.includes('icon')) return interactive === true ? 'control' : 'icon'; + if (raw.includes('image') || raw.includes('img')) return 'image'; + if (interactive === true || raw.includes('button') || raw.includes('control')) return 'control'; + return 'unknown'; +} + +function resolveBBox(item: ParsedContent, viewport: { width: number; height: number }): PerceptionElement['bbox'] | undefined { + const raw = item.bbox ?? item.box ?? item.bounding_box; + let x: number | undefined; + let y: number | undefined; + let width: number | undefined; + let height: number | undefined; + + if (Array.isArray(raw) && raw.length >= 4) { + const nums = raw.slice(0, 4).map(finiteNumber); + if (nums.every((n): n is number => n !== undefined)) { + const [a, b, c, d] = nums; + const ratio = nums.every(n => n >= 0 && n <= 1); + x = ratio ? a * viewport.width : a; + y = ratio ? b * viewport.height : b; + width = ratio ? (c - a) * viewport.width : c - a; + height = ratio ? (d - b) * viewport.height : d - b; + } + } else if (isRecord(raw)) { + const rx = finiteNumber(raw.x); + const ry = finiteNumber(raw.y); + const rw = finiteNumber(raw.width ?? raw.w); + const rh = finiteNumber(raw.height ?? raw.h); + const x1 = finiteNumber(raw.x1 ?? raw.left); + const y1 = finiteNumber(raw.y1 ?? raw.top); + const x2 = finiteNumber(raw.x2 ?? raw.right); + const y2 = finiteNumber(raw.y2 ?? raw.bottom); + const values = [rx, ry, rw, rh, x1, y1, x2, y2].filter((n): n is number => n !== undefined); + const ratio = values.length > 0 && values.every(n => n >= 0 && n <= 1); + if (rx !== undefined && ry !== undefined && rw !== undefined && rh !== undefined) { + x = ratio ? rx * viewport.width : rx; + y = ratio ? ry * viewport.height : ry; + width = ratio ? rw * viewport.width : rw; + height = ratio ? rh * viewport.height : rh; + } else if (x1 !== undefined && y1 !== undefined && x2 !== undefined && y2 !== undefined) { + x = ratio ? x1 * viewport.width : x1; + y = ratio ? y1 * viewport.height : y1; + width = ratio ? (x2 - x1) * viewport.width : x2 - x1; + height = ratio ? (y2 - y1) * viewport.height : y2 - y1; + } + } + + if (x === undefined || y === undefined || width === undefined || height === undefined) return undefined; + const cx = clamp(x, 0, viewport.width); + const cy = clamp(y, 0, viewport.height); + return { + x: cx, + y: cy, + width: clamp(width, 0, viewport.width - cx), + height: clamp(height, 0, viewport.height - cy), + }; +} + +function bboxRatio(bbox: PerceptionElement['bbox'], viewport: { width: number; height: number }): PerceptionElement['bboxRatio'] { + const width = Math.max(1, viewport.width); + const height = Math.max(1, viewport.height); + return { + x: clamp(bbox.x / width, 0, 1), + y: clamp(bbox.y / height, 0, 1), + width: clamp(bbox.width / width, 0, 1), + height: clamp(bbox.height / height, 0, 1), + }; +} + +export class OmniParserHttpProvider { + readonly name = 'omniparser-http'; + + constructor(private readonly page: Page, private readonly options: OmniParserHttpProviderOptions) {} + + async capture(tabId: string, url: string): Promise { + const started = Date.now(); + const viewport = await resolveViewportDimensions(this.page); + const areaError = validateCaptureArea(viewport, 'OmniParser screenshot'); + if (areaError) throw new Error(areaError); + + const timeoutMs = normalizeTimeoutMs(this.options.timeoutMs, this.options.context); + const controller = new AbortController(); + const onAbort = () => controller.abort(this.options.context?.signal?.reason ?? new Error('Tool call aborted')); + this.options.context?.signal?.addEventListener('abort', onAbort, { once: true }); + const timer = setTimeout(() => controller.abort(new Error(`OmniParser HTTP provider timed out after ${timeoutMs}ms`)), timeoutMs); + + try { + const screenshot = await this.captureScreenshot(Math.min(timeoutMs, DEFAULT_SCREENSHOT_TIMEOUT_MS)); + const response = await fetch(this.options.endpointUrl, { + method: 'POST', + headers: { 'content-type': 'application/json' }, + body: JSON.stringify({ base64_image: screenshot.data }), + signal: controller.signal, + }); + if (!response.ok) { + throw new Error(`OmniParser HTTP provider returned ${response.status}`); + } + + const body = await response.json() as OmniParserResponse; + return this.toSnapshot(body, tabId, url, viewport, started, screenshot.mimeType); + } catch (error) { + if (controller.signal.aborted) { + const reason = controller.signal.reason; + throw reason instanceof Error ? reason : new Error(String(reason || 'OmniParser HTTP provider aborted')); + } + throw error; + } finally { + clearTimeout(timer); + this.options.context?.signal?.removeEventListener('abort', onAbort); + } + } + + private async captureScreenshot(timeoutMs: number): Promise<{ data: string; mimeType: string }> { + let timer: ReturnType | undefined; + const buffer = await Promise.race([ + this.page.screenshot({ type: 'webp', quality: DEFAULT_SCREENSHOT_QUALITY, fullPage: false }), + new Promise((_, reject) => { + timer = setTimeout(() => reject(new Error('OmniParser screenshot timed out')), timeoutMs); + }), + ]).finally(() => { if (timer) clearTimeout(timer); }); + + const encoded = bufferToBase64WithPayloadGuard(Buffer.from(buffer), 'OmniParser screenshot'); + if ('error' in encoded) throw new Error(encoded.error); + return { data: encoded.data, mimeType: 'image/webp' }; + } + + private toSnapshot( + body: OmniParserResponse, + tabId: string, + url: string, + viewport: { width: number; height: number }, + started: number, + screenshotMimeType: string + ): PerceptionSnapshot { + if (!Array.isArray(body.parsed_content_list)) { + throw new Error('Malformed OmniParser response: parsed_content_list must be an array'); + } + + const warnings: string[] = []; + const maxElements = Math.max(0, this.options.maxElements ?? DEFAULT_MAX_ELEMENTS); + const maxLabelLength = this.options.maxLabelLength ?? DEFAULT_MAX_LABEL_LENGTH; + const entries = body.parsed_content_list.filter(isRecord); + if (body.parsed_content_list.length !== entries.length) { + warnings.push('Malformed OmniParser entries were ignored.'); + } + if (entries.length > maxElements) { + warnings.push(`OmniParser snapshot truncated from ${entries.length} to ${maxElements} elements.`); + } + + const elements: PerceptionElement[] = []; + for (const item of entries.slice(0, maxElements)) { + const bbox = resolveBBox(item, viewport); + if (!bbox) { + warnings.push('OmniParser entry without a valid bbox was ignored.'); + continue; + } + const interactive = resolveInteractive(item); + const type = resolveType(item, interactive); + elements.push({ + id: `op${elements.length + 1}`, + type, + label: sanitizePerceptionLabel(resolveLabel(item), maxLabelLength), + role: typeof item.type === 'string' ? item.type : undefined, + interactive, + bbox, + bboxRatio: bboxRatio(bbox, viewport), + confidence: finiteNumber(item.confidence ?? item.score), + source: this.name, + metadata: typeof item.source === 'string' ? { upstreamSource: item.source } : undefined, + }); + } + + return { + version: 1, + provider: this.name, + tabId, + url, + capturedAt: started, + viewport, + screenshotMimeType, + elements, + warnings, + latencyMs: finiteNumber(body.latency) ?? Date.now() - started, + }; + } +} diff --git a/tests/vision/omniparser-http-provider.test.ts b/tests/vision/omniparser-http-provider.test.ts new file mode 100644 index 00000000..7028730c --- /dev/null +++ b/tests/vision/omniparser-http-provider.test.ts @@ -0,0 +1,131 @@ +/// + +import { OmniParserHttpProvider } from '../../src/vision/providers/omniparser-http-provider'; +import { getOmniParserProviderConfig } from '../../src/vision/config'; + +type MockPage = { + viewport: jest.Mock; + screenshot: jest.Mock; +}; + +function page(): MockPage { + return { + viewport: jest.fn(() => ({ width: 1000, height: 500 })), + screenshot: jest.fn(async () => Buffer.from('image')), + }; +} + +const originalFetch = global.fetch; +const originalEnv = { ...process.env }; + +afterEach(() => { + global.fetch = originalFetch; + process.env = { ...originalEnv }; + jest.useRealTimers(); +}); + +describe('OmniParserHttpProvider', () => { + test('posts a guarded screenshot and converts ratio bboxes into perception elements', async () => { + const p = page(); + global.fetch = jest.fn(async () => new Response(JSON.stringify({ + latency: 25, + parsed_content_list: [ + { type: 'text', content: 'Search label', bbox: [0.1, 0.2, 0.3, 0.4] }, + { type: 'icon', content: 'Continue button', bbox: [0.4, 0.1, 0.6, 0.2], interactive: true, confidence: 0.93 }, + ], + }), { status: 200, headers: { 'content-type': 'application/json' } })) as typeof fetch; + + const snapshot = await new OmniParserHttpProvider(p as never, { + endpointUrl: 'http://127.0.0.1:9901/parse/', + timeoutMs: 1000, + }).capture('tab-1', 'https://example.test'); + + expect(global.fetch).toHaveBeenCalledWith('http://127.0.0.1:9901/parse/', expect.objectContaining({ + method: 'POST', + headers: { 'content-type': 'application/json' }, + })); + expect(JSON.parse((global.fetch as jest.Mock).mock.calls[0][1].body)).toEqual({ + base64_image: Buffer.from('image').toString('base64'), + }); + expect(snapshot.provider).toBe('omniparser-http'); + expect(snapshot.elements).toHaveLength(2); + expect(snapshot.elements[0]).toMatchObject({ + id: 'op1', + type: 'text', + label: 'Search label', + interactive: 'unknown', + source: 'omniparser-http', + }); + expect(snapshot.elements[0].bbox).toMatchObject({ x: 100, y: 100, height: 100 }); + expect(snapshot.elements[0].bbox.width).toBeCloseTo(200); + expect(snapshot.elements[0].bboxRatio).toMatchObject({ x: 0.1, y: 0.2, height: 0.2 }); + expect(snapshot.elements[0].bboxRatio.width).toBeCloseTo(0.2); + expect(snapshot.elements[1]).toMatchObject({ type: 'control', interactive: true, confidence: 0.93 }); + }); + + test('bounds labels/elements and ignores malformed entries with warnings', async () => { + const p = page(); + global.fetch = jest.fn(async () => new Response(JSON.stringify({ + parsed_content_list: [ + { type: 'text', content: `password=super-secret-fixture-password ${'x'.repeat(50)}`, bbox: { x: 0, y: 0, width: 0.2, height: 0.1 } }, + { type: 'text', content: 'missing bbox' }, + { type: 'text', content: 'third', bbox: [0.2, 0.2, 0.3, 0.3] }, + ], + }), { status: 200 })) as typeof fetch; + + const snapshot = await new OmniParserHttpProvider(p as never, { + endpointUrl: 'http://local/parse', + maxElements: 2, + maxLabelLength: 12, + }).capture('tab', 'https://example.test'); + + expect(snapshot.elements).toHaveLength(1); + expect(snapshot.elements[0].label).toBe('[REDACTED]…'); + expect(snapshot.warnings.join('\n')).toContain('truncated from 3 to 2'); + expect(snapshot.warnings.join('\n')).toContain('without a valid bbox'); + }); + + test('throws bounded error on malformed response', async () => { + const p = page(); + global.fetch = jest.fn(async () => new Response(JSON.stringify({ parsed_content_list: 'bad' }), { status: 200 })) as typeof fetch; + + await expect(new OmniParserHttpProvider(p as never, { + endpointUrl: 'http://local/parse', + }).capture('tab', 'https://example.test')).rejects.toThrow('parsed_content_list must be an array'); + }); + + test('respects timeout with abort controller', async () => { + jest.useFakeTimers(); + const p = page(); + global.fetch = jest.fn((_url, init) => new Promise((_resolve, reject) => { + (init as RequestInit).signal?.addEventListener('abort', () => reject((init as RequestInit).signal?.reason)); + })) as typeof fetch; + + const promise = new OmniParserHttpProvider(p as never, { + endpointUrl: 'http://local/parse', + timeoutMs: 10, + }).capture('tab', 'https://example.test'); + const expectation = expect(promise).rejects.toThrow('timed out after 10ms'); + await jest.advanceTimersByTimeAsync(20); + + await expectation; + }); +}); + +describe('getOmniParserProviderConfig', () => { + test('keeps provider opt-in and parses bounds', () => { + delete process.env.OPENCHROME_VISION_PROVIDER; + expect(getOmniParserProviderConfig().provider).toBe('dom'); + + process.env.OPENCHROME_VISION_PROVIDER = 'omniparser-http'; + process.env.OPENCHROME_OMNIPARSER_URL = 'http://127.0.0.1:9901/parse/'; + process.env.OPENCHROME_OMNIPARSER_TIMEOUT_MS = '1234'; + process.env.OPENCHROME_OMNIPARSER_MAX_ELEMENTS = '42'; + expect(getOmniParserProviderConfig()).toMatchObject({ + provider: 'omniparser-http', + endpointUrl: 'http://127.0.0.1:9901/parse/', + timeoutMs: 1234, + maxElements: 42, + }); + }); +});