diff --git a/src/core/analyzer.ts b/src/core/analyzer.ts index 754593b..2bbc4ac 100644 --- a/src/core/analyzer.ts +++ b/src/core/analyzer.ts @@ -8,7 +8,7 @@ import { collectRecognizedNames, collectStringVariables, decoratorExtractor, - findNodesByType, + getNodesByType, importExtractor, includeRouterExtractor, mountExtractor, @@ -37,32 +37,32 @@ function resolveVariables( /** Analyze a syntax tree and extract FastAPI-related information */ export function analyzeTree(tree: Tree, filePath: string): FileAnalysis { - const rootNode = tree.rootNode + const nodesByType = getNodesByType(tree.rootNode) // Get all decorated definitions (functions and classes with decorators) - const decoratedDefs = findNodesByType(rootNode, "decorated_definition") + const decoratedDefs = nodesByType.get("decorated_definition") ?? [] const routes = decoratedDefs.map(decoratorExtractor).filter(notNull) // Get all router assignments - const assignments = findNodesByType(rootNode, "assignment") - const { fastAPINames, apiRouterNames } = collectRecognizedNames(rootNode) + const assignments = nodesByType.get("assignment") ?? [] + const { fastAPINames, apiRouterNames } = collectRecognizedNames(nodesByType) const routers = assignments .map((node) => routerExtractor(node, apiRouterNames, fastAPINames)) .filter(notNull) // Get all include_router and mount calls - const callNodes = findNodesByType(rootNode, "call") + const callNodes = nodesByType.get("call") ?? [] const includeRouters = callNodes.map(includeRouterExtractor).filter(notNull) const mounts = callNodes.map(mountExtractor).filter(notNull) // Get all import statements - const importNodes = findNodesByType(rootNode, "import_statement") - const importFromNodes = findNodesByType(rootNode, "import_from_statement") + const importNodes = nodesByType.get("import_statement") ?? [] + const importFromNodes = nodesByType.get("import_from_statement") ?? [] const imports = [...importNodes, ...importFromNodes] .map(importExtractor) .filter(notNull) - const stringVariables = collectStringVariables(rootNode) + const stringVariables = collectStringVariables(nodesByType) for (const route of routes) { route.path = resolveVariables(route.path, stringVariables) diff --git a/src/core/extractors.ts b/src/core/extractors.ts index 3421472..c306b6d 100644 --- a/src/core/extractors.ts +++ b/src/core/extractors.ts @@ -14,13 +14,6 @@ import type { } from "./internal" import { ROUTE_METHODS } from "./internal" -/** Recursively finds all nodes of a given type within a subtree */ -export function findNodesByType(node: Node, type: string): Node[] { - const results: Node[] = [] - collectNodesByType(node, type, results) - return results -} - function stripDocstring(raw: string): string { let content: string if ( @@ -49,16 +42,24 @@ function stripDocstring(raw: string): string { return dedented.join("\n").trim() } -function collectNodesByType(node: Node, type: string, results: Node[]): void { - if (node.type === type) { - results.push(node) - } - for (let i = 0; i < node.childCount; i++) { - const child = node.child(i) - if (child) { - collectNodesByType(child, type, results) +export function getNodesByType(root: Node): Map { + const results = new Map() + + function collectNodesByType(node: Node, results: Map): void { + if (!results.has(node.type)) { + results.set(node.type, []) + } + results.get(node.type)!.push(node) + + for (let i = 0; i < node.childCount; i++) { + const child = node.child(i) + if (child) { + collectNodesByType(child, results) + } } } + collectNodesByType(root, results) + return results } /** @@ -74,9 +75,11 @@ function collectNodesByType(node: Node, type: string, results: Node[]): void { * settings.PREFIX = "/api" -> (skipped, not a simple identifier) * def f(): BASE = "/local" -> (skipped, inside function) */ -export function collectStringVariables(rootNode: Node): Map { +export function collectStringVariables( + nodesByType: Map, +): Map { const variables = new Map() - const assignmentNodes = findNodesByType(rootNode, "assignment") + const assignmentNodes = nodesByType.get("assignment") ?? [] for (const assign of assignmentNodes) { if ( @@ -172,7 +175,11 @@ export function decoratorExtractor(node: Node): RouteInfo | null { // Grammar guarantees: decorated_definition always has a first child (the decorator) const decoratorNode = node.firstNamedChild! - const callNode = findNodesByType(decoratorNode, "call")[0] + const callNode = + decoratorNode.firstNamedChild?.type === "call" + ? decoratorNode.firstNamedChild + : null + const functionNode = callNode?.childForFieldName("function") const argumentsNode = callNode?.childForFieldName("arguments") const objectNode = functionNode?.childForFieldName("object") @@ -351,7 +358,8 @@ export function importExtractor(node: Node): ImportInfo | null { if (node.type === "import_statement") { let modulePath = "" // Handle aliased imports: "import fastapi as f" - for (const aliased of findNodesByType(node, "aliased_import")) { + const aliasedImports = getNodesByType(node).get("aliased_import") ?? [] + for (const aliased of aliasedImports) { const nameNode = aliased.childForFieldName("name") const aliasNode = aliased.childForFieldName("alias") if (nameNode) { @@ -362,7 +370,7 @@ export function importExtractor(node: Node): ImportInfo | null { } } // Non-aliased: "import fastapi" or "import fastapi.routing" - const nameNodes = findNodesByType(node, "dotted_name") + const nameNodes = getNodesByType(node).get("dotted_name") ?? [] for (const nameNode of nameNodes) { if (!hasAncestor(nameNode, "aliased_import")) { if (!modulePath) modulePath = nameNode.text // preserve full dotted path @@ -387,7 +395,8 @@ export function importExtractor(node: Node): ImportInfo | null { ) // Aliased imports (e.g., "router as users_router") - for (const aliased of findNodesByType(node, "aliased_import")) { + const aliasedImports = getNodesByType(node).get("aliased_import") ?? [] + for (const aliased of aliasedImports) { const nameNode = aliased.childForFieldName("name") const aliasNode = aliased.childForFieldName("alias") if (nameNode) { @@ -398,7 +407,7 @@ export function importExtractor(node: Node): ImportInfo | null { } // Non-aliased imports (skip first dotted_name which is the module path) - const nameNodes = findNodesByType(node, "dotted_name") + const nameNodes = getNodesByType(node).get("dotted_name") ?? [] for (let i = 1; i < nameNodes.length; i++) { const nameNode = nameNodes[i] if (!hasAncestor(nameNode, "aliased_import")) { @@ -423,7 +432,7 @@ export function importExtractor(node: Node): ImportInfo | null { * fastAPINames = Set { "FastAPI", "fastapi.FastAPI", "MyApp" } * apiRouterNames = Set { "APIRouter", "fastapi.APIRouter", "MyRouter", "CustomRouter" } */ -export function collectRecognizedNames(rootNode: Node): { +export function collectRecognizedNames(nodesByType: Map): { fastAPINames: Set apiRouterNames: Set } { @@ -431,7 +440,7 @@ export function collectRecognizedNames(rootNode: Node): { const apiRouterNames = new Set(["APIRouter", "fastapi.APIRouter"]) // Add aliases from "from fastapi import X as Y" imports - for (const node of findNodesByType(rootNode, "import_from_statement")) { + for (const node of nodesByType.get("import_from_statement") ?? []) { const info = importExtractor(node) if (!info || info.modulePath !== "fastapi") continue for (const named of info.namedImports) { @@ -442,7 +451,7 @@ export function collectRecognizedNames(rootNode: Node): { } // Add module aliases from "import fastapi as f" → recognizes f.FastAPI, f.APIRouter - for (const node of findNodesByType(rootNode, "import_statement")) { + for (const node of nodesByType.get("import_statement") ?? []) { const info = importExtractor(node) if (!info) continue for (const named of info.namedImports) { @@ -456,7 +465,7 @@ export function collectRecognizedNames(rootNode: Node): { // Add subclasses, checking against the already-accumulated alias sets so // "class MyRouter(AR)" works when AR is an alias for APIRouter - for (const cls of findNodesByType(rootNode, "class_definition")) { + for (const cls of nodesByType.get("class_definition") ?? []) { const nameNode = cls.childForFieldName("name") const superclassesNode = cls.childForFieldName("superclasses") if (!nameNode || !superclassesNode) continue diff --git a/src/test/core/extractors.test.ts b/src/test/core/extractors.test.ts index 1cc63db..e85bd48 100644 --- a/src/test/core/extractors.test.ts +++ b/src/test/core/extractors.test.ts @@ -5,7 +5,7 @@ import { decoratorExtractor, extractPathFromNode, extractStringValue, - findNodesByType, + getNodesByType, importExtractor, includeRouterExtractor, mountExtractor, @@ -41,10 +41,8 @@ def list_users(): pass ` const tree = parse(code) - const decoratedDefs = findNodesByType( - tree.rootNode, - "decorated_definition", - ) + const nodesByType = getNodesByType(tree.rootNode) + const decoratedDefs = nodesByType.get("decorated_definition") ?? [] assert.strictEqual(decoratedDefs.length, 1) const result = decoratorExtractor(decoratedDefs[0]) @@ -62,10 +60,8 @@ def get_user(user_id: int): pass ` const tree = parse(code) - const decoratedDefs = findNodesByType( - tree.rootNode, - "decorated_definition", - ) + const nodesByType = getNodesByType(tree.rootNode) + const decoratedDefs = nodesByType.get("decorated_definition") ?? [] const result = decoratorExtractor(decoratedDefs[0]) assert.ok(result) @@ -79,10 +75,8 @@ def create_item(): pass ` const tree = parse(code) - const decoratedDefs = findNodesByType( - tree.rootNode, - "decorated_definition", - ) + const nodesByType = getNodesByType(tree.rootNode) + const decoratedDefs = nodesByType.get("decorated_definition") ?? [] const result = decoratorExtractor(decoratedDefs[0]) assert.ok(result) @@ -98,10 +92,8 @@ def websocket_handler(websocket: WebSocket): pass ` const tree = parse(code) - const decoratedDefs = findNodesByType( - tree.rootNode, - "decorated_definition", - ) + const nodesByType = getNodesByType(tree.rootNode) + const decoratedDefs = nodesByType.get("decorated_definition") ?? [] const result = decoratorExtractor(decoratedDefs[0]) assert.ok(result) @@ -116,10 +108,8 @@ def handler(): pass ` const tree = parse(code) - const decoratedDefs = findNodesByType( - tree.rootNode, - "decorated_definition", - ) + const nodesByType = getNodesByType(tree.rootNode) + const decoratedDefs = nodesByType.get("decorated_definition") ?? [] const result = decoratorExtractor(decoratedDefs[0]) assert.ok(result) @@ -133,10 +123,8 @@ def handler(): pass ` const tree = parse(code) - const decoratedDefs = findNodesByType( - tree.rootNode, - "decorated_definition", - ) + const nodesByType = getNodesByType(tree.rootNode) + const decoratedDefs = nodesByType.get("decorated_definition") ?? [] const result = decoratorExtractor(decoratedDefs[0]) assert.ok(result) @@ -150,10 +138,8 @@ def handler(): pass ` const tree = parse(code) - const decoratedDefs = findNodesByType( - tree.rootNode, - "decorated_definition", - ) + const nodesByType = getNodesByType(tree.rootNode) + const decoratedDefs = nodesByType.get("decorated_definition") ?? [] const result = decoratorExtractor(decoratedDefs[0]) assert.ok(result) @@ -167,10 +153,8 @@ def handler(): pass ` const tree = parse(code) - const decoratedDefs = findNodesByType( - tree.rootNode, - "decorated_definition", - ) + const nodesByType = getNodesByType(tree.rootNode) + const decoratedDefs = nodesByType.get("decorated_definition") ?? [] const result = decoratorExtractor(decoratedDefs[0]) assert.strictEqual(result, null) }) @@ -182,10 +166,8 @@ def not_found(request, exc): pass ` const tree = parse(code) - const decoratedDefs = findNodesByType( - tree.rootNode, - "decorated_definition", - ) + const nodesByType = getNodesByType(tree.rootNode) + const decoratedDefs = nodesByType.get("decorated_definition") ?? [] const result = decoratorExtractor(decoratedDefs[0]) assert.strictEqual(result, null) }) @@ -197,10 +179,8 @@ def handle_items(): pass ` const tree = parse(code) - const decoratedDefs = findNodesByType( - tree.rootNode, - "decorated_definition", - ) + const nodesByType = getNodesByType(tree.rootNode) + const decoratedDefs = nodesByType.get("decorated_definition") ?? [] const result = decoratorExtractor(decoratedDefs[0]) assert.ok(result) @@ -215,10 +195,8 @@ def handle_items(): pass ` const tree = parse(code) - const decoratedDefs = findNodesByType( - tree.rootNode, - "decorated_definition", - ) + const nodesByType = getNodesByType(tree.rootNode) + const decoratedDefs = nodesByType.get("decorated_definition") ?? [] const result = decoratorExtractor(decoratedDefs[0]) assert.ok(result) @@ -232,10 +210,8 @@ def get_user(user_id: int): pass ` const tree = parse(code) - const decoratedDefs = findNodesByType( - tree.rootNode, - "decorated_definition", - ) + const nodesByType = getNodesByType(tree.rootNode) + const decoratedDefs = nodesByType.get("decorated_definition") ?? [] const result = decoratorExtractor(decoratedDefs[0]) assert.ok(result) @@ -248,7 +224,8 @@ def regular_function(): pass ` const tree = parse(code) - const funcDefs = findNodesByType(tree.rootNode, "function_definition") + const nodesByType = getNodesByType(tree.rootNode) + const funcDefs = nodesByType.get("function_definition") ?? [] const result = decoratorExtractor(funcDefs[0]) assert.strictEqual(result, null) @@ -261,10 +238,8 @@ def handler(): pass ` const tree = parse(code) - const decoratedDefs = findNodesByType( - tree.rootNode, - "decorated_definition", - ) + const nodesByType = getNodesByType(tree.rootNode) + const decoratedDefs = nodesByType.get("decorated_definition") ?? [] const result = decoratorExtractor(decoratedDefs[0]) assert.ok(result) @@ -280,10 +255,8 @@ def list_users(): pass ` const tree = parse(code) - const decoratedDefs = findNodesByType( - tree.rootNode, - "decorated_definition", - ) + const nodesByType = getNodesByType(tree.rootNode) + const decoratedDefs = nodesByType.get("decorated_definition") ?? [] const result = decoratorExtractor(decoratedDefs[0]) assert.ok(result) @@ -302,10 +275,8 @@ def list_users(): pass ` const tree = parse(code) - const decoratedDefs = findNodesByType( - tree.rootNode, - "decorated_definition", - ) + const nodesByType = getNodesByType(tree.rootNode) + const decoratedDefs = nodesByType.get("decorated_definition") ?? [] const result = decoratorExtractor(decoratedDefs[0]) assert.ok(result) @@ -323,10 +294,8 @@ def list_users(): pass ` const tree = parse(code) - const decoratedDefs = findNodesByType( - tree.rootNode, - "decorated_definition", - ) + const nodesByType = getNodesByType(tree.rootNode) + const decoratedDefs = nodesByType.get("decorated_definition") ?? [] const result = decoratorExtractor(decoratedDefs[0]) assert.ok(result) @@ -340,10 +309,8 @@ def list_users(): pass ` const tree = parse(code) - const decoratedDefs = findNodesByType( - tree.rootNode, - "decorated_definition", - ) + const nodesByType = getNodesByType(tree.rootNode) + const decoratedDefs = nodesByType.get("decorated_definition") ?? [] const result = decoratorExtractor(decoratedDefs[0]) assert.ok(result) @@ -355,7 +322,8 @@ def list_users(): test("extracts FastAPI app instantiation", () => { const code = "app = FastAPI()" const tree = parse(code) - const assignments = findNodesByType(tree.rootNode, "assignment") + const nodesByType = getNodesByType(tree.rootNode) + const assignments = nodesByType.get("assignment") ?? [] const result = routerExtractor(assignments[0]) assert.ok(result) @@ -367,7 +335,8 @@ def list_users(): test("extracts APIRouter instantiation", () => { const code = "router = APIRouter()" const tree = parse(code) - const assignments = findNodesByType(tree.rootNode, "assignment") + const nodesByType = getNodesByType(tree.rootNode) + const assignments = nodesByType.get("assignment") ?? [] const result = routerExtractor(assignments[0]) assert.ok(result) @@ -378,7 +347,8 @@ def list_users(): test("extracts APIRouter with prefix", () => { const code = `router = APIRouter(prefix="/users")` const tree = parse(code) - const assignments = findNodesByType(tree.rootNode, "assignment") + const nodesByType = getNodesByType(tree.rootNode) + const assignments = nodesByType.get("assignment") ?? [] const result = routerExtractor(assignments[0]) assert.ok(result) @@ -388,7 +358,8 @@ def list_users(): test("extracts APIRouter with tags", () => { const code = `router = APIRouter(tags=["users", "admin"])` const tree = parse(code) - const assignments = findNodesByType(tree.rootNode, "assignment") + const nodesByType = getNodesByType(tree.rootNode) + const assignments = nodesByType.get("assignment") ?? [] const result = routerExtractor(assignments[0]) assert.ok(result) @@ -398,7 +369,8 @@ def list_users(): test("extracts APIRouter with prefix and tags", () => { const code = `router = APIRouter(prefix="/api", tags=["api"])` const tree = parse(code) - const assignments = findNodesByType(tree.rootNode, "assignment") + const nodesByType = getNodesByType(tree.rootNode) + const assignments = nodesByType.get("assignment") ?? [] const result = routerExtractor(assignments[0]) assert.ok(result) @@ -409,7 +381,8 @@ def list_users(): test("ignores positional arguments in router constructor", () => { const code = "router = APIRouter(some_config)" const tree = parse(code) - const assignments = findNodesByType(tree.rootNode, "assignment") + const nodesByType = getNodesByType(tree.rootNode) + const assignments = nodesByType.get("assignment") ?? [] const result = routerExtractor(assignments[0]) assert.ok(result) @@ -420,7 +393,8 @@ def list_users(): test("handles dynamic prefix", () => { const code = "router = APIRouter(prefix=settings.API_PREFIX)" const tree = parse(code) - const assignments = findNodesByType(tree.rootNode, "assignment") + const nodesByType = getNodesByType(tree.rootNode) + const assignments = nodesByType.get("assignment") ?? [] const result = routerExtractor(assignments[0]) assert.ok(result) @@ -430,7 +404,8 @@ def list_users(): test("returns null for non-router assignment", () => { const code = "x = 5" const tree = parse(code) - const assignments = findNodesByType(tree.rootNode, "assignment") + const nodesByType = getNodesByType(tree.rootNode) + const assignments = nodesByType.get("assignment") ?? [] const result = routerExtractor(assignments[0]) assert.strictEqual(result, null) @@ -439,7 +414,8 @@ def list_users(): test("returns null for other function call", () => { const code = "result = some_function()" const tree = parse(code) - const assignments = findNodesByType(tree.rootNode, "assignment") + const nodesByType = getNodesByType(tree.rootNode) + const assignments = nodesByType.get("assignment") ?? [] const result = routerExtractor(assignments[0]) assert.strictEqual(result, null) @@ -448,7 +424,8 @@ def list_users(): test("extracts qualified fastapi.FastAPI() call", () => { const code = "app = fastapi.FastAPI()" const tree = parse(code) - const assignments = findNodesByType(tree.rootNode, "assignment") + const nodesByType = getNodesByType(tree.rootNode) + const assignments = nodesByType.get("assignment") ?? [] const result = routerExtractor(assignments[0]) assert.ok(result) @@ -459,7 +436,8 @@ def list_users(): test("extracts qualified fastapi.APIRouter() call", () => { const code = "router = fastapi.APIRouter(prefix='/api')" const tree = parse(code) - const assignments = findNodesByType(tree.rootNode, "assignment") + const nodesByType = getNodesByType(tree.rootNode) + const assignments = nodesByType.get("assignment") ?? [] const result = routerExtractor(assignments[0]) assert.ok(result) @@ -471,7 +449,8 @@ def list_users(): test("returns null for custom subclass without subclasses set", () => { const code = "admin_router = AdminAPIRouter(prefix='/admin')" const tree = parse(code) - const assignments = findNodesByType(tree.rootNode, "assignment") + const nodesByType = getNodesByType(tree.rootNode) + const assignments = nodesByType.get("assignment") ?? [] const result = routerExtractor(assignments[0]) assert.strictEqual(result, null) @@ -485,10 +464,11 @@ class AdminAPIRouter(APIRouter): admin_router = AdminAPIRouter(prefix="/admin") ` const tree = parse(code) - const { apiRouterNames } = collectRecognizedNames(tree.rootNode) + const nodesByType = getNodesByType(tree.rootNode) + const { apiRouterNames } = collectRecognizedNames(nodesByType) assert.ok(apiRouterNames.has("AdminAPIRouter")) - const assignments = findNodesByType(tree.rootNode, "assignment") + const assignments = nodesByType.get("assignment") ?? [] const result = routerExtractor(assignments[0], apiRouterNames) assert.ok(result) @@ -505,13 +485,13 @@ class MyApp(FastAPI): app = MyApp() ` const tree = parse(code) - const { fastAPINames, apiRouterNames } = collectRecognizedNames( - tree.rootNode, - ) + const nodesByType = getNodesByType(tree.rootNode) + const { fastAPINames, apiRouterNames } = + collectRecognizedNames(nodesByType) assert.ok(fastAPINames.has("MyApp")) assert.ok(!apiRouterNames.has("MyApp")) - const assignments = findNodesByType(tree.rootNode, "assignment") + const assignments = nodesByType.get("assignment") ?? [] const result = routerExtractor( assignments[0], apiRouterNames, @@ -530,12 +510,12 @@ from fastapi import FastAPI as FA app = FA() ` const tree = parse(code) - const { fastAPINames, apiRouterNames } = collectRecognizedNames( - tree.rootNode, - ) + const nodesByType = getNodesByType(tree.rootNode) + const { fastAPINames, apiRouterNames } = + collectRecognizedNames(nodesByType) assert.ok(fastAPINames.has("FA")) - const assignments = findNodesByType(tree.rootNode, "assignment") + const assignments = nodesByType.get("assignment") ?? [] const result = routerExtractor( assignments[0], apiRouterNames, @@ -554,10 +534,11 @@ from fastapi import APIRouter as AR router = AR(prefix="/items") ` const tree = parse(code) - const { apiRouterNames } = collectRecognizedNames(tree.rootNode) + const nodesByType = getNodesByType(tree.rootNode) + const { apiRouterNames } = collectRecognizedNames(nodesByType) assert.ok(apiRouterNames.has("AR")) - const assignments = findNodesByType(tree.rootNode, "assignment") + const assignments = nodesByType.get("assignment") ?? [] const result = routerExtractor(assignments[0], apiRouterNames) assert.ok(result) @@ -576,11 +557,13 @@ class MyRouter(AR): router = MyRouter(prefix="/items") ` const tree = parse(code) - const { apiRouterNames } = collectRecognizedNames(tree.rootNode) + + const nodesByType = getNodesByType(tree.rootNode) + const { apiRouterNames } = collectRecognizedNames(nodesByType) assert.ok(apiRouterNames.has("AR")) assert.ok(apiRouterNames.has("MyRouter")) - const assignments = findNodesByType(tree.rootNode, "assignment") + const assignments = nodesByType.get("assignment") ?? [] const result = routerExtractor(assignments[0], apiRouterNames) assert.ok(result) @@ -590,9 +573,9 @@ router = MyRouter(prefix="/items") test("collectRecognizedNames ignores non-aliased imports", () => { const code = "from fastapi import FastAPI, APIRouter" const tree = parse(code) - const { fastAPINames, apiRouterNames } = collectRecognizedNames( - tree.rootNode, - ) + const nodesByType = getNodesByType(tree.rootNode) + const { fastAPINames, apiRouterNames } = + collectRecognizedNames(nodesByType) // Only the defaults — no extras from non-aliased imports assert.strictEqual(fastAPINames.size, 2) // "FastAPI", "fastapi.FastAPI" assert.strictEqual(apiRouterNames.size, 2) // "APIRouter", "fastapi.APIRouter" @@ -606,13 +589,13 @@ app = f.FastAPI() router = f.APIRouter(prefix="/items") ` const tree = parse(code) - const { fastAPINames, apiRouterNames } = collectRecognizedNames( - tree.rootNode, - ) + const nodesByType = getNodesByType(tree.rootNode) + const { fastAPINames, apiRouterNames } = + collectRecognizedNames(nodesByType) assert.ok(fastAPINames.has("f.FastAPI")) assert.ok(apiRouterNames.has("f.APIRouter")) - const assignments = findNodesByType(tree.rootNode, "assignment") + const assignments = nodesByType.get("assignment") ?? [] const appResult = routerExtractor( assignments[0], apiRouterNames, @@ -638,7 +621,8 @@ router = f.APIRouter(prefix="/items") test("extracts simple import", () => { const code = "import fastapi" const tree = parse(code) - const imports = findNodesByType(tree.rootNode, "import_statement") + const nodesByType = getNodesByType(tree.rootNode) + const imports = nodesByType.get("import_statement") ?? [] const result = importExtractor(imports[0]) assert.ok(result) @@ -653,7 +637,8 @@ router = f.APIRouter(prefix="/items") test("preserves full dotted modulePath for import fastapi.routing", () => { const code = "import fastapi.routing" const tree = parse(code) - const imports = findNodesByType(tree.rootNode, "import_statement") + const nodesByType = getNodesByType(tree.rootNode) + const imports = nodesByType.get("import_statement") ?? [] const result = importExtractor(imports[0]) assert.ok(result) @@ -667,7 +652,8 @@ router = f.APIRouter(prefix="/items") test("extracts aliased module import (import fastapi as f)", () => { const code = "import fastapi as f" const tree = parse(code) - const imports = findNodesByType(tree.rootNode, "import_statement") + const nodesByType = getNodesByType(tree.rootNode) + const imports = nodesByType.get("import_statement") ?? [] const result = importExtractor(imports[0]) assert.ok(result) @@ -682,7 +668,8 @@ router = f.APIRouter(prefix="/items") test("extracts from import", () => { const code = "from fastapi import FastAPI" const tree = parse(code) - const imports = findNodesByType(tree.rootNode, "import_from_statement") + const nodesByType = getNodesByType(tree.rootNode) + const imports = nodesByType.get("import_from_statement") ?? [] const result = importExtractor(imports[0]) assert.ok(result) @@ -694,7 +681,8 @@ router = f.APIRouter(prefix="/items") test("extracts relative import with single dot", () => { const code = "from .routes import users" const tree = parse(code) - const imports = findNodesByType(tree.rootNode, "import_from_statement") + const nodesByType = getNodesByType(tree.rootNode) + const imports = nodesByType.get("import_from_statement") ?? [] const result = importExtractor(imports[0]) assert.ok(result) @@ -706,7 +694,8 @@ router = f.APIRouter(prefix="/items") test("extracts relative import with double dot", () => { const code = "from ..api import router" const tree = parse(code) - const imports = findNodesByType(tree.rootNode, "import_from_statement") + const nodesByType = getNodesByType(tree.rootNode) + const imports = nodesByType.get("import_from_statement") ?? [] const result = importExtractor(imports[0]) assert.ok(result) @@ -718,7 +707,8 @@ router = f.APIRouter(prefix="/items") test("extracts import with alias", () => { const code = "from .users import router as users_router" const tree = parse(code) - const imports = findNodesByType(tree.rootNode, "import_from_statement") + const nodesByType = getNodesByType(tree.rootNode) + const imports = nodesByType.get("import_from_statement") ?? [] const result = importExtractor(imports[0]) assert.ok(result) @@ -731,7 +721,8 @@ router = f.APIRouter(prefix="/items") test("extracts multiple imports", () => { const code = "from fastapi import FastAPI, APIRouter" const tree = parse(code) - const imports = findNodesByType(tree.rootNode, "import_from_statement") + const nodesByType = getNodesByType(tree.rootNode) + const imports = nodesByType.get("import_from_statement") ?? [] const result = importExtractor(imports[0]) assert.ok(result) @@ -742,7 +733,8 @@ router = f.APIRouter(prefix="/items") test("returns null for non-import node", () => { const code = "x = 5" const tree = parse(code) - const assignments = findNodesByType(tree.rootNode, "assignment") + const nodesByType = getNodesByType(tree.rootNode) + const assignments = nodesByType.get("assignment") ?? [] const result = importExtractor(assignments[0]) assert.strictEqual(result, null) @@ -753,7 +745,8 @@ router = f.APIRouter(prefix="/items") test("extracts include_router call", () => { const code = "app.include_router(users.router)" const tree = parse(code) - const calls = findNodesByType(tree.rootNode, "call") + const nodesByType = getNodesByType(tree.rootNode) + const calls = nodesByType.get("call") ?? [] const result = includeRouterExtractor(calls[0]) assert.ok(result) @@ -765,7 +758,8 @@ router = f.APIRouter(prefix="/items") test("extracts include_router with prefix", () => { const code = `app.include_router(users.router, prefix="/users")` const tree = parse(code) - const calls = findNodesByType(tree.rootNode, "call") + const nodesByType = getNodesByType(tree.rootNode) + const calls = nodesByType.get("call") ?? [] const result = includeRouterExtractor(calls[0]) assert.ok(result) @@ -775,7 +769,8 @@ router = f.APIRouter(prefix="/items") test("extracts include_router with dynamic prefix", () => { const code = "app.include_router(router, prefix=settings.PREFIX)" const tree = parse(code) - const calls = findNodesByType(tree.rootNode, "call") + const nodesByType = getNodesByType(tree.rootNode) + const calls = nodesByType.get("call") ?? [] const result = includeRouterExtractor(calls[0]) assert.ok(result) @@ -785,7 +780,8 @@ router = f.APIRouter(prefix="/items") test("extracts include_router with tags", () => { const code = `app.include_router(router, tags=["users", "admin"])` const tree = parse(code) - const calls = findNodesByType(tree.rootNode, "call") + const nodesByType = getNodesByType(tree.rootNode) + const calls = nodesByType.get("call") ?? [] const result = includeRouterExtractor(calls[0]) assert.ok(result) @@ -795,7 +791,8 @@ router = f.APIRouter(prefix="/items") test("extracts include_router with 'router' keyword argument", () => { const code = `app.include_router(router=users_router, prefix="/api")` const tree = parse(code) - const calls = findNodesByType(tree.rootNode, "call") + const nodesByType = getNodesByType(tree.rootNode) + const calls = nodesByType.get("call") ?? [] const result = includeRouterExtractor(calls[0]) assert.ok(result) @@ -806,7 +803,8 @@ router = f.APIRouter(prefix="/items") test("returns null for non-include_router call", () => { const code = "app.some_method(arg)" const tree = parse(code) - const calls = findNodesByType(tree.rootNode, "call") + const nodesByType = getNodesByType(tree.rootNode) + const calls = nodesByType.get("call") ?? [] const result = includeRouterExtractor(calls[0]) assert.strictEqual(result, null) @@ -815,7 +813,8 @@ router = f.APIRouter(prefix="/items") test("returns null for function call (not method)", () => { const code = "include_router(router)" const tree = parse(code) - const calls = findNodesByType(tree.rootNode, "call") + const nodesByType = getNodesByType(tree.rootNode) + const calls = nodesByType.get("call") ?? [] const result = includeRouterExtractor(calls[0]) assert.strictEqual(result, null) @@ -826,7 +825,8 @@ router = f.APIRouter(prefix="/items") test("extracts mount call", () => { const code = `app.mount("/static", static_app)` const tree = parse(code) - const calls = findNodesByType(tree.rootNode, "call") + const nodesByType = getNodesByType(tree.rootNode) + const calls = nodesByType.get("call") ?? [] const result = mountExtractor(calls[0]) assert.ok(result) @@ -838,7 +838,8 @@ router = f.APIRouter(prefix="/items") test("extracts mount with dynamic path", () => { const code = "app.mount(settings.STATIC_PATH, static_app)" const tree = parse(code) - const calls = findNodesByType(tree.rootNode, "call") + const nodesByType = getNodesByType(tree.rootNode) + const calls = nodesByType.get("call") ?? [] const result = mountExtractor(calls[0]) assert.ok(result) @@ -848,7 +849,8 @@ router = f.APIRouter(prefix="/items") test("returns null for non-mount call", () => { const code = "app.some_method(arg1, arg2)" const tree = parse(code) - const calls = findNodesByType(tree.rootNode, "call") + const nodesByType = getNodesByType(tree.rootNode) + const calls = nodesByType.get("call") ?? [] const result = mountExtractor(calls[0]) assert.strictEqual(result, null) @@ -857,7 +859,8 @@ router = f.APIRouter(prefix="/items") test("returns null for mount with missing arguments", () => { const code = `app.mount("/static")` const tree = parse(code) - const calls = findNodesByType(tree.rootNode, "call") + const nodesByType = getNodesByType(tree.rootNode) + const calls = nodesByType.get("call") ?? [] const result = mountExtractor(calls[0]) assert.strictEqual(result, null) @@ -871,7 +874,8 @@ PREFIX = "/api" VERSION = "/v1" ` const tree = parse(code) - const vars = collectStringVariables(tree.rootNode) + const nodesByType = getNodesByType(tree.rootNode) + const vars = collectStringVariables(nodesByType) assert.strictEqual(vars.get("PREFIX"), "/api") assert.strictEqual(vars.get("VERSION"), "/v1") }) @@ -884,7 +888,8 @@ def handler(): PREFIX = "/local" ` const tree = parse(code) - const vars = collectStringVariables(tree.rootNode) + const nodesByType = getNodesByType(tree.rootNode) + const vars = collectStringVariables(nodesByType) assert.strictEqual(vars.get("PREFIX"), "/api") }) @@ -896,7 +901,8 @@ class Config: PREFIX = "/class-level" ` const tree = parse(code) - const vars = collectStringVariables(tree.rootNode) + const nodesByType = getNodesByType(tree.rootNode) + const vars = collectStringVariables(nodesByType) assert.strictEqual(vars.get("PREFIX"), "/api") }) @@ -906,7 +912,8 @@ COUNT = 42 FLAG = True ` const tree = parse(code) - const vars = collectStringVariables(tree.rootNode) + const nodesByType = getNodesByType(tree.rootNode) + const vars = collectStringVariables(nodesByType) assert.strictEqual(vars.size, 0) }) }) @@ -915,7 +922,8 @@ FLAG = True test("returns null for non-string node", () => { const code = "x = 42" const tree = parse(code) - const nodes = findNodesByType(tree.rootNode, "integer") + const nodesByType = getNodesByType(tree.rootNode) + const nodes = nodesByType.get("integer") ?? [] assert.strictEqual(extractStringValue(nodes[0]), null) }) }) @@ -924,7 +932,8 @@ FLAG = True test("returns dynamic placeholder for non-plus binary operator", () => { const code = "x = a - b" const tree = parse(code) - const ops = findNodesByType(tree.rootNode, "binary_operator") + const nodesByType = getNodesByType(tree.rootNode) + const ops = nodesByType.get("binary_operator") ?? [] const result = extractPathFromNode(ops[0]) assert.strictEqual(result, "\uE000a - b\uE000") }) @@ -938,10 +947,8 @@ def handler(): pass ` const tree = parse(code) - const decoratedDefs = findNodesByType( - tree.rootNode, - "decorated_definition", - ) + const nodesByType = getNodesByType(tree.rootNode) + const decoratedDefs = nodesByType.get("decorated_definition") ?? [] const result = decoratorExtractor(decoratedDefs[0]) assert.ok(result) @@ -955,10 +962,8 @@ def handler(): pass ` const tree = parse(code) - const decoratedDefs = findNodesByType( - tree.rootNode, - "decorated_definition", - ) + const nodesByType = getNodesByType(tree.rootNode) + const decoratedDefs = nodesByType.get("decorated_definition") ?? [] const result = decoratorExtractor(decoratedDefs[0]) assert.ok(result) diff --git a/src/vscode/testCodeLensProvider.ts b/src/vscode/testCodeLensProvider.ts index c793a59..04d7411 100644 --- a/src/vscode/testCodeLensProvider.ts +++ b/src/vscode/testCodeLensProvider.ts @@ -16,7 +16,7 @@ import { import type { Node } from "web-tree-sitter" import { extractPathFromNode, - findNodesByType, + getNodesByType, resolveArgNode, } from "../core/extractors" import { ROUTE_METHODS } from "../core/internal" @@ -110,7 +110,8 @@ export class TestCodeLensProvider implements CodeLensProvider { private findTestClientCalls(rootNode: Node): TestClientCall[] { const calls: TestClientCall[] = [] - const callNodes = findNodesByType(rootNode, "call") + const nodesByType = getNodesByType(rootNode) + const callNodes = nodesByType.get("call") ?? [] for (const callNode of callNodes) { // Grammar guarantees: call nodes always have a function field