Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions src/core/analyzer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {
collectRecognizedNames,
collectStringVariables,
decoratorExtractor,
findNodesByType,
getNodesByType,
importExtractor,
includeRouterExtractor,
mountExtractor,
Expand Down Expand Up @@ -37,32 +37,32 @@ function resolveVariables(

/** Analyze a syntax tree and extract FastAPI-related information */
export function analyzeTree(tree: Tree, filePath: string): FileAnalysis {
const rootNode = tree.rootNode
const nodesByType = getNodesByType(tree.rootNode)

// Get all decorated definitions (functions and classes with decorators)
const decoratedDefs = findNodesByType(rootNode, "decorated_definition")
const decoratedDefs = nodesByType.get("decorated_definition") ?? []
const routes = decoratedDefs.map(decoratorExtractor).filter(notNull)

// Get all router assignments
const assignments = findNodesByType(rootNode, "assignment")
const { fastAPINames, apiRouterNames } = collectRecognizedNames(rootNode)
const assignments = nodesByType.get("assignment") ?? []
const { fastAPINames, apiRouterNames } = collectRecognizedNames(nodesByType)
const routers = assignments
.map((node) => routerExtractor(node, apiRouterNames, fastAPINames))
.filter(notNull)

// Get all include_router and mount calls
const callNodes = findNodesByType(rootNode, "call")
const callNodes = nodesByType.get("call") ?? []
const includeRouters = callNodes.map(includeRouterExtractor).filter(notNull)
const mounts = callNodes.map(mountExtractor).filter(notNull)

// Get all import statements
const importNodes = findNodesByType(rootNode, "import_statement")
const importFromNodes = findNodesByType(rootNode, "import_from_statement")
const importNodes = nodesByType.get("import_statement") ?? []
const importFromNodes = nodesByType.get("import_from_statement") ?? []
const imports = [...importNodes, ...importFromNodes]
.map(importExtractor)
.filter(notNull)

const stringVariables = collectStringVariables(rootNode)
const stringVariables = collectStringVariables(nodesByType)

for (const route of routes) {
route.path = resolveVariables(route.path, stringVariables)
Expand Down
61 changes: 35 additions & 26 deletions src/core/extractors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,6 @@ import type {
} from "./internal"
import { ROUTE_METHODS } from "./internal"

/** Recursively finds all nodes of a given type within a subtree */
export function findNodesByType(node: Node, type: string): Node[] {
const results: Node[] = []
collectNodesByType(node, type, results)
return results
}

function stripDocstring(raw: string): string {
let content: string
if (
Expand Down Expand Up @@ -49,16 +42,24 @@ function stripDocstring(raw: string): string {
return dedented.join("\n").trim()
}

function collectNodesByType(node: Node, type: string, results: Node[]): void {
if (node.type === type) {
results.push(node)
}
for (let i = 0; i < node.childCount; i++) {
const child = node.child(i)
if (child) {
collectNodesByType(child, type, results)
export function getNodesByType(root: Node): Map<string, Node[]> {
const results = new Map<string, Node[]>()

function collectNodesByType(node: Node, results: Map<string, Node[]>): void {
if (!results.has(node.type)) {
results.set(node.type, [])
}
results.get(node.type)!.push(node)

for (let i = 0; i < node.childCount; i++) {
const child = node.child(i)
if (child) {
collectNodesByType(child, results)
}
}
}
collectNodesByType(root, results)
return results
}

/**
Expand All @@ -74,9 +75,11 @@ function collectNodesByType(node: Node, type: string, results: Node[]): void {
* settings.PREFIX = "/api" -> (skipped, not a simple identifier)
* def f(): BASE = "/local" -> (skipped, inside function)
*/
export function collectStringVariables(rootNode: Node): Map<string, string> {
export function collectStringVariables(
nodesByType: Map<string, Node[]>,
): Map<string, string> {
const variables = new Map<string, string>()
const assignmentNodes = findNodesByType(rootNode, "assignment")
const assignmentNodes = nodesByType.get("assignment") ?? []

for (const assign of assignmentNodes) {
if (
Expand Down Expand Up @@ -172,7 +175,11 @@ export function decoratorExtractor(node: Node): RouteInfo | null {
// Grammar guarantees: decorated_definition always has a first child (the decorator)
const decoratorNode = node.firstNamedChild!

const callNode = findNodesByType(decoratorNode, "call")[0]
const callNode =
decoratorNode.firstNamedChild?.type === "call"
? decoratorNode.firstNamedChild
: null

const functionNode = callNode?.childForFieldName("function")
const argumentsNode = callNode?.childForFieldName("arguments")
const objectNode = functionNode?.childForFieldName("object")
Expand Down Expand Up @@ -351,7 +358,8 @@ export function importExtractor(node: Node): ImportInfo | null {
if (node.type === "import_statement") {
let modulePath = ""
// Handle aliased imports: "import fastapi as f"
for (const aliased of findNodesByType(node, "aliased_import")) {
const aliasedImports = getNodesByType(node).get("aliased_import") ?? []
for (const aliased of aliasedImports) {
const nameNode = aliased.childForFieldName("name")
const aliasNode = aliased.childForFieldName("alias")
if (nameNode) {
Expand All @@ -362,7 +370,7 @@ export function importExtractor(node: Node): ImportInfo | null {
}
}
// Non-aliased: "import fastapi" or "import fastapi.routing"
const nameNodes = findNodesByType(node, "dotted_name")
const nameNodes = getNodesByType(node).get("dotted_name") ?? []
for (const nameNode of nameNodes) {
if (!hasAncestor(nameNode, "aliased_import")) {
if (!modulePath) modulePath = nameNode.text // preserve full dotted path
Expand All @@ -387,7 +395,8 @@ export function importExtractor(node: Node): ImportInfo | null {
)

// Aliased imports (e.g., "router as users_router")
for (const aliased of findNodesByType(node, "aliased_import")) {
const aliasedImports = getNodesByType(node).get("aliased_import") ?? []
for (const aliased of aliasedImports) {
const nameNode = aliased.childForFieldName("name")
const aliasNode = aliased.childForFieldName("alias")
if (nameNode) {
Expand All @@ -398,7 +407,7 @@ export function importExtractor(node: Node): ImportInfo | null {
}

// Non-aliased imports (skip first dotted_name which is the module path)
const nameNodes = findNodesByType(node, "dotted_name")
const nameNodes = getNodesByType(node).get("dotted_name") ?? []
for (let i = 1; i < nameNodes.length; i++) {
const nameNode = nameNodes[i]
if (!hasAncestor(nameNode, "aliased_import")) {
Expand All @@ -423,15 +432,15 @@ export function importExtractor(node: Node): ImportInfo | null {
* fastAPINames = Set { "FastAPI", "fastapi.FastAPI", "MyApp" }
* apiRouterNames = Set { "APIRouter", "fastapi.APIRouter", "MyRouter", "CustomRouter" }
*/
export function collectRecognizedNames(rootNode: Node): {
export function collectRecognizedNames(nodesByType: Map<string, Node[]>): {
fastAPINames: Set<string>
apiRouterNames: Set<string>
} {
const fastAPINames = new Set<string>(["FastAPI", "fastapi.FastAPI"])
const apiRouterNames = new Set<string>(["APIRouter", "fastapi.APIRouter"])

// Add aliases from "from fastapi import X as Y" imports
for (const node of findNodesByType(rootNode, "import_from_statement")) {
for (const node of nodesByType.get("import_from_statement") ?? []) {
const info = importExtractor(node)
if (!info || info.modulePath !== "fastapi") continue
for (const named of info.namedImports) {
Expand All @@ -442,7 +451,7 @@ export function collectRecognizedNames(rootNode: Node): {
}

// Add module aliases from "import fastapi as f" → recognizes f.FastAPI, f.APIRouter
for (const node of findNodesByType(rootNode, "import_statement")) {
for (const node of nodesByType.get("import_statement") ?? []) {
const info = importExtractor(node)
if (!info) continue
for (const named of info.namedImports) {
Expand All @@ -456,7 +465,7 @@ export function collectRecognizedNames(rootNode: Node): {

// Add subclasses, checking against the already-accumulated alias sets so
// "class MyRouter(AR)" works when AR is an alias for APIRouter
for (const cls of findNodesByType(rootNode, "class_definition")) {
for (const cls of nodesByType.get("class_definition") ?? []) {
const nameNode = cls.childForFieldName("name")
const superclassesNode = cls.childForFieldName("superclasses")
if (!nameNode || !superclassesNode) continue
Expand Down
Loading
Loading