Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with Flux LoRAs trained with SimpleTuner #9134

Closed
sayakpaul opened this issue Aug 9, 2024 · 34 comments
Closed

Issue with Flux LoRAs trained with SimpleTuner #9134

sayakpaul opened this issue Aug 9, 2024 · 34 comments

Comments

@sayakpaul
Copy link
Member

sayakpaul commented Aug 9, 2024

@bghira we're seeing issues when doing:

import torch
from diffusers import DiffusionPipeline

base_model = "black-forest-labs/FLUX.1-dev"
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16).to("cuda")

pipe.load_lora_weights("pzc163/flux-lora-littletinies")
generator = torch.Generator(device="cuda").manual_seed(0)
image = pipe(
    prompt=f"a dog in the park",
    num_inference_steps=28,
    guidance_scale=3.5,
    width=1024,
    height=1024,
    generator=generator,
 ).images[0]

But it outputs gibberish. The inference code from https://huggingface.co/pzc163/flux-lora-littletinies shows many unsupported arguments such negative_prompt and guidance_rescale.

Have you found any bugs with inference?

Cc: @apolinario @asomoza

@AmericanPresidentJimmyCarter
Copy link
Contributor

+1, no idea how to use it with the quantized model either.

@sayakpaul
Copy link
Member Author

@apolinario says Comfy can run it perfectly fine. So, I wonder what we're missing.

@sayakpaul
Copy link
Member Author

Okay, it's the lora alpha that is the culprit

diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py
index f612cc0c6..f432175d0 100644
--- a/src/diffusers/loaders/lora_pipeline.py
+++ b/src/diffusers/loaders/lora_pipeline.py
@@ -1686,6 +1686,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
                     )
                 else:
                     lora_config_kwargs.pop("use_dora")
+            lora_config_kwargs['lora_alpha'] = 16
             lora_config = LoraConfig(**lora_config_kwargs)
 
             # adapter_name

With that hack and the following code,

import torch
from diffusers import DiffusionPipeline

base_model = "black-forest-labs/FLUX.1-dev"
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16).to("cuda")

pipe.load_lora_weights("pzc163/flux-lora-littletinies")
generator = torch.Generator(device="cuda").manual_seed(0)
image = pipe(
    prompt=f"ethnographic photography of teddy bear at a picnic",
    num_inference_steps=50,
    guidance_scale=7.5,
    max_sequence_length=512,
    width=1152,
    height=768,
    generator=generator,
).images[0]
image.save("flux_dev_lora.png")

I am able to generate:
flux_dev_lora

@sayakpaul
Copy link
Member Author

I also see SimpleTuner allows to train a different lora_alpha than lora_rank. For diffusers, it's always the same rank for easier loading:

https://github.com/bghira/SimpleTuner/blob/513b71b65b85af3bfc6c8ce1ee8abaafef1f4152/train.py#L932C28-L932C43

I wonder how lora_alpha is parsed for diffusers LoRA params then. Because we parse alphas only for non-diffusers checkpoints

state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)

The problem is a little bit tricky because now we have to either do some kind of heuristics to determine the alpha or allow saving LoRA configs in the metadata of the LoRA checkpoint. This is only possible if the checkpoint is serialized in safetensors, which should not be a problem.

@sayakpaul
Copy link
Member Author

We cannot likely obtain the lora_alpha info from the input state dict. To verify, I did the following:

from diffusers import FluxTransformer2DModel 
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict

init_dict = {
    "patch_size": 1,
    "in_channels": 4,
    "num_layers": 1,
    "num_single_layers": 1,
    "attention_head_dim": 16,
    "num_attention_heads": 2,
    "joint_attention_dim": 32,
    "pooled_projection_dim": 32,
    "axes_dims_rope": [4, 4, 8],
} 
model = FluxTransformer2DModel(**init_dict)

denoiser_lora_config = LoraConfig(
    r=4,
    lora_alpha=2,
    target_modules=["to_q", "to_k", "to_v", "to_out.0"],
    init_lora_weights=False
)
model.add_adapter(denoiser_lora_config)
peft_sd = get_peft_model_state_dict(model)
for k, v in peft_sd.items():
    print(k ,v.shape)

Despite the lora_alpha being set to 2 (different from rank), it's nowhere to be seen in the state dict shapes and that is expected:

transformer_blocks.0.attn.to_q.lora_A.weight torch.Size([4, 32])
transformer_blocks.0.attn.to_q.lora_B.weight torch.Size([32, 4])
transformer_blocks.0.attn.to_k.lora_A.weight torch.Size([4, 32])
transformer_blocks.0.attn.to_k.lora_B.weight torch.Size([32, 4])
transformer_blocks.0.attn.to_v.lora_A.weight torch.Size([4, 32])
transformer_blocks.0.attn.to_v.lora_B.weight torch.Size([32, 4])
transformer_blocks.0.attn.to_out.0.lora_A.weight torch.Size([4, 32])
transformer_blocks.0.attn.to_out.0.lora_B.weight torch.Size([32, 4])
single_transformer_blocks.0.attn.to_q.lora_A.weight torch.Size([4, 32])
single_transformer_blocks.0.attn.to_q.lora_B.weight torch.Size([32, 4])
single_transformer_blocks.0.attn.to_k.lora_A.weight torch.Size([4, 32])
single_transformer_blocks.0.attn.to_k.lora_B.weight torch.Size([32, 4])
single_transformer_blocks.0.attn.to_v.lora_A.weight torch.Size([4, 32])
single_transformer_blocks.0.attn.to_v.lora_B.weight torch.Size([32, 4])

So, we probably have to save the LoRA config in the safetensors metadata as I explained here:
#9134 (comment). I did a PoC for this back in the day: #6135.

@BenjaminBossan do you have any other idea here?

@sayakpaul
Copy link
Member Author

Okay, another interesting fact discovered by @asomoza is that this LoRA should work as expected with schnell and with dev the results are widely different.

@sayakpaul
Copy link
Member Author

sayakpaul commented Aug 9, 2024

@AmericanPresidentJimmyCarter I did some brute-forcing (it's not long-term) to perform inference with quantization:

import torch
from diffusers import DiffusionPipeline
from optimum.quanto import quantize, qfloat8, freeze

base_model = "black-forest-labs/FLUX.1-schnell"
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16).to("cuda")

pipe.load_lora_weights("pzc163/flux-lora-littletinies")
# generator = torch.Generator(device="cuda").manual_seed(0)

pipe.fuse_lora(lora_scale=0.75)
pipe.unload_lora_weights()

quantize(pipe.transformer, qfloat8)
freeze(pipe.transformer)

# quantize(pipe.text_encoder, qfloat8)
# freeze(pipe.text_encoder)

# quantize(pipe.text_encoder_2, qfloat8)
# freeze(pipe.text_encoder_2)

image = pipe(
    prompt=f"a girl is walking in the forest,meeting a big tiger",
    num_inference_steps=4,
    guidance_scale=0.0,
    max_sequence_length=256,
    width=1024,
    height=1024,
    # joint_attention_kwargs={"scale": 1.0},
    generator=torch.manual_seed(0),
).images[0]
image.save("flux_dev_lora.png")

But it still doesn't apparently match the ComfyUI output even with lora_alpha=16 which is the default alpha value SimpleTuner sets:
https://github.com/bghira/SimpleTuner/blob/aceaf3e0b69de83f5a5864a5c40676982f4497cc/helpers/arguments.py#L240

Interestingly, if we keep the lora_alpha between [32, 42], the results somewhat resemble ComfyUI but not entirely.

@bghira
Copy link
Contributor

bghira commented Aug 9, 2024

i also received a report that lora alpha is in the state dict for other loras eg kohya, which makes it hard to know how to load diffusers loras which do not

@bghira
Copy link
Contributor

bghira commented Aug 9, 2024

i wonder if it needs to be scaled in fp32 and not bf16

@sayakpaul
Copy link
Member Author

Yeah for other LoRAs, we parse the alpha but for diffusers we don't. This is why in our training code, lora_alpha=args.rank always. To allow, alpha, we would need something like this: #6135.

We never serialize alpha information in our LoRA ckpts and this is because historically, we never supported them from our training.

i wonder if it needs to be scaled in fp32 and not bf16

You mean the LoRA layers? Despite that, we need to change lora_alpha here to 16 (I am assuming 16 was used to train the said LoRA as that's default in SimpleTuner).

lora_config = LoraConfig(**lora_config_kwargs)

Still trying to identify the potential suspects.

@bghira
Copy link
Contributor

bghira commented Aug 9, 2024

the rank and alpha should be / is listed in the model card too fwiw

@sayakpaul
Copy link
Member Author

Yes, it is listed. Rank is 64 and alpha is 16.

@BenjaminBossan
Copy link
Member

Hmm, apart from storing the alpha value somewhere in the metadata to access it later, there is one thing that comes to mind: In the end, alpha is just a scale for the LoRA output: output = base_output + lora_output * alpha / rank. IIUC we assume by default that alpha = rank. Let's say we have an example where alpha = 2 * rank, the lora_output would be off by a factor of 2. Therefore, we could achieve the same results by multiplying 2 onto one of the LoRA weights, say lora_A.

To make this work would require an export script that is aware of alpha and knows that the output is intended to be loaded into diffusers. Then it could go through the state_dict and multiply each lora_A by alpha / rank.

This is not ideal and I would prefer if alpha was known through the metadata, but maybe it's an idea.

@sayakpaul
Copy link
Member Author

Thanks for providing your suggestion. However, alpha should be determined during loading just like we determine the rank. IIUC, to follow your suggestion, we would need to apply that scaling during serializing, right?

@bghira
Copy link
Contributor

bghira commented Aug 9, 2024

is there anything stopping us from adding alpha to the safetensors data? no one uses pth files anymore but even if they did it would probably be doable there too

@sayakpaul
Copy link
Member Author

sayakpaul commented Aug 9, 2024

Nothing actually. In fact, I would definitely prefer serializing the lora config to metadata. Or we could try what @BenjaminBossan ideated. Can get a PR ready soon.

@bghira
Copy link
Contributor

bghira commented Aug 9, 2024

i think lora config is more flexible as we dont permanently scale things with potential precision issues that we cant undo later

@BenjaminBossan
Copy link
Member

Yes, as I said, providing alpha as metadata is strongly preferred, my suggestion was just a last resort if the former is not possible. When storing alpha, we need to remember that it could be different for each layer.

@bghira
Copy link
Contributor

bghira commented Aug 9, 2024

i think sayak was suggesting it become a part of the layer shapes or maybe i misread that, but maybe we can reuse config metadata layout from eg. kohya for this type of issue, and meet the world halfway

@sayakpaul
Copy link
Member Author

#6135 serializes the entire LoRA Config as a metadata into the safetensors and reads it accordingly. What are the edge cases are we talking about here?

@bghira
Copy link
Contributor

bghira commented Aug 9, 2024

we could handle the most obvious cases first and then can go back later and add adjustments for edge cases.

i think this one is important to start with for now, but if layerwise alpha is something we support already somehow, it should probably be made to work as well. if it's not currently supported, i think it's a nice future addition since nothing of value will be lost with the change.

i can take a look at kohya's trainer and see if there's a trivial way to discover the kind of config options we have to take into account. but that's definitely a separate list of improvements for the future

@bghira
Copy link
Contributor

bghira commented Aug 9, 2024

ok, kohya doesn't store anything but the alpha. it serialises its config into the safetensors but that's the entire training config, eg. global_step and the learning_rate yada yada

their trainer adds .alpha as a scalar to each layer

@sayakpaul
Copy link
Member Author

#9143

@AmericanPresidentJimmyCarter
Copy link
Contributor

@AmericanPresidentJimmyCarter I did some brute-forcing (it's not long-term) to perform inference with quantization:

But it still doesn't apparently match the ComfyUI output even with lora_alpha=16 which is the default alpha value SimpleTuner sets: https://github.com/bghira/SimpleTuner/blob/aceaf3e0b69de83f5a5864a5c40676982f4497cc/helpers/arguments.py#L240

Interestingly, if we keep the lora_alpha between [32, 42], the results somewhat resemble ComfyUI but not entirely.

I tried this and just got noise, but I'm unsure if it's the LoRA itself or not.

@AmericanPresidentJimmyCarter
Copy link
Contributor

Nevermind, got it working. Writeup here.

https://gist.github.com/AmericanPresidentJimmyCarter/0134fc7848aac7025d0c967c6c4df53b

@sayakpaul
Copy link
Member Author

Thanks for the detailed write-up! Did you also try out the "pzc163/flux-lora-littletinies" LoRA?

@bghira
Copy link
Contributor

bghira commented Aug 11, 2024

it goes a bit further, the guidance_scale value needs to really be 1.0 at training time to preserve the model's distillation. annnnnd that's all you need, to do that.

but training cfg in seems to actually improve the model's variability and creativity. so, i left both choices up to the user

@sayakpaul
Copy link
Member Author

Thanks!

If there's any pipeline-level changes needed to accommodate those, we can discuss and have them integrated to the core diffusers. I saw @AmericanPresidentJimmyCarter had a pipeline implementation and SimpleTuner has it too.

We want to figure out if that can be eliminated and better maintained within diffusers itself :)

Also, LMK if you need reviews for the changes being introduced in #9143.

@bghira
Copy link
Contributor

bghira commented Aug 11, 2024

currently the main reason i've got copies of the pipelines is for the vae decode whose input dtype routinely desyncs from the vae dtype itself - i think the pipeline assumes everything will be in the same precision. i would love to get rid of those forever.

@sayakpaul
Copy link
Member Author

I was mainly referring to the CFG related changes instrumented in @AmericanPresidentJimmyCarter’s script.

currently the main reason i've got copies of the pipelines is for the vae decode whose input dtype routinely desyncs from the vae dtype itself

Oh okay. Do we know when it happens ( e.g. MPS device, during training)? Or is it very random? Because just for inference, it shouldn’t happen.

@bghira
Copy link
Contributor

bghira commented Aug 11, 2024

it doesn't happen on MPS, actually. it's ironic... but it does happen on CUDA devices. i'm really not sure why that is. validations will succeed locally and then on cuda, complains about input dtype being != vae dtype.

@bghira
Copy link
Contributor

bghira commented Aug 11, 2024

example: #8391

@bghira
Copy link
Contributor

bghira commented Aug 11, 2024

running the scheduler step forward on nvidia / cuda results in the latents becoming fp32. I tried to dig into where and why that's happening, but it's inside Accelerate, and links closely with Pytorch internals. It is beyond me.

@sayakpaul
Copy link
Member Author

Ah okay, then some autocast voodoo, maybe. Then in any case, for the CFG stuff, if you think we should have those changes inside our implementation, let us know. But for now, I will close this issue since it seems to be resolved in my eyes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants