Merge pull request #729 from hcsolakoglu/fix-ckpt-rotation

Exclude pretrained models from the checkpoint rotation logic
This commit is contained in:
Yushen CHEN
2025-01-27 19:57:05 +08:00
committed by GitHub
3 changed files with 28 additions and 7 deletions

View File

@@ -160,10 +160,14 @@ class Trainer:
return
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{update}.pt")
if self.keep_last_n_checkpoints > 0:
# Updated logic to exclude pretrained model from rotation
checkpoints = [
f
for f in os.listdir(self.checkpoint_path)
if f.startswith("model_") and f.endswith(".pt") and f != "model_last.pt"
if f.startswith("model_")
and not f.startswith("pretrained_") # Exclude pretrained models
and f.endswith(".pt")
and f != "model_last.pt"
]
checkpoints.sort(key=lambda x: int(x.split("_")[1].split(".")[0]))
while len(checkpoints) > self.keep_last_n_checkpoints:
@@ -183,10 +187,24 @@ class Trainer:
if "model_last.pt" in os.listdir(self.checkpoint_path):
latest_checkpoint = "model_last.pt"
else:
latest_checkpoint = sorted(
[f for f in os.listdir(self.checkpoint_path) if f.endswith(".pt")],
key=lambda x: int("".join(filter(str.isdigit, x))),
)[-1]
# Updated to consider pretrained models for loading but prioritize training checkpoints
all_checkpoints = [
f
for f in os.listdir(self.checkpoint_path)
if (f.startswith("model_") or f.startswith("pretrained_")) and f.endswith(".pt")
]
# First try to find regular training checkpoints
training_checkpoints = [f for f in all_checkpoints if f.startswith("model_") and f != "model_last.pt"]
if training_checkpoints:
latest_checkpoint = sorted(
training_checkpoints,
key=lambda x: int("".join(filter(str.isdigit, x))),
)[-1]
else:
# If no training checkpoints, use pretrained model
latest_checkpoint = next(f for f in all_checkpoints if f.startswith("pretrained_"))
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")

View File

@@ -111,7 +111,8 @@ def main():
if not os.path.isdir(checkpoint_path):
os.makedirs(checkpoint_path, exist_ok=True)
file_checkpoint = os.path.join(checkpoint_path, os.path.basename(ckpt_path))
# Change: Add 'pretrained_' prefix to copied model
file_checkpoint = os.path.join(checkpoint_path, "pretrained_" + os.path.basename(ckpt_path))
if not os.path.isfile(file_checkpoint):
shutil.copy2(ckpt_path, file_checkpoint)
print("copy checkpoint for finetune")

View File

@@ -1099,7 +1099,9 @@ def vocab_extend(project_name, symbols, model_type):
dataset_name = name_project.replace("_pinyin", "").replace("_char", "")
new_ckpt_path = os.path.join(path_project_ckpts, dataset_name)
os.makedirs(new_ckpt_path, exist_ok=True)
new_ckpt_file = os.path.join(new_ckpt_path, "model_1200000.pt")
# Add pretrained_ prefix to model when copying for consistency with finetune_cli.py
new_ckpt_file = os.path.join(new_ckpt_path, "pretrained_model_1200000.pt")
size = expand_model_embeddings(ckpt_path, new_ckpt_file, num_new_tokens=vocab_size_new)