Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/openjd/model/_internal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,21 @@
validate_step_parameter_space_dimensions,
)
from ._variable_reference_validation import prevalidate_model_template_variable_references
from ._validator_functions import (
validate_int_fmtstring_field,
validate_list_field,
validate_float_fmtstring_field,
)

__all__ = (
"instantiate_model",
"prevalidate_model_template_variable_references",
"validate_step_parameter_space_chunk_constraint",
"validate_step_parameter_space_dimensions",
"validate_unique_elements",
"validate_float_fmtstring_field",
"validate_int_fmtstring_field",
"validate_list_field",
"CombinationExpressionAssociationNode",
"CombinationExpressionIdentifierNode",
"CombinationExpressionNode",
Expand Down
119 changes: 119 additions & 0 deletions src/openjd/model/_internal/_validator_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.

from typing import cast, Any, Callable, Union, Optional
from decimal import Decimal, InvalidOperation

from pydantic_core import PydanticKnownError, ValidationError, InitErrorDetails

from .._format_strings import FormatString


def validate_int_fmtstring_field(
value: Union[int, float, Decimal, str, FormatString], ge: Optional[int] = None
) -> Union[int, float, Decimal, str, FormatString]:
"""Validates a field that is allowed to be either an integer, a string containing an integer,
or a string containing expressions that resolve to an integer."""
value_type_wrong_msg = "Value must be an integer or a string containing an integer."
# Validate the type
if isinstance(value, str):
if not isinstance(value, FormatString):
value = FormatString(value)
# If the string value has no expressions, we can validate the value now.
if len(value.expressions) == 0:
try:
int_value = int(value)
except ValueError:
raise ValueError(value_type_wrong_msg)
else:
# In this case, we cannot validate now. We need to validate this later when the template
# is instantiated into a job and the expressions have been resolved.
return value
elif isinstance(value, (int, float, Decimal)) and not isinstance(value, bool):
int_value = int(value)
if int_value != value:
raise ValueError(value_type_wrong_msg)
else:
raise ValueError(value_type_wrong_msg)

# Validate the value constraints
if ge is not None and not (int_value >= ge):
raise PydanticKnownError("greater_than_equal", {"ge": ge})

return int_value


def validate_float_fmtstring_field(
value: Union[int, float, Decimal, str, FormatString], ge: Optional[Decimal] = None
) -> Union[int, float, Decimal, str, FormatString]:
"""Validates a field that is allowed to be either an float, a string containing an float,
or a string containing expressions that resolve to a float."""
value_type_wrong_msg = "Value must be a float or a string containing a float."
# Validate the type
if isinstance(value, str):
if not isinstance(value, FormatString):
value = FormatString(value)
# If the string value has no expressions, we can validate the value now.
if len(value.expressions) == 0:
try:
float_value = Decimal(str(value))
except InvalidOperation:
raise ValueError(value_type_wrong_msg)
else:
# In this case, we cannot validate now. We need to validate this later when the template
# is instantiated into a job and the expressions have been resolved.
return value
elif isinstance(value, (int, float, Decimal)) and not isinstance(value, bool):
try:
float_value = Decimal(str(value))
except InvalidOperation:
raise ValueError(value_type_wrong_msg)
else:
raise ValueError(value_type_wrong_msg)

# Validate the value constraints
if ge is not None and not (float_value >= ge):
raise PydanticKnownError("greater_than_equal", {"ge": ge})

return float_value


def validate_list_field(value: list, validator: Callable) -> list:
"""Validates a list of values using the provided validator function."""
errors = list[InitErrorDetails]()
for i, item in enumerate(value):
try:
validator(item)
except PydanticKnownError as exc:
# Copy known errors verbatim with added location data
errors.append(
InitErrorDetails(
type=exc.type,
loc=(i,),
ctx=exc.context or {},
input=item,
)
)
except ValidationError as exc:
# Convert the ErrorDetails to InitErrorDetails by extending the 'loc' and excluding the 'msg'
for error_details in exc.errors():
init_error_details: dict[str, Any] = {}
for err_key, err_value in error_details.items():
if err_key == "loc":
init_error_details["loc"] = (i,) + cast(tuple, err_value)
elif err_key != "msg":
init_error_details[err_key] = err_value
errors.append(cast(InitErrorDetails, init_error_details))
except Exception as exc:
# Copy other errors as value_error
errors.append(
InitErrorDetails(
type="value_error",
loc=(i,),
ctx={"error": exc},
input=item,
)
)
if errors:
# Raise the list of all the individual errors as a new ValidationError
raise ValidationError.from_exception_data("list", line_errors=errors)
return value
174 changes: 15 additions & 159 deletions src/openjd/model/v2023_09/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from graphlib import CycleError, TopologicalSorter
from typing import Any, ClassVar, Literal, Optional, Type, Union, cast, Iterable
from typing_extensions import Annotated, Self
import annotated_types

from pydantic import (
field_validator,
Expand All @@ -17,13 +16,12 @@
Field,
PositiveInt,
PositiveFloat,
Strict,
StrictBool,
StrictInt,
ValidationError,
ValidationInfo,
)
from pydantic_core import InitErrorDetails, PydanticKnownError
from pydantic_core import InitErrorDetails
from pydantic.fields import ModelPrivateAttr

from .._format_strings import FormatString
Expand All @@ -37,6 +35,9 @@
validate_step_parameter_space_dimensions,
validate_step_parameter_space_chunk_constraint,
validate_unique_elements,
validate_int_fmtstring_field,
validate_float_fmtstring_field,
validate_list_field,
)
from .._internal._variable_reference_validation import (
prevalidate_model_template_variable_references,
Expand Down Expand Up @@ -567,44 +568,23 @@ class TaskChunksRangeConstraint(str, Enum):


class TaskChunksDefinition(OpenJDModel_v2023_09):
defaultTaskCount: Union[Annotated[int, annotated_types.Ge(1), Strict()], FormatString]
targetRuntimeSeconds: Optional[
Union[Annotated[int, annotated_types.Ge(0), Strict()], FormatString]
] = None
defaultTaskCount: Union[int, FormatString]
targetRuntimeSeconds: Optional[Union[int, FormatString]] = None
rangeConstraint: TaskChunksRangeConstraint

_job_creation_metadata = JobCreationMetadata(
resolve_fields={"defaultTaskCount", "targetRuntimeSeconds"},
)

@field_validator("defaultTaskCount")
@field_validator("defaultTaskCount", mode="before")
@classmethod
def _validate_default_task_count(cls, value: Any) -> Any:
if isinstance(value, FormatString):
# If the string value has no expressions, can validate the value now.
# Otherwise will validate when
if len(value.expressions) == 0:
try:
int_value = int(value)
except ValueError:
raise ValueError("String literal must contain an integer.")
if int_value < 1:
raise PydanticKnownError("greater_than_equal", {"ge": 1})
return value
return validate_int_fmtstring_field(value, ge=1)

@field_validator("targetRuntimeSeconds")
@field_validator("targetRuntimeSeconds", mode="before")
@classmethod
def _validate_target_runtime_seconds(cls, value: Any) -> Any:
if isinstance(value, FormatString):
# If the string value has no expressions, can validate it now
if len(value.expressions) == 0:
try:
int_value = int(value)
except ValueError:
raise ValueError("String literal must contain an integer.")
if int_value < 0:
raise PydanticKnownError("greater_than_equal", {"ge": 0})
return value
return validate_int_fmtstring_field(value, ge=0)


class IntTaskParameterDefinition(OpenJDModel_v2023_09):
Expand Down Expand Up @@ -651,21 +631,7 @@ def _validate_range_element_type(cls, value: Any) -> Any:
# We do allow coersion from a string since we want to allow "1", and
# "1.2" or "a" will fail the type coersion
if isinstance(value, list):
errors = list[InitErrorDetails]()
for i, item in enumerate(value):
if isinstance(item, bool) or not isinstance(item, (int, str)):
errors.append(
InitErrorDetails(
type="value_error",
loc=(i,),
ctx={
"error": ValueError("Value must be an integer or integer string.")
},
input=item,
)
)
if errors:
raise ValidationError.from_exception_data(cls.__name__, line_errors=errors)
return validate_list_field(value, validate_int_fmtstring_field)
elif isinstance(value, RangeString):
# Nothing to do - it's guaranteed to be a format string at this point
pass
Expand All @@ -675,34 +641,7 @@ def _validate_range_element_type(cls, value: Any) -> Any:
@field_validator("range")
@classmethod
def _validate_range_elements(cls, value: Any) -> Any:
if isinstance(value, list):
errors = list[InitErrorDetails]()
for i, item in enumerate(value):
if isinstance(item, TaskParameterStringValue):
# A TaskParameterStringValue is a FormatString.
# FormatString.expressions is the list of all expressions in the format string
# ( e.g. "{{ Param.Foo }}").
# Validate the string if it does not contain any expressions, in order to catch
# errors earlier when possible.
if len(item.expressions) == 0:
try:
int(item)
except ValueError:
errors.append(
InitErrorDetails(
type="value_error",
loc=(i,),
ctx={
"error": ValueError(
"String literal must contain an integer."
)
},
input=item,
)
)
if errors:
raise ValidationError.from_exception_data(cls.__name__, line_errors=errors)
else:
if isinstance(value, FormatString):
# If there are no format expressions, we can validate the range expression.
# otherwise we defer to the RangeExressionTaskParameter model when
# they've all been evaluated
Expand Down Expand Up @@ -747,51 +686,8 @@ def _validate_range_element_type(cls, value: Any) -> Any:
# pydantic will automatically type coerce values into floats. We explicitly
# want to reject non-integer values, so this *pre* validator validates the
# value *before* pydantic tries to type coerce it.
# We do allow coersion from a string since we want to allow "1", and
# "1.2" or "a" will fail the type coersion
if isinstance(value, list):
errors = list[InitErrorDetails]()
for i, item in enumerate(value):
if isinstance(item, bool) or not isinstance(item, (int, float, str)):
errors.append(
InitErrorDetails(
type="value_error",
loc=("range", i),
ctx={"error": ValueError("Value must be a float or float string.")},
input=item,
)
)
if errors:
raise ValidationError.from_exception_data(cls.__name__, line_errors=errors)
return value

@field_validator("range")
@classmethod
def _validate_range_elements(
cls, value: list[Union[Decimal, TaskParameterStringValue]]
) -> list[Union[Decimal, TaskParameterStringValue]]:
errors = list[InitErrorDetails]()
for i, item in enumerate(value):
if isinstance(item, TaskParameterStringValue):
# A TaskParameterStringValue is a FormatString.
# FormatString.expressions is the list of all expressions in the format string
# ( e.g. "{{ Param.Foo }}").
# Validate the string if it does not contain any expressions, in order to catch
# errors earlier when possible.
if len(item.expressions) == 0:
try:
float(item)
except ValueError:
errors.append(
InitErrorDetails(
type="value_error",
loc=(i,),
ctx={"error": ValueError("String literal must contain a float.")},
input=item,
)
)
if errors:
raise ValidationError.from_exception_data(cls.__name__, line_errors=errors)
return validate_list_field(value, validate_float_fmtstring_field)
return value


Expand Down Expand Up @@ -911,21 +807,7 @@ def _validate_range_element_type(cls, value: Any) -> Any:
# We do allow coersion from a string since we want to allow "1", and
# "1.2" or "a" will fail the type coersion
if isinstance(value, list):
errors = list[InitErrorDetails]()
for i, item in enumerate(value):
if isinstance(item, bool) or not isinstance(item, (int, str)):
errors.append(
InitErrorDetails(
type="value_error",
loc=(i,),
ctx={
"error": ValueError("Value must be an integer or integer string.")
},
input=item,
)
)
if errors:
raise ValidationError.from_exception_data(cls.__name__, line_errors=errors)
return validate_list_field(value, validate_int_fmtstring_field)
elif isinstance(value, RangeString):
# Nothing to do - it's guaranteed to be a format string at this point
pass
Expand All @@ -935,33 +817,7 @@ def _validate_range_element_type(cls, value: Any) -> Any:
@field_validator("range")
@classmethod
def _validate_range_elements(cls, value: Any) -> Any:
if isinstance(value, list):
errors = list[InitErrorDetails]()
for i, item in enumerate(value):
if isinstance(item, TaskParameterStringValue):
# A TaskParameterStringValue is a FormatString.
# FormatString.expressions is the list of all expressions in the format string
# ( e.g. "{{ Param.Foo }}").
# Reject the string if it contains any expressions.
if len(item.expressions) == 0:
try:
int(item)
except ValueError:
errors.append(
InitErrorDetails(
type="value_error",
loc=(i,),
ctx={
"error": ValueError(
"String literal must contain an integer."
)
},
input=item,
)
)
if errors:
raise ValidationError.from_exception_data(cls.__name__, line_errors=errors)
else:
if isinstance(value, FormatString):
# If there are no format expressions, we can validate the range expression.
# otherwise we defer to the RangeExressionTaskParameter model when
# they've all been evaluated
Expand Down
Loading