Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reduce_scatter_tensor raises ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY in multi-node usage #65

Open
garrett361 opened this issue May 29, 2024 · 0 comments

Comments

@garrett361
Copy link

Cross posting from this ipex issue.

Repeated calls into torch.dist.reduce_scatter_tensor eventually raise a
ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY error in multi-node setups. Similar behavior is found when
using Fully Sharded Data Parallel, which calls into reduce_scatter_tensor internally.

Script to reproduce is below. Steps:

  1. Create source and destination tensors on all ranks in a multi-node setup.
  2. Repeatedly reduce_scatter_tensor and print out memory readings at each step
  3. Eventually, the above error is raised (without any corresponding jump in memory readings)
import argparse
import os

import intel_extension_for_pytorch as ipex  # noqa
import oneccl_bindings_for_pytorch  # noqa
import torch
import torch.distributed as dist


def get_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dim",
        type=int,
        default=2**30,
    )
    parser.add_argument(
        "--dtype",
        type=str,
        default="bfloat16",
    )
    parser.add_argument(
        "--max-steps",
        type=int,
        default=100,
    )
    args = parser.parse_args()
    return args


def main(dim: int, dtype: str, max_steps: int) -> None:
    world_size = int(os.environ["WORLD_SIZE"])
    rank = int(os.environ["RANK"])
    local_rank = int(os.environ["LOCAL_RANK"])
    device = torch.device(f"xpu:{local_rank}")
    torch.xpu.set_device(device)

    # Force dim to be divisible by the world size
    new_dim = world_size * (dim // world_size)
    if new_dim != dim:
        if not rank:
            print(
                f"Adjusting original {dim=} to {new_dim} in order to be divisible by {world_size=}",
                flush=True,
            )
        dim = new_dim

    try:
        dist.init_process_group("ccl")

        t_in = torch.randn(dim, dtype=getattr(torch, dtype), device=device)
        t_out = torch.empty(dim // world_size, dtype=getattr(torch, dtype), device=device)

        for step in range(1, max_steps + 1):
            dist.reduce_scatter_tensor(t_out, t_in, op=dist.ReduceOp.SUM)
            torch.xpu.synchronize()
            peak_mem_gib = torch.xpu.memory_stats()["allocated_bytes.all.peak"] / 2**30
            current_mem_gib = torch.xpu.memory_stats()["allocated_bytes.all.current"] / 2**30
            print(f"[{rank=}]: {step=} memory {peak_mem_gib=}, {current_mem_gib=}", flush=True)

    finally:
        dist.destroy_process_group()


if __name__ == "__main__":
    args = get_args()
    main(**vars(args))

Example logs:

[... snip ...]
[rank=14]: step=27 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=13]: step=27 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=17]: step=27 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=6]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=4]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=2]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=8]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=10]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=7]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=1]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=11]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=3]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=9]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=20]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=0]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=5]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=19]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=21]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=22]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=23]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=16]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=15]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=18]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=12]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=14]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=13]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=17]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=6]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=11]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=2]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=8]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=10]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=0]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=4]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=1]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=3]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=9]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=20]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=7]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=5]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=23]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=22]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=12]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=18]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=15]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=14]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=16]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=21]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=13]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=19]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=17]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
2024:05:29-19:16:18:(202165) |CCL_ERROR| worker.cpp:338 ccl_worker_func: worker 0 caught internal exception: oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
terminate called after throwing an instance of 'ccl::v1::exception'
  what():  oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
2024:05:29-19:16:18:(202162) |CCL_ERROR| worker.cpp:338 ccl_worker_func: worker 0 caught internal exception: oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
terminate called after throwing an instance of 'ccl::v1::exception'
  what():  oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
2024:05:29-19:16:18:(202164) |CCL_ERROR| worker.cpp:338 ccl_worker_func: worker 0 caught internal exception: oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
terminate called after throwing an instance of 'ccl::v1::exception'
  what():  oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
2024:05:29-19:16:18:(202173) |CCL_ERROR| worker.cpp:338 ccl_worker_func: worker 0 caught internal exception: oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
terminate called after throwing an instance of 'ccl::v1::exception'
  what():  oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
2024:05:29-19:16:18:(202167) |CCL_ERROR| worker.cpp:338 ccl_worker_func: worker 0 caught internal exception: oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
terminate called after throwing an instance of 'ccl::v1::exception'
  what():  oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
2024:05:29-19:16:18:(202166) |CCL_ERROR| worker.cpp:338 ccl_worker_func: worker 0 caught internal exception: oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
terminate called after throwing an instance of 'ccl::v1::exception'
  what():  oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
2024:05:29-19:16:18:(149693) |CCL_ERROR| worker.cpp:338 ccl_worker_func: worker 0 caught internal exception: oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
terminate called after throwing an instance of 'ccl::v1::exception'
  what():  oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
2024:05:29-19:16:18:(202168) |CCL_ERROR| worker.cpp:338 ccl_worker_func: worker 0 caught internal exception: oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
terminate called after throwing an instance of 'ccl::v1::exception'
  what():  oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
2024:05:29-19:16:18:(202163) |CCL_ERROR| worker.cpp:338 ccl_worker_func: worker 0 caught internal exception: oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
terminate called after throwing an instance of 'ccl::v1::exception'
  what():  oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
/lus/gila/projects/Aurora_deployment/mk/decoders/alcf/set_torch_dist_env.sh: line 25: 200400 Aborted                 $@
x1921c5s2b0n0.hostmgmt2000.cm.americas.sgi.com: rank 6 exited with code 134
x1921c5s2b0n0.hostmgmt2000.cm.americas.sgi.com: rank 0 died from signal 15
2024:05:29-19:16:18:(149692) |CCL_ERROR| worker.cpp:338 ccl_worker_func: worker 0 caught internal exception: oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
terminate called after throwing an instance of 'ccl::v1::exception'
  what():  oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY

The behavior seems specific to multi-node setups. I have not seen the same error raised on a single
node.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant