From 41c4bfb8225d0b09d0f62addbba5c8725cbfcb58 Mon Sep 17 00:00:00 2001 From: Mark Wiebe <399551+mwiebe@users.noreply.github.com> Date: Thu, 13 Feb 2025 15:45:45 -0800 Subject: [PATCH] refactor: Refactor to improve comments and reduce code duplication * Comments are more clear about difference between model parsing, when values may not be available, and job instantiation, when all values must be available for format strings. * Created a new file to hold general validator functions for some commonly used patterns. * Adjust the model to reduce the number of confusingly spurious errors, like when errors include "must be int" and "must be string" for the exact same field, seemingly contradictory errors. Signed-off-by: Mark Wiebe <399551+mwiebe@users.noreply.github.com> --- src/openjd/model/_internal/__init__.py | 8 + .../model/_internal/_validator_functions.py | 119 ++++++++++++ src/openjd/model/v2023_09/_model.py | 174 ++---------------- .../test_chunk_int_task_parameter_type.py | 136 ++++++++++++-- 4 files changed, 258 insertions(+), 179 deletions(-) create mode 100644 src/openjd/model/_internal/_validator_functions.py diff --git a/src/openjd/model/_internal/__init__.py b/src/openjd/model/_internal/__init__.py index b0dee5e..efe0f42 100644 --- a/src/openjd/model/_internal/__init__.py +++ b/src/openjd/model/_internal/__init__.py @@ -13,6 +13,11 @@ validate_step_parameter_space_dimensions, ) from ._variable_reference_validation import prevalidate_model_template_variable_references +from ._validator_functions import ( + validate_int_fmtstring_field, + validate_list_field, + validate_float_fmtstring_field, +) __all__ = ( "instantiate_model", @@ -20,6 +25,9 @@ "validate_step_parameter_space_chunk_constraint", "validate_step_parameter_space_dimensions", "validate_unique_elements", + "validate_float_fmtstring_field", + "validate_int_fmtstring_field", + "validate_list_field", "CombinationExpressionAssociationNode", "CombinationExpressionIdentifierNode", "CombinationExpressionNode", diff --git a/src/openjd/model/_internal/_validator_functions.py b/src/openjd/model/_internal/_validator_functions.py new file mode 100644 index 0000000..b7abd44 --- /dev/null +++ b/src/openjd/model/_internal/_validator_functions.py @@ -0,0 +1,119 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from typing import cast, Any, Callable, Union, Optional +from decimal import Decimal, InvalidOperation + +from pydantic_core import PydanticKnownError, ValidationError, InitErrorDetails + +from .._format_strings import FormatString + + +def validate_int_fmtstring_field( + value: Union[int, float, Decimal, str, FormatString], ge: Optional[int] = None +) -> Union[int, float, Decimal, str, FormatString]: + """Validates a field that is allowed to be either an integer, a string containing an integer, + or a string containing expressions that resolve to an integer.""" + value_type_wrong_msg = "Value must be an integer or a string containing an integer." + # Validate the type + if isinstance(value, str): + if not isinstance(value, FormatString): + value = FormatString(value) + # If the string value has no expressions, we can validate the value now. + if len(value.expressions) == 0: + try: + int_value = int(value) + except ValueError: + raise ValueError(value_type_wrong_msg) + else: + # In this case, we cannot validate now. We need to validate this later when the template + # is instantiated into a job and the expressions have been resolved. + return value + elif isinstance(value, (int, float, Decimal)) and not isinstance(value, bool): + int_value = int(value) + if int_value != value: + raise ValueError(value_type_wrong_msg) + else: + raise ValueError(value_type_wrong_msg) + + # Validate the value constraints + if ge is not None and not (int_value >= ge): + raise PydanticKnownError("greater_than_equal", {"ge": ge}) + + return int_value + + +def validate_float_fmtstring_field( + value: Union[int, float, Decimal, str, FormatString], ge: Optional[Decimal] = None +) -> Union[int, float, Decimal, str, FormatString]: + """Validates a field that is allowed to be either an float, a string containing an float, + or a string containing expressions that resolve to a float.""" + value_type_wrong_msg = "Value must be a float or a string containing a float." + # Validate the type + if isinstance(value, str): + if not isinstance(value, FormatString): + value = FormatString(value) + # If the string value has no expressions, we can validate the value now. + if len(value.expressions) == 0: + try: + float_value = Decimal(str(value)) + except InvalidOperation: + raise ValueError(value_type_wrong_msg) + else: + # In this case, we cannot validate now. We need to validate this later when the template + # is instantiated into a job and the expressions have been resolved. + return value + elif isinstance(value, (int, float, Decimal)) and not isinstance(value, bool): + try: + float_value = Decimal(str(value)) + except InvalidOperation: + raise ValueError(value_type_wrong_msg) + else: + raise ValueError(value_type_wrong_msg) + + # Validate the value constraints + if ge is not None and not (float_value >= ge): + raise PydanticKnownError("greater_than_equal", {"ge": ge}) + + return float_value + + +def validate_list_field(value: list, validator: Callable) -> list: + """Validates a list of values using the provided validator function.""" + errors = list[InitErrorDetails]() + for i, item in enumerate(value): + try: + validator(item) + except PydanticKnownError as exc: + # Copy known errors verbatim with added location data + errors.append( + InitErrorDetails( + type=exc.type, + loc=(i,), + ctx=exc.context or {}, + input=item, + ) + ) + except ValidationError as exc: + # Convert the ErrorDetails to InitErrorDetails by extending the 'loc' and excluding the 'msg' + for error_details in exc.errors(): + init_error_details: dict[str, Any] = {} + for err_key, err_value in error_details.items(): + if err_key == "loc": + init_error_details["loc"] = (i,) + cast(tuple, err_value) + elif err_key != "msg": + init_error_details[err_key] = err_value + errors.append(cast(InitErrorDetails, init_error_details)) + except Exception as exc: + # Copy other errors as value_error + errors.append( + InitErrorDetails( + type="value_error", + loc=(i,), + ctx={"error": exc}, + input=item, + ) + ) + if errors: + # Raise the list of all the individual errors as a new ValidationError + raise ValidationError.from_exception_data("list", line_errors=errors) + return value diff --git a/src/openjd/model/v2023_09/_model.py b/src/openjd/model/v2023_09/_model.py index b8f2204..8de45dd 100644 --- a/src/openjd/model/v2023_09/_model.py +++ b/src/openjd/model/v2023_09/_model.py @@ -8,7 +8,6 @@ from graphlib import CycleError, TopologicalSorter from typing import Any, ClassVar, Literal, Optional, Type, Union, cast, Iterable from typing_extensions import Annotated, Self -import annotated_types from pydantic import ( field_validator, @@ -17,13 +16,12 @@ Field, PositiveInt, PositiveFloat, - Strict, StrictBool, StrictInt, ValidationError, ValidationInfo, ) -from pydantic_core import InitErrorDetails, PydanticKnownError +from pydantic_core import InitErrorDetails from pydantic.fields import ModelPrivateAttr from .._format_strings import FormatString @@ -37,6 +35,9 @@ validate_step_parameter_space_dimensions, validate_step_parameter_space_chunk_constraint, validate_unique_elements, + validate_int_fmtstring_field, + validate_float_fmtstring_field, + validate_list_field, ) from .._internal._variable_reference_validation import ( prevalidate_model_template_variable_references, @@ -567,44 +568,23 @@ class TaskChunksRangeConstraint(str, Enum): class TaskChunksDefinition(OpenJDModel_v2023_09): - defaultTaskCount: Union[Annotated[int, annotated_types.Ge(1), Strict()], FormatString] - targetRuntimeSeconds: Optional[ - Union[Annotated[int, annotated_types.Ge(0), Strict()], FormatString] - ] = None + defaultTaskCount: Union[int, FormatString] + targetRuntimeSeconds: Optional[Union[int, FormatString]] = None rangeConstraint: TaskChunksRangeConstraint _job_creation_metadata = JobCreationMetadata( resolve_fields={"defaultTaskCount", "targetRuntimeSeconds"}, ) - @field_validator("defaultTaskCount") + @field_validator("defaultTaskCount", mode="before") @classmethod def _validate_default_task_count(cls, value: Any) -> Any: - if isinstance(value, FormatString): - # If the string value has no expressions, can validate the value now. - # Otherwise will validate when - if len(value.expressions) == 0: - try: - int_value = int(value) - except ValueError: - raise ValueError("String literal must contain an integer.") - if int_value < 1: - raise PydanticKnownError("greater_than_equal", {"ge": 1}) - return value + return validate_int_fmtstring_field(value, ge=1) - @field_validator("targetRuntimeSeconds") + @field_validator("targetRuntimeSeconds", mode="before") @classmethod def _validate_target_runtime_seconds(cls, value: Any) -> Any: - if isinstance(value, FormatString): - # If the string value has no expressions, can validate it now - if len(value.expressions) == 0: - try: - int_value = int(value) - except ValueError: - raise ValueError("String literal must contain an integer.") - if int_value < 0: - raise PydanticKnownError("greater_than_equal", {"ge": 0}) - return value + return validate_int_fmtstring_field(value, ge=0) class IntTaskParameterDefinition(OpenJDModel_v2023_09): @@ -651,21 +631,7 @@ def _validate_range_element_type(cls, value: Any) -> Any: # We do allow coersion from a string since we want to allow "1", and # "1.2" or "a" will fail the type coersion if isinstance(value, list): - errors = list[InitErrorDetails]() - for i, item in enumerate(value): - if isinstance(item, bool) or not isinstance(item, (int, str)): - errors.append( - InitErrorDetails( - type="value_error", - loc=(i,), - ctx={ - "error": ValueError("Value must be an integer or integer string.") - }, - input=item, - ) - ) - if errors: - raise ValidationError.from_exception_data(cls.__name__, line_errors=errors) + return validate_list_field(value, validate_int_fmtstring_field) elif isinstance(value, RangeString): # Nothing to do - it's guaranteed to be a format string at this point pass @@ -675,34 +641,7 @@ def _validate_range_element_type(cls, value: Any) -> Any: @field_validator("range") @classmethod def _validate_range_elements(cls, value: Any) -> Any: - if isinstance(value, list): - errors = list[InitErrorDetails]() - for i, item in enumerate(value): - if isinstance(item, TaskParameterStringValue): - # A TaskParameterStringValue is a FormatString. - # FormatString.expressions is the list of all expressions in the format string - # ( e.g. "{{ Param.Foo }}"). - # Validate the string if it does not contain any expressions, in order to catch - # errors earlier when possible. - if len(item.expressions) == 0: - try: - int(item) - except ValueError: - errors.append( - InitErrorDetails( - type="value_error", - loc=(i,), - ctx={ - "error": ValueError( - "String literal must contain an integer." - ) - }, - input=item, - ) - ) - if errors: - raise ValidationError.from_exception_data(cls.__name__, line_errors=errors) - else: + if isinstance(value, FormatString): # If there are no format expressions, we can validate the range expression. # otherwise we defer to the RangeExressionTaskParameter model when # they've all been evaluated @@ -747,51 +686,8 @@ def _validate_range_element_type(cls, value: Any) -> Any: # pydantic will automatically type coerce values into floats. We explicitly # want to reject non-integer values, so this *pre* validator validates the # value *before* pydantic tries to type coerce it. - # We do allow coersion from a string since we want to allow "1", and - # "1.2" or "a" will fail the type coersion if isinstance(value, list): - errors = list[InitErrorDetails]() - for i, item in enumerate(value): - if isinstance(item, bool) or not isinstance(item, (int, float, str)): - errors.append( - InitErrorDetails( - type="value_error", - loc=("range", i), - ctx={"error": ValueError("Value must be a float or float string.")}, - input=item, - ) - ) - if errors: - raise ValidationError.from_exception_data(cls.__name__, line_errors=errors) - return value - - @field_validator("range") - @classmethod - def _validate_range_elements( - cls, value: list[Union[Decimal, TaskParameterStringValue]] - ) -> list[Union[Decimal, TaskParameterStringValue]]: - errors = list[InitErrorDetails]() - for i, item in enumerate(value): - if isinstance(item, TaskParameterStringValue): - # A TaskParameterStringValue is a FormatString. - # FormatString.expressions is the list of all expressions in the format string - # ( e.g. "{{ Param.Foo }}"). - # Validate the string if it does not contain any expressions, in order to catch - # errors earlier when possible. - if len(item.expressions) == 0: - try: - float(item) - except ValueError: - errors.append( - InitErrorDetails( - type="value_error", - loc=(i,), - ctx={"error": ValueError("String literal must contain a float.")}, - input=item, - ) - ) - if errors: - raise ValidationError.from_exception_data(cls.__name__, line_errors=errors) + return validate_list_field(value, validate_float_fmtstring_field) return value @@ -911,21 +807,7 @@ def _validate_range_element_type(cls, value: Any) -> Any: # We do allow coersion from a string since we want to allow "1", and # "1.2" or "a" will fail the type coersion if isinstance(value, list): - errors = list[InitErrorDetails]() - for i, item in enumerate(value): - if isinstance(item, bool) or not isinstance(item, (int, str)): - errors.append( - InitErrorDetails( - type="value_error", - loc=(i,), - ctx={ - "error": ValueError("Value must be an integer or integer string.") - }, - input=item, - ) - ) - if errors: - raise ValidationError.from_exception_data(cls.__name__, line_errors=errors) + return validate_list_field(value, validate_int_fmtstring_field) elif isinstance(value, RangeString): # Nothing to do - it's guaranteed to be a format string at this point pass @@ -935,33 +817,7 @@ def _validate_range_element_type(cls, value: Any) -> Any: @field_validator("range") @classmethod def _validate_range_elements(cls, value: Any) -> Any: - if isinstance(value, list): - errors = list[InitErrorDetails]() - for i, item in enumerate(value): - if isinstance(item, TaskParameterStringValue): - # A TaskParameterStringValue is a FormatString. - # FormatString.expressions is the list of all expressions in the format string - # ( e.g. "{{ Param.Foo }}"). - # Reject the string if it contains any expressions. - if len(item.expressions) == 0: - try: - int(item) - except ValueError: - errors.append( - InitErrorDetails( - type="value_error", - loc=(i,), - ctx={ - "error": ValueError( - "String literal must contain an integer." - ) - }, - input=item, - ) - ) - if errors: - raise ValidationError.from_exception_data(cls.__name__, line_errors=errors) - else: + if isinstance(value, FormatString): # If there are no format expressions, we can validate the range expression. # otherwise we defer to the RangeExressionTaskParameter model when # they've all been evaluated diff --git a/test/openjd/model/v2023_09/test_chunk_int_task_parameter_type.py b/test/openjd/model/v2023_09/test_chunk_int_task_parameter_type.py index 23a7669..3b5ab67 100644 --- a/test/openjd/model/v2023_09/test_chunk_int_task_parameter_type.py +++ b/test/openjd/model/v2023_09/test_chunk_int_task_parameter_type.py @@ -13,7 +13,7 @@ ) -@pytest.mark.parametrize( +PARAMETRIZE_CASES: tuple = ( "data", ( pytest.param( @@ -65,10 +65,10 @@ { "name": "foo", "type": "CHUNK[INT]", - "range": [1, "2", "{{Param.Value}}"], + "range": [-1, 0, 1, "2", "{{Param.Value}}"], "chunks": {"defaultTaskCount": 1, "rangeConstraint": "CONTIGUOUS"}, }, - id="mix of item types", + id="mix of item types and values", ), pytest.param( { @@ -159,6 +159,9 @@ ), ), ) + + +@pytest.mark.parametrize(*PARAMETRIZE_CASES) def test_chunk_int_task_parameter_parse_success(data: dict[str, Any]) -> None: # It parses successfully when the TASK_CHUNKING extension is requested _parse_model( @@ -176,7 +179,7 @@ def test_chunk_int_task_parameter_parse_success(data: dict[str, Any]) -> None: assert excinfo.value.error_count() == 1 -@pytest.mark.parametrize( +PARAMETRIZE_CASES = ( "data,error_message,error_count", ( pytest.param({}, "Field required", 4, id="empty object"), @@ -304,9 +307,39 @@ def test_chunk_int_task_parameter_parse_success(data: dict[str, Any]) -> None: "rangeConstraint": "NONCONTIGUOUS", }, }, - "Value must be an integer or integer string.", + "Value must be an integer or a string containing an integer.", + 1, + id="disallow floats in range", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": [1, 2], + "chunks": { + "defaultTaskCount": 10.1, + "targetRuntimeSeconds": 1000, + "rangeConstraint": "NONCONTIGUOUS", + }, + }, + "Value must be an integer or a string containing an integer.", 1, - id="disallow floats", + id="disallow floats in defaultTaskCount", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": [1, 2], + "chunks": { + "defaultTaskCount": 10, + "targetRuntimeSeconds": 1000.01, + "rangeConstraint": "NONCONTIGUOUS", + }, + }, + "Value must be an integer or a string containing an integer.", + 1, + id="disallow floats in targetRuntimeSeconds", ), pytest.param( { @@ -319,9 +352,39 @@ def test_chunk_int_task_parameter_parse_success(data: dict[str, Any]) -> None: "rangeConstraint": "NONCONTIGUOUS", }, }, - "Value must be an integer or integer string.", + "Value must be an integer or a string containing an integer.", 1, - id="disallow bool", + id="disallow bool in range", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": [1], + "chunks": { + "defaultTaskCount": True, + "targetRuntimeSeconds": 1000, + "rangeConstraint": "NONCONTIGUOUS", + }, + }, + "Value must be an integer or a string containing an integer.", + 1, + id="disallow bool in defaultTaskCount", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": [1], + "chunks": { + "defaultTaskCount": 10, + "targetRuntimeSeconds": True, + "rangeConstraint": "NONCONTIGUOUS", + }, + }, + "Value must be an integer or a string containing an integer.", + 1, + id="disallow bool in targetRuntimeSeconds", ), pytest.param( { @@ -334,9 +397,39 @@ def test_chunk_int_task_parameter_parse_success(data: dict[str, Any]) -> None: "rangeConstraint": "NONCONTIGUOUS", }, }, - "String literal must contain an integer.", + "Value must be an integer or a string containing an integer.", + 1, + id="disallow float strings in range", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": ["1"], + "chunks": { + "defaultTaskCount": "1.1", + "targetRuntimeSeconds": 1000, + "rangeConstraint": "NONCONTIGUOUS", + }, + }, + "Value must be an integer or a string containing an integer.", + 1, + id="disallow float strings in defaultTaskCount", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": ["1"], + "chunks": { + "defaultTaskCount": 10, + "targetRuntimeSeconds": "1000.1", + "rangeConstraint": "NONCONTIGUOUS", + }, + }, + "Value must be an integer or a string containing an integer.", 1, - id="disallow float strings", + id="disallow float strings in targetRuntimeSeconds", ), pytest.param( { @@ -350,8 +443,8 @@ def test_chunk_int_task_parameter_parse_success(data: dict[str, Any]) -> None: }, }, "Failed to parse interpolation expression at [0, 20]. Reason: Braces mismatch.", - 3, - id="malformed format string", + 1, + id="malformed format string in range", ), pytest.param( { @@ -364,9 +457,9 @@ def test_chunk_int_task_parameter_parse_success(data: dict[str, Any]) -> None: "rangeConstraint": "NONCONTIGUOUS", }, }, - "String literal must contain an integer.", + "Value must be an integer or a string containing an integer.", 1, - id="literal string not an int", + id="literal string not an int in range", ), pytest.param( { @@ -408,7 +501,7 @@ def test_chunk_int_task_parameter_parse_success(data: dict[str, Any]) -> None: }, }, "Input should be greater than or equal to 1", - 2, + 1, id="defaultTaskCount 0 (too small)", ), pytest.param( @@ -423,7 +516,7 @@ def test_chunk_int_task_parameter_parse_success(data: dict[str, Any]) -> None: }, }, "Input should be greater than or equal to 0", - 2, + 1, id="targetRuntimeSeconds -1 (too small)", ), pytest.param( @@ -467,7 +560,7 @@ def test_chunk_int_task_parameter_parse_success(data: dict[str, Any]) -> None: "rangeConstraint": "CONTIGUOUS", }, }, - "String literal must contain an integer.", + "Value must be an integer or a string containing an integer.", 1, id="defaultTaskCount is str with non-integer value", ), @@ -483,7 +576,7 @@ def test_chunk_int_task_parameter_parse_success(data: dict[str, Any]) -> None: }, }, "Failed to parse interpolation expression at [0, 18]. Reason: Braces mismatch.", - 2, + 1, id="defaultTaskCount is str with incorrect expression", ), pytest.param( @@ -497,7 +590,7 @@ def test_chunk_int_task_parameter_parse_success(data: dict[str, Any]) -> None: "rangeConstraint": "CONTIGUOUS", }, }, - "String literal must contain an integer.", + "Value must be an integer or a string containing an integer.", 1, id="targetRuntimeSeconds is str with non-integer value", ), @@ -528,11 +621,14 @@ def test_chunk_int_task_parameter_parse_success(data: dict[str, Any]) -> None: }, }, "Failed to parse interpolation expression at [0, 27]. Reason: Braces mismatch.", - 2, + 1, id="targetRuntimeSeconds is str with incorrect expression", ), ), ) + + +@pytest.mark.parametrize(*PARAMETRIZE_CASES) def test_chunk_int_task_parameter_parse_fails( data: dict[str, Any], error_message: str, error_count: int ) -> None: