Skip to content

Commit

Permalink
nsys-jax: add example of command-line post-processing/summarisation (#…
Browse files Browse the repository at this point in the history
…936)

`nsys-jax --nsys-jax-analysis summary python your_program.py` will now
execute the "summary" analysis script after profile collection, giving
lower-latency feedback. Outputs of these analysis scripts will be
included in the `.zip` archive for convenience.

Many lines of the diff are due to changing the "natural" timestamp
format from nanoseconds to milliseconds, allowing many factors of `1e-6`
to be removed.

Add extra logic to `nsys-jax` to take advantage of openxla/xla#14092.

Also added a CI job testing execution of `nsys-jax-combine`.
  • Loading branch information
olupton authored Jul 9, 2024
1 parent 5e8afad commit 2f1f572
Show file tree
Hide file tree
Showing 18 changed files with 863 additions and 297 deletions.
2 changes: 1 addition & 1 deletion .github/container/Dockerfile.base
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ COPY check-shm.sh /opt/nvidia/entrypoint.d/
## written by nsys-jax, while the nsys-jax wrapper is used inside the container.
###############################################################################

ADD nsys-jax /usr/local/bin
ADD nsys-jax nsys-jax-combine /usr/local/bin/
ADD jax_nsys/ /opt/jax_nsys
ADD requirements-nsys-jax.in /opt/pip-tools.d/
RUN ln -s /opt/jax_nsys/install-protoc /usr/local/bin/
Expand Down
293 changes: 148 additions & 145 deletions .github/container/jax_nsys/Analysis.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
)
from .data_loaders import load_profiler_data
from .protobuf import xla_module_metadata
from .protobuf_utils import compile_protos
from .utils import remove_child_ranges
from .protobuf_utils import compile_protos, ensure_compiled_protos_are_importable
from .utils import remove_autotuning_detail, remove_child_ranges
from .visualization import create_flamegraph, display_flamegraph

__all__ = [
Expand All @@ -17,8 +17,10 @@
"compile_protos",
"create_flamegraph",
"display_flamegraph",
"ensure_compiled_protos_are_importable",
"generate_compilation_statistics",
"load_profiler_data",
"remove_autotuning_detail",
"remove_child_ranges",
"xla_module_metadata",
]
98 changes: 60 additions & 38 deletions .github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import math
import numpy as np
import pandas as pd # type: ignore
import pathlib
from typing import Any

from .protobuf import xla_module_metadata
from .protobuf import HloProto, xla_module_metadata
from .utils import make_child_mask, ProfilerData

pd.options.mode.copy_on_write = True
Expand Down Expand Up @@ -38,7 +39,7 @@ def align_profiler_data_timestamps(
align_df = comm_df[comm_df["CollectiveSize"] == max_collective_size]
# Calculate the collectives' end times
end_times = (
align_df["ProjStartNs"] + align_df["ProjDurNs"] + align_df["ProjDurHiddenNs"]
align_df["ProjStartMs"] + align_df["ProjDurMs"] + align_df["ProjDurHiddenMs"]
)
# For each collective, calculate the mean end time of each collective across devices
mean_end_times = end_times.groupby(
Expand All @@ -51,10 +52,10 @@ def align_profiler_data_timestamps(
# Apply these corrections to the device-side timestamps
for k in ["communication", "module", "thunk"]:
df = getattr(frames, k)
df["ProjStartNs"] -= median_device_skews
df["ProjStartMs"] -= median_device_skews
setattr(frames, k, df)
return frames, {
"collective_end_time_skews_ns": end_time_skews,
"collective_end_time_skews_ms": end_time_skews,
"device_corrections": median_device_skews,
"collective_size": max_collective_size,
}
Expand Down Expand Up @@ -100,7 +101,9 @@ def apply_warmup_heuristics(frames: ProfilerData) -> tuple[ProfilerData, Profile
prog_exec_values = df.index.get_level_values("ProgramExecution")
init_mask = compile_mask & (prog_exec_values == 0)
steady_mask = ~compile_mask | (prog_exec_values > 1)
assert steady_mask.any(), "No steady-state executions identified, profile collection may have been too short"
assert (
len(df) == 0 or steady_mask.any()
), "No steady-state executions identified, profile collection may have been too short"
assert (prog_exec_values[~init_mask & ~steady_mask] == 1).all()
setattr(init, k, df[init_mask])
setattr(steady, k, df[steady_mask])
Expand Down Expand Up @@ -155,16 +158,9 @@ def _collective_correction(kind: str, size: int) -> tuple[float, float]:
assert False, f"Unknown collective kind {kind}"


@functools.lru_cache
def get_message_size(program_id: int, instruction: str) -> pd.Series:
"""
Given the name of a collective instruction (e.g. all-gather-start.N), calculate the
message size in bytes. See https://openxla.org/xla/operation_semantics#allgather,
https://openxla.org/xla/operation_semantics#allreduce and so on for more explanation
of the semantics. This implementation aims to follow the same conventions that NCCL
uses in its NVTX payloads and tests.
"""
module_proto = xla_module_metadata(program_id)
def _get_message_size(
module_proto: HloProto, instruction: str
) -> tuple[int, str, int, float, float]:
_, inst = module_proto.find_instruction(instruction)
assert (
inst.opcode
Expand Down Expand Up @@ -206,8 +202,24 @@ def get_message_size(program_id: int, instruction: str) -> pd.Series:

collective = inst.opcode.removesuffix("-start")
bw_correction, bus_correction = _collective_correction(collective, collective_size)
return (total_msg_size, collective, collective_size, bw_correction, bus_correction)


@functools.lru_cache
def get_message_size(
program_id: int, instruction: str, prefix: pathlib.Path
) -> pd.Series:
"""
Given the name of a collective instruction (e.g. all-gather-start.N), calculate the
message size in bytes. See https://openxla.org/xla/operation_semantics#allgather,
https://openxla.org/xla/operation_semantics#allreduce and so on for more explanation
of the semantics. This implementation aims to follow the same conventions that NCCL
uses in its NVTX payloads and tests.
"""
return pd.Series(
[total_msg_size, collective, collective_size, bw_correction, bus_correction],
xla_module_metadata(program_id, prefix=prefix, policy="all").unique_result(
lambda proto: _get_message_size(proto, instruction)
),
index=[
"MessageSize",
"Collective",
Expand All @@ -218,25 +230,33 @@ def get_message_size(program_id: int, instruction: str) -> pd.Series:
)


def calculate_collective_metrics(thunk_df: pd.DataFrame) -> pd.DataFrame:
def calculate_collective_metrics(
thunk_df: pd.DataFrame, prefix: pathlib.Path
) -> pd.DataFrame:
"""
Given a "thunk" data frame from `load_profiler_data`, produce a new data frame that
contains one row per communication thunk and contains extra metrics such as the
message size, algorithm bandwidth, bus bandwidth, and collective operation.
"""
comm_df = thunk_df[thunk_df["Communication"]].drop(columns=["Communication"])
if len(comm_df) == 0:
return comm_df
comm_df = pd.concat(
[
comm_df,
comm_df.apply(lambda row: get_message_size(row.name[0], row.Name), axis=1),
comm_df.apply(
lambda row: get_message_size(row.name[0], row.Name, prefix=prefix),
axis=1,
),
],
axis=1,
)
# Note that this is decimal GB not binary GiB; GB/s == B/ns
# Note that this is decimal GB not binary GiB; GB/s == B/ns == 1e-6 * B / ms
comm_df["AlgorithmBandwidthGBPerSec"] = (
comm_df["BandwidthCorrection"]
1e-6
* comm_df["BandwidthCorrection"]
* comm_df["MessageSize"]
/ (comm_df["ProjDurNs"] + comm_df["ProjDurHiddenNs"])
/ (comm_df["ProjDurMs"] + comm_df["ProjDurHiddenMs"])
)
comm_df["BusBandwidthGBPerSec"] = (
comm_df["AlgorithmBandwidthGBPerSec"] * comm_df["BusBandwidthCorrection"]
Expand All @@ -262,7 +282,7 @@ def generate_compilation_statistics(compile_df: pd.DataFrame) -> pd.DataFrame:
statistics.
"""
# Aggregate compilation stats in here
compile_time_ns: dict[str, np.ndarray] = defaultdict(lambda: np.zeros(2))
compile_time_ms: dict[str, np.ndarray] = defaultdict(lambda: np.zeros(2))
for profile_name, profile_df in compile_df.groupby("ProfileName"):
# Identify the main thread
main_thread = profile_df.loc[compile_df["Name"] == "XlaCompile", "TID"].unique()
Expand All @@ -275,7 +295,9 @@ def generate_compilation_statistics(compile_df: pd.DataFrame) -> pd.DataFrame:
profile_df["TID"].ne(main_thread), "ParentId"
].astype(np.int32)
# These are the main-thread ranges that directly contain parallel workers
launcher_mask = profile_df.loc[worker_parent_ids, "TID"].eq(main_thread)
launcher_mask = profile_df.loc[(profile_name, worker_parent_ids), "TID"].eq(
main_thread
)
launcher_ids = launcher_mask[launcher_mask].index.unique()
# Loop over the main-thread ranges that launched parallel work
for launcher_row in profile_df.loc[launcher_ids, :].itertuples():
Expand All @@ -290,41 +312,41 @@ def generate_compilation_statistics(compile_df: pd.DataFrame) -> pd.DataFrame:
# could be relaxed if needed.
child_df = profile_df[make_child_mask(profile_df, launcher_row.Index)]
is_main = child_df["TID"] == launcher_row.TID
child_ends = child_df["StartNs"] + child_df["DurNs"]
child_ends = child_df["StartMs"] + child_df["DurMs"]
# Assuming there's only one parallel region inside `launcher_row`
parallel_start = child_df.loc[~is_main, "StartNs"].min()
parallel_start = child_df.loc[~is_main, "StartMs"].min()
parallel_end = child_ends[~is_main].max()
# Assert that there are no main-thread tasks during this period
main_before = is_main & (child_ends < parallel_start)
main_after = is_main & (child_df["StartNs"] > parallel_end)
main_after = is_main & (child_df["StartMs"] > parallel_end)
assert ((main_before | main_after) == is_main).all()
# Aggregate statistics for how the worker threads spend their time and use that
# distribution to divide up the [parallel_start, parallel_end] range of the overall
# compilation time.
parallel_dur = parallel_end - parallel_start
total_worker_time = child_df.loc[~is_main, "DurNonChildNs"].sum()
total_worker_time = child_df.loc[~is_main, "DurNonChildMs"].sum()

def attribute_parallel_time(row):
compile_time_ns[row.Name] += (
parallel_dur * row.DurNonChildNs / total_worker_time,
parallel_dur * row.DurChildNs / total_worker_time,
compile_time_ms[row.Name] += (
parallel_dur * row.DurNonChildMs / total_worker_time,
parallel_dur * row.DurChildMs / total_worker_time,
)

child_df[~is_main].apply(attribute_parallel_time, axis="columns")
# Easy to update these given the simplifying assumptions above; they are set to
# np.nan when worker ranges are spliced in by `_load_nvtx_pushpop_trace`
profile_df.loc[launcher_row.Index, "DurChildNs"] = (
child_df.loc[is_main, "DurNs"].sum() + parallel_dur
compile_df.loc[launcher_row.Index, "DurChildMs"] = (
child_df.loc[is_main, "DurMs"].sum() + parallel_dur
)
profile_df.loc[launcher_row.Index, "DurNonChildNs"] = (
launcher_row.DurNs - compile_df.loc[launcher_row.Index, "DurChildNs"]
compile_df.loc[launcher_row.Index, "DurNonChildMs"] = (
launcher_row.DurMs - compile_df.loc[launcher_row.Index, "DurChildMs"]
)

# `compile_time_ns` now accounts for parallel compilation worker threads, but not
# `compile_time_ms` now accounts for parallel compilation worker threads, but not
# the work from the main thread. Add that too.
for row in compile_df[compile_df["TID"] == main_thread].itertuples():
compile_time_ns[row.Name] += (row.DurNonChildNs, row.DurChildNs)
compile_time_ms[row.Name] += (row.DurNonChildMs, row.DurChildMs)

return pd.DataFrame.from_dict(
compile_time_ns, columns=["DurNonChildNs", "DurChildNs"], orient="index"
).sort_values(by=["DurNonChildNs"], ascending=False)
compile_time_ms, columns=["DurNonChildMs", "DurChildMs"], orient="index"
).sort_values(by=["DurNonChildMs"], ascending=False)
Loading

0 comments on commit 2f1f572

Please sign in to comment.