Skip to content

Commit

Permalink
remove debug comment
Browse files Browse the repository at this point in the history
  • Loading branch information
DrownFish19 committed Mar 29, 2024
1 parent b8b828b commit 28ed30f
Showing 1 changed file with 13 additions and 21 deletions.
34 changes: 13 additions & 21 deletions paddlenlp/transformers/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def load_state_dict(

return state_dict

state_dict = paddlenlp_load(checkpoint_file, map_location="cpu") # debug: sub-branch pdparams读取
state_dict = paddlenlp_load(checkpoint_file, map_location="cpu")
return state_dict


Expand Down Expand Up @@ -444,7 +444,7 @@ def resolve_weight_file_from_hf_hub(
for fn in file_name_list:
resolved_file = cached_file_for_hf_hub(
repo_id, fn, cache_dir, subfolder, _raise_exceptions_for_missing_entries=False
) # debug: "linly-ai/chinese-llama-2-7b-hf" actually, missing "-hf"
)
if resolved_file is not None:
if resolved_file.endswith(".json"):
is_sharded = True
Expand Down Expand Up @@ -714,7 +714,7 @@ def load_sharded_checkpoint(model, folder, variant=None, strict=True, prefer_saf
loader = safe_load_file if load_safe else partial(paddlenlp_load, map_location="cpu")

for shard_file in shard_files:
state_dict = loader(os.path.join(folder, shard_file)) # debug: sub-branch sharding
state_dict = loader(os.path.join(folder, shard_file))
with warnings.catch_warnings():
warnings.resetwarnings()
warnings.filterwarnings("ignore", message=r".*is not found in the provided dict.*")
Expand Down Expand Up @@ -1790,15 +1790,15 @@ def _load_pretrained_model(
except ImportError:
raise ImportError("Quantization features require `paddlepaddle >= 2.5.2`")
if state_dict is not None:
state_dict = convert_to_quantize_state_dict( # debug: sub-branch quantize
state_dict = convert_to_quantize_state_dict(
state_dict,
quantization_linear_list,
config.quantization_config,
dtype,
)
loaded_keys = [k for k in state_dict.keys()]
else:
loaded_keys = update_loaded_state_dict_keys( # debug: branch-6.1 low_cpu_mem_usage
loaded_keys = update_loaded_state_dict_keys(
loaded_keys, quantization_linear_list, config.quantization_config
)
if keep_in_fp32_modules is None:
Expand Down Expand Up @@ -1885,9 +1885,7 @@ def _find_mismatched_keys(
# To avoid recursive import temporarily.
import paddlenlp.ops.fast_transformer.transformer.decoding as ft_decoding

state_dict = ft_decoding.get_ft_para_conf().fit_partial_model(
model_to_load, state_dict
) # debug: sub-branch MP
state_dict = ft_decoding.get_ft_para_conf().fit_partial_model(model_to_load, state_dict)

mismatched_keys = _find_mismatched_keys(
state_dict,
Expand Down Expand Up @@ -1933,18 +1931,15 @@ def _find_mismatched_keys(
):
pre_tensor_parallel_split = True
assert loaded_keys is not None, "loaded_keys is not None."

tp_actions = cls.get_tensor_parallel_convert_actions(config, loaded_keys)

# Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors
state_dict = load_state_dict( # debug: branch-5(6.5) load from sharding files // low_cpu_memory(state_dict is None)
state_dict = load_state_dict(
shard_file,
tp_actions if pre_tensor_parallel_split else None,
None if config.quantization_config.is_weight_quantize() else set(expected_keys),
)

if config.quantization_config.is_weight_quantize():
state_dict = convert_to_quantize_state_dict( # debug: sub-branch quantize
state_dict = convert_to_quantize_state_dict(
state_dict,
quantization_linear_list,
config.quantization_config,
Expand All @@ -1965,7 +1960,7 @@ def _find_mismatched_keys(
if config.tensor_parallel_degree > 1 and ".tp" not in shard_file and not pre_tensor_parallel_split:
logger.info("Converting state_dict to Tensor Parallel Format")
# ignore error for multi shard, since only parts of data
state_dict = cls.convert_tensor_parallel( # debug: sub-branch TP
state_dict = cls.convert_tensor_parallel(
None, config, state_dict=state_dict, ignore_error=len(resolved_archive_file) > 1
)
logger.info("Converted state_dict to Tensor Parallel Format")
Expand Down Expand Up @@ -2295,7 +2290,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
f"Starting to convert pytorch weight file<{resolved_archive_file}> to "
f"paddle weight file<{convert_dir}> ..."
)
state_dict = cls.convert( # debug: branch-4 torch
state_dict = cls.convert(
resolved_archive_file,
config,
# cache_dir=os.path.join(cache_dir, pretrained_model_name_or_path, subfolder),
Expand All @@ -2314,9 +2309,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
if not is_sharded and state_dict is None:
# 4. loading non-sharded ckpt from the state dict
if config.tensor_parallel_degree > 1 and resolved_archive_file.endswith("model_state.pdparams"):
state_dict = cls.convert_tensor_parallel(
resolved_archive_file, config
) # debug: branch-3 pdparams转换 + TP
state_dict = cls.convert_tensor_parallel(resolved_archive_file, config)
elif config.tensor_parallel_degree > 1 and resolved_archive_file.endswith("model.safetensors"):
with safe_open(resolved_archive_file, framework="np", device="cpu") as f:
loaded_keys = f.keys()
Expand All @@ -2341,7 +2334,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
tp_actions = cls.get_tensor_parallel_convert_actions(

Check warning on line 2334 in paddlenlp/transformers/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/model_utils.py#L2334

Added line #L2334 was not covered by tests
config, loaded_keys, ignore_params=separate_keys | fused_keys
)
state_dict = load_state_dict(resolved_archive_file, tp_actions) # debug: branch-2 TP
state_dict = load_state_dict(resolved_archive_file, tp_actions)

# apply qkv/gate_up fuse action and tensor-parallel action sequentially
if "attention_qkv_proj" in do_fuse_parameter_list:
Expand All @@ -2354,7 +2347,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
for qkv_name in tp_actions.keys():
state_dict[qkv_name] = tp_actions[qkv_name](state_dict[qkv_name]) # apply tp-action for qkv

Check warning on line 2348 in paddlenlp/transformers/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/model_utils.py#L2347-L2348

Added lines #L2347 - L2348 were not covered by tests
else:
state_dict = load_state_dict(resolved_archive_file) # debug: branch-1 normal
state_dict = load_state_dict(resolved_archive_file)

do_fuse_parameter_list, do_separate_parameter_list = select_fuse_parameter(model, state_dict.keys())
if do_fuse_parameter_list:
Expand All @@ -2375,7 +2368,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
loaded_state_dict_keys = [k for k in state_dict.keys()]

if low_cpu_mem_usage: # or use_keep_in_fp32_modules:
# debug: branch-6 low_cpu_mem_usage on
state_dict = None

# will only support load paddle.Tensor to model.
Expand Down

0 comments on commit 28ed30f

Please sign in to comment.