Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

loss.detach().clone().mean() * (microbatch_size / current_batch_size #1596

Open
YixinSong-e opened this issue Oct 17, 2024 · 4 comments
Open
Labels
bug Something isn't working

Comments

@YixinSong-e
Copy link

When I set moe_loss_weight:0

[rank7]:   File "/home/syx/miniconda3/envs/lmf/lib/python3.11/site-packages/composer/trainer/trainer.py", line 2907, in <lambda>
[rank7]:     **kwargs: self._train_microbatches(microbatches, loss_dict, **kwargs).item(),
[rank7]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/home/syx/miniconda3/envs/lmf/lib/python3.11/site-packages/composer/trainer/trainer.py", line 3075, in _train_microbatches
[rank7]:     microbatch_loss_dict = self._train_microbatch(use_grad_scaling, current_batch_size, is_final_microbatch)
[rank7]:                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/home/syx/miniconda3/envs/lmf/lib/python3.11/site-packages/composer/trainer/trainer.py", line 3209, in _train_microbatch
[rank7]:     microbatch_loss_dict[k] = loss.detach().clone().mean() * (microbatch_size / current_batch_size)
[rank7]:                               ^^^^^^^^^^^
[rank7]: AttributeError: 'float' object has no attribute 'detach'
@YixinSong-e YixinSong-e added the bug Something isn't working label Oct 17, 2024
@XiaohanZhangCMU
Copy link
Contributor

Hi @YixinSong-e thanks for bringing up the issue.
Can you send me the yaml you are using? specifically which image are you running off?

@YixinSong-e
Copy link
Author

I find the reason. When moe_loss_weight is set to 0, the megablocks will return a float number, not a tensor. I fix the issue by bypass ths loss['lbl'].

@XiaohanZhangCMU
Copy link
Contributor

Hi @YixinSong-e Can you explain in a bit more details what you mean by "bypass the loss['lbl']"?

@YixinSong-e
Copy link
Author

YixinSong-e commented Oct 19, 2024

In llmfoundry/models/mpt/modeling_mpt.py file,
the loss computation method is shown below:

    def loss(self, outputs: CausalLMOutputWithPast,
             batch: Mapping) -> Union[dict, torch.Tensor]:
        loss = compute_loss_from_logits(
            outputs,
            self.shift_labels,
            batch['labels'],
            self.loss_fn,
            batch.get('sample_weighing_factor', None),
        )

        if self.config.ffn_config['ffn_type'] in ffns_with_megablocks:
            # MegaBlocks MoE load balancing loss
            try:  # Add try/catch to avoid transformers complaining and raising errors
                from megablocks.layers.moe import batched_load_balancing_loss
            except:
                raise RuntimeError(
                    'Requirements for MegaBlocks not installed; see install instructions in `README.md`.',
                )
            lbl = batched_load_balancing_loss(self.model.transformer.mb_args)
            return {
                'total': loss + lbl,
                'loss': loss,
                'lbl': lbl,
            }
        return loss

When I set the moe_loss_weight to 0, the result returned from batched_load_balancing_loss is a float number 0.0, not a tensor, which means {'loss': 0.0}.
And in composer trainer composer/trainer/trainer.py line 3289, the loss

for k, loss in microbatch_loss_dict.items():
                microbatch_loss_dict[k] = loss.detach().clone().mean() * (microbatch_size / current_batch_size)

But 0.0 is not a tensor, which do not have the detach() function.
So I just let loss function just return the 'total' loss to bypass this error.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants