Skip to content

Commit

Permalink
Pad dynamically when preparing the dataset for batch size autotuning (#…
Browse files Browse the repository at this point in the history
…962)

* Pad dynamically when preparing the dataset for batch size autotuning

* Fix structure
  • Loading branch information
guillaumekln authored Jul 29, 2022
1 parent 8f4bd61 commit 576ec37
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
20 changes: 19 additions & 1 deletion opennmt/inputters/inputter.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,14 +817,32 @@ def make_training_dataset(
padded_shapes = self.get_padded_shapes(
dataset.element_spec, maximum_length=maximum_length
)

# Dynamically pad each sequence to the maximum length.
def _pad_to_shape(tensor, padded_shape):
if tensor.shape.rank == 0:
return tensor
tensor_shape = misc.shape_list(tensor)
paddings = [
[0, padded_dim - tensor_dim]
if tf.is_tensor(tensor_dim) and padded_dim is not None
else [0, 0]
for tensor_dim, padded_dim in zip(tensor_shape, padded_shape)
]
return tf.pad(tensor, paddings)

dataset = dataset.map(
lambda *arg: tf.nest.map_structure(
_pad_to_shape, misc.item_or_tuple(arg), padded_shapes
)
)
dataset = dataset.apply(
dataset_util.batch_sequence_dataset(
batch_size,
batch_type=batch_type,
batch_multiplier=batch_multiplier,
length_bucket_width=1,
length_fn=constant_length_fn,
padded_shapes=padded_shapes,
)
)
return dataset
Expand Down
3 changes: 3 additions & 0 deletions opennmt/tests/inputter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,9 @@ def testBatchAutotuneDataset(self):
batch_autotune_mode=True,
)

source_spec, target_spec = dataset.element_spec
self.assertListEqual(source_spec["ids"].shape.as_list(), [None, None])

source, target = next(iter(dataset))
self.assertListEqual(source["ids"].shape.as_list(), [8, 100])
self.assertListEqual(target["ids"].shape.as_list(), [8, 120])
Expand Down

0 comments on commit 576ec37

Please sign in to comment.