From b3aa4be6dcd4f8ab19a937c867ff76762631d014 Mon Sep 17 00:00:00 2001 From: Ian Lumsden Date: Tue, 12 Nov 2024 09:21:06 -0500 Subject: [PATCH 1/5] Adds type hints to Hatchet --- hatchet/external/console.py | 33 ++- hatchet/frame.py | 27 +-- hatchet/graph.py | 46 ++-- hatchet/graphframe.py | 220 ++++++++++++-------- hatchet/node.py | 58 ++++-- hatchet/query/__init__.py | 46 +++- hatchet/query/compat.py | 76 ++++--- hatchet/query/compound.py | 33 ++- hatchet/query/engine.py | 33 ++- hatchet/query/object_dialect.py | 30 ++- hatchet/query/query.py | 25 ++- hatchet/query/string_dialect.py | 205 ++++++++---------- hatchet/readers/caliper_native_reader.py | 27 ++- hatchet/readers/caliper_reader.py | 12 +- hatchet/readers/dataframe_reader.py | 17 +- hatchet/readers/gprof_dot_reader.py | 7 +- hatchet/readers/hdf5_reader.py | 4 +- hatchet/readers/hpctoolkit_reader.py | 34 ++- hatchet/readers/hpctoolkit_reader_latest.py | 13 +- hatchet/readers/json_reader.py | 4 +- hatchet/readers/literal_reader.py | 15 +- hatchet/readers/pyinstrument_reader.py | 9 +- hatchet/readers/spotdb_reader.py | 30 ++- hatchet/readers/tau_reader.py | 39 ++-- hatchet/readers/timemory_reader.py | 61 ++++-- hatchet/util/colormaps.py | 4 +- hatchet/util/deprecated.py | 5 +- hatchet/util/dot.py | 29 ++- hatchet/util/executable.py | 2 +- hatchet/util/profiler.py | 17 +- hatchet/util/timer.py | 12 +- hatchet/writers/dataframe_writer.py | 12 +- hatchet/writers/hdf5_writer.py | 6 +- 33 files changed, 746 insertions(+), 445 deletions(-) diff --git a/hatchet/external/console.py b/hatchet/external/console.py index 77b01b35..e69177a9 100644 --- a/hatchet/external/console.py +++ b/hatchet/external/console.py @@ -34,16 +34,23 @@ import pandas as pd import numpy as np import warnings +from typing import Any, Dict, List, Tuple, Union from ..util.colormaps import ColorMaps +from ..node import Node class ConsoleRenderer: - def __init__(self, unicode=False, color=False): + def __init__(self, unicode: bool = False, color: bool = False) -> None: self.unicode = unicode self.color = color self.visited = [] - def render(self, roots, dataframe, **kwargs): + def render( + self, + roots: Union[List[Node], Tuple[Node, ...]], + dataframe: pd.DataFrame, + **kwargs, + ) -> str: self.render_header = kwargs["render_header"] if self.render_header: @@ -161,7 +168,7 @@ def render(self, roots, dataframe, **kwargs): return result.encode("utf-8") # pylint: disable=W1401 - def render_preamble(self): + def render_preamble(self) -> str: lines = [ r" __ __ __ __ ", r" / /_ ____ _/ /______/ /_ ___ / /_", @@ -174,7 +181,7 @@ def render_preamble(self): return "\n".join(lines) - def render_legend(self): + def render_legend(self) -> str: def render_label(index, low, high): metric_range = self.max_metric - self.min_metric @@ -247,7 +254,13 @@ def render_label(index, low, high): return legend - def render_frame(self, node, dataframe, indent="", child_indent=""): + def render_frame( + self, + node: Node, + dataframe: pd.DataFrame, + indent: str = "", + child_indent: str = "", + ) -> str: node_depth = node._depth if node_depth < self.depth: # set dataframe index based on whether rank and thread are part of @@ -288,8 +301,8 @@ def render_frame(self, node, dataframe, indent="", child_indent=""): "none": "", "constant": "\U00002192", "phased": "\U00002933", - "dynamic": "\U000021DD", - "sporadic": "\U0000219D", + "dynamic": "\U000021dd", + "sporadic": "\U0000219d", } pattern_metric = dataframe.loc[df_index, self.annotation_column] annotation_content = self.temporal_symbols[pattern_metric] @@ -395,7 +408,7 @@ def render_frame(self, node, dataframe, indent="", child_indent=""): return result - def _ansi_color_for_metric(self, metric): + def _ansi_color_for_metric(self, metric: float) -> str: metric_range = self.max_metric - self.min_metric if metric_range != 0: @@ -418,7 +431,7 @@ def _ansi_color_for_metric(self, metric): else: return self.colors.blue - def _ansi_color_for_name(self, node_name): + def _ansi_color_for_name(self, node_name: str) -> str: if self.highlight is False: return "" @@ -445,7 +458,7 @@ class colors_enabled: class colors_disabled: colormap = ["", "", "", "", "", "", ""] - def __getattr__(self, key): + def __getattr__(self, key: str) -> str: return "" colors_disabled = colors_disabled() diff --git a/hatchet/frame.py b/hatchet/frame.py index b090a963..f89e8f75 100644 --- a/hatchet/frame.py +++ b/hatchet/frame.py @@ -4,6 +4,7 @@ # SPDX-License-Identifier: MIT from functools import total_ordering +from typing import Any, Dict, List, Tuple, Union @total_ordering @@ -14,7 +15,7 @@ class Frame: attrs (dict): dictionary of attributes and values """ - def __init__(self, attrs=None, **kwargs): + def __init__(self, attrs: Dict[str, Any] = None, **kwargs) -> None: """Construct a frame from a dictionary, or from immediate kwargs. Arguments: @@ -48,42 +49,44 @@ def __init__(self, attrs=None, **kwargs): self._tuple_repr = None - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return self.tuple_repr == other.tuple_repr - def __lt__(self, other): + def __lt__(self, other: object) -> bool: return self.tuple_repr < other.tuple_repr - def __gt__(self, other): + def __gt__(self, other: object) -> bool: return self.tuple_repr > other.tuple_repr - def __hash__(self): + def __hash__(self) -> int: return hash(self.tuple_repr) - def __str__(self): + def __str__(self) -> str: """str() with sorted attributes, so output is deterministic.""" return "{%s}" % ", ".join("'%s': '%s'" % (k, v) for k, v in self.tuple_repr) - def __repr__(self): + def __repr__(self) -> str: return "Frame(%s)" % self @property - def tuple_repr(self): + def tuple_repr(self) -> Tuple[Tuple[str, Any], ...]: """Make a tuple of attributes and values based on reader.""" if not self._tuple_repr: self._tuple_repr = tuple(sorted((k, v) for k, v in self.attrs.items())) return self._tuple_repr - def copy(self): + def copy(self) -> "Frame": return Frame(self.attrs.copy()) - def __getitem__(self, name): + def __getitem__(self, name: str) -> Any: return self.attrs[name] - def get(self, name, default=None): + def get(self, name: str, default: Any = None): return self.attrs.get(name, default) - def values(self, names): + def values( + self, names: Union[List[str], Tuple[str, ...], str] + ) -> Union[Tuple[Any, ...], Any]: """Return a tuple of attribute values from this Frame.""" if isinstance(names, (list, tuple)): return tuple(self.attrs.get(name) for name in names) diff --git a/hatchet/graph.py b/hatchet/graph.py index 55194bd2..20d76b75 100644 --- a/hatchet/graph.py +++ b/hatchet/graph.py @@ -4,11 +4,13 @@ # SPDX-License-Identifier: MIT from collections import defaultdict +from collections.abc import Iterable +from typing import Any, Dict, List, Tuple, Union from .node import Node, traversal_order, node_traversal_order -def index_by(attr, objects): +def index_by(attr: str, objects: Union[List, Tuple]) -> Dict: """Put objects into lists based on the value of an attribute. Returns: @@ -23,12 +25,17 @@ def index_by(attr, objects): class Graph: """A possibly multi-rooted tree or graph from one input dataset.""" - def __init__(self, roots): + def __init__(self, roots: Union[List[Node], Tuple[Node, ...]]) -> None: assert roots is not None self.roots = roots self.node_ordering = False - def traverse(self, order="pre", attrs=None, visited=None): + def traverse( + self, + order: str = "pre", + attrs: Union[List[str], Tuple[str, ...], str] = None, + visited: Dict[int, int] = None, + ) -> Iterable[Union[Node, Union[Tuple[Any, ...], Any]]]: """Preorder traversal of all roots of this Graph. Arguments: @@ -47,7 +54,12 @@ def traverse(self, order="pre", attrs=None, visited=None): for value in root.traverse(order=order, attrs=attrs, visited=visited): yield value - def node_order_traverse(self, order="pre", attrs=None, visited=None): + def node_order_traverse( + self, + order: str = "pre", + attrs: Union[List[str], Tuple[str, ...], str] = None, + visited: Dict[int, int] = None, + ) -> Iterable[Union[Node, Union[Tuple[Any, ...], Any]]]: """Preorder traversal of all roots of this Graph, sorting by "node order" column. Arguments: @@ -69,7 +81,7 @@ def node_order_traverse(self, order="pre", attrs=None, visited=None): ): yield value - def is_tree(self): + def is_tree(self) -> bool: """True if this graph is a tree, false otherwise.""" if len(self.roots) > 1: return False @@ -78,7 +90,7 @@ def is_tree(self): list(self.traverse(visited=visited)) return all(v == 1 for v in visited.values()) - def find_merges(self): + def find_merges(self) -> Dict[Node, Node]: """Find nodes that have the same parent and frame. Find nodes that have the same parent and duplicate frame, and @@ -135,7 +147,7 @@ def _find_child_merges(node_list): return merges - def merge_nodes(self, merges): + def merge_nodes(self, merges: Dict[Node, Node]): """Merge some nodes in a graph into others. ``merges`` is a dictionary keyed by old nodes, with values equal @@ -159,12 +171,12 @@ def transform(node_list): child.parents = transform(child.parents) self.roots = transform(self.roots) - def normalize(self): + def normalize(self) -> Dict[Node, Node]: merges = self.find_merges() self.merge_nodes(merges) return merges - def copy(self, old_to_new=None): + def copy(self, old_to_new: Dict[Node, Node] = None) -> "Graph": """Create and return a copy of this graph. Arguments: @@ -192,7 +204,7 @@ def copy(self, old_to_new=None): return graph - def union(self, other, old_to_new=None): + def union(self, other: "Graph", old_to_new: Dict[Node, Node] = None) -> "Graph": """Create the union of self and other and return it as a new Graph. This creates a new graph and does not modify self or other. The @@ -342,7 +354,7 @@ def connect(parent, new_node): return graph - def enumerate_depth(self): + def enumerate_depth(self) -> None: def _iter_depth(node, visited): for child in node.children: if child not in visited: @@ -356,7 +368,7 @@ def _iter_depth(node, visited): root._depth = 0 # depth of root node is 0 _iter_depth(root, visited) - def enumerate_traverse(self): + def enumerate_traverse(self) -> None: if not self._check_enumerate_traverse(): # if "node order" column exists, we traverse sorting by _hatchet_nid if self.node_ordering: @@ -368,7 +380,7 @@ def enumerate_traverse(self): self.enumerate_depth() - def _check_enumerate_traverse(self): + def _check_enumerate_traverse(self) -> bool: # if "node order" column exists, we traverse sorting by _hatchet_nid if self.node_ordering: for i, node in enumerate(self.node_order_traverse()): @@ -379,11 +391,11 @@ def _check_enumerate_traverse(self): if i != node._hatchet_nid: return False - def __len__(self): + def __len__(self) -> int: """Size of the graph in terms of number of nodes.""" return sum(1 for _ in self.traverse()) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: """Check if two graphs have the same structure by comparing frame at each node. """ @@ -415,11 +427,11 @@ def __eq__(self, other): return True - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not (self == other) @staticmethod - def from_lists(*roots): + def from_lists(*roots) -> "Graph": """Convenience method to invoke Node.from_lists() on each root value.""" if not all(isinstance(r, (list, tuple)) for r in roots): raise ValueError( diff --git a/hatchet/graphframe.py b/hatchet/graphframe.py index d041597b..c2c92019 100644 --- a/hatchet/graphframe.py +++ b/hatchet/graphframe.py @@ -8,6 +8,9 @@ import sys import traceback from collections import defaultdict +from collections.abc import Callable +from typing import Any, Dict, List, Tuple, Union +from io import TextIOWrapper import multiprocess as mp import numpy as np @@ -19,6 +22,8 @@ from .node import Node from .query import ( AbstractQuery, + Query, + CompoundQuery, ObjectQuery, QueryEngine, is_hatchet_query, @@ -39,7 +44,21 @@ raise -def parallel_apply(filter_function, dataframe, queue): +PandasSingleGroupbyType = Union[Callable, Dict, str, Tuple[str, ...], pd.Grouper] +PandasMultipleGroupbyType = Union[ + PandasSingleGroupbyType, + List[PandasSingleGroupbyType], + Tuple[PandasSingleGroupbyType, ...], +] +PandasSingleGroupbyAggType = Union[Callable, str] +PandasMultipleGroupbyAggType = Union[ + List[PandasSingleGroupbyAggType], + Tuple[PandasSingleGroupbyAggType, ...], + Dict[Union[str, List[str], Tuple[str, ...]], PandasSingleGroupbyAggType], +] + + +def parallel_apply(filter_function: Callable, dataframe: pd.DataFrame, queue: mp.Queue): """A function called in parallel, which does a pandas apply on part of a dataframe and returns the results via multiprocessing queue function.""" filtered_rows = dataframe.apply(filter_function, axis=1) @@ -54,13 +73,13 @@ class GraphFrame: def __init__( self, - graph, - dataframe, - exc_metrics=None, - inc_metrics=None, - default_metric="time", - metadata={}, - ): + graph: Graph, + dataframe: pd.DataFrame, + exc_metrics: List[str] = None, + inc_metrics: List[str] = None, + default_metric: str = "time", + metadata: Dict[str, Any] = {}, + ) -> None: """Create a new GraphFrame from a graph and a dataframe. Likely, you do not want to use this function. @@ -94,7 +113,7 @@ def __init__( self.query_engine = QueryEngine() @staticmethod - def from_hpctoolkit(dirname): + def from_hpctoolkit(dirname: str) -> "GraphFrame": """Read an HPCToolkit database directory into a new GraphFrame. Arguments: @@ -115,7 +134,7 @@ def from_hpctoolkit_latest( max_depth: int = None, min_percentage_of_application_time: int = None, min_percentage_of_parent_time: int = None, - ): + ) -> "GraphFrame": """ Read an HPCToolkit database directory into a new GraphFrame @@ -139,7 +158,9 @@ def from_hpctoolkit_latest( ).read() @staticmethod - def from_caliper(filename_or_stream, query=None): + def from_caliper( + filename_or_stream: Union[str, TextIOWrapper], query: str = None + ) -> "GraphFrame": """Read in a Caliper .cali or .json file. Args: @@ -155,8 +176,10 @@ def from_caliper(filename_or_stream, query=None): @staticmethod def from_caliperreader( - filename_or_caliperreader, native=False, string_attributes=[] - ): + filename_or_caliperreader: Any, + native: bool = False, + string_attributes: Union[List[str], str] = [], + ) -> "GraphFrame": """Read in a native Caliper `cali` file using Caliper's python reader. Args: @@ -175,11 +198,11 @@ def from_caliperreader( @staticmethod def from_timeseries( - filename_or_caliperreader, - level="loop.start_iteration", - native=False, - string_attributes=[], - ): + filename_or_caliperreader: Any, + level: str = "loop.start_iteration", + native: bool = False, + string_attributes: Union[List[str], str] = [], + ) -> "GraphFrame": """Read in a native Caliper timeseries `cali` file using Caliper's python reader. Args: @@ -197,7 +220,7 @@ def from_timeseries( ).read_timeseries(level=level) @staticmethod - def from_spotdb(db_key, list_of_ids=None): + def from_spotdb(db_key: Any, list_of_ids: List = None) -> "GraphFrame": """Read multiple graph frames from a SpotDB instance Args: @@ -221,7 +244,7 @@ def from_spotdb(db_key, list_of_ids=None): return SpotDBReader(db_key, list_of_ids).read() @staticmethod - def from_gprof_dot(filename): + def from_gprof_dot(filename: str) -> "GraphFrame": """Read in a DOT file generated by gprof2dot.""" # import this lazily to avoid circular dependencies from .readers.gprof_dot_reader import GprofDotReader @@ -229,7 +252,7 @@ def from_gprof_dot(filename): return GprofDotReader(filename).read() @staticmethod - def from_cprofile(filename): + def from_cprofile(filename: str) -> "GraphFrame": """Read in a pstats/prof file generated using python's cProfile.""" # import this lazily to avoid circular dependencies from .readers.cprofile_reader import CProfileReader @@ -237,7 +260,7 @@ def from_cprofile(filename): return CProfileReader(filename).read() @staticmethod - def from_pyinstrument(filename): + def from_pyinstrument(filename: str) -> "GraphFrame": """Read in a JSON file generated using Pyinstrument.""" # import this lazily to avoid circular dependencies from .readers.pyinstrument_reader import PyinstrumentReader @@ -245,7 +268,7 @@ def from_pyinstrument(filename): return PyinstrumentReader(filename).read() @staticmethod - def from_tau(dirname): + def from_tau(dirname: str) -> "GraphFrame": """Read in a profile generated using TAU.""" # import this lazily to avoid circular dependencies from .readers.tau_reader import TAUReader @@ -253,7 +276,11 @@ def from_tau(dirname): return TAUReader(dirname).read() @staticmethod - def from_timemory(input=None, select=None, **_kwargs): + def from_timemory( + input: Union[str, TextIOWrapper, Dict[str, Any]] = None, + select: List[str] = None, + **_kwargs, + ) -> "GraphFrame": """Read in timemory data. Links: @@ -337,7 +364,7 @@ def from_timemory(input=None, select=None, **_kwargs): raise @staticmethod - def from_literal(graph_dict): + def from_literal(graph_dict: List[Dict]) -> "GraphFrame": """Create a GraphFrame from a list of dictionaries.""" # import this lazily to avoid circular dependencies from .readers.literal_reader import LiteralReader @@ -345,7 +372,7 @@ def from_literal(graph_dict): return LiteralReader(graph_dict).read() @staticmethod - def from_lists(*lists): + def from_lists(*lists) -> "GraphFrame": """Make a simple GraphFrame from lists. This creates a Graph from lists (see ``Graph.from_lists()``) and uses @@ -367,25 +394,27 @@ def from_lists(*lists): return gf @staticmethod - def from_json(json_spec, **kwargs): + def from_json(json_spec: str, **kwargs) -> "GraphFrame": from .readers.json_reader import JsonReader return JsonReader(json_spec).read(**kwargs) @staticmethod - def from_hdf(filename, **kwargs): + def from_hdf(filename: str, **kwargs) -> "GraphFrame": # import this lazily to avoid circular dependencies from .readers.hdf5_reader import HDF5Reader return HDF5Reader(filename).read(**kwargs) - def to_hdf(self, filename, key="hatchet_graphframe", **kwargs): + def to_hdf( + self, filename: str, key: str = "hatchet_graphframe", **kwargs + ) -> "GraphFrame": # import this lazily to avoid circular dependencies from .writers.hdf5_writer import HDF5Writer HDF5Writer(filename).write(self, key=key, **kwargs) - def copy(self): + def copy(self) -> "GraphFrame": """Return a partially shallow copy of the graphframe. This copies the DataFrame object, but the data is comprised of references. The Graph is shared between self and the new GraphFrame. @@ -411,7 +440,7 @@ def copy(self): copy.copy(self.metadata), ) - def deepcopy(self): + def deepcopy(self) -> "GraphFrame": """Return a deep copy of the graphframe. Arguments: @@ -446,7 +475,7 @@ def deepcopy(self): copy.deepcopy(self.metadata), ) - def drop_index_levels(self, function=np.mean): + def drop_index_levels(self, function: Callable = np.mean): """Drop all index levels but `node`.""" index_names = list(self.dataframe.index.names) index_names.remove("node") @@ -467,13 +496,13 @@ def drop_index_levels(self, function=np.mean): def filter( self, - filter_obj, - squash=True, - update_inc_cols=True, - num_procs=mp.cpu_count(), - rec_limit=1000, - multi_index_mode="off", - ): + filter_obj: Union[Callable, List, str, Query, CompoundQuery, AbstractQuery], + squash: bool = True, + update_inc_cols: bool = True, + num_procs: int = mp.cpu_count(), + rec_limit: int = 1000, + multi_index_mode: str = "off", + ) -> "GraphFrame": """Filter the dataframe using a user-supplied function. Note: Operates in parallel on user-supplied lambda functions. @@ -571,7 +600,7 @@ def filter( return filtered_gf.squash(update_inc_cols) return filtered_gf - def squash(self, update_inc_cols=True): + def squash(self, update_inc_cols: bool = True) -> "GraphFrame": """Rewrite the Graph to include only nodes present in the DataFrame's rows. This can be used to simplify the Graph, or to normalize Graph @@ -676,7 +705,11 @@ def rewire(node, new_parent, visited): new_gf.update_inclusive_columns() return new_gf - def _init_sum_columns(self, columns, out_columns): + def _init_sum_columns( + self, + columns: List[str], + out_columns: List[str], + ) -> List[str]: """Helper function for subtree_sum and subgraph_sum.""" if out_columns is None: out_columns = columns @@ -691,7 +724,10 @@ def _init_sum_columns(self, columns, out_columns): return out_columns def subtree_sum( - self, columns, out_columns=None, function=lambda x: x.sum(min_count=1) + self, + columns: List[str], + out_columns: List[str] = None, + function: Callable = lambda x: x.sum(min_count=1), ): """Compute sum of elements in subtrees. Valid only for trees. @@ -751,7 +787,10 @@ def subtree_sum( ) def subgraph_sum( - self, columns, out_columns=None, function=lambda x: x.sum(min_count=1) + self, + columns: List[str], + out_columns: List[str] = None, + function: Callable = lambda x: x.sum(min_count=1), ): """Compute sum of elements in subgraphs. @@ -813,7 +852,7 @@ def subgraph_sum( function(self.dataframe.loc[(subgraph_nodes), columns]) ) - def generate_exclusive_columns(self, inc_metrics=None): + def generate_exclusive_columns(self, inc_metrics: Union[str, List[str]] = None): """Generates exclusive metrics from available inclusive metrics. Arguments: inc_metrics (str, list, optional): Instead of generating the exclusive time for each inclusive metric, it is possible to specify those metrics manually. Defaults to None. @@ -938,11 +977,11 @@ def update_inclusive_columns(self): self.subgraph_sum(self.exc_metrics, self.inc_metrics) self.inc_metrics = list(set(self.inc_metrics + old_inc_metrics)) - def show_metric_columns(self): + def show_metric_columns(self) -> List[str]: """Returns a list of dataframe column labels.""" return list(self.exc_metrics + self.inc_metrics) - def unify(self, other): + def unify(self, other: "GraphFrame"): """Returns a unified graphframe. Ensure self and other have the same graph and same node IDs. This may @@ -986,23 +1025,23 @@ def unify(self, other): ) def tree( self, - metric_column=None, - annotation_column=None, - precision=3, - name_column="name", - expand_name=False, - context_column="file", - rank=0, - thread=0, - depth=10000, - highlight_name=False, - colormap="RdYlGn", - invert_colormap=False, - colormap_annotations=None, - render_header=True, - min_value=None, - max_value=None, - ): + metric_column: str = None, + annotation_column: str = None, + precision: int = 3, + name_column: str = "name", + expand_name: bool = False, + context_column: str = "file", + rank: int = 0, + thread: int = 0, + depth: int = 10000, + highlight_name: bool = False, + colormap: str = "RdYlGn", + invert_colormap: bool = False, + colormap_annotations: Union[str, List, Dict] = None, + render_header: bool = True, + min_value: int = None, + max_value: int = None, + ) -> str: """Visualize the Hatchet graphframe as a tree Arguments: @@ -1068,7 +1107,14 @@ def tree( max_value=max_value, ) - def to_dot(self, metric=None, name="name", rank=0, thread=0, threshold=0.0): + def to_dot( + self, + metric: str = None, + name: str = "name", + rank: int = 0, + thread: int = 0, + threshold: float = 0.0, + ) -> str: """Write the graph in the graphviz dot format: https://www.graphviz.org/doc/info/lang.html """ @@ -1141,7 +1187,13 @@ def to_flamegraph(self, metric=None, name="name", rank=0, thread=0, threshold=0. return folded_stack - def to_literal(self, name="name", rank=0, thread=0, cat_columns=[]): + def to_literal( + self, + name: str = "name", + rank: int = 0, + thread: int = 0, + cat_columns: Union[List[str], Tuple[str, ...]] = [], + ) -> List[Dict]: """Format this graph as a list of dictionaries for Roundtrip visualizations. """ @@ -1223,7 +1275,7 @@ def add_nodes(hnode): return graph_literal - def to_dict(self): + def to_dict(self) -> Dict: hatchet_dict = {} """ @@ -1251,10 +1303,10 @@ def to_dict(self): return hatchet_dict - def to_json(self): + def to_json(self) -> str: return json.dumps(self.to_dict()) - def _operator(self, other, op): + def _operator(self, other: "GraphFrame", op: Callable) -> "GraphFrame": """Generic function to apply operator to two dataframes and store result in self. @@ -1277,7 +1329,7 @@ def _operator(self, other, op): return self - def _insert_missing_rows(self, other): + def _insert_missing_rows(self, other: "GraphFrame") -> "GraphFrame": """Helper function to add rows that exist in other, but not in self. This returns a graphframe with a modified dataframe. The new rows will @@ -1386,7 +1438,11 @@ def _insert_missing_rows(self, other): return self - def groupby_aggregate(self, groupby_function, agg_function): + def groupby_aggregate( + self, + groupby_function: Union[PandasSingleGroupbyType, PandasMultipleGroupbyType], + agg_function: Union[PandasSingleGroupbyAggType, PandasMultipleGroupbyAggType], + ) -> "GraphFrame": """Groupby-aggregate dataframe and reindex the Graph. Reindex the graph to match the groupby-aggregated dataframe. @@ -1503,7 +1559,7 @@ def reindex(node, parent, visited): new_gf.drop_index_levels() return new_gf - def add(self, other): + def add(self, other: "GraphFrame") -> "GraphFrame": """Returns the column-wise sum of two graphframes as a new graphframe. This graphframe is the union of self's and other's graphs, and does not @@ -1521,7 +1577,7 @@ def add(self, other): return self_copy._operator(other_copy, self_copy.dataframe.add) - def sub(self, other): + def sub(self, other: "GraphFrame") -> "GraphFrame": """Returns the column-wise difference of two graphframes as a new graphframe. @@ -1540,7 +1596,7 @@ def sub(self, other): return self_copy._operator(other_copy, self_copy.dataframe.sub) - def div(self, other): + def div(self, other: "GraphFrame") -> "GraphFrame": """Returns the column-wise float division of two graphframes as a new graphframe. This graphframe is the union of self's and other's graphs, and does not @@ -1558,7 +1614,7 @@ def div(self, other): return self_copy._operator(other_copy, self_copy.dataframe.divide) - def mul(self, other): + def mul(self, other: "GraphFrame") -> "GraphFrame": """Returns the column-wise float multiplication of two graphframes as a new graphframe. This graphframe is the union of self's and other's graphs, and does not @@ -1576,7 +1632,7 @@ def mul(self, other): return self_copy._operator(other_copy, self_copy.dataframe.multiply) - def __iadd__(self, other): + def __iadd__(self, other: "GraphFrame") -> "GraphFrame": """Computes column-wise sum of two graphframes and stores the result in self. @@ -1595,7 +1651,7 @@ def __iadd__(self, other): return self._operator(other_copy, self.dataframe.add) - def __add__(self, other): + def __add__(self, other: "GraphFrame") -> "GraphFrame": """Returns the column-wise sum of two graphframes as a new graphframe. This graphframe is the union of self's and other's graphs, and does not @@ -1606,7 +1662,7 @@ def __add__(self, other): """ return self.add(other) - def __mul__(self, other): + def __mul__(self, other: "GraphFrame") -> "GraphFrame": """Returns the column-wise multiplication of two graphframes as a new graphframe. This graphframe is the union of self's and other's graphs, and does not @@ -1617,7 +1673,7 @@ def __mul__(self, other): """ return self.mul(other) - def __isub__(self, other): + def __isub__(self, other: "GraphFrame") -> "GraphFrame": """Computes column-wise difference of two graphframes and stores the result in self. @@ -1636,7 +1692,7 @@ def __isub__(self, other): return self._operator(other_copy, self.dataframe.sub) - def __sub__(self, other): + def __sub__(self, other: "GraphFrame") -> "GraphFrame": """Returns the column-wise difference of two graphframes as a new graphframe. @@ -1648,7 +1704,7 @@ def __sub__(self, other): """ return self.sub(other) - def __idiv__(self, other): + def __idiv__(self, other: "GraphFrame") -> "GraphFrame": """Computes column-wise float division of two graphframes and stores the result in self. @@ -1667,7 +1723,7 @@ def __idiv__(self, other): return self._operator(other_copy, self.dataframe.div) - def __truediv__(self, other): + def __truediv__(self, other: "GraphFrame") -> "GraphFrame": """Returns the column-wise float division of two graphframes as a new graphframe. @@ -1679,7 +1735,7 @@ def __truediv__(self, other): """ return self.div(other) - def __imul__(self, other): + def __imul__(self, other: "GraphFrame") -> "GraphFrame": """Computes column-wise float multiplication of two graphframes and stores the result in self. diff --git a/hatchet/node.py b/hatchet/node.py index 672627a2..022ad532 100644 --- a/hatchet/node.py +++ b/hatchet/node.py @@ -4,16 +4,18 @@ # SPDX-License-Identifier: MIT from functools import total_ordering +from typing import Any, Dict, List, Set, Tuple, Union +from collections.abc import Iterable from .frame import Frame -def traversal_order(node): +def traversal_order(node: "Node") -> Tuple(Frame, int): """Deterministic key function for sorting nodes in traversals.""" return (node.frame, id(node)) -def node_traversal_order(node): +def node_traversal_order(node: "Node") -> int: """Deterministic key function for sorting nodes by specified "node order" (which gets assigned to _hatchet_nid) in traversals.""" return node._hatchet_nid @@ -23,7 +25,9 @@ def node_traversal_order(node): class Node: """A node in the graph. The node only stores its frame.""" - def __init__(self, frame_obj, parent=None, hnid=-1, depth=-1): + def __init__( + self, frame_obj: Frame, parent: "Node" = None, hnid: int = -1, depth: int = -1 + ) -> None: self.frame = frame_obj self._depth = depth self._hatchet_nid = hnid @@ -33,17 +37,17 @@ def __init__(self, frame_obj, parent=None, hnid=-1, depth=-1): self.add_parent(parent) self.children = [] - def add_parent(self, node): + def add_parent(self, node: "Node"): """Adds a parent to this node's list of parents.""" assert isinstance(node, Node) self.parents.append(node) - def add_child(self, node): + def add_child(self, node: "Node"): """Adds a child to this node's list of children.""" assert isinstance(node, Node) self.children.append(node) - def paths(self): + def paths(self) -> List[Tuple["Node", ...]]: """List of tuples, one for each path from this node to any root. Paths are tuples of node objects. @@ -58,7 +62,7 @@ def paths(self): paths.extend([path + node_value for path in parent_paths]) return paths - def path(self, attrs=None): + def path(self, attrs: Dict[str, Any] = None) -> Tuple["Node", ...]: """Path to this node from root. Raises if there are multiple paths. This is useful for trees (where each node only has one path), as @@ -71,7 +75,9 @@ def path(self, attrs=None): raise MultiplePathError("Node has more than one path: " % paths) return paths[0] - def dag_equal(self, other, vs=None, vo=None): + def dag_equal( + self, other: "Node", vs: Set[int] = None, vo: Set[int] = None + ) -> bool: """Check if DAG rooted at self has the same structure as that rooted at other. """ @@ -113,7 +119,12 @@ def dag_equal(self, other, vs=None, vo=None): return True - def traverse(self, order="pre", attrs=None, visited=None): + def traverse( + self, + order: str = "pre", + attrs: Union[List[str], Tuple[str, ...], str] = None, + visited: Dict[int, int] = None, + ) -> Iterable[Union["Node", Union[Tuple[Any, ...], Any]]]: """Traverse the tree depth-first and yield each node. Arguments: @@ -149,7 +160,12 @@ def value(node): if order == "post": yield value(self) - def node_order_traverse(self, order="pre", attrs=None, visited=None): + def node_order_traverse( + self, + order: str = "pre", + attrs: Union[List[str], Tuple[str, ...], str] = None, + visited: Dict[int, int] = None, + ) -> Iterable[Union["Node", Union[Tuple[Any, ...], Any]]]: """Traverse the tree depth-first and yield each node, sorting children by "node order". Arguments: @@ -187,28 +203,34 @@ def value(node): if order == "post": yield value(self) - def __hash__(self): + def __hash__(self) -> int: return self._hatchet_nid - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return self._hatchet_nid == other._hatchet_nid - def __lt__(self, other): + def __lt__(self, other: object) -> bool: return self._hatchet_nid < other._hatchet_nid - def __gt__(self, other): + def __gt__(self, other: object) -> bool: return self._hatchet_nid > other._hatchet_nid - def __str__(self): + def __str__(self) -> str: """Returns a string representation of the node.""" return str(self.frame) - def copy(self): + def copy(self) -> "Node": """Copy this node without preserving parents or children.""" return Node(frame_obj=self.frame.copy()) @classmethod - def from_lists(cls, lists): + def from_lists( + cls, + lists: Union[ + List[str, "Node", Union[List, Tuple]], + Tuple[str, "Node", Union[List, Tuple]], + ], + ) -> "Node": r"""Construct a hierarchy of nodes from recursive lists. For example, this will construct a simple tree: @@ -278,7 +300,7 @@ def _from_lists(lists, parent): return _from_lists(lists, None) - def __repr__(self): + def __repr__(self) -> str: return "Node({%s})" % ", ".join( "%s: %s" % (repr(k), repr(v)) for k, v in sorted(self.frame.attrs.items()) ) diff --git a/hatchet/query/__init__.py b/hatchet/query/__init__.py index 77e5bfdc..6e0da6a9 100644 --- a/hatchet/query/__init__.py +++ b/hatchet/query/__init__.py @@ -6,6 +6,8 @@ # Make flake8 ignore unused names in this file # flake8: noqa: F401 +from typing import Any, Union, List, TypeVar + from .query import Query from .compound import ( CompoundQuery, @@ -39,20 +41,54 @@ parse_cypher_query, ) +BaseQueryType = TypeVar("BaseQuery", Query, ObjectQuery, StringQuery, str, List) +CompoundQueryType = TypeVar( + "CompoundQuery", + CompoundQuery, + ConjunctionQuery, + DisjunctionQuery, + ExclusiveDisjunctionQuery, + NegationQuery, +) +LegacyQueryType = TypeVar( + "LegacyQuery", + AbstractQuery, + NaryQuery, + AndQuery, + IntersectionQuery, + OrQuery, + UnionQuery, + XorQuery, + SymDifferenceQuery, + NotQuery, + QueryMatcher, + CypherQuery, +) +AnyQueryType = TypeVar("AnyQuery", BaseQueryType, CompoundQueryType, LegacyQueryType) + -def combine_via_conjunction(query0, query1): +def combine_via_conjunction( + query0: Union[BaseQueryType, CompoundQueryType], + query1: Union[BaseQueryType, CompoundQueryType], +) -> ConjunctionQuery: return ConjunctionQuery(query0, query1) -def combine_via_disjunction(query0, query1): +def combine_via_disjunction( + query0: Union[BaseQueryType, CompoundQueryType], + query1: Union[BaseQueryType, CompoundQueryType], +) -> DisjunctionQuery: return DisjunctionQuery(query0, query1) -def combine_via_exclusive_disjunction(query0, query1): +def combine_via_exclusive_disjunction( + query0: Union[BaseQueryType, CompoundQueryType], + query1: Union[BaseQueryType, CompoundQueryType], +) -> ExclusiveDisjunctionQuery: return ExclusiveDisjunctionQuery(query0, query1) -def negate_query(query): +def negate_query(query: Union[BaseQueryType, CompoundQueryType]) -> NegationQuery: return NegationQuery(query) @@ -68,7 +104,7 @@ def negate_query(query): CompoundQuery.__not__ = negate_query -def is_hatchet_query(query_obj): +def is_hatchet_query(query_obj: Any) -> bool: return ( issubclass(type(query_obj), Query) or issubclass(type(query_obj), CompoundQuery) diff --git a/hatchet/query/compat.py b/hatchet/query/compat.py index d62a0c5c..2e71a682 100644 --- a/hatchet/query/compat.py +++ b/hatchet/query/compat.py @@ -13,7 +13,11 @@ ABC = ABCMeta("ABC", (object,), {"__slots__": ()}) import sys import warnings +from collections.abc import Callable +from typing import List, Union +from ..graphframe import GraphFrame +from ..node import Node from .query import Query from .compound import ( CompoundQuery, @@ -29,17 +33,17 @@ # QueryEngine object for running the legacy "apply" methods -COMPATABILITY_ENGINE = QueryEngine() +COMPATABILITY_ENGINE: QueryEngine = QueryEngine() class AbstractQuery(ABC): """Base class for all 'old-style' queries.""" @abstractmethod - def apply(self, gf): + def apply(self, gf: GraphFrame) -> List[Node]: pass - def __and__(self, other): + def __and__(self, other: "AbstractQuery") -> "AndQuery": """Create a new AndQuery using this query and another. Arguments: @@ -50,7 +54,7 @@ def __and__(self, other): """ return AndQuery(self, other) - def __or__(self, other): + def __or__(self, other: "AbstractQuery") -> "OrQuery": """Create a new OrQuery using this query and another. Arguments: @@ -61,7 +65,7 @@ def __or__(self, other): """ return OrQuery(self, other) - def __xor__(self, other): + def __xor__(self, other: "AbstractQuery") -> "XorQuery": """Create a new XorQuery using this query and another. Arguments: @@ -72,7 +76,7 @@ def __xor__(self, other): """ return XorQuery(self, other) - def __invert__(self): + def __invert__(self) -> "NegationQuery": """Create a new NotQuery using this query. Returns: @@ -81,7 +85,7 @@ def __invert__(self): return NotQuery(self) @abstractmethod - def _get_new_query(self): + def _get_new_query(self) -> Union[Query, CompoundQuery]: pass @@ -89,7 +93,7 @@ class NaryQuery(AbstractQuery): """Base class for all compound queries that act on and merged N separate subqueries.""" - def __init__(self, *args): + def __init__(self, *args) -> None: """Create a new NaryQuery object. Arguments: @@ -115,7 +119,7 @@ def __init__(self, *args): high-level query or a subclass of AbstractQuery" ) - def apply(self, gf): + def apply(self, gf: GraphFrame) -> List[Node]: """Applies the query to the specified GraphFrame. Arguments: @@ -127,7 +131,7 @@ def apply(self, gf): true_query = self._get_new_query() return COMPATABILITY_ENGINE.apply(true_query, gf.graph, gf.dataframe) - def _get_new_query(self): + def _get_new_query(self) -> CompoundQuery: """Gets all the underlying 'new-style' queries in this object. Returns: @@ -142,7 +146,9 @@ def _get_new_query(self): return self._convert_to_new_query(true_subqueries) @abstractmethod - def _convert_to_new_query(self, subqueries): + def _convert_to_new_query( + self, subqueries: List[Union[Query, CompoundQuery]] + ) -> CompoundQuery: pass @@ -150,7 +156,7 @@ class AndQuery(NaryQuery): """Compound query that returns the intersection of the results of the subqueries.""" - def __init__(self, *args): + def __init__(self, *args) -> None: """Create a new AndQuery object. Arguments: @@ -168,7 +174,9 @@ def __init__(self, *args): if len(self.compat_subqueries) < 2: raise BadNumberNaryQueryArgs("AndQuery requires 2 or more subqueries") - def _convert_to_new_query(self, subqueries): + def _convert_to_new_query( + self, subqueries: List[Union[Query, CompoundQuery]] + ) -> CompoundQuery: return ConjunctionQuery(*subqueries) @@ -180,7 +188,7 @@ class OrQuery(NaryQuery): """Compound query that returns the union of the results of the subqueries""" - def __init__(self, *args): + def __init__(self, *args) -> None: """Create a new OrQuery object. Arguments: @@ -198,7 +206,9 @@ def __init__(self, *args): if len(self.compat_subqueries) < 2: raise BadNumberNaryQueryArgs("OrQuery requires 2 or more subqueries") - def _convert_to_new_query(self, subqueries): + def _convert_to_new_query( + self, subqueries: List[Union[Query, CompoundQuery]] + ) -> CompoundQuery: return DisjunctionQuery(*subqueries) @@ -210,7 +220,7 @@ class XorQuery(NaryQuery): """Compound query that returns the symmetric difference (i.e., set-based XOR) of the results of the subqueries""" - def __init__(self, *args): + def __init__(self, *args) -> None: """Create a new XorQuery object. Arguments: @@ -228,7 +238,9 @@ def __init__(self, *args): if len(self.compat_subqueries) < 2: raise BadNumberNaryQueryArgs("XorQuery requires 2 or more subqueries") - def _convert_to_new_query(self, subqueries): + def _convert_to_new_query( + self, subqueries: List[Union[Query, CompoundQuery]] + ) -> CompoundQuery: return ExclusiveDisjunctionQuery(*subqueries) @@ -240,7 +252,7 @@ class NotQuery(NaryQuery): """Compound query that returns all nodes in the GraphFrame that are not returned from the subquery.""" - def __init__(self, *args): + def __init__(self, *args) -> None: """Create a new NotQuery object. Arguments: @@ -258,14 +270,16 @@ def __init__(self, *args): if len(self.compat_subqueries) < 1: raise BadNumberNaryQueryArgs("NotQuery requires exactly 1 subquery") - def _convert_to_new_query(self, subqueries): + def _convert_to_new_query( + self, subqueries: List[Union[Query, CompoundQuery]] + ) -> CompoundQuery: return NegationQuery(*subqueries) class QueryMatcher(AbstractQuery): """Processes and applies base syntax queries and Object-based queries to GraphFrames.""" - def __init__(self, query=None): + def __init__(self, query: Union[List, Query] = None) -> None: """Create a new QueryMatcher object. Arguments: @@ -285,7 +299,11 @@ def __init__(self, query=None): else: raise InvalidQueryPath("Provided query is not a valid object dialect query") - def match(self, wildcard_spec=".", filter_func=lambda row: True): + def match( + self, + wildcard_spec: Union[str, int] = ".", + filter_func: Callable = lambda row: True, + ) -> "QueryMatcher": """Start a query with a root node described by the arguments. Arguments: @@ -299,7 +317,11 @@ def match(self, wildcard_spec=".", filter_func=lambda row: True): self.true_query.match(wildcard_spec, filter_func) return self - def rel(self, wildcard_spec=".", filter_func=lambda row: True): + def rel( + self, + wildcard_spec: Union[str, int] = ".", + filter_func: Callable = lambda row: True, + ) -> "QueryMatcher": """Add another edge and node to the query. Arguments: @@ -313,7 +335,7 @@ def rel(self, wildcard_spec=".", filter_func=lambda row: True): self.true_query.rel(wildcard_spec, filter_func) return self - def apply(self, gf): + def apply(self, gf: GraphFrame) -> List[Node]: """Apply the query to a GraphFrame. Arguments: @@ -324,7 +346,7 @@ def apply(self, gf): """ return COMPATABILITY_ENGINE.apply(self.true_query, gf.graph, gf.dataframe) - def _get_new_query(self): + def _get_new_query(self) -> Union[Query, CompoundQuery]: """Get all the underlying 'new-style' query in this object. Returns: @@ -336,7 +358,7 @@ def _get_new_query(self): class CypherQuery(QueryMatcher): """Processes and applies Strinb-based queries to GraphFrames.""" - def __init__(self, cypher_query): + def __init__(self, cypher_query: str) -> None: """Create a new Cypher object. Arguments: @@ -349,7 +371,7 @@ def __init__(self, cypher_query): ) self.true_query = parse_string_dialect(cypher_query) - def _get_new_query(self): + def _get_new_query(self) -> Union[Query, CompoundQuery]: """Gets all the underlying 'new-style' queries in this object. Returns: @@ -358,7 +380,7 @@ def _get_new_query(self): return self.true_query -def parse_cypher_query(cypher_query): +def parse_cypher_query(cypher_query: str) -> CypherQuery: """Parse all types of String-based queries, including multi-queries that leverage the curly brace delimiters. diff --git a/hatchet/query/compound.py b/hatchet/query/compound.py index 73a1ed1a..64c011f6 100644 --- a/hatchet/query/compound.py +++ b/hatchet/query/compound.py @@ -6,7 +6,10 @@ from abc import abstractmethod import sys +from typing import List +from ..node import Node +from ..graph import Graph from .query import Query from .string_dialect import parse_string_dialect from .object_dialect import ObjectQuery @@ -16,7 +19,7 @@ class CompoundQuery(object): """Base class for all types of compound queries.""" - def __init__(self, *queries): + def __init__(self, *queries) -> None: """Collect the provided queries into a list, constructing ObjectQuery and StringQuery objects as needed. Arguments: @@ -39,7 +42,9 @@ def __init__(self, *queries): ) @abstractmethod - def _apply_op_to_results(self, subquery_results): + def _apply_op_to_results( + self, subquery_results: List[List[Node]], graph: Graph + ) -> List[Node]: """Combines/Modifies the results of the subqueries based on the operation the subclass represents. """ @@ -51,7 +56,7 @@ class ConjunctionQuery(CompoundQuery): using set conjunction. """ - def __init__(self, *queries): + def __init__(self, *queries) -> None: """Create the ConjunctionQuery. Arguments: @@ -66,7 +71,9 @@ def __init__(self, *queries): "ConjunctionQuery requires 2 or more subqueries" ) - def _apply_op_to_results(self, subquery_results, graph): + def _apply_op_to_results( + self, subquery_results: List[List[Node]], graph: Graph + ) -> List[Node]: """Combines the results of the subqueries using set conjunction. Arguments: @@ -85,7 +92,7 @@ class DisjunctionQuery(CompoundQuery): using set disjunction. """ - def __init__(self, *queries): + def __init__(self, *queries) -> None: """Create the DisjunctionQuery. Arguments: @@ -100,7 +107,9 @@ def __init__(self, *queries): "DisjunctionQuery requires 2 or more subqueries" ) - def _apply_op_to_results(self, subquery_results, graph): + def _apply_op_to_results( + self, subquery_results: List[List[Node]], graph: Graph + ) -> List[Node]: """Combines the results of the subqueries using set disjunction. Arguments: @@ -119,7 +128,7 @@ class ExclusiveDisjunctionQuery(CompoundQuery): using exclusive set disjunction. """ - def __init__(self, *queries): + def __init__(self, *queries) -> None: """Create the ExclusiveDisjunctionQuery. Arguments: @@ -132,7 +141,9 @@ def __init__(self, *queries): if len(self.subqueries) < 2: raise BadNumberNaryQueryArgs("XorQuery requires 2 or more subqueries") - def _apply_op_to_results(self, subquery_results, graph): + def _apply_op_to_results( + self, subquery_results: List[List[Node]], graph: Graph + ) -> List[Node]: """Combines the results of the subqueries using exclusive set disjunction. Arguments: @@ -153,7 +164,7 @@ class NegationQuery(CompoundQuery): its single subquery. """ - def __init__(self, *queries): + def __init__(self, *queries) -> None: """Create the NegationQuery. Arguments: @@ -166,7 +177,9 @@ def __init__(self, *queries): if len(self.subqueries) != 1: raise BadNumberNaryQueryArgs("NotQuery requires exactly 1 subquery") - def _apply_op_to_results(self, subquery_results, graph): + def _apply_op_to_results( + self, subquery_results: List[List[Node]], graph: Graph + ) -> List[Node]: """Inverts the results of the subquery so that all nodes not in the results are returned. Arguments: diff --git a/hatchet/query/engine.py b/hatchet/query/engine.py index 9717e240..ecd7acd7 100644 --- a/hatchet/query/engine.py +++ b/hatchet/query/engine.py @@ -5,9 +5,11 @@ from itertools import groupby import pandas as pd +from typing import List, Set, Union from .errors import InvalidQueryFilter from ..node import Node, traversal_order +from ..graph import Graph from .query import Query from .compound import CompoundQuery from .object_dialect import ObjectQuery @@ -17,15 +19,17 @@ class QueryEngine: """Class for applying queries to GraphFrames.""" - def __init__(self): + def __init__(self) -> None: """Creates the QueryEngine.""" self.search_cache = {} - def reset_cache(self): + def reset_cache(self) -> None: """Resets the cache in the QueryEngine.""" self.search_cache = {} - def apply(self, query, graph, dframe): + def apply( + self, query: Union[Query, CompoundQuery], graph: Graph, dframe: pd.DataFrame + ) -> List[Node]: """Apply the query to a GraphFrame. Arguments: @@ -59,7 +63,7 @@ def apply(self, query, graph, dframe): else: raise TypeError("Invalid query data type ({})".format(str(type(query)))) - def _cache_node(self, node, query, dframe): + def _cache_node(self, node: Node, query: Query, dframe: pd.DataFrame) -> None: """Cache (Memoize) the parts of the query that the node matches. Arguments: @@ -82,7 +86,9 @@ def _cache_node(self, node, query, dframe): matches.append(i) self.search_cache[node._hatchet_nid] = matches - def _match_0_or_more(self, query, dframe, node, wcard_idx): + def _match_0_or_more( + self, query: Query, dframe: pd.DataFrame, node: Node, wcard_idx: int + ) -> List[List[Node]]: """Process a "*" predicate in the query on a subgraph. Arguments: @@ -128,7 +134,9 @@ def _match_0_or_more(self, query, dframe, node, wcard_idx): return [[]] return None - def _match_1(self, query, dframe, node, idx): + def _match_1( + self, query: Query, dframe: pd.DataFrame, node: Node, idx: int + ) -> List[List[Node]]: """Process a "." predicate in the query on a subgraph. Arguments: @@ -156,7 +164,9 @@ def _match_1(self, query, dframe, node, idx): return None return matches - def _match_pattern(self, query, dframe, pattern_root, match_idx): + def _match_pattern( + self, query: Query, dframe: pd.DataFrame, pattern_root: Node, match_idx: int + ) -> List[List[Node]]: """Try to match the query pattern starting at the provided root node. Arguments: @@ -221,7 +231,14 @@ def _match_pattern(self, query, dframe, pattern_root, match_idx): pattern_idx += 1 return matches - def _apply_impl(self, query, dframe, node, visited, matches): + def _apply_impl( + self, + query: Query, + dframe: pd.DataFrame, + node: Node, + visited: Set[Node], + matches: List[List[Node]], + ) -> None: """Traverse the subgraph with the specified root, and collect all paths that match the query. Arguments: diff --git a/hatchet/query/object_dialect.py b/hatchet/query/object_dialect.py index daf55c65..ecdc413b 100644 --- a/hatchet/query/object_dialect.py +++ b/hatchet/query/object_dialect.py @@ -12,12 +12,14 @@ ) import re import sys +from typing import Dict, List, Tuple, Union from .errors import InvalidQueryPath, InvalidQueryFilter, MultiIndexModeMismatch +from ..node import Node from .query import Query -def _process_multi_index_mode(apply_result, multi_index_mode): +def _process_multi_index_mode(apply_result: pd.Series, multi_index_mode: str): if multi_index_mode == "any": return apply_result.any() if multi_index_mode == "all": @@ -27,12 +29,19 @@ def _process_multi_index_mode(apply_result, multi_index_mode): ) -def _process_predicate(attr_filter, multi_index_mode): +def _process_predicate( + attr_filter: Dict[Union[str, Tuple[str, ...]], Union[str, Real]], + multi_index_mode: str, +) -> bool: """Converts high-level API attribute filter to a lambda""" compops = ("<", ">", "==", ">=", "<=", "<>", "!=") # , - def filter_series(df_row): - def filter_single_series(df_row, key, single_value): + def filter_series(df_row: pd.Series) -> bool: + def filter_single_series( + df_row: pd.Series, + key: Union[str, Tuple[str]], + single_value: Union[str, Real], + ) -> bool: if key == "depth": node = df_row.name if isinstance(single_value, str) and single_value.lower().startswith( @@ -125,14 +134,19 @@ def filter_single_series(df_row, key, single_value): ) return matches - def filter_dframe(df_row): + def filter_dframe(df_row: pd.DataFrame) -> bool: if multi_index_mode == "off": raise MultiIndexModeMismatch( "The ObjectQuery's 'multi_index_mode' argument \ cannot be set to 'off' when using multi-indexed data" ) - def filter_single_dframe(node, df_row, key, single_value): + def filter_single_dframe( + node: Node, + df_row: pd.Series, + key: Union[str, Tuple[str, ...]], + single_value: Union[str, Real], + ) -> bool: if key == "depth": if isinstance(single_value, str) and single_value.lower().startswith( compops @@ -207,7 +221,7 @@ def filter_single_dframe(node, df_row, key, single_value): ) return matches - def filter_choice(df_row): + def filter_choice(df_row: Union[pd.Series, pd.DataFrame]) -> bool: if isinstance(df_row, pd.DataFrame): return filter_dframe(df_row) return filter_series(df_row) @@ -218,7 +232,7 @@ def filter_choice(df_row): class ObjectQuery(Query): """Class for representing and parsing queries using the Object-based dialect.""" - def __init__(self, query, multi_index_mode="off"): + def __init__(self, query: List, multi_index_mode: str = "off") -> None: """Builds a new ObjectQuery from an instance of the Object-based dialect syntax. Arguments: diff --git a/hatchet/query/query.py b/hatchet/query/query.py index 39f17743..0cc6139e 100644 --- a/hatchet/query/query.py +++ b/hatchet/query/query.py @@ -3,17 +3,22 @@ # # SPDX-License-Identifier: MIT +from typing import Tuple +from collections.abc import Callable + from .errors import InvalidQueryPath class Query(object): """Class for representing and building Hatchet Call Path Queries""" - def __init__(self): + def __init__(self) -> None: """Create new Query""" self.query_pattern = [] - def match(self, quantifier=".", predicate=lambda row: True): + def match( + self, quantifier: str = ".", predicate: Callable = lambda row: True + ) -> "Query": """Start a query with a root node described by the arguments. Arguments: @@ -28,7 +33,9 @@ def match(self, quantifier=".", predicate=lambda row: True): self._add_node(quantifier, predicate) return self - def rel(self, quantifier=".", predicate=lambda row: True): + def rel( + self, quantifier: str = ".", predicate: Callable = lambda row: True + ) -> "Query": """Add a new node to the end of the query. Arguments: @@ -45,7 +52,9 @@ def rel(self, quantifier=".", predicate=lambda row: True): self._add_node(quantifier, predicate) return self - def relation(self, quantifer=".", predicate=lambda row: True): + def relation( + self, quantifer: str = ".", predicate: Callable = lambda row: True + ) -> "Query": """Alias to Query.rel. Add a new node to the end of the query. Arguments: @@ -57,15 +66,17 @@ def relation(self, quantifer=".", predicate=lambda row: True): """ return self.rel(quantifer, predicate) - def __len__(self): + def __len__(self) -> int: """Returns the length of the query.""" return len(self.query_pattern) - def __iter__(self): + def __iter__(self) -> Tuple[str, Callable]: """Allows users to iterate over the Query like a list.""" return iter(self.query_pattern) - def _add_node(self, quantifer=".", predicate=lambda row: True): + def _add_node( + self, quantifer: str = ".", predicate: Callable = lambda row: True + ) -> None: """Add a node to the query. Arguments: diff --git a/hatchet/query/string_dialect.py b/hatchet/query/string_dialect.py index 791128fe..6f83c8b1 100644 --- a/hatchet/query/string_dialect.py +++ b/hatchet/query/string_dialect.py @@ -6,6 +6,8 @@ from numbers import Real import re import sys +from collections.abc import Callable +from typing import Any, Tuple, Union import pandas as pd # noqa: F401 from pandas.api.types import is_numeric_dtype, is_string_dtype # noqa: F401 import numpy as np # noqa: F401 @@ -15,6 +17,7 @@ from .errors import InvalidQueryPath, InvalidQueryFilter, RedundantQueryFilterWarning from .query import Query +from .compound import CompoundQuery # PEG grammar for the String-based dialect @@ -60,12 +63,14 @@ cypher_query_mm = metamodel_from_str(CYPHER_GRAMMAR) -def cname(obj): +def cname(obj: Any) -> str: """Utility function to get the name of the rule represented by the input""" return obj.__class__.__name__ -def filter_check_types(type_check, df_row, filt_lambda): +def filter_check_types( + type_check: str, df_row: Union[pd.Series, pd.DataFrame], filt_lambda: Callable +) -> bool: """Utility function used in String-based predicates to make sure the node data used in the actual boolean predicate is of the correct type. @@ -97,7 +102,7 @@ def filter_check_types(type_check, df_row, filt_lambda): class StringQuery(Query): """Class for representing and parsing queries using the String-based dialect.""" - def __init__(self, cypher_query, multi_index_mode="off"): + def __init__(self, cypher_query: str, multi_index_mode: str = "off") -> None: """Builds a new StringQuery object representing a query in the String-based dialect. Arguments: @@ -128,7 +133,7 @@ def __init__(self, cypher_query, multi_index_mode="off"): self._build_lambdas() self._build_query() - def _build_query(self): + def _build_query(self) -> None: """Builds the entire query using 'match' and 'rel' using the pre-parsed quantifiers and predicates. """ @@ -149,7 +154,7 @@ def _build_query(self): else: self.rel(quantifier=wcard, predicate=eval(filt_str)) - def _build_lambdas(self): + def _build_lambdas(self) -> None: """Constructs the final predicate lambdas from the pre-parsed predicate information. """ @@ -175,7 +180,7 @@ def _build_lambdas(self): ) self.lambda_filters[i] = bool_expr - def _parse_path(self, path_obj): + def _parse_path(self, path_obj: Any) -> None: """Parses the MATCH statement of a String-based query.""" nodes = path_obj.path.nodes idx = len(self.wcards) @@ -188,7 +193,7 @@ def _parse_path(self, path_obj): self.wcard_pos[n.name] = idx idx += 1 - def _parse_conditions(self, cond_expr): + def _parse_conditions(self, cond_expr: Any) -> None: """Top level function for parsing the WHERE statement of a String-based query. """ @@ -209,7 +214,7 @@ def _parse_conditions(self, cond_expr): if self.filters[i][0][0] != "not": self.filters[i][0][0] = None - def _is_unary_cond(self, obj): + def _is_unary_cond(self, obj: Any) -> bool: """Detect whether a predicate is unary or not.""" if ( cname(obj) == "NotCond" @@ -220,13 +225,13 @@ def _is_unary_cond(self, obj): return True return False - def _is_binary_cond(self, obj): + def _is_binary_cond(self, obj: Any) -> bool: """Detect whether a predicate is binary or not.""" if cname(obj) in ["AndCond", "OrCond"]: return True return False - def _parse_binary_cond(self, obj): + def _parse_binary_cond(self, obj: Any) -> Tuple[str, str, str, str]: """Top level function for parsing binary predicates.""" if cname(obj) == "AndCond": return self._parse_and_cond(obj) @@ -234,38 +239,40 @@ def _parse_binary_cond(self, obj): return self._parse_or_cond(obj) raise RuntimeError("Bad Binary Condition") - def _parse_or_cond(self, obj): + def _parse_or_cond(self, obj: Any) -> Tuple[str, str, str, str]: """Top level function for parsing predicates combined with logical OR.""" converted_subcond = self._parse_unary_cond(obj.subcond) converted_subcond[0] = "or" return converted_subcond - def _parse_and_cond(self, obj): + def _parse_and_cond(self, obj: Any) -> Tuple[str, str, str, str]: """Top level function for parsing predicates combined with logical AND.""" converted_subcond = self._parse_unary_cond(obj.subcond) converted_subcond[0] = "and" return converted_subcond - def _parse_unary_cond(self, obj): + def _parse_unary_cond(self, obj: Any) -> Tuple[str, str, str, str]: """Top level function for parsing unary predicates.""" if cname(obj) == "NotCond": return self._parse_not_cond(obj) return self._parse_single_cond(obj) - def _parse_not_cond(self, obj): + def _parse_not_cond(self, obj: Any) -> Tuple[str, str, str, str]: """Parse predicates containing the logical NOT operator.""" converted_subcond = self._parse_single_cond(obj.subcond) converted_subcond[2] = "not {}".format(converted_subcond[2]) return converted_subcond - def _run_method_based_on_multi_idx_mode(self, method_name, obj): + def _run_method_based_on_multi_idx_mode( + self, method_name: str, obj: Any + ) -> Tuple[str, str, str, str]: real_method_name = method_name if self.multi_index_mode != "off": real_method_name = method_name + "_multi_idx" method = eval("StringQuery.{}".format(real_method_name)) return method(self, obj) - def _parse_single_cond(self, obj): + def _parse_single_cond(self, obj: Any) -> Tuple[str, str, str, str]: """Top level function for parsing individual numeric or string predicates.""" if self._is_str_cond(obj): return self._parse_str(obj) @@ -281,7 +288,7 @@ def _parse_single_cond(self, obj): return self._run_method_based_on_multi_idx_mode("_parse_not_leaf", obj) raise RuntimeError("Bad Single Condition") - def _parse_none(self, obj): + def _parse_none(self, obj: Any) -> Tuple[str, str, str, str]: """Parses 'property IS NONE'.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ @@ -308,12 +315,12 @@ def _parse_none(self, obj): None, ] - def _add_aggregation_call_to_multi_idx_predicate(self, predicate): + def _add_aggregation_call_to_multi_idx_predicate(self, predicate: str) -> str: if self.multi_index_mode == "any": return predicate + ".any()" return predicate + ".all()" - def _parse_none_multi_idx(self, obj): + def _parse_none_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ None, @@ -341,7 +348,7 @@ def _parse_none_multi_idx(self, obj): None, ] - def _parse_not_none(self, obj): + def _parse_not_none(self, obj: Any) -> Tuple[str, str, str, str]: """Parses 'property IS NOT NONE'.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ @@ -368,7 +375,7 @@ def _parse_not_none(self, obj): None, ] - def _parse_not_none_multi_idx(self, obj): + def _parse_not_none_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ None, @@ -396,7 +403,7 @@ def _parse_not_none_multi_idx(self, obj): None, ] - def _parse_leaf(self, obj): + def _parse_leaf(self, obj: Any) -> Tuple[str, str, str, str]: """Parses 'node IS LEAF'.""" return [ None, @@ -405,7 +412,7 @@ def _parse_leaf(self, obj): None, ] - def _parse_leaf_multi_idx(self, obj): + def _parse_leaf_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: return [ None, obj.name, @@ -413,7 +420,7 @@ def _parse_leaf_multi_idx(self, obj): None, ] - def _parse_not_leaf(self, obj): + def _parse_not_leaf(self, obj: Any) -> Tuple[str, str, str, str]: """Parses 'node IS NOT LEAF'.""" return [ None, @@ -422,7 +429,7 @@ def _parse_not_leaf(self, obj): None, ] - def _parse_not_leaf_multi_idx(self, obj): + def _parse_not_leaf_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: return [ None, obj.name, @@ -430,7 +437,7 @@ def _parse_not_leaf_multi_idx(self, obj): None, ] - def _is_str_cond(self, obj): + def _is_str_cond(self, obj: Any) -> bool: """Determines whether a predicate is for string data.""" if cname(obj) in [ "StringEq", @@ -442,7 +449,7 @@ def _is_str_cond(self, obj): return True return False - def _is_num_cond(self, obj): + def _is_num_cond(self, obj: Any) -> bool: """Determines whether a predicate is for numeric data.""" if cname(obj) in [ "NumEq", @@ -458,7 +465,7 @@ def _is_num_cond(self, obj): return True return False - def _parse_str(self, obj): + def _parse_str(self, obj: Any) -> Tuple[str, str, str, str]: """Function that redirects processing of string predicates to the correct function. """ @@ -476,7 +483,7 @@ def _parse_str(self, obj): return self._run_method_based_on_multi_idx_mode("_parse_str_match", obj) raise RuntimeError("Bad String Op Class") - def _parse_str_eq(self, obj): + def _parse_str_eq(self, obj: Any) -> Tuple[str, str, str, str]: """Processes string equivalence predicates.""" return [ None, @@ -496,7 +503,7 @@ def _parse_str_eq(self, obj): ), ] - def _parse_str_eq_multi_idx(self, obj): + def _parse_str_eq_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: return [ None, obj.name, @@ -517,7 +524,7 @@ def _parse_str_eq_multi_idx(self, obj): ), ] - def _parse_str_starts_with(self, obj): + def _parse_str_starts_with(self, obj: Any) -> Tuple[str, str, str, str]: """Processes string 'startswith' predicates.""" return [ None, @@ -537,7 +544,7 @@ def _parse_str_starts_with(self, obj): ), ] - def _parse_str_starts_with_multi_idx(self, obj): + def _parse_str_starts_with_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: return [ None, obj.name, @@ -558,7 +565,7 @@ def _parse_str_starts_with_multi_idx(self, obj): ), ] - def _parse_str_ends_with(self, obj): + def _parse_str_ends_with(self, obj: Any) -> Tuple[str, str, str, str]: """Processes string 'endswith' predicates.""" return [ None, @@ -578,7 +585,7 @@ def _parse_str_ends_with(self, obj): ), ] - def _parse_str_ends_with_multi_idx(self, obj): + def _parse_str_ends_with_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: return [ None, obj.name, @@ -599,7 +606,7 @@ def _parse_str_ends_with_multi_idx(self, obj): ), ] - def _parse_str_contains(self, obj): + def _parse_str_contains(self, obj: Any) -> Tuple[str, str, str, str]: """Processes string 'contains' predicates.""" return [ None, @@ -619,7 +626,7 @@ def _parse_str_contains(self, obj): ), ] - def _parse_str_contains_multi_idx(self, obj): + def _parse_str_contains_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: return [ None, obj.name, @@ -640,7 +647,7 @@ def _parse_str_contains_multi_idx(self, obj): ), ] - def _parse_str_match(self, obj): + def _parse_str_match(self, obj: Any) -> Tuple[str, str, str, str]: """Processes string regex match predicates.""" return [ None, @@ -660,7 +667,7 @@ def _parse_str_match(self, obj): ), ] - def _parse_str_match_multi_idx(self, obj): + def _parse_str_match_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: return [ None, obj.name, @@ -681,7 +688,7 @@ def _parse_str_match_multi_idx(self, obj): ), ] - def _parse_num(self, obj): + def _parse_num(self, obj: Any) -> Tuple[str, str, str, str]: """Function that redirects processing of numeric predicates to the correct function. """ @@ -705,7 +712,7 @@ def _parse_num(self, obj): return self._run_method_based_on_multi_idx_mode("_parse_num_not_inf", obj) raise RuntimeError("Bad Number Op Class") - def _parse_num_eq(self, obj): + def _parse_num_eq(self, obj: Any) -> Tuple[str, str, str, str]: """Processes numeric equivalence predicates.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val == -1: @@ -722,9 +729,7 @@ def _parse_num_eq(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ @@ -747,9 +752,7 @@ def _parse_num_eq(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ @@ -782,7 +785,7 @@ def _parse_num_eq(self, obj): ), ] - def _parse_num_eq_multi_idx(self, obj): + def _parse_num_eq_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val == -1: return [ @@ -798,9 +801,7 @@ def _parse_num_eq_multi_idx(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ @@ -823,9 +824,7 @@ def _parse_num_eq_multi_idx(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ @@ -862,7 +861,7 @@ def _parse_num_eq_multi_idx(self, obj): ), ] - def _parse_num_lt(self, obj): + def _parse_num_lt(self, obj: Any) -> Tuple[str, str, str, str]: """Processes numeric less-than predicates.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: @@ -872,9 +871,7 @@ def _parse_num_lt(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ @@ -897,9 +894,7 @@ def _parse_num_lt(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ @@ -932,7 +927,7 @@ def _parse_num_lt(self, obj): ), ] - def _parse_num_lt_multi_idx(self, obj): + def _parse_num_lt_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: warnings.warn( @@ -941,9 +936,7 @@ def _parse_num_lt_multi_idx(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ @@ -966,9 +959,7 @@ def _parse_num_lt_multi_idx(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ @@ -1005,7 +996,7 @@ def _parse_num_lt_multi_idx(self, obj): ), ] - def _parse_num_gt(self, obj): + def _parse_num_gt(self, obj: Any) -> Tuple[str, str, str, str]: """Processes numeric greater-than predicates.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: @@ -1015,9 +1006,7 @@ def _parse_num_gt(self, obj): This condition will always be true. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ @@ -1040,9 +1029,7 @@ def _parse_num_gt(self, obj): This condition will always be true. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ @@ -1075,7 +1062,7 @@ def _parse_num_gt(self, obj): ), ] - def _parse_num_gt_multi_idx(self, obj): + def _parse_num_gt_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: warnings.warn( @@ -1084,9 +1071,7 @@ def _parse_num_gt_multi_idx(self, obj): This condition will always be true. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ @@ -1109,9 +1094,7 @@ def _parse_num_gt_multi_idx(self, obj): This condition will always be true. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ @@ -1148,7 +1131,7 @@ def _parse_num_gt_multi_idx(self, obj): ), ] - def _parse_num_lte(self, obj): + def _parse_num_lte(self, obj: Any) -> Tuple[str, str, str, str]: """Processes numeric less-than-or-equal-to predicates.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: @@ -1158,9 +1141,7 @@ def _parse_num_lte(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ @@ -1183,9 +1164,7 @@ def _parse_num_lte(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ @@ -1218,7 +1197,7 @@ def _parse_num_lte(self, obj): ), ] - def _parse_num_lte_multi_idx(self, obj): + def _parse_num_lte_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: warnings.warn( @@ -1227,9 +1206,7 @@ def _parse_num_lte_multi_idx(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ @@ -1252,9 +1229,7 @@ def _parse_num_lte_multi_idx(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ @@ -1291,7 +1266,7 @@ def _parse_num_lte_multi_idx(self, obj): ), ] - def _parse_num_gte(self, obj): + def _parse_num_gte(self, obj: Any) -> Tuple[str, str, str, str]: """Processes numeric greater-than-or-equal-to predicates.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: @@ -1301,9 +1276,7 @@ def _parse_num_gte(self, obj): This condition will always be true. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ @@ -1326,9 +1299,7 @@ def _parse_num_gte(self, obj): This condition will always be true. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ @@ -1361,7 +1332,7 @@ def _parse_num_gte(self, obj): ), ] - def _parse_num_gte_multi_idx(self, obj): + def _parse_num_gte_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: warnings.warn( @@ -1370,9 +1341,7 @@ def _parse_num_gte_multi_idx(self, obj): This condition will always be true. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ @@ -1395,9 +1364,7 @@ def _parse_num_gte_multi_idx(self, obj): This condition will always be true. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ @@ -1434,7 +1401,7 @@ def _parse_num_gte_multi_idx(self, obj): ), ] - def _parse_num_nan(self, obj): + def _parse_num_nan(self, obj: Any) -> Tuple[str, str, str, str]: """Processes predicates that check for NaN.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ @@ -1465,7 +1432,7 @@ def _parse_num_nan(self, obj): ), ] - def _parse_num_nan_multi_idx(self, obj): + def _parse_num_nan_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ None, @@ -1497,7 +1464,7 @@ def _parse_num_nan_multi_idx(self, obj): ), ] - def _parse_num_not_nan(self, obj): + def _parse_num_not_nan(self, obj: Any) -> Tuple[str, str, str, str]: """Processes predicates that check for NaN.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ @@ -1528,7 +1495,7 @@ def _parse_num_not_nan(self, obj): ), ] - def _parse_num_not_nan_multi_idx(self, obj): + def _parse_num_not_nan_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ None, @@ -1560,7 +1527,7 @@ def _parse_num_not_nan_multi_idx(self, obj): ), ] - def _parse_num_inf(self, obj): + def _parse_num_inf(self, obj: Any) -> Tuple[str, str, str, str]: """Processes predicates that check for Infinity.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ @@ -1591,7 +1558,7 @@ def _parse_num_inf(self, obj): ), ] - def _parse_num_inf_multi_idx(self, obj): + def _parse_num_inf_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ None, @@ -1623,7 +1590,7 @@ def _parse_num_inf_multi_idx(self, obj): ), ] - def _parse_num_not_inf(self, obj): + def _parse_num_not_inf(self, obj: Any) -> Tuple[str, str, str, str]: """Processes predicates that check for not-Infinity.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ @@ -1654,7 +1621,7 @@ def _parse_num_not_inf(self, obj): ), ] - def _parse_num_not_inf_multi_idx(self, obj): + def _parse_num_not_inf_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ None, @@ -1687,7 +1654,9 @@ def _parse_num_not_inf_multi_idx(self, obj): ] -def parse_string_dialect(query_str, multi_index_mode="off"): +def parse_string_dialect( + query_str: str, multi_index_mode: str = "off" +) -> Union[StringQuery, CompoundQuery]: """Parse all types of String-based queries, including multi-queries that leverage the curly brace delimiters. diff --git a/hatchet/readers/caliper_native_reader.py b/hatchet/readers/caliper_native_reader.py index e4c8edf5..bfb95aa2 100644 --- a/hatchet/readers/caliper_native_reader.py +++ b/hatchet/readers/caliper_native_reader.py @@ -7,6 +7,7 @@ import pandas as pd import numpy as np import os +from typing import Any, Dict, List, Tuple, Union import caliperreader as cr @@ -45,7 +46,12 @@ class CaliperNativeReader: ), } - def __init__(self, filename_or_caliperreader, native, string_attributes): + def __init__( + self, + filename_or_caliperreader: Union[str, cr.CaliperReader], + native: bool, + string_attributes: Union[str, List[str]], + ) -> None: """Read in a native cali using Caliper's python reader. Args: @@ -81,7 +87,7 @@ def __init__(self, filename_or_caliperreader, native, string_attributes): if isinstance(self.string_attributes, str): self.string_attributes = [self.string_attributes] - def _create_metric_df(self, metrics): + def _create_metric_df(self, metrics: List[str]) -> pd.DataFrame: """Make a list of metric columns and create a dataframe, group by node""" for col in self.record_data_cols: if self.filename_or_caliperreader.attribute(col).is_value(): @@ -90,7 +96,7 @@ def _create_metric_df(self, metrics): df_new = df_metrics.groupby(df_metrics["nid"]).aggregate("first").reset_index() return df_new - def _reset_metrics(self, metrics): + def _reset_metrics(self, metrics: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Since the initial functions (i.e. main) are only called once, this keeps a small subset of the timeseries data and resets the rest so future iterations will be filled with nans """ @@ -106,7 +112,7 @@ def _reset_metrics(self, metrics): new_mets.append({k: node_dict.get(k, np.nan) for k in cols_to_keep}) return new_mets - def read_metrics(self, ctx="path"): + def read_metrics(self, ctx: str = "path") -> List[pd.DataFrame]: """append each metrics table to a list and return the list, split on timeseries_level if exists""" metric_dfs = [] all_metrics = [] @@ -192,10 +198,10 @@ def read_metrics(self, ctx="path"): # will return a list with only one element unless it is a timeseries return metric_dfs - def create_graph(self, ctx="path"): + def create_graph(self, ctx: str = "path") -> List[Node]: list_roots = [] - def _create_parent(child_node, parent_callpath): + def _create_parent(child_node: Node, parent_callpath: Any) -> None: """We may encounter a parent node in the callpath before we see it as a child node. In this case, we need to create a hatchet node for the parent. @@ -347,7 +353,7 @@ def _create_parent(child_node, parent_callpath): return list_roots - def _parse_metadata(self, mdata): + def _parse_metadata(self, mdata: Dict[str, str]) -> Dict[str, str]: """Convert Caliper Metadata values into correct Python objects. Args: @@ -385,7 +391,7 @@ def _parse_metadata(self, mdata): parsed_mdata[k] = v return parsed_mdata - def read(self): + def read(self) -> hatchet.graphframe.GraphFrame: """Read the caliper records to extract the calling context tree.""" if isinstance(self.filename_or_caliperreader, str): if self.filename_ext != ".cali": @@ -414,7 +420,6 @@ def read(self): # If not a timeseries there will just be one element in the list for df_fixed_data in metrics_list: - metrics = pd.DataFrame.from_dict(data=df_fixed_data) # add missing intermediate nodes to the df_fixed_data dataframe @@ -571,7 +576,9 @@ def read(self): # othewise we'll have populated the timeseries list of gfs attribute and can ignore the return value return self.gf_list[0] - def read_timeseries(self, level="loop.start_iteration"): + def read_timeseries( + self, level: str = "loop.start_iteration" + ) -> List[hatchet.graphframe.GraphFrame]: """Read in a timeseries Cali file. We need to intercept the read function so we can get a list of profiles for thicket diff --git a/hatchet/readers/caliper_reader.py b/hatchet/readers/caliper_reader.py index 4d2dacc8..a298f70f 100644 --- a/hatchet/readers/caliper_reader.py +++ b/hatchet/readers/caliper_reader.py @@ -9,6 +9,8 @@ import subprocess import os import math +from typing import List, Union +from io import TextIOWrapper import pandas as pd import numpy as np @@ -26,7 +28,9 @@ class CaliperReader: """Read in a Caliper file (`cali` or split JSON) or file-like object.""" - def __init__(self, filename_or_stream, query=""): + def __init__( + self, filename_or_stream: Union[str, TextIOWrapper], query: str = "" + ) -> None: """Read from Caliper files (`cali` or split JSON). Args: @@ -55,7 +59,7 @@ def __init__(self, filename_or_stream, query=""): if isinstance(self.filename_or_stream, str): _, self.filename_ext = os.path.splitext(filename_or_stream) - def read_json_sections(self): + def read_json_sections(self) -> None: # if cali-query exists, extract data from .cali to a file-like object if self.filename_ext == ".cali": cali_query = which("cali-query") @@ -140,7 +144,7 @@ def read_json_sections(self): if self.json_cols[idx] != "rank" and item["is_value"] is True: self.metric_columns.append(self.json_cols[idx]) - def create_graph(self): + def create_graph(self) -> List[Node]: list_roots = [] global unknown_label_counter @@ -189,7 +193,7 @@ def create_graph(self): return list_roots - def read(self): + def read(self) -> hatchet.graphframe.GraphFrame: """Read the caliper JSON file to extract the calling context tree.""" with self.timer.phase("read json"): self.read_json_sections() diff --git a/hatchet/readers/dataframe_reader.py b/hatchet/readers/dataframe_reader.py index 7b298c22..50d21e0d 100644 --- a/hatchet/readers/dataframe_reader.py +++ b/hatchet/readers/dataframe_reader.py @@ -7,7 +7,10 @@ from hatchet.node import Node from hatchet.graph import Graph +import pandas as pd + from abc import abstractmethod +from typing import Dict, List # TODO The ABC class was introduced in Python 3.4. # When support for earlier versions is (eventually) dropped, @@ -21,7 +24,7 @@ ABC = ABCMeta("ABC", (object,), {"__slots__": ()}) -def _get_node_from_df_iloc(df, ind): +def _get_node_from_df_iloc(df: pd.DataFrame, ind: int) -> Node: node = None if isinstance(df.iloc[ind].name, tuple): node = df.iloc[ind].name[0] @@ -34,7 +37,7 @@ def _get_node_from_df_iloc(df, ind): return node -def _get_parents_and_children(df): +def _get_parents_and_children(df: pd.DataFrame) -> Dict[Node, Dict[str, List[int]]]: rel_dict = {} for i in range(len(df)): node = _get_node_from_df_iloc(df, i) @@ -45,7 +48,9 @@ def _get_parents_and_children(df): return rel_dict -def _reconstruct_graph(df, rel_dict): +def _reconstruct_graph( + df: pd.DataFrame, rel_dict: Dict[Node, Dict[str, List[int]]] +) -> Graph: node_list = sorted(list(df.index.to_frame()["node"])) for i in range(len(df)): node = _get_node_from_df_iloc(df, i) @@ -60,14 +65,14 @@ def _reconstruct_graph(df, rel_dict): class DataframeReader(ABC): """Abstract Base Class for reading in checkpointing files.""" - def __init__(self, filename): + def __init__(self, filename: str) -> None: self.filename = filename @abstractmethod - def _read_dataframe_from_file(self, **kwargs): + def _read_dataframe_from_file(self, **kwargs) -> pd.DataFrame: pass - def read(self, **kwargs): + def read(self, **kwargs) -> hatchet.graphframe.GraphFrame: df = self._read_dataframe_from_file(**kwargs) rel_dict = _get_parents_and_children(df) graph = _reconstruct_graph(df, rel_dict) diff --git a/hatchet/readers/gprof_dot_reader.py b/hatchet/readers/gprof_dot_reader.py index 1ac4c502..981388cc 100644 --- a/hatchet/readers/gprof_dot_reader.py +++ b/hatchet/readers/gprof_dot_reader.py @@ -4,6 +4,7 @@ # SPDX-License-Identifier: MIT import re +from typing import List import pandas as pd import pydot @@ -19,7 +20,7 @@ class GprofDotReader: """Read in gprof/callgrind output in dot format generated by gprof2dot.""" - def __init__(self, filename): + def __init__(self, filename: str) -> None: self.dotfile = filename self.name_to_hnode = {} @@ -27,7 +28,7 @@ def __init__(self, filename): self.timer = Timer() - def create_graph(self): + def create_graph(self) -> List[Node]: """Read the DOT files to create a graph.""" graphs = pydot.graph_from_dot_file(self.dotfile, encoding="utf-8") @@ -97,7 +98,7 @@ def create_graph(self): return list_roots - def read(self): + def read(self) -> hatchet.graphframe.GraphFrame: """Read the DOT file generated by gprof2dot to create a graphframe. The DOT file contains a call graph. """ diff --git a/hatchet/readers/hdf5_reader.py b/hatchet/readers/hdf5_reader.py index b9a48544..f3286a9c 100644 --- a/hatchet/readers/hdf5_reader.py +++ b/hatchet/readers/hdf5_reader.py @@ -10,14 +10,14 @@ class HDF5Reader(DataframeReader): - def __init__(self, filename): + def __init__(self, filename: str) -> None: # TODO Remove Arguments when Python 2.7 support is dropped if sys.version_info[0] == 2: super(HDF5Reader, self).__init__(filename) else: super().__init__(filename) - def _read_dataframe_from_file(self, **kwargs): + def _read_dataframe_from_file(self, **kwargs) -> pd.DataFrame: df = None with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=Warning) diff --git a/hatchet/readers/hpctoolkit_reader.py b/hatchet/readers/hpctoolkit_reader.py index a0e0dae9..7d72ce82 100644 --- a/hatchet/readers/hpctoolkit_reader.py +++ b/hatchet/readers/hpctoolkit_reader.py @@ -8,6 +8,7 @@ import re import os import traceback +from typing import Any, Dict, Tuple, Union import numpy as np import pandas as pd @@ -42,13 +43,15 @@ src_file = 0 -def init_shared_array(buf_): +# TODO replace the "Any" type hint with numpy.typing.ArrayLike +# when our minimum supported version of numpy is 1.20 or higher +def init_shared_array(buf_: Any) -> None: """Initialize shared array.""" global shared_metrics shared_metrics = buf_ -def read_metricdb_file(args): +def read_metricdb_file(args: Tuple[str, int, int, int, int, Tuple[int, int]]) -> None: """Read a single metricdb file into a 1D array.""" ( filename, @@ -93,7 +96,7 @@ class HPCToolkitReader: metric-db files. """ - def __init__(self, dir_name): + def __init__(self, dir_name: str) -> None: # this is the name of the HPCToolkit database directory. The directory # contains an experiment.xml and some metric-db files self.dir_name = dir_name @@ -147,7 +150,7 @@ def __init__(self, dir_name): self.timer = Timer() - def fill_tables(self): + def fill_tables(self) -> Tuple[Dict, Dict, Dict, Dict]: """Read certain sections of the experiment.xml file to create dicts of load modules, src_files, procedure_names, and metric_names. """ @@ -171,7 +174,7 @@ def fill_tables(self): self.metric_names, ) - def read_all_metricdb_files(self): + def read_all_metricdb_files(self) -> None: """Read all the metric-db files and create a dataframe with num_nodes X num_metricdb_files rows and num_metrics columns. Three additional columns store the node id, MPI process rank, and thread id (if applicable). @@ -234,7 +237,7 @@ def read_all_metricdb_files(self): # subtract_exclusive_metric_vals/ num nodes is already calculated self.total_execution_threads = self.num_threads_per_rank * self.num_ranks - def read(self): + def read(self) -> hatchet.graphframe.GraphFrame: """Read the experiment.xml file to extract the calling context tree and create a dataframe out of it. Then merge the two dataframes to create the final dataframe. @@ -318,7 +321,7 @@ def read(self): return hatchet.graphframe.GraphFrame(graph, dataframe, exc_metrics, inc_metrics) - def parse_xml_children(self, xml_node, hnode): + def parse_xml_children(self, xml_node: Any, hnode: Node) -> None: """Parses all children of an XML node.""" for xml_child in xml_node: if xml_child.tag != "M": @@ -326,7 +329,9 @@ def parse_xml_children(self, xml_node, hnode): line = int(xml_node.get("l")) self.parse_xml_node(xml_child, nid, line, hnode) - def parse_xml_node(self, xml_node, parent_nid, parent_line, hparent): + def parse_xml_node( + self, xml_node: Any, parent_nid: int, parent_line: int, hparent: Node + ) -> None: """Parses an XML node and its children recursively.""" nid = int(xml_node.get("i")) @@ -414,7 +419,16 @@ def parse_xml_node(self, xml_node, parent_nid, parent_line, hparent): hparent.add_child(hnode) self.parse_xml_children(xml_node, hnode) - def create_node_dict(self, nid, hnode, name, node_type, src_file, line, module): + def create_node_dict( + self, + nid: int, + hnode: Node, + name: str, + node_type: str, + src_file: str, + line: int, + module: str, + ) -> Dict[str, Union[int, str, Node]]: """Create a dict with all the node attributes.""" node_dict = { "nid": nid, @@ -428,7 +442,7 @@ def create_node_dict(self, nid, hnode, name, node_type, src_file, line, module): return node_dict - def count_cpu_threads_per_rank(self): + def count_cpu_threads_per_rank(self) -> int: metricdb_files = glob.glob(self.dir_name + "/*.metric-db") cpu_thread_ids = set() diff --git a/hatchet/readers/hpctoolkit_reader_latest.py b/hatchet/readers/hpctoolkit_reader_latest.py index 237d47dd..3181fe84 100644 --- a/hatchet/readers/hpctoolkit_reader_latest.py +++ b/hatchet/readers/hpctoolkit_reader_latest.py @@ -6,7 +6,7 @@ import os import re import struct -from typing import Dict, Union +from typing import Dict, Tuple, Union import pandas as pd @@ -18,7 +18,7 @@ def safe_unpack( format: str, data: bytes, offset: int, index: int = None, index_length: int = None -) -> tuple: +) -> Tuple: length = struct.calcsize(format) if index: offset += index * (length if index_length is None else index_length) @@ -49,7 +49,6 @@ def read_string(data: bytes, offset: int) -> str: class HPCToolkitReaderLatest: - def __init__( self, dir_path: str, @@ -217,7 +216,7 @@ def _parse_function( return self._functions[pFunction] def _store_cct_node( - self, ctxId: int, frame: dict, parent: Node = None, depth: int = 0 + self, ctxId: int, frame: Dict, parent: Node = None, depth: int = 0 ) -> Node: node = Node(Frame(frame), parent=parent, hnid=ctxId, depth=depth) if parent is None: @@ -228,9 +227,7 @@ def _store_cct_node( "node": node, "name": ( # f"{frame['type']}: {frame['name']}" - frame["name"] - if frame["name"] != 1 - else "entry" + frame["name"] if frame["name"] != 1 else "entry" ), } @@ -249,7 +246,6 @@ def _parse_context( meta_db: bytes, parent_time: int, ) -> None: - final_offset = current_offset + total_size while current_offset < final_offset: @@ -311,7 +307,6 @@ def _parse_context( def _read_summary_profile( self, ) -> None: - with open(self._profile_file, "rb") as file: file.seek(FILE_HEADER_OFFSET) formatProfileInfos = " None: """Read from a json string specification of a graphframe json (string): Json specification of a graphframe. """ self.spec_dict = json.loads(json_spec) - def read(self): + def read(self) -> hatchet.graphframe.GraphFrame: roots = [] for graph_spec in self.spec_dict["graph"]: # turn frames into nodes diff --git a/hatchet/readers/literal_reader.py b/hatchet/readers/literal_reader.py index c3c95a01..f3e3fe1f 100644 --- a/hatchet/readers/literal_reader.py +++ b/hatchet/readers/literal_reader.py @@ -3,6 +3,8 @@ # # SPDX-License-Identifier: MIT +from typing import Any, Dict, List + import pandas as pd import hatchet.graphframe @@ -59,7 +61,7 @@ class LiteralReader: (GraphFrame): graphframe containing data from dictionaries """ - def __init__(self, graph_dict): + def __init__(self, graph_dict: Dict) -> None: """Read from list of dictionaries. graph_dict (dict): List of dictionaries encoding nodes. @@ -67,8 +69,13 @@ def __init__(self, graph_dict): self.graph_dict = graph_dict def parse_node_literal( - self, frame_to_node_dict, node_dicts, child_dict, hparent, seen_nids - ): + self, + frame_to_node_dict: Dict[Frame, Node], + node_dicts: List[Dict[str, Any]], + child_dict: Dict[str, Any], + hparent: Node, + seen_nids: List[int], + ) -> None: """Create node_dict for one node and then call the function recursively on all children. """ @@ -110,7 +117,7 @@ def parse_node_literal( frame_to_node_dict, node_dicts, child, hnode, seen_nids ) - def read(self): + def read(self) -> hatchet.graphframe.GraphFrame: list_roots = [] node_dicts = [] frame_to_node_dict = {} diff --git a/hatchet/readers/pyinstrument_reader.py b/hatchet/readers/pyinstrument_reader.py index 3ce00400..cc739a76 100644 --- a/hatchet/readers/pyinstrument_reader.py +++ b/hatchet/readers/pyinstrument_reader.py @@ -4,6 +4,7 @@ # SPDX-License-Identifier: MIT import json +from typing import Any, Dict import pandas as pd @@ -14,14 +15,14 @@ class PyinstrumentReader: - def __init__(self, filename): + def __init__(self, filename: str) -> None: self.pyinstrument_json_filename = filename self.graph_dict = {} self.list_roots = [] self.node_dicts = [] - def create_graph(self): - def parse_node_literal(child_dict, hparent): + def create_graph(self) -> Graph: + def parse_node_literal(child_dict: Dict[str, Any], hparent: Node) -> None: """Create node_dict for one node and then call the function recursively on all children.""" @@ -85,7 +86,7 @@ def parse_node_literal(child_dict, hparent): return graph - def read(self): + def read(self) -> hatchet.graphframe.GraphFrame: with open(self.pyinstrument_json_filename) as pyinstrument_json: self.graph_dict = json.load(pyinstrument_json) diff --git a/hatchet/readers/spotdb_reader.py b/hatchet/readers/spotdb_reader.py index 3913e1d3..bccee0e9 100644 --- a/hatchet/readers/spotdb_reader.py +++ b/hatchet/readers/spotdb_reader.py @@ -3,6 +3,8 @@ # # SPDX-License-Identifier: MIT +from typing import Any, Dict, List + import pandas as pd import hatchet.graphframe @@ -12,7 +14,7 @@ from hatchet.util.timer import Timer -def _find_child_node(node, name): +def _find_child_node(node: Node, name: str) -> Node: """Return child with given name from parent node""" for c in node.children: if c.frame.get("name") == name: @@ -23,7 +25,12 @@ def _find_child_node(node, name): class SpotDatasetReader: """Reads a (single-run) dataset from SpotDB""" - def __init__(self, regionprofile, metadata, attr_info): + def __init__( + self, + regionprofile: Dict[str, Any], + metadata: Dict[str, Any], + attr_info: Dict[str, Any], + ) -> None: """Initialize SpotDataset reader Args: @@ -49,7 +56,7 @@ def __init__(self, regionprofile, metadata, attr_info): self.timer = Timer() - def create_graph(self): + def create_graph(self) -> None: """Create the graph. Fills in df_data and metric_columns.""" self.df_data.clear() @@ -81,7 +88,9 @@ def create_graph(self): self.df_data.append(dict({"name": name, "node": node}, **metrics)) - def read(self, default_metric="Total time (inc)"): + def read( + self, default_metric: str = "Total time (inc)" + ) -> hatchet.graphframe.GraphFrame: """Create GraphFrame for the given Spot dataset.""" with self.timer.phase("graph construction"): @@ -116,7 +125,7 @@ def read(self, default_metric="Total time (inc)"): default_metric=default_metric, ) - def _create_node(self, path): + def _create_node(self, path: List[str]) -> Node: parent = self.roots.get(path[0], None) if parent is None: parent = Node(Frame(name=path[0])) @@ -136,7 +145,14 @@ def _create_node(self, path): class SpotDBReader: """Import multiple runs as graph frames from a SpotDB instance""" - def __init__(self, db_key, list_of_ids=None, default_metric="Total time (inc)"): + # TODO decide if we want to import spotdb elsewhere to correctly add a type + # hint to 'db_key' + def __init__( + self, + db_key: Any, + list_of_ids: List = None, + default_metric: str = "Total time (inc)", + ) -> None: """Initialize SpotDBReader Args: @@ -156,7 +172,7 @@ def __init__(self, db_key, list_of_ids=None, default_metric="Total time (inc)"): self.list_of_ids = list_of_ids self.default_metric = default_metric - def read(self): + def read(self) -> List[hatchet.graphframe.GraphFrame]: """Read given runs from SpotDB Returns: diff --git a/hatchet/readers/tau_reader.py b/hatchet/readers/tau_reader.py index aed2af11..4d192a0a 100644 --- a/hatchet/readers/tau_reader.py +++ b/hatchet/readers/tau_reader.py @@ -6,6 +6,7 @@ import re import os import glob +from typing import Any, Dict, List, Tuple import pandas as pd import hatchet.graphframe from hatchet.node import Node @@ -16,7 +17,7 @@ class TAUReader: """Read in a profile generated using TAU.""" - def __init__(self, dirname): + def __init__(self, dirname: str) -> None: self.dirname = dirname self.node_dicts = [] self.callpath_to_node = {} @@ -30,17 +31,17 @@ def __init__(self, dirname): def create_node_dict( self, - node, - columns, - metric_values, - name, - filename, - module, - start_line, - end_line, - rank, - thread, - ): + node: Node, + columns: List[str], + metric_values: Tuple[Any, ...], + name: str, + filename: str, + module: str, + start_line: int, + end_line: int, + rank: int, + thread: int, + ) -> Dict[str, Any]: node_dict = { "node": node, "rank": rank, @@ -55,8 +56,10 @@ def create_node_dict( node_dict[columns[i + 1]] = metric_values[i] return node_dict - def create_graph(self): - def _get_name_file_module(is_parent, node_info, symbol): + def create_graph(self) -> List[Node]: + def _get_name_file_module( + is_parent: bool, node_info: str, symbol: str + ) -> Tuple[str, str, str]: """This function gets the name, file and module information for a node using the corresponding line in the output file. Example line: [UNWIND] [@] [{} {}] @@ -134,7 +137,7 @@ def _get_name_file_module(is_parent, node_info, symbol): module = node_info[2] return [name, file, module] - def _get_line_numbers(node_info): + def _get_line_numbers(node_info: str) -> Tuple[str, str]: start_line, end_line = 0, 0 # There should be [{}] symbols if there is line number information. if "[{" in node_info: @@ -157,7 +160,7 @@ def _get_line_numbers(node_info): end_line = line_numbers.split(",")[1] return [start_line, end_line] - def _create_parent(child_node, parent_callpath): + def _create_parent(child_node: Node, parent_callpath: str) -> None: """In TAU output, sometimes we see a node as a parent in the callpath before we see it as a leaf node. In this case, we need to create a hatchet node for the parent. @@ -201,7 +204,7 @@ def _create_parent(child_node, parent_callpath): child_node.add_parent(parent_node) _create_parent(parent_node, grand_parent_callpath) - def _construct_column_list(first_rank_filenames): + def _construct_column_list(first_rank_filenames: List[str]) -> List[str]: """This function constructs columns, exc_metrics, and inc_metrics using all metric files of a rank. It gets the all metric files of a rank as a tuple and only loads the @@ -443,7 +446,7 @@ def _construct_column_list(first_rank_filenames): return list_roots - def read(self): + def read(self) -> hatchet.graphframe.GraphFrame: """Read the TAU profile file to extract the calling context tree.""" # Add all nodes and roots. roots = self.create_graph() diff --git a/hatchet/readers/timemory_reader.py b/hatchet/readers/timemory_reader.py index 4e0a7f70..d3a01c51 100644 --- a/hatchet/readers/timemory_reader.py +++ b/hatchet/readers/timemory_reader.py @@ -7,6 +7,9 @@ import pandas as pd import os import glob +import re +from io import TextIOWrapper +from typing import Any, Dict, List, Tuple, Union from hatchet.graphframe import GraphFrame from ..node import Node from ..graph import Graph @@ -17,7 +20,12 @@ class TimemoryReader: """Read in timemory JSON output""" - def __init__(self, input, select=None, **_kwargs): + def __init__( + self, + input: Union[str, TextIOWrapper, Dict], + select: List[str] = None, + **_kwargs, + ) -> None: """Arguments: input (str or file-stream or dict or None): Valid argument types are: @@ -77,11 +85,13 @@ def _get_select(val): else: raise TypeError("select must be None or list of string") - def create_graph(self): + def create_graph(self) -> GraphFrame: """Create graph and dataframe""" list_roots = [] - def remove_keys(_dict, _keys): + def remove_keys( + _dict: Dict[str, Any], _keys: Union[str, List[str]] + ) -> Dict[str, Any]: """Remove keys from dictionary""" if isinstance(_keys, str): if _keys in _dict: @@ -91,13 +101,13 @@ def remove_keys(_dict, _keys): _dict = remove_keys(_dict, _key) return _dict - def add_metrics(_dict): + def add_metrics(_dict: Dict[str, Any]) -> None: """Add any keys to metric_cols which don't already exist""" for key, itr in _dict.items(): if key not in self.metric_cols: self.metric_cols.append(key) - def process_regex(_data): + def process_regex(_data: re.Match) -> Dict[str, str]: """Process the regex data for func/file/line info""" _tmp = {} if _data is not None and len(_data.groups()) > 0: @@ -110,10 +120,8 @@ def process_regex(_data): pass return _tmp if _tmp else None - def perform_regex(_prefix): + def perform_regex(_prefix: str) -> Dict[str, str]: """Performs a search for standard configurations of function + file + line""" - import re - _tmp = None for _pattern in [ # [func][file] @@ -134,7 +142,7 @@ def perform_regex(_prefix): break return _tmp if _tmp else None - def get_name_line_file(_prefix): + def get_name_line_file(_prefix: str) -> Tuple[Dict[str, str], Dict[str, str]]: """Get the standard set of dictionary entries. Also, parses the prefix for func-file-line info which is typically in the form: @@ -164,7 +172,9 @@ def get_name_line_file(_prefix): _keys["name"] = "{}/{}".format(_keys["name"], _pdict["tail"]) return (_keys, _extra) - def format_labels(_labels): + def format_labels( + _labels: Union[str, Dict[str, Any], List[str], Tuple[str, ...]], + ) -> List[str]: """Formats multi dimensional metrics which refer to multiple metrics stored in a 1D list. @@ -186,7 +196,11 @@ def format_labels(_labels): _ret.append(_item.strip().replace(" ", "-").replace("_", "-")) return _ret - def match_labels_and_values(_metric_stats, _metric_label, _metric_type): + def match_labels_and_values( + _metric_stats: Dict[str, Any], + _metric_label: Union[str, List[str]], + _metric_type: str, + ) -> Dict[str, Any]: """Match metric labels with values and add '(inc)' if the metric type is inclusive. @@ -214,7 +228,9 @@ def match_labels_and_values(_metric_stats, _metric_label, _metric_type): _ret["{}.{}{}".format(_key, _metric_label, _metric_type)] = _item return _ret - def collapse_ids(_obj, _expect_scalar=False): + def collapse_ids( + _obj: Union[int, List[int]], _expect_scalar: bool = False + ) -> Union[str, int]: """node/rank/thread id may be int, array of ints, or None. When the entry is a list of integers (which happens when metric values are aggregates of multiple ranks/threads), this function generates a consistent @@ -250,7 +266,13 @@ def collapse_ids(_obj, _expect_scalar=False): return f"{_obj}" if _expect_scalar else int(_obj) return None - def parse_node(_metric_name, _node_data, _hparent, _rank, _parent_callpath): + def parse_node( + _metric_name: str, + _node_data: Dict[str, Any], + _hparent: Node, + _rank: int, + _parent_callpath: Tuple[str], + ) -> None: """Create callpath_to_node_dict for one node and then call the function recursively on all children. """ @@ -374,7 +396,9 @@ def parse_node(_metric_name, _node_data, _hparent, _rank, _parent_callpath): for _child in _node_data["children"]: parse_node(_metric_name, _child, _hnode, _rank, callpath) - def read_graph(_metric_name, ranks_data, _rank): + def read_graph( + _metric_name: str, ranks_data: List[List[Dict[str, Any]]], _rank: int + ) -> bool: """The layout of the graph at this stage is subject to slightly different structures based on whether distributed memory parallelism (DMP) @@ -413,7 +437,11 @@ def read_graph(_metric_name, ranks_data, _rank): return True return False - def read_properties(properties, _metric_name, _metric_data): + def read_properties( + properties: Dict[str, Dict[str, Any]], + _metric_name: str, + _metric_data: Dict[str, Any], + ) -> None: """Read in the properties for a component. This contains information on the type of the component, a description, a unit_value relative to the @@ -538,7 +566,6 @@ def read_properties(properties, _metric_name, _metric_data): if self.multiple_ranks or self.multiple_threads: dataframe = dataframe.unstack() for idx, row in dataframe.iterrows(): - # There is always a valid name for an index. # Take that valid name and assign to other ranks/rows. name = row["name"][row["name"].first_valid_index()] @@ -564,7 +591,7 @@ def read_properties(properties, _metric_name, _metric_data): graph, dataframe, exc_metrics, inc_metrics, self.default_metric ) - def read(self): + def read(self) -> GraphFrame: """Read timemory json.""" # check if the input is a dictionary. diff --git a/hatchet/util/colormaps.py b/hatchet/util/colormaps.py index b3b23ce8..c1036605 100644 --- a/hatchet/util/colormaps.py +++ b/hatchet/util/colormaps.py @@ -3,6 +3,8 @@ # # SPDX-License-Identifier: MIT +from typing import List + class ColorMaps: # RdYlGn (Default) color map @@ -98,7 +100,7 @@ class ColorMaps: def __init__(self): self.colors = [] - def get_colors(self, colormap, invert_colormap): + def get_colors(self, colormap: str, invert_colormap: bool) -> List[str]: """Returns a list of colors based on the colormap and invert_colormap arguments. """ diff --git a/hatchet/util/deprecated.py b/hatchet/util/deprecated.py index 4d39316d..46fda6b9 100644 --- a/hatchet/util/deprecated.py +++ b/hatchet/util/deprecated.py @@ -4,6 +4,7 @@ # SPDX-License-Identifier: MIT import functools +from typing import Any, Dict def deprecated_params(**old_to_new): @@ -18,7 +19,9 @@ def wrapper(*args, **kwargs): return deco -def rename_kwargs(fname, old_to_new, kwargs): +def rename_kwargs( + fname: str, old_to_new: Dict[str, str], kwargs: Dict[str, Any] +) -> None: for old, new in old_to_new.items(): if old in kwargs: if new in kwargs: diff --git a/hatchet/util/dot.py b/hatchet/util/dot.py index 84cd62a8..8811de48 100644 --- a/hatchet/util/dot.py +++ b/hatchet/util/dot.py @@ -3,11 +3,24 @@ # # SPDX-License-Identifier: MIT +from typing import List, Tuple, Union + import matplotlib.cm import matplotlib.colors +import pandas as pd + +from ..node import Node -def trees_to_dot(roots, dataframe, metric, name, rank, thread, threshold): +def trees_to_dot( + roots: List[Node], + dataframe: pd.DataFrame, + metric: Union[str, Tuple[str, ...]], + name: Union[str, Tuple[str, ...]], + rank: int, + thread: int, + threshold: float, +) -> str: """Calls to_dot in turn for each tree in the graph/forest.""" text = ( "strict digraph {\n" @@ -33,14 +46,22 @@ def trees_to_dot(roots, dataframe, metric, name, rank, thread, threshold): return text -def to_dot(hnode, dataframe, metric, name, rank, thread, threshold, visited): +def to_dot( + hnode: Node, + dataframe: pd.DataFrame, + metric: Union[str, Tuple[str, ...]], + name: Union[str, Tuple[str, ...]], + rank: int, + thread: int, + threshold: float, + visited: List[Node], +) -> Tuple[str, str]: """Write to graphviz dot format.""" colormap = matplotlib.cm.Reds min_time = dataframe[metric].min() max_time = dataframe[metric].max() - def add_nodes_and_edges(hnode): - + def add_nodes_and_edges(hnode: Node) -> Tuple[str, str]: # set dataframe index based on if rank is a part of the index if "rank" in dataframe.index.names and "thread" in dataframe.index.names: df_index = (hnode, rank, thread) diff --git a/hatchet/util/executable.py b/hatchet/util/executable.py index c370ce57..49ae04a4 100644 --- a/hatchet/util/executable.py +++ b/hatchet/util/executable.py @@ -6,7 +6,7 @@ import os -def which(executable): +def which(executable: str) -> str: """Finds an `executable` in the user's PATH like command-line which. Args: diff --git a/hatchet/util/profiler.py b/hatchet/util/profiler.py index 22ce80d4..ac8d4984 100644 --- a/hatchet/util/profiler.py +++ b/hatchet/util/profiler.py @@ -7,6 +7,7 @@ import traceback import sys import os +from typing import List from datetime import datetime @@ -17,7 +18,7 @@ import pstats -def print_incomptable_msg(stats_file): +def print_incomptable_msg(stats_file: str) -> None: """ Function which makes the syntax cleaner in Profiler.write_to_file(). """ @@ -36,12 +37,12 @@ class Profiler: Exports a pstats file to be read by the pstats reader. """ - def __init__(self): + def __init__(self) -> None: self._prf = cProfile.Profile() self._output = "hatchet-profile" self._active = False - def start(self): + def start(self) -> None: """ Description: Place before the block of code to be profiled. """ @@ -54,7 +55,7 @@ def start(self): self._active = True self._prf.enable() - def stop(self): + def stop(self) -> None: """ Description: Place at the end of the block of code being profiled. """ @@ -63,7 +64,7 @@ def stop(self): self._prf.disable() self.write_to_file() - def reset(self): + def reset(self) -> None: """ Description: Resets the profilier. """ @@ -75,7 +76,7 @@ def reset(self): self._prf = cProfile.Profile() - def __str__(self): + def __str__(self) -> str: """ Description: Writes stats object out as a string. """ @@ -83,7 +84,9 @@ def __str__(self): pstats.Stats(self._prf, stream=s).print_stats() return s.getvalue() - def write_to_file(self, filename="", add_pstats_files=[]): + def write_to_file( + self, filename: str = "", add_pstats_files: List[str] = [] + ) -> None: """ Description: Write the pstats object to a binary file to be read in by an appropriate source. diff --git a/hatchet/util/timer.py b/hatchet/util/timer.py index 03960a38..f40f3043 100644 --- a/hatchet/util/timer.py +++ b/hatchet/util/timer.py @@ -5,19 +5,19 @@ from collections import OrderedDict from contextlib import contextmanager -from datetime import datetime +from datetime import datetime, timedelta from io import StringIO class Timer(object): """Simple phase timer with a context manager.""" - def __init__(self): + def __init__(self) -> None: self._phase = None self._start_time = None self._times = OrderedDict() - def start_phase(self, phase): + def start_phase(self, phase: str) -> timedelta: now = datetime.now() delta = None @@ -29,7 +29,7 @@ def start_phase(self, phase): self._start_time = now return delta - def end_phase(self): + def end_phase(self) -> None: assert self._phase and self._start_time now = datetime.now() @@ -42,7 +42,7 @@ def end_phase(self): self._phase = None self._start_time = None - def __str__(self): + def __str__(self) -> str: out = StringIO() out.write("Times:\n") for phase, delta in self._times.items(): @@ -50,7 +50,7 @@ def __str__(self): return out.getvalue() @contextmanager - def phase(self, name): + def phase(self, name: str): self.start_phase(name) yield self.end_phase() diff --git a/hatchet/writers/dataframe_writer.py b/hatchet/writers/dataframe_writer.py index b23c6172..3adc37a0 100644 --- a/hatchet/writers/dataframe_writer.py +++ b/hatchet/writers/dataframe_writer.py @@ -4,6 +4,8 @@ # SPDX-License-Identifier: MIT from hatchet.node import Node +from hatchet.graphframe import GraphFrame +import pandas as pd from abc import abstractmethod @@ -19,7 +21,7 @@ ABC = ABCMeta("ABC", (object,), {"__slots__": ()}) -def _get_node_from_df_iloc(df, ind): +def _get_node_from_df_iloc(df: pd.DataFrame, ind: int) -> Node: node = None if isinstance(df.iloc[ind].name, tuple): node = df.iloc[ind].name[0] @@ -32,7 +34,7 @@ def _get_node_from_df_iloc(df, ind): return node -def _fill_children_and_parents(dump_df): +def _fill_children_and_parents(dump_df: pd.DataFrame) -> pd.DataFrame: dump_df["children"] = [[] for _ in range(len(dump_df))] dump_df["parents"] = [[] for _ in range(len(dump_df))] for i in range(len(dump_df)): @@ -49,14 +51,14 @@ def _fill_children_and_parents(dump_df): class DataframeWriter(ABC): - def __init__(self, filename): + def __init__(self, filename: str) -> None: self.filename = filename @abstractmethod - def _write_dataframe_to_file(self, df, **kwargs): + def _write_dataframe_to_file(self, df: pd.DataFrame, **kwargs) -> None: pass - def write(self, gf, **kwargs): + def write(self, gf: GraphFrame, **kwargs) -> None: gf_cpy = gf.deepcopy() dump_df = _fill_children_and_parents(gf_cpy.dataframe) dump_df["exc_metrics"] = None diff --git a/hatchet/writers/hdf5_writer.py b/hatchet/writers/hdf5_writer.py index 81489015..2027a07e 100644 --- a/hatchet/writers/hdf5_writer.py +++ b/hatchet/writers/hdf5_writer.py @@ -6,17 +6,19 @@ import warnings import sys +import pandas as pd + from .dataframe_writer import DataframeWriter class HDF5Writer(DataframeWriter): - def __init__(self, filename): + def __init__(self, filename: str) -> None: if sys.version_info[0] == 2: super(HDF5Writer, self).__init__(filename) else: super().__init__(filename) - def _write_dataframe_to_file(self, df, **kwargs): + def _write_dataframe_to_file(self, df: pd.DataFrame, **kwargs) -> None: if "key" not in kwargs: raise KeyError("Writing to HDF5 requires a user-supplied key") key = kwargs["key"] From 581aeb4edaa7d282fee8a4e81dfd20737b7af8f0 Mon Sep 17 00:00:00 2001 From: Ian Lumsden Date: Wed, 13 Nov 2024 15:13:10 -0500 Subject: [PATCH 2/5] Adds config for MyPy and adds Optional to several arguments --- hatchet/external/console.py | 4 +- hatchet/frame.py | 12 +- hatchet/graph.py | 20 +-- hatchet/graphframe.py | 49 ++++--- hatchet/node.py | 29 ++-- hatchet/query/compat.py | 6 +- hatchet/query/engine.py | 8 +- hatchet/query/string_dialect.py | 154 ++++++++++++++------ hatchet/readers/hpctoolkit_reader_latest.py | 16 +- hatchet/readers/spotdb_reader.py | 4 +- hatchet/readers/timemory_reader.py | 10 +- hatchet/util/executable.py | 3 +- pyproject.toml | 3 + 13 files changed, 203 insertions(+), 115 deletions(-) diff --git a/hatchet/external/console.py b/hatchet/external/console.py index e69177a9..1de5e145 100644 --- a/hatchet/external/console.py +++ b/hatchet/external/console.py @@ -34,7 +34,7 @@ import pandas as pd import numpy as np import warnings -from typing import Any, Dict, List, Tuple, Union +from typing import List, Optional, Tuple, Union from ..util.colormaps import ColorMaps from ..node import Node @@ -47,7 +47,7 @@ def __init__(self, unicode: bool = False, color: bool = False) -> None: def render( self, - roots: Union[List[Node], Tuple[Node, ...]], + roots: Optional[Union[List[Node], Tuple[Node, ...]]], dataframe: pd.DataFrame, **kwargs, ) -> str: diff --git a/hatchet/frame.py b/hatchet/frame.py index f89e8f75..4462640c 100644 --- a/hatchet/frame.py +++ b/hatchet/frame.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: MIT from functools import total_ordering -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union @total_ordering @@ -15,7 +15,7 @@ class Frame: attrs (dict): dictionary of attributes and values """ - def __init__(self, attrs: Dict[str, Any] = None, **kwargs) -> None: + def __init__(self, attrs: Optional[Dict[str, Any]] = None, **kwargs) -> None: """Construct a frame from a dictionary, or from immediate kwargs. Arguments: @@ -49,13 +49,13 @@ def __init__(self, attrs: Dict[str, Any] = None, **kwargs) -> None: self._tuple_repr = None - def __eq__(self, other: object) -> bool: + def __eq__(self, other: "Frame") -> bool: return self.tuple_repr == other.tuple_repr - def __lt__(self, other: object) -> bool: + def __lt__(self, other: "Frame") -> bool: return self.tuple_repr < other.tuple_repr - def __gt__(self, other: object) -> bool: + def __gt__(self, other: "Frame") -> bool: return self.tuple_repr > other.tuple_repr def __hash__(self) -> int: @@ -81,7 +81,7 @@ def copy(self) -> "Frame": def __getitem__(self, name: str) -> Any: return self.attrs[name] - def get(self, name: str, default: Any = None): + def get(self, name: str, default: Optional[Any] = None): return self.attrs.get(name, default) def values( diff --git a/hatchet/graph.py b/hatchet/graph.py index 20d76b75..78a1fae5 100644 --- a/hatchet/graph.py +++ b/hatchet/graph.py @@ -5,7 +5,7 @@ from collections import defaultdict from collections.abc import Iterable -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from .node import Node, traversal_order, node_traversal_order @@ -33,8 +33,8 @@ def __init__(self, roots: Union[List[Node], Tuple[Node, ...]]) -> None: def traverse( self, order: str = "pre", - attrs: Union[List[str], Tuple[str, ...], str] = None, - visited: Dict[int, int] = None, + attrs: Optional[Union[List[str], Tuple[str, ...], str]] = None, + visited: Optional[Dict[int, int]] = None, ) -> Iterable[Union[Node, Union[Tuple[Any, ...], Any]]]: """Preorder traversal of all roots of this Graph. @@ -57,8 +57,8 @@ def traverse( def node_order_traverse( self, order: str = "pre", - attrs: Union[List[str], Tuple[str, ...], str] = None, - visited: Dict[int, int] = None, + attrs: Optional[Union[List[str], Tuple[str, ...], str]] = None, + visited: Optional[Dict[int, int]] = None, ) -> Iterable[Union[Node, Union[Tuple[Any, ...], Any]]]: """Preorder traversal of all roots of this Graph, sorting by "node order" column. @@ -176,7 +176,7 @@ def normalize(self) -> Dict[Node, Node]: self.merge_nodes(merges) return merges - def copy(self, old_to_new: Dict[Node, Node] = None) -> "Graph": + def copy(self, old_to_new: Optional[Dict[Node, Node]] = None) -> "Graph": """Create and return a copy of this graph. Arguments: @@ -204,7 +204,9 @@ def copy(self, old_to_new: Dict[Node, Node] = None) -> "Graph": return graph - def union(self, other: "Graph", old_to_new: Dict[Node, Node] = None) -> "Graph": + def union( + self, other: "Graph", old_to_new: Optional[Dict[Node, Node]] = None + ) -> "Graph": """Create the union of self and other and return it as a new Graph. This creates a new graph and does not modify self or other. The @@ -395,7 +397,7 @@ def __len__(self) -> int: """Size of the graph in terms of number of nodes.""" return sum(1 for _ in self.traverse()) - def __eq__(self, other: object) -> bool: + def __eq__(self, other: "Graph") -> bool: """Check if two graphs have the same structure by comparing frame at each node. """ @@ -427,7 +429,7 @@ def __eq__(self, other: object) -> bool: return True - def __ne__(self, other: object) -> bool: + def __ne__(self, other: "Graph") -> bool: return not (self == other) @staticmethod diff --git a/hatchet/graphframe.py b/hatchet/graphframe.py index c2c92019..f350d465 100644 --- a/hatchet/graphframe.py +++ b/hatchet/graphframe.py @@ -9,7 +9,7 @@ import traceback from collections import defaultdict from collections.abc import Callable -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from io import TextIOWrapper import multiprocess as mp @@ -75,8 +75,8 @@ def __init__( self, graph: Graph, dataframe: pd.DataFrame, - exc_metrics: List[str] = None, - inc_metrics: List[str] = None, + exc_metrics: Optional[List[str]] = None, + inc_metrics: Optional[List[str]] = None, default_metric: str = "time", metadata: Dict[str, Any] = {}, ) -> None: @@ -131,9 +131,9 @@ def from_hpctoolkit(dirname: str) -> "GraphFrame": @staticmethod def from_hpctoolkit_latest( dirname: str, - max_depth: int = None, - min_percentage_of_application_time: int = None, - min_percentage_of_parent_time: int = None, + max_depth: Optional[int] = None, + min_percentage_of_application_time: Optional[int] = None, + min_percentage_of_parent_time: Optional[int] = None, ) -> "GraphFrame": """ Read an HPCToolkit database directory into a new GraphFrame @@ -159,7 +159,7 @@ def from_hpctoolkit_latest( @staticmethod def from_caliper( - filename_or_stream: Union[str, TextIOWrapper], query: str = None + filename_or_stream: Union[str, TextIOWrapper], query: Optional[str] = None ) -> "GraphFrame": """Read in a Caliper .cali or .json file. @@ -220,7 +220,7 @@ def from_timeseries( ).read_timeseries(level=level) @staticmethod - def from_spotdb(db_key: Any, list_of_ids: List = None) -> "GraphFrame": + def from_spotdb(db_key: Any, list_of_ids: Optional[List] = None) -> "GraphFrame": """Read multiple graph frames from a SpotDB instance Args: @@ -277,8 +277,8 @@ def from_tau(dirname: str) -> "GraphFrame": @staticmethod def from_timemory( - input: Union[str, TextIOWrapper, Dict[str, Any]] = None, - select: List[str] = None, + input: Optional[Union[str, TextIOWrapper, Dict[str, Any]]] = None, + select: Optional[List[str]] = None, **_kwargs, ) -> "GraphFrame": """Read in timemory data. @@ -726,7 +726,7 @@ def _init_sum_columns( def subtree_sum( self, columns: List[str], - out_columns: List[str] = None, + out_columns: Optional[List[str]] = None, function: Callable = lambda x: x.sum(min_count=1), ): """Compute sum of elements in subtrees. Valid only for trees. @@ -789,7 +789,7 @@ def subtree_sum( def subgraph_sum( self, columns: List[str], - out_columns: List[str] = None, + out_columns: Optional[List[str]] = None, function: Callable = lambda x: x.sum(min_count=1), ): """Compute sum of elements in subgraphs. @@ -852,7 +852,9 @@ def subgraph_sum( function(self.dataframe.loc[(subgraph_nodes), columns]) ) - def generate_exclusive_columns(self, inc_metrics: Union[str, List[str]] = None): + def generate_exclusive_columns( + self, inc_metrics: Optional[Union[str, List[str]]] = None + ): """Generates exclusive metrics from available inclusive metrics. Arguments: inc_metrics (str, list, optional): Instead of generating the exclusive time for each inclusive metric, it is possible to specify those metrics manually. Defaults to None. @@ -1025,8 +1027,8 @@ def unify(self, other: "GraphFrame"): ) def tree( self, - metric_column: str = None, - annotation_column: str = None, + metric_column: Optional[str] = None, + annotation_column: Optional[str] = None, precision: int = 3, name_column: str = "name", expand_name: bool = False, @@ -1037,10 +1039,10 @@ def tree( highlight_name: bool = False, colormap: str = "RdYlGn", invert_colormap: bool = False, - colormap_annotations: Union[str, List, Dict] = None, + colormap_annotations: Optional[Union[str, List, Dict]] = None, render_header: bool = True, - min_value: int = None, - max_value: int = None, + min_value: Optional[int] = None, + max_value: Optional[int] = None, ) -> str: """Visualize the Hatchet graphframe as a tree @@ -1109,7 +1111,7 @@ def tree( def to_dot( self, - metric: str = None, + metric: Optional[str] = None, name: str = "name", rank: int = 0, thread: int = 0, @@ -1124,7 +1126,14 @@ def to_dot( self.graph.roots, self.dataframe, metric, name, rank, thread, threshold ) - def to_flamegraph(self, metric=None, name="name", rank=0, thread=0, threshold=0.0): + def to_flamegraph( + self, + metric: Optional[Union[str, Tuple[str, ...]]] = None, + name: str = "name", + rank: int = 0, + thread: int = 0, + threshold: float = 0.0, + ) -> str: """Write the graph in the folded stack output required by FlameGraph http://www.brendangregg.com/flamegraphs.html """ diff --git a/hatchet/node.py b/hatchet/node.py index 022ad532..101dca45 100644 --- a/hatchet/node.py +++ b/hatchet/node.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: MIT from functools import total_ordering -from typing import Any, Dict, List, Set, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union from collections.abc import Iterable from .frame import Frame @@ -26,7 +26,11 @@ class Node: """A node in the graph. The node only stores its frame.""" def __init__( - self, frame_obj: Frame, parent: "Node" = None, hnid: int = -1, depth: int = -1 + self, + frame_obj: Frame, + parent: Optional["Node"] = None, + hnid: int = -1, + depth: int = -1, ) -> None: self.frame = frame_obj self._depth = depth @@ -62,7 +66,7 @@ def paths(self) -> List[Tuple["Node", ...]]: paths.extend([path + node_value for path in parent_paths]) return paths - def path(self, attrs: Dict[str, Any] = None) -> Tuple["Node", ...]: + def path(self, attrs: Optional[Dict[str, Any]] = None) -> Tuple["Node", ...]: """Path to this node from root. Raises if there are multiple paths. This is useful for trees (where each node only has one path), as @@ -76,7 +80,10 @@ def path(self, attrs: Dict[str, Any] = None) -> Tuple["Node", ...]: return paths[0] def dag_equal( - self, other: "Node", vs: Set[int] = None, vo: Set[int] = None + self, + other: "Node", + vs: Optional[Set[int]] = None, + vo: Optional[Set[int]] = None, ) -> bool: """Check if DAG rooted at self has the same structure as that rooted at other. @@ -122,8 +129,8 @@ def dag_equal( def traverse( self, order: str = "pre", - attrs: Union[List[str], Tuple[str, ...], str] = None, - visited: Dict[int, int] = None, + attrs: Optional[Union[List[str], Tuple[str, ...], str]] = None, + visited: Optional[Dict[int, int]] = None, ) -> Iterable[Union["Node", Union[Tuple[Any, ...], Any]]]: """Traverse the tree depth-first and yield each node. @@ -163,8 +170,8 @@ def value(node): def node_order_traverse( self, order: str = "pre", - attrs: Union[List[str], Tuple[str, ...], str] = None, - visited: Dict[int, int] = None, + attrs: Optional[Union[List[str], Tuple[str, ...], str]] = None, + visited: Optional[Dict[int, int]] = None, ) -> Iterable[Union["Node", Union[Tuple[Any, ...], Any]]]: """Traverse the tree depth-first and yield each node, sorting children by "node order". @@ -206,13 +213,13 @@ def value(node): def __hash__(self) -> int: return self._hatchet_nid - def __eq__(self, other: object) -> bool: + def __eq__(self, other: "Node") -> bool: return self._hatchet_nid == other._hatchet_nid - def __lt__(self, other: object) -> bool: + def __lt__(self, other: "Node") -> bool: return self._hatchet_nid < other._hatchet_nid - def __gt__(self, other: object) -> bool: + def __gt__(self, other: "Node") -> bool: return self._hatchet_nid > other._hatchet_nid def __str__(self) -> str: diff --git a/hatchet/query/compat.py b/hatchet/query/compat.py index 2e71a682..7197ffc6 100644 --- a/hatchet/query/compat.py +++ b/hatchet/query/compat.py @@ -14,7 +14,7 @@ import sys import warnings from collections.abc import Callable -from typing import List, Union +from typing import List, Optional, Union from ..graphframe import GraphFrame from ..node import Node @@ -131,7 +131,7 @@ def apply(self, gf: GraphFrame) -> List[Node]: true_query = self._get_new_query() return COMPATABILITY_ENGINE.apply(true_query, gf.graph, gf.dataframe) - def _get_new_query(self) -> CompoundQuery: + def _get_new_query(self) -> Union[Query, CompoundQuery]: """Gets all the underlying 'new-style' queries in this object. Returns: @@ -279,7 +279,7 @@ def _convert_to_new_query( class QueryMatcher(AbstractQuery): """Processes and applies base syntax queries and Object-based queries to GraphFrames.""" - def __init__(self, query: Union[List, Query] = None) -> None: + def __init__(self, query: Optional[Union[List, Query]] = None) -> None: """Create a new QueryMatcher object. Arguments: diff --git a/hatchet/query/engine.py b/hatchet/query/engine.py index ecd7acd7..dfc5b383 100644 --- a/hatchet/query/engine.py +++ b/hatchet/query/engine.py @@ -5,7 +5,7 @@ from itertools import groupby import pandas as pd -from typing import List, Set, Union +from typing import List, Optional, Set, Union from .errors import InvalidQueryFilter from ..node import Node, traversal_order @@ -88,7 +88,7 @@ def _cache_node(self, node: Node, query: Query, dframe: pd.DataFrame) -> None: def _match_0_or_more( self, query: Query, dframe: pd.DataFrame, node: Node, wcard_idx: int - ) -> List[List[Node]]: + ) -> Optional[List[List[Node]]]: """Process a "*" predicate in the query on a subgraph. Arguments: @@ -136,7 +136,7 @@ def _match_0_or_more( def _match_1( self, query: Query, dframe: pd.DataFrame, node: Node, idx: int - ) -> List[List[Node]]: + ) -> Optional[List[List[Node]]]: """Process a "." predicate in the query on a subgraph. Arguments: @@ -166,7 +166,7 @@ def _match_1( def _match_pattern( self, query: Query, dframe: pd.DataFrame, pattern_root: Node, match_idx: int - ) -> List[List[Node]]: + ) -> Optional[List[List[Node]]]: """Try to match the query pattern starting at the provided root node. Arguments: diff --git a/hatchet/query/string_dialect.py b/hatchet/query/string_dialect.py index 6f83c8b1..d013e93d 100644 --- a/hatchet/query/string_dialect.py +++ b/hatchet/query/string_dialect.py @@ -7,7 +7,7 @@ import re import sys from collections.abc import Callable -from typing import Any, Tuple, Union +from typing import Any, Optional, Tuple, Union import pandas as pd # noqa: F401 from pandas.api.types import is_numeric_dtype, is_string_dtype # noqa: F401 import numpy as np # noqa: F401 @@ -231,7 +231,9 @@ def _is_binary_cond(self, obj: Any) -> bool: return True return False - def _parse_binary_cond(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_binary_cond( + self, obj: Any + ) -> Tuple[Optional[str], str, str, Optional[str]]: """Top level function for parsing binary predicates.""" if cname(obj) == "AndCond": return self._parse_and_cond(obj) @@ -239,25 +241,31 @@ def _parse_binary_cond(self, obj: Any) -> Tuple[str, str, str, str]: return self._parse_or_cond(obj) raise RuntimeError("Bad Binary Condition") - def _parse_or_cond(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_or_cond(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: """Top level function for parsing predicates combined with logical OR.""" converted_subcond = self._parse_unary_cond(obj.subcond) converted_subcond[0] = "or" return converted_subcond - def _parse_and_cond(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_and_cond( + self, obj: Any + ) -> Tuple[Optional[str], str, str, Optional[str]]: """Top level function for parsing predicates combined with logical AND.""" converted_subcond = self._parse_unary_cond(obj.subcond) converted_subcond[0] = "and" return converted_subcond - def _parse_unary_cond(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_unary_cond( + self, obj: Any + ) -> Tuple[Optional[str], str, str, Optional[str]]: """Top level function for parsing unary predicates.""" if cname(obj) == "NotCond": return self._parse_not_cond(obj) return self._parse_single_cond(obj) - def _parse_not_cond(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_not_cond( + self, obj: Any + ) -> Tuple[Optional[str], str, str, Optional[str]]: """Parse predicates containing the logical NOT operator.""" converted_subcond = self._parse_single_cond(obj.subcond) converted_subcond[2] = "not {}".format(converted_subcond[2]) @@ -265,14 +273,16 @@ def _parse_not_cond(self, obj: Any) -> Tuple[str, str, str, str]: def _run_method_based_on_multi_idx_mode( self, method_name: str, obj: Any - ) -> Tuple[str, str, str, str]: + ) -> Tuple[Optional[str], str, str, Optional[str]]: real_method_name = method_name if self.multi_index_mode != "off": real_method_name = method_name + "_multi_idx" method = eval("StringQuery.{}".format(real_method_name)) return method(self, obj) - def _parse_single_cond(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_single_cond( + self, obj: Any + ) -> Tuple[Optional[str], str, str, Optional[str]]: """Top level function for parsing individual numeric or string predicates.""" if self._is_str_cond(obj): return self._parse_str(obj) @@ -288,7 +298,7 @@ def _parse_single_cond(self, obj: Any) -> Tuple[str, str, str, str]: return self._run_method_based_on_multi_idx_mode("_parse_not_leaf", obj) raise RuntimeError("Bad Single Condition") - def _parse_none(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_none(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: """Parses 'property IS NONE'.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ @@ -320,7 +330,9 @@ def _add_aggregation_call_to_multi_idx_predicate(self, predicate: str) -> str: return predicate + ".any()" return predicate + ".all()" - def _parse_none_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_none_multi_idx( + self, obj: Any + ) -> Tuple[Optional[str], str, str, Optional[str]]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ None, @@ -348,7 +360,9 @@ def _parse_none_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: None, ] - def _parse_not_none(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_not_none( + self, obj: Any + ) -> Tuple[Optional[str], str, str, Optional[str]]: """Parses 'property IS NOT NONE'.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ @@ -375,7 +389,9 @@ def _parse_not_none(self, obj: Any) -> Tuple[str, str, str, str]: None, ] - def _parse_not_none_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_not_none_multi_idx( + self, obj: Any + ) -> Tuple[Optional[str], str, str, Optional[str]]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ None, @@ -403,7 +419,7 @@ def _parse_not_none_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: None, ] - def _parse_leaf(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_leaf(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: """Parses 'node IS LEAF'.""" return [ None, @@ -412,7 +428,9 @@ def _parse_leaf(self, obj: Any) -> Tuple[str, str, str, str]: None, ] - def _parse_leaf_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_leaf_multi_idx( + self, obj: Any + ) -> Tuple[Optional[str], str, str, Optional[str]]: return [ None, obj.name, @@ -420,7 +438,9 @@ def _parse_leaf_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: None, ] - def _parse_not_leaf(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_not_leaf( + self, obj: Any + ) -> Tuple[Optional[str], str, str, Optional[str]]: """Parses 'node IS NOT LEAF'.""" return [ None, @@ -429,7 +449,9 @@ def _parse_not_leaf(self, obj: Any) -> Tuple[str, str, str, str]: None, ] - def _parse_not_leaf_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_not_leaf_multi_idx( + self, obj: Any + ) -> Tuple[Optional[str], str, str, Optional[str]]: return [ None, obj.name, @@ -465,7 +487,7 @@ def _is_num_cond(self, obj: Any) -> bool: return True return False - def _parse_str(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_str(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: """Function that redirects processing of string predicates to the correct function. """ @@ -483,7 +505,7 @@ def _parse_str(self, obj: Any) -> Tuple[str, str, str, str]: return self._run_method_based_on_multi_idx_mode("_parse_str_match", obj) raise RuntimeError("Bad String Op Class") - def _parse_str_eq(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_str_eq(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: """Processes string equivalence predicates.""" return [ None, @@ -503,7 +525,9 @@ def _parse_str_eq(self, obj: Any) -> Tuple[str, str, str, str]: ), ] - def _parse_str_eq_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_str_eq_multi_idx( + self, obj: Any + ) -> Tuple[Optional[str], str, str, Optional[str]]: return [ None, obj.name, @@ -524,7 +548,9 @@ def _parse_str_eq_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: ), ] - def _parse_str_starts_with(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_str_starts_with( + self, obj: Any + ) -> Tuple[Optional[str], str, str, Optional[str]]: """Processes string 'startswith' predicates.""" return [ None, @@ -544,7 +570,9 @@ def _parse_str_starts_with(self, obj: Any) -> Tuple[str, str, str, str]: ), ] - def _parse_str_starts_with_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_str_starts_with_multi_idx( + self, obj: Any + ) -> Tuple[Optional[str], str, str, Optional[str]]: return [ None, obj.name, @@ -565,7 +593,9 @@ def _parse_str_starts_with_multi_idx(self, obj: Any) -> Tuple[str, str, str, str ), ] - def _parse_str_ends_with(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_str_ends_with( + self, obj: Any + ) -> Tuple[Optional[str], str, str, Optional[str]]: """Processes string 'endswith' predicates.""" return [ None, @@ -585,7 +615,9 @@ def _parse_str_ends_with(self, obj: Any) -> Tuple[str, str, str, str]: ), ] - def _parse_str_ends_with_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_str_ends_with_multi_idx( + self, obj: Any + ) -> Tuple[Optional[str], str, str, Optional[str]]: return [ None, obj.name, @@ -606,7 +638,9 @@ def _parse_str_ends_with_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: ), ] - def _parse_str_contains(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_str_contains( + self, obj: Any + ) -> Tuple[Optional[str], str, str, Optional[str]]: """Processes string 'contains' predicates.""" return [ None, @@ -626,7 +660,9 @@ def _parse_str_contains(self, obj: Any) -> Tuple[str, str, str, str]: ), ] - def _parse_str_contains_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_str_contains_multi_idx( + self, obj: Any + ) -> Tuple[Optional[str], str, str, Optional[str]]: return [ None, obj.name, @@ -647,7 +683,9 @@ def _parse_str_contains_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: ), ] - def _parse_str_match(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_str_match( + self, obj: Any + ) -> Tuple[Optional[str], str, str, Optional[str]]: """Processes string regex match predicates.""" return [ None, @@ -667,7 +705,9 @@ def _parse_str_match(self, obj: Any) -> Tuple[str, str, str, str]: ), ] - def _parse_str_match_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_str_match_multi_idx( + self, obj: Any + ) -> Tuple[Optional[str], str, str, Optional[str]]: return [ None, obj.name, @@ -688,7 +728,7 @@ def _parse_str_match_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: ), ] - def _parse_num(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_num(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: """Function that redirects processing of numeric predicates to the correct function. """ @@ -712,7 +752,7 @@ def _parse_num(self, obj: Any) -> Tuple[str, str, str, str]: return self._run_method_based_on_multi_idx_mode("_parse_num_not_inf", obj) raise RuntimeError("Bad Number Op Class") - def _parse_num_eq(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_num_eq(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: """Processes numeric equivalence predicates.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val == -1: @@ -785,7 +825,9 @@ def _parse_num_eq(self, obj: Any) -> Tuple[str, str, str, str]: ), ] - def _parse_num_eq_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_num_eq_multi_idx( + self, obj: Any + ) -> Tuple[Optional[str], str, str, Optional[str]]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val == -1: return [ @@ -861,7 +903,7 @@ def _parse_num_eq_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: ), ] - def _parse_num_lt(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_num_lt(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: """Processes numeric less-than predicates.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: @@ -927,7 +969,9 @@ def _parse_num_lt(self, obj: Any) -> Tuple[str, str, str, str]: ), ] - def _parse_num_lt_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_num_lt_multi_idx( + self, obj: Any + ) -> Tuple[Optional[str], str, str, Optional[str]]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: warnings.warn( @@ -996,7 +1040,7 @@ def _parse_num_lt_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: ), ] - def _parse_num_gt(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_num_gt(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: """Processes numeric greater-than predicates.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: @@ -1062,7 +1106,9 @@ def _parse_num_gt(self, obj: Any) -> Tuple[str, str, str, str]: ), ] - def _parse_num_gt_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_num_gt_multi_idx( + self, obj: Any + ) -> Tuple[Optional[str], str, str, Optional[str]]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: warnings.warn( @@ -1131,7 +1177,7 @@ def _parse_num_gt_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: ), ] - def _parse_num_lte(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_num_lte(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: """Processes numeric less-than-or-equal-to predicates.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: @@ -1197,7 +1243,9 @@ def _parse_num_lte(self, obj: Any) -> Tuple[str, str, str, str]: ), ] - def _parse_num_lte_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_num_lte_multi_idx( + self, obj: Any + ) -> Tuple[Optional[str], str, str, Optional[str]]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: warnings.warn( @@ -1266,7 +1314,7 @@ def _parse_num_lte_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: ), ] - def _parse_num_gte(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_num_gte(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: """Processes numeric greater-than-or-equal-to predicates.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: @@ -1332,7 +1380,9 @@ def _parse_num_gte(self, obj: Any) -> Tuple[str, str, str, str]: ), ] - def _parse_num_gte_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_num_gte_multi_idx( + self, obj: Any + ) -> Tuple[Optional[str], str, str, Optional[str]]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: warnings.warn( @@ -1401,7 +1451,7 @@ def _parse_num_gte_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: ), ] - def _parse_num_nan(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_num_nan(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: """Processes predicates that check for NaN.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ @@ -1432,7 +1482,9 @@ def _parse_num_nan(self, obj: Any) -> Tuple[str, str, str, str]: ), ] - def _parse_num_nan_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_num_nan_multi_idx( + self, obj: Any + ) -> Tuple[Optional[str], str, str, Optional[str]]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ None, @@ -1464,7 +1516,9 @@ def _parse_num_nan_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: ), ] - def _parse_num_not_nan(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_num_not_nan( + self, obj: Any + ) -> Tuple[Optional[str], str, str, Optional[str]]: """Processes predicates that check for NaN.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ @@ -1495,7 +1549,9 @@ def _parse_num_not_nan(self, obj: Any) -> Tuple[str, str, str, str]: ), ] - def _parse_num_not_nan_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_num_not_nan_multi_idx( + self, obj: Any + ) -> Tuple[Optional[str], str, str, Optional[str]]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ None, @@ -1527,7 +1583,7 @@ def _parse_num_not_nan_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: ), ] - def _parse_num_inf(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_num_inf(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: """Processes predicates that check for Infinity.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ @@ -1558,7 +1614,9 @@ def _parse_num_inf(self, obj: Any) -> Tuple[str, str, str, str]: ), ] - def _parse_num_inf_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_num_inf_multi_idx( + self, obj: Any + ) -> Tuple[Optional[str], str, str, Optional[str]]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ None, @@ -1590,7 +1648,9 @@ def _parse_num_inf_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: ), ] - def _parse_num_not_inf(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_num_not_inf( + self, obj: Any + ) -> Tuple[Optional[str], str, str, Optional[str]]: """Processes predicates that check for not-Infinity.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ @@ -1621,7 +1681,9 @@ def _parse_num_not_inf(self, obj: Any) -> Tuple[str, str, str, str]: ), ] - def _parse_num_not_inf_multi_idx(self, obj: Any) -> Tuple[str, str, str, str]: + def _parse_num_not_inf_multi_idx( + self, obj: Any + ) -> Tuple[Optional[str], str, str, Optional[str]]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ None, diff --git a/hatchet/readers/hpctoolkit_reader_latest.py b/hatchet/readers/hpctoolkit_reader_latest.py index 3181fe84..97c1cd46 100644 --- a/hatchet/readers/hpctoolkit_reader_latest.py +++ b/hatchet/readers/hpctoolkit_reader_latest.py @@ -6,7 +6,7 @@ import os import re import struct -from typing import Dict, Tuple, Union +from typing import Dict, Optional, Tuple, Union import pandas as pd @@ -17,7 +17,11 @@ def safe_unpack( - format: str, data: bytes, offset: int, index: int = None, index_length: int = None + format: str, + data: bytes, + offset: int, + index: Optional[int] = None, + index_length: Optional[int] = None, ) -> Tuple: length = struct.calcsize(format) if index: @@ -52,9 +56,9 @@ class HPCToolkitReaderLatest: def __init__( self, dir_path: str, - max_depth: int = None, - min_application_percentage_time: int = None, - min_parent_percentage_time: int = None, + max_depth: Optional[int] = None, + min_application_percentage_time: Optional[int] = None, + min_parent_percentage_time: Optional[int] = None, ) -> None: self._dir_path = dir_path self._max_depth = max_depth @@ -216,7 +220,7 @@ def _parse_function( return self._functions[pFunction] def _store_cct_node( - self, ctxId: int, frame: Dict, parent: Node = None, depth: int = 0 + self, ctxId: int, frame: Dict, parent: Optional[Node] = None, depth: int = 0 ) -> Node: node = Node(Frame(frame), parent=parent, hnid=ctxId, depth=depth) if parent is None: diff --git a/hatchet/readers/spotdb_reader.py b/hatchet/readers/spotdb_reader.py index bccee0e9..5352019e 100644 --- a/hatchet/readers/spotdb_reader.py +++ b/hatchet/readers/spotdb_reader.py @@ -3,7 +3,7 @@ # # SPDX-License-Identifier: MIT -from typing import Any, Dict, List +from typing import Any, Dict, Optional, List import pandas as pd @@ -150,7 +150,7 @@ class SpotDBReader: def __init__( self, db_key: Any, - list_of_ids: List = None, + list_of_ids: Optional[List] = None, default_metric: str = "Total time (inc)", ) -> None: """Initialize SpotDBReader diff --git a/hatchet/readers/timemory_reader.py b/hatchet/readers/timemory_reader.py index d3a01c51..e4ff0d91 100644 --- a/hatchet/readers/timemory_reader.py +++ b/hatchet/readers/timemory_reader.py @@ -9,7 +9,7 @@ import glob import re from io import TextIOWrapper -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from hatchet.graphframe import GraphFrame from ..node import Node from ..graph import Graph @@ -23,7 +23,7 @@ class TimemoryReader: def __init__( self, input: Union[str, TextIOWrapper, Dict], - select: List[str] = None, + select: Optional[List[str]] = None, **_kwargs, ) -> None: """Arguments: @@ -107,7 +107,7 @@ def add_metrics(_dict: Dict[str, Any]) -> None: if key not in self.metric_cols: self.metric_cols.append(key) - def process_regex(_data: re.Match) -> Dict[str, str]: + def process_regex(_data: re.Match) -> Optional[Dict[str, str]]: """Process the regex data for func/file/line info""" _tmp = {} if _data is not None and len(_data.groups()) > 0: @@ -120,7 +120,7 @@ def process_regex(_data: re.Match) -> Dict[str, str]: pass return _tmp if _tmp else None - def perform_regex(_prefix: str) -> Dict[str, str]: + def perform_regex(_prefix: str) -> Optional[Dict[str, str]]: """Performs a search for standard configurations of function + file + line""" _tmp = None for _pattern in [ @@ -230,7 +230,7 @@ def match_labels_and_values( def collapse_ids( _obj: Union[int, List[int]], _expect_scalar: bool = False - ) -> Union[str, int]: + ) -> Optional[Union[str, int]]: """node/rank/thread id may be int, array of ints, or None. When the entry is a list of integers (which happens when metric values are aggregates of multiple ranks/threads), this function generates a consistent diff --git a/hatchet/util/executable.py b/hatchet/util/executable.py index 49ae04a4..2e6b42d3 100644 --- a/hatchet/util/executable.py +++ b/hatchet/util/executable.py @@ -4,9 +4,10 @@ # SPDX-License-Identifier: MIT import os +from typing import Optional -def which(executable: str) -> str: +def which(executable: str) -> Optional[str]: """Finds an `executable` in the user's PATH like command-line which. Args: diff --git a/pyproject.toml b/pyproject.toml index 1a2dfd6b..a1261f27 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,9 @@ authors = [ ] license = "MIT" +[tool.mypy] +exclude = "hatchet/tests" + [tool.ruff] line-length = 88 target-version = 'py39' From db076f6cc4a966652af3f10aec6d06adf4464a87 Mon Sep 17 00:00:00 2001 From: Ian Lumsden Date: Wed, 13 Nov 2024 22:28:58 -0500 Subject: [PATCH 3/5] Fixes typing issues and bugs identified by mypy --- .gitignore | 165 +++++++++- hatchet/external/__init__.py | 33 +- hatchet/external/console.py | 26 +- hatchet/frame.py | 8 +- hatchet/graph.py | 29 +- hatchet/graphframe.py | 80 +++-- hatchet/node.py | 17 +- hatchet/query/__init__.py | 37 +-- hatchet/query/compat.py | 37 +-- hatchet/query/compound.py | 177 ++++++++++- hatchet/query/engine.py | 33 +- hatchet/query/object_dialect.py | 20 +- hatchet/query/query.py | 16 +- hatchet/query/string_dialect.py | 324 +++----------------- hatchet/readers/caliper_native_reader.py | 59 ++-- hatchet/readers/caliper_reader.py | 61 ++-- hatchet/readers/dataframe_reader.py | 15 +- hatchet/readers/gprof_dot_reader.py | 6 +- hatchet/readers/hpctoolkit_reader.py | 27 +- hatchet/readers/hpctoolkit_reader_latest.py | 31 +- hatchet/readers/literal_reader.py | 7 +- hatchet/readers/pyinstrument_reader.py | 8 +- hatchet/readers/spotdb_reader.py | 10 +- hatchet/readers/tau_reader.py | 81 ++--- hatchet/readers/timemory_reader.py | 36 ++- hatchet/util/colormaps.py | 2 +- hatchet/util/dot.py | 10 +- hatchet/util/executable.py | 4 +- hatchet/util/profiler.py | 5 +- hatchet/util/timer.py | 7 +- hatchet/writers/dataframe_writer.py | 13 +- pyproject.toml | 9 +- 32 files changed, 767 insertions(+), 626 deletions(-) diff --git a/.gitignore b/.gitignore index 7473c86b..cf4fafd1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,10 +1,5 @@ *.pyc -.cache -.pytest_cache -.ipynb_checkpoints -build -docs/_build hatchet/cython_modules/libs/graphframe_modules.*.so hatchet/cython_modules/libs/reader_modules.*.so hatchet/cython_modules/*.c @@ -12,5 +7,163 @@ hatchet/vis/*node_modules* hatchet/vis/static/*_bundle* *package-lock.json +############################################### +# Everything from here on comes from the GitHub +# gitignore template for Python projects +############################################### + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ dist/ -llnl_hatchet.egg-info/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ diff --git a/hatchet/external/__init__.py b/hatchet/external/__init__.py index 166a47b7..f06f71b2 100644 --- a/hatchet/external/__init__.py +++ b/hatchet/external/__init__.py @@ -3,6 +3,8 @@ # # SPDX-License-Identifier: MIT +from typing import TYPE_CHECKING + class VersionError(Exception): """ @@ -13,23 +15,24 @@ class VersionError(Exception): pass -try: - import IPython +if not TYPE_CHECKING: + try: + import IPython - # Testing IPython version - if int(IPython.__version__.split(".")[0]) > 7: - raise VersionError() + # Testing IPython version + if int(IPython.__version__.split(".")[0]) > 7: + raise VersionError() - from .roundtrip.roundtrip.manager import Roundtrip + from .roundtrip.roundtrip.manager import Roundtrip - # Refrencing Roundtrip here to resolve scope issues with import - Roundtrip + # Refrencing Roundtrip here to resolve scope issues with import + Roundtrip -except ImportError: - pass + except ImportError: + pass -except VersionError: - if IPython.get_ipython() is not None: - print( - "Warning: Roundtrip module could not be loaded. Requires jupyter notebook version <= 7.x." - ) + except VersionError: + if IPython.get_ipython() is not None: + print( + "Warning: Roundtrip module could not be loaded. Requires jupyter notebook version <= 7.x." + ) diff --git a/hatchet/external/console.py b/hatchet/external/console.py index 1de5e145..6b7de115 100644 --- a/hatchet/external/console.py +++ b/hatchet/external/console.py @@ -34,7 +34,7 @@ import pandas as pd import numpy as np import warnings -from typing import List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from ..util.colormaps import ColorMaps from ..node import Node @@ -43,14 +43,19 @@ class ConsoleRenderer: def __init__(self, unicode: bool = False, color: bool = False) -> None: self.unicode = unicode self.color = color - self.visited = [] + self.visited: List[Node] = [] + self.colors_annotations_mapping: Optional[Union[List, Dict[str, Any]]] = None + self.colors: Optional[ + Union["ConsoleRenderer.colors_enabled", "ConsoleRenderer.colors_disabled"] + ] = None + self.temporal_symbols: Dict[str, str] = {} def render( self, roots: Optional[Union[List[Node], Tuple[Node, ...]]], dataframe: pd.DataFrame, **kwargs, - ) -> str: + ) -> Union[str, bytes]: self.render_header = kwargs["render_header"] if self.render_header: @@ -79,7 +84,7 @@ def render( self.max_value = kwargs["max_value"] if self.color: - self.colors = self.colors_enabled + self.colors = self.colors_enabled() # set the colormap based on user input self.colors.colormap = ColorMaps().get_colors( self.colormap, self.invert_colormap @@ -100,7 +105,7 @@ def render( elif isinstance(self.colormap_annotations, dict): self.colors_annotations_mapping = self.colormap_annotations else: - self.colors = self.colors_disabled + self.colors = self.colors_disabled() if isinstance(self.metric_columns, (str, tuple)): self.primary_metric = self.metric_columns @@ -265,14 +270,13 @@ def render_frame( if node_depth < self.depth: # set dataframe index based on whether rank and thread are part of # the MultiIndex + df_index: Union[Tuple[Node, int, int], Tuple[Node, int], Node] = node if "rank" in dataframe.index.names and "thread" in dataframe.index.names: df_index = (node, self.rank, self.thread) elif "rank" in dataframe.index.names: df_index = (node, self.rank) elif "thread" in dataframe.index.names: df_index = (node, self.thread) - else: - df_index = node node_metric = dataframe.loc[df_index, self.primary_metric] @@ -326,10 +330,12 @@ def render_frame( # no pattern column elif self.colormap_annotations: if isinstance(self.colormap_annotations, dict): + assert isinstance(self.colors_annotations_mapping, Dict) color_annotation = self.colors_annotations_mapping[ annotation_content ] else: + assert isinstance(self.colors_annotations_mapping, List) color_annotation = self.colors_annotations.colormap[ self.colors_annotations_mapping.index(annotation_content) % len(self.colors_annotations.colormap) @@ -441,7 +447,7 @@ def _ansi_color_for_name(self, node_name: str) -> str: return self.colors.bg_white_255 + self.colors.dark_gray_255 class colors_enabled: - colormap = [] + colormap: List[str] = [] blue = "\033[34m" cyan = "\033[36m" @@ -456,9 +462,7 @@ class colors_enabled: end = "\033[0m" class colors_disabled: - colormap = ["", "", "", "", "", "", ""] + colormap: List[str] = ["", "", "", "", "", "", ""] def __getattr__(self, key: str) -> str: return "" - - colors_disabled = colors_disabled() diff --git a/hatchet/frame.py b/hatchet/frame.py index 4462640c..affd0708 100644 --- a/hatchet/frame.py +++ b/hatchet/frame.py @@ -47,9 +47,11 @@ def __init__(self, attrs: Optional[Dict[str, Any]] = None, **kwargs) -> None: if "type" not in self.attrs: self.attrs["type"] = "None" - self._tuple_repr = None + self._tuple_repr: Optional[Tuple[Tuple[str, Any], ...]] = None - def __eq__(self, other: "Frame") -> bool: + def __eq__(self, other: object) -> bool: + if not isinstance(other, Frame): + return NotImplemented return self.tuple_repr == other.tuple_repr def __lt__(self, other: "Frame") -> bool: @@ -69,7 +71,7 @@ def __repr__(self) -> str: return "Frame(%s)" % self @property - def tuple_repr(self) -> Tuple[Tuple[str, Any], ...]: + def tuple_repr(self) -> Optional[Tuple[Tuple[str, Any], ...]]: """Make a tuple of attributes and values based on reader.""" if not self._tuple_repr: self._tuple_repr = tuple(sorted((k, v) for k, v in self.attrs.items())) diff --git a/hatchet/graph.py b/hatchet/graph.py index 78a1fae5..91524d29 100644 --- a/hatchet/graph.py +++ b/hatchet/graph.py @@ -5,7 +5,7 @@ from collections import defaultdict from collections.abc import Iterable -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union from .node import Node, traversal_order, node_traversal_order @@ -86,7 +86,7 @@ def is_tree(self) -> bool: if len(self.roots) > 1: return False - visited = {} + visited: Dict[int, int] = {} list(self.traverse(visited=visited)) return all(v == 1 for v in visited.values()) @@ -101,7 +101,7 @@ def find_merges(self) -> Dict[Node, Node]: (dict): dictionary from nodes to their merge targets """ - merges = {} # old_node -> merged_node + merges: Dict[Node, Node] = {} # old_node -> merged_node inverted_merges = defaultdict( lambda: [] ) # merged_node -> list of corresponding old_nodes @@ -125,6 +125,7 @@ def _find_child_merges(node_list): _find_child_merges(self.roots) for node in self.traverse(): + assert isinstance(node, Node) if node in processed: continue nodes = None @@ -189,6 +190,7 @@ def copy(self, old_to_new: Optional[Dict[Node, Node]] = None) -> "Graph": # first pass creates new nodes for node in self.traverse(): + assert isinstance(node, Node) old_to_new[node] = node.copy() # second pass hooks up parents and children @@ -205,7 +207,7 @@ def copy(self, old_to_new: Optional[Dict[Node, Node]] = None) -> "Graph": return graph def union( - self, other: "Graph", old_to_new: Optional[Dict[Node, Node]] = None + self, other: "Graph", old_to_new: Optional[Dict[int, Node]] = None ) -> "Graph": """Create the union of self and other and return it as a new Graph. @@ -365,7 +367,7 @@ def _iter_depth(node, visited): child._depth = node._depth + 1 _iter_depth(child, visited) - visited = set() + visited: Set[Node] = set() for root in self.roots: root._depth = 0 # depth of root node is 0 _iter_depth(root, visited) @@ -375,9 +377,11 @@ def enumerate_traverse(self) -> None: # if "node order" column exists, we traverse sorting by _hatchet_nid if self.node_ordering: for i, node in enumerate(self.node_order_traverse()): + assert isinstance(node, Node) node._hatchet_nid = i else: for i, node in enumerate(self.traverse()): + assert isinstance(node, Node) node._hatchet_nid = i self.enumerate_depth() @@ -386,23 +390,28 @@ def _check_enumerate_traverse(self) -> bool: # if "node order" column exists, we traverse sorting by _hatchet_nid if self.node_ordering: for i, node in enumerate(self.node_order_traverse()): + assert isinstance(node, Node) if i != node._hatchet_nid: return False else: for i, node in enumerate(self.traverse()): + assert isinstance(node, Node) if i != node._hatchet_nid: return False + return True def __len__(self) -> int: """Size of the graph in terms of number of nodes.""" return sum(1 for _ in self.traverse()) - def __eq__(self, other: "Graph") -> bool: + def __eq__(self, other: object) -> bool: """Check if two graphs have the same structure by comparing frame at each node. """ - vs = set() - vo = set() + if not isinstance(other, Graph): + return NotImplemented + vs: Set[int] = set() + vo: Set[int] = set() # if both graphs are pointing to the same object, then graphs are equal if self is other: @@ -429,7 +438,9 @@ def __eq__(self, other: "Graph") -> bool: return True - def __ne__(self, other: "Graph") -> bool: + def __ne__(self, other: object) -> bool: + if not isinstance(other, Graph): + return NotImplemented return not (self == other) @staticmethod diff --git a/hatchet/graphframe.py b/hatchet/graphframe.py index f350d465..5e1ff5e3 100644 --- a/hatchet/graphframe.py +++ b/hatchet/graphframe.py @@ -8,8 +8,8 @@ import sys import traceback from collections import defaultdict -from collections.abc import Callable -from typing import Any, Dict, List, Optional, Tuple, Union +from collections.abc import Callable, Iterable +from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast from io import TextIOWrapper import multiprocess as mp @@ -33,7 +33,7 @@ from .util.dot import trees_to_dot try: - from .cython_modules.libs import graphframe_modules as _gfm_cy + from .cython_modules.libs import graphframe_modules as _gfm_cy # type: ignore except ImportError: print("-" * 80) print( @@ -134,7 +134,7 @@ def from_hpctoolkit_latest( max_depth: Optional[int] = None, min_percentage_of_application_time: Optional[int] = None, min_percentage_of_parent_time: Optional[int] = None, - ) -> "GraphFrame": + ) -> Optional["GraphFrame"]: """ Read an HPCToolkit database directory into a new GraphFrame @@ -202,7 +202,7 @@ def from_timeseries( level: str = "loop.start_iteration", native: bool = False, string_attributes: Union[List[str], str] = [], - ) -> "GraphFrame": + ) -> List["GraphFrame"]: """Read in a native Caliper timeseries `cali` file using Caliper's python reader. Args: @@ -220,7 +220,9 @@ def from_timeseries( ).read_timeseries(level=level) @staticmethod - def from_spotdb(db_key: Any, list_of_ids: Optional[List] = None) -> "GraphFrame": + def from_spotdb( + db_key: Any, list_of_ids: Optional[List] = None + ) -> List["GraphFrame"]: """Read multiple graph frames from a SpotDB instance Args: @@ -280,7 +282,7 @@ def from_timemory( input: Optional[Union[str, TextIOWrapper, Dict[str, Any]]] = None, select: Optional[List[str]] = None, **_kwargs, - ) -> "GraphFrame": + ) -> Optional["GraphFrame"]: """Read in timemory data. Links: @@ -354,14 +356,17 @@ def from_timemory( pass else: try: - import timemory + import timemory # type: ignore[import-not-found] - TimemoryReader(timemory.get(hierarchy=True), select, **_kwargs).read() + return TimemoryReader( + timemory.get(hierarchy=True), select, **_kwargs + ).read() except ImportError: print( "Error! timemory could not be imported. Provide filename, file stream, or dict." ) raise + return None @staticmethod def from_literal(graph_dict: List[Dict]) -> "GraphFrame": @@ -385,7 +390,11 @@ def from_lists(*lists) -> "GraphFrame": df = pd.DataFrame({"node": list(graph.traverse())}) df["time"] = [1.0] * len(graph) - df["name"] = [n.frame["name"] for n in graph.traverse()] + name_col = [] + for n in graph.traverse(): + assert isinstance(n, Node) + name_col.append(n.frame["name"]) + df["name"] = name_col df.set_index(["node"], inplace=True) df.sort_index(inplace=True) @@ -406,9 +415,7 @@ def from_hdf(filename: str, **kwargs) -> "GraphFrame": return HDF5Reader(filename).read(**kwargs) - def to_hdf( - self, filename: str, key: str = "hatchet_graphframe", **kwargs - ) -> "GraphFrame": + def to_hdf(self, filename: str, key: str = "hatchet_graphframe", **kwargs) -> None: # import this lazily to avoid circular dependencies from .writers.hdf5_writer import HDF5Writer @@ -455,7 +462,7 @@ def deepcopy(self) -> "GraphFrame": default_metric (str): N/A metadata (dict): Copy of self's metadata """ - node_clone = {} + node_clone: Dict[Node, Node] = {} graph_copy = self.graph.copy(node_clone) dataframe_copy = self.dataframe.copy() @@ -562,7 +569,7 @@ def filter( elif isinstance(filter_obj, (list, str)) or is_hatchet_query(filter_obj): # use a callpath query to apply the filter - query = filter_obj + query: Union[Query, CompoundQuery] # If a raw Object-dialect query is provided (not already passed to ObjectQuery), # create a new ObjectQuery object. if isinstance(filter_obj, list): @@ -573,7 +580,10 @@ def filter( query = parse_string_dialect(filter_obj, multi_index_mode) # If an old-style query is provided, extract the underlying new-style query. elif issubclass(type(filter_obj), AbstractQuery): - query = filter_obj._get_new_query() + query = cast(AbstractQuery, filter_obj)._get_new_query() + else: + assert isinstance(filter_obj, (Query, CompoundQuery)) + query = filter_obj query_matches = self.query_engine.apply(query, self.graph, self.dataframe) # match_set = list(set().union(*query_matches)) # filtered_df = dataframe_copy.loc[dataframe_copy["node"].isin(match_set)] @@ -619,7 +629,7 @@ def squash(self, update_inc_cols: bool = True) -> "GraphFrame": # Maintain sets of connections to make for each old node. # Start with old -> new mapping and update as we traverse subgraphs. - connections = defaultdict(lambda: set()) + connections: Dict[Node, Set[Node]] = defaultdict(lambda: set()) connections.update({k: {v} for k, v in old_to_new.items()}) new_roots = [] # list of new roots @@ -657,7 +667,7 @@ def rewire(node, new_parent, visited): return connections[node] # run rewire for each root and make a new graph - visited = set() + visited: Set[Node] = set() for root in self.graph.roots: rewire(root, None, visited) graph = Graph(new_roots) @@ -750,7 +760,8 @@ def subtree_sum( out_columns = self._init_sum_columns(columns, out_columns) # sum over the output columns - for node in self.graph.traverse(order="post"): + for trav_node in self.graph.traverse(order="post"): + node = cast(Node, trav_node) if node.children: # TODO: need a better way of aggregating inclusive metrics when # TODO: there is a multi-index @@ -815,7 +826,8 @@ def subgraph_sum( return out_columns = self._init_sum_columns(columns, out_columns) - for node in self.graph.traverse(): + for trav_node in self.graph.traverse(): + node = cast(Node, trav_node) subgraph_nodes = list(node.traverse()) # TODO: need a better way of aggregating inclusive metrics when # TODO: there is a multi-index @@ -893,13 +905,15 @@ def generate_exclusive_columns( # suffix) to the generation list. else: generation_pairs.append((inc + " (exc)", inc)) + node: Node # Consider each new exclusive metric and its corresponding inclusive metric for exc, inc in generation_pairs: # Process of obtaining inclusive data for a node differs if the DataFrame has an Index vs a MultiIndex if isinstance(self.dataframe.index, pd.MultiIndex): - new_data = {} + new_data: Dict[Union[Tuple[Any, ...], Node], int] = {} # Traverse every node in the Graph - for node in self.graph.traverse(): + for trav_node in self.graph.traverse(): + node = cast(Node, trav_node) # Consider each unique portion of the MultiIndex corresponding to the current node for non_node_idx in self.dataframe.loc[(node)].index.unique(): # If there's only 1 index level besides "node", add it to a 1-element list to ensure consistent typing @@ -930,7 +944,7 @@ def generate_exclusive_columns( # Create a basic Node-metric dict for the new exclusive metric new_data = {n: -1 for n in self.dataframe.index.values} # Traverse the graph - for node in self.graph.traverse(): + for node in cast(Iterable[Node], self.graph.traverse()): # Sum up the inclusive metric values of the current node's children inc_sum = 0 for child in node.children: @@ -994,7 +1008,7 @@ def unify(self, other: "GraphFrame"): if self.graph is other.graph: return - node_map = {} + node_map: Dict[int, Node] = {} union_graph = self.graph.union(other.graph, node_map) self_index_names = self.dataframe.index.names @@ -1043,7 +1057,7 @@ def tree( render_header: bool = True, min_value: Optional[int] = None, max_value: Optional[int] = None, - ) -> str: + ) -> Union[str, bytes]: """Visualize the Hatchet graphframe as a tree Arguments: @@ -1122,8 +1136,9 @@ def to_dot( """ if metric is None: metric = self.default_metric + graph_roots = cast(List[Node], self.graph.roots) return trees_to_dot( - self.graph.roots, self.dataframe, metric, name, rank, thread, threshold + graph_roots, self.dataframe, metric, name, rank, thread, threshold ) def to_flamegraph( @@ -1142,9 +1157,10 @@ def to_flamegraph( metric = self.default_metric for root in self.graph.roots: - for hnode in root.traverse(): + for hnode in cast(Iterable[Node], root.traverse()): callpath = hnode.path() for i in range(0, len(callpath) - 1): + df_index: Union[Tuple[Node, int, int], Tuple[Node, int], Node] if ( "rank" in self.dataframe.index.names and "thread" in self.dataframe.index.names @@ -1285,7 +1301,9 @@ def add_nodes(hnode): return graph_literal def to_dict(self) -> Dict: - hatchet_dict = {} + hatchet_dict: Dict[ + str, Union[List[Dict[int, Dict[str, Any]]], List[str], Dict] + ] = {} """ Nodes: {hatchet_nid: {node data, children:[by-id]}} @@ -1293,7 +1311,7 @@ def to_dict(self) -> Dict: graphs = [] for root in self.graph.roots: formatted_graph_dict = {} - for n in root.traverse(): + for n in cast(Iterable[Node], root.traverse()): formatted_graph_dict[n._hatchet_nid] = { "data": n.frame.attrs, "children": [c._hatchet_nid for c in n.children], @@ -1468,7 +1486,7 @@ def groupby_aggregate( """ # create new nodes for each unique node in the old dataframe # length is equal to number of nodes in original graph - old_to_new = {} + old_to_new: Dict[Node, Node] = {} # list of new roots new_roots = [] @@ -1537,7 +1555,7 @@ def reindex(node, parent, visited): old_to_new[i] = super_node # reindex graph by traversing old graph - visited = set() + visited: Set[Node] = set() for root in self.graph.roots: reindex(root, None, visited) diff --git a/hatchet/node.py b/hatchet/node.py index 101dca45..f1205642 100644 --- a/hatchet/node.py +++ b/hatchet/node.py @@ -10,7 +10,7 @@ from .frame import Frame -def traversal_order(node: "Node") -> Tuple(Frame, int): +def traversal_order(node: "Node") -> Tuple[Frame, int]: """Deterministic key function for sorting nodes in traversals.""" return (node.frame, id(node)) @@ -36,10 +36,10 @@ def __init__( self._depth = depth self._hatchet_nid = hnid - self.parents = [] + self.parents: List["Node"] = [] if parent is not None: self.add_parent(parent) - self.children = [] + self.children: List["Node"] = [] def add_parent(self, node: "Node"): """Adds a parent to this node's list of parents.""" @@ -76,7 +76,7 @@ def path(self, attrs: Optional[Dict[str, Any]] = None) -> Tuple["Node", ...]: """ paths = self.paths() if len(paths) > 1: - raise MultiplePathError("Node has more than one path: " % paths) + raise MultiplePathError("Node has more than one path: " + str(paths)) return paths[0] def dag_equal( @@ -213,7 +213,9 @@ def value(node): def __hash__(self) -> int: return self._hatchet_nid - def __eq__(self, other: "Node") -> bool: + def __eq__(self, other: object) -> bool: + if not isinstance(other, Node): + return NotImplemented return self._hatchet_nid == other._hatchet_nid def __lt__(self, other: "Node") -> bool: @@ -233,10 +235,7 @@ def copy(self) -> "Node": @classmethod def from_lists( cls, - lists: Union[ - List[str, "Node", Union[List, Tuple]], - Tuple[str, "Node", Union[List, Tuple]], - ], + lists: Tuple[str, "Node", Union[List, Tuple]], ) -> "Node": r"""Construct a hierarchy of nodes from recursive lists. diff --git a/hatchet/query/__init__.py b/hatchet/query/__init__.py index 6e0da6a9..ba36e5db 100644 --- a/hatchet/query/__init__.py +++ b/hatchet/query/__init__.py @@ -6,7 +6,7 @@ # Make flake8 ignore unused names in this file # flake8: noqa: F401 -from typing import Any, Union, List, TypeVar +from typing import Any, Union, List from .query import Query from .compound import ( @@ -15,9 +15,10 @@ DisjunctionQuery, ExclusiveDisjunctionQuery, NegationQuery, + parse_string_dialect, ) from .object_dialect import ObjectQuery -from .string_dialect import StringQuery, parse_string_dialect +from .string_dialect import StringQuery from .engine import QueryEngine from .errors import ( InvalidQueryPath, @@ -41,17 +42,15 @@ parse_cypher_query, ) -BaseQueryType = TypeVar("BaseQuery", Query, ObjectQuery, StringQuery, str, List) -CompoundQueryType = TypeVar( - "CompoundQuery", +BaseQueryType = Union[Query, ObjectQuery, StringQuery, str, List] +CompoundQueryType = Union[ CompoundQuery, ConjunctionQuery, DisjunctionQuery, ExclusiveDisjunctionQuery, NegationQuery, -) -LegacyQueryType = TypeVar( - "LegacyQuery", +] +LegacyQueryType = Union[ AbstractQuery, NaryQuery, AndQuery, @@ -63,8 +62,7 @@ NotQuery, QueryMatcher, CypherQuery, -) -AnyQueryType = TypeVar("AnyQuery", BaseQueryType, CompoundQueryType, LegacyQueryType) +] def combine_via_conjunction( @@ -92,16 +90,19 @@ def negate_query(query: Union[BaseQueryType, CompoundQueryType]) -> NegationQuer return NegationQuery(query) -Query.__and__ = combine_via_conjunction -Query.__or__ = combine_via_disjunction -Query.__xor__ = combine_via_exclusive_disjunction -Query.__not__ = negate_query +# Note: skipping mypy checks here because we're monkey +# patching these operators. Per mypy Issue #2427, +# mypy doesn't like this +Query.__and__ = combine_via_conjunction # type: ignore +Query.__or__ = combine_via_disjunction # type: ignore +Query.__xor__ = combine_via_exclusive_disjunction # type: ignore +Query.__not__ = negate_query # type: ignore -CompoundQuery.__and__ = combine_via_conjunction -CompoundQuery.__or__ = combine_via_disjunction -CompoundQuery.__xor__ = combine_via_exclusive_disjunction -CompoundQuery.__not__ = negate_query +CompoundQuery.__and__ = combine_via_conjunction # type: ignore +CompoundQuery.__or__ = combine_via_disjunction # type: ignore +CompoundQuery.__xor__ = combine_via_exclusive_disjunction # type: ignore +CompoundQuery.__not__ = negate_query # type: ignore def is_hatchet_query(query_obj: Any) -> bool: diff --git a/hatchet/query/compat.py b/hatchet/query/compat.py index 7197ffc6..a2a96d6d 100644 --- a/hatchet/query/compat.py +++ b/hatchet/query/compat.py @@ -3,20 +3,13 @@ # # SPDX-License-Identifier: MIT -from abc import abstractmethod +from abc import abstractmethod, ABC -try: - from abc import ABC -except ImportError: - from abc import ABCMeta - - ABC = ABCMeta("ABC", (object,), {"__slots__": ()}) import sys import warnings from collections.abc import Callable -from typing import List, Optional, Union +from typing import List, Optional, Union, cast, TYPE_CHECKING -from ..graphframe import GraphFrame from ..node import Node from .query import Query from .compound import ( @@ -25,12 +18,15 @@ DisjunctionQuery, ExclusiveDisjunctionQuery, NegationQuery, + parse_string_dialect, ) from .object_dialect import ObjectQuery -from .string_dialect import parse_string_dialect from .engine import QueryEngine from .errors import BadNumberNaryQueryArgs, InvalidQueryPath +if TYPE_CHECKING: + from ..graphframe import GraphFrame + # QueryEngine object for running the legacy "apply" methods COMPATABILITY_ENGINE: QueryEngine = QueryEngine() @@ -40,7 +36,7 @@ class AbstractQuery(ABC): """Base class for all 'old-style' queries.""" @abstractmethod - def apply(self, gf: GraphFrame) -> List[Node]: + def apply(self, gf: "GraphFrame") -> List[Node]: pass def __and__(self, other: "AbstractQuery") -> "AndQuery": @@ -76,7 +72,7 @@ def __xor__(self, other: "AbstractQuery") -> "XorQuery": """ return XorQuery(self, other) - def __invert__(self) -> "NegationQuery": + def __invert__(self) -> "NotQuery": """Create a new NotQuery using this query. Returns: @@ -99,7 +95,9 @@ def __init__(self, *args) -> None: Arguments: *args (AbstractQuery, str, or list): the subqueries to be performed """ - self.compat_subqueries = [] + self.compat_subqueries: List[ + Union[QueryMatcher, CypherQuery, AbstractQuery, Query, CompoundQuery] + ] = [] if isinstance(args[0], tuple) and len(args) == 1: args = args[0] for query in args: @@ -119,7 +117,7 @@ def __init__(self, *args) -> None: high-level query or a subclass of AbstractQuery" ) - def apply(self, gf: GraphFrame) -> List[Node]: + def apply(self, gf: "GraphFrame") -> List[Node]: """Applies the query to the specified GraphFrame. Arguments: @@ -139,9 +137,10 @@ def _get_new_query(self) -> Union[Query, CompoundQuery]: """ true_subqueries = [] for subq in self.compat_subqueries: - true_subq = subq if issubclass(type(subq), AbstractQuery): - true_subq = subq._get_new_query() + true_subq = cast(AbstractQuery, subq)._get_new_query() + else: + true_subq = cast(Union[Query, CompoundQuery], subq) true_subqueries.append(true_subq) return self._convert_to_new_query(true_subqueries) @@ -291,7 +290,7 @@ def __init__(self, query: Optional[Union[List, Query]] = None) -> None: DeprecationWarning, stacklevel=2, ) - self.true_query = None + self.true_query: Optional[Union[Query, CompoundQuery]] = None if query is None: self.true_query = Query() elif isinstance(query, list): @@ -314,6 +313,7 @@ def match( Returns: (QueryMatcher): the instance of the class that called this function """ + assert isinstance(self.true_query, Query) self.true_query.match(wildcard_spec, filter_func) return self @@ -332,10 +332,11 @@ def rel( Returns: (QueryMatcher): the instance of the class that called this function """ + assert isinstance(self.true_query, Query) self.true_query.rel(wildcard_spec, filter_func) return self - def apply(self, gf: GraphFrame) -> List[Node]: + def apply(self, gf: "GraphFrame") -> List[Node]: """Apply the query to a GraphFrame. Arguments: diff --git a/hatchet/query/compound.py b/hatchet/query/compound.py index 64c011f6..dc35fd17 100644 --- a/hatchet/query/compound.py +++ b/hatchet/query/compound.py @@ -6,12 +6,13 @@ from abc import abstractmethod import sys -from typing import List +import re +from typing import List, Optional, Set, Union, cast from ..node import Node from ..graph import Graph from .query import Query -from .string_dialect import parse_string_dialect +from .string_dialect import StringQuery from .object_dialect import ObjectQuery from .errors import BadNumberNaryQueryArgs @@ -153,7 +154,7 @@ def _apply_op_to_results( Returns: (list): A list containing all the nodes satisfying the exclusive disjunction of the subqueries' results """ - xor_set = set() + xor_set: Set[Node] = set() for res in subquery_results: xor_set = xor_set.symmetric_difference(set(res)) return list(xor_set) @@ -189,6 +190,174 @@ def _apply_op_to_results( Returns: (list): A list containing all the nodes in the Graph not contained in the subquery's results """ - nodes = set(graph.traverse()) + trav_nodes = set(graph.traverse()) + nodes = cast(Set[Node], trav_nodes) query_nodes = set(subquery_results[0]) return list(nodes.difference(query_nodes)) + + +def parse_string_dialect( + query_str: str, multi_index_mode: str = "off" +) -> Union[StringQuery, CompoundQuery]: + """Parse all types of String-based queries, including multi-queries that leverage + the curly brace delimiters. + + Arguments: + query_str (str): the String-based query to be parsed + + Returns: + (Query or CompoundQuery): A Hatchet query object representing the String-based query + """ + # TODO Check if there's a way to prevent curly braces in a string + # from being captured + + # Find the number of curly brace-delimited regions in the query + query_str = query_str.strip() + curly_brace_elems = re.findall(r"\{(.*?)\}", query_str) + num_curly_brace_elems = len(curly_brace_elems) + # If there are no curly brace-delimited regions, just pass the query + # off to the CypherQuery constructor + if num_curly_brace_elems == 0: + if sys.version_info[0] == 2: + query_str = query_str.decode("utf-8") + return StringQuery(query_str, multi_index_mode) + # Create an iterator over the curly brace-delimited regions + curly_brace_iter = re.finditer(r"\{(.*?)\}", query_str) + # Will store curly brace-delimited regions in the WHERE clause + condition_list = None + # Will store curly brace-delimited regions that contain entire + # mid-level queries (MATCH clause and WHERE clause) + query_list = None + # If entire queries are in brace-delimited regions, store the indexes + # of the regions here so we don't consider brace-delimited regions + # within the already-captured region. + query_idxes = None + # Store which compound queries to apply to the curly brace-delimited regions + compound_ops = [] + for i, match in enumerate(curly_brace_iter): + # Get the substring within curly braces + substr = query_str[match.start() + 1 : match.end() - 1] + substr = substr.strip() + # If an entire query (MATCH + WHERE) is within curly braces, + # add the query to "query_list", and add the indexes corresponding + # to the query to "query_idxes" + if substr.startswith("MATCH"): + if query_list is None: + query_list = [] + if query_idxes is None: + query_idxes = [] + query_list.append(substr) + query_idxes.append((match.start(), match.end())) + # If the curly brace-delimited region contains only parts of a + # WHERE clause, first, check if the region is within another + # curly brace delimited region. If it is, do nothing (it will + # be handled later). Otherwise, add the region to "condition_list" + elif re.match(r"[a-zA-Z0-9_]+\..*", substr) is not None: + is_encapsulated_region = False + if query_idxes is not None: + for s, e in query_idxes: + if match.start() >= s or match.end() <= e: + is_encapsulated_region = True + break + if is_encapsulated_region: + continue + if condition_list is None: + condition_list = [] + condition_list.append(substr) + # If the curly brace-delimited region is neither a whole query + # or part of a WHERE clause, raise an error + else: + raise ValueError("Invalid grouping (with curly braces) within the query") + # If there is a compound operator directly after the curly brace-delimited region, + # capture the type of operator, and store the type in "compound_ops" + if i + 1 < num_curly_brace_elems: + rest_substr = query_str[match.end() :] + rest_substr = rest_substr.strip() + if rest_substr.startswith("AND"): + compound_ops.append("AND") + elif rest_substr.startswith("OR"): + compound_ops.append("OR") + elif rest_substr.startswith("XOR"): + compound_ops.append("XOR") + else: + raise ValueError("Invalid compound operator type found!") + # Each call to this function should only consider one of the full query or + # WHERE clause versions at a time. If both types were captured, raise an error + # because some type of internal logic issue occured. + if condition_list is not None and query_list is not None: + raise ValueError( + "Curly braces must be around either a full mid-level query or a set of conditions in a single mid-level query" + ) + # This branch is for the WHERE clause version + if condition_list is not None: + # Make sure you correctly gathered curly brace-delimited regions and + # compound operators + if len(condition_list) != len(compound_ops) + 1: + raise ValueError( + "Incompatible number of curly brace elements and compound operators" + ) + # Get the MATCH clause that will be shared across the subqueries + match_comp_obj = re.search(r"MATCH\s+(?P.*)\s+WHERE", query_str) + match_comp = match_comp_obj.group("match_field") + # Iterate over the compound operators + full_query: Optional[Union[StringQuery, CompoundQuery]] = None + for i, op in enumerate(compound_ops): + # If in the first iteration, set the initial query as a CypherQuery where + # the MATCH clause is the shared match clause and the WHERE clause is the + # first curly brace-delimited region + if i == 0: + query1 = "MATCH {} WHERE {}".format(match_comp, condition_list[i]) + if sys.version_info[0] == 2: + query1 = query1.decode("utf-8") + full_query = StringQuery(query1, multi_index_mode) + # Get the next query as a CypherQuery where + # the MATCH clause is the shared match clause and the WHERE clause is the + # next curly brace-delimited region + next_query = "MATCH {} WHERE {}".format(match_comp, condition_list[i + 1]) + if sys.version_info[0] == 2: + next_query = next_query.decode("utf-8") + next_string_query: Union[StringQuery, CompoundQuery] = StringQuery( + next_query, multi_index_mode + ) + # Add the next query to the full query using the compound operator + # currently being considered + if op == "AND": + assert full_query is not None + full_query = ConjunctionQuery(full_query, next_string_query) + elif op == "OR": + assert full_query is not None + full_query = DisjunctionQuery(full_query, next_string_query) + else: + assert full_query is not None + full_query = ExclusiveDisjunctionQuery(full_query, next_string_query) + return full_query + # This branch is for the full query version + else: + # Make sure you correctly gathered curly brace-delimited regions and + # compound operators + if len(query_list) != len(compound_ops) + 1: + raise ValueError( + "Incompatible number of curly brace elements and compound operators" + ) + # Iterate over the compound operators + full_query = None + for i, op in enumerate(compound_ops): + # If in the first iteration, set the initial query as the result + # of recursively calling this function on the first curly brace-delimited region + if i == 0: + full_query = parse_string_dialect(query_list[i]) + # Get the next query by recursively calling this function + # on the next curly brace-delimited region + next_string_query = parse_string_dialect(query_list[i + 1]) + # Add the next query to the full query using the compound operator + # currently being considered + if op == "AND": + assert full_query is not None + full_query = ConjunctionQuery(full_query, next_string_query) + elif op == "OR": + assert full_query is not None + full_query = DisjunctionQuery(full_query, next_string_query) + else: + assert full_query is not None + full_query = ExclusiveDisjunctionQuery(full_query, next_string_query) + return full_query diff --git a/hatchet/query/engine.py b/hatchet/query/engine.py index dfc5b383..26366ad7 100644 --- a/hatchet/query/engine.py +++ b/hatchet/query/engine.py @@ -5,15 +5,14 @@ from itertools import groupby import pandas as pd -from typing import List, Optional, Set, Union +from typing import Dict, List, Optional, Set, Union, cast from .errors import InvalidQueryFilter from ..node import Node, traversal_order from ..graph import Graph from .query import Query -from .compound import CompoundQuery +from .compound import CompoundQuery, parse_string_dialect from .object_dialect import ObjectQuery -from .string_dialect import parse_string_dialect class QueryEngine: @@ -21,7 +20,7 @@ class QueryEngine: def __init__(self) -> None: """Creates the QueryEngine.""" - self.search_cache = {} + self.search_cache: Dict[int, List[int]] = {} def reset_cache(self) -> None: """Resets the cache in the QueryEngine.""" @@ -42,24 +41,26 @@ def apply( """ if issubclass(type(query), Query): self.reset_cache() - matches = [] - visited = set() + matches: List[List[Node]] = [] + visited: Set[int] = set() + casted_query = cast(Query, query) for root in sorted(graph.roots, key=traversal_order): - self._apply_impl(query, dframe, root, visited, matches) + self._apply_impl(casted_query, dframe, root, visited, matches) assert len(visited) == len(graph) matched_node_set = list(set().union(*matches)) # return matches return matched_node_set elif issubclass(type(query), CompoundQuery): results = [] - for subq in query.subqueries: + compound_query = cast(CompoundQuery, query) + for subq in compound_query.subqueries: subq_obj = subq if isinstance(subq, list): subq_obj = ObjectQuery(subq) elif isinstance(subq, str): subq_obj = parse_string_dialect(subq) results.append(self.apply(subq_obj, graph, dframe)) - return query._apply_op_to_results(results, graph) + return compound_query._apply_op_to_results(results, graph) else: raise TypeError("Invalid query data type ({})".format(str(type(query)))) @@ -184,15 +185,15 @@ def _match_pattern( if query.query_pattern[match_idx][0] == "*": pattern_idx = 0 # Starting matching pattern - matches = [[pattern_root]] + matches: List[List[Node]] = [[pattern_root]] while pattern_idx < len(query): # Get the wildcard type wcard, _ = query.query_pattern[pattern_idx] - new_matches = [] + new_matches: List[List[Node]] = [] # Consider each existing match individually so that more # nodes can be added to them. for m in matches: - sub_match = [] + sub_match: List[Optional[List[Node]]] = [] # Get the portion of the subgraph that matches the next # part of the query. if wcard == ".": @@ -217,9 +218,9 @@ def _match_pattern( ) # Merge the next part of the match path with the # existing part. - for s in sub_match: - if s is not None: - new_matches.append(m + s) + for sm in sub_match: + if sm is not None: + new_matches.append(m + sm) new_matches = [uniq_match for uniq_match, _ in groupby(new_matches)] # Overwrite the old matches with the updated matches matches = new_matches @@ -236,7 +237,7 @@ def _apply_impl( query: Query, dframe: pd.DataFrame, node: Node, - visited: Set[Node], + visited: Set[int], matches: List[List[Node]], ) -> None: """Traverse the subgraph with the specified root, and collect all paths that match the query. diff --git a/hatchet/query/object_dialect.py b/hatchet/query/object_dialect.py index ecdc413b..cd0ed641 100644 --- a/hatchet/query/object_dialect.py +++ b/hatchet/query/object_dialect.py @@ -13,6 +13,7 @@ import re import sys from typing import Dict, List, Tuple, Union +from collections.abc import Callable, Iterable from .errors import InvalidQueryPath, InvalidQueryFilter, MultiIndexModeMismatch from ..node import Node @@ -32,14 +33,14 @@ def _process_multi_index_mode(apply_result: pd.Series, multi_index_mode: str): def _process_predicate( attr_filter: Dict[Union[str, Tuple[str, ...]], Union[str, Real]], multi_index_mode: str, -) -> bool: +) -> Callable[[Union[pd.Series, pd.DataFrame]], bool]: """Converts high-level API attribute filter to a lambda""" compops = ("<", ">", "==", ">=", "<=", "<>", "!=") # , def filter_series(df_row: pd.Series) -> bool: def filter_single_series( df_row: pd.Series, - key: Union[str, Tuple[str]], + key: Union[str, Tuple[str, ...]], single_value: Union[str, Real], ) -> bool: if key == "depth": @@ -118,14 +119,7 @@ def filter_single_series( metric_name = k if isinstance(k, (tuple, list)) and len(k) == 1: metric_name = k[0] - try: - _ = iter(v) - # Manually raise TypeError if v is a string so that - # the string is processed as a non-iterable - if isinstance(v, str): - raise TypeError - # Runs if v is not iterable (e.g., list, tuple, etc.) - except TypeError: + if isinstance(v, str) or not isinstance(v, Iterable): matches = matches and filter_single_series(df_row, metric_name, v) else: for single_value in v: @@ -208,11 +202,7 @@ def filter_single_dframe( metric_name = k if isinstance(k, (tuple, list)) and len(k) == 1: metric_name = k[0] - try: - _ = iter(v) - if isinstance(v, str): - raise TypeError - except TypeError: + if isinstance(v, str) or not isinstance(v, Iterable): matches = matches and filter_single_dframe(node, df_row, metric_name, v) else: for single_value in v: diff --git a/hatchet/query/query.py b/hatchet/query/query.py index 0cc6139e..d3be624e 100644 --- a/hatchet/query/query.py +++ b/hatchet/query/query.py @@ -3,8 +3,8 @@ # # SPDX-License-Identifier: MIT -from typing import Tuple -from collections.abc import Callable +from typing import List, Tuple, Union +from collections.abc import Callable, Iterator from .errors import InvalidQueryPath @@ -14,10 +14,10 @@ class Query(object): def __init__(self) -> None: """Create new Query""" - self.query_pattern = [] + self.query_pattern: List[Tuple[Union[str, int], Callable]] = [] def match( - self, quantifier: str = ".", predicate: Callable = lambda row: True + self, quantifier: Union[str, int] = ".", predicate: Callable = lambda row: True ) -> "Query": """Start a query with a root node described by the arguments. @@ -34,7 +34,7 @@ def match( return self def rel( - self, quantifier: str = ".", predicate: Callable = lambda row: True + self, quantifier: Union[str, int] = ".", predicate: Callable = lambda row: True ) -> "Query": """Add a new node to the end of the query. @@ -53,7 +53,7 @@ def rel( return self def relation( - self, quantifer: str = ".", predicate: Callable = lambda row: True + self, quantifer: Union[str, int] = ".", predicate: Callable = lambda row: True ) -> "Query": """Alias to Query.rel. Add a new node to the end of the query. @@ -70,12 +70,12 @@ def __len__(self) -> int: """Returns the length of the query.""" return len(self.query_pattern) - def __iter__(self) -> Tuple[str, Callable]: + def __iter__(self) -> Iterator[Tuple[Union[str, int], Callable]]: """Allows users to iterate over the Query like a list.""" return iter(self.query_pattern) def _add_node( - self, quantifer: str = ".", predicate: Callable = lambda row: True + self, quantifer: Union[str, int] = ".", predicate: Callable = lambda row: True ) -> None: """Add a node to the query. diff --git a/hatchet/query/string_dialect.py b/hatchet/query/string_dialect.py index d013e93d..6b5f5292 100644 --- a/hatchet/query/string_dialect.py +++ b/hatchet/query/string_dialect.py @@ -7,7 +7,7 @@ import re import sys from collections.abc import Callable -from typing import Any, Optional, Tuple, Union +from typing import Any, Dict, Optional, List, Union, TYPE_CHECKING import pandas as pd # noqa: F401 from pandas.api.types import is_numeric_dtype, is_string_dtype # noqa: F401 import numpy as np # noqa: F401 @@ -17,7 +17,6 @@ from .errors import InvalidQueryPath, InvalidQueryFilter, RedundantQueryFilterWarning from .query import Query -from .compound import CompoundQuery # PEG grammar for the String-based dialect @@ -124,12 +123,12 @@ def __init__(self, cypher_query: str, multi_index_mode: str = "off") -> None: e.message ) ) - self.wcards = [] - self.wcard_pos = {} + self.wcards: List[List[Any]] = [] + self.wcard_pos: Dict[str, int] = {} self._parse_path(model.path_expr) - self.filters = [[] for _ in self.wcards] + self.filters: List[List[Any]] = [[] for _ in self.wcards] self._parse_conditions(model.cond_expr) - self.lambda_filters = [None for _ in self.wcards] + self.lambda_filters: List[Optional[str]] = [None for _ in self.wcards] self._build_lambdas() self._build_query() @@ -185,7 +184,7 @@ def _parse_path(self, path_obj: Any) -> None: nodes = path_obj.path.nodes idx = len(self.wcards) for n in nodes: - new_node = [n.wcard, n.name] + new_node: List[Any] = [n.wcard, n.name] if n.wcard is None or n.wcard == "" or n.wcard == 0: new_node[0] = "." self.wcards.append(new_node) @@ -231,9 +230,7 @@ def _is_binary_cond(self, obj: Any) -> bool: return True return False - def _parse_binary_cond( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_binary_cond(self, obj: Any) -> List[Optional[str]]: """Top level function for parsing binary predicates.""" if cname(obj) == "AndCond": return self._parse_and_cond(obj) @@ -241,31 +238,25 @@ def _parse_binary_cond( return self._parse_or_cond(obj) raise RuntimeError("Bad Binary Condition") - def _parse_or_cond(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_or_cond(self, obj: Any) -> List[Optional[str]]: """Top level function for parsing predicates combined with logical OR.""" converted_subcond = self._parse_unary_cond(obj.subcond) converted_subcond[0] = "or" return converted_subcond - def _parse_and_cond( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_and_cond(self, obj: Any) -> List[Optional[str]]: """Top level function for parsing predicates combined with logical AND.""" converted_subcond = self._parse_unary_cond(obj.subcond) converted_subcond[0] = "and" return converted_subcond - def _parse_unary_cond( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_unary_cond(self, obj: Any) -> List[Optional[str]]: """Top level function for parsing unary predicates.""" if cname(obj) == "NotCond": return self._parse_not_cond(obj) return self._parse_single_cond(obj) - def _parse_not_cond( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_not_cond(self, obj: Any) -> List[Optional[str]]: """Parse predicates containing the logical NOT operator.""" converted_subcond = self._parse_single_cond(obj.subcond) converted_subcond[2] = "not {}".format(converted_subcond[2]) @@ -273,16 +264,14 @@ def _parse_not_cond( def _run_method_based_on_multi_idx_mode( self, method_name: str, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + ) -> List[Optional[str]]: real_method_name = method_name if self.multi_index_mode != "off": real_method_name = method_name + "_multi_idx" method = eval("StringQuery.{}".format(real_method_name)) return method(self, obj) - def _parse_single_cond( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_single_cond(self, obj: Any) -> List[Optional[str]]: """Top level function for parsing individual numeric or string predicates.""" if self._is_str_cond(obj): return self._parse_str(obj) @@ -298,7 +287,7 @@ def _parse_single_cond( return self._run_method_based_on_multi_idx_mode("_parse_not_leaf", obj) raise RuntimeError("Bad Single Condition") - def _parse_none(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_none(self, obj: Any) -> List[Optional[str]]: """Parses 'property IS NONE'.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ @@ -330,9 +319,7 @@ def _add_aggregation_call_to_multi_idx_predicate(self, predicate: str) -> str: return predicate + ".any()" return predicate + ".all()" - def _parse_none_multi_idx( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_none_multi_idx(self, obj: Any) -> List[Optional[str]]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ None, @@ -360,9 +347,7 @@ def _parse_none_multi_idx( None, ] - def _parse_not_none( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_not_none(self, obj: Any) -> List[Optional[str]]: """Parses 'property IS NOT NONE'.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ @@ -389,9 +374,7 @@ def _parse_not_none( None, ] - def _parse_not_none_multi_idx( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_not_none_multi_idx(self, obj: Any) -> List[Optional[str]]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ None, @@ -419,7 +402,7 @@ def _parse_not_none_multi_idx( None, ] - def _parse_leaf(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_leaf(self, obj: Any) -> List[Optional[str]]: """Parses 'node IS LEAF'.""" return [ None, @@ -428,9 +411,7 @@ def _parse_leaf(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]] None, ] - def _parse_leaf_multi_idx( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_leaf_multi_idx(self, obj: Any) -> List[Optional[str]]: return [ None, obj.name, @@ -438,9 +419,7 @@ def _parse_leaf_multi_idx( None, ] - def _parse_not_leaf( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_not_leaf(self, obj: Any) -> List[Optional[str]]: """Parses 'node IS NOT LEAF'.""" return [ None, @@ -449,9 +428,7 @@ def _parse_not_leaf( None, ] - def _parse_not_leaf_multi_idx( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_not_leaf_multi_idx(self, obj: Any) -> List[Optional[str]]: return [ None, obj.name, @@ -487,7 +464,7 @@ def _is_num_cond(self, obj: Any) -> bool: return True return False - def _parse_str(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_str(self, obj: Any) -> List[Optional[str]]: """Function that redirects processing of string predicates to the correct function. """ @@ -505,7 +482,7 @@ def _parse_str(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: return self._run_method_based_on_multi_idx_mode("_parse_str_match", obj) raise RuntimeError("Bad String Op Class") - def _parse_str_eq(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_str_eq(self, obj: Any) -> List[Optional[str]]: """Processes string equivalence predicates.""" return [ None, @@ -525,9 +502,7 @@ def _parse_str_eq(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str ), ] - def _parse_str_eq_multi_idx( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_str_eq_multi_idx(self, obj: Any) -> List[Optional[str]]: return [ None, obj.name, @@ -548,9 +523,7 @@ def _parse_str_eq_multi_idx( ), ] - def _parse_str_starts_with( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_str_starts_with(self, obj: Any) -> List[Optional[str]]: """Processes string 'startswith' predicates.""" return [ None, @@ -570,9 +543,7 @@ def _parse_str_starts_with( ), ] - def _parse_str_starts_with_multi_idx( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_str_starts_with_multi_idx(self, obj: Any) -> List[Optional[str]]: return [ None, obj.name, @@ -593,9 +564,7 @@ def _parse_str_starts_with_multi_idx( ), ] - def _parse_str_ends_with( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_str_ends_with(self, obj: Any) -> List[Optional[str]]: """Processes string 'endswith' predicates.""" return [ None, @@ -615,9 +584,7 @@ def _parse_str_ends_with( ), ] - def _parse_str_ends_with_multi_idx( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_str_ends_with_multi_idx(self, obj: Any) -> List[Optional[str]]: return [ None, obj.name, @@ -638,9 +605,7 @@ def _parse_str_ends_with_multi_idx( ), ] - def _parse_str_contains( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_str_contains(self, obj: Any) -> List[Optional[str]]: """Processes string 'contains' predicates.""" return [ None, @@ -660,9 +625,7 @@ def _parse_str_contains( ), ] - def _parse_str_contains_multi_idx( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_str_contains_multi_idx(self, obj: Any) -> List[Optional[str]]: return [ None, obj.name, @@ -683,9 +646,7 @@ def _parse_str_contains_multi_idx( ), ] - def _parse_str_match( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_str_match(self, obj: Any) -> List[Optional[str]]: """Processes string regex match predicates.""" return [ None, @@ -705,9 +666,7 @@ def _parse_str_match( ), ] - def _parse_str_match_multi_idx( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_str_match_multi_idx(self, obj: Any) -> List[Optional[str]]: return [ None, obj.name, @@ -728,7 +687,7 @@ def _parse_str_match_multi_idx( ), ] - def _parse_num(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_num(self, obj: Any) -> List[Optional[str]]: """Function that redirects processing of numeric predicates to the correct function. """ @@ -752,7 +711,7 @@ def _parse_num(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: return self._run_method_based_on_multi_idx_mode("_parse_num_not_inf", obj) raise RuntimeError("Bad Number Op Class") - def _parse_num_eq(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_num_eq(self, obj: Any) -> List[Optional[str]]: """Processes numeric equivalence predicates.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val == -1: @@ -825,9 +784,7 @@ def _parse_num_eq(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str ), ] - def _parse_num_eq_multi_idx( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_num_eq_multi_idx(self, obj: Any) -> List[Optional[str]]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val == -1: return [ @@ -903,7 +860,7 @@ def _parse_num_eq_multi_idx( ), ] - def _parse_num_lt(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_num_lt(self, obj: Any) -> List[Optional[str]]: """Processes numeric less-than predicates.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: @@ -969,9 +926,7 @@ def _parse_num_lt(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str ), ] - def _parse_num_lt_multi_idx( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_num_lt_multi_idx(self, obj: Any) -> List[Optional[str]]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: warnings.warn( @@ -1040,7 +995,7 @@ def _parse_num_lt_multi_idx( ), ] - def _parse_num_gt(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_num_gt(self, obj: Any) -> List[Optional[str]]: """Processes numeric greater-than predicates.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: @@ -1106,9 +1061,7 @@ def _parse_num_gt(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str ), ] - def _parse_num_gt_multi_idx( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_num_gt_multi_idx(self, obj: Any) -> List[Optional[str]]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: warnings.warn( @@ -1177,7 +1130,7 @@ def _parse_num_gt_multi_idx( ), ] - def _parse_num_lte(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_num_lte(self, obj: Any) -> List[Optional[str]]: """Processes numeric less-than-or-equal-to predicates.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: @@ -1243,9 +1196,7 @@ def _parse_num_lte(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[st ), ] - def _parse_num_lte_multi_idx( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_num_lte_multi_idx(self, obj: Any) -> List[Optional[str]]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: warnings.warn( @@ -1314,7 +1265,7 @@ def _parse_num_lte_multi_idx( ), ] - def _parse_num_gte(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_num_gte(self, obj: Any) -> List[Optional[str]]: """Processes numeric greater-than-or-equal-to predicates.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: @@ -1380,9 +1331,7 @@ def _parse_num_gte(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[st ), ] - def _parse_num_gte_multi_idx( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_num_gte_multi_idx(self, obj: Any) -> List[Optional[str]]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: warnings.warn( @@ -1451,7 +1400,7 @@ def _parse_num_gte_multi_idx( ), ] - def _parse_num_nan(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_num_nan(self, obj: Any) -> List[Optional[str]]: """Processes predicates that check for NaN.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ @@ -1482,9 +1431,7 @@ def _parse_num_nan(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[st ), ] - def _parse_num_nan_multi_idx( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_num_nan_multi_idx(self, obj: Any) -> List[Optional[str]]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ None, @@ -1516,9 +1463,7 @@ def _parse_num_nan_multi_idx( ), ] - def _parse_num_not_nan( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_num_not_nan(self, obj: Any) -> List[Optional[str]]: """Processes predicates that check for NaN.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ @@ -1549,9 +1494,7 @@ def _parse_num_not_nan( ), ] - def _parse_num_not_nan_multi_idx( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_num_not_nan_multi_idx(self, obj: Any) -> List[Optional[str]]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ None, @@ -1583,7 +1526,7 @@ def _parse_num_not_nan_multi_idx( ), ] - def _parse_num_inf(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_num_inf(self, obj: Any) -> List[Optional[str]]: """Processes predicates that check for Infinity.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ @@ -1614,9 +1557,7 @@ def _parse_num_inf(self, obj: Any) -> Tuple[Optional[str], str, str, Optional[st ), ] - def _parse_num_inf_multi_idx( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_num_inf_multi_idx(self, obj: Any) -> List[Optional[str]]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ None, @@ -1648,9 +1589,7 @@ def _parse_num_inf_multi_idx( ), ] - def _parse_num_not_inf( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_num_not_inf(self, obj: Any) -> List[Optional[str]]: """Processes predicates that check for not-Infinity.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ @@ -1681,9 +1620,7 @@ def _parse_num_not_inf( ), ] - def _parse_num_not_inf_multi_idx( - self, obj: Any - ) -> Tuple[Optional[str], str, str, Optional[str]]: + def _parse_num_not_inf_multi_idx(self, obj: Any) -> List[Optional[str]]: if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ None, @@ -1714,162 +1651,3 @@ def _parse_num_not_inf_multi_idx( else "'{}'".format(obj.prop.ids[0]) ), ] - - -def parse_string_dialect( - query_str: str, multi_index_mode: str = "off" -) -> Union[StringQuery, CompoundQuery]: - """Parse all types of String-based queries, including multi-queries that leverage - the curly brace delimiters. - - Arguments: - query_str (str): the String-based query to be parsed - - Returns: - (Query or CompoundQuery): A Hatchet query object representing the String-based query - """ - # TODO Check if there's a way to prevent curly braces in a string - # from being captured - - # Find the number of curly brace-delimited regions in the query - query_str = query_str.strip() - curly_brace_elems = re.findall(r"\{(.*?)\}", query_str) - num_curly_brace_elems = len(curly_brace_elems) - # If there are no curly brace-delimited regions, just pass the query - # off to the CypherQuery constructor - if num_curly_brace_elems == 0: - if sys.version_info[0] == 2: - query_str = query_str.decode("utf-8") - return StringQuery(query_str, multi_index_mode) - # Create an iterator over the curly brace-delimited regions - curly_brace_iter = re.finditer(r"\{(.*?)\}", query_str) - # Will store curly brace-delimited regions in the WHERE clause - condition_list = None - # Will store curly brace-delimited regions that contain entire - # mid-level queries (MATCH clause and WHERE clause) - query_list = None - # If entire queries are in brace-delimited regions, store the indexes - # of the regions here so we don't consider brace-delimited regions - # within the already-captured region. - query_idxes = None - # Store which compound queries to apply to the curly brace-delimited regions - compound_ops = [] - for i, match in enumerate(curly_brace_iter): - # Get the substring within curly braces - substr = query_str[match.start() + 1 : match.end() - 1] - substr = substr.strip() - # If an entire query (MATCH + WHERE) is within curly braces, - # add the query to "query_list", and add the indexes corresponding - # to the query to "query_idxes" - if substr.startswith("MATCH"): - if query_list is None: - query_list = [] - if query_idxes is None: - query_idxes = [] - query_list.append(substr) - query_idxes.append((match.start(), match.end())) - # If the curly brace-delimited region contains only parts of a - # WHERE clause, first, check if the region is within another - # curly brace delimited region. If it is, do nothing (it will - # be handled later). Otherwise, add the region to "condition_list" - elif re.match(r"[a-zA-Z0-9_]+\..*", substr) is not None: - is_encapsulated_region = False - if query_idxes is not None: - for s, e in query_idxes: - if match.start() >= s or match.end() <= e: - is_encapsulated_region = True - break - if is_encapsulated_region: - continue - if condition_list is None: - condition_list = [] - condition_list.append(substr) - # If the curly brace-delimited region is neither a whole query - # or part of a WHERE clause, raise an error - else: - raise ValueError("Invalid grouping (with curly braces) within the query") - # If there is a compound operator directly after the curly brace-delimited region, - # capture the type of operator, and store the type in "compound_ops" - if i + 1 < num_curly_brace_elems: - rest_substr = query_str[match.end() :] - rest_substr = rest_substr.strip() - if rest_substr.startswith("AND"): - compound_ops.append("AND") - elif rest_substr.startswith("OR"): - compound_ops.append("OR") - elif rest_substr.startswith("XOR"): - compound_ops.append("XOR") - else: - raise ValueError("Invalid compound operator type found!") - # Each call to this function should only consider one of the full query or - # WHERE clause versions at a time. If both types were captured, raise an error - # because some type of internal logic issue occured. - if condition_list is not None and query_list is not None: - raise ValueError( - "Curly braces must be around either a full mid-level query or a set of conditions in a single mid-level query" - ) - # This branch is for the WHERE clause version - if condition_list is not None: - # Make sure you correctly gathered curly brace-delimited regions and - # compound operators - if len(condition_list) != len(compound_ops) + 1: - raise ValueError( - "Incompatible number of curly brace elements and compound operators" - ) - # Get the MATCH clause that will be shared across the subqueries - match_comp_obj = re.search(r"MATCH\s+(?P.*)\s+WHERE", query_str) - match_comp = match_comp_obj.group("match_field") - # Iterate over the compound operators - full_query = None - for i, op in enumerate(compound_ops): - # If in the first iteration, set the initial query as a CypherQuery where - # the MATCH clause is the shared match clause and the WHERE clause is the - # first curly brace-delimited region - if i == 0: - query1 = "MATCH {} WHERE {}".format(match_comp, condition_list[i]) - if sys.version_info[0] == 2: - query1 = query1.decode("utf-8") - full_query = StringQuery(query1, multi_index_mode) - # Get the next query as a CypherQuery where - # the MATCH clause is the shared match clause and the WHERE clause is the - # next curly brace-delimited region - next_query = "MATCH {} WHERE {}".format(match_comp, condition_list[i + 1]) - if sys.version_info[0] == 2: - next_query = next_query.decode("utf-8") - next_query = StringQuery(next_query, multi_index_mode) - # Add the next query to the full query using the compound operator - # currently being considered - if op == "AND": - full_query = full_query & next_query - elif op == "OR": - full_query = full_query | next_query - else: - full_query = full_query ^ next_query - return full_query - # This branch is for the full query version - else: - # Make sure you correctly gathered curly brace-delimited regions and - # compound operators - if len(query_list) != len(compound_ops) + 1: - raise ValueError( - "Incompatible number of curly brace elements and compound operators" - ) - # Iterate over the compound operators - full_query = None - for i, op in enumerate(compound_ops): - # If in the first iteration, set the initial query as the result - # of recursively calling this function on the first curly brace-delimited region - if i == 0: - full_query = parse_string_dialect(query_list[i]) - # Get the next query by recursively calling this function - # on the next curly brace-delimited region - next_query = parse_string_dialect(query_list[i + 1]) - # Add the next query to the full query using the compound operator - # currently being considered - if op == "AND": - full_query = full_query & next_query - elif op == "OR": - full_query = full_query | next_query - else: - full_query = full_query ^ next_query - return full_query diff --git a/hatchet/readers/caliper_native_reader.py b/hatchet/readers/caliper_native_reader.py index bfb95aa2..04ae2e63 100644 --- a/hatchet/readers/caliper_native_reader.py +++ b/hatchet/readers/caliper_native_reader.py @@ -7,7 +7,8 @@ import pandas as pd import numpy as np import os -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, cast +from collections.abc import Callable import caliperreader as cr @@ -60,24 +61,26 @@ def __init__( native (bool): use native metric names or user-readable metric names string_attributes (str or list): Adds existing string attributes from within the caliper file to the dataframe """ - self.filename_or_caliperreader = filename_or_caliperreader + self.filename_or_caliperreader: Union[str, cr.CaliperReader] = ( + filename_or_caliperreader + ) self.filename_ext = "" self.use_native_metric_names = native self.string_attributes = string_attributes - self.df_nodes = {} - self.metric_cols = [] - self.record_data_cols = [] - self.node_dicts = [] - self.callpath_to_node = {} - self.idx_to_node = {} - self.callpath_to_idx = {} - self.global_nid = 0 - self.node_ordering = False - self.gf_list = [] - self.timeseries_level = None + self.df_nodes: Optional[pd.DataFrame] = None + self.metric_cols: List[str] = [] + self.record_data_cols: List[str] = [] + # self.node_dicts = [] + self.callpath_to_node: Dict[Tuple[str, ...], Node] = {} + self.idx_to_node: Dict[int, Dict[str, Any]] = {} + self.callpath_to_idx: Dict[Tuple[str, ...], int] = {} + self.global_nid: int = 0 + self.node_ordering: bool = False + self.gf_list: List[hatchet.graphframe.GraphFrame] = [] + self.timeseries_level: Optional[str] = None - self.default_metric = None + self.default_metric: Optional[str] = None self.timer = Timer() @@ -87,8 +90,9 @@ def __init__( if isinstance(self.string_attributes, str): self.string_attributes = [self.string_attributes] - def _create_metric_df(self, metrics: List[str]) -> pd.DataFrame: + def _create_metric_df(self, metrics: List[Dict[str, Any]]) -> pd.DataFrame: """Make a list of metric columns and create a dataframe, group by node""" + assert isinstance(self.filename_or_caliperreader, cr.CaliperReader) for col in self.record_data_cols: if self.filename_or_caliperreader.attribute(col).is_value(): self.metric_cols.append(col) @@ -114,8 +118,9 @@ def _reset_metrics(self, metrics: List[Dict[str, Any]]) -> List[Dict[str, Any]]: def read_metrics(self, ctx: str = "path") -> List[pd.DataFrame]: """append each metrics table to a list and return the list, split on timeseries_level if exists""" + assert isinstance(self.filename_or_caliperreader, cr.CaliperReader) metric_dfs = [] - all_metrics = [] + all_metrics: List[Dict[str, Any]] = [] next_timestep = 0 cur_timestep = 0 records = self.filename_or_caliperreader.records @@ -175,9 +180,12 @@ def read_metrics(self, ctx: str = "path") -> List[pd.DataFrame]: or item in self.string_attributes ): try: - node_dict[item] = self.__cali_type_dict[ - attr_type - ](record[item]) + node_dict[item] = ( + cast( + Callable, + self.__cali_type_dict[attr_type], + )(record[item]), + ) if item not in self.record_data_cols: self.record_data_cols.append(item) except ValueError as e: @@ -199,9 +207,10 @@ def read_metrics(self, ctx: str = "path") -> List[pd.DataFrame]: return metric_dfs def create_graph(self, ctx: str = "path") -> List[Node]: + assert isinstance(self.filename_or_caliperreader, cr.CaliperReader) list_roots = [] - def _create_parent(child_node: Node, parent_callpath: Any) -> None: + def _create_parent(child_node: Node, parent_callpath: Tuple[str, ...]) -> None: """We may encounter a parent node in the callpath before we see it as a child node. In this case, we need to create a hatchet node for the parent. @@ -353,7 +362,9 @@ def _create_parent(child_node: Node, parent_callpath: Any) -> None: return list_roots - def _parse_metadata(self, mdata: Dict[str, str]) -> Dict[str, str]: + def _parse_metadata( + self, mdata: Dict[str, Union[List, str]] + ) -> Dict[str, Union[List, int, float, str]]: """Convert Caliper Metadata values into correct Python objects. Args: @@ -362,7 +373,7 @@ def _parse_metadata(self, mdata: Dict[str, str]) -> Dict[str, str]: Return: (dict[str: str]): modified metadata """ - parsed_mdata = {} + parsed_mdata: Dict[str, Union[List, int, float, str]] = {} for k, v in mdata.items(): # environment information service brings in different metadata types if isinstance(v, list): @@ -428,7 +439,7 @@ def read(self) -> hatchet.graphframe.GraphFrame: rank_list = range(0, num_ranks) # create a standard dict to be used for filling all missing rows - default_metric_dict = {} + default_metric_dict: Dict[str, Any] = {} for idx, col in enumerate(self.record_data_cols): if self.filename_or_caliperreader.attribute(col).is_value(): default_metric_dict[list(self.record_data_cols)[idx]] = 0 @@ -437,7 +448,7 @@ def read(self) -> hatchet.graphframe.GraphFrame: default_metric_dict["nid"] = np.nan # create a list of dicts, one dict for each missing row - missing_nodes = [] + missing_nodes: List[Dict[str, Any]] = [] for iteridx, row in self.df_nodes.iterrows(): # check if df_nodes row exists in df_fixed_data metric_rows = df_fixed_data.loc[metrics["nid"] == row["nid"]] diff --git a/hatchet/readers/caliper_reader.py b/hatchet/readers/caliper_reader.py index a298f70f..e8ae25f3 100644 --- a/hatchet/readers/caliper_reader.py +++ b/hatchet/readers/caliper_reader.py @@ -9,7 +9,8 @@ import subprocess import os import math -from typing import List, Union +from typing import Any, Dict, List, Union, cast +from collections.abc import Callable from io import TextIOWrapper import pandas as pd @@ -43,21 +44,21 @@ def __init__( self.query = query self.node_ordering = False - self.json_data = {} - self.json_cols = {} - self.json_cols_mdata = {} - self.json_nodes = {} + self.json_data: List[List[Union[int, float]]] = [] + self.json_cols: List[str] = [] + self.json_cols_mdata: List[Dict[str, Any]] = [] + self.json_nodes: List[Dict[str, Any]] = [] - self.metadata = {} + self.metadata: Dict[str, Any] = {} - self.idx_to_label = {} - self.idx_to_node = {} + self.idx_to_label: Dict[int, str] = {} + self.idx_to_node: Dict[int, Dict[str, Union[int, str, Node]]] = {} self.timer = Timer() self.nid_col_name = "nid" if isinstance(self.filename_or_stream, str): - _, self.filename_ext = os.path.splitext(filename_or_stream) + _, self.filename_ext = os.path.splitext(cast(str, filename_or_stream)) def read_json_sections(self) -> None: # if cali-query exists, extract data from .cali to a file-like object @@ -65,19 +66,20 @@ def read_json_sections(self) -> None: cali_query = which("cali-query") if not cali_query: raise ValueError("from_caliper() needs cali-query to query .cali file") - cali_json = subprocess.Popen( + assert isinstance(self.filename_or_stream, str) + cali_json_popen = subprocess.Popen( [cali_query, "-q", self.query, self.filename_or_stream], stdout=subprocess.PIPE, ) - self.filename_or_stream = cali_json.stdout + self.filename_or_stream = str(cali_json_popen.stdout) # if filename_or_stream is a str, then open the file, otherwise # directly load the file-like object if isinstance(self.filename_or_stream, str): - with open(self.filename_or_stream) as cali_json: + with open(cast(str, self.filename_or_stream)) as cali_json: json_obj = json.load(cali_json) else: - json_obj = json.loads(self.filename_or_stream.read().decode("utf-8")) + json_obj = json.loads(self.filename_or_stream.read()) # read various sections of the Caliper JSON file self.json_data = json_obj["data"] @@ -121,27 +123,30 @@ def read_json_sections(self) -> None: self.json_data.remove(i) # change column names - for idx, item in enumerate(self.json_cols): - if item == self.path_col_name: + for idx, col_item in enumerate(self.json_cols): + if col_item == self.path_col_name: # this column is just a pointer into the nodes section self.json_cols[idx] = self.nid_col_name # make other columns consistent with other readers - if item == "mpi.rank": + if col_item == "mpi.rank": self.json_cols[idx] = "rank" - if item == "module#cali.sampler.pc": + if col_item == "module#cali.sampler.pc": self.json_cols[idx] = "module" - if item == "sum#time.duration" or item == "sum#avg#sum#time.duration": + if ( + col_item == "sum#time.duration" + or col_item == "sum#avg#sum#time.duration" + ): self.json_cols[idx] = "time" if ( - item == "inclusive#sum#time.duration" - or item == "sum#avg#inclusive#sum#time.duration" + col_item == "inclusive#sum#time.duration" + or col_item == "sum#avg#inclusive#sum#time.duration" ): self.json_cols[idx] = "time (inc)" # make list of metric columns - self.metric_columns = [] - for idx, item in enumerate(self.json_cols_mdata): - if self.json_cols[idx] != "rank" and item["is_value"] is True: + self.metric_columns: List[str] = [] + for idx, col_mdata_item in enumerate(self.json_cols_mdata): + if self.json_cols[idx] != "rank" and col_mdata_item["is_value"] is True: self.metric_columns.append(self.json_cols[idx]) def create_graph(self) -> List[Node]: @@ -162,7 +167,7 @@ def create_graph(self) -> List[Node]: # If there is a node orderering, assign to the _hatchet_nid if "Node order" in self.json_cols: self.node_ordering = True - order = self.json_data[idx][0] + order = cast(int, self.json_data[idx][0]) if "parent" not in node: # since this node does not have a parent, this is a root graph_root = Node( @@ -177,7 +182,9 @@ def create_graph(self) -> List[Node]: } self.idx_to_node[idx] = node_dict else: - parent_hnode = (self.idx_to_node[node["parent"]])["node"] + parent_hnode = cast( + Node, (self.idx_to_node[node["parent"]])["node"] + ) hnode = Node( Frame({"type": self.node_type, "name": node_label}), hnid=order, @@ -214,7 +221,7 @@ def read(self) -> hatchet.graphframe.GraphFrame: if self.both_hierarchies is True: # create dict that stores aggregation function for each column - agg_dict = {} + agg_dict: Dict[str, Callable] = {} for idx, item in enumerate(self.json_cols_mdata): col = self.json_cols[idx] if col != "rank" and col != "nid": @@ -285,7 +292,7 @@ def read(self) -> hatchet.graphframe.GraphFrame: # only need to do something if there are more than one # file:line number entries for the node if len(line_groups.size()) > 1: - sn_hnode = self.idx_to_node[nid]["node"] + sn_hnode = cast(Node, self.idx_to_node[nid]["node"]) for line, line_group in line_groups: # create the node label diff --git a/hatchet/readers/dataframe_reader.py b/hatchet/readers/dataframe_reader.py index 50d21e0d..1caea64b 100644 --- a/hatchet/readers/dataframe_reader.py +++ b/hatchet/readers/dataframe_reader.py @@ -9,20 +9,9 @@ import pandas as pd -from abc import abstractmethod +from abc import abstractmethod, ABC from typing import Dict, List -# TODO The ABC class was introduced in Python 3.4. -# When support for earlier versions is (eventually) dropped, -# this entire "try-except" block can be reduced to: -# from abc import ABC -try: - from abc import ABC -except ImportError: - from abc import ABCMeta - - ABC = ABCMeta("ABC", (object,), {"__slots__": ()}) - def _get_node_from_df_iloc(df: pd.DataFrame, ind: int) -> Node: node = None @@ -38,7 +27,7 @@ def _get_node_from_df_iloc(df: pd.DataFrame, ind: int) -> Node: def _get_parents_and_children(df: pd.DataFrame) -> Dict[Node, Dict[str, List[int]]]: - rel_dict = {} + rel_dict: Dict[Node, Dict[str, List[int]]] = {} for i in range(len(df)): node = _get_node_from_df_iloc(df, i) if node not in rel_dict: diff --git a/hatchet/readers/gprof_dot_reader.py b/hatchet/readers/gprof_dot_reader.py index 981388cc..76ccadcb 100644 --- a/hatchet/readers/gprof_dot_reader.py +++ b/hatchet/readers/gprof_dot_reader.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: MIT import re -from typing import List +from typing import Dict, List, Union import pandas as pd import pydot @@ -23,8 +23,8 @@ class GprofDotReader: def __init__(self, filename: str) -> None: self.dotfile = filename - self.name_to_hnode = {} - self.name_to_dict = {} + self.name_to_hnode: Dict[str, Node] = {} + self.name_to_dict: Dict[str, Dict[str, Union[str, Node]]] = {} self.timer = Timer() diff --git a/hatchet/readers/hpctoolkit_reader.py b/hatchet/readers/hpctoolkit_reader.py index 7d72ce82..2ddf57b3 100644 --- a/hatchet/readers/hpctoolkit_reader.py +++ b/hatchet/readers/hpctoolkit_reader.py @@ -8,7 +8,7 @@ import re import os import traceback -from typing import Any, Dict, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import pandas as pd @@ -18,11 +18,11 @@ try: import xml.etree.cElementTree as ET except ImportError: - import xml.etree.ElementTree as ET + import xml.etree.ElementTree as ET # type: ignore[no-redef] # cython imports try: - import hatchet.cython_modules.libs.reader_modules as _crm + import hatchet.cython_modules.libs.reader_modules as _crm # type: ignore except ImportError: print("-" * 80) print( @@ -40,7 +40,8 @@ from hatchet.frame import Frame -src_file = 0 +src_file: Optional[str] = None +shared_metrics: Optional[Any] = None # TODO replace the "Any" type hint with numpy.typing.ArrayLike @@ -51,7 +52,7 @@ def init_shared_array(buf_: Any) -> None: shared_metrics = buf_ -def read_metricdb_file(args: Tuple[str, int, int, int, int, Tuple[int, int]]) -> None: +def read_metricdb_file(args: Tuple[str, int, int, int, int, List[int]]) -> None: """Read a single metricdb file into a 1D array.""" ( filename, @@ -85,7 +86,7 @@ def read_metricdb_file(args: Tuple[str, int, int, int, int, Tuple[int, int]]) -> rank * num_threads_per_rank + num_cpu_threads_per_rank + (thread - 500) ) * num_nodes - arr[rank_offset : rank_offset + num_nodes, :num_metrics].flat = arr1d.flat + arr[rank_offset : rank_offset + num_nodes, :num_metrics].flat = arr1d.flat # type: ignore[misc] arr[rank_offset : rank_offset + num_nodes, num_metrics] = range(1, num_nodes + 1) arr[rank_offset : rank_offset + num_nodes, num_metrics + 1] = rank arr[rank_offset : rank_offset + num_nodes, num_metrics + 2] = thread @@ -136,17 +137,17 @@ def __init__(self, dir_name: str) -> None: self.num_metrics = struct.unpack(">i", metricdb.read(4))[0] else: raise ValueError( - "HPCToolkitReader doesn't support endian '%s'" % endian + "HPCToolkitReader doesn't support endian '{:r}'".format(endian) ) - self.load_modules = {} - self.src_files = {} - self.procedure_names = {} - self.metric_names = {} + self.load_modules: Dict = {} + self.src_files: Dict = {} + self.procedure_names: Dict = {} + self.metric_names: Dict = {} # this list of dicts will hold all the node information such as # procedure name, load module, filename, etc. for all the nodes - self.node_dicts = [] + self.node_dicts: List[Dict[str, Union[int, str, Node]]] = [] self.timer = Timer() @@ -428,7 +429,7 @@ def create_node_dict( src_file: str, line: int, module: str, - ) -> Dict[str, Union[int, str, Node]]: + ) -> Dict[str, Any]: """Create a dict with all the node attributes.""" node_dict = { "nid": nid, diff --git a/hatchet/readers/hpctoolkit_reader_latest.py b/hatchet/readers/hpctoolkit_reader_latest.py index 97c1cd46..d201912b 100644 --- a/hatchet/readers/hpctoolkit_reader_latest.py +++ b/hatchet/readers/hpctoolkit_reader_latest.py @@ -6,7 +6,7 @@ import os import re import struct -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import pandas as pd @@ -68,18 +68,18 @@ def __init__( self._meta_file = None self._profile_file = None - self._functions = {} - self._source_files = {} - self._load_modules = {} - self._metric_descriptions = {} - self._summary_profile = {} + self._functions: Dict[int, Dict[str, Any]] = {} + self._source_files: Dict[int, Dict[str, Any]] = {} + self._load_modules: Dict[int, Dict[str, Any]] = {} + self._metric_descriptions: Dict = {} + self._summary_profile: Dict = {} - self._time_metric = None - self._inclusive_metrics = {} - self._exclusive_metrics = {} + self._time_metric: Optional[str] = None + self._inclusive_metrics: Dict = {} + self._exclusive_metrics: Dict = {} - self._cct_roots = [] - self._metrics_table = [] + self._cct_roots: List[Node] = [] + self._metrics_table: List[Dict[str, Any]] = [] for file_path in os.listdir(self._dir_path): if file_path.split(".")[-1] == "db": @@ -278,7 +278,7 @@ def _parse_context( ): continue - frame = {"type": NODE_TYPE_MAPPING[lexicalType]} + frame: Dict[str, Union[str, int]] = {"type": NODE_TYPE_MAPPING[lexicalType]} if nFlexWords: if lexicalType == 0: @@ -365,7 +365,7 @@ def _read_summary_profile( def _read_cct( self, - ) -> None: + ) -> Optional[GraphFrame]: with open(self._meta_file, "rb") as file: meta_db = file.read() @@ -410,7 +410,7 @@ def _read_cct( if im in table.columns.tolist(): inclusive_metrics.append(im) - for em in (list(self._exclusive_metrics.values()),): + for em in list(self._exclusive_metrics.values()): if em in table.columns.tolist(): exclusive_metrics.append(em) @@ -424,8 +424,9 @@ def _read_cct( print("DATA IMPORTED") return graphframe + return None - def read(self) -> GraphFrame: + def read(self) -> Optional[GraphFrame]: self._read_metric_descriptions() self._read_summary_profile() return self._read_cct() diff --git a/hatchet/readers/literal_reader.py b/hatchet/readers/literal_reader.py index f3e3fe1f..8f2b5f71 100644 --- a/hatchet/readers/literal_reader.py +++ b/hatchet/readers/literal_reader.py @@ -3,7 +3,8 @@ # # SPDX-License-Identifier: MIT -from typing import Any, Dict, List +from typing import Any, Dict, List, cast +from collections.abc import Iterable import pandas as pd @@ -61,7 +62,7 @@ class LiteralReader: (GraphFrame): graphframe containing data from dictionaries """ - def __init__(self, graph_dict: Dict) -> None: + def __init__(self, graph_dict: List[Dict]) -> None: """Read from list of dictionaries. graph_dict (dict): List of dictionaries encoding nodes. @@ -156,7 +157,7 @@ def read(self) -> hatchet.graphframe.GraphFrame: graph = Graph(list_roots) # test if nids are already loaded - if -1 in [n._hatchet_nid for n in graph.traverse()]: + if -1 in [n._hatchet_nid for n in cast(Iterable[Node], graph.traverse())]: graph.enumerate_traverse() else: graph.enumerate_depth() diff --git a/hatchet/readers/pyinstrument_reader.py b/hatchet/readers/pyinstrument_reader.py index cc739a76..621cea1b 100644 --- a/hatchet/readers/pyinstrument_reader.py +++ b/hatchet/readers/pyinstrument_reader.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: MIT import json -from typing import Any, Dict +from typing import Any, Dict, List import pandas as pd @@ -17,9 +17,9 @@ class PyinstrumentReader: def __init__(self, filename: str) -> None: self.pyinstrument_json_filename = filename - self.graph_dict = {} - self.list_roots = [] - self.node_dicts = [] + self.graph_dict: Dict[str, Any] = {} + self.list_roots: List[Node] = [] + self.node_dicts: List[Dict[str, Any]] = [] def create_graph(self) -> Graph: def parse_node_literal(child_dict: Dict[str, Any], hparent: Node) -> None: diff --git a/hatchet/readers/spotdb_reader.py b/hatchet/readers/spotdb_reader.py index 5352019e..dccd973a 100644 --- a/hatchet/readers/spotdb_reader.py +++ b/hatchet/readers/spotdb_reader.py @@ -3,7 +3,7 @@ # # SPDX-License-Identifier: MIT -from typing import Any, Dict, Optional, List +from typing import Any, Dict, Optional, List, Set import pandas as pd @@ -50,9 +50,9 @@ def __init__( self.regionprofile = regionprofile self.attr_info = attr_info self.metadata = metadata - self.df_data = [] - self.roots = {} - self.metric_columns = set() + self.df_data: List[Dict[str, Any]] = [] + self.roots: Dict[str, Node] = {} + self.metric_columns: Set[str] = set() self.timer = Timer() @@ -178,7 +178,7 @@ def read(self) -> List[hatchet.graphframe.GraphFrame]: Returns: List of GraphFrames, one for each entry that was found """ - import spotdb + import spotdb # type: ignore[import-not-found] if isinstance(self.db_key, str): db = spotdb.connect(self.db_key) diff --git a/hatchet/readers/tau_reader.py b/hatchet/readers/tau_reader.py index 4d192a0a..8cebab25 100644 --- a/hatchet/readers/tau_reader.py +++ b/hatchet/readers/tau_reader.py @@ -6,7 +6,8 @@ import re import os import glob -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Tuple, Union, cast +from collections.abc import Iterable import pandas as pd import hatchet.graphframe from hatchet.node import Node @@ -19,13 +20,13 @@ class TAUReader: def __init__(self, dirname: str) -> None: self.dirname = dirname - self.node_dicts = [] - self.callpath_to_node = {} - self.rank_thread_to_data = {} - self.filepath_to_data = {} - self.inc_metrics = [] - self.exc_metrics = [] - self.columns = [] + self.node_dicts: List[Dict[str, Any]] = [] + self.callpath_to_node: Dict[Tuple[str, ...], Node] = {} + # self.rank_thread_to_data = {} + # self.filepath_to_data = {} + self.inc_metrics: List[str] = [] + self.exc_metrics: List[str] = [] + self.columns: List[str] = [] self.multiple_ranks = False self.multiple_threads = False @@ -33,7 +34,7 @@ def create_node_dict( self, node: Node, columns: List[str], - metric_values: Tuple[Any, ...], + metric_values: Union[List[Any], Tuple[Any, ...]], name: str, filename: str, module: str, @@ -59,7 +60,7 @@ def create_node_dict( def create_graph(self) -> List[Node]: def _get_name_file_module( is_parent: bool, node_info: str, symbol: str - ) -> Tuple[str, str, str]: + ) -> List[str]: """This function gets the name, file and module information for a node using the corresponding line in the output file. Example line: [UNWIND] [@] [{} {}] @@ -74,18 +75,18 @@ def _get_name_file_module( # formats. Example formats are given in comments. if symbol == " [@] ": # Check if there is a [@] symbol. - node_info = node_info.split(symbol) + split_node_info = node_info.split(symbol) # We don't need file and module information if it's a parent node. if not is_parent: - file = node_info[0].split()[1] - if "[{" in node_info[1]: + file = split_node_info[0].split()[1] + if "[{" in split_node_info[1]: # Sometimes we see file and module information inside of [{}] # Example: [UNWIND] [@] [{} {}] - name_and_module = node_info[1].split(" [{") + name_and_module = split_node_info[1].split(" [{") module = name_and_module[1].split()[0].strip("}") else: # Example: [UNWIND] [@] - name_and_module = node_info[1].split() + name_and_module = split_node_info[1].split() module = name_and_module[1] # Check if module is in file. @@ -99,46 +100,46 @@ def _get_name_file_module( name = "[UNWIND] " + name_and_module[0] else: # We just need to take name if it is a parent - name = "[UNWIND] " + node_info[1].split()[0] + name = "[UNWIND] " + split_node_info[1].split()[0] elif symbol == " C ": # Check if there is a C symbol. # "C" symbol means it's a C function. - node_info = node_info.split(symbol) - name = node_info[0] + split_node_info = node_info.split(symbol) + name = split_node_info[0] # We don't need file and module information if it's a parent node. if not is_parent: - if "[{" in node_info[1]: + if "[{" in split_node_info[1]: # Example: C [{} {}] - node_info = node_info[1].split() - file = node_info[0].strip("}[{") + split_node_info = split_node_info[1].split() + file = split_node_info[0].strip("}[{") else: if "[{" in node_info: # If there isn't C or [@] # Example: [] [{} {}] - node_info = node_info.split(" [{") - name = node_info[0] + split_node_info = node_info.split(" [{") + name = split_node_info[0] # We don't need file and module information if it's a parent node. if not is_parent: - file = node_info[1].split()[0].strip("}{") + file = split_node_info[1].split()[0].strip("}{") else: # Example 1: [] # Example 2: [] # Example 3: name = node_info - node_info = node_info.split() + split_node_info = node_info.split() # We need to take module information from the first example. # Another example is "[CONTEXT] .TAU application" which contradicts # with the first example. So we check if there is "\" symbol which # will show the module information in this case. - if len(node_info) == 3 and "/" in name: - name = node_info[0] + " " + node_info[1] + if len(split_node_info) == 3 and "/" in name: + name = split_node_info[0] + " " + split_node_info[1] # We don't need file and module information if it's a parent node. if not is_parent: - module = node_info[2] + module = split_node_info[2] return [name, file, module] - def _get_line_numbers(node_info: str) -> Tuple[str, str]: - start_line, end_line = 0, 0 + def _get_line_numbers(node_info: str) -> List[str]: + start_line, end_line = "0", "0" # There should be [{}] symbols if there is line number information. if "[{" in node_info: tmp_module_or_file_line = ( @@ -149,9 +150,9 @@ def _get_line_numbers(node_info: str) -> Tuple[str, str]: if "-" in line_numbers: # Sometimes there is "-" between start line and end line # Example: {341,1}-{396,1} - line_numbers = line_numbers.split("-") - start_line = line_numbers[0].split(",")[0] - end_line = line_numbers[1].split(",")[0] + split_line_numbers = line_numbers.split("-") + start_line = split_line_numbers[0].split(",")[0] + end_line = split_line_numbers[1].split(",")[0] else: if "," in line_numbers: # Sometimes we don't have "-". @@ -160,7 +161,7 @@ def _get_line_numbers(node_info: str) -> Tuple[str, str]: end_line = line_numbers.split(",")[1] return [start_line, end_line] - def _create_parent(child_node: Node, parent_callpath: str) -> None: + def _create_parent(child_node: Node, parent_callpath: Tuple[str, ...]) -> None: """In TAU output, sometimes we see a node as a parent in the callpath before we see it as a leaf node. In this case, we need to create a hatchet node for the parent. @@ -210,7 +211,7 @@ def _construct_column_list(first_rank_filenames: List[str]) -> List[str]: all metric files of a rank as a tuple and only loads the second line (metadata) of these files. """ - columns = [] + columns: List[str] = [] for file_index in range(len(first_rank_filenames)): with open(first_rank_filenames[file_index], "r") as f: # Skip the first line: "192 templated_functions_MULTI_TIME" @@ -252,7 +253,7 @@ def _construct_column_list(first_rank_filenames: List[str]) -> List[str]: # Each tuple stores all the metric files of a rank. # We process one rank at a time. # Example: [(metric1/profile.x.0.0, metric2/profile.x.0.0), ...] - profile_filenames = list(zip(*profile_filenames)) + profile_filenames = list(cast(Iterable[List[str]], zip(*profile_filenames))) # Get column information from the metric files of a rank. self.columns = _construct_column_list(profile_filenames[0]) @@ -281,7 +282,7 @@ def _construct_column_list(first_rank_filenames: List[str]) -> List[str]: root_line = re.match(r"\"(.*)\"\s(.*)\sG", file_data[0][0]) root_name = root_line.group(1).strip(" ") # convert it to a tuple to use it as a key in callpath_to_node dictionary - root_callpath = tuple([root_name]) + root_callpath: Tuple[str, ...] = tuple([root_name]) root_values = list(map(int, root_line.group(2).split(" ")[:-1])) # After first profile.0.0.0, only get Excl and Incl metric values @@ -343,7 +344,7 @@ def _construct_column_list(first_rank_filenames: List[str]) -> List[str]: # Example: ".TAU application => foo() => bar()" 31 0 155019 155019 0 GROUP="TAU_SAMPLE|TAU_CALLPATH" callpath_line_regex = re.match(r"\"(.*)\"\s(.*)\sG", line) # callpath: ".TAU application => foo() => bar()" - callpath = [ + callpath: Union[List[str], Tuple[str, ...]] = [ name.strip(" ") for name in callpath_line_regex.group(1).split("=>") ] @@ -435,9 +436,9 @@ def _construct_column_list(first_rank_filenames: List[str]) -> List[str]: # module leaf_name_file_module[2], # start line - leaf_line_numbers[0], + int(leaf_line_numbers[0]), # end line - leaf_line_numbers[1], + int(leaf_line_numbers[1]), rank, thread, ) diff --git a/hatchet/readers/timemory_reader.py b/hatchet/readers/timemory_reader.py index e4ff0d91..cf4ec115 100644 --- a/hatchet/readers/timemory_reader.py +++ b/hatchet/readers/timemory_reader.py @@ -9,7 +9,7 @@ import glob import re from io import TextIOWrapper -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, cast from hatchet.graphframe import GraphFrame from ..node import Node from ..graph import Graph @@ -22,7 +22,7 @@ class TimemoryReader: def __init__( self, - input: Union[str, TextIOWrapper, Dict], + timemory_input: Union[str, TextIOWrapper, Dict], select: Optional[List[str]] = None, **_kwargs, ) -> None: @@ -47,18 +47,20 @@ def __init__( identical name/file/line/etc. info but from different ranks are not combined """ - self.graph_dict = {"timemory": {}} - self.input = input - self.default_metric = None + self.graph_dict: Dict[str, Dict[str, Any]] = {"timemory": {}} + self.input = timemory_input + self.default_metric: Optional[str] = None self.timer = Timer() - self.metric_cols = [] - self.properties = {} + self.metric_cols: List[str] = [] + self.properties: Dict[str, Any] = {} self.include_tid = True self.include_nid = True self.multiple_ranks = False self.multiple_threads = False - self.callpath_to_node_dict = {} # (callpath, rank, thread): - self.callpath_to_node = {} # (callpath): + self.callpath_to_node_dict: Dict[ + Tuple, Dict[str, Any] + ] = {} # (callpath, rank, thread): + self.callpath_to_node: Dict[Tuple[str, ...], Node] = {} # (callpath): # the per_thread and per_rank settings make sure that # squashing doesn't collapse the threads/ranks @@ -271,7 +273,7 @@ def parse_node( _node_data: Dict[str, Any], _hparent: Node, _rank: int, - _parent_callpath: Tuple[str], + _parent_callpath: Tuple[str, ...], ) -> None: """Create callpath_to_node_dict for one node and then call the function recursively on all children. @@ -296,7 +298,7 @@ def parse_node( _prop = self.properties[_metric_name] _frame_attrs, _extra = get_name_line_file(_node_data["node"]["prefix"]) - callpath = _parent_callpath + (_frame_attrs["name"],) + callpath: Tuple[str, ...] = _parent_callpath + (_frame_attrs["name"],) # check if the node already exits. _hnode = self.callpath_to_node.get(callpath) @@ -315,7 +317,9 @@ def parse_node( # for the Frame(_keys) effectively circumvent Hatchet's # default behavior of combining similar thread/rank entries _tid_dict = _frame_attrs if self.per_thread else _extra - _rank_dict = _frame_attrs if self.per_rank else _extra + _rank_dict: Dict[str, Union[str, int]] = cast( + Dict[str, Union[str, int]], _frame_attrs if self.per_rank else _extra + ) # handle the rank _rank_dict["rank"] = collapse_ids(_rank, self.per_rank) @@ -324,10 +328,10 @@ def parse_node( self.include_nid = False # extract some relevant data - _tid_dict["thread"] = collapse_ids( - _node_data["node"]["tid"], self.per_thread + _tid_dict["thread"] = cast( + str, collapse_ids(_node_data["node"]["tid"], self.per_thread) ) - _extra["pid"] = collapse_ids(_node_data["node"]["pid"], False) + _extra["pid"] = cast(str, collapse_ids(_node_data["node"]["pid"], False)) _extra["count"] = _node_data["node"]["inclusive"]["entry"]["laps"] # check if there are multiple threads @@ -598,7 +602,7 @@ def read(self) -> GraphFrame: if isinstance(self.input, dict): self.graph_dict = self.input # check if the input is a directory and get '.tree.json' files if true. - elif os.path.isdir(self.input): + elif isinstance(self.input, str) and os.path.isdir(self.input): tree_files = glob.glob(self.input + "/*.tree.json") for file in tree_files: # read all files that end with .tree.json. diff --git a/hatchet/util/colormaps.py b/hatchet/util/colormaps.py index c1036605..fd3b9d49 100644 --- a/hatchet/util/colormaps.py +++ b/hatchet/util/colormaps.py @@ -127,7 +127,7 @@ def get_colors(self, colormap: str, invert_colormap: bool) -> List[str]: self.colors = self.Spectral.copy() else: raise ValueError( - self.colormap + colormap + " is an incorrect colormap. Select one BrBG, PiYg, PRGn," + " PuOr, RdBu, RdGy, RdYlBu, RdYlGn, or Spectral." ) diff --git a/hatchet/util/dot.py b/hatchet/util/dot.py index 8811de48..8f9797fa 100644 --- a/hatchet/util/dot.py +++ b/hatchet/util/dot.py @@ -33,7 +33,7 @@ def trees_to_dot( all_edges = "" # call to_dot for each root in the graph - visited = [] + visited: List[Node] = [] for root in roots: (nodes, edges) = to_dot( root, dataframe, metric, name, rank, thread, threshold, visited @@ -57,20 +57,22 @@ def to_dot( visited: List[Node], ) -> Tuple[str, str]: """Write to graphviz dot format.""" - colormap = matplotlib.cm.Reds + # Tell mypy to ignore Reds here because mpl.cm is + # dynamically generated. So, mypy cannot discover that + # Reds exists + colormap = matplotlib.cm.Reds # type: ignore[attr-defined] min_time = dataframe[metric].min() max_time = dataframe[metric].max() def add_nodes_and_edges(hnode: Node) -> Tuple[str, str]: # set dataframe index based on if rank is a part of the index + df_index: Union[Tuple[Node, int, int], Tuple[Node, int], Node] = hnode if "rank" in dataframe.index.names and "thread" in dataframe.index.names: df_index = (hnode, rank, thread) elif "rank" in dataframe.index.names: df_index = (hnode, rank) elif "thread" in dataframe.index.names: df_index = (hnode, thread) - else: - df_index = hnode node_time = dataframe.loc[df_index, metric] node_name = dataframe.loc[df_index, name] diff --git a/hatchet/util/executable.py b/hatchet/util/executable.py index 2e6b42d3..0ff7f2e5 100644 --- a/hatchet/util/executable.py +++ b/hatchet/util/executable.py @@ -14,9 +14,9 @@ def which(executable: str) -> Optional[str]: executable (str): executable to search for """ path = os.environ.get("PATH", "/usr/sbin:/usr/bin:/sbin:/bin") - path = path.split(os.pathsep) + split_path = path.split(os.pathsep) - for directory in path: + for directory in split_path: exe = os.path.join(directory, executable) if os.path.isfile(exe) and os.access(exe, os.X_OK): return exe diff --git a/hatchet/util/profiler.py b/hatchet/util/profiler.py index ac8d4984..b98c5b51 100644 --- a/hatchet/util/profiler.py +++ b/hatchet/util/profiler.py @@ -11,10 +11,7 @@ from datetime import datetime -try: - from StringIO import StringIO # python2 -except ImportError: - from io import StringIO # python3 +from io import StringIO # python3 import pstats diff --git a/hatchet/util/timer.py b/hatchet/util/timer.py index f40f3043..1720d3c1 100644 --- a/hatchet/util/timer.py +++ b/hatchet/util/timer.py @@ -7,15 +7,16 @@ from contextlib import contextmanager from datetime import datetime, timedelta from io import StringIO +from typing import Optional class Timer(object): """Simple phase timer with a context manager.""" def __init__(self) -> None: - self._phase = None - self._start_time = None - self._times = OrderedDict() + self._phase: Optional[str] = None + self._start_time: Optional[datetime] = None + self._times: OrderedDict = OrderedDict() def start_phase(self, phase: str) -> timedelta: now = datetime.now() diff --git a/hatchet/writers/dataframe_writer.py b/hatchet/writers/dataframe_writer.py index 3adc37a0..d75b0aca 100644 --- a/hatchet/writers/dataframe_writer.py +++ b/hatchet/writers/dataframe_writer.py @@ -7,18 +7,7 @@ from hatchet.graphframe import GraphFrame import pandas as pd -from abc import abstractmethod - -# TODO The ABC class was introduced in Python 3.4. -# When support for earlier versions is (eventually) dropped, -# this entire "try-except" block can be reduced to: -# from abc import ABC -try: - from abc import ABC -except ImportError: - from abc import ABCMeta - - ABC = ABCMeta("ABC", (object,), {"__slots__": ()}) +from abc import abstractmethod, ABC def _get_node_from_df_iloc(df: pd.DataFrame, ind: int) -> Node: diff --git a/pyproject.toml b/pyproject.toml index a1261f27..12214108 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,14 @@ authors = [ license = "MIT" [tool.mypy] -exclude = "hatchet/tests" +exclude = [ + "hatchet/tests", + "hatchet/vis", + "hatchet/external/roundtrip", + "setup.py", +] +strict_optional = false +disable_error_code = "import-untyped" [tool.ruff] line-length = 88 From 6179835479b7ca0c59142709964be890627b9c10 Mon Sep 17 00:00:00 2001 From: Ian Lumsden Date: Thu, 14 Nov 2024 20:15:14 +0000 Subject: [PATCH 4/5] Adds devcontainer to Hatchet to help with things like formatting --- .devcontainer/Dockerfile | 60 ++++++++++++++ .devcontainer/devcontainer.json | 81 +++++++++++++++++++ .github/dependabot.yml | 12 +++ hatchet/graph.py | 7 +- hatchet/graphframe.py | 6 +- hatchet/node.py | 7 +- hatchet/query/compat.py | 7 +- hatchet/query/object_dialect.py | 6 +- hatchet/query/query.py | 7 +- hatchet/query/string_dialect.py | 90 +++++++++++++++------ hatchet/readers/caliper_native_reader.py | 16 ++-- hatchet/readers/caliper_reader.py | 6 +- hatchet/readers/hpctoolkit_reader_latest.py | 4 +- hatchet/readers/literal_reader.py | 7 +- hatchet/readers/tau_reader.py | 8 +- hatchet/readers/timemory_reader.py | 6 +- 16 files changed, 287 insertions(+), 43 deletions(-) create mode 100644 .devcontainer/Dockerfile create mode 100644 .devcontainer/devcontainer.json create mode 100644 .github/dependabot.yml diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile new file mode 100644 index 00000000..bf27765f --- /dev/null +++ b/.devcontainer/Dockerfile @@ -0,0 +1,60 @@ +FROM continuumio/miniconda3:24.9.2-0 + +USER root + +ARG USERNAME=vscode +ARG USER_UID=1000 +ARG USER_GID=1000 +ENV USERNAME=${USERNAME} +ENV USER_UID=${USER_UID} +ENV USER_GID=${USER_GID} + +RUN apt-get update -q \ + && apt-get install -q -y --no-install-recommends \ + build-essential \ + ripgrep \ + pandoc \ + adduser \ + git \ + grep \ + curl \ + wget \ + vim + +RUN conda install -y python=3.9 \ + && pip install --no-cache-dir pipx \ + && pipx reinstall-all + +RUN conda install -c conda-forge gh jupyterlab + +COPY requirements.txt /requirements.txt + +RUN python3 -m pip install -r requirements.txt + +RUN python3 -m pip install --upgrade flake8-pytest-importorskip click==8.0.4 black==24.4.2 flake8==4.0.1 + +RUN groupadd -g ${USER_GID} ${USERNAME} && \ + adduser --disabled-password --uid ${USER_UID} --gid ${USER_GID} --gecos "" ${USERNAME} && \ + echo "${USERNAME} ALL=(ALL) NOPASSWD: ALL" > /etc/sudoers + +USER $USERNAME + +# FROM mcr.microsoft.com/devcontainers/miniconda:1-3 +# +# # Copy environment.yml (if found) to a temp location so we update the environment. Also +# # copy "noop.txt" so the COPY instruction does not fail if no environment.yml exists. +# COPY environment.yml* .devcontainer/noop.txt /tmp/conda-tmp/ +# RUN if [ -f "/tmp/conda-tmp/environment.yml" ]; then umask 0002 && /opt/conda/bin/conda env update -n base -f /tmp/conda-tmp/environment.yml; fi \ +# && rm -rf /tmp/conda-tmp +# +# COPY requirements.txt /requirements.txt +# +# RUN conda install -y python=3.9 \ +# && pip install --no-cache-dir pipx \ +# && pipx reinstall-all +# +# +# # [Optional] Uncomment this section to install additional OS packages. +# # RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ +# # && apt-get -y install --no-install-recommends + \ No newline at end of file diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 00000000..d85b8169 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,81 @@ +// For format details, see https://aka.ms/devcontainer.json. For config options, see the +// README at: https://github.com/devcontainers/templates/tree/main/src/miniconda +{ + "name": "Hatchet Python 3.9", + "build": { + "context": "..", + "dockerfile": "Dockerfile" + }, + // "features": { + // // "ghcr.io/devcontainers/features/git:1": { + // // "ppa": true, + // // "version": "os-provided" + // // }, + // // "ghcr.io/devcontainers/features/git-lfs:1": { + // // "autoPull": true, + // // "version": "latest" + // // }, + // "ghcr.io/devcontainers/features/github-cli:1": { + // "installDirectlyFromGitHubRelease": true, + // "version": "latest" + // } + // // "ghcr.io/devcontainers/features/python:1": { + // // "installTools": true, + // // "toolsToInstall": "autopep8,yapf,pydocstyle,pycodestyle,bandit,pytest,pylint", + // // "enableShared": true, + // // "installJupyterlab": true, + // // "version": "3.9" + // // }, + // // "ghcr.io/devcontainers-extra/features/act:1": { + // // "version": "latest" + // // }, + // // "ghcr.io/devcontainers-extra/features/curl-apt-get:1": {}, + // // "ghcr.io/devcontainers-extra/features/fd:1": { + // // "version": "latest" + // // }, + // // "ghcr.io/devcontainers-extra/features/fzf:1": { + // // "version": "latest" + // // }, + // // "ghcr.io/devcontainers-extra/features/ripgrep:1": { + // // "version": "latest" + // // }, + // // "ghcr.io/devcontainers-extra/features/wget-apt-get:1": {} + // }, + // Configure tool-specific properties. + "customizations": { + "vscode": { + "extensions": [ + "ms-python.python", + "ms-python.black-formatter", + "ms-python.flake8", + "dbaeumer.vscode-eslint", + "ms-python.debugpy", + "ms-toolsai.jupyter" + ], + "settings": { + "python.defaultInterpreterPath": "/opt/conda/bin/python3", + "[python]": { + "editor.defaultFormatter": "ms-python.black-formatter", + "editor.formatOnSave": true + }, + "python.formatting.provider": "black", + "python.formatting.blackPath": "/opt/conda/bin/black", + "black-formatter.path": [ + "/opt/conda/bin/black" + ] + } + } + }, + "remoteEnv": { + "PATH": "/opt/conda/bin:${containerEnv:PATH}", + "EDITOR": "/usr/bin/vim" + } + // Use 'postCreateCommand' to run commands after the container is created. + // "postCreateCommand": "echo \"export PATH=$CURRENT_PYTHON_BINDIR:\\$PATH\" >> ~/.bashrc" + // Features to add to the dev container. More info: https://containers.dev/features. + // "features": {}, + // Use 'forwardPorts' to make a list of ports inside the container available locally. + // "forwardPorts": [], + // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. + // "remoteUser": "root" +} \ No newline at end of file diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..f33a02cd --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,12 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for more information: +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates +# https://containers.dev/guide/dependabot + +version: 2 +updates: + - package-ecosystem: "devcontainers" + directory: "/" + schedule: + interval: weekly diff --git a/hatchet/graph.py b/hatchet/graph.py index 91524d29..5555ccdc 100644 --- a/hatchet/graph.py +++ b/hatchet/graph.py @@ -3,10 +3,15 @@ # # SPDX-License-Identifier: MIT +import sys from collections import defaultdict -from collections.abc import Iterable from typing import Any, Dict, List, Optional, Set, Tuple, Union +if sys.version_info >= (3, 9): + from collections.abc import Iterable +else: + from typing import Iterable + from .node import Node, traversal_order, node_traversal_order diff --git a/hatchet/graphframe.py b/hatchet/graphframe.py index 5e1ff5e3..ffddcc8f 100644 --- a/hatchet/graphframe.py +++ b/hatchet/graphframe.py @@ -8,10 +8,14 @@ import sys import traceback from collections import defaultdict -from collections.abc import Callable, Iterable from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast from io import TextIOWrapper +if sys.version_info >= (3, 9): + from collections.abc import Callable, Iterable +else: + from typing import Callable, Iterable + import multiprocess as mp import numpy as np import pandas as pd diff --git a/hatchet/node.py b/hatchet/node.py index f1205642..7c76c3c9 100644 --- a/hatchet/node.py +++ b/hatchet/node.py @@ -3,9 +3,14 @@ # # SPDX-License-Identifier: MIT +import sys from functools import total_ordering from typing import Any, Dict, List, Optional, Set, Tuple, Union -from collections.abc import Iterable + +if sys.version_info >= (3, 9): + from collections.abc import Iterable +else: + from typing import Iterable from .frame import Frame diff --git a/hatchet/query/compat.py b/hatchet/query/compat.py index a2a96d6d..8408d20b 100644 --- a/hatchet/query/compat.py +++ b/hatchet/query/compat.py @@ -7,9 +7,14 @@ import sys import warnings -from collections.abc import Callable from typing import List, Optional, Union, cast, TYPE_CHECKING +if sys.version_info >= (3, 9): + from collections.abc import Callable +else: + from typing import Callable + + from ..node import Node from .query import Query from .compound import ( diff --git a/hatchet/query/object_dialect.py b/hatchet/query/object_dialect.py index cd0ed641..9e5e120d 100644 --- a/hatchet/query/object_dialect.py +++ b/hatchet/query/object_dialect.py @@ -13,7 +13,11 @@ import re import sys from typing import Dict, List, Tuple, Union -from collections.abc import Callable, Iterable + +if sys.version_info >= (3, 9): + from collections.abc import Callable, Iterable +else: + from typing import Callable, Iterable from .errors import InvalidQueryPath, InvalidQueryFilter, MultiIndexModeMismatch from ..node import Node diff --git a/hatchet/query/query.py b/hatchet/query/query.py index d3be624e..3eb8e2ea 100644 --- a/hatchet/query/query.py +++ b/hatchet/query/query.py @@ -3,8 +3,13 @@ # # SPDX-License-Identifier: MIT +import sys from typing import List, Tuple, Union -from collections.abc import Callable, Iterator + +if sys.version_info >= (3, 9): + from collections.abc import Callable, Iterator +else: + from typing import Callable, Iterator from .errors import InvalidQueryPath diff --git a/hatchet/query/string_dialect.py b/hatchet/query/string_dialect.py index 6b5f5292..c5d90b2f 100644 --- a/hatchet/query/string_dialect.py +++ b/hatchet/query/string_dialect.py @@ -4,10 +4,9 @@ # SPDX-License-Identifier: MIT from numbers import Real -import re +import re # noqa: F401 import sys -from collections.abc import Callable -from typing import Any, Dict, Optional, List, Union, TYPE_CHECKING +from typing import Any, Dict, Optional, List, Union import pandas as pd # noqa: F401 from pandas.api.types import is_numeric_dtype, is_string_dtype # noqa: F401 import numpy as np # noqa: F401 @@ -15,6 +14,11 @@ from textx.exceptions import TextXError import warnings +if sys.version_info >= (3, 9): + from collections.abc import Callable +else: + from typing import Callable + from .errors import InvalidQueryPath, InvalidQueryFilter, RedundantQueryFilterWarning from .query import Query @@ -728,7 +732,9 @@ def _parse_num_eq(self, obj: Any) -> List[Optional[str]]: This condition will always be false. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -751,7 +757,9 @@ def _parse_num_eq(self, obj: Any) -> List[Optional[str]]: This condition will always be false. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -800,7 +808,9 @@ def _parse_num_eq_multi_idx(self, obj: Any) -> List[Optional[str]]: This condition will always be false. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -823,7 +833,9 @@ def _parse_num_eq_multi_idx(self, obj: Any) -> List[Optional[str]]: This condition will always be false. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -870,7 +882,9 @@ def _parse_num_lt(self, obj: Any) -> List[Optional[str]]: This condition will always be false. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -893,7 +907,9 @@ def _parse_num_lt(self, obj: Any) -> List[Optional[str]]: This condition will always be false. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -935,7 +951,9 @@ def _parse_num_lt_multi_idx(self, obj: Any) -> List[Optional[str]]: This condition will always be false. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -958,7 +976,9 @@ def _parse_num_lt_multi_idx(self, obj: Any) -> List[Optional[str]]: This condition will always be false. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -1005,7 +1025,9 @@ def _parse_num_gt(self, obj: Any) -> List[Optional[str]]: This condition will always be true. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -1028,7 +1050,9 @@ def _parse_num_gt(self, obj: Any) -> List[Optional[str]]: This condition will always be true. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -1070,7 +1094,9 @@ def _parse_num_gt_multi_idx(self, obj: Any) -> List[Optional[str]]: This condition will always be true. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -1093,7 +1119,9 @@ def _parse_num_gt_multi_idx(self, obj: Any) -> List[Optional[str]]: This condition will always be true. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -1140,7 +1168,9 @@ def _parse_num_lte(self, obj: Any) -> List[Optional[str]]: This condition will always be false. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -1163,7 +1193,9 @@ def _parse_num_lte(self, obj: Any) -> List[Optional[str]]: This condition will always be false. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -1205,7 +1237,9 @@ def _parse_num_lte_multi_idx(self, obj: Any) -> List[Optional[str]]: This condition will always be false. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -1228,7 +1262,9 @@ def _parse_num_lte_multi_idx(self, obj: Any) -> List[Optional[str]]: This condition will always be false. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -1275,7 +1311,9 @@ def _parse_num_gte(self, obj: Any) -> List[Optional[str]]: This condition will always be true. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -1298,7 +1336,9 @@ def _parse_num_gte(self, obj: Any) -> List[Optional[str]]: This condition will always be true. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -1340,7 +1380,9 @@ def _parse_num_gte_multi_idx(self, obj: Any) -> List[Optional[str]]: This condition will always be true. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -1363,7 +1405,9 @@ def _parse_num_gte_multi_idx(self, obj: Any) -> List[Optional[str]]: This condition will always be true. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ diff --git a/hatchet/readers/caliper_native_reader.py b/hatchet/readers/caliper_native_reader.py index 04ae2e63..5e2bb969 100644 --- a/hatchet/readers/caliper_native_reader.py +++ b/hatchet/readers/caliper_native_reader.py @@ -7,8 +7,13 @@ import pandas as pd import numpy as np import os +import sys from typing import Any, Dict, List, Optional, Tuple, Union, cast -from collections.abc import Callable + +if sys.version_info > (3, 9): + from collections.abc import Callable +else: + from typing import Callable import caliperreader as cr @@ -180,12 +185,9 @@ def read_metrics(self, ctx: str = "path") -> List[pd.DataFrame]: or item in self.string_attributes ): try: - node_dict[item] = ( - cast( - Callable, - self.__cali_type_dict[attr_type], - )(record[item]), - ) + node_dict[item] = cast( + Callable, self.__cali_type_dict[attr_type] + )(record[item]) if item not in self.record_data_cols: self.record_data_cols.append(item) except ValueError as e: diff --git a/hatchet/readers/caliper_reader.py b/hatchet/readers/caliper_reader.py index e8ae25f3..6c616ccd 100644 --- a/hatchet/readers/caliper_reader.py +++ b/hatchet/readers/caliper_reader.py @@ -10,9 +10,13 @@ import os import math from typing import Any, Dict, List, Union, cast -from collections.abc import Callable from io import TextIOWrapper +if sys.version_info >= (3, 9): + from collections.abc import Callable +else: + from typing import Callable + import pandas as pd import numpy as np diff --git a/hatchet/readers/hpctoolkit_reader_latest.py b/hatchet/readers/hpctoolkit_reader_latest.py index d201912b..c76065e2 100644 --- a/hatchet/readers/hpctoolkit_reader_latest.py +++ b/hatchet/readers/hpctoolkit_reader_latest.py @@ -231,7 +231,9 @@ def _store_cct_node( "node": node, "name": ( # f"{frame['type']}: {frame['name']}" - frame["name"] if frame["name"] != 1 else "entry" + frame["name"] + if frame["name"] != 1 + else "entry" ), } diff --git a/hatchet/readers/literal_reader.py b/hatchet/readers/literal_reader.py index 8f2b5f71..76dfba2e 100644 --- a/hatchet/readers/literal_reader.py +++ b/hatchet/readers/literal_reader.py @@ -3,8 +3,13 @@ # # SPDX-License-Identifier: MIT +import sys from typing import Any, Dict, List, cast -from collections.abc import Iterable + +if sys.version_info >= (3, 9): + from collections.abc import Iterable +else: + from typing import Iterable import pandas as pd diff --git a/hatchet/readers/tau_reader.py b/hatchet/readers/tau_reader.py index 8cebab25..4d50cbe4 100644 --- a/hatchet/readers/tau_reader.py +++ b/hatchet/readers/tau_reader.py @@ -6,9 +6,15 @@ import re import os import glob +import sys from typing import Any, Dict, List, Tuple, Union, cast -from collections.abc import Iterable import pandas as pd + +if sys.version_info >= (3, 9): + from collections.abc import Iterable +else: + from typing import Iterable + import hatchet.graphframe from hatchet.node import Node from hatchet.graph import Graph diff --git a/hatchet/readers/timemory_reader.py b/hatchet/readers/timemory_reader.py index cf4ec115..4628a46d 100644 --- a/hatchet/readers/timemory_reader.py +++ b/hatchet/readers/timemory_reader.py @@ -57,9 +57,9 @@ def __init__( self.include_nid = True self.multiple_ranks = False self.multiple_threads = False - self.callpath_to_node_dict: Dict[ - Tuple, Dict[str, Any] - ] = {} # (callpath, rank, thread): + self.callpath_to_node_dict: Dict[Tuple, Dict[str, Any]] = ( + {} + ) # (callpath, rank, thread): self.callpath_to_node: Dict[Tuple[str, ...], Node] = {} # (callpath): # the per_thread and per_rank settings make sure that From 9e443736fd3a9370f9b29dab20f7169e29e21be5 Mon Sep 17 00:00:00 2001 From: Ian Lumsden Date: Thu, 14 Nov 2024 21:11:13 +0000 Subject: [PATCH 5/5] Adds type checking (through mypy) to CI --- .devcontainer/Dockerfile | 21 +-------------------- .github/workflows/unit-tests.yaml | 8 +++++++- hatchet/graphframe.py | 2 +- 3 files changed, 9 insertions(+), 22 deletions(-) diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index bf27765f..073251a4 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -31,30 +31,11 @@ COPY requirements.txt /requirements.txt RUN python3 -m pip install -r requirements.txt -RUN python3 -m pip install --upgrade flake8-pytest-importorskip click==8.0.4 black==24.4.2 flake8==4.0.1 +RUN python3 -m pip install --upgrade flake8-pytest-importorskip click==8.0.4 black==24.4.2 flake8==4.0.1 mypy==1.13.0 RUN groupadd -g ${USER_GID} ${USERNAME} && \ adduser --disabled-password --uid ${USER_UID} --gid ${USER_GID} --gecos "" ${USERNAME} && \ echo "${USERNAME} ALL=(ALL) NOPASSWD: ALL" > /etc/sudoers USER $USERNAME - -# FROM mcr.microsoft.com/devcontainers/miniconda:1-3 -# -# # Copy environment.yml (if found) to a temp location so we update the environment. Also -# # copy "noop.txt" so the COPY instruction does not fail if no environment.yml exists. -# COPY environment.yml* .devcontainer/noop.txt /tmp/conda-tmp/ -# RUN if [ -f "/tmp/conda-tmp/environment.yml" ]; then umask 0002 && /opt/conda/bin/conda env update -n base -f /tmp/conda-tmp/environment.yml; fi \ -# && rm -rf /tmp/conda-tmp -# -# COPY requirements.txt /requirements.txt -# -# RUN conda install -y python=3.9 \ -# && pip install --no-cache-dir pipx \ -# && pipx reinstall-all -# -# -# # [Optional] Uncomment this section to install additional OS packages. -# # RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ -# # && apt-get -y install --no-install-recommends \ No newline at end of file diff --git a/.github/workflows/unit-tests.yaml b/.github/workflows/unit-tests.yaml index 93d1abfc..412c7101 100644 --- a/.github/workflows/unit-tests.yaml +++ b/.github/workflows/unit-tests.yaml @@ -37,19 +37,25 @@ jobs: python setup.py build_ext --inplace python -m pip list - - name: Update Black + - name: Update Black and mypy if: ${{ matrix.python-version == 3.9 }} run: | pip install flake8-pytest-importorskip pip install --upgrade click==8.0.4 pip install black==24.4.2 pip install flake8==4.0.1 + pip install mypy==1.13.0 - name: Lint and Format Check with Flake8 and Black if: ${{ matrix.python-version == 3.9 }} run: | black --diff --check . flake8 + + - name: Run type checking with mypy + if: ${{ matrix.python-version == 3.9 }} + run: | + mypy hatchet --pretty - name: Check License Headers run: | diff --git a/hatchet/graphframe.py b/hatchet/graphframe.py index ffddcc8f..0b5fb04a 100644 --- a/hatchet/graphframe.py +++ b/hatchet/graphframe.py @@ -1092,7 +1092,7 @@ def tree( if color is False: try: - import IPython + import IPython # type: ignore[import-not-found] shell = IPython.get_ipython().__class__.__name__ except ImportError: