diff --git a/CHANGELOG.md b/CHANGELOG.md index c7d7677b..a7567110 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,3 +24,4 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### 🐛 Bug fixes - Default to selecting all for event stream ([#194](https://github.com/roostorg/osprey/pull/194) by [@chimosky](https://github.com/chimosky)) +- Fix failed UDF query ([#233](https://github.com/roostorg/osprey/pull/233) by [@chimosky](https://github.com/chimosky)) diff --git a/osprey_worker/src/osprey/engine/ast_validator/validation_context.py b/osprey_worker/src/osprey/engine/ast_validator/validation_context.py index 1fc2ce95..2f33b176 100644 --- a/osprey_worker/src/osprey/engine/ast_validator/validation_context.py +++ b/osprey_worker/src/osprey/engine/ast_validator/validation_context.py @@ -62,8 +62,8 @@ def __init__( self._validation_results = {} self._errors = [] self._warnings = [] - self._validator_registry = validator_registry or ValidatorRegistry.get_instance() self._udf_registry = udf_registry + self._validator_registry = validator_registry or ValidatorRegistry.get_instance() self._validator_stack = [] self._validator_inputs = {} self._warning_as_error = warning_as_error diff --git a/osprey_worker/src/osprey/engine/query_language/__init__.py b/osprey_worker/src/osprey/engine/query_language/__init__.py index c21cbc47..ca1f4e7c 100644 --- a/osprey_worker/src/osprey/engine/query_language/__init__.py +++ b/osprey_worker/src/osprey/engine/query_language/__init__.py @@ -1,5 +1,6 @@ from osprey.engine.ast.sources import SOURCE_ENTRY_POINT_PATH, Sources from osprey.engine.ast_validator.validation_context import ValidatedSources, ValidationContext, ValidationFailed +from osprey.engine.ast_validator.validator_registry import ValidatorRegistry from osprey.engine.ast_validator.validators.unique_stored_names import UniqueStoredNames from osprey.engine.ast_validator.validators.validate_static_types import ValidateStaticTypes from osprey.engine.ast_validator.validators.variables_must_be_defined import VariablesMustBeDefined @@ -9,6 +10,19 @@ from osprey.engine.utils.imports import import_all_direct_children +def _consolidate_registry(validator_registry: ValidatorRegistry) -> ValidatorRegistry: + validators = { + validator + for validator in ValidatorRegistry.get_instance().get_validators() + if not validator.exclude_from_query_validation + } + registry_validators = set() + for validator in validator_registry.get_validators(): + registry_validators.add(validator) + + return ValidatorRegistry.from_validator_classes(validators.union(registry_validators)) + + def parse_query_to_validated_ast(query: str, rules_sources: ValidatedSources) -> ValidatedSources: """ Takes a string query (e.g. 'A == B or C == D', 'C <= 3 and D not in [4, 5, 6]') @@ -16,9 +30,10 @@ def parse_query_to_validated_ast(query: str, rules_sources: ValidatedSources) -> """ try: + consolidated_registry = _consolidate_registry(REGISTRY) sources = Sources.from_dict({SOURCE_ENTRY_POINT_PATH: 'Query = ' + query}) validation_context = ( - ValidationContext(sources=sources, udf_registry=UDF_REGISTRY, validator_registry=REGISTRY) + ValidationContext(sources=sources, udf_registry=UDF_REGISTRY, validator_registry=consolidated_registry) .set_validator_input( VariablesMustBeDefined, set(rules_sources.get_validator_result(UniqueStoredNames).keys()) ) diff --git a/osprey_worker/src/osprey/engine/query_language/tests/test_ast_druid_translator.py b/osprey_worker/src/osprey/engine/query_language/tests/test_ast_druid_translator.py index 0f06bb03..5c50458d 100644 --- a/osprey_worker/src/osprey/engine/query_language/tests/test_ast_druid_translator.py +++ b/osprey_worker/src/osprey/engine/query_language/tests/test_ast_druid_translator.py @@ -77,7 +77,7 @@ def test_parses_query_with_regex( make_rules_sources: MakeRulesSourcesFunction, check_json_output: CheckJsonOutputFunction ) -> None: validated_sources = parse_query_to_validated_ast( - "RegexMatch(item=A, regex='^foo$') and C == D", + "RegexMatch(target=A, pattern='^foo$') and C == D", make_rules_sources([('A', '"hello"'), 'C', 'D']), ) transformed_query = DruidQueryTransformer(validated_sources=validated_sources).transform() diff --git a/osprey_worker/src/osprey/engine/query_language/tests/test_regex_match.py b/osprey_worker/src/osprey/engine/query_language/tests/test_regex_match.py index 622f4bef..19791c78 100644 --- a/osprey_worker/src/osprey/engine/query_language/tests/test_regex_match.py +++ b/osprey_worker/src/osprey/engine/query_language/tests/test_regex_match.py @@ -17,18 +17,18 @@ def test_regex_match_accepts_valid_call(run_validation: RunValidationFunction) -> None: - run_validation("RegexMatch(item=A, regex='^foo$')") + run_validation("RegexMatch(target=A, pattern='^foo$')") def test_regex_match_fails_with_invalid_regex( run_validation: RunValidationFunction, check_failure: CheckFailureFunction ) -> None: with check_failure(): - run_validation("RegexMatch(item=A, regex='[')") + run_validation("RegexMatch(target=A, pattern='[')") def test_regex_match_fails_with_invalid_item_node( run_validation: RunValidationFunction, check_failure: CheckFailureFunction ) -> None: with check_failure(): - run_validation("RegexMatch(item='Jake', regex='^foo$')") + run_validation("RegexMatch(target='Jake', pattern='^foo$')") diff --git a/osprey_worker/src/osprey/engine/query_language/udfs/regex_match.py b/osprey_worker/src/osprey/engine/query_language/udfs/regex_match.py index c8cce62d..d56b0cd0 100644 --- a/osprey_worker/src/osprey/engine/query_language/udfs/regex_match.py +++ b/osprey_worker/src/osprey/engine/query_language/udfs/regex_match.py @@ -9,8 +9,8 @@ class Arguments(ArgumentsBase): - item: str - regex: ConstExpr[str] + target: str + pattern: ConstExpr[str] @register @@ -20,23 +20,23 @@ class RegexMatch(QueryUdfBase[Arguments, bool]): # Examples - `RegexMatch(item=UserName, regex='^jake')` + `RegexMatch(target=UserName, pattern='^jake')` """ def __init__(self, validation_context: ValidationContext, arguments: Arguments): super().__init__(validation_context, arguments) - regex = arguments.regex + regex = arguments.pattern with regex.attribute_errors(): re.compile(regex.value) self.regex = regex.value - item_node = arguments.get_argument_ast('item') + item_node = arguments.get_argument_ast('target') if isinstance(item_node, grammar.Name): self.item = item_node.identifier else: self.item = '' validation_context.add_error( - message='expected variable', span=item_node.span, hint='argument `item` must be a variable' + message='expected variable', span=item_node.span, hint='argument `target` must be a variable' ) def to_druid_query(self) -> Dict[str, object]: diff --git a/osprey_worker/src/osprey/engine/stdlib/udfs/string.py b/osprey_worker/src/osprey/engine/stdlib/udfs/string.py index 56e5a945..34b872a1 100644 --- a/osprey_worker/src/osprey/engine/stdlib/udfs/string.py +++ b/osprey_worker/src/osprey/engine/stdlib/udfs/string.py @@ -410,7 +410,7 @@ def execute(self, execution_context: ExecutionContext, arguments: StringArgument # split the message into individual tokens as based on a modified URL regex from messages_common. # should capture space based links and markdown based links without duplication. potential_urls: Iterator[Optional[ParseResult]] = ( - _safe_urlparse(token) for token in re.findall('(https?:\/\/[^\/\s][^\s\)>]+)', arguments.s) + _safe_urlparse(token) for token in re.findall(r'(https?:\/\/[^\/\s][^\s\)>]+)', arguments.s) ) # filter out any tokens that do not have a scheme or a domain (or failed to parse) @@ -444,7 +444,7 @@ def execute(self, execution_context: ExecutionContext, arguments: StringArgument # split the message into individual tokens as based on a modified URL regex from messages_common. # should capture space based links and markdown based links without duplication. potential_urls: Iterator[Optional[ParseResult]] = ( - _safe_urlparse(token) for token in re.findall('(https?:\/\/[^\/\s][^\s\)>]+)', arguments.s) + _safe_urlparse(token) for token in re.findall(r'(https?:\/\/[^\/\s][^\s\)>]+)', arguments.s) ) # filter out any tokens that do not have a scheme or a domain (or failed to parse)