Skip to content

Commit

Permalink
implement masked loss for LoRA, Textual Inversion, Dreambooth & others
Browse files Browse the repository at this point in the history
implement mask loading from mask folder
  • Loading branch information
zapp authored and recris committed Oct 9, 2023
1 parent 2d87bb6 commit 7de0550
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 4 deletions.
6 changes: 6 additions & 0 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
get_weighted_text_embeddings,
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
get_latent_masks
)


Expand Down Expand Up @@ -339,6 +340,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
else:
target = noise

if args.masked_loss and batch['masks'] is not None:
mask = get_latent_masks(batch['masks'], noise_pred.shape, noise_pred.device)
noise_pred = noise_pred * mask
target = target * mask

if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred:
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
Expand Down
15 changes: 15 additions & 0 deletions library/custom_train_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,21 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
return noise


def get_latent_masks(image_masks, latent_shape, device):
masks = (
image_masks
.to(device)
.reshape(latent_shape[0], 1, latent_shape[2] * 8, latent_shape[3] * 8)
)
# resize to match latent
masks = torch.nn.functional.interpolate(
masks.float(),
size=latent_shape[-2:],
mode="nearest"
)
return masks


"""
##########################################
# Perlin Noise
Expand Down
62 changes: 58 additions & 4 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool,
self.text_encoder_outputs1: Optional[torch.Tensor] = None
self.text_encoder_outputs2: Optional[torch.Tensor] = None
self.text_encoder_pool2: Optional[torch.Tensor] = None
# Masked Loss
self.mask: np.ndarray = None
self.mask_flipped: np.ndarray = None


class BucketManager:
Expand Down Expand Up @@ -1050,6 +1053,7 @@ def __getitem__(self, index):
input_ids2_list = []
latents_list = []
images = []
masks = []
original_sizes_hw = []
crop_top_lefts = []
target_sizes_hw = []
Expand All @@ -1071,8 +1075,10 @@ def __getitem__(self, index):
crop_ltrb = image_info.latents_crop_ltrb # calc values later if flipped
if not flipped:
latents = image_info.latents
mask = image_info.mask
else:
latents = image_info.latents_flipped
mask = image_info.mask_flipped

image = None
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
Expand All @@ -1087,6 +1093,9 @@ def __getitem__(self, index):
# 画像を読み込み、必要ならcropする
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path)
im_h, im_w = img.shape[0:2]
# loss mask is alpha channel, separate it
mask = img[:, :, -1] / 255
img = img[:, :, :3]

if self.enable_bucket:
img, original_size, crop_ltrb = trim_and_resize_if_required(
Expand Down Expand Up @@ -1124,9 +1133,11 @@ def __getitem__(self, index):

latents = None
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
mask = torch.from_numpy(mask)

images.append(image)
latents_list.append(latents)
masks.append(torch.tensor(mask))

target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8)

Expand Down Expand Up @@ -1218,7 +1229,7 @@ def __getitem__(self, index):
else:
images = None
example["images"] = images

example["masks"] = torch.stack(masks) if masks[0] is not None else None
example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None
example["captions"] = captions

Expand Down Expand Up @@ -2132,12 +2143,44 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset:

def load_image(image_path):
image = Image.open(image_path)
if not image.mode == "RGB":
image = image.convert("RGB")
if not image.mode == "RGBA":
image = image.convert("RGBA")
img = np.array(image, np.uint8)
img[..., -1] = load_mask(image_path, img.shape[:2])
return img


def load_mask(image_path, target_shape):
p = pathlib.Path(image_path)
mask_path = os.path.join(p.parent, 'mask', p.stem + '.png')
result = None

if os.path.exists(mask_path):
try:
mask = np.array(Image.open(mask_path))
if len(mask.shape) > 2 and mask.max() <= 255:
result = np.array(Image.open(mask_path).convert("L"))
elif len(mask.shape) == 2 and mask.max() > 255:
result = mask // (((2 ** 16) - 1) // 255)
elif len(mask.shape) == 2 and mask.max() <= 255:
result = mask
else:
print(f"{mask_path} has invalid mask format: using default mask")
except:
print(f"failed to load mask: {mask_path}")

# use default when mask file is unavailable
if result is None:
result = np.full(target_shape, 255, np.uint8)

# stretch mask to image shape
if result.shape != target_shape:
print(f"{mask_path} does not match image dimensions, resizing")
result = cv2.resize(result, dsize=target_shape, interpolation=cv2.INTER_NEAREST)

return result


# 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom)
def trim_and_resize_if_required(
random_crop: bool, image: Image.Image, reso, resized_size: Tuple[int, int]
Expand Down Expand Up @@ -2184,12 +2227,17 @@ def cache_batch_latents(
latents_original_size and latents_crop_ltrb are also set
"""
images = []
masks = []
for info in image_infos:
image = load_image(info.absolute_path) if info.image is None else np.array(info.image, np.uint8)
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
# alpha channel contains loss mask, separate it
mask = image[:, :, -1] / 255
image = image[:, :, :3]
image = IMAGE_TRANSFORMS(image)
images.append(image)
masks.append(mask)

info.latents_original_size = original_size
info.latents_crop_ltrb = crop_ltrb
Expand All @@ -2207,7 +2255,7 @@ def cache_batch_latents(
else:
flipped_latents = [None] * len(latents)

for info, latent, flipped_latent in zip(image_infos, latents, flipped_latents):
for info, latent, flipped_latent, mask in zip(image_infos, latents, flipped_latents, masks):
# check NaN
if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()):
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
Expand All @@ -2216,8 +2264,10 @@ def cache_batch_latents(
save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent)
else:
info.latents = latent
info.mask = mask
if flip_aug:
info.latents_flipped = flipped_latent
info.mask_flipped = mask.flip(mask, dims=[3])

# FIXME this slows down caching a lot, specify this as an option
if torch.cuda.is_available():
Expand Down Expand Up @@ -3159,6 +3209,10 @@ def add_dataset_arguments(
"--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します"
)

parser.add_argument(
"--masked_loss", action="store_true", help="Enable masking of latent loss using grayscale mask images"
)

parser.add_argument(
"--token_warmup_min",
type=int,
Expand Down
6 changes: 6 additions & 0 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
add_v_prediction_like_loss,
get_latent_masks
)
from library.sdxl_original_unet import SdxlUNet2DConditionModel

Expand Down Expand Up @@ -548,6 +549,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

target = noise

if args.masked_loss and batch['masks'] is not None:
mask = get_latent_masks(batch['masks'], noise_pred.shape, noise_pred.device)
noise_pred = noise_pred * mask
target = target * mask

if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.v_pred_like_loss:
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
Expand Down
6 changes: 6 additions & 0 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
pyramid_noise_like,
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction,
get_latent_masks
)

# perlin_noise,
Expand Down Expand Up @@ -326,6 +327,11 @@ def train(args):
else:
target = noise

if args.masked_loss and batch['masks'] is not None:
mask = get_latent_masks(batch['masks'], noise_pred.shape, noise_pred.device)
noise_pred = noise_pred * mask
target = target * mask

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])

Expand Down
6 changes: 6 additions & 0 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
add_v_prediction_like_loss,
get_latent_masks
)


Expand Down Expand Up @@ -796,6 +797,11 @@ def remove_model(old_ckpt_name):
else:
target = noise

if args.masked_loss and batch['masks'] is not None:
mask = get_latent_masks(batch['masks'], noise_pred.shape, noise_pred.device)
noise_pred = noise_pred * mask
target = target * mask

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])

Expand Down
6 changes: 6 additions & 0 deletions train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
add_v_prediction_like_loss,
get_latent_masks
)

imagenet_templates_small = [
Expand Down Expand Up @@ -570,6 +571,11 @@ def remove_model(old_ckpt_name):
else:
target = noise

if args.masked_loss and batch['masks'] is not None:
mask = get_latent_masks(batch['masks'], noise_pred.shape, noise_pred.device)
noise_pred = noise_pred * mask
target = target * mask

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])

Expand Down
6 changes: 6 additions & 0 deletions train_textual_inversion_XTI.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
pyramid_noise_like,
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction,
get_latent_masks
)
import library.original_unet as original_unet
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
Expand Down Expand Up @@ -461,6 +462,11 @@ def remove_model(old_ckpt_name):
else:
target = noise

if args.masked_loss and batch['masks'] is not None:
mask = get_latent_masks(batch['masks'], noise_pred.shape, noise_pred.device)
noise_pred = noise_pred * mask
target = target * mask

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])

Expand Down

0 comments on commit 7de0550

Please sign in to comment.