diff --git a/.github/workflows/unit-tests.yaml b/.github/workflows/unit-tests.yaml index a59b3664..d395f0ca 100644 --- a/.github/workflows/unit-tests.yaml +++ b/.github/workflows/unit-tests.yaml @@ -12,9 +12,9 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] + python-version: ["3.8", "3.9", "3.10", "3.11"] include: - - python-version: "3.6" + - python-version: "3.7" os: ubuntu-20.04 steps: @@ -26,7 +26,7 @@ jobs: - name: Install Python packages run: | pip install --upgrade pip - pip install --upgrade numpy pandas pytest otf2 + pip install --upgrade numpy pandas pytest otf2 bokeh datashader - name: Lint and format check with flake8 and black if: ${{ matrix.python-version == 3.9 }} @@ -54,7 +54,7 @@ jobs: - name: Install Python packages run: | pip install --upgrade pip - pip install --upgrade numpy pandas pytest otf2 + pip install --upgrade numpy pandas pytest otf2 bokeh datashader - name: Basic test with pytest run: | diff --git a/.gitignore b/.gitignore index 5499117a..c7fd0484 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ .cache .pytest_cache .ipynb_checkpoints +env \ No newline at end of file diff --git a/docs/examples/vis.ipynb b/docs/examples/vis.ipynb new file mode 100644 index 00000000..bb867b2d --- /dev/null +++ b/docs/examples/vis.ipynb @@ -0,0 +1,90 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "e5841460", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "import sys\n", + "\n", + "sys.path.append(\"../../\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "da5ebec2", + "metadata": {}, + "outputs": [], + "source": [ + "import pipit as pp" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7070057d", + "metadata": {}, + "outputs": [], + "source": [ + "ping_pong = pp.Trace.from_otf2(\"../../pipit/tests/data/ping-pong-otf2\")\n", + "ping_pong" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2476f1f3-a6cf-49f2-8e3e-0e004f088504", + "metadata": {}, + "outputs": [], + "source": [ + "ping_pong.plot_comm_matrix()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a4aff933", + "metadata": {}, + "outputs": [], + "source": [ + "ping_pong.plot_message_histogram()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d7318b31-b9bf-4a28-ac44-b75614d2ddf4", + "metadata": {}, + "outputs": [], + "source": [ + "ping_pong.plot_timeline()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pipit/readers/otf2_reader.py b/pipit/readers/otf2_reader.py index b1fae889..b685c3a7 100644 --- a/pipit/readers/otf2_reader.py +++ b/pipit/readers/otf2_reader.py @@ -316,9 +316,9 @@ def events_reader(self, rank_size): if value is not None and key != "time": # uses field_to_val to convert all data types # and ensure that there are no pickling errors - attributes_dict[ - self.field_to_val(key) - ] = self.handle_data(value) + attributes_dict[self.field_to_val(key)] = ( + self.handle_data(value) + ) event_attributes.append(attributes_dict) else: # nan attributes for leave rows diff --git a/pipit/readers/projections_reader.py b/pipit/readers/projections_reader.py index 38153e8f..dd59ffe2 100644 --- a/pipit/readers/projections_reader.py +++ b/pipit/readers/projections_reader.py @@ -513,10 +513,10 @@ def _read_log_file(self, rank_size) -> pd.DataFrame: pe = int(line_arr[5]) msglen = int(line_arr[6]) send_time = int(line_arr[7]) * 1000 - numPEs = int(line_arr[8]) - destPEs = [] - for i in (0, numPEs): - destPEs.append(int(line_arr[9 + i])) + num_procs = int(line_arr[8]) + dest_procs = [] + for i in (0, num_procs): + dest_procs.append(int(line_arr[9 + i])) details = { "From PE": pe, @@ -525,7 +525,7 @@ def _read_log_file(self, rank_size) -> pd.DataFrame: "Message Length": msglen, "Event ID": event, "Send Time": send_time, - "Destinatopn PEs": destPEs, + "Destinatopn PEs": dest_procs, } _add_to_trace_dict( @@ -534,7 +534,7 @@ def _read_log_file(self, rank_size) -> pd.DataFrame: "Instant", time, pe_num, - "To " + str(numPEs) + "processors", + "To " + str(num_procs) + "processors", ) # Processing of chare (i.e. execution) ? diff --git a/pipit/trace.py b/pipit/trace.py index 2b4f111a..ac4b493f 100644 --- a/pipit/trace.py +++ b/pipit/trace.py @@ -291,6 +291,89 @@ def _match_caller_callee(self): {"_depth": "category", "_parent": "category"} ) + def _match_messages(self): + """ + Matches corresponding MpiSend/MpiRecv and MpiIsend/MpiIrecv instant events + """ + # if we've already matched -- i.e. if _matching_event column exists AND it's not empty for MpiSend/MpiRecv + if ( + "_matching_event" in self.events.columns + and not self.events[ + self.events["Name"].isin(["MpiSend", "MpiRecv", "MpiIsend", "MpiIrecv"]) + ]["_matching_event"] + .isnull() + .all() + ): + return + + if "_matching_event" not in self.events.columns: + self.events["_matching_event"] = None + + if "_matching_timestamp" not in self.events.columns: + self.events["_matching_timestamp"] = np.nan + + matching_events = list(self.events["_matching_event"]) + matching_times = list(self.events["_matching_timestamp"]) + + mpi_events = self.events[ + self.events["Name"].isin(["MpiSend", "MpiRecv", "MpiIsend", "MpiIrecv"]) + ] + + queue = [[] for _ in range(len(self.events["Process"].unique()))] + + df_indices = list(mpi_events.index) + timestamps = list(mpi_events["Timestamp (ns)"]) + names = list(mpi_events["Name"]) + attrs = list(mpi_events["Attributes"]) + processes = list(mpi_events["Process"]) + + # Iterate through all events + for i in range(len(mpi_events)): + curr_df_index = df_indices[i] + curr_timestamp = timestamps[i] + curr_name = names[i] + curr_attrs = attrs[i] + curr_process = processes[i] + + if curr_name == "MpiSend" or curr_name == "MpiIsend": + # Add current dataframe index, timestmap, and process to stack + if "receiver" in curr_attrs: + queue[curr_attrs["receiver"]].append( + (curr_df_index, curr_timestamp, curr_name, curr_process) + ) + elif curr_name == "MpiRecv" or curr_name == "MpiIrecv": + if "sender" in curr_attrs: + send_process = None + i = 0 + + # we want to iterate through the queue in order + # until we find the corresponding "send" event + while send_process != curr_attrs["sender"] and i < len( + queue[curr_process] + ): + send_df_index, send_timestamp, send_name, send_process = queue[ + curr_process + ][i] + i += 1 + + if send_process == curr_attrs["sender"] and i <= len( + queue[curr_process] + ): + # remove matched event from queue + del queue[curr_process][i - 1] + + # Fill in the lists with the matching values if event found + matching_events[send_df_index] = curr_df_index + matching_events[curr_df_index] = send_df_index + + matching_times[send_df_index] = curr_timestamp + matching_times[curr_df_index] = send_timestamp + + self.events["_matching_event"] = matching_events + self.events["_matching_timestamp"] = matching_times + + self.events = self.events.astype({"_matching_event": "Int32"}) + def calc_inc_metrics(self, columns=None): # if no columns are specified by the user, then we calculate # inclusive metrics for all the numeric columns in the trace @@ -548,7 +631,7 @@ def flat_profile( self.events.loc[self.events["Event Type"] == "Enter"] .groupby([groupby_column, "Process"], observed=True)[metrics] .sum() - .groupby(groupby_column) + .groupby(groupby_column, observed=True) .mean() ) @@ -597,7 +680,7 @@ def load_imbalance(self, metric="time.exc", num_processes=1): return imbalance_df - def idle_time(self, idle_functions=["Idle"], MPI_events=False): + def idle_time(self, idle_functions=["Idle"], mpi_events=False): # dict for creating a new dataframe idle_times = {"Process": [], "Idle Time": []} @@ -605,19 +688,19 @@ def idle_time(self, idle_functions=["Idle"], MPI_events=False): idle_times["Process"].append(process) idle_times["Idle Time"].append( self._calculate_idle_time_for_process( - process, idle_functions, MPI_events + process, idle_functions, mpi_events ) ) return pd.DataFrame(idle_times) def _calculate_idle_time_for_process( - self, process, idle_functions=["Idle"], MPI_events=False + self, process, idle_functions=["Idle"], mpi_events=False ): # calculate inclusive metrics if "time.inc" not in self.events.columns: self.calc_inc_metrics() - if MPI_events: + if mpi_events: idle_functions += ["MPI_Wait", "MPI_Waitall", "MPI_Recv"] # filter the dataframe to include only 'Enter' events within the specified # process with the specified function names @@ -861,3 +944,210 @@ def detect_pattern( patterns.append(match_original) return patterns + + def plot_comm_matrix(self, output="size", *args, **kwargs): + from .vis import plot_comm_matrix + + # Generate the data + data = self.comm_matrix(output=output) + + # Return the Bokeh plot + return plot_comm_matrix(data, output=output, *args, **kwargs) + + def plot_message_histogram(self, bins=20, *args, **kwargs): + from .vis import plot_message_histogram + + # Generate the data + data = self.message_histogram(bins=bins) + + # Return the Bokeh plot + return plot_message_histogram(data, *args, **kwargs) + + def plot_comm_over_time(self, output="size", message_type="send", *args, **kwargs): + from .vis import plot_comm_over_time + + # Generate the data + data = self.comm_over_time( + output=output, message_type=message_type, *args, **kwargs + ) + + # Return the Bokeh plot + return plot_comm_over_time(data, message_type=message_type, output=output) + + def plot_comm_by_process(self, output="size", *args, **kwargs): + from .vis import plot_comm_by_process + + # Generate the data + data = self.comm_by_process(output=output) + + # Return the Bokeh plot + return plot_comm_by_process(data, output=output, *args, **kwargs) + + def plot_timeline(self, *args, **kwargs): + from .vis import plot_timeline + + # Return the Bokeh plot + return plot_timeline(self, *args, **kwargs) + + def critical_path_analysis(self): + self._match_events() + + instant_recv_events = ["MpiRecv", "MpiIrecv"] + recv_events = ["MPI_Recv", "MPI_Irecv"] + send_events = ["MPI_Send", "MPI_Isend"] + collective_functions = [ + "MPI_Allgather", + "MPI_Allgatherv", + "MPI_Allreduce", + "MPI_Alltoall", + "MPI_Alltoallv", + "MPI_Alltoallw", + "MPI_Barrier", + "MPI_Bcast", + "MPI_Gather", + "MPI_Gatherv", + "MPI_Iallgather", + "MPI_Iallreduce", + "MPI_Ibarrier", + "MPI_Ibcast", + "MPI_Igather", + "MPI_Igatherv", + "MPI_Ireduce", + "MPI_Iscatter", + "MPI_Iscatterv", + "MPI_Reduce", + "MPI_Scatter", + "MPI_Scatterv", + "MPI_Exscan", + "MPI_Op_create", + "MPI_Op_free", + "MPI_Reduce_local", + "MPI_Reduce_scatter", + "MPI_Scan", + "MPI_File_iread_at_all", + "MPI_File_iwrite_at_all", + "MPI_File_iread_all", + "MPI_File_iwrite_all", + "MPI_File_read_all_begin", + "MPI_File_write_all_begin", + "MPI_File_write_all_end", + "MPI_File_close", + ] + + critical_paths = [] + critical_path = [] + last_event = None + leave_events = self.events[(self.events["Event Type"] == "Leave")] + num_of_processes = leave_events["Process"].astype(int).max() + + if "MPI_Finalize" in self.events["Name"].values: + last_event = self.events[ + (self.events["Event Type"] == "Leave") + & (self.events["Name"] == "MPI_Finalize") + ].iloc[-1] + else: + last_event = self.events[self.events["Event Type"] == "Leave"].iloc[-1] + last_name = last_event["Name"] + last_process = last_event["Process"] + last_timestamp = last_event["Timestamp (ns)"] + critical_path.append(last_event) + after_recieve = False + after_collective = False + # Main loop to trace back + while True: + # Filter for events from the same process before the last timestamp + candidate_events = leave_events[ + (leave_events["Process"] == last_event["Process"]) + & (leave_events["Timestamp (ns)"] < last_event["Timestamp (ns)"]) + ] + + # obtain the latest function after the collective function call. + # we basically do the something similar to starting with MPI_Finalize + # but this time we use a different function. + if after_collective: + candidate_name = candidate_events.iloc[-1]["Name"] + candidate_events = leave_events[ + (leave_events["Name"] == candidate_name) + & (leave_events["Timestamp (ns)"] < last_event["Timestamp (ns)"]) + ] + + # No more events to trace back from. + if candidate_events.empty: + break + + # Select the last event from the candidates if + # we the previous event is not a receive. + if not after_recieve: + last_event = candidate_events.iloc[-1] + critical_path.append(last_event) + + # If we continue after a receive function. + if last_event["Name"] in recv_events: + # Get the corresponding instant event for the recv function. + last_instant_event = self.events[ + (self.events["Event Type"] == "Instant") + & (self.events["Timestamp (ns)"] < last_timestamp) + & (self.events["Process"] == last_process) + & (self.events["Name"].isin(instant_recv_events)) + ] + # Sometimes recv function have some instant events which + # do not include the sender information. We ignore them. + if last_instant_event.empty: + continue + else: + last_instant_event = last_instant_event.iloc[-1] + + # Get the corresponding send event. + last_event = self.events[ + (self.events["Timestamp (ns)"] < last_timestamp) + & ( + self.events["Process"] + == last_instant_event["Attributes"]["sender"] + ) + & ( + self.events["Name"].isin(send_events) + & (self.events["Event Type"] == "Leave") + ) + ].iloc[-1] + last_timestamp = last_event["Timestamp (ns)"] + last_process = last_event["Process"] + + after_receive = True + after_collective = False + critical_path.append(last_event) + pass + + # Restart the detection after a collective function. + if last_event["Name"] in collective_functions: + critical_paths.append(critical_path) + critical_path = [] + after_collective = True + after_receive = False + + start_times = [] + check_if_done = leave_events[leave_events["Name"] == last_event["Name"]] + for start_time in check_if_done.iloc[0 : num_of_processes + 1][ + "Timestamp (ns)" + ]: + start_times.append(start_time) + + # Exit if we have traced back to the beginning + leave_events = leave_events.reset_index(drop=True) + if ( + leave_events[ + (leave_events["Timestamp (ns)"] == last_event["Timestamp (ns)"]) + ].index + <= num_of_processes + 1 + ): + if ( + last_event["Name"] == leave_events.iloc[0]["Name"] + and last_event["Timestamp (ns)"] in start_times + ): + critical_paths.append(critical_path) + break + + critical_dfs = [] + for critical_path in critical_paths: + if len(critical_path) > 1: + critical_dfs.append(pd.DataFrame(critical_path)) + return critical_dfs diff --git a/pipit/util/cct.py b/pipit/util/cct.py index 6557a588..b4992271 100644 --- a/pipit/util/cct.py +++ b/pipit/util/cct.py @@ -86,9 +86,11 @@ def create_cct(events): # add node as root or child of its # parent depending on current depth - graph.add_root( - curr_node - ) if curr_depth == 0 else parent_node.add_child(curr_node) + ( + graph.add_root(curr_node) + if curr_depth == 0 + else parent_node.add_child(curr_node) + ) # Update nodes stack, column, and current depth nodes_stack.append(curr_node) diff --git a/pipit/util/config.py b/pipit/util/config.py index f71bb53f..c3f77882 100644 --- a/pipit/util/config.py +++ b/pipit/util/config.py @@ -83,6 +83,23 @@ def url_validator(key, value): ) +# Validator to check if theme is valid YAML +def theme_validator(key, value): + import yaml + + try: + yaml.safe_load(value) + except yaml.YAMLError: + raise ValueError( + ( + 'Error loading configuration: The Value "{}" for Configuration "{}"' + + "must be a valid YAML" + ).format(value, key) + ) + else: + return True + + registered_options = { "log_level": { "default": "INFO", @@ -92,6 +109,34 @@ def url_validator(key, value): "default": "http://localhost:8888", "validator": url_validator, }, + "theme": { + "default": """ + attrs: + Plot: + height: 350 + width: 700 + background_fill_color: "#fafafa" + Axis: + axis_label_text_font_style: "bold" + minor_tick_line_color: null + Toolbar: + autohide: true + logo: null + HoverTool: + point_policy: "follow_mouse" + Legend: + label_text_font_size: "8.5pt" + spacing: 6 + border_line_color: null + glyph_width: 16 + glyph_height: 16 + Scatter: + size: 9 + DataRange1d: + range_padding: 0.05 + """, + "validator": theme_validator, + }, } global_config = {key: registered_options[key]["default"] for key in registered_options} diff --git a/pipit/vis/__init__.py b/pipit/vis/__init__.py new file mode 100644 index 00000000..03f443fc --- /dev/null +++ b/pipit/vis/__init__.py @@ -0,0 +1,7 @@ +from .core import ( + plot_comm_matrix, + plot_message_histogram, + plot_comm_over_time, + plot_comm_by_process, +) # noqa: F401 +from .timeline import plot_timeline # noqa: F401 diff --git a/pipit/vis/core.py b/pipit/vis/core.py new file mode 100644 index 00000000..5c37665a --- /dev/null +++ b/pipit/vis/core.py @@ -0,0 +1,289 @@ +import numpy as np +import pandas as pd +from bokeh.models import ( + ColorBar, + HoverTool, + LinearColorMapper, + LogColorMapper, + NumeralTickFormatter, +) +from bokeh.plotting import figure +from bokeh.models import BasicTicker +from bokeh.transform import dodge + +from .util import ( + clamp, + get_process_ticker, + get_size_hover_formatter, + get_size_tick_formatter, + show, + get_time_tick_formatter, + get_time_hover_formatter, +) + + +def plot_comm_matrix( + data, output="size", cmap="log", palette="Viridis256", return_fig=False +): + """Plots the trace's communication matrix. + + Args: + data (numpy.ndarray): a 2D numpy array of shape (N, N) containing the + communication matrix between N processes. + output (str, optional): Specifies whether the matrix contains "size" + or "count" values. Defaults to "size". + cmap (str, optional): Specifies the color mapping. Options are "log", + "linear", and "any". Defaults to "log". + palette (str, optional): Name of Bokeh color palette to use. Defaults to + "Viridis256". + return_fig (bool, optional): Specifies whether to return the Bokeh figure + object. Defaults to False, which displays the result and returns nothing. + + Returns: + Bokeh figure object if return_fig, None otherwise + """ + nranks = data.shape[0] + + # Define color mapper + if cmap == "linear": + color_mapper = LinearColorMapper(palette=palette, low=0, high=np.amax(data)) + elif cmap == "log": + color_mapper = LogColorMapper( + palette=palette, low=max(np.amin(data), 1), high=np.amax(data) + ) + elif cmap == "any": + color_mapper = LinearColorMapper(palette=palette, low=1, high=1) + + # Create bokeh plot + p = figure( + x_axis_label="Receiver", + y_axis_label="Sender", + x_range=(-0.5, nranks - 0.5), + y_range=(nranks - 0.5, -0.5), + x_axis_location="above", + tools="hover,pan,reset,wheel_zoom,save", + width=90 + clamp(nranks * 30, 200, 500), + height=10 + clamp(nranks * 30, 200, 500), + toolbar_location="below", + ) + + # Add glyphs and layouts + p.image( + image=[np.flipud(data)], + x=-0.5, + y=-0.5, + dw=nranks, + dh=nranks, + color_mapper=color_mapper, + origin="top_left", + ) + + color_bar = ColorBar( + color_mapper=color_mapper, + formatter=( + get_size_tick_formatter(ignore_range=cmap == "log") + if output == "size" + else NumeralTickFormatter() + ), + width=15, + ) + p.add_layout(color_bar, "right") + + # Customize plot + p.axis.ticker = get_process_ticker(nranks=nranks) + p.grid.visible = False + + # Configure hover + hover = p.select(HoverTool) + hover.tooltips = [ + ("Sender", "$y{0.}"), + ("Receiver", "$x{0.}"), + ("Count", "@image") if output == "count" else ("Volume", "@image{custom}"), + ] + hover.formatters = {"@image": get_size_hover_formatter()} + + # Return plot + return show(p, return_fig=return_fig) + + +def plot_message_histogram( + data, + return_fig=False, +): + """Plots the trace's message size histogram. + + Args: + data (hist, edges): Histogram and edges + return_fig (bool, optional): Specifies whether to return the Bokeh figure + object. Defaults to False, which displays the result and returns nothing. + + Returns: + Bokeh figure object if return_fig, None otherwise + """ + hist, edges = data + + # Create bokeh plot + p = figure( + x_axis_label="Message size", + y_axis_label="Number of messages", + tools="hover,save", + ) + p.y_range.start = 0 + + # Add glyphs and layouts + p.quad(top=hist, bottom=0, left=edges[:-1], right=edges[1:]) + + # Customize plot + p.xaxis.formatter = get_size_tick_formatter() + p.yaxis.formatter = NumeralTickFormatter() + p.xgrid.visible = False + + # Configure hover + hover = p.select(HoverTool) + hover.tooltips = [ + ("Bin", "@left{custom} - @right{custom}"), + ("Count", "@top"), + ] + hover.formatters = { + "@left": get_size_hover_formatter(), + "@right": get_size_hover_formatter(), + } + + # Return plot + return show(p, return_fig=return_fig) + + +def plot_comm_over_time( + data: tuple, output: str, message_type: str, return_fig: bool = False +): + """Plots the trace's communication over time. + + Args: + data (hist, edges): Histogram and edges + output (str): Specifies whether the matrix contains "size" or "count" values. + message_type (str): Specifies whether the message is "send" or "receive". + return_fig (bool, optional): Specifies whether to return the Bokeh figure + object. Defaults to False, which displays the result and returns nothing. + + Returns: + Bokeh figure object if return_fig, None otherwise + """ + + hist, edges = data + is_size = output == "size" + + p = figure( + x_axis_label="Time", + y_axis_label="Total volume sent" if is_size else "Number of messages", + tools="hover,save", + ) + p.y_range.start = 0 + p.xaxis.formatter = get_time_tick_formatter() + p.yaxis.formatter = get_size_tick_formatter() + + p.quad(top=hist, bottom=0, left=edges[:-1], right=edges[1:], line_color="white") + + hover = p.select(HoverTool) + hover.tooltips = ( + [ + ("Bin", "@left{custom} - @right{custom}"), + ("Total volume sent:", "@top{custom}"), + ] + if is_size + else [ + ("Bin", "@left{custom} - @right{custom}"), + ("number of messages:", "@top"), + ] + ) + hover.formatters = { + "@left": get_time_hover_formatter(), + "@right": get_time_hover_formatter(), + "@top": get_size_hover_formatter(), + } + + return show(p, return_fig=return_fig) + + +def plot_comm_by_process( + data: pd.DataFrame, + output: str, + return_fig: bool = False, + width: int = None, + height: int = 600, +): + """ + Plots the trace's communication by process. + + Args: + data (pd.DataFrame): DataFrame containing the communication data. + output (str): Specifies whether the matrix contains "size" or "count" values. + return_fig (bool, optional): Specifies whether to return the Bokeh figure + object. Defaults to False, which displays the result and returns nothing. + width: The width of the plot. Default is None, which makes the plot full width. + height: The height of the plot. Default is 600. + + Returns: + Bokeh figure object if return_fig, None otherwise + """ + data.reset_index() + is_size = output == "size" + + p = figure( + x_range=(-0.5, len(data) - 0.5), + y_axis_label="Volume", + x_axis_label="Process", + tools="hover,save", + width=width, + sizing_mode="fixed" if width is not None else "stretch_width", + height=height, + ) + p.y_range.start = 0 + p.y_range.range_padding = 0.5 + + p.xgrid.visible = False + p.yaxis.formatter = get_size_tick_formatter() + p.xaxis.ticker = BasicTicker( + base=2, + desired_num_ticks=min(len(data), 16), + min_interval=1, + num_minor_ticks=0, + ) + + p.vbar( + x=dodge("Process", -0.1667, range=p.y_range), + top="Sent", + width=0.2, + source=data, + color="#1f77b4", + legend_label="Total sent", + ) + p.vbar( + x=dodge("Process", 0.1667, range=p.y_range), + top="Received", + width=0.2, + source=data, + color="#d62728", + legend_label="Total received", + ) + p.add_layout(p.legend[0], "below") + p.legend.orientation = "horizontal" + p.legend.location = "center" + + hover = p.select(HoverTool) + hover.tooltips = ( + [ + ("Bin", "@left{custom} - @right{custom}"), + ("Total volume sent:", "@top{custom}"), + ] + if is_size + else [ + ("Bin", "@left{custom} - @right{custom}"), + ("number of messages:", "@top"), + ] + ) + hover.formatters = { + "@Sent": get_size_hover_formatter(), + "@Received": get_size_hover_formatter(), + } + + return show(p, return_fig=return_fig) diff --git a/pipit/vis/timeline.py b/pipit/vis/timeline.py new file mode 100644 index 00000000..0e4ece5e --- /dev/null +++ b/pipit/vis/timeline.py @@ -0,0 +1,454 @@ +import numpy as np +import pandas as pd +from bokeh.models import ( + Arrow, + ColumnDataSource, + CustomJS, + CustomJSTickFormatter, + FixedTicker, + Grid, + HoverTool, + OpenHead, + WheelZoomTool, +) +from bokeh.events import RangesUpdate, Tap +from bokeh.plotting import figure +from bokeh.transform import dodge + +import pipit as pp +from pipit.vis.util import ( + factorize_tuples, + get_factor_cmap, + get_html_tooltips, + get_time_hover_formatter, + get_time_tick_formatter, + show, + trimmed, +) + + +def prepare_data(trace: pp.Trace, show_depth: bool, instant_events: bool): + """Prepare data for plotting the timeline.""" + # Generate necessary metrics + trace.calc_exc_metrics(["Timestamp (ns)"]) + trace._match_events() + trace._match_caller_callee() + trace._match_messages() + + # Prepare data for plotting + events = ( + trace.events[trace.events["Event Type"].isin(["Enter", "Instant"])] + .sort_values(by="time.inc", ascending=False) + .copy(deep=False) + ) + + # Determine y-coordinates from process and depth + y_tuples = ( + list(zip(events["Process"], events["_depth"])) + if show_depth + else list(zip(events["Process"])) + ) + + codes, y_tuples = factorize_tuples(y_tuples) + events["y"] = codes + num_ys = len(y_tuples) + + events["_depth"] = events["_depth"].astype(float).fillna("") + events["name_trimmed"] = trimmed(events["Name"]) + events["_matching_event"] = events["_matching_event"].fillna(-1) + + # Only select a subset of columns for plotting + events = events[ + [ + "Timestamp (ns)", + "_matching_timestamp", + "_matching_event", + "y", + "Name", + "time.inc", + "Process", + "time.exc", + "name_trimmed", + "Event Type", + ] + ] + events["first_letter"] = "" + events.loc[events["Name"] == "MpiSend", "first_letter"] = "S" + events.loc[events["Name"] == "MpiRecv", "first_letter"] = "R" + events.loc[events["Name"] == "MpiIsend", "first_letter"] = "IS" + events.loc[events["Name"] == "MpiIrecv", "first_letter"] = "IR" + events.loc[events["Name"] == "MpiIrecvRequest", "first_letter"] = "IRR" + events.loc[events["Name"] == "MpiIsendComplete", "first_letter"] = "ISC" + events.loc[events["Name"] == "MpiCollectiveBegin", "first_letter"] = "CB" + events.loc[events["Name"] == "MpiCollectiveEnd", "first_letter"] = "CE" + + return events, y_tuples, num_ys + + +def update_cds( + x0: float, + x1: float, + events: pd.DataFrame, + instant_events: bool, + hbar_source: ColumnDataSource, + scatter_source: ColumnDataSource, +) -> None: + """ + Callback function that updates the 3 data sources (hbar_source, scatter_source, + image_source) based on the new range. + + Called when user zooms or pans the timeline (and once initially). + """ + x0 = x0 - (x1 - x0) * 0.25 + x1 = x1 + (x1 - x0) * 0.25 + + # Remove events that are out of bounds + in_bounds = events[ + ( + (events["Event Type"] == "Instant") + & (events["Timestamp (ns)"] > x0) + & (events["Timestamp (ns)"] < x1) + ) + | ( + (events["Event Type"] == "Enter") + & (events["_matching_timestamp"] > x0) + & (events["Timestamp (ns)"] < x1) + ) + ].copy(deep=False) + + # Update hbar_source to keep 5000 largest functions + func = in_bounds[in_bounds["Event Type"] == "Enter"] + large = func + hbar_source.data = large + + # Update scatter_source to keep sampled events + if instant_events: + inst = in_bounds[in_bounds["Event Type"] == "Instant"].copy(deep=False) + + if len(inst) > 500: + inst["bin"] = pd.cut(x=inst["Timestamp (ns)"], bins=1000, labels=False) + + grouped = inst.groupby(["bin", "y"]) + samples = grouped.first().reset_index() + samples = samples[~samples["Timestamp (ns)"].isna()] + + scatter_source.data = samples + else: + scatter_source.data = inst + + +def tap_callback( + event: Tap, + events: pd.DataFrame, + trace: pp.Trace, + show_depth: bool, + p: figure, +) -> None: + """ + Callback function that adds an MPI message arrow when user clicks + on a send or receive event. + """ + x = event.x + y = event.y + + candidates = events[ + (events["Event Type"] == "Instant") + & (events["Name"].isin(["MpiSend", "MpiRecv", "MpiIsend", "MpiIrecv"])) + & (events["y"] == round(y)) + ] + + dx = candidates["Timestamp (ns)"] - x + distance = pd.Series(dx * dx) + + selected = candidates.iloc[distance.argsort().values] + + if len(selected) >= 1: + selected = selected.iloc[0] + + match = trace._get_matching_p2p_event(selected.name) + send = ( + selected + if selected["Name"] in ["MpiSend", "MpiIsend"] + else events.loc[match] + ) + recv = ( + selected + if selected["Name"] in ["MpiRecv", "MpiIrecv"] + else events.loc[match] + ) + + arrow = Arrow( + end=OpenHead(line_color="#28282B", line_width=1.5, size=8), + line_color="#28282B", + line_width=1.5, + x_start=send["Timestamp (ns)"], + y_start=send["y"] - 0.2 if show_depth else send["y"], + x_end=recv["Timestamp (ns)"], + y_end=recv["y"] - 0.2 if show_depth else recv["y"], + level="overlay", + ) + p.add_layout(arrow) + + +def plot_timeline( + trace: pp.Trace, + show_depth: bool = False, + instant_events: bool = False, + critical_path: bool = False, + messages: str = "click", + x_start: float = None, + x_end: float = None, + width: int = None, + height: int = None, + legend_nrows: int = None, +): + """ + Displays the events of a trace on a timeline. + + Instant events are drawn as points, function calls are drawn as horizontal bars, + and MPI messages are drawn as arrows. + + Args: + trace: The trace to be visualized. + show_depth: Whether to show the depth of the function calls. + instant_events: Whether to show instant events. + critical_path: Whether to show the critical path. NOTE: critical_path currently + only works when show_depth==False. TODO: make it work with show_depth=True. + show_messages: Whether to show MPI messages. Can be "click" (default), or "all". + x_start: The start time of the x-axis range. + x_end: The end time of the x-axis range. + width: The width of the plot. Default is None, which makes the plot full width. + height: The height of the plot. Default is None, which makes the plot adapt to + the number of ticks on the y-axis. + legend_nrows: The number of rows in the legend. Default is None, which makes the + legend adapt to the number of items. + + Returns: + The Bokeh plot. + """ + + # Prepare data to be plotted + events, y_tuples, num_ys = prepare_data(trace, show_depth, instant_events) + + # Define the 3 data sources (Bokeh ColumnDataSource) + hbar_source = ColumnDataSource(events.head(0)) + scatter_source = ColumnDataSource(events.head(0)) + image_source = ColumnDataSource( + data=dict( + image=[np.zeros((50, 16), dtype=np.uint32)], x=[0], y=[0], dw=[0], dh=[0] + ) + ) + + # Create Bokeh plot + if x_start is None: + x_start = events["Timestamp (ns)"].min() + if x_end is None: + x_end = ( + events["Timestamp (ns)"].max() + + (events["Timestamp (ns)"].max() - events["Timestamp (ns)"].min()) * 0.05 + ) + + height = height if height is not None else 150 + 30 * num_ys + p = figure( + x_range=(x_start, x_end), + y_range=(num_ys - 0.5, -0.5), + x_axis_location="above", + tools="hover,xpan,reset,xbox_zoom,xwheel_zoom,save", + output_backend="webgl", + height=min(500, height), + sizing_mode="stretch_width" if width is None else "fixed", + width=width, + toolbar_location=None, + x_axis_label="Time", + ) + + # Define color mappings + fill_cmap = get_factor_cmap("Name", trace) + line_cmap = get_factor_cmap("Name", trace, scale=0.7) + + # Add glyphs + # Bars for "large" functions + hbar = p.hbar( + left="Timestamp (ns)", + right="_matching_timestamp", + y="y", + height=0.8 if show_depth else 0.8, + source=hbar_source, + fill_color=fill_cmap, + line_color=line_cmap, + line_width=1, + line_alpha=0.5, + legend_field="name_trimmed", + ) + + # Image for small functions + p.image_rgba(source=image_source) + + # Scatter for instant events + if instant_events: + scatter = p.scatter( + x="Timestamp (ns)", + y=dodge("y", -0.2 if show_depth else 0), + # size=9, + line_color="#0868ac", + alpha=1, + color="#ccebc5", + line_width=0.8, + marker="diamond", + source=scatter_source, + legend_label="Instant event", + ) + + # Arrows for MPI messages + if messages == "all": + sends = events[events["Name"].isin(["MpiSend", "MpiIsend"])] + for i in range(len(sends)): + p.add_layout( + Arrow( + end=OpenHead(), + x_start=sends["Timestamp (ns)"].iloc[i], + y_start=( + sends["y"].iloc[i] - 0.2 if show_depth else sends["y"].iloc[i] + ), + x_end=events.loc[sends["_matching_event"].iloc[i]][ + "Timestamp (ns)" + ], + y_end=( + events.loc[sends["_matching_event"].iloc[i]]["y"] - 0.2 + if show_depth + else events.loc[sends["_matching_event"].iloc[i]]["y"] + ), + level="annotation", + ) + ) + + # Arrows for critical path + if critical_path: + critical_dfs = trace.critical_path_analysis() + for df in critical_dfs: + # Draw hatch pattern + p.hbar( + left="Timestamp (ns)", + right="_matching_timestamp", + y="Process", + height=0.8, + source=df, + fill_color=None, + line_color=None, + hatch_color="white", + hatch_pattern="right_diagonal_line", + ) + + # Draw arrows + # TODO: can we vectorize this? + for i in range(len(df) - 1): + p.add_layout( + Arrow( + end=OpenHead(line_color="black", line_width=2, size=9), + line_color="black", + line_width=2, + x_start=df["Timestamp (ns)"].iloc[i], + y_start=df["Process"].iloc[i], + x_end=df["Timestamp (ns)"].iloc[i + 1], + y_end=df["Process"].iloc[i + 1], + level="overlay", + ) + ) + + # Additional plot config + p.toolbar.active_scroll = p.select(dict(type=WheelZoomTool))[0] + + # Grid config + depth_ticks = np.arange(0, num_ys) + process_ticks = np.array( + [i for i, v in enumerate(y_tuples) if len(v) == 1 or v[1] == 0] + ) + p.ygrid.visible = False + g1 = Grid( + dimension=1, + grid_line_color="white", + grid_line_width=2 if show_depth else 2, + ticker=FixedTicker( + ticks=np.concatenate([depth_ticks - 0.49, depth_ticks + 0.49]) + ), + level="glyph", + ) + g2 = Grid( + dimension=1, + grid_line_width=2, + band_fill_color="gray", + band_fill_alpha=0.1, + ticker=FixedTicker(ticks=process_ticks - 0.5), + level="glyph", + ) + p.add_layout(g1) + p.add_layout(g2) + + # Axis config + p.xaxis.formatter = get_time_tick_formatter() + p.yaxis.formatter = CustomJSTickFormatter( + args={ + "y_tuples": y_tuples, + }, + code=""" + return "Process " + y_tuples[Math.floor(tick)][0]; + """, + ) + p.yaxis.ticker = FixedTicker(ticks=process_ticks + 0.1) + p.yaxis.major_tick_line_color = None + + # Legend config + p.add_layout(p.legend[0], "below") + p.legend.orientation = "horizontal" + p.legend.location = "center" + p.legend.nrows = legend_nrows if legend_nrows is not None else "auto" + + # Hover config + hover = p.select(HoverTool) + hover.tooltips = get_html_tooltips( + { + "Name": "@Name", + # "Process": "@Process", + "Enter": "@{Timestamp (ns)}{custom} [@{index}]", + "Leave": "@{_matching_timestamp}{custom} [@{_matching_event}]", + "Time (Inc)": "@{time.inc}{custom}", + "Time (Exc)": "@{time.exc}{custom}", + } + ) + hover.formatters = { + "@{Timestamp (ns)}": get_time_hover_formatter(), + "@{_matching_timestamp}": get_time_hover_formatter(), + "@{time.inc}": get_time_hover_formatter(), + "@{time.exc}": get_time_hover_formatter(), + } + hover.renderers = [hbar, scatter] if instant_events else [hbar] + hover.callback = CustomJS( + code=""" + let hbar_tooltip = document.querySelector('.bk-tooltip'); + let scatter_tooltip = hbar_tooltip.nextElementSibling; + + if (hbar_tooltip && scatter_tooltip && + hbar_tooltip.style.display != 'none' && + scatter_tooltip.style.display != 'none') + { + hbar_tooltip.style.display = 'none'; + } + """ + ) + + # Add interactive callbacks (these happen on the Python side) + p.on_event( + RangesUpdate, + lambda event: update_cds( + event.x0, event.x1, events, instant_events, hbar_source, scatter_source + ), + ) + + if messages == "click": + p.on_event(Tap, lambda event: tap_callback(event, events, trace, show_depth, p)) + + # Make initial call to callback + update_cds(x_start, x_end, events, instant_events, hbar_source, scatter_source) + + # Return plot + return show(p) diff --git a/pipit/vis/util.py b/pipit/vis/util.py new file mode 100644 index 00000000..e60df9b0 --- /dev/null +++ b/pipit/vis/util.py @@ -0,0 +1,454 @@ +# Copyright 2022 Parallel Software and Systems Group, University of Maryland. +# See the top-level LICENSE file for details. +# +# SPDX-License-Identifier: MIT + +from bokeh.io import output_notebook, show as bk_show +from bokeh.models import ( + CustomJSHover, + CustomJSTickFormatter, + PrintfTickFormatter, + NumeralTickFormatter, + BasicTicker, +) + +# from bokeh.palettes import Category20_20 +from bokeh.transform import factor_cmap +from bokeh.themes import Theme +import yaml + +import math +import numpy as np +import pandas as pd + +import pipit as pp + + +# Formatters +def get_process_ticker(nranks): + return BasicTicker( + base=2, desired_num_ticks=min(nranks, 16), min_interval=1, num_minor_ticks=0 + ) + + +def get_process_tick_formatter(): + return PrintfTickFormatter(format="Process %d") + + +def format_time(n: float) -> str: + """Converts timestamp/timedelta from ns to human-readable time""" + # Adapted from https://github.com/dask/dask/blob/main/dask/utils.py + + if n >= 1e9 * 24 * 60 * 60 * 2: + d = int(n / 1e9 / 3600 / 24) + h = int((n / 1e9 - d * 3600 * 24) / 3600) + return f"{d}d {h}hr" + + if n >= 1e9 * 60 * 60 * 2: + h = int(n / 1e9 / 3600) + m = int((n / 1e9 - h * 3600) / 60) + return f"{h}hr {m}m" + + if n >= 1e9 * 60 * 10: + m = int(n / 1e9 / 60) + s = int(n / 1e9 - m * 60) + return f"{m}m {s}s" + + if n >= 1e9: + return "%.2f s" % (n / 1e9) + + if n >= 1e6: + return "%.2f ms" % (n / 1e6) + + if n >= 1e3: + return "%.2f us" % (n / 1e3) + + return "%.2f ns" % n + + +# JS expression equivalent to `format_time` function above; assumes: +# - `x` is the value (in ns) being compared to determine units +# - `y` is the value (in ns) actually being formatted +JS_FORMAT_TIME = """ + if (x >= 1e9 * 24 * 60 * 60 * 2) { + d = Math.round(y / 1e9 / 3600 / 24) + h = Math.round((y / 1e9 - d * 3600 * 24) / 3600) + return `${d}d ${h}hr` + } + + if (x >= 1e9 * 60 * 60 * 2) { + h = Math.round(y / 1e9 / 3600) + m = Math.round((y / 1e9 - h * 3600) / 60) + return `${h}hr ${m}m` + } + + if (x >= 1e9 * 60 * 10) { + m = Math.round(y / 1e9 / 60) + s = Math.round(y / 1e9 - m * 60) + return `${m}m ${s}s` + } + + if (x >= 1e9) + return (y / 1e9).toFixed(2) + "s" + + if (x >= 1e6) + return (y / 1e6).toFixed(2) + "ms" + + if (x >= 1e3) { + var ms = Math.floor(y / 1e6); + var us = ((y - ms * 1e6) / 1e3); + + var str = ""; + if (ms) str += ms + "ms "; + if (us) str += Math.round(us) + "us"; + else str += "0us"; + + return str; + } + + var ms = Math.floor(y / 1e6); + var us = Math.floor((y - ms * 1e6) / 1e3); + var ns = Math.round(y % 1000); + + var str = ""; + + if (ms) str += ms + "ms "; + if (us) str += us + "us "; + if (ns) str += ns + "ns"; + else if (!us) str += "0ns"; + + return str; +""" + + +# Used to format ticks for time-based axes +def get_time_tick_formatter(): + return CustomJSTickFormatter( + code=f""" + let x = Math.max(...ticks) - Math.min(...ticks); + let y = tick; + {JS_FORMAT_TIME} + """ + ) + + +# Used to format tooltips for time-based values +def get_time_hover_formatter(): + return CustomJSHover( + code=f""" + let x = value; + let y = value; + {JS_FORMAT_TIME} + """, + ) + + +def format_size(b): + """Converts bytes to something more readable""" + + if b < 1e3: # Less than 1 kB -> byte + return f"{b:.2f} B" + if b < 1e6: # Less than 1 MB -> kB + return f"{(b / 1e3):.2f} kB" + if b < 1e9: # Less than 1 GB -> MB + return f"{(b / 1e6):.2f} MB" + if b < 1e12: # Less than 1 TB -> GB + return f"{(b / 1e9):.2f} GB" + if b < 1e15: # Less than 1 PB -> TB + return f"{(b / 1e12):.2f} TB" + else: + return f"{(b / 1e15):.2f} PB" + + +# JS expression equivalent to `format_size` function above; assumes: +# - `x` is the value (in bytes) being compared to determine units +# - `y` is the value (in bytes) actually being formatted +JS_FORMAT_SIZE = """ + if(x < 1e3) + return (y).toFixed(2) + " B"; + if(x < 1e6) + return (y / 1e3).toFixed(2) + " kB"; + if(x < 1e9) + return (y / 1e6).toFixed(2) + " MB"; + if(x < 1e12) + return (y / 1e9).toFixed(2) + " GB"; + if(x < 1e15) + return (y / 1e12).toFixed(2) + " TB"; + else + return (y / 1e15).toFixed(2) + " PB"; +""" + + +# Used to format ticks for size-based axes +def get_size_tick_formatter(ignore_range=False): + x = "tick" if ignore_range else "Math.max(...ticks) - Math.min(...ticks);" + return CustomJSTickFormatter( + code=f""" + let x = {x} + let y = tick; + {JS_FORMAT_SIZE} + """ + ) + + +# Used to format tooltips for size-based values +def get_size_hover_formatter(): + return CustomJSHover( + code=f""" + let x = value; + let y = value; + {JS_FORMAT_SIZE} + """, + ) + + +def get_percent_tick_formatter(): + return NumeralTickFormatter(format="0.0%") + + +def get_percent_hover_formatter(): + return CustomJSHover( + code=""" + return parseFloat(value * 100).toFixed(2)+"%" + """ + ) + + +# TODO: maybe do this client side with transform +def trimmed(names: pd.Series) -> pd.Series: + return np.where( + names.str.len() < 30, names, names.str[0:10] + "..." + names.str[-5:] + ) + + +def get_trimmed_tick_formatter(): + return CustomJSTickFormatter( + code=""" + if (tick.length < 30) { + return tick; + } else { + return tick.substr(0, 20) + "..." + + tick.substr(tick.length - 5, tick.length); + } + """ + ) + + +# Helper functions + + +def in_notebook(): + """Returns True if we are in notebook environment, False otherwise""" + try: + from IPython import get_ipython + + if "IPKernelApp" not in get_ipython().config: # pragma: no cover + return False + except ImportError: + return False + except AttributeError: + return False + return True + + +def show(p, return_fig=False): + """Used to wrap return values of plotting functions. + + If return_figure is True, then just returns the figure object, otherwise starts a + Bokeh server containing the figure. If we are in a notebook, displays the + figure in the output cell, otherwise shows figure in new browser tab. + + See https://docs.bokeh.org/en/latest/docs/user_guide/output/jupyter.html#bokeh-server-applications, # noqa E501 + https://docs.bokeh.org/en/latest/docs/user_guide/server/library.html. + """ + if return_fig: + return p + + # Create a Bokeh app containing the figure + def bkapp(doc): + doc.clear() + doc.add_root(p) + doc.theme = Theme( + json=yaml.load( + pp.get_option("theme"), + Loader=yaml.FullLoader, + ) + ) + + if in_notebook(): + # If notebook, show it in output cell + output_notebook(hide_banner=True) + bk_show(bkapp, notebook_url=pp.get_option("notebook_url")) + else: + # If standalone, start HTTP server and show in browser + from bokeh.server.server import Server + + server = Server({"/": bkapp}, port=0, allow_websocket_origin=["*"]) + server.start() + server.io_loop.add_callback(server.show, "/") + server.io_loop.start() + + +def clamp(value, min_val, max_val): + """Clamps value to min and max bounds""" + + if value < min_val: + return min_val + if value > max_val: + return max_val + return value + + +def get_html_tooltips(tooltips_dict): + html = """ +
+
+
+ """ + for k, v in tooltips_dict.items(): + html += f""" +
+
+ {k}: +
+
+ {v} +
+
+ """ + html += """ +
+ + + """ + return html + + +def factorize_tuples(tuples_list): + unique_values = sorted(set(tuples_list)) + value_to_index = {value: i for i, value in enumerate(unique_values)} + codes = [value_to_index[value] for value in tuples_list] + return codes, list(unique_values) + + +def hex_to_rgb(hex): + hex = hex.strip("#") + + r, g, b = int(hex[:2], 16), int(hex[2:4], 16), int(hex[4:], 16) + return (r, g, b) + + +def rgb_to_hex(rgb): + r, g, b = rgb + return "#%02x%02x%02x" % (int(r), int(g), int(b)) + + +def average_hex(*hex): + """Averages any number of hex colors, returns result in hex""" + colors = [hex_to_rgb(h) for h in hex] + return rgb_to_hex(np.mean(colors, axis=0)) + + +def scale_hex(hex, scale): + """Multiplies a hex color by a scalar, returns result in hex""" + if scale < 0 or len(hex) != 7: + return hex + + r, g, b = hex_to_rgb(hex) + + r = int(clamp(r * scale, 0, 255)) + g = int(clamp(g * scale, 0, 255)) + b = int(clamp(b * scale, 0, 255)) + + return rgb_to_hex((r, g, b)) + + +def get_height(num_yticks, height_per_tick=400): + """Calculates ideal plot height based on number of y ticks""" + return clamp(int(math.log10(num_yticks) * height_per_tick + 50), 200, 900) + + +LIGHT = [ + "#aec7e8", + # "#ffbb78", # reserved for sim_life_1d + "#98df8a", + "#ff9896", + "#c5b0d5", + "#c49c94", + "#f7b6d2", + # "#c7c7c7", # reserved for idle + "#dbdb8d", + "#9edae5", +] +DARK = [ + # "#1f77b4", # reserved for send + "#ff7f0e", + "#2ca02c", + # "#d62728", # reserved for recv + "#9467bd", + "#8c564b", + "#e377c2", + # "#7f7f7f", # take out + "#bcbd22", + "#17becf", +] + + +def get_palette(trace, scale=None): + funcs = trace.events[trace.events["Event Type"] == "Enter"] + # names = funcs["Name"].unique().tolist() + names = reversed(trace.flat_profile(["time.exc"]).index.tolist()) + + depths = ( + funcs.groupby("Name", observed=True)["_depth"] + .agg(lambda x: x.value_counts().index[0]) + .to_dict() + ) + + palette = {} + + palette["MPI_Send"] = "#1f77b4" + palette["MPI_Isend"] = "#1f77b4" + + palette["MPI_Recv"] = "#d62728" + palette["MPI_Irecv"] = "#d62728" + + palette["MPI_Wait"] = "#c7c7c7" + palette["MPI_Waitany"] = "#c7c7c7" + palette["MPI_Waitall"] = "#c7c7c7" + palette["Idle"] = "#c7c7c7" + + palette["sim_life_1d"] = "#ffbb78" + + dark_index = 0 + light_index = 0 + + for i, f in enumerate(names): + if f not in palette: + if depths[f] % 2 == 0: + # palette[f] = LIGHT[hash(f) % len(LIGHT)] + palette[f] = LIGHT[light_index % len(LIGHT)] + light_index += 1 + else: + # palette[f] = DARK[hash(f) % len(DARK)] + palette[f] = DARK[dark_index % len(DARK)] + dark_index += 1 + + # apply multiplier + if scale: + for k, v in palette.items(): + palette[k] = scale_hex(v, scale) + + return palette + + +def get_factor_cmap(field_name, trace, **kwargs): + palette = get_palette(trace, **kwargs) + return factor_cmap(field_name, list(palette.values()), list(palette.keys())) diff --git a/requirements.txt b/requirements.txt index 1f317793..7ad07110 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ numpy otf2 pandas +bokeh \ No newline at end of file