Skip to content

Commit acf2ea4

Browse files
bagel897eacodegen
andauthored
refactor: Split out file IO (#404)
# Motivation - Helpful for LSP support # Content - Move all file I/O into it's own class --------- Co-authored-by: eacodegen <[email protected]>
1 parent 54d1809 commit acf2ea4

File tree

8 files changed

+197
-73
lines changed

8 files changed

+197
-73
lines changed

src/codegen/sdk/codebase/codebase_context.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import os
44
from collections import Counter, defaultdict
5-
from concurrent.futures import ThreadPoolExecutor
65
from contextlib import contextmanager
76
from enum import IntEnum, auto, unique
87
from functools import lru_cache
@@ -16,14 +15,14 @@
1615
from codegen.sdk.codebase.config_parser import ConfigParser, get_config_parser_for_language
1716
from codegen.sdk.codebase.diff_lite import ChangeType, DiffLite
1817
from codegen.sdk.codebase.flagging.flags import Flags
18+
from codegen.sdk.codebase.io.file_io import FileIO
1919
from codegen.sdk.codebase.transaction_manager import TransactionManager
2020
from codegen.sdk.codebase.validation import get_edges, post_reset_validation
2121
from codegen.sdk.core.autocommit import AutoCommit, commiter
2222
from codegen.sdk.core.directory import Directory
2323
from codegen.sdk.core.external.dependency_manager import DependencyManager, get_dependency_manager
2424
from codegen.sdk.core.external.language_engine import LanguageEngine, get_language_engine
2525
from codegen.sdk.enums import Edge, EdgeType, NodeType, ProgrammingLanguage
26-
from codegen.sdk.extensions.io import write_changes
2726
from codegen.sdk.extensions.sort import sort_editables
2827
from codegen.sdk.extensions.utils import uncache_all
2928
from codegen.sdk.typescript.external.ts_declassify.ts_declassify import TSDeclassify
@@ -37,6 +36,7 @@
3736
from git import Commit as GitCommit
3837

3938
from codegen.git.repo_operator.repo_operator import RepoOperator
39+
from codegen.sdk.codebase.io.io import IO
4040
from codegen.sdk.codebase.node_classes.node_classes import NodeClasses
4141
from codegen.sdk.core.dataclasses.usage import Usage
4242
from codegen.sdk.core.expressions import Expression
@@ -92,7 +92,6 @@ class CodebaseContext:
9292
pending_syncs: list[DiffLite] # Diffs that have been applied to disk, but not the graph (to be used for sync graph)
9393
all_syncs: list[DiffLite] # All diffs that have been applied to the graph (to be used for graph reset)
9494
_autocommit: AutoCommit
95-
pending_files: set[SourceFile]
9695
generation: int
9796
parser: Parser[Expression]
9897
synced_commit: GitCommit | None
@@ -110,6 +109,7 @@ class CodebaseContext:
110109
session_options: SessionOptions = SessionOptions()
111110
projects: list[ProjectConfig]
112111
unapplied_diffs: list[DiffLite]
112+
io: IO
113113

114114
def __init__(
115115
self,
@@ -134,6 +134,7 @@ def __init__(
134134

135135
# =====[ __init__ attributes ]=====
136136
self.projects = projects
137+
self.io = FileIO()
137138
context = projects[0]
138139
self.node_classes = get_node_classes(context.programming_language)
139140
self.config = config
@@ -165,7 +166,6 @@ def __init__(
165166
self.pending_syncs = []
166167
self.all_syncs = []
167168
self.unapplied_diffs = []
168-
self.pending_files = set()
169169
self.flags = Flags()
170170

171171
def __repr__(self):
@@ -259,7 +259,13 @@ def _reset_files(self, syncs: list[DiffLite]) -> None:
259259
files_to_remove.append(sync.path)
260260
modified_files.add(sync.path)
261261
logger.info(f"Writing {len(files_to_write)} files to disk and removing {len(files_to_remove)} files")
262-
write_changes(files_to_remove, files_to_write)
262+
for file in files_to_remove:
263+
self.io.delete_file(file)
264+
to_save = set()
265+
for file, content in files_to_write:
266+
self.io.write_file(file, content)
267+
to_save.add(file)
268+
self.io.save_files(to_save)
263269

264270
@stopwatch
265271
def reset_codebase(self) -> None:
@@ -270,7 +276,7 @@ def reset_codebase(self) -> None:
270276
def undo_applied_diffs(self) -> None:
271277
self.transaction_manager.clear_transactions()
272278
self.reset_codebase()
273-
self.check_changes()
279+
self.io.check_changes()
274280
self.pending_syncs.clear() # Discard pending changes
275281
if len(self.all_syncs) > 0:
276282
logger.info(f"Unapplying {len(self.all_syncs)} diffs to graph. Current graph commit: {self.synced_commit}")
@@ -432,7 +438,7 @@ def _process_diff_files(self, files_to_sync: Mapping[SyncType, list[Path]], incr
432438

433439
# Step 5: Add new files as nodes to graph (does not yet add edges)
434440
for filepath in files_to_sync[SyncType.ADD]:
435-
content = filepath.read_text(errors="ignore")
441+
content = self.io.read_text(filepath)
436442
# TODO: this is wrong with context changes
437443
if filepath.suffix in self.extensions:
438444
file_cls = self.node_classes.file_cls
@@ -634,17 +640,6 @@ def remove_edge(self, u: NodeId, v: NodeId, *, edge_type: EdgeType | None = None
634640
continue
635641
self._graph.remove_edge_from_index(edge)
636642

637-
def check_changes(self) -> None:
638-
for file in self.pending_files:
639-
file.check_changes()
640-
self.pending_files.clear()
641-
642-
def write_files(self, files: set[Path] | None = None) -> None:
643-
to_write = set(filter(lambda f: f.filepath in files, self.pending_files)) if files is not None else self.pending_files
644-
with ThreadPoolExecutor() as exec:
645-
exec.map(lambda f: f.write_pending_content(), to_write)
646-
self.pending_files.difference_update(to_write)
647-
648643
@lru_cache(maxsize=10000)
649644
def to_absolute(self, filepath: PathLike | str) -> Path:
650645
path = Path(filepath)
@@ -684,7 +679,7 @@ def commit_transactions(self, sync_graph: bool = True, sync_file: bool = True, f
684679

685680
# Write files if requested
686681
if sync_file:
687-
self.write_files(files)
682+
self.io.save_files(files)
688683

689684
# Sync the graph if requested
690685
if sync_graph and len(self.pending_syncs) > 0:
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import logging
2+
from concurrent.futures import ThreadPoolExecutor
3+
from pathlib import Path
4+
5+
from codegen.sdk.codebase.io.io import IO, BadWriteError
6+
7+
logger = logging.getLogger(__name__)
8+
9+
10+
class FileIO(IO):
11+
"""IO implementation that writes files to disk, and tracks pending changes."""
12+
13+
files: dict[Path, bytes]
14+
15+
def __init__(self):
16+
self.files = {}
17+
18+
def write_bytes(self, path: Path, content: bytes) -> None:
19+
self.files[path] = content
20+
21+
def read_bytes(self, path: Path) -> bytes:
22+
if path in self.files:
23+
return self.files[path]
24+
else:
25+
return path.read_bytes()
26+
27+
def save_files(self, files: set[Path] | None = None) -> None:
28+
to_save = set(filter(lambda f: f in files, self.files)) if files is not None else self.files.keys()
29+
with ThreadPoolExecutor() as exec:
30+
exec.map(lambda path: path.write_bytes(self.files[path]), to_save)
31+
if files is None:
32+
self.files.clear()
33+
else:
34+
for path in to_save:
35+
del self.files[path]
36+
37+
def check_changes(self) -> None:
38+
if self.files:
39+
logger.error(BadWriteError("Directly called file write without calling commit_transactions"))
40+
self.files.clear()
41+
42+
def delete_file(self, path: Path) -> None:
43+
self.untrack_file(path)
44+
if path.exists():
45+
path.unlink()
46+
47+
def untrack_file(self, path: Path) -> None:
48+
self.files.pop(path, None)
49+
50+
def file_exists(self, path: Path) -> bool:
51+
return path.exists()

src/codegen/sdk/codebase/io/io.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from abc import ABC, abstractmethod
2+
from pathlib import Path
3+
4+
5+
class BadWriteError(Exception):
6+
pass
7+
8+
9+
class IO(ABC):
10+
def write_file(self, path: Path, content: str | bytes | None) -> None:
11+
if content is None:
12+
self.untrack_file(path)
13+
elif isinstance(content, str):
14+
self.write_text(path, content)
15+
else:
16+
self.write_bytes(path, content)
17+
18+
def write_text(self, path: Path, content: str) -> None:
19+
self.write_bytes(path, content.encode("utf-8"))
20+
21+
@abstractmethod
22+
def untrack_file(self, path: Path) -> None:
23+
pass
24+
25+
@abstractmethod
26+
def write_bytes(self, path: Path, content: bytes) -> None:
27+
pass
28+
29+
@abstractmethod
30+
def read_bytes(self, path: Path) -> bytes:
31+
pass
32+
33+
def read_text(self, path: Path) -> str:
34+
return self.read_bytes(path).decode("utf-8")
35+
36+
@abstractmethod
37+
def save_files(self, files: set[Path] | None = None) -> None:
38+
pass
39+
40+
@abstractmethod
41+
def check_changes(self) -> None:
42+
pass
43+
44+
@abstractmethod
45+
def delete_file(self, path: Path) -> None:
46+
pass
47+
48+
@abstractmethod
49+
def file_exists(self, path: Path) -> bool:
50+
pass

src/codegen/sdk/codebase/transactions.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import os
21
from collections.abc import Callable
32
from difflib import unified_diff
43
from enum import IntEnum
@@ -267,7 +266,7 @@ def __init__(
267266

268267
def execute(self) -> None:
269268
"""Renames the file"""
270-
self.file.write_pending_content()
269+
self.file.ctx.io.save_files({self.file.path})
271270
self.file_path.rename(self.new_file_path)
272271

273272
def get_diff(self) -> DiffLite:
@@ -292,8 +291,7 @@ def __init__(
292291

293292
def execute(self) -> None:
294293
"""Removes the file"""
295-
os.remove(self.file_path)
296-
self.file._pending_content_bytes = None
294+
self.file.ctx.io.delete_file(self.file.path)
297295

298296
def get_diff(self) -> DiffLite:
299297
"""Gets the diff produced by this transaction"""

src/codegen/sdk/core/codebase.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -481,17 +481,17 @@ def get_file(self, filepath: str, *, optional: bool = False, ignore_case: bool =
481481

482482
def get_file_from_path(path: Path) -> File | None:
483483
try:
484-
return File.from_content(path, path.read_text(), self.ctx, sync=False)
484+
return File.from_content(path, self.ctx.io.read_text(path), self.ctx, sync=False)
485485
except UnicodeDecodeError:
486486
# Handle when file is a binary file
487-
return File.from_content(path, path.read_bytes(), self.ctx, sync=False, binary=True)
487+
return File.from_content(path, self.ctx.io.read_bytes(path), self.ctx, sync=False, binary=True)
488488

489489
# Try to get the file from the graph first
490490
file = self.ctx.get_file(filepath, ignore_case=ignore_case)
491491
if file is not None:
492492
return file
493493
absolute_path = self.ctx.to_absolute(filepath)
494-
if absolute_path.exists():
494+
if self.ctx.io.file_exists(absolute_path):
495495
return get_file_from_path(absolute_path)
496496
elif ignore_case:
497497
parent = absolute_path.parent

src/codegen/sdk/core/file.py

Lines changed: 14 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing import TYPE_CHECKING, Generic, Literal, Self, TypeVar, override
1212

1313
from tree_sitter import Node as TSNode
14+
from typing_extensions import deprecated
1415

1516
from codegen.sdk._proxy import proxy_property
1617
from codegen.sdk.codebase.codebase_context import CodebaseContext
@@ -45,10 +46,6 @@
4546
logger = logging.getLogger(__name__)
4647

4748

48-
class BadWriteError(Exception):
49-
pass
50-
51-
5249
@apidoc
5350
class File(Editable[None]):
5451
"""Represents a generic file.
@@ -66,7 +63,6 @@ class File(Editable[None]):
6663
file_path: str
6764
path: Path
6865
node_type: Literal[NodeType.FILE] = NodeType.FILE
69-
_pending_content_bytes: bytes | None = None
7066
_directory: Directory | None
7167
_pending_imports: set[str]
7268
_binary: bool = False
@@ -117,10 +113,8 @@ def from_content(cls, filepath: str | Path, content: str | bytes, ctx: CodebaseC
117113
if not path.exists():
118114
update_graph = True
119115
path.parent.mkdir(parents=True, exist_ok=True)
120-
if not binary:
121-
path.write_text(content)
122-
else:
123-
path.write_bytes(content)
116+
ctx.io.write_file(path, content)
117+
ctx.io.save_files({path})
124118

125119
new_file = cls(filepath, ctx, ts_node=None, binary=binary)
126120
return new_file
@@ -133,10 +127,7 @@ def content_bytes(self) -> bytes:
133127
134128
TODO: move rest of graph sitter to operate in bytes to prevent multi byte character issues?
135129
"""
136-
# Check against None due to possibility of empty byte
137-
if self._pending_content_bytes is None:
138-
return self.path.read_bytes()
139-
return self._pending_content_bytes
130+
return self.ctx.io.read_bytes(self.path)
140131

141132
@property
142133
@reader
@@ -162,31 +153,18 @@ def content(self) -> str:
162153

163154
@noapidoc
164155
def write(self, content: str | bytes, to_disk: bool = False) -> None:
165-
"""Writes string contents to the file."""
166-
self.write_bytes(content.encode("utf-8") if isinstance(content, str) else content, to_disk=to_disk)
167-
168-
@noapidoc
169-
def write_bytes(self, content_bytes: bytes, to_disk: bool = False) -> None:
170-
self._pending_content_bytes = content_bytes
171-
self.ctx.pending_files.add(self)
156+
"""Writes contents to the file."""
157+
self.ctx.io.write_file(self.path, content)
172158
if to_disk:
173-
self.write_pending_content()
159+
self.ctx.io.save_files({self.path})
174160
if self.ts_node.start_byte == self.ts_node.end_byte:
175161
# TS didn't parse anything, register a write to make sure the transaction manager can restore the file later.
176162
self.edit("")
177163

178164
@noapidoc
179-
def write_pending_content(self) -> None:
180-
if self._pending_content_bytes is not None:
181-
self.path.write_bytes(self._pending_content_bytes)
182-
self._pending_content_bytes = None
183-
logger.debug("Finished write_pending_content")
184-
185-
@noapidoc
186-
@writer
187-
def check_changes(self) -> None:
188-
if self._pending_content_bytes is not None:
189-
logger.error(BadWriteError("Directly called file write without calling commit_transactions"))
165+
@deprecated("Use write instead")
166+
def write_bytes(self, content_bytes: bytes, to_disk: bool = False) -> None:
167+
self.write(content_bytes, to_disk=to_disk)
190168

191169
@property
192170
@reader
@@ -272,7 +250,7 @@ def remove(self) -> None:
272250
None
273251
"""
274252
self.transaction_manager.add_file_remove_transaction(self)
275-
self._pending_content_bytes = None
253+
self.ctx.io.write_file(self.path, None)
276254

277255
@property
278256
def filepath(self) -> str:
@@ -596,10 +574,11 @@ def from_content(cls, filepath: str | PathLike | Path, content: str, ctx: Codeba
596574
return None
597575

598576
update_graph = False
599-
if not path.exists():
577+
if not ctx.io.file_exists(path):
600578
update_graph = True
601579
path.parent.mkdir(parents=True, exist_ok=True)
602-
path.write_text(content)
580+
ctx.io.write_file(path, content)
581+
ctx.io.save_files({path})
603582

604583
if update_graph and sync:
605584
ctx.add_single_file(path)

src/codegen/sdk/extensions/io.pyx

Lines changed: 0 additions & 12 deletions
This file was deleted.

0 commit comments

Comments
 (0)