mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-12 15:50:07 -08:00
fix inference-cli; clean-up
This commit is contained in:
40
README.md
40
README.md
@@ -58,38 +58,28 @@ Once your datasets are prepared, you can start the training process.
|
||||
# setup accelerate config, e.g. use multi-gpu ddp, fp16
|
||||
# will be to: ~/.cache/huggingface/accelerate/default_config.yaml
|
||||
accelerate config
|
||||
accelerate launch test_train.py
|
||||
accelerate launch train.py
|
||||
```
|
||||
An initial guidance on Finetuning [#57](https://github.com/SWivid/F5-TTS/discussions/57).
|
||||
|
||||
## Inference
|
||||
|
||||
To run inference with pretrained models, download the checkpoints from [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS).
|
||||
To run inference with pretrained models, download the checkpoints from [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS), or automatically downloaded with `inference-cli` and `gradio_app`.
|
||||
|
||||
Currently support up to 30s generation, which is the **TOTAL** length of prompt audio and the generated. Batch inference with chunks is supported by Gradio APP now.
|
||||
Currently support 30s for a single generation, which is the **TOTAL** length of prompt audio and the generated. Batch inference with chunks is supported by `inference-cli` and `gradio_app`.
|
||||
- To avoid possible inference failures, make sure you have seen through the following instructions.
|
||||
- A longer prompt audio allows shorter generated output. The part longer than 30s cannot be generated properly. Consider split your text and do several separate inferences or leverage the local Gradio APP which enables a batch inference with chunks.
|
||||
- A longer prompt audio allows shorter generated output. The part longer than 30s cannot be generated properly. Consider using a prompt audio <15s.
|
||||
- Uppercased letters will be uttered letter by letter, so use lowercased letters for normal words.
|
||||
- Add some spaces (blank: " ") or punctuations (e.g. "," ".") to explicitly introduce some pauses. If first few words skipped in code-switched generation (cuz different speed with different languages), this might help.
|
||||
|
||||
### Single Inference
|
||||
### CLI Inference
|
||||
|
||||
You can test single inference using the following command. Before running the command, modify the config up to your need.
|
||||
Either you can specify everything in `inference-cli.toml` or override with flags. Leave `--ref_text ""` will have ASR model transcribe the reference audio automatically (use extra GPU memory). If encounter network error, consider use local ckpt, just set `ckpt_path` in `inference-cli.py`
|
||||
|
||||
```bash
|
||||
# modify the config up to your need,
|
||||
# e.g. fix_duration (the total length of prompt + to_generate, currently support up to 30s)
|
||||
# nfe_step (larger takes more time to do more precise inference ode)
|
||||
# ode_method (switch to 'midpoint' for better compatibility with small nfe_step, )
|
||||
# ( though 'midpoint' is 2nd-order ode solver, slower compared to 1st-order 'Euler')
|
||||
python test_infer_single.py
|
||||
```
|
||||
### Speech Editing
|
||||
python inference-cli.py --model "F5-TTS" --ref_audio "tests/ref_audio/test_en_1_ref_short.wav" --ref_text "Some call me nature, others call me mother nature." --gen_text "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
|
||||
|
||||
To test speech editing capabilities, use the following command.
|
||||
|
||||
```bash
|
||||
python test_infer_single_edit.py
|
||||
python inference-cli.py --model "E2-TTS" --ref_audio "tests/ref_audio/test_zh_1_ref_short.wav" --ref_text "对,这就是我,万人敬仰的太乙真人。" --gen_text "突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道:\"我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?\""
|
||||
```
|
||||
|
||||
### Gradio App
|
||||
@@ -102,7 +92,7 @@ First, make sure you have the dependencies installed (`pip install -r requiremen
|
||||
pip install -r requirements_gradio.txt
|
||||
```
|
||||
|
||||
After installing the dependencies, launch the app (will load ckpt from Huggingface, you may set `ckpt_path` to local file in `gradio_app.py`):
|
||||
After installing the dependencies, launch the app (will load ckpt from Huggingface, you may set `ckpt_path` to local file in `gradio_app.py`). Currently load ASR model, F5-TTS and E2 TTS all in once, thus use more GPU memory than `inference-cli`.
|
||||
|
||||
```bash
|
||||
python gradio_app.py
|
||||
@@ -120,6 +110,14 @@ Or launch a share link:
|
||||
python gradio_app.py --share
|
||||
```
|
||||
|
||||
### Speech Editing
|
||||
|
||||
To test speech editing capabilities, use the following command.
|
||||
|
||||
```bash
|
||||
python speech_edit.py
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
|
||||
### Prepare Test Datasets
|
||||
@@ -127,7 +125,7 @@ python gradio_app.py --share
|
||||
1. Seed-TTS test set: Download from [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval).
|
||||
2. LibriSpeech test-clean: Download from [OpenSLR](http://www.openslr.org/12/).
|
||||
3. Unzip the downloaded datasets and place them in the data/ directory.
|
||||
4. Update the path for the test-clean data in `test_infer_batch.py`
|
||||
4. Update the path for the test-clean data in `scripts/eval_infer_batch.py`
|
||||
5. Our filtered LibriSpeech-PC 4-10s subset is already under data/ in this repo
|
||||
|
||||
### Batch Inference for Test Set
|
||||
@@ -137,7 +135,7 @@ To run batch inference for evaluations, execute the following commands:
|
||||
```bash
|
||||
# batch inference for evaluations
|
||||
accelerate config # if not set before
|
||||
bash test_infer_batch.sh
|
||||
bash scripts/eval_infer_batch.sh
|
||||
```
|
||||
|
||||
### Download Evaluation Model Checkpoints
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
import re
|
||||
import torch
|
||||
import torchaudio
|
||||
@@ -16,10 +15,8 @@ from model.utils import (
|
||||
save_spectrogram,
|
||||
)
|
||||
from transformers import pipeline
|
||||
import librosa
|
||||
import click
|
||||
import soundfile as sf
|
||||
import tomllib
|
||||
import tomli
|
||||
import argparse
|
||||
import tqdm
|
||||
from pathlib import Path
|
||||
@@ -42,19 +39,19 @@ parser.add_argument(
|
||||
)
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--reference",
|
||||
"--ref_audio",
|
||||
type=str,
|
||||
help="Reference audio file < 15 seconds."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--subtitle",
|
||||
"--ref_text",
|
||||
type=str,
|
||||
help="Subtitle for the reference audio."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--text",
|
||||
"--gen_text",
|
||||
type=str,
|
||||
help="Text to generate.",
|
||||
)
|
||||
@@ -70,11 +67,11 @@ parser.add_argument(
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
config = tomllib.load(open(args.config, "rb"))
|
||||
config = tomli.load(open(args.config, "rb"))
|
||||
|
||||
ref_audio = args.reference if args.reference else config["reference"]
|
||||
ref_text = args.subtitle if args.subtitle else config["subtitle"]
|
||||
gen_text = args.text if args.text else config["text"]
|
||||
ref_audio = args.ref_audio if args.ref_audio else config["ref_audio"]
|
||||
ref_text = args.ref_text if args.ref_text else config["ref_text"]
|
||||
gen_text = args.gen_text if args.gen_text else config["gen_text"]
|
||||
output_dir = args.output_dir if args.output_dir else config["output_dir"]
|
||||
exp_name = args.model if args.model else config["model"]
|
||||
remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
|
||||
@@ -100,13 +97,6 @@ device = (
|
||||
|
||||
print(f"Using {device} device")
|
||||
|
||||
pipe = pipeline(
|
||||
"automatic-speech-recognition",
|
||||
model="openai/whisper-large-v3-turbo",
|
||||
torch_dtype=torch.float16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# --------------------- Settings -------------------- #
|
||||
|
||||
target_sample_rate = 24000
|
||||
@@ -151,13 +141,6 @@ F5TTS_model_cfg = dict(
|
||||
)
|
||||
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
||||
|
||||
F5TTS_ema_model = load_model(
|
||||
"F5TTS_Base", DiT, F5TTS_model_cfg, 1200000
|
||||
)
|
||||
E2TTS_ema_model = load_model(
|
||||
"E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
|
||||
)
|
||||
|
||||
def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
|
||||
if len(text.encode('utf-8')) <= max_chars:
|
||||
return [text]
|
||||
@@ -256,9 +239,9 @@ def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
|
||||
|
||||
def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence):
|
||||
if exp_name == "F5-TTS":
|
||||
ema_model = F5TTS_ema_model
|
||||
ema_model = load_model("F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
|
||||
elif exp_name == "E2-TTS":
|
||||
ema_model = E2TTS_ema_model
|
||||
ema_model = load_model("E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
|
||||
|
||||
audio, sr = torchaudio.load(ref_audio)
|
||||
if audio.shape[0] > 1:
|
||||
@@ -363,6 +346,12 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_s
|
||||
|
||||
if not ref_text.strip():
|
||||
print("No reference text provided, transcribing reference audio...")
|
||||
pipe = pipeline(
|
||||
"automatic-speech-recognition",
|
||||
model="openai/whisper-large-v3-turbo",
|
||||
torch_dtype=torch.float16,
|
||||
device=device,
|
||||
)
|
||||
ref_text = pipe(
|
||||
ref_audio,
|
||||
chunk_length_s=30,
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# F5-TTS | E2-TTS
|
||||
model = "F5-TTS"
|
||||
reference = "tests/ref_audio/test_en_1_ref_short.wav"
|
||||
ref_audio = "tests/ref_audio/test_en_1_ref_short.wav"
|
||||
# If an empty "", transcribes the reference audio automatically.
|
||||
subtitle = "Some call me nature, others call me mother nature."
|
||||
text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
|
||||
ref_text = "Some call me nature, others call me mother nature."
|
||||
gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
|
||||
remove_silence = true
|
||||
output_dir = "tests"
|
||||
@@ -188,7 +188,7 @@ def load_dataset(
|
||||
dataset_type: str = "CustomDataset",
|
||||
audio_type: str = "raw",
|
||||
mel_spec_kwargs: dict = dict()
|
||||
) -> CustomDataset | HFDataset:
|
||||
) -> CustomDataset:
|
||||
|
||||
print("Loading dataset ...")
|
||||
|
||||
|
||||
@@ -1,16 +1,21 @@
|
||||
accelerate>=0.33.0
|
||||
cached_path
|
||||
click
|
||||
datasets
|
||||
einops>=0.8.0
|
||||
einx>=0.3.0
|
||||
ema_pytorch>=0.5.2
|
||||
faster_whisper
|
||||
funasr
|
||||
gradio
|
||||
jieba
|
||||
jiwer
|
||||
librosa
|
||||
matplotlib
|
||||
pydub
|
||||
pypinyin
|
||||
safetensors
|
||||
soundfile
|
||||
# torch>=2.0
|
||||
# torchaudio>=2.3.0
|
||||
torchdiffeq
|
||||
@@ -20,6 +25,4 @@ vocos
|
||||
wandb
|
||||
x_transformers>=1.31.14
|
||||
zhconv
|
||||
zhon
|
||||
pydub
|
||||
cached_path
|
||||
zhon
|
||||
@@ -1,5 +0,0 @@
|
||||
cached_path
|
||||
click
|
||||
gradio
|
||||
pydub
|
||||
soundfile
|
||||
@@ -1,4 +1,6 @@
|
||||
import os
|
||||
import sys, os
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import time
|
||||
import random
|
||||
from tqdm import tqdm
|
||||
13
scripts/eval_infer_batch.sh
Normal file
13
scripts/eval_infer_batch.sh
Normal file
@@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
|
||||
# e.g. F5-TTS, 16 NFE
|
||||
accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
|
||||
accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16
|
||||
accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16
|
||||
|
||||
# e.g. Vanilla E2 TTS, 32 NFE
|
||||
accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
|
||||
accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
|
||||
accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
|
||||
|
||||
# etc.
|
||||
@@ -1,13 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# e.g. F5-TTS, 16 NFE
|
||||
accelerate launch test_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
|
||||
accelerate launch test_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16
|
||||
accelerate launch test_infer_batch.py -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16
|
||||
|
||||
# e.g. Vanilla E2 TTS, 32 NFE
|
||||
accelerate launch test_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
|
||||
accelerate launch test_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
|
||||
accelerate launch test_infer_batch.py -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
|
||||
|
||||
# etc.
|
||||
@@ -1,161 +0,0 @@
|
||||
import os
|
||||
import re
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from einops import rearrange
|
||||
from vocos import Vocos
|
||||
|
||||
from model import CFM, UNetT, DiT, MMDiT
|
||||
from model.utils import (
|
||||
load_checkpoint,
|
||||
get_tokenizer,
|
||||
convert_char_to_pinyin,
|
||||
save_spectrogram,
|
||||
)
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
|
||||
|
||||
# --------------------- Dataset Settings -------------------- #
|
||||
|
||||
target_sample_rate = 24000
|
||||
n_mel_channels = 100
|
||||
hop_length = 256
|
||||
target_rms = 0.1
|
||||
|
||||
tokenizer = "pinyin"
|
||||
dataset_name = "Emilia_ZH_EN"
|
||||
|
||||
|
||||
# ---------------------- infer setting ---------------------- #
|
||||
|
||||
seed = None # int | None
|
||||
|
||||
exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
|
||||
ckpt_step = 1200000
|
||||
|
||||
nfe_step = 32 # 16, 32
|
||||
cfg_strength = 2.
|
||||
ode_method = 'euler' # euler | midpoint
|
||||
sway_sampling_coef = -1.
|
||||
speed = 1.
|
||||
fix_duration = 27 # None (will linear estimate. if code-switched, consider fix) | float (total in seconds, include ref audio)
|
||||
|
||||
if exp_name == "F5TTS_Base":
|
||||
model_cls = DiT
|
||||
model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
|
||||
|
||||
elif exp_name == "E2TTS_Base":
|
||||
model_cls = UNetT
|
||||
model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
|
||||
|
||||
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
|
||||
output_dir = "tests"
|
||||
|
||||
ref_audio = "tests/ref_audio/test_en_1_ref_short.wav"
|
||||
ref_text = "Some call me nature, others call me mother nature."
|
||||
gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
|
||||
|
||||
# ref_audio = "tests/ref_audio/test_zh_1_ref_short.wav"
|
||||
# ref_text = "对,这就是我,万人敬仰的太乙真人。"
|
||||
# gen_text = "突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道:\"我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?\""
|
||||
|
||||
|
||||
# -------------------------------------------------#
|
||||
|
||||
use_ema = True
|
||||
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
# Vocoder model
|
||||
local = False
|
||||
if local:
|
||||
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
||||
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
|
||||
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
|
||||
vocos.load_state_dict(state_dict)
|
||||
vocos.eval()
|
||||
else:
|
||||
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
||||
|
||||
# Tokenizer
|
||||
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
|
||||
|
||||
# Model
|
||||
model = CFM(
|
||||
transformer = model_cls(
|
||||
**model_cfg,
|
||||
text_num_embeds = vocab_size,
|
||||
mel_dim = n_mel_channels
|
||||
),
|
||||
mel_spec_kwargs = dict(
|
||||
target_sample_rate = target_sample_rate,
|
||||
n_mel_channels = n_mel_channels,
|
||||
hop_length = hop_length,
|
||||
),
|
||||
odeint_kwargs = dict(
|
||||
method = ode_method,
|
||||
),
|
||||
vocab_char_map = vocab_char_map,
|
||||
).to(device)
|
||||
|
||||
model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
|
||||
|
||||
# Audio
|
||||
audio, sr = torchaudio.load(ref_audio)
|
||||
if audio.shape[0] > 1:
|
||||
audio = torch.mean(audio, dim=0, keepdim=True)
|
||||
rms = torch.sqrt(torch.mean(torch.square(audio)))
|
||||
if rms < target_rms:
|
||||
audio = audio * target_rms / rms
|
||||
if sr != target_sample_rate:
|
||||
resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
|
||||
audio = resampler(audio)
|
||||
audio = audio.to(device)
|
||||
|
||||
# Text
|
||||
if len(ref_text[-1].encode('utf-8')) == 1:
|
||||
ref_text = ref_text + " "
|
||||
text_list = [ref_text + gen_text]
|
||||
if tokenizer == "pinyin":
|
||||
final_text_list = convert_char_to_pinyin(text_list)
|
||||
else:
|
||||
final_text_list = [text_list]
|
||||
print(f"text : {text_list}")
|
||||
print(f"pinyin: {final_text_list}")
|
||||
|
||||
# Duration
|
||||
ref_audio_len = audio.shape[-1] // hop_length
|
||||
if fix_duration is not None:
|
||||
duration = int(fix_duration * target_sample_rate / hop_length)
|
||||
else: # simple linear scale calcul
|
||||
zh_pause_punc = r"。,、;:?!"
|
||||
ref_text_len = len(ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_text))
|
||||
gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
|
||||
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
|
||||
|
||||
# Inference
|
||||
with torch.inference_mode():
|
||||
generated, trajectory = model.sample(
|
||||
cond = audio,
|
||||
text = final_text_list,
|
||||
duration = duration,
|
||||
steps = nfe_step,
|
||||
cfg_strength = cfg_strength,
|
||||
sway_sampling_coef = sway_sampling_coef,
|
||||
seed = seed,
|
||||
)
|
||||
print(f"Generated mel: {generated.shape}")
|
||||
|
||||
# Final result
|
||||
generated = generated[:, ref_audio_len:, :]
|
||||
generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
|
||||
generated_wave = vocos.decode(generated_mel_spec.cpu())
|
||||
if rms < target_rms:
|
||||
generated_wave = generated_wave * rms / target_rms
|
||||
|
||||
save_spectrogram(generated_mel_spec[0].cpu().numpy(), f"{output_dir}/test_single.png")
|
||||
torchaudio.save(f"{output_dir}/test_single.wav", generated_wave, target_sample_rate)
|
||||
print(f"Generated wav: {generated_wave.shape}")
|
||||
Reference in New Issue
Block a user