diff --git a/deepspeed/sequence/fpdt_layer.py b/deepspeed/sequence/fpdt_layer.py index 4fab768ce63ca..4fa2cc988a194 100644 --- a/deepspeed/sequence/fpdt_layer.py +++ b/deepspeed/sequence/fpdt_layer.py @@ -47,7 +47,7 @@ def _update_out_and_lse( block_out = block_out.to(torch.float32) block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) - new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) + new_lse = lse + torch.log1p(torch.exp(block_lse - lse)) out = torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out