Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Doing the loss reduction in foundry instead of in the loss functions. #1079

Merged
17 changes: 13 additions & 4 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -984,7 +984,8 @@ def __init__(
from flash_attn.losses.cross_entropy import \
CrossEntropyLoss as FusedCrossEntropyLoss

self.loss_fn = FusedCrossEntropyLoss(ignore_index=-100)
self.loss_fn = FusedCrossEntropyLoss(ignore_index=-100,
vchiley marked this conversation as resolved.
Show resolved Hide resolved
reduction='none')
except:
raise ValueError(
'Fused Cross Entropy is not installed. Either (1) have a CUDA-compatible GPU '
Expand All @@ -994,7 +995,8 @@ def __init__(
'if installing from pypi, or (2) set your config model.loss_fn=torch_crossentropy.'
)
elif loss_fn_config == 'torch_crossentropy':
self.loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
self.loss_fn = nn.CrossEntropyLoss(ignore_index=-100,
reduction='none')
else:
raise ValueError(
f'Specified loss_fn={self.loss_fn} not recognized. `loss_fn` must be one of [`fused_crossentropy`, `torch_crossentropy`].'
Expand All @@ -1016,8 +1018,15 @@ def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast:
def loss(self, outputs: CausalLMOutputWithPast,
batch: Mapping) -> torch.Tensor:
targets = self.get_targets(batch)
return self.loss_fn(outputs.logits.view(-1, outputs.logits.size(-1)),
targets.view(-1))
losses = self.loss_fn(outputs.logits.view(-1, outputs.logits.size(-1)),
targets.view(-1))

if torch.all(targets == self.loss_fn.ignore_index):
loss = losses.sum()
else:
loss = losses.sum() / (targets != self.loss_fn.ignore_index).sum()
vchiley marked this conversation as resolved.
Show resolved Hide resolved

return loss

def flops_per_batch(self, batch: Mapping) -> int:
# Note: this computation does not take into account padding, and assumes
Expand Down
101 changes: 100 additions & 1 deletion tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ def test_loss_fn():
model_2.to(test_cfg.device)

assert isinstance(model_1.loss_fn, torch.nn.CrossEntropyLoss)
model_2.loss_fn = FusedCrossEntropyLoss(ignore_index=-100)
model_2.loss_fn = FusedCrossEntropyLoss(ignore_index=-100, reduction='none')

optimizer_1 = DecoupledAdamW(model_1.parameters(),
lr=test_cfg.optimizer.lr,
Expand Down Expand Up @@ -517,6 +517,105 @@ def test_loss_fn():
atol=1e-4), f'differed at step {i}'


@pytest.mark.gpu
@pytest.mark.parametrize('loss_fn_config',
['torch_crossentropy', 'fused_crossentropy'])
def test_loss_reduction(loss_fn_config: str):
"""Tests the Fused CrossEntropy vs torch.nn.CrossEntropy loss function.

We provide non-zero tolerances to account for small numerics differences
between the two loss implementations.
"""
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss # type: ignore # isort: skip
except:
pytest.skip('Fused cross entropy was not installed')

# run numerical test in pure fp32
from torch.backends import cuda, cudnn
cuda.matmul.allow_tf32 = False
cudnn.allow_tf32 = False

conf_path = 'scripts/train/yamls/pretrain/testing.yaml'
with open(conf_path) as f:
test_cfg = om.load(f)

assert isinstance(test_cfg, DictConfig)

test_cfg.model.loss_fn = loss_fn_config

test_cfg.device = 'cuda:0'
test_cfg.model.init_device = 'cpu'
test_cfg.model.init_config = {
'name': 'baseline_',
'init_std': 0.02,
}

tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(test_cfg.tokenizer)
tokenizer = build_tokenizer(test_cfg.tokenizer.name,
tokenizer_cfg.get('kwargs', {}))

model_1 = build_composer_model(
name=test_cfg.model.name,
cfg=test_cfg.model,
tokenizer=tokenizer,
)
model_2 = copy.deepcopy(model_1)

model_1.to(test_cfg.device)
model_2.to(test_cfg.device)

# Reduce the loss in FusedCrossEntropyLoss
if loss_fn_config == 'fused_crossentropy':
assert isinstance(model_1.loss_fn, FusedCrossEntropyLoss)
model_2.loss_fn = FusedCrossEntropyLoss(ignore_index=-100,
reduction='mean')
else:
assert isinstance(model_1.loss_fn, torch.nn.CrossEntropyLoss)
model_2.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100,
reduction='mean')

optimizer_1 = DecoupledAdamW(model_1.parameters(),
lr=test_cfg.optimizer.lr,
betas=test_cfg.optimizer.betas,
eps=test_cfg.optimizer.eps,
weight_decay=test_cfg.optimizer.weight_decay)
optimizer_2 = DecoupledAdamW(model_2.parameters(),
lr=test_cfg.optimizer.lr,
betas=test_cfg.optimizer.betas,
eps=test_cfg.optimizer.eps,
weight_decay=test_cfg.optimizer.weight_decay)

for i in range(3):
batch = gen_random_batch(2, test_cfg)
output_1 = model_1(batch)
output_2 = model_2(batch)
assert output_1.logits.allclose(output_2.logits, rtol=1e-4,
atol=1e-4), f'differed at step {i}'

loss_1 = model_1.loss(output_1, batch)

# Loss for model_2 gets reduced within the loss_fn, so we handle it separately
targets = model_2.get_targets(batch)
loss_2 = model_2.loss_fn(
output_2.logits.view(-1, output_2.logits.size(-1)),
targets.view(-1))

assert isinstance(loss_1, torch.Tensor)
assert isinstance(loss_2, torch.Tensor)
assert loss_1.allclose(loss_2, rtol=1e-3,
atol=1e-3), f'differed at step {i}'
loss_1.backward()
loss_2.backward()
optimizer_1.step()
optimizer_2.step()

for p1, p2 in zip(model_1.parameters(), model_2.parameters()):
assert p1.data.shape == p2.data.shape
assert p1.data.allclose(p2.data, rtol=1e-5,
atol=1e-4), f'differed at step {i}'


@pytest.mark.parametrize('peft_config', [
None,
{
Expand Down
Loading