Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 44 additions & 24 deletions .github/container/nsys_jax/nsys_jax/analysis.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ def align_profiler_data_timestamps(
# Error if the communication frame doesn't exist at all, but not if it is empty.
# Calling this on a profile that does not contain any communication should
# gracefully yield empty results.
assert frames.communication is not None, (
"align_profiler_data_timestamps requires a communication frame"
)
assert (
frames.communication is not None
), "align_profiler_data_timestamps requires a communication frame"
if not len(frames.communication):
# Nothing to be done, return an empty result
return frames, {}
Expand All @@ -43,9 +43,9 @@ def align_profiler_data_timestamps(
f"WARNING: cannot align {num_profiled_devices} devices because max collective size is 1"
)
return frames, {}
assert num_profiled_devices == max_collective_size, (
f"Aligning {num_profiled_devices} using collectives of size {max_collective_size} is not implemented"
)
assert (
num_profiled_devices == max_collective_size
), f"Aligning {num_profiled_devices} using collectives of size {max_collective_size} is not implemented"
# Find the collectives that will be used
align_df = comm_df[comm_df["CollectiveSize"] == max_collective_size]
# Calculate the collectives' end times
Expand Down Expand Up @@ -189,18 +189,19 @@ def _get_message_size(
) -> tuple[int, str, int, float, float]:
_, inst = module_proto.find_instruction(instruction)
comm_inst = inst.communication_proto()
assert comm_inst.opcode in {
"all-gather-start",
"all-reduce-start",
"all-to-all",
"collective-broadcast",
"collective-permute-start",
"dynamic-slice",
"dynamic-update-slice",
"reduce-scatter",
}, (
f"{instruction}: message size calculation for {comm_inst.opcode} has not yet been validated"
)
assert (
comm_inst.opcode
in {
"all-gather-start",
"all-reduce-start",
"all-to-all",
"collective-broadcast",
"collective-permute-start",
"dynamic-slice",
"dynamic-update-slice",
"reduce-scatter",
}
), f"{instruction}: message size calculation for {comm_inst.opcode} has not yet been validated"

def _byte_size(inst) -> int:
size_bits = math.prod(
Expand Down Expand Up @@ -254,9 +255,9 @@ def _byte_size(inst) -> int:
collective_size = iota_group_list.num_devices_per_group
else:
collective_sizes = set(len(group.replica_ids) for group in replica_groups)
assert len(collective_sizes) == 1, (
f"Heterogeneous collective {comm_inst} could not be interpreted"
)
assert (
len(collective_sizes) == 1
), f"Heterogeneous collective {comm_inst} could not be interpreted"
collective_size = next(iter(collective_sizes))
total_msg_size = 0
for operand_id in comm_inst.operand_ids:
Expand Down Expand Up @@ -391,10 +392,29 @@ def generate_compilation_statistics(compile_df: pd.DataFrame) -> pd.DataFrame:
# Assuming there's only one parallel region inside `launcher_row`
parallel_start = child_df.loc[~is_main, "StartMs"].min()
parallel_end = child_ends[~is_main].max()
# Assert that there are no main-thread tasks during this period
main_before = is_main & (child_ends < parallel_start)
main_after = is_main & (child_df["StartMs"] > parallel_end)
# Check for main-thread tasks that don't overlap with the parallel period
main_before = is_main & (child_ends <= parallel_start)
main_after = is_main & (child_df["StartMs"] >= parallel_end)
# Identify any main-thread tasks that overlap with the parallel period
main_overlap = is_main & ~(main_before | main_after)
if main_overlap.any():
# Some main-thread tasks overlap with the parallel region.
# This can happen with NCCL operations or other intermediate operations.
# Classify them based on where most of their duration falls.
overlap_tasks = child_df.loc[main_overlap]
for idx in overlap_tasks.index:
task_start = child_df.loc[idx, "StartMs"]
task_end = child_ends.loc[idx]
# Calculate how much time is before, during, and after the parallel region
before_time = max(0, min(task_end, parallel_start) - task_start)
after_time = max(0, task_end - max(task_start, parallel_end))
# Classify based on which is larger
if before_time > after_time:
main_before.loc[idx] = True
else:
main_after.loc[idx] = True
assert ((main_before | main_after) == is_main).all()

# Aggregate statistics for how the worker threads spend their time and use that
# distribution to divide up the [parallel_start, parallel_end] range of the overall
# compilation time.
Expand Down
51 changes: 46 additions & 5 deletions .github/container/nsys_jax/nsys_jax/data_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,17 +249,50 @@ def _load_nvtx_gpu_proj_trace_single(
# get the names of the ranges referred to by ModuleId
mod_id_names = df.loc[mod_ids, "Name"]
assert mod_ids.shape == mod_id_names.shape

# Get a mask in mod_id_names of entries where ModuleId in the original
# Thunk is not referring to a Module yet. Intermediate levels of the
# hierarchy can be other thunks (e.g. an individual graph node may
# have a thunk representing the whole graph as a parent).
# have a thunk representing the whole graph as a parent) or other
# range types like NCCL operations.
mask = ~mod_id_names.str.startswith(module_prefix)
assert (mask == mod_id_names.str.startswith(thunk_prefix)).all()

# Identify which non-module entries are thunks vs other range types
is_thunk = mod_id_names.str.startswith(thunk_prefix)
is_module = mod_id_names.str.startswith(module_prefix)

# Assert that we only have modules, thunks, or known intermediate ranges
# This catches unexpected range types in the hierarchy
is_recognized = (
is_thunk | is_module | mod_id_names.str.contains("nccl", case=False)
)
assert is_recognized.all(), f"Found unrecognized range types in hierarchy: {mod_id_names[~is_recognized].unique()}"

# Assert that mask is consistent with module detection
assert (mask == ~is_module).all(), "Mask inconsistency with module detection"

# Convert to numpy arrays for cross-indexing operations
# (mod_ids and mask have different pandas indices)
mod_ids_array = mod_ids.values
mask_array = mask.values
is_thunk_array = is_thunk.values

# Assert that all non-module entries have valid parent IDs to continue navigation
non_module_mod_ids = mod_ids_array[mask_array]
assert (
df.loc[non_module_mod_ids, "ParentId"].notna().all()
), "Found non-module entries without valid parent IDs, cannot navigate up hierarchy"

assert mask.shape == mod_ids.shape

# We want to end up without all_thunks containing thunks with child
# thunks, as noted above.
thunk_ids_with_child_thunks = mod_ids.array[mask]
all_thunks[thunk_ids_with_child_thunks] = False
# thunks, as noted above. Only filter out thunks, not other range types.
thunk_ids_with_child_thunks = mod_ids_array[mask_array & is_thunk_array]
# Only update indices that actually exist in all_thunks to prevent reindexing
existing_indices = all_thunks.index.intersection(thunk_ids_with_child_thunks)
if len(existing_indices) > 0:
all_thunks[existing_indices] = False

# Set thunk_ids to be the (shorter) list of indices (in df) of the
# Thunks whose ModuleId values need to be updated
thunk_ids = thunk_ids[mask]
Expand All @@ -271,6 +304,10 @@ def _load_nvtx_gpu_proj_trace_single(

# Now all the Thunks should have ModuleId pointing to an XlaModule range.
mod_ids = sorted(set(df.loc[all_thunks, "ModuleId"].astype(np.int32)))

# Ensure all_thunks only contains indices that exist in df
all_thunks = all_thunks.reindex(df.index, fill_value=False)

assert df.loc[all_thunks, "Name"].str.startswith(thunk_prefix).all()
assert df.loc[mod_ids, "Name"].str.startswith(module_prefix).all()

Expand Down Expand Up @@ -397,6 +434,10 @@ def clean_data_frame(d):
if "thunk" in frames:
# At this point there should be no need to look beyond the rows for individual
# thunks + the protobuf data, and we can further clean up the data.

# Ensure all_thunks is aligned with df one final time before using it
all_thunks = all_thunks.reindex(df.index, fill_value=False)

thunk_df = clean_data_frame(df[all_thunks])
thunk_df["Name"] = thunk_df["Name"].str.replace(
pat=f"^{tsl_prefix}Thunk:#(?:name=.*?,|)hlo_op=([a-z0-9._-]+)#$",
Expand Down
Loading