diff --git a/src/openjd/model/_internal/__init__.py b/src/openjd/model/_internal/__init__.py index b0dee5e..efe0f42 100644 --- a/src/openjd/model/_internal/__init__.py +++ b/src/openjd/model/_internal/__init__.py @@ -13,6 +13,11 @@ validate_step_parameter_space_dimensions, ) from ._variable_reference_validation import prevalidate_model_template_variable_references +from ._validator_functions import ( + validate_int_fmtstring_field, + validate_list_field, + validate_float_fmtstring_field, +) __all__ = ( "instantiate_model", @@ -20,6 +25,9 @@ "validate_step_parameter_space_chunk_constraint", "validate_step_parameter_space_dimensions", "validate_unique_elements", + "validate_float_fmtstring_field", + "validate_int_fmtstring_field", + "validate_list_field", "CombinationExpressionAssociationNode", "CombinationExpressionIdentifierNode", "CombinationExpressionNode", diff --git a/src/openjd/model/_internal/_validator_functions.py b/src/openjd/model/_internal/_validator_functions.py new file mode 100644 index 0000000..b7abd44 --- /dev/null +++ b/src/openjd/model/_internal/_validator_functions.py @@ -0,0 +1,119 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from typing import cast, Any, Callable, Union, Optional +from decimal import Decimal, InvalidOperation + +from pydantic_core import PydanticKnownError, ValidationError, InitErrorDetails + +from .._format_strings import FormatString + + +def validate_int_fmtstring_field( + value: Union[int, float, Decimal, str, FormatString], ge: Optional[int] = None +) -> Union[int, float, Decimal, str, FormatString]: + """Validates a field that is allowed to be either an integer, a string containing an integer, + or a string containing expressions that resolve to an integer.""" + value_type_wrong_msg = "Value must be an integer or a string containing an integer." + # Validate the type + if isinstance(value, str): + if not isinstance(value, FormatString): + value = FormatString(value) + # If the string value has no expressions, we can validate the value now. + if len(value.expressions) == 0: + try: + int_value = int(value) + except ValueError: + raise ValueError(value_type_wrong_msg) + else: + # In this case, we cannot validate now. We need to validate this later when the template + # is instantiated into a job and the expressions have been resolved. + return value + elif isinstance(value, (int, float, Decimal)) and not isinstance(value, bool): + int_value = int(value) + if int_value != value: + raise ValueError(value_type_wrong_msg) + else: + raise ValueError(value_type_wrong_msg) + + # Validate the value constraints + if ge is not None and not (int_value >= ge): + raise PydanticKnownError("greater_than_equal", {"ge": ge}) + + return int_value + + +def validate_float_fmtstring_field( + value: Union[int, float, Decimal, str, FormatString], ge: Optional[Decimal] = None +) -> Union[int, float, Decimal, str, FormatString]: + """Validates a field that is allowed to be either an float, a string containing an float, + or a string containing expressions that resolve to a float.""" + value_type_wrong_msg = "Value must be a float or a string containing a float." + # Validate the type + if isinstance(value, str): + if not isinstance(value, FormatString): + value = FormatString(value) + # If the string value has no expressions, we can validate the value now. + if len(value.expressions) == 0: + try: + float_value = Decimal(str(value)) + except InvalidOperation: + raise ValueError(value_type_wrong_msg) + else: + # In this case, we cannot validate now. We need to validate this later when the template + # is instantiated into a job and the expressions have been resolved. + return value + elif isinstance(value, (int, float, Decimal)) and not isinstance(value, bool): + try: + float_value = Decimal(str(value)) + except InvalidOperation: + raise ValueError(value_type_wrong_msg) + else: + raise ValueError(value_type_wrong_msg) + + # Validate the value constraints + if ge is not None and not (float_value >= ge): + raise PydanticKnownError("greater_than_equal", {"ge": ge}) + + return float_value + + +def validate_list_field(value: list, validator: Callable) -> list: + """Validates a list of values using the provided validator function.""" + errors = list[InitErrorDetails]() + for i, item in enumerate(value): + try: + validator(item) + except PydanticKnownError as exc: + # Copy known errors verbatim with added location data + errors.append( + InitErrorDetails( + type=exc.type, + loc=(i,), + ctx=exc.context or {}, + input=item, + ) + ) + except ValidationError as exc: + # Convert the ErrorDetails to InitErrorDetails by extending the 'loc' and excluding the 'msg' + for error_details in exc.errors(): + init_error_details: dict[str, Any] = {} + for err_key, err_value in error_details.items(): + if err_key == "loc": + init_error_details["loc"] = (i,) + cast(tuple, err_value) + elif err_key != "msg": + init_error_details[err_key] = err_value + errors.append(cast(InitErrorDetails, init_error_details)) + except Exception as exc: + # Copy other errors as value_error + errors.append( + InitErrorDetails( + type="value_error", + loc=(i,), + ctx={"error": exc}, + input=item, + ) + ) + if errors: + # Raise the list of all the individual errors as a new ValidationError + raise ValidationError.from_exception_data("list", line_errors=errors) + return value diff --git a/src/openjd/model/v2023_09/_model.py b/src/openjd/model/v2023_09/_model.py index b8f2204..8de45dd 100644 --- a/src/openjd/model/v2023_09/_model.py +++ b/src/openjd/model/v2023_09/_model.py @@ -8,7 +8,6 @@ from graphlib import CycleError, TopologicalSorter from typing import Any, ClassVar, Literal, Optional, Type, Union, cast, Iterable from typing_extensions import Annotated, Self -import annotated_types from pydantic import ( field_validator, @@ -17,13 +16,12 @@ Field, PositiveInt, PositiveFloat, - Strict, StrictBool, StrictInt, ValidationError, ValidationInfo, ) -from pydantic_core import InitErrorDetails, PydanticKnownError +from pydantic_core import InitErrorDetails from pydantic.fields import ModelPrivateAttr from .._format_strings import FormatString @@ -37,6 +35,9 @@ validate_step_parameter_space_dimensions, validate_step_parameter_space_chunk_constraint, validate_unique_elements, + validate_int_fmtstring_field, + validate_float_fmtstring_field, + validate_list_field, ) from .._internal._variable_reference_validation import ( prevalidate_model_template_variable_references, @@ -567,44 +568,23 @@ class TaskChunksRangeConstraint(str, Enum): class TaskChunksDefinition(OpenJDModel_v2023_09): - defaultTaskCount: Union[Annotated[int, annotated_types.Ge(1), Strict()], FormatString] - targetRuntimeSeconds: Optional[ - Union[Annotated[int, annotated_types.Ge(0), Strict()], FormatString] - ] = None + defaultTaskCount: Union[int, FormatString] + targetRuntimeSeconds: Optional[Union[int, FormatString]] = None rangeConstraint: TaskChunksRangeConstraint _job_creation_metadata = JobCreationMetadata( resolve_fields={"defaultTaskCount", "targetRuntimeSeconds"}, ) - @field_validator("defaultTaskCount") + @field_validator("defaultTaskCount", mode="before") @classmethod def _validate_default_task_count(cls, value: Any) -> Any: - if isinstance(value, FormatString): - # If the string value has no expressions, can validate the value now. - # Otherwise will validate when - if len(value.expressions) == 0: - try: - int_value = int(value) - except ValueError: - raise ValueError("String literal must contain an integer.") - if int_value < 1: - raise PydanticKnownError("greater_than_equal", {"ge": 1}) - return value + return validate_int_fmtstring_field(value, ge=1) - @field_validator("targetRuntimeSeconds") + @field_validator("targetRuntimeSeconds", mode="before") @classmethod def _validate_target_runtime_seconds(cls, value: Any) -> Any: - if isinstance(value, FormatString): - # If the string value has no expressions, can validate it now - if len(value.expressions) == 0: - try: - int_value = int(value) - except ValueError: - raise ValueError("String literal must contain an integer.") - if int_value < 0: - raise PydanticKnownError("greater_than_equal", {"ge": 0}) - return value + return validate_int_fmtstring_field(value, ge=0) class IntTaskParameterDefinition(OpenJDModel_v2023_09): @@ -651,21 +631,7 @@ def _validate_range_element_type(cls, value: Any) -> Any: # We do allow coersion from a string since we want to allow "1", and # "1.2" or "a" will fail the type coersion if isinstance(value, list): - errors = list[InitErrorDetails]() - for i, item in enumerate(value): - if isinstance(item, bool) or not isinstance(item, (int, str)): - errors.append( - InitErrorDetails( - type="value_error", - loc=(i,), - ctx={ - "error": ValueError("Value must be an integer or integer string.") - }, - input=item, - ) - ) - if errors: - raise ValidationError.from_exception_data(cls.__name__, line_errors=errors) + return validate_list_field(value, validate_int_fmtstring_field) elif isinstance(value, RangeString): # Nothing to do - it's guaranteed to be a format string at this point pass @@ -675,34 +641,7 @@ def _validate_range_element_type(cls, value: Any) -> Any: @field_validator("range") @classmethod def _validate_range_elements(cls, value: Any) -> Any: - if isinstance(value, list): - errors = list[InitErrorDetails]() - for i, item in enumerate(value): - if isinstance(item, TaskParameterStringValue): - # A TaskParameterStringValue is a FormatString. - # FormatString.expressions is the list of all expressions in the format string - # ( e.g. "{{ Param.Foo }}"). - # Validate the string if it does not contain any expressions, in order to catch - # errors earlier when possible. - if len(item.expressions) == 0: - try: - int(item) - except ValueError: - errors.append( - InitErrorDetails( - type="value_error", - loc=(i,), - ctx={ - "error": ValueError( - "String literal must contain an integer." - ) - }, - input=item, - ) - ) - if errors: - raise ValidationError.from_exception_data(cls.__name__, line_errors=errors) - else: + if isinstance(value, FormatString): # If there are no format expressions, we can validate the range expression. # otherwise we defer to the RangeExressionTaskParameter model when # they've all been evaluated @@ -747,51 +686,8 @@ def _validate_range_element_type(cls, value: Any) -> Any: # pydantic will automatically type coerce values into floats. We explicitly # want to reject non-integer values, so this *pre* validator validates the # value *before* pydantic tries to type coerce it. - # We do allow coersion from a string since we want to allow "1", and - # "1.2" or "a" will fail the type coersion if isinstance(value, list): - errors = list[InitErrorDetails]() - for i, item in enumerate(value): - if isinstance(item, bool) or not isinstance(item, (int, float, str)): - errors.append( - InitErrorDetails( - type="value_error", - loc=("range", i), - ctx={"error": ValueError("Value must be a float or float string.")}, - input=item, - ) - ) - if errors: - raise ValidationError.from_exception_data(cls.__name__, line_errors=errors) - return value - - @field_validator("range") - @classmethod - def _validate_range_elements( - cls, value: list[Union[Decimal, TaskParameterStringValue]] - ) -> list[Union[Decimal, TaskParameterStringValue]]: - errors = list[InitErrorDetails]() - for i, item in enumerate(value): - if isinstance(item, TaskParameterStringValue): - # A TaskParameterStringValue is a FormatString. - # FormatString.expressions is the list of all expressions in the format string - # ( e.g. "{{ Param.Foo }}"). - # Validate the string if it does not contain any expressions, in order to catch - # errors earlier when possible. - if len(item.expressions) == 0: - try: - float(item) - except ValueError: - errors.append( - InitErrorDetails( - type="value_error", - loc=(i,), - ctx={"error": ValueError("String literal must contain a float.")}, - input=item, - ) - ) - if errors: - raise ValidationError.from_exception_data(cls.__name__, line_errors=errors) + return validate_list_field(value, validate_float_fmtstring_field) return value @@ -911,21 +807,7 @@ def _validate_range_element_type(cls, value: Any) -> Any: # We do allow coersion from a string since we want to allow "1", and # "1.2" or "a" will fail the type coersion if isinstance(value, list): - errors = list[InitErrorDetails]() - for i, item in enumerate(value): - if isinstance(item, bool) or not isinstance(item, (int, str)): - errors.append( - InitErrorDetails( - type="value_error", - loc=(i,), - ctx={ - "error": ValueError("Value must be an integer or integer string.") - }, - input=item, - ) - ) - if errors: - raise ValidationError.from_exception_data(cls.__name__, line_errors=errors) + return validate_list_field(value, validate_int_fmtstring_field) elif isinstance(value, RangeString): # Nothing to do - it's guaranteed to be a format string at this point pass @@ -935,33 +817,7 @@ def _validate_range_element_type(cls, value: Any) -> Any: @field_validator("range") @classmethod def _validate_range_elements(cls, value: Any) -> Any: - if isinstance(value, list): - errors = list[InitErrorDetails]() - for i, item in enumerate(value): - if isinstance(item, TaskParameterStringValue): - # A TaskParameterStringValue is a FormatString. - # FormatString.expressions is the list of all expressions in the format string - # ( e.g. "{{ Param.Foo }}"). - # Reject the string if it contains any expressions. - if len(item.expressions) == 0: - try: - int(item) - except ValueError: - errors.append( - InitErrorDetails( - type="value_error", - loc=(i,), - ctx={ - "error": ValueError( - "String literal must contain an integer." - ) - }, - input=item, - ) - ) - if errors: - raise ValidationError.from_exception_data(cls.__name__, line_errors=errors) - else: + if isinstance(value, FormatString): # If there are no format expressions, we can validate the range expression. # otherwise we defer to the RangeExressionTaskParameter model when # they've all been evaluated diff --git a/test/openjd/model/v2023_09/test_chunk_int_task_parameter_type.py b/test/openjd/model/v2023_09/test_chunk_int_task_parameter_type.py index 23a7669..3b5ab67 100644 --- a/test/openjd/model/v2023_09/test_chunk_int_task_parameter_type.py +++ b/test/openjd/model/v2023_09/test_chunk_int_task_parameter_type.py @@ -13,7 +13,7 @@ ) -@pytest.mark.parametrize( +PARAMETRIZE_CASES: tuple = ( "data", ( pytest.param( @@ -65,10 +65,10 @@ { "name": "foo", "type": "CHUNK[INT]", - "range": [1, "2", "{{Param.Value}}"], + "range": [-1, 0, 1, "2", "{{Param.Value}}"], "chunks": {"defaultTaskCount": 1, "rangeConstraint": "CONTIGUOUS"}, }, - id="mix of item types", + id="mix of item types and values", ), pytest.param( { @@ -159,6 +159,9 @@ ), ), ) + + +@pytest.mark.parametrize(*PARAMETRIZE_CASES) def test_chunk_int_task_parameter_parse_success(data: dict[str, Any]) -> None: # It parses successfully when the TASK_CHUNKING extension is requested _parse_model( @@ -176,7 +179,7 @@ def test_chunk_int_task_parameter_parse_success(data: dict[str, Any]) -> None: assert excinfo.value.error_count() == 1 -@pytest.mark.parametrize( +PARAMETRIZE_CASES = ( "data,error_message,error_count", ( pytest.param({}, "Field required", 4, id="empty object"), @@ -304,9 +307,39 @@ def test_chunk_int_task_parameter_parse_success(data: dict[str, Any]) -> None: "rangeConstraint": "NONCONTIGUOUS", }, }, - "Value must be an integer or integer string.", + "Value must be an integer or a string containing an integer.", + 1, + id="disallow floats in range", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": [1, 2], + "chunks": { + "defaultTaskCount": 10.1, + "targetRuntimeSeconds": 1000, + "rangeConstraint": "NONCONTIGUOUS", + }, + }, + "Value must be an integer or a string containing an integer.", 1, - id="disallow floats", + id="disallow floats in defaultTaskCount", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": [1, 2], + "chunks": { + "defaultTaskCount": 10, + "targetRuntimeSeconds": 1000.01, + "rangeConstraint": "NONCONTIGUOUS", + }, + }, + "Value must be an integer or a string containing an integer.", + 1, + id="disallow floats in targetRuntimeSeconds", ), pytest.param( { @@ -319,9 +352,39 @@ def test_chunk_int_task_parameter_parse_success(data: dict[str, Any]) -> None: "rangeConstraint": "NONCONTIGUOUS", }, }, - "Value must be an integer or integer string.", + "Value must be an integer or a string containing an integer.", 1, - id="disallow bool", + id="disallow bool in range", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": [1], + "chunks": { + "defaultTaskCount": True, + "targetRuntimeSeconds": 1000, + "rangeConstraint": "NONCONTIGUOUS", + }, + }, + "Value must be an integer or a string containing an integer.", + 1, + id="disallow bool in defaultTaskCount", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": [1], + "chunks": { + "defaultTaskCount": 10, + "targetRuntimeSeconds": True, + "rangeConstraint": "NONCONTIGUOUS", + }, + }, + "Value must be an integer or a string containing an integer.", + 1, + id="disallow bool in targetRuntimeSeconds", ), pytest.param( { @@ -334,9 +397,39 @@ def test_chunk_int_task_parameter_parse_success(data: dict[str, Any]) -> None: "rangeConstraint": "NONCONTIGUOUS", }, }, - "String literal must contain an integer.", + "Value must be an integer or a string containing an integer.", + 1, + id="disallow float strings in range", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": ["1"], + "chunks": { + "defaultTaskCount": "1.1", + "targetRuntimeSeconds": 1000, + "rangeConstraint": "NONCONTIGUOUS", + }, + }, + "Value must be an integer or a string containing an integer.", + 1, + id="disallow float strings in defaultTaskCount", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": ["1"], + "chunks": { + "defaultTaskCount": 10, + "targetRuntimeSeconds": "1000.1", + "rangeConstraint": "NONCONTIGUOUS", + }, + }, + "Value must be an integer or a string containing an integer.", 1, - id="disallow float strings", + id="disallow float strings in targetRuntimeSeconds", ), pytest.param( { @@ -350,8 +443,8 @@ def test_chunk_int_task_parameter_parse_success(data: dict[str, Any]) -> None: }, }, "Failed to parse interpolation expression at [0, 20]. Reason: Braces mismatch.", - 3, - id="malformed format string", + 1, + id="malformed format string in range", ), pytest.param( { @@ -364,9 +457,9 @@ def test_chunk_int_task_parameter_parse_success(data: dict[str, Any]) -> None: "rangeConstraint": "NONCONTIGUOUS", }, }, - "String literal must contain an integer.", + "Value must be an integer or a string containing an integer.", 1, - id="literal string not an int", + id="literal string not an int in range", ), pytest.param( { @@ -408,7 +501,7 @@ def test_chunk_int_task_parameter_parse_success(data: dict[str, Any]) -> None: }, }, "Input should be greater than or equal to 1", - 2, + 1, id="defaultTaskCount 0 (too small)", ), pytest.param( @@ -423,7 +516,7 @@ def test_chunk_int_task_parameter_parse_success(data: dict[str, Any]) -> None: }, }, "Input should be greater than or equal to 0", - 2, + 1, id="targetRuntimeSeconds -1 (too small)", ), pytest.param( @@ -467,7 +560,7 @@ def test_chunk_int_task_parameter_parse_success(data: dict[str, Any]) -> None: "rangeConstraint": "CONTIGUOUS", }, }, - "String literal must contain an integer.", + "Value must be an integer or a string containing an integer.", 1, id="defaultTaskCount is str with non-integer value", ), @@ -483,7 +576,7 @@ def test_chunk_int_task_parameter_parse_success(data: dict[str, Any]) -> None: }, }, "Failed to parse interpolation expression at [0, 18]. Reason: Braces mismatch.", - 2, + 1, id="defaultTaskCount is str with incorrect expression", ), pytest.param( @@ -497,7 +590,7 @@ def test_chunk_int_task_parameter_parse_success(data: dict[str, Any]) -> None: "rangeConstraint": "CONTIGUOUS", }, }, - "String literal must contain an integer.", + "Value must be an integer or a string containing an integer.", 1, id="targetRuntimeSeconds is str with non-integer value", ), @@ -528,11 +621,14 @@ def test_chunk_int_task_parameter_parse_success(data: dict[str, Any]) -> None: }, }, "Failed to parse interpolation expression at [0, 27]. Reason: Braces mismatch.", - 2, + 1, id="targetRuntimeSeconds is str with incorrect expression", ), ), ) + + +@pytest.mark.parametrize(*PARAMETRIZE_CASES) def test_chunk_int_task_parameter_parse_fails( data: dict[str, Any], error_message: str, error_count: int ) -> None: