From 632cb739d01ba18849ead5f7c5923b21d836ad4f Mon Sep 17 00:00:00 2001 From: Shashank Rajput <144760128+ShashankMosaicML@users.noreply.github.com> Date: Tue, 2 Apr 2024 10:05:35 -0700 Subject: [PATCH] Doing the loss reduction in foundry instead of in the loss functions. (#1079) * setting loss_fn reduction to None * fixing a unit test * add error message * adding test to check reduction * adding test to check reduction * Update llmfoundry/models/mpt/modeling_mpt.py Co-authored-by: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> * preserving batch dimension of targets * minor change --------- Co-authored-by: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- llmfoundry/models/mpt/modeling_mpt.py | 17 ++++- tests/models/test_model.py | 101 +++++++++++++++++++++++++- 2 files changed, 113 insertions(+), 5 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index e0a666f62c..016473195a 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -984,7 +984,8 @@ def __init__( from flash_attn.losses.cross_entropy import \ CrossEntropyLoss as FusedCrossEntropyLoss - self.loss_fn = FusedCrossEntropyLoss(ignore_index=-100) + self.loss_fn = FusedCrossEntropyLoss(ignore_index=-100, + reduction='none') except: raise ValueError( 'Fused Cross Entropy is not installed. Either (1) have a CUDA-compatible GPU ' @@ -994,7 +995,8 @@ def __init__( 'if installing from pypi, or (2) set your config model.loss_fn=torch_crossentropy.' ) elif loss_fn_config == 'torch_crossentropy': - self.loss_fn = nn.CrossEntropyLoss(ignore_index=-100) + self.loss_fn = nn.CrossEntropyLoss(ignore_index=-100, + reduction='none') else: raise ValueError( f'Specified loss_fn={self.loss_fn} not recognized. `loss_fn` must be one of [`fused_crossentropy`, `torch_crossentropy`].' @@ -1016,8 +1018,15 @@ def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast: def loss(self, outputs: CausalLMOutputWithPast, batch: Mapping) -> torch.Tensor: targets = self.get_targets(batch) - return self.loss_fn(outputs.logits.view(-1, outputs.logits.size(-1)), - targets.view(-1)) + losses = self.loss_fn(outputs.logits.view(-1, outputs.logits.size(-1)), + targets.view(-1)) + + if torch.all(targets == self.loss_fn.ignore_index): + loss = losses.sum() + else: + loss = losses.sum() / (targets != self.loss_fn.ignore_index).sum() + + return loss def flops_per_batch(self, batch: Mapping) -> int: # Note: this computation does not take into account padding, and assumes diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 464423f512..c5f6062b0e 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -480,7 +480,7 @@ def test_loss_fn(): model_2.to(test_cfg.device) assert isinstance(model_1.loss_fn, torch.nn.CrossEntropyLoss) - model_2.loss_fn = FusedCrossEntropyLoss(ignore_index=-100) + model_2.loss_fn = FusedCrossEntropyLoss(ignore_index=-100, reduction='none') optimizer_1 = DecoupledAdamW(model_1.parameters(), lr=test_cfg.optimizer.lr, @@ -517,6 +517,105 @@ def test_loss_fn(): atol=1e-4), f'differed at step {i}' +@pytest.mark.gpu +@pytest.mark.parametrize('loss_fn_config', + ['torch_crossentropy', 'fused_crossentropy']) +def test_loss_reduction(loss_fn_config: str): + """Tests the Fused CrossEntropy vs torch.nn.CrossEntropy loss function. + + We provide non-zero tolerances to account for small numerics differences + between the two loss implementations. + """ + try: + from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss # type: ignore # isort: skip + except: + pytest.skip('Fused cross entropy was not installed') + + # run numerical test in pure fp32 + from torch.backends import cuda, cudnn + cuda.matmul.allow_tf32 = False + cudnn.allow_tf32 = False + + conf_path = 'scripts/train/yamls/pretrain/testing.yaml' + with open(conf_path) as f: + test_cfg = om.load(f) + + assert isinstance(test_cfg, DictConfig) + + test_cfg.model.loss_fn = loss_fn_config + + test_cfg.device = 'cuda:0' + test_cfg.model.init_device = 'cpu' + test_cfg.model.init_config = { + 'name': 'baseline_', + 'init_std': 0.02, + } + + tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(test_cfg.tokenizer) + tokenizer = build_tokenizer(test_cfg.tokenizer.name, + tokenizer_cfg.get('kwargs', {})) + + model_1 = build_composer_model( + name=test_cfg.model.name, + cfg=test_cfg.model, + tokenizer=tokenizer, + ) + model_2 = copy.deepcopy(model_1) + + model_1.to(test_cfg.device) + model_2.to(test_cfg.device) + + # Reduce the loss in FusedCrossEntropyLoss + if loss_fn_config == 'fused_crossentropy': + assert isinstance(model_1.loss_fn, FusedCrossEntropyLoss) + model_2.loss_fn = FusedCrossEntropyLoss(ignore_index=-100, + reduction='mean') + else: + assert isinstance(model_1.loss_fn, torch.nn.CrossEntropyLoss) + model_2.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100, + reduction='mean') + + optimizer_1 = DecoupledAdamW(model_1.parameters(), + lr=test_cfg.optimizer.lr, + betas=test_cfg.optimizer.betas, + eps=test_cfg.optimizer.eps, + weight_decay=test_cfg.optimizer.weight_decay) + optimizer_2 = DecoupledAdamW(model_2.parameters(), + lr=test_cfg.optimizer.lr, + betas=test_cfg.optimizer.betas, + eps=test_cfg.optimizer.eps, + weight_decay=test_cfg.optimizer.weight_decay) + + for i in range(3): + batch = gen_random_batch(2, test_cfg) + output_1 = model_1(batch) + output_2 = model_2(batch) + assert output_1.logits.allclose(output_2.logits, rtol=1e-4, + atol=1e-4), f'differed at step {i}' + + loss_1 = model_1.loss(output_1, batch) + + # Loss for model_2 gets reduced within the loss_fn, so we handle it separately + targets = model_2.get_targets(batch) + loss_2 = model_2.loss_fn( + output_2.logits.view(-1, output_2.logits.size(-1)), + targets.view(-1)) + + assert isinstance(loss_1, torch.Tensor) + assert isinstance(loss_2, torch.Tensor) + assert loss_1.allclose(loss_2, rtol=1e-3, + atol=1e-3), f'differed at step {i}' + loss_1.backward() + loss_2.backward() + optimizer_1.step() + optimizer_2.step() + + for p1, p2 in zip(model_1.parameters(), model_2.parameters()): + assert p1.data.shape == p2.data.shape + assert p1.data.allclose(p2.data, rtol=1e-5, + atol=1e-4), f'differed at step {i}' + + @pytest.mark.parametrize('peft_config', [ None, {