diff --git a/README.md b/README.md index 23e049354..52c963392 100644 --- a/README.md +++ b/README.md @@ -161,6 +161,10 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - Example: `--network_args "loraplus_unet_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` or `--network_args "loraplus_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` etc. - `network_module` `networks.lora` and `networks.dylora` are available. +- The feature to use the transparency (alpha channel) of the image as a mask in the loss calculation has been added. PR [#1223](https://github.com/kohya-ss/sd-scripts/pull/1223) Thanks to u-haru! + - The transparent part is ignored during training. Specify the `--alpha_mask` option in the training script or specify `alpha_mask = true` in the dataset configuration file. + - See [About masked loss](./docs/masked_loss_README.md) for details. + - LoRA training in SDXL now supports block-wise learning rates and block-wise dim (rank). PR [#1331](https://github.com/kohya-ss/sd-scripts/pull/1331) - Specify the learning rate and dim (rank) for each block. - See [Block-wise learning rates in LoRA](./docs/train_network_README-ja.md#階層別学習率) for details (Japanese only). @@ -214,6 +218,10 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821! - 例:`--network_args "loraplus_unet_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` または `--network_args "loraplus_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` など - `network_module` の `networks.lora` および `networks.dylora` で使用可能です。 +- 画像の透明度(アルファチャネル)をロス計算時のマスクとして使用する機能が追加されました。PR [#1223](https://github.com/kohya-ss/sd-scripts/pull/1223) u-haru 氏に感謝します。 + - 透明部分が学習時に無視されるようになります。学習スクリプトに `--alpha_mask` オプションを指定するか、データセット設定ファイルに `alpha_mask = true` を指定してください。 + - 詳細は [マスクロスについて](./docs/masked_loss_README-ja.md) をご覧ください。 + - SDXL の LoRA で階層別学習率、階層別 dim (rank) をサポートしました。PR [#1331](https://github.com/kohya-ss/sd-scripts/pull/1331) - ブロックごとに学習率および dim (rank) を指定することができます。 - 詳細は [LoRA の階層別学習率](./docs/train_network_README-ja.md#階層別学習率) をご覧ください。 diff --git a/docs/masked_loss_README-ja.md b/docs/masked_loss_README-ja.md new file mode 100644 index 000000000..58f042c3b --- /dev/null +++ b/docs/masked_loss_README-ja.md @@ -0,0 +1,57 @@ +## マスクロスについて + +マスクロスは、入力画像のマスクで指定された部分だけ損失計算することで、画像の一部分だけを学習することができる機能です。 +たとえばキャラクタを学習したい場合、キャラクタ部分だけをマスクして学習することで、背景を無視して学習することができます。 + +マスクロスのマスクには、二種類の指定方法があります。 + +- マスク画像を用いる方法 +- 透明度(アルファチャネル)を使用する方法 + +なお、サンプルは [ずんずんPJイラスト/3Dデータ](https://zunko.jp/con_illust.html) の「AI画像モデル用学習データ」を使用しています。 + +### マスク画像を用いる方法 + +学習画像それぞれに対応するマスク画像を用意する方法です。学習画像と同じファイル名のマスク画像を用意し、それを学習画像と別のディレクトリに保存します。 + +- 学習画像 + ![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/607c5116-5f62-47de-8b66-9c4a597f0441) +- マスク画像 + ![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/53e9b0f8-a4bf-49ed-882d-4026f84e8450) + +```.toml +[[datasets.subsets]] +image_dir = "/path/to/a_zundamon" +caption_extension = ".txt" +conditioning_data_dir = "/path/to/a_zundamon_mask" +num_repeats = 8 +``` + +マスク画像は、学習画像と同じサイズで、学習する部分を白、無視する部分を黒で描画します。グレースケールにも対応しています(127 ならロス重みが 0.5 になります)。なお、正確にはマスク画像の R チャネルが用いられます。 + +DreamBooth 方式の dataset で、`conditioning_data_dir` で指定したディレクトリにマスク画像を保存してください。ControlNet のデータセットと同じですので、詳細は [ControlNet-LLLite](train_lllite_README-ja.md#データセットの準備) を参照してください。 + +### 透明度(アルファチャネル)を使用する方法 + +学習画像の透明度(アルファチャネル)がマスクとして使用されます。透明度が 0 の部分は無視され、255 の部分は学習されます。半透明の場合は、その透明度に応じてロス重みが変化します(127 ならおおむね 0.5)。 + +![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/0baa129b-446a-4aac-b98c-7208efb0e75e) + +※それぞれの画像は透過PNG + +学習時のスクリプトのオプションに `--alpha_mask` を指定するか、dataset の設定ファイルの subset で、`alpha_mask` を指定してください。たとえば、以下のようになります。 + +```toml +[[datasets.subsets]] +image_dir = "/path/to/image/dir" +caption_extension = ".txt" +num_repeats = 8 +alpha_mask = true +``` + +## 学習時の注意事項 + +- 現時点では DreamBooth 方式の dataset のみ対応しています。 +- マスクは latents のサイズ、つまり 1/8 に縮小されてから適用されます。そのため、細かい部分(たとえばアホ毛やイヤリングなど)はうまく学習できない可能性があります。マスクをわずかに拡張するなどの工夫が必要かもしれません。 +- マスクロスを用いる場合、学習対象外の部分をキャプションに含める必要はないかもしれません。(要検証) +- `alpha_mask` の場合、マスクの有無を切り替えると latents キャッシュが自動的に再生成されます。 diff --git a/docs/masked_loss_README.md b/docs/masked_loss_README.md new file mode 100644 index 000000000..3ac5ad211 --- /dev/null +++ b/docs/masked_loss_README.md @@ -0,0 +1,56 @@ +## Masked Loss + +Masked loss is a feature that allows you to train only part of an image by calculating the loss only for the part specified by the mask of the input image. For example, if you want to train a character, you can train only the character part by masking it, ignoring the background. + +There are two ways to specify the mask for masked loss. + +- Using a mask image +- Using transparency (alpha channel) of the image + +The sample uses the "AI image model training data" from [ZunZunPJ Illustration/3D Data](https://zunko.jp/con_illust.html). + +### Using a mask image + +This is a method of preparing a mask image corresponding to each training image. Prepare a mask image with the same file name as the training image and save it in a different directory from the training image. + +- Training image + ![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/607c5116-5f62-47de-8b66-9c4a597f0441) +- Mask image + ![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/53e9b0f8-a4bf-49ed-882d-4026f84e8450) + +```.toml +[[datasets.subsets]] +image_dir = "/path/to/a_zundamon" +caption_extension = ".txt" +conditioning_data_dir = "/path/to/a_zundamon_mask" +num_repeats = 8 +``` + +The mask image is the same size as the training image, with the part to be trained drawn in white and the part to be ignored in black. It also supports grayscale (127 gives a loss weight of 0.5). The R channel of the mask image is used currently. + +Use the dataset in the DreamBooth method, and save the mask image in the directory specified by `conditioning_data_dir`. It is the same as the ControlNet dataset, so please refer to [ControlNet-LLLite](train_lllite_README.md#Preparing-the-dataset) for details. + +### Using transparency (alpha channel) of the image + +The transparency (alpha channel) of the training image is used as a mask. The part with transparency 0 is ignored, the part with transparency 255 is trained. For semi-transparent parts, the loss weight changes according to the transparency (127 gives a weight of about 0.5). + +![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/0baa129b-446a-4aac-b98c-7208efb0e75e) + +※Each image is a transparent PNG + +Specify `--alpha_mask` in the training script options or specify `alpha_mask` in the subset of the dataset configuration file. For example, it will look like this. + +```toml +[[datasets.subsets]] +image_dir = "/path/to/image/dir" +caption_extension = ".txt" +num_repeats = 8 +alpha_mask = true +``` + +## Notes on training + +- At the moment, only the dataset in the DreamBooth method is supported. +- The mask is applied after the size is reduced to 1/8, which is the size of the latents. Therefore, fine details (such as ahoge or earrings) may not be learned well. Some dilations of the mask may be necessary. +- If using masked loss, it may not be necessary to include parts that are not to be trained in the caption. (To be verified) +- In the case of `alpha_mask`, the latents cache is automatically regenerated when the enable/disable state of the mask is switched. diff --git a/docs/train_network_README-ja.md b/docs/train_network_README-ja.md index 46085117c..55c80c4b0 100644 --- a/docs/train_network_README-ja.md +++ b/docs/train_network_README-ja.md @@ -102,6 +102,8 @@ accelerate launch --num_cpu_threads_per_process 1 train_network.py * Text Encoderに関連するLoRAモジュールに、通常の学習率(--learning_rateオプションで指定)とは異なる学習率を使う時に指定します。Text Encoderのほうを若干低めの学習率(5e-5など)にしたほうが良い、という話もあるようです。 * `--network_args` * 複数の引数を指定できます。後述します。 +* `--alpha_mask` + * 画像のアルファ値をマスクとして使用します。透過画像を学習する際に使用します。[PR #1223](https://github.com/kohya-ss/sd-scripts/pull/1223) `--network_train_unet_only` と `--network_train_text_encoder_only` の両方とも未指定時(デフォルト)はText EncoderとU-Netの両方のLoRAモジュールを有効にします。 diff --git a/docs/train_network_README-zh.md b/docs/train_network_README-zh.md index ed7a0c4ef..830014f72 100644 --- a/docs/train_network_README-zh.md +++ b/docs/train_network_README-zh.md @@ -101,6 +101,8 @@ LoRA的模型将会被保存在通过`--output_dir`选项指定的文件夹中 * 当在Text Encoder相关的LoRA模块中使用与常规学习率(由`--learning_rate`选项指定)不同的学习率时,应指定此选项。可能最好将Text Encoder的学习率稍微降低(例如5e-5)。 * `--network_args` * 可以指定多个参数。将在下面详细说明。 +* `--alpha_mask` + * 使用图像的 Alpha 值作为遮罩。这在学习透明图像时使用。[PR #1223](https://github.com/kohya-ss/sd-scripts/pull/1223) 当未指定`--network_train_unet_only`和`--network_train_text_encoder_only`时(默认情况),将启用Text Encoder和U-Net的两个LoRA模块。 diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index 0389da388..019c737a6 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -11,6 +11,7 @@ import torch from library.device_utils import init_ipex, get_preferred_device + init_ipex() from torchvision import transforms @@ -18,8 +19,10 @@ import library.model_util as model_util import library.train_util as train_util from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) DEVICE = get_preferred_device() @@ -89,7 +92,9 @@ def main(args): # bucketのサイズを計算する max_reso = tuple([int(t) for t in args.max_resolution.split(",")]) - assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}" + assert ( + len(max_reso) == 2 + ), f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}" bucket_manager = train_util.BucketManager( args.bucket_no_upscale, max_reso, args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps @@ -107,7 +112,7 @@ def main(args): def process_batch(is_last): for bucket in bucket_manager.buckets: if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size: - train_util.cache_batch_latents(vae, True, bucket, args.flip_aug, False) + train_util.cache_batch_latents(vae, True, bucket, args.flip_aug, args.alpha_mask, False) bucket.clear() # 読み込みの高速化のためにDataLoaderを使うオプション @@ -208,7 +213,9 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル") parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル") - parser.add_argument("--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)") + parser.add_argument( + "--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)" + ) parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") parser.add_argument( "--max_data_loader_n_workers", @@ -231,10 +238,16 @@ def setup_parser() -> argparse.ArgumentParser: help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します", ) parser.add_argument( - "--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します" + "--bucket_no_upscale", + action="store_true", + help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します", ) parser.add_argument( - "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度" + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help="use mixed precision / 混合精度を使う場合、その精度", ) parser.add_argument( "--full_path", @@ -242,7 +255,15 @@ def setup_parser() -> argparse.ArgumentParser: help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)", ) parser.add_argument( - "--flip_aug", action="store_true", help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する" + "--flip_aug", + action="store_true", + help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する", + ) + parser.add_argument( + "--alpha_mask", + type=str, + default="", + help="save alpha mask for images for loss calculation / 損失計算用に画像のアルファマスクを保存する", ) parser.add_argument( "--skip_existing", diff --git a/library/config_util.py b/library/config_util.py index 59f5f86d2..10b2457f3 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -86,11 +86,13 @@ class DreamBoothSubsetParams(BaseSubsetParams): class_tokens: Optional[str] = None caption_extension: str = ".caption" cache_info: bool = False + alpha_mask: bool = False @dataclass class FineTuningSubsetParams(BaseSubsetParams): metadata_file: Optional[str] = None + alpha_mask: bool = False @dataclass @@ -213,11 +215,13 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] DB_SUBSET_DISTINCT_SCHEMA = { Required("image_dir"): str, "is_reg": bool, + "alpha_mask": bool, } # FT means FineTuning FT_SUBSET_DISTINCT_SCHEMA = { Required("metadata_file"): str, "image_dir": str, + "alpha_mask": bool, } CN_SUBSET_ASCENDABLE_SCHEMA = { "caption_extension": str, @@ -538,6 +542,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu random_crop: {subset.random_crop} token_warmup_min: {subset.token_warmup_min}, token_warmup_step: {subset.token_warmup_step}, + alpha_mask: {subset.alpha_mask}, """ ), " ", diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 406e0e36e..2a513dc5b 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -480,12 +480,20 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale): def apply_masked_loss(loss, batch): - # mask image is -1 to 1. we need to convert it to 0 to 1 - mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel + if "conditioning_images" in batch: + # conditioning image is -1 to 1. we need to convert it to 0 to 1 + mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel + mask_image = mask_image / 2 + 0.5 + # print(f"conditioning_image: {mask_image.shape}") + elif "alpha_masks" in batch and batch["alpha_masks"] is not None: + # alpha mask is 0 to 1 + mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension + # print(f"mask_image: {mask_image.shape}, {mask_image.mean()}") + else: + return loss # resize to the same size as the loss mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area") - mask_image = mask_image / 2 + 0.5 loss = loss * mask_image return loss diff --git a/library/train_util.py b/library/train_util.py index 410471470..1f9f3c5df 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -159,6 +159,7 @@ def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, self.text_encoder_outputs1: Optional[torch.Tensor] = None self.text_encoder_outputs2: Optional[torch.Tensor] = None self.text_encoder_pool2: Optional[torch.Tensor] = None + self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime class BucketManager: @@ -361,6 +362,7 @@ class BaseSubset: def __init__( self, image_dir: Optional[str], + alpha_mask: Optional[bool], num_repeats: int, shuffle_caption: bool, caption_separator: str, @@ -381,6 +383,7 @@ def __init__( token_warmup_step: Union[float, int], ) -> None: self.image_dir = image_dir + self.alpha_mask = alpha_mask if alpha_mask is not None else False self.num_repeats = num_repeats self.shuffle_caption = shuffle_caption self.caption_separator = caption_separator @@ -412,6 +415,7 @@ def __init__( class_tokens: Optional[str], caption_extension: str, cache_info: bool, + alpha_mask: bool, num_repeats, shuffle_caption, caption_separator: str, @@ -435,6 +439,7 @@ def __init__( super().__init__( image_dir, + alpha_mask, num_repeats, shuffle_caption, caption_separator, @@ -473,6 +478,7 @@ def __init__( self, image_dir, metadata_file: str, + alpha_mask: bool, num_repeats, shuffle_caption, caption_separator, @@ -496,6 +502,7 @@ def __init__( super().__init__( image_dir, + alpha_mask, num_repeats, shuffle_caption, caption_separator, @@ -554,6 +561,7 @@ def __init__( super().__init__( image_dir, + False, # alpha_mask num_repeats, shuffle_caption, caption_separator, @@ -915,7 +923,7 @@ def make_buckets(self): logger.info(f"mean ar error (without repeats): {mean_img_ar_error}") # データ参照用indexを作る。このindexはdatasetのshuffleに用いられる - self.buckets_indices: List(BucketBatchIndex) = [] + self.buckets_indices: List[BucketBatchIndex] = [] for bucket_index, bucket in enumerate(self.bucket_manager.buckets): batch_count = int(math.ceil(len(bucket) / self.batch_size)) for batch_index in range(batch_count): @@ -994,7 +1002,9 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc if not is_main_process: # store to info only continue - cache_available = is_disk_cached_latents_is_expected(info.bucket_reso, info.latents_npz, subset.flip_aug) + cache_available = is_disk_cached_latents_is_expected( + info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask + ) if cache_available: # do not add to batch continue @@ -1020,7 +1030,7 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc # iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded logger.info("caching latents...") for batch in tqdm(batches, smoothing=1, total=len(batches)): - cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.random_crop) + cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) # weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる # SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する @@ -1088,8 +1098,8 @@ def cache_text_encoder_outputs( def get_image_size(self, image_path): return imagesize.get(image_path) - def load_image_with_face_info(self, subset: BaseSubset, image_path: str): - img = load_image(image_path) + def load_image_with_face_info(self, subset: BaseSubset, image_path: str, alpha_mask=False): + img = load_image(image_path, alpha_mask) face_cx = face_cy = face_w = face_h = 0 if subset.face_crop_aug_range is not None: @@ -1166,6 +1176,7 @@ def __getitem__(self, index): input_ids_list = [] input_ids2_list = [] latents_list = [] + alpha_mask_list = [] images = [] original_sizes_hw = [] crop_top_lefts = [] @@ -1190,21 +1201,28 @@ def __getitem__(self, index): crop_ltrb = image_info.latents_crop_ltrb # calc values later if flipped if not flipped: latents = image_info.latents + alpha_mask = image_info.alpha_mask else: latents = image_info.latents_flipped + alpha_mask = None if image_info.alpha_mask is None else torch.flip(image_info.alpha_mask, [1]) image = None elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 - latents, original_size, crop_ltrb, flipped_latents = load_latents_from_disk(image_info.latents_npz) + latents, original_size, crop_ltrb, flipped_latents, alpha_mask = load_latents_from_disk(image_info.latents_npz) if flipped: latents = flipped_latents + alpha_mask = None if alpha_mask is None else alpha_mask[:, ::-1].copy() # copy to avoid negative stride problem del flipped_latents latents = torch.FloatTensor(latents) + if alpha_mask is not None: + alpha_mask = torch.FloatTensor(alpha_mask) image = None else: # 画像を読み込み、必要ならcropする - img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path) + img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info( + subset, image_info.absolute_path, subset.alpha_mask + ) im_h, im_w = img.shape[0:2] if self.enable_bucket: @@ -1236,16 +1254,32 @@ def __getitem__(self, index): # augmentation aug = self.aug_helper.get_augmentor(subset.color_aug) if aug is not None: - img = aug(image=img)["image"] + # augment RGB channels only + img_rgb = img[:, :, :3] + img_rgb = aug(image=img_rgb)["image"] + img[:, :, :3] = img_rgb if flipped: img = img[:, ::-1, :].copy() # copy to avoid negative stride problem + if subset.alpha_mask: + if img.shape[2] == 4: + alpha_mask = img[:, :, 3] # [H,W] + alpha_mask = transforms.ToTensor()(alpha_mask) # 0-255 -> 0-1 + else: + alpha_mask = torch.ones((img.shape[0], img.shape[1]), dtype=torch.float32) + else: + alpha_mask = None + + img = img[:, :, :3] # remove alpha channel + latents = None image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる + del img images.append(image) latents_list.append(latents) + alpha_mask_list.append(alpha_mask) target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8) @@ -1331,6 +1365,23 @@ def __getitem__(self, index): example["text_encoder_outputs2_list"] = torch.stack(text_encoder_outputs2_list) example["text_encoder_pool2_list"] = torch.stack(text_encoder_pool2_list) + # if one of alpha_masks is not None, we need to replace None with ones + none_or_not = [x is None for x in alpha_mask_list] + if all(none_or_not): + example["alpha_masks"] = None + elif any(none_or_not): + for i in range(len(alpha_mask_list)): + if alpha_mask_list[i] is None: + if images[i] is not None: + alpha_mask_list[i] = torch.ones((images[i].shape[1], images[i].shape[2]), dtype=torch.float32) + else: + alpha_mask_list[i] = torch.ones( + (latents_list[i].shape[1] * 8, latents_list[i].shape[2] * 8), dtype=torch.float32 + ) + example["alpha_masks"] = torch.stack(alpha_mask_list) + else: + example["alpha_masks"] = torch.stack(alpha_mask_list) + if images[0] is not None: images = torch.stack(images) images = images.to(memory_format=torch.contiguous_format).float() @@ -1361,6 +1412,7 @@ def get_item_for_caching(self, bucket, bucket_batch_size, image_index): resized_sizes = [] bucket_reso = None flip_aug = None + alpha_mask = None random_crop = None for image_key in bucket[image_index : image_index + bucket_batch_size]: @@ -1369,10 +1421,13 @@ def get_item_for_caching(self, bucket, bucket_batch_size, image_index): if flip_aug is None: flip_aug = subset.flip_aug + alpha_mask = subset.alpha_mask random_crop = subset.random_crop bucket_reso = image_info.bucket_reso else: + # TODO そもそも混在してても動くようにしたほうがいい assert flip_aug == subset.flip_aug, "flip_aug must be same in a batch" + assert alpha_mask == subset.alpha_mask, "alpha_mask must be same in a batch" assert random_crop == subset.random_crop, "random_crop must be same in a batch" assert bucket_reso == image_info.bucket_reso, "bucket_reso must be same in a batch" @@ -1409,6 +1464,7 @@ def get_item_for_caching(self, bucket, bucket_batch_size, image_index): example["absolute_paths"] = absolute_paths example["resized_sizes"] = resized_sizes example["flip_aug"] = flip_aug + example["alpha_mask"] = alpha_mask example["random_crop"] = random_crop example["bucket_reso"] = bucket_reso return example @@ -1892,6 +1948,7 @@ def __init__( None, subset.caption_extension, subset.cache_info, + False, subset.num_repeats, subset.shuffle_caption, subset.caption_separator, @@ -2117,7 +2174,7 @@ def disable_token_padding(self): dataset.disable_token_padding() -def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): +def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alpha_mask: bool): expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意 if not os.path.exists(npz_path): @@ -2135,6 +2192,15 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): return False if npz["latents_flipped"].shape[1:3] != expected_latents_size: return False + + if alpha_mask: + if "alpha_mask" not in npz: + return False + if npz["alpha_mask"].shape[0:2] != reso: # HxW + return False + else: + if "alpha_mask" in npz: + return False except Exception as e: logger.error(f"Error loading file: {npz_path}") raise e @@ -2145,7 +2211,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): # 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) def load_latents_from_disk( npz_path, -) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor]]: +) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: npz = np.load(npz_path) if "latents" not in npz: raise ValueError(f"error: npz is old format. please re-generate {npz_path}") @@ -2154,13 +2220,16 @@ def load_latents_from_disk( original_size = npz["original_size"].tolist() crop_ltrb = npz["crop_ltrb"].tolist() flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None - return latents, original_size, crop_ltrb, flipped_latents + alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None + return latents, original_size, crop_ltrb, flipped_latents, alpha_mask -def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None): +def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None): kwargs = {} if flipped_latents_tensor is not None: kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() + if alpha_mask is not None: + kwargs["alpha_mask"] = alpha_mask # ndarray np.savez( npz_path, latents=latents_tensor.float().cpu().numpy(), @@ -2232,6 +2301,13 @@ def debug_dataset(train_dataset, show_input_ids=False): if os.name == "nt": cv2.imshow("cond_img", cond_img) + if "alpha_masks" in example and example["alpha_masks"] is not None: + alpha_mask = example["alpha_masks"][j] + logger.info(f"alpha mask size: {alpha_mask.size()}") + alpha_mask = (alpha_mask[0].numpy() * 255.0).astype(np.uint8) + if os.name == "nt": + cv2.imshow("alpha_mask", alpha_mask) + if os.name == "nt": # only windows cv2.imshow("img", im) k = cv2.waitKey() @@ -2349,17 +2425,21 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset: return train_dataset_group -def load_image(image_path): +def load_image(image_path, alpha=False): image = Image.open(image_path) - if not image.mode == "RGB": - image = image.convert("RGB") + if alpha: + if not image.mode == "RGBA": + image = image.convert("RGBA") + else: + if not image.mode == "RGB": + image = image.convert("RGB") img = np.array(image, np.uint8) return img # 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom) def trim_and_resize_if_required( - random_crop: bool, image: Image.Image, reso, resized_size: Tuple[int, int] + random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int] ) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]: image_height, image_width = image.shape[0:2] original_size = (image_width, image_height) # size before resize @@ -2391,7 +2471,7 @@ def trim_and_resize_if_required( def cache_batch_latents( - vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, random_crop: bool + vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, use_alpha_mask: bool, random_crop: bool ) -> None: r""" requires image_infos to have: absolute_path, bucket_reso, resized_size, latents_npz @@ -2403,16 +2483,29 @@ def cache_batch_latents( latents_original_size and latents_crop_ltrb are also set """ images = [] + alpha_masks: List[np.ndarray] = [] for info in image_infos: - image = load_image(info.absolute_path) if info.image is None else np.array(info.image, np.uint8) + image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size) - image = IMAGE_TRANSFORMS(image) - images.append(image) info.latents_original_size = original_size info.latents_crop_ltrb = crop_ltrb + if use_alpha_mask: + if image.shape[2] == 4: + alpha_mask = image[:, :, 3] # [H,W] + alpha_mask = alpha_mask.astype(np.float32) / 255.0 + else: + alpha_mask = np.ones_like(image[:, :, 0], dtype=np.float32) + else: + alpha_mask = None + alpha_masks.append(alpha_mask) + + image = image[:, :, :3] # remove alpha channel if exists + image = IMAGE_TRANSFORMS(image) + images.append(image) + img_tensors = torch.stack(images, dim=0) img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype) @@ -2426,17 +2519,25 @@ def cache_batch_latents( else: flipped_latents = [None] * len(latents) - for info, latent, flipped_latent in zip(image_infos, latents, flipped_latents): + for info, latent, flipped_latent, alpha_mask in zip(image_infos, latents, flipped_latents, alpha_masks): # check NaN if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()): raise RuntimeError(f"NaN detected in latents: {info.absolute_path}") if cache_to_disk: - save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent) + save_latents_to_disk( + info.latents_npz, + latent, + info.latents_original_size, + info.latents_crop_ltrb, + flipped_latent, + alpha_mask, + ) else: info.latents = latent if flip_aug: info.latents_flipped = flipped_latent + info.alpha_mask = alpha_mask if not HIGH_VRAM: clean_memory_on_device(vae.device) @@ -3683,6 +3784,11 @@ def add_dataset_arguments( default=0, help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)", ) + parser.add_argument( + "--alpha_mask", + action="store_true", + help="use alpha channel as mask for training / 画像のアルファチャンネルをlossのマスクに使用する", + ) parser.add_argument( "--dataset_class", diff --git a/sdxl_train.py b/sdxl_train.py index 7c71a5133..9e20c60ca 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -711,7 +711,7 @@ def optimizer_hook(parameter: torch.Tensor): loss = train_util.conditional_loss( noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c ) - if args.masked_loss: + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) diff --git a/tools/cache_latents.py b/tools/cache_latents.py index 347db27f7..b7c88121e 100644 --- a/tools/cache_latents.py +++ b/tools/cache_latents.py @@ -17,10 +17,13 @@ BlueprintGenerator, ) from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) + def cache_to_disk(args: argparse.Namespace) -> None: train_util.prepare_dataset_args(args, True) @@ -107,7 +110,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: else: _, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator) - if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える + if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える vae.set_use_memory_efficient_attention_xformers(args.xformers) vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) @@ -136,6 +139,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: b_size = len(batch["images"]) vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size flip_aug = batch["flip_aug"] + alpha_mask = batch["alpha_mask"] random_crop = batch["random_crop"] bucket_reso = batch["bucket_reso"] @@ -154,14 +158,16 @@ def cache_to_disk(args: argparse.Namespace) -> None: image_info.latents_npz = os.path.splitext(absolute_path)[0] + ".npz" if args.skip_existing: - if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug): + if train_util.is_disk_cached_latents_is_expected( + image_info.bucket_reso, image_info.latents_npz, flip_aug, alpha_mask + ): logger.warning(f"Skipping {image_info.latents_npz} because it already exists.") continue image_infos.append(image_info) if len(image_infos) > 0: - train_util.cache_batch_latents(vae, True, image_infos, flip_aug, random_crop) + train_util.cache_batch_latents(vae, True, image_infos, flip_aug, alpha_mask, random_crop) accelerator.wait_for_everyone() accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.") diff --git a/train_db.py b/train_db.py index a5408cd3d..39d8ea6ed 100644 --- a/train_db.py +++ b/train_db.py @@ -359,7 +359,7 @@ def train(args): target = noise loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) - if args.masked_loss: + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) diff --git a/train_network.py b/train_network.py index 38e4888e8..b272a6e1a 100644 --- a/train_network.py +++ b/train_network.py @@ -774,7 +774,9 @@ def load_model_hook(models, input_dir): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "network_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs + "network_train" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, ) loss_recorder = train_util.LossRecorder() @@ -902,7 +904,7 @@ def remove_model(old_ckpt_name): loss = train_util.conditional_loss( noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c ) - if args.masked_loss: + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 184607d1d..ade077c36 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -589,7 +589,7 @@ def remove_model(old_ckpt_name): target = noise loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) - if args.masked_loss: + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 8eed00fa1..efb59137b 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -474,7 +474,7 @@ def remove_model(old_ckpt_name): target = noise loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) - if args.masked_loss: + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3])