Skip to content

Commit 467d29d

Browse files
committed
fix: fix crash on invalid discriminator in Optional discriminated unions
Signed-off-by: Stephen Crowe <[email protected]>
1 parent 1c0263a commit 467d29d

File tree

3 files changed

+98
-1
lines changed

3 files changed

+98
-1
lines changed

src/openjd/model/_internal/_variable_reference_validation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,8 @@ def _get_model_for_singleton_value(
459459
# Find the correct model for the discriminator value by unwrapping the Union and then the discriminator Literals
460460
assert typing.get_origin(model) is typing.Union # For the type checker
461461
for sub_model in typing.get_args(model):
462+
if sub_model is type(None):
463+
continue
462464
sub_model_discr_value = sub_model.model_fields[discriminator].annotation
463465
if typing.get_origin(sub_model_discr_value) is not typing.Literal:
464466
raise NotImplementedError(

test/openjd/model/_internal/test_variable_reference_validation.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22

3-
from typing import Any, Literal, Union
3+
from typing import Any, Literal, Optional, Union
44
from enum import Enum
55
from typing_extensions import Annotated
66
from pydantic import Field
@@ -1145,6 +1145,54 @@ class BaseModel(OpenJDModel):
11451145
# THEN
11461146
assert len(errors) == 0
11471147

1148+
@pytest.mark.parametrize(
1149+
"data",
1150+
[
1151+
pytest.param({"name": "Foo", "sub": {"kind": "INVALID"}}, id="invalid discriminator"),
1152+
pytest.param({"name": "Foo", "sub": {"kind": ""}}, id="empty discriminator"),
1153+
pytest.param({"name": "Foo", "sub": None}, id="None value"),
1154+
],
1155+
)
1156+
def test_optional_discriminated_union_invalid_discriminator(self, data: dict[str, Any]) -> None:
1157+
# Test that Optional discriminated unions with invalid discriminator values
1158+
# don't crash with AttributeError on NoneType.
1159+
1160+
# GIVEN
1161+
class Kind(str, Enum):
1162+
ONE = "ONE"
1163+
TWO = "TWO"
1164+
1165+
class SubModel1(OpenJDModel):
1166+
kind: Literal[Kind.ONE]
1167+
field1: FormatString
1168+
_template_variable_scope = ResolutionScope.TEMPLATE
1169+
1170+
class SubModel2(OpenJDModel):
1171+
kind: Literal[Kind.TWO]
1172+
field2: FormatString
1173+
_template_variable_scope = ResolutionScope.TEMPLATE
1174+
1175+
class BaseModel(OpenJDModel):
1176+
name: str
1177+
sub: Optional[
1178+
Annotated[Union[SubModel1, SubModel2], Field(..., discriminator="kind")]
1179+
] = None
1180+
_template_variable_definitions = DefinesTemplateVariables(
1181+
defines={TemplateVariableDef(prefix="|Param.", resolves=ResolutionScope.TEMPLATE)},
1182+
field="name",
1183+
)
1184+
_template_variable_sources = {
1185+
"sub": {"__self__"},
1186+
}
1187+
1188+
# WHEN
1189+
errors = prevalidate_model_template_variable_references(
1190+
BaseModel, data, context=ModelParsingContext_v2023_09()
1191+
)
1192+
1193+
# THEN
1194+
assert len(errors) == 0
1195+
11481196

11491197
class TestNonDiscriminatedUnion:
11501198
"""Test that if we have unions in the model that isn't a discriminated union then we handle them correctly.

test/openjd/model/test_fuzz.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,42 @@ def fuzz_format_string():
7171
)
7272

7373

74+
def fuzz_cancelation():
75+
return random.choice(
76+
[
77+
None,
78+
{},
79+
{"mode": "NOTIFY_THEN_TERMINATE"},
80+
{"mode": "TERMINATE"},
81+
{"mode": "INVALID_MODE"},
82+
{"mode": random_string(1, 20)},
83+
{"mode": None},
84+
{"mode": 123},
85+
{"mode": "NOTIFY_THEN_TERMINATE", "notifyPeriodInSeconds": 30},
86+
{"mode": "NOTIFY_THEN_TERMINATE", "notifyPeriodInSeconds": "invalid"},
87+
random_string(1, 10),
88+
]
89+
)
90+
91+
92+
def fuzz_task_parameter():
93+
return random.choice(
94+
[
95+
{},
96+
{"name": "P", "type": "INT", "range": "1-10"},
97+
{"name": "P", "type": "FLOAT", "range": [1.0, 2.0]},
98+
{"name": "P", "type": "STRING", "values": ["a", "b"]},
99+
{"name": "P", "type": "PATH", "values": ["/tmp"]},
100+
{"name": "P", "type": "INVALID_TYPE", "range": "1-10"},
101+
{"name": "P", "type": random_string(1, 15), "range": "1-10"},
102+
{"name": "P", "type": None},
103+
{"name": "P", "type": 123},
104+
{"type": "INT", "range": "1-10"},
105+
random_string(1, 10),
106+
]
107+
)
108+
109+
74110
def fuzz_step():
75111
base = {"name": "Step1", "script": {"actions": {"onRun": {"command": "echo"}}}}
76112
mutations = [
@@ -81,8 +117,15 @@ def fuzz_step():
81117
{"name": "Step1", "script": {"actions": {}}},
82118
{"name": "Step1", "script": {"actions": {"onRun": {}}}},
83119
{"name": "Step1", "script": {"actions": {"onRun": {"command": ""}}}},
120+
{
121+
"name": "Step1",
122+
"script": {
123+
"actions": {"onRun": {"command": "echo", "cancelation": fuzz_cancelation()}}
124+
},
125+
},
84126
{**base, "parameterSpace": {}},
85127
{**base, "parameterSpace": {"taskParameterDefinitions": []}},
128+
{**base, "parameterSpace": {"taskParameterDefinitions": [fuzz_task_parameter()]}},
86129
{
87130
**base,
88131
"parameterSpace": {
@@ -106,12 +149,16 @@ def fuzz_job_parameter():
106149
{"name": "Param1", "type": "FLOAT"},
107150
{"name": "Param1", "type": "PATH"},
108151
{"name": "Param1", "type": "INVALID"},
152+
{"name": "Param1", "type": random_string(1, 15)},
153+
{"name": "Param1", "type": None},
154+
{"name": "Param1", "type": 123},
109155
{"name": "", "type": "STRING"},
110156
{"name": 123, "type": "STRING"},
111157
{"type": "STRING"},
112158
{"name": "P", "type": "INT", "minValue": 10, "maxValue": 5},
113159
{"name": "P", "type": "INT", "default": "notanint"},
114160
{"name": "P", "type": "STRING", "allowedValues": []},
161+
random_string(1, 10),
115162
]
116163
)
117164

0 commit comments

Comments
 (0)