Skip to content

Commit ec17019

Browse files
committed
Fix fix_test.py
1 parent 5a5b721 commit ec17019

2 files changed

Lines changed: 188 additions & 161 deletions

File tree

scripts/fix_test.py

Lines changed: 0 additions & 161 deletions
This file was deleted.

scripts/mark_test_failures.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
"""
2+
An automated script to mark failures in python test suite.
3+
It adds @unittest.expectedFailure to the test functions that are failing in RustPython, but not in CPython.
4+
As well as marking the test with a TODO comment.
5+
6+
How to use:
7+
1. Copy a specific test from the CPython repository to the RustPython repository.
8+
2. Remove all unexpected failures from the test and skip the tests that hang.
9+
3. Build RustPython: cargo build --release
10+
4. Run from the project root:
11+
- For single-file tests: python ./scripts/fix_test.py --path ./Lib/test/test_venv.py
12+
- For package tests: python ./scripts/fix_test.py --path ./Lib/test/test_inspect/test_inspect.py
13+
5. Verify: cargo run --release -- -m test test_venv (should pass with expected failures)
14+
6. Actually fix the tests marked with # TODO: RUSTPYTHON
15+
"""
16+
17+
import argparse
18+
import ast
19+
import itertools
20+
import platform
21+
import sys
22+
from pathlib import Path
23+
24+
25+
def parse_args():
26+
parser = argparse.ArgumentParser(description="Fix test.")
27+
parser.add_argument("--path", type=Path, help="Path to test file")
28+
parser.add_argument("--force", action="store_true", help="Force modification")
29+
parser.add_argument(
30+
"--platform", action="store_true", help="Platform specific failure"
31+
)
32+
33+
args = parser.parse_args()
34+
return args
35+
36+
37+
class Test:
38+
name: str = ""
39+
path: str = ""
40+
result: str = ""
41+
42+
def __str__(self):
43+
return f"Test(name={self.name}, path={self.path}, result={self.result})"
44+
45+
46+
class TestResult:
47+
tests_result: str = ""
48+
tests = []
49+
stdout = ""
50+
51+
def __str__(self):
52+
return f"TestResult(tests_result={self.tests_result},tests={len(self.tests)})"
53+
54+
55+
def parse_results(result):
56+
lines = result.stdout.splitlines()
57+
test_results = TestResult()
58+
test_results.stdout = result.stdout
59+
in_test_results = False
60+
for line in lines:
61+
if line == "Run tests sequentially":
62+
in_test_results = True
63+
elif line.startswith("-----------"):
64+
in_test_results = False
65+
if in_test_results and " ... " in line:
66+
line = line.strip()
67+
# Skip lines that don't look like test results
68+
if line.startswith("tests") or line.startswith("["):
69+
continue
70+
# Parse: "test_name (path) [subtest] ... RESULT"
71+
parts = line.split(" ... ")
72+
if len(parts) >= 2:
73+
test_info = parts[0]
74+
result_str = parts[-1].lower()
75+
# Only process FAIL or ERROR
76+
if result_str not in ("fail", "error"):
77+
continue
78+
# Extract test name (first word)
79+
first_space = test_info.find(" ")
80+
if first_space > 0:
81+
test = Test()
82+
test.name = test_info[:first_space]
83+
# Extract path from (path)
84+
rest = test_info[first_space:].strip()
85+
if rest.startswith("("):
86+
end_paren = rest.find(")")
87+
if end_paren > 0:
88+
test.path = rest[1:end_paren]
89+
test.result = result_str
90+
test_results.tests.append(test)
91+
elif "== Tests result: " in line:
92+
res = line.split("== Tests result: ")[1]
93+
res = res.split(" ")[0]
94+
test_results.tests_result = res
95+
return test_results
96+
97+
98+
def path_to_test(path) -> list[str]:
99+
# path format: test.module_name[.submodule].ClassName.test_method
100+
# We need [ClassName, test_method] - always the last 2 elements
101+
parts = path.split(".")
102+
return parts[-2:] # Get class name and method name
103+
104+
105+
def find_test_lineno(file: str, test: list[str]) -> tuple[int, int] | None:
106+
"""Find the line number and column offset of a test function.
107+
Returns (lineno, col_offset) or None if not found.
108+
"""
109+
a = ast.parse(file)
110+
for key, node in ast.iter_fields(a):
111+
if key == "body":
112+
for n in node:
113+
match n:
114+
case ast.ClassDef():
115+
if len(test) == 2 and test[0] == n.name:
116+
for fn in n.body:
117+
match fn:
118+
case ast.FunctionDef() | ast.AsyncFunctionDef():
119+
if fn.name == test[-1]:
120+
return (fn.lineno, fn.col_offset)
121+
case ast.FunctionDef() | ast.AsyncFunctionDef():
122+
if n.name == test[0] and len(test) == 1:
123+
return (n.lineno, n.col_offset)
124+
return None
125+
126+
127+
def apply_modifications(file: str, modifications: list[tuple[int, int]]) -> str:
128+
"""Apply all modifications in reverse order to avoid line number offset issues."""
129+
lines = file.splitlines()
130+
fixture = "@unittest.expectedFailure"
131+
# Sort by line number in descending order
132+
modifications.sort(key=lambda x: x[0], reverse=True)
133+
for lineno, col_offset in modifications:
134+
indent = " " * col_offset
135+
lines.insert(lineno - 1, indent + fixture)
136+
lines.insert(lineno - 1, indent + "# TODO: RUSTPYTHON")
137+
return "\n".join(lines)
138+
139+
140+
def run_test(test_name):
141+
print(f"Running test: {test_name}")
142+
rustpython_location = "./target/release/rustpython"
143+
if sys.platform == "win32":
144+
rustpython_location += ".exe"
145+
146+
import subprocess
147+
148+
result = subprocess.run(
149+
[rustpython_location, "-m", "test", "-v", test_name],
150+
capture_output=True,
151+
text=True,
152+
)
153+
return parse_results(result)
154+
155+
156+
if __name__ == "__main__":
157+
args = parse_args()
158+
test_path = args.path.resolve()
159+
if not test_path.exists():
160+
print(f"Error: File not found: {test_path}")
161+
sys.exit(1)
162+
test_name = test_path.stem
163+
tests = run_test(test_name)
164+
f = test_path.read_text(encoding="utf-8")
165+
166+
# Collect all modifications first (with deduplication for subtests)
167+
modifications = []
168+
seen_tests = set() # Track (class_name, method_name) to avoid duplicates
169+
for test in tests.tests:
170+
if test.result == "fail" or test.result == "error":
171+
test_parts = path_to_test(test.path)
172+
test_key = tuple(test_parts)
173+
if test_key in seen_tests:
174+
continue # Skip duplicate (same test, different subtest)
175+
seen_tests.add(test_key)
176+
location = find_test_lineno(f, test_parts)
177+
if location:
178+
print(f"Modifying test: {test.name} at line {location[0]}")
179+
modifications.append(location)
180+
else:
181+
print(f"Warning: Could not find test: {test.name} ({test_parts})")
182+
183+
# Apply all modifications in reverse order
184+
if modifications:
185+
f = apply_modifications(f, modifications)
186+
test_path.write_text(f, encoding="utf-8")
187+
188+
print(f"Modified {len(modifications)} tests")

0 commit comments

Comments
 (0)