-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy patheval.py
48 lines (35 loc) · 1.42 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from network import *
from data_utils import *
import hyperparams as hp
import librosa
class Graph:
def __init__(self):
self.graph = tf.Graph()
with self.graph.as_default():
self.x = tf.placeholder(tf.float32, [None, hp.timestep, 1], name='X')
output = network(self.x, use_mulaw=hp.use_mulaw)
if hp.use_mulaw:
self.prediction = mu_law_decode(tf.argmax(output, axis=2))
else:
self.prediction = tf.squeeze(output, -1)
def main():
g = Graph()
mixture = librosa.load('./data/' + hp.test_data, sr=hp.sample_rate)[0]
mixture_len = len(mixture) // hp.timestep
print mixture_len
mixture = np.expand_dims(mixture[:mixture_len * hp.timestep].reshape([-1,hp.timestep]),-1)
with g.graph.as_default():
with tf.Session() as sess:
saver = tf.train.Saver()
saver.restore(sess, tf.train.latest_checkpoint(hp.save_dir))
print "restore successfully!"
outputs = []
for part in mixture:
part = np.expand_dims(part, axis=0)
output = sess.run(g.prediction, feed_dict={g.x:part})
np.squeeze(output, axis=0)
outputs.append(output)
result = np.vstack(outputs).reshape(-1)
librosa.output.write_wav("./data/result.wav", result, sr=hp.sample_rate)
if __name__ == '__main__':
main()