diff --git a/awq/models/starcoder2.py b/awq/models/starcoder2.py index 2e493514..be79506b 100644 --- a/awq/models/starcoder2.py +++ b/awq/models/starcoder2.py @@ -110,13 +110,9 @@ def fuse_transformer(self): module.self_attn.k_proj, module.self_attn.v_proj, ) - norm_1 = FasterTransformerRMSNorm( - module.input_layernorm.weight, module.input_layernorm.eps - ) - norm_2 = FasterTransformerRMSNorm( - module.post_attention_layernorm.weight, - module.post_attention_layernorm.eps, - ) + # SC2 use normal LayerNorm + norm_1 = module.input_layernorm + norm_2 = module.post_attention_layernorm blocks.append( LlamaLikeBlock( hidden_size=self.model.config.hidden_size,