Skip to content

Commit 6613c5f

Browse files
committed
fix data api
1 parent a741cee commit 6613c5f

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

GDWCT.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from utils import *
33
from glob import glob
44
import time
5-
from tensorflow.contrib.data import batch_and_drop_remainder
5+
from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch
66

77
class GDWCT(object) :
88
def __init__(self, sess, args):
@@ -270,8 +270,18 @@ def build_model(self):
270270
trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset)
271271
trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset)
272272

273-
trainA = trainA.prefetch(self.batch_size).shuffle(self.dataset_num).map(Image_Data_Class.image_processing, num_parallel_calls=8).apply(batch_and_drop_remainder(self.batch_size)).repeat()
274-
trainB = trainB.prefetch(self.batch_size).shuffle(self.dataset_num).map(Image_Data_Class.image_processing, num_parallel_calls=8).apply(batch_and_drop_remainder(self.batch_size)).repeat()
273+
gpu_device = '/gpu:0'
274+
275+
trainA = trainA.\
276+
apply(shuffle_and_repeat(self.dataset_num)). \
277+
apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)). \
278+
apply(prefetch_to_device(gpu_device, None))
279+
280+
trainB = trainB. \
281+
apply(shuffle_and_repeat(self.dataset_num)). \
282+
apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)). \
283+
apply(prefetch_to_device(gpu_device, None))
284+
# When using dataset.prefetch, use buffer_size=None to let it detect optimal buffer size
275285

276286
trainA_iterator = trainA.make_one_shot_iterator()
277287
trainB_iterator = trainB.make_one_shot_iterator()

0 commit comments

Comments
 (0)