Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Funnel-Transformer #1156

Merged
merged 9 commits into from
Nov 20, 2020
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion trax/layers/combinators.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class Parallel(base.Layer):
Then Parallel(F, G, H) will take 6 inputs and give 4 outputs:

- inputs: a, b, c, d, e, f
- outputs: F(a), G(b, c, d), h1, h2
- outputs: F(a), G(b, c, d), h1, h2 where h1, h2 = H(e, f)

As an important special case, a None argument to Parallel acts as if it takes
one argument, which it leaves unchanged. (It acts as a one-arg no-op.) For
Expand Down
5 changes: 5 additions & 0 deletions trax/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from trax.models.research import layerdrop_transformer
from trax.models.research import rezero
from trax.models.research import transformer2
from trax.models.research import funnel_transformer


# Ginify
Expand Down Expand Up @@ -86,3 +87,7 @@ def model_configure(*args, **kwargs):
RNNLM = model_configure(rnn.RNNLM)
GRULM = model_configure(rnn.GRULM)
LSTMSeq2SeqAttn = model_configure(rnn.LSTMSeq2SeqAttn)
FunnelTransformerEncoder = model_configure(
funnel_transformer.FunnelTransformerEncoder)
FunnelTransformer = model_configure(
funnel_transformer.FunnelTransformer)
284 changes: 284 additions & 0 deletions trax/models/research/funnel_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
# coding=utf-8
# Copyright 2020 The Trax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Funnel Transformer model.

Funnel-Transformer: Filtering out Sequential Redundancy for Efficient
Language Processing https://arxiv.org/abs/2006.03236 """
from trax import layers as tl
from trax.layers.assert_shape import assert_shape
from trax.models.transformer import _EncoderBlock, _FeedForwardBlock


@assert_shape('bld->bSd')
def PoolLayer(pool_layer=tl.AvgPool,
pool_size=(2,),
strides=(2,),
separate_cls=True):
if separate_cls:
cls_selection = tl.Fn('select_cls_token', lambda x: x[:, :1, :])
tokens_after_cls = tl.Fn('rest_tokens', lambda x: x[:, 1:, :])

return tl.Serial(
tl.Branch(
cls_selection,
tl.Serial(
tokens_after_cls,
pool_layer(pool_size, strides)
)
),
tl.Concatenate(axis=1)
)
else:
return pool_layer(pool_size, strides)


@assert_shape('b11l->b11S')
def MaskPool(pool_size=(2,), strides=(2,), separate_cls=True):
return tl.Serial(
tl.Fn('reshape', lambda x: x.swapaxes(1, -1).squeeze(axis=-1)),
PoolLayer(tl.MaxPool, pool_size, strides, separate_cls),
tl.Fn('reshape_back', lambda x: x[..., None].swapaxes(1, -1))
)


@assert_shape('bld->bd')
def SelectFirst():
return tl.Fn('select_first', lambda x: x[:, 0, :])


def _Upsample(short, long):
factor = -(-long.shape[1] // short.shape[1]) # ceil division
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this works as upsampling; short counterexample below. I think it would be reasonable just to assert that input length is always divisible by pool_size wherever we use pooling (input length is usually a power of two anyway). Usually the input length in training/evaluation is a power of 2 anyway, I think.

Counterexample: input_size=31, pool_size=2, single decoding block. AvgPool with default padding ("valid", not "same") will produce tensor of length 15. After this downsampling block we will have upsampling with factor = 3. This is out of sync with pool_size - we will get tokens mapped to wrong positions in "short.repeat".
While this counterexample would be fixed by changing the padding to "same", I have a feeling that we should just assert that input length at every stage will be divisible by pool_size. What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Asserting that input length is divisible by pool_size would be fine for FunnelTransformer with upsampling, but we use the same pooler for the version without it. We could modify pooler to consider whether its output will be upsampled later, but I'm not sure if it's worth it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What could be done fairly easily is to insert in _Upsample function check that "if long.shape[1] % short.shape[1] != 0: raise ValueError('message')." Then you don't need to modify anything but _Upsample function; downsampling isn't touched at all, and we can be sure that upsampling works correctly.
The current issue I have with this _Upsample function is that it computes wrong results (see my counterexample), and I think that throwing an exception is much better than silently returning wrong results.

Adding this check is also much easier than correcting the implementation - this would involve passing a pool_size/stride to _Upsample to be used in place of 'factor', and some padding (see the counterexample in my previous comment). I think that correcting the implementation may not be worth the effort, but adding an assert is worth it.

(Also, while it isn't necessary, I would consider adding this kind of assert even during downsampling. Let's consider the case when we have only downsampling, with pool_size=2, and input length not divisible by 2. Current downsampling simply throws away the last token (due to a padding "valid", which is the default), which is kind of a strange behaviour.)

new_vecs = long + short.repeat(factor, axis=1)[:, :long.shape[1], :]
return new_vecs


def _Upsampler():
return tl.Fn('Upsampler', _Upsample)


def _FunnelBlock(d_model, d_ff, n_heads,
dropout, dropout_shared_axes, mode, ff_activation,
pool_layer, pool_size, strides, separate_cls):
"""Internal funnel block. On input it takes (activations, masks).

Args:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you include arguments "pool_layer", "separate_cls" in the description?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

d_model: Final dimension of tensors at most points in the model, including
the initial embedding output.
d_ff: Size of special dense layer in the feed-forward part of each block.
n_heads: Number of attention heads.
dropout: Stochastic rate (probability) for dropping an activation value
when applying dropout within a block.
dropout_shared_axes: Tensor axes on which to share a dropout mask.
Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is
a useful way to save memory and apply consistent masks to activation
vectors at different sequence positions.
mode: If `'train'`, each block will include dropout; else, it will
pass all values through unaltered.
ff_activation: Type of activation function at the end of each block; must
be an activation-type subclass of `Layer`.
pool_size: Shape of window that gets reduced to a single vector value.
If the layer inputs are :math:`n`-dimensional arrays, then `pool_size`
must be a tuple of length :math:`n-2`.
strides: Offsets from the location of one window to the locations of
neighboring windows along each axis. If specified, must be a tuple of
the same length as `pool_size`. If None, then offsets of 1 along each
window axis, :math:`(1, ..., 1)`, will be used.
Returns:
A list of layers that maps (activations, mask) to (activations', mask).
"""
attention = tl.AttentionQKV(
d_feature=d_model, n_heads=n_heads, dropout=dropout, mode=mode)
feed_forward = _FeedForwardBlock(
d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation)
pooling = PoolLayer(pool_layer, pool_size, strides, separate_cls)
mask_pooling = MaskPool(pool_size, strides, separate_cls)

return tl.Serial( # h, mask
tl.Branch(pooling, None, None), # h', h, h, mask
tl.Dup(), # h', h', h, h, mask
tl.Parallel(
None,
attention
), # h', attention(...), mask
tl.Add(), # h'+attention(...), mask
tl.LayerNorm(), # funnel_activations, mask
tl.Parallel(
None,
mask_pooling
), # funnel_activations, mask'
feed_forward
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we have a residual connection here? The paper doesn't mention anything like that, and their PyTorch code seems to include it. If this is intended, can you write a brief comment about it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The paper authors are not clear on this indeed, but they mention:

To inherit the high capacity and optimization advantages of the Transformer architecture, the proposed model keeps the same overall skeleton of interleaved S-Attn and P-FFN sub-modules wrapped by residual connection and layer normalization.

We decided to merge _FunnelBlock and _FunnelResidualBlock into one, with two residuals and two layer norms, similarly to the original PyTorch code (the only difference is they have a LN at the end of P-FFN, and Trax P-FFN seems to have one at the beginning, so we make up for it with a LN at the very beginning of the funnel block).

)


def FunnelTransformerEncoder(vocab_size,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add documentation to those models, along with argument description like in _FunnelBlock? This is applicable also to _FunnelResidualBlock and FunnelTransformer .

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstrings have been added.

n_classes=10,
d_model=512,
d_ff=2048,
encoder_segment_lengths=(2, 2, 2),
n_heads=8,
max_len=2048,
dropout=0.1,
dropout_shared_axes=None,
mode='train',
ff_activation=tl.Relu,
pool_layer=tl.AvgPool,
pool_size=(2,),
strides=(2,),
separate_cls=True):
"""Returns a Funnel Encoder.
"""
assert encoder_segment_lengths

positional_encoder = [
tl.Embedding(vocab_size, d_model),
tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode),
tl.PositionalEncoding(max_len=max_len)]

encoder_blocks = []
n_encoder_segments = len(encoder_segment_lengths)

for i in range(n_encoder_segments):
# Building i'th segment
for _ in range(encoder_segment_lengths[i]):
# Create segment_size encoder blocks
encoder_blocks.append(
_EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
mode, ff_activation))

# If not last segment, add funnel block
if i != n_encoder_segments - 1:
encoder_blocks.append(_FunnelBlock(d_model, d_ff, n_heads, dropout,
dropout_shared_axes, mode,
ff_activation, pool_layer, pool_size,
strides, separate_cls))

cls_pooling = SelectFirst() if separate_cls else tl.Mean(axis=1)

# Assemble and return the model.
return tl.Serial( # toks
# Encode.
tl.Branch(
positional_encoder, tl.PaddingMask()), # vecs masks
encoder_blocks, # vecs masks
tl.Select([0], n_in=2), # vecs
tl.LayerNorm(), # vecs

# Map to output categories.
cls_pooling, # cls
tl.Dense(n_classes), # cls
tl.LogSoftmax(), # cls
)


def _FunnelResidualBlock(d_model, d_ff, n_heads,
dropout, dropout_shared_axes, mode, ff_activation,
pool_layer, pool_size, strides):
feed_forward = _FeedForwardBlock(
d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation)

dropout_ = tl.Dropout(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do variables have "_" suffix here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed - dropout_ layer variable name was taken from the original TransformerEncoder to avoid shadowing dropout rate argument, replaced by a more meaningful hidden_dropout.

rate=dropout, shared_axes=dropout_shared_axes, mode=mode)

attn_ = tl.AttentionQKV(d_model, n_heads=n_heads, dropout=dropout, mode=mode)

pooling_ = PoolLayer(pool_layer, pool_size, strides)

return [
tl.Parallel(tl.Branch(pooling_, None), None),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not very important, but can this be replaced by Select + pooling? I think it will be clearer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This Parallel actually looks like a no-op so I removed it, but I would prefer not to apply pooling inside the residual with attention (Select is used there to split into Q, K, V).

tl.Residual(
tl.Parallel(tl.LayerNorm(), tl.LayerNorm()),
tl.Select([0, 1, 1, 2]),
attn_,
tl.Parallel(None, MaskPool()),
dropout_
),
tl.Residual(
feed_forward
)
]


def FunnelTransformer(vocab_size,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename it to FunnelTransformerDecoder, to keep naming consistent with models/transformer.py ? It seems closer to TransformerDecoder than Transformer, since the former outputs an embedding per token (like this Funnel class) and the latter predicts a class per token.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per our previous discussion, we changed the FunnelTransformer to output token-level categorical distribution over vocab instead of embeddings, which makes it useful for example as a BERT.

d_model=512,
d_ff=2048,
encoder_segment_lengths=(2, 2, 2),
n_decoder_blocks=2,
n_heads=8,
max_len=2048,
dropout=0.1,
dropout_shared_axes=None,
mode='train',
ff_activation=tl.Relu,
pool_layer=tl.AvgPool,
pool_size=(2,),
separate_cls=True):
"""Returns a Full Funnel Transformer.
"""
assert encoder_segment_lengths

positional_encoder = [
tl.Embedding(vocab_size, d_model),
tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode),
tl.PositionalEncoding(max_len=max_len)]

n_encoder_segments = len(encoder_segment_lengths)

encoder_blocks_before_first_pooling = [
_EncoderBlock(d_model, d_ff, n_heads, dropout,
dropout_shared_axes, mode, ff_activation)
for _ in range(encoder_segment_lengths[0])]
encoder_blocks_from_first_pooling = []

for i in range(1, n_encoder_segments):
# Building i'th segment

# Add funnel block between segments
encoder_blocks_from_first_pooling.append(
_FunnelBlock(d_model, d_ff, n_heads, dropout,
dropout_shared_axes, mode,
ff_activation, pool_layer,
pool_size=pool_size, strides=pool_size,
separate_cls=separate_cls))

for _ in range(encoder_segment_lengths[i]):
# Create segment_size encoder blocks
encoder_blocks_from_first_pooling.append(
_EncoderBlock(d_model, d_ff, n_heads, dropout,
dropout_shared_axes, mode, ff_activation))

decoder_blocks = [_EncoderBlock(d_model, d_ff, n_heads, dropout,
dropout_shared_axes, mode, ff_activation)
for _ in range(n_decoder_blocks)]

# Assemble and return the model.
return tl.Serial( # toks
tl.Branch(
positional_encoder, tl.PaddingMask()), # vecs masks
encoder_blocks_before_first_pooling, # vecs masks
tl.Select([0, 1, 0, 1]),
# vecs masks residual = vecs old_masks
encoder_blocks_from_first_pooling, # vecs masks residual masks
tl.Select([0, 2, 3]), # vecs residual masks
tl.Parallel(
# residual from first segment is taken before
# normalization, so apply it now
None, tl.LayerNorm(), None), # vecs norm(residual) masks
_Upsampler(), # vecs masks
decoder_blocks,
tl.Select([0], n_in=2), # vecs
tl.LayerNorm(),
)
Loading