diff --git a/finetune_gradio.py b/finetune_gradio.py index e61fc53..062e68d 100644 --- a/finetune_gradio.py +++ b/finetune_gradio.py @@ -1,9 +1,12 @@ import os import sys +import tempfile +import random from transformers import pipeline import gradio as gr import torch +import gc import click import torchaudio from glob import glob @@ -20,11 +23,16 @@ import psutil import platform import subprocess from datasets.arrow_writer import ArrowWriter +from datasets import Dataset as Dataset_ +from api import F5TTS training_process = None system = platform.system() python_executable = sys.executable or "python" +tts_api = None +last_checkpoint = "" +last_device = "" path_data = "data" @@ -240,7 +248,12 @@ def start_training( last_per_steps=800, finetune=True, ): - global training_process + global training_process, tts_api + + if tts_api is not None: + del tts_api + gc.collect() + torch.cuda.empty_cache() path_project = os.path.join(path_data, dataset_name + "_pinyin") @@ -288,7 +301,7 @@ def start_training( training_process = subprocess.Popen(cmd, shell=True) time.sleep(5) - yield "check terminal for wandb", gr.update(interactive=False), gr.update(interactive=True) + yield "train start", gr.update(interactive=False), gr.update(interactive=True) # Wait for the training process to finish training_process.wait() @@ -519,6 +532,17 @@ def calculate_train( path_project = os.path.join(path_data, name_project) file_duraction = os.path.join(path_project, "duration.json") + if not os.path.isfile(file_duraction): + return ( + 1000, + max_samples, + num_warmup_updates, + save_per_updates, + last_per_steps, + "project not found !", + learning_rate, + ) + with open(file_duraction, "r") as file: data = json.load(file) @@ -549,8 +573,8 @@ def calculate_train( else: max_samples = 64 - num_warmup_updates = int(samples * 0.10) - save_per_updates = int(samples * 0.25) + num_warmup_updates = int(samples * 0.05) + save_per_updates = int(samples * 0.10) last_per_steps = int(save_per_updates * 5) max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples) @@ -559,7 +583,7 @@ def calculate_train( last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps) if finetune: - learning_rate = 1e-4 + learning_rate = 1e-5 else: learning_rate = 7.5e-5 @@ -611,6 +635,7 @@ def vocab_check(project_name): sp = item.split("|") if len(sp) != 2: continue + text = sp[1].lower().strip() for t in text: @@ -625,6 +650,80 @@ def vocab_check(project_name): return info +def get_random_sample_prepare(project_name): + name_project = project_name + "_pinyin" + path_project = os.path.join(path_data, name_project) + file_arrow = os.path.join(path_project, "raw.arrow") + if not os.path.isfile(file_arrow): + return "", None + dataset = Dataset_.from_file(file_arrow) + random_sample = dataset.shuffle(seed=random.randint(0, 1000)).select([0]) + text = "[" + " , ".join(["' " + t + " '" for t in random_sample["text"][0]]) + "]" + audio_path = random_sample["audio_path"][0] + return text, audio_path + + +def get_random_sample_transcribe(project_name): + name_project = project_name + "_pinyin" + path_project = os.path.join(path_data, name_project) + file_metadata = os.path.join(path_project, "metadata.csv") + if not os.path.isfile(file_metadata): + return "", None + + data = "" + with open(file_metadata, "r", encoding="utf-8") as f: + data = f.read() + + list_data = [] + for item in data.split("\n"): + sp = item.split("|") + if len(sp) != 2: + continue + list_data.append([os.path.join(path_project, "wavs", sp[0] + ".wav"), sp[1]]) + + if list_data == []: + return "", None + + random_item = random.choice(list_data) + + return random_item[1], random_item[0] + + +def get_random_sample_infer(project_name): + text, audio = get_random_sample_transcribe(project_name) + return ( + text, + text, + audio, + ) + + +def infer(project_name, file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step): + global last_checkpoint, last_device, tts_api + + if not os.path.isfile(file_checkpoint): + return None + + if training_process is not None: + device_test = "cpu" + else: + device_test = None + + if last_checkpoint != file_checkpoint or last_device != device_test: + if last_checkpoint != file_checkpoint: + last_checkpoint = file_checkpoint + if last_device != device_test: + last_device = device_test + + tts_api = F5TTS(model_type=exp_name, ckpt_file=file_checkpoint, device=device_test) + + print("update", device_test, file_checkpoint) + + with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: + tts_api.infer(gen_text=gen_text, ref_text=ref_text, ref_file=ref_audio, nfe_step=nfe_step, file_wave=f.name) + return f.name + + with gr.Blocks() as app: with gr.Row(): project_name = gr.Textbox(label="project name", value="my_speak") @@ -661,6 +760,18 @@ with gr.Blocks() as app: ) ch_manual.change(fn=check_user, inputs=[ch_manual], outputs=[audio_speaker, mark_info_transcribe]) + random_sample_transcribe = gr.Button("random sample") + + with gr.Row(): + random_text_transcribe = gr.Text(label="Text") + random_audio_transcribe = gr.Audio(label="Audio", type="filepath") + + random_sample_transcribe.click( + fn=get_random_sample_transcribe, + inputs=[project_name], + outputs=[random_text_transcribe, random_audio_transcribe], + ) + with gr.TabItem("prepare Data"): gr.Markdown( """```plaintext @@ -687,6 +798,16 @@ with gr.Blocks() as app: txt_info_prepare = gr.Text(label="info", value="") bt_prepare.click(fn=create_metadata, inputs=[project_name], outputs=[txt_info_prepare]) + random_sample_prepare = gr.Button("random sample") + + with gr.Row(): + random_text_prepare = gr.Text(label="Pinyin") + random_audio_prepare = gr.Audio(label="Audio", type="filepath") + + random_sample_prepare.click( + fn=get_random_sample_prepare, inputs=[project_name], outputs=[random_text_prepare, random_audio_prepare] + ) + with gr.TabItem("train Data"): with gr.Row(): bt_calculate = bt_create = gr.Button("Auto Settings") @@ -696,11 +817,11 @@ with gr.Blocks() as app: with gr.Row(): exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base") - learning_rate = gr.Number(label="Learning Rate", value=1e-4, step=1e-4) + learning_rate = gr.Number(label="Learning Rate", value=1e-5, step=1e-5) with gr.Row(): batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=1000) - max_samples = gr.Number(label="Max Samples", value=16) + max_samples = gr.Number(label="Max Samples", value=64) with gr.Row(): grad_accumulation_steps = gr.Number(label="Gradient Accumulation Steps", value=1) @@ -778,6 +899,28 @@ with gr.Blocks() as app: txt_info_check = gr.Text(label="info", value="") check_button.click(fn=vocab_check, inputs=[project_name], outputs=[txt_info_check]) + with gr.TabItem("test model"): + exp_name = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS") + nfe_step = gr.Number(label="n_step", value=32) + file_checkpoint_pt = gr.Textbox(label="Checkpoint", value="") + + random_sample_infer = gr.Button("random sample") + + ref_text = gr.Textbox(label="ref text") + ref_audio = gr.Audio(label="audio ref", type="filepath") + gen_text = gr.Textbox(label="gen text") + random_sample_infer.click( + fn=get_random_sample_infer, inputs=[project_name], outputs=[ref_text, gen_text, ref_audio] + ) + check_button_infer = gr.Button("infer") + gen_audio = gr.Audio(label="audio gen", type="filepath") + + check_button_infer.click( + fn=infer, + inputs=[project_name, file_checkpoint_pt, exp_name, ref_text, ref_audio, gen_text, nfe_step], + outputs=[gen_audio], + ) + @click.command() @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")