Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/core/analyzer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import type { Tree } from "web-tree-sitter"
import { logError } from "../utils/logger"
import {
collectRecognizedNames,
collectStringVariables,
decoratorExtractor,
findNodesByType,
Expand Down Expand Up @@ -44,7 +45,10 @@ export function analyzeTree(tree: Tree, filePath: string): FileAnalysis {

// Get all router assignments
const assignments = findNodesByType(rootNode, "assignment")
const routers = assignments.map(routerExtractor).filter(notNull)
const { fastAPINames, apiRouterNames } = collectRecognizedNames(rootNode)
const routers = assignments
.map((node) => routerExtractor(node, apiRouterNames, fastAPINames))
.filter(notNull)

// Get all include_router and mount calls
const callNodes = findNodesByType(rootNode, "call")
Expand Down
113 changes: 105 additions & 8 deletions src/core/extractors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,27 @@ function collectNodesByType(node: Node, type: string, results: Node[]): void {
* Collects string variable assignments from the AST for path resolution.
* Handles simple assignments like `WEBHOOK_PATH = "/webhook"`.
*
* Only module-level assignments are collected — function/class-local variables
* are skipped to prevent shadowing module-level constants with the same name.
*
* Examples:
* WEBHOOK_PATH = "/webhook" -> Map { "WEBHOOK_PATH" => "/webhook" }
* BASE = "/api" -> Map { "BASE" => "/api" }
* settings.PREFIX = "/api" -> (skipped, not a simple identifier)
* def f(): BASE = "/local" -> (skipped, inside function)
*/
export function collectStringVariables(rootNode: Node): Map<string, string> {
const variables = new Map<string, string>()
const assignmentNodes = findNodesByType(rootNode, "assignment")

for (const assign of assignmentNodes) {
if (
hasAncestor(assign, "function_definition") ||
hasAncestor(assign, "class_definition")
) {
continue
}

const left = assign.childForFieldName("left")
const right = assign.childForFieldName("right")
if (
Expand Down Expand Up @@ -237,7 +248,11 @@ function extractTags(listNode: Node): string[] {
.filter((v): v is string => v !== null)
}

export function routerExtractor(node: Node): RouterInfo | null {
export function routerExtractor(
node: Node,
apiRouterNames?: Set<string>,
fastAPINames?: Set<string>,
): RouterInfo | null {
if (node.type !== "assignment") {
return null
}
Expand All @@ -250,9 +265,17 @@ export function routerExtractor(node: Node): RouterInfo | null {

const funcName = valueNode.childForFieldName("function")?.text
let type: RouterType
if (funcName === "APIRouter" || funcName === "fastapi.APIRouter") {
if (
funcName !== undefined &&
(apiRouterNames?.has(funcName) ??
(funcName === "APIRouter" || funcName === "fastapi.APIRouter"))
) {
type = "APIRouter"
} else if (funcName === "FastAPI" || funcName === "fastapi.FastAPI") {
} else if (
funcName !== undefined &&
(fastAPINames?.has(funcName) ??
(funcName === "FastAPI" || funcName === "fastapi.FastAPI"))
) {
type = "FastAPI"
} else {
return null
Expand Down Expand Up @@ -326,13 +349,28 @@ export function importExtractor(node: Node): ImportInfo | null {
const namedImports: ImportedName[] = []

if (node.type === "import_statement") {
let modulePath = ""
// Handle aliased imports: "import fastapi as f"
for (const aliased of findNodesByType(node, "aliased_import")) {
const nameNode = aliased.childForFieldName("name")
const aliasNode = aliased.childForFieldName("alias")
if (nameNode) {
if (!modulePath) modulePath = nameNode.text // preserve full dotted path
const alias = aliasNode?.text ?? null
names.push(alias ?? nameNode.text)
namedImports.push({ name: nameNode.text, alias })
}
}
// Non-aliased: "import fastapi" or "import fastapi.routing"
const nameNodes = findNodesByType(node, "dotted_name")
for (const nameNode of nameNodes) {
const firstName = nameNode.text.split(".")[0]
names.push(firstName)
namedImports.push({ name: firstName, alias: null })
if (!hasAncestor(nameNode, "aliased_import")) {
if (!modulePath) modulePath = nameNode.text // preserve full dotted path
const firstName = nameNode.text.split(".")[0]
names.push(firstName)
namedImports.push({ name: firstName, alias: null })
}
}
const modulePath = nameNodes[0]?.text ?? ""
return {
modulePath,
names,
Expand Down Expand Up @@ -372,14 +410,73 @@ export function importExtractor(node: Node): ImportInfo | null {
return { modulePath, names, namedImports, isRelative, relativeDots }
}

/**
* Extracts recognized FastAPI and APIRouter names from imports and class definitions.
* This allows routerExtractor to handle user-defined aliases and subclasses.
*
* For example, if the code has:
* from fastapi import FastAPI as MyApp
* from fastapi import APIRouter as MyRouter
* class CustomRouter(MyRouter): ...
*
* Then this function will return:
* fastAPINames = Set { "FastAPI", "fastapi.FastAPI", "MyApp" }
* apiRouterNames = Set { "APIRouter", "fastapi.APIRouter", "MyRouter", "CustomRouter" }
*/
export function collectRecognizedNames(rootNode: Node): {
fastAPINames: Set<string>
apiRouterNames: Set<string>
} {
const fastAPINames = new Set<string>(["FastAPI", "fastapi.FastAPI"])
const apiRouterNames = new Set<string>(["APIRouter", "fastapi.APIRouter"])

// Add aliases from "from fastapi import X as Y" imports
for (const node of findNodesByType(rootNode, "import_from_statement")) {
const info = importExtractor(node)
if (!info || info.modulePath !== "fastapi") continue
for (const named of info.namedImports) {
if (named.alias === null) continue
if (named.name === "FastAPI") fastAPINames.add(named.alias)
else if (named.name === "APIRouter") apiRouterNames.add(named.alias)
}
}

// Add module aliases from "import fastapi as f" → recognizes f.FastAPI, f.APIRouter
for (const node of findNodesByType(rootNode, "import_statement")) {
const info = importExtractor(node)
if (!info) continue
for (const named of info.namedImports) {
if (named.alias === null) continue
if (named.name === "fastapi") {
fastAPINames.add(`${named.alias}.FastAPI`)
apiRouterNames.add(`${named.alias}.APIRouter`)
}
}
}

// Add subclasses, checking against the already-accumulated alias sets so
// "class MyRouter(AR)" works when AR is an alias for APIRouter
for (const cls of findNodesByType(rootNode, "class_definition")) {
const nameNode = cls.childForFieldName("name")
const superclassesNode = cls.childForFieldName("superclasses")
if (!nameNode || !superclassesNode) continue
for (const parent of superclassesNode.namedChildren) {
if (apiRouterNames.has(parent.text)) apiRouterNames.add(nameNode.text)
else if (fastAPINames.has(parent.text)) fastAPINames.add(nameNode.text)
}
}

return { fastAPINames, apiRouterNames }
}

/**
* Resolves a function argument value node by positional index or keyword name.
*
* Examples:
* app.get("/users", response_model=List[User]) → position 0 = string node "/users"
* app.get(path="/users", response_model=List[User]) → keyword "path" = string node "/users"
*/
function resolveArgNode(
export function resolveArgNode(
args: Node[],
position: number,
keywordName: string,
Expand Down
Loading
Loading