From b5b4134b0306d63a974ad6be8eee81fe789dd954 Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Tue, 7 Jan 2025 15:46:32 -0800 Subject: [PATCH] Use `torch.log1p` This function provides greater precision than `log(1 + x)` for small values of `x`. Found with TorchFix https://github.com/pytorch-labs/torchfix/ --- deepspeed/sequence/fpdt_layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/sequence/fpdt_layer.py b/deepspeed/sequence/fpdt_layer.py index 4fab768ce63c..4fa2cc988a19 100644 --- a/deepspeed/sequence/fpdt_layer.py +++ b/deepspeed/sequence/fpdt_layer.py @@ -47,7 +47,7 @@ def _update_out_and_lse( block_out = block_out.to(torch.float32) block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) - new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) + new_lse = lse + torch.log1p(torch.exp(block_lse - lse)) out = torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out