mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-12 15:50:07 -08:00
Merge branch 'main' of github.com:SWivid/F5-TTS into main
This commit is contained in:
@@ -1,9 +1,12 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import tempfile
|
||||
import random
|
||||
from transformers import pipeline
|
||||
import gradio as gr
|
||||
import torch
|
||||
import gc
|
||||
import click
|
||||
import torchaudio
|
||||
from glob import glob
|
||||
@@ -20,11 +23,16 @@ import psutil
|
||||
import platform
|
||||
import subprocess
|
||||
from datasets.arrow_writer import ArrowWriter
|
||||
from datasets import Dataset as Dataset_
|
||||
from api import F5TTS
|
||||
|
||||
|
||||
training_process = None
|
||||
system = platform.system()
|
||||
python_executable = sys.executable or "python"
|
||||
tts_api = None
|
||||
last_checkpoint = ""
|
||||
last_device = ""
|
||||
|
||||
path_data = "data"
|
||||
|
||||
@@ -240,7 +248,12 @@ def start_training(
|
||||
last_per_steps=800,
|
||||
finetune=True,
|
||||
):
|
||||
global training_process
|
||||
global training_process, tts_api
|
||||
|
||||
if tts_api is not None:
|
||||
del tts_api
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
path_project = os.path.join(path_data, dataset_name + "_pinyin")
|
||||
|
||||
@@ -288,7 +301,7 @@ def start_training(
|
||||
training_process = subprocess.Popen(cmd, shell=True)
|
||||
|
||||
time.sleep(5)
|
||||
yield "check terminal for wandb", gr.update(interactive=False), gr.update(interactive=True)
|
||||
yield "train start", gr.update(interactive=False), gr.update(interactive=True)
|
||||
|
||||
# Wait for the training process to finish
|
||||
training_process.wait()
|
||||
@@ -519,6 +532,17 @@ def calculate_train(
|
||||
path_project = os.path.join(path_data, name_project)
|
||||
file_duraction = os.path.join(path_project, "duration.json")
|
||||
|
||||
if not os.path.isfile(file_duraction):
|
||||
return (
|
||||
1000,
|
||||
max_samples,
|
||||
num_warmup_updates,
|
||||
save_per_updates,
|
||||
last_per_steps,
|
||||
"project not found !",
|
||||
learning_rate,
|
||||
)
|
||||
|
||||
with open(file_duraction, "r") as file:
|
||||
data = json.load(file)
|
||||
|
||||
@@ -549,8 +573,8 @@ def calculate_train(
|
||||
else:
|
||||
max_samples = 64
|
||||
|
||||
num_warmup_updates = int(samples * 0.10)
|
||||
save_per_updates = int(samples * 0.25)
|
||||
num_warmup_updates = int(samples * 0.05)
|
||||
save_per_updates = int(samples * 0.10)
|
||||
last_per_steps = int(save_per_updates * 5)
|
||||
|
||||
max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples)
|
||||
@@ -559,7 +583,7 @@ def calculate_train(
|
||||
last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps)
|
||||
|
||||
if finetune:
|
||||
learning_rate = 1e-4
|
||||
learning_rate = 1e-5
|
||||
else:
|
||||
learning_rate = 7.5e-5
|
||||
|
||||
@@ -611,6 +635,7 @@ def vocab_check(project_name):
|
||||
sp = item.split("|")
|
||||
if len(sp) != 2:
|
||||
continue
|
||||
|
||||
text = sp[1].lower().strip()
|
||||
|
||||
for t in text:
|
||||
@@ -625,6 +650,80 @@ def vocab_check(project_name):
|
||||
return info
|
||||
|
||||
|
||||
def get_random_sample_prepare(project_name):
|
||||
name_project = project_name + "_pinyin"
|
||||
path_project = os.path.join(path_data, name_project)
|
||||
file_arrow = os.path.join(path_project, "raw.arrow")
|
||||
if not os.path.isfile(file_arrow):
|
||||
return "", None
|
||||
dataset = Dataset_.from_file(file_arrow)
|
||||
random_sample = dataset.shuffle(seed=random.randint(0, 1000)).select([0])
|
||||
text = "[" + " , ".join(["' " + t + " '" for t in random_sample["text"][0]]) + "]"
|
||||
audio_path = random_sample["audio_path"][0]
|
||||
return text, audio_path
|
||||
|
||||
|
||||
def get_random_sample_transcribe(project_name):
|
||||
name_project = project_name + "_pinyin"
|
||||
path_project = os.path.join(path_data, name_project)
|
||||
file_metadata = os.path.join(path_project, "metadata.csv")
|
||||
if not os.path.isfile(file_metadata):
|
||||
return "", None
|
||||
|
||||
data = ""
|
||||
with open(file_metadata, "r", encoding="utf-8") as f:
|
||||
data = f.read()
|
||||
|
||||
list_data = []
|
||||
for item in data.split("\n"):
|
||||
sp = item.split("|")
|
||||
if len(sp) != 2:
|
||||
continue
|
||||
list_data.append([os.path.join(path_project, "wavs", sp[0] + ".wav"), sp[1]])
|
||||
|
||||
if list_data == []:
|
||||
return "", None
|
||||
|
||||
random_item = random.choice(list_data)
|
||||
|
||||
return random_item[1], random_item[0]
|
||||
|
||||
|
||||
def get_random_sample_infer(project_name):
|
||||
text, audio = get_random_sample_transcribe(project_name)
|
||||
return (
|
||||
text,
|
||||
text,
|
||||
audio,
|
||||
)
|
||||
|
||||
|
||||
def infer(project_name, file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step):
|
||||
global last_checkpoint, last_device, tts_api
|
||||
|
||||
if not os.path.isfile(file_checkpoint):
|
||||
return None
|
||||
|
||||
if training_process is not None:
|
||||
device_test = "cpu"
|
||||
else:
|
||||
device_test = None
|
||||
|
||||
if last_checkpoint != file_checkpoint or last_device != device_test:
|
||||
if last_checkpoint != file_checkpoint:
|
||||
last_checkpoint = file_checkpoint
|
||||
if last_device != device_test:
|
||||
last_device = device_test
|
||||
|
||||
tts_api = F5TTS(model_type=exp_name, ckpt_file=file_checkpoint, device=device_test)
|
||||
|
||||
print("update", device_test, file_checkpoint)
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
||||
tts_api.infer(gen_text=gen_text, ref_text=ref_text, ref_file=ref_audio, nfe_step=nfe_step, file_wave=f.name)
|
||||
return f.name
|
||||
|
||||
|
||||
with gr.Blocks() as app:
|
||||
with gr.Row():
|
||||
project_name = gr.Textbox(label="project name", value="my_speak")
|
||||
@@ -661,6 +760,18 @@ with gr.Blocks() as app:
|
||||
)
|
||||
ch_manual.change(fn=check_user, inputs=[ch_manual], outputs=[audio_speaker, mark_info_transcribe])
|
||||
|
||||
random_sample_transcribe = gr.Button("random sample")
|
||||
|
||||
with gr.Row():
|
||||
random_text_transcribe = gr.Text(label="Text")
|
||||
random_audio_transcribe = gr.Audio(label="Audio", type="filepath")
|
||||
|
||||
random_sample_transcribe.click(
|
||||
fn=get_random_sample_transcribe,
|
||||
inputs=[project_name],
|
||||
outputs=[random_text_transcribe, random_audio_transcribe],
|
||||
)
|
||||
|
||||
with gr.TabItem("prepare Data"):
|
||||
gr.Markdown(
|
||||
"""```plaintext
|
||||
@@ -687,6 +798,16 @@ with gr.Blocks() as app:
|
||||
txt_info_prepare = gr.Text(label="info", value="")
|
||||
bt_prepare.click(fn=create_metadata, inputs=[project_name], outputs=[txt_info_prepare])
|
||||
|
||||
random_sample_prepare = gr.Button("random sample")
|
||||
|
||||
with gr.Row():
|
||||
random_text_prepare = gr.Text(label="Pinyin")
|
||||
random_audio_prepare = gr.Audio(label="Audio", type="filepath")
|
||||
|
||||
random_sample_prepare.click(
|
||||
fn=get_random_sample_prepare, inputs=[project_name], outputs=[random_text_prepare, random_audio_prepare]
|
||||
)
|
||||
|
||||
with gr.TabItem("train Data"):
|
||||
with gr.Row():
|
||||
bt_calculate = bt_create = gr.Button("Auto Settings")
|
||||
@@ -696,11 +817,11 @@ with gr.Blocks() as app:
|
||||
|
||||
with gr.Row():
|
||||
exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
|
||||
learning_rate = gr.Number(label="Learning Rate", value=1e-4, step=1e-4)
|
||||
learning_rate = gr.Number(label="Learning Rate", value=1e-5, step=1e-5)
|
||||
|
||||
with gr.Row():
|
||||
batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=1000)
|
||||
max_samples = gr.Number(label="Max Samples", value=16)
|
||||
max_samples = gr.Number(label="Max Samples", value=64)
|
||||
|
||||
with gr.Row():
|
||||
grad_accumulation_steps = gr.Number(label="Gradient Accumulation Steps", value=1)
|
||||
@@ -778,6 +899,28 @@ with gr.Blocks() as app:
|
||||
txt_info_check = gr.Text(label="info", value="")
|
||||
check_button.click(fn=vocab_check, inputs=[project_name], outputs=[txt_info_check])
|
||||
|
||||
with gr.TabItem("test model"):
|
||||
exp_name = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
|
||||
nfe_step = gr.Number(label="n_step", value=32)
|
||||
file_checkpoint_pt = gr.Textbox(label="Checkpoint", value="")
|
||||
|
||||
random_sample_infer = gr.Button("random sample")
|
||||
|
||||
ref_text = gr.Textbox(label="ref text")
|
||||
ref_audio = gr.Audio(label="audio ref", type="filepath")
|
||||
gen_text = gr.Textbox(label="gen text")
|
||||
random_sample_infer.click(
|
||||
fn=get_random_sample_infer, inputs=[project_name], outputs=[ref_text, gen_text, ref_audio]
|
||||
)
|
||||
check_button_infer = gr.Button("infer")
|
||||
gen_audio = gr.Audio(label="audio gen", type="filepath")
|
||||
|
||||
check_button_infer.click(
|
||||
fn=infer,
|
||||
inputs=[project_name, file_checkpoint_pt, exp_name, ref_text, ref_audio, gen_text, nfe_step],
|
||||
outputs=[gen_audio],
|
||||
)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
|
||||
|
||||
Reference in New Issue
Block a user