-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Disable PEFT input autocast when using fp8 layerwise casting #10685
Conversation
def _disable_peft_input_autocast(module: torch.nn.Module) -> None: | ||
for submodule in module.modules(): | ||
if not isinstance(submodule, BaseTunerLayer): | ||
continue | ||
registry = HookRegistry.check_if_exists_or_initialize(submodule) | ||
hook = PeftInputAutocastDisableHook() | ||
registry.register_hook(hook, _PEFT_AUTOCAST_DISABLE_HOOK) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that layerwise casting can be enabled on modeling levels aside from the top-level module, some lora weights can be in fp8, while others may be in bf16 (for example). While it's an edge case, I think handling this hook addition should be done only if any submodule of the BaseTunerLayer contains a _diffusers_hook
registry with a _LAYERWISE_UPCASTING_HOOK
hook. LMK if I'm overthinking or if I implement it as described
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if any submodule of the BaseTunerLayer contains a _diffusers_hook registry with a _LAYERWISE_UPCASTING_HOOK hook
Yeah with you. Just some clarification questions on the following LoC:
registry = HookRegistry.check_if_exists_or_initialize(submodule)
hook = PeftInputAutocastDisableHook()
registry.register_hook(hook, _PEFT_AUTOCAST_DISABLE_HOOK)
registry = HookRegistry.check_if_exists_or_initialize(submodule)
runs the check ifsubmodule
has a certain type of registry? If so,registry
can be abool
?- Shouldn't this be conditioned --
registry.register_hook(hook, _PEFT_AUTOCAST_DISABLE_HOOK)
? Something likeif isinstance(registry, bool) and registry
: ...` or am I reading it wrong?
I haven't gone through the implementation of the above-used methods, so apologies for that in advance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
registry = HookRegistry.check_if_exists_or_initialize(submodule) runs the check if submodule has a certain type of registry? If so, registry can be a bool?
Shouldn't this be conditioned -- registry.register_hook(hook, _PEFT_AUTOCAST_DISABLE_HOOK)? Something like if isinstance(registry, bool) and registry: ...` or am I reading it wrong?
The check_if_exists_or_initialize
method always returns a HookRegistry
object and not a bool, so we don't need a condition. Maybe could use a better name if you have suggestions in a follow-up PR since any changes related to it haven't gone into a release yet
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@a-r-r-o-w thanks! Design-wise this looks good to me! I think the major todos are:
- Version-guard the imports (i.e.,
disable_lora_input_dtype_casting
) and implementation. - Docs (I think having a note in https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference is sufficient).
- Tests
def _disable_peft_input_autocast(module: torch.nn.Module) -> None: | ||
for submodule in module.modules(): | ||
if not isinstance(submodule, BaseTunerLayer): | ||
continue | ||
registry = HookRegistry.check_if_exists_or_initialize(submodule) | ||
hook = PeftInputAutocastDisableHook() | ||
registry.register_hook(hook, _PEFT_AUTOCAST_DISABLE_HOOK) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if any submodule of the BaseTunerLayer contains a _diffusers_hook registry with a _LAYERWISE_UPCASTING_HOOK hook
Yeah with you. Just some clarification questions on the following LoC:
registry = HookRegistry.check_if_exists_or_initialize(submodule)
hook = PeftInputAutocastDisableHook()
registry.register_hook(hook, _PEFT_AUTOCAST_DISABLE_HOOK)
registry = HookRegistry.check_if_exists_or_initialize(submodule)
runs the check ifsubmodule
has a certain type of registry? If so,registry
can be abool
?- Shouldn't this be conditioned --
registry.register_hook(hook, _PEFT_AUTOCAST_DISABLE_HOOK)
? Something likeif isinstance(registry, bool) and registry
: ...` or am I reading it wrong?
I haven't gone through the implementation of the above-used methods, so apologies for that in advance.
from .hooks import HookRegistry, ModelHook | ||
|
||
|
||
if is_peft_available(): | ||
from peft.helpers import disable_lora_input_dtype_casting |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR is now merged in PEFT. I renamed the function to disable_input_dtype_casting
. Let's also add a guard for the existence of the function, since users may have an older PEFT version installed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the headsup.
@a-r-r-o-w LMK if you would like me to take over the PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @BenjaminBossan! I've updated the code to use latest name and added a version guard. Will document and add tests once CI is green and dependancy tests don't fail like before
@sayakpaul I can wrap it up, thanks for offering to help! On a separate not, please do not merge the main branch or push changes arbitrarily. It causes push conflicts unexpectedly on my own work branches. I know the fix is as simple as git pull --rebase
or similar, but I don't fancy dealing with it, as it's an extra unnecessary step on a branch that is supposed to be mine to work on. I then need to force push otherwise there can sometimes be merge conficts (several instances in the past). If you need the changes from main on a certain branch for testing, please feel free to create your own branch from main and cherrypick the changes in, or make sure to get a green light from the branch owner first.
I don't know what others feel about the same, but I can imagine it being a unnecessary disruption to deal with in general.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ed791c4
to
020e374
Compare
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking fantastic! Thanks, Aryan.
I just have one comment for tests.
A hook that disables the casting of inputs to the module weight dtype during the forward pass. By default, PEFT | ||
casts the inputs to the weight dtype of the module, which can lead to precision loss. | ||
|
||
The reasons for needing this are: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a LOT for writing this! Really, thanks!
@@ -134,6 +167,7 @@ def apply_layerwise_casting( | |||
skip_modules_classes, | |||
non_blocking, | |||
) | |||
_disable_peft_input_autocast(module) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The method is already version-guarded. So, no worries.
(For other reviewers).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this integration. I don't have anything to add on top of what Sayak already wrote.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👨🍳 ❤️
Context: https://huggingface.slack.com/archives/C04L3MWLE6B/p1738076520963309
TLDR; Input casting based on dtype of weights in PEFT layers can lead to lower quality generation (for example, when keeping lora layers in fp8 precision), if we were to handle force casting the inputs back to compute_dtype. However, we are resorting to casting the inputs due to reasons mentioned in the docstring. With the help of Benjamin, we can now opt-out of the input casting, which seems like the ideal way to handle it.
Tracking in PEFT here: huggingface/peft#2353
reproducer
The example code errors out with
diffusers:main
andpeft:main
, but passes when using the respective PR branches.