Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: shape '[616, 1, 40]' is invalid for input of size 49280 #18

Open
biasnhbi opened this issue Jun 11, 2023 · 0 comments
Open

Comments

@biasnhbi
Copy link

biasnhbi commented Jun 11, 2023

╭─────────────────── Traceback (most recent call last) ────────────────────╮
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/runpy.py:196 in           │
│ _run_module_as_main                                                      │
│                                                                          │
│   193main_globals = sys.modules["__main__"].__dict__                │
│   194if alter_argv:                                                 │
│   195 │   │   sys.argv[0] = mod_spec.origin                              │
│ ❱ 196return _run_code(code, main_globals, None,                     │
│   197 │   │   │   │   │    "__main__", mod_spec)                         │
│   198                                                                    │
│   199 def run_module(mod_name, init_globals=None,                        │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/runpy.py:86 in _run_code  │
│                                                                          │
│    83 │   │   │   │   │      __loader__ = loader,                        │
│    84 │   │   │   │   │      __package__ = pkg_name,                     │
│    85 │   │   │   │   │      __spec__ = mod_spec)                        │
│ ❱  86exec(code, run_globals)                                        │
│    87return run_globals                                             │
│    88                                                                    │
│    89 def _run_module_code(code, init_globals=None,                      │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/hcpdiff/tra │
│ in_ac_single.py:105 in <module>                                          │
│                                                                          │
│   102 │                                                                  │
│   103conf = load_config_with_cli(args.cfg, args_list=sys.argv[3:])  │
│   104trainer = TrainerSingleCard(conf)                              │
│ ❱ 105trainer.train()                                                │
│   106                                                                    │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/hcpdiff/tra │
│ in_ac.py:409 in train                                                    │
│                                                                          │
│   406 │   │                                                              │
│   407 │   │   loss_sum = np.ones(30)                                     │
│   408 │   │   for data_list in self.train_loader_group:                  │
│ ❱ 409 │   │   │   loss = self.train_one_step(data_list)                  │
│   410 │   │   │   loss_sum[self.global_step%len(loss_sum)] = loss        │
│   411 │   │   │                                                          │
│   412 │   │   │   self.global_step += 1                                  │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/hcpdiff/tra │
│ in_ac.py:501 in train_one_step                                           │
│                                                                          │
│   498 │   │   │   │   other_datas = {k:v.to(self.device, dtype=self.weig │
│   499 │   │   │   │                                                      │
│   500 │   │   │   │   latents = self.get_latents(image, self.train_loade │
│ ❱ 501 │   │   │   │   model_pred, target, timesteps = self.forward(laten │
│   502 │   │   │   │   loss = self.get_loss(model_pred, target, timesteps │
│   503 │   │   │   │   self.accelerator.backward(loss)                    │
│   504                                                                    │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/hcpdiff/tra │
│ in_ac.py:479 in forward                                                  │
│                                                                          │
│   476 │   │                                                              │
│   477 │   │   # CFG context for DreamArtist                              │478 │   │   noisy_latents, timesteps = self.cfg_context.pre(noisy_late │
│ ❱ 479 │   │   model_pred = self.encode_decode(prompt_ids, noisy_latents, │
│   480 │   │   model_pred = self.cfg_context.post(model_pred)             │
│   481 │   │                                                              │
│   482 │   │   # Get the target for loss depending on the prediction type │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/hcpdiff/tra │
│ in_ac_single.py:78 in encode_decode                                      │
│                                                                          │
│    75 │   │   │   │   feeder(input_all)                                  │
│    76 │   │                                                              │
│    77 │   │   encoder_hidden_states = self.text_encoder(prompt_ids, outp │
│ ❱  78 │   │   model_pred = self.unet(noisy_latents, timesteps, encoder_h │
│    79 │   │   return model_pred                                          │
│    80 │                                                                  │
│    81def get_loss(self, model_pred, target, timesteps, att_mask):   │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/torch/nn/mo │
│ dules/module.py:1130 in _call_impl                                       │
│                                                                          │
│   1127 │   │   # this function, and just call forward.                   │1128 │   │   if not (self._backward_hooks or self._forward_hooks or se │
│   1129 │   │   │   │   or _global_forward_hooks or _global_forward_pre_h │
│ ❱ 1130 │   │   │   return forward_call(*input, **kwargs)                 │
│   1131 │   │   # Do not call functions when jit is used                  │1132 │   │   full_backward_hooks, non_full_backward_hooks = [], []     │
│   1133 │   │   if self._backward_hooks or _global_backward_hooks:        │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/accelerate/ │
│ utils/operations.py:553 in forward                                       │
│                                                                          │
│   550model_forward = ConvertOutputsToFp32(model_forward)            │
│   551 │                                                                  │
│   552def forward(*args, **kwargs):                                  │
│ ❱ 553 │   │   return model_forward(*args, **kwargs)                      │
│   554 │                                                                  │
│   555# To act like a decorator so that it can be popped when doing  │556forward.__wrapped__ = model_forward                            │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/accelerate/ │
│ utils/operations.py:541 in __call__                                      │
│                                                                          │
│   538 │   │   update_wrapper(self, model_forward)                        │
│   539 │                                                                  │
│   540def __call__(self, *args, **kwargs):                           │
│ ❱ 541 │   │   return convert_to_fp32(self.model_forward(*args, **kwargs) │
│   542 │                                                                  │
│   543def __getstate__(self):                                        │
│   544 │   │   raise pickle.PicklingError(                                │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/torch/amp/a │
│ utocast_mode.py:12 in decorate_autocast                                  │
│                                                                          │
│     9 │   @functools.wraps(func)                                         │
│    10def decorate_autocast(*args, **kwargs):                        │
│    11 │   │   with autocast_instance:                                    │
│ ❱  12 │   │   │   return func(*args, **kwargs)                           │
│    13decorate_autocast.__script_unsupported = '@autocast() decorato │
│    14return decorate_autocast                                       │
│    15                                                                    │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/diffusers/m │
│ odels/unet_2d_condition.py:481 in forward                                │
│                                                                          │
│   478 │   │   down_block_res_samples = (sample,)                         │
│   479 │   │   for downsample_block in self.down_blocks:                  │
│   480 │   │   │   if hasattr(downsample_block, "has_cross_attention") an │
│ ❱ 481 │   │   │   │   sample, res_samples = downsample_block(            │
│   482 │   │   │   │   │   hidden_states=sample,                          │
│   483 │   │   │   │   │   temb=emb,                                      │
│   484 │   │   │   │   │   encoder_hidden_states=encoder_hidden_states,   │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/torch/nn/mo │
│ dules/module.py:1130 in _call_impl                                       │
│                                                                          │
│   1127 │   │   # this function, and just call forward.                   │1128 │   │   if not (self._backward_hooks or self._forward_hooks or se │
│   1129 │   │   │   │   or _global_forward_hooks or _global_forward_pre_h │
│ ❱ 1130 │   │   │   return forward_call(*input, **kwargs)                 │
│   1131 │   │   # Do not call functions when jit is used                  │1132 │   │   full_backward_hooks, non_full_backward_hooks = [], []     │
│   1133 │   │   if self._backward_hooks or _global_backward_hooks:        │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/diffusers/m │
│ odels/unet_2d_blocks.py:781 in forward                                   │
│                                                                          │
│    778 │   │   │   │   │   return custom_forward                         │
│    779 │   │   │   │                                                     │
│    780 │   │   │   │   hidden_states = torch.utils.checkpoint.checkpoint │
│ ❱  781 │   │   │   │   hidden_states = torch.utils.checkpoint.checkpoint │
│    782 │   │   │   │   │   create_custom_forward(attn, return_dict=False │
│    783 │   │   │   │   │   hidden_states,                                │
│    784 │   │   │   │   │   encoder_hidden_states,                        │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/hcpdiff/tra │
│ in_ac.py:48 in checkpoint_fix                                            │
│                                                                          │
│    45 # fix checkpoint bug for train part of model                       │46 import torch.utils.checkpoint                                      │
│    47 def checkpoint_fix(function, *args, use_reentrant: bool = False, c │
│ ❱  48return checkpoint_raw(function, *args, use_reentrant=use_reent │
│    49 torch.utils.checkpoint.checkpoint = checkpoint_fix                 │
│    50                                                                    │
│    51 class Trainer:                                                     │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/torch/utils │
│ /checkpoint.py:237 in checkpoint                                         │
│                                                                          │
│   234if use_reentrant:                                              │
│   235 │   │   return CheckpointFunction.apply(function, preserve, *args) │
│   236else:                                                          │
│ ❱ 237 │   │   return _checkpoint_without_reentrant(                      │
│   238 │   │   │   function,                                              │
│   239 │   │   │   preserve,                                              │
│   240 │   │   │   *args                                                  │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/torch/utils │
│ /checkpoint.py:383 in _checkpoint_without_reentrant                      │
│                                                                          │
│   380 │   │   return storage.pop(x)                                      │
│   381 │                                                                  │
│   382with torch.autograd.graph.saved_tensors_hooks(pack, unpack):   │
│ ❱ 383 │   │   output = function(*args)                                   │
│   384 │   │   if torch.cuda._initialized and preserve_rng_state and not  │
│   385 │   │   │   # Cuda was not initialized before running the forward, │386 │   │   │   # stash the CUDA state.                                │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/diffusers/m │
│ odels/unet_2d_blocks.py:774 in custom_forward                            │
│                                                                          │
│    771 │   │   │   │   def create_custom_forward(module, return_dict=Non │
│    772 │   │   │   │   │   def custom_forward(*inputs):                  │
│    773 │   │   │   │   │   │   if return_dict is not None:               │
│ ❱  774 │   │   │   │   │   │   │   return module(*inputs, return_dict=re │
│    775 │   │   │   │   │   │   else:                                     │
│    776 │   │   │   │   │   │   │   return module(*inputs)                │
│    777                                                                   │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/torch/nn/mo │
│ dules/module.py:1130 in _call_impl                                       │
│                                                                          │
│   1127 │   │   # this function, and just call forward.                   │1128 │   │   if not (self._backward_hooks or self._forward_hooks or se │
│   1129 │   │   │   │   or _global_forward_hooks or _global_forward_pre_h │
│ ❱ 1130 │   │   │   return forward_call(*input, **kwargs)                 │
│   1131 │   │   # Do not call functions when jit is used                  │1132 │   │   full_backward_hooks, non_full_backward_hooks = [], []     │
│   1133 │   │   if self._backward_hooks or _global_backward_hooks:        │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/diffusers/m │
│ odels/transformer_2d.py:265 in forward                                   │
│                                                                          │
│   262 │   │                                                              │
│   263 │   │   # 2. Blocks                                                │264 │   │   for block in self.transformer_blocks:                      │
│ ❱ 265 │   │   │   hidden_states = block(                                 │
│   266 │   │   │   │   hidden_states,                                     │
│   267 │   │   │   │   encoder_hidden_states=encoder_hidden_states,       │
│   268 │   │   │   │   timestep=timestep,                                 │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/torch/nn/mo │
│ dules/module.py:1130 in _call_impl                                       │
│                                                                          │
│   1127 │   │   # this function, and just call forward.                   │1128 │   │   if not (self._backward_hooks or self._forward_hooks or se │
│   1129 │   │   │   │   or _global_forward_hooks or _global_forward_pre_h │
│ ❱ 1130 │   │   │   return forward_call(*input, **kwargs)                 │
│   1131 │   │   # Do not call functions when jit is used                  │1132 │   │   full_backward_hooks, non_full_backward_hooks = [], []     │
│   1133 │   │   if self._backward_hooks or _global_backward_hooks:        │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/diffusers/m │
│ odels/attention.py:307 in forward                                        │
│                                                                          │
│   304 │   │   │   )                                                      │
│   305 │   │   │                                                          │
│   306 │   │   │   # 2. Cross-Attention                                   │
│ ❱ 307 │   │   │   attn_output = self.attn2(                              │
│   308 │   │   │   │   norm_hidden_states,                                │
│   309 │   │   │   │   encoder_hidden_states=encoder_hidden_states,       │
│   310 │   │   │   │   attention_mask=attention_mask,                     │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/torch/nn/mo │
│ dules/module.py:1130 in _call_impl                                       │
│                                                                          │
│   1127 │   │   # this function, and just call forward.                   │1128 │   │   if not (self._backward_hooks or self._forward_hooks or se │
│   1129 │   │   │   │   or _global_forward_hooks or _global_forward_pre_h │
│ ❱ 1130 │   │   │   return forward_call(*input, **kwargs)                 │
│   1131 │   │   # Do not call functions when jit is used                  │1132 │   │   full_backward_hooks, non_full_backward_hooks = [], []     │
│   1133 │   │   if self._backward_hooks or _global_backward_hooks:        │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/diffusers/m │
│ odels/cross_attention.py:160 in forward                                  │
│                                                                          │
│   157 │   │   # The `CrossAttention` class can call different attention  │158 │   │   # here we simply pass along all tensors to the selected pr │159 │   │   # For standard processors that are defined here, `**cross_ │
│ ❱ 160 │   │   return self.processor(                                     │
│   161 │   │   │   self,                                                  │
│   162 │   │   │   hidden_states,                                         │
│   163 │   │   │   encoder_hidden_states=encoder_hidden_states,           │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/diffusers/m │
│ odels/cross_attention.py:374 in __call__                                 │
│                                                                          │
│   371 │   │   key = attn.head_to_batch_dim(key).contiguous()             │
│   372 │   │   value = attn.head_to_batch_dim(value).contiguous()         │
│   373 │   │                                                              │
│ ❱ 374 │   │   hidden_states = xformers.ops.memory_efficient_attention(   │
│   375 │   │   │   query, key, value, attn_bias=attention_mask, op=self.a │
│   376 │   │   )                                                          │
│   377 │   │   hidden_states = hidden_states.to(query.dtype)              │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/xformers/op │
│ s/fmha/__init__.py:192 in memory_efficient_attention                     │
│                                                                          │
│   189 │   │   and options.                                               │
│   190 │   :return: multi-head attention Tensor with shape ``[B, Mq, H, K │
│   191 │   """                                                            │
│ ❱ 192 │   return _memory_efficient_attention(                            │
│   193 │   │   Inputs(                                                    │
│   194 │   │   │   query=query, key=key, value=value, p=p, attn_bias=attn │
│   195 │   │   ),                                                         │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/xformers/op │
│ s/fmha/__init__.py:295 in _memory_efficient_attention                    │
│                                                                          │
│   292 │   │   )                                                          │
│   293 │                                                                  │
│   294 │   output_shape = inp.normalize_bmhk()                            │
│ ❱ 295 │   return _fMHA.apply(                                            │
│   296 │   │   op, inp.query, inp.key, inp.value, inp.attn_bias, inp.p, i │
│   297 │   ).reshape(output_shape)                                        │
│   298                                                                    │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/xformers/op │
│ s/fmha/__init__.py:41 in forward                                         │
│                                                                          │
│    38 │   │   op_fw = op[0] if op is not None else None                  │
│    39 │   │   op_bw = op[1] if op is not None else None                  │
│    40 │   │                                                              │
│ ❱  41 │   │   out, op_ctx = _memory_efficient_attention_forward_requires │
│    42 │   │   │   inp=inp, op=op_fw                                      │
│    43 │   │   )                                                          │
│    44                                                                    │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/xformers/op │
│ s/fmha/__init__.py:323 in                                                │
│ _memory_efficient_attention_forward_requires_grad                        │
│                                                                          │
│   320 │   │   op = _dispatch_fw(inp)                                     │
│   321 │   else:                                                          │
│   322 │   │   _ensure_op_supports_or_raise(ValueError, "memory_efficient │
│ ❱ 323out = op.apply(inp, needs_gradient=True)                       │
│   324assert out[1] is not None                                      │
│   325return (out[0].reshape(output_shape), out[1])                  │
│   326                                                                    │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/xformers/op │
│ s/fmha/flash.py:235 in apply                                             │
│                                                                          │
│   232 │   │   │   max_seqlen_q,                                          │
│   233 │   │   │   cu_seqlens_k,                                          │
│   234 │   │   │   max_seqlen_k,                                          │
│ ❱ 235 │   │   ) = _convert_input_format(inp)                             │
│   236 │   │   out, softmax_lse, rng_state = cls.OPERATOR(                │
│   237 │   │   │   inp.query,                                             │
│   238 │   │   │   inp.key,                                               │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/xformers/op │
│ s/fmha/flash.py:177 in _convert_input_format                             │
│                                                                          │
│   174new_inp = replace(                                             │
│   175 │   │   inp,                                                       │
│   176 │   │   query=query.reshape([batch * seqlen_q, num_heads, head_dim │
│ ❱ 177 │   │   key=key.reshape([batch * seqlen_kv, num_heads, head_dim_q] │
│   178 │   │   value=value.reshape([batch * seqlen_kv, num_heads, head_di │
│   179 │   )                                                              │
│   180softmax_scale = inp.query.shape[-1] ** (-0.5) if inp.scale is  │
╰──────────────────────────────────────────────────────────────────────────╯
RuntimeError: shape '[616, 1, 40]' is invalid for input of size 49280       
``` `
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant