Skip to content

Unnecessarily scaling gradients when gradient_accumulation_steps is 1 #2515

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
shunting314 opened this issue Mar 19, 2025 · 1 comment
Open
Labels
best practice Things we should be doing but aren't enhancement New feature or request

Comments

@shunting314
Copy link

shunting314 commented Mar 19, 2025

The scale_grads call here will need to read and write each parameter once. For a 8B parameters model with bfloat16 dtype, that results in 8B * 2byte * 2 = 32GB memory access. Assuming a 2TBGS memory bandwidth, this will translate to 16ms latency. Note that this estimation is a lower bound. In practice, due to kernel launch overhead, this latency can be as larger as 42.5ms as tested llama3.1 8B model on a H100:

Image (check the long blue bar)

Simple change like

diff --git a/recipes/full_finetune_single_device.py b/recipes/full_finetune_single_device.py
index 39e596ae..83007343 100644
--- a/recipes/full_finetune_single_device.py
+++ b/recipes/full_finetune_single_device.py
@@ -692,14 +692,21 @@ class FullFinetuneRecipeSingleDevice(FTRecipeInterface):

                 # Loss is normalized by default so we multiply by the number of tokens
                 # This way we can normalize by the total number of tokens if we're accumulating gradients
-                current_loss = self._loss_step(batch) * current_num_tokens
-                running_loss += current_loss
-                current_loss.backward()
+
+                if self._gradient_accumulation_steps > 1:
+                    current_loss = self._loss_step(batch) * current_num_tokens
+                    running_loss += current_loss
+                    current_loss.backward()
+                else:
+                    current_loss = self._loss_step(batch)
+                    running_loss += current_loss * current_num_tokens
+                    current_loss.backward()

                 # Step with optimizer
                 if (idx + 1) % self._gradient_accumulation_steps == 0:
                     if not self._optimizer_in_bwd:
-                        training.scale_grads(self._model, 1 / num_tokens)
+                        if self._gradient_accumulation_steps > 1:
+                            training.scale_grads(self._model, 1 / num_tokens)
                         if self._clip_grad_norm is not None:
                             grad_norm = torch.nn.utils.clip_grad_norm_(
                                 self._model.parameters(),

can just skip this scale_grads call when gradient accumulation steps is 1.

For batch size 16, the total latency for each batch is about 800+ms, this means we can speedup 5%. For larger batch size, the absolution latency saving should not change but the relative speedup will be smaller.

cc @ebsmothers @IvanKobzarev

@felipemello1 felipemello1 added enhancement New feature or request best practice Things we should be doing but aren't labels Mar 24, 2025
@nathan-az
Copy link
Contributor

nathan-az commented Mar 31, 2025

Just a note if there's appetite for this enhancement - in the gradient accumulation case, if we know the token count for each update step ahead of time, I think the loss can be scaled correctly before each backward step, completely avoid the separate scaling of the grad norms.

I don't know that this would be straightforward without a pretty significant refactor of the training loop. Below is an example.

for idx in range(self._steps_per_epoch):
    # note that below uses full update steps to determine if profiler should begin
    if (
        self._is_rank_zero
        and curr_epoch == 0
        and self.profiler_profile_memory
        and idx == self.profiler_wait_steps + self.profiler_warmup_steps
        and self._device.type == "cuda"
    ):
        torch.cuda.memory._record_memory_history()
    
    
    # pre-fetch examples and move to device
    batches = []
    token_counts = []
    
    # track fwd_step for hybrid sharding option
    for fwd_step in range(self._gradient_accumulation_steps):
        try:
            batch = next(self._dataloader)
        except StopIteration:
            break
        utils.batch_to_device(batch, self._device)
        current_num_tokens = (
            batch["labels"] != self._loss_fn.ignore_index
        ).sum()
        num_tokens += current_num_tokens
        batches.append(batch)
        token_counts.append(current_num_tokens.item())
        
    if len(batches) == 0:
        # not even a partial group of batches left
        break
        
    torch.distributed.all_reduce(num_tokens)

    for token_count, batch in zip(token_counts, batches):
        labels = batch.pop("labels")
        # prepare inputs
        current_loss = self._loss_fn(logits, labels) * (self.dp_size * token_count / num_tokens)
        del logits
        running_loss += current_loss

    current_loss.backward()
    
    if not self._optimizer_in_bwd:
        if self._clip_grad_norm is not None:
            grad_norm = torch.nn.utils.clip_grad_norm_(
                self._model.parameters(),
                max_norm=float(self._clip_grad_norm),
            )
            # If sharded, collect the DTensor here
            if isinstance(grad_norm, DTensor):
                grad_norm = grad_norm.full_tensor()
        self._optimizer.step()
        self._optimizer.zero_grad(set_to_none=True)

I'm not certain the above is correct. However I think it's on the right track and removes the scaling requirement completely. In addition, there might be some slight precision benefit to scaling the loss this way, since token_count / num_tokens (1/128 in a pretty large examples-per-update case) is likely less extreme than scaling by the token count (potentially over 100k).

IMO it also reads a bit more nicely with less nesting in if statements, but that may be personal preference..

If there's interest in the above, I'm happy to test that it's equivalent and make a PR for the full_finetune_distributed recipe.

EDIT: I got curious about this, made a PR!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
best practice Things we should be doing but aren't enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants