Skip to content

Commit

Permalink
initial ORPO implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
feffy380 committed May 13, 2024
1 parent 4e5db9a commit 2accb06
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 15 deletions.
3 changes: 2 additions & 1 deletion library/custom_train_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,7 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
if noise_offset is None:
return noise
if adaptive_noise_scale is not None:
raise NotImplemented
# latent shape: (batch_size, channels, height, width)
# abs mean value for each channel
latent_mean = torch.abs(latents.mean(dim=(2, 3), keepdim=True))
Expand All @@ -484,7 +485,7 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
noise_offset = noise_offset + adaptive_noise_scale * latent_mean
noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative

noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
noise = noise + noise_offset * torch.randn((latents.shape[0] // 2, latents.shape[1], 1, 1), device=latents.device).repeat(2, 1, 1, 1)
return noise


Expand Down
63 changes: 52 additions & 11 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,11 @@ def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool,
self.caption: str = caption
self.is_reg: bool = is_reg
self.absolute_path: str = absolute_path
self.absolute_path_l: str = None
self.image_size: Tuple[int, int] = None
self.image_size_l: Tuple[int, int] = None
self.resized_size: Tuple[int, int] = None
self.resized_size_l: Tuple[int, int] = None
self.bucket_reso: Tuple[int, int] = None
self.latents: torch.Tensor = None
self.latents_flipped: torch.Tensor = None
Expand All @@ -156,6 +159,7 @@ def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool,
self.latents_crop_ltrb: Tuple[int, int] = None # crop left top right bottom in original pixel size, not latents size
self.cond_img_path: str = None
self.image: Optional[Image.Image] = None # optional, original PIL Image
self.image_l: Optional[Image.Image] = None
# SDXL, optional
self.text_encoder_outputs_npz: Optional[str] = None
self.text_encoder_outputs1: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -912,6 +916,8 @@ def make_buckets(self):
for info in tqdm(self.image_data.values()):
if info.image_size is None:
info.image_size = self.get_image_size(info.absolute_path)
if info.absolute_path_l and info.image_size_l is None:
info.image_size_l = self.get_image_size(info.absolute_path_l)

if self.enable_bucket:
logger.info("make buckets")
Expand Down Expand Up @@ -941,6 +947,11 @@ def make_buckets(self):
image_info.bucket_reso, image_info.resized_size, ar_error = self.bucket_manager.select_bucket(
image_width, image_height
)
if image_info.absolute_path_l is not None:
bucket_reso_l, image_info.resized_size_l, _ = self.bucket_manager.select_bucket(
*image_info.image_size_l
)
assert bucket_reso_l == image_info.bucket_reso, (f"{image_info.absolute_path} buckets don't match", bucket_reso_l, image_info.bucket_reso)

# logger.info(image_info.image_key, image_info.bucket_reso)
img_ar_errors.append(abs(ar_error))
Expand Down Expand Up @@ -1092,18 +1103,26 @@ def __len__(self):
def __getitem__(self, idx):
image_infos = self.batches[idx]
images = []
images_l = []
for info in image_infos:
image = load_image(info.absolute_path) if info.image is None else np.array(info.image, np.uint8)
image_l = load_image(info.absolute_path_l) if info.image_l is None else np.array(info.image_l, np.uint8)
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
image, original_size, crop_ltrb = trim_and_resize_if_required(self.random_crop, image, info.bucket_reso, info.resized_size)
image = IMAGE_TRANSFORMS(image)
images.append(image)

# prepare non-preferred image
image_l, _, _ = trim_and_resize_if_required(self.random_crop, image_l, info.bucket_reso, info.resized_size_l)
image_l = IMAGE_TRANSFORMS(image_l)
images_l.append(image_l)

info.latents_original_size = original_size
info.latents_crop_ltrb = crop_ltrb

img_tensors = torch.stack(images, dim=0)
return image_infos, img_tensors
img_tensors_l = torch.stack(images_l, dim=0)
return image_infos, (img_tensors, img_tensors_l)

def custom_collate(batch):
return batch
Expand Down Expand Up @@ -1696,6 +1715,8 @@ def load_dreambooth_dir(subset: DreamBoothSubset):

for img_path, caption, size in zip(img_paths, captions, sizes):
info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
l_path = pathlib.Path(info.absolute_path)
info.absolute_path_l = str((l_path.parent.with_name(l_path.parent.name + "_l") / l_path.name).with_suffix(".png"))
if size is not None:
info.image_size = size
if subset.is_reg:
Expand Down Expand Up @@ -2214,7 +2235,7 @@ def disable_token_padding(self):


def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool):
expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意
expected_latents_size = (8, reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意

if not os.path.exists(npz_path):
return False
Expand All @@ -2226,13 +2247,13 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool):
return False
if "latents" not in npz or "original_size" not in npz or "crop_ltrb" not in npz: # old ver?
return False
if npz["latents"].shape[1:3] != expected_latents_size:
if npz["latents"].shape[:3] != expected_latents_size:
return False

if flip_aug:
if "latents_flipped" not in npz:
return False
if npz["latents_flipped"].shape[1:3] != expected_latents_size:
if npz["latents_flipped"].shape[:3] != expected_latents_size:
return False

return True
Expand Down Expand Up @@ -2498,21 +2519,30 @@ def cache_batch_latents(
latents_flipped is also set if flip_aug is True
latents_original_size and latents_crop_ltrb are also set
"""
img_tensors, img_tensors_l = img_tensors
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
img_tensors_l = img_tensors_l.to(device=vae.device, dtype=vae.dtype)
with torch.no_grad():
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
latents_l = vae.encode(img_tensors_l).latent_dist.sample().to("cpu")

if flip_aug:
img_tensors = torch.flip(img_tensors, dims=[3])
img_tensors_l = torch.flip(img_tensors_l, dims=[3])
with torch.no_grad():
flipped_latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
flipped_latents_l = vae.encode(img_tensors_l).latent_dist.sample().to("cpu")
flipped_latents = torch.cat((flipped_latents, flipped_latents_l), dim=1)
else:
flipped_latents = [None] * len(latents)
flipped_latents = [None] * len(latents) * 2

for info, latent, flipped_latent in zip(image_infos, latents, flipped_latents):
# check NaN
if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()):
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
for info, latent, latent_l, flipped_latent in zip(image_infos, latents, latents_l, flipped_latents):
# NOTE: disabled for speed
# # check NaN
# if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()):
# raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")

latent = torch.cat((latent, latent_l), dim=0)

if cache_to_disk:
save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent)
Expand Down Expand Up @@ -3417,6 +3447,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=0.1,
help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1",
)
parser.add_argument(
"--beta_orpo",
type=float,
default=0.1,
help="ORPO contribution factor",
)

parser.add_argument(
"--lowram",
Expand Down Expand Up @@ -5146,27 +5182,32 @@ def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler,

def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
noise = torch.randn_like(latents, device=latents.device).chunk(2)[0].repeat(2, 1, 1, 1)
if args.noise_offset:
if args.noise_offset_random_strength:
noise_offset = torch.rand(1, device=latents.device) * args.noise_offset
else:
noise_offset = args.noise_offset
noise = custom_train_functions.apply_noise_offset(latents, noise, noise_offset, args.adaptive_noise_scale)
if args.multires_noise_iterations:
raise NotImplemented
noise = custom_train_functions.pyramid_noise_like(
noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount
)

# Sample a random timestep for each image
b_size = latents.shape[0]
b_size = latents.shape[0] // 2
min_timestep = 0 if args.min_timestep is None else args.min_timestep
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep
timesteps, huber_c = get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, latents.device)
timesteps = timesteps.repeat(2)
if isinstance(huber_c, torch.Tensor):
huber_c = huber_c.repeat(2)

# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
if args.ip_noise_gamma:
raise NotImplemented
if args.ip_noise_gamma_random_strength:
strength = torch.rand(1, device=latents.device) * args.ip_noise_gamma
else:
Expand Down
23 changes: 20 additions & 3 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,6 +1060,8 @@ def update_embeddings_map(accelerator, text_encoders, embeddings_map, embedding_

if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(device=accelerator.device, dtype=weight_dtype, non_blocking=True)
# (batch_size, 2*channels, h, w) -> (2*batch_size, channels, h, w)
latents = torch.cat(latents.chunk(2, dim=1))
else:
with torch.no_grad():
# latentに変換
Expand Down Expand Up @@ -1097,6 +1099,7 @@ def update_embeddings_map(accelerator, text_encoders, embeddings_map, embedding_
text_encoder_conds = self.get_text_cond(
args, accelerator, batch, tokenizers, text_encoders, weight_dtype
)
text_encoder_conds = text_encoder_conds.repeat(2, *([1] * len(text_encoder_conds.shape[1:])))

with torch.no_grad():
latents = latents * self.vae_scale_factor
Expand Down Expand Up @@ -1136,15 +1139,29 @@ def update_embeddings_map(accelerator, text_encoders, embeddings_map, embedding_
else:
target = noise

loss = train_util.conditional_loss(
model_losses = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
model_losses = apply_masked_loss(model_losses, batch)

# ODDS ratio loss
# In the diffusion formulation, we're assuming that the MSE loss
# approximates the logp.
model_losses = model_losses.mean(dim=list(range(1, len(model_losses.shape))))
model_losses_w, model_losses_l = model_losses.chunk(2)
log_odds = model_losses_w - model_losses_l

# Ratio loss.
ratio = torch.nn.functional.logsigmoid(log_odds)
ratio_losses = args.beta_orpo * ratio

# Full ORPO loss
loss = model_losses_w - ratio_losses

loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
timesteps = timesteps.chunk(2)[0]

if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
Expand Down

0 comments on commit 2accb06

Please sign in to comment.