Skip to content

Commit

Permalink
add state_dict
Browse files Browse the repository at this point in the history
  • Loading branch information
FrostML committed Oct 28, 2022
1 parent 8377703 commit a924b77
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 32 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import os
import numpy as np
Expand Down Expand Up @@ -178,6 +192,11 @@ def do_predict(args):
transformer.load(init_from_params=os.path.join(args.init_from_params,
"transformer.pdparams"))

# Providing model_dict still works.
# state_dict = paddle.load(os.path.join(args.init_from_params,
# "transformer.pdparams"))
# transformer.load(state_dict=state_dict)

f = open(args.output_file, "w")
with paddle.no_grad():
if args.profile:
Expand Down
5 changes: 5 additions & 0 deletions examples/machine_translation/transformer/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ def do_predict(args):
transformer.load(os.path.join(args.init_from_params,
"transformer.pdparams"))

# Providing model_dict still works.
# state_dict = paddle.load(os.path.join(args.init_from_params,
# "transformer.pdparams"))
# transformer.load(state_dict=state_dict)

# Set evaluate mode
transformer.eval()

Expand Down
71 changes: 39 additions & 32 deletions paddlenlp/ops/faster_transformer/transformer/faster_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,45 +249,46 @@ def forward(self, src_word, trg_word=None):

return ids

def load(self, init_from_params):
def load(self, init_from_params=None, state_dict=None):
# Load the trained model
assert init_from_params, (
"Please set init_from_params to load the infer model.")
if init_from_params is None and state_dict is None:
raise ValueError(
"Either init_from_params or state_dict must be given to load the infer model. "
)

model_dict = paddle.load(init_from_params, return_numpy=True)
if state_dict is None:
state_dict = paddle.load(init_from_params, return_numpy=True)
else:
for state in state_dict:
# NOTE: This API only used in dygraph, so paddle.Tensor is enough.
if isinstance(state_dict[state], paddle.Tensor):
state_dict[state] = state_dict[state].numpy()

# To set weight[padding_idx] to 0.
model_dict["trg_word_embedding.word_embedding.weight"][
state_dict["trg_word_embedding.word_embedding.weight"][
self.bos_id] = [0] * self.d_model

# Dealing with weight sharing.
if self.weight_sharing:
model_dict["decoding_linear.weight"] = np.transpose(
model_dict["trg_word_embedding.word_embedding.weight"])
state_dict["decoding_linear.weight"] = np.transpose(
state_dict["trg_word_embedding.word_embedding.weight"])
else:
model_dict["decoding_linear.weight"] = model_dict["linear.weight"]

# To avoid a longer length than training, reset the size of position
# encoding to max_length
model_dict["encoder.pos_encoder.weight"] = position_encoding_init(
self.max_length, self.d_model)
model_dict["decoder.pos_encoder.weight"] = position_encoding_init(
self.max_length, self.d_model)
state_dict["decoding_linear.weight"] = state_dict["linear.weight"]

if self.decoding._fuse_qkv:
for item in self.state_dict():
if "decoder" in item and "self_attn.q_proj" in item:
num_layer = item.split(".")[3]
param_type = item.split(".")[-1]

model_dict["decoding.slf_q_" + param_type + "_" +
state_dict["decoding.slf_q_" + param_type + "_" +
num_layer] = np.concatenate(
(model_dict[item],
model_dict["transformer.decoder.layers." +
(state_dict[item],
state_dict["transformer.decoder.layers." +
num_layer +
".self_attn.k_proj." +
param_type],
model_dict["transformer.decoder.layers." +
state_dict["transformer.decoder.layers." +
num_layer +
".self_attn.v_proj." +
param_type]),
Expand All @@ -296,17 +297,17 @@ def load(self, init_from_params):
if self.use_fp16_decoding:
for item in self.state_dict():
if "decoder" in item or "decoding.slf" in item:
model_dict[item] = np.float16(model_dict[item])
model_dict["decoding_linear.weight"] = np.float16(
model_dict["decoding_linear.weight"])
model_dict["trg_word_embedding.word_embedding.weight"] = np.float16(
model_dict["trg_word_embedding.word_embedding.weight"])
model_dict["trg_pos_embedding.pos_encoder.weight"] = np.float16(
model_dict["trg_pos_embedding.pos_encoder.weight"])
model_dict["decoding_linear.bias"] = np.zeros([self.trg_vocab_size],
state_dict[item] = np.float16(state_dict[item])
state_dict["decoding_linear.weight"] = np.float16(
state_dict["decoding_linear.weight"])
state_dict["trg_word_embedding.word_embedding.weight"] = np.float16(
state_dict["trg_word_embedding.word_embedding.weight"])
state_dict["trg_pos_embedding.pos_encoder.weight"] = np.float16(
state_dict["trg_pos_embedding.pos_encoder.weight"])
state_dict["decoding_linear.bias"] = np.zeros([self.trg_vocab_size],
dtype="float16")

self.load_dict(model_dict)
self.load_dict(state_dict)

if self.enable_faster_encoder:
self = enable_faster_encoder(self, use_fp16=self.use_fp16_encoder)
Expand Down Expand Up @@ -695,12 +696,18 @@ def forward(self, src_word, trg_word=None):
out = paddle.transpose(out, [1, 0, 2])
return out

def load(self, path):
def load(self, path=None, state_dict=None):
if path is None and state_dict is None:
raise ValueError(
"Either path or state_dict must be given to load the infer model. "
)

if isinstance(self.transformer, FasterTransformer):
self.transformer.load(path)
self.transformer.load(path, state_dict)
else:
model_dict = paddle.load(path)
self.transformer.load_dict(model_dict)
if state_dict is None:
state_dict = paddle.load(path)
self.transformer.load_dict(state_dict)


class FasterOPT(OPTPretrainedModel):
Expand Down

0 comments on commit a924b77

Please sign in to comment.