Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable PEFT input autocast when using fp8 layerwise casting #10685

Merged
merged 9 commits into from
Feb 13, 2025

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Jan 29, 2025

Context: https://huggingface.slack.com/archives/C04L3MWLE6B/p1738076520963309

TLDR; Input casting based on dtype of weights in PEFT layers can lead to lower quality generation (for example, when keeping lora layers in fp8 precision), if we were to handle force casting the inputs back to compute_dtype. However, we are resorting to casting the inputs due to reasons mentioned in the docstring. With the help of Benjamin, we can now opt-out of the input casting, which seems like the ideal way to handle it.

Tracking in PEFT here: huggingface/peft#2353

reproducer
import gc
from typing import Any

import torch
from diffusers import CogVideoXPipeline
from diffusers.hooks import apply_layerwise_casting
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_debug
from peft.tuners.lora.layer import Linear as LoRALinear

set_verbosity_debug()


class DisableTensorTo:
    def __enter__(self):
        self.original_to = torch.Tensor.to
        
        def noop_to(self, *args, **kwargs):
            return self
    
        torch.Tensor.to = noop_to
    
    def __exit__(self, exc_type, exc_value, traceback):
        torch.Tensor.to = self.original_to


def main():
    # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
    pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
    pipe.to("cuda")

    pipe.load_lora_weights("Cseti/CogVideoX-LoRA-Wallace_and_Gromit", weight_name="walgro1-3000.safetensors", adapter_name="cogvideox-lora")
    pipe.set_adapters(["cogvideox-lora"], [1.0])
    
    apply_layerwise_casting(
        pipe.transformer,
        storage_dtype=torch.float8_e4m3fn,
        compute_dtype=torch.bfloat16,
        skip_modules_pattern=["patch_embed", "norm", "^proj_out$"]
    )
    
    for name, parameter in pipe.transformer.named_parameters():
        if "lora" in name:
            assert(parameter.dtype == torch.float8_e4m3fn)

    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

    print(f"Model memory: {torch.cuda.max_memory_allocated() / 1024**3:.3f} GB")

    prompt = "walgro1. The scene begins with a close-up of Gromit's face, his expressive eyes filling the frame. His brow furrows slightly, ears perked forward in concentration. The soft lighting highlights the subtle details of his fur, every strand catching the warm sunlight filtering in from a nearby window. His dark, round nose twitches ever so slightly, sensing something in the air, and his gaze darts to the side, following an unseen movement. The camera lingers on Gromit’s face, capturing the subtleties of his expression—a quirked eyebrow and a knowing look that suggests he’s piecing together something clever. His silent, thoughtful demeanor speaks volumes as he watches the scene unfold with quiet intensity. The background remains out of focus, drawing all attention to the sharp intelligence in his eyes and the slight tilt of his head. In the claymation style of Wallace and Gromit."

    with torch.no_grad():
        prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
            prompt=prompt,
            negative_prompt="",
            do_classifier_free_guidance=True,
            num_videos_per_prompt=1,
            device="cuda",
            dtype=torch.bfloat16,
        )

    video = pipe(
        prompt_embeds=prompt_embeds,
        negative_prompt_embeds=negative_prompt_embeds,
        guidance_scale=6,
        num_inference_steps=50,
        generator=torch.Generator().manual_seed(42)
    ).frames[0]
    export_to_video(video, "output.mp4", fps=8)

    print(f"Inference memory: {torch.cuda.max_memory_allocated() / 1024**3:.3f} GB")

main()

The example code errors out with diffusers:main and peft:main, but passes when using the respective PR branches.

Comment on lines 235 to 241
def _disable_peft_input_autocast(module: torch.nn.Module) -> None:
for submodule in module.modules():
if not isinstance(submodule, BaseTunerLayer):
continue
registry = HookRegistry.check_if_exists_or_initialize(submodule)
hook = PeftInputAutocastDisableHook()
registry.register_hook(hook, _PEFT_AUTOCAST_DISABLE_HOOK)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that layerwise casting can be enabled on modeling levels aside from the top-level module, some lora weights can be in fp8, while others may be in bf16 (for example). While it's an edge case, I think handling this hook addition should be done only if any submodule of the BaseTunerLayer contains a _diffusers_hook registry with a _LAYERWISE_UPCASTING_HOOK hook. LMK if I'm overthinking or if I implement it as described

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if any submodule of the BaseTunerLayer contains a _diffusers_hook registry with a _LAYERWISE_UPCASTING_HOOK hook

Yeah with you. Just some clarification questions on the following LoC:

registry = HookRegistry.check_if_exists_or_initialize(submodule)
hook = PeftInputAutocastDisableHook()
registry.register_hook(hook, _PEFT_AUTOCAST_DISABLE_HOOK)
  • registry = HookRegistry.check_if_exists_or_initialize(submodule) runs the check if submodule has a certain type of registry? If so, registry can be a bool?
  • Shouldn't this be conditioned -- registry.register_hook(hook, _PEFT_AUTOCAST_DISABLE_HOOK)? Something like if isinstance(registry, bool) and registry: ...` or am I reading it wrong?

I haven't gone through the implementation of the above-used methods, so apologies for that in advance.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

registry = HookRegistry.check_if_exists_or_initialize(submodule) runs the check if submodule has a certain type of registry? If so, registry can be a bool?
Shouldn't this be conditioned -- registry.register_hook(hook, _PEFT_AUTOCAST_DISABLE_HOOK)? Something like if isinstance(registry, bool) and registry: ...` or am I reading it wrong?

The check_if_exists_or_initialize method always returns a HookRegistry object and not a bool, so we don't need a condition. Maybe could use a better name if you have suggestions in a follow-up PR since any changes related to it haven't gone into a release yet

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@a-r-r-o-w thanks! Design-wise this looks good to me! I think the major todos are:

Comment on lines 235 to 241
def _disable_peft_input_autocast(module: torch.nn.Module) -> None:
for submodule in module.modules():
if not isinstance(submodule, BaseTunerLayer):
continue
registry = HookRegistry.check_if_exists_or_initialize(submodule)
hook = PeftInputAutocastDisableHook()
registry.register_hook(hook, _PEFT_AUTOCAST_DISABLE_HOOK)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if any submodule of the BaseTunerLayer contains a _diffusers_hook registry with a _LAYERWISE_UPCASTING_HOOK hook

Yeah with you. Just some clarification questions on the following LoC:

registry = HookRegistry.check_if_exists_or_initialize(submodule)
hook = PeftInputAutocastDisableHook()
registry.register_hook(hook, _PEFT_AUTOCAST_DISABLE_HOOK)
  • registry = HookRegistry.check_if_exists_or_initialize(submodule) runs the check if submodule has a certain type of registry? If so, registry can be a bool?
  • Shouldn't this be conditioned -- registry.register_hook(hook, _PEFT_AUTOCAST_DISABLE_HOOK)? Something like if isinstance(registry, bool) and registry: ...` or am I reading it wrong?

I haven't gone through the implementation of the above-used methods, so apologies for that in advance.

from .hooks import HookRegistry, ModelHook


if is_peft_available():
from peft.helpers import disable_lora_input_dtype_casting
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR is now merged in PEFT. I renamed the function to disable_input_dtype_casting. Let's also add a guard for the existence of the function, since users may have an older PEFT version installed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the headsup.

@a-r-r-o-w LMK if you would like me to take over the PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @BenjaminBossan! I've updated the code to use latest name and added a version guard. Will document and add tests once CI is green and dependancy tests don't fail like before

@sayakpaul I can wrap it up, thanks for offering to help! On a separate not, please do not merge the main branch or push changes arbitrarily. It causes push conflicts unexpectedly on my own work branches. I know the fix is as simple as git pull --rebase or similar, but I don't fancy dealing with it, as it's an extra unnecessary step on a branch that is supposed to be mine to work on. I then need to force push otherwise there can sometimes be merge conficts (several instances in the past). If you need the changes from main on a certain branch for testing, please feel free to create your own branch from main and cherrypick the changes in, or make sure to get a green light from the branch owner first.

image

I don't know what others feel about the same, but I can imagine it being a unnecessary disruption to deal with in general.

Copy link
Member

@sayakpaul sayakpaul Feb 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noted. Cc: @DN6 @hlky too as I think we do it (merging main into a PR branch) quite regularly.

push changes arbitrarily

Note sure if I have pushed changes arbitrarily without explicit permissions first.

@a-r-r-o-w a-r-r-o-w force-pushed the peft/disable-input-autocast branch from ed791c4 to 020e374 Compare February 7, 2025 04:13
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking fantastic! Thanks, Aryan.

I just have one comment for tests.

A hook that disables the casting of inputs to the module weight dtype during the forward pass. By default, PEFT
casts the inputs to the weight dtype of the module, which can lead to precision loss.

The reasons for needing this are:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a LOT for writing this! Really, thanks!

@@ -134,6 +167,7 @@ def apply_layerwise_casting(
skip_modules_classes,
non_blocking,
)
_disable_peft_input_autocast(module)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method is already version-guarded. So, no worries.

(For other reviewers).

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this integration. I don't have anything to add on top of what Sayak already wrote.

@a-r-r-o-w a-r-r-o-w requested a review from sayakpaul February 13, 2025 13:10
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👨‍🍳 ❤️

@sayakpaul sayakpaul merged commit a0c2299 into main Feb 13, 2025
15 checks passed
@sayakpaul sayakpaul deleted the peft/disable-input-autocast branch February 13, 2025 17:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants