Skip to content

Commit

Permalink
nsys-jax: track XLA collective annotation changes (#953)
Browse files Browse the repository at this point in the history
- openxla/xla#14344 changed how some
asynchronous collective launches were annotated
- Add a pure-Python wrapper for the `HloInstruction` proto, as already
existed for `HloModule`, and use that as a place to add new logic
required to identify collective operations
- Track other changes to protobuf dump introduced by
openxla/xla@72ee1e0
- Optimise data loading time; when loading `nsys-jax-combine` output
then use multiprocessing to parallelise the first loading phase
  • Loading branch information
olupton authored Jul 22, 2024
1 parent 08903d4 commit 1a629a7
Show file tree
Hide file tree
Showing 6 changed files with 371 additions and 230 deletions.
22 changes: 8 additions & 14 deletions .github/container/jax_nsys/Analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -359,17 +359,17 @@
"gpu_idle_between_modules = [\"[GPU idle between module executions]\"]\n",
"\n",
"\n",
"@functools.lru_cache\n",
"@functools.cache\n",
"def instructions_and_frames(hlo_module, instruction_name):\n",
" _, hlo_inst = hlo_module.find_instruction(instruction_name)\n",
" instructions = [hlo_inst] + [\n",
" instructions = [hlo_inst.proto()] + [\n",
" called_inst\n",
" for called_comp_id in hlo_inst.called_computation_ids\n",
" for called_comp_id in hlo_inst.proto().called_computation_ids\n",
" for called_inst in hlo_module.find_computation(called_comp_id).instructions\n",
" ]\n",
" metadata = [inst.metadata for inst in instructions]\n",
" frames = [hlo_module.get_stack_frames(meta.stack_frame_id) for meta in metadata]\n",
" return hlo_inst.opcode, metadata, frames\n",
" return hlo_inst.proto().opcode, metadata, frames\n",
"\n",
"\n",
"for thunk_row in thunk_summary.itertuples():\n",
Expand Down Expand Up @@ -580,7 +580,9 @@
"detailed_mask = (compute_duration_rel_stds > var_threshold) & (\n",
" compute_duration_means > mean_threshold\n",
")\n",
"assert detailed_mask.sum() == detailed_limit\n",
"assert (\n",
" detailed_mask.sum() <= detailed_limit\n",
"), f\"Aimed for {detailed_limit} and got {detailed_mask.sum()}\"\n",
"\n",
"fig, axs = plt.subplots(\n",
" ncols=2, width_ratios=[1, 2], figsize=[15, 5], tight_layout=True\n",
Expand Down Expand Up @@ -732,7 +734,7 @@
" )\n",
"\n",
" ax.set_title(\n",
" f\"{steady_state.module.loc[program_id, 'Name'].iloc[0]} ({program_id}), {devices_to_show} most extreme devices\"\n",
" f\"{steady_state.module.loc[program_id, 'Name'].iloc[0]} ({program_id}), {min(outlier_devices.size, devices_to_show)} most extreme devices\"\n",
" )\n",
" ax.set_xlabel(\"Mean time within module [ms]\")\n",
" ax.set_ylabel(\"Mean(executions) bias from mean(executions&devices) [ms]\")\n",
Expand Down Expand Up @@ -853,14 +855,6 @@
" f\"Peak heap memory usage in module ID {module_id} is {max(heap_usage) / 1e9:.3f} GB\"\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1c80b631-cf10-43fd-837a-5bd9dded72bc",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
32 changes: 19 additions & 13 deletions .github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,9 @@ def _get_message_size(
module_proto: HloProto, instruction: str
) -> tuple[int, str, int, float, float]:
_, inst = module_proto.find_instruction(instruction)
comm_inst = inst.communication_proto()
assert (
inst.opcode
comm_inst.opcode
in {
"all-gather-start",
"all-reduce-start",
Expand All @@ -172,26 +173,31 @@ def _get_message_size(
"collective-permute-start",
"reduce-scatter",
}
), f"{instruction}: message size calculation for {inst.opcode} has not yet been validated"
if inst.opcode == "collective-permute-start":
), f"{instruction}: message size calculation for {comm_inst.opcode} has not yet been validated"
if comm_inst.opcode == "collective-permute-start":
# See https://openxla.org/xla/operation_semantics#collectivepermute, which
# generates pair-wise send+recv between devices
collective_size = 2
else:
# replica_groups is something like {{0,1},{4,5},{2,3},{6,7}}, if there are 8
# devices that are doing pair-wise collectives
collective_size = len(inst.replica_groups[0].replica_ids)
assert all(
len(group.replica_ids) == collective_size for group in inst.replica_groups
), f"Heterogeneous collective {inst.replica_groups} could not be interpreted"
try:
replica_groups = comm_inst.collective_device_list.replica_groups
except AttributeError:
replica_groups = comm_inst.replica_groups
collective_sizes = set(len(group.replica_ids) for group in replica_groups)
assert (
len(collective_sizes) == 1
), f"Heterogeneous collective {comm_inst} could not be interpreted"
collective_size = next(iter(collective_sizes))
total_msg_size = 0
for operand_id in inst.operand_ids:
for operand_id in comm_inst.operand_ids:
_, operand = module_proto.find_instruction_by_id(operand_id)
msg_size_bits = math.prod(
operand.shape.dimensions,
start=element_type_width(operand.shape.element_type),
operand.proto().shape.dimensions,
start=element_type_width(operand.proto().shape.element_type),
)
if inst.opcode == "reduce-scatter":
if comm_inst.opcode == "reduce-scatter":
# NCCL's convention is that the message size of a reduce-scatter is the size of output buffer:
# https://github.com/NVIDIA/nccl/blob/ab2b89c4c339bd7f816fbc114a4b05d386b66290/src/collectives.cc#L122
msg_size_bits, rem = divmod(msg_size_bits, collective_size)
Expand All @@ -200,12 +206,12 @@ def _get_message_size(
assert rem == 0
total_msg_size += msg_size_bytes

collective = inst.opcode.removesuffix("-start")
collective = comm_inst.opcode.removesuffix("-start")
bw_correction, bus_correction = _collective_correction(collective, collective_size)
return (total_msg_size, collective, collective_size, bw_correction, bus_correction)


@functools.lru_cache
@functools.cache
def get_message_size(
program_id: int, instruction: str, prefix: pathlib.Path
) -> pd.Series:
Expand Down
Loading

0 comments on commit 1a629a7

Please sign in to comment.