|
| 1 | +from time import time # Track how long an epoch takes |
| 2 | +import os # Creating and finding files/directories |
| 3 | +import logging # Logging tools |
| 4 | +from datetime import date # Logging the date for model versioning |
| 5 | + |
| 6 | +import torch # For ML |
| 7 | +from tqdm import tqdm # For fancy progress bars |
| 8 | + |
| 9 | +from src.model import Translator # Our model |
| 10 | +from src.data import get_data, create_mask, generate_square_subsequent_mask # Loading data and data preprocessing |
| 11 | +from argparse import ArgumentParser # For args |
| 12 | + |
| 13 | +# Train on the GPU if possible |
| 14 | +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 15 | + |
| 16 | +# Function to generate output sequence using greedy algorithm |
| 17 | +def greedy_decode(model, src, src_mask, max_len, start_symbol, end_symbol): |
| 18 | + |
| 19 | + # Move to device |
| 20 | + src = src.to(DEVICE) |
| 21 | + src_mask = src_mask.to(DEVICE) |
| 22 | + |
| 23 | + # Encode input |
| 24 | + memory = model.encode(src, src_mask) |
| 25 | + |
| 26 | + # Output will be stored here |
| 27 | + ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE) |
| 28 | + |
| 29 | + # For each element in our translation (which could range up to the maximum translation length) |
| 30 | + for _ in range(max_len-1): |
| 31 | + |
| 32 | + # Decode the encoded representation of the input |
| 33 | + memory = memory.to(DEVICE) |
| 34 | + tgt_mask = (generate_square_subsequent_mask(ys.size(0), DEVICE).type(torch.bool)).to(DEVICE) |
| 35 | + out = model.decode(ys, memory, tgt_mask) |
| 36 | + |
| 37 | + # Reshape |
| 38 | + out = out.transpose(0, 1) |
| 39 | + |
| 40 | + # Covert to probabilities and take the max of these probabilities |
| 41 | + prob = model.ff(out[:, -1]) |
| 42 | + _, next_word = torch.max(prob, dim=1) |
| 43 | + next_word = next_word.item() |
| 44 | + |
| 45 | + # Now we have an output which is the vector representation of the translation |
| 46 | + ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0) |
| 47 | + if next_word == end_symbol: |
| 48 | + break |
| 49 | + |
| 50 | + return ys |
| 51 | + |
| 52 | +# Opens an user interface where users can translate an arbitrary sentence |
| 53 | +def inference(opts): |
| 54 | + |
| 55 | + # Get training data, tokenizer and vocab |
| 56 | + # objects as well as any special symbols we added to our dataset |
| 57 | + _, _, src_vocab, tgt_vocab, src_transform, _, special_symbols = get_data(opts) |
| 58 | + |
| 59 | + src_vocab_size = len(src_vocab) |
| 60 | + tgt_vocab_size = len(tgt_vocab) |
| 61 | + |
| 62 | + # Create model |
| 63 | + model = Translator( |
| 64 | + num_encoder_layers=opts.enc_layers, |
| 65 | + num_decoder_layers=opts.dec_layers, |
| 66 | + embed_size=opts.embed_size, |
| 67 | + num_heads=opts.attn_heads, |
| 68 | + src_vocab_size=src_vocab_size, |
| 69 | + tgt_vocab_size=tgt_vocab_size, |
| 70 | + dim_feedforward=opts.dim_feedforward, |
| 71 | + dropout=opts.dropout |
| 72 | + ).to(DEVICE) |
| 73 | + |
| 74 | + # Load in weights |
| 75 | + model.load_state_dict(torch.load(opts.model_path)) |
| 76 | + |
| 77 | + # Set to inference |
| 78 | + model.eval() |
| 79 | + |
| 80 | + # Accept input and keep translating until they quit |
| 81 | + while True: |
| 82 | + print("> ", end="") |
| 83 | + |
| 84 | + sentence = input() |
| 85 | + |
| 86 | + # Convert to tokens |
| 87 | + src = src_transform(sentence).view(-1, 1) |
| 88 | + num_tokens = src.shape[0] |
| 89 | + |
| 90 | + src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool) |
| 91 | + |
| 92 | + # Decode |
| 93 | + tgt_tokens = greedy_decode( |
| 94 | + model, src, src_mask, max_len=num_tokens+5, start_symbol=special_symbols["<bos>"], end_symbol=special_symbols["<eos>"] |
| 95 | + ).flatten() |
| 96 | + |
| 97 | + # Convert to list of tokens |
| 98 | + output_as_list = list(tgt_tokens.cpu().numpy()) |
| 99 | + |
| 100 | + # Convert tokens to words |
| 101 | + output_list_words = tgt_vocab.lookup_tokens(output_as_list) |
| 102 | + |
| 103 | + # Remove special tokens and convert to string |
| 104 | + translation = " ".join(output_list_words).replace("<bos>", "").replace("<eos>", "") |
| 105 | + |
| 106 | + print(translation) |
| 107 | + |
| 108 | +# Train the model for 1 epoch |
| 109 | +def train(model, train_dl, loss_fn, optim, special_symbols, opts): |
| 110 | + |
| 111 | + # Object for accumulating losses |
| 112 | + losses = 0 |
| 113 | + |
| 114 | + # Put model into inference mode |
| 115 | + model.train() |
| 116 | + for src, tgt in tqdm(train_dl, ascii=True): |
| 117 | + |
| 118 | + src = src.to(DEVICE) |
| 119 | + tgt = tgt.to(DEVICE) |
| 120 | + |
| 121 | + # We need to reshape the input slightly to fit into the transformer |
| 122 | + tgt_input = tgt[:-1, :] |
| 123 | + |
| 124 | + # Create masks |
| 125 | + src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input, special_symbols["<pad>"], DEVICE) |
| 126 | + |
| 127 | + # Pass into model, get probability over the vocab out |
| 128 | + logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask) |
| 129 | + |
| 130 | + # Reset gradients before we try to compute the gradients over the loss |
| 131 | + optim.zero_grad() |
| 132 | + |
| 133 | + # Get original shape back |
| 134 | + tgt_out = tgt[1:, :] |
| 135 | + |
| 136 | + # Compute loss and gradient over that loss |
| 137 | + loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1)) |
| 138 | + loss.backward() |
| 139 | + |
| 140 | + # Step weights |
| 141 | + optim.step() |
| 142 | + |
| 143 | + # Accumulate a running loss for reporting |
| 144 | + losses += loss.item() |
| 145 | + |
| 146 | + if opts.dry_run: |
| 147 | + break |
| 148 | + |
| 149 | + # Return the average loss |
| 150 | + return losses / len(list(train_dl)) |
| 151 | + |
| 152 | +# Check the model accuracy on the validation dataset |
| 153 | +def validate(model, valid_dl, loss_fn, special_symbols): |
| 154 | + |
| 155 | + # Object for accumulating losses |
| 156 | + losses = 0 |
| 157 | + |
| 158 | + # Turn off gradients a moment |
| 159 | + model.eval() |
| 160 | + |
| 161 | + for src, tgt in tqdm(valid_dl): |
| 162 | + |
| 163 | + src = src.to(DEVICE) |
| 164 | + tgt = tgt.to(DEVICE) |
| 165 | + |
| 166 | + # We need to reshape the input slightly to fit into the transformer |
| 167 | + tgt_input = tgt[:-1, :] |
| 168 | + |
| 169 | + # Create masks |
| 170 | + src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input, special_symbols["<pad>"], DEVICE) |
| 171 | + |
| 172 | + # Pass into model, get probability over the vocab out |
| 173 | + logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask) |
| 174 | + |
| 175 | + # Get original shape back, compute loss, accumulate that loss |
| 176 | + tgt_out = tgt[1:, :] |
| 177 | + loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1)) |
| 178 | + losses += loss.item() |
| 179 | + |
| 180 | + # Return the average loss |
| 181 | + return losses / len(list(valid_dl)) |
| 182 | + |
| 183 | +# Train the model |
| 184 | +def main(opts): |
| 185 | + |
| 186 | + # Set up logging |
| 187 | + os.makedirs(opts.logging_dir, exist_ok=True) |
| 188 | + logger = logging.getLogger(__name__) |
| 189 | + logging.basicConfig(filename=opts.logging_dir + "log.txt", level=logging.INFO) |
| 190 | + |
| 191 | + # This prints it to the screen as well |
| 192 | + console = logging.StreamHandler() |
| 193 | + console.setLevel(logging.INFO) |
| 194 | + logging.getLogger().addHandler(console) |
| 195 | + |
| 196 | + logging.info(f"Translation task: {opts.src} -> {opts.tgt}") |
| 197 | + logging.info(f"Using device: {DEVICE}") |
| 198 | + |
| 199 | + # Get training data, tokenizer and vocab |
| 200 | + # objects as well as any special symbols we added to our dataset |
| 201 | + train_dl, valid_dl, src_vocab, tgt_vocab, _, _, special_symbols = get_data(opts) |
| 202 | + |
| 203 | + logging.info("Loaded data") |
| 204 | + |
| 205 | + src_vocab_size = len(src_vocab) |
| 206 | + tgt_vocab_size = len(tgt_vocab) |
| 207 | + |
| 208 | + logging.info(f"{opts.src} vocab size: {src_vocab_size}") |
| 209 | + logging.info(f"{opts.tgt} vocab size: {tgt_vocab_size}") |
| 210 | + |
| 211 | + # Create model |
| 212 | + model = Translator( |
| 213 | + num_encoder_layers=opts.enc_layers, |
| 214 | + num_decoder_layers=opts.dec_layers, |
| 215 | + embed_size=opts.embed_size, |
| 216 | + num_heads=opts.attn_heads, |
| 217 | + src_vocab_size=src_vocab_size, |
| 218 | + tgt_vocab_size=tgt_vocab_size, |
| 219 | + dim_feedforward=opts.dim_feedforward, |
| 220 | + dropout=opts.dropout |
| 221 | + ).to(DEVICE) |
| 222 | + |
| 223 | + logging.info("Model created... starting training!") |
| 224 | + |
| 225 | + # Set up our learning tools |
| 226 | + loss_fn = torch.nn.CrossEntropyLoss(ignore_index=special_symbols["<pad>"]) |
| 227 | + |
| 228 | + # These special values are from the "Attention is all you need" paper |
| 229 | + optim = torch.optim.Adam(model.parameters(), lr=opts.lr, betas=(0.9, 0.98), eps=1e-9) |
| 230 | + |
| 231 | + best_val_loss = 1e6 |
| 232 | + |
| 233 | + for idx, epoch in enumerate(range(1, opts.epochs+1)): |
| 234 | + |
| 235 | + start_time = time() |
| 236 | + train_loss = train(model, train_dl, loss_fn, optim, special_symbols, opts) |
| 237 | + epoch_time = time() - start_time |
| 238 | + val_loss = validate(model, valid_dl, loss_fn, special_symbols) |
| 239 | + |
| 240 | + # Once training is done, we want to save out the model |
| 241 | + if val_loss < best_val_loss: |
| 242 | + best_val_loss = val_loss |
| 243 | + logging.info("New best model, saving...") |
| 244 | + torch.save(model.state_dict(), opts.logging_dir + "best.pt") |
| 245 | + |
| 246 | + torch.save(model.state_dict(), opts.logging_dir + "last.pt") |
| 247 | + |
| 248 | + logger.info(f"Epoch: {epoch}\n\tTrain loss: {train_loss:.3f}\n\tVal loss: {val_loss:.3f}\n\tEpoch time = {epoch_time:.1f} seconds\n\tETA = {epoch_time*(opts.epochs-idx-1):.1f} seconds") |
| 249 | + |
| 250 | +if __name__ == "__main__": |
| 251 | + |
| 252 | + parser = ArgumentParser( |
| 253 | + prog="Machine Translator training and inference", |
| 254 | + ) |
| 255 | + |
| 256 | + # Inference mode |
| 257 | + parser.add_argument("--inference", action="store_true", |
| 258 | + help="Set true to run inference") |
| 259 | + parser.add_argument("--model_path", type=str, |
| 260 | + help="Path to the model to run inference on") |
| 261 | + |
| 262 | + # Translation settings |
| 263 | + parser.add_argument("--src", type=str, default="de", |
| 264 | + help="Source language (translating FROM this language)") |
| 265 | + parser.add_argument("--tgt", type=str, default="en", |
| 266 | + help="Target language (translating TO this language)") |
| 267 | + |
| 268 | + # Training settings |
| 269 | + parser.add_argument("-e", "--epochs", type=int, default=30, |
| 270 | + help="Epochs") |
| 271 | + parser.add_argument("--lr", type=float, default=1e-4, |
| 272 | + help="Default learning rate") |
| 273 | + parser.add_argument("--batch", type=int, default=128, |
| 274 | + help="Batch size") |
| 275 | + parser.add_argument("--backend", type=str, default="cpu", |
| 276 | + help="Batch size") |
| 277 | + |
| 278 | + # Transformer settings |
| 279 | + parser.add_argument("--attn_heads", type=int, default=8, |
| 280 | + help="Number of attention heads") |
| 281 | + parser.add_argument("--enc_layers", type=int, default=5, |
| 282 | + help="Number of encoder layers") |
| 283 | + parser.add_argument("--dec_layers", type=int, default=5, |
| 284 | + help="Number of decoder layers") |
| 285 | + parser.add_argument("--embed_size", type=int, default=512, |
| 286 | + help="Size of the language embedding") |
| 287 | + parser.add_argument("--dim_feedforward", type=int, default=512, |
| 288 | + help="Feedforward dimensionality") |
| 289 | + parser.add_argument("--dropout", type=float, default=0.1, |
| 290 | + help="Transformer dropout") |
| 291 | + |
| 292 | + # Logging settings |
| 293 | + parser.add_argument("--logging_dir", type=str, default="./" + str(date.today()) + "/", |
| 294 | + help="Where the output of this program should be placed") |
| 295 | + |
| 296 | + # Just for continuous integration |
| 297 | + parser.add_argument("--dry_run", action="store_true") |
| 298 | + |
| 299 | + args = parser.parse_args() |
| 300 | + |
| 301 | + DEVICE = torch.device("cuda" if args.backend == "gpu" and torch.cuda.is_available() else "cpu") |
| 302 | + |
| 303 | + if args.inference: |
| 304 | + inference(args) |
| 305 | + else: |
| 306 | + main(args) |
0 commit comments