Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bugs when sdxl training #234

Merged
merged 5 commits into from
Jan 15, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 92 additions & 39 deletions lycoris/kohya.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ def create_network(
rs_lora = str_bool(kwargs.get("rs_lora", False))
unbalanced_factorization = str_bool(kwargs.get("unbalanced_factorization", False))
train_t5xxl = str_bool(kwargs.get("train_t5xxl", False))
#lora_plus
loraplus_lr_ratio = float(kwargs.get("loraplus_lr_ratio", None)) if kwargs.get("loraplus_lr_ratio", None) is not None else None
loraplus_unet_lr_ratio = float(kwargs.get("loraplus_unet_lr_ratio", None)) if kwargs.get("loraplus_unet_lr_ratio", None) is not None else None
loraplus_text_encoder_lr_ratio = float(kwargs.get("loraplus_text_encoder_lr_ratio", None)) if kwargs.get("loraplus_text_encoder_lr_ratio", None) is not None else None
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)

if unbalanced_factorization:
logger.info("Unbalanced factorization for LoKr is enabled")
Expand Down Expand Up @@ -157,22 +163,23 @@ def create_network_from_weights(
if lora_name in unet_loras:
unet_loras[lora_name] = modules

if isinstance(text_encoder, list):
text_encoders = text_encoder
use_index = True
else:
text_encoders = [text_encoder]
use_index = False

for idx, te in enumerate(text_encoders):
if use_index:
prefix = f"{LycorisNetworkKohya.LORA_PREFIX_TEXT_ENCODER}{idx+1}"
if text_encoder:
if isinstance(text_encoder, list):
text_encoders = text_encoder
use_index = True
else:
prefix = LycorisNetworkKohya.LORA_PREFIX_TEXT_ENCODER
for name, modules in te.named_modules():
lora_name = f"{prefix}_{name}".replace(".", "_")
if lora_name in te_loras:
te_loras[lora_name] = modules
text_encoders = [text_encoder]
use_index = False

for idx, te in enumerate(text_encoders):
if use_index:
prefix = f"{LycorisNetworkKohya.LORA_PREFIX_TEXT_ENCODER}{idx+1}"
else:
prefix = LycorisNetworkKohya.LORA_PREFIX_TEXT_ENCODER
for name, modules in te.named_modules():
lora_name = f"{prefix}_{name}".replace(".", "_")
if lora_name in te_loras:
te_loras[lora_name] = modules

original_level = logger.level
logger.setLevel(logging.ERROR)
Expand All @@ -192,14 +199,16 @@ def create_network_from_weights(
logger.info(f"{len(network.unet_loras)} Modules Loaded")

logger.info("Loading TE Modules from state dict...")
for lora_name, orig_modules in te_loras.items():
if orig_modules is None:
continue
lyco_type, params = get_module(weights_sd, lora_name)
module = make_module(lyco_type, params, lora_name, orig_modules)
if module is not None:
network.text_encoder_loras.append(module)
logger.info(f"{len(network.text_encoder_loras)} Modules Loaded")

if text_encoder:
for lora_name, orig_modules in te_loras.items():
if orig_modules is None:
continue
lyco_type, params = get_module(weights_sd, lora_name)
module = make_module(lyco_type, params, lora_name, orig_modules)
if module is not None:
network.text_encoder_loras.append(module)
logger.info(f"{len(network.text_encoder_loras)} Modules Loaded")

for lora in network.unet_loras + network.text_encoder_loras:
lora.multiplier = multiplier
Expand Down Expand Up @@ -292,6 +301,11 @@ def __init__(
self.multiplier = multiplier
self.lora_dim = lora_dim
self.train_t5xxl = train_t5xxl

# 初始化LoRA+相关属性
self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
self.loraplus_text_encoder_lr_ratio = None

if not self.ENABLE_CONV:
conv_lora_dim = 0
Expand Down Expand Up @@ -607,30 +621,69 @@ def apply_max_norm_regularization(self, max_norm_value, device):

return key_scaled, sum(norms) / len(norms), max(norms)

def prepare_optimizer_params(self, text_encoder_lr=None, unet_lr: float = 1e-4, learning_rate=None):
def enumerate_params(loras):
params = []
for lora in loras:
params.extend(lora.parameters())
return params
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
self.loraplus_lr_ratio = loraplus_lr_ratio
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio

logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")

def prepare_optimizer_params(self, text_encoder_lr=None, unet_lr: float = 1e-4, learning_rate=None):
self.requires_grad_(True)

all_params = []
lr_descriptions = []

def assemble_params(loras, lr, ratio):
param_groups = {"lora": {}, "plus": {}}
for lora in loras:
for name, param in lora.named_parameters():
if ratio is not None and "lora_up" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
param_groups["lora"][f"{lora.lora_name}.{name}"] = param

params = []
descriptions = []
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}

if len(param_data["params"]) == 0:
continue

if lr is not None:
if key == "plus":
param_data["lr"] = lr * ratio
else:
param_data["lr"] = lr

if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
logger.info("NO LR skipping!")
continue

params.append(param_data)
descriptions.append("plus" if key == "plus" else "")

return params, descriptions

if self.text_encoder_loras:
param_data = {"params": enumerate_params(self.text_encoder_loras)}
if text_encoder_lr is not None:
param_data["lr"] = text_encoder_lr
all_params.append(param_data)
lr_descriptions.append("text_encoder")
params, descriptions = assemble_params(
self.text_encoder_loras,
text_encoder_lr if text_encoder_lr is not None else default_lr,
self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio,
)
all_params.extend(params)
lr_descriptions.extend(["textencoder" + (" " + d if d else "") for d in descriptions])

if self.unet_loras:
param_data = {"params": enumerate_params(self.unet_loras)}
if unet_lr is not None:
param_data["lr"] = unet_lr
all_params.append(param_data)
lr_descriptions.append("unet")
params, descriptions = assemble_params(
self.unet_loras,
unet_lr if unet_lr is not None else default_lr,
self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
)
all_params.extend(params)
lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions])

return all_params, lr_descriptions

Expand Down