|
| 1 | +#!/usr/bin/env python3 |
| 2 | +""" |
| 3 | +RustPython Test Decorator Tool |
| 4 | +
|
| 5 | +This tool automatically adds @expectedFailure decorators and TODO comments |
| 6 | +to failing test methods in CPython test suites being ported to RustPython. |
| 7 | +It handles inheritance properly by placing decorators at the appropriate level. |
| 8 | +""" |
| 9 | + |
| 10 | +import libcst as cst |
| 11 | +from typing import Dict, Set, List, Tuple, Optional |
| 12 | +import subprocess |
| 13 | +import re |
| 14 | +from pathlib import Path |
| 15 | +import argparse |
| 16 | + |
| 17 | + |
| 18 | +class TestFailureAnalyzer: |
| 19 | + """Analyzes test failures and determines where to place decorators""" |
| 20 | + |
| 21 | + def __init__(self): |
| 22 | + self.test_results: Dict[str, Dict[str, bool]] = {} # class.method -> pass/fail |
| 23 | + self.class_hierarchy: Dict[str, Set[str]] = {} # parent -> children |
| 24 | + self.method_definitions: Dict[str, str] = {} # method -> defining class |
| 25 | + |
| 26 | + def run_tests(self, test_file: str) -> Dict[str, bool]: |
| 27 | + """Run tests and collect failure information""" |
| 28 | + try: |
| 29 | + # Run pytest with verbose output to get individual test results |
| 30 | + result = subprocess.run( |
| 31 | + ["cargo", "run", "--", test_file, "-v", "-b"], |
| 32 | + capture_output=True, |
| 33 | + text=True |
| 34 | + ) |
| 35 | + |
| 36 | + # Parse test results |
| 37 | + test_results = {} |
| 38 | + for line in result.stderr.split('\n'): |
| 39 | + # Match test result lines (e.g., "ERROR: test_unicode (__main__.ArrayReconstructorTest.test_unicode)") |
| 40 | + match = re.match(r'(PASSED|FAIL|ERROR):\s+(\w+)\s+\(__main__\.(\w+)\.(\w+)\)', line) |
| 41 | + if match: |
| 42 | + status, method_name, class_name, _ = match.groups() |
| 43 | + test_key = f"{class_name}.{method_name}" |
| 44 | + test_results[test_key] = status == "PASSED" |
| 45 | + |
| 46 | + return test_results |
| 47 | + |
| 48 | + except Exception as e: |
| 49 | + print(f"Error running tests: {e}") |
| 50 | + return {} |
| 51 | + |
| 52 | + |
| 53 | +class ClassHierarchyVisitor(cst.CSTVisitor): |
| 54 | + """Visitor to build class hierarchy and method definitions""" |
| 55 | + |
| 56 | + def __init__(self): |
| 57 | + self.current_class: Optional[str] = None |
| 58 | + self.class_hierarchy: Dict[str, Set[str]] = {} |
| 59 | + self.method_definitions: Dict[str, str] = {} |
| 60 | + self.class_bases: Dict[str, List[str]] = {} |
| 61 | + |
| 62 | + def visit_ClassDef(self, node: cst.ClassDef) -> None: |
| 63 | + class_name = node.name.value |
| 64 | + self.current_class = class_name |
| 65 | + |
| 66 | + # Extract base classes |
| 67 | + bases = [] |
| 68 | + for arg in node.bases: |
| 69 | + if isinstance(arg.value, cst.Name): |
| 70 | + bases.append(arg.value.value) |
| 71 | + elif isinstance(arg.value, cst.Attribute): |
| 72 | + # Handle cases like unittest.TestCase |
| 73 | + bases.append(self._get_full_name(arg.value)) |
| 74 | + |
| 75 | + self.class_bases[class_name] = bases |
| 76 | + |
| 77 | + # Build hierarchy (reverse mapping) |
| 78 | + for base in bases: |
| 79 | + if base not in self.class_hierarchy: |
| 80 | + self.class_hierarchy[base] = set() |
| 81 | + self.class_hierarchy[base].add(class_name) |
| 82 | + |
| 83 | + def leave_ClassDef(self, node: cst.ClassDef) -> None: |
| 84 | + self.current_class = None |
| 85 | + |
| 86 | + def visit_FunctionDef(self, node: cst.FunctionDef) -> None: |
| 87 | + if self.current_class and node.name.value.startswith('test_'): |
| 88 | + method_key = f"{self.current_class}.{node.name.value}" |
| 89 | + self.method_definitions[method_key] = self.current_class |
| 90 | + |
| 91 | + def _get_full_name(self, node: cst.Attribute) -> str: |
| 92 | + """Get full name from attribute node""" |
| 93 | + parts = [] |
| 94 | + current: cst.BaseExpression = node |
| 95 | + while isinstance(current, cst.Attribute): |
| 96 | + parts.append(current.attr.value) |
| 97 | + current = current.value |
| 98 | + if isinstance(current, cst.Name): |
| 99 | + parts.append(current.value) |
| 100 | + return '.'.join(reversed(parts)) |
| 101 | + |
| 102 | + |
| 103 | +class TestDecoratorTransformer(cst.CSTTransformer): |
| 104 | + """Transformer to add decorators and comments to failing tests""" |
| 105 | + |
| 106 | + def __init__(self, methods_to_decorate: Dict[str, Set[str]]): |
| 107 | + self.methods_to_decorate = methods_to_decorate |
| 108 | + self.current_class: Optional[str] = None |
| 109 | + self.added_decorators: Set[str] = set() |
| 110 | + |
| 111 | + def visit_ClassDef(self, node: cst.ClassDef) -> None: |
| 112 | + self.current_class = node.name.value |
| 113 | + |
| 114 | + def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: |
| 115 | + self.current_class = None |
| 116 | + return updated_node |
| 117 | + |
| 118 | + def leave_FunctionDef(self, original_node: cst.FunctionDef, node: cst.FunctionDef) -> cst.FunctionDef: |
| 119 | + if not self.current_class: |
| 120 | + return node |
| 121 | + |
| 122 | + method_name = node.name.value |
| 123 | + if not method_name.startswith('test_'): |
| 124 | + return node |
| 125 | + |
| 126 | + # Check if this method needs decoration |
| 127 | + if method_name in self.methods_to_decorate.get(self.current_class, set()): |
| 128 | + # Check if already has expectedFailure decorator |
| 129 | + has_decorator = any( |
| 130 | + isinstance(d.decorator, cst.Name) and d.decorator.value == "expectedFailure" |
| 131 | + or isinstance(d.decorator, cst.Attribute) and d.decorator.attr.value == "expectedFailure" |
| 132 | + for d in node.decorators |
| 133 | + ) |
| 134 | + |
| 135 | + if not has_decorator: |
| 136 | + # Add TODO comment |
| 137 | + todo_comment = cst.EmptyLine( |
| 138 | + comment=cst.Comment("# TODO: RUSTPYTHON") |
| 139 | + ) |
| 140 | + |
| 141 | + # Add expectedFailure decorator |
| 142 | + decorator = cst.Decorator( |
| 143 | + decorator=cst.Attribute( |
| 144 | + value=cst.Name("unittest"), |
| 145 | + attr=cst.Name("expectedFailure") |
| 146 | + ) |
| 147 | + ) |
| 148 | + |
| 149 | + # Create new decorators list |
| 150 | + new_decorators = list(node.decorators) + [decorator] |
| 151 | + |
| 152 | + # Add comment before the method |
| 153 | + new_leading_lines = [cst.EmptyLine(), todo_comment] + list(line for line in node.leading_lines if not isinstance(line, cst.EmptyLine)) |
| 154 | + |
| 155 | + return node.with_changes( |
| 156 | + decorators=new_decorators, |
| 157 | + leading_lines=new_leading_lines |
| 158 | + ) |
| 159 | + |
| 160 | + return node |
| 161 | + |
| 162 | + |
| 163 | +class OverrideMethodTransformer(cst.CSTTransformer): |
| 164 | + """Transformer to add override methods in child classes""" |
| 165 | + |
| 166 | + def __init__(self, overrides_to_add: Dict[str, Set[str]]): |
| 167 | + self.overrides_to_add = overrides_to_add |
| 168 | + self.current_class: Optional[str] = None |
| 169 | + |
| 170 | + def visit_ClassDef(self, node: cst.ClassDef) -> None: |
| 171 | + self.current_class = node.name.value |
| 172 | + |
| 173 | + def leave_ClassDef(self, original_node: cst.ClassDef, node: cst.ClassDef) -> cst.ClassDef: |
| 174 | + if self.current_class in self.overrides_to_add: |
| 175 | + methods_to_override = self.overrides_to_add[self.current_class] |
| 176 | + |
| 177 | + # Create override methods |
| 178 | + new_methods = [] |
| 179 | + for method_name in methods_to_override: |
| 180 | + # Create a simple override method that calls super() |
| 181 | + override_method = cst.FunctionDef( |
| 182 | + name=cst.Name(method_name), |
| 183 | + params=cst.Parameters([ |
| 184 | + cst.Param(name=cst.Name("self")) |
| 185 | + ]), |
| 186 | + body=cst.IndentedBlock([ |
| 187 | + cst.SimpleStatementLine([ |
| 188 | + cst.Expr( |
| 189 | + cst.Call( |
| 190 | + func=cst.Attribute( |
| 191 | + value=cst.Call( |
| 192 | + func=cst.Name("super") |
| 193 | + ), |
| 194 | + attr=cst.Name(method_name) |
| 195 | + ) |
| 196 | + ) |
| 197 | + ) |
| 198 | + ]) |
| 199 | + ]), |
| 200 | + decorators=[ |
| 201 | + cst.Decorator( |
| 202 | + decorator=cst.Attribute( |
| 203 | + value=cst.Name("unittest"), |
| 204 | + attr=cst.Name("expectedFailure") |
| 205 | + ) |
| 206 | + ) |
| 207 | + ], |
| 208 | + leading_lines=[ |
| 209 | + cst.EmptyLine( |
| 210 | + whitespace=cst.SimpleWhitespace(" "), |
| 211 | + comment=cst.Comment("# TODO RUSTPYTHON: Fix this test") |
| 212 | + ) |
| 213 | + ] |
| 214 | + ) |
| 215 | + new_methods.append(override_method) |
| 216 | + |
| 217 | + # Add new methods to the class body |
| 218 | + if new_methods: |
| 219 | + new_body = list(node.body.body) + new_methods |
| 220 | + return node.with_changes( |
| 221 | + body=node.body.with_changes(body=new_body) |
| 222 | + ) |
| 223 | + |
| 224 | + self.current_class = None |
| 225 | + return node |
| 226 | + |
| 227 | + |
| 228 | +def analyze_test_failures(test_file: str, test_results: Dict[str, bool], |
| 229 | + hierarchy_visitor: ClassHierarchyVisitor) -> Tuple[Dict[str, Set[str]], Dict[str, Set[str]]]: |
| 230 | + """ |
| 231 | + Analyze test failures and determine where to place decorators. |
| 232 | + Returns: (methods_to_decorate, overrides_to_add) |
| 233 | + """ |
| 234 | + methods_to_decorate: Dict[str, Set[str]] = {} |
| 235 | + overrides_to_add: Dict[str, Set[str]] = {} |
| 236 | + |
| 237 | + # Group failures by method name |
| 238 | + method_failures: Dict[str, Set[str]] = {} # method -> set of failing classes |
| 239 | + for test_key, passed in test_results.items(): |
| 240 | + if not passed: |
| 241 | + class_name, method_name = test_key.split('.') |
| 242 | + if method_name not in method_failures: |
| 243 | + method_failures[method_name] = set() |
| 244 | + method_failures[method_name].add(class_name) |
| 245 | + |
| 246 | + # For each failing method, determine where to place the decorator |
| 247 | + for method_name, failing_classes in method_failures.items(): |
| 248 | + # Find all classes that have this method |
| 249 | + all_classes_with_method = set() |
| 250 | + for test_key in hierarchy_visitor.method_definitions: |
| 251 | + if test_key.endswith(f".{method_name}"): |
| 252 | + class_name = test_key.split('.')[0] |
| 253 | + all_classes_with_method.add(class_name) |
| 254 | + |
| 255 | + # Check if all child classes of a parent fail |
| 256 | + for parent_class in hierarchy_visitor.class_hierarchy: |
| 257 | + children = hierarchy_visitor.class_hierarchy[parent_class] |
| 258 | + children_with_method = children & all_classes_with_method |
| 259 | + |
| 260 | + if children_with_method and children_with_method.issubset(failing_classes): |
| 261 | + # All children with this method fail - decorate parent |
| 262 | + if parent_class not in methods_to_decorate: |
| 263 | + methods_to_decorate[parent_class] = set() |
| 264 | + methods_to_decorate[parent_class].add(method_name) |
| 265 | + |
| 266 | + # Remove from failing_classes as we've handled them |
| 267 | + failing_classes -= children_with_method |
| 268 | + |
| 269 | + # Handle remaining failures - need overrides in specific child classes |
| 270 | + for class_name in failing_classes: |
| 271 | + # Check if method is defined in this class or inherited |
| 272 | + test_key = f"{class_name}.{method_name}" |
| 273 | + if test_key in hierarchy_visitor.method_definitions: |
| 274 | + defining_class = hierarchy_visitor.method_definitions[test_key] |
| 275 | + if defining_class == class_name: |
| 276 | + # Method defined in this class - decorate it |
| 277 | + if class_name not in methods_to_decorate: |
| 278 | + methods_to_decorate[class_name] = set() |
| 279 | + methods_to_decorate[class_name].add(method_name) |
| 280 | + else: |
| 281 | + # Method inherited - need override |
| 282 | + if class_name not in overrides_to_add: |
| 283 | + overrides_to_add[class_name] = set() |
| 284 | + overrides_to_add[class_name].add(method_name) |
| 285 | + |
| 286 | + return methods_to_decorate, overrides_to_add |
| 287 | + |
| 288 | + |
| 289 | +def process_test_file(file_path: str, dry_run: bool = False) -> None: |
| 290 | + """Process a single test file""" |
| 291 | + print(f"Processing {file_path}...") |
| 292 | + |
| 293 | + # Read the file |
| 294 | + with open(file_path, 'r') as f: |
| 295 | + source_code = f.read() |
| 296 | + |
| 297 | + # Parse with libcst |
| 298 | + module = cst.parse_module(source_code) |
| 299 | + |
| 300 | + # Build class hierarchy |
| 301 | + hierarchy_visitor = ClassHierarchyVisitor() |
| 302 | + module.visit(hierarchy_visitor) |
| 303 | + |
| 304 | + # Run tests and collect failures |
| 305 | + analyzer = TestFailureAnalyzer() |
| 306 | + test_results = analyzer.run_tests(file_path) |
| 307 | + |
| 308 | + if not test_results: |
| 309 | + print("No test results found. Make sure the file contains valid unittest tests.") |
| 310 | + return |
| 311 | + |
| 312 | + # Analyze failures and determine decorator placement |
| 313 | + methods_to_decorate, overrides_to_add = analyze_test_failures( |
| 314 | + file_path, test_results, hierarchy_visitor |
| 315 | + ) |
| 316 | + |
| 317 | + # Apply transformations |
| 318 | + if methods_to_decorate: |
| 319 | + module = module.visit(TestDecoratorTransformer(methods_to_decorate)) |
| 320 | + |
| 321 | + if overrides_to_add: |
| 322 | + module = module.visit(OverrideMethodTransformer(overrides_to_add)) |
| 323 | + |
| 324 | + # Write back the modified code |
| 325 | + if dry_run: |
| 326 | + print("Dry run - would make the following changes:") |
| 327 | + print(f"Methods to decorate: {methods_to_decorate}") |
| 328 | + print(f"Override methods to add: {overrides_to_add}") |
| 329 | + else: |
| 330 | + with open(file_path, 'w') as f: |
| 331 | + f.write(module.code) |
| 332 | + print(f"Updated {file_path}") |
| 333 | + |
| 334 | + |
| 335 | +def main(): |
| 336 | + parser = argparse.ArgumentParser( |
| 337 | + description="Add @expectedFailure decorators to failing RustPython tests" |
| 338 | + ) |
| 339 | + parser.add_argument( |
| 340 | + "files", |
| 341 | + nargs="+", |
| 342 | + help="Test files to process" |
| 343 | + ) |
| 344 | + parser.add_argument( |
| 345 | + "--dry-run", |
| 346 | + action="store_true", |
| 347 | + help="Show what would be changed without modifying files" |
| 348 | + ) |
| 349 | + |
| 350 | + args = parser.parse_args() |
| 351 | + |
| 352 | + for file_path in args.files: |
| 353 | + if Path(file_path).exists(): |
| 354 | + process_test_file(file_path, args.dry_run) |
| 355 | + else: |
| 356 | + print(f"File not found: {file_path}") |
| 357 | + |
| 358 | + |
| 359 | +if __name__ == "__main__": |
| 360 | + main() |
0 commit comments