Skip to content

Commit 7df10c2

Browse files
Language translation example added (#1131) (#1240)
* Intial model setup * Training works * Added inference * Code clean up and commenting * Update to README.md * Add requirements.txt * Updated top level README, added example to CI * Potentially fixed testing (maybe not enough memory?) * Update requirements.txt --------- Co-authored-by: Mark Saroufim <marksaroufim@meta.com>
1 parent 2d725b6 commit 7df10c2

File tree

7 files changed

+608
-1
lines changed

7 files changed

+608
-1
lines changed

README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ https://pytorch.org/examples/
2929
- [PyTorch Module Transformations using fx](./fx/README.md)
3030
- Distributed PyTorch examples with [Distributed Data Parallel](./distributed/ddp/README.md) and [RPC](./distributed/rpc)
3131
- [Several examples illustrating the C++ Frontend](cpp)
32-
- [Image Classification Using Forward-Forward ](./mnist_forward_forward/README.md)
32+
- [Image Classification Using Forward-Forward](./mnist_forward_forward/README.md)
33+
- [Language Translation using Transformers](./language_translation/README.md)
3334

3435

3536

language_translation/README.md

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Language Translation
2+
3+
This example shows how one might use transformers for language translation. In particular, this implementation is loosely based on the [Attention is All You Need paper](https://arxiv.org/abs/1706.03762).
4+
5+
## Requirements
6+
7+
We will need a tokenizer for our languages. Torchtext does include a tokenizer for English, but unfortunately, we will need more languages then that. We can get these tokenizers via ```spacy```
8+
9+
```bash
10+
python3 -m spacy download <language>
11+
python3 -m spacy download en
12+
python3 -m spacy download de
13+
```
14+
15+
Spacy supports many languages. For a full accounting of supported languages, please look [here](https://spacy.io/usage/models). This example will default from German to English.
16+
17+
Torchtext is also required:
18+
```bash
19+
pip install torchtext
20+
```
21+
22+
Just running these commands will get you started:
23+
```bash
24+
pip install -r requirements.txt
25+
python3 -m spacy download <language-you-want>
26+
```
27+
28+
## Usage
29+
30+
This example contains a lot of flags that you can set to change the behavior / training of the module. You can see all of them by running:
31+
32+
```bash
33+
python3 main.py -h
34+
```
35+
36+
But in general, all of the settings have "sensible" defaults; however, the default translation is to translate from German to English. To *train* the model, you only need to run the following command, but there is also an example for how to use any language you want:
37+
38+
```bash
39+
python3 main.py
40+
python3 main.py --src en --tgt fr # For english to french translation
41+
```
42+
43+
For model inference, you can use this command:
44+
45+
```bash
46+
python3 main.py --inference --model_path <path-to-model>
47+
```
48+
49+
After some loading time, this will open an interactive interface where you can type in whatever sentence you are interested in translating.

language_translation/main.py

+306
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
from time import time # Track how long an epoch takes
2+
import os # Creating and finding files/directories
3+
import logging # Logging tools
4+
from datetime import date # Logging the date for model versioning
5+
6+
import torch # For ML
7+
from tqdm import tqdm # For fancy progress bars
8+
9+
from src.model import Translator # Our model
10+
from src.data import get_data, create_mask, generate_square_subsequent_mask # Loading data and data preprocessing
11+
from argparse import ArgumentParser # For args
12+
13+
# Train on the GPU if possible
14+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15+
16+
# Function to generate output sequence using greedy algorithm
17+
def greedy_decode(model, src, src_mask, max_len, start_symbol, end_symbol):
18+
19+
# Move to device
20+
src = src.to(DEVICE)
21+
src_mask = src_mask.to(DEVICE)
22+
23+
# Encode input
24+
memory = model.encode(src, src_mask)
25+
26+
# Output will be stored here
27+
ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
28+
29+
# For each element in our translation (which could range up to the maximum translation length)
30+
for _ in range(max_len-1):
31+
32+
# Decode the encoded representation of the input
33+
memory = memory.to(DEVICE)
34+
tgt_mask = (generate_square_subsequent_mask(ys.size(0), DEVICE).type(torch.bool)).to(DEVICE)
35+
out = model.decode(ys, memory, tgt_mask)
36+
37+
# Reshape
38+
out = out.transpose(0, 1)
39+
40+
# Covert to probabilities and take the max of these probabilities
41+
prob = model.ff(out[:, -1])
42+
_, next_word = torch.max(prob, dim=1)
43+
next_word = next_word.item()
44+
45+
# Now we have an output which is the vector representation of the translation
46+
ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
47+
if next_word == end_symbol:
48+
break
49+
50+
return ys
51+
52+
# Opens an user interface where users can translate an arbitrary sentence
53+
def inference(opts):
54+
55+
# Get training data, tokenizer and vocab
56+
# objects as well as any special symbols we added to our dataset
57+
_, _, src_vocab, tgt_vocab, src_transform, _, special_symbols = get_data(opts)
58+
59+
src_vocab_size = len(src_vocab)
60+
tgt_vocab_size = len(tgt_vocab)
61+
62+
# Create model
63+
model = Translator(
64+
num_encoder_layers=opts.enc_layers,
65+
num_decoder_layers=opts.dec_layers,
66+
embed_size=opts.embed_size,
67+
num_heads=opts.attn_heads,
68+
src_vocab_size=src_vocab_size,
69+
tgt_vocab_size=tgt_vocab_size,
70+
dim_feedforward=opts.dim_feedforward,
71+
dropout=opts.dropout
72+
).to(DEVICE)
73+
74+
# Load in weights
75+
model.load_state_dict(torch.load(opts.model_path))
76+
77+
# Set to inference
78+
model.eval()
79+
80+
# Accept input and keep translating until they quit
81+
while True:
82+
print("> ", end="")
83+
84+
sentence = input()
85+
86+
# Convert to tokens
87+
src = src_transform(sentence).view(-1, 1)
88+
num_tokens = src.shape[0]
89+
90+
src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
91+
92+
# Decode
93+
tgt_tokens = greedy_decode(
94+
model, src, src_mask, max_len=num_tokens+5, start_symbol=special_symbols["<bos>"], end_symbol=special_symbols["<eos>"]
95+
).flatten()
96+
97+
# Convert to list of tokens
98+
output_as_list = list(tgt_tokens.cpu().numpy())
99+
100+
# Convert tokens to words
101+
output_list_words = tgt_vocab.lookup_tokens(output_as_list)
102+
103+
# Remove special tokens and convert to string
104+
translation = " ".join(output_list_words).replace("<bos>", "").replace("<eos>", "")
105+
106+
print(translation)
107+
108+
# Train the model for 1 epoch
109+
def train(model, train_dl, loss_fn, optim, special_symbols, opts):
110+
111+
# Object for accumulating losses
112+
losses = 0
113+
114+
# Put model into inference mode
115+
model.train()
116+
for src, tgt in tqdm(train_dl, ascii=True):
117+
118+
src = src.to(DEVICE)
119+
tgt = tgt.to(DEVICE)
120+
121+
# We need to reshape the input slightly to fit into the transformer
122+
tgt_input = tgt[:-1, :]
123+
124+
# Create masks
125+
src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input, special_symbols["<pad>"], DEVICE)
126+
127+
# Pass into model, get probability over the vocab out
128+
logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)
129+
130+
# Reset gradients before we try to compute the gradients over the loss
131+
optim.zero_grad()
132+
133+
# Get original shape back
134+
tgt_out = tgt[1:, :]
135+
136+
# Compute loss and gradient over that loss
137+
loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
138+
loss.backward()
139+
140+
# Step weights
141+
optim.step()
142+
143+
# Accumulate a running loss for reporting
144+
losses += loss.item()
145+
146+
if opts.dry_run:
147+
break
148+
149+
# Return the average loss
150+
return losses / len(list(train_dl))
151+
152+
# Check the model accuracy on the validation dataset
153+
def validate(model, valid_dl, loss_fn, special_symbols):
154+
155+
# Object for accumulating losses
156+
losses = 0
157+
158+
# Turn off gradients a moment
159+
model.eval()
160+
161+
for src, tgt in tqdm(valid_dl):
162+
163+
src = src.to(DEVICE)
164+
tgt = tgt.to(DEVICE)
165+
166+
# We need to reshape the input slightly to fit into the transformer
167+
tgt_input = tgt[:-1, :]
168+
169+
# Create masks
170+
src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input, special_symbols["<pad>"], DEVICE)
171+
172+
# Pass into model, get probability over the vocab out
173+
logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)
174+
175+
# Get original shape back, compute loss, accumulate that loss
176+
tgt_out = tgt[1:, :]
177+
loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
178+
losses += loss.item()
179+
180+
# Return the average loss
181+
return losses / len(list(valid_dl))
182+
183+
# Train the model
184+
def main(opts):
185+
186+
# Set up logging
187+
os.makedirs(opts.logging_dir, exist_ok=True)
188+
logger = logging.getLogger(__name__)
189+
logging.basicConfig(filename=opts.logging_dir + "log.txt", level=logging.INFO)
190+
191+
# This prints it to the screen as well
192+
console = logging.StreamHandler()
193+
console.setLevel(logging.INFO)
194+
logging.getLogger().addHandler(console)
195+
196+
logging.info(f"Translation task: {opts.src} -> {opts.tgt}")
197+
logging.info(f"Using device: {DEVICE}")
198+
199+
# Get training data, tokenizer and vocab
200+
# objects as well as any special symbols we added to our dataset
201+
train_dl, valid_dl, src_vocab, tgt_vocab, _, _, special_symbols = get_data(opts)
202+
203+
logging.info("Loaded data")
204+
205+
src_vocab_size = len(src_vocab)
206+
tgt_vocab_size = len(tgt_vocab)
207+
208+
logging.info(f"{opts.src} vocab size: {src_vocab_size}")
209+
logging.info(f"{opts.tgt} vocab size: {tgt_vocab_size}")
210+
211+
# Create model
212+
model = Translator(
213+
num_encoder_layers=opts.enc_layers,
214+
num_decoder_layers=opts.dec_layers,
215+
embed_size=opts.embed_size,
216+
num_heads=opts.attn_heads,
217+
src_vocab_size=src_vocab_size,
218+
tgt_vocab_size=tgt_vocab_size,
219+
dim_feedforward=opts.dim_feedforward,
220+
dropout=opts.dropout
221+
).to(DEVICE)
222+
223+
logging.info("Model created... starting training!")
224+
225+
# Set up our learning tools
226+
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=special_symbols["<pad>"])
227+
228+
# These special values are from the "Attention is all you need" paper
229+
optim = torch.optim.Adam(model.parameters(), lr=opts.lr, betas=(0.9, 0.98), eps=1e-9)
230+
231+
best_val_loss = 1e6
232+
233+
for idx, epoch in enumerate(range(1, opts.epochs+1)):
234+
235+
start_time = time()
236+
train_loss = train(model, train_dl, loss_fn, optim, special_symbols, opts)
237+
epoch_time = time() - start_time
238+
val_loss = validate(model, valid_dl, loss_fn, special_symbols)
239+
240+
# Once training is done, we want to save out the model
241+
if val_loss < best_val_loss:
242+
best_val_loss = val_loss
243+
logging.info("New best model, saving...")
244+
torch.save(model.state_dict(), opts.logging_dir + "best.pt")
245+
246+
torch.save(model.state_dict(), opts.logging_dir + "last.pt")
247+
248+
logger.info(f"Epoch: {epoch}\n\tTrain loss: {train_loss:.3f}\n\tVal loss: {val_loss:.3f}\n\tEpoch time = {epoch_time:.1f} seconds\n\tETA = {epoch_time*(opts.epochs-idx-1):.1f} seconds")
249+
250+
if __name__ == "__main__":
251+
252+
parser = ArgumentParser(
253+
prog="Machine Translator training and inference",
254+
)
255+
256+
# Inference mode
257+
parser.add_argument("--inference", action="store_true",
258+
help="Set true to run inference")
259+
parser.add_argument("--model_path", type=str,
260+
help="Path to the model to run inference on")
261+
262+
# Translation settings
263+
parser.add_argument("--src", type=str, default="de",
264+
help="Source language (translating FROM this language)")
265+
parser.add_argument("--tgt", type=str, default="en",
266+
help="Target language (translating TO this language)")
267+
268+
# Training settings
269+
parser.add_argument("-e", "--epochs", type=int, default=30,
270+
help="Epochs")
271+
parser.add_argument("--lr", type=float, default=1e-4,
272+
help="Default learning rate")
273+
parser.add_argument("--batch", type=int, default=128,
274+
help="Batch size")
275+
parser.add_argument("--backend", type=str, default="cpu",
276+
help="Batch size")
277+
278+
# Transformer settings
279+
parser.add_argument("--attn_heads", type=int, default=8,
280+
help="Number of attention heads")
281+
parser.add_argument("--enc_layers", type=int, default=5,
282+
help="Number of encoder layers")
283+
parser.add_argument("--dec_layers", type=int, default=5,
284+
help="Number of decoder layers")
285+
parser.add_argument("--embed_size", type=int, default=512,
286+
help="Size of the language embedding")
287+
parser.add_argument("--dim_feedforward", type=int, default=512,
288+
help="Feedforward dimensionality")
289+
parser.add_argument("--dropout", type=float, default=0.1,
290+
help="Transformer dropout")
291+
292+
# Logging settings
293+
parser.add_argument("--logging_dir", type=str, default="./" + str(date.today()) + "/",
294+
help="Where the output of this program should be placed")
295+
296+
# Just for continuous integration
297+
parser.add_argument("--dry_run", action="store_true")
298+
299+
args = parser.parse_args()
300+
301+
DEVICE = torch.device("cuda" if args.backend == "gpu" and torch.cuda.is_available() else "cpu")
302+
303+
if args.inference:
304+
inference(args)
305+
else:
306+
main(args)

language_translation/requirements.txt

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
torch
2+
torchtext
3+
torchdata
4+
spacy
5+
portalocker

0 commit comments

Comments
 (0)