diff --git a/pyproject.toml b/pyproject.toml index eef0cfc..f8c1bbe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ classifiers = [ ] dependencies = [ "pyyaml == 6.0.*", - "pydantic >= 1.10.17", + "pydantic >= 2.10, < 3", ] [project.urls] @@ -99,7 +99,7 @@ mypy_path = "src" # See: https://docs.pydantic.dev/mypy_plugin/ # - Helps mypy understand pydantic typing. -plugins = "pydantic.v1.mypy" +plugins = "pydantic.mypy" [tool.ruff] line-length = 100 @@ -111,15 +111,6 @@ ignore = [ "E731", ] -[tool.ruff.lint.pep8-naming] -classmethod-decorators = [ - "classmethod", - # pydantic decorators are classmethod decorators - # suppress N805 errors on classes decorated with them - "pydantic.validator", - "pydantic.root_validator", -] - [tool.ruff.lint.isort] known-first-party = [ "openjd", diff --git a/src/openjd/model/_convert_pydantic_error.py b/src/openjd/model/_convert_pydantic_error.py index ac97e92..b424bc5 100644 --- a/src/openjd/model/_convert_pydantic_error.py +++ b/src/openjd/model/_convert_pydantic_error.py @@ -1,37 +1,33 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -from typing import TypedDict, Union, Type -from pydantic.v1 import BaseModel +from typing import Type, Union +from pydantic import BaseModel +from pydantic_core import ErrorDetails from inspect import getmodule -# Calling pydantic's ValidationError.errors() returns a list[ErrorDict], but -# pydantic doesn't export the ErrorDict type publicly. So, we create it here for -# type checking. -# Note that we ignore the 'ctx' key since we don't use it. -# See: https://github.com/pydantic/pydantic/blob/d9c2af3a701ca982945a590de1a1da98b3fb4003/pydantic/error_wrappers.py#L50 -Loc = tuple[Union[int, str], ...] - -class ErrorDict(TypedDict): - loc: Loc - msg: str - type: str - - -def pydantic_validationerrors_to_str(root_model: Type[BaseModel], errors: list[ErrorDict]) -> str: +def pydantic_validationerrors_to_str( + root_model: Type[BaseModel], errors: list[ErrorDetails] +) -> str: """This is our own custom stringification of the Pydantic ValidationError to use in place of str(). Pydantic's default stringification too verbose for our purpose, and contains information that we don't want. """ results = list[str]() - for error in errors: - results.append(_error_dict_to_str(root_model, error)) + for error_details in errors: + results.append(_error_dict_to_str(root_model, error_details)) return f"{len(errors)} validation errors for {root_model.__name__}\n" + "\n".join(results) -def _error_dict_to_str(root_model: Type[BaseModel], error: ErrorDict) -> str: - loc = error["loc"] - msg = error["msg"] +def _error_dict_to_str(root_model: Type[BaseModel], error_details: ErrorDetails) -> str: + error_type = error_details["type"] + loc = error_details["loc"] + # Skip the "Value error," prefix by getting the exception message directly. + # This preserves the message formatting created when Pydantic V1 was in use. + if error_type == "value_error": + msg = str(error_details["ctx"]["error"]) + else: + msg = error_details["msg"] # When a model's root_validator raises an error other than a ValidationError # (i.e. raises something like a ValueError or a TypeError) then pydantic @@ -54,7 +50,7 @@ def _error_dict_to_str(root_model: Type[BaseModel], error: ErrorDict) -> str: return f"{_loc_to_str(root_model, loc)}:\n\t{msg}" -def _loc_to_str(root_model: Type[BaseModel], loc: Loc) -> str: +def _loc_to_str(root_model: Type[BaseModel], loc: tuple[Union[int, str], ...]) -> str: model_module = getmodule(root_model) # If a nested error is from a root validator, then just report the error as being diff --git a/src/openjd/model/_create_job.py b/src/openjd/model/_create_job.py index f334757..c4ea3f6 100644 --- a/src/openjd/model/_create_job.py +++ b/src/openjd/model/_create_job.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import Optional, cast -from pydantic.v1 import ValidationError +from pydantic import ValidationError from ._errors import CompatibilityError, DecodeValidationError from ._symbol_table import SymbolTable @@ -22,7 +22,7 @@ SpecificationRevision, TemplateSpecificationVersion, ) -from ._convert_pydantic_error import pydantic_validationerrors_to_str, ErrorDict +from ._convert_pydantic_error import pydantic_validationerrors_to_str __all__ = ("preprocess_job_parameters",) @@ -330,9 +330,7 @@ def create_job( job = instantiate_model(job_template, symtab) except ValidationError as exc: raise DecodeValidationError( - pydantic_validationerrors_to_str( - job_template.__class__, cast(list[ErrorDict], exc.errors()) - ) + pydantic_validationerrors_to_str(job_template.__class__, exc.errors()) ) return cast(Job, job) diff --git a/src/openjd/model/_format_strings/_dyn_constrained_str.py b/src/openjd/model/_format_strings/_dyn_constrained_str.py index 26a2eb1..20c32a5 100644 --- a/src/openjd/model/_format_strings/_dyn_constrained_str.py +++ b/src/openjd/model/_format_strings/_dyn_constrained_str.py @@ -1,14 +1,11 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -import re -from typing import TYPE_CHECKING, Any, Callable, Optional, Pattern, Union - -from pydantic.v1.errors import AnyStrMaxLengthError, AnyStrMinLengthError, StrRegexError -from pydantic.v1.utils import update_not_none -from pydantic.v1.validators import strict_str_validator +from typing import Any, Callable, Optional, Pattern, Union -if TYPE_CHECKING: - from pydantic.v1.typing import CallableGenerator +from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler +from pydantic.json_schema import JsonSchemaValue +from pydantic_core import core_schema +import re class DynamicConstrainedStr(str): @@ -33,37 +30,43 @@ def _get_max_length(cls) -> Optional[int]: return cls._max_length @classmethod - def __modify_schema__(cls, field_schema: dict[str, Any]) -> None: - update_not_none( - field_schema, - minLength=cls._min_length, - maxLength=cls._get_max_length(), - ) - - @classmethod - def __get_validators__(cls) -> "CallableGenerator": - yield strict_str_validator # Always strict string. - yield cls._validate_min_length - yield cls._validate_max_length - yield cls._validate_regex + def _validate(cls, value: str) -> Any: + if not isinstance(value, str): + raise ValueError("String required") - @classmethod - def _validate_min_length(cls, value: str) -> str: if cls._min_length is not None and len(value) < cls._min_length: - raise AnyStrMinLengthError(limit_value=cls._min_length) - return value - - @classmethod - def _validate_max_length(cls, value: str) -> str: + raise ValueError(f"String must be at least {cls._min_length} characters long") max_length = cls._get_max_length() + if max_length is not None and len(value) > max_length: - raise AnyStrMaxLengthError(limit_value=max_length) - return value + raise ValueError(f"String must be at most {max_length} characters long") - @classmethod - def _validate_regex(cls, value: str) -> str: if cls._regex is not None: if not re.match(cls._regex, value): pattern: str = cls._regex if isinstance(cls._regex, str) else cls._regex.pattern - raise StrRegexError(pattern=pattern) - return value + raise ValueError(f"String does not match the required pattern: {pattern}") + + return cls(value) + + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: type[Any], handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + return core_schema.no_info_plain_validator_function(cls._validate) + + @classmethod + def __get_pydantic_json_schema__( + cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> JsonSchemaValue: + json_schema: dict[str, Any] = {"type": "string"} + if cls._min_length is not None: + json_schema["minLength"] = cls._min_length + max_length = cls._get_max_length() + if max_length is not None: + json_schema["maxLength"] = max_length + if cls._regex is not None: + json_schema["pattern"] = ( + cls._regex if isinstance(cls._regex, str) else cls._regex.pattern + ) + + return json_schema diff --git a/src/openjd/model/_format_strings/_format_string.py b/src/openjd/model/_format_strings/_format_string.py index 5ebf738..42f35b9 100644 --- a/src/openjd/model/_format_strings/_format_string.py +++ b/src/openjd/model/_format_strings/_format_string.py @@ -2,16 +2,13 @@ from dataclasses import dataclass from numbers import Real -from typing import TYPE_CHECKING, Optional, Union +from typing import Optional, Union from .._errors import ExpressionError, TokenError from .._symbol_table import SymbolTable from ._dyn_constrained_str import DynamicConstrainedStr from ._expression import InterpolationExpression -if TYPE_CHECKING: - from pydantic.v1.typing import CallableGenerator - @dataclass class ExpressionInfo: @@ -21,8 +18,9 @@ class ExpressionInfo: resolved_value: Optional[Union[Real, str]] = None -class FormatStringError(Exception): +class FormatStringError(ValueError): def __init__(self, *, string: str, start: int, end: int, expr: str = "", details: str = ""): + self.input = string expression = f"Expression: {expr}. " if expr else "" reason = f"Reason: {details}." if details else "" msg = ( @@ -202,24 +200,3 @@ def _preprocess(self) -> list[Union[str, ExpressionInfo]]: result_list.append(expression_info) return result_list - - # Pydantic datamodel interfaces - # ================================ - # Reference: https://pydantic-docs.helpmanual.io/usage/types/#custom-data-types - - @classmethod - def __get_validators__(cls) -> "CallableGenerator": - for validator in super().__get_validators__(): - yield validator - yield cls._validate - - @classmethod - def _validate(cls, value: str) -> "FormatString": - # Reference: https://pydantic-docs.helpmanual.io/usage/validators/ - # Class constructor will raise validation errors on the value contents. - try: - return cls(value) - except FormatStringError as e: - # Pydantic validators must return a ValueError or AssertionError - # Convert the FormatStringError into a ValueError - raise ValueError(str(e)) diff --git a/src/openjd/model/_internal/_create_job.py b/src/openjd/model/_internal/_create_job.py index 1b90222..3f1f5ac 100644 --- a/src/openjd/model/_internal/_create_job.py +++ b/src/openjd/model/_internal/_create_job.py @@ -1,9 +1,9 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -from typing import Any, Union +from typing import cast, Any, Union -from pydantic.v1 import ValidationError -from pydantic.v1.error_wrappers import ErrorWrapper +from pydantic import ValidationError +from pydantic_core import InitErrorDetails from .._symbol_table import SymbolTable from .._format_strings import FormatString, FormatStringError @@ -37,11 +37,10 @@ def instantiate_model( # noqa: C901 Returns: OpenJDModel: The transformed model. """ - - errors = list[ErrorWrapper]() + errors = list[InitErrorDetails]() instantiated_fields = dict[str, Any]() - for field_name in model.__fields__.keys(): + for field_name in model.model_fields.keys(): target_field_name = field_name if field_name in model._job_creation_metadata.rename_fields: target_field_name = model._job_creation_metadata.rename_fields[field_name] @@ -66,14 +65,30 @@ def instantiate_model( # noqa: C901 else: # Raises: ValidationError, FormatStringError instantiated = _instantiate_noncollection_value( - model, field_name, field, symtab, loc + (field_name,) + model, field_name, field, symtab, (*loc, field_name) ) instantiated_fields[target_field_name] = instantiated - except (ValidationError, FormatStringError) as exc: - errors.append(ErrorWrapper(exc, loc)) + except ValidationError as exc: + # Convert the ErrorDetails to InitErrorDetails by excluding the 'msg' + for error_details in exc.errors(): + init_error_details = { + key: value for key, value in error_details.items() if key != "msg" + } + errors.append(cast(InitErrorDetails, init_error_details)) + except FormatStringError as exc: + errors.append( + InitErrorDetails( + type="value_error", + loc=loc, + ctx={"error": ValueError(str(exc))}, + input=exc.input, + ) + ) if errors: - raise ValidationError(errors, model.__class__) + raise ValidationError.from_exception_data( + title=model.__class__.__name__, line_errors=errors + ) if model._job_creation_metadata.adds_fields is not None: new_fields = model._job_creation_metadata.adds_fields(within_field, model, symtab) @@ -89,7 +104,18 @@ def instantiate_model( # noqa: C901 return create_as_class(**instantiated_fields) return model.__class__(**instantiated_fields) except ValidationError as exc: - raise ValidationError([ErrorWrapper(exc, loc)], model.__class__) + # Convert the ErrorDetails to InitErrorDetails by concatenating the 'loc' values and excluding the 'msg' + for error_details in exc.errors(): + init_error_details = {} + for key, value in error_details.items(): + if key == "loc": + init_error_details["loc"] = loc + cast(tuple, value) + elif key != "msg": + init_error_details[key] = value + errors.append(cast(InitErrorDetails, init_error_details)) + raise ValidationError.from_exception_data( + title=model.__class__.__name__, line_errors=errors + ) def _instantiate_noncollection_value( @@ -142,7 +168,7 @@ def _instantiate_list_field( # noqa: C901 symtab (SymbolTable): Symbol table for format string value lookups. loc (tuple[Union[str,int], ...]): Path to this value. """ - errors = list[ErrorWrapper]() + errors = list[InitErrorDetails]() result: Union[list[Any], dict[str, Any]] if field_name in within_model._job_creation_metadata.reshape_field_to_dict: key_field = within_model._job_creation_metadata.reshape_field_to_dict[field_name] @@ -156,14 +182,24 @@ def _instantiate_list_field( # noqa: C901 field_name, item, symtab, - loc - + ( - field_name, - idx, - ), + (*loc, field_name, idx), + ) + except ValidationError as exc: + # Convert the ErrorDetails to InitErrorDetails by excluding the 'msg' + for error_details in exc.errors(): + init_error_details = { + key: value for key, value in error_details.items() if key != "msg" + } + errors.append(cast(InitErrorDetails, init_error_details)) + except FormatStringError as exc: + errors.append( + InitErrorDetails( + type="value_error", + loc=loc, + ctx={"error": ValueError(str(exc))}, + input=exc.input, + ) ) - except (ValidationError, FormatStringError) as exc: - errors.append(ErrorWrapper(exc, loc)) else: result = list[Any]() for idx, item in enumerate(value): @@ -175,18 +211,30 @@ def _instantiate_list_field( # noqa: C901 field_name, item, symtab, - loc - + ( - field_name, - idx, - ), + (*loc, field_name, idx), + ) + ) + except ValidationError as exc: + # Convert the ErrorDetails to InitErrorDetails by excluding the 'msg' + for error_details in exc.errors(): + init_error_details = { + key: value for key, value in error_details.items() if key != "msg" + } + errors.append(cast(InitErrorDetails, init_error_details)) + except FormatStringError as exc: + errors.append( + InitErrorDetails( + type="value_error", + loc=loc, + ctx={"error": ValueError(str(exc))}, + input=exc.input, ) ) - except (ValidationError, FormatStringError) as exc: - errors.append(ErrorWrapper(exc, loc)) if errors: - raise ValidationError(errors, within_model.__class__) + raise ValidationError.from_exception_data( + title=within_model.__class__.__name__, line_errors=errors + ) return result @@ -207,7 +255,7 @@ def _instantiate_dict_field( symtab (SymbolTable): Symbol table for format string value lookups. loc (tuple[Union[str,int], ...]): Path to this value. """ - errors = list[ErrorWrapper]() + errors = list[InitErrorDetails]() result = dict[str, Any]() for key, item in value.items(): try: @@ -223,10 +271,26 @@ def _instantiate_dict_field( key, ), ) - except (ValidationError, FormatStringError) as exc: - errors.append(ErrorWrapper(exc, loc)) + except ValidationError as exc: + # Convert the ErrorDetails to InitErrorDetails by excluding the 'msg' + for error_details in exc.errors(): + init_error_details = { + key: value for key, value in error_details.items() if key != "msg" + } + errors.append(cast(InitErrorDetails, init_error_details)) + except FormatStringError as exc: + errors.append( + InitErrorDetails( + type="value_error", + loc=loc, + ctx={"error": ValueError(str(exc))}, + input=exc.input, + ) + ) if errors: - raise ValidationError(errors, within_model.__class__) + raise ValidationError.from_exception_data( + title=within_model.__class__.__name__, line_errors=errors + ) return result diff --git a/src/openjd/model/_internal/_variable_reference_validation.py b/src/openjd/model/_internal/_variable_reference_validation.py index 40e733b..7b23a81 100644 --- a/src/openjd/model/_internal/_variable_reference_validation.py +++ b/src/openjd/model/_internal/_variable_reference_validation.py @@ -1,12 +1,16 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. from collections import defaultdict -from typing import Any, Optional, Type +import typing +from typing import cast, Any, Optional, Type, Literal, Union from inspect import isclass -from pydantic.v1.error_wrappers import ErrorWrapper -import pydantic.v1.fields -from pydantic.v1.typing import is_literal_type +from pydantic import Discriminator +from pydantic_core import InitErrorDetails +from pydantic.fields import FieldInfo, ModelPrivateAttr + +# Workaround for Python 3.9 where issubclass raises an error "TypeError: issubclass() arg 1 must be a class" +from pydantic.v1.utils import lenient_issubclass from .._types import OpenJDModel, ResolutionScope from .._format_strings import FormatString, FormatStringError @@ -102,32 +106,17 @@ # - Information on this is encoded in the model's `_template_variable_sources` field. See the comment for this field in the # OpenJDModel base class for information on this property # 4. Since this validation is a pre-validator, we basically have to re-implement a fragment of Pydantic's model parser for this -# depth first traversal. Thus, you'll need to know the following about Pydantic v1.x's data model and parser to understand this +# depth first traversal. Thus, you'll need to know the following about Pydantic v2's data model and parser to understand this # implementation: -# a) All models are derived from pydantic.v1.BaseModel -# b) pydantic.BaseModel.__fields__: dict[str, pydantic.ModelField] is injected into all BaseModels by pydantic's BaseModel metaclass. +# a) All models are derived from pydantic.BaseModel +# b) pydantic.BaseModel.model_fields: dict[str, FieldInfo] is injected into all BaseModels by pydantic's BaseModel metaclass. # This member is what gives pydantic information about each of the fields defined in the model class. The key of the dict is the # name of the field in the model. -# c) pydantic.ModelField describes the type information about a model's field: -# i) pydantic.ModelField.shape is an integer that defines the shape of the field -# SHAPE_SINGLETON means that it's a singleton type. -# SHAPE_LIST means that it's a list type. -# SHAPE_DICT means that it's a dict type. -# etc. -# ii) pydantic.ModelField.type_ gives you the type of the field; this is only useful for scalar singleton fields. -# iii) pydantic.ModelField.sub_fields: Optional[list[pydantic.ModelField]] exists for list, dictionary, and union-typed singleton -# fields: -# 1. For SHAPE_LIST: sub_fields has length 1, and its element is the ModelField for the elements of the list. -# 2. For SHAPE_DICT: sub_fields has length 1, and its element is the ModelField for the value-type of the dict. -# 3. For SHAPE_SINGLETON: -# a) For scalar-typed fields: sub_fields is None -# b) For union-typed fields: sub_fields is a list of all of the types in the union -# iv) For discriminated unions: -# 1. pydantic.ModelField.discriminator_key: Optional[str] exists and it gives the name of the submodel field used to -# determine which type of the union a given data value is. -# 2. pydantic.sub_fields_mapping: Optional[dict[str,pydantic.ModelField]] exists and can be used to find the unioned type -# for a given discriminator value. -# +# c) pydantic.FieldInfo describes the type information about a model's field: +# i) pydantic.FieldInfo.annotation gives you the type of the field; The structure of the field is contained +# in this type, including typing.Annotated values for discriminated unions. Both the definition collection and +# validation recursively unwraps these types along with the values. The pydantic.FieldInfo also includes a discriminator +# value, so the code handles both cases. class ScopedSymtabs(defaultdict): @@ -189,13 +178,14 @@ def _internal_deepcopy(value: Any) -> Any: def _validate_model_template_variable_references( - cls: Type[OpenJDModel], - values: dict[str, Any], + model: Type, + value: Any, current_scope: ResolutionScope, symbol_prefix: str, symbols: ScopedSymtabs, loc: tuple, -) -> list[ErrorWrapper]: + discriminator: Union[str, Discriminator, None] = None, +) -> list[InitErrorDetails]: """Inner implementation of prevalidate_model_template_variable_references(). Arguments: @@ -206,17 +196,117 @@ def _validate_model_template_variable_references( symbols - The variable symbols that have been defined in each reference scope. loc - The path of fields taken from the root of the model to the current recursive level """ + # The errors that we're collecting for this node in the traversal, and will return from the function call. - errors: list[ErrorWrapper] = [] + errors: list[InitErrorDetails] = [] + + model_origin = typing.get_origin(model) + + # Unwrap the Optional types + if model_origin is typing.Optional: + return _validate_model_template_variable_references( + typing.get_args(model)[0], + value, + current_scope, + symbol_prefix, + symbols, + loc, + discriminator=discriminator, + ) + + # Unwrap the Annotated type, and get the discriminator while doing so + if model_origin is typing.Annotated: + model_args = typing.get_args(model) + for annotation in model_args[1:]: + if isinstance(annotation, FieldInfo): + discriminator = annotation.discriminator + return _validate_model_template_variable_references( + model_args[0], + value, + current_scope, + symbol_prefix, + symbols, + loc, + discriminator=discriminator, + ) + + # Validate all the items of a list + if model_origin is list: + # If the shape expects a list, but the value isn't one then we have a validation error. + # The error will get flagged by subsequent passes of the model validation. + if isinstance(value, list): + item_model = typing.get_args(model)[0] + for i, item in enumerate(value): + errors.extend( + _validate_model_template_variable_references( + item_model, item, current_scope, symbol_prefix, symbols, (*loc, i) + ) + ) + return errors + + if model_origin is dict: + # If the shape expects a dict, but the value isn't one then we have a validation error. + # The error will get flagged by subsequent passes of the model validation. + if isinstance(value, dict): + item_model = typing.get_args(model)[1] + for key, item in value.items(): + if not isinstance(key, str): + continue + errors.extend( + _validate_model_template_variable_references( + item_model, + item, + current_scope, + symbol_prefix, + symbols, + (*loc, key), + ) + ) + return errors + + # Validate all the variables from a non-discriminated union + if model_origin is Union and discriminator is None: + for sub_type in typing.get_args(model): + errors.extend( + _validate_model_template_variable_references( + sub_type, value, current_scope, symbol_prefix, symbols, loc + ) + ) + return errors + + # Unwrap a discriminated union to the selected type + if model_origin is Union and discriminator is not None: + unioned_model = _get_model_for_singleton_value(model, value, discriminator) + if unioned_model is not None: + return _validate_model_template_variable_references( + unioned_model, + value, + current_scope, + symbol_prefix, + symbols, + loc, + ) + else: + return [] + + if isclass(model) and lenient_issubclass(model, FormatString): + if isinstance(value, str): + return _check_format_string(value, current_scope, symbols, loc) + return [] + + # Return an empty error list if it's not an OpenJDModel, or if it's not a dict + if not (isclass(model) and lenient_issubclass(model, OpenJDModel) and isinstance(value, dict)): + return [] # Does this cls change the variable reference scope for itself and its children? If so, then update # our scope. - if cls._template_variable_scope is not None: - current_scope = cls._template_variable_scope + model_override_scope = cast(ModelPrivateAttr, model._template_variable_scope).get_default() + if model_override_scope is not None: + current_scope = model_override_scope # Apply any changes that this node makes to the template variable prefix. # e.g. It may change "Env." to "Env.File." - variable_defs = cls._template_variable_definitions + variable_defs = model._template_variable_definitions if variable_defs.symbol_prefix.startswith("|"): # The "|" character resets the nesting. symbol_prefix = variable_defs.symbol_prefix[1:] @@ -225,15 +315,17 @@ def _validate_model_template_variable_references( symbol_prefix += variable_defs.symbol_prefix # Recursively collect all of the variable definitions at this node and its child nodes. - value_symbols = _collect_variable_definitions(cls, values, current_scope, symbol_prefix) + value_symbols = _collect_variable_definitions( + model, value, current_scope, symbol_prefix, recursive_pruning=False + ) # Recursively validate the contents of FormatStrings within the model. - # Note: cls.__fields__: dict[str, pydantic.v1.fields.ModelField] - for field_name, field_model in cls.__fields__.items(): - field_value = values.get(field_name, None) - if field_value is None: + for field_name, field_info in model.model_fields.items(): + field_value = value.get(field_name) + field_model = field_info.annotation + if field_value is None or field_model is None: continue - if is_literal_type(field_model.type_): + if typing.get_origin(field_model) is Literal: # Literals aren't format strings and cannot be recursed in to; skip them. continue @@ -242,161 +334,33 @@ def _validate_model_template_variable_references( # If field_name is in _template_variable_sources, then the value tells us which # source fields from the current model/cls we need to propagate down into the # recursion for validating field_name - for source in cls._template_variable_sources.get(field_name, set()): + for source in model._template_variable_sources.get(field_name, set()): validation_symbols.update_self(value_symbols.get(source, ScopedSymtabs())) # Add in all of the symbols passed down from the parent. validation_symbols.update_self(symbols) - if field_model.shape == pydantic.v1.fields.SHAPE_SINGLETON: - _validate_singleton( - errors, + errors.extend( + _validate_model_template_variable_references( field_model, field_value, current_scope, symbol_prefix, validation_symbols, (*loc, field_name), + field_info.discriminator, ) - elif field_model.shape == pydantic.v1.fields.SHAPE_LIST: - if not isinstance(field_value, list): - continue - assert field_model.sub_fields is not None # For the type checker - item_model = field_model.sub_fields[0] - for i, item in enumerate(field_value): - _validate_singleton( - errors, - item_model, - item, - current_scope, - symbol_prefix, - validation_symbols, - (*loc, field_name, i), - ) - elif field_model.shape == pydantic.v1.fields.SHAPE_DICT: - if not isinstance(field_value, dict): - continue - assert field_model.sub_fields is not None # For the type checker - item_model = field_model.sub_fields[0] - for key, item in field_value.items(): - if not isinstance(key, str): - continue - _validate_singleton( - errors, - item_model, - item, - current_scope, - symbol_prefix, - validation_symbols, - (*loc, field_name, key), - ) - else: - raise NotImplementedError( - "You have hit an unimplemented code path. Please report this as a bug." - ) - - return errors - - -def _validate_singleton( - errors: list[ErrorWrapper], - field_model: pydantic.v1.fields.ModelField, - field_value: Any, - current_scope: ResolutionScope, - symbol_prefix: str, - symbols: ScopedSymtabs, - loc: tuple, -) -> None: - # Note: ModelField.sub_fields is populated if (otherwise it's None): - # a) field is a list type => sub_fields has 1 element, and its type is the element type of the list - # - this is handled *before* calling this function. - # b) field is a union => sub_fields' elements are the model types in the union - # c) field is a discriminated union => sub_fields has 1 element, and it is a ModelField with info about the union. - - if ( - field_model.discriminator_key is None - and field_model.sub_fields - and len(field_model.sub_fields) > 1 - ): - # The field is a union without a discriminator. - # e.g. Union[ list[Union[int,FormatString]], FormatString ] - _validate_general_union( - errors, field_model, field_value, current_scope, symbol_prefix, symbols, loc ) - return - - if field_model.discriminator_key: - # Discriminated union case - figure out what the actual model type is. - if not isinstance(field_value, dict): - # Validation error -- discriminated unions are always discriminating models, and so - # must by a dict. - return - model = _get_model_for_singleton_value(field_model, field_value) - if model is None: - # Validation error - will be flagged by a subsequent validation stage. - return - field_model = model - - if isclass(field_model.type_) and issubclass(field_model.type_, FormatString): - if isinstance(field_value, str): - errors.extend(_check_format_string(field_value, current_scope, symbols, loc)) - elif isclass(field_model.type_) and issubclass(field_model.type_, OpenJDModel): - if isinstance(field_value, dict): - errors.extend( - _validate_model_template_variable_references( - field_model.type_, - field_value, - current_scope, - symbol_prefix, - symbols, - loc, - ) - ) - -def _validate_general_union( - errors: list[ErrorWrapper], - field_model: pydantic.v1.fields.ModelField, - field_value: Any, - current_scope: ResolutionScope, - symbol_prefix: str, - symbols: ScopedSymtabs, - loc: tuple, -) -> None: - # Notes: - # - We narrowly only handle the kinds of unions that are present in the current model. - # - We rely on additions to the model being well tested w.r.t. evaluation of format strings, and - # such new tests being added signaling that this code needs to be enhanced. - # - Unions of model types are not currently present in the model so we do not handle/test that case. - # - The only union type that we have looks like: Union[ list[Union[int,FormatString]], FormatString ] - # - It's in the range field of task parameter definitions - - # We have to consider that the value may be any one of the types in the union, so we have to look at each possible type - # and attempt to process the value as that type. - assert field_model.sub_fields is not None # For the type checker - for sub_field in field_model.sub_fields: - if sub_field.shape == pydantic.v1.fields.SHAPE_SINGLETON: - _validate_singleton( - errors, sub_field, field_value, current_scope, symbol_prefix, symbols, loc - ) - elif sub_field.shape == pydantic.v1.fields.SHAPE_LIST: - if not isinstance(field_value, list): - # The given value must be a list in this case. - continue - assert sub_field.sub_fields is not None - item_model = sub_field.sub_fields[0] # For the type checker - for item in field_value: - _validate_singleton( - errors, item_model, item, current_scope, symbol_prefix, symbols, loc - ) + return errors def _check_format_string( value: str, current_scope: ResolutionScope, symbols: ScopedSymtabs, loc: tuple -) -> list[ErrorWrapper]: +) -> list[InitErrorDetails]: # Collect the variable reference errors, if any, from the given FormatString value. - errors = list[ErrorWrapper]() + errors = list[InitErrorDetails]() scoped_symbols = symbols[current_scope] try: f_value = FormatString(value) @@ -409,41 +373,61 @@ def _check_format_string( try: expr.expression.validate_symbol_refs(symbols=scoped_symbols) except ValueError as exc: - errors.append(ErrorWrapper(exc, loc)) + errors.append( + InitErrorDetails(type="value_error", loc=loc, ctx={"error": exc}, input=value) + ) return errors def _get_model_for_singleton_value( - field_model: pydantic.v1.fields.ModelField, value: Any -) -> Optional[pydantic.v1.fields.ModelField]: - """Given a ModelField and the value that we're given for that field, determine - the actual ModelField for the value in the event that the ModelField may be for + model: Any, value: Any, discriminator: Union[str, Discriminator, None] = None +) -> Optional[Type]: + """Given a FieldInfo and the value that we're given for that field, determine + the actual Model for the value in the event that the FieldInfo may be for a discriminated union.""" - # Precondition: value is a dict - assert isinstance(value, dict) - - if field_model.discriminator_key is None: - # If it's not a discriminated union, then the type_ of the field is the expected type of the value. - return field_model + # Unpack the annotated type, extracting the discriminator if provided + if typing.get_origin(model) is typing.Annotated: + for annotation in typing.get_args(model)[1:]: + if isinstance(annotation, FieldInfo): + discriminator = annotation.discriminator + model = typing.get_args(model)[0] + + if discriminator is None or typing.get_origin(model) is not typing.Union: + # If it's not a discriminated union, then pass through the type + return model + elif not isinstance(discriminator, str): + # This code only supports a field name discriminator, not a callable + raise NotImplementedError( + "You have hit an unimplemented code path. Please report this as a bug." + ) + elif not isinstance(value, dict): + # Validation error - will be flagged by a subsequent validation stage. + return None # The field is a discriminated union. Use the discriminator key to figure out which model # this specific value is. - key_value = value.get(field_model.discriminator_key, None) - if not key_value: + discr_value = value.get(discriminator) + if not discr_value: # key didn't have a value. This is a validation error that a later phase of validation # will flag. return None - if not isinstance(key_value, str): + if not isinstance(discr_value, str): # Keys must be strings. return None - assert field_model.sub_fields_mapping is not None # For the type checker - sub_model = field_model.sub_fields_mapping.get(key_value) - if not sub_model: - # The key value that we were given is not valid. - return None - return sub_model + # Find the correct model for the discriminator value by unwrapping the Union and then the discriminator Literals + assert typing.get_origin(model) is typing.Union # For the type checker + for sub_model in typing.get_args(model): + sub_model_discr_value = sub_model.model_fields[discriminator].annotation + if typing.get_origin(sub_model_discr_value) is not typing.Literal: + raise NotImplementedError( + "You have hit an unimplemented code path. Please report this as a bug." + ) + if typing.get_args(sub_model_discr_value)[0] == discr_value: + return sub_model + + return None ## ============================================= @@ -458,94 +442,151 @@ def _get_model_for_singleton_value( def _collect_variable_definitions( # noqa: C901 (suppress: too complex) - cls: Type[OpenJDModel], - values: dict[str, Any], + model: Type, + value: Any, current_scope: ResolutionScope, symbol_prefix: str, + recursive_pruning: bool = True, + discriminator: Union[str, Discriminator, None] = None, ) -> dict[str, ScopedSymtabs]: """Collects the names of variables that each field of this model object provides. The return value is a dictionary with a set of symbols for each field, - "__self__" for the model itself, and "__exports__" for the symbols that it + "__self__" for the model itself, and "__export__" for the symbols that it exports to its parent in the data model. + + When the model is not an OpenJDModel, it only populates the "__export__". """ # NOTE: This is not written to be super generic and handle all possible OpenJD models going # forward forever. It handles the subset of the general Pydantic data model that OpenJD is # currently using, and will be extended as we use additional features of Pydantic's data model. - symbols: dict[str, ScopedSymtabs] = {"__self__": ScopedSymtabs()} - - defs = cls._template_variable_definitions - - if defs.field: - # defs.field being defined means that the cls defines a template variable. - - # Figure out the name of the variable. - name: str = "" - if (def_field_value := values.get(defs.field, None)) is not None: - # The name of the variable is in a field and the field has a value - # in the given data. - if isinstance(def_field_value, str): - # The field can only be a name if its value is a string; otherwise, - # this will get flagged as a validation error later. - name = def_field_value - - # Define the symbols that are defined in the appropriate scopes if we have the name. - if name: - for vardef in defs.defines: - if vardef.prefix.startswith("|"): - symbol_name = f"{vardef.prefix[1:]}{name}" - else: - symbol_name = f"{symbol_prefix}{vardef.prefix}{name}" - _add_symbol(symbols["__self__"], vardef.resolves, symbol_name) - - # If this object injects any template variables then those are injected at the - # current model's scope. - for symbol in defs.inject: - if symbol.startswith("|"): - symbol_name = symbol[1:] + model_origin = typing.get_origin(model) + + # Unwrap the Optional types + if model_origin is typing.Optional: + return _collect_variable_definitions( + typing.get_args(model)[0], + value, + current_scope, + symbol_prefix, + discriminator=discriminator, + ) + + # Unwrap the Annotated type, and get the discriminator while doing so + if model_origin is typing.Annotated: + model_args = typing.get_args(model) + for annotation in model_args[1:]: + if isinstance(annotation, FieldInfo): + discriminator = annotation.discriminator + return _collect_variable_definitions( + model_args[0], value, current_scope, symbol_prefix, discriminator=discriminator + ) + + # Aggregate all the collected variable definitions from a list + if model_origin is list: + # If the shape expects a list, but the value isn't one then we have a validation error. + # The error will get flagged by subsequent passes of the model validation. + symtab = ScopedSymtabs() + if isinstance(value, list): + item_model = typing.get_args(model)[0] + for item in value: + symtab.update_self( + _collect_variable_definitions(item_model, item, current_scope, symbol_prefix)[ + "__export__" + ] + ) + return {"__export__": symtab} + + # Aggregate all the matching variables from a non-discriminated union + if model_origin is Union and discriminator is None: + symtab = ScopedSymtabs() + for sub_type in typing.get_args(model): + symtab.update_self( + _collect_variable_definitions(sub_type, value, current_scope, symbol_prefix)[ + "__export__" + ] + ) + return {"__export__": symtab} + + # Unwrap a discriminated union to the selected type + if model_origin is Union and discriminator is not None: + unioned_model = _get_model_for_singleton_value(model, value, discriminator) + if unioned_model is not None: + return _collect_variable_definitions( + unioned_model, + value, + current_scope, + symbol_prefix, + ) else: - symbol_name = f"{symbol_prefix}{symbol}" - _add_symbol(symbols["__self__"], current_scope, symbol_name) + return {"__export__": ScopedSymtabs()} - # Note: cls.__fields__: dict[str, pydantic.v1.fields.ModelField] - for field_name, field_model in cls.__fields__.items(): - field_value = values.get(field_name, None) - if field_value is None: - continue + # Anything except for an OpenJDModel returns an empty result + if not isclass(model) or not lenient_issubclass(model, OpenJDModel): + return {"__export__": ScopedSymtabs()} - if is_literal_type(field_model.type_): - # Literals cannot define variables, so skip this field. - continue + # If the model has no exported variable definitions, prune it + if recursive_pruning and "__export__" not in model._template_variable_sources: + return {"__export__": ScopedSymtabs()} - if field_model.shape == pydantic.v1.fields.SHAPE_SINGLETON: - result = _collect_singleton(field_model, field_value, current_scope, symbol_prefix) - if result: - symbols[field_name] = result - elif field_model.shape == pydantic.v1.fields.SHAPE_LIST: - # If the shape expects a list, but the value isn't one then we have a validation error. - # The error will get flagged by subsequent passes of the model validation. - if not isinstance(field_value, list): - continue - assert field_model.sub_fields is not None - item_model = field_model.sub_fields[0] - symbols[field_name] = ScopedSymtabs() - for item in field_value: - result = _collect_singleton(item_model, item, current_scope, symbol_prefix) - if result: - symbols[field_name].update_self(result) - elif field_model.shape == pydantic.v1.fields.SHAPE_DICT: - # dict[] fields can't define symbols. + # If the value is not a dict, then it's a validation error. We'll flag that error later. + if not isinstance(value, dict): + return {"__export__": ScopedSymtabs()} + + symbols: dict[str, ScopedSymtabs] = {"__self__": ScopedSymtabs(), "__export__": ScopedSymtabs()} + + # Process the variable definitions defined by this model + defs = getattr(model, "_template_variable_definitions", None) + + if defs: + if defs.field: + # defs.field being defined means that the cls defines a template variable. + + # Figure out the name of the variable. + name: str = "" + if (def_field_value := value.get(defs.field)) is not None: + # The name of the variable is in a field and the field has a value + # in the given data. + if isinstance(def_field_value, str): + # The field can only be a name if its value is a string; otherwise, + # this will get flagged as a validation error later. + name = def_field_value + + # Define the symbols that are defined in the appropriate scopes if we have the name. + if name: + for vardef in defs.defines: + if vardef.prefix.startswith("|"): + symbol_name = f"{vardef.prefix[1:]}{name}" + else: + symbol_name = f"{symbol_prefix}{vardef.prefix}{name}" + _add_symbol(symbols["__self__"], vardef.resolves, symbol_name) + + # If this object injects any template variables then those are injected at the + # current model's scope. + for symbol in defs.inject: + if symbol.startswith("|"): + symbol_name = symbol[1:] + else: + symbol_name = f"{symbol_prefix}{symbol}" + _add_symbol(symbols["__self__"], current_scope, symbol_name) + + # Collect the variable definitions exported by the fields of the model + for field_name, field_info in model.model_fields.items(): + field_value = value.get(field_name) + field_model = field_info.annotation + if field_value is None or field_model is None: continue - else: - raise NotImplementedError( - "You have hit an unimplemented code path. Please report this as a bug." - ) + + discriminator = field_info.discriminator + + symbols[field_name] = _collect_variable_definitions( + field_model, field_value, current_scope, symbol_prefix, discriminator=discriminator + )["__export__"] # Collect the exported symbols as specified by the metadata - symbols["__export__"] = ScopedSymtabs() - for source in cls._template_variable_sources.get("__export__", set()): + for source in model._template_variable_sources.get("__export__", set()): symbols["__export__"].update_self(symbols.get(source, ScopedSymtabs())) return symbols @@ -563,53 +604,3 @@ def _add_symbol(into: ScopedSymtabs, scope: ResolutionScope, symbol_name: str) - into[ResolutionScope.TASK].add(symbol_name) else: into[ResolutionScope.TASK].add(symbol_name) - - -def _collect_singleton( - model: pydantic.v1.fields.ModelField, - value: Any, - current_scope: ResolutionScope, - symbol_prefix: str, -) -> Optional[ScopedSymtabs]: - # Singletons that we recurse in to must all be OpenJDModels, so that means that - # the value must be a dictionary. hen the provided field value must be a dictionary - # to have a chance of being valid. If it's not valid, then we just skip it and - # let subsequent validation passes in the model itself flag those. - - # Note: ModelField.sub_fields is populated if (otherwise it's None): - # a) field is a list type => sub_fields has 1 element, and its type is the element type of the list - # b) field is a union => sub_fields' elements are the model types in the union - # c) field is a discriminated union => sub_fields has 1 element, and it is a ModelField with info about the union. - if not isinstance(value, dict): - return None - - if ( - model.discriminator_key is None - and model.sub_fields is not None - and len(model.sub_fields) > 1 - ): - # The only cases like this in our *current* model are the range field of IntTaskParameterDefinitions; they - # are non-discriminated unions of types that do not contain variable definitions, so we skip them. - return None - - if isclass(model.type_) and not issubclass(model.type_, OpenJDModel): - # The field is something like str, int, etc. These can't define variables, so skip it. - return None - - value_model = _get_model_for_singleton_value(model, value) - if value_model is None: - return None - if not isclass(value_model.type_) or ( - isclass(value_model.type_) and not issubclass(value_model.type_, OpenJDModel) - ): - # We only recursively collect from OpenJDModel typed values. - return None - if "__export__" not in value_model.type_._template_variable_sources: - # If the model doesn't export symbols, then there's no point to recursing in to this value. - return None - return _collect_variable_definitions( - value_model.type_, - value, - current_scope, - symbol_prefix, - )["__export__"] diff --git a/src/openjd/model/_parse.py b/src/openjd/model/_parse.py index c8e651f..66f5ec5 100644 --- a/src/openjd/model/_parse.py +++ b/src/openjd/model/_parse.py @@ -7,13 +7,13 @@ from typing import Any, ClassVar, Optional, Type, TypeVar, Union, cast import yaml -from pydantic.v1 import BaseModel -from pydantic.v1 import ValidationError as PydanticValidationError -from pydantic.v1.error_wrappers import ErrorWrapper +from pydantic import BaseModel +from pydantic import ValidationError as PydanticValidationError +from pydantic_core import ErrorDetails, InitErrorDetails from ._errors import DecodeValidationError from ._types import EnvironmentTemplate, JobTemplate, OpenJDModel, TemplateSpecificationVersion -from ._convert_pydantic_error import pydantic_validationerrors_to_str, ErrorDict +from ._convert_pydantic_error import pydantic_validationerrors_to_str from .v2023_09 import JobTemplate as JobTemplate_2023_09 from .v2023_09 import EnvironmentTemplate as EnvironmentTemplate_2023_09 @@ -33,7 +33,7 @@ class DocumentType(str, Enum): # Pydantic injects a __pydantic_model__ attribute into all dataclasses. To be able to parse -# dataclass models we need to be able to invoke Model.__pydantic_model__.parse_obj(), but +# dataclass models we need to be able to invoke Model.__pydantic_model__.model_validate(), but # type checkers do not realize that pydantic dataclasses have a __pydantic_model__ attribute. # So, we type-cast into this class to invoke that method. class PydanticDataclass: @@ -46,7 +46,7 @@ class PydanticDataclass: def _parse_model(*, model: Type[T], obj: Any) -> T: if is_dataclass(model): - return cast(T, cast(PydanticDataclass, model).__pydantic_model__.parse_obj(obj)) + return cast(T, cast(PydanticDataclass, model).__pydantic_model__.model_validate(obj)) else: prevalidator_error: Optional[PydanticValidationError] = None if hasattr(model, "_root_template_prevalidator"): @@ -55,12 +55,20 @@ def _parse_model(*, model: Type[T], obj: Any) -> T: except PydanticValidationError as exc: prevalidator_error = exc try: - result = cast(T, cast(BaseModel, model).parse_obj(obj)) + result = cast(T, cast(BaseModel, model).model_validate(obj)) except PydanticValidationError as exc: - errors: list[ErrorWrapper] = cast(list[ErrorWrapper], exc.raw_errors) if prevalidator_error is not None: - errors.extend(cast(list[ErrorWrapper], prevalidator_error.raw_errors)) - raise PydanticValidationError(errors, model) + errors = list[InitErrorDetails]() + for error_details in exc.errors() + prevalidator_error.errors(): + init_error_details = { + key: value for key, value in error_details.items() if key != "msg" + } + errors.append(cast(InitErrorDetails, init_error_details)) + raise PydanticValidationError.from_exception_data( + title=exc.title, line_errors=errors + ) + else: + raise if prevalidator_error is not None: raise prevalidator_error return result @@ -70,7 +78,7 @@ def parse_model(*, model: Type[T], obj: Any) -> T: try: return _parse_model(model=model, obj=obj) except PydanticValidationError as exc: - errors: list[ErrorDict] = cast(list[ErrorDict], exc.errors()) + errors: list[ErrorDetails] = exc.errors() raise DecodeValidationError(pydantic_validationerrors_to_str(model, errors)) @@ -105,7 +113,7 @@ def model_to_object(*, model: OpenJDModel) -> dict[str, Any]: """Given a model from this package, encode it as a dictionary such that it could be written to a JSON/YAML document.""" - as_dict = model.dict() + as_dict = model.model_dump() # Some of the values in the model can be type 'Decimal', which doesn't # encode into json/yaml without special handling. So, we convert those in to diff --git a/src/openjd/model/_symbol_table.py b/src/openjd/model/_symbol_table.py index 09ee6b2..5c5fbcf 100644 --- a/src/openjd/model/_symbol_table.py +++ b/src/openjd/model/_symbol_table.py @@ -31,6 +31,9 @@ def __init__(self, *, source: Optional[Union[SymbolTable, dict[str, Any]]] = Non else: raise TypeError(f"Cannot initialize with type {type(source)}") + def __repr__(self) -> str: + return f"SymbolTable({self._table})" + def __contains__(self, symbol: str) -> bool: return symbol in self._table diff --git a/src/openjd/model/_types.py b/src/openjd/model/_types.py index b47dbc6..af13dc2 100644 --- a/src/openjd/model/_types.py +++ b/src/openjd/model/_types.py @@ -8,7 +8,7 @@ from enum import Enum from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Type, Union -from pydantic.v1 import BaseModel, Extra +from pydantic import ConfigDict, BaseModel from ._symbol_table import SymbolTable @@ -269,13 +269,7 @@ class JobCreationMetadata: class OpenJDModel(BaseModel): - # See: https://docs.pydantic.dev/usage/model_config/ - class Config: - # Forbid extra fields in the input - extra = Extra.forbid - - # Make the model instances immutable - frozen = True + model_config = ConfigDict(extra="forbid", frozen=True) # The specific schema revision that the model implements. revision: ClassVar[SpecificationRevision] diff --git a/src/openjd/model/v2023_09/_model.py b/src/openjd/model/v2023_09/_model.py index 97f6cb5..7dac58d 100644 --- a/src/openjd/model/v2023_09/_model.py +++ b/src/openjd/model/v2023_09/_model.py @@ -6,23 +6,23 @@ from decimal import Decimal, InvalidOperation from enum import Enum from graphlib import CycleError, TopologicalSorter -from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional, Type, Union, cast -from typing_extensions import Annotated +from typing import Any, ClassVar, Literal, Optional, Type, Union, cast +from typing_extensions import Annotated, Self -from pydantic.v1 import ( +from pydantic import ( + field_validator, + model_validator, + StringConstraints, Field, PositiveInt, PositiveFloat, StrictBool, StrictInt, ValidationError, - conint, - conlist, - constr, - root_validator, - validator, + ValidationInfo, ) -from pydantic.v1.error_wrappers import ErrorWrapper +from pydantic_core import InitErrorDetails +from pydantic.fields import ModelPrivateAttr from .._format_strings import FormatString from .._errors import ExpressionError, TokenError @@ -119,10 +119,10 @@ class ValueReferenceConstants(Enum): # C1 = 0x80-0x9F # https://www.unicode.org/charts/PDF/U0080.pdf _Cc_characters = r"\u0000-\u001F\u007F-\u009F" -_standard_string_regex = rf"(?-m:^[^{_Cc_characters}]+\Z)" +_standard_string_regex = rf"(?-m:^[^{_Cc_characters}]+\z)" # Latin alphanumeric, starting with a letter -_identifier_regex = r"(?-m:^[A-Za-z_][A-Za-z0-9_]*\Z)" +_identifier_regex = r"(?-m:^[A-Za-z_][A-Za-z0-9_]*\z)" # Regex for defining file filter patterns allowed for use in file dialogs. # 1. Allowable values: "*", "*.*", and "*.[:file-extension-chars:]+". @@ -136,7 +136,7 @@ class ValueReferenceConstants(Enum): rf"(?-m:^(?:\*|\*\.\*|\*\." rf"[^{_Cc_characters}\\/\*" rf"\?\[\]#%&\{{\}}<>\$\!'" - rf"\\\":@`|=]+)\Z)" + rf"\\\":@`|=]+)\z)" ) @@ -144,27 +144,33 @@ class JobTemplateName(FormatString): _min_length = 1 -if TYPE_CHECKING: - JobName = str - Identifier = str - Description = str - EnvironmentName = str - StepName = str - ParameterStringValue = str -else: - JobName = constr(min_length=1, max_length=128, strict=True, regex=_standard_string_regex) - Identifier = constr(min_length=1, max_length=64, strict=True, regex=_identifier_regex) - Description = constr( +JobName = Annotated[ + str, + StringConstraints(min_length=1, max_length=128, strict=True, pattern=_standard_string_regex), +] +Identifier = Annotated[ + str, StringConstraints(min_length=1, max_length=64, strict=True, pattern=_identifier_regex) +] +Description = Annotated[ + str, + StringConstraints( min_length=1, max_length=2048, strict=True, # All unicode except the [Cc] (control characters) category # Allow CR, LF, and TAB. - regex=f"(?-m:^(?:[^{_Cc_characters}]|[\r\n\t])+\Z)", - ) - EnvironmentName = constr(min_length=1, max_length=64, strict=True, regex=_standard_string_regex) - StepName = constr(min_length=1, max_length=64, strict=True, regex=_standard_string_regex) - ParameterStringValue = constr(min_length=0, max_length=1024, strict=True) + pattern=f"(?-m:^(?:[^{_Cc_characters}]|[\r\n\t])+\\z)", + ), +] +EnvironmentName = Annotated[ + str, + StringConstraints(min_length=1, max_length=64, strict=True, pattern=_standard_string_regex), +] +StepName = Annotated[ + str, + StringConstraints(min_length=1, max_length=64, strict=True, pattern=_standard_string_regex), +] +ParameterStringValue = Annotated[str, StringConstraints(min_length=0, max_length=1024, strict=True)] # ================================================================== # ============================= Script types ======================= @@ -177,13 +183,13 @@ class JobTemplateName(FormatString): class CommandString(FormatString): _min_length = 1 # All unicode except the [Cc] (control characters) category - _regex = f"(?-m:^[^{_Cc_characters}]+\Z)" + _regex = f"(?-m:^[^{_Cc_characters}]+\\Z)" class ArgString(FormatString): # All unicode except the [Cc] (control characters) category # Allow CR, LF, and TAB. - _regex = f"(?-m:^[^{_Cc_characters}]*\Z)" + _regex = f"(?-m:^[^{_Cc_characters}]*\\Z)" class CancelationMode(str, Enum): @@ -191,10 +197,7 @@ class CancelationMode(str, Enum): TERMINATE = "TERMINATE" -if TYPE_CHECKING: - NotifyPeriodType = int -else: - NotifyPeriodType = conint(ge=1, le=600) +NotifyPeriodType = Annotated[int, Field(ge=1, le=600)] class CancelationMethodNotifyThenTerminate(OpenJDModel_v2023_09): @@ -245,10 +248,7 @@ class CancelationMethodTerminate(OpenJDModel_v2023_09): mode: Literal[CancelationMode.TERMINATE] -if TYPE_CHECKING: - ArgListType = list[ArgString] -else: - ArgListType = conlist(ArgString, min_items=1) +ArgListType = Annotated[list[ArgString], Field(min_length=1)] class Action(OpenJDModel_v2023_09): @@ -298,7 +298,8 @@ class EnvironmentActions(OpenJDModel_v2023_09): onEnter: Optional[Action] = Field(None) # noqa: N815 onExit: Optional[Action] = Field(None) # noqa: N815 - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def _requires_oneof(cls, values: dict[str, Any]) -> dict[str, Any]: """A validator that runs on the model data before parsing.""" on_enter = values.get("onEnter") @@ -315,11 +316,8 @@ class EmbeddedFileTypes(str, Enum): TEXT = "TEXT" -if TYPE_CHECKING: - Filename = str -else: - # TODO - regex of allowable filename characters - Filename = constr(min_length=1, max_length=64, strict=True) +# TODO - regex of allowable filename characters +Filename = Annotated[str, StringConstraints(min_length=1, max_length=64, strict=True)] class DataString(FormatString): @@ -360,10 +358,7 @@ class EmbeddedFileText(OpenJDModel_v2023_09): # --------------------- Script types ---------------------------- -if TYPE_CHECKING: - EmbeddedFiles = list[EmbeddedFileText] -else: - EmbeddedFiles = conlist(EmbeddedFileText, min_items=1) +EmbeddedFiles = Annotated[list[EmbeddedFileText], Field(min_length=1)] class StepScript(OpenJDModel_v2023_09): @@ -394,7 +389,8 @@ class StepScript(OpenJDModel_v2023_09): "embeddedFiles": {"embeddedFiles", "__self__"}, } - @validator("embeddedFiles") + @field_validator("embeddedFiles") + @classmethod def _unique_names(cls, v: Optional[EmbeddedFiles]) -> Optional[EmbeddedFiles]: if v is not None: return validate_unique_elements(v, item_value=lambda v: v.name, property="name") @@ -429,7 +425,8 @@ class EnvironmentScript(OpenJDModel_v2023_09): "embeddedFiles": {"embeddedFiles", "__self__"}, } - @validator("embeddedFiles") + @field_validator("embeddedFiles") + @classmethod def _unique_names(cls, v: Optional[EmbeddedFiles]) -> Optional[EmbeddedFiles]: if v is not None: return validate_unique_elements(v, item_value=lambda v: v.name, property="name") @@ -462,18 +459,16 @@ class RangeString(FormatString): _min_length = 1 -if TYPE_CHECKING: - # Note: Ordering within the Unions is important. Pydantic will try to match in - # the order given. - IntRangeList = list[Union[int, TaskParameterStringValue]] - FloatRangeList = list[Union[Decimal, TaskParameterStringValue]] - StringRangeList = list[TaskParameterStringValue] - TaskParameterStringValueAsJob = str -else: - IntRangeList = conlist(Union[int, TaskParameterStringValue], min_items=1, max_items=1024) - FloatRangeList = conlist(Union[Decimal, TaskParameterStringValue], min_items=1, max_items=1024) - StringRangeList = conlist(TaskParameterStringValue, min_items=1, max_items=1024) - TaskParameterStringValueAsJob = constr(min_length=0, max_length=1024) +# Note: Ordering within the Unions is important. Pydantic will try to match in +# the order given. +IntRangeList = Annotated[ + list[Union[int, TaskParameterStringValue]], Field(min_length=1, max_length=1024) +] +FloatRangeList = Annotated[ + list[Union[Decimal, TaskParameterStringValue]], Field(min_length=1, max_length=1024) +] +StringRangeList = Annotated[list[TaskParameterStringValue], Field(min_length=1, max_length=1024)] +TaskParameterStringValueAsJob = Annotated[str, StringConstraints(min_length=0, max_length=1024)] TaskRangeList = list[TaskParameterStringValueAsJob] TaskRangeExpression = RangeString @@ -483,15 +478,28 @@ class RangeString(FormatString): class RangeListTaskParameterDefinition(OpenJDModel_v2023_09): # element type of items in the range type: TaskParameterType + # NOTE: Pydantic V1 was allowing non-string values in this range, V2 is enforcing that type. range: TaskRangeList + @field_validator("range", mode="before") + @classmethod + def _coerce_to_string(cls, value: Any) -> Any: + # Coerce any int, float, or Decimal values into str + def coerce(v: Any) -> Any: + if isinstance(v, (int, float, Decimal)): + return str(v) + return v + + return [coerce(item) for item in value] + class RangeExpressionTaskParameterDefinition(OpenJDModel_v2023_09): # element type of items in the range type: TaskParameterType range: TaskRangeExpression - @validator("range") + @field_validator("range") + @classmethod def _validate_range_expression(cls, value: Any) -> Any: """At this point, the format expressions have been resolved and we can determine if it's a valid RangeExpression""" @@ -538,48 +546,66 @@ def _get_range_task_param_type(model: Any) -> Type[OpenJDModel]: exclude_fields={"name"}, ) - @validator("range", pre=True) + @field_validator("range", mode="before") + @classmethod def _validate_range_element_type(cls, value: Any) -> Any: - # pydandic will automatically type coerse values into integers. We explicitly + # pydantic will automatically type coerse values into integers. We explicitly # want to reject non-integer values, so this *pre* validator validates the # value *before* pydantic tries to type coerse it. # We do allow coersion from a string since we want to allow "1", and # "1.2" or "a" will fail the type coersion if isinstance(value, list): - errors = list[ErrorWrapper]() - for v in value: - if isinstance(v, bool) or not isinstance(v, (int, str)): + errors = list[InitErrorDetails]() + for i, item in enumerate(value): + if isinstance(item, bool) or not isinstance(item, (int, str)): errors.append( - ErrorWrapper( - ValueError("Value must be an integer or integer string."), ("range", v) + InitErrorDetails( + type="value_error", + loc=(i,), + ctx={ + "error": ValueError("Value must be an integer or integer string.") + }, + input=item, ) ) if errors: - raise ValidationError(errors, IntTaskParameterDefinition) + raise ValidationError.from_exception_data(cls.__name__, line_errors=errors) elif isinstance(value, RangeString): # TODO: nothing to do - it's guaranteed to be a format string at this point pass return value - @validator("range") + @field_validator("range") + @classmethod def _validate_range_elements(cls, value: Any) -> Any: if isinstance(value, list): - errors = list[ErrorWrapper]() - for v in value: - if isinstance(v, TaskParameterStringValue): + errors = list[InitErrorDetails]() + for i, item in enumerate(value): + if isinstance(item, TaskParameterStringValue): # A TaskParameterStringValue is a FormatString. # FormatString.expressions is the list of all expressions in the format string # ( e.g. "{{ Param.Foo }}"). - # Reject the string if it contains any expressions. - if len(v.expressions) == 0: - errors.append( - ErrorWrapper( - ValueError("String literal must contain an integer."), ("range", v) + # Validate the string if it does not contain any expressions, in order to catch + # errors earlier when possible. + if len(item.expressions) == 0: + try: + int(item) + except ValueError: + errors.append( + InitErrorDetails( + type="value_error", + loc=(i,), + ctx={ + "error": ValueError( + "String literal must contain an integer." + ) + }, + input=item, + ) ) - ) if errors: - raise ValidationError(errors, IntTaskParameterDefinition) + raise ValidationError.from_exception_data(cls.__name__, line_errors=errors) else: # If there are no format expressions, we can validate the range expression. # otherwise we defer to the RangeExressionTaskParameter model when @@ -619,29 +645,58 @@ class FloatTaskParameterDefinition(OpenJDModel_v2023_09): exclude_fields={"name"}, ) - @validator("range", each_item=True, pre=True) - def _validate_range_element_type(cls, v: Any) -> Any: - # pydandic will automatically type coerse values into floats. We explicitly + @field_validator("range", mode="before") + @classmethod + def _validate_range_element_type(cls, value: Any) -> Any: + # pydantic will automatically type coerce values into floats. We explicitly # want to reject non-integer values, so this *pre* validator validates the - # value *before* pydantic tries to type coerse it. + # value *before* pydantic tries to type coerce it. # We do allow coersion from a string since we want to allow "1", and # "1.2" or "a" will fail the type coersion - if isinstance(v, bool) or not isinstance(v, (int, float, str)): - raise ValueError("Item must be a float, int, or float string.") - return v + if isinstance(value, list): + errors = list[InitErrorDetails]() + for i, item in enumerate(value): + if isinstance(item, bool) or not isinstance(item, (int, float, str)): + errors.append( + InitErrorDetails( + type="value_error", + loc=("range", i), + ctx={"error": ValueError("Value must be a float or float string.")}, + input=item, + ) + ) + if errors: + raise ValidationError.from_exception_data(cls.__name__, line_errors=errors) + return value - @validator("range", each_item=True) + @field_validator("range") + @classmethod def _validate_range_elements( - cls, v: Union[Decimal, TaskParameterStringValue] - ) -> Union[Decimal, TaskParameterStringValue]: - if isinstance(v, TaskParameterStringValue): - # A TaskParameterStringValue is a FormatString. - # FormatString.expressions is the list of all expressions in the format string - # ( e.g. "{{ Param.Foo }}"). - # Reject the string if it contains any expressions. - if len(v.expressions) == 0: - raise ValueError("String literal must contain an integer or float.") - return v + cls, value: list[Union[Decimal, TaskParameterStringValue]] + ) -> list[Union[Decimal, TaskParameterStringValue]]: + errors = list[InitErrorDetails]() + for i, item in enumerate(value): + if isinstance(item, TaskParameterStringValue): + # A TaskParameterStringValue is a FormatString. + # FormatString.expressions is the list of all expressions in the format string + # ( e.g. "{{ Param.Foo }}"). + # Validate the string if it does not contain any expressions, in order to catch + # errors earlier when possible. + if len(item.expressions) == 0: + try: + float(item) + except ValueError: + errors.append( + InitErrorDetails( + type="value_error", + loc=(i,), + ctx={"error": ValueError("String literal must contain a float.")}, + input=item, + ) + ) + if errors: + raise ValidationError.from_exception_data(cls.__name__, line_errors=errors) + return value class StringTaskParameterDefinition(OpenJDModel_v2023_09): @@ -707,20 +762,21 @@ class PathTaskParameterDefinition(OpenJDModel_v2023_09): PathTaskParameterDefinition, ] -if TYPE_CHECKING: - TaskParameterList = list[TaskParameterDefinition] - CombinationExpr = str -else: - TaskParameterList = conlist( - Annotated[TaskParameterDefinition, Field(..., discriminator="type")], - min_items=1, - max_items=16, - ) - # Limit the CombinationExpr to characters allowed in an Identifier plus whitespace - # and the operator characters. - CombinationExpr = constr( - min_length=1, max_length=1280, strict=True, regex=r"(?-m:^[A-Za-z0-9\*\(\), ]+\Z)" - ) +TaskParameterList = Annotated[ + list[Annotated[TaskParameterDefinition, Field(..., discriminator="type")]], + Field( + min_length=1, + max_length=16, + ), +] +# Limit the CombinationExpr to characters allowed in an Identifier plus whitespace +# and the operator characters. +CombinationExpr = Annotated[ + str, + StringConstraints( + min_length=1, max_length=1280, strict=True, pattern=r"(?-m:^[A-Za-z0-9\*\(\), ]+\z)" + ), +] TaskRangeParameter = Union[RangeListTaskParameterDefinition, RangeExpressionTaskParameterDefinition] @@ -733,11 +789,14 @@ class StepParameterSpace(OpenJDModel_v2023_09): taskParameterDefinitions: dict[Identifier, TaskRangeParameter] combination: Optional[CombinationExpr] = None - @validator("combination") - def _validate_parameter_space(cls, v: str, values: dict[str, Any]) -> str: + @field_validator("combination") + @classmethod + def _validate_parameter_space(cls, v: str, info: ValidationInfo) -> str: if v is None: return v - param_defs = cast(dict[Identifier, TaskRangeParameter], values["taskParameterDefinitions"]) + param_defs = cast( + dict[Identifier, TaskRangeParameter], info.data["taskParameterDefinitions"] + ) parameter_range_lengths = { id: ( len(param.range) @@ -772,22 +831,19 @@ class StepParameterSpaceDefinition(OpenJDModel_v2023_09): reshape_field_to_dict={"taskParameterDefinitions": "name"}, ) - @validator("taskParameterDefinitions") + @field_validator("taskParameterDefinitions") + @classmethod def _validate_parameters(cls, v: TaskParameterList) -> TaskParameterList: # Must have a unique name for each Task parameter return validate_unique_elements(v, item_value=lambda v: v.name, property="name") - @root_validator - def _validate_combination(cls, values: dict[str, Any]) -> dict[str, Any]: - if values.get("combination") is None: - return values - if values.get("taskParameterDefinitions") is None: - return values + @model_validator(mode="after") + def _validate_combination(self) -> Self: + if self.combination is None or self.taskParameterDefinitions is None: + return self - parameter_list: TaskParameterList = cast( - TaskParameterList, values["taskParameterDefinitions"] - ) - combination: CombinationExpr = cast(CombinationExpr, values["combination"]) + parameter_def_list: TaskParameterList = self.taskParameterDefinitions + combination: CombinationExpr = self.combination # Ensure that the 'combination' string: # a) is a properly formed combination expression; and @@ -801,26 +857,36 @@ def _validate_combination(cls, values: dict[str, Any]) -> dict[str, Any]: expr_identifiers = list[str]() parse_tree.collect_identifiers(expr_identifiers) unique_expr_identifiers = set(expr_identifiers) - parameter_names = [param.name for param in parameter_list] + parameter_names = [param.name for param in parameter_def_list] unique_parameter_names = set(parameter_names) - errors = list[ErrorWrapper]() + errors = list[InitErrorDetails]() if len(unique_expr_identifiers) < len(unique_parameter_names): # Missing some parameter identifiers in the expression missing = sorted(list(unique_parameter_names - unique_expr_identifiers)) errors.append( - ErrorWrapper( - ValueError(f"Expression missing parameters: {','.join(missing)}"), - ("combination",), + InitErrorDetails( + type="value_error", + loc=("combination",), + ctx={ + "error": ValueError(f"Expression missing parameters: {','.join(missing)}") + }, + input=combination, ) ) if len(unique_parameter_names) < len(unique_expr_identifiers): # Have some extra parameters referenced in the expression extra = sorted(list(unique_expr_identifiers - unique_parameter_names)) errors.append( - ErrorWrapper( - ValueError(f"Expression references undefined parameters: {','.join(extra)}"), - ("combination",), + InitErrorDetails( + type="value_error", + loc=("combination",), + ctx={ + "error": ValueError( + f"Expression references undefined parameters: {','.join(extra)}" + ) + }, + input=combination, ) ) if len(expr_identifiers) != len(unique_expr_identifiers): @@ -829,30 +895,32 @@ def _validate_combination(cls, values: dict[str, Any]) -> dict[str, Any]: [id for id in expr_identifiers if id not in unique_expr_identifiers] ) errors.append( - ErrorWrapper( - ValueError( - f"Expression can only reference each parameter once: {','.join(duplicates)} " - ), - ("combination",), + InitErrorDetails( + type="value_error", + loc=("combination",), + ctx={ + "error": ValueError( + f"Expression can only reference each parameter once: {','.join(duplicates)} " + ) + }, + input=combination, ) ) if errors: - raise ValidationError(errors, StepParameterSpaceDefinition) + raise ValidationError.from_exception_data(self.__class__.__name__, errors) - return values + return self # ================================================================== # ====================== Environments Variables ==================== # ================================================================== -if TYPE_CHECKING: - EnvironmentVariableNameString = str -else: - EnvironmentVariableNameString = constr( - min_length=1, max_length=256, regex=r"(?-m:^[a-zA-Z_][a-zA-Z0-9_]*\Z)" - ) +EnvironmentVariableNameString = Annotated[ + str, + StringConstraints(min_length=1, max_length=256, pattern=r"(?-m:^[a-zA-Z_][a-zA-Z0-9_]*\z)"), +] class EnvironmentVariableValueString(FormatString): @@ -889,13 +957,15 @@ class Environment(OpenJDModel_v2023_09): _template_variable_scope = ResolutionScope.SESSION - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def _validate_has_script_or_variables(cls, values: dict[str, Any]) -> dict[str, Any]: if values.get("script") is None and values.get("variables") is None: raise ValueError("Environment must have either a script or variables.") return values - @validator("variables") + @field_validator("variables") + @classmethod def _validate_variables( cls, variables: Optional[EnvironmentVariableObject] ) -> Optional[EnvironmentVariableObject]: @@ -918,26 +988,22 @@ class JobParameterType(str, Enum): FLOAT = "FLOAT" -if TYPE_CHECKING: - AllowedParameterStringValueList = list[ParameterStringValue] - AllowedIntParameterList = list[int] - AllowedFloatParameterList = list[Decimal] - UserInterfaceLabelStringValue = str - FileDialogFilterPatternStringValue = str - FileDialogFilterPatternStringValueList = list[FileDialogFilterPatternStringValue] -else: - AllowedParameterStringValueList = conlist(ParameterStringValue, min_items=1) - AllowedIntParameterList = conlist(int, min_items=1) - AllowedFloatParameterList = conlist(Decimal, min_items=1) - UserInterfaceLabelStringValue = constr( - min_length=1, max_length=64, strict=True, regex=_standard_string_regex - ) - FileDialogFilterPatternStringValue = constr( - min_length=1, max_length=20, strict=True, regex=_file_dialog_filter_pattern_regex - ) - FileDialogFilterPatternStringValueList = conlist( - FileDialogFilterPatternStringValue, min_items=1, max_items=20 - ) +AllowedParameterStringValueList = Annotated[list[ParameterStringValue], Field(min_length=1)] +AllowedIntParameterList = Annotated[list[int], Field(min_length=1)] +AllowedFloatParameterList = Annotated[list[Decimal], Field(min_length=1)] +UserInterfaceLabelStringValue = Annotated[ + str, + StringConstraints(min_length=1, max_length=64, strict=True, pattern=_standard_string_regex), +] +FileDialogFilterPatternStringValue = Annotated[ + str, + StringConstraints( + min_length=1, max_length=20, strict=True, pattern=_file_dialog_filter_pattern_regex + ), +] +FileDialogFilterPatternStringValueList = Annotated[ + list[FileDialogFilterPatternStringValue], Field(min_length=1, max_length=20) +] # Target model for a job parameter when instantiating a job. @@ -972,8 +1038,8 @@ class JobStringParameterDefinitionUserInterface(OpenJDModel_v2023_09): """ control: StringUserInterfaceControl - label: Optional[UserInterfaceLabelStringValue] - groupLabel: Optional[UserInterfaceLabelStringValue] + label: Optional[UserInterfaceLabelStringValue] = None + groupLabel: Optional[UserInterfaceLabelStringValue] = None class JobStringParameterDefinition(OpenJDModel_v2023_09, JobParameterInterface): @@ -996,7 +1062,7 @@ class JobStringParameterDefinition(OpenJDModel_v2023_09, JobParameterInterface): name: Identifier type: Literal[JobParameterType.STRING] - userInterface: Optional[JobStringParameterDefinitionUserInterface] + userInterface: Optional[JobStringParameterDefinitionUserInterface] = None description: Optional[Description] = None # Note: Ordering of the following fields is essential for the validators to work correctly. minLength: Optional[StrictInt] = None # noqa: N815 @@ -1027,66 +1093,88 @@ class JobStringParameterDefinition(OpenJDModel_v2023_09, JobParameterInterface): }, ) - @validator("minLength") - def _validate_min_length(cls, v: Optional[int]) -> Optional[int]: - if v is None: - return v - if v <= 0: + @field_validator("minLength") + @classmethod + def _validate_min_length(cls, value: Optional[int]) -> Optional[int]: + if value is None: + return value + if value <= 0: raise ValueError("Required: 0 < minLength.") - return v + return value - @validator("maxLength") - def _validate_max_length(cls, v: Optional[int], values: dict[str, Any]) -> Optional[int]: - if v is None: - return v - if v <= 0: + @field_validator("maxLength") + @classmethod + def _validate_max_length(cls, value: Optional[int], info: ValidationInfo) -> Optional[int]: + if value is None: + return value + if value <= 0: raise ValueError("Required: 0 < maxLength.") - min_length = values.get("minLength") + min_length = info.data.get("minLength") if min_length is None: - return v - if min_length > v: + return value + if min_length > value: raise ValueError("Required: minLength <= maxLength.") - return v + return value - @validator("allowedValues", each_item=True) + @field_validator("allowedValues") + @classmethod def _validate_allowed_values_item( - cls, v: ParameterStringValue, values: dict[str, Any] - ) -> ParameterStringValue: - min_length = values.get("minLength") - if min_length is not None: - if len(v) < min_length: - raise ValueError("Value is shorter than minLength.") - max_length = values.get("maxLength") - if max_length is not None: - if len(v) > max_length: - raise ValueError("Value is longer than maxLength.") - return v + cls, value: AllowedParameterStringValueList, info: ValidationInfo + ) -> AllowedParameterStringValueList: + min_length = info.data.get("minLength") + max_length = info.data.get("maxLength") + errors = list[InitErrorDetails]() + for i, item in enumerate(value): + if min_length is not None: + if len(item) < min_length: + errors.append( + InitErrorDetails( + type="value_error", + loc=(i,), + ctx={"error": ValueError("Value is shorter than minLength.")}, + input=item, + ) + ) + if max_length is not None: + if len(item) > max_length: + errors.append( + InitErrorDetails( + type="value_error", + loc=(i,), + ctx={"error": ValueError("Value is longer than maxLength.")}, + input=item, + ) + ) + if errors: + raise ValidationError.from_exception_data(cls.__name__, line_errors=errors) + return value - @validator("default") + @field_validator("default") + @classmethod def _validate_default( - cls, v: ParameterStringValue, values: dict[str, Any] + cls, value: ParameterStringValue, info: ValidationInfo ) -> ParameterStringValue: - min_length = values.get("minLength") + min_length = info.data.get("minLength") if min_length is not None: - if len(v) < min_length: + if len(value) < min_length: raise ValueError("Value is shorter than minLength.") - max_length = values.get("maxLength") + max_length = info.data.get("maxLength") if max_length is not None: - if len(v) > max_length: + if len(value) > max_length: raise ValueError("Value is longer than maxLength.") - allowed_values = values.get("allowedValues") + allowed_values = info.data.get("allowedValues") if allowed_values is not None: - if v not in allowed_values: + if value not in allowed_values: raise ValueError("Must be an allowed value.") - return v + return value - @root_validator - def _validate_user_interface_compatibility(cls, values: dict[str, Any]) -> dict[str, Any]: + @model_validator(mode="after") + def _validate_user_interface_compatibility(self) -> Self: # validate that the user interface control is compatible with the value constraints - if values.get("userInterface"): - user_interface_control = values["userInterface"].control - if values.get("allowedValues") and user_interface_control in ( + if self.userInterface: + user_interface_control = self.userInterface.control + if self.allowedValues and user_interface_control in ( StringUserInterfaceControl.LINE_EDIT, StringUserInterfaceControl.MULTILINE_EDIT, ): @@ -1094,20 +1182,20 @@ def _validate_user_interface_compatibility(cls, values: dict[str, Any]) -> dict[ f"User interface control {user_interface_control.name} cannot be used when 'allowedValues' is provided" ) if ( - not values.get("allowedValues") + not self.allowedValues and user_interface_control == StringUserInterfaceControl.DROPDOWN_LIST ): raise ValueError( f"User interface control {user_interface_control.name} requires that 'allowedValues' be provided" ) if user_interface_control == StringUserInterfaceControl.CHECK_BOX: - allowed_values = set(v.upper() for v in values.get("allowedValues", [])) + allowed_values = set(v.upper() for v in self.allowedValues or []) if allowed_values not in ALLOWED_VALUES_FOR_CHECK_BOX: raise ValueError( f"User interface control {user_interface_control.name} requires that 'allowedValues' be " + f"one of {ALLOWED_VALUES_FOR_CHECK_BOX} (case and order insensitive)" ) - return values + return self # override def _check_constraints(self, value: Any) -> None: @@ -1160,12 +1248,9 @@ class JobPathParameterDefinitionFileFilter(OpenJDModel_v2023_09): patterns: FileDialogFilterPatternStringValueList -if TYPE_CHECKING: - JobPathParameterDefinitionFileFilterList = list[JobPathParameterDefinitionFileFilter] -else: - JobPathParameterDefinitionFileFilterList = conlist( - JobPathParameterDefinitionFileFilter, min_items=1, max_items=20 - ) +JobPathParameterDefinitionFileFilterList = Annotated[ + list[JobPathParameterDefinitionFileFilter], Field(min_length=1, max_length=20) +] class JobPathParameterDefinitionUserInterface(OpenJDModel_v2023_09): @@ -1185,10 +1270,10 @@ class JobPathParameterDefinitionUserInterface(OpenJDModel_v2023_09): """ control: PathUserInterfaceControl - label: Optional[UserInterfaceLabelStringValue] - groupLabel: Optional[UserInterfaceLabelStringValue] - fileFilters: Optional[JobPathParameterDefinitionFileFilterList] - fileFilterDefault: Optional[JobPathParameterDefinitionFileFilter] + label: Optional[UserInterfaceLabelStringValue] = None + groupLabel: Optional[UserInterfaceLabelStringValue] = None + fileFilters: Optional[JobPathParameterDefinitionFileFilterList] = None + fileFilterDefault: Optional[JobPathParameterDefinitionFileFilter] = None class JobPathParameterDefinition(OpenJDModel_v2023_09, JobParameterInterface): @@ -1215,9 +1300,9 @@ class JobPathParameterDefinition(OpenJDModel_v2023_09, JobParameterInterface): name: Identifier type: Literal[JobParameterType.PATH] - objectType: Optional[JobPathParameterDefinitionObjectType] - dataFlow: Optional[JobPathParameterDefinitionDataFlow] - userInterface: Optional[JobPathParameterDefinitionUserInterface] + objectType: Optional[JobPathParameterDefinitionObjectType] = None + dataFlow: Optional[JobPathParameterDefinitionDataFlow] = None + userInterface: Optional[JobPathParameterDefinitionUserInterface] = None description: Optional[Description] = None # Note: Ordering of the following fields is essential for the validators to work correctly. minLength: Optional[StrictInt] = None # noqa: N815 @@ -1250,66 +1335,88 @@ class JobPathParameterDefinition(OpenJDModel_v2023_09, JobParameterInterface): }, ) - @validator("minLength") - def _validate_min_length(cls, v: Optional[int]) -> Optional[int]: - if v is None: - return v - if v <= 0: + @field_validator("minLength") + @classmethod + def _validate_min_length(cls, value: Optional[int]) -> Optional[int]: + if value is None: + return value + if value <= 0: raise ValueError("Required: 0 < minLength.") - return v + return value - @validator("maxLength") - def _validate_max_length(cls, v: Optional[int], values: dict[str, Any]) -> Optional[int]: - if v is None: - return v - if v <= 0: + @field_validator("maxLength") + @classmethod + def _validate_max_length(cls, value: Optional[int], info: ValidationInfo) -> Optional[int]: + if value is None: + return value + if value <= 0: raise ValueError("Required: 0 < maxLength.") - min_length = values.get("minLength") + min_length = info.data.get("minLength") if min_length is None: - return v - if min_length > v: + return value + if min_length > value: raise ValueError("Required: minLength <= maxLength.") - return v + return value - @validator("allowedValues", each_item=True) + @field_validator("allowedValues") + @classmethod def _validate_allowed_values_item( - cls, v: ParameterStringValue, values: dict[str, Any] + cls, value: ParameterStringValue, info: ValidationInfo ) -> ParameterStringValue: - min_length = values.get("minLength") - if min_length is not None: - if len(v) < min_length: - raise ValueError("Value is shorter than minLength.") - max_length = values.get("maxLength") - if max_length is not None: - if len(v) > max_length: - raise ValueError("Value is longer than maxLength.") - return v + min_length = info.data.get("minLength") + max_length = info.data.get("maxLength") + errors = list[InitErrorDetails]() + for i, item in enumerate(value): + if min_length is not None: + if len(item) < min_length: + errors.append( + InitErrorDetails( + type="value_error", + loc=("allowedValues", i), + ctx={"error": ValueError("Value is shorter than minLength.")}, + input=item, + ) + ) + if max_length is not None: + if len(item) > max_length: + errors.append( + InitErrorDetails( + type="value_error", + loc=("allowedValues", i), + ctx={"error": ValueError("Value is longer than maxLength.")}, + input=item, + ) + ) + if errors: + raise ValidationError.from_exception_data(cls.__name__, line_errors=errors) + return value - @validator("default") + @field_validator("default") + @classmethod def _validate_default( - cls, v: ParameterStringValue, values: dict[str, Any] + cls, value: ParameterStringValue, info: ValidationInfo ) -> ParameterStringValue: - min_length = values.get("minLength") + min_length = info.data.get("minLength") if min_length is not None: - if len(v) < min_length: + if len(value) < min_length: raise ValueError("Value is shorter than minLength.") - max_length = values.get("maxLength") + max_length = info.data.get("maxLength") if max_length is not None: - if len(v) > max_length: + if len(value) > max_length: raise ValueError("Value is longer than maxLength.") - allowed_values = values.get("allowedValues") + allowed_values = info.data.get("allowedValues") if allowed_values is not None: - if v not in allowed_values: + if value not in allowed_values: raise ValueError("Must be an allowed value.") - return v + return value - @root_validator - def _validate_user_interface_compatibility(cls, values: dict[str, Any]) -> dict[str, Any]: + @model_validator(mode="after") + def _validate_user_interface_compatibility(self) -> Self: # validate that the user interface control is compatible with the value constraints - if values.get("userInterface"): - user_interface_control = values["userInterface"].control - if values.get("allowedValues") and user_interface_control in ( + if self.userInterface: + user_interface_control = self.userInterface.control + if self.allowedValues and user_interface_control in ( PathUserInterfaceControl.CHOOSE_INPUT_FILE, PathUserInterfaceControl.CHOOSE_OUTPUT_FILE, PathUserInterfaceControl.CHOOSE_DIRECTORY, @@ -1318,14 +1425,14 @@ def _validate_user_interface_compatibility(cls, values: dict[str, Any]) -> dict[ f"User interface control {user_interface_control.name} cannot be used when 'allowedValues' is provided" ) if ( - not values.get("allowedValues") + not self.allowedValues and user_interface_control == PathUserInterfaceControl.DROPDOWN_LIST ): raise ValueError( f"User interface control {user_interface_control.name} requires that 'allowedValues' be provided" ) if ( - values["userInterface"].fileFilters or values["userInterface"].fileFilterDefault + self.userInterface.fileFilters or self.userInterface.fileFilterDefault ) and user_interface_control not in [ PathUserInterfaceControl.CHOOSE_INPUT_FILE, PathUserInterfaceControl.CHOOSE_OUTPUT_FILE, @@ -1335,23 +1442,25 @@ def _validate_user_interface_compatibility(cls, values: dict[str, Any]) -> dict[ + " or 'fileFilterDefault is provided" ) if ( - values.get("objectType") == JobPathParameterDefinitionObjectType.FILE + self.objectType == JobPathParameterDefinitionObjectType.FILE and user_interface_control == PathUserInterfaceControl.CHOOSE_DIRECTORY ): raise ValueError( f"User interface control {user_interface_control.name} cannot be used with 'objectType' of FILE" ) - if values.get( - "objectType" - ) == JobPathParameterDefinitionObjectType.DIRECTORY and user_interface_control in [ - PathUserInterfaceControl.CHOOSE_INPUT_FILE, - PathUserInterfaceControl.CHOOSE_OUTPUT_FILE, - ]: + if ( + self.objectType == JobPathParameterDefinitionObjectType.DIRECTORY + and user_interface_control + in [ + PathUserInterfaceControl.CHOOSE_INPUT_FILE, + PathUserInterfaceControl.CHOOSE_OUTPUT_FILE, + ] + ): raise ValueError( f"User interface control {user_interface_control.name} cannot be used with 'objectType' of DIRECTORY" ) - return values + return self # override def _check_constraints(self, value: Any) -> None: @@ -1391,9 +1500,9 @@ class JobIntParameterDefinitionUserInterface(OpenJDModel_v2023_09): """ control: IntUserInterfaceControl - label: Optional[UserInterfaceLabelStringValue] - groupLabel: Optional[UserInterfaceLabelStringValue] - singleStepDelta: Optional[PositiveInt] + label: Optional[UserInterfaceLabelStringValue] = None + groupLabel: Optional[UserInterfaceLabelStringValue] = None + singleStepDelta: Optional[PositiveInt] = None class JobIntParameterDefinition(OpenJDModel_v2023_09): @@ -1416,7 +1525,7 @@ class JobIntParameterDefinition(OpenJDModel_v2023_09): name: Identifier type: Literal[JobParameterType.INT] - userInterface: Optional[JobIntParameterDefinitionUserInterface] + userInterface: Optional[JobIntParameterDefinitionUserInterface] = None description: Optional[Description] = None # Note: Ordering of the following fields is essential for the validators to work correctly. minValue: Optional[int] = None # noqa: N815 @@ -1448,106 +1557,143 @@ class JobIntParameterDefinition(OpenJDModel_v2023_09): ) @classmethod - def _precheck_is_int_type(cls, v: Any) -> None: + def _precheck_is_int_type(cls, value: Any) -> None: # prevent floats, bools, and other types from coercing into an int. # strings that contain floats are handled by pydantic's checks. - if not isinstance(v, (int, str)) or isinstance(v, bool): + if not isinstance(value, (int, str)) or isinstance(value, bool): raise ValueError("Value must be an integer or integer string.") - @validator("minValue", pre=True) - def _validate_min_value_type(cls, v: Optional[Any]) -> Optional[Any]: - if v is None: - return v - cls._precheck_is_int_type(v) - return v + @field_validator("minValue", mode="before") + @classmethod + def _validate_min_value_type(cls, value: Optional[Any]) -> Optional[Any]: + if value is None: + return value + cls._precheck_is_int_type(value) + return value - @validator("maxValue", pre=True) - def _validate_max_value_type(cls, v: Optional[Any]) -> Optional[Any]: - if v is None: - return v - cls._precheck_is_int_type(v) - return v + @field_validator("maxValue", mode="before") + @classmethod + def _validate_max_value_type(cls, value: Optional[Any]) -> Optional[Any]: + if value is None: + return value + cls._precheck_is_int_type(value) + return value - @validator("allowedValues", each_item=True, pre=True) - def _validate_allowed_values_item_type(cls, v: Any) -> Optional[Any]: - cls._precheck_is_int_type(v) - return v + @field_validator("allowedValues", mode="before") + @classmethod + def _validate_allowed_values_item_type(cls, value: Any) -> Any: + errors = list[InitErrorDetails]() + for i, item in enumerate(value): + if isinstance(item, bool) or not isinstance(item, (int, str)): + try: + cls._precheck_is_int_type(value) + except ValueError as e: + errors.append( + InitErrorDetails( + type="value_error", + loc=(i,), + ctx={"error": e}, + input=item, + ) + ) + if errors: + raise ValidationError.from_exception_data(cls.__name__, line_errors=errors) + return value - @validator("default", pre=True) - def _validate_default_value_type(cls, v: Optional[Any]) -> Optional[Any]: - if v is None: - return v - cls._precheck_is_int_type(v) - return v + @field_validator("default", mode="before") + @classmethod + def _validate_default_value_type(cls, value: Optional[Any]) -> Optional[Any]: + if value is None: + return value + cls._precheck_is_int_type(value) + return value - @validator("maxValue") - def _validate_max_value(cls, v: Optional[int], values: dict[str, Any]) -> Optional[int]: - if v is None: - return v - min_value = values.get("minValue") + @field_validator("maxValue") + @classmethod + def _validate_max_value(cls, value: Optional[int], info: ValidationInfo) -> Optional[int]: + if value is None: + return value + min_value = info.data.get("minValue") if min_value is None: - return v - if min_value > v: + return value + if min_value > value: raise ValueError("Required: minValue <= maxValue.") - return v + return value - @validator("allowedValues", each_item=True) - def _validate_allowed_values_item(cls, v: int, values: dict[str, Any]) -> int: - min_value = values.get("minValue") - if min_value is not None: - if v < min_value: - raise ValueError("Value less than minValue.") - max_value = values.get("maxValue") - if max_value is not None: - if v > max_value: - raise ValueError("Value larger than maxValue.") - return v + @field_validator("allowedValues") + @classmethod + def _validate_allowed_values_item(cls, value: list[int], info: ValidationInfo) -> list[int]: + min_value = info.data.get("minValue") + max_value = info.data.get("maxValue") + errors = list[InitErrorDetails]() + for i, item in enumerate(value): + if min_value is not None: + if item < min_value: + errors.append( + InitErrorDetails( + type="value_error", + loc=(i,), + ctx={"error": ValueError("Value less than minValue.")}, + input=item, + ) + ) + if max_value is not None: + if item > max_value: + errors.append( + InitErrorDetails( + type="value_error", + loc=(i,), + ctx={"error": ValueError("Value larger than minValue.")}, + input=item, + ) + ) + if errors: + raise ValidationError.from_exception_data(cls.__name__, line_errors=errors) + return value - @validator("default") - def _validate_default(cls, v: int, values: dict[str, Any]) -> int: - min_value = values.get("minValue") + @field_validator("default") + @classmethod + def _validate_default(cls, value: int, info: ValidationInfo) -> int: + min_value = info.data.get("minValue") if min_value is not None: - if v < min_value: + if value < min_value: raise ValueError("Value less than minValue.") - max_value = values.get("maxValue") + max_value = info.data.get("maxValue") if max_value is not None: - if v > max_value: + if value > max_value: raise ValueError("Value larger than maxValue.") - allowed_values = values.get("allowedValues") + allowed_values = info.data.get("allowedValues") if allowed_values is not None: - if v not in allowed_values: + if value not in allowed_values: raise ValueError("Must be an allowed value.") - return v + return value - @root_validator - def _validate_user_interface_compatibility(cls, values: dict[str, Any]) -> dict[str, Any]: + @model_validator(mode="after") + def _validate_user_interface_compatibility(self) -> Self: # validate that the user interface control is compatible with the value constraints - if values.get("userInterface"): - user_interface_control = values["userInterface"].control - if ( - values.get("allowedValues") - and user_interface_control == IntUserInterfaceControl.SPIN_BOX - ): + if self.userInterface: + user_interface_control = self.userInterface.control + if self.allowedValues and user_interface_control == IntUserInterfaceControl.SPIN_BOX: raise ValueError( f"User interface control {user_interface_control.name} cannot be used when 'allowedValues' is provided" ) if ( - not values.get("allowedValues") + not self.allowedValues and user_interface_control == IntUserInterfaceControl.DROPDOWN_LIST ): raise ValueError( f"User interface control {user_interface_control.name} requires that 'allowedValues' be provided" ) if ( - values["userInterface"].singleStepDelta + self.userInterface.singleStepDelta and user_interface_control != IntUserInterfaceControl.SPIN_BOX ): raise ValueError( f"User interface control {user_interface_control.name} cannot be used when 'singleStepDelta' is provided" ) - return values + return self # override def _check_constraints(self, value: Any) -> None: @@ -1599,10 +1745,10 @@ class JobFloatParameterDefinitionUserInterface(OpenJDModel_v2023_09): """ control: FloatUserInterfaceControl - label: Optional[UserInterfaceLabelStringValue] - groupLabel: Optional[UserInterfaceLabelStringValue] - decimals: Optional[PositiveInt] - singleStepDelta: Optional[PositiveFloat] + label: Optional[UserInterfaceLabelStringValue] = None + groupLabel: Optional[UserInterfaceLabelStringValue] = None + decimals: Optional[PositiveInt] = None + singleStepDelta: Optional[PositiveFloat] = None class JobFloatParameterDefinition(OpenJDModel_v2023_09): @@ -1625,7 +1771,7 @@ class JobFloatParameterDefinition(OpenJDModel_v2023_09): name: Identifier type: Literal[JobParameterType.FLOAT] - userInterface: Optional[JobFloatParameterDefinitionUserInterface] + userInterface: Optional[JobFloatParameterDefinitionUserInterface] = None description: Optional[Description] = None # Note: Ordering of the following fields is essential for the validators to work correctly. minValue: Optional[Decimal] = None # noqa: N815 @@ -1656,74 +1802,96 @@ class JobFloatParameterDefinition(OpenJDModel_v2023_09): }, ) - @validator("maxValue") - def _validate_max_value(cls, v: Optional[Decimal], values: dict[str, Any]) -> Optional[Decimal]: - if v is None: - return v - min_value = values.get("minValue") + @field_validator("maxValue") + @classmethod + def _validate_max_value( + cls, value: Optional[Decimal], info: ValidationInfo + ) -> Optional[Decimal]: + if value is None: + return value + min_value = info.data.get("minValue") if min_value is None: - return v - if min_value > v: + return value + if min_value > value: raise ValueError("Required: minValue <= maxValue.") - return v + return value - @validator("allowedValues", each_item=True) - def _validate_allowed_values_item(cls, v: Decimal, values: dict[str, Any]) -> Decimal: - min_value = values.get("minValue") - if min_value is not None: - if v < min_value: - raise ValueError("Value less than minValue.") - max_value = values.get("maxValue") - if max_value is not None: - if v > max_value: - raise ValueError("Value larger than maxValue.") - return v + @field_validator("allowedValues") + @classmethod + def _validate_allowed_values_item( + cls, value: list[Decimal], info: ValidationInfo + ) -> list[Decimal]: + min_value = info.data.get("minValue") + max_value = info.data.get("maxValue") + errors = list[InitErrorDetails]() + for i, item in enumerate(value): + if min_value is not None: + if item < min_value: + errors.append( + InitErrorDetails( + type="value_error", + loc=(i,), + ctx={"error": ValueError("Value less than minValue.")}, + input=item, + ) + ) + if max_value is not None: + if item > max_value: + errors.append( + InitErrorDetails( + type="value_error", + loc=(i,), + ctx={"error": ValueError("Value larger than maxValue.")}, + input=item, + ) + ) + if errors: + raise ValidationError.from_exception_data(cls.__name__, line_errors=errors) + return value - @validator("default") - def _validate_default(cls, v: Decimal, values: dict[str, Any]) -> Decimal: - min_value = values.get("minValue") + @field_validator("default") + @classmethod + def _validate_default(cls, value: Decimal, info: ValidationInfo) -> Decimal: + min_value = info.data.get("minValue") if min_value is not None: - if v < min_value: + if value < min_value: raise ValueError("Value less than minValue.") - max_value = values.get("maxValue") + max_value = info.data.get("maxValue") if max_value is not None: - if v > max_value: + if value > max_value: raise ValueError("Value larger than maxValue.") - allowed_values = values.get("allowedValues") + allowed_values = info.data.get("allowedValues") if allowed_values is not None: - if v not in allowed_values: + if value not in allowed_values: raise ValueError("Must be an allowed value.") - return v + return value - @root_validator - def _validate_user_interface_compatibility(cls, values: dict[str, Any]) -> dict[str, Any]: + @model_validator(mode="after") + def _validate_user_interface_compatibility(self) -> Self: # validate that the user interface control is compatible with the value constraints - if values.get("userInterface"): - user_interface_control = values["userInterface"].control - if ( - values.get("allowedValues") - and user_interface_control == FloatUserInterfaceControl.SPIN_BOX - ): + if self.userInterface: + user_interface_control = self.userInterface.control + if self.allowedValues and user_interface_control == FloatUserInterfaceControl.SPIN_BOX: raise ValueError( f"User interface control {user_interface_control.name} cannot be used when 'allowedValues' is provided" ) if ( - not values.get("allowedValues") + not self.allowedValues and user_interface_control == FloatUserInterfaceControl.DROPDOWN_LIST ): raise ValueError( f"User interface control {user_interface_control.name} requires that 'allowedValues' be provided" ) if ( - values["userInterface"].singleStepDelta + self.userInterface.singleStepDelta and user_interface_control != FloatUserInterfaceControl.SPIN_BOX ): raise ValueError( f"User interface control {user_interface_control.name} cannot be used when 'singleStepDelta' is provided" ) - return values + return self # override def _check_constraints(self, value: Any) -> None: @@ -1787,10 +1955,9 @@ class AttributeCapabilityValue(FormatString): _min_length = 1 -if TYPE_CHECKING: - AttributeCapabilityList = list[AttributeCapabilityValue] -else: - AttributeCapabilityList = conlist(AttributeCapabilityValue, min_items=1, max_items=50) +AttributeCapabilityList = Annotated[ + list[AttributeCapabilityValue], Field(min_length=1, max_length=50) +] class AmountRequirement(OpenJDModel_v2023_09): @@ -1810,15 +1977,16 @@ class AmountRequirement(OpenJDModel_v2023_09): """ name: str - min: Optional[Decimal] - max: Optional[Decimal] + min: Optional[Decimal] = None + max: Optional[Decimal] = None - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def validate_concrete_model(cls, values: dict[str, Any]) -> dict[str, Any]: # Reuse the AmountRequirementTemplate validation. Because all the template # variables have been substituted, it will now run validation it couldn't # before. - AmountRequirementTemplate.parse_obj(values) + AmountRequirementTemplate.model_validate(values) return values @@ -1837,8 +2005,8 @@ class AmountRequirementTemplate(OpenJDModel_v2023_09): """ name: AmountCapabilityName - min: Optional[Decimal] - max: Optional[Decimal] + min: Optional[Decimal] = None + max: Optional[Decimal] = None _job_creation_metadata = JobCreationMetadata( create_as=JobCreateAsMetadata(model=AmountRequirement), @@ -1847,33 +2015,37 @@ class AmountRequirementTemplate(OpenJDModel_v2023_09): }, ) - @validator("name") + @field_validator("name") + @classmethod def _validate_name(cls, v: str) -> str: validate_amount_capability_name( capability_name=v, standard_capabilities=_STANDARD_AMOUNT_CAPABILITIES_NAMES ) return v - @validator("min") - def _validate_min(cls, v: Optional[Decimal], values: dict[str, Any]) -> Optional[Decimal]: + @field_validator("min") + @classmethod + def _validate_min(cls, v: Optional[Decimal]) -> Optional[Decimal]: if v is None: return v if v < 0: raise ValueError(f"Value {v} must be zero or greater") return v - @validator("max") - def _validate_max(cls, v: Optional[Decimal], values: dict[str, Any]) -> Optional[Decimal]: + @field_validator("max") + @classmethod + def _validate_max(cls, v: Optional[Decimal], info: ValidationInfo) -> Optional[Decimal]: if v is None: return v if v <= 0: raise ValueError("Value must be greater than 0") - v_min = values.get("min") + v_min = info.data.get("min") if v_min is not None and v_min > v: raise ValueError("Value for 'max' must be greater or equal to 'min'") return v - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def _validate_has_one_optional(cls, values: dict[str, Any]) -> dict[str, Any]: if not ("min" in values or "max" in values): raise ValueError("At least one of 'min' or 'max' must be defined.") @@ -1892,15 +2064,16 @@ class AttributeRequirement(OpenJDModel_v2023_09): """ name: str - anyOf: Optional[list[str]] - allOf: Optional[list[str]] + anyOf: Optional[list[str]] = None + allOf: Optional[list[str]] = None - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def validate_concrete_model(cls, values: dict[str, Any]) -> dict[str, Any]: # Reuse the AttributeRequirementTemplate validation. Because all the template # variables have been substituted, it will now run validation it couldn't # before. - AttributeRequirementTemplate.parse_obj(values) + AttributeRequirementTemplate.model_validate(values) return values @@ -1914,8 +2087,8 @@ class AttributeRequirementTemplate(OpenJDModel_v2023_09): """ name: AttributeCapabilityName - anyOf: Optional[AttributeCapabilityList] # noqa: N815 - allOf: Optional[AttributeCapabilityList] # noqa: N815 + anyOf: Optional[AttributeCapabilityList] = None # noqa: N815 + allOf: Optional[AttributeCapabilityList] = None # noqa: N815 _job_creation_metadata = JobCreationMetadata( create_as=JobCreateAsMetadata(model=AttributeRequirement), @@ -1927,7 +2100,8 @@ class AttributeRequirementTemplate(OpenJDModel_v2023_09): ) _attribute_capability_value_max_length: int = 100 - @validator("name") + @field_validator("name") + @classmethod def _validate_name(cls, v: str) -> str: validate_attribute_capability_name( capability_name=v, standard_capabilities=_STANDARD_ATTRIBUTE_CAPABILITIES_NAMES @@ -1936,10 +2110,10 @@ def _validate_name(cls, v: str) -> str: @classmethod def _validate_attribute_list( - cls, v: AttributeCapabilityList, values: dict[str, Any], is_allof: bool + cls, v: AttributeCapabilityList, info: ValidationInfo, is_allof: bool ) -> None: try: - capability_name = values["name"].lower() + capability_name = info.data["name"].lower() except KeyError: # Just return as though there is no error. The missing name field # will be reported by the validation of 'name' @@ -1967,30 +2141,36 @@ def _validate_attribute_list( continue if not cls._attribute_capability_value_regex.match(item): raise ValueError(f"Value {item} is not a valid attribute capability value.") - if len(item) > cls._attribute_capability_value_max_length: + attribute_capability_value_max_length = cast( + ModelPrivateAttr, cls._attribute_capability_value_max_length + ).get_default() + if len(item) > attribute_capability_value_max_length: raise ValueError( - f"Value {item} exceeds {cls._attribute_capability_value_max_length} character length limit." + f"Value {item} exceeds {attribute_capability_value_max_length} character length limit." ) - @validator("allOf") + @field_validator("allOf") + @classmethod def _validate_allof( - cls, v: Optional[AttributeCapabilityList], values: dict[str, Any] + cls, v: Optional[AttributeCapabilityList], info: ValidationInfo ) -> Optional[AttributeCapabilityList]: if v is None: return v - cls._validate_attribute_list(v, values, True) + cls._validate_attribute_list(v, info, True) return v - @validator("anyOf") + @field_validator("anyOf") + @classmethod def _validate_anyof( - cls, v: Optional[AttributeCapabilityList], values: dict[str, Any] + cls, v: Optional[AttributeCapabilityList], info: ValidationInfo ) -> Optional[AttributeCapabilityList]: if v is None: return v - cls._validate_attribute_list(v, values, False) + cls._validate_attribute_list(v, info, False) return v - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def _validate_has_one_optional(cls, values: dict[str, Any]) -> dict[str, Any]: if not ("anyOf" in values or "allOf" in values): raise ValueError("At least one of 'anyOf' or 'allOf' must be defined.") @@ -2012,7 +2192,8 @@ class HostRequirementsTemplate(OpenJDModel_v2023_09): _max_allowed_requirements: int = 50 - @validator("amounts") + @field_validator("amounts") + @classmethod def _validate_amounts( cls, v: Optional[list[AmountRequirementTemplate]] ) -> Optional[list[AmountRequirementTemplate]]: @@ -2022,7 +2203,8 @@ def _validate_amounts( raise ValueError("List must contain at least one element or not be defined.") return v - @validator("attributes") + @field_validator("attributes") + @classmethod def _validate_attributes( cls, v: Optional[list[AttributeRequirementTemplate]] ) -> Optional[list[AttributeRequirementTemplate]]: @@ -2032,22 +2214,22 @@ def _validate_attributes( raise ValueError("List must contain at least one element or not be defined.") return v - @root_validator - def _validate(cls, values: dict[str, Any]) -> dict[str, Any]: - if not ("amounts" in values or "attributes" in values): + @model_validator(mode="after") + def _validate(self) -> Self: + amounts = self.amounts + attributes = self.attributes + if amounts is None and attributes is None: raise ValueError( "Must define at least one of 'amounts' or 'attributes' if defining this property." ) - amounts = values.get("amounts") - attributes = values.get("attributes") total_amounts = len(amounts) if amounts is not None else 0 total_attributes = len(attributes) if attributes is not None else 0 total = total_amounts + total_attributes - if total > cls._max_allowed_requirements: + if total > self._max_allowed_requirements: raise ValueError( - f"The total number of requirements must not exceed {cls._max_allowed_requirements}. {total} requirements defined." + f"The total number of requirements must not exceed {self._max_allowed_requirements}. {total} requirements defined." ) - return values + return self # ================================================================== @@ -2059,12 +2241,8 @@ class StepDependency(OpenJDModel_v2023_09): dependsOn: StepName -if TYPE_CHECKING: - StepEnvironmentList = list[Environment] - StepDependenciesList = list[StepDependency] -else: - StepEnvironmentList = conlist(Environment, min_items=1) - StepDependenciesList = conlist(StepDependency, min_items=1) +StepEnvironmentList = Annotated[list[Environment], Field(min_length=1)] +StepDependenciesList = Annotated[list[StepDependency], Field(min_length=1)] # Target model for a StepTemplate when instantiating a job. @@ -2110,7 +2288,8 @@ class StepTemplate(OpenJDModel_v2023_09): } _job_creation_metadata = JobCreationMetadata(create_as=JobCreateAsMetadata(model=Step)) - @validator("dependencies") + @field_validator("dependencies") + @classmethod def _validate_no_duplicate_deps( cls, v: Optional[StepDependenciesList] ) -> Optional[StepDependenciesList]: @@ -2121,7 +2300,8 @@ def _validate_no_duplicate_deps( raise ValueError("Duplicate dependencies are not allowed.") return v - @validator("stepEnvironments") + @field_validator("stepEnvironments") + @classmethod def _unique_environment_names( cls, v: Optional[StepEnvironmentList] ) -> Optional[StepEnvironmentList]: @@ -2129,32 +2309,21 @@ def _unique_environment_names( return validate_unique_elements(v, item_value=lambda v: v.name, property="name") return v - @root_validator - def _validate_no_self_dependency(cls, values: dict[str, Any]) -> dict[str, Any]: + @model_validator(mode="after") + def _validate_no_self_dependency(self) -> Self: # Dependency of the step upon itself is not allowed. - deps: StepDependenciesList = values.get("dependencies", []) + deps: StepDependenciesList = self.dependencies or [] if not deps: - return values - stepname = values.get("name") + return self + stepname = self.name if any(dep.dependsOn == stepname for dep in deps): raise ValueError("A step cannot depend upon itself.") - return values + return self -if TYPE_CHECKING: - StepTemplateList = list[StepTemplate] - JobParameterDefinitionList = list[ - Union[ - JobIntParameterDefinition, - JobFloatParameterDefinition, - JobStringParameterDefinition, - JobPathParameterDefinition, - ] - ] - JobEnvironmentsList = list[Environment] -else: - StepTemplateList = conlist(StepTemplate, min_items=1) - JobParameterDefinitionList = conlist( +StepTemplateList = Annotated[list[StepTemplate], Field(min_length=1)] +JobParameterDefinitionList = Annotated[ + list[ Annotated[ Union[ JobIntParameterDefinition, @@ -2163,12 +2332,14 @@ def _validate_no_self_dependency(cls, values: dict[str, Any]) -> dict[str, Any]: JobPathParameterDefinition, ], Field(..., discriminator="type"), - ], - min_items=1, - max_items=50, - ) - JobEnvironmentsList = conlist(Environment, min_items=1) - + ] + ], + Field( + min_length=1, + max_length=50, + ), +] +JobEnvironmentsList = Annotated[list[Environment], Field(min_length=1)] JobParameters = dict[Identifier, JobParameter] @@ -2222,11 +2393,13 @@ class JobTemplate(OpenJDModel_v2023_09): rename_fields={"parameterDefinitions": "parameters"}, ) - @validator("steps") + @field_validator("steps") + @classmethod def _unique_step_names(cls, v: StepTemplateList) -> StepTemplateList: return validate_unique_elements(v, item_value=lambda v: v.name, property="name") - @validator("parameterDefinitions") + @field_validator("parameterDefinitions") + @classmethod def _unique_parameter_names( cls, v: Optional[JobParameterDefinitionList] ) -> Optional[JobParameterDefinitionList]: @@ -2234,7 +2407,8 @@ def _unique_parameter_names( return validate_unique_elements(v, item_value=lambda v: v.name, property="name") return v - @validator("jobEnvironments") + @field_validator("jobEnvironments") + @classmethod def _unique_environment_names( cls, v: Optional[JobEnvironmentsList] ) -> Optional[JobEnvironmentsList]: @@ -2251,13 +2425,13 @@ def _root_template_prevalidator(cls, values: dict[str, Any]) -> dict[str, Any]: cast(Type[OpenJDModel], cls), values ) if errors: - raise ValidationError(errors, JobTemplate) + raise ValidationError.from_exception_data(cls.__name__, line_errors=errors) return values - @root_validator - def _validate_no_step_dependency_cycles(cls, values: dict[str, Any]) -> dict[str, Any]: + @model_validator(mode="after") + def _validate_no_step_dependency_cycles(self) -> Self: depgraph = dict[str, set[str]]() - steplist = values.get("steps", []) + steplist = self.steps or [] for step in steplist: if step.dependencies is not None: dependsOn = set[str](dep.dependsOn for dep in step.dependencies) @@ -2271,72 +2445,76 @@ def _validate_no_step_dependency_cycles(cls, values: dict[str, Any]) -> dict[str cycle = " -> ".join(exc.args[1]) raise ValueError(f"Step dependencies form a cycle: {cycle}") from None - return values + return self - @root_validator - def _validate_step_deps_exist(cls, values: dict[str, Any]) -> dict[str, Any]: + @model_validator(mode="after") + def _validate_step_deps_exist(self) -> Self: # Check that the deps referenced by all steps actually exist - steplist = values.get("steps", []) + steplist = self.steps or [] if not steplist: - return values + return self - errors = list[ErrorWrapper]() + errors = list[InitErrorDetails]() stepnames = set[str](step.name for step in steplist) for i, step in enumerate(steplist): if step.dependencies is not None: for j, dep in enumerate(step.dependencies): if dep.dependsOn not in stepnames: errors.append( - ErrorWrapper( - ValueError(f"Unknown step '{dep.dependsOn}'"), + InitErrorDetails( + type="value_error", # The path to the problematic dependsOn value - ("step", i, "dependencies", j, "dependsOn"), + loc=("step", i, "dependencies", j, "dependsOn"), + ctx={"error": ValueError(f"Unknown step '{dep.dependsOn}'")}, + input=dep.dependsOn, ) ) if errors: - raise ValidationError(errors, JobTemplate) + raise ValidationError.from_exception_data(self.__class__.__name__, errors) - return values + return self - @root_validator - def _validate_env_names_dont_match_step_env_names( - cls, values: dict[str, Any] - ) -> dict[str, Any]: + @model_validator(mode="after") + def _validate_env_names_dont_match_step_env_names(self) -> Self: # Check that if we have job-level Environments defined that none of the defined Step-level # environments have the same name. # Names must be unique between Steps & Jobs. - steplist = values.get("steps", []) + steplist = self.steps or [] if not steplist: - return values + return self - envlist = values.get("jobEnvironments", []) + envlist = self.jobEnvironments or [] if not envlist: - return values + return self job_env_names = set(env.name for env in cast(JobEnvironmentsList, envlist)) - errors = list[ErrorWrapper]() + errors = list[InitErrorDetails]() for i, step in enumerate(steplist): if step.stepEnvironments is not None: for j, env in enumerate(step.stepEnvironments): if env.name in job_env_names: errors.append( - ErrorWrapper( - ValueError( - f"Name {env.name} must differ from the names of Environments defined at the root of the template." - ), + InitErrorDetails( + type="value_error", # The path to the problematic environment name - ("step", i, "stepEnvironments", j, "name"), + loc=("step", i, "stepEnvironments", j, "name"), + ctx={ + "error": ValueError( + f"Name {env.name} must differ from the names of Environments defined at the root of the template." + ) + }, + input=env.name, ) ) if errors: - raise ValidationError(errors, JobTemplate) + raise ValidationError.from_exception_data(self.__class__.__name__, errors) - return values + return self class EnvironmentTemplate(OpenJDModel_v2023_09): @@ -2360,7 +2538,8 @@ class EnvironmentTemplate(OpenJDModel_v2023_09): "environment": {"parameterDefinitions"}, } - @validator("parameterDefinitions") + @field_validator("parameterDefinitions") + @classmethod def _unique_parameter_names( cls, v: Optional[JobParameterDefinitionList] ) -> Optional[JobParameterDefinitionList]: @@ -2377,5 +2556,5 @@ def _root_template_prevalidator(cls, values: dict[str, Any]) -> dict[str, Any]: cast(Type[OpenJDModel], cls), values ) if errors: - raise ValidationError(errors, EnvironmentTemplate) + raise ValidationError.from_exception_data(cls.__name__, line_errors=errors) return values diff --git a/test/openjd/model/_internal/test_create_job.py b/test/openjd/model/_internal/test_create_job.py index 0ad75ee..7160f93 100644 --- a/test/openjd/model/_internal/test_create_job.py +++ b/test/openjd/model/_internal/test_create_job.py @@ -5,7 +5,7 @@ from typing import Optional, Union, cast import pytest -from pydantic.v1 import PositiveInt, ValidationError +from pydantic import PositiveInt, ValidationError from openjd.model import SymbolTable from openjd.model._format_strings import FormatString diff --git a/test/openjd/model/_internal/test_variable_reference_validation.py b/test/openjd/model/_internal/test_variable_reference_validation.py index e6c06a9..6e1ba95 100644 --- a/test/openjd/model/_internal/test_variable_reference_validation.py +++ b/test/openjd/model/_internal/test_variable_reference_validation.py @@ -3,7 +3,7 @@ from typing import Any, Literal, Union from enum import Enum from typing_extensions import Annotated -from pydantic.v1 import Field +from pydantic import Field import pytest diff --git a/test/openjd/model/format_strings/test_dyn_constrained_str.py b/test/openjd/model/format_strings/test_dyn_constrained_str.py index d6b955a..ff324f0 100644 --- a/test/openjd/model/format_strings/test_dyn_constrained_str.py +++ b/test/openjd/model/format_strings/test_dyn_constrained_str.py @@ -3,7 +3,7 @@ import re import pytest -from pydantic.v1 import BaseModel, ValidationError +from pydantic import BaseModel, ValidationError from openjd.model._format_strings._dyn_constrained_str import DynamicConstrainedStr @@ -18,7 +18,7 @@ class Model(BaseModel): s: DynamicConstrainedStr # WHEN - Model.parse_obj({"s": "123"}) + Model.model_validate({"s": "123"}) # THEN # raised no error @@ -34,7 +34,7 @@ class Model(BaseModel): model = Model(s="12") # WHEN - as_dict = model.dict() + as_dict = model.model_dump() # THEN assert as_dict == {"s": "12"} @@ -48,7 +48,7 @@ class Model(BaseModel): # WHEN with pytest.raises(ValidationError) as excinfo: - Model.parse_obj({"s": 123}) + Model.model_validate({"s": 123}) # THEN assert len(excinfo.value.errors()) == 1 @@ -64,7 +64,7 @@ class Model(BaseModel): s: StrType # WHEN - Model.parse_obj({"s": "0" * 10}) + Model.model_validate({"s": "0" * 10}) # THEN # raised no error @@ -81,7 +81,7 @@ class Model(BaseModel): # WHEN with pytest.raises(ValidationError) as excinfo: - Model.parse_obj({"s": "0" * 9}) + Model.model_validate({"s": "0" * 9}) # THEN assert len(excinfo.value.errors()) == 1 @@ -97,7 +97,7 @@ class Model(BaseModel): s: StrType # WHEN - Model.parse_obj({"s": "0" * 10}) + Model.model_validate({"s": "0" * 10}) # THEN # raised no error @@ -114,7 +114,7 @@ class Model(BaseModel): # WHEN with pytest.raises(ValidationError) as excinfo: - Model.parse_obj({"s": "0" * 11}) + Model.model_validate({"s": "0" * 11}) # THEN assert len(excinfo.value.errors()) == 1 @@ -130,7 +130,7 @@ class Model(BaseModel): s: StrType # WHEN - Model.parse_obj({"s": "0" * 10}) + Model.model_validate({"s": "0" * 10}) # THEN # no errors raised @@ -147,7 +147,7 @@ class Model(BaseModel): # WHEN with pytest.raises(ValidationError) as excinfo: - Model.parse_obj({"s": "1" * 10}) + Model.model_validate({"s": "1" * 10}) # THEN assert len(excinfo.value.errors()) == 1 @@ -163,7 +163,7 @@ class Model(BaseModel): s: StrType # WHEN - Model.parse_obj({"s": "0" * 10}) + Model.model_validate({"s": "0" * 10}) # THEN # no errors raised @@ -180,7 +180,7 @@ class Model(BaseModel): # WHEN with pytest.raises(ValidationError) as excinfo: - Model.parse_obj({"s": "1" * 10}) + Model.model_validate({"s": "1" * 10}) # THEN assert len(excinfo.value.errors()) == 1 diff --git a/test/openjd/model/test_convert_pydantic_error.py b/test/openjd/model/test_convert_pydantic_error.py index 24789e3..1e666b3 100644 --- a/test/openjd/model/test_convert_pydantic_error.py +++ b/test/openjd/model/test_convert_pydantic_error.py @@ -1,10 +1,11 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +from typing_extensions import Self -from pydantic.v1 import BaseModel, root_validator, Field +from pydantic import BaseModel, Field, model_validator +from pydantic_core import ErrorDetails from typing import Literal, Union from openjd.model._convert_pydantic_error import ( - ErrorDict, pydantic_validationerrors_to_str, _error_dict_to_str, ) @@ -20,9 +21,19 @@ class Model(BaseModel): f1: str f2: int - errors: list[ErrorDict] = [ - {"loc": ("f1",), "msg": "error message1", "type": "error-type"}, - {"loc": ("f2",), "msg": "error message2", "type": "error-type"}, + errors: list[ErrorDetails] = [ + { + "loc": ("f1",), + "msg": "error message1", + "type": "error-type", + "input": "input-value1", + }, + { + "loc": ("f2",), + "msg": "error message2", + "type": "error-type", + "input": "input-value2", + }, ] expected = "2 validation errors for Model\nf1:\n\terror message1\nf2:\n\terror message2" @@ -47,7 +58,12 @@ class Model(BaseModel): f1: str f2: int - error: ErrorDict = {"loc": ("f2",), "msg": "error message", "type": "error-type"} + error: ErrorDetails = { + "loc": ("f2",), + "msg": "error message", + "type": "error-type", + "input": "input-value", + } expected = "f2:\n\terror message" # WHEN @@ -67,13 +83,14 @@ class Inner(BaseModel): class Model(BaseModel): inner: Inner - error: ErrorDict = { + error: ErrorDetails = { "loc": ( "inner", "ff", ), "msg": "error message", "type": "error-type", + "input": "input-value", } expected = "inner -> ff:\n\terror message" @@ -83,7 +100,7 @@ class Model(BaseModel): # THEN assert result == expected - def test_base_root_validator_error(self) -> None: + def test_base_model_validator_error(self) -> None: # Make sure that our path to error is correct for validation error # at the base level's root validator # This is a special case where we do not want the error message to be indented @@ -94,11 +111,16 @@ def test_base_root_validator_error(self) -> None: class Model(BaseModel): ff: str - @root_validator - def _validate(cls, values): + @model_validator(mode="after") + def _validate(self) -> Self: raise ValueError("error message") - error: ErrorDict = {"loc": ("__root__",), "msg": "error message", "type": "error-type"} + error: ErrorDetails = { + "loc": ("__root__",), + "msg": "error message", + "type": "error-type", + "input": "input-value", + } expected = "Model: error message" # WHEN @@ -107,7 +129,7 @@ def _validate(cls, values): # THEN assert result == expected - def test_inner_root_validator_error(self) -> None: + def test_inner_model_validator_error(self) -> None: # Make sure that our path to error is correct for validation error # at a nested level's root validator. # In this case we drop the '__root__' field at the end and report the error @@ -118,20 +140,21 @@ def test_inner_root_validator_error(self) -> None: class Inner(BaseModel): ff: str - @root_validator - def _validate(cls, values): + @model_validator(mode="after") + def _validate(self) -> Self: raise ValueError("error message") class Model(BaseModel): inner: Inner - error: ErrorDict = { + error: ErrorDetails = { "loc": ( "inner", "__root__", ), "msg": "error message", "type": "error-type", + "input": "input-value", } expected = "inner:\n\terror message" @@ -155,13 +178,14 @@ def test_scalar(self) -> None: class Model(BaseModel): field: list[int] - error: ErrorDict = { + error: ErrorDetails = { "loc": ( "field", 2, ), "msg": "error message", "type": "error-type", + "input": "input-value", } expected = "field[2]:\n\terror message" @@ -182,7 +206,7 @@ class Inner(BaseModel): class Model(BaseModel): inner: list[Inner] - error: ErrorDict = { + error: ErrorDetails = { "loc": ( "inner", 2, @@ -190,6 +214,7 @@ class Model(BaseModel): ), "msg": "error message", "type": "error-type", + "input": "input-value", } expected = "inner[2] -> ff:\n\terror message" @@ -224,7 +249,7 @@ class TestDiscriminatedUnion: def test(self) -> None: # GIVEN - error: ErrorDict = { + error: ErrorDetails = { "loc": ( "inner", 2, @@ -233,6 +258,7 @@ def test(self) -> None: ), "msg": "error message", "type": "error-type", + "input": "input-value", } expected = "inner[2] -> ff:\n\terror message" diff --git a/test/openjd/model/test_create_job.py b/test/openjd/model/test_create_job.py index ba453fd..1519a27 100644 --- a/test/openjd/model/test_create_job.py +++ b/test/openjd/model/test_create_job.py @@ -735,7 +735,7 @@ def test_fails_to_instantiate(self) -> None: # THEN assert ( - "1 validation errors for JobTemplate\nname:\n\tensure this value has at most 128 characters" + "1 validation errors for JobTemplate\nname:\n\tString should have at most 128 characters" in str(excinfo.value) ) @@ -781,6 +781,6 @@ def test_uneven_parameter_space_association(self) -> None: # THEN assert ( - "1 validation errors for JobTemplate\nsteps[0] -> steps[0] -> parameterSpace -> combination:\n\tAssociative expressions must have arguments with identical ranges. Expression (A, B) has argument lengths (10, 2)." + "1 validation errors for JobTemplate\nsteps[0] -> parameterSpace -> combination:\n\tAssociative expressions must have arguments with identical ranges. Expression (A, B) has argument lengths (10, 2)." in str(excinfo.value) ) diff --git a/test/openjd/model/test_importable.py b/test/openjd/model/test_importable.py index f1f500f..637e58a 100644 --- a/test/openjd/model/test_importable.py +++ b/test/openjd/model/test_importable.py @@ -5,5 +5,5 @@ def test_openjd_importable(): import openjd # noqa: F401 -def test_importable(): +def test_openjd_model_importable(): import openjd.model # noqa: F401 diff --git a/test/openjd/model/test_step_param_space_iter.py b/test/openjd/model/test_step_param_space_iter.py index a15148f..be06bc1 100644 --- a/test/openjd/model/test_step_param_space_iter.py +++ b/test/openjd/model/test_step_param_space_iter.py @@ -32,7 +32,7 @@ class TestStepParameterSpaceIterator_2023_09: # noqa: N801 @pytest.mark.parametrize( "range_int_param", [ - RangeListTaskParameterDefinition_2023_09(type=ParameterValueType.INT, range=[1, 2]), + RangeListTaskParameterDefinition_2023_09(type=ParameterValueType.INT, range=["1", "2"]), RangeExpressionTaskParameterDefinition_2023_09( type=ParameterValueType.INT, range="1-2" ), @@ -104,7 +104,7 @@ def test_no_param_getelem(self): @pytest.mark.parametrize( "range_int_param", [ - RangeListTaskParameterDefinition_2023_09(type=ParameterValueType.INT, range=[1, 2]), + RangeListTaskParameterDefinition_2023_09(type=ParameterValueType.INT, range=["1", "2"]), RangeExpressionTaskParameterDefinition_2023_09( type=ParameterValueType.INT, range="1-2" ), @@ -131,7 +131,7 @@ def test_single_param_iteration(self, range_int_param): with pytest.raises(StopIteration): next(it) - @pytest.mark.parametrize("param_range", [[10], [10, 11, 12, 13, 14, 15]]) + @pytest.mark.parametrize("param_range", [["10"], ["10", "11", "12", "13", "14", "15"]]) def test_single_param_getelem(self, param_range): # GIVEN space = StepParameterSpace_2023_09( @@ -163,7 +163,7 @@ def test_single_param_getelem(self, param_range): @pytest.mark.parametrize( "given, expected", [ - ([1, 2, 3], 3), + (["1", "2", "3"], 3), ("1-5", 5), (["a", "b", "c", "d"], 4), ], @@ -197,7 +197,7 @@ def test_single_param_len(self, given, expected) -> None: @pytest.mark.parametrize( "range_int_param", [ - RangeListTaskParameterDefinition_2023_09(type=ParameterValueType.INT, range=[1, 2]), + RangeListTaskParameterDefinition_2023_09(type=ParameterValueType.INT, range=["1", "2"]), RangeExpressionTaskParameterDefinition_2023_09( type=ParameterValueType.INT, range="1-2" ), @@ -239,7 +239,7 @@ def test_product_iteration(self) -> None: space = StepParameterSpace_2023_09( taskParameterDefinitions={ "Param1": RangeListTaskParameterDefinition_2023_09( - type=ParameterValueType.INT, range=[1, 2] + type=ParameterValueType.INT, range=["1", "2"] ), "Param2": RangeListTaskParameterDefinition_2023_09( type=ParameterValueType.STRING, range=["a", "b", "c"] @@ -287,7 +287,7 @@ def test_product_len(self): type=ParameterValueType.STRING, range=["a", "b", "c"] ), "Param3": RangeListTaskParameterDefinition_2023_09( - type=ParameterValueType.INT, range=[-1, -2] + type=ParameterValueType.INT, range=["-1", "-2"] ), }, combination="Param1 * Param2 * Param3", @@ -306,7 +306,7 @@ def test_product_getitem(self) -> None: space = StepParameterSpace_2023_09( taskParameterDefinitions={ "Param1": RangeListTaskParameterDefinition_2023_09( - type=ParameterValueType.INT, range=[1, 2] + type=ParameterValueType.INT, range=["1", "2"] ), "Param2": RangeListTaskParameterDefinition_2023_09( type=ParameterValueType.STRING, range=["a", "b", "c"] @@ -361,7 +361,7 @@ def test_associate_iteration(self) -> None: type=ParameterValueType.STRING, range=["a", "b", "c", "d"] ), "Param3": RangeListTaskParameterDefinition_2023_09( - type=ParameterValueType.INT, range=[-1, -2, -3, -4] + type=ParameterValueType.INT, range=["-1", "-2", "-3", "-4"] ), }, combination="(Param1, Param2, Param3)", @@ -389,7 +389,7 @@ def test_associate_len(self) -> None: space = StepParameterSpace_2023_09( taskParameterDefinitions={ "Param1": RangeListTaskParameterDefinition_2023_09( - type=ParameterValueType.INT, range=[1, 2, 3, 4] + type=ParameterValueType.INT, range=["1", "2", "3", "4"] ), "Param2": RangeListTaskParameterDefinition_2023_09( type=ParameterValueType.STRING, range=["a", "b", "c", "d"] @@ -420,14 +420,14 @@ def test_associate_getitem(self) -> None: type=ParameterValueType.STRING, range=["a", "b", "c", "d"] ), "Param3": RangeListTaskParameterDefinition_2023_09( - type=ParameterValueType.INT, range=[-1, -2, -3, -4] + type=ParameterValueType.INT, range=["-1", "-2", "-3", "-4"] ), }, combination="(Param1, Param2, Param3)", ) # WHEN - result = StepParameterSpaceIterator(space=space) + space_iter = StepParameterSpaceIterator(space=space) # THEN element: Callable[[int, str, int], dict[str, ParameterValue]] = lambda p1, p2, p3: { @@ -442,13 +442,13 @@ def test_associate_getitem(self) -> None: element(4, "d", -4), ] with pytest.raises(IndexError): - result[len(expected_values)] + space_iter[len(expected_values)] with pytest.raises(IndexError): - result[-len(expected_values) - 1] - assert expected_values == [result[i] for i in range(0, len(expected_values))] + space_iter[-len(expected_values) - 1] + assert expected_values == [space_iter[i] for i in range(0, len(expected_values))] expected_reversed = expected_values.copy() expected_reversed.reverse() - assert expected_reversed == [result[-i - 1] for i in range(0, len(expected_values))] + assert expected_reversed == [space_iter[-i - 1] for i in range(0, len(expected_values))] def test_nested_expr_iteration(self) -> None: # A more deeply nested test to hit all of the recursive edge cases. @@ -458,7 +458,7 @@ def test_nested_expr_iteration(self) -> None: space = StepParameterSpace_2023_09( taskParameterDefinitions={ "Param1": RangeListTaskParameterDefinition_2023_09( - type=ParameterValueType.INT, range=[1, 2] + type=ParameterValueType.INT, range=["1", "2"] ), "Param2": RangeListTaskParameterDefinition_2023_09( type=ParameterValueType.STRING, range=["a", "b", "c", "d"] @@ -467,7 +467,7 @@ def test_nested_expr_iteration(self) -> None: type=ParameterValueType.INT, range="10-11" ), "Param4": RangeListTaskParameterDefinition_2023_09( - type=ParameterValueType.INT, range=[20, 21] + type=ParameterValueType.INT, range=["20", "21"] ), }, combination="Param1 * ( Param2, Param3 * Param4 )", diff --git a/test/openjd/model/v2023_09/test_action.py b/test/openjd/model/v2023_09/test_action.py index c1955a3..e4c150b 100644 --- a/test/openjd/model/v2023_09/test_action.py +++ b/test/openjd/model/v2023_09/test_action.py @@ -3,7 +3,7 @@ from typing import Any import pytest -from pydantic.v1 import ValidationError +from pydantic import ValidationError from openjd.model._parse import _parse_model from openjd.model.v2023_09 import Action, EnvironmentActions, StepActions diff --git a/test/openjd/model/v2023_09/test_create.py b/test/openjd/model/v2023_09/test_create.py index b10e98c..96b7c7a 100644 --- a/test/openjd/model/v2023_09/test_create.py +++ b/test/openjd/model/v2023_09/test_create.py @@ -391,6 +391,6 @@ def test(self) -> None: # Note: The dict compare generates an easier to read diff if there's a test failure. # It is not essential to the test. - assert result.dict() == expected.dict() + assert result.model_dump() == expected.model_dump() # This is the important assertion. assert result == expected diff --git a/test/openjd/model/v2023_09/test_definitions.py b/test/openjd/model/v2023_09/test_definitions.py index 873e5c5..82f2a80 100644 --- a/test/openjd/model/v2023_09/test_definitions.py +++ b/test/openjd/model/v2023_09/test_definitions.py @@ -1,7 +1,7 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. import pytest -from pydantic.v1 import BaseModel +from pydantic import BaseModel from typing import Type import openjd.model.v2023_09 as mod from inspect import getmembers, getmodule, isclass diff --git a/test/openjd/model/v2023_09/test_embedded.py b/test/openjd/model/v2023_09/test_embedded.py index 65b9be4..24f2dda 100644 --- a/test/openjd/model/v2023_09/test_embedded.py +++ b/test/openjd/model/v2023_09/test_embedded.py @@ -3,7 +3,7 @@ from typing import Any import pytest -from pydantic.v1 import ValidationError +from pydantic import ValidationError from openjd.model._parse import _parse_model from openjd.model.v2023_09 import EmbeddedFileText diff --git a/test/openjd/model/v2023_09/test_environment_template.py b/test/openjd/model/v2023_09/test_environment_template.py index 2bd9910..edef1b0 100644 --- a/test/openjd/model/v2023_09/test_environment_template.py +++ b/test/openjd/model/v2023_09/test_environment_template.py @@ -3,7 +3,7 @@ from typing import Any import pytest -from pydantic.v1 import ValidationError +from pydantic import ValidationError from openjd.model._parse import _parse_model from openjd.model.v2023_09 import EnvironmentTemplate diff --git a/test/openjd/model/v2023_09/test_environments.py b/test/openjd/model/v2023_09/test_environments.py index a498b48..24350cc 100644 --- a/test/openjd/model/v2023_09/test_environments.py +++ b/test/openjd/model/v2023_09/test_environments.py @@ -3,7 +3,7 @@ from typing import Any import pytest -from pydantic.v1 import ValidationError +from pydantic import ValidationError from openjd.model._parse import _parse_model from openjd.model.v2023_09 import Environment diff --git a/test/openjd/model/v2023_09/test_job_parameters.py b/test/openjd/model/v2023_09/test_job_parameters.py index 459c22e..4759f59 100644 --- a/test/openjd/model/v2023_09/test_job_parameters.py +++ b/test/openjd/model/v2023_09/test_job_parameters.py @@ -4,7 +4,7 @@ from typing import Any import pytest -from pydantic.v1 import ValidationError +from pydantic import ValidationError from openjd.model._parse import _parse_model from openjd.model.v2023_09 import ( diff --git a/test/openjd/model/v2023_09/test_job_template.py b/test/openjd/model/v2023_09/test_job_template.py index a4ed7f2..94df84f 100644 --- a/test/openjd/model/v2023_09/test_job_template.py +++ b/test/openjd/model/v2023_09/test_job_template.py @@ -3,7 +3,7 @@ from typing import Any import pytest -from pydantic.v1 import ValidationError +from pydantic import ValidationError from openjd.model._parse import _parse_model from openjd.model.v2023_09 import JobTemplate diff --git a/test/openjd/model/v2023_09/test_parameter_space.py b/test/openjd/model/v2023_09/test_parameter_space.py index b9d5d0b..cf07fa2 100644 --- a/test/openjd/model/v2023_09/test_parameter_space.py +++ b/test/openjd/model/v2023_09/test_parameter_space.py @@ -3,7 +3,7 @@ from typing import Any import pytest -from pydantic.v1 import ValidationError +from pydantic import ValidationError from openjd.model._parse import _parse_model from openjd.model.v2023_09 import ( diff --git a/test/openjd/model/v2023_09/test_scripts.py b/test/openjd/model/v2023_09/test_scripts.py index 7acd2a1..aa1207e 100644 --- a/test/openjd/model/v2023_09/test_scripts.py +++ b/test/openjd/model/v2023_09/test_scripts.py @@ -3,7 +3,7 @@ from typing import Any import pytest -from pydantic.v1 import ValidationError +from pydantic import ValidationError from openjd.model._parse import _parse_model from openjd.model.v2023_09 import EnvironmentScript, StepScript diff --git a/test/openjd/model/v2023_09/test_step_host_requirements.py b/test/openjd/model/v2023_09/test_step_host_requirements.py index 241afcb..a80ba12 100644 --- a/test/openjd/model/v2023_09/test_step_host_requirements.py +++ b/test/openjd/model/v2023_09/test_step_host_requirements.py @@ -4,7 +4,7 @@ import string import pytest -from pydantic.v1 import ValidationError +from pydantic import ValidationError from openjd.model._parse import _parse_model from openjd.model.v2023_09 import ( diff --git a/test/openjd/model/v2023_09/test_step_template.py b/test/openjd/model/v2023_09/test_step_template.py index dc0770e..5bfde38 100644 --- a/test/openjd/model/v2023_09/test_step_template.py +++ b/test/openjd/model/v2023_09/test_step_template.py @@ -3,7 +3,7 @@ from typing import Any import pytest -from pydantic.v1 import ValidationError +from pydantic import ValidationError from openjd.model._parse import _parse_model from openjd.model.v2023_09 import StepTemplate diff --git a/test/openjd/model/v2023_09/test_strings.py b/test/openjd/model/v2023_09/test_strings.py index 91369df..260277c 100644 --- a/test/openjd/model/v2023_09/test_strings.py +++ b/test/openjd/model/v2023_09/test_strings.py @@ -4,7 +4,7 @@ import string import pytest -from pydantic.v1 import BaseModel, ValidationError +from pydantic import BaseModel, ValidationError from openjd.model.v2023_09 import ( AmountCapabilityName, @@ -105,7 +105,7 @@ def test_parse_success(self, value: str) -> None: data = {"name": value} # WHEN - JobTemplateNameModel.parse_obj(data) + JobTemplateNameModel.model_validate(data) # THEN # no exceptions raised @@ -119,7 +119,7 @@ def test_parse_fails(self, data: dict[str, Any]) -> None: # WHEN with pytest.raises(ValidationError) as excinfo: - JobTemplateNameModel.parse_obj(data) + JobTemplateNameModel.model_validate(data) # THEN assert len(excinfo.value.errors()) > 0 @@ -139,7 +139,7 @@ def test_parse_success(self, value: str) -> None: data = {"name": value} # WHEN - JobNameModel.parse_obj(data) + JobNameModel.model_validate(data) # THEN # no exceptions raised @@ -163,7 +163,7 @@ def test_parse_fails(self, data: dict[str, Any]) -> None: # WHEN with pytest.raises(ValidationError) as excinfo: - JobNameModel.parse_obj(data) + JobNameModel.model_validate(data) # THEN assert len(excinfo.value.errors()) > 0 @@ -183,7 +183,7 @@ def test_parse_success(self, value: str) -> None: data = {"name": value} # WHEN - StepNameModel.parse_obj(data) + StepNameModel.model_validate(data) # THEN # no exceptions raised @@ -207,7 +207,7 @@ def test_parse_fails(self, data: dict[str, Any]) -> None: # WHEN with pytest.raises(ValidationError) as excinfo: - StepNameModel.parse_obj(data) + StepNameModel.model_validate(data) # THEN assert len(excinfo.value.errors()) > 0 @@ -227,7 +227,7 @@ def test_parse_success(self, value: str) -> None: data = {"name": value} # WHEN - EnvironmentNameModel.parse_obj(data) + EnvironmentNameModel.model_validate(data) # THEN # no exceptions raised @@ -251,7 +251,7 @@ def test_parse_fails(self, data: dict[str, Any]) -> None: # WHEN with pytest.raises(ValidationError) as excinfo: - EnvironmentNameModel.parse_obj(data) + EnvironmentNameModel.model_validate(data) # THEN assert len(excinfo.value.errors()) > 0 @@ -280,7 +280,7 @@ def test_parse_success(self, value: str) -> None: data = {"name": value} # WHEN - EnvironmentVariableNameStringModel.parse_obj(data) + EnvironmentVariableNameStringModel.model_validate(data) # THEN # no exceptions raised @@ -315,7 +315,7 @@ def test_parse_fails(self, data: dict[str, Any]) -> None: # WHEN with pytest.raises(ValidationError) as excinfo: - EnvironmentVariableNameStringModel.parse_obj(data) + EnvironmentVariableNameStringModel.model_validate(data) # THEN assert len(excinfo.value.errors()) > 0 @@ -335,7 +335,7 @@ def test_parse_success(self, value: str) -> None: data = {"value": value} # WHEN - EnvironmentVariableValueStringModel.parse_obj(data) + EnvironmentVariableValueStringModel.model_validate(data) # THEN # no exceptions raised @@ -351,7 +351,7 @@ def test_parse_fails(self, data: dict[str, Any]) -> None: # WHEN with pytest.raises(ValidationError) as excinfo: - EnvironmentVariableValueStringModel.parse_obj(data) + EnvironmentVariableValueStringModel.model_validate(data) # THEN assert len(excinfo.value.errors()) > 0 @@ -379,7 +379,7 @@ def test_parse_success(self, value: str) -> None: data = {"id": value} # WHEN - IdentifierModel.parse_obj(data) + IdentifierModel.model_validate(data) # THEN # no exceptions raised @@ -414,7 +414,7 @@ def test_parse_fails(self, data: dict[str, Any]) -> None: # WHEN with pytest.raises(ValidationError) as excinfo: - IdentifierModel.parse_obj(data) + IdentifierModel.model_validate(data) # THEN assert len(excinfo.value.errors()) > 0 @@ -437,7 +437,7 @@ def test_parse_success(self, value: str) -> None: data = {"desc": value} # WHEN - DescriptionModel.parse_obj(data) + DescriptionModel.model_validate(data) # THEN # no exceptions raised @@ -460,7 +460,7 @@ def test_parse_fails(self, data: dict[str, Any]) -> None: # WHEN with pytest.raises(ValidationError) as excinfo: - DescriptionModel.parse_obj(data) + DescriptionModel.model_validate(data) # THEN assert len(excinfo.value.errors()) > 0 @@ -479,7 +479,7 @@ def test_parse_success(self, value: str) -> None: data = {"str": value} # WHEN - ParameterStringModel.parse_obj(data) + ParameterStringModel.model_validate(data) # THEN # no exceptions raised @@ -496,7 +496,7 @@ def test_parse_fails(self, data: dict[str, Any]) -> None: # WHEN with pytest.raises(ValidationError) as excinfo: - ParameterStringModel.parse_obj(data) + ParameterStringModel.model_validate(data) # THEN assert len(excinfo.value.errors()) > 0 @@ -519,7 +519,7 @@ def test_parse_success(self, value: str) -> None: data = {"arg": value} # WHEN - ArgStringModel.parse_obj(data) + ArgStringModel.model_validate(data) # THEN # no exceptions raised @@ -542,7 +542,7 @@ def test_parse_fails(self, data: dict[str, Any]) -> None: # WHEN with pytest.raises(ValidationError) as excinfo: - ArgStringModel.parse_obj(data) + ArgStringModel.model_validate(data) # THEN assert len(excinfo.value.errors()) > 0 @@ -565,7 +565,7 @@ def test_parse_success(self, value: str) -> None: data = {"cmd": value} # WHEN - CommandStringModel.parse_obj(data) + CommandStringModel.model_validate(data) # THEN # no exceptions raised @@ -589,7 +589,7 @@ def test_parse_fails(self, data: dict[str, Any]) -> None: # WHEN with pytest.raises(ValidationError) as excinfo: - CommandStringModel.parse_obj(data) + CommandStringModel.model_validate(data) # THEN assert len(excinfo.value.errors()) > 0 @@ -611,7 +611,7 @@ def test_parse_success(self, value: str) -> None: data = {"expr": value} # WHEN - CombinationExprModel.parse_obj(data) + CombinationExprModel.model_validate(data) # THEN # no exceptions raised @@ -639,7 +639,7 @@ def test_parse_fails(self, data: dict[str, Any]) -> None: # WHEN with pytest.raises(ValidationError) as excinfo: - CombinationExprModel.parse_obj(data) + CombinationExprModel.model_validate(data) # THEN assert len(excinfo.value.errors()) > 0 @@ -654,7 +654,7 @@ def test_parse_success(self, value: str) -> None: data = {"str": value} # WHEN - TaskParameterStringValueAsJobModel.parse_obj(data) + TaskParameterStringValueAsJobModel.model_validate(data) # THEN # no exceptions raised @@ -668,7 +668,7 @@ def test_parse_fails(self, data: dict[str, Any]) -> None: # WHEN with pytest.raises(ValidationError) as excinfo: - TaskParameterStringValueAsJobModel.parse_obj(data) + TaskParameterStringValueAsJobModel.model_validate(data) # THEN assert len(excinfo.value.errors()) > 0 @@ -685,7 +685,7 @@ def test_parse_success(self, value: str) -> None: data = {"str": value} # WHEN - AmountCapabilityNameModel.parse_obj(data) + AmountCapabilityNameModel.model_validate(data) # THEN # no exceptions raised @@ -702,7 +702,7 @@ def test_parse_fails(self, data: dict[str, Any]) -> None: # WHEN with pytest.raises(ValidationError) as excinfo: - AmountCapabilityNameModel.parse_obj(data) + AmountCapabilityNameModel.model_validate(data) # THEN assert len(excinfo.value.errors()) > 0 @@ -719,7 +719,7 @@ def test_parse_success(self, value: str) -> None: data = {"str": value} # WHEN - AttributeCapabilityNameModel.parse_obj(data) + AttributeCapabilityNameModel.model_validate(data) # THEN # no exceptions raised @@ -736,7 +736,7 @@ def test_parse_fails(self, data: dict[str, Any]) -> None: # WHEN with pytest.raises(ValidationError) as excinfo: - AttributeCapabilityNameModel.parse_obj(data) + AttributeCapabilityNameModel.model_validate(data) # THEN assert len(excinfo.value.errors()) > 0 @@ -756,7 +756,7 @@ def test_parse_success(self, value: str) -> None: data = {"str": value} # WHEN - UserInterfaceLabelStringValueModel.parse_obj(data) + UserInterfaceLabelStringValueModel.model_validate(data) # THEN # no exceptions raised @@ -780,7 +780,7 @@ def test_parse_fails(self, data: dict[str, Any]) -> None: # WHEN with pytest.raises(ValidationError) as excinfo: - UserInterfaceLabelStringValueModel.parse_obj(data) + UserInterfaceLabelStringValueModel.model_validate(data) # THEN assert len(excinfo.value.errors()) > 0 @@ -801,7 +801,7 @@ def test_parse_success(self, value: str) -> None: data = {"str": value} # WHEN - FileDialogFilterPatternStringValueModel.parse_obj(data) + FileDialogFilterPatternStringValueModel.model_validate(data) # THEN # no exceptions raised @@ -854,7 +854,7 @@ def test_parse_fails(self, data: dict[str, Any]) -> None: # WHEN with pytest.raises(ValidationError) as excinfo: - FileDialogFilterPatternStringValueModel.parse_obj(data) + FileDialogFilterPatternStringValueModel.model_validate(data) # THEN assert len(excinfo.value.errors()) > 0 diff --git a/test/openjd/model/v2023_09/test_template_variables.py b/test/openjd/model/v2023_09/test_template_variables.py index 4714ea0..da4c1cd 100644 --- a/test/openjd/model/v2023_09/test_template_variables.py +++ b/test/openjd/model/v2023_09/test_template_variables.py @@ -3,7 +3,7 @@ from typing import Any import pytest -from pydantic.v1 import ValidationError +from pydantic import ValidationError from openjd.model._parse import _parse_model from openjd.model.v2023_09 import JobTemplate