diff --git a/paddlenlp/transformers/model_utils.py b/paddlenlp/transformers/model_utils.py index ef08850b1d97..aee4c17b77e4 100644 --- a/paddlenlp/transformers/model_utils.py +++ b/paddlenlp/transformers/model_utils.py @@ -354,7 +354,7 @@ def load_state_dict( return state_dict - state_dict = paddlenlp_load(checkpoint_file, map_location="cpu") # debug: sub-branch pdparams读取 + state_dict = paddlenlp_load(checkpoint_file, map_location="cpu") return state_dict @@ -444,7 +444,7 @@ def resolve_weight_file_from_hf_hub( for fn in file_name_list: resolved_file = cached_file_for_hf_hub( repo_id, fn, cache_dir, subfolder, _raise_exceptions_for_missing_entries=False - ) # debug: "linly-ai/chinese-llama-2-7b-hf" actually, missing "-hf" + ) if resolved_file is not None: if resolved_file.endswith(".json"): is_sharded = True @@ -714,7 +714,7 @@ def load_sharded_checkpoint(model, folder, variant=None, strict=True, prefer_saf loader = safe_load_file if load_safe else partial(paddlenlp_load, map_location="cpu") for shard_file in shard_files: - state_dict = loader(os.path.join(folder, shard_file)) # debug: sub-branch sharding + state_dict = loader(os.path.join(folder, shard_file)) with warnings.catch_warnings(): warnings.resetwarnings() warnings.filterwarnings("ignore", message=r".*is not found in the provided dict.*") @@ -1790,7 +1790,7 @@ def _load_pretrained_model( except ImportError: raise ImportError("Quantization features require `paddlepaddle >= 2.5.2`") if state_dict is not None: - state_dict = convert_to_quantize_state_dict( # debug: sub-branch quantize + state_dict = convert_to_quantize_state_dict( state_dict, quantization_linear_list, config.quantization_config, @@ -1798,7 +1798,7 @@ def _load_pretrained_model( ) loaded_keys = [k for k in state_dict.keys()] else: - loaded_keys = update_loaded_state_dict_keys( # debug: branch-6.1 low_cpu_mem_usage + loaded_keys = update_loaded_state_dict_keys( loaded_keys, quantization_linear_list, config.quantization_config ) if keep_in_fp32_modules is None: @@ -1885,9 +1885,7 @@ def _find_mismatched_keys( # To avoid recursive import temporarily. import paddlenlp.ops.fast_transformer.transformer.decoding as ft_decoding - state_dict = ft_decoding.get_ft_para_conf().fit_partial_model( - model_to_load, state_dict - ) # debug: sub-branch MP + state_dict = ft_decoding.get_ft_para_conf().fit_partial_model(model_to_load, state_dict) mismatched_keys = _find_mismatched_keys( state_dict, @@ -1933,18 +1931,15 @@ def _find_mismatched_keys( ): pre_tensor_parallel_split = True assert loaded_keys is not None, "loaded_keys is not None." - tp_actions = cls.get_tensor_parallel_convert_actions(config, loaded_keys) - # Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors - state_dict = load_state_dict( # debug: branch-5(6.5) load from sharding files // low_cpu_memory(state_dict is None) + state_dict = load_state_dict( shard_file, tp_actions if pre_tensor_parallel_split else None, None if config.quantization_config.is_weight_quantize() else set(expected_keys), ) - if config.quantization_config.is_weight_quantize(): - state_dict = convert_to_quantize_state_dict( # debug: sub-branch quantize + state_dict = convert_to_quantize_state_dict( state_dict, quantization_linear_list, config.quantization_config, @@ -1965,7 +1960,7 @@ def _find_mismatched_keys( if config.tensor_parallel_degree > 1 and ".tp" not in shard_file and not pre_tensor_parallel_split: logger.info("Converting state_dict to Tensor Parallel Format") # ignore error for multi shard, since only parts of data - state_dict = cls.convert_tensor_parallel( # debug: sub-branch TP + state_dict = cls.convert_tensor_parallel( None, config, state_dict=state_dict, ignore_error=len(resolved_archive_file) > 1 ) logger.info("Converted state_dict to Tensor Parallel Format") @@ -2295,7 +2290,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): f"Starting to convert pytorch weight file<{resolved_archive_file}> to " f"paddle weight file<{convert_dir}> ..." ) - state_dict = cls.convert( # debug: branch-4 torch + state_dict = cls.convert( resolved_archive_file, config, # cache_dir=os.path.join(cache_dir, pretrained_model_name_or_path, subfolder), @@ -2314,9 +2309,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): if not is_sharded and state_dict is None: # 4. loading non-sharded ckpt from the state dict if config.tensor_parallel_degree > 1 and resolved_archive_file.endswith("model_state.pdparams"): - state_dict = cls.convert_tensor_parallel( - resolved_archive_file, config - ) # debug: branch-3 pdparams转换 + TP + state_dict = cls.convert_tensor_parallel(resolved_archive_file, config) elif config.tensor_parallel_degree > 1 and resolved_archive_file.endswith("model.safetensors"): with safe_open(resolved_archive_file, framework="np", device="cpu") as f: loaded_keys = f.keys() @@ -2341,7 +2334,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): tp_actions = cls.get_tensor_parallel_convert_actions( config, loaded_keys, ignore_params=separate_keys | fused_keys ) - state_dict = load_state_dict(resolved_archive_file, tp_actions) # debug: branch-2 TP + state_dict = load_state_dict(resolved_archive_file, tp_actions) # apply qkv/gate_up fuse action and tensor-parallel action sequentially if "attention_qkv_proj" in do_fuse_parameter_list: @@ -2354,7 +2347,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): for qkv_name in tp_actions.keys(): state_dict[qkv_name] = tp_actions[qkv_name](state_dict[qkv_name]) # apply tp-action for qkv else: - state_dict = load_state_dict(resolved_archive_file) # debug: branch-1 normal + state_dict = load_state_dict(resolved_archive_file) do_fuse_parameter_list, do_separate_parameter_list = select_fuse_parameter(model, state_dict.keys()) if do_fuse_parameter_list: @@ -2375,7 +2368,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): loaded_state_dict_keys = [k for k in state_dict.keys()] if low_cpu_mem_usage: # or use_keep_in_fp32_modules: - # debug: branch-6 low_cpu_mem_usage on state_dict = None # will only support load paddle.Tensor to model.