Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unloading multiple loras: norms do not return to their original values #10745

Open
christopher5106 opened this issue Feb 7, 2025 · 25 comments
Open

Comments

@christopher5106
Copy link

When unloading from multiple loras on flux pipeline, I believe that the norm layers are not restored here.

Shouldn't we have:

        if len(transformer_norm_state_dict) > 0:
            original_norm_layers_state_dict = self._load_norm_into_transformer(
                transformer_norm_state_dict,
                transformer=transformer,
                discard_original_layers=False,
            )
            if not hasattr(transformer, "_transformer_norm_layers"):
                 transformer._transformer_norm_layers = original_norm_layers_state_dict
@christopher5106 christopher5106 changed the title Loading multiple loras: norms do not return to their original values Unloading multiple loras: norms do not return to their original values Feb 7, 2025
@sayakpaul
Copy link
Member

Should it not already take care of it?

transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)

What am I missing?

Additionally, the following test does ensure its effectivity:

def test_lora_unload_with_parameter_expanded_shapes(self):

@christopher5106
Copy link
Author

Ah, is it possible to call load_lora_weights() multiple times on a pipeline to load multiple weights ? does it unload in between to restore original weights ?

@sayakpaul
Copy link
Member

If you don’t call unload_lora_weights() it won’t be called automatically.

@christopher5106
Copy link
Author

so in that case of multiple calls to load_lora_weights(), the attribute _transformer_norm_layers become overwritten by the norms of the previously loaded lora ?

@sayakpaul
Copy link
Member

If you're loading Control LoRA and want to keep it with others, the norm layer values that came with the Control LoRA will remain like so. Otherwise, the effectiveness of Control LoRA won't fully be there.

Or am I misinterpreting the core use case here?

@christopher5106
Copy link
Author

From what I understand, you make the assumption that only Control Lora have trained norms, right ?

From what i see in my code is:

for lora, adapter in loras:
   pipe.load_lora_weights(local_weights_cache, adapter_name=adapter_name)

that means if two loras come with trained norm layers, we loose original weights. I believe we should make it generic because we might forget this assumption.

@sayakpaul
Copy link
Member

From what I understand, you make the assumption that only Control Lora have trained norms, right ?

Well, Control LoRA is a bit of a special case in that its state dicts have the exact same norm params as their non-LoRA variants. More specifically, these norm layer params differ from the base Flux.1 Dev model and there were taken from the non-LoRA control variants of Flux (so for the Depth Control LoRA, that would be https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev and for the Canny Control LoRA, that would be https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev/). The norm params are not LoRA params.

But on the other hand, it's totally possible to also target the norm layers for applying LoRA (which is something we see often in the community).

So, I don't think there's a need to change anything here.

@christopher5106
Copy link
Author

but my wonder is about the following:

I load one lora with norm layers, let's say lora1_norms:

  • it saves the original layer weights (let's say original_norms) in _transformer_norm_layers

I load a second lora with norm layers, let's s say lora2_norms:

  • it saves the overwritten layer weights in _transformer_norm_layers

so now in _transformer_norm_layers I have lora1_norms

when I apply unload_lora_weights(), does it revert to original_norms or to lora1_norms ? what did I miss ?

@sayakpaul
Copy link
Member

when I apply unload_lora_weights(), does it revert to original_norms or to lora1_norms ? what did I miss ?

It unloads all the LoRA overwritten params including lora1 and lora2.

More notes:
https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux#note-about-unloadloraweights-when-using-flux-loras

@christopher5106
Copy link
Author

christopher5106 commented Feb 15, 2025

Sorry I don't see what I miss

If I call two times load_lora_weights with 2 different loras lora1 and lora2 that have norm layers, that means to me that two times it's calling _load_norm_into_transformer() that returns the overwritten layers. But to me, where it takes these values is from the current transformer state dict

overwritten_layers_state_dict[key] = transformer_state_dict[key].clone()
and if these norms have been overwritten by lora1 already, when you call unload_loraweights, it will restore the lora1 norm with
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
but not the original norms in the model.

@sayakpaul
Copy link
Member

Yeah you're right! But it's also not very common to have structures for norm layer params. We can prioritize that as in when such a LoRA comes in. For now, we can add a comment to make ourselves aware of it.

@christopher5106
Copy link
Author

As you want ;)
But you see my point now :-)
I believe it's a bit wider problem than norm layers since load_state_dict() and state_dict() can load whatever tensors for whatever layers in our models

@sayakpaul
Copy link
Member

But it's also weird for LoRA state dicts to have arbitrary non-LoRA params. So, that is there.

@christopher5106
Copy link
Author

bfl could create a new trend that loras with bias and norms become a way to have better loras, no?

@sayakpaul
Copy link
Member

And we support that already. My point was putting arbitrary keys into a LoRA state dict. As in when those things come, we can support. We cannot speculate about those.

@christopher5106
Copy link
Author

christopher5106 commented Feb 15, 2025

Some companies use diffusers code on their production and now there is a good way to hack them is to submit 2 loras, the first one with norms either infinite or close to zero, the second one with norms, and this will make their production produce either b&w images or images full of artifacts if these norms are never unloaded, no ?

@christopher5106
Copy link
Author

Let me summarize what I understand:

Case 1/ A user adds a BFL canny lora and BFL depth lora on the same pipe: unloading does not restore original norm layers for flux dev

Case2/ A hacker creates a lora with norm layers and submit to different platforms based on diffusers that heavily load/unload loras: unloading does not restore the model norm layers for flux.

Is that correct ?

@a-r-r-o-w
Copy link
Member

What you mention is indeed a problem.

We can:

  • either raise an error if a second lora with norms is loaded mentioning only loading 1 is supported
  • only keep the original copy of the norms from base model and revert to it when unload_lora_weights is called. if any further layers with norms are loaded, we raise a warning and ignore

Anything apart from this is probably too much effort for something not as popularly done in the community as of yet. Folks that use diffusers in production this way should be mindful of the implications when allowing loading arbitrary state dict and develop their own countermeasures for such scenarios. WDYT?

@christopher5106
Copy link
Author

christopher5106 commented Feb 19, 2025

It's just 3 lines to add not to override already stored norm values:

            if hasattr(transformer, "_transformer_norm_layers"):
                  for key in original_norm_layers_state_dict:
                        if key not in transformer._transformer_norm_layers:
                                transformer._transformer_norm_layers[key] = original_norm_layers_state_dict[key]
            else:
                 transformer._transformer_norm_layers = original_norm_layers_state_dict

Other codes do it, for example here where only norm layers that have not been already stored are cloned for later restoration.

@christopher5106
Copy link
Author

There is another thing a bit surprising to me:

        if not (has_lora_keys or has_norm_keys):
            raise ValueError("Invalid LoRA checkpoint.")

# Flux Control LoRAs also have norm keys
has_norm_keys = any(
norm_key in key for key in state_dict.keys() for norm_key in self._control_lora_supported_norm_keys
)
if not (has_lora_keys or has_norm_keys):
raise ValueError("Invalid LoRA checkpoint.")

shouldn't it be simply if not has_lora_keys ? Why is allowed to load a statedict made of norm keys but no lora keys ?

thanks for your clarification

the less we allow to load layers this way by overriding other layers, the better

@christopher5106
Copy link
Author

Last, I'm wondering why only weights for text_encoder_1 are loaded. Anyway, we don't have any loras submitted with trained loras for text encoder until now.

@sayakpaul
Copy link
Member

Because we don't have T5-trained LoRAs yet.

I have a PR open to help users warn if there are unused keys in the state dict: #10187.

We recently added support to load Flux community LoRAs that have text encoder (CLIP) too: #10810.

@christopher5106
Copy link
Author

christopher5106 commented Feb 21, 2025

Let me share you an extract of one Lora we got:

"diffusion_model.double_blocks.0.img_attn.proj.diff_b","[3072]","torch.float16"
"diffusion_model.double_blocks.0.img_attn.proj.lora_down.weight","[32, 3072]","torch.float16"
"diffusion_model.double_blocks.0.img_attn.proj.lora_up.weight","[3072, 32]","torch.float16"
"diffusion_model.double_blocks.0.img_attn.qkv.diff_b","[9216]","torch.float16""diffusion_model.double_blocks.0.img_attn.qkv.lora_down.weight","[32, 3072]","torch.float16"
"diffusion_model.double_blocks.0.img_attn.qkv.lora_up.weight","[9216, 32]","torch.float16""diffusion_model.double_blocks.0.img_mlp.0.diff_b","[12288]","torch.float16"
"diffusion_model.double_blocks.0.img_mlp.0.lora_down.weight","[32, 3072]","torch.float16""diffusion_model.double_blocks.0.img_mlp.0.lora_up.weight","[12288, 32]","torch.float16"
"diffusion_model.double_blocks.0.img_mlp.2.diff_b","[3072]","torch.float16""diffusion_model.double_blocks.0.img_mlp.2.lora_down.weight","[32, 12288]","torch.float16"
"diffusion_model.double_blocks.0.img_mlp.2.lora_up.weight","[3072, 32]","torch.float16""diffusion_model.double_blocks.0.img_mod.lin.diff_b","[18432]","torch.float16"
"diffusion_model.double_blocks.0.img_mod.lin.lora_down.weight","[32, 3072]","torch.float16""diffusion_model.double_blocks.0.img_mod.lin.lora_up.weight","[18432, 32]","torch.float16"
...
"diffusion_model.single_blocks.0.modulation.lin.diff_b","[9216]","torch.float16""diffusion_model.single_blocks.0.modulation.lin.lora_down.weight","[32, 3072]","torch.float16""diffusion_model.single_blocks.0.modulation.lin.lora_up.weight","[9216, 32]","torch.float16"
"diffusion_model.single_blocks.1.linear1.diff_b","[21504]","torch.float16"
...
"diffusion_model.vector_in.in_layer.diff_b","[3072]","torch.float16"
"diffusion_model.vector_in.in_layer.lora_down.weight","[32, 768]","torch.float16"
"diffusion_model.vector_in.in_layer.lora_up.weight","[3072, 32]","torch.float16"
"diffusion_model.vector_in.out_layer.diff_b","[3072]","torch.float16"
"diffusion_model.vector_in.out_layer.lora_down.weight","[32, 3072]","torch.float16"
"diffusion_model.vector_in.out_layer.lora_up.weight","[3072, 32]","torch.float16""text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.lora_down.weight","[32, 768]","torch.float16"
"text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.lora_up.weight","[77, 32]","torch.float16"
"text_encoders.clip_l.transformer.text_model.encoder.layers.0.layer_norm1.diff","[768]","torch.float16"
"text_encoders.clip_l.transformer.text_model.encoder.layers.0.layer_norm1.diff_b","[768]","torch.float16"
"text_encoders.clip_l.transformer.text_model.encoder.layers.0.layer_norm2.diff","[768]","torch.float16"
...
"text_encoders.t5xxl.transformer.encoder.block.0.layer.0.layer_norm.diff","[4096]","torch.float16"
"text_encoders.t5xxl.transformer.encoder.block.0.layer.1.DenseReluDense.wi_0.lora_down.weight","[32, 4096]","torch.float16"
"text_encoders.t5xxl.transformer.encoder.block.0.layer.1.DenseReluDense.wi_0.lora_up.weight","[10240, 32]","torch.float16"
"text_encoders.t5xxl.transformer.encoder.block.0.layer.1.DenseReluDense.wi_1.lora_down.weight","[32, 4096]","torch.float16"
"text_encoders.t5xxl.transformer.encoder.block.0.layer.1.DenseReluDense.wi_1.lora_up.weight","[10240, 32]","torch.float16"
"text_encoders.t5xxl.transformer.encoder.block.0.layer.1.DenseReluDense.wo.lora_down.weight","[32, 10240]","torch.float16"
"text_encoders.t5xxl.transformer.encoder.block.0.layer.1.DenseReluDense.wo.lora_up.weight","[4096, 32]","torch.float16"
"text_encoders.t5xxl.transformer.encoder.block.0.layer.1.layer_norm.diff","[4096]","torch.float16"

Looks like users train Flux with T5 or is it another model ?

@sayakpaul
Copy link
Member

Nice, it's first time I saw a T5 trained LoRA. Is it Flux? Btw, we are deviating from the original thread. So, let's please move this discussion to a new one (as a feature request) and I will add support.

@christopher5106
Copy link
Author

#10862 yes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants