|
25 | 25 | )
|
26 | 26 | from torchtune import config
|
27 | 27 |
|
| 28 | +from torchtune.training.checkpointing._utils import ( |
| 29 | + ADAPTER_MODEL_FNAME, |
| 30 | + get_largest_iter_folder, |
| 31 | + RECIPE_STATE_DIRNAME, |
| 32 | + safe_torch_load, |
| 33 | + SHARD_FNAME, |
| 34 | +) |
| 35 | + |
28 | 36 |
|
29 | 37 | class TestLoRADPOSingleDeviceRecipe:
|
30 | 38 | def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2):
|
@@ -99,18 +107,21 @@ def test_training_state_on_resume(
|
99 | 107 |
|
100 | 108 | resumed_log_dir = (tmpdir / "resumed/").mkdir()
|
101 | 109 | resumed_log_file = gen_log_file_name(resumed_log_dir)
|
| 110 | + |
102 | 111 | # Resume training
|
| 112 | + epoch_folder = get_largest_iter_folder(tmpdir) |
| 113 | + epoch_folder_minus_one = f"epoch_{int(epoch_folder.split('_')[-1]) - 1}" |
103 | 114 | cmd_2 = f"""
|
104 | 115 | tune run lora_dpo_single_device \
|
105 | 116 | --config llama2/7B_lora_dpo_single_device \
|
106 | 117 | output_dir={tmpdir} \
|
107 | 118 | model.lora_attn_modules=['q_proj','v_proj'] \
|
108 | 119 | model.apply_lora_to_mlp=False \
|
109 | 120 | checkpointer=torchtune.training.FullModelHFCheckpointer \
|
110 |
| - checkpointer.checkpoint_dir={tmpdir} \ |
| 121 | + checkpointer.checkpoint_dir={ckpt_dir} \ |
111 | 122 | checkpointer.checkpoint_files=[{ckpt_path}]\
|
112 |
| - checkpointer.adapter_checkpoint={os.path.join(tmpdir, "adapter_0.pt")} |
113 |
| - checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")} |
| 123 | + checkpointer.adapter_checkpoint={os.path.join(tmpdir, epoch_folder_minus_one, f"{ADAPTER_MODEL_FNAME}.pt")} |
| 124 | + checkpointer.recipe_checkpoint={os.path.join(tmpdir, RECIPE_STATE_DIRNAME, "recipe_state.pt")} |
114 | 125 | checkpointer.output_dir={tmpdir} \
|
115 | 126 | checkpointer.model_type=LLAMA2 \
|
116 | 127 | resume_from_checkpoint=True \
|
@@ -177,17 +188,24 @@ def test_save_and_load_merged_weights(self, tmpdir, monkeypatch):
|
177 | 188 | )
|
178 | 189 |
|
179 | 190 | # Load base model and trained adapter weights into LoRA model and call fwd
|
180 |
| - with open(f"{tmpdir}/adapter_1.pt", "rb") as f: |
181 |
| - lora_sd = torch.load(f, weights_only=True) |
| 191 | + epoch_folder = get_largest_iter_folder(tmpdir) |
| 192 | + adpt_path = os.path.join(tmpdir, epoch_folder, f"{ADAPTER_MODEL_FNAME}.pt") |
| 193 | + lora_sd = safe_torch_load(adpt_path, weights_only=True) |
| 194 | + |
182 | 195 | with open(ckpt_path, "rb") as f:
|
183 | 196 | base_model_sd = torch.load(f, weights_only=True)
|
184 | 197 | lora_model.load_state_dict(lora_sd, strict=False)
|
185 | 198 | lora_model.load_state_dict(base_model_sd, strict=False)
|
186 | 199 | baseline_out = lora_model(inputs)
|
187 | 200 |
|
188 | 201 | # Load merged final ckpt directly into llama2 and call fwd
|
189 |
| - with open(f"{tmpdir}/torchtune_model_1.pt", "rb") as f: |
190 |
| - sd = torch.load(f, weights_only=True) |
| 202 | + suffix = ".bin" |
| 203 | + model_ckpt_fname = ( |
| 204 | + SHARD_FNAME.format(cpt_idx="1".zfill(5), num_shards="1".zfill(5)) + suffix |
| 205 | + ) |
| 206 | + model_path = os.path.join(tmpdir, epoch_folder, model_ckpt_fname) |
| 207 | + sd = safe_torch_load(model_path, weights_only=True) |
| 208 | + |
191 | 209 | llama2_model.load_state_dict(sd)
|
192 | 210 | merged_ckpt_out = llama2_model(inputs)
|
193 | 211 | torch.testing.assert_close(baseline_out, merged_ckpt_out, rtol=1e-5, atol=1e-5)
|
0 commit comments