Skip to content

Commit

Permalink
[New Model]add t5-encoder-model (#3168)
Browse files Browse the repository at this point in the history
* add t5-encoder-model

* update t5model

* update t5encoder & test modeling

* update t5

* update type hinting

* update cache type annotation
  • Loading branch information
wj-Mcat authored Sep 21, 2022
1 parent 68d7946 commit c64ed99
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 6 deletions.
109 changes: 106 additions & 3 deletions paddlenlp/transformers/t5/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import math
from typing import Optional, Tuple, Union, List

import numpy as np
import paddle
from paddle import Tensor

import paddle.nn as nn
import paddle.nn.functional as F
Expand All @@ -25,9 +28,8 @@
from ..nezha.modeling import ACT2FN

__all__ = [
'T5Model',
"T5PretrainedModel",
'T5ForConditionalGeneration',
'T5Model', "T5PretrainedModel", 'T5ForConditionalGeneration',
'T5EncoderModel'
]

T5_PRETRAINED_MODEL_ARCHIVE_LIST = [
Expand Down Expand Up @@ -1730,3 +1732,104 @@ def __getattr__(self, name):
return getattr(self, self.base_model_prefix).config[name]
except KeyError:
raise e


class T5EncoderModel(T5PretrainedModel):
base_model_class = None

def __init__(self,
vocab_size=32128,
d_model=768,
d_kv=64,
d_ff=3072,
num_layers=12,
num_heads=12,
relative_attention_num_buckets=32,
dropout_rate=0.1,
layer_norm_epsilon=1e-06,
feed_forward_proj="relu",
is_decoder: bool = False,
**kwargs):
super().__init__()
self.config = {
"vocab_size": vocab_size,
"d_model": d_model,
"d_kv": d_kv,
"d_ff": d_ff,
"num_layers": num_layers,
"num_heads": num_heads,
"relative_attention_num_buckets": relative_attention_num_buckets,
"dropout_rate": dropout_rate,
"layer_norm_epsilon": layer_norm_epsilon,
"feed_forward_proj": feed_forward_proj,
"is_decoder": is_decoder,
}
self.config.update(kwargs)
self.shared = nn.Embedding(vocab_size, d_model)

self.use_cache = False
self.is_encoder_decoder = False
self.encoder = T5Stack(d_model,
num_layers,
layer_norm_epsilon,
dropout_rate,
relative_attention_num_buckets,
d_kv,
num_heads,
feed_forward_proj,
d_ff,
embed_tokens=self.shared,
is_decoder=is_decoder)

# Initialize weights and apply final processing
self.init_weights()

def _post_init(self, *args, **kwargs):
"""
**prevent the `config` property to be assigned**
It would be hooked after `__init__` to add a dict including arguments of
`__init__` as a attribute named `config` of the pretrained model instance.
"""
pass

@property
def t5(self):
return self

This comment has been minimized.

Copy link
@ZHUI

ZHUI Jun 15, 2023

Collaborator

??

def get_input_embeddings(self):
return self.shared

def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
self.encoder.set_input_embeddings(new_embeddings)

def get_encoder(self):
return self.encoder

def forward(
self,
input_ids: Tensor = None,
attention_mask: Optional[Tensor] = None,
encoder_hidden_states: Optional[Tuple[Tensor]] = None,
encoder_attention_mask: Optional[Tensor] = None,
cache=None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
):
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
cache=cache,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)

return encoder_outputs


T5EncoderModel.base_model_class = T5EncoderModel
7 changes: 4 additions & 3 deletions tests/transformers/t5/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor

import paddle
from paddlenlp.transformers import T5ForConditionalGeneration, T5Model, T5Tokenizer
from paddlenlp.transformers import T5ForConditionalGeneration, T5Model, T5Tokenizer, T5EncoderModel
from paddlenlp.transformers.t5.modeling import T5_PRETRAINED_MODEL_ARCHIVE_LIST


Expand Down Expand Up @@ -500,9 +500,10 @@ def prepare_config_and_inputs_for_common(self):
class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
base_model_class = T5Model

all_model_classes = (T5Model, T5ForConditionalGeneration)
all_model_classes = (T5Model, T5ForConditionalGeneration, T5EncoderModel)
all_generative_model_classes = {T5ForConditionalGeneration: (T5Model, "t5")}
all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration)
all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration,
T5EncoderModel)
fx_compatible = True
test_pruning = False
test_resize_embeddings = True
Expand Down
1 change: 1 addition & 0 deletions tests/transformers/test_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@ def test_sample_generate(self):
output_generate[0].tolist())

def test_beam_search_generate(self):
paddle.seed(100)
for model_class in self.all_generative_model_classes.keys():
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(
)
Expand Down

0 comments on commit c64ed99

Please sign in to comment.