Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only write to HBM at the last iteration. #8393

Merged
merged 2 commits into from
Nov 20, 2024

Conversation

vanbasten23
Copy link
Collaborator

@vanbasten23 vanbasten23 commented Nov 19, 2024

Test plan: root@t1v-n-f3643994-w-0:/workspaces/persist# python pytorch/xla/test/test_tpu_paged_attention_kernel.py 2>&1 | tee ~/out.txt

@vanbasten23 vanbasten23 marked this pull request as ready for review November 19, 2024 00:01
@vanbasten23
Copy link
Collaborator Author

vanbasten23 commented Nov 19, 2024

The TPU CI failure seems to be irrelevant to the PR:

+ python3 test/spmd/test_xla_distributed_checkpoint.py
E1119 00:37:42.757418399  122932 server_chttp2.cc:40]        {"created":"@1731976662.757394348","description":"Only 1 addresses added out of total 2 resolved","file":"external/com_github_grpc_grpc/src/core/ext/transport/chttp2/server/chttp2_server.cc","file_line":404,"referenced_errors":[{"created":"@1731976662.757391588","description":"Address family not supported by protocol","errno":97,"file":"external/com_github_grpc_grpc/src/core/lib/iomgr/socket_utils_common_posix.cc","file_line":420,"os_error":"Address family not supported by protocol","syscall":"socket","target_address":"[::1]:8547"}]}
WARNING:root:Preemption sync point reached at step 10. Triggering a checkpoint.
/home/runner/.local/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_saver.py:290: UserWarning: The function definition for SavePlanner.set_up_planner has been updated to include the storage_meta argument. Please update your implementation to include this parameter.
  warnings.warn(
/home/runner/.local/lib/python3.10/site-packages/torch/distributed/checkpoint/filesystem.py:116: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  if tensor.storage().size() != tensor.numel():
.......2024-11-19 00:37:44.764514: W external/xla/xla/tsl/distributed_runtime/preemption/preemption_notifier.cc:89] SIGTERM caught at 2024-11-19T00:37:44.764454699+00:00
./home/runner/.local/lib/python3.10/site-packages/torch_xla/runtime.py:242: UserWarning: Replicating tensors already initialized on non-virtual XLA device for SPMD to force SPMD mode. This is one-time overhead to setup, and to minimize such, please set SPMD mode before initializting tensors (i.e., call use_spmd() in the beginning of the program).
  warnings.warn(
.s/home/runner/.local/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_saver.py:143: UserWarning: torch.distributed is unavailable or uninitialized, assuming the intent is to save in a single process.
  warnings.warn(
/home/runner/.local/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_loader.py:144: UserWarning: torch.distributed is unavailable or uninitialized, assuming the intent is to load in a single process.
  warnings.warn(
.....E1119 00:37:45.690169  124197 preemption_sync_manager.cc:247] Preemption sync failed - could not inform service of current call counter: ALREADY_EXISTS: Config key PREEMPTION_CURRENT_COUNTER//job:jax_worker/task:0 already exists.
Additional GRPC error information from remote target coordination_service while calling /tensorflow.CoordinationService/InsertKeyValue:
:{"created":"@1731976665.690086867","description":"Error received from peer ipv4:127.0.0.1:8547","file":"external/com_github_grpc_grpc/src/core/lib/surface/call.cc","file_line":1056,"grpc_message":"Config key PREEMPTION_CURRENT_COUNTER//job:jax_worker/task:0 already exists.","grpc_status":6} [type.googleapis.com/tensorflow.CoordinationServiceError='']
E1119 00:37:45.690533  12[387](https://github.com/pytorch/xla/actions/runs/11903399461/job/33171217138?pr=8393#step:5:388)4 preemption_sync_manager.cc:303] Failed to cancel preemption barrier: FAILED_PRECONDITION: Barrier (PREEMPTION_SYNC_BARRIER) has already been passed with status code: 0
Additional GRPC error information from remote target coordination_service while calling /tensorflow.CoordinationService/CancelBarrier:
:{"created":"@1731976665.690513157","description":"Error received from peer ipv4:127.0.0.1:8547","file":"external/com_github_grpc_grpc/src/core/lib/surface/call.cc","file_line":1056,"grpc_message":"Barrier (PREEMPTION_SYNC_BARRIER) has already been passed with status code: 0","grpc_status":9} [type.googleapis.com/tensorflow.CoordinationServiceError='']
FF........
======================================================================
FAIL: test_adamw (__main__.OptimizerCheckpointTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/runner/_work/xla/xla/pytorch/xla/test/spmd/test_xla_distributed_checkpoint.py", line 38, in run
    f(*args, **kwargs)
  File "/home/runner/_work/xla/xla/pytorch/xla/test/spmd/test_xla_distributed_checkpoint.py", line 684, in test_adamw
    self._test_optimizer(tmpdir, torch.optim.AdamW)
  File "/home/runner/_work/xla/xla/pytorch/xla/test/spmd/test_xla_distributed_checkpoint.py", line 671, in _test_optimizer
    self._assert_same_state_dict(state_dict, new_state_dict)
  File "/home/runner/_work/xla/xla/pytorch/xla/test/spmd/test_xla_distributed_checkpoint.py", line 90, in _assert_same_state_dict
    self._assert_same_state_dict(
  File "/home/runner/_work/xla/xla/pytorch/xla/test/spmd/test_xla_distributed_checkpoint.py", line 90, in _assert_same_state_dict
    self._assert_same_state_dict(
  File "/home/runner/_work/xla/xla/pytorch/xla/test/spmd/test_xla_distributed_checkpoint.py", line 82, in _assert_same_state_dict
    assert sharding1 == sharding2, f"Different sharding on tensors at {keypath}: {sharding1} vs {sharding2}"
AssertionError: Different sharding on tensors at .model.fc1.bias: {replicated} vs 

======================================================================
FAIL: test_sgd (__main__.OptimizerCheckpointTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/runner/_work/xla/xla/pytorch/xla/test/spmd/test_xla_distributed_checkpoint.py", line 38, in run
    f(*args, **kwargs)
  File "/home/runner/_work/xla/xla/pytorch/xla/test/spmd/test_xla_distributed_checkpoint.py", line 680, in test_sgd
    self._test_optimizer(tmpdir, torch.optim.SGD)
  File "/home/runner/_work/xla/xla/pytorch/xla/test/spmd/test_xla_distributed_checkpoint.py", line 671, in _test_optimizer
    self._assert_same_state_dict(state_dict, new_state_dict)
  File "/home/runner/_work/xla/xla/pytorch/xla/test/spmd/test_xla_distributed_checkpoint.py", line 90, in _assert_same_state_dict
    self._assert_same_state_dict(
  File "/home/runner/_work/xla/xla/pytorch/xla/test/spmd/test_xla_distributed_checkpoint.py", line 90, in _assert_same_state_dict
    self._assert_same_state_dict(
  File "/home/runner/_work/xla/xla/pytorch/xla/test/spmd/test_xla_distributed_checkpoint.py", line 82, in _assert_same_state_dict
    assert sharding1 == sharding2, f"Different sharding on tensors at {keypath}: {sharding1} vs {sharding2}"
AssertionError: Different sharding on tensors at .model.fc1.bias: {replicated} vs 

----------------------------------------------------------------------
Ran 25 tests in 3.498s

FAILED (failures=2, skipped=1)

I run the test root@t1v-n-f3643994-w-0:/workspaces/persist# python pytorch/xla/test/test_tpu_paged_attention_kernel.py 2>&1 | tee ~/out.txt locally and it passed.

@JackCaoG
Copy link
Collaborator

TPU CI failure should be resolved if you rebase, I disabled that test for now

@vanbasten23 vanbasten23 force-pushed the xiowei/notWriteToHBMEverytime branch from 1c17a71 to d52c6f2 Compare November 19, 2024 03:20
@vanbasten23
Copy link
Collaborator Author

TPU CI failure should be resolved if you rebase, I disabled that test for now

Thanks Jack for the info!

@vanbasten23
Copy link
Collaborator Author

The TPU test failure is very strange. On my TPU v4, PagedAttentionKernelTest.test_paged_attention_with_query_padding127 succeeded but PagedAttentionKernelTest.test_paged_attention_with_query_padding128 failed due to jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Error loading program: Fatal error at line kv_len = lengths[i] (full error)

The failing test succeeded on my v5e VM though which uses an older version of torch and torch_xla:

torch: 33191bb664fef338d94586a05bf1f14d17a00340
torch_xla: 0c3d54bf0c1e19ad4a139e0fe753bdde6b6e7dd8
libtpu_nightly==0.1.dev20241020+nightly
jax==0.4.35.dev20241020
jaxlib==0.4.35.dev20241020

@vanbasten23
Copy link
Collaborator Author

Seems the error is due to OOM despite the confusing error message even with jax.block_until_ready(actual_output): #8356 (comment). The error is gone when I reduce the size.

@JackCaoG JackCaoG merged commit d572aeb into master Nov 20, 2024
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants