Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot restart training after training tenc 2 AND using fused_backward_pass #1369

Open
araleza opened this issue Jun 11, 2024 · 14 comments
Open

Comments

@araleza
Copy link

araleza commented Jun 11, 2024

If you finetune SDXL base with:

--train_text_encoder --learning_rate_te1 1e-10 --learning_rate_te2 1e-10 --fused_backward_pass

Then it will train fine. But if you stop training and restart by training from the e.g. <whatever>-step00001000.safetensors file, you get this error message:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1280, 1280]], which is output 0 of AsStridedBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

This doesn't happen if you only train te1 and the unet. It also only happens when you use --fused_backward_pass.

Full call stack:

Traceback (most recent call last):
  File "/home/ara/m.2/Dev/sdxl/sd-scripts/./sdxl_train.py", line 944, in <module>
    train(args)
  File "/home/ara/m.2/Dev/sdxl/sd-scripts/./sdxl_train.py", line 733, in train
    accelerator.backward(loss)
  File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 1905, in backward
    loss.backward(**kwargs)
  File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/_tensor.py", line 522, in backward
    torch.autograd.backward(
  File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/autograd/__init__.py", line 266, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/autograd/function.py", line 289, in apply
    return user_fn(self, *args)
  File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 319, in backward
    torch.autograd.backward(outputs_with_grad, args_with_grad)
  File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/autograd/__init__.py", line 266, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass

(I also mentioned this bug back when the original pull request occurred: see #1259)

@kohya-ss
Copy link
Owner

Thank you for opening this! Unfortunately, I cannot reproduce this issue. I think it may be caused by the difference of the version of PyTorch. Which version are you using? I'm using 2.1.2.

@araleza
Copy link
Author

araleza commented Jun 13, 2024

Sorry, I should have provided reproduction steps for this. Here they are now.

The bug reproduces even when on a clean checkout of the dev branch. (I haven't used main recently).

git clone https://github.com/kohya-ss/sd-scripts
cd sd-scripts
git switch dev
python -m venv venv
source venv/bin/activate
pip install torch torchvision -r requirements.txt
pip install xformers

(I have to install xformers with a second pip command rather than adding it to the parameters of the first 'pip install' line, due to version incompatibilities with torch. I still get torchvision 0.18.1 requires torch==2.3.1, but you have torch 2.3.0 which is incompatible., but this doesn't seem to be an issue when actually running.)

You asked about my torch version. Here it is (from pip list), along with my xformers version:

torch                     2.3.0
xformers                  0.0.26.post1

Then I run training:
accelerate launch --num_cpu_threads_per_process=2 "./sdxl_train.py" --pretrained_model_name_or_path="/home/ara/Documents/sdxl/sd_xl_base_1.0.safetensors" --enable_bucket --min_bucket_reso=64 --max_bucket_reso=1024 --train_data_dir="/home/ara/Documents/sdxl/img" --resolution="1024,1024" --output_dir="/home/ara/Documents/sdxl/dreambooth" --logging_dir="/home/ara/Documents/sdxl/log" --save_model_as=safetensors --vae="/home/ara/Documents/sdxl/sdxl_vae.safetensors" --output_name="earthscape" --lr_scheduler_num_cycles="20000" --max_token_length=150 --max_data_loader_n_workers="0" --lr_scheduler="constant_with_warmup" --lr_warmup_steps="200" --max_train_steps="16000" --caption_extension=".txt" --optimizer_type="Adafactor" --optimizer_args scale_parameter=False relative_step=False warmup_init=False --max_data_loader_n_workers="0" --max_token_length=150 --bucket_reso_steps=32 --save_every_n_steps="10" --save_last_n_steps="20" --min_snr_gamma=5 --gradient_checkpointing --xformers --bucket_no_upscale --noise_offset=0.0357 --sample_sampler=k_dpm_2 --fused_backward_pass --cache_latents --train_batch_size="4" --train_text_encoder --learning_rate_te1 1e-10 --learning_rate="2e-7"

I set up this training to write a .safetensors file almost immediately, at step 10. After step 10, I stop training, and run the same command line again, but this time changing sd_xl_base_1.0.safetensors to dreambooth/earthscape-step00000010.safetensors. This produces the error message:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1280, 1280]], which is output 0 of AsStridedBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

This should hopefully reproduce the issue for you. Thank you for your attention so far. :)

If I re-run these steps, but instead of --fused_basedwards_pass I use --full_bf16 --mixed_precision="bf16", then training restarts with no error message.

If I re-run the steps, keeping --fused_backwards_pass, but this time changing --learning_rate_te2 1e-10 to be --learning_rate_te2 0 (stopping tenc 2 from training) then the training process is again able to restart with no error message. So the error only occurs with both the fused_backwards_pass being enabled, and tenc 2 being trained.

This seems to be an important bug to fix for SDXL training, as I am seeing amazing results from training tenc 2 at a very low rate of 1e-10. This training rate has not been possible with bf16 training as it is below the precision that bf16 is able to handle. But with the fp32 training made possible with fused_backwards_pass and tenc 2 being trained, I see impressive image quality changes. I just cannot restart training if I stop!

By the way, I saw a new warning now that I've reinstalled sd-scripts to make these reproduction instructions:

sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/conv.py:456: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)
  return F.conv2d(input, weight, bias, self.stride,

I didn't have that warning before, so I don't think it's relevant to this bug report.

@araleza
Copy link
Author

araleza commented Jun 13, 2024

Update to my long reproduction steps above:

If I remove --save_model_as=safetensors then training is able to restart! (If I run without that option from the start)

So to reproduce the issue, all three of these options need to be set:

--save_model_as=safetensors
--fused_backward_pass
--learning_rate_te2 1e-10

Being able to use .ckpt instead of .safetensors to allow me to continue training is great news, as it provides a workaround way to restart training even without this bug being fixed.

Edit: I just got the error message again, even with .ckpt being used. :-/ Not sure why it worked for that one test run, but it seems that the bug does not need safetensors after all.

@kohya-ss
Copy link
Owner

Thank you for the detailed steps! The dev branch recommend torch==2.1.2 and xformers==0.0.23.post1 as wrote in README.md. So I may need a new venv to reproduce the issue.

In addition, I don't think the format of the file (.ckpt or .safetensors) affect the issue. So the issue may depend on something special...

@araleza
Copy link
Author

araleza commented Jun 13, 2024

Okay, so I took my build and installed those versions:

pip install torch==2.1.2 xformers==0.0.23.post1 torchvision

which got me:

torch                     2.1.2
xformers                  0.0.23.post1

(I had to include torchvision on that installation line to get one that worked with torch 2.1.2)

I made sure to regenerate a fresh .ckpt file, and didn't pick up the one that I'd already made with the later torch version that I previously had. But, the same error message still reproduces, even with a new .ckpt being written out by sd-scripts, and torch/xformers set to these older versions.

@araleza
Copy link
Author

araleza commented Jun 17, 2024

I started trying to get more information about this. Since the error message suggested that I add torch.autograd.set_detect_anomaly(True), I did that, and got the following debugging trace information associated with the error, if this rings any bells for anyone:

warnings.warn(
/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/autograd/init.py:251: UserWarning: Error detected in AddmmBackward0. Traceback of forward call that caused the error:
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/autograd/function.py", line 288, in apply
return user_fn(self, *args)
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 271, in backward
outputs = ctx.run_function(*detached_inputs)
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py", line 372, in forward
hidden_states, attn_weights = self.self_attn(
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py", line 262, in forward
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
(Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:114.)
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/autograd/init.py:251: UserWarning:

Previous calculation was induced by CheckpointFunctionBackward. Traceback of forward call that induced the previous calculation:
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/ara/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/main.py", line 39, in
cli.main()
File "/home/ara/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
run()
File "/home/ara/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
runpy.run_path(target, run_name="main")
File "/home/ara/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
return _run_module_code(code, init_globals, run_name,
File "/home/ara/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
_run_code(code, mod_globals, init_globals,
File "/home/ara/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
exec(code, run_globals)
File "sdxl_train.py", line 963, in
train(args)
File "sdxl_train.py", line 659, in train
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
File "/home/ara/m.2/Dev/sdxl/sd-scripts/library/train_util.py", line 4701, in get_hidden_states_sdxl
enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py", line 1207, in forward
text_outputs = self.text_model(
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py", line 703, in forward
encoder_outputs = self.encoder(
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py", line 622, in forward
layer_outputs = self._gradient_checkpointing_func(
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
return fn(*args, **kwargs)
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 451, in checkpoint
return CheckpointFunction.apply(function, preserve, *args)
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
(Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:121.)
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/autograd/init.py:251: UserWarning: Error detected in CheckpointFunctionBackward. Traceback of forward call that caused the error:
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/ara/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/main.py", line 39, in
cli.main()
File "/home/ara/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
run()
File "/home/ara/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
runpy.run_path(target, run_name="main")
File "/home/ara/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
return _run_module_code(code, init_globals, run_name,
File "/home/ara/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
_run_code(code, mod_globals, init_globals,
File "/home/ara/.vscode/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
exec(code, run_globals)
File "sdxl_train.py", line 963, in
train(args)
File "sdxl_train.py", line 659, in train
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
File "/home/ara/m.2/Dev/sdxl/sd-scripts/library/train_util.py", line 4701, in get_hidden_states_sdxl
enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py", line 1207, in forward
text_outputs = self.text_model(
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py", line 703, in forward
encoder_outputs = self.encoder(
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py", line 622, in forward
layer_outputs = self._gradient_checkpointing_func(
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
return fn(*args, **kwargs)
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 451, in checkpoint
return CheckpointFunction.apply(function, preserve, *args)
File "/home/ara/m.2/Dev/sdxl/sd-scripts/venv/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
(Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:114.)
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass

Tenc2 is mentioned in that trace:

enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)

Any thoughts?

@araleza
Copy link
Author

araleza commented Jun 17, 2024

Okay, I haven't fix it, but I've now found a workaround that allows training of tenc2 to continue, and it also indicates roughly where the trouble is likely to be coming from.

The workaround is to edit the torch library file: venv/lib/python3.10/site-packages/torch/utils/checkpoint.py which is in that subdirectory of your sd-scripts checkout, assuming you're using a venv (which you probably should be). Add the line indicated here:

image

and you can successfully continue training a .safetensors or .ckpt that was written out by sd-scripts while training.

The issue seems to be related this this warning, which is seen even when using the recommended torch version 2.1.2:

sd-scripts/venv/lib/python3.10/site-packages/torch/utils/checkpoint.py:430: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.

I'll let someone that actually knows what they're doing with this figure out the correct fix. And I'm still not sure why this issue doesn't trigger when starting training from sd_xl_base_1.0.safetensors , but it does trigger when starting from a .safetensors written out by sd-scripts.

@Jannchie
Copy link

I encountered a similar issue. When enabling optimizer_type = "AdamW" and train_text_encoder, a warning about use_reentrant appears. At this point, if I try to save the training state, the training process freezes.

@Jannchie
Copy link

Add {"use_reentrant": True} to sdxl_train.py could fix the use_reentrant problem.. But it still freezes on saved states. May be it is another issue.

if args.gradient_checkpointing:
    text_encoder1.gradient_checkpointing_enable({"use_reentrant": False})
    text_encoder2.gradient_checkpointing_enable({"use_reentrant": False})

@araleza
Copy link
Author

araleza commented Jun 25, 2024

@Jannchie, I didn't know use_reentrant could be passed into gradient_checkpointing_enable() like that. That's great news, as it lets my issue be fixed with a sd-scripts change, rather than hacking the library function like I was doing.

As for your hang, is it specific to AdamW? I've only been using Adafactor. Is there some advantage to using AdamW by the way? I haven't tried that.

@Jannchie
Copy link

I’m a beginner and not quite sure about the specific effects, but I am attempting to replicate the settings from https://huggingface.co/cagliostrolab/animagine-xl-3.1.

Regarding the freeze issue, by referring to this issue, I found that specifying the saving format as safetensors (instead of the default diffusers format) can resolve the problem.

@TopSalad3530
Copy link

TopSalad3530 commented Jun 28, 2024

I've ran into the same problem. In my case, the issue only started to appear after I switched to save_precision=float from FP16. Converting the problematic checkpoint down to FP16 seemed to have resolved the issue for me. I don't think this necessarily has anything to do with the precision itself however: it might just be that the conversion process happened to have cleared whatever problematic metadata ("version") from the tensors as a side effect.

I did also try the {"use_reentrant": False} option, but for some reason it increased VRAM consumption so much that training was no longer possible on 24GB, even at batch size 1, so I don't believe it's a one-size-fit-all solution to this problem. Mistake. See below.

@araleza
Copy link
Author

araleza commented Jun 28, 2024

Interesting. I'm also on 24GB, and batch size 4 works great for me. Have you

  1. passed in --gradient_checkpointing, which (if I forget to pass it in) is usually what makes me run out of memory unexpectedly.
  2. passed in --cache_latents, so it doesn't have to run the VAE repeatedly? I keep features like color augmentation / random crop switched off to allow this, but you can keep flip augmentation on.

@TopSalad3530
Copy link

My bad -- turned out that I based my modifications on the sdxl_train.py from main which didn't have fused_backward_pass at all, instead of dev. Tried again and this time everything went fine.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants