Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions hatchet/graphframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import sys
import traceback
from collections import defaultdict
import warnings

import multiprocess as mp
import numpy as np
Expand Down Expand Up @@ -472,7 +473,8 @@ def filter(
update_inc_cols=True,
num_procs=mp.cpu_count(),
rec_limit=1000,
multi_index_mode="off",
predicate_row_aggregator=None,
multi_index_mode=None,
):
"""Filter the dataframe using a user-supplied function.

Expand All @@ -484,7 +486,20 @@ def filter(
update_inc_cols (boolean, optional): if True, update inclusive columns when performing squash.
rec_limit: set Python recursion limit, increase if running into
recursion depth errors) (default: 1000).
predicate_row_aggregator (str or Callable, optional): function to use in Query Language
to merge multiple predicate results for each node into a single boolean. When providing
a string value, the following are accepted: "all" (equivalent to Python 'all'), "any"
(equivalent to Python 'any'), "off" (no aggregation)
multi_index_mode: deprecated alias for "predicate_row_aggregator"
"""
if multi_index_mode is not None:
warnings.warn(
"'multi_index_mode' parameter is deprecated. Use 'predicate_row_aggregator' instead",
DeprecationWarning,
)
if predicate_row_aggregator is None:
predicate_row_aggregator = multi_index_mode

sys.setrecursionlimit(rec_limit)

dataframe_copy = self.dataframe.copy()
Expand Down Expand Up @@ -537,15 +552,17 @@ def filter(
# If a raw Object-dialect query is provided (not already passed to ObjectQuery),
# create a new ObjectQuery object.
if isinstance(filter_obj, list):
query = ObjectQuery(filter_obj, multi_index_mode)
query = ObjectQuery(filter_obj)
# If a raw String-dialect query is provided (not already passed to StringQuery),
# create a new StringQuery object.
elif isinstance(filter_obj, str):
query = parse_string_dialect(filter_obj, multi_index_mode)
query = parse_string_dialect(filter_obj)
# If an old-style query is provided, extract the underlying new-style query.
elif issubclass(type(filter_obj), AbstractQuery):
query = filter_obj._get_new_query()
query_matches = self.query_engine.apply(query, self.graph, self.dataframe)
query_matches = self.query_engine.apply(
query, self.graph, self.dataframe, predicate_row_aggregator
)
# match_set = list(set().union(*query_matches))
# filtered_df = dataframe_copy.loc[dataframe_copy["node"].isin(match_set)]
filtered_df = dataframe_copy.loc[dataframe_copy["node"].isin(query_matches)]
Expand Down
6 changes: 4 additions & 2 deletions hatchet/query/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def apply(self, gf):
(list): A list of nodes representing the result of the query
"""
true_query = self._get_new_query()
return COMPATABILITY_ENGINE.apply(true_query, gf.graph, gf.dataframe)
return COMPATABILITY_ENGINE.apply(true_query, gf.graph, gf.dataframe, "off")

def _get_new_query(self):
"""Gets all the underlying 'new-style' queries in this object.
Expand Down Expand Up @@ -322,7 +322,9 @@ def apply(self, gf):
Returns:
(list): A list representing the set of nodes from paths that match this query
"""
return COMPATABILITY_ENGINE.apply(self.true_query, gf.graph, gf.dataframe)
return COMPATABILITY_ENGINE.apply(
self.true_query, gf.graph, gf.dataframe, "off"
)

def _get_new_query(self):
"""Get all the underlying 'new-style' query in this object.
Expand Down
102 changes: 83 additions & 19 deletions hatchet/query/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# SPDX-License-Identifier: MIT

from itertools import groupby
from collections.abc import Iterable
import pandas as pd

from .errors import InvalidQueryFilter
Expand All @@ -14,6 +15,22 @@
from .string_dialect import parse_string_dialect


def _all_aggregator(pred_result):
if isinstance(pred_result, Iterable):
return all(pred_result)
elif isinstance(pred_result, pd.Series):
return pred_result.all()
return pred_result


def _any_aggregator(pred_result):
if isinstance(pred_result, Iterable):
return any(pred_result)
elif isinstance(pred_result, pd.Series):
return pred_result.any()
return pred_result


class QueryEngine:
"""Class for applying queries to GraphFrames."""

Expand All @@ -25,7 +42,7 @@ def reset_cache(self):
"""Resets the cache in the QueryEngine."""
self.search_cache = {}

def apply(self, query, graph, dframe):
def apply(self, query, graph, dframe, predicate_row_aggregator):
"""Apply the query to a GraphFrame.

Arguments:
Expand All @@ -37,11 +54,28 @@ def apply(self, query, graph, dframe):
(list): A list representing the set of nodes from paths that match the query
"""
if issubclass(type(query), Query):
aggregator = predicate_row_aggregator
if predicate_row_aggregator is None:
aggregator = query.default_aggregator
if aggregator == "all":
aggregator = _all_aggregator
elif aggregator == "any":
aggregator = _any_aggregator
elif aggregator == "off":
if isinstance(dframe.index, pd.MultiIndex):
raise ValueError(
"'predicate_row_aggregator' cannot be 'off' when the DataFrame has a row multi-index"
)
aggregator = None
elif not callable(aggregator):
raise ValueError(
"Invalid value provided for 'predicate_row_aggregator'"
)
self.reset_cache()
matches = []
visited = set()
for root in sorted(graph.roots, key=traversal_order):
self._apply_impl(query, dframe, root, visited, matches)
self._apply_impl(query, dframe, aggregator, root, visited, matches)
assert len(visited) == len(graph)
matched_node_set = list(set().union(*matches))
# return matches
Expand All @@ -54,12 +88,14 @@ def apply(self, query, graph, dframe):
subq_obj = ObjectQuery(subq)
elif isinstance(subq, str):
subq_obj = parse_string_dialect(subq)
results.append(self.apply(subq_obj, graph, dframe))
results.append(
self.apply(subq_obj, graph, dframe, predicate_row_aggregator)
)
return query._apply_op_to_results(results, graph)
else:
raise TypeError("Invalid query data type ({})".format(str(type(query))))

def _cache_node(self, node, query, dframe):
def _cache_node(self, node, query, dframe, predicate_row_aggregator):
"""Cache (Memoize) the parts of the query that the node matches.

Arguments:
Expand All @@ -78,11 +114,19 @@ def _cache_node(self, node, query, dframe):
row = dframe.xs(node, level="node", drop_level=False)
else:
row = dframe.loc[node]
if filter_func(row):
predicate_result = filter_func(row)
if (
not isinstance(predicate_result, bool)
and predicate_row_aggregator is not None
):
predicate_result = predicate_row_aggregator(predicate_result)
if predicate_result:
matches.append(i)
self.search_cache[node._hatchet_nid] = matches

def _match_0_or_more(self, query, dframe, node, wcard_idx):
def _match_0_or_more(
self, query, dframe, predicate_row_aggregator, node, wcard_idx
):
"""Process a "*" predicate in the query on a subgraph.

Arguments:
Expand All @@ -98,7 +142,7 @@ def _match_0_or_more(self, query, dframe, node, wcard_idx):
"""
# Cache the node if it's not already cached
if node._hatchet_nid not in self.search_cache:
self._cache_node(node, query, dframe)
self._cache_node(node, query, dframe, predicate_row_aggregator)
# If the node matches with the next non-wildcard query node,
# end the recursion and return the node.
if wcard_idx + 1 in self.search_cache[node._hatchet_nid]:
Expand All @@ -113,7 +157,9 @@ def _match_0_or_more(self, query, dframe, node, wcard_idx):
return [[node]]
return None
for child in sorted(node.children, key=traversal_order):
sub_match = self._match_0_or_more(query, dframe, child, wcard_idx)
sub_match = self._match_0_or_more(
query, dframe, predicate_row_aggregator, child, wcard_idx
)
if sub_match is not None:
matches.extend(sub_match)
if len(matches) == 0:
Expand All @@ -128,7 +174,7 @@ def _match_0_or_more(self, query, dframe, node, wcard_idx):
return [[]]
return None

def _match_1(self, query, dframe, node, idx):
def _match_1(self, query, dframe, predicate_row_aggregator, node, idx):
"""Process a "." predicate in the query on a subgraph.

Arguments:
Expand All @@ -142,12 +188,12 @@ def _match_1(self, query, dframe, node, idx):
Will return None if there are no matches for the "." predicate.
"""
if node._hatchet_nid not in self.search_cache:
self._cache_node(node, query, dframe)
self._cache_node(node, query, dframe, predicate_row_aggregator)
matches = []
for child in sorted(node.children, key=traversal_order):
# Cache the node if it's not already cached
if child._hatchet_nid not in self.search_cache:
self._cache_node(child, query, dframe)
self._cache_node(child, query, dframe, predicate_row_aggregator)
if idx in self.search_cache[child._hatchet_nid]:
matches.append([child])
# To be consistent with the other matching functions, return
Expand All @@ -156,7 +202,9 @@ def _match_1(self, query, dframe, node, idx):
return None
return matches

def _match_pattern(self, query, dframe, pattern_root, match_idx):
def _match_pattern(
self, query, dframe, predicate_row_aggregator, pattern_root, match_idx
):
"""Try to match the query pattern starting at the provided root node.

Arguments:
Expand Down Expand Up @@ -186,7 +234,9 @@ def _match_pattern(self, query, dframe, pattern_root, match_idx):
# Get the portion of the subgraph that matches the next
# part of the query.
if wcard == ".":
s = self._match_1(query, dframe, m[-1], pattern_idx)
s = self._match_1(
query, dframe, predicate_row_aggregator, m[-1], pattern_idx
)
if s is None:
sub_match.append(s)
else:
Expand All @@ -196,7 +246,13 @@ def _match_pattern(self, query, dframe, pattern_root, match_idx):
sub_match.append([])
else:
for child in sorted(m[-1].children, key=traversal_order):
s = self._match_0_or_more(query, dframe, child, pattern_idx)
s = self._match_0_or_more(
query,
dframe,
predicate_row_aggregator,
child,
pattern_idx,
)
if s is None:
sub_match.append(s)
else:
Expand All @@ -221,7 +277,9 @@ def _match_pattern(self, query, dframe, pattern_root, match_idx):
pattern_idx += 1
return matches

def _apply_impl(self, query, dframe, node, visited, matches):
def _apply_impl(
self, query, dframe, predicate_row_aggregator, node, visited, matches
):
"""Traverse the subgraph with the specified root, and collect all paths that match the query.

Arguments:
Expand All @@ -237,21 +295,27 @@ def _apply_impl(self, query, dframe, node, visited, matches):
return
# Cache the node if it's not already cached
if node._hatchet_nid not in self.search_cache:
self._cache_node(node, query, dframe)
self._cache_node(node, query, dframe, predicate_row_aggregator)
# If the node matches the starting/root node of the query,
# try to get all query matches in the subgraph rooted at
# this node.
if query.query_pattern[0][0] == "*":
if 1 in self.search_cache[node._hatchet_nid]:
sub_match = self._match_pattern(query, dframe, node, 1)
sub_match = self._match_pattern(
query, dframe, predicate_row_aggregator, node, 1
)
if sub_match is not None:
matches.extend(sub_match)
if 0 in self.search_cache[node._hatchet_nid]:
sub_match = self._match_pattern(query, dframe, node, 0)
sub_match = self._match_pattern(
query, dframe, predicate_row_aggregator, node, 0
)
if sub_match is not None:
matches.extend(sub_match)
# Note that the node is now visited.
visited.add(node._hatchet_nid)
# Continue the Depth First Search.
for child in sorted(node.children, key=traversal_order):
self._apply_impl(query, dframe, child, visited, matches)
self._apply_impl(
query, dframe, predicate_row_aggregator, child, visited, matches
)
6 changes: 0 additions & 6 deletions hatchet/query/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,3 @@ class RedundantQueryFilterWarning(Warning):

class BadNumberNaryQueryArgs(Exception):
"""Raised when a query filter does not have a valid syntax"""


class MultiIndexModeMismatch(Exception):
"""Raised when an ObjectQuery or StringQuery object
is set to use multi-indexed data, but no multi-indexed
data is provided"""
Loading
Loading