-
Notifications
You must be signed in to change notification settings - Fork 221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
4 bit Adam should support non constant lr #730
Comments
Wait a moment, what does "convert learning rate to a CUDA tensor" mean @gau-nernst? That should only be done once during initialization, and not for every step. It is true that compiled optimizers do not performantly support python float for lr, but they should support Tensors without incurring more recompiles. As in, the user should be able to init an optimizer with a Tensor lr and have the compiled optimizer run without recompiles. |
@msaroufim @winglian can you provide a repro + the slowdown that is occurring? There is an expected slowdown (I would imagine it being very small) since we are running more ops, but I'm surprised that it would even get close to slower than eager and prevent adoption of a compiled optimizer. Also, unrelated question are we able to compile the 4 bit optimizer today? |
For a repro you can run https://github.com/pytorch/ao/blob/main/benchmarks/benchmark_low_bit_adam.py#L109 and enable And this is the learning rate conversion logic https://github.com/pytorch/ao/blob/main/torchao/prototype/low_bit_optim/adamw.py#L80-L84 |
Can we avoid this if the user passes in a Tensor LR + if we maintain a device -> lr dictionary in the single_param implementation instead of casting during every step? https://github.com/pytorch/ao/blob/main/torchao/prototype/low_bit_optim/adamw.py#L14 |
Agreed with Jane, I don't think we want to be allocating a tensor on every step. Although I'm surprised allocating a scalar tensor results in a 3-4% slowdown but I believe u. IIRC, if you have a scalar tensor LR wouldn't it get moved to the correct device at kernel launch regardless? Do we need a full dict? |
Oh right, I'd try the steps in order then:
|
@msaroufim I don't have a reproducer offhand, but the last time I tried the 4bit optimizer was a couple of weeks ago in the context of torchtune and iirc, I thought it was an order of magnitude slower than a regular optimizer |
Yes. (Though there are some memory issues with torch.compile() which forced me to do a workaround - ao/torchao/prototype/low_bit_optim/adamw.py Lines 221 to 222 in c0b0731
I'm a bit lost here. How would this help? Do you mean I can do def f(lr_tensor_scalar_cpu, param_tensor_cuda, exp_avg_tensor_cuda...):
...
torch.compile(f)(...) In other words, if CPU tensor is a scalar, I don't need to move it to CUDA? I thought everything needs to be on the same device. (I can try it later, not with my machine atm)
I'm also not following why having a dictionary would help 😢. Ultimately, it depends on how the LR is updated by the user. Personally this is how I normally update the LR (note: the LR from ao/benchmarks/benchmark_low_bit_adam.py Lines 266 to 268 in b09307a
This is inspired by timm here and here. I have tried with I tend to avoid LR scheduler from Also note that the casting python float LR to CUDA tensor is done in-place, so if the LR is not updated by the user, there is no casting in subsequent steps. ao/torchao/prototype/low_bit_optim/adamw.py Lines 83 to 84 in b09307a
I notice that axolotl uses The workaround I would recommend is to only update the LR every xx steps, so that the slow down is amortized over a longer period. |
@winglian when you observed the slowdown, did you use a LR scheduler? To isolate whether it is the LR schedule issue or the 4-bit optimizer itself is slow. |
@gau-nernst LR does not need to be on CUDA as it will get special treatment as a scalar Tensor, and probably should not be on CUDA. For now, let's ignore the lr dictionary. The proposal is for the LR to be allowed as a CPU ScalarTensor from the start (from the constructor), and for all LRScheduling to keep it a Tensor so that it is dynamic enough for compiling. If the LR from .get_lr(step) is a Python float, you can do I agree with your point of isolating which portion is slow. Our comments above are directly related to the LRSchedule issue. |
@janeyx99 Thank you for your explanation, the proposed solution makes sense. I will test it out
I was not aware of this. Just curious, does this work in eager also? |
Yes it works in eager! I'm also curious if you find it untrue for any op-->if so, we can fix it in pytorch/pytorch. The trick we do is to .item() on cpu scalar tensors and pass it as a scalar in the cuda kernel, to avoid needing it to be a Tensor at all for the kernel (which then has the device restrictions). |
Our low bit optimizers were merged in HF huggingface/transformers#31865 but
We have a known limitation that the 4 bit optimizer is not great when we don't have a constant learning rate
This is mentioned in the README https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim
Known issue: When learning rate is updated every step (e.g. using cosine learning rate scheduler), training speed is slower. This is because we have to convert learning rate to a CUDA tensor (which incurs expensive memory transfer cost), since torch.compile() will treat a Python float as a constant and trigger recompile whenever the value is changed.
However this is preventing @winglian from adopting this work
cc @gau-nernst @mlazos @janeyx99
The text was updated successfully, but these errors were encountered: