Skip to content

Commit c4b2e20

Browse files
committed
wip
1 parent 217c59f commit c4b2e20

File tree

2 files changed

+412
-19
lines changed

2 files changed

+412
-19
lines changed

item.py

Lines changed: 360 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,360 @@
1+
#!/usr/bin/env python3
2+
"""
3+
RustPython Test Decorator Tool
4+
5+
This tool automatically adds @expectedFailure decorators and TODO comments
6+
to failing test methods in CPython test suites being ported to RustPython.
7+
It handles inheritance properly by placing decorators at the appropriate level.
8+
"""
9+
10+
import libcst as cst
11+
from typing import Dict, Set, List, Tuple, Optional
12+
import subprocess
13+
import re
14+
from pathlib import Path
15+
import argparse
16+
17+
18+
class TestFailureAnalyzer:
19+
"""Analyzes test failures and determines where to place decorators"""
20+
21+
def __init__(self):
22+
self.test_results: Dict[str, Dict[str, bool]] = {} # class.method -> pass/fail
23+
self.class_hierarchy: Dict[str, Set[str]] = {} # parent -> children
24+
self.method_definitions: Dict[str, str] = {} # method -> defining class
25+
26+
def run_tests(self, test_file: str) -> Dict[str, bool]:
27+
"""Run tests and collect failure information"""
28+
try:
29+
# Run pytest with verbose output to get individual test results
30+
result = subprocess.run(
31+
["cargo", "run", "--", test_file, "-v", "-b"],
32+
capture_output=True,
33+
text=True
34+
)
35+
36+
# Parse test results
37+
test_results = {}
38+
for line in result.stderr.split('\n'):
39+
# Match test result lines (e.g., "ERROR: test_unicode (__main__.ArrayReconstructorTest.test_unicode)")
40+
match = re.match(r'(PASSED|FAIL|ERROR):\s+(\w+)\s+\(__main__\.(\w+)\.(\w+)\)', line)
41+
if match:
42+
status, method_name, class_name, _ = match.groups()
43+
test_key = f"{class_name}.{method_name}"
44+
test_results[test_key] = status == "PASSED"
45+
46+
return test_results
47+
48+
except Exception as e:
49+
print(f"Error running tests: {e}")
50+
return {}
51+
52+
53+
class ClassHierarchyVisitor(cst.CSTVisitor):
54+
"""Visitor to build class hierarchy and method definitions"""
55+
56+
def __init__(self):
57+
self.current_class: Optional[str] = None
58+
self.class_hierarchy: Dict[str, Set[str]] = {}
59+
self.method_definitions: Dict[str, str] = {}
60+
self.class_bases: Dict[str, List[str]] = {}
61+
62+
def visit_ClassDef(self, node: cst.ClassDef) -> None:
63+
class_name = node.name.value
64+
self.current_class = class_name
65+
66+
# Extract base classes
67+
bases = []
68+
for arg in node.bases:
69+
if isinstance(arg.value, cst.Name):
70+
bases.append(arg.value.value)
71+
elif isinstance(arg.value, cst.Attribute):
72+
# Handle cases like unittest.TestCase
73+
bases.append(self._get_full_name(arg.value))
74+
75+
self.class_bases[class_name] = bases
76+
77+
# Build hierarchy (reverse mapping)
78+
for base in bases:
79+
if base not in self.class_hierarchy:
80+
self.class_hierarchy[base] = set()
81+
self.class_hierarchy[base].add(class_name)
82+
83+
def leave_ClassDef(self, node: cst.ClassDef) -> None:
84+
self.current_class = None
85+
86+
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
87+
if self.current_class and node.name.value.startswith('test_'):
88+
method_key = f"{self.current_class}.{node.name.value}"
89+
self.method_definitions[method_key] = self.current_class
90+
91+
def _get_full_name(self, node: cst.Attribute) -> str:
92+
"""Get full name from attribute node"""
93+
parts = []
94+
current: cst.BaseExpression = node
95+
while isinstance(current, cst.Attribute):
96+
parts.append(current.attr.value)
97+
current = current.value
98+
if isinstance(current, cst.Name):
99+
parts.append(current.value)
100+
return '.'.join(reversed(parts))
101+
102+
103+
class TestDecoratorTransformer(cst.CSTTransformer):
104+
"""Transformer to add decorators and comments to failing tests"""
105+
106+
def __init__(self, methods_to_decorate: Dict[str, Set[str]]):
107+
self.methods_to_decorate = methods_to_decorate
108+
self.current_class: Optional[str] = None
109+
self.added_decorators: Set[str] = set()
110+
111+
def visit_ClassDef(self, node: cst.ClassDef) -> None:
112+
self.current_class = node.name.value
113+
114+
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
115+
self.current_class = None
116+
return updated_node
117+
118+
def leave_FunctionDef(self, original_node: cst.FunctionDef, node: cst.FunctionDef) -> cst.FunctionDef:
119+
if not self.current_class:
120+
return node
121+
122+
method_name = node.name.value
123+
if not method_name.startswith('test_'):
124+
return node
125+
126+
# Check if this method needs decoration
127+
if method_name in self.methods_to_decorate.get(self.current_class, set()):
128+
# Check if already has expectedFailure decorator
129+
has_decorator = any(
130+
isinstance(d.decorator, cst.Name) and d.decorator.value == "expectedFailure"
131+
or isinstance(d.decorator, cst.Attribute) and d.decorator.attr.value == "expectedFailure"
132+
for d in node.decorators
133+
)
134+
135+
if not has_decorator:
136+
# Add TODO comment
137+
todo_comment = cst.EmptyLine(
138+
comment=cst.Comment("# TODO: RUSTPYTHON")
139+
)
140+
141+
# Add expectedFailure decorator
142+
decorator = cst.Decorator(
143+
decorator=cst.Attribute(
144+
value=cst.Name("unittest"),
145+
attr=cst.Name("expectedFailure")
146+
)
147+
)
148+
149+
# Create new decorators list
150+
new_decorators = list(node.decorators) + [decorator]
151+
152+
# Add comment before the method
153+
new_leading_lines = [cst.EmptyLine(), todo_comment] + list(line for line in node.leading_lines if not isinstance(line, cst.EmptyLine))
154+
155+
return node.with_changes(
156+
decorators=new_decorators,
157+
leading_lines=new_leading_lines
158+
)
159+
160+
return node
161+
162+
163+
class OverrideMethodTransformer(cst.CSTTransformer):
164+
"""Transformer to add override methods in child classes"""
165+
166+
def __init__(self, overrides_to_add: Dict[str, Set[str]]):
167+
self.overrides_to_add = overrides_to_add
168+
self.current_class: Optional[str] = None
169+
170+
def visit_ClassDef(self, node: cst.ClassDef) -> None:
171+
self.current_class = node.name.value
172+
173+
def leave_ClassDef(self, original_node: cst.ClassDef, node: cst.ClassDef) -> cst.ClassDef:
174+
if self.current_class in self.overrides_to_add:
175+
methods_to_override = self.overrides_to_add[self.current_class]
176+
177+
# Create override methods
178+
new_methods = []
179+
for method_name in methods_to_override:
180+
# Create a simple override method that calls super()
181+
override_method = cst.FunctionDef(
182+
name=cst.Name(method_name),
183+
params=cst.Parameters([
184+
cst.Param(name=cst.Name("self"))
185+
]),
186+
body=cst.IndentedBlock([
187+
cst.SimpleStatementLine([
188+
cst.Expr(
189+
cst.Call(
190+
func=cst.Attribute(
191+
value=cst.Call(
192+
func=cst.Name("super")
193+
),
194+
attr=cst.Name(method_name)
195+
)
196+
)
197+
)
198+
])
199+
]),
200+
decorators=[
201+
cst.Decorator(
202+
decorator=cst.Attribute(
203+
value=cst.Name("unittest"),
204+
attr=cst.Name("expectedFailure")
205+
)
206+
)
207+
],
208+
leading_lines=[
209+
cst.EmptyLine(
210+
whitespace=cst.SimpleWhitespace(" "),
211+
comment=cst.Comment("# TODO RUSTPYTHON: Fix this test")
212+
)
213+
]
214+
)
215+
new_methods.append(override_method)
216+
217+
# Add new methods to the class body
218+
if new_methods:
219+
new_body = list(node.body.body) + new_methods
220+
return node.with_changes(
221+
body=node.body.with_changes(body=new_body)
222+
)
223+
224+
self.current_class = None
225+
return node
226+
227+
228+
def analyze_test_failures(test_file: str, test_results: Dict[str, bool],
229+
hierarchy_visitor: ClassHierarchyVisitor) -> Tuple[Dict[str, Set[str]], Dict[str, Set[str]]]:
230+
"""
231+
Analyze test failures and determine where to place decorators.
232+
Returns: (methods_to_decorate, overrides_to_add)
233+
"""
234+
methods_to_decorate: Dict[str, Set[str]] = {}
235+
overrides_to_add: Dict[str, Set[str]] = {}
236+
237+
# Group failures by method name
238+
method_failures: Dict[str, Set[str]] = {} # method -> set of failing classes
239+
for test_key, passed in test_results.items():
240+
if not passed:
241+
class_name, method_name = test_key.split('.')
242+
if method_name not in method_failures:
243+
method_failures[method_name] = set()
244+
method_failures[method_name].add(class_name)
245+
246+
# For each failing method, determine where to place the decorator
247+
for method_name, failing_classes in method_failures.items():
248+
# Find all classes that have this method
249+
all_classes_with_method = set()
250+
for test_key in hierarchy_visitor.method_definitions:
251+
if test_key.endswith(f".{method_name}"):
252+
class_name = test_key.split('.')[0]
253+
all_classes_with_method.add(class_name)
254+
255+
# Check if all child classes of a parent fail
256+
for parent_class in hierarchy_visitor.class_hierarchy:
257+
children = hierarchy_visitor.class_hierarchy[parent_class]
258+
children_with_method = children & all_classes_with_method
259+
260+
if children_with_method and children_with_method.issubset(failing_classes):
261+
# All children with this method fail - decorate parent
262+
if parent_class not in methods_to_decorate:
263+
methods_to_decorate[parent_class] = set()
264+
methods_to_decorate[parent_class].add(method_name)
265+
266+
# Remove from failing_classes as we've handled them
267+
failing_classes -= children_with_method
268+
269+
# Handle remaining failures - need overrides in specific child classes
270+
for class_name in failing_classes:
271+
# Check if method is defined in this class or inherited
272+
test_key = f"{class_name}.{method_name}"
273+
if test_key in hierarchy_visitor.method_definitions:
274+
defining_class = hierarchy_visitor.method_definitions[test_key]
275+
if defining_class == class_name:
276+
# Method defined in this class - decorate it
277+
if class_name not in methods_to_decorate:
278+
methods_to_decorate[class_name] = set()
279+
methods_to_decorate[class_name].add(method_name)
280+
else:
281+
# Method inherited - need override
282+
if class_name not in overrides_to_add:
283+
overrides_to_add[class_name] = set()
284+
overrides_to_add[class_name].add(method_name)
285+
286+
return methods_to_decorate, overrides_to_add
287+
288+
289+
def process_test_file(file_path: str, dry_run: bool = False) -> None:
290+
"""Process a single test file"""
291+
print(f"Processing {file_path}...")
292+
293+
# Read the file
294+
with open(file_path, 'r') as f:
295+
source_code = f.read()
296+
297+
# Parse with libcst
298+
module = cst.parse_module(source_code)
299+
300+
# Build class hierarchy
301+
hierarchy_visitor = ClassHierarchyVisitor()
302+
module.visit(hierarchy_visitor)
303+
304+
# Run tests and collect failures
305+
analyzer = TestFailureAnalyzer()
306+
test_results = analyzer.run_tests(file_path)
307+
308+
if not test_results:
309+
print("No test results found. Make sure the file contains valid unittest tests.")
310+
return
311+
312+
# Analyze failures and determine decorator placement
313+
methods_to_decorate, overrides_to_add = analyze_test_failures(
314+
file_path, test_results, hierarchy_visitor
315+
)
316+
317+
# Apply transformations
318+
if methods_to_decorate:
319+
module = module.visit(TestDecoratorTransformer(methods_to_decorate))
320+
321+
if overrides_to_add:
322+
module = module.visit(OverrideMethodTransformer(overrides_to_add))
323+
324+
# Write back the modified code
325+
if dry_run:
326+
print("Dry run - would make the following changes:")
327+
print(f"Methods to decorate: {methods_to_decorate}")
328+
print(f"Override methods to add: {overrides_to_add}")
329+
else:
330+
with open(file_path, 'w') as f:
331+
f.write(module.code)
332+
print(f"Updated {file_path}")
333+
334+
335+
def main():
336+
parser = argparse.ArgumentParser(
337+
description="Add @expectedFailure decorators to failing RustPython tests"
338+
)
339+
parser.add_argument(
340+
"files",
341+
nargs="+",
342+
help="Test files to process"
343+
)
344+
parser.add_argument(
345+
"--dry-run",
346+
action="store_true",
347+
help="Show what would be changed without modifying files"
348+
)
349+
350+
args = parser.parse_args()
351+
352+
for file_path in args.files:
353+
if Path(file_path).exists():
354+
process_test_file(file_path, args.dry_run)
355+
else:
356+
print(f"File not found: {file_path}")
357+
358+
359+
if __name__ == "__main__":
360+
main()

0 commit comments

Comments
 (0)