1+ import  logging 
12import  os 
2- from  itertools  import  chain 
3+ from  collections . abc  import  Iterator 
34from  pathlib  import  Path 
4- from  typing  import  TYPE_CHECKING , Generic , Self , TypeVar 
5- 
6- from  codegen .shared .decorators .docs  import  apidoc , py_noapidoc 
7- 
8- if  TYPE_CHECKING :
9-     from  codegen .sdk .core .assignment  import  Assignment 
10-     from  codegen .sdk .core .class_definition  import  Class 
11-     from  codegen .sdk .core .file  import  File 
12-     from  codegen .sdk .core .function  import  Function 
13-     from  codegen .sdk .core .import_resolution  import  Import , ImportStatement 
14-     from  codegen .sdk .core .symbol  import  Symbol 
15-     from  codegen .sdk .typescript .class_definition  import  TSClass 
16-     from  codegen .sdk .typescript .export  import  TSExport 
17-     from  codegen .sdk .typescript .file  import  TSFile 
18-     from  codegen .sdk .typescript .function  import  TSFunction 
19-     from  codegen .sdk .typescript .import_resolution  import  TSImport 
20-     from  codegen .sdk .typescript .statements .import_statement  import  TSImportStatement 
21-     from  codegen .sdk .typescript .symbol  import  TSSymbol 
22- 
23- import  logging 
5+ from  typing  import  Generic , Self 
6+ 
7+ from  codegen .sdk .core .interfaces .has_symbols  import  (
8+     HasSymbols ,
9+     TClass ,
10+     TFile ,
11+     TFunction ,
12+     TGlobalVar ,
13+     TImport ,
14+     TImportStatement ,
15+     TSymbol ,
16+ )
17+ from  codegen .sdk .core .utils .cache_utils  import  cached_generator 
18+ from  codegen .shared .decorators .docs  import  apidoc , noapidoc 
2419
2520logger  =  logging .getLogger (__name__ )
2621
2722
28- TFile  =  TypeVar ("TFile" , bound = "File" )
29- TSymbol  =  TypeVar ("TSymbol" , bound = "Symbol" )
30- TImportStatement  =  TypeVar ("TImportStatement" , bound = "ImportStatement" )
31- TGlobalVar  =  TypeVar ("TGlobalVar" , bound = "Assignment" )
32- TClass  =  TypeVar ("TClass" , bound = "Class" )
33- TFunction  =  TypeVar ("TFunction" , bound = "Function" )
34- TImport  =  TypeVar ("TImport" , bound = "Import" )
35- 
36- TSGlobalVar  =  TypeVar ("TSGlobalVar" , bound = "Assignment" )
37- 
38- 
3923@apidoc  
40- class  Directory (Generic [TFile , TSymbol , TImportStatement , TGlobalVar , TClass , TFunction , TImport ]):
24+ class  Directory (
25+     HasSymbols [TFile , TSymbol , TImportStatement , TGlobalVar , TClass , TFunction , TImport ],
26+     Generic [TFile , TSymbol , TImportStatement , TGlobalVar , TClass , TFunction , TImport ],
27+ ):
4128    """Directory representation for codebase. 
4229
4330    GraphSitter abstraction of a file directory that can be used to look for files and symbols within a specific directory. 
@@ -58,7 +45,7 @@ def __init__(self, path: Path, dirpath: str, parent: Self | None):
5845        self .path  =  path 
5946        self .dirpath  =  dirpath 
6047        self .parent  =  parent 
61-         self .items  =  dict () 
48+         self .items  =  {} 
6249
6350    def  __iter__ (self ):
6451        return  iter (self .items .values ())
@@ -126,62 +113,13 @@ def _get_subdirectories(directory: Directory):
126113        _get_subdirectories (self )
127114        return  subdirectories 
128115
129-     @property  
130-     def  symbols (self ) ->  list [TSymbol ]:
131-         """Get a recursive list of all symbols in the directory and its subdirectories.""" 
132-         return  list (chain .from_iterable (f .symbols  for  f  in  self .files ))
133- 
134-     @property  
135-     def  import_statements (self ) ->  list [TImportStatement ]:
136-         """Get a recursive list of all import statements in the directory and its subdirectories.""" 
137-         return  list (chain .from_iterable (f .import_statements  for  f  in  self .files ))
138- 
139-     @property  
140-     def  global_vars (self ) ->  list [TGlobalVar ]:
141-         """Get a recursive list of all global variables in the directory and its subdirectories.""" 
142-         return  list (chain .from_iterable (f .global_vars  for  f  in  self .files ))
143- 
144-     @property  
145-     def  classes (self ) ->  list [TClass ]:
146-         """Get a recursive list of all classes in the directory and its subdirectories.""" 
147-         return  list (chain .from_iterable (f .classes  for  f  in  self .files ))
148- 
149-     @property  
150-     def  functions (self ) ->  list [TFunction ]:
151-         """Get a recursive list of all functions in the directory and its subdirectories.""" 
152-         return  list (chain .from_iterable (f .functions  for  f  in  self .files ))
153- 
154-     @property  
155-     @py_noapidoc  
156-     def  exports (self : "Directory[TSFile, TSSymbol, TSImportStatement, TSGlobalVar, TSClass, TSFunction, TSImport]" ) ->  "list[TSExport]" :
157-         """Get a recursive list of all exports in the directory and its subdirectories.""" 
158-         return  list (chain .from_iterable (f .exports  for  f  in  self .files ))
159- 
160-     @property  
161-     def  imports (self ) ->  list [TImport ]:
162-         """Get a recursive list of all imports in the directory and its subdirectories.""" 
163-         return  list (chain .from_iterable (f .imports  for  f  in  self .files ))
164- 
165-     def  get_symbol (self , name : str ) ->  TSymbol  |  None :
166-         """Get a symbol by name in the directory and its subdirectories.""" 
167-         return  next ((s  for  s  in  self .symbols  if  s .name  ==  name ), None )
168- 
169-     def  get_import_statement (self , name : str ) ->  TImportStatement  |  None :
170-         """Get an import statement by name in the directory and its subdirectories.""" 
171-         return  next ((s  for  s  in  self .import_statements  if  s .name  ==  name ), None )
172- 
173-     def  get_global_var (self , name : str ) ->  TGlobalVar  |  None :
174-         """Get a global variable by name in the directory and its subdirectories.""" 
175-         return  next ((s  for  s  in  self .global_vars  if  s .name  ==  name ), None )
176- 
177-     def  get_class (self , name : str ) ->  TClass  |  None :
178-         """Get a class by name in the directory and its subdirectories.""" 
179-         return  next ((s  for  s  in  self .classes  if  s .name  ==  name ), None )
180- 
181-     def  get_function (self , name : str ) ->  TFunction  |  None :
182-         """Get a function by name in the directory and its subdirectories.""" 
183-         return  next ((s  for  s  in  self .functions  if  s .name  ==  name ), None )
116+     @noapidoc  
117+     @cached_generator () 
118+     def  files_generator (self ) ->  Iterator [TFile ]:
119+         """Yield files recursively from the directory.""" 
120+         yield  from  self .files 
184121
122+     # Directory-specific methods 
185123    def  add_file (self , file : TFile ) ->  None :
186124        """Add a file to the directory.""" 
187125        rel_path  =  os .path .relpath (file .file_path , self .dirpath )
@@ -202,18 +140,12 @@ def get_file(self, filename: str, ignore_case: bool = False) -> TFile | None:
202140        from  codegen .sdk .core .file  import  File 
203141
204142        if  ignore_case :
205-             return  next ((f  for  name , f  in  self .items .items () if  name .lower () ==  filename .lower () and  isinstance (f , File )), None )
143+             return  next (
144+                 (f  for  name , f  in  self .items .items () if  name .lower () ==  filename .lower () and  isinstance (f , File )),
145+                 None ,
146+             )
206147        return  self .items .get (filename , None )
207148
208-     @py_noapidoc  
209-     def  get_export (self : "Directory[TSFile, TSSymbol, TSImportStatement, TSGlobalVar, TSClass, TSFunction, TSImport]" , name : str ) ->  "TSExport | None" :
210-         """Get an export by name in the directory and its subdirectories (supports only typescript).""" 
211-         return  next ((s  for  s  in  self .exports  if  s .name  ==  name ), None )
212- 
213-     def  get_import (self , name : str ) ->  TImport  |  None :
214-         """Get an import by name in the directory and its subdirectories.""" 
215-         return  next ((s  for  s  in  self .imports  if  s .name  ==  name ), None )
216- 
217149    def  add_subdirectory (self , subdirectory : Self ) ->  None :
218150        """Add a subdirectory to the directory.""" 
219151        rel_path  =  os .path .relpath (subdirectory .dirpath , self .dirpath )
@@ -230,23 +162,22 @@ def remove_subdirectory_by_path(self, subdirectory_path: str) -> None:
230162        del  self .items [rel_path ]
231163
232164    def  get_subdirectory (self , subdirectory_name : str ) ->  Self  |  None :
233-         """Get a subdirectory by its path  relative to the directory.""" 
165+         """Get a subdirectory by its name ( relative to the directory) .""" 
234166        return  self .items .get (subdirectory_name , None )
235167
236-     def  remove (self ) ->  None :
237-         """Remove the directory and all its files and subdirectories.""" 
238-         for  f  in  self .files :
239-             f .remove ()
240- 
241168    def  update_filepath (self , new_filepath : str ) ->  None :
242-         """Update the filepath of the directory.""" 
169+         """Update the filepath of the directory and its contained files .""" 
243170        old_path  =  self .dirpath 
244171        new_path  =  new_filepath 
245- 
246172        for  file  in  self .files :
247173            new_file_path  =  os .path .join (new_path , os .path .relpath (file .file_path , old_path ))
248174            file .update_filepath (new_file_path )
249175
176+     def  remove (self ) ->  None :
177+         """Remove all the files in the files container.""" 
178+         for  f  in  self .files :
179+             f .remove ()
180+ 
250181    def  rename (self , new_name : str ) ->  None :
251182        """Rename the directory.""" 
252183        parent_dir , _  =  os .path .split (self .dirpath )
0 commit comments