diff --git a/thicket/tests/test_filter_stats.py b/thicket/tests/test_filter_stats.py index 8f1e45d0..00f6114f 100644 --- a/thicket/tests/test_filter_stats.py +++ b/thicket/tests/test_filter_stats.py @@ -39,6 +39,9 @@ def check_filter_stats(th, columns_values): # We can't check th.graph because of squash in filter_stats assert th.statsframe.graph is not new_th.statsframe.graph + # Check Thicket and Statsframe graph in sync + assert len(new_th.graph) == len(new_th.statsframe.graph) + # filtered nodes in aggregated statistics table stats_nodes = sorted( new_th.statsframe.dataframe.index.drop_duplicates().tolist() diff --git a/thicket/thicket.py b/thicket/thicket.py index 95824448..ec2a74e7 100644 --- a/thicket/thicket.py +++ b/thicket/thicket.py @@ -1662,7 +1662,10 @@ def filter_stats(self, filter_function): ] # filter nodes in the graphframe based on the dataframe nodes + # We want to preserve columns, so "new_statsframe=False", + # however to match graph filter we update statsframe.graph new_thicket = new_thicket.squash(new_statsframe=False) + new_thicket.statsframe.graph = new_thicket.graph return new_thicket