Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deepspeed train flux1 dreambooth lora can not save model #9393

Closed
ldtgodlike opened this issue Sep 9, 2024 · 16 comments · Fixed by #9473
Closed

deepspeed train flux1 dreambooth lora can not save model #9393

ldtgodlike opened this issue Sep 9, 2024 · 16 comments · Fixed by #9473
Labels
bug Something isn't working

Comments

@ldtgodlike
Copy link

Describe the bug

when I run the script train_dreambooth_lora_flux.py. It raise ValueError: unexpected save model: <class 'deepspeed.runtime.engine.DeepSpeedEngine'>. something bug in save_model_hook?
Uploading image.png…

Reproduction

accelerate launch train_dreambooth_lora_flux_custom.py
--pretrained_model_name_or_path=$MODEL_NAME
--instance_data_dir=$INSTANCE_DIR
--output_dir=$OUTPUT_DIR
--mixed_precision="bf16"
--instance_prompt="bedroom, YF_CN style"
--resolution=1024
--train_batch_size=1
--guidance_scale=1
--gradient_accumulation_steps=4
--optimizer="prodigy"
--learning_rate=1.
--report_to="tensorboard"
--lr_scheduler="constant"
--lr_warmup_steps=0
--num_train_epochs=30
--validation_prompt="bedroom, YF_CN style"
--validation_epochs=80
--checkpointing_steps=500
--seed="0"
--gradient_checkpointing
--use_8bit_adam
--rank=4

Logs

No response

System Info

torch==2.3.1
accelerate==0.34.2
deepspeed==0.15.1+8ac42ed7
diffusers==0.31.0.dev0

default_config.yaml as follow:

compute_environment: LOCAL_MACHINE
debug: true
deepspeed_config:
gradient_accumulation_steps: 1
gradient_clipping: 1.0
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: fals

Who can help?

@sayakpaul

@ldtgodlike ldtgodlike added the bug Something isn't working label Sep 9, 2024
@sayakpaul
Copy link
Member

Refer to huggingface/accelerate#2787 to get an idea of the adjustments needed to make it work.

@ldtgodlike
Copy link
Author

ldtgodlike commented Sep 11, 2024

Refer to huggingface/accelerate#2787 to get an idea of the adjustments needed to make it work.
if isinstance(unwrap_model(model), type(unwrap_model(transformer)))
inplace
if isinstance(model, type(unwrap_model(transformer)))
can save the checkpoint。
However, this script still has many errors, including but not limited to being unable to load the lora trained in deepspeed but pytorch only can work, activating train_text_decoder, and accelerating the initialization of multiple models. I don't know why the index start from 1 instead of 0 which raise out of range of list:

params_to_optimize[1]["lr"] = args.learning_rate
params_to_optimize[2]["lr"] = args.learning_rate

@sayakpaul
image
and
image

@sayakpaul
Copy link
Member

However, this script still has many errors, including but not limited to being unable to load the lora trained in deepspeed but pytorch only can work, activating train_text_decoder, and accelerating the initialization of multiple models. I don't know why the index start from 1 instead of 0 which raise out of range of list:

It would have more helpful if provided more information on how you're launching the training experiments, etc. We already test if we're able to resume training:

def test_dreambooth_checkpointing(self):

activating train_text_decoder, and accelerating the initialization of multiple models. I don't know why the index start from 1 instead of 0 which raise out of range of list:

This I don't understand. Please elaborate so that we can provide further suggestions.

@ldtgodlike
Copy link
Author

ldtgodlike commented Sep 11, 2024

it seems as if no --train_text_encoder found in:
diffusers/examples/dreambooth/test_dreambooth_flux.py

my script as follow:

accelerate launch train_dreambooth_lora_flux.py \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --instance_data_dir=$INSTANCE_DIR \
  --output_dir=$OUTPUT_DIR \
  --mixed_precision="bf16" \
  --instance_prompt="bedroom, YF_CN style" \
  --resolution=1024 \
  --train_batch_size=1 \
  --guidance_scale=1 \
  --gradient_accumulation_steps=4 \
  --optimizer="prodigy" \
  --learning_rate=1. \
  --report_to="tensorboard" \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --num_train_epochs=20 \
  --validation_prompt="bedroom, YF_CN style" \
  --validation_epochs=80 \
  --checkpointing_steps=10 \
  --seed="0" \
  --gradient_checkpointing \
  --use_8bit_adam \
  --dataloader_num_workers=1 \
  --train_text_encoder \
  --rank=4

@sayakpaul
and inference script:

import torch
from diffusers import FluxPipeline, FluxTransformer2DModel
import numpy as np
import os
from PIL import Image
from tqdm import tqdm


pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev",
                                    torch_dtype=torch.bfloat16,
                                    )

lora_path = "lora_path"
pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors")
pipe.fuse_lora()
pipe.to(torch.device('cuda'))

width=1024
height=800

prompt = "a bedroom with a bed, 2 night stand and a wardrobe,a bay window on the right, YF_CN style"

images = []
for i in range(10):
    generator = torch.manual_seed(i)
    image = pipe(prompt=prompt,
                num_inference_steps=20, width=width, height=height, generator=generator
                                     ).images[0]
    images.append(np.asarray(image))
image = Image.fromarray(np.vstack(images))
image.save("test.jpg")

pytorch_lora_weights.zip
deepspeed config:

compute_environment: LOCAL_MACHINE
debug: true
deepspeed_config:
  gradient_accumulation_steps: 1
  gradient_clipping: 1.0
  offload_optimizer_device: none
  offload_param_device: none
  zero3_init_flag: false
  zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

@sayakpaul
Copy link
Member

Okay so, it fails for train_text_encoder or does it fail without train_text_encoder as well?

@ldtgodlike
Copy link
Author

Okay so, it fails for train_text_encoder or does it fail without train_text_encoder as well?

only fails for train_text_encoder

@sayakpaul
Copy link
Member

Okay that is helpful.

The error you posted in #9393 (comment), seems easy to solve. We should just filter out the "module" keys in the state dict and it should work. Can you try that out first?

What errors do you see in the text encoder training?

@ldtgodlike
Copy link
Author

errors caused by accelerate and deepspeed like
image

@sayakpaul
Copy link
Member

Oh that I am not sure about then. Ccing @muellerzr for advice.

@ldtgodlike
Copy link
Author

Okay that is helpful.

The error you posted in #9393 (comment), seems easy to solve. We should just filter out the "module" keys in the state dict and it should work. Can you try that out first?

the lora trained by deepspeed, i filter out the "module" in keys, and it could work as same as without deeepspeed:

from safetensors.torch import save_file
from diffusers import FluxPipeline

def convett(input_dir, sava_path):
    lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
    transformer_state_dict = {
        f'{k.replace("module.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
    }
    save_file(transformer_state_dict, sava_path)

input_dir = "pytorch_lora_weights.safetensors"
sava_path = "out.safetensors"
convett(input_dir, sava_path)

@sayakpaul
Copy link
Member

Yeah of course that is why I suggested. Usually, you would want to always call unwrap_model() here

transformer_lora_layers_to_save = get_peft_model_state_dict(model)

and here

text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)

Copy link

github-actions bot commented Oct 9, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Oct 9, 2024
@sayakpaul
Copy link
Member

Is this still a problem?

@ldtgodlike
Copy link
Author

no problem

@github-actions github-actions bot removed the stale Issues that haven't received updates label Oct 12, 2024
@LianShuaiLong
Copy link

same error, and i have tried this modification

if isinstance(unwrap_model(model), type(unwrap_model(transformer)))
inplace
if isinstance(model, type(unwrap_model(transformer)))

but it does not work

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
3 participants