Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

train_custom_diffusion.py does not support fp16 #5502

Closed
jiaqiw09 opened this issue Oct 24, 2023 · 4 comments
Closed

train_custom_diffusion.py does not support fp16 #5502

jiaqiw09 opened this issue Oct 24, 2023 · 4 comments
Assignees
Labels
stale Issues that haven't received updates

Comments

@jiaqiw09
Copy link
Contributor

jiaqiw09 commented Oct 24, 2023

when I run train_custom_diffusion.py by following args, there will be a bug report

  python3 custom_diffusion/train_custom_diffusion.py
  --pretrained_model_name_or_path CompVis/stable-diffusion-v1-4
  --instance_data_dir diffusers/cat_toy_example
  --prior_loss_weight 1.0
  --class_prompt "cat"
  --num_class_images 200
  --instance_prompt "photo of a <new1> cat"
  --resolution 512
  --train_batch_size 1
  --learning_rate 1e-5
  --lr_warmup_steps 0
  --max_train_steps 10
  --scale_lr
  --hflip
  --modifier_token "<new1>"
  --validation_prompt "<new1> cat sitting in a bucket"
  --no_safe_serialization
  --mixed_precision "fp16"
Traceback (most recent call last):
  File "custom_diffusion/train_custom_diffusion.py", line 1341, in <module>
    main(args)
  File "custom_diffusion/train_custom_diffusion.py", line 1123, in main
    model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
  File "/root/anaconda3/envs/wjq_test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/anaconda3/envs/wjq_test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/wjq/diffusers_accuracy/src/diffusers/models/unet_2d_condition.py", line 1010, in forward
    sample, res_samples = downsample_block(
  File "/root/anaconda3/envs/wjq_test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/anaconda3/envs/wjq_test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/wjq/diffusers_accuracy/src/diffusers/models/unet_2d_blocks.py", line 1103, in forward
    hidden_states = attn(
  File "/root/anaconda3/envs/wjq_test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/anaconda3/envs/wjq_test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/wjq/diffusers_accuracy/src/diffusers/models/transformer_2d.py", line 323, in forward
    hidden_states = block(
  File "/root/anaconda3/envs/wjq_test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/anaconda3/envs/wjq_test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/wjq/diffusers_accuracy/src/diffusers/models/attention.py", line 218, in forward
    attn_output = self.attn2(
  File "/root/anaconda3/envs/wjq_test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/anaconda3/envs/wjq_test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/wjq/diffusers_accuracy/src/diffusers/models/attention_processor.py", line 423, in forward
    return self.processor(
  File "/home/wjq/diffusers_accuracy/src/diffusers/models/attention_processor.py", line 1246, in __call__
    hidden_states = F.scaled_dot_product_attention(
RuntimeError: Expected query, key, and value to have the same dtype, but got query.dtype: c10::Half key.dtype: float and value.dtype: float instead.`
@jiaqiw09
Copy link
Contributor Author

jiaqiw09 commented Oct 24, 2023

@DN6 here is some basic analysis.

I just check the code using hook of each layer, it seems the problems happens in down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_q

and if I narrow the margin of error, the problem seems to happen in src/diffusers/models/attention_processor.py

hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )

and for query

tensor([[[ 0.3604,  0.3496,  1.6914,  ..., -0.0689, -0.0560, -0.1388],
         [ 0.5840,  0.3242,  1.5635,  ...,  0.0389,  0.0034,  0.0142],
         [ 0.5352,  0.5337,  1.5850,  ..., -0.0407, -0.1105, -0.1510],
         ...,
         [ 0.2068,  0.4807,  1.7617,  ..., -0.0231,  0.0940, -0.0759],
         [ 0.0457,  0.6743,  1.6045,  ..., -0.0941,  0.0852, -0.1387],
         [-0.2017,  0.5015,  1.6875,  ..., -0.1460,  0.0304, -0.0388]]],
       device='npu:0', dtype=torch.float16)

and for key and value , both have fllowing feature

tensor([[[-0.3885,  0.0230, -0.0522,  ..., -0.4901, -0.3066,  0.0674],
         [-1.0019, -0.1030,  1.9585,  ..., -0.3306, -0.2973, -1.8877],
         [ 0.1654, -0.1679, -0.6554,  ...,  0.4813,  1.1950, -1.8424],
         ...,
         [-0.2540,  0.1089, -0.0968,  ...,  0.4131,  0.9504, -0.7500],
         [-0.2531,  0.1275, -0.1063,  ...,  0.4039,  0.9594, -0.7297],
         [-0.2598,  0.1148, -0.0549,  ...,  0.4198,  0.9326, -0.7927]]],
       device='npu:0', grad_fn=<NativeLayerNormBackward0>)

both key and value is from

if self.train_kv:
    key = self.to_k_custom_diffusion(encoder_hidden_states)
    value = self.to_v_custom_diffusion(encoder_hidden_states)

and the encoder_hidden_states is also grad_fn=<NativeLayerNormBackward0>

Does anyoen meet same problem before?

@DN6
Copy link
Collaborator

DN6 commented Oct 30, 2023

Hi @jiaqiw09 I was able to reproduce the issue. It seems like the text encoder isn't producing the encoder_hidden_states with the right dtype. Taking a look into why this is happening.

@DN6 DN6 self-assigned this Oct 30, 2023
@DN6
Copy link
Collaborator

DN6 commented Nov 7, 2023

@jiaqiw09 This should be fixed with PR: #5663

Copy link

github-actions bot commented Dec 2, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Dec 2, 2023
@DN6 DN6 closed this as completed Dec 4, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale Issues that haven't received updates
Projects
None yet
Development

No branches or pull requests

2 participants