Skip to content

Commit

Permalink
Merge remote-tracking branch 'qinziang/fuse' into dev-fuse-qkv
Browse files Browse the repository at this point in the history
  • Loading branch information
DrownFish19 committed Mar 28, 2024
2 parents 4d49a3e + 05507a7 commit b8b828b
Show file tree
Hide file tree
Showing 11 changed files with 440 additions and 33 deletions.
39 changes: 39 additions & 0 deletions paddlenlp/transformers/bert/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from __future__ import annotations

# import re
import warnings
from typing import Optional, Tuple

Expand Down Expand Up @@ -191,6 +192,14 @@ def _get_name_mappings(cls, config: BertConfig) -> list[StateDictNameMapping]:
f"encoder.layer.{layer_index}.attention.self.value.bias",
f"encoder.layers.{layer_index}.self_attn.v_proj.bias",
],
[
f"encoder.layer.{layer_index}.attention.self.query-key-value.weight",
f"encoder.layers.{layer_index}.self_attn.qkv_proj.weight",
],
[
f"encoder.layer.{layer_index}.attention.self.query-key-value.bias",
f"encoder.layers.{layer_index}.self_attn.qkv_proj.bias",
],
[
f"encoder.layer.{layer_index}.attention.output.dense.weight",
f"encoder.layers.{layer_index}.self_attn.out_proj.weight",
Expand Down Expand Up @@ -248,6 +257,36 @@ def _get_name_mappings(cls, config: BertConfig) -> list[StateDictNameMapping]:
mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)]
return mappings

# @classmethod
# def _get_fused_param_mappings(cls):
# # return parameter fuse utils
# from paddlenlp.transformers.conversion_utils import (
# merged_as_tensor_parallel_qkv,
# )

# # attention: q,k,v -> qkv, ffn: gate, up -> gate_up
# mappings = {
# "fuse_action": [merged_as_tensor_parallel_qkv, None],
# "split_action": [None, None],
# "attn_param_names": {
# "qkv_proj": lambda layer_id: re.sub(
# r"\d+", str(layer_id), "bert.encoder.layers.0.self_attn.qkv_proj.weight"
# ),
# "q_proj": lambda layer_id: re.sub(
# r"\d+", str(layer_id), "bert.encoder.layers.0.self_attn.q_proj.weight"
# ),
# "k_proj": lambda layer_id: re.sub(
# r"\d+", str(layer_id), "bert.encoder.layers.0.self_attn.k_proj.weight"
# ),
# "v_proj": lambda layer_id: re.sub(
# r"\d+", str(layer_id), "bert.encoder.layers.0.self_attn.v_proj.weight"
# ),
# },
# "ffn_param_names": {"gate_up_proj": None, "gate_proj": None, "up_proj": None},
# }

# return mappings

def _init_weights(self, layer):
"""Initialization hook"""
if isinstance(layer, (nn.Linear, nn.Embedding)):
Expand Down
4 changes: 4 additions & 0 deletions paddlenlp/transformers/bloom/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ def __init__(
long_sequence_strategy_name=None,
long_sequence_init_args=None,
use_long_sequence_strategies=False,
fuse_attention_qkv: bool = False,
fuse_attention_ffn: bool = False,
**kwargs,
):

Expand Down Expand Up @@ -159,3 +161,5 @@ def __init__(
self.long_sequence_strategy_name = long_sequence_strategy_name
self.long_sequence_init_args = {} if long_sequence_init_args is None else long_sequence_init_args
self.use_long_sequence_strategies = use_long_sequence_strategies
self.fuse_attention_qkv = fuse_attention_qkv
self.fuse_attention_ffn = fuse_attention_ffn
97 changes: 83 additions & 14 deletions paddlenlp/transformers/conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import inspect
import json
import os
import re
from copy import deepcopy
from dataclasses import dataclass
from typing import (
Expand Down Expand Up @@ -489,6 +490,24 @@ def splited_qkv_to_tensor_parallel_qkv(weight_list, num_attention_heads):
return naive_merged_qkv_to_tensor_parallel_qkv(weight)


def merged_as_tensor_parallel_qkv(state_dict, q_name, k_name, v_name, num_hidden_layers):
q = state_dict[q_name]
k = state_dict[k_name]
v = state_dict[v_name]

Check warning on line 496 in paddlenlp/transformers/conversion_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/conversion_utils.py#L494-L496

Added lines #L494 - L496 were not covered by tests

naive_merged_qkv = np.concatenate((q, k, v), axis=-1)

Check warning on line 498 in paddlenlp/transformers/conversion_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/conversion_utils.py#L498

Added line #L498 was not covered by tests

return naive_merged_qkv_to_tensor_parallel_qkv(naive_merged_qkv, num_hidden_layers)

Check warning on line 500 in paddlenlp/transformers/conversion_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/conversion_utils.py#L500

Added line #L500 was not covered by tests


def merge_as_naive_merged_qkv():
pass

Check warning on line 504 in paddlenlp/transformers/conversion_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/conversion_utils.py#L504

Added line #L504 was not covered by tests


def merge_as_splited_qkv():
pass

Check warning on line 508 in paddlenlp/transformers/conversion_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/conversion_utils.py#L508

Added line #L508 was not covered by tests


def get_tensor_parallel_merge_func(tensor_parallel_degree, tensor_parallel_rank, num_attention_heads=None):
def fn(
x,
Expand Down Expand Up @@ -1082,10 +1101,19 @@ def _get_name_mappings(cls, config: PretrainedConfig) -> List[StateDictNameMappi

@classmethod
def get_tensor_parallel_convert_actions(
cls, config: PretrainedConfig, loaded_state_dict_keys, is_split=True, ignore_error=False
cls, config: PretrainedConfig, loaded_state_dict_keys, is_split=True, ignore_error=False, ignore_params=[]
):
name_action_mappings = cls._get_tensor_parallel_mappings(config, is_split=is_split)

# avoid act on fuse parameters (qkv/gate-up), they are not consistant between config and loaded_state_dict_keys
name_map_list = cls._get_name_mappings(config)
for key in ignore_params:
for name_map in name_map_list:
if name_map.target_name == key:
name_action_mappings.pop(name_map.source_name.split("model.")[-1], None)

Check warning on line 1113 in paddlenlp/transformers/conversion_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/conversion_utils.py#L1109-L1113

Added lines #L1109 - L1113 were not covered by tests

state_keys_map = cls._resolve_prefix_keys(name_action_mappings.keys(), loaded_state_dict_keys, ignore_error)

for k, v in state_keys_map.items():
name_action_mappings[v] = name_action_mappings.pop(k)
return name_action_mappings
Expand All @@ -1100,26 +1128,67 @@ def convert_tensor_parallel(
weight_file (str | None): the weight file path of `model_state.pdparams` file
config (PretrainedConfig): the PretrainedConfig instance of model
"""
name_action_mappings = cls._get_tensor_parallel_mappings(config)

def _apply_tp_action(name_action_mappings):
state_keys_map = cls._resolve_prefix_keys(name_action_mappings.keys(), state_dict.keys(), ignore_error)

Check warning on line 1133 in paddlenlp/transformers/conversion_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/conversion_utils.py#L1132-L1133

Added lines #L1132 - L1133 were not covered by tests

for k, v in state_keys_map.items():
name_action_mappings[v] = name_action_mappings.pop(k)

Check warning on line 1136 in paddlenlp/transformers/conversion_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/conversion_utils.py#L1135-L1136

Added lines #L1135 - L1136 were not covered by tests

for name, action in name_action_mappings.items():
if name not in state_dict:
if not ignore_error:
logger.warning(f"Key <{name}> not in the model state weight file.")
continue
tensor = state_dict.pop(name)
new_tensor = action(tensor)
with device_guard("cpu"):
state_dict[name] = paddle.Tensor(new_tensor, zero_copy=True)

Check warning on line 1146 in paddlenlp/transformers/conversion_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/conversion_utils.py#L1138-L1146

Added lines #L1138 - L1146 were not covered by tests

if state_dict is None:
with device_guard("cpu"):
state_dict = paddle.load(weight_file, return_numpy=False)
logger.info("Starting to convert orignal state_dict to tensor parallel state_dict.")

state_keys_map = cls._resolve_prefix_keys(name_action_mappings.keys(), state_dict.keys(), ignore_error)
from paddlenlp.transformers.model_utils import select_fuse_parameter

Check warning on line 1153 in paddlenlp/transformers/conversion_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/conversion_utils.py#L1153

Added line #L1153 was not covered by tests

for k, v in state_keys_map.items():
name_action_mappings[v] = name_action_mappings.pop(k)
do_fuse_parameter_list, do_separate_parameter_list = select_fuse_parameter(cls, state_dict.keys(), config)
if "attention_qkv_proj" in do_fuse_parameter_list:
state_dict, fuse_success = cls.fuse_attention_parameters(

Check warning on line 1157 in paddlenlp/transformers/conversion_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/conversion_utils.py#L1155-L1157

Added lines #L1155 - L1157 were not covered by tests
state_dict, ["attention_qkv_proj"], config
) # design: q, k, v => qkv

for name, action in name_action_mappings.items():
if name not in state_dict:
if not ignore_error:
logger.warning(f"Key <{name}> not in the model state weight file.")
continue
tensor = state_dict.pop(name)
new_tensor = action(tensor)
with device_guard("cpu"):
state_dict[name] = paddle.Tensor(new_tensor, zero_copy=True)
name_action_mappings = cls._get_tensor_parallel_mappings(config)

Check warning on line 1161 in paddlenlp/transformers/conversion_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/conversion_utils.py#L1161

Added line #L1161 was not covered by tests

# avoid act on fuse parameters (qkv/gate-up), they are not consistant between config and loaded_state_dict_keys
# pop qkv tp actions and apply the rest actions
if "attention_qkv_proj" in do_fuse_parameter_list:

Check warning on line 1165 in paddlenlp/transformers/conversion_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/conversion_utils.py#L1165

Added line #L1165 was not covered by tests

name_map_list = [

Check warning on line 1167 in paddlenlp/transformers/conversion_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/conversion_utils.py#L1167

Added line #L1167 was not covered by tests
lambda layer_id: re.sub(r"\d+", str(layer_id), "layers.0.self_attn.q_proj.weight"),
lambda layer_id: re.sub(r"\d+", str(layer_id), "layers.0.self_attn.k_proj.weight"),
lambda layer_id: re.sub(r"\d+", str(layer_id), "layers.0.self_attn.v_proj.weight"),
lambda layer_id: re.sub(r"\d+", str(layer_id), "layers.0.self_attn.qkv_proj.weight"),
]
tp_action_keys = list(name_action_mappings.keys())
poped_param_names = []
for key in tp_action_keys:
for name_map in name_map_list:
if re.sub(r"\d+", "0", key) == name_map(0):
name_action_mappings.pop(key, None)
poped_param_names.append(key)

Check warning on line 1179 in paddlenlp/transformers/conversion_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/conversion_utils.py#L1173-L1179

Added lines #L1173 - L1179 were not covered by tests

_apply_tp_action(name_action_mappings)

Check warning on line 1181 in paddlenlp/transformers/conversion_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/conversion_utils.py#L1181

Added line #L1181 was not covered by tests

# tail processing qkv parameters
if "attention_qkv_proj" in do_fuse_parameter_list:
name_action_mappings_fuse = cls._get_tensor_parallel_mappings(config)
tp_action_fuse_keys = list(name_action_mappings_fuse.keys())
for key in tp_action_fuse_keys:
if key not in poped_param_names:
name_action_mappings_fuse.pop(key, None)

Check warning on line 1189 in paddlenlp/transformers/conversion_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/conversion_utils.py#L1184-L1189

Added lines #L1184 - L1189 were not covered by tests

_apply_tp_action(name_action_mappings_fuse)

Check warning on line 1191 in paddlenlp/transformers/conversion_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/conversion_utils.py#L1191

Added line #L1191 was not covered by tests

return state_dict

Expand Down
4 changes: 4 additions & 0 deletions paddlenlp/transformers/glm/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ def __init__(
pool_token="cls",
layernorm_epsilon=1e-5,
use_scaled_init_for_output_weights=False,
fuse_attention_qkv=False,
fuse_attention_ffn=False,
**kwargs
):
super().__init__(**kwargs)
Expand All @@ -250,3 +252,5 @@ def __init__(
self.layernorm_epsilon = layernorm_epsilon
self.use_scaled_init_for_output_weights = use_scaled_init_for_output_weights
self._fast_entry = None
self.fuse_attention_qkv = fuse_attention_qkv
self.fuse_attention_ffn = fuse_attention_ffn
31 changes: 31 additions & 0 deletions paddlenlp/transformers/gpt/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import collections
import contextlib
import math
import re
from functools import partial

import numpy as np
Expand Down Expand Up @@ -803,6 +804,36 @@ def get_tensor_parallel_split_mappings(num_layers):

return mappings

@classmethod
def _get_fused_param_mappings(cls):
# return parameter fuse utils
from paddlenlp.transformers.conversion_utils import (
merged_as_tensor_parallel_qkv,
)

# attention: q,k,v -> qkv, ffn: gate, up -> gate_up
mappings = {
"fuse_action": [merged_as_tensor_parallel_qkv, None],
"split_action": [None, None],
"attn_param_names": {
"qkv_proj": lambda layer_id: re.sub(
r"\d+", str(layer_id), "gpt.decoder.layers.0.self_attn.qkv_proj.weight"
),
"q_proj": lambda layer_id: re.sub(
r"\d+", str(layer_id), "gpt.decoder.layers.0.self_attn.q_proj.weight"
),
"k_proj": lambda layer_id: re.sub(
r"\d+", str(layer_id), "gpt.decoder.layers.0.self_attn.k_proj.weight"
),
"v_proj": lambda layer_id: re.sub(
r"\d+", str(layer_id), "gpt.decoder.layers.0.self_attn.v_proj.weight"
),
},
"ffn_param_names": {"gate_up_proj": None, "gate_proj": None, "up_proj": None},
}

return mappings

@classmethod
def _get_name_mappings(cls, config: GPTConfig) -> list[StateDictNameMapping]:
mappings: list[StateDictNameMapping] = []
Expand Down
33 changes: 32 additions & 1 deletion paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from __future__ import annotations

import math
import re
import warnings
from functools import partial
from typing import Optional, Tuple
Expand Down Expand Up @@ -1198,11 +1199,13 @@ def _get_name_mappings(cls, config: LlamaConfig) -> list[StateDictNameMapping]:
[f"layers.{layer_index}.self_attn.q_proj.weight", None, "transpose"],
[f"layers.{layer_index}.self_attn.k_proj.weight", None, "transpose"],
[f"layers.{layer_index}.self_attn.v_proj.weight", None, "transpose"],
[f"layers.{layer_index}.self_attn.qkv_proj.weight", None, "transpose"],
[f"layers.{layer_index}.self_attn.o_proj.weight", None, "transpose"],
[f"layers.{layer_index}.self_attn.rotary_emb.inv_freq"],
[f"layers.{layer_index}.mlp.gate_proj.weight", None, "transpose"],
[f"layers.{layer_index}.mlp.down_proj.weight", None, "transpose"],
[f"layers.{layer_index}.mlp.up_proj.weight", None, "transpose"],
[f"layers.{layer_index}.mlp.gate_up_proj.weight", None, "transpose"],
[f"layers.{layer_index}.input_layernorm.weight"],
[f"layers.{layer_index}.post_attention_layernorm.weight"],
]
Expand Down Expand Up @@ -1246,7 +1249,7 @@ def get_tensor_parallel_split_mappings(num_layers):
base_actions.pop("lm_head.weight")
base_actions.pop("embed_tokens.weight")
# Column Linear
if config.fuse_attention_qkv:
if config.fuse_attention_qkv: # notice: config content here is not descriptive but commanding

Check warning on line 1252 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1252

Added line #L1252 was not covered by tests
base_actions["layers.0.self_attn.qkv_proj.weight"] = partial(fn, is_column=True)
else:
base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True)
Expand Down Expand Up @@ -1275,6 +1278,34 @@ def get_tensor_parallel_split_mappings(num_layers):

return mappings

@classmethod
def _get_fused_param_mappings(cls):
# return parameter fuse utils
from paddlenlp.transformers.conversion_utils import (
merged_as_tensor_parallel_qkv,
)

# attention: q,k,v -> qkv, ffn: gate, up -> gate_up
mappings = {
"fuse_action": [merged_as_tensor_parallel_qkv, None],
"split_action": [None, None],
"attn_param_names": {
"qkv_proj": lambda layer_id: re.sub(r"\d+", str(layer_id), "llama.layers.0.self_attn.qkv_proj.weight"),
"q_proj": lambda layer_id: re.sub(r"\d+", str(layer_id), "llama.layers.0.self_attn.q_proj.weight"),
"k_proj": lambda layer_id: re.sub(r"\d+", str(layer_id), "llama.layers.0.self_attn.k_proj.weight"),
"v_proj": lambda layer_id: re.sub(r"\d+", str(layer_id), "llama.layers.0.self_attn.v_proj.weight"),
},
"ffn_param_names": {
"gate_up_proj": lambda layer_id: re.sub(
r"\d+", str(layer_id), "llama.layers.0.mlp.gate_up_proj.weight"
),
"gate_proj": lambda layer_id: re.sub(r"\d+", str(layer_id), "llama.layers.0.mlp.gate_proj.weight"),
"up_proj": lambda layer_id: re.sub(r"\d+", str(layer_id), "llama.layers.0.mlp.up_proj.weight"),
},
}

return mappings

def _init_weights(self, layer):
"""Initialization hook"""
if self.config.tensor_parallel_degree > 1:
Expand Down
Loading

0 comments on commit b8b828b

Please sign in to comment.