Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Ray spill out of disk error when using alpa to auto-parallelize llama #969

Open
zigzagcai opened this issue Nov 21, 2023 · 2 comments
Open

Comments

@zigzagcai
Copy link

zigzagcai commented Nov 21, 2023

Please describe the bug
When I tried to use alpa to parallelize llama-7b model on ray cluster (one node with 8 GPUs), disk space will continue to grow and never stop due to ray object spilling. Finally the program will throw out of disk space error.

Please describe the expected behavior
As expected, alpa training will run normally.

System information and environment

  • OS Platform and Distribution: Ubuntu 20.04 docker
  • Python version: 3.10.13
  • CUDA version: 11.8
  • NCCL version: 2.16.2
  • cupy version: cupy-cuda11x==12.2.0
  • GPU model and memory: NVIDIA A800 80GB
  • Alpa version: 1.0.0.dev0, build from source (alpa main branch)
  • TensorFlow version: 2.11.0
  • JAX version: 0.3.22
  • Ray version:
>>> print(ray.__version__)
2.1.0
>>> print(ray.__commit__)
be49bde7ee4f6adb3f8710aee0665c27f9f0bb62

To Reproduce
Steps to reproduce the behavior:

  1. LLaMa model used: https://github.com/young-geng/EasyLM/tree/main/EasyLM/models/llama
  2. ray start --head
  3. cd examples/llama_finetune
  4. bash run_llama.sh

Error Logs

cd examples/llama_finetune && bash run_llama.sh

2023-11-21 11:15:40.798832: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /fs/llm/zigzagcai/gcc_10.2.0/lib64:/usr/local/nccl-rdma-sharp-plugins/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
2023-11-21 11:15:40.798902: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /fs/llm/zigzagcai/gcc_10.2.0/lib64:/usr/local/nccl-rdma-sharp-plugins/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
2023-11-21 11:15:40.798911: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
2023-11-21 11:15:42,074 INFO worker.py:1342 -- Connecting to existing Ray cluster at address: xxx.xx.x.xxx:6379...
2023-11-21 11:15:42,080 INFO worker.py:1519 -- Connected to Ray cluster. View the dashboard at 127.0.0.1:8265
INFO:__main__:Training/evaluation parameters TrainingArguments(output_dir='./output', overwrite_output_dir=True, do_train=True, do_eval=False, per_device_train_batch_size=32, per_device_eval_batch_size=16, num_micro_batches=32, operator_parallel=1, pipeline_parallel=1, use_remat=True, learning_rate=0.0005, weight_decay=0.0, adam_beta1=0.9, adam_beta2=0.999, adam_epsilon=1e-08, adafactor=False, num_train_epochs=3.0, warmup_ratio=0.03, logging_steps=1, save_steps=3000, eval_steps=1000, seed=42, push_to_hub=False, hub_model_id=None, hub_token=None)
Model config LLaMAConfig {
  "attn_pdrop": 0.0,
  "bos_token_id": 0,
  "embd_pdrop": 0.0,
  "eos_token_id": 1,
  "fcm_max_ratio": 0.0,
  "fcm_min_ratio": 0.0,
  "gradient_checkpointing": true,
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 11008,
  "max_sequence_length": 2048,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "resid_pdrop": 0.0,
  "rms_norm_eps": 1e-06,
  "tie_word_embeddings": false,
  "transformers_version": "4.28.1",
  "use_cache": true,
  "vocab_size": 32000
}

loading file tokenizer.model
loading file added_tokens.json
loading file special_tokens_map.json
loading file tokenizer_config.json
Generate config GenerationConfig {
  "_from_model_config": true,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "transformers_version": "4.28.1"
}

loading configuration file /root/llama/llama-7b/config.json
Model config LlamaConfig {
  "_name_or_path": "/root/llama/llama-7b",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 11008,
  "max_position_embeddings": 2048,
  "max_sequence_length": 2048,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "pad_token_id": 0,
  "rms_norm_eps": 1e-06,
  "tie_word_embeddings": false,
  "torch_dtype": "float16",
  "transformers_version": "4.28.1",
  "use_cache": true,
  "vocab_size": 32000
}

loading weights file /root/llama/llama-7b/model.safetensors.index.json
Generate config GenerationConfig {
  "_from_model_config": true,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "pad_token_id": 0,
  "transformers_version": "4.28.1"
}

Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:04<00:00,  2.44s/it]
All model checkpoint weights were used when initializing LlamaForCausalLM.

All the weights of LlamaForCausalLM were initialized from the model checkpoint at /root/llama/llama-7b.
If your task is similar to the task the model of the checkpoint was trained on, you can already use LlamaForCausalLM for predictions without further training.
loading configuration file /root/llama/llama-7b/generation_config.json
Generate config GenerationConfig {
  "_from_model_config": true,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "pad_token_id": 0,
  "transformers_version": "4.28.1"
}

Loading data...
#train 44425, #eval 907
Formatting inputs...Skip in lazy mode
Formatting inputs...Skip in lazy mode
INFO:__main__:***** Build dataset *****
INFO:__main__:***** Running training *****
INFO:__main__:  Num examples = 44425
INFO:__main__:  Num Epochs = 3
INFO:__main__:  Batch size per device (w. accumulation) = 32
INFO:__main__:  Global train batch size (w. parallel & distributed) = 256
INFO:__main__:  Total optimization steps = 519
Initial compilation. This might take some minutes...
Epoch ... :   0%|                                                                                                                                                                                                                                                       | 0/3 [00:00<?, ?it/s(raylet) Spilled 1049654 MiB, 1571 objects, write throughput 663 MiB/s.                                                                                                                                                                                                | 0/173 [00:00<?, ?it/s]
Epoch ... :   0%|                                                                                                                                                                                                                                                       | 0/3 [16:44<?, ?it/s]
Traceback (most recent call last):
  File "/fs/llm/zigzagcai/alpa/examples/llama_finetune/run_easylm_flax.py", line 886, in <module>
    main()
  File "/fs/llm/zigzagcai/alpa/examples/llama_finetune/run_easylm_flax.py", line 752, in main
    state, train_metric = p_train_step(state, batch)
  File "/root/miniconda3/envs/alpa/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/fs/llm/zigzagcai/alpa/alpa/api.py", line 130, in __call__
    out = executable.launch_on_driver(*args_flat)
  File "/fs/llm/zigzagcai/alpa/alpa/mesh_executable.py", line 665, in launch_on_driver
    input_bufs = physical_mesh.shard_args_to_bufs(
  File "/fs/llm/zigzagcai/alpa/alpa/device_mesh.py", line 1325, in shard_args_to_bufs
    ref = shard_arg_handlers[type(arg)](arg, self, indices)[0]
  File "/fs/llm/zigzagcai/alpa/alpa/device_mesh.py", line 2484, in _shard_device_array
    return _shard_array(np.asarray(array), device_mesh, indices, num_batch,
  File "/fs/llm/zigzagcai/alpa/alpa/device_mesh.py", line 2477, in _shard_array
    return _device_mesh_put(device_mesh, datas, num_batch, batch_dim)
  File "/fs/llm/zigzagcai/alpa/alpa/device_mesh.py", line 2434, in _device_mesh_put
    device_mesh.workers[host_id].put_buffers.remote(
  File "/root/miniconda3/envs/alpa/lib/python3.10/site-packages/ray/actor.py", line 138, in remote
    return self._remote(args, kwargs)
  File "/root/miniconda3/envs/alpa/lib/python3.10/site-packages/ray/util/tracing/tracing_helper.py", line 425, in _start_span
    return method(self, args, kwargs, *_args, **_kwargs)
  File "/root/miniconda3/envs/alpa/lib/python3.10/site-packages/ray/actor.py", line 184, in _remote
    return invocation(args, kwargs)
  File "/root/miniconda3/envs/alpa/lib/python3.10/site-packages/ray/actor.py", line 171, in invocation
    return actor._actor_method_call(
  File "/root/miniconda3/envs/alpa/lib/python3.10/site-packages/ray/actor.py", line 1170, in _actor_method_call
    object_refs = worker.core_worker.submit_actor_task(
  File "python/ray/_raylet.pyx", line 1982, in ray._raylet.CoreWorker.submit_actor_task
  File "python/ray/_raylet.pyx", line 1987, in ray._raylet.CoreWorker.submit_actor_task
  File "python/ray/_raylet.pyx", line 402, in ray._raylet.prepare_args_and_increment_put_refs
  File "python/ray/_raylet.pyx", line 393, in ray._raylet.prepare_args_and_increment_put_refs
  File "python/ray/_raylet.pyx", line 482, in ray._raylet.prepare_args_internal
  File "python/ray/_raylet.pyx", line 1599, in ray._raylet.CoreWorker.put_serialized_object_and_increment_local_ref
  File "python/ray/_raylet.pyx", line 1488, in ray._raylet.CoreWorker._create_put_buffer
  File "python/ray/_raylet.pyx", line 188, in ray._raylet.check_status
jax._src.traceback_util.UnfilteredStackTrace: ray.exceptions.OutOfDiskError: Local disk is full
The object cannot be created because the local object store is full and the local disk's utilization is over capacity (95% by default).Tip: Use `df` on this node to check disk usage and `ray memory` to check object store memory usage.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/fs/llm/zigzagcai/alpa/examples/llama_finetune/run_easylm_flax.py", line 886, in <module>
    main()
  File "/fs/llm/zigzagcai/alpa/examples/llama_finetune/run_easylm_flax.py", line 752, in main
    state, train_metric = p_train_step(state, batch)
  File "/fs/llm/zigzagcai/alpa/alpa/mesh_executable.py", line 665, in launch_on_driver
    input_bufs = physical_mesh.shard_args_to_bufs(
  File "/fs/llm/zigzagcai/alpa/alpa/device_mesh.py", line 1325, in shard_args_to_bufs
    ref = shard_arg_handlers[type(arg)](arg, self, indices)[0]
  File "/fs/llm/zigzagcai/alpa/alpa/device_mesh.py", line 2484, in _shard_device_array
    return _shard_array(np.asarray(array), device_mesh, indices, num_batch,
  File "/fs/llm/zigzagcai/alpa/alpa/device_mesh.py", line 2477, in _shard_array
    return _device_mesh_put(device_mesh, datas, num_batch, batch_dim)
  File "/fs/llm/zigzagcai/alpa/alpa/device_mesh.py", line 2434, in _device_mesh_put
    device_mesh.workers[host_id].put_buffers.remote(
  File "/root/miniconda3/envs/alpa/lib/python3.10/site-packages/ray/actor.py", line 138, in remote
    return self._remote(args, kwargs)
  File "/root/miniconda3/envs/alpa/lib/python3.10/site-packages/ray/util/tracing/tracing_helper.py", line 425, in _start_span
    return method(self, args, kwargs, *_args, **_kwargs)
  File "/root/miniconda3/envs/alpa/lib/python3.10/site-packages/ray/actor.py", line 184, in _remote
    return invocation(args, kwargs)
  File "/root/miniconda3/envs/alpa/lib/python3.10/site-packages/ray/actor.py", line 171, in invocation
    return actor._actor_method_call(
  File "/root/miniconda3/envs/alpa/lib/python3.10/site-packages/ray/actor.py", line 1170, in _actor_method_call
    object_refs = worker.core_worker.submit_actor_task(
  File "python/ray/_raylet.pyx", line 1982, in ray._raylet.CoreWorker.submit_actor_task
  File "python/ray/_raylet.pyx", line 1987, in ray._raylet.CoreWorker.submit_actor_task
  File "python/ray/_raylet.pyx", line 402, in ray._raylet.prepare_args_and_increment_put_refs
  File "python/ray/_raylet.pyx", line 393, in ray._raylet.prepare_args_and_increment_put_refs
  File "python/ray/_raylet.pyx", line 482, in ray._raylet.prepare_args_internal
  File "python/ray/_raylet.pyx", line 1599, in ray._raylet.CoreWorker.put_serialized_object_and_increment_local_ref
  File "python/ray/_raylet.pyx", line 1488, in ray._raylet.CoreWorker._create_put_buffer
  File "python/ray/_raylet.pyx", line 188, in ray._raylet.check_status
ray.exceptions.OutOfDiskError: Local disk is full
The object cannot be created because the local object store is full and the local disk's utilization is over capacity (95% by default).Tip: Use `df` on this node to check disk usage and `ray memory` to check object store memory usage.
(raylet) [2023-11-21 11:36:54,055 E 446708 446738] (raylet) file_system_monitor.cc:105: /tmp/ray/session_2023-11-21_10-23-31_768034_446624 is over 95% full, available space: 42212503552; capacity: 844367142912. Object creation will fail if spilling is required.
@zigzagcai zigzagcai changed the title Ray spill out of disk error when using alpa to auto-parallelize llama [Bug] Ray spill out of disk error when using alpa to auto-parallelize llama Nov 21, 2023
@zigzagcai
Copy link
Author

zigzagcai commented Nov 21, 2023

Also, we can see from nvidia-smi that GPU memory was reserved but GPU utility is always 0. The memory continues leaking and ray spill object continues to grow, until the out of disk error throwed.

Tue Nov 21 12:26:30 2023
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.161.03   Driver Version: 470.161.03   CUDA Version: 11.8     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A800-SXM...  On   | 00000000:4D:00.0 Off |                    0 |
| N/A   30C    P0    67W / 400W |  73279MiB / 81251MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A800-SXM...  On   | 00000000:52:00.0 Off |                    0 |
| N/A   28C    P0    65W / 400W |  73279MiB / 81251MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   2  NVIDIA A800-SXM...  On   | 00000000:69:00.0 Off |                    0 |
| N/A   26C    P0    64W / 400W |  73279MiB / 81251MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   3  NVIDIA A800-SXM...  On   | 00000000:6F:00.0 Off |                    0 |
| N/A   28C    P0    67W / 400W |  73279MiB / 81251MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   4  NVIDIA A800-SXM...  On   | 00000000:B3:00.0 Off |                    0 |
| N/A   28C    P0    64W / 400W |  73279MiB / 81251MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   5  NVIDIA A800-SXM...  On   | 00000000:B7:00.0 Off |                    0 |
| N/A   27C    P0    65W / 400W |  73279MiB / 81251MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   6  NVIDIA A800-SXM...  On   | 00000000:D5:00.0 Off |                    0 |
| N/A   28C    P0    63W / 400W |  73279MiB / 81251MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   7  NVIDIA A800-SXM...  On   | 00000000:D8:00.0 Off |                    0 |
| N/A   29C    P0    66W / 400W |  73279MiB / 81251MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+

@zigzagcai zigzagcai changed the title [Bug] Ray spill out of disk error when using alpa to auto-parallelize llama Ray spill out of disk error when using alpa to auto-parallelize llama Nov 21, 2023
@zigzagcai
Copy link
Author

zigzagcai commented Nov 27, 2023

Update:

I have solved this issue by specifying ray start --head --system-config='{"object_spilling_threshold":0.99}'
For those who also met such errors, FYI:
https://github.com/zigzagcai/alpa

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant