mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-12 15:50:07 -08:00
Merge pull request #729 from hcsolakoglu/fix-ckpt-rotation
Exclude pretrained models from the checkpoint rotation logic
This commit is contained in:
@@ -160,10 +160,14 @@ class Trainer:
|
||||
return
|
||||
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{update}.pt")
|
||||
if self.keep_last_n_checkpoints > 0:
|
||||
# Updated logic to exclude pretrained model from rotation
|
||||
checkpoints = [
|
||||
f
|
||||
for f in os.listdir(self.checkpoint_path)
|
||||
if f.startswith("model_") and f.endswith(".pt") and f != "model_last.pt"
|
||||
if f.startswith("model_")
|
||||
and not f.startswith("pretrained_") # Exclude pretrained models
|
||||
and f.endswith(".pt")
|
||||
and f != "model_last.pt"
|
||||
]
|
||||
checkpoints.sort(key=lambda x: int(x.split("_")[1].split(".")[0]))
|
||||
while len(checkpoints) > self.keep_last_n_checkpoints:
|
||||
@@ -183,10 +187,24 @@ class Trainer:
|
||||
if "model_last.pt" in os.listdir(self.checkpoint_path):
|
||||
latest_checkpoint = "model_last.pt"
|
||||
else:
|
||||
latest_checkpoint = sorted(
|
||||
[f for f in os.listdir(self.checkpoint_path) if f.endswith(".pt")],
|
||||
key=lambda x: int("".join(filter(str.isdigit, x))),
|
||||
)[-1]
|
||||
# Updated to consider pretrained models for loading but prioritize training checkpoints
|
||||
all_checkpoints = [
|
||||
f
|
||||
for f in os.listdir(self.checkpoint_path)
|
||||
if (f.startswith("model_") or f.startswith("pretrained_")) and f.endswith(".pt")
|
||||
]
|
||||
|
||||
# First try to find regular training checkpoints
|
||||
training_checkpoints = [f for f in all_checkpoints if f.startswith("model_") and f != "model_last.pt"]
|
||||
if training_checkpoints:
|
||||
latest_checkpoint = sorted(
|
||||
training_checkpoints,
|
||||
key=lambda x: int("".join(filter(str.isdigit, x))),
|
||||
)[-1]
|
||||
else:
|
||||
# If no training checkpoints, use pretrained model
|
||||
latest_checkpoint = next(f for f in all_checkpoints if f.startswith("pretrained_"))
|
||||
|
||||
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
|
||||
checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
|
||||
|
||||
|
||||
@@ -111,7 +111,8 @@ def main():
|
||||
if not os.path.isdir(checkpoint_path):
|
||||
os.makedirs(checkpoint_path, exist_ok=True)
|
||||
|
||||
file_checkpoint = os.path.join(checkpoint_path, os.path.basename(ckpt_path))
|
||||
# Change: Add 'pretrained_' prefix to copied model
|
||||
file_checkpoint = os.path.join(checkpoint_path, "pretrained_" + os.path.basename(ckpt_path))
|
||||
if not os.path.isfile(file_checkpoint):
|
||||
shutil.copy2(ckpt_path, file_checkpoint)
|
||||
print("copy checkpoint for finetune")
|
||||
|
||||
@@ -1099,7 +1099,9 @@ def vocab_extend(project_name, symbols, model_type):
|
||||
dataset_name = name_project.replace("_pinyin", "").replace("_char", "")
|
||||
new_ckpt_path = os.path.join(path_project_ckpts, dataset_name)
|
||||
os.makedirs(new_ckpt_path, exist_ok=True)
|
||||
new_ckpt_file = os.path.join(new_ckpt_path, "model_1200000.pt")
|
||||
|
||||
# Add pretrained_ prefix to model when copying for consistency with finetune_cli.py
|
||||
new_ckpt_file = os.path.join(new_ckpt_path, "pretrained_model_1200000.pt")
|
||||
|
||||
size = expand_model_embeddings(ckpt_path, new_ckpt_file, num_new_tokens=vocab_size_new)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user