Skip to content

Commit 424ffc3

Browse files
felipemello1Felipe Mellovancoykendall
authored
Update checkpointing directory (pytorch#2074)
Co-authored-by: Felipe Mello <felipemello@fb.com> Co-authored-by: vancoyendall <vancoykendall@gmail.com>
1 parent f8563dd commit 424ffc3

19 files changed

+869
-256
lines changed

recipes/configs/llama3_2/1B_full_single_device.yaml

+1-5
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ output_dir: /tmp/torchtune/llama3_2_1B/full_single_device # /tmp may be deleted
2525
# Tokenizer
2626
tokenizer:
2727
_component_: torchtune.models.llama3.llama3_tokenizer
28-
path: /tmp/Llama-3.2-1B-Instruct/original/tokenizer.model
28+
path: /tmp/Llama-3.2-1B-Instruct/original/tokenizer.model
2929
max_seq_len: null
3030

3131
# Dataset
@@ -35,10 +35,6 @@ dataset:
3535
seed: null
3636
shuffle: True
3737

38-
# Model Arguments
39-
model:
40-
_component_: torchtune.models.llama3_2.llama3_2_1b
41-
4238
checkpointer:
4339
_component_: torchtune.training.FullModelHFCheckpointer
4440
checkpoint_dir: /tmp/Llama-3.2-1B-Instruct/

recipes/configs/llama3_2/1B_lora_single_device.yaml

+1-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
#
1616
# This config works only for training on single device.
1717

18-
1918
output_dir: /tmp/torchtune/llama3_2_1B/lora_single_device # /tmp may be deleted by your system. Change it to your preference.
2019

2120
# Model Arguments
@@ -37,7 +36,7 @@ checkpointer:
3736
_component_: torchtune.training.FullModelHFCheckpointer
3837
checkpoint_dir: /tmp/Llama-3.2-1B-Instruct/
3938
checkpoint_files: [
40-
model.safetensors
39+
model.safetensors
4140
]
4241
recipe_checkpoint: null
4342
output_dir: ${output_dir}

recipes/lora_finetune_single_device.py

+1
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,7 @@ def _setup_model(
457457
)
458458
else:
459459
lora_missing, lora_unexpected = None, None
460+
460461
validate_missing_and_unexpected_for_lora(
461462
lora_attn_modules=self._lora_attn_modules,
462463
apply_lora_to_mlp=self._apply_lora_to_mlp,

tests/recipes/test_full_finetune_distributed.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@
2727
TOKENIZER_PATHS,
2828
)
2929

30+
from torchtune.training.checkpointing._utils import (
31+
get_largest_iter_folder,
32+
RECIPE_STATE_DIRNAME,
33+
SHARD_FNAME,
34+
)
35+
3036

3137
class TestFullFinetuneDistributedRecipe:
3238
def _get_test_config_overrides(self):
@@ -141,7 +147,6 @@ def test_training_state_on_resume(
141147
tokenizer_path = Path(TOKENIZER_PATHS[model_type])
142148
ckpt_dir = ckpt_path.parent
143149
log_file = gen_log_file_name(tmpdir)
144-
145150
# Config file needed for model conversion.
146151
# Create a second copy for training resume
147152
write_hf_ckpt_config(ckpt_dir)
@@ -171,16 +176,22 @@ def test_training_state_on_resume(
171176
runpy.run_path(TUNE_PATH, run_name="__main__")
172177

173178
# Resume training
179+
epoch_folder = get_largest_iter_folder(tmpdir)
180+
epoch_folder_minus_one = f"epoch_{int(epoch_folder.split('_')[-1]) - 1}"
181+
suffix = ".safetensors" if ckpt_type == "hf" else ".bin"
182+
model_ckpt_fname = (
183+
SHARD_FNAME.format(cpt_idx="1".zfill(5), num_shards="1".zfill(5)) + suffix
184+
)
174185
cmd_2 = f"""
175186
tune run --nnodes 1 --nproc_per_node 2 full_finetune_distributed \
176187
--config {config} \
177188
batch_size={micro_batch_size} \
178189
gradient_accumulation_steps={gradient_accumulation_steps} \
179190
output_dir={tmpdir} \
180191
checkpointer._component_={ckpt_component} \
181-
checkpointer.checkpoint_dir='{tmpdir}' \
182-
checkpointer.checkpoint_files=[{os.path.join(tmpdir, "torchtune_model_0.pt")}]\
183-
checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")}\
192+
checkpointer.checkpoint_dir='{ckpt_dir}' \
193+
checkpointer.checkpoint_files=[{os.path.join(epoch_folder_minus_one, model_ckpt_fname)}]\
194+
checkpointer.recipe_checkpoint={os.path.join(RECIPE_STATE_DIRNAME, "recipe_state.pt")}\
184195
checkpointer.output_dir={tmpdir} \
185196
checkpointer.model_type={model_type.upper()} \
186197
tokenizer.path='{tokenizer_path}' \

tests/recipes/test_full_finetune_single_device.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@
2929
TOKENIZER_PATHS,
3030
)
3131

32+
from torchtune.training.checkpointing._utils import (
33+
get_largest_iter_folder,
34+
RECIPE_STATE_DIRNAME,
35+
SHARD_FNAME,
36+
)
37+
3238

3339
class TestFullFinetuneSingleDeviceRecipe:
3440
def _get_test_config_overrides(self):
@@ -173,15 +179,21 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch):
173179
runpy.run_path(TUNE_PATH, run_name="__main__")
174180

175181
# Resume training
182+
epoch_folder = get_largest_iter_folder(tmpdir)
183+
epoch_folder_minus_one = f"epoch_{int(epoch_folder.split('_')[-1]) - 1}"
184+
suffix = ".safetensors"
185+
model_ckpt_fname = (
186+
SHARD_FNAME.format(cpt_idx="1".zfill(5), num_shards="1".zfill(5)) + suffix
187+
)
176188
cmd_2 = f"""
177189
tune run full_finetune_single_device \
178190
--config llama2/7B_full_low_memory \
179191
batch_size=8 \
180192
output_dir={tmpdir} \
181193
checkpointer._component_=torchtune.training.FullModelHFCheckpointer \
182-
checkpointer.checkpoint_dir={tmpdir} \
183-
checkpointer.checkpoint_files=[{os.path.join(tmpdir, "hf_model_0001_0.pt")}]\
184-
checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")}\
194+
checkpointer.checkpoint_dir={ckpt_dir} \
195+
checkpointer.checkpoint_files=[{os.path.join(epoch_folder_minus_one, model_ckpt_fname)}]\
196+
checkpointer.recipe_checkpoint={os.path.join(RECIPE_STATE_DIRNAME, "recipe_state.pt")}\
185197
checkpointer.output_dir={tmpdir} \
186198
checkpointer.model_type=LLAMA2 \
187199
tokenizer.path=/tmp/test-artifacts/tokenizer.model \

tests/recipes/test_knowledge_distillation_distributed.py

+24-7
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,14 @@
2828
)
2929
from torchtune import config
3030

31+
from torchtune.training.checkpointing._utils import (
32+
ADAPTER_MODEL_FNAME,
33+
get_largest_iter_folder,
34+
RECIPE_STATE_DIRNAME,
35+
safe_torch_load,
36+
SHARD_FNAME,
37+
)
38+
3139

3240
class TestKDDistributedRecipe:
3341
def _get_test_config_overrides(self, epochs: int = 2):
@@ -146,15 +154,17 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch):
146154
runpy.run_path(TUNE_PATH, run_name="__main__")
147155

148156
# Resume training
157+
epoch_folder = get_largest_iter_folder(tmpdir)
158+
epoch_folder_minus_one = f"epoch_{int(epoch_folder.split('_')[-1]) - 1}"
149159
cmd_2 = f"""
150160
tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed \
151161
--config llama3_2/8B_to_1B_KD_lora_distributed \
152162
output_dir={tmpdir} \
153163
checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \
154-
checkpointer.checkpoint_dir={tmpdir} \
164+
checkpointer.checkpoint_dir={ckpt_dir} \
155165
checkpointer.checkpoint_files=[{ckpt_path}]\
156-
checkpointer.adapter_checkpoint={os.path.join(tmpdir, "adapter_0.pt")}
157-
checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")}
166+
checkpointer.adapter_checkpoint={os.path.join(epoch_folder_minus_one, f"{ADAPTER_MODEL_FNAME}.pt")}
167+
checkpointer.recipe_checkpoint={os.path.join(RECIPE_STATE_DIRNAME, "recipe_state.pt")}
158168
checkpointer.output_dir={tmpdir} \
159169
teacher_checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \
160170
teacher_checkpointer.checkpoint_dir='{ckpt_dir}' \
@@ -238,17 +248,24 @@ def test_save_and_load_merged_weights(self, tmpdir, monkeypatch):
238248
)
239249

240250
# Load base model and trained adapter weights into LoRA model and call fwd
241-
with open(f"{tmpdir}/adapter_1.pt", "rb") as f:
242-
lora_sd = torch.load(f, weights_only=True)
251+
epoch_folder = get_largest_iter_folder(tmpdir)
252+
adpt_path = os.path.join(tmpdir, epoch_folder, f"{ADAPTER_MODEL_FNAME}.pt")
253+
lora_sd = safe_torch_load(adpt_path, weights_only=True)
254+
243255
with open(ckpt_path, "rb") as f:
244256
base_model_sd = torch.load(f, weights_only=True)
245257
lora_model.load_state_dict(lora_sd, strict=False)
246258
lora_model.load_state_dict(base_model_sd, strict=False)
247259
baseline_out = lora_model(inputs)
248260

249261
# Load merged final ckpt directly into 3 and call fwd
250-
with open(f"{tmpdir}/torchtune_model_1.pt", "rb") as f:
251-
sd = torch.load(f, weights_only=True)
262+
suffix = ".safetensors" if ckpt_type == "hf" else ".bin"
263+
model_ckpt_fname = (
264+
SHARD_FNAME.format(cpt_idx="1".zfill(5), num_shards="1".zfill(5)) + suffix
265+
)
266+
model_path = os.path.join(tmpdir, epoch_folder, model_ckpt_fname)
267+
sd = safe_torch_load(model_path, weights_only=True)
268+
252269
llama3_model.load_state_dict(sd)
253270
merged_ckpt_out = llama3_model(inputs)
254271
torch.testing.assert_close(baseline_out, merged_ckpt_out, rtol=1e-5, atol=1e-5)

tests/recipes/test_knowledge_distillation_single_device.py

+24-7
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,14 @@
2727
)
2828
from torchtune import config
2929

30+
from torchtune.training.checkpointing._utils import (
31+
ADAPTER_MODEL_FNAME,
32+
get_largest_iter_folder,
33+
RECIPE_STATE_DIRNAME,
34+
safe_torch_load,
35+
SHARD_FNAME,
36+
)
37+
3038

3139
class TestKDSingleDeviceRecipe:
3240
def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2):
@@ -184,15 +192,17 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch):
184192
runpy.run_path(TUNE_PATH, run_name="__main__")
185193

186194
# Resume training
195+
epoch_folder = get_largest_iter_folder(tmpdir)
196+
epoch_folder_minus_one = f"epoch_{int(epoch_folder.split('_')[-1]) - 1}"
187197
cmd_2 = f"""
188198
tune run knowledge_distillation_single_device \
189199
--config qwen2/1.5_to_0.5B_KD_lora_single_device \
190200
output_dir={tmpdir} \
191201
checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \
192-
checkpointer.checkpoint_dir={tmpdir} \
202+
checkpointer.checkpoint_dir={ckpt_dir} \
193203
checkpointer.checkpoint_files=[{ckpt_path}]\
194-
checkpointer.adapter_checkpoint={os.path.join(tmpdir, "adapter_0.pt")}
195-
checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")}
204+
checkpointer.adapter_checkpoint={os.path.join(epoch_folder_minus_one, f"{ADAPTER_MODEL_FNAME}.pt")}
205+
checkpointer.recipe_checkpoint={os.path.join(RECIPE_STATE_DIRNAME, "recipe_state.pt")}
196206
checkpointer.output_dir={tmpdir} \
197207
checkpointer.model_type=LLAMA3 \
198208
teacher_checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \
@@ -292,17 +302,24 @@ def test_save_and_load_merged_weights(self, tmpdir, monkeypatch):
292302
)
293303

294304
# Load base model and trained adapter weights into LoRA model and call fwd
295-
with open(f"{tmpdir}/adapter_1.pt", "rb") as f:
296-
lora_sd = torch.load(f, weights_only=True)
305+
epoch_folder = get_largest_iter_folder(tmpdir)
306+
adpt_path = os.path.join(tmpdir, epoch_folder, f"{ADAPTER_MODEL_FNAME}.pt")
307+
lora_sd = safe_torch_load(adpt_path, weights_only=True)
308+
297309
with open(ckpt_path, "rb") as f:
298310
base_model_sd = torch.load(f, weights_only=True)
299311
lora_model.load_state_dict(lora_sd, strict=False)
300312
lora_model.load_state_dict(base_model_sd, strict=False)
301313
baseline_out = lora_model(inputs)
302314

303315
# Load merged final ckpt directly into 3 and call fwd
304-
with open(f"{tmpdir}/torchtune_model_1.pt", "rb") as f:
305-
sd = torch.load(f, weights_only=True)
316+
suffix = ".safetensors" if ckpt_type == "hf" else ".bin"
317+
model_ckpt_fname = (
318+
SHARD_FNAME.format(cpt_idx="1".zfill(5), num_shards="1".zfill(5)) + suffix
319+
)
320+
model_path = os.path.join(tmpdir, epoch_folder, model_ckpt_fname)
321+
sd = safe_torch_load(model_path, weights_only=True)
322+
306323
llama3_model.load_state_dict(sd)
307324
merged_ckpt_out = llama3_model(inputs)
308325
torch.testing.assert_close(baseline_out, merged_ckpt_out, rtol=1e-5, atol=1e-5)

tests/recipes/test_lora_dpo_single_device.py

+25-7
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@
2525
)
2626
from torchtune import config
2727

28+
from torchtune.training.checkpointing._utils import (
29+
ADAPTER_MODEL_FNAME,
30+
get_largest_iter_folder,
31+
RECIPE_STATE_DIRNAME,
32+
safe_torch_load,
33+
SHARD_FNAME,
34+
)
35+
2836

2937
class TestLoRADPOSingleDeviceRecipe:
3038
def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2):
@@ -99,18 +107,21 @@ def test_training_state_on_resume(
99107

100108
resumed_log_dir = (tmpdir / "resumed/").mkdir()
101109
resumed_log_file = gen_log_file_name(resumed_log_dir)
110+
102111
# Resume training
112+
epoch_folder = get_largest_iter_folder(tmpdir)
113+
epoch_folder_minus_one = f"epoch_{int(epoch_folder.split('_')[-1]) - 1}"
103114
cmd_2 = f"""
104115
tune run lora_dpo_single_device \
105116
--config llama2/7B_lora_dpo_single_device \
106117
output_dir={tmpdir} \
107118
model.lora_attn_modules=['q_proj','v_proj'] \
108119
model.apply_lora_to_mlp=False \
109120
checkpointer=torchtune.training.FullModelHFCheckpointer \
110-
checkpointer.checkpoint_dir={tmpdir} \
121+
checkpointer.checkpoint_dir={ckpt_dir} \
111122
checkpointer.checkpoint_files=[{ckpt_path}]\
112-
checkpointer.adapter_checkpoint={os.path.join(tmpdir, "adapter_0.pt")}
113-
checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")}
123+
checkpointer.adapter_checkpoint={os.path.join(tmpdir, epoch_folder_minus_one, f"{ADAPTER_MODEL_FNAME}.pt")}
124+
checkpointer.recipe_checkpoint={os.path.join(tmpdir, RECIPE_STATE_DIRNAME, "recipe_state.pt")}
114125
checkpointer.output_dir={tmpdir} \
115126
checkpointer.model_type=LLAMA2 \
116127
resume_from_checkpoint=True \
@@ -177,17 +188,24 @@ def test_save_and_load_merged_weights(self, tmpdir, monkeypatch):
177188
)
178189

179190
# Load base model and trained adapter weights into LoRA model and call fwd
180-
with open(f"{tmpdir}/adapter_1.pt", "rb") as f:
181-
lora_sd = torch.load(f, weights_only=True)
191+
epoch_folder = get_largest_iter_folder(tmpdir)
192+
adpt_path = os.path.join(tmpdir, epoch_folder, f"{ADAPTER_MODEL_FNAME}.pt")
193+
lora_sd = safe_torch_load(adpt_path, weights_only=True)
194+
182195
with open(ckpt_path, "rb") as f:
183196
base_model_sd = torch.load(f, weights_only=True)
184197
lora_model.load_state_dict(lora_sd, strict=False)
185198
lora_model.load_state_dict(base_model_sd, strict=False)
186199
baseline_out = lora_model(inputs)
187200

188201
# Load merged final ckpt directly into llama2 and call fwd
189-
with open(f"{tmpdir}/torchtune_model_1.pt", "rb") as f:
190-
sd = torch.load(f, weights_only=True)
202+
suffix = ".bin"
203+
model_ckpt_fname = (
204+
SHARD_FNAME.format(cpt_idx="1".zfill(5), num_shards="1".zfill(5)) + suffix
205+
)
206+
model_path = os.path.join(tmpdir, epoch_folder, model_ckpt_fname)
207+
sd = safe_torch_load(model_path, weights_only=True)
208+
191209
llama2_model.load_state_dict(sd)
192210
merged_ckpt_out = llama2_model(inputs)
193211
torch.testing.assert_close(baseline_out, merged_ckpt_out, rtol=1e-5, atol=1e-5)

tests/recipes/test_lora_finetune_distributed.py

+25-7
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,14 @@
2828
)
2929
from torchtune import config
3030

31+
from torchtune.training.checkpointing._utils import (
32+
ADAPTER_MODEL_FNAME,
33+
get_largest_iter_folder,
34+
RECIPE_STATE_DIRNAME,
35+
safe_torch_load,
36+
SHARD_FNAME,
37+
)
38+
3139

3240
class TestLoRAFinetuneDistributedRecipe:
3341
def _get_test_config_overrides(self):
@@ -169,6 +177,8 @@ def test_training_state_on_resume(
169177
runpy.run_path(TUNE_PATH, run_name="__main__")
170178

171179
# Resume training
180+
epoch_folder = get_largest_iter_folder(tmpdir)
181+
epoch_folder_minus_one = f"epoch_{int(epoch_folder.split('_')[-1]) - 1}"
172182
cmd_2 = f"""
173183
tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed \
174184
--config {config} \
@@ -178,10 +188,10 @@ def test_training_state_on_resume(
178188
model.lora_attn_modules=['q_proj','v_proj'] \
179189
model.apply_lora_to_mlp=False \
180190
checkpointer._component_={ckpt_component} \
181-
checkpointer.checkpoint_dir={tmpdir} \
191+
checkpointer.checkpoint_dir={ckpt_dir} \
182192
checkpointer.checkpoint_files=[{ckpt_path}]\
183-
checkpointer.adapter_checkpoint={os.path.join(tmpdir, "adapter_0.pt")}
184-
checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")}
193+
checkpointer.adapter_checkpoint={os.path.join(epoch_folder_minus_one, f"{ADAPTER_MODEL_FNAME}.pt")}
194+
checkpointer.recipe_checkpoint={os.path.join(RECIPE_STATE_DIRNAME, "recipe_state.pt")}
185195
checkpointer.output_dir={tmpdir} \
186196
checkpointer.model_type={model_type.upper()} \
187197
tokenizer.path='{tokenizer_path}' \
@@ -259,17 +269,25 @@ def test_save_and_load_merged_weights(
259269
model = config.instantiate(OmegaConf.from_dotlist(base_config).model)
260270

261271
# Load base model and trained adapter weights into LoRA model and call fwd
262-
with open(f"{tmpdir}/adapter_1.pt", "rb") as f:
263-
lora_sd = torch.load(f, weights_only=True)
272+
epoch_folder = get_largest_iter_folder(tmpdir)
273+
adpt_path = os.path.join(tmpdir, epoch_folder, f"{ADAPTER_MODEL_FNAME}.pt")
274+
lora_sd = safe_torch_load(adpt_path, weights_only=True)
275+
264276
with open(ckpt_path, "rb") as f:
265277
base_model_sd = torch.load(f, weights_only=True)
278+
266279
lora_model.load_state_dict(lora_sd, strict=False)
267280
lora_model.load_state_dict(base_model_sd, strict=False)
268281
baseline_out = lora_model(inputs)
269282

270283
# Load merged final ckpt directly into model and call fwd
271-
with open(f"{tmpdir}/torchtune_model_1.pt", "rb") as f:
272-
sd = torch.load(f, weights_only=True)
284+
suffix = ".safetensors" if ckpt_type == "hf" else ".bin"
285+
model_ckpt_fname = (
286+
SHARD_FNAME.format(cpt_idx="1".zfill(5), num_shards="1".zfill(5)) + suffix
287+
)
288+
model_path = os.path.join(tmpdir, epoch_folder, model_ckpt_fname)
289+
sd = safe_torch_load(model_path, weights_only=True)
290+
273291
model.load_state_dict(sd)
274292
merged_ckpt_out = model(inputs)
275293

0 commit comments

Comments
 (0)