Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions osprey_worker/src/osprey/engine/ast/ast_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Type, TypeVar, Union
from typing import Any, Callable, Iterator, Optional, Sequence, Set, Tuple, Type, TypeVar, Union

# from osprey.engine.utils.periodic_execution_yielder import maybe_periodic_yield
from .grammar import ASTNode, Root, Statement
Expand Down Expand Up @@ -41,7 +41,7 @@ def traverse_mro(klass: Any) -> None:
def _make_memoized_field_values_iterator() -> Callable[
['ASTNode'], Iterator[Tuple[str, Union['ASTNode', Sequence['ASTNode']]]]
]:
_field_cache: Dict[Type['ASTNode'], List[str]] = {}
_field_cache: dict[Type['ASTNode'], list[str]] = {}

def _iter_field_values(node: ASTNode) -> Iterator[Tuple[str, Union['ASTNode', Sequence['ASTNode']]]]:
# To avoid the cost of iterating fields over known node classes,
Expand Down
20 changes: 10 additions & 10 deletions osprey_worker/src/osprey/engine/ast/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from hashlib import sha256
from itertools import chain
from pathlib import Path
from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Union
from typing import Any, Iterator, Optional, Sequence, Set, Union

import deepmerge
import yaml
Expand All @@ -19,7 +19,7 @@ class Sources:
"""A collection of sources, and an arbitrary configuration which describes a set of imported rules and perhaps
additional configuration that will be executed by the engine."""

def __init__(self, sources: Dict[str, Source], config: Optional['SourcesConfig'] = None):
def __init__(self, sources: dict[str, Source], config: Optional['SourcesConfig'] = None):
assert SOURCE_ENTRY_POINT_PATH in sources, (
"Sources requires a file with the `path` 'main.sml' to be present as the entry-point"
)
Expand Down Expand Up @@ -49,12 +49,12 @@ def paths(self) -> Set[str]:
"""Returns the set of paths within this sources collection."""
return set(self._sources.keys())

def glob(self, pattern: str) -> List[Source]:
def glob(self, pattern: str) -> list[Source]:
"""Returns a list of sources that match the given pattern."""
matches = fnmatch.filter(self._sources.keys(), pattern)
return [self._sources[k] for k in matches]

def to_dict(self) -> Dict[str, str]:
def to_dict(self) -> dict[str, str]:
"""The inverse of from_dict, serializes this Sources collection to a dictionary."""
sources = list(self)
if self._config.source.contents:
Expand All @@ -63,7 +63,7 @@ def to_dict(self) -> Dict[str, str]:
return {source.path: source.contents for source in sources}

@staticmethod
def from_dict(sources_dict: Dict[str, str]) -> 'Sources':
def from_dict(sources_dict: dict[str, str]) -> 'Sources':
"""Creates a Sources object from a dict of path -> contents."""
builder = SourcesBuilder()

Expand Down Expand Up @@ -123,9 +123,9 @@ class SourcesBuilder:
at which you can mutate sources."""

def __init__(self) -> None:
self._sources: Dict[str, Source] = {}
self._sources: dict[str, Source] = {}
self._config: Optional['SourcesConfig'] = None
self._config_sources: Dict[str, Source] = {}
self._config_sources: dict[str, Source] = {}

def add_source(self, source: Source) -> 'SourcesBuilder':
"""Adds a source to the sources collection."""
Expand All @@ -150,7 +150,7 @@ def build(self) -> Sources:
return Sources(self._sources, config=self._config)


class SourcesConfig(Dict[str, Any]):
class SourcesConfig(dict[str, Any]):
"""Wraps a configuration provided by a source, performing preliminary validation of its parsed
contents, but not interpreting the content. This lets us side-load a "config" within the sources, that
will generally have additional meaning outside of the rules engine, for example, doing event sampling
Expand All @@ -174,7 +174,7 @@ def __init__(self, *sources: Source):
if not all(isinstance(config, dict) for config in configs):
raise TypeError('Config is not a yaml serialized dictionary.')

config: Dict[str, Any] = reduce(deepmerge.merge_or_raise.merge, configs, {})
config: dict[str, Any] = reduce(deepmerge.merge_or_raise.merge, configs, {})

if not isinstance(config, dict):
raise TypeError('Config is not a yaml serialized dictionary.')
Expand All @@ -193,7 +193,7 @@ def source(self) -> Source:
return self._source

@property
def sources(self) -> List[Source]:
def sources(self) -> list[Source]:
return self._sources

def closest_span_for_location(self, location: Sequence[Union[int, str]], key_only: bool) -> Span:
Expand Down
6 changes: 3 additions & 3 deletions osprey_worker/src/osprey/engine/ast/yaml.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, ClassVar, Dict, Hashable, Iterator, List, Type, TypeVar
from typing import TYPE_CHECKING, Any, ClassVar, Hashable, Iterator, Type, TypeVar

import yaml
from osprey.engine.ast.grammar import Source
Expand Down Expand Up @@ -37,11 +37,11 @@ def from_node(cls: Type[_T], orig: Any, node: yaml.Node, source: Source) -> '_T'
return instance


class _ListWithLineAndCol(List[object], WithLineAndCol):
class _ListWithLineAndCol(list[object], WithLineAndCol):
__slots__ = ('line_num', 'column_num', 'source')


class _DictWithLineAndCol(Dict[Hashable, Any], WithLineAndCol):
class _DictWithLineAndCol(dict[Hashable, Any], WithLineAndCol):
__slots__ = ('line_num', 'column_num', 'source')


Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type, Union, cast
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Sequence, Tuple, Type, Union, cast

from osprey.engine.ast.error_utils import SpanWithHint, render_span_context_with_message
from osprey.engine.ast.errors import OspreySyntaxError
Expand All @@ -26,16 +26,16 @@ class ValidationContext:
sources: Sources
"""The sources this validator is validating."""

_validation_results: Dict[Union[Type[BaseValidator], Type[HasResult[Any]]], Any]
_validation_results: dict[Union[Type[BaseValidator], Type[HasResult[Any]]], Any]
"""The stored results of a given validator."""

_errors: List['ValidationError']
_errors: list['ValidationError']
"""Errors emitted by the validators that have run."""

_warnings: List['ValidationWarning']
_warnings: list['ValidationWarning']
"""Warnings emitted by the validators that have run."""

_validator_stack: List[Type[BaseValidator]]
_validator_stack: list[Type[BaseValidator]]
"""Stack of the currently running validators."""

_validator_registry: ValidatorRegistry
Expand All @@ -44,7 +44,7 @@ class ValidationContext:
_udf_registry: 'UDFRegistry'
"""The registry holding the user defined functions that may be used by validators."""

_validator_inputs: Dict[Type[HasInput[Any]], Any]
_validator_inputs: dict[Type[HasInput[Any]], Any]
"""Holds any dynamic inputs that the validators might need."""

_warning_as_error: bool
Expand Down Expand Up @@ -173,7 +173,7 @@ def run(self) -> 'ValidatedSources':
if self._errors or (self._warning_as_error and self._warnings):
raise ValidationFailed(self._errors, self._warnings)

validation_results: Dict[Type[HasResult[Any]], Any] = {}
validation_results: dict[Type[HasResult[Any]], Any] = {}
for k, v in self._validation_results.items():
if issubclass(k, HasResult):
validation_results[k] = v
Expand Down Expand Up @@ -228,7 +228,7 @@ def get_validator_result(self, validator_class: Type[HasResult[T_co]]) -> T_co:

return cast(T_co, result)

def validator_depends_on(self, validator_classes: List[Type[BaseValidator]]) -> None:
def validator_depends_on(self, validator_classes: list[Type[BaseValidator]]) -> None:
"""Call from a validator's `.run()` or `__init__()` function, marking it as being dependent on the provided
`validator_classes` to have run first."""

Expand Down Expand Up @@ -367,15 +367,15 @@ class ValidatedSources:
def __init__(
self,
sources: Sources,
validation_results: Dict[Type[HasResult[Any]], Any],
validation_results: dict[Type[HasResult[Any]], Any],
warnings: Sequence[ValidationWarning],
):
self.sources = sources
self.warnings = warnings
self._validation_results = validation_results

@property
def validation_results(self) -> Dict[Type[HasResult[Any]], Any]:
def validation_results(self) -> dict[Type[HasResult[Any]], Any]:
return self._validation_results

def get_validator_result(self, validator_class: Type[HasResult[T_co]]) -> T_co:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Sequence, Tuple, cast
from typing import TYPE_CHECKING, Sequence, Tuple, cast

from osprey.engine.ast.ast_utils import filter_nodes
from osprey.engine.ast.grammar import Call, Source, Span
Expand Down Expand Up @@ -33,7 +33,7 @@ class ImportsMustNotHaveCycles(BaseValidator, HasResult[ImportGraphResult]):
_udf_node_mapping: UDFNodeMapping
"""Cached result of ValidateCallKwargs"""

_pairs: Dict[Tuple[Source, Source], Span]
_pairs: dict[Tuple[Source, Source], Span]
"""A mapping of XSource imported YSource -> Span, that will be used to build the error message if a cyclic
dependency is found."""

Expand All @@ -55,14 +55,14 @@ def run(self) -> None:
except CyclicDependencyError as e:
# We have our cycle path, and we're going to transform that into a coherent
# error that can then be displayed to the end user.
cycle_path = cast(List[Source], list(e.path))
cycle_path = cast(list[Source], list(e.path))

# In order to do that, we need to figure out all the spans we're going to point to
# in the error message. This means that we need to map the cycle back to the spans
# which the import happened. The mapping of "x imported y" -> span is maintained in `_pairs` and
# is built as we're iterating over the sources.

spans: List[Span] = []
spans: list[Span] = []

# So, we're going to loop over the cycle. Assuming we have the sources "foo", "bar" and "baz", and the
# path is: foo -> bar -> baz, we want to get the spans that show where:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Callable, Dict, List, Type
from typing import Any, Callable, Type

import pytest
from osprey.engine.ast_validator.validation_context import ValidationContext
Expand All @@ -12,7 +12,7 @@
from osprey.engine.udf.base import UDFBase
from osprey.engine.udf.registry import UDFRegistry

pytestmark: List[Callable[[Any], Any]] = [
pytestmark: list[Callable[[Any], Any]] = [
pytest.mark.use_validators([ValidateCallKwargs, UniqueStoredNames]),
pytest.mark.use_osprey_stdlib,
]
Expand Down Expand Up @@ -113,11 +113,11 @@ def test_missing_keyword_argument(run_validation: RunValidationFunction, check_f
class UnexpectedArgsArguments(ArgumentsBase):
required: str
optional: str = 'hello'
extra_arguments: Dict[str, str]
extra_arguments: dict[str, str]


class UnexpectedArgsUdfBase:
def execute(self, execution_context: ExecutionContext, arguments: UnexpectedArgsArguments) -> List[str]:
def execute(self, execution_context: ExecutionContext, arguments: UnexpectedArgsArguments) -> list[str]:
return [
i
for kv in {
Expand All @@ -134,10 +134,10 @@ def execute(self, execution_context: ExecutionContext, arguments: UnexpectedArgs
[{'required': 'hi'}, {'required': 'hi', 'optional': 'hi'}, {'required': 'hi', 'unexpected': 'world'}],
)
def test_allow_unexpected_args(
extra_arguments: Dict[str, str], execute: ExecuteFunction, udf_registry: UDFRegistry
extra_arguments: dict[str, str], execute: ExecuteFunction, udf_registry: UDFRegistry
) -> None:
@udf_registry.register
class UnexpectedArgsUdf(UnexpectedArgsUdfBase, UDFBase[UnexpectedArgsArguments, List[str]]):
class UnexpectedArgsUdf(UnexpectedArgsUdfBase, UDFBase[UnexpectedArgsArguments, list[str]]):
pass

extra_arguments_str = ', '.join(f'{k}={v!r}' for k, v in extra_arguments.items())
Expand All @@ -160,13 +160,13 @@ class UnexpectedArgsUdf(UnexpectedArgsUdfBase, UDFBase[UnexpectedArgsArguments,
],
)
def test_allow_unexpected_respects_type(
extra_arguments: Dict[str, str],
extra_arguments: dict[str, str],
execute: ExecuteFunction,
check_failure: CheckFailureFunction,
udf_registry: UDFRegistry,
) -> None:
@udf_registry.register
class ValidatingUnexpectedArgsUdf(UnexpectedArgsUdfBase, UDFBase[UnexpectedArgsArguments, List[str]]):
class ValidatingUnexpectedArgsUdf(UnexpectedArgsUdfBase, UDFBase[UnexpectedArgsArguments, list[str]]):
def __init__(self, validation_context: ValidationContext, arguments: UnexpectedArgsArguments):
assert set(arguments.get_extra_arguments_ast().keys()) == {'unexpected'}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ error: argument `unexpected` to `ValidatingUnexpectedArgsUdf` has incompatible t
--> main.sml:1:60
|
1 | Ret = ValidatingUnexpectedArgsUdf(required='hi', unexpected=['hello'])
| ^ has type `List[str]`, expected `str`
| ^ has type `list[str]`, expected `str`
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
error: unexpected additional variants to `List[...]`
error: unexpected additional variants to `list[...]`
--> main.sml:1:15
|
1 | Foo: List[str, int, float] = JsonData(path='$.foo')
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
error: unexpected additional variants to `List[...]`
error: unexpected additional variants to `list[...]`
--> main.sml:1:15
|
1 | Foo: List[str, str] = JsonData(path='$.foo')
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from typing import TYPE_CHECKING, DefaultDict, Dict, List
from typing import TYPE_CHECKING, DefaultDict

from osprey.engine.ast.ast_utils import filter_nodes
from osprey.engine.ast.grammar import Name, Span, Store
Expand All @@ -10,7 +10,7 @@
from ..validation_context import ValidationContext


IdentifierIndex = Dict[str, Span]
IdentifierIndex = dict[str, Span]


class UniqueStoredNames(BaseValidator, HasResult[IdentifierIndex]):
Expand All @@ -23,8 +23,8 @@ def __init__(self, context: 'ValidationContext'):
self.identifier_index: IdentifierIndex = {}

def run(self) -> None:
stored_global_names: DefaultDict[str, List[Span]] = defaultdict(list)
stored_local_names_by_file: DefaultDict[str, DefaultDict[str, List[Span]]] = defaultdict(
stored_global_names: DefaultDict[str, list[Span]] = defaultdict(list)
stored_local_names_by_file: DefaultDict[str, DefaultDict[str, list[Span]]] = defaultdict(
lambda: defaultdict(list)
)
# Iterate over the ast, finding name nodes that are using the Store context.
Expand Down
Loading
Loading