Skip to content

Commit bc70665

Browse files
committed
Initial commit
1 parent 3ab028c commit bc70665

30 files changed

+1787
-0
lines changed

.gitignore

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
__pycache__/
2+
.cache/
3+
*.pyc
4+
.DS_Store
5+
run*.sh

LICENSE

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
Copyright (c) 2017 Keith Ito
2+
3+
Permission is hereby granted, free of charge, to any person obtaining a copy
4+
of this software and associated documentation files (the "Software"), to deal
5+
in the Software without restriction, including without limitation the rights
6+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7+
copies of the Software, and to permit persons to whom the Software is
8+
furnished to do so, subject to the following conditions:
9+
10+
The above copyright notice and this permission notice shall be included in
11+
all copies or substantial portions of the Software.
12+
13+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19+
THE SOFTWARE.

README.md

+110
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,112 @@
11
# tacotron
2+
23
An implementation of Google's Tacotron speech synthesis model in Tensorflow.
4+
5+
6+
## Overview
7+
8+
Earlier this year, Google published a paper, [Tacotron: A Fully End-to-End Text-To-Speech Synthesis Model](https://arxiv.org/pdf/1703.10135.pdf),
9+
where they present a neural text-to-speech model that learns to synthesize speech directly from
10+
(text, audio) pairs.
11+
12+
Google [released](https://google.github.io/tacotron) some nice audio samples that their model
13+
generated but didn't provide their source code or training data. This is an attempt to
14+
implement the model described in their paper.
15+
16+
Output after training for 185K steps (~2 days):
17+
18+
* [Audio Samples](https://keithito.github.io/audio-samples/)
19+
20+
The quality isn't as good as what Google demoed. But hopefully it will get there someday :-).
21+
22+
23+
24+
## Quick Start
25+
26+
### Installing dependencies
27+
```
28+
pip install -r requirements.txt
29+
```
30+
31+
32+
### Using a pre-trained model
33+
34+
1. Download and unpack a model:
35+
```
36+
curl http://data.keithito.com/data/speech/tacotron-20170708.tar.bz2 | tar x -C /tmp
37+
```
38+
39+
2. Run the demo server:
40+
```
41+
python3 demo_server.py --checkpoint /tmp/tacotron-20170708/model.ckpt
42+
```
43+
44+
3. Point your browser at [localhost:9000](http://localhost:9000) and type!
45+
46+
47+
48+
### Training
49+
50+
1. Download a speech dataset. The following are supported out of the box:
51+
* [LJ Speech](https://keithito.com/LJ-Speech-Dataset) (Public Domain)
52+
* [Blizzard 2012](http://www.cstr.ed.ac.uk/projects/blizzard/2012/phase_one) (Creative Commons Attribution Share-Alike)
53+
54+
You can use other datasets if you convert them to the right format. See
55+
[ljspeech.py](datasets/ljspeech.py) for an example.
56+
57+
58+
2. Unpack the dataset into `~/tacotron`. After unpacking, your tree should look like this for
59+
LJ Speech:
60+
```
61+
tacotron
62+
|- LJSpeech-1.0
63+
|- metadata.csv
64+
|- wavs
65+
```
66+
67+
or like this for Blizzard 2012:
68+
```
69+
tacotron
70+
|- Blizzard2012
71+
|- ATrampAbroad
72+
| |- sentence_index.txt
73+
| |- lab
74+
| |- wav
75+
|- TheManThatCorruptedHadleyburg
76+
|- sentence_index.txt
77+
|- lab
78+
|- wav
79+
```
80+
81+
3. Preprocess the data
82+
```
83+
python3 preprocess.py --dataset ljspeech
84+
```
85+
*Use --dataset blizzard for Blizzard data*
86+
87+
4. Train
88+
```
89+
python3 train.py
90+
```
91+
*Note: using [TCMalloc](http://goog-perftools.sourceforge.net/doc/tcmalloc.html) seems to
92+
improve training performance.*
93+
94+
5. Monitor with Tensorboard (optional)
95+
```
96+
tensorboard --logdir ~/tacotron/logs-tacotron
97+
```
98+
99+
The trainer dumps audio and alignments every 1000 steps. You can find these in
100+
`~/tacotron/logs-tacotron`. You can also pass a Slack webhook URL as the `--slack_url`
101+
flag, and it will send you progress updates.
102+
103+
104+
105+
## Other Implementations
106+
107+
* Alex Barron has some nice results from his implementation trained on the
108+
[Nancy Corpus](http://www.cstr.ed.ac.uk/projects/blizzard/2011/lessac_blizzard2011):
109+
https://github.com/barronalex/Tacotron
110+
111+
* Kyubyong Park has a very promising implementation trained on the World English Bible here:
112+
https://github.com/Kyubyong/tacotron

datasets/__init__.py

Whitespace-only changes.

datasets/blizzard.py

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from concurrent.futures import ProcessPoolExecutor
2+
from functools import partial
3+
import numpy as np
4+
import os
5+
from hparams import hparams
6+
from util import audio
7+
8+
9+
_max_out_length = 700
10+
_end_buffer = 0.05
11+
_min_confidence = 90
12+
13+
# Note: "A Tramp Abroad" & "The Man That Corrupted Hadleyburg" are higher quality than the others.
14+
books = [
15+
'ATrampAbroad',
16+
'TheManThatCorruptedHadleyburg',
17+
# 'LifeOnTheMississippi',
18+
# 'TheAdventuresOfTomSawyer',
19+
]
20+
21+
def build_from_path(in_dir, out_dir, num_workers=1, tqdm=lambda x: x):
22+
executor = ProcessPoolExecutor(max_workers=num_workers)
23+
futures = []
24+
index = 1
25+
for book in books:
26+
with open(os.path.join(in_dir, book, 'sentence_index.txt')) as f:
27+
for line in f:
28+
parts = line.strip().split('\t')
29+
if line[0] is not '#' and len(parts) == 8 and float(parts[3]) > _min_confidence:
30+
wav_path = os.path.join(in_dir, book, 'wav', '%s.wav' % parts[0])
31+
labels_path = os.path.join(in_dir, book, 'lab', '%s.lab' % parts[0])
32+
text = parts[5]
33+
task = partial(_process_utterance, out_dir, index, wav_path, labels_path, text)
34+
futures.append(executor.submit(task))
35+
index += 1
36+
results = [future.result() for future in tqdm(futures)]
37+
return [r for r in results if r is not None]
38+
39+
40+
def _process_utterance(out_dir, index, wav_path, labels_path, text):
41+
# Load the wav file and trim silence from the ends:
42+
wav = audio.load_wav(wav_path)
43+
start_offset, end_offset = _parse_labels(labels_path)
44+
start = int(start_offset * hparams.sample_rate)
45+
end = int(end_offset * hparams.sample_rate) if end_offset is not None else -1
46+
wav = wav[start:end]
47+
max_samples = _max_out_length * hparams.frame_shift_ms / 1000 * hparams.sample_rate
48+
if len(wav) > max_samples:
49+
return None
50+
spectrogram = audio.spectrogram(wav).astype(np.float32)
51+
n_frames = spectrogram.shape[1]
52+
mel_spectrogram = audio.melspectrogram(wav).astype(np.float32)
53+
spectrogram_filename = 'blizzard-spec-%05d.npy' % index
54+
mel_filename = 'blizzard-mel-%05d.npy' % index
55+
np.save(os.path.join(out_dir, spectrogram_filename), spectrogram.T, allow_pickle=False)
56+
np.save(os.path.join(out_dir, mel_filename), mel_spectrogram.T, allow_pickle=False)
57+
return (spectrogram_filename, mel_filename, n_frames, text)
58+
59+
60+
def _parse_labels(path):
61+
labels = []
62+
with open(os.path.join(path)) as f:
63+
for line in f:
64+
parts = line.strip().split(' ')
65+
if len(parts) >= 3:
66+
labels.append((float(parts[0]), ' '.join(parts[2:])))
67+
start = 0
68+
end = None
69+
if labels[0][1] == 'sil':
70+
start = labels[0][0]
71+
if labels[-1][1] == 'sil':
72+
end = labels[-2][0] + _end_buffer
73+
return (start, end)

datasets/datafeeder.py

+150
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
import numpy as np
2+
import os
3+
import random
4+
import tensorflow as tf
5+
import threading
6+
import time
7+
import traceback
8+
from util import cmudict, textinput
9+
from util.infolog import log
10+
11+
12+
_batches_per_group = 32
13+
_p_cmudict = 0.5
14+
_pad = 0
15+
16+
17+
class DataFeeder(threading.Thread):
18+
'''Feeds batches of data into a queue on a background thread.'''
19+
20+
def __init__(self, coordinator, metadata_filename, hparams):
21+
super(DataFeeder, self).__init__()
22+
self._coord = coordinator
23+
self._hparams = hparams
24+
self._offset = 0
25+
26+
# Load metadata:
27+
self._datadir = os.path.dirname(metadata_filename)
28+
with open(metadata_filename) as f:
29+
self._metadata = [line.strip().split('|') for line in f]
30+
hours = sum((int(x[2]) for x in self._metadata)) * hparams.frame_shift_ms / (3600 * 1000)
31+
log('Loaded metadata for %d examples (%.2f hours)' % (len(self._metadata), hours))
32+
33+
# Create placeholders for inputs and targets. Don't specify batch size because we want to
34+
# be able to feed different sized batches at eval time.
35+
self._placeholders = [
36+
tf.placeholder(tf.int32, [None, None], 'inputs'),
37+
tf.placeholder(tf.int32, [None], 'input_lengths'),
38+
tf.placeholder(tf.float32, [None, None, hparams.num_mels], 'mel_targets'),
39+
tf.placeholder(tf.float32, [None, None, hparams.num_freq], 'linear_targets')
40+
]
41+
42+
# Create queue for buffering data:
43+
queue = tf.FIFOQueue(8, [tf.int32, tf.int32, tf.float32, tf.float32], name='input_queue')
44+
self._enqueue_op = queue.enqueue(self._placeholders)
45+
self.inputs, self.input_lengths, self.mel_targets, self.linear_targets = queue.dequeue()
46+
self.inputs.set_shape(self._placeholders[0].shape)
47+
self.input_lengths.set_shape(self._placeholders[1].shape)
48+
self.mel_targets.set_shape(self._placeholders[2].shape)
49+
self.linear_targets.set_shape(self._placeholders[3].shape)
50+
51+
# Load CMUDict: If enabled, this will randomly substitute some words in the training data with
52+
# their ARPABet equivalents, which will allow you to also pass ARPABet to the model for
53+
# synthesis (useful for proper nouns, etc.)
54+
if hparams.use_cmudict:
55+
cmudict_path = os.path.join(self._datadir, 'cmudict-0.7b')
56+
if not os.path.isfile(cmudict_path):
57+
raise Exception('If use_cmudict=True, you must download ' +
58+
'http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b to %s' % cmudict_path)
59+
self._cmudict = cmudict.CMUDict(cmudict_path, keep_ambiguous=False)
60+
log('Loaded CMUDict with %d unambiguous entries' % len(self._cmudict))
61+
else:
62+
self._cmudict = None
63+
64+
65+
def start_in_session(self, session):
66+
self._session = session
67+
self.start()
68+
69+
70+
def run(self):
71+
try:
72+
while not self._coord.should_stop():
73+
self._enqueue_next_group()
74+
except Exception as e:
75+
traceback.print_exc()
76+
self._coord.request_stop(e)
77+
78+
79+
def _enqueue_next_group(self):
80+
start = time.time()
81+
82+
# Read a group of examples:
83+
n = self._hparams.batch_size
84+
r = self._hparams.outputs_per_step
85+
examples = [self._get_next_example() for i in range(n * _batches_per_group)]
86+
87+
# Bucket examples based on similar output sequence length for efficiency:
88+
examples.sort(key=lambda x: x[-1])
89+
batches = [examples[i:i+n] for i in range(0, len(examples), n)]
90+
random.shuffle(batches)
91+
92+
log('Generated %d batches of size %d in %.03f sec' % (len(batches), n, time.time() - start))
93+
for batch in batches:
94+
feed_dict = dict(zip(self._placeholders, _prepare_batch(batch, r)))
95+
self._session.run(self._enqueue_op, feed_dict=feed_dict)
96+
97+
98+
def _get_next_example(self):
99+
'''Loads a single example (input, mel_target, linear_target, cost) from disk'''
100+
if self._offset >= len(self._metadata):
101+
self._offset = 0
102+
random.shuffle(self._metadata)
103+
meta = self._metadata[self._offset]
104+
self._offset += 1
105+
106+
text = meta[3]
107+
if self._cmudict and random.random() < _p_cmudict:
108+
text = ' '.join([self._maybe_get_arpabet(word) for word in text.split(' ')])
109+
110+
input_data = np.asarray(textinput.to_sequence(text), dtype=np.int32)
111+
linear_target = np.load(os.path.join(self._datadir, meta[0]))
112+
mel_target = np.load(os.path.join(self._datadir, meta[1]))
113+
return (input_data, mel_target, linear_target, len(linear_target))
114+
115+
116+
def _maybe_get_arpabet(self, word):
117+
pron = self._cmudict.lookup(word)
118+
return '{%s}' % pron[0] if pron is not None and random.random() < 0.5 else word
119+
120+
121+
def _prepare_batch(batch, outputs_per_step):
122+
random.shuffle(batch)
123+
inputs = _prepare_inputs([x[0] for x in batch])
124+
input_lengths = np.asarray([len(x[0]) for x in batch], dtype=np.int32)
125+
mel_targets = _prepare_targets([x[1] for x in batch], outputs_per_step)
126+
linear_targets = _prepare_targets([x[2] for x in batch], outputs_per_step)
127+
return (inputs, input_lengths, mel_targets, linear_targets)
128+
129+
130+
def _prepare_inputs(inputs):
131+
max_len = max((len(x) for x in inputs))
132+
return np.stack([_pad_input(x, max_len) for x in inputs])
133+
134+
135+
def _prepare_targets(targets, alignment):
136+
max_len = max((len(t) for t in targets)) + 1
137+
return np.stack([_pad_target(t, _round_up(max_len, alignment)) for t in targets])
138+
139+
140+
def _pad_input(x, length):
141+
return np.pad(x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)
142+
143+
144+
def _pad_target(t, length):
145+
return np.pad(t, [(0, length - t.shape[0]), (0,0)], mode='constant', constant_values=_pad)
146+
147+
148+
def _round_up(x, multiple):
149+
remainder = x % multiple
150+
return x if remainder == 0 else x + multiple - remainder

datasets/ljspeech.py

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from concurrent.futures import ProcessPoolExecutor
2+
from functools import partial
3+
import numpy as np
4+
import os
5+
from util import audio
6+
7+
8+
def build_from_path(in_dir, out_dir, num_workers=1, tqdm=lambda x: x):
9+
executor = ProcessPoolExecutor(max_workers=num_workers)
10+
futures = []
11+
index = 1
12+
with open(os.path.join(in_dir, 'metadata.csv')) as f:
13+
for line in f:
14+
parts = line.strip().split('|')
15+
wav_path = os.path.join(in_dir, 'wavs', '%s.wav' % parts[0])
16+
text = parts[2]
17+
futures.append(executor.submit(partial(_process_utterance, out_dir, index, wav_path, text)))
18+
index += 1
19+
return [future.result() for future in tqdm(futures)]
20+
21+
22+
def _process_utterance(out_dir, index, wav_path, text):
23+
wav = audio.load_wav(wav_path)
24+
spectrogram = audio.spectrogram(wav).astype(np.float32)
25+
n_frames = spectrogram.shape[1]
26+
mel_spectrogram = audio.melspectrogram(wav).astype(np.float32)
27+
spectrogram_filename = 'ljspeech-spec-%05d.npy' % index
28+
mel_filename = 'ljspeech-mel-%05d.npy' % index
29+
np.save(os.path.join(out_dir, spectrogram_filename), spectrogram.T, allow_pickle=False)
30+
np.save(os.path.join(out_dir, mel_filename), mel_spectrogram.T, allow_pickle=False)
31+
return (spectrogram_filename, mel_filename, n_frames, text)

0 commit comments

Comments
 (0)