Skip to content

Commit

Permalink
Update predict_generation.py
Browse files Browse the repository at this point in the history
for chatglm2
  • Loading branch information
lxp521125 authored Jul 25, 2023
1 parent 9eb3cd7 commit 2224ca3
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions llm/chatglm/predict_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
chatglm_pad_attention_mask,
chatglm_postprocess_past_key_value,
)
from paddlenlp.transformers import ChatGLMConfig, ChatGLMForCausalLM, ChatGLMTokenizer
from paddlenlp.transformers import ChatGLMConfig, AutoTokenizer, AutoModelForCausalLM


def parse_arguments():
Expand Down Expand Up @@ -62,7 +62,7 @@ def __init__(self, args=None, tokenizer=None, model=None, **kwargs):
self.src_length = kwargs["src_length"]
self.tgt_length = kwargs["tgt_length"]
else:
self.tokenizer = ChatGLMTokenizer.from_pretrained(args.model_name_or_path)
self.tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
self.batch_size = args.batch_size
self.args = args
self.src_length = self.args.src_length
Expand Down Expand Up @@ -92,7 +92,7 @@ def __init__(self, args=None, tokenizer=None, model=None, **kwargs):
config = ChatGLMConfig.from_pretrained(args.model_name_or_path)
dtype = config.dtype if config.dtype is not None else config.paddle_dtype

self.model = ChatGLMForCausalLM.from_pretrained(
self.model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
tensor_parallel_degree=tensor_parallel_degree,
tensor_parallel_rank=tensor_parallel_rank,
Expand Down

0 comments on commit 2224ca3

Please sign in to comment.