-
Notifications
You must be signed in to change notification settings - Fork 14
Expand file tree
/
Copy pathdefault_positive_tolerance_interpretation.py
More file actions
79 lines (59 loc) · 2.32 KB
/
Copy pathdefault_positive_tolerance_interpretation.py
File metadata and controls
79 lines (59 loc) · 2.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from enum import IntEnum
from pass_bench.positive_tolerance_interpretation import (
PositiveToleranceInterpretation,
)
class DefaultErrorEnum(IntEnum):
"""
Values correspond to the minimum tolerance level required.
"""
kAccuracyViolation = 1 # Accuracy
kRuntimeFailure = 2 # Includes Runtime, NaN, Inf, TypeMismatch, etc.
kCompilationFailed = 3 # Compile Failure
@classmethod
def get_error_enum(cls, base_error_type: str) -> "DefaultErrorEnum":
if not base_error_type:
return cls.kRuntimeFailure
etype = base_error_type.lower()
if "accuracy" in etype:
return cls.kAccuracyViolation
if "compile_fail" in etype:
return cls.kCompilationFailed
return cls.kRuntimeFailure
class DefaultPositiveToleranceInterpretation(PositiveToleranceInterpretation):
"""
Legacy interpretation:
- t=1: Accuracy errors tolerated.
- t=3: Runtime/Compilation errors tolerated.
"""
def __init__(self, *argc, **kwargs):
super().__init__(*argc, **kwargs)
def type_name(self) -> str:
return "default"
def get_errno(self, error_type: str) -> int:
return DefaultErrorEnum.get_error_enum(error_type).value
def get_error_type(self, errno: int) -> str:
mapping = {1: "accuracy", 2: "runtime_fail", 3: "compile_fail"}
return mapping.get(errno, "unknown_error")
def get_tolerance_mapping(self) -> dict[int, int]:
return {
DefaultErrorEnum.kAccuracyViolation.value: 1,
DefaultErrorEnum.kRuntimeFailure.value: 3,
DefaultErrorEnum.kCompilationFailed.value: 3,
}
def is_error_tolerated(self, tolerance: int, base_error_code: str) -> bool:
if base_error_code == "correct":
return True
if base_error_code in ["eager_fail", "reference_fail"]:
return False
error_enum = DefaultErrorEnum.get_error_enum(base_error_code)
mapping = self.get_tolerance_mapping()
required_threshold = mapping.get(error_enum.value, 999)
return tolerance >= required_threshold
def num_errno_enum_values(self) -> int:
"""
Default mode defines 3 levels of errors:
1: Accuracy
2: Runtime (Generic)
3: Compilation
"""
return len(DefaultErrorEnum)