-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BUG in W4A8_awq-kv-FP8, W-fp8-A-fp8-kv-fp8, in the 0.17.0.post1 #2810
Comments
Hi @white-wolf-tech, the warning info you provided is caused by a latest optimization over
|
Thank you for your reply. I conducted the same experiment using FP8 quantization. |
The model I'm using is Qwen2.5-3B. When I used the code in the "examples" folder and directly compiled the weights in FP16 or smoothquant (W8A8 (INT8)) format, the model ran normally. |
the following is the build script: python quantize.py --model_dir $hf_input_dir \
--calib_dataset $calib_dataset_path \
--dtype 'auto' \
--qformat "fp8" \
--awq_block_size 128 \
--batch_size 32 \
--output_dir $input_temp_dir \
--kv_cache_dtype "fp8"
trtllm-build --checkpoint_dir $input_temp_dir \
--output_dir $output_dir \
--max_batch_size 512 \
--max_input_len 1024 \
--max_seq_len 2048 \
--max_beam_width 1 \
--max_num_tokens 16384 \
--gemm_plugin auto \
--kv_cache_type paged \
--remove_input_padding enable \
--context_fmha enable \
--use_paged_context_fmha enable \
--use_fp8_context_fmha enable \
--tokens_per_block 32 \
--use_fused_mlp enable \
--multiple_profiles enable \
--reduce_fusion enable \
--user_buffer enable \
--workers 4 \
--log_level info | tee build.log log is : [TensorRT-LLM] TensorRT-LLM version: 0.17.0.post1 When using FP8 quantization, there is no warning information as before. I used the llm-api for a preliminary test. from transformers import AutoTokenizer
from tqdm import tqdm
import tensorrt_llm
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import (KvCacheConfig,
LookaheadDecodingConfig,
MedusaDecodingConfig,
QuantAlgo,
QuantConfig,
SchedulerConfig
)
def main():
hf_model_dir = "/data/models/Qwen2.5-3B"
trt_engine_path = "/data/v0.17.0/code/models/qwen3b-trt/1gpu-w8a8kv8/"
prompt = "Please help me calculate the factorial of 29"
tokenizer = AutoTokenizer.from_pretrained(hf_model_dir, trust_remote_code=True)
sampling_params = SamplingParams(temperature=0.0, max_tokens=128, end_id=151643)
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9, enable_block_reuse=True,)
llm = LLM(model=trt_engine_path,
tokenizer=tokenizer,
kv_cache_config=kv_cache_config,
enable_chunked_prefill=True,
)
outputs = llm.generate([prompt], sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
if __name__ == '__main__':
main() Build a container using Image nvcr.io/nvidia/tritonserver:25.01-trtllm-python-py3 and conduct the test with the above code. The running result is normal. Then, when I used Triton Server with the TensorRT-LLM backend to deploy the compiled engine, the output result was incorrect. The template filling code is as follows: ENGINE_DIR=/code/models/qwen3b-trt/1gpu-w8a8kv8
TOKENIZER_DIR=/code/models/qwen_2_5_tokenizer
MODEL_FOLDER=code/triton_deploy/inflight_batcher_llm
TRITON_MAX_BATCH_SIZE=512
INSTANCE_COUNT=64
BLS_INSTANCE_COUNT=256
MAX_QUEUE_DELAY_MS=0
MAX_QUEUE_SIZE=0
FILL_SCRIPT=./fill_template.py
DECOUPLED_MODE=false
python3 ${FILL_SCRIPT} -i ${MODEL_FOLDER}/ensemble/config.pbtxt \
triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},\
logits_datatype:TYPE_FP32
python3 ${FILL_SCRIPT} -i ${MODEL_FOLDER}/preprocessing/config.pbtxt \
tokenizer_dir:${TOKENIZER_DIR},\
triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},\
preprocessing_instance_count:${INSTANCE_COUNT}
python3 ${FILL_SCRIPT} -i ${MODEL_FOLDER}/tensorrt_llm/config.pbtxt \
triton_backend:tensorrtllm,\
triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},\
decoupled_mode:${DECOUPLED_MODE},\
max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MS},\
max_queue_size:${MAX_QUEUE_SIZE},\
encoder_input_features_data_type:TYPE_FP16,\
logits_datatype:TYPE_FP32,\
engine_dir:${ENGINE_DIR},\
batching_strategy:inflight_fused_batching,\
batch_scheduler_policy:guaranteed_no_evict,\
kv_cache_free_gpu_mem_fraction:0.9,\
enable_kv_cache_reuse:true,\
enable_chunked_context:true,\
enable_context_fmha_fp32_acc:true,\
multi_block_mode:true,\
cuda_graph_mode:true
python3 ${FILL_SCRIPT} -i ${MODEL_FOLDER}/postprocessing/config.pbtxt \
tokenizer_dir:${TOKENIZER_DIR},\
triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},\
postprocessing_instance_count:${INSTANCE_COUNT},\
max_queue_size:${MAX_QUEUE_SIZE}
python3 ${FILL_SCRIPT} -i ${MODEL_FOLDER}/tensorrt_llm_bls/config.pbtxt \
triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},\
decoupled_mode:${DECOUPLED_MODE},\
logits_datatype:TYPE_FP32,\
bls_instance_count:${BLS_INSTANCE_COUNT} The output result is a bunch of meaningless tokens, and the result is as follows: "xx.Componentlocklocklocklocklocklocklocklocklocklocklocklocklocklocklocklocklocklocklocklocklocklocklocklocklocklocklocklocklocklocklocklocklocklocklocklocklocklock" My current conclusion is that there is no problem with the inference process of TensorRT-LLM. So, is there a problem when it is integrated with Triton Server? What could this problem be? |
I added a log print at the very beginning of the execute function in the model.py file within the tensorrt_llm directory. |
System Info
GPU: L20
tensorrt-LLM: v0.17.0.post1
modelopt: 0.23.0
Who can help?
No response
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
I used the following code to perform model quantization, and everything seemed to be fine.
After I completed the calibration on the test set, I got the following warning when using tensorrt-llm for conversion. The output result of the finally compiled model is incorrect.
The warning shows that these variables were not adopted, but they were present during the quantization process.
It should be that there are some differences between the conversion code of tensorrt-llm and the implementation of modelopt that led to this situation.
Is there any solution to this problem?
warning info:
[TRT-LLM] [W] Provided but not required tensors: {'transformer.layers.22.mlp.fc.activation_scaling_factor', 'transformer.layers.19.mlp.fc.activation_scaling_factor', 'transformer.layers.14.mlp.fc.activation_scaling_factor', 'transformer.layers.25.mlp.fc.activation_scaling_factor', 'transformer.layers.25.mlp.gate.activation_scaling_factor', 'transformer.layers.5.mlp.fc.activation_scaling_factor', 'transformer.layers.5.attention.qkv.activation_scaling_factor', 'transformer.layers.8.mlp.gate.activation_scaling_factor', 'transformer.layers.32.mlp.fc.activation_scaling_factor', 'transformer.layers.23.mlp.fc.activation_scaling_factor', 'transformer.layers.31.attention.qkv.activation_scaling_factor', 'transformer.layers.33.mlp.gate.activation_scaling_factor', 'transformer.layers.2.mlp.gate.activation_scaling_factor', 'transformer.layers.14.attention.qkv.activation_scaling_factor', 'transformer.layers.18.mlp.gate.activation_scaling_factor', 'transformer.layers.19.attention.qkv.activation_scaling_factor', 'transformer.layers.20.attention.qkv.activation_scaling_factor', 'transformer.layers.24.attention.qkv.activation_scaling_factor', 'transformer.layers.7.mlp.fc.activation_scaling_factor', 'transformer.layers.17.attention.qkv.activation_scaling_factor', 'transformer.layers.28.mlp.gate.activation_scaling_factor', 'transformer.layers.6.mlp.fc.activation_scaling_factor', 'transformer.layers.15.attention.qkv.activation_scaling_factor', 'transformer.layers.10.attention.qkv.activation_scaling_factor', 'transformer.layers.21.mlp.gate.activation_scaling_factor', 'transformer.layers.17.mlp.fc.activation_scaling_factor', 'transformer.layers.23.attention.qkv.activation_scaling_factor', 'transformer.layers.27.mlp.fc.activation_scaling_factor', 'transformer.layers.28.attention.qkv.activation_scaling_factor', 'transformer.layers.26.attention.qkv.activation_scaling_factor', 'transformer.layers.13.attention.qkv.activation_scaling_factor', 'transformer.layers.5.mlp.gate.activation_scaling_factor', 'transformer.layers.11.mlp.fc.activation_scaling_factor', 'transformer.layers.13.mlp.fc.activation_scaling_factor', 'transformer.layers.8.attention.qkv.activation_scaling_factor', 'transformer.layers.9.attention.qkv.activation_scaling_factor', 'transformer.layers.26.mlp.fc.activation_scaling_factor', 'transformer.layers.0.mlp.fc.activation_scaling_factor', 'transformer.layers.18.mlp.fc.activation_scaling_factor', 'transformer.layers.24.mlp.gate.activation_scaling_factor', 'transformer.layers.34.mlp.fc.activation_scaling_factor', 'transformer.layers.12.attention.qkv.activation_scaling_factor', 'transformer.layers.4.mlp.gate.activation_scaling_factor', 'transformer.layers.2.mlp.fc.activation_scaling_factor', 'transformer.layers.4.attention.qkv.activation_scaling_factor', 'transformer.layers.25.attention.qkv.activation_scaling_factor', 'transformer.layers.0.mlp.gate.activation_scaling_factor', 'transformer.layers.2.attention.qkv.activation_scaling_factor', 'transformer.layers.4.mlp.fc.activation_scaling_factor', 'transformer.layers.16.attention.qkv.activation_scaling_factor', 'transformer.layers.16.mlp.gate.activation_scaling_factor', 'transformer.layers.32.attention.qkv.activation_scaling_factor', 'transformer.layers.33.mlp.fc.activation_scaling_factor', 'transformer.layers.24.mlp.fc.activation_scaling_factor', 'transformer.layers.1.mlp.fc.activation_scaling_factor', 'transformer.layers.17.mlp.gate.activation_scaling_factor', 'transformer.layers.15.mlp.gate.activation_scaling_factor', 'transformer.layers.34.mlp.gate.activation_scaling_factor', 'transformer.layers.35.attention.qkv.activation_scaling_factor', 'transformer.layers.33.attention.qkv.activation_scaling_factor', 'transformer.layers.12.mlp.gate.activation_scaling_factor', 'transformer.layers.29.mlp.fc.activation_scaling_factor', 'transformer.layers.32.mlp.gate.activation_scaling_factor', 'transformer.layers.21.mlp.fc.activation_scaling_factor', 'transformer.layers.20.mlp.fc.activation_scaling_factor', 'transformer.layers.15.mlp.fc.activation_scaling_factor', 'transformer.layers.9.mlp.gate.activation_scaling_factor', 'transformer.layers.0.attention.qkv.activation_scaling_factor', 'transformer.layers.19.mlp.gate.activation_scaling_factor', 'transformer.layers.3.attention.qkv.activation_scaling_factor', 'transformer.layers.3.mlp.fc.activation_scaling_factor', 'transformer.layers.1.attention.qkv.activation_scaling_factor', 'transformer.layers.26.mlp.gate.activation_scaling_factor', 'transformer.layers.11.attention.qkv.activation_scaling_factor', 'transformer.layers.34.attention.qkv.activation_scaling_factor', 'transformer.layers.20.mlp.gate.activation_scaling_factor', 'transformer.layers.1.mlp.gate.activation_scaling_factor', 'transformer.layers.9.mlp.fc.activation_scaling_factor', 'transformer.layers.11.mlp.gate.activation_scaling_factor', 'transformer.layers.30.mlp.fc.activation_scaling_factor', 'transformer.layers.27.mlp.gate.activation_scaling_factor', 'transformer.layers.18.attention.qkv.activation_scaling_factor', 'transformer.layers.35.mlp.gate.activation_scaling_factor', 'transformer.layers.10.mlp.gate.activation_scaling_factor', 'transformer.layers.31.mlp.fc.activation_scaling_factor', 'transformer.layers.13.mlp.gate.activation_scaling_factor', 'transformer.layers.30.attention.qkv.activation_scaling_factor', 'transformer.layers.31.mlp.gate.activation_scaling_factor', 'transformer.layers.22.mlp.gate.activation_scaling_factor', 'transformer.layers.16.mlp.fc.activation_scaling_factor', 'transformer.layers.29.attention.qkv.activation_scaling_factor', 'transformer.layers.30.mlp.gate.activation_scaling_factor', 'transformer.layers.7.mlp.gate.activation_scaling_factor', 'transformer.layers.23.mlp.gate.activation_scaling_factor', 'transformer.layers.6.mlp.gate.activation_scaling_factor', 'transformer.layers.6.attention.qkv.activation_scaling_factor', 'transformer.layers.22.attention.qkv.activation_scaling_factor', 'transformer.layers.7.attention.qkv.activation_scaling_factor', 'transformer.layers.28.mlp.fc.activation_scaling_factor', 'transformer.layers.3.mlp.gate.activation_scaling_factor', 'transformer.layers.8.mlp.fc.activation_scaling_factor', 'transformer.layers.10.mlp.fc.activation_scaling_factor', 'transformer.layers.29.mlp.gate.activation_scaling_factor', 'transformer.layers.27.attention.qkv.activation_scaling_factor', 'transformer.layers.14.mlp.gate.activation_scaling_factor', 'transformer.layers.12.mlp.fc.activation_scaling_factor', 'transformer.layers.35.mlp.fc.activation_scaling_factor', 'transformer.layers.21.attention.qkv.activation_scaling_factor'}
Expected behavior
model`s output is normal
actual behavior
model`s output is wrong!!!
additional notes
NULL
The text was updated successfully, but these errors were encountered: