diff --git a/src/core/analyzer.ts b/src/core/analyzer.ts index 3390b15..754593b 100644 --- a/src/core/analyzer.ts +++ b/src/core/analyzer.ts @@ -5,6 +5,7 @@ import type { Tree } from "web-tree-sitter" import { logError } from "../utils/logger" import { + collectRecognizedNames, collectStringVariables, decoratorExtractor, findNodesByType, @@ -44,7 +45,10 @@ export function analyzeTree(tree: Tree, filePath: string): FileAnalysis { // Get all router assignments const assignments = findNodesByType(rootNode, "assignment") - const routers = assignments.map(routerExtractor).filter(notNull) + const { fastAPINames, apiRouterNames } = collectRecognizedNames(rootNode) + const routers = assignments + .map((node) => routerExtractor(node, apiRouterNames, fastAPINames)) + .filter(notNull) // Get all include_router and mount calls const callNodes = findNodesByType(rootNode, "call") diff --git a/src/core/extractors.ts b/src/core/extractors.ts index ccd127b..3421472 100644 --- a/src/core/extractors.ts +++ b/src/core/extractors.ts @@ -65,16 +65,27 @@ function collectNodesByType(node: Node, type: string, results: Node[]): void { * Collects string variable assignments from the AST for path resolution. * Handles simple assignments like `WEBHOOK_PATH = "/webhook"`. * + * Only module-level assignments are collected — function/class-local variables + * are skipped to prevent shadowing module-level constants with the same name. + * * Examples: * WEBHOOK_PATH = "/webhook" -> Map { "WEBHOOK_PATH" => "/webhook" } * BASE = "/api" -> Map { "BASE" => "/api" } * settings.PREFIX = "/api" -> (skipped, not a simple identifier) + * def f(): BASE = "/local" -> (skipped, inside function) */ export function collectStringVariables(rootNode: Node): Map { const variables = new Map() const assignmentNodes = findNodesByType(rootNode, "assignment") for (const assign of assignmentNodes) { + if ( + hasAncestor(assign, "function_definition") || + hasAncestor(assign, "class_definition") + ) { + continue + } + const left = assign.childForFieldName("left") const right = assign.childForFieldName("right") if ( @@ -237,7 +248,11 @@ function extractTags(listNode: Node): string[] { .filter((v): v is string => v !== null) } -export function routerExtractor(node: Node): RouterInfo | null { +export function routerExtractor( + node: Node, + apiRouterNames?: Set, + fastAPINames?: Set, +): RouterInfo | null { if (node.type !== "assignment") { return null } @@ -250,9 +265,17 @@ export function routerExtractor(node: Node): RouterInfo | null { const funcName = valueNode.childForFieldName("function")?.text let type: RouterType - if (funcName === "APIRouter" || funcName === "fastapi.APIRouter") { + if ( + funcName !== undefined && + (apiRouterNames?.has(funcName) ?? + (funcName === "APIRouter" || funcName === "fastapi.APIRouter")) + ) { type = "APIRouter" - } else if (funcName === "FastAPI" || funcName === "fastapi.FastAPI") { + } else if ( + funcName !== undefined && + (fastAPINames?.has(funcName) ?? + (funcName === "FastAPI" || funcName === "fastapi.FastAPI")) + ) { type = "FastAPI" } else { return null @@ -326,13 +349,28 @@ export function importExtractor(node: Node): ImportInfo | null { const namedImports: ImportedName[] = [] if (node.type === "import_statement") { + let modulePath = "" + // Handle aliased imports: "import fastapi as f" + for (const aliased of findNodesByType(node, "aliased_import")) { + const nameNode = aliased.childForFieldName("name") + const aliasNode = aliased.childForFieldName("alias") + if (nameNode) { + if (!modulePath) modulePath = nameNode.text // preserve full dotted path + const alias = aliasNode?.text ?? null + names.push(alias ?? nameNode.text) + namedImports.push({ name: nameNode.text, alias }) + } + } + // Non-aliased: "import fastapi" or "import fastapi.routing" const nameNodes = findNodesByType(node, "dotted_name") for (const nameNode of nameNodes) { - const firstName = nameNode.text.split(".")[0] - names.push(firstName) - namedImports.push({ name: firstName, alias: null }) + if (!hasAncestor(nameNode, "aliased_import")) { + if (!modulePath) modulePath = nameNode.text // preserve full dotted path + const firstName = nameNode.text.split(".")[0] + names.push(firstName) + namedImports.push({ name: firstName, alias: null }) + } } - const modulePath = nameNodes[0]?.text ?? "" return { modulePath, names, @@ -372,6 +410,65 @@ export function importExtractor(node: Node): ImportInfo | null { return { modulePath, names, namedImports, isRelative, relativeDots } } +/** + * Extracts recognized FastAPI and APIRouter names from imports and class definitions. + * This allows routerExtractor to handle user-defined aliases and subclasses. + * + * For example, if the code has: + * from fastapi import FastAPI as MyApp + * from fastapi import APIRouter as MyRouter + * class CustomRouter(MyRouter): ... + * + * Then this function will return: + * fastAPINames = Set { "FastAPI", "fastapi.FastAPI", "MyApp" } + * apiRouterNames = Set { "APIRouter", "fastapi.APIRouter", "MyRouter", "CustomRouter" } + */ +export function collectRecognizedNames(rootNode: Node): { + fastAPINames: Set + apiRouterNames: Set +} { + const fastAPINames = new Set(["FastAPI", "fastapi.FastAPI"]) + const apiRouterNames = new Set(["APIRouter", "fastapi.APIRouter"]) + + // Add aliases from "from fastapi import X as Y" imports + for (const node of findNodesByType(rootNode, "import_from_statement")) { + const info = importExtractor(node) + if (!info || info.modulePath !== "fastapi") continue + for (const named of info.namedImports) { + if (named.alias === null) continue + if (named.name === "FastAPI") fastAPINames.add(named.alias) + else if (named.name === "APIRouter") apiRouterNames.add(named.alias) + } + } + + // Add module aliases from "import fastapi as f" → recognizes f.FastAPI, f.APIRouter + for (const node of findNodesByType(rootNode, "import_statement")) { + const info = importExtractor(node) + if (!info) continue + for (const named of info.namedImports) { + if (named.alias === null) continue + if (named.name === "fastapi") { + fastAPINames.add(`${named.alias}.FastAPI`) + apiRouterNames.add(`${named.alias}.APIRouter`) + } + } + } + + // Add subclasses, checking against the already-accumulated alias sets so + // "class MyRouter(AR)" works when AR is an alias for APIRouter + for (const cls of findNodesByType(rootNode, "class_definition")) { + const nameNode = cls.childForFieldName("name") + const superclassesNode = cls.childForFieldName("superclasses") + if (!nameNode || !superclassesNode) continue + for (const parent of superclassesNode.namedChildren) { + if (apiRouterNames.has(parent.text)) apiRouterNames.add(nameNode.text) + else if (fastAPINames.has(parent.text)) fastAPINames.add(nameNode.text) + } + } + + return { fastAPINames, apiRouterNames } +} + /** * Resolves a function argument value node by positional index or keyword name. * @@ -379,7 +476,7 @@ export function importExtractor(node: Node): ImportInfo | null { * app.get("/users", response_model=List[User]) → position 0 = string node "/users" * app.get(path="/users", response_model=List[User]) → keyword "path" = string node "/users" */ -function resolveArgNode( +export function resolveArgNode( args: Node[], position: number, keywordName: string, diff --git a/src/test/core/extractors.test.ts b/src/test/core/extractors.test.ts index a449f2b..1cc63db 100644 --- a/src/test/core/extractors.test.ts +++ b/src/test/core/extractors.test.ts @@ -1,5 +1,7 @@ import * as assert from "node:assert" import { + collectRecognizedNames, + collectStringVariables, decoratorExtractor, extractPathFromNode, extractStringValue, @@ -465,6 +467,171 @@ def list_users(): assert.strictEqual(result.type, "APIRouter") assert.strictEqual(result.prefix, "/api") }) + + test("returns null for custom subclass without subclasses set", () => { + const code = "admin_router = AdminAPIRouter(prefix='/admin')" + const tree = parse(code) + const assignments = findNodesByType(tree.rootNode, "assignment") + const result = routerExtractor(assignments[0]) + + assert.strictEqual(result, null) + }) + + test("recognizes custom APIRouter subclass", () => { + const code = ` +class AdminAPIRouter(APIRouter): + pass + +admin_router = AdminAPIRouter(prefix="/admin") +` + const tree = parse(code) + const { apiRouterNames } = collectRecognizedNames(tree.rootNode) + assert.ok(apiRouterNames.has("AdminAPIRouter")) + + const assignments = findNodesByType(tree.rootNode, "assignment") + const result = routerExtractor(assignments[0], apiRouterNames) + + assert.ok(result) + assert.strictEqual(result.variableName, "admin_router") + assert.strictEqual(result.type, "APIRouter") + assert.strictEqual(result.prefix, "/admin") + }) + + test("recognizes FastAPI subclass", () => { + const code = ` +class MyApp(FastAPI): + pass + +app = MyApp() +` + const tree = parse(code) + const { fastAPINames, apiRouterNames } = collectRecognizedNames( + tree.rootNode, + ) + assert.ok(fastAPINames.has("MyApp")) + assert.ok(!apiRouterNames.has("MyApp")) + + const assignments = findNodesByType(tree.rootNode, "assignment") + const result = routerExtractor( + assignments[0], + apiRouterNames, + fastAPINames, + ) + + assert.ok(result) + assert.strictEqual(result.variableName, "app") + assert.strictEqual(result.type, "FastAPI") + }) + + test("recognizes aliased FastAPI import (FastAPI as FA)", () => { + const code = ` +from fastapi import FastAPI as FA + +app = FA() +` + const tree = parse(code) + const { fastAPINames, apiRouterNames } = collectRecognizedNames( + tree.rootNode, + ) + assert.ok(fastAPINames.has("FA")) + + const assignments = findNodesByType(tree.rootNode, "assignment") + const result = routerExtractor( + assignments[0], + apiRouterNames, + fastAPINames, + ) + + assert.ok(result) + assert.strictEqual(result.variableName, "app") + assert.strictEqual(result.type, "FastAPI") + }) + + test("recognizes aliased APIRouter import (APIRouter as AR)", () => { + const code = ` +from fastapi import APIRouter as AR + +router = AR(prefix="/items") +` + const tree = parse(code) + const { apiRouterNames } = collectRecognizedNames(tree.rootNode) + assert.ok(apiRouterNames.has("AR")) + + const assignments = findNodesByType(tree.rootNode, "assignment") + const result = routerExtractor(assignments[0], apiRouterNames) + + assert.ok(result) + assert.strictEqual(result.variableName, "router") + assert.strictEqual(result.type, "APIRouter") + assert.strictEqual(result.prefix, "/items") + }) + + test("recognizes subclass of aliased APIRouter (class MyRouter(AR))", () => { + const code = ` +from fastapi import APIRouter as AR + +class MyRouter(AR): + pass + +router = MyRouter(prefix="/items") +` + const tree = parse(code) + const { apiRouterNames } = collectRecognizedNames(tree.rootNode) + assert.ok(apiRouterNames.has("AR")) + assert.ok(apiRouterNames.has("MyRouter")) + + const assignments = findNodesByType(tree.rootNode, "assignment") + const result = routerExtractor(assignments[0], apiRouterNames) + + assert.ok(result) + assert.strictEqual(result.type, "APIRouter") + }) + + test("collectRecognizedNames ignores non-aliased imports", () => { + const code = "from fastapi import FastAPI, APIRouter" + const tree = parse(code) + const { fastAPINames, apiRouterNames } = collectRecognizedNames( + tree.rootNode, + ) + // Only the defaults — no extras from non-aliased imports + assert.strictEqual(fastAPINames.size, 2) // "FastAPI", "fastapi.FastAPI" + assert.strictEqual(apiRouterNames.size, 2) // "APIRouter", "fastapi.APIRouter" + }) + + test("recognizes module alias (import fastapi as f)", () => { + const code = ` +import fastapi as f + +app = f.FastAPI() +router = f.APIRouter(prefix="/items") +` + const tree = parse(code) + const { fastAPINames, apiRouterNames } = collectRecognizedNames( + tree.rootNode, + ) + assert.ok(fastAPINames.has("f.FastAPI")) + assert.ok(apiRouterNames.has("f.APIRouter")) + + const assignments = findNodesByType(tree.rootNode, "assignment") + const appResult = routerExtractor( + assignments[0], + apiRouterNames, + fastAPINames, + ) + assert.ok(appResult) + assert.strictEqual(appResult.variableName, "app") + assert.strictEqual(appResult.type, "FastAPI") + + const routerResult = routerExtractor( + assignments[1], + apiRouterNames, + fastAPINames, + ) + assert.ok(routerResult) + assert.strictEqual(routerResult.variableName, "router") + assert.strictEqual(routerResult.type, "APIRouter") + assert.strictEqual(routerResult.prefix, "/items") + }) }) suite("importExtractor", () => { @@ -477,6 +644,38 @@ def list_users(): assert.ok(result) assert.strictEqual(result.modulePath, "fastapi") assert.deepStrictEqual(result.names, ["fastapi"]) + assert.deepStrictEqual(result.namedImports, [ + { name: "fastapi", alias: null }, + ]) + assert.strictEqual(result.isRelative, false) + }) + + test("preserves full dotted modulePath for import fastapi.routing", () => { + const code = "import fastapi.routing" + const tree = parse(code) + const imports = findNodesByType(tree.rootNode, "import_statement") + const result = importExtractor(imports[0]) + + assert.ok(result) + assert.strictEqual(result.modulePath, "fastapi.routing") + assert.deepStrictEqual(result.names, ["fastapi"]) + assert.deepStrictEqual(result.namedImports, [ + { name: "fastapi", alias: null }, + ]) + }) + + test("extracts aliased module import (import fastapi as f)", () => { + const code = "import fastapi as f" + const tree = parse(code) + const imports = findNodesByType(tree.rootNode, "import_statement") + const result = importExtractor(imports[0]) + + assert.ok(result) + assert.strictEqual(result.modulePath, "fastapi") + assert.deepStrictEqual(result.names, ["f"]) + assert.deepStrictEqual(result.namedImports, [ + { name: "fastapi", alias: "f" }, + ]) assert.strictEqual(result.isRelative, false) }) @@ -665,6 +864,53 @@ def list_users(): }) }) + suite("collectStringVariables", () => { + test("collects module-level string assignments", () => { + const code = ` +PREFIX = "/api" +VERSION = "/v1" +` + const tree = parse(code) + const vars = collectStringVariables(tree.rootNode) + assert.strictEqual(vars.get("PREFIX"), "/api") + assert.strictEqual(vars.get("VERSION"), "/v1") + }) + + test("ignores function-local variables", () => { + const code = ` +PREFIX = "/api" + +def handler(): + PREFIX = "/local" +` + const tree = parse(code) + const vars = collectStringVariables(tree.rootNode) + assert.strictEqual(vars.get("PREFIX"), "/api") + }) + + test("ignores class-level variables", () => { + const code = ` +PREFIX = "/api" + +class Config: + PREFIX = "/class-level" +` + const tree = parse(code) + const vars = collectStringVariables(tree.rootNode) + assert.strictEqual(vars.get("PREFIX"), "/api") + }) + + test("ignores non-string assignments", () => { + const code = ` +COUNT = 42 +FLAG = True +` + const tree = parse(code) + const vars = collectStringVariables(tree.rootNode) + assert.strictEqual(vars.size, 0) + }) + }) + suite("extractStringValue", () => { test("returns null for non-string node", () => { const code = "x = 42" diff --git a/src/test/core/routerResolver.test.ts b/src/test/core/routerResolver.test.ts index b37d4c0..3e92675 100644 --- a/src/test/core/routerResolver.test.ts +++ b/src/test/core/routerResolver.test.ts @@ -570,5 +570,80 @@ suite("routerResolver", () => { assert.strictEqual(result, null) }) + + test("resolves custom APIRouter subclass as child router", async () => { + const result = await buildRouterGraph( + fixtures.customSubclass.mainPy, + parser, + fixtures.customSubclass.root, + nodeFileSystem, + ) + + assert.ok(result) + assert.strictEqual(result.type, "FastAPI") + assert.strictEqual(result.variableName, "app") + + assert.strictEqual( + result.children.length, + 1, + "Should have one child router", + ) + + const adminRouter = result.children[0].router + assert.strictEqual(adminRouter.type, "APIRouter") + assert.strictEqual(adminRouter.prefix, "/admin") + assert.strictEqual(adminRouter.routes.length, 2) + + const paths = adminRouter.routes.map((r) => r.path) + assert.ok(paths.includes("/users")) + + const methods = adminRouter.routes.map((r) => r.method.toLowerCase()) + assert.ok(methods.includes("get")) + assert.ok(methods.includes("post")) + }) + + test("resolves aliased FastAPI and APIRouter class imports", async () => { + const result = await buildRouterGraph( + fixtures.aliasedClass.mainPy, + parser, + fixtures.aliasedClass.root, + nodeFileSystem, + ) + + assert.ok(result) + assert.strictEqual(result.type, "FastAPI") + assert.strictEqual(result.variableName, "app") + + assert.strictEqual( + result.children.length, + 1, + "Should have one child router", + ) + + const usersRouter = result.children[0].router + assert.strictEqual(usersRouter.type, "APIRouter") + assert.strictEqual(usersRouter.prefix, "/users") + assert.strictEqual(usersRouter.routes.length, 2) + }) + + test("resolves module-aliased fastapi import (import fastapi as f)", async () => { + const result = await buildRouterGraph( + fixtures.aliasedModule.mainPy, + parser, + fixtures.aliasedModule.root, + nodeFileSystem, + ) + + assert.ok(result) + assert.strictEqual(result.type, "FastAPI") + assert.strictEqual(result.variableName, "app") + + assert.strictEqual(result.children.length, 1) + + const usersRouter = result.children[0].router + assert.strictEqual(usersRouter.type, "APIRouter") + assert.strictEqual(usersRouter.prefix, "/users") + assert.strictEqual(usersRouter.routes.length, 2) + }) }) }) diff --git a/src/test/fixtures/aliased-class/main.py b/src/test/fixtures/aliased-class/main.py new file mode 100644 index 0000000..616a5b2 --- /dev/null +++ b/src/test/fixtures/aliased-class/main.py @@ -0,0 +1,18 @@ +from fastapi import FastAPI as FA +from fastapi import APIRouter as AR + +app = FA() +router = AR(prefix="/users") + + +@router.get("/") +def list_users(): + return [] + + +@router.post("/") +def create_user(): + return {} + + +app.include_router(router) diff --git a/src/test/fixtures/aliased-module/main.py b/src/test/fixtures/aliased-module/main.py new file mode 100644 index 0000000..f10be75 --- /dev/null +++ b/src/test/fixtures/aliased-module/main.py @@ -0,0 +1,17 @@ +import fastapi as f + +app = f.FastAPI() +router = f.APIRouter(prefix="/users") + + +@router.get("/") +def list_users(): + return [] + + +@router.post("/") +def create_user(): + return {} + + +app.include_router(router) diff --git a/src/test/fixtures/custom-subclass/main.py b/src/test/fixtures/custom-subclass/main.py new file mode 100644 index 0000000..2093db3 --- /dev/null +++ b/src/test/fixtures/custom-subclass/main.py @@ -0,0 +1,5 @@ +from fastapi import FastAPI +from routers import admin_router + +app = FastAPI() +app.include_router(admin_router) diff --git a/src/test/fixtures/custom-subclass/routers.py b/src/test/fixtures/custom-subclass/routers.py new file mode 100644 index 0000000..b8217df --- /dev/null +++ b/src/test/fixtures/custom-subclass/routers.py @@ -0,0 +1,18 @@ +from fastapi import APIRouter + + +class AdminAPIRouter(APIRouter): + pass + + +admin_router = AdminAPIRouter(prefix="/admin") + + +@admin_router.get("/users") +def list_users(): + return [] + + +@admin_router.post("/users") +def create_user(): + return {} diff --git a/src/test/providers/testCodeLensProvider.test.ts b/src/test/providers/testCodeLensProvider.test.ts index adefffb..5d067d5 100644 --- a/src/test/providers/testCodeLensProvider.test.ts +++ b/src/test/providers/testCodeLensProvider.test.ts @@ -287,6 +287,22 @@ def test_something(): assert.strictEqual(lenses.length, 0) }) + test("creates CodeLens for url= keyword argument", async () => { + const app = createMockApp([createRoute("GET", "/users")]) + provider.setApps([app]) + + const doc = await vscode.workspace.openTextDocument({ + content: ` +def test_get_users(): + response = client.get(url="/users") +`, + language: "python", + }) + const lenses = provider.provideCodeLenses(doc) + assert.strictEqual(lenses.length, 1) + assert.ok(lenses[0].command?.title.includes("/users")) + }) + test("ignores calls with no arguments", async () => { const app = createMockApp([createRoute("GET", "/users")]) provider.setApps([app]) diff --git a/src/test/testUtils.ts b/src/test/testUtils.ts index 02751c2..8724b59 100644 --- a/src/test/testUtils.ts +++ b/src/test/testUtils.ts @@ -87,6 +87,18 @@ export const fixtures = { root: uri(join(fixturesPath, "factory-func")), mainPy: uri(join(fixturesPath, "factory-func", "main.py")), }, + customSubclass: { + root: uri(join(fixturesPath, "custom-subclass")), + mainPy: uri(join(fixturesPath, "custom-subclass", "main.py")), + }, + aliasedClass: { + root: uri(join(fixturesPath, "aliased-class")), + mainPy: uri(join(fixturesPath, "aliased-class", "main.py")), + }, + aliasedModule: { + root: uri(join(fixturesPath, "aliased-module")), + mainPy: uri(join(fixturesPath, "aliased-module", "main.py")), + }, } /** diff --git a/src/vscode/testCodeLensProvider.ts b/src/vscode/testCodeLensProvider.ts index a7c89c7..c793a59 100644 --- a/src/vscode/testCodeLensProvider.ts +++ b/src/vscode/testCodeLensProvider.ts @@ -14,7 +14,11 @@ import { Uri, } from "vscode" import type { Node } from "web-tree-sitter" -import { extractPathFromNode, findNodesByType } from "../core/extractors" +import { + extractPathFromNode, + findNodesByType, + resolveArgNode, +} from "../core/extractors" import { ROUTE_METHODS } from "../core/internal" import type { Parser } from "../core/parser" import { @@ -134,8 +138,11 @@ export class TestCodeLensProvider implements CodeLensProvider { continue } - const pathArg = args[0] - // extractPathFromNode always returns a non-empty string for valid AST nodes + const pathArg = resolveArgNode(args, 0, "url") + + if (!pathArg) { + continue + } const path = extractPathFromNode(pathArg) calls.push({