Skip to content

Commit 6f49201

Browse files
igerberclaude
andcommitted
Fix import expansion: keep full module paths, expand relative aliases
P1: parse_imports() no longer truncates to 2 components. Full module paths like diff_diff.visualization._common are preserved. For 'from . import foo' style relative imports, each alias is appended to the resolved base package (e.g., diff_diff.visualization._event_study instead of diff_diff.visualization). P2: Add 3 regression tests: submodule import truncation check, relative import alias expansion, visualization __init__.py expansion includes submodule files. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent a354a2f commit 6f49201

2 files changed

Lines changed: 50 additions & 13 deletions

File tree

.claude/scripts/openai_review.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -243,27 +243,26 @@ def parse_imports(file_path: str) -> "set[str]":
243243
for node in ast.walk(tree):
244244
if isinstance(node, ast.Import):
245245
for alias in node.names:
246-
if alias.name.startswith("diff_diff"):
247-
# Extract the top-level module: diff_diff.linalg.foo -> diff_diff.linalg
248-
parts = alias.name.split(".")
249-
if len(parts) >= 2:
250-
imports.add(f"{parts[0]}.{parts[1]}")
246+
if alias.name.startswith("diff_diff."):
247+
imports.add(alias.name)
251248
elif isinstance(node, ast.ImportFrom):
252249
if node.module and node.level == 0:
253250
# Absolute import: from diff_diff.linalg import ...
254-
if node.module.startswith("diff_diff"):
255-
parts = node.module.split(".")
256-
if len(parts) >= 2:
257-
imports.add(f"{parts[0]}.{parts[1]}")
251+
if node.module.startswith("diff_diff."):
252+
imports.add(node.module)
258253
elif node.level > 0 and package:
259-
# Relative import: from . import utils, from .linalg import solve_ols
254+
# Relative import: from .foo import bar, or from . import foo
260255
resolved = _resolve_relative_import(
261256
package, node.module, node.level
262257
)
263258
if resolved and resolved.startswith("diff_diff"):
264-
parts = resolved.split(".")
265-
if len(parts) >= 2:
266-
imports.add(f"{parts[0]}.{parts[1]}")
259+
if node.module:
260+
# from .linalg import solve_ols → resolved = "diff_diff.linalg"
261+
imports.add(resolved)
262+
else:
263+
# from . import _event_study, _common → append each alias
264+
for alias in node.names:
265+
imports.add(f"{resolved}.{alias.name}")
267266
return imports
268267

269268

tests/test_openai_review.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,32 @@ def test_ignores_non_diff_diff_imports(self, review_mod, tmp_path):
542542
imports = review_mod.parse_imports(str(test_file))
543543
assert imports == set()
544544

545+
def test_submodule_imports_not_truncated(self, review_mod, repo_root):
546+
"""Submodule imports should keep full path, not truncate to 2 components."""
547+
path = os.path.join(repo_root, "diff_diff", "visualization", "_staggered.py")
548+
if not os.path.isfile(path):
549+
pytest.skip("diff_diff/visualization/_staggered.py not found")
550+
imports = review_mod.parse_imports(path)
551+
# Should include full submodule paths like diff_diff.visualization._common
552+
has_submodule = any(
553+
m.count(".") >= 2 for m in imports # at least 3 components
554+
)
555+
assert has_submodule, (
556+
f"Expected submodule imports (3+ components) but got: {imports}"
557+
)
558+
559+
def test_relative_import_aliases_expanded(self, review_mod, repo_root):
560+
"""from . import _event_study should resolve to diff_diff.visualization._event_study."""
561+
path = os.path.join(repo_root, "diff_diff", "visualization", "__init__.py")
562+
if not os.path.isfile(path):
563+
pytest.skip("diff_diff/visualization/__init__.py not found")
564+
imports = review_mod.parse_imports(path)
565+
# Should include individual submodule names, not just the package
566+
submodules = [m for m in imports if m.startswith("diff_diff.visualization._")]
567+
assert len(submodules) > 0, (
568+
f"Expected visualization submodule imports but got: {imports}"
569+
)
570+
545571
def test_handles_syntax_error(self, review_mod, tmp_path, capsys):
546572
test_file = tmp_path / "bad.py"
547573
test_file.write_text("def foo(:\n pass\n")
@@ -583,6 +609,18 @@ def test_deduplicates_against_changed_set(self, review_mod, repo_root):
583609
result = review_mod.expand_import_graph([bacon, linalg], repo_root)
584610
assert linalg not in [os.path.normpath(p) for p in result]
585611

612+
def test_visualization_init_includes_submodules(self, review_mod, repo_root):
613+
"""expand_import_graph on visualization/__init__.py should include submodules."""
614+
path = os.path.join(repo_root, "diff_diff", "visualization", "__init__.py")
615+
if not os.path.isfile(path):
616+
pytest.skip("diff_diff/visualization/__init__.py not found")
617+
result = review_mod.expand_import_graph([path], repo_root)
618+
filenames = [os.path.basename(p) for p in result]
619+
# Should include visualization submodules like _event_study.py, _staggered.py
620+
assert any(f.startswith("_") and f.endswith(".py") for f in filenames), (
621+
f"Expected visualization submodule files but got: {filenames}"
622+
)
623+
586624
def test_empty_input(self, review_mod, repo_root):
587625
assert review_mod.expand_import_graph([], repo_root) == []
588626

0 commit comments

Comments
 (0)