Skip to content

Commit

Permalink
Doing the loss reduction in foundry instead of in the loss functions. (
Browse files Browse the repository at this point in the history
…#1079)

* setting loss_fn reduction to None

* fixing a unit test

* add error message

* adding test to check reduction

* adding test to check reduction

* Update llmfoundry/models/mpt/modeling_mpt.py

Co-authored-by: Vitaliy Chiley <[email protected]>

* preserving batch dimension of targets

* minor change

---------

Co-authored-by: Vitaliy Chiley <[email protected]>
Co-authored-by: Daniel King <[email protected]>
  • Loading branch information
3 people authored Apr 2, 2024
1 parent caf7fda commit 632cb73
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 5 deletions.
17 changes: 13 additions & 4 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -984,7 +984,8 @@ def __init__(
from flash_attn.losses.cross_entropy import \
CrossEntropyLoss as FusedCrossEntropyLoss

self.loss_fn = FusedCrossEntropyLoss(ignore_index=-100)
self.loss_fn = FusedCrossEntropyLoss(ignore_index=-100,
reduction='none')
except:
raise ValueError(
'Fused Cross Entropy is not installed. Either (1) have a CUDA-compatible GPU '
Expand All @@ -994,7 +995,8 @@ def __init__(
'if installing from pypi, or (2) set your config model.loss_fn=torch_crossentropy.'
)
elif loss_fn_config == 'torch_crossentropy':
self.loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
self.loss_fn = nn.CrossEntropyLoss(ignore_index=-100,
reduction='none')
else:
raise ValueError(
f'Specified loss_fn={self.loss_fn} not recognized. `loss_fn` must be one of [`fused_crossentropy`, `torch_crossentropy`].'
Expand All @@ -1016,8 +1018,15 @@ def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast:
def loss(self, outputs: CausalLMOutputWithPast,
batch: Mapping) -> torch.Tensor:
targets = self.get_targets(batch)
return self.loss_fn(outputs.logits.view(-1, outputs.logits.size(-1)),
targets.view(-1))
losses = self.loss_fn(outputs.logits.view(-1, outputs.logits.size(-1)),
targets.view(-1))

if torch.all(targets == self.loss_fn.ignore_index):
loss = losses.sum()
else:
loss = losses.sum() / (targets != self.loss_fn.ignore_index).sum()

return loss

def flops_per_batch(self, batch: Mapping) -> int:
# Note: this computation does not take into account padding, and assumes
Expand Down
101 changes: 100 additions & 1 deletion tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ def test_loss_fn():
model_2.to(test_cfg.device)

assert isinstance(model_1.loss_fn, torch.nn.CrossEntropyLoss)
model_2.loss_fn = FusedCrossEntropyLoss(ignore_index=-100)
model_2.loss_fn = FusedCrossEntropyLoss(ignore_index=-100, reduction='none')

optimizer_1 = DecoupledAdamW(model_1.parameters(),
lr=test_cfg.optimizer.lr,
Expand Down Expand Up @@ -517,6 +517,105 @@ def test_loss_fn():
atol=1e-4), f'differed at step {i}'


@pytest.mark.gpu
@pytest.mark.parametrize('loss_fn_config',
['torch_crossentropy', 'fused_crossentropy'])
def test_loss_reduction(loss_fn_config: str):
"""Tests the Fused CrossEntropy vs torch.nn.CrossEntropy loss function.
We provide non-zero tolerances to account for small numerics differences
between the two loss implementations.
"""
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss # type: ignore # isort: skip
except:
pytest.skip('Fused cross entropy was not installed')

# run numerical test in pure fp32
from torch.backends import cuda, cudnn
cuda.matmul.allow_tf32 = False
cudnn.allow_tf32 = False

conf_path = 'scripts/train/yamls/pretrain/testing.yaml'
with open(conf_path) as f:
test_cfg = om.load(f)

assert isinstance(test_cfg, DictConfig)

test_cfg.model.loss_fn = loss_fn_config

test_cfg.device = 'cuda:0'
test_cfg.model.init_device = 'cpu'
test_cfg.model.init_config = {
'name': 'baseline_',
'init_std': 0.02,
}

tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(test_cfg.tokenizer)
tokenizer = build_tokenizer(test_cfg.tokenizer.name,
tokenizer_cfg.get('kwargs', {}))

model_1 = build_composer_model(
name=test_cfg.model.name,
cfg=test_cfg.model,
tokenizer=tokenizer,
)
model_2 = copy.deepcopy(model_1)

model_1.to(test_cfg.device)
model_2.to(test_cfg.device)

# Reduce the loss in FusedCrossEntropyLoss
if loss_fn_config == 'fused_crossentropy':
assert isinstance(model_1.loss_fn, FusedCrossEntropyLoss)
model_2.loss_fn = FusedCrossEntropyLoss(ignore_index=-100,
reduction='mean')
else:
assert isinstance(model_1.loss_fn, torch.nn.CrossEntropyLoss)
model_2.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100,
reduction='mean')

optimizer_1 = DecoupledAdamW(model_1.parameters(),
lr=test_cfg.optimizer.lr,
betas=test_cfg.optimizer.betas,
eps=test_cfg.optimizer.eps,
weight_decay=test_cfg.optimizer.weight_decay)
optimizer_2 = DecoupledAdamW(model_2.parameters(),
lr=test_cfg.optimizer.lr,
betas=test_cfg.optimizer.betas,
eps=test_cfg.optimizer.eps,
weight_decay=test_cfg.optimizer.weight_decay)

for i in range(3):
batch = gen_random_batch(2, test_cfg)
output_1 = model_1(batch)
output_2 = model_2(batch)
assert output_1.logits.allclose(output_2.logits, rtol=1e-4,
atol=1e-4), f'differed at step {i}'

loss_1 = model_1.loss(output_1, batch)

# Loss for model_2 gets reduced within the loss_fn, so we handle it separately
targets = model_2.get_targets(batch)
loss_2 = model_2.loss_fn(
output_2.logits.view(-1, output_2.logits.size(-1)),
targets.view(-1))

assert isinstance(loss_1, torch.Tensor)
assert isinstance(loss_2, torch.Tensor)
assert loss_1.allclose(loss_2, rtol=1e-3,
atol=1e-3), f'differed at step {i}'
loss_1.backward()
loss_2.backward()
optimizer_1.step()
optimizer_2.step()

for p1, p2 in zip(model_1.parameters(), model_2.parameters()):
assert p1.data.shape == p2.data.shape
assert p1.data.allclose(p2.data, rtol=1e-5,
atol=1e-4), f'differed at step {i}'


@pytest.mark.parametrize('peft_config', [
None,
{
Expand Down

0 comments on commit 632cb73

Please sign in to comment.