Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a incorrect assert on frame count for PT_XLA_DEBUG=1 #6466

Merged
merged 1 commit into from
Feb 21, 2024

Conversation

JackCaoG
Copy link
Collaborator

@JackCaoG JackCaoG commented Feb 5, 2024

I discovered this while doing a repo for the customer

  File "test.py", line 43, in main
    trainer.train()
  File "/src/repo/magvit2-pytorch/magvit2_pytorch/trainer.py", line 517, in train
    self.train_step(dl_iter)
  File "/src/repo/magvit2-pytorch/magvit2_pytorch/trainer.py", line 351, in train_step
    loss, loss_breakdown = self.model(
  File "/src/pytorch/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/src/pytorch/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "<@beartype(magvit2_pytorch.magvit2_pytorch.VideoTokenizer.forward) at 0x7f962217a160>", line 55, in forward
  File "/src/repo/magvit2-pytorch/magvit2_pytorch/magvit2_pytorch.py", line 1826, in forward
    norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)
  File "/src/pytorch/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
  File "<@beartype(magvit2_pytorch.magvit2_pytorch.grad_layer_wrt_loss) at 0x7f9622152ca0>", line 52, in grad_layer_wrt_loss
  File "/src/repo/magvit2-pytorch/magvit2_pytorch/magvit2_pytorch.py", line 127, in grad_layer_wrt_loss
    return torch_grad(
  File "/src/pytorch/torch/autograd/__init__.py", line 412, in grad
    result = _engine_run_backward(
  File "/src/pytorch/torch/autograd/graph.py", line 744, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: torch_xla/csrc/debug_util.cpp:265 : Check failed: frames.size() >= 1 (1 vs. 0)

@JackCaoG JackCaoG requested a review from will-cromar February 5, 2024 03:47
@JackCaoG JackCaoG merged commit c08ae21 into master Feb 21, 2024
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants