Skip to content

Commit c00d28f

Browse files
authored
Add MBart unit tests (PaddlePaddle#3323)
1 parent 90c6f1d commit c00d28f

15 files changed

+1388
-193
lines changed

faster_generation/samples/mbart_sample.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import paddle
15-
from paddlenlp.transformers import MBartForConditionalGeneration, MBartTokenizer
15+
from paddlenlp.transformers import MBartForConditionalGeneration, MBart50Tokenizer
1616

1717
model_name = "mbart-large-50-many-to-many-mmt"
1818

19-
tokenizer = MBartTokenizer.from_pretrained(model_name, src_lang="en_XX")
19+
tokenizer = MBart50Tokenizer.from_pretrained(model_name, src_lang="en_XX")
2020
model = MBartForConditionalGeneration.from_pretrained(model_name)
2121
model.eval()
2222

paddlenlp/transformers/auto/tokenizer.py

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
("LayoutLMTokenizer", "layoutlm"),
5555
("LukeTokenizer", "luke"),
5656
("MBartTokenizer", "mbart"),
57+
("MBart50Tokenizer", "mbart"),
5758
("MegatronBertTokenizer", "megatronbert"),
5859
("MobileBertTokenizer", "mobilebert"),
5960
("MPNetTokenizer", "mpnet"),

paddlenlp/transformers/mbart/modeling.py

+6
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,12 @@ def get_encoder(self):
501501
def get_decoder(self):
502502
return self.decoder
503503

504+
def get_input_embeddings(self):
505+
return self.shared
506+
507+
def set_input_embeddings(self, value):
508+
self.shared = value
509+
504510
def forward(self,
505511
input_ids,
506512
attention_mask=None,

0 commit comments

Comments
 (0)