Skip to content

Commit a7ab0da

Browse files
authored
Llama develop (speedup 2.x) (#504)
* llama develop * format * llama rms layer norm * finish llama kernel develop * llama op develop * llama develop * format * llama develop * llama export * llama export develop * llama develop * llama develop * llama rotary fuse kernel * llama fuse transpose and rotary position emb * llama develop * llama develop * llama develop * llama develop * llama develop * llama develop * adapt export * llama develop * llama develop * format
1 parent 7e5bed6 commit a7ab0da

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+3312
-69
lines changed

build.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@ if [ ! -d 'build' ]; then
22
mkdir build
33
fi
44
# DEVICE_ARCH could be cuda/x86/arm
5-
cd build && cmake -DUSE_NEW_ARCH=OFF -DDEVICE_ARCH=cuda -DUSE_TRITONBACKEND=OFF -DDEBUG_MODE=OFF -DFP16_MODE=OFF -DMEM_DEBUG=OFF .. && make -j${nproc}
5+
cd build && cmake -DUSE_NEW_ARCH=ON -DDEVICE_ARCH=cuda -DUSE_TRITONBACKEND=OFF -DDEBUG_MODE=ON -DFP16_MODE=ON -DMEM_DEBUG=OFF .. && make -j${nproc}
66
# you can use comand like below to compile lightseq with pybind interface:
77
# sudo PATH=$PATH:/usr/local/hdf5 CUDACXX=/usr/local/cuda/bin/nvcc DEVICE_ARCH=cuda ENABLE_FP32=0 ENABLE_DEBUG=0 ENABLE_NEW_ARCH=1 python3 setup.py install

lightseq/csrc/example/CMakeLists.txt

+3
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,6 @@ target_link_libraries(transformer_example PUBLIC liblightseq)
88

99
add_executable(gpt_example gpt_example.cc)
1010
target_link_libraries(gpt_example PUBLIC liblightseq)
11+
12+
add_executable(llama_example llama_example.cc)
13+
target_link_libraries(llama_example PUBLIC liblightseq)
+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
#include "model_base.h"
2+
#include "llama.h"
3+
4+
/**
5+
@file
6+
Example of how to run gpt inference using our implementation.
7+
*/
8+
9+
int main(int argc, char* argv[]) {
10+
std::string model_weights_path = argv[1];
11+
std::vector<int> example_input = {1, 21784, 26539, 338,
12+
263, 4933, 6509, 6890};
13+
int eg_seq_len = example_input.size();
14+
15+
int batch_size = 1;
16+
int batch_seq_len = eg_seq_len;
17+
18+
if (argc == 4) {
19+
batch_size = atoi(argv[2]);
20+
batch_seq_len = atoi(argv[3]);
21+
}
22+
23+
int max_batch_size = std::max(8, batch_size);
24+
25+
std::vector<int> host_input;
26+
for (int i = 0; i < batch_size; ++i) {
27+
for (int j = 0; j < batch_seq_len; ++j) {
28+
host_input.push_back(example_input[j % eg_seq_len]);
29+
}
30+
}
31+
32+
auto model = lightseq::cuda::LSModelFactory::GetInstance().CreateModel(
33+
"Llama", model_weights_path, 1);
34+
35+
void* d_input;
36+
CHECK_GPU_ERROR(
37+
cudaMalloc(&d_input, sizeof(int) * batch_size * batch_seq_len));
38+
CHECK_GPU_ERROR(cudaMemcpy(d_input, host_input.data(),
39+
sizeof(int) * batch_size * batch_seq_len,
40+
cudaMemcpyHostToDevice));
41+
42+
printf("example step.1\n");
43+
44+
model->set_input_ptr(0, d_input);
45+
model->set_input_shape(0, {batch_size, batch_seq_len});
46+
47+
for (int i = 0; i < model->get_output_size(); i++) {
48+
void* d_output;
49+
std::vector<int> shape = model->get_output_max_shape(i);
50+
int total_size = 1;
51+
for (int j = 0; j < shape.size(); j++) {
52+
total_size *= shape[j];
53+
}
54+
CHECK_GPU_ERROR(cudaMalloc(&d_output, total_size * sizeof(int)));
55+
model->set_output_ptr(i, d_output);
56+
}
57+
printf("example step.2\n");
58+
CHECK_GPU_ERROR(cudaStreamSynchronize(0));
59+
std::cout << "infer preprocessing finished" << std::endl;
60+
printf("example step.2-1\n");
61+
std::cout << "infer preprocessing finished 2" << std::endl;
62+
63+
std::chrono::duration<double> elapsed;
64+
int iter = 0;
65+
/* ---step5. infer and log--- */
66+
for (int i = 0; i < 5; i++) {
67+
auto start = std::chrono::high_resolution_clock::now();
68+
model->Infer();
69+
auto finish = std::chrono::high_resolution_clock::now();
70+
if (i) {
71+
iter++;
72+
elapsed += finish - start;
73+
}
74+
}
75+
76+
std::cout << "lightseq inference latency: " << elapsed.count() * 1000 / iter
77+
<< " ms" << std::endl;
78+
79+
for (int i = 0; i < model->get_output_size(); i++) {
80+
const int* d_output;
81+
d_output = static_cast<const int*>(model->get_output_ptr(i));
82+
std::vector<int> shape = model->get_output_shape(i);
83+
std::cout << "output shape: ";
84+
for (int j = 0; j < shape.size(); j++) {
85+
std::cout << shape[j] << " ";
86+
}
87+
std::cout << std::endl;
88+
if (i == 0) {
89+
lightseq::print_vec(d_output, "d_output", shape[2]);
90+
}
91+
}
92+
93+
return 0;
94+
}

lightseq/csrc/export/__init__.py

Whitespace-only changes.
+153
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
"""
2+
Export Hugging Face GPT2 models to hdf5 format.
3+
"""
4+
import __init__
5+
import os
6+
import h5py
7+
import numpy as np
8+
from collections import OrderedDict
9+
from util import parse_args, check_arguements, ModelArguements, fill_hdf5_layer
10+
import torch
11+
12+
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
13+
14+
15+
"""
16+
For the mapping dictionary: key is the value of the proto parameter,
17+
value is a powerful expression, each && split tensor name of the matching path or expression.
18+
19+
The sub-pattern of the path is separated by spaces, and the expression starts with a expression_.
20+
You can operate separately on each tensor and support multiple expressions. Multiple matching paths
21+
and the expression will finally be concatenated on axis = -1.
22+
"""
23+
24+
25+
"""
26+
'model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.0.self_attn.rotary_emb.inv_freq', 'model.layers.0.mlp.gate_proj.weight', 'model.layers.0.mlp.down_proj.weight', 'model.layers.0.mlp.up_proj.weight', 'model.layers.0.input_layernorm.weight', 'model.layers.0.post_attention_layernorm.weight'
27+
"""
28+
29+
dec_layer_mapping_dict = OrderedDict(
30+
{
31+
"attention_norm_scale": "input_layernorm weight",
32+
"attention_project_qkv": "self_attn q_proj weight&&self_attn k_proj weight&&self_attn v_proj weight&&expression_.transpose(0, 1)",
33+
"attention_output": "self_attn o_proj weight&&expression_.transpose(0, 1)",
34+
"ffn_norm_scale": "post_attention_layernorm weight",
35+
"gate_up_project_weight": "mlp gate_proj weight&&mlp up_proj weight&&expression_.transpose(0, 1)",
36+
"down_project_weight": "mlp down_proj weight&&expression_.transpose(0, 1)",
37+
}
38+
)
39+
40+
src_emb_mapping_dict = OrderedDict(
41+
{
42+
"post_norm_scale": "norm weight",
43+
"token_embedding": "embed_tokens weight",
44+
"logits_linear_weight": "lm_head weight&&expression_.transpose(0, 1)",
45+
}
46+
)
47+
48+
49+
def extract_llama_weights(
50+
output_file: str,
51+
arguments: ModelArguements,
52+
):
53+
# load var names
54+
state_dict = torch.load(arguments.model_file)
55+
56+
head_num = arguments.head_num
57+
enc_var_name_list = list(state_dict.keys())
58+
59+
# initialize output file
60+
output_file += ".hdf5"
61+
print("Saving model to hdf5...")
62+
print("Writing to {0}".format(output_file))
63+
64+
# exit(0)
65+
hdf5_file = h5py.File(output_file, "w")
66+
67+
# fill each encoder layer's params
68+
enc_tensor_names = {}
69+
for name in enc_var_name_list:
70+
name_split = name.split(".")
71+
if len(name_split) <= 2 or not name_split[2].isdigit():
72+
continue
73+
layer_id = int(name_split[2])
74+
enc_tensor_names.setdefault(layer_id, []).append(name)
75+
76+
# fill encoder_stack
77+
for layer_id in sorted(enc_tensor_names.keys()):
78+
fill_hdf5_layer(
79+
enc_tensor_names[layer_id],
80+
state_dict,
81+
hdf5_file,
82+
f"decoder_layers/{layer_id}/",
83+
dec_layer_mapping_dict,
84+
)
85+
86+
# fill src_embedding - except for position embedding
87+
fill_hdf5_layer(
88+
enc_var_name_list,
89+
state_dict,
90+
hdf5_file,
91+
"src_embedding/",
92+
src_emb_mapping_dict,
93+
)
94+
95+
# save number of layers metadata
96+
hdf5_file.create_dataset(
97+
"model_conf/hidden_size", data=arguments.hidden_size, dtype="i4"
98+
)
99+
hdf5_file.create_dataset(
100+
"model_conf/inner_size", data=arguments.inner_size, dtype="i4"
101+
)
102+
hdf5_file.create_dataset("model_conf/max_step", data=arguments.max_step, dtype="i4")
103+
hdf5_file.create_dataset("model_conf/head_num", data=arguments.head_num, dtype="i4")
104+
hdf5_file.create_dataset(
105+
"model_conf/layer_num", data=arguments.layer_num, dtype="i4"
106+
)
107+
hdf5_file.create_dataset(
108+
"model_conf/src_padding_id", data=arguments.padding_id, dtype="i4"
109+
)
110+
hdf5_file.create_dataset(
111+
"model_conf/generate_method",
112+
data=np.array([ord(c) for c in arguments.generation_method]).astype(np.int8),
113+
dtype="i1",
114+
)
115+
hdf5_file.create_dataset("model_conf/topp", data=arguments.topp, dtype="f4")
116+
hdf5_file.create_dataset("model_conf/topk", data=arguments.topk, dtype="i4")
117+
hdf5_file.create_dataset("model_conf/eos_id", data=arguments.eos_id, dtype="i4")
118+
hdf5_file.create_dataset(
119+
"model_conf/extra_decode_length", data=arguments.extra_decode_length, dtype="i4"
120+
)
121+
hdf5_file.create_dataset(
122+
"model_conf/src_vocab_size", data=arguments.vocab_size, dtype="i4"
123+
)
124+
125+
hdf5_file.close()
126+
# read-in again to double check
127+
hdf5_file = h5py.File(output_file, "r")
128+
129+
def _print_pair(key, value):
130+
if key == "generate_method":
131+
value = "".join(map(chr, value[()]))
132+
else:
133+
value = value[()]
134+
print(f"{key}: {value}")
135+
136+
list(map(lambda x: _print_pair(*x), hdf5_file["model_conf"].items()))
137+
138+
139+
if __name__ == "__main__":
140+
args = parse_args()
141+
142+
arguments = ModelArguements(args)
143+
basename = os.path.basename(arguments.model_repo)
144+
output_lightseq_model_name = "_".join(["lightseq_llama", basename, "7b"])
145+
# default eos_id from https://huggingface.co/transformers/model_doc/gpt2.html#gpt2lmheadmodel
146+
147+
arguments.eos_id = 2 # need to set
148+
arguments.padding_id = 0 # need to set
149+
150+
if not check_arguements(arguments):
151+
exit(0)
152+
153+
extract_llama_weights(output_lightseq_model_name, arguments)

0 commit comments

Comments
 (0)