Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

4 bit Adam should support non constant lr #730

Closed
msaroufim opened this issue Aug 22, 2024 · 12 comments · Fixed by #736
Closed

4 bit Adam should support non constant lr #730

msaroufim opened this issue Aug 22, 2024 · 12 comments · Fixed by #736

Comments

@msaroufim
Copy link
Member

msaroufim commented Aug 22, 2024

Our low bit optimizers were merged in HF huggingface/transformers#31865 but

We have a known limitation that the 4 bit optimizer is not great when we don't have a constant learning rate

This is mentioned in the README https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim

Known issue: When learning rate is updated every step (e.g. using cosine learning rate scheduler), training speed is slower. This is because we have to convert learning rate to a CUDA tensor (which incurs expensive memory transfer cost), since torch.compile() will treat a Python float as a constant and trigger recompile whenever the value is changed.

However this is preventing @winglian from adopting this work

cc @gau-nernst @mlazos @janeyx99

@janeyx99
Copy link
Contributor

janeyx99 commented Aug 22, 2024

Wait a moment, what does "convert learning rate to a CUDA tensor" mean @gau-nernst? That should only be done once during initialization, and not for every step.

It is true that compiled optimizers do not performantly support python float for lr, but they should support Tensors without incurring more recompiles. As in, the user should be able to init an optimizer with a Tensor lr and have the compiled optimizer run without recompiles.

@mlazos
Copy link

mlazos commented Aug 22, 2024

@msaroufim @winglian can you provide a repro + the slowdown that is occurring? There is an expected slowdown (I would imagine it being very small) since we are running more ops, but I'm surprised that it would even get close to slower than eager and prevent adoption of a compiled optimizer. Also, unrelated question are we able to compile the 4 bit optimizer today?

@msaroufim
Copy link
Member Author

For a repro you can run https://github.com/pytorch/ao/blob/main/benchmarks/benchmark_low_bit_adam.py#L109 and enable cosine_lr_scheduler

And this is the learning rate conversion logic https://github.com/pytorch/ao/blob/main/torchao/prototype/low_bit_optim/adamw.py#L80-L84

@msaroufim msaroufim changed the title Feedback on 4 bit Adam optimizer 4 bit Adam optimizer should support non constant lr Aug 22, 2024
@msaroufim msaroufim changed the title 4 bit Adam optimizer should support non constant lr 4 bit Adam should support non constant lr Aug 22, 2024
@janeyx99
Copy link
Contributor

And this is the learning rate conversion logic main/torchao/prototype/low_bit_optim/adamw.py#L80-L84

Can we avoid this if the user passes in a Tensor LR + if we maintain a device -> lr dictionary in the single_param implementation instead of casting during every step? https://github.com/pytorch/ao/blob/main/torchao/prototype/low_bit_optim/adamw.py#L14
Similar to how we do it in fused_adam https://github.com/pytorch/pytorch/blob/main/torch/optim/adamw.py#L653

@mlazos
Copy link

mlazos commented Aug 22, 2024

And this is the learning rate conversion logic main/torchao/prototype/low_bit_optim/adamw.py#L80-L84

Can we avoid this if the user passes in a Tensor LR + if we maintain a device -> lr dictionary in the single_param implementation instead of casting during every step? main/torchao/prototype/low_bit_optim/adamw.py#L14 Similar to how we do it in fused_adam pytorch/pytorch@main/torch/optim/adamw.py#L653

Agreed with Jane, I don't think we want to be allocating a tensor on every step. Although I'm surprised allocating a scalar tensor results in a 3-4% slowdown but I believe u. IIRC, if you have a scalar tensor LR wouldn't it get moved to the correct device at kernel launch regardless? Do we need a full dict?

@janeyx99
Copy link
Contributor

janeyx99 commented Aug 22, 2024

Oh right, I'd try the steps in order then:

  1. Don't auto-convert lr inside the optimizer, but rather have the default be a torch.tensor(1e-3). on CPU!!
  2. If the user inputs a Tensor LR on CUDA (they shouldn't), only then would we need a full dictionary.

@winglian
Copy link

@msaroufim I don't have a reproducer offhand, but the last time I tried the 4bit optimizer was a couple of weeks ago in the context of torchtune and iirc, I thought it was an order of magnitude slower than a regular optimizer

@gau-nernst
Copy link
Collaborator

Also, unrelated question are we able to compile the 4 bit optimizer today?

Yes. (Though there are some memory issues with torch.compile() which forced me to do a workaround -

# NOTE: right now, torch.compile(param_groups_adam) will have excessive memory usage for 4-bit optim.
# thus, as a workaround, we use torch.compile(single_param_adam) and call it for each param.
. Opened an issue here but no follow-up 😢 pytorch/pytorch#131294)

Don't auto-convert lr inside the optimizer, but rather have the default be a torch.tensor(1e-3). on CPU!!

I'm a bit lost here. How would this help? Do you mean I can do

def f(lr_tensor_scalar_cpu, param_tensor_cuda, exp_avg_tensor_cuda...):
    ...

torch.compile(f)(...)

In other words, if CPU tensor is a scalar, I don't need to move it to CUDA? I thought everything needs to be on the same device. (I can try it later, not with my machine atm)

If the user inputs a Tensor LR on CUDA (they shouldn't), only then would we need a full dictionary.

I'm also not following why having a dictionary would help 😢.

Ultimately, it depends on how the LR is updated by the user. Personally this is how I normally update the LR (note: the LR from .get_lr(step) is a Python float)

lr = lr_schedule.get_lr(step)
for param_group in optim.param_groups:
param_group["lr"] = lr

This is inspired by timm here and here. I have tried with param_group["lr"].copy_(lr_python_float) but didn't see a difference.

I tend to avoid LR scheduler from torch.optim.lr_scheduler since they are stateful -> requiring saving and loading the LR scheduler state when resuming training (and I think they are generally overly complex 😅).

Also note that the casting python float LR to CUDA tensor is done in-place, so if the LR is not updated by the user, there is no casting in subsequent steps.

if not isinstance(group["lr"], Tensor):
group["lr"] = torch.tensor(group["lr"], device=p.device)

I notice that axolotl uses LambdaLR https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/schedulers.py. Perhaps I can add it to my test script too, since it might be a popular way to update LR?

The workaround I would recommend is to only update the LR every xx steps, so that the slow down is amortized over a longer period.

@gau-nernst
Copy link
Collaborator

@winglian when you observed the slowdown, did you use a LR scheduler? To isolate whether it is the LR schedule issue or the 4-bit optimizer itself is slow.

@janeyx99
Copy link
Contributor

@gau-nernst LR does not need to be on CUDA as it will get special treatment as a scalar Tensor, and probably should not be on CUDA. For now, let's ignore the lr dictionary.

The proposal is for the LR to be allowed as a CPU ScalarTensor from the start (from the constructor), and for all LRScheduling to keep it a Tensor so that it is dynamic enough for compiling. If the LR from .get_lr(step) is a Python float, you can do lr.fill_(lrschedule.get_lr(step)) if lr is a Tensor. The idea is to avoid creating CUDA tensors at all as this is bound to be more expensive. Even from the perspective of UX, it is not straightforward/easy to debug when changing around the Tensor-ness of lr under the hood.

I agree with your point of isolating which portion is slow. Our comments above are directly related to the LRSchedule issue.

@gau-nernst
Copy link
Collaborator

@janeyx99 Thank you for your explanation, the proposed solution makes sense. I will test it out

LR does not need to be on CUDA as it will get special treatment as a scalar Tensor

I was not aware of this. Just curious, does this work in eager also?

@janeyx99
Copy link
Contributor

Yes it works in eager! I'm also curious if you find it untrue for any op-->if so, we can fix it in pytorch/pytorch. The trick we do is to .item() on cpu scalar tensors and pass it as a scalar in the cuda kernel, to avoid needing it to be a Tensor at all for the kernel (which then has the device restrictions).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants