diff --git a/hatchet/graphframe.py b/hatchet/graphframe.py index d041597b..d403ee8d 100644 --- a/hatchet/graphframe.py +++ b/hatchet/graphframe.py @@ -8,6 +8,7 @@ import sys import traceback from collections import defaultdict +import warnings import multiprocess as mp import numpy as np @@ -472,7 +473,8 @@ def filter( update_inc_cols=True, num_procs=mp.cpu_count(), rec_limit=1000, - multi_index_mode="off", + predicate_row_aggregator=None, + multi_index_mode=None, ): """Filter the dataframe using a user-supplied function. @@ -484,7 +486,20 @@ def filter( update_inc_cols (boolean, optional): if True, update inclusive columns when performing squash. rec_limit: set Python recursion limit, increase if running into recursion depth errors) (default: 1000). + predicate_row_aggregator (str or Callable, optional): function to use in Query Language + to merge multiple predicate results for each node into a single boolean. When providing + a string value, the following are accepted: "all" (equivalent to Python 'all'), "any" + (equivalent to Python 'any'), "off" (no aggregation) + multi_index_mode: deprecated alias for "predicate_row_aggregator" """ + if multi_index_mode is not None: + warnings.warn( + "'multi_index_mode' parameter is deprecated. Use 'predicate_row_aggregator' instead", + DeprecationWarning, + ) + if predicate_row_aggregator is None: + predicate_row_aggregator = multi_index_mode + sys.setrecursionlimit(rec_limit) dataframe_copy = self.dataframe.copy() @@ -537,15 +552,17 @@ def filter( # If a raw Object-dialect query is provided (not already passed to ObjectQuery), # create a new ObjectQuery object. if isinstance(filter_obj, list): - query = ObjectQuery(filter_obj, multi_index_mode) + query = ObjectQuery(filter_obj) # If a raw String-dialect query is provided (not already passed to StringQuery), # create a new StringQuery object. elif isinstance(filter_obj, str): - query = parse_string_dialect(filter_obj, multi_index_mode) + query = parse_string_dialect(filter_obj) # If an old-style query is provided, extract the underlying new-style query. elif issubclass(type(filter_obj), AbstractQuery): query = filter_obj._get_new_query() - query_matches = self.query_engine.apply(query, self.graph, self.dataframe) + query_matches = self.query_engine.apply( + query, self.graph, self.dataframe, predicate_row_aggregator + ) # match_set = list(set().union(*query_matches)) # filtered_df = dataframe_copy.loc[dataframe_copy["node"].isin(match_set)] filtered_df = dataframe_copy.loc[dataframe_copy["node"].isin(query_matches)] diff --git a/hatchet/query/compat.py b/hatchet/query/compat.py index d62a0c5c..3820060e 100644 --- a/hatchet/query/compat.py +++ b/hatchet/query/compat.py @@ -125,7 +125,7 @@ def apply(self, gf): (list): A list of nodes representing the result of the query """ true_query = self._get_new_query() - return COMPATABILITY_ENGINE.apply(true_query, gf.graph, gf.dataframe) + return COMPATABILITY_ENGINE.apply(true_query, gf.graph, gf.dataframe, "off") def _get_new_query(self): """Gets all the underlying 'new-style' queries in this object. @@ -322,7 +322,9 @@ def apply(self, gf): Returns: (list): A list representing the set of nodes from paths that match this query """ - return COMPATABILITY_ENGINE.apply(self.true_query, gf.graph, gf.dataframe) + return COMPATABILITY_ENGINE.apply( + self.true_query, gf.graph, gf.dataframe, "off" + ) def _get_new_query(self): """Get all the underlying 'new-style' query in this object. diff --git a/hatchet/query/engine.py b/hatchet/query/engine.py index 9717e240..50ce9410 100644 --- a/hatchet/query/engine.py +++ b/hatchet/query/engine.py @@ -4,6 +4,7 @@ # SPDX-License-Identifier: MIT from itertools import groupby +from collections.abc import Iterable import pandas as pd from .errors import InvalidQueryFilter @@ -14,6 +15,22 @@ from .string_dialect import parse_string_dialect +def _all_aggregator(pred_result): + if isinstance(pred_result, Iterable): + return all(pred_result) + elif isinstance(pred_result, pd.Series): + return pred_result.all() + return pred_result + + +def _any_aggregator(pred_result): + if isinstance(pred_result, Iterable): + return any(pred_result) + elif isinstance(pred_result, pd.Series): + return pred_result.any() + return pred_result + + class QueryEngine: """Class for applying queries to GraphFrames.""" @@ -25,7 +42,7 @@ def reset_cache(self): """Resets the cache in the QueryEngine.""" self.search_cache = {} - def apply(self, query, graph, dframe): + def apply(self, query, graph, dframe, predicate_row_aggregator): """Apply the query to a GraphFrame. Arguments: @@ -37,11 +54,28 @@ def apply(self, query, graph, dframe): (list): A list representing the set of nodes from paths that match the query """ if issubclass(type(query), Query): + aggregator = predicate_row_aggregator + if predicate_row_aggregator is None: + aggregator = query.default_aggregator + if aggregator == "all": + aggregator = _all_aggregator + elif aggregator == "any": + aggregator = _any_aggregator + elif aggregator == "off": + if isinstance(dframe.index, pd.MultiIndex): + raise ValueError( + "'predicate_row_aggregator' cannot be 'off' when the DataFrame has a row multi-index" + ) + aggregator = None + elif not callable(aggregator): + raise ValueError( + "Invalid value provided for 'predicate_row_aggregator'" + ) self.reset_cache() matches = [] visited = set() for root in sorted(graph.roots, key=traversal_order): - self._apply_impl(query, dframe, root, visited, matches) + self._apply_impl(query, dframe, aggregator, root, visited, matches) assert len(visited) == len(graph) matched_node_set = list(set().union(*matches)) # return matches @@ -54,12 +88,14 @@ def apply(self, query, graph, dframe): subq_obj = ObjectQuery(subq) elif isinstance(subq, str): subq_obj = parse_string_dialect(subq) - results.append(self.apply(subq_obj, graph, dframe)) + results.append( + self.apply(subq_obj, graph, dframe, predicate_row_aggregator) + ) return query._apply_op_to_results(results, graph) else: raise TypeError("Invalid query data type ({})".format(str(type(query)))) - def _cache_node(self, node, query, dframe): + def _cache_node(self, node, query, dframe, predicate_row_aggregator): """Cache (Memoize) the parts of the query that the node matches. Arguments: @@ -78,11 +114,19 @@ def _cache_node(self, node, query, dframe): row = dframe.xs(node, level="node", drop_level=False) else: row = dframe.loc[node] - if filter_func(row): + predicate_result = filter_func(row) + if ( + not isinstance(predicate_result, bool) + and predicate_row_aggregator is not None + ): + predicate_result = predicate_row_aggregator(predicate_result) + if predicate_result: matches.append(i) self.search_cache[node._hatchet_nid] = matches - def _match_0_or_more(self, query, dframe, node, wcard_idx): + def _match_0_or_more( + self, query, dframe, predicate_row_aggregator, node, wcard_idx + ): """Process a "*" predicate in the query on a subgraph. Arguments: @@ -98,7 +142,7 @@ def _match_0_or_more(self, query, dframe, node, wcard_idx): """ # Cache the node if it's not already cached if node._hatchet_nid not in self.search_cache: - self._cache_node(node, query, dframe) + self._cache_node(node, query, dframe, predicate_row_aggregator) # If the node matches with the next non-wildcard query node, # end the recursion and return the node. if wcard_idx + 1 in self.search_cache[node._hatchet_nid]: @@ -113,7 +157,9 @@ def _match_0_or_more(self, query, dframe, node, wcard_idx): return [[node]] return None for child in sorted(node.children, key=traversal_order): - sub_match = self._match_0_or_more(query, dframe, child, wcard_idx) + sub_match = self._match_0_or_more( + query, dframe, predicate_row_aggregator, child, wcard_idx + ) if sub_match is not None: matches.extend(sub_match) if len(matches) == 0: @@ -128,7 +174,7 @@ def _match_0_or_more(self, query, dframe, node, wcard_idx): return [[]] return None - def _match_1(self, query, dframe, node, idx): + def _match_1(self, query, dframe, predicate_row_aggregator, node, idx): """Process a "." predicate in the query on a subgraph. Arguments: @@ -142,12 +188,12 @@ def _match_1(self, query, dframe, node, idx): Will return None if there are no matches for the "." predicate. """ if node._hatchet_nid not in self.search_cache: - self._cache_node(node, query, dframe) + self._cache_node(node, query, dframe, predicate_row_aggregator) matches = [] for child in sorted(node.children, key=traversal_order): # Cache the node if it's not already cached if child._hatchet_nid not in self.search_cache: - self._cache_node(child, query, dframe) + self._cache_node(child, query, dframe, predicate_row_aggregator) if idx in self.search_cache[child._hatchet_nid]: matches.append([child]) # To be consistent with the other matching functions, return @@ -156,7 +202,9 @@ def _match_1(self, query, dframe, node, idx): return None return matches - def _match_pattern(self, query, dframe, pattern_root, match_idx): + def _match_pattern( + self, query, dframe, predicate_row_aggregator, pattern_root, match_idx + ): """Try to match the query pattern starting at the provided root node. Arguments: @@ -186,7 +234,9 @@ def _match_pattern(self, query, dframe, pattern_root, match_idx): # Get the portion of the subgraph that matches the next # part of the query. if wcard == ".": - s = self._match_1(query, dframe, m[-1], pattern_idx) + s = self._match_1( + query, dframe, predicate_row_aggregator, m[-1], pattern_idx + ) if s is None: sub_match.append(s) else: @@ -196,7 +246,13 @@ def _match_pattern(self, query, dframe, pattern_root, match_idx): sub_match.append([]) else: for child in sorted(m[-1].children, key=traversal_order): - s = self._match_0_or_more(query, dframe, child, pattern_idx) + s = self._match_0_or_more( + query, + dframe, + predicate_row_aggregator, + child, + pattern_idx, + ) if s is None: sub_match.append(s) else: @@ -221,7 +277,9 @@ def _match_pattern(self, query, dframe, pattern_root, match_idx): pattern_idx += 1 return matches - def _apply_impl(self, query, dframe, node, visited, matches): + def _apply_impl( + self, query, dframe, predicate_row_aggregator, node, visited, matches + ): """Traverse the subgraph with the specified root, and collect all paths that match the query. Arguments: @@ -237,21 +295,27 @@ def _apply_impl(self, query, dframe, node, visited, matches): return # Cache the node if it's not already cached if node._hatchet_nid not in self.search_cache: - self._cache_node(node, query, dframe) + self._cache_node(node, query, dframe, predicate_row_aggregator) # If the node matches the starting/root node of the query, # try to get all query matches in the subgraph rooted at # this node. if query.query_pattern[0][0] == "*": if 1 in self.search_cache[node._hatchet_nid]: - sub_match = self._match_pattern(query, dframe, node, 1) + sub_match = self._match_pattern( + query, dframe, predicate_row_aggregator, node, 1 + ) if sub_match is not None: matches.extend(sub_match) if 0 in self.search_cache[node._hatchet_nid]: - sub_match = self._match_pattern(query, dframe, node, 0) + sub_match = self._match_pattern( + query, dframe, predicate_row_aggregator, node, 0 + ) if sub_match is not None: matches.extend(sub_match) # Note that the node is now visited. visited.add(node._hatchet_nid) # Continue the Depth First Search. for child in sorted(node.children, key=traversal_order): - self._apply_impl(query, dframe, child, visited, matches) + self._apply_impl( + query, dframe, predicate_row_aggregator, child, visited, matches + ) diff --git a/hatchet/query/errors.py b/hatchet/query/errors.py index 248c6a28..ebb835f0 100644 --- a/hatchet/query/errors.py +++ b/hatchet/query/errors.py @@ -18,9 +18,3 @@ class RedundantQueryFilterWarning(Warning): class BadNumberNaryQueryArgs(Exception): """Raised when a query filter does not have a valid syntax""" - - -class MultiIndexModeMismatch(Exception): - """Raised when an ObjectQuery or StringQuery object - is set to use multi-indexed data, but no multi-indexed - data is provided""" diff --git a/hatchet/query/object_dialect.py b/hatchet/query/object_dialect.py index daf55c65..f4767565 100644 --- a/hatchet/query/object_dialect.py +++ b/hatchet/query/object_dialect.py @@ -13,21 +13,11 @@ import re import sys -from .errors import InvalidQueryPath, InvalidQueryFilter, MultiIndexModeMismatch +from .errors import InvalidQueryPath, InvalidQueryFilter from .query import Query -def _process_multi_index_mode(apply_result, multi_index_mode): - if multi_index_mode == "any": - return apply_result.any() - if multi_index_mode == "all": - return apply_result.all() - raise ValueError( - "Multi-Index Mode for the Object-based dialect must be either 'any' or 'all'" - ) - - -def _process_predicate(attr_filter, multi_index_mode): +def _process_predicate(attr_filter): """Converts high-level API attribute filter to a lambda""" compops = ("<", ">", "==", ">=", "<=", "<>", "!=") # , @@ -126,12 +116,6 @@ def filter_single_series(df_row, key, single_value): return matches def filter_dframe(df_row): - if multi_index_mode == "off": - raise MultiIndexModeMismatch( - "The ObjectQuery's 'multi_index_mode' argument \ - cannot be set to 'off' when using multi-indexed data" - ) - def filter_single_dframe(node, df_row, key, single_value): if key == "depth": if isinstance(single_value, str) and single_value.lower().startswith( @@ -164,21 +148,18 @@ def filter_single_dframe(node, df_row, key, single_value): raise InvalidQueryFilter( "Value for attribute {} must be a string.".format(key) ) - apply_ret = df_row[key].apply( + return df_row[key].apply( lambda x: re.match(single_value + r"\Z", x) is not None ) - return _process_multi_index_mode(apply_ret, multi_index_mode) if is_numeric_dtype(df_row[key]): if isinstance(single_value, str) and single_value.lower().startswith( compops ): - apply_ret = df_row[key].apply( + return df_row[key].apply( lambda x: eval("{} {}".format(x, single_value)) ) - return _process_multi_index_mode(apply_ret, multi_index_mode) if isinstance(single_value, Real): - apply_ret = df_row[key].apply(lambda x: x == single_value).any() - return _process_multi_index_mode(apply_ret, multi_index_mode) + return df_row[key].apply(lambda x: x == single_value).any() raise InvalidQueryFilter( "Attribute {} has a numeric type. Valid filters for this attribute are a string starting with a comparison operator or a real number.".format( key @@ -218,7 +199,7 @@ def filter_choice(df_row): class ObjectQuery(Query): """Class for representing and parsing queries using the Object-based dialect.""" - def __init__(self, query, multi_index_mode="off"): + def __init__(self, query): """Builds a new ObjectQuery from an instance of the Object-based dialect syntax. Arguments: @@ -229,18 +210,15 @@ def __init__(self, query, multi_index_mode="off"): else: super().__init__() assert isinstance(query, list) - assert multi_index_mode in ["off", "all", "any"] for qnode in query: if isinstance(qnode, dict): - self._add_node(predicate=_process_predicate(qnode, multi_index_mode)) + self._add_node(predicate=_process_predicate(qnode)) elif isinstance(qnode, str) or isinstance(qnode, int): self._add_node(quantifer=qnode) elif isinstance(qnode, tuple): assert isinstance(qnode[1], dict) if isinstance(qnode[0], str) or isinstance(qnode[0], int): - self._add_node( - qnode[0], _process_predicate(qnode[1], multi_index_mode) - ) + self._add_node(qnode[0], _process_predicate(qnode[1])) else: raise InvalidQueryPath( "The first value of a tuple entry in a path must be either a string or integer." @@ -249,3 +227,4 @@ def __init__(self, query, multi_index_mode="off"): raise InvalidQueryPath( "A query path must be a list containing String, Integer, Dict, or Tuple elements" ) + self.default_aggregator = "all" diff --git a/hatchet/query/query.py b/hatchet/query/query.py index 39f17743..28330311 100644 --- a/hatchet/query/query.py +++ b/hatchet/query/query.py @@ -12,6 +12,7 @@ class Query(object): def __init__(self): """Create new Query""" self.query_pattern = [] + self.default_aggregator = "off" def match(self, quantifier=".", predicate=lambda row: True): """Start a query with a root node described by the arguments. diff --git a/hatchet/query/string_dialect.py b/hatchet/query/string_dialect.py index 791128fe..6990d768 100644 --- a/hatchet/query/string_dialect.py +++ b/hatchet/query/string_dialect.py @@ -97,7 +97,7 @@ def filter_check_types(type_check, df_row, filt_lambda): class StringQuery(Query): """Class for representing and parsing queries using the String-based dialect.""" - def __init__(self, cypher_query, multi_index_mode="off"): + def __init__(self, cypher_query): """Builds a new StringQuery object representing a query in the String-based dialect. Arguments: @@ -107,11 +107,9 @@ def __init__(self, cypher_query, multi_index_mode="off"): super(StringQuery, self).__init__() else: super().__init__() - assert multi_index_mode in ["off", "all", "any"] - self.multi_index_mode = multi_index_mode - model = None + self.model = None try: - model = cypher_query_mm.model_from_str(cypher_query) + self.model = cypher_query_mm.model_from_str(cypher_query) except TextXError as e: # TODO Change to a "raise-from" expression when Python 2.7 support is dropped raise InvalidQueryPath( @@ -121,9 +119,13 @@ def __init__(self, cypher_query, multi_index_mode="off"): ) self.wcards = [] self.wcard_pos = {} - self._parse_path(model.path_expr) + self.default_aggregator = "all" + + def parse(self, dframe): + has_multi_index = isinstance(dframe.index, pd.MultiIndex) + self._parse_path(self.model.path_expr) self.filters = [[] for _ in self.wcards] - self._parse_conditions(model.cond_expr) + self._parse_conditions(self.model.cond_expr, has_multi_index) self.lambda_filters = [None for _ in self.wcards] self._build_lambdas() self._build_query() @@ -188,7 +190,7 @@ def _parse_path(self, path_obj): self.wcard_pos[n.name] = idx idx += 1 - def _parse_conditions(self, cond_expr): + def _parse_conditions(self, cond_expr, has_multi_index): """Top level function for parsing the WHERE statement of a String-based query. """ @@ -196,9 +198,9 @@ def _parse_conditions(self, cond_expr): for cond in conditions: converted_condition = None if self._is_unary_cond(cond): - converted_condition = self._parse_unary_cond(cond) + converted_condition = self._parse_unary_cond(cond, has_multi_index) elif self._is_binary_cond(cond): - converted_condition = self._parse_binary_cond(cond) + converted_condition = self._parse_binary_cond(cond, has_multi_index) else: raise RuntimeError("Bad Condition") self.filters[self.wcard_pos[converted_condition[1]]].append( @@ -226,59 +228,67 @@ def _is_binary_cond(self, obj): return True return False - def _parse_binary_cond(self, obj): + def _parse_binary_cond(self, obj, has_multi_index): """Top level function for parsing binary predicates.""" if cname(obj) == "AndCond": - return self._parse_and_cond(obj) + return self._parse_and_cond(obj, has_multi_index) if cname(obj) == "OrCond": - return self._parse_or_cond(obj) + return self._parse_or_cond(obj, has_multi_index) raise RuntimeError("Bad Binary Condition") - def _parse_or_cond(self, obj): + def _parse_or_cond(self, obj, has_multi_index): """Top level function for parsing predicates combined with logical OR.""" - converted_subcond = self._parse_unary_cond(obj.subcond) + converted_subcond = self._parse_unary_cond(obj.subcond, has_multi_index) converted_subcond[0] = "or" return converted_subcond - def _parse_and_cond(self, obj): + def _parse_and_cond(self, obj, has_multi_index): """Top level function for parsing predicates combined with logical AND.""" - converted_subcond = self._parse_unary_cond(obj.subcond) + converted_subcond = self._parse_unary_cond(obj.subcond, has_multi_index) converted_subcond[0] = "and" return converted_subcond - def _parse_unary_cond(self, obj): + def _parse_unary_cond(self, obj, has_multi_index): """Top level function for parsing unary predicates.""" if cname(obj) == "NotCond": - return self._parse_not_cond(obj) - return self._parse_single_cond(obj) + return self._parse_not_cond(obj, has_multi_index) + return self._parse_single_cond(obj, has_multi_index) - def _parse_not_cond(self, obj): + def _parse_not_cond(self, obj, has_multi_index): """Parse predicates containing the logical NOT operator.""" - converted_subcond = self._parse_single_cond(obj.subcond) + converted_subcond = self._parse_single_cond(obj.subcond, has_multi_index) converted_subcond[2] = "not {}".format(converted_subcond[2]) return converted_subcond - def _run_method_based_on_multi_idx_mode(self, method_name, obj): + def _run_method_based_on_multi_index(self, method_name, obj, has_multi_index): real_method_name = method_name - if self.multi_index_mode != "off": + if has_multi_index: real_method_name = method_name + "_multi_idx" method = eval("StringQuery.{}".format(real_method_name)) return method(self, obj) - def _parse_single_cond(self, obj): + def _parse_single_cond(self, obj, has_multi_index): """Top level function for parsing individual numeric or string predicates.""" if self._is_str_cond(obj): - return self._parse_str(obj) + return self._parse_str(obj, has_multi_index) if self._is_num_cond(obj): - return self._parse_num(obj) + return self._parse_num(obj, has_multi_index) if cname(obj) == "NoneCond": - return self._run_method_based_on_multi_idx_mode("_parse_none", obj) + return self._run_method_based_on_multi_index( + "_parse_none", obj, has_multi_index + ) if cname(obj) == "NotNoneCond": - return self._run_method_based_on_multi_idx_mode("_parse_not_none", obj) + return self._run_method_based_on_multi_index( + "_parse_not_none", obj, has_multi_index + ) if cname(obj) == "LeafCond": - return self._run_method_based_on_multi_idx_mode("_parse_leaf", obj) + return self._run_method_based_on_multi_index( + "_parse_leaf", obj, has_multi_index + ) if cname(obj) == "NotLeafCond": - return self._run_method_based_on_multi_idx_mode("_parse_not_leaf", obj) + return self._run_method_based_on_multi_index( + "_parse_not_leaf", obj, has_multi_index + ) raise RuntimeError("Bad Single Condition") def _parse_none(self, obj): @@ -308,11 +318,6 @@ def _parse_none(self, obj): None, ] - def _add_aggregation_call_to_multi_idx_predicate(self, predicate): - if self.multi_index_mode == "any": - return predicate + ".any()" - return predicate + ".all()" - def _parse_none_multi_idx(self, obj): if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ @@ -331,12 +336,10 @@ def _parse_none_multi_idx(self, obj): return [ None, obj.name, - self._add_aggregation_call_to_multi_idx_predicate( - "df_row[{}].apply(lambda elem: elem is None)".format( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ) + "df_row[{}].apply(lambda elem: elem is None)".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) ), None, ] @@ -386,12 +389,10 @@ def _parse_not_none_multi_idx(self, obj): return [ None, obj.name, - self._add_aggregation_call_to_multi_idx_predicate( - "df_row[{}].apply(lambda elem: elem is not None)".format( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ) + "df_row[{}].apply(lambda elem: elem is not None)".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) ), None, ] @@ -458,22 +459,30 @@ def _is_num_cond(self, obj): return True return False - def _parse_str(self, obj): + def _parse_str(self, obj, has_multi_index): """Function that redirects processing of string predicates to the correct function. """ if cname(obj) == "StringEq": - return self._run_method_based_on_multi_idx_mode("_parse_str_eq", obj) + return self._run_method_based_on_multi_index( + "_parse_str_eq", obj, has_multi_index + ) if cname(obj) == "StringStartsWith": - return self._run_method_based_on_multi_idx_mode( - "_parse_str_starts_with", obj + return self._run_method_based_on_multi_index( + "_parse_str_starts_with", obj, has_multi_index ) if cname(obj) == "StringEndsWith": - return self._run_method_based_on_multi_idx_mode("_parse_str_ends_with", obj) + return self._run_method_based_on_multi_index( + "_parse_str_ends_with", obj, has_multi_index + ) if cname(obj) == "StringContains": - return self._run_method_based_on_multi_idx_mode("_parse_str_contains", obj) + return self._run_method_based_on_multi_index( + "_parse_str_contains", obj, has_multi_index + ) if cname(obj) == "StringMatch": - return self._run_method_based_on_multi_idx_mode("_parse_str_match", obj) + return self._run_method_based_on_multi_index( + "_parse_str_match", obj, has_multi_index + ) raise RuntimeError("Bad String Op Class") def _parse_str_eq(self, obj): @@ -500,15 +509,13 @@ def _parse_str_eq_multi_idx(self, obj): return [ None, obj.name, - self._add_aggregation_call_to_multi_idx_predicate( - 'df_row[{}].apply(lambda elem: elem == "{}")'.format( - ( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ), - obj.val, - ) + 'df_row[{}].apply(lambda elem: elem == "{}")'.format( + ( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + obj.val, ), "is_string_dtype(df_row[{}])".format( str(tuple(obj.prop.ids)) @@ -541,15 +548,13 @@ def _parse_str_starts_with_multi_idx(self, obj): return [ None, obj.name, - self._add_aggregation_call_to_multi_idx_predicate( - 'df_row[{}].apply(lambda elem: elem.startswith("{}"))'.format( - ( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ), - obj.val, - ) + 'df_row[{}].apply(lambda elem: elem.startswith("{}"))'.format( + ( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + obj.val, ), "is_string_dtype(df_row[{}])".format( str(tuple(obj.prop.ids)) @@ -582,15 +587,13 @@ def _parse_str_ends_with_multi_idx(self, obj): return [ None, obj.name, - self._add_aggregation_call_to_multi_idx_predicate( - 'df_row[{}].apply(lambda elem: elem.endswith("{}"))'.format( - ( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ), - obj.val, - ) + 'df_row[{}].apply(lambda elem: elem.endswith("{}"))'.format( + ( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + obj.val, ), "is_string_dtype(df_row[{}])".format( str(tuple(obj.prop.ids)) @@ -623,15 +626,13 @@ def _parse_str_contains_multi_idx(self, obj): return [ None, obj.name, - self._add_aggregation_call_to_multi_idx_predicate( - 'df_row[{}].apply(lambda elem: "{}" in elem)'.format( - ( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ), - obj.val, - ) + 'df_row[{}].apply(lambda elem: "{}" in elem)'.format( + ( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + obj.val, ), "is_string_dtype(df_row[{}])".format( str(tuple(obj.prop.ids)) @@ -664,15 +665,13 @@ def _parse_str_match_multi_idx(self, obj): return [ None, obj.name, - self._add_aggregation_call_to_multi_idx_predicate( - 'df_row[{}].apply(lambda elem: re.match("{}", elem) is not None)'.format( - ( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ), - obj.val, - ) + 'df_row[{}].apply(lambda elem: re.match("{}", elem) is not None)'.format( + ( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + obj.val, ), "is_string_dtype(df_row[{}])".format( str(tuple(obj.prop.ids)) @@ -681,28 +680,46 @@ def _parse_str_match_multi_idx(self, obj): ), ] - def _parse_num(self, obj): + def _parse_num(self, obj, has_multi_index): """Function that redirects processing of numeric predicates to the correct function. """ if cname(obj) == "NumEq": - return self._run_method_based_on_multi_idx_mode("_parse_num_eq", obj) + return self._run_method_based_on_multi_index( + "_parse_num_eq", obj, has_multi_index + ) if cname(obj) == "NumLt": - return self._run_method_based_on_multi_idx_mode("_parse_num_lt", obj) + return self._run_method_based_on_multi_index( + "_parse_num_lt", obj, has_multi_index + ) if cname(obj) == "NumGt": - return self._run_method_based_on_multi_idx_mode("_parse_num_gt", obj) + return self._run_method_based_on_multi_index( + "_parse_num_gt", obj, has_multi_index + ) if cname(obj) == "NumLte": - return self._run_method_based_on_multi_idx_mode("_parse_num_lte", obj) + return self._run_method_based_on_multi_index( + "_parse_num_lte", obj, has_multi_index + ) if cname(obj) == "NumGte": - return self._run_method_based_on_multi_idx_mode("_parse_num_gte", obj) + return self._run_method_based_on_multi_index( + "_parse_num_gte", obj, has_multi_index + ) if cname(obj) == "NumNan": - return self._run_method_based_on_multi_idx_mode("_parse_num_nan", obj) + return self._run_method_based_on_multi_index( + "_parse_num_nan", obj, has_multi_index + ) if cname(obj) == "NumNotNan": - return self._run_method_based_on_multi_idx_mode("_parse_num_not_nan", obj) + return self._run_method_based_on_multi_index( + "_parse_num_not_nan", obj, has_multi_index + ) if cname(obj) == "NumInf": - return self._run_method_based_on_multi_idx_mode("_parse_num_inf", obj) + return self._run_method_based_on_multi_index( + "_parse_num_inf", obj, has_multi_index + ) if cname(obj) == "NumNotInf": - return self._run_method_based_on_multi_idx_mode("_parse_num_not_inf", obj) + return self._run_method_based_on_multi_index( + "_parse_num_not_inf", obj, has_multi_index + ) raise RuntimeError("Bad Number Op Class") def _parse_num_eq(self, obj): @@ -845,15 +862,13 @@ def _parse_num_eq_multi_idx(self, obj): return [ None, obj.name, - self._add_aggregation_call_to_multi_idx_predicate( - "df_row[{}].apply(lambda elem: elem == {})".format( - ( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ), - obj.val, - ) + "df_row[{}].apply(lambda elem: elem == {})".format( + ( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + obj.val, ), "is_numeric_dtype(df_row[{}])".format( str(tuple(obj.prop.ids)) @@ -988,15 +1003,13 @@ def _parse_num_lt_multi_idx(self, obj): return [ None, obj.name, - self._add_aggregation_call_to_multi_idx_predicate( - "df_row[{}].apply(lambda elem: elem < {})".format( - ( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ), - obj.val, - ) + "df_row[{}].apply(lambda elem: elem < {})".format( + ( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + obj.val, ), "is_numeric_dtype(df_row[{}])".format( str(tuple(obj.prop.ids)) @@ -1131,15 +1144,13 @@ def _parse_num_gt_multi_idx(self, obj): return [ None, obj.name, - self._add_aggregation_call_to_multi_idx_predicate( - "df_row[{}].apply(lambda elem: elem > {})".format( - ( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ), - obj.val, - ) + "df_row[{}].apply(lambda elem: elem > {})".format( + ( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + obj.val, ), "is_numeric_dtype(df_row[{}])".format( str(tuple(obj.prop.ids)) @@ -1274,15 +1285,13 @@ def _parse_num_lte_multi_idx(self, obj): return [ None, obj.name, - self._add_aggregation_call_to_multi_idx_predicate( - "df_row[{}].apply(lambda elem: elem <= {})".format( - ( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ), - obj.val, - ) + "df_row[{}].apply(lambda elem: elem <= {})".format( + ( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + obj.val, ), "is_numeric_dtype(df_row[{}])".format( str(tuple(obj.prop.ids)) @@ -1417,15 +1426,13 @@ def _parse_num_gte_multi_idx(self, obj): return [ None, obj.name, - self._add_aggregation_call_to_multi_idx_predicate( - "df_row[{}].apply(lambda elem: elem >= {})".format( - ( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ), - obj.val, - ) + "df_row[{}].apply(lambda elem: elem >= {})".format( + ( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + obj.val, ), "is_numeric_dtype(df_row[{}])".format( str(tuple(obj.prop.ids)) @@ -1483,12 +1490,10 @@ def _parse_num_nan_multi_idx(self, obj): return [ None, obj.name, - self._add_aggregation_call_to_multi_idx_predicate( - "pd.isna(df_row[{}])".format( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ) + "pd.isna(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) ), "is_numeric_dtype(df_row[{}])".format( str(tuple(obj.prop.ids)) @@ -1546,12 +1551,10 @@ def _parse_num_not_nan_multi_idx(self, obj): return [ None, obj.name, - self._add_aggregation_call_to_multi_idx_predicate( - "not pd.isna(df_row[{}])".format( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ) + "not pd.isna(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) ), "is_numeric_dtype(df_row[{}])".format( str(tuple(obj.prop.ids)) @@ -1609,12 +1612,10 @@ def _parse_num_inf_multi_idx(self, obj): return [ None, obj.name, - self._add_aggregation_call_to_multi_idx_predicate( - "np.isinf(df_row[{}])".format( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ) + "np.isinf(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) ), "is_numeric_dtype(df_row[{}])".format( str(tuple(obj.prop.ids)) @@ -1672,12 +1673,10 @@ def _parse_num_not_inf_multi_idx(self, obj): return [ None, obj.name, - self._add_aggregation_call_to_multi_idx_predicate( - "not np.isinf(df_row[{}])".format( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ) + "not np.isinf(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) ), "is_numeric_dtype(df_row[{}])".format( str(tuple(obj.prop.ids)) @@ -1687,7 +1686,7 @@ def _parse_num_not_inf_multi_idx(self, obj): ] -def parse_string_dialect(query_str, multi_index_mode="off"): +def parse_string_dialect(query_str): """Parse all types of String-based queries, including multi-queries that leverage the curly brace delimiters. @@ -1709,7 +1708,7 @@ def parse_string_dialect(query_str, multi_index_mode="off"): if num_curly_brace_elems == 0: if sys.version_info[0] == 2: query_str = query_str.decode("utf-8") - return StringQuery(query_str, multi_index_mode) + return StringQuery(query_str) # Create an iterator over the curly brace-delimited regions curly_brace_iter = re.finditer(r"\{(.*?)\}", query_str) # Will store curly brace-delimited regions in the WHERE clause @@ -1798,14 +1797,14 @@ def parse_string_dialect(query_str, multi_index_mode="off"): query1 = "MATCH {} WHERE {}".format(match_comp, condition_list[i]) if sys.version_info[0] == 2: query1 = query1.decode("utf-8") - full_query = StringQuery(query1, multi_index_mode) + full_query = StringQuery(query1) # Get the next query as a CypherQuery where # the MATCH clause is the shared match clause and the WHERE clause is the # next curly brace-delimited region next_query = "MATCH {} WHERE {}".format(match_comp, condition_list[i + 1]) if sys.version_info[0] == 2: next_query = next_query.decode("utf-8") - next_query = StringQuery(next_query, multi_index_mode) + next_query = StringQuery(next_query) # Add the next query to the full query using the compound operator # currently being considered if op == "AND": diff --git a/hatchet/tests/query.py b/hatchet/tests/query.py index a8082dea..9f911ae7 100644 --- a/hatchet/tests/query.py +++ b/hatchet/tests/query.py @@ -25,7 +25,7 @@ ExclusiveDisjunctionQuery, NegationQuery, ) -from hatchet.query.errors import MultiIndexModeMismatch +from hatchet.query.engine import _all_aggregator, _any_aggregator def test_construct_object_dialect(): @@ -239,7 +239,7 @@ def test_node_caching(mock_graph_literal): query = ObjectQuery(path) engine = QueryEngine() - engine._cache_node(node, query, gf.dataframe) + engine._cache_node(node, query, gf.dataframe, None) assert 0 in engine.search_cache[node._hatchet_nid] assert 1 in engine.search_cache[node._hatchet_nid] @@ -270,12 +270,12 @@ def test_match_0_or_more_wildcard(mock_graph_literal): engine = QueryEngine() matched_paths = [] for child in sorted(node.children, key=traversal_order): - match = engine._match_0_or_more(query, gf.dataframe, child, 1) + match = engine._match_0_or_more(query, gf.dataframe, None, child, 1) if match is not None: matched_paths.extend(match) assert sorted(matched_paths, key=len) == sorted(correct_paths, key=len) - assert engine._match_0_or_more(query, gf.dataframe, none_node, 1) is None + assert engine._match_0_or_more(query, gf.dataframe, None, none_node, 1) is None def test_match_1(mock_graph_literal): @@ -288,10 +288,10 @@ def test_match_1(mock_graph_literal): query = ObjectQuery(path) engine = QueryEngine() - assert engine._match_1(query, gf.dataframe, gf.graph.roots[0].children[0], 2) == [ - [gf.graph.roots[0].children[0].children[1]] - ] - assert engine._match_1(query, gf.dataframe, gf.graph.roots[0], 2) is None + assert engine._match_1( + query, gf.dataframe, None, gf.graph.roots[0].children[0], 2 + ) == [[gf.graph.roots[0].children[0].children[1]]] + assert engine._match_1(query, gf.dataframe, None, gf.graph.roots[0], 2) is None def test_match(mock_graph_literal): @@ -316,7 +316,9 @@ def test_match(mock_graph_literal): ] query0 = ObjectQuery(path0) engine = QueryEngine() - assert engine._match_pattern(query0, gf.dataframe, root, 0) == match0 + assert ( + engine._match_pattern(query0, gf.dataframe, _all_aggregator, root, 0) == match0 + ) engine.reset_cache() @@ -328,7 +330,7 @@ def test_match(mock_graph_literal): {"time (inc)": 7.5, "time": 7.5}, ] query1 = ObjectQuery(path1) - assert engine._match_pattern(query1, gf.dataframe, root, 0) is None + assert engine._match_pattern(query1, gf.dataframe, _all_aggregator, root, 0) is None def test_apply(mock_graph_literal): @@ -350,7 +352,7 @@ def test_apply(mock_graph_literal): query = ObjectQuery(path) engine = QueryEngine() - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(match) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, None)) == sorted(match) path = [{"time (inc)": ">= 30.0"}, ".", {"name": "bar"}, "*"] match = [ @@ -361,12 +363,12 @@ def test_apply(mock_graph_literal): root.children[1].children[0].children[0].children[0].children[1], ] query = ObjectQuery(path) - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(match) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, None)) == sorted(match) path = [{"name": "foo"}, {"name": "bar"}, {"time": 5.0}] match = [root, root.children[0], root.children[0].children[0]] query = ObjectQuery(path) - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(match) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, None)) == sorted(match) path = [{"name": "foo"}, {"name": "qux"}, ("+", {"time (inc)": "> 15.0"})] match = [ @@ -377,22 +379,22 @@ def test_apply(mock_graph_literal): root.children[1].children[0].children[0].children[0], ] query = ObjectQuery(path) - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(match) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, None)) == sorted(match) path = [{"name": "this"}, ("*", {"name": "is"}), {"name": "nonsense"}] query = ObjectQuery(path) - assert engine.apply(query, gf.graph, gf.dataframe) == [] + assert engine.apply(query, gf.graph, gf.dataframe, None) == [] path = [{"name": 5}, "*", {"name": "whatever"}] query = ObjectQuery(path) with pytest.raises(InvalidQueryFilter): - engine.apply(query, gf.graph, gf.dataframe) + engine.apply(query, gf.graph, gf.dataframe, None) path = [{"time": "badstring"}, "*", {"name": "whatever"}] query = ObjectQuery(path) with pytest.raises(InvalidQueryFilter): - engine.apply(query, gf.graph, gf.dataframe) + engine.apply(query, gf.graph, gf.dataframe, None) class DummyType: def __init__(self): @@ -408,7 +410,7 @@ def __init__(self): path = [{"name": "foo"}, {"name": "bar"}, {"list": DummyType()}] query = ObjectQuery(path) with pytest.raises(InvalidQueryFilter): - engine.apply(query, gf.graph, gf.dataframe) + engine.apply(query, gf.graph, gf.dataframe, None) path = ["*", {"name": "bar"}, {"name": "grault"}, "*"] match = [ @@ -453,11 +455,11 @@ def __init__(self): ] match = list(set().union(*match)) query = ObjectQuery(path) - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(match) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, None)) == sorted(match) path = ["*", {"name": "bar"}, {"name": "grault"}, "+"] query = ObjectQuery(path) - assert engine.apply(query, gf.graph, gf.dataframe) == [] + assert engine.apply(query, gf.graph, gf.dataframe, None) == [] # Test a former edge case with the + quantifier/wildcard match = [ @@ -487,7 +489,7 @@ def __init__(self): match = list(set().union(*match)) path = [("+", {"name": "ba.*"})] query = ObjectQuery(path) - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(match) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, None)) == sorted(match) def test_apply_indices(calc_pi_hpct_db): @@ -512,12 +514,14 @@ def test_apply_indices(calc_pi_hpct_db): ], ] matches = list(set().union(*matches)) - query = ObjectQuery(path, multi_index_mode="all") + query = ObjectQuery(path) engine = QueryEngine() - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(matches) + assert sorted( + engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all") + ) == sorted(matches) gf.drop_index_levels() - assert engine.apply(query, gf.graph, gf.dataframe) == matches + assert engine.apply(query, gf.graph, gf.dataframe, None) == matches def test_object_dialect_depth(mock_graph_literal): @@ -526,7 +530,7 @@ def test_object_dialect_depth(mock_graph_literal): engine = QueryEngine() roots = gf.graph.roots matches = [c for r in roots for c in r.children] - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(matches) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, None)) == sorted(matches) query = ObjectQuery([("*", {"depth": "<= 2"})]) matches = [ @@ -553,11 +557,11 @@ def test_object_dialect_depth(mock_graph_literal): [roots[1].children[0].children[1]], ] matches = list(set().union(*matches)) - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(matches) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, None)) == sorted(matches) with pytest.raises(InvalidQueryFilter): query = ObjectQuery([{"depth": "hello"}]) - engine.apply(query, gf.graph, gf.dataframe) + engine.apply(query, gf.graph, gf.dataframe, None) def test_object_dialect_hatchet_nid(mock_graph_literal): @@ -574,21 +578,21 @@ def test_object_dialect_hatchet_nid(mock_graph_literal): [root.children[0].children[1]], ] matches = list(set().union(*matches)) - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(matches) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, None)) == sorted(matches) query = ObjectQuery([{"node_id": 0}]) - assert engine.apply(query, gf.graph, gf.dataframe) == [gf.graph.roots[0]] + assert engine.apply(query, gf.graph, gf.dataframe, None) == [gf.graph.roots[0]] with pytest.raises(InvalidQueryFilter): query = ObjectQuery([{"node_id": "hello"}]) - engine.apply(query, gf.graph, gf.dataframe) + engine.apply(query, gf.graph, gf.dataframe, None) def test_object_dialect_depth_index_levels(calc_pi_hpct_db): gf = GraphFrame.from_hpctoolkit(str(calc_pi_hpct_db)) root = gf.graph.roots[0] - query = ObjectQuery([("*", {"depth": "<= 2"})], multi_index_mode="all") + query = ObjectQuery([("*", {"depth": "<= 2"})]) engine = QueryEngine() matches = [ [root, root.children[0], root.children[0].children[0]], @@ -599,22 +603,27 @@ def test_object_dialect_depth_index_levels(calc_pi_hpct_db): [root.children[0].children[1]], ] matches = list(set().union(*matches)) - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(matches) + assert sorted( + engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all") + ) == sorted(matches) - query = ObjectQuery([("*", {"depth": 0})], multi_index_mode="all") + query = ObjectQuery([("*", {"depth": 0})]) matches = [root] - assert engine.apply(query, gf.graph, gf.dataframe) == matches + assert ( + engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all") + == matches + ) with pytest.raises(InvalidQueryFilter): - query = ObjectQuery([{"depth": "hello"}], multi_index_mode="all") - engine.apply(query, gf.graph, gf.dataframe) + query = ObjectQuery([{"depth": "hello"}]) + engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all") def test_object_dialect_node_id_index_levels(calc_pi_hpct_db): gf = GraphFrame.from_hpctoolkit(str(calc_pi_hpct_db)) root = gf.graph.roots[0] - query = ObjectQuery([("*", {"node_id": "<= 2"})], multi_index_mode="all") + query = ObjectQuery([("*", {"node_id": "<= 2"})]) engine = QueryEngine() matches = [ [root, root.children[0]], @@ -624,15 +633,20 @@ def test_object_dialect_node_id_index_levels(calc_pi_hpct_db): [root.children[0].children[0]], ] matches = list(set().union(*matches)) - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(matches) + assert sorted( + engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all") + ) == sorted(matches) - query = ObjectQuery([("*", {"node_id": 0})], multi_index_mode="all") + query = ObjectQuery([("*", {"node_id": 0})]) matches = [root] - assert engine.apply(query, gf.graph, gf.dataframe) == matches + assert ( + engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all") + == matches + ) with pytest.raises(InvalidQueryFilter): - query = ObjectQuery([{"node_id": "hello"}], multi_index_mode="all") - engine.apply(query, gf.graph, gf.dataframe) + query = ObjectQuery([{"node_id": "hello"}]) + engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all") def test_object_dialect_multi_condition_one_attribute(mock_graph_literal): @@ -691,7 +705,7 @@ def test_object_dialect_multi_condition_one_attribute(mock_graph_literal): [roots[1].children[0]], ] matches = list(set().union(*matches)) - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(matches) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, None)) == sorted(matches) def test_obj_query_is_query(): @@ -788,7 +802,7 @@ def test_conjunction_query(mock_graph_literal): roots[0].children[1], roots[0].children[1].children[0], ] - assert sorted(engine.apply(compound_query, gf.graph, gf.dataframe)) == sorted( + assert sorted(engine.apply(compound_query, gf.graph, gf.dataframe, None)) == sorted( matches ) @@ -811,7 +825,7 @@ def test_disjunction_query(mock_graph_literal): roots[1].children[0].children[0], roots[1].children[0].children[1], ] - assert sorted(engine.apply(compound_query, gf.graph, gf.dataframe)) == sorted( + assert sorted(engine.apply(compound_query, gf.graph, gf.dataframe, None)) == sorted( matches ) @@ -830,7 +844,7 @@ def test_exc_disjunction_query(mock_graph_literal): roots[0].children[2].children[0].children[1].children[0].children[0], roots[1].children[0].children[0], ] - assert sorted(engine.apply(compound_query, gf.graph, gf.dataframe)) == sorted( + assert sorted(engine.apply(compound_query, gf.graph, gf.dataframe, None)) == sorted( matches ) @@ -952,7 +966,7 @@ def test_apply_string_dialect(mock_graph_literal): query = StringQuery(path) engine = QueryEngine() - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(match) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, None)) == sorted(match) path = """MATCH (p)->(".")->(q)->("*") WHERE p."time (inc)" >= 30.0 AND q."name" = "bar" @@ -965,14 +979,14 @@ def test_apply_string_dialect(mock_graph_literal): root.children[1].children[0].children[0].children[0].children[1], ] query = StringQuery(path) - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(match) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, None)) == sorted(match) path = """MATCH (p)->(q)->(r) WHERE p."name" = "foo" AND q."name" = "bar" AND r."time" = 5.0 """ match = [root, root.children[0], root.children[0].children[0]] query = StringQuery(path) - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(match) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, None)) == sorted(match) path = """MATCH (p)->(q)->("+", r) WHERE p."name" = "foo" AND q."name" = "qux" AND r."time (inc)" > 15.0 @@ -985,7 +999,7 @@ def test_apply_string_dialect(mock_graph_literal): root.children[1].children[0].children[0].children[0], ] query = StringQuery(path) - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(match) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, None)) == sorted(match) path = """MATCH (p)->(q) WHERE p."time (inc)" > 100 OR p."time (inc)" <= 30 AND q."time (inc)" = 20 @@ -998,28 +1012,28 @@ def test_apply_string_dialect(mock_graph_literal): roots[1].children[0], ] query = StringQuery(path) - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(match) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, None)) == sorted(match) path = """MATCH (p)->("*", q)->(r) WHERE p."name" = "this" AND q."name" = "is" AND r."name" = "nonsense" """ query = StringQuery(path) - assert engine.apply(query, gf.graph, gf.dataframe) == [] + assert engine.apply(query, gf.graph, gf.dataframe, None) == [] path = """MATCH (p)->("*")->(q) WHERE p."name" = 5 AND q."name" = "whatever" """ with pytest.raises(InvalidQueryFilter): query = StringQuery(path) - engine.apply(query, gf.graph, gf.dataframe) + engine.apply(query, gf.graph, gf.dataframe, None) path = """MATCH (p)->("*")->(q) WHERE p."time" = "badstring" AND q."name" = "whatever" """ query = StringQuery(path) with pytest.raises(InvalidQueryFilter): - engine.apply(query, gf.graph, gf.dataframe) + engine.apply(query, gf.graph, gf.dataframe, None) class DummyType: def __init__(self): @@ -1037,7 +1051,7 @@ def __init__(self): """ with pytest.raises(InvalidQueryPath): query = StringQuery(path) - engine.apply(query, gf.graph, gf.dataframe) + engine.apply(query, gf.graph, gf.dataframe, None) path = """MATCH ("*")->(p)->(q)->("*") WHERE p."name" = "bar" AND q."name" = "grault" @@ -1084,13 +1098,13 @@ def __init__(self): ] match = list(set().union(*match)) query = StringQuery(path) - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(match) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, None)) == sorted(match) path = """MATCH ("*")->(p)->(q)->("+") WHERE p."name" = "bar" AND q."name" = "grault" """ query = StringQuery(path) - assert engine.apply(query, gf.graph, gf.dataframe) == [] + assert engine.apply(query, gf.graph, gf.dataframe, None) == [] gf.dataframe["time"] = np.NaN gf.dataframe.at[gf.graph.roots[0], "time"] = 5.0 @@ -1098,7 +1112,7 @@ def __init__(self): WHERE p."time" IS NOT NAN""" match = [gf.graph.roots[0]] query = StringQuery(path) - assert engine.apply(query, gf.graph, gf.dataframe) == match + assert engine.apply(query, gf.graph, gf.dataframe, None) == match gf.dataframe["time"] = 5.0 gf.dataframe.at[gf.graph.roots[0], "time"] = np.NaN @@ -1106,7 +1120,7 @@ def __init__(self): WHERE p."time" IS NAN""" match = [gf.graph.roots[0]] query = StringQuery(path) - assert engine.apply(query, gf.graph, gf.dataframe) == match + assert engine.apply(query, gf.graph, gf.dataframe, None) == match gf.dataframe["time"] = np.Inf gf.dataframe.at[gf.graph.roots[0], "time"] = 5.0 @@ -1114,7 +1128,7 @@ def __init__(self): WHERE p."time" IS NOT INF""" match = [gf.graph.roots[0]] query = StringQuery(path) - assert engine.apply(query, gf.graph, gf.dataframe) == match + assert engine.apply(query, gf.graph, gf.dataframe, None) == match gf.dataframe["time"] = 5.0 gf.dataframe.at[gf.graph.roots[0], "time"] = np.Inf @@ -1122,7 +1136,7 @@ def __init__(self): WHERE p."time" IS INF""" match = [gf.graph.roots[0]] query = StringQuery(path) - assert engine.apply(query, gf.graph, gf.dataframe) == match + assert engine.apply(query, gf.graph, gf.dataframe, None) == match names = gf.dataframe["name"].copy() gf.dataframe["name"] = None @@ -1131,7 +1145,7 @@ def __init__(self): WHERE p."name" IS NOT NONE""" match = [gf.graph.roots[0]] query = StringQuery(path) - assert engine.apply(query, gf.graph, gf.dataframe) == match + assert engine.apply(query, gf.graph, gf.dataframe, None) == match gf.dataframe["name"] = names gf.dataframe.at[gf.graph.roots[0], "name"] = None @@ -1139,7 +1153,7 @@ def __init__(self): WHERE p."name" IS NONE""" match = [gf.graph.roots[0]] query = StringQuery(path) - assert engine.apply(query, gf.graph, gf.dataframe) == match + assert engine.apply(query, gf.graph, gf.dataframe, None) == match def test_string_conj_compound_query(mock_graph_literal): @@ -1162,12 +1176,12 @@ def test_string_conj_compound_query(mock_graph_literal): roots[0].children[1], roots[0].children[1].children[0], ] - assert sorted(engine.apply(compound_query1, gf.graph, gf.dataframe)) == sorted( - matches - ) - assert sorted(engine.apply(compound_query2, gf.graph, gf.dataframe)) == sorted( - matches - ) + assert sorted( + engine.apply(compound_query1, gf.graph, gf.dataframe, None) + ) == sorted(matches) + assert sorted( + engine.apply(compound_query2, gf.graph, gf.dataframe, None) + ) == sorted(matches) def test_string_disj_compound_query(mock_graph_literal): @@ -1197,12 +1211,12 @@ def test_string_disj_compound_query(mock_graph_literal): roots[1].children[0].children[0], roots[1].children[0].children[1], ] - assert sorted(engine.apply(compound_query1, gf.graph, gf.dataframe)) == sorted( - matches - ) - assert sorted(engine.apply(compound_query2, gf.graph, gf.dataframe)) == sorted( - matches - ) + assert sorted( + engine.apply(compound_query1, gf.graph, gf.dataframe, None) + ) == sorted(matches) + assert sorted( + engine.apply(compound_query2, gf.graph, gf.dataframe, None) + ) == sorted(matches) def test_cypher_exc_disj_compound_query(mock_graph_literal): @@ -1228,12 +1242,12 @@ def test_cypher_exc_disj_compound_query(mock_graph_literal): roots[0].children[2].children[0].children[1].children[0].children[0], roots[1].children[0].children[0], ] - assert sorted(engine.apply(compound_query1, gf.graph, gf.dataframe)) == sorted( - matches - ) - assert sorted(engine.apply(compound_query2, gf.graph, gf.dataframe)) == sorted( - matches - ) + assert sorted( + engine.apply(compound_query1, gf.graph, gf.dataframe, None) + ) == sorted(matches) + assert sorted( + engine.apply(compound_query2, gf.graph, gf.dataframe, None) + ) == sorted(matches) def test_leaf_query(small_mock2): @@ -1267,24 +1281,24 @@ def test_leaf_query(small_mock2): """ ) engine = QueryEngine() - assert sorted(engine.apply(obj_query, gf.graph, gf.dataframe)) == sorted(matches) - assert sorted(engine.apply(str_query_numeric, gf.graph, gf.dataframe)) == sorted( - matches - ) - assert sorted(engine.apply(str_query_is_leaf, gf.graph, gf.dataframe)) == sorted( + assert sorted(engine.apply(obj_query, gf.graph, gf.dataframe, None)) == sorted( matches ) assert sorted( - engine.apply(str_query_is_not_leaf, gf.graph, gf.dataframe) + engine.apply(str_query_numeric, gf.graph, gf.dataframe, None) + ) == sorted(matches) + assert sorted( + engine.apply(str_query_is_leaf, gf.graph, gf.dataframe, None) + ) == sorted(matches) + assert sorted( + engine.apply(str_query_is_not_leaf, gf.graph, gf.dataframe, None) ) == sorted(nonleaves) def test_object_dialect_all_mode(tau_profile_dir): gf = GraphFrame.from_tau(tau_profile_dir) engine = QueryEngine() - query = ObjectQuery( - [".", ("+", {"time (inc)": ">= 17983.0"})], multi_index_mode="all" - ) + query = ObjectQuery([".", ("+", {"time (inc)": ">= 17983.0"})]) roots = gf.graph.roots matches = [ roots[0], @@ -1292,7 +1306,9 @@ def test_object_dialect_all_mode(tau_profile_dir): roots[0].children[6].children[1], roots[0].children[0], ] - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(matches) + assert sorted( + engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all") + ) == sorted(matches) def test_string_dialect_all_mode(tau_profile_dir): @@ -1301,8 +1317,7 @@ def test_string_dialect_all_mode(tau_profile_dir): query = StringQuery( """MATCH (".")->("+", p) WHERE p."time (inc)" >= 17983.0 - """, - multi_index_mode="all", + """ ) roots = gf.graph.roots matches = [ @@ -1311,19 +1326,23 @@ def test_string_dialect_all_mode(tau_profile_dir): roots[0].children[6].children[1], roots[0].children[0], ] - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(matches) + assert sorted( + engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all") + ) == sorted(matches) def test_object_dialect_any_mode(tau_profile_dir): gf = GraphFrame.from_tau(tau_profile_dir) engine = QueryEngine() - query = ObjectQuery([{"time": "< 24.0"}], multi_index_mode="any") + query = ObjectQuery([{"time": "< 24.0"}]) roots = gf.graph.roots matches = [ roots[0].children[2], roots[0].children[6].children[3], ] - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(matches) + assert sorted( + engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="any") + ) == sorted(matches) def test_string_dialect_any_mode(tau_profile_dir): @@ -1332,31 +1351,24 @@ def test_string_dialect_any_mode(tau_profile_dir): query = StringQuery( """MATCH (".", p) WHERE p."time" < 24.0 - """, - multi_index_mode="any", + """ ) roots = gf.graph.roots matches = [ roots[0].children[2], roots[0].children[6].children[3], ] - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(matches) - - -def test_multi_index_mode_assertion_error(tau_profile_dir): - with pytest.raises(AssertionError): - _ = ObjectQuery([".", ("*", {"name": "test"})], multi_index_mode="foo") - with pytest.raises(AssertionError): - _ = StringQuery( - """ MATCH (".")->("*", p) - WHERE p."name" = "test" - """, - multi_index_mode="foo", - ) + assert sorted( + engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="any") + ) == sorted(matches) + + +def test_predicate_row_aggregator_assertion_error(tau_profile_dir): gf = GraphFrame.from_tau(tau_profile_dir) - query = ObjectQuery( - [".", ("*", {"time (inc)": "> 17983.0"})], multi_index_mode="off" - ) engine = QueryEngine() - with pytest.raises(MultiIndexModeMismatch): - engine.apply(query, gf.graph, gf.dataframe) + query = ObjectQuery([".", ("*", {"name": "test"})]) + with pytest.raises(ValueError): + engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="foo") + query = ObjectQuery([".", ("*", {"time (inc)": "> 17983.0"})]) + with pytest.raises(ValueError): + engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="off")