Merge branch 'main' of github.com:SWivid/F5-TTS into main

This commit is contained in:
SWivid
2024-10-22 01:16:32 +08:00

View File

@@ -1,9 +1,12 @@
import os
import sys
import tempfile
import random
from transformers import pipeline
import gradio as gr
import torch
import gc
import click
import torchaudio
from glob import glob
@@ -20,11 +23,16 @@ import psutil
import platform
import subprocess
from datasets.arrow_writer import ArrowWriter
from datasets import Dataset as Dataset_
from api import F5TTS
training_process = None
system = platform.system()
python_executable = sys.executable or "python"
tts_api = None
last_checkpoint = ""
last_device = ""
path_data = "data"
@@ -240,7 +248,12 @@ def start_training(
last_per_steps=800,
finetune=True,
):
global training_process
global training_process, tts_api
if tts_api is not None:
del tts_api
gc.collect()
torch.cuda.empty_cache()
path_project = os.path.join(path_data, dataset_name + "_pinyin")
@@ -288,7 +301,7 @@ def start_training(
training_process = subprocess.Popen(cmd, shell=True)
time.sleep(5)
yield "check terminal for wandb", gr.update(interactive=False), gr.update(interactive=True)
yield "train start", gr.update(interactive=False), gr.update(interactive=True)
# Wait for the training process to finish
training_process.wait()
@@ -519,6 +532,17 @@ def calculate_train(
path_project = os.path.join(path_data, name_project)
file_duraction = os.path.join(path_project, "duration.json")
if not os.path.isfile(file_duraction):
return (
1000,
max_samples,
num_warmup_updates,
save_per_updates,
last_per_steps,
"project not found !",
learning_rate,
)
with open(file_duraction, "r") as file:
data = json.load(file)
@@ -549,8 +573,8 @@ def calculate_train(
else:
max_samples = 64
num_warmup_updates = int(samples * 0.10)
save_per_updates = int(samples * 0.25)
num_warmup_updates = int(samples * 0.05)
save_per_updates = int(samples * 0.10)
last_per_steps = int(save_per_updates * 5)
max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples)
@@ -559,7 +583,7 @@ def calculate_train(
last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps)
if finetune:
learning_rate = 1e-4
learning_rate = 1e-5
else:
learning_rate = 7.5e-5
@@ -611,6 +635,7 @@ def vocab_check(project_name):
sp = item.split("|")
if len(sp) != 2:
continue
text = sp[1].lower().strip()
for t in text:
@@ -625,6 +650,80 @@ def vocab_check(project_name):
return info
def get_random_sample_prepare(project_name):
name_project = project_name + "_pinyin"
path_project = os.path.join(path_data, name_project)
file_arrow = os.path.join(path_project, "raw.arrow")
if not os.path.isfile(file_arrow):
return "", None
dataset = Dataset_.from_file(file_arrow)
random_sample = dataset.shuffle(seed=random.randint(0, 1000)).select([0])
text = "[" + " , ".join(["' " + t + " '" for t in random_sample["text"][0]]) + "]"
audio_path = random_sample["audio_path"][0]
return text, audio_path
def get_random_sample_transcribe(project_name):
name_project = project_name + "_pinyin"
path_project = os.path.join(path_data, name_project)
file_metadata = os.path.join(path_project, "metadata.csv")
if not os.path.isfile(file_metadata):
return "", None
data = ""
with open(file_metadata, "r", encoding="utf-8") as f:
data = f.read()
list_data = []
for item in data.split("\n"):
sp = item.split("|")
if len(sp) != 2:
continue
list_data.append([os.path.join(path_project, "wavs", sp[0] + ".wav"), sp[1]])
if list_data == []:
return "", None
random_item = random.choice(list_data)
return random_item[1], random_item[0]
def get_random_sample_infer(project_name):
text, audio = get_random_sample_transcribe(project_name)
return (
text,
text,
audio,
)
def infer(project_name, file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step):
global last_checkpoint, last_device, tts_api
if not os.path.isfile(file_checkpoint):
return None
if training_process is not None:
device_test = "cpu"
else:
device_test = None
if last_checkpoint != file_checkpoint or last_device != device_test:
if last_checkpoint != file_checkpoint:
last_checkpoint = file_checkpoint
if last_device != device_test:
last_device = device_test
tts_api = F5TTS(model_type=exp_name, ckpt_file=file_checkpoint, device=device_test)
print("update", device_test, file_checkpoint)
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
tts_api.infer(gen_text=gen_text, ref_text=ref_text, ref_file=ref_audio, nfe_step=nfe_step, file_wave=f.name)
return f.name
with gr.Blocks() as app:
with gr.Row():
project_name = gr.Textbox(label="project name", value="my_speak")
@@ -661,6 +760,18 @@ with gr.Blocks() as app:
)
ch_manual.change(fn=check_user, inputs=[ch_manual], outputs=[audio_speaker, mark_info_transcribe])
random_sample_transcribe = gr.Button("random sample")
with gr.Row():
random_text_transcribe = gr.Text(label="Text")
random_audio_transcribe = gr.Audio(label="Audio", type="filepath")
random_sample_transcribe.click(
fn=get_random_sample_transcribe,
inputs=[project_name],
outputs=[random_text_transcribe, random_audio_transcribe],
)
with gr.TabItem("prepare Data"):
gr.Markdown(
"""```plaintext
@@ -687,6 +798,16 @@ with gr.Blocks() as app:
txt_info_prepare = gr.Text(label="info", value="")
bt_prepare.click(fn=create_metadata, inputs=[project_name], outputs=[txt_info_prepare])
random_sample_prepare = gr.Button("random sample")
with gr.Row():
random_text_prepare = gr.Text(label="Pinyin")
random_audio_prepare = gr.Audio(label="Audio", type="filepath")
random_sample_prepare.click(
fn=get_random_sample_prepare, inputs=[project_name], outputs=[random_text_prepare, random_audio_prepare]
)
with gr.TabItem("train Data"):
with gr.Row():
bt_calculate = bt_create = gr.Button("Auto Settings")
@@ -696,11 +817,11 @@ with gr.Blocks() as app:
with gr.Row():
exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
learning_rate = gr.Number(label="Learning Rate", value=1e-4, step=1e-4)
learning_rate = gr.Number(label="Learning Rate", value=1e-5, step=1e-5)
with gr.Row():
batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=1000)
max_samples = gr.Number(label="Max Samples", value=16)
max_samples = gr.Number(label="Max Samples", value=64)
with gr.Row():
grad_accumulation_steps = gr.Number(label="Gradient Accumulation Steps", value=1)
@@ -778,6 +899,28 @@ with gr.Blocks() as app:
txt_info_check = gr.Text(label="info", value="")
check_button.click(fn=vocab_check, inputs=[project_name], outputs=[txt_info_check])
with gr.TabItem("test model"):
exp_name = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
nfe_step = gr.Number(label="n_step", value=32)
file_checkpoint_pt = gr.Textbox(label="Checkpoint", value="")
random_sample_infer = gr.Button("random sample")
ref_text = gr.Textbox(label="ref text")
ref_audio = gr.Audio(label="audio ref", type="filepath")
gen_text = gr.Textbox(label="gen text")
random_sample_infer.click(
fn=get_random_sample_infer, inputs=[project_name], outputs=[ref_text, gen_text, ref_audio]
)
check_button_infer = gr.Button("infer")
gen_audio = gr.Audio(label="audio gen", type="filepath")
check_button_infer.click(
fn=infer,
inputs=[project_name, file_checkpoint_pt, exp_name, ref_text, ref_audio, gen_text, nfe_step],
outputs=[gen_audio],
)
@click.command()
@click.option("--port", "-p", default=None, type=int, help="Port to run the app on")