-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Issue with Flux LoRAs trained with SimpleTuner #9134
Comments
+1, no idea how to use it with the quantized model either. |
@apolinario says Comfy can run it perfectly fine. So, I wonder what we're missing. |
Okay, it's the lora alpha that is the culprit diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py
index f612cc0c6..f432175d0 100644
--- a/src/diffusers/loaders/lora_pipeline.py
+++ b/src/diffusers/loaders/lora_pipeline.py
@@ -1686,6 +1686,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
)
else:
lora_config_kwargs.pop("use_dora")
+ lora_config_kwargs['lora_alpha'] = 16
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name With that hack and the following code, import torch
from diffusers import DiffusionPipeline
base_model = "black-forest-labs/FLUX.1-dev"
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16).to("cuda")
pipe.load_lora_weights("pzc163/flux-lora-littletinies")
generator = torch.Generator(device="cuda").manual_seed(0)
image = pipe(
prompt=f"ethnographic photography of teddy bear at a picnic",
num_inference_steps=50,
guidance_scale=7.5,
max_sequence_length=512,
width=1152,
height=768,
generator=generator,
).images[0]
image.save("flux_dev_lora.png") |
I also see SimpleTuner allows to train a different I wonder how
The problem is a little bit tricky because now we have to either do some kind of heuristics to determine the alpha or allow saving LoRA configs in the metadata of the LoRA checkpoint. This is only possible if the checkpoint is serialized in safetensors, which should not be a problem. |
We cannot likely obtain the from diffusers import FluxTransformer2DModel
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
init_dict = {
"patch_size": 1,
"in_channels": 4,
"num_layers": 1,
"num_single_layers": 1,
"attention_head_dim": 16,
"num_attention_heads": 2,
"joint_attention_dim": 32,
"pooled_projection_dim": 32,
"axes_dims_rope": [4, 4, 8],
}
model = FluxTransformer2DModel(**init_dict)
denoiser_lora_config = LoraConfig(
r=4,
lora_alpha=2,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False
)
model.add_adapter(denoiser_lora_config)
peft_sd = get_peft_model_state_dict(model)
for k, v in peft_sd.items():
print(k ,v.shape) Despite the transformer_blocks.0.attn.to_q.lora_A.weight torch.Size([4, 32])
transformer_blocks.0.attn.to_q.lora_B.weight torch.Size([32, 4])
transformer_blocks.0.attn.to_k.lora_A.weight torch.Size([4, 32])
transformer_blocks.0.attn.to_k.lora_B.weight torch.Size([32, 4])
transformer_blocks.0.attn.to_v.lora_A.weight torch.Size([4, 32])
transformer_blocks.0.attn.to_v.lora_B.weight torch.Size([32, 4])
transformer_blocks.0.attn.to_out.0.lora_A.weight torch.Size([4, 32])
transformer_blocks.0.attn.to_out.0.lora_B.weight torch.Size([32, 4])
single_transformer_blocks.0.attn.to_q.lora_A.weight torch.Size([4, 32])
single_transformer_blocks.0.attn.to_q.lora_B.weight torch.Size([32, 4])
single_transformer_blocks.0.attn.to_k.lora_A.weight torch.Size([4, 32])
single_transformer_blocks.0.attn.to_k.lora_B.weight torch.Size([32, 4])
single_transformer_blocks.0.attn.to_v.lora_A.weight torch.Size([4, 32])
single_transformer_blocks.0.attn.to_v.lora_B.weight torch.Size([32, 4]) So, we probably have to save the LoRA config in the safetensors metadata as I explained here: @BenjaminBossan do you have any other idea here? |
Okay, another interesting fact discovered by @asomoza is that this LoRA should work as expected with |
@AmericanPresidentJimmyCarter I did some brute-forcing (it's not long-term) to perform inference with quantization: import torch
from diffusers import DiffusionPipeline
from optimum.quanto import quantize, qfloat8, freeze
base_model = "black-forest-labs/FLUX.1-schnell"
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16).to("cuda")
pipe.load_lora_weights("pzc163/flux-lora-littletinies")
# generator = torch.Generator(device="cuda").manual_seed(0)
pipe.fuse_lora(lora_scale=0.75)
pipe.unload_lora_weights()
quantize(pipe.transformer, qfloat8)
freeze(pipe.transformer)
# quantize(pipe.text_encoder, qfloat8)
# freeze(pipe.text_encoder)
# quantize(pipe.text_encoder_2, qfloat8)
# freeze(pipe.text_encoder_2)
image = pipe(
prompt=f"a girl is walking in the forest,meeting a big tiger",
num_inference_steps=4,
guidance_scale=0.0,
max_sequence_length=256,
width=1024,
height=1024,
# joint_attention_kwargs={"scale": 1.0},
generator=torch.manual_seed(0),
).images[0]
image.save("flux_dev_lora.png") But it still doesn't apparently match the ComfyUI output even with Interestingly, if we keep the |
i also received a report that lora alpha is in the state dict for other loras eg kohya, which makes it hard to know how to load diffusers loras which do not |
i wonder if it needs to be scaled in fp32 and not bf16 |
Yeah for other LoRAs, we parse the alpha but for diffusers we don't. This is why in our training code, We never serialize
You mean the LoRA layers? Despite that, we need to change diffusers/src/diffusers/loaders/lora_pipeline.py Line 1689 in 65e3090
Still trying to identify the potential suspects. |
the rank and alpha should be / is listed in the model card too fwiw |
Yes, it is listed. Rank is 64 and alpha is 16. |
Hmm, apart from storing the alpha value somewhere in the metadata to access it later, there is one thing that comes to mind: In the end, alpha is just a scale for the LoRA output: To make this work would require an export script that is aware of alpha and knows that the output is intended to be loaded into diffusers. Then it could go through the This is not ideal and I would prefer if alpha was known through the metadata, but maybe it's an idea. |
Thanks for providing your suggestion. However, |
is there anything stopping us from adding alpha to the safetensors data? no one uses pth files anymore but even if they did it would probably be doable there too |
Nothing actually. In fact, I would definitely prefer serializing the lora config to metadata. Or we could try what @BenjaminBossan ideated. Can get a PR ready soon. |
i think lora config is more flexible as we dont permanently scale things with potential precision issues that we cant undo later |
Yes, as I said, providing alpha as metadata is strongly preferred, my suggestion was just a last resort if the former is not possible. When storing alpha, we need to remember that it could be different for each layer. |
i think sayak was suggesting it become a part of the layer shapes or maybe i misread that, but maybe we can reuse config metadata layout from eg. kohya for this type of issue, and meet the world halfway |
#6135 serializes the entire LoRA Config as a metadata into the safetensors and reads it accordingly. What are the edge cases are we talking about here? |
we could handle the most obvious cases first and then can go back later and add adjustments for edge cases. i think this one is important to start with for now, but if layerwise alpha is something we support already somehow, it should probably be made to work as well. if it's not currently supported, i think it's a nice future addition since nothing of value will be lost with the change. i can take a look at kohya's trainer and see if there's a trivial way to discover the kind of config options we have to take into account. but that's definitely a separate list of improvements for the future |
ok, kohya doesn't store anything but the alpha. it serialises its config into the safetensors but that's the entire training config, eg. global_step and the learning_rate yada yada their trainer adds |
I tried this and just got noise, but I'm unsure if it's the LoRA itself or not. |
Nevermind, got it working. Writeup here. https://gist.github.com/AmericanPresidentJimmyCarter/0134fc7848aac7025d0c967c6c4df53b |
Thanks for the detailed write-up! Did you also try out the "pzc163/flux-lora-littletinies" LoRA? |
it goes a bit further, the but training cfg in seems to actually improve the model's variability and creativity. so, i left both choices up to the user |
Thanks! If there's any pipeline-level changes needed to accommodate those, we can discuss and have them integrated to the core We want to figure out if that can be eliminated and better maintained within Also, LMK if you need reviews for the changes being introduced in #9143. |
currently the main reason i've got copies of the pipelines is for the vae decode whose input dtype routinely desyncs from the vae dtype itself - i think the pipeline assumes everything will be in the same precision. i would love to get rid of those forever. |
I was mainly referring to the CFG related changes instrumented in @AmericanPresidentJimmyCarter’s script.
Oh okay. Do we know when it happens ( e.g. MPS device, during training)? Or is it very random? Because just for inference, it shouldn’t happen. |
it doesn't happen on MPS, actually. it's ironic... but it does happen on CUDA devices. i'm really not sure why that is. validations will succeed locally and then on cuda, complains about input dtype being != vae dtype. |
example: #8391 |
running the scheduler step forward on nvidia / cuda results in the latents becoming fp32. I tried to dig into where and why that's happening, but it's inside Accelerate, and links closely with Pytorch internals. It is beyond me. |
Ah okay, then some autocast voodoo, maybe. Then in any case, for the CFG stuff, if you think we should have those changes inside our implementation, let us know. But for now, I will close this issue since it seems to be resolved in my eyes. |
@bghira we're seeing issues when doing:
But it outputs gibberish. The inference code from https://huggingface.co/pzc163/flux-lora-littletinies shows many unsupported arguments such
negative_prompt
andguidance_rescale
.Have you found any bugs with inference?
Cc: @apolinario @asomoza
The text was updated successfully, but these errors were encountered: