|
| 1 | +""" |
| 2 | +Export Hugging Face GPT2 models to hdf5 format. |
| 3 | +""" |
| 4 | +import __init__ |
| 5 | +import os |
| 6 | +import h5py |
| 7 | +import numpy as np |
| 8 | +from collections import OrderedDict |
| 9 | +from util import parse_args, check_arguements, ModelArguements, fill_hdf5_layer |
| 10 | +import torch |
| 11 | + |
| 12 | +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" |
| 13 | + |
| 14 | + |
| 15 | +""" |
| 16 | +For the mapping dictionary: key is the value of the proto parameter, |
| 17 | +value is a powerful expression, each && split tensor name of the matching path or expression. |
| 18 | +
|
| 19 | +The sub-pattern of the path is separated by spaces, and the expression starts with a expression_. |
| 20 | +You can operate separately on each tensor and support multiple expressions. Multiple matching paths |
| 21 | +and the expression will finally be concatenated on axis = -1. |
| 22 | +""" |
| 23 | + |
| 24 | + |
| 25 | +""" |
| 26 | +'model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.0.self_attn.rotary_emb.inv_freq', 'model.layers.0.mlp.gate_proj.weight', 'model.layers.0.mlp.down_proj.weight', 'model.layers.0.mlp.up_proj.weight', 'model.layers.0.input_layernorm.weight', 'model.layers.0.post_attention_layernorm.weight' |
| 27 | +""" |
| 28 | + |
| 29 | +dec_layer_mapping_dict = OrderedDict( |
| 30 | + { |
| 31 | + "attention_norm_scale": "input_layernorm weight", |
| 32 | + "attention_project_qkv": "self_attn q_proj weight&&self_attn k_proj weight&&self_attn v_proj weight&&expression_.transpose(0, 1)", |
| 33 | + "attention_output": "self_attn o_proj weight&&expression_.transpose(0, 1)", |
| 34 | + "ffn_norm_scale": "post_attention_layernorm weight", |
| 35 | + "gate_up_project_weight": "mlp gate_proj weight&&mlp up_proj weight&&expression_.transpose(0, 1)", |
| 36 | + "down_project_weight": "mlp down_proj weight&&expression_.transpose(0, 1)", |
| 37 | + } |
| 38 | +) |
| 39 | + |
| 40 | +src_emb_mapping_dict = OrderedDict( |
| 41 | + { |
| 42 | + "post_norm_scale": "norm weight", |
| 43 | + "token_embedding": "embed_tokens weight", |
| 44 | + "logits_linear_weight": "lm_head weight&&expression_.transpose(0, 1)", |
| 45 | + } |
| 46 | +) |
| 47 | + |
| 48 | + |
| 49 | +def extract_llama_weights( |
| 50 | + output_file: str, |
| 51 | + arguments: ModelArguements, |
| 52 | +): |
| 53 | + # load var names |
| 54 | + state_dict = torch.load(arguments.model_file) |
| 55 | + |
| 56 | + head_num = arguments.head_num |
| 57 | + enc_var_name_list = list(state_dict.keys()) |
| 58 | + |
| 59 | + # initialize output file |
| 60 | + output_file += ".hdf5" |
| 61 | + print("Saving model to hdf5...") |
| 62 | + print("Writing to {0}".format(output_file)) |
| 63 | + |
| 64 | + # exit(0) |
| 65 | + hdf5_file = h5py.File(output_file, "w") |
| 66 | + |
| 67 | + # fill each encoder layer's params |
| 68 | + enc_tensor_names = {} |
| 69 | + for name in enc_var_name_list: |
| 70 | + name_split = name.split(".") |
| 71 | + if len(name_split) <= 2 or not name_split[2].isdigit(): |
| 72 | + continue |
| 73 | + layer_id = int(name_split[2]) |
| 74 | + enc_tensor_names.setdefault(layer_id, []).append(name) |
| 75 | + |
| 76 | + # fill encoder_stack |
| 77 | + for layer_id in sorted(enc_tensor_names.keys()): |
| 78 | + fill_hdf5_layer( |
| 79 | + enc_tensor_names[layer_id], |
| 80 | + state_dict, |
| 81 | + hdf5_file, |
| 82 | + f"decoder_layers/{layer_id}/", |
| 83 | + dec_layer_mapping_dict, |
| 84 | + ) |
| 85 | + |
| 86 | + # fill src_embedding - except for position embedding |
| 87 | + fill_hdf5_layer( |
| 88 | + enc_var_name_list, |
| 89 | + state_dict, |
| 90 | + hdf5_file, |
| 91 | + "src_embedding/", |
| 92 | + src_emb_mapping_dict, |
| 93 | + ) |
| 94 | + |
| 95 | + # save number of layers metadata |
| 96 | + hdf5_file.create_dataset( |
| 97 | + "model_conf/hidden_size", data=arguments.hidden_size, dtype="i4" |
| 98 | + ) |
| 99 | + hdf5_file.create_dataset( |
| 100 | + "model_conf/inner_size", data=arguments.inner_size, dtype="i4" |
| 101 | + ) |
| 102 | + hdf5_file.create_dataset("model_conf/max_step", data=arguments.max_step, dtype="i4") |
| 103 | + hdf5_file.create_dataset("model_conf/head_num", data=arguments.head_num, dtype="i4") |
| 104 | + hdf5_file.create_dataset( |
| 105 | + "model_conf/layer_num", data=arguments.layer_num, dtype="i4" |
| 106 | + ) |
| 107 | + hdf5_file.create_dataset( |
| 108 | + "model_conf/src_padding_id", data=arguments.padding_id, dtype="i4" |
| 109 | + ) |
| 110 | + hdf5_file.create_dataset( |
| 111 | + "model_conf/generate_method", |
| 112 | + data=np.array([ord(c) for c in arguments.generation_method]).astype(np.int8), |
| 113 | + dtype="i1", |
| 114 | + ) |
| 115 | + hdf5_file.create_dataset("model_conf/topp", data=arguments.topp, dtype="f4") |
| 116 | + hdf5_file.create_dataset("model_conf/topk", data=arguments.topk, dtype="i4") |
| 117 | + hdf5_file.create_dataset("model_conf/eos_id", data=arguments.eos_id, dtype="i4") |
| 118 | + hdf5_file.create_dataset( |
| 119 | + "model_conf/extra_decode_length", data=arguments.extra_decode_length, dtype="i4" |
| 120 | + ) |
| 121 | + hdf5_file.create_dataset( |
| 122 | + "model_conf/src_vocab_size", data=arguments.vocab_size, dtype="i4" |
| 123 | + ) |
| 124 | + |
| 125 | + hdf5_file.close() |
| 126 | + # read-in again to double check |
| 127 | + hdf5_file = h5py.File(output_file, "r") |
| 128 | + |
| 129 | + def _print_pair(key, value): |
| 130 | + if key == "generate_method": |
| 131 | + value = "".join(map(chr, value[()])) |
| 132 | + else: |
| 133 | + value = value[()] |
| 134 | + print(f"{key}: {value}") |
| 135 | + |
| 136 | + list(map(lambda x: _print_pair(*x), hdf5_file["model_conf"].items())) |
| 137 | + |
| 138 | + |
| 139 | +if __name__ == "__main__": |
| 140 | + args = parse_args() |
| 141 | + |
| 142 | + arguments = ModelArguements(args) |
| 143 | + basename = os.path.basename(arguments.model_repo) |
| 144 | + output_lightseq_model_name = "_".join(["lightseq_llama", basename, "7b"]) |
| 145 | + # default eos_id from https://huggingface.co/transformers/model_doc/gpt2.html#gpt2lmheadmodel |
| 146 | + |
| 147 | + arguments.eos_id = 2 # need to set |
| 148 | + arguments.padding_id = 0 # need to set |
| 149 | + |
| 150 | + if not check_arguements(arguments): |
| 151 | + exit(0) |
| 152 | + |
| 153 | + extract_llama_weights(output_lightseq_model_name, arguments) |
0 commit comments