|
| 1 | +import datetime |
| 2 | +import glob |
| 3 | +import html |
| 4 | +import os |
| 5 | +import sys |
| 6 | +import traceback |
| 7 | +import tqdm |
| 8 | + |
| 9 | +import torch |
| 10 | + |
| 11 | +from ldm.util import default |
| 12 | +from modules import devices, shared, processing, sd_models |
| 13 | +import torch |
| 14 | +from torch import einsum |
| 15 | +from einops import rearrange, repeat |
| 16 | +import modules.textual_inversion.dataset |
| 17 | + |
| 18 | + |
| 19 | +class HypernetworkModule(torch.nn.Module): |
| 20 | + def __init__(self, dim, state_dict=None): |
| 21 | + super().__init__() |
| 22 | + |
| 23 | + self.linear1 = torch.nn.Linear(dim, dim * 2) |
| 24 | + self.linear2 = torch.nn.Linear(dim * 2, dim) |
| 25 | + |
| 26 | + if state_dict is not None: |
| 27 | + self.load_state_dict(state_dict, strict=True) |
| 28 | + else: |
| 29 | + self.linear1.weight.data.fill_(0.0001) |
| 30 | + self.linear1.bias.data.fill_(0.0001) |
| 31 | + self.linear2.weight.data.fill_(0.0001) |
| 32 | + self.linear2.bias.data.fill_(0.0001) |
| 33 | + |
| 34 | + self.to(devices.device) |
| 35 | + |
| 36 | + def forward(self, x): |
| 37 | + return x + (self.linear2(self.linear1(x))) |
| 38 | + |
| 39 | + |
| 40 | +class Hypernetwork: |
| 41 | + filename = None |
| 42 | + name = None |
| 43 | + |
| 44 | + def __init__(self, name=None): |
| 45 | + self.filename = None |
| 46 | + self.name = name |
| 47 | + self.layers = {} |
| 48 | + self.step = 0 |
| 49 | + self.sd_checkpoint = None |
| 50 | + self.sd_checkpoint_name = None |
| 51 | + |
| 52 | + for size in [320, 640, 768, 1280]: |
| 53 | + self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size)) |
| 54 | + |
| 55 | + def weights(self): |
| 56 | + res = [] |
| 57 | + |
| 58 | + for k, layers in self.layers.items(): |
| 59 | + for layer in layers: |
| 60 | + layer.train() |
| 61 | + res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias] |
| 62 | + |
| 63 | + return res |
| 64 | + |
| 65 | + def save(self, filename): |
| 66 | + state_dict = {} |
| 67 | + |
| 68 | + for k, v in self.layers.items(): |
| 69 | + state_dict[k] = (v[0].state_dict(), v[1].state_dict()) |
| 70 | + |
| 71 | + state_dict['step'] = self.step |
| 72 | + state_dict['name'] = self.name |
| 73 | + state_dict['sd_checkpoint'] = self.sd_checkpoint |
| 74 | + state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name |
| 75 | + |
| 76 | + torch.save(state_dict, filename) |
| 77 | + |
| 78 | + def load(self, filename): |
| 79 | + self.filename = filename |
| 80 | + if self.name is None: |
| 81 | + self.name = os.path.splitext(os.path.basename(filename))[0] |
| 82 | + |
| 83 | + state_dict = torch.load(filename, map_location='cpu') |
| 84 | + |
| 85 | + for size, sd in state_dict.items(): |
| 86 | + if type(size) == int: |
| 87 | + self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1])) |
| 88 | + |
| 89 | + self.name = state_dict.get('name', self.name) |
| 90 | + self.step = state_dict.get('step', 0) |
| 91 | + self.sd_checkpoint = state_dict.get('sd_checkpoint', None) |
| 92 | + self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None) |
| 93 | + |
| 94 | + |
| 95 | +def load_hypernetworks(path): |
| 96 | + res = {} |
| 97 | + |
| 98 | + for filename in glob.iglob(path + '**/*.pt', recursive=True): |
| 99 | + try: |
| 100 | + hn = Hypernetwork() |
| 101 | + hn.load(filename) |
| 102 | + res[hn.name] = hn |
| 103 | + except Exception: |
| 104 | + print(f"Error loading hypernetwork {filename}", file=sys.stderr) |
| 105 | + print(traceback.format_exc(), file=sys.stderr) |
| 106 | + |
| 107 | + return res |
| 108 | + |
| 109 | + |
| 110 | +def attention_CrossAttention_forward(self, x, context=None, mask=None): |
| 111 | + h = self.heads |
| 112 | + |
| 113 | + q = self.to_q(x) |
| 114 | + context = default(context, x) |
| 115 | + |
| 116 | + hypernetwork_layers = (shared.hypernetwork.layers if shared.hypernetwork is not None else {}).get(context.shape[2], None) |
| 117 | + |
| 118 | + if hypernetwork_layers is not None: |
| 119 | + hypernetwork_k, hypernetwork_v = hypernetwork_layers |
| 120 | + |
| 121 | + self.hypernetwork_k = hypernetwork_k |
| 122 | + self.hypernetwork_v = hypernetwork_v |
| 123 | + |
| 124 | + context_k = hypernetwork_k(context) |
| 125 | + context_v = hypernetwork_v(context) |
| 126 | + else: |
| 127 | + context_k = context |
| 128 | + context_v = context |
| 129 | + |
| 130 | + k = self.to_k(context_k) |
| 131 | + v = self.to_v(context_v) |
| 132 | + |
| 133 | + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) |
| 134 | + |
| 135 | + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale |
| 136 | + |
| 137 | + if mask is not None: |
| 138 | + mask = rearrange(mask, 'b ... -> b (...)') |
| 139 | + max_neg_value = -torch.finfo(sim.dtype).max |
| 140 | + mask = repeat(mask, 'b j -> (b h) () j', h=h) |
| 141 | + sim.masked_fill_(~mask, max_neg_value) |
| 142 | + |
| 143 | + # attention, what we cannot get enough of |
| 144 | + attn = sim.softmax(dim=-1) |
| 145 | + |
| 146 | + out = einsum('b i j, b j d -> b i d', attn, v) |
| 147 | + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) |
| 148 | + return self.to_out(out) |
| 149 | + |
| 150 | + |
| 151 | +def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt): |
| 152 | + assert hypernetwork_name, 'embedding not selected' |
| 153 | + |
| 154 | + shared.hypernetwork = shared.hypernetworks[hypernetwork_name] |
| 155 | + |
| 156 | + shared.state.textinfo = "Initializing hypernetwork training..." |
| 157 | + shared.state.job_count = steps |
| 158 | + |
| 159 | + filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') |
| 160 | + |
| 161 | + log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name) |
| 162 | + |
| 163 | + if save_hypernetwork_every > 0: |
| 164 | + hypernetwork_dir = os.path.join(log_directory, "hypernetworks") |
| 165 | + os.makedirs(hypernetwork_dir, exist_ok=True) |
| 166 | + else: |
| 167 | + hypernetwork_dir = None |
| 168 | + |
| 169 | + if create_image_every > 0: |
| 170 | + images_dir = os.path.join(log_directory, "images") |
| 171 | + os.makedirs(images_dir, exist_ok=True) |
| 172 | + else: |
| 173 | + images_dir = None |
| 174 | + |
| 175 | + cond_model = shared.sd_model.cond_stage_model |
| 176 | + |
| 177 | + shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." |
| 178 | + with torch.autocast("cuda"): |
| 179 | + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file) |
| 180 | + |
| 181 | + hypernetwork = shared.hypernetworks[hypernetwork_name] |
| 182 | + weights = hypernetwork.weights() |
| 183 | + for weight in weights: |
| 184 | + weight.requires_grad = True |
| 185 | + |
| 186 | + optimizer = torch.optim.AdamW(weights, lr=learn_rate) |
| 187 | + |
| 188 | + losses = torch.zeros((32,)) |
| 189 | + |
| 190 | + last_saved_file = "<none>" |
| 191 | + last_saved_image = "<none>" |
| 192 | + |
| 193 | + ititial_step = hypernetwork.step or 0 |
| 194 | + if ititial_step > steps: |
| 195 | + return hypernetwork, filename |
| 196 | + |
| 197 | + pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) |
| 198 | + for i, (x, text) in pbar: |
| 199 | + hypernetwork.step = i + ititial_step |
| 200 | + |
| 201 | + if hypernetwork.step > steps: |
| 202 | + break |
| 203 | + |
| 204 | + if shared.state.interrupted: |
| 205 | + break |
| 206 | + |
| 207 | + with torch.autocast("cuda"): |
| 208 | + c = cond_model([text]) |
| 209 | + |
| 210 | + x = x.to(devices.device) |
| 211 | + loss = shared.sd_model(x.unsqueeze(0), c)[0] |
| 212 | + del x |
| 213 | + |
| 214 | + losses[hypernetwork.step % losses.shape[0]] = loss.item() |
| 215 | + |
| 216 | + optimizer.zero_grad() |
| 217 | + loss.backward() |
| 218 | + optimizer.step() |
| 219 | + |
| 220 | + pbar.set_description(f"loss: {losses.mean():.7f}") |
| 221 | + |
| 222 | + if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0: |
| 223 | + last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt') |
| 224 | + hypernetwork.save(last_saved_file) |
| 225 | + |
| 226 | + if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: |
| 227 | + last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png') |
| 228 | + |
| 229 | + preview_text = text if preview_image_prompt == "" else preview_image_prompt |
| 230 | + |
| 231 | + p = processing.StableDiffusionProcessingTxt2Img( |
| 232 | + sd_model=shared.sd_model, |
| 233 | + prompt=preview_text, |
| 234 | + steps=20, |
| 235 | + do_not_save_grid=True, |
| 236 | + do_not_save_samples=True, |
| 237 | + ) |
| 238 | + |
| 239 | + processed = processing.process_images(p) |
| 240 | + image = processed.images[0] |
| 241 | + |
| 242 | + shared.state.current_image = image |
| 243 | + image.save(last_saved_image) |
| 244 | + |
| 245 | + last_saved_image += f", prompt: {preview_text}" |
| 246 | + |
| 247 | + shared.state.job_no = hypernetwork.step |
| 248 | + |
| 249 | + shared.state.textinfo = f""" |
| 250 | +<p> |
| 251 | +Loss: {losses.mean():.7f}<br/> |
| 252 | +Step: {hypernetwork.step}<br/> |
| 253 | +Last prompt: {html.escape(text)}<br/> |
| 254 | +Last saved embedding: {html.escape(last_saved_file)}<br/> |
| 255 | +Last saved image: {html.escape(last_saved_image)}<br/> |
| 256 | +</p> |
| 257 | +""" |
| 258 | + |
| 259 | + checkpoint = sd_models.select_checkpoint() |
| 260 | + |
| 261 | + hypernetwork.sd_checkpoint = checkpoint.hash |
| 262 | + hypernetwork.sd_checkpoint_name = checkpoint.model_name |
| 263 | + hypernetwork.save(filename) |
| 264 | + |
| 265 | + return hypernetwork, filename |
| 266 | + |
| 267 | + |
0 commit comments