Skip to content

Commit

Permalink
Final funnel transformer (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
shadowatyyy authored Oct 27, 2020
1 parent 2325781 commit 09bdb52
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 10 deletions.
2 changes: 2 additions & 0 deletions trax/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,5 @@ def model_configure(*args, **kwargs):
LSTMSeq2SeqAttn = model_configure(rnn.LSTMSeq2SeqAttn)
FunnelTransformerEncoder = model_configure(
funnel_transformer.FunnelTransformerEncoder)
FunnelTransformer = model_configure(
funnel_transformer.FunnelTransformer)
100 changes: 92 additions & 8 deletions trax/models/research/funnel_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@ def PoolLayer(pool_layer=tl.AvgPool,
) if separate_cls else pool_layer(pool_size, strides)


def _upsample(short, masks, long):
factor = -(-long.shape[1] // short.shape[1]) # ceil division
new_vecs = long + short.repeat(factor, axis=1)[:, :long.shape[1], :]
new_masks = masks.repeat(factor, axis=-1)[:, :, :, :long.shape[1]]
return new_vecs, new_masks


def _Upsampler():
return tl.Fn('Upsampler', _upsample, n_out=2)


def _FunnelBlock(d_model, d_ff, n_heads,
dropout, dropout_shared_axes, mode, ff_activation,
pool_layer, pool_size, strides, separate_cls):
Expand Down Expand Up @@ -83,19 +94,20 @@ def _FunnelBlock(d_model, d_ff, n_heads,
d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation)
pooling = PoolLayer(pool_layer, pool_size, strides, separate_cls)

return tl.Serial( # h, mask
tl.Branch(pooling, None, None), # h', h, h, mask
tl.Dup(), # h', h', h, h, mask
return tl.Serial( # h, mask
tl.Branch(pooling, None, None), # h', h, h, mask
tl.Dup(), # h', h', h, h, mask
tl.Parallel(
None,
attention
), # h', attention(...), mask
tl.Add(), # h'+attention(...), mask
tl.LayerNorm(), # funnel_activations, mask
), # h', attention(...), mask
tl.Add(), # h'+attention(...), mask
tl.LayerNorm(), # funnel_activations, mask
tl.Parallel(
None,
tl.Fn('max pool experiment', _InternalMaxPool),
), # funnel_activations, mask'
tl.Fn('max pool experiment',
_InternalMaxPool),
), # funnel_activations, mask'
feed_forward
)

Expand Down Expand Up @@ -187,3 +199,75 @@ def _FunnelResidualBlock(d_model, d_ff, n_heads,
feed_forward
)
]


def FunnelTransformer(vocab_size,
d_model=512, # start
d_ff=2048,
encoder_segment_lengths=(2, 2, 2),
n_decoder_blocks=2,
n_heads=8,
max_len=2048,
dropout=0.1,
dropout_shared_axes=None,
mode='train',
ff_activation=tl.Relu,
pool_layer=tl.AvgPool,
pool_size=(2,),
strides=(2,),
separate_cls=True):
"""Returns a Full Funnel Transformer.
"""
segments = len(encoder_segment_lengths)
funnels = segments - 1
assert (funnels >= 0)

positional_encoder = [
tl.Embedding(vocab_size, d_model),
tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode),
tl.PositionalEncoding(max_len=max_len)]

n_encoder_segments = len(encoder_segment_lengths)

encoder_blocks_before_first_pooling = [
_EncoderBlock(d_model, d_ff, n_heads, dropout,
dropout_shared_axes, mode, ff_activation)
for _ in range(encoder_segment_lengths[0])]
encoder_blocks_from_first_pooling = []

for i in range(1, n_encoder_segments):
# Building i'th segment

# add funnel block between segments
encoder_blocks_from_first_pooling.append(
_FunnelBlock(d_model, d_ff, n_heads, dropout,
dropout_shared_axes, mode,
ff_activation, pool_layer, pool_size,
strides, separate_cls))

for _ in range(encoder_segment_lengths[i]):
# segment_size encoder blocks
encoder_blocks_from_first_pooling.append(
_EncoderBlock(d_model, d_ff, n_heads, dropout,
dropout_shared_axes, mode, ff_activation))

decoder_blocks = [_EncoderBlock(d_model, d_ff, n_heads, dropout,
dropout_shared_axes, mode, ff_activation)
for _ in range(n_decoder_blocks)]

# Assemble and return the model.
return tl.Serial( # toks
tl.Branch(
positional_encoder, tl.PaddingMask()), # vecs masks
encoder_blocks_before_first_pooling, # vecs masks
tl.Select([0, 1, 0]), # vecs masks residual = vecs
encoder_blocks_from_first_pooling, # vecs masks residual
tl.Parallel(
# residual from first segment is taken before
# normalization, so apply it now
None, None, tl.LayerNorm()), # vecs masks norm(residual)
_Upsampler(), # vecs masks
decoder_blocks,
tl.Select([0], n_in=2), # vecs
tl.LayerNorm(),
)
13 changes: 11 additions & 2 deletions trax/models/research/funnel_transformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from trax import layers as tl, shapes
from trax.models.research.funnel_transformer import PoolLayer, \
_FunnelResidualBlock, \
FunnelTransformerEncoder
FunnelTransformerEncoder, \
FunnelTransformer


class FunnelTransformerTest(parameterized.TestCase):
Expand Down Expand Up @@ -59,7 +60,7 @@ def test_funnel_block_forward_shape(self):

self.assertEqual(y.shape, (1, n_even // 2, d_model))

def test_funnel_transformer_forward_shape(self):
def test_funnel_transformer_encoder_forward_shape(self):
n_classes = 2
model = FunnelTransformerEncoder(10, n_classes)

Expand All @@ -69,6 +70,14 @@ def test_funnel_transformer_forward_shape(self):

self.assertEqual(y.shape, (3, n_classes))

def test_funnel_transformer_forward_shape(self):
model = FunnelTransformer(10)

x = np.ones((3, 64), dtype=np.int32)
_ = model.init(shapes.signature(x))
y = model(x)

self.assertEqual(y.shape, (3, 64, 512))

if __name__ == '__main__':
absltest.main()

0 comments on commit 09bdb52

Please sign in to comment.