|
2 | 2 | from utils import *
|
3 | 3 | from glob import glob
|
4 | 4 | import time
|
5 |
| -from tensorflow.contrib.data import batch_and_drop_remainder |
| 5 | +from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch |
6 | 6 |
|
7 | 7 | class GDWCT(object) :
|
8 | 8 | def __init__(self, sess, args):
|
@@ -270,8 +270,18 @@ def build_model(self):
|
270 | 270 | trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset)
|
271 | 271 | trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset)
|
272 | 272 |
|
273 |
| - trainA = trainA.prefetch(self.batch_size).shuffle(self.dataset_num).map(Image_Data_Class.image_processing, num_parallel_calls=8).apply(batch_and_drop_remainder(self.batch_size)).repeat() |
274 |
| - trainB = trainB.prefetch(self.batch_size).shuffle(self.dataset_num).map(Image_Data_Class.image_processing, num_parallel_calls=8).apply(batch_and_drop_remainder(self.batch_size)).repeat() |
| 273 | + gpu_device = '/gpu:0' |
| 274 | + |
| 275 | + trainA = trainA.\ |
| 276 | + apply(shuffle_and_repeat(self.dataset_num)). \ |
| 277 | + apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)). \ |
| 278 | + apply(prefetch_to_device(gpu_device, None)) |
| 279 | + |
| 280 | + trainB = trainB. \ |
| 281 | + apply(shuffle_and_repeat(self.dataset_num)). \ |
| 282 | + apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)). \ |
| 283 | + apply(prefetch_to_device(gpu_device, None)) |
| 284 | + # When using dataset.prefetch, use buffer_size=None to let it detect optimal buffer size |
275 | 285 |
|
276 | 286 | trainA_iterator = trainA.make_one_shot_iterator()
|
277 | 287 | trainB_iterator = trainB.make_one_shot_iterator()
|
|
0 commit comments