fix inference-cli; clean-up

This commit is contained in:
SWivid
2024-10-14 23:40:31 +08:00
parent 9ec24868a9
commit 9d2b8cb3da
12 changed files with 61 additions and 235 deletions

View File

@@ -58,38 +58,28 @@ Once your datasets are prepared, you can start the training process.
# setup accelerate config, e.g. use multi-gpu ddp, fp16
# will be to: ~/.cache/huggingface/accelerate/default_config.yaml
accelerate config
accelerate launch test_train.py
accelerate launch train.py
```
An initial guidance on Finetuning [#57](https://github.com/SWivid/F5-TTS/discussions/57).
## Inference
To run inference with pretrained models, download the checkpoints from [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS).
To run inference with pretrained models, download the checkpoints from [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS), or automatically downloaded with `inference-cli` and `gradio_app`.
Currently support up to 30s generation, which is the **TOTAL** length of prompt audio and the generated. Batch inference with chunks is supported by Gradio APP now.
Currently support 30s for a single generation, which is the **TOTAL** length of prompt audio and the generated. Batch inference with chunks is supported by `inference-cli` and `gradio_app`.
- To avoid possible inference failures, make sure you have seen through the following instructions.
- A longer prompt audio allows shorter generated output. The part longer than 30s cannot be generated properly. Consider split your text and do several separate inferences or leverage the local Gradio APP which enables a batch inference with chunks.
- A longer prompt audio allows shorter generated output. The part longer than 30s cannot be generated properly. Consider using a prompt audio <15s.
- Uppercased letters will be uttered letter by letter, so use lowercased letters for normal words.
- Add some spaces (blank: " ") or punctuations (e.g. "," ".") to explicitly introduce some pauses. If first few words skipped in code-switched generation (cuz different speed with different languages), this might help.
### Single Inference
### CLI Inference
You can test single inference using the following command. Before running the command, modify the config up to your need.
Either you can specify everything in `inference-cli.toml` or override with flags. Leave `--ref_text ""` will have ASR model transcribe the reference audio automatically (use extra GPU memory). If encounter network error, consider use local ckpt, just set `ckpt_path` in `inference-cli.py`
```bash
# modify the config up to your need,
# e.g. fix_duration (the total length of prompt + to_generate, currently support up to 30s)
# nfe_step (larger takes more time to do more precise inference ode)
# ode_method (switch to 'midpoint' for better compatibility with small nfe_step, )
# ( though 'midpoint' is 2nd-order ode solver, slower compared to 1st-order 'Euler')
python test_infer_single.py
```
### Speech Editing
python inference-cli.py --model "F5-TTS" --ref_audio "tests/ref_audio/test_en_1_ref_short.wav" --ref_text "Some call me nature, others call me mother nature." --gen_text "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
To test speech editing capabilities, use the following command.
```bash
python test_infer_single_edit.py
python inference-cli.py --model "E2-TTS" --ref_audio "tests/ref_audio/test_zh_1_ref_short.wav" --ref_text "对,这就是我,万人敬仰的太乙真人。" --gen_text "突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道:\"我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?\""
```
### Gradio App
@@ -102,7 +92,7 @@ First, make sure you have the dependencies installed (`pip install -r requiremen
pip install -r requirements_gradio.txt
```
After installing the dependencies, launch the app (will load ckpt from Huggingface, you may set `ckpt_path` to local file in `gradio_app.py`):
After installing the dependencies, launch the app (will load ckpt from Huggingface, you may set `ckpt_path` to local file in `gradio_app.py`). Currently load ASR model, F5-TTS and E2 TTS all in once, thus use more GPU memory than `inference-cli`.
```bash
python gradio_app.py
@@ -120,6 +110,14 @@ Or launch a share link:
python gradio_app.py --share
```
### Speech Editing
To test speech editing capabilities, use the following command.
```bash
python speech_edit.py
```
## Evaluation
### Prepare Test Datasets
@@ -127,7 +125,7 @@ python gradio_app.py --share
1. Seed-TTS test set: Download from [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval).
2. LibriSpeech test-clean: Download from [OpenSLR](http://www.openslr.org/12/).
3. Unzip the downloaded datasets and place them in the data/ directory.
4. Update the path for the test-clean data in `test_infer_batch.py`
4. Update the path for the test-clean data in `scripts/eval_infer_batch.py`
5. Our filtered LibriSpeech-PC 4-10s subset is already under data/ in this repo
### Batch Inference for Test Set
@@ -137,7 +135,7 @@ To run batch inference for evaluations, execute the following commands:
```bash
# batch inference for evaluations
accelerate config # if not set before
bash test_infer_batch.sh
bash scripts/eval_infer_batch.sh
```
### Download Evaluation Model Checkpoints

View File

@@ -1,4 +1,3 @@
import os
import re
import torch
import torchaudio
@@ -16,10 +15,8 @@ from model.utils import (
save_spectrogram,
)
from transformers import pipeline
import librosa
import click
import soundfile as sf
import tomllib
import tomli
import argparse
import tqdm
from pathlib import Path
@@ -42,19 +39,19 @@ parser.add_argument(
)
parser.add_argument(
"-r",
"--reference",
"--ref_audio",
type=str,
help="Reference audio file < 15 seconds."
)
parser.add_argument(
"-s",
"--subtitle",
"--ref_text",
type=str,
help="Subtitle for the reference audio."
)
parser.add_argument(
"-t",
"--text",
"--gen_text",
type=str,
help="Text to generate.",
)
@@ -70,11 +67,11 @@ parser.add_argument(
)
args = parser.parse_args()
config = tomllib.load(open(args.config, "rb"))
config = tomli.load(open(args.config, "rb"))
ref_audio = args.reference if args.reference else config["reference"]
ref_text = args.subtitle if args.subtitle else config["subtitle"]
gen_text = args.text if args.text else config["text"]
ref_audio = args.ref_audio if args.ref_audio else config["ref_audio"]
ref_text = args.ref_text if args.ref_text else config["ref_text"]
gen_text = args.gen_text if args.gen_text else config["gen_text"]
output_dir = args.output_dir if args.output_dir else config["output_dir"]
exp_name = args.model if args.model else config["model"]
remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
@@ -100,13 +97,6 @@ device = (
print(f"Using {device} device")
pipe = pipeline(
"automatic-speech-recognition",
model="openai/whisper-large-v3-turbo",
torch_dtype=torch.float16,
device=device,
)
# --------------------- Settings -------------------- #
target_sample_rate = 24000
@@ -151,13 +141,6 @@ F5TTS_model_cfg = dict(
)
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
F5TTS_ema_model = load_model(
"F5TTS_Base", DiT, F5TTS_model_cfg, 1200000
)
E2TTS_ema_model = load_model(
"E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
)
def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
if len(text.encode('utf-8')) <= max_chars:
return [text]
@@ -256,9 +239,9 @@ def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence):
if exp_name == "F5-TTS":
ema_model = F5TTS_ema_model
ema_model = load_model("F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
elif exp_name == "E2-TTS":
ema_model = E2TTS_ema_model
ema_model = load_model("E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
audio, sr = torchaudio.load(ref_audio)
if audio.shape[0] > 1:
@@ -363,6 +346,12 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_s
if not ref_text.strip():
print("No reference text provided, transcribing reference audio...")
pipe = pipeline(
"automatic-speech-recognition",
model="openai/whisper-large-v3-turbo",
torch_dtype=torch.float16,
device=device,
)
ref_text = pipe(
ref_audio,
chunk_length_s=30,

View File

@@ -1,8 +1,8 @@
# F5-TTS | E2-TTS
model = "F5-TTS"
reference = "tests/ref_audio/test_en_1_ref_short.wav"
ref_audio = "tests/ref_audio/test_en_1_ref_short.wav"
# If an empty "", transcribes the reference audio automatically.
subtitle = "Some call me nature, others call me mother nature."
text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
ref_text = "Some call me nature, others call me mother nature."
gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
remove_silence = true
output_dir = "tests"

View File

@@ -188,7 +188,7 @@ def load_dataset(
dataset_type: str = "CustomDataset",
audio_type: str = "raw",
mel_spec_kwargs: dict = dict()
) -> CustomDataset | HFDataset:
) -> CustomDataset:
print("Loading dataset ...")

View File

@@ -1,16 +1,21 @@
accelerate>=0.33.0
cached_path
click
datasets
einops>=0.8.0
einx>=0.3.0
ema_pytorch>=0.5.2
faster_whisper
funasr
gradio
jieba
jiwer
librosa
matplotlib
pydub
pypinyin
safetensors
soundfile
# torch>=2.0
# torchaudio>=2.3.0
torchdiffeq
@@ -20,6 +25,4 @@ vocos
wandb
x_transformers>=1.31.14
zhconv
zhon
pydub
cached_path
zhon

View File

@@ -1,5 +0,0 @@
cached_path
click
gradio
pydub
soundfile

View File

@@ -1,4 +1,6 @@
import os
import sys, os
sys.path.append(os.getcwd())
import time
import random
from tqdm import tqdm

View File

@@ -0,0 +1,13 @@
#!/bin/bash
# e.g. F5-TTS, 16 NFE
accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16
accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16
# e.g. Vanilla E2 TTS, 32 NFE
accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
# etc.

View File

@@ -1,13 +0,0 @@
#!/bin/bash
# e.g. F5-TTS, 16 NFE
accelerate launch test_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
accelerate launch test_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16
accelerate launch test_infer_batch.py -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16
# e.g. Vanilla E2 TTS, 32 NFE
accelerate launch test_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
accelerate launch test_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
accelerate launch test_infer_batch.py -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
# etc.

View File

@@ -1,161 +0,0 @@
import os
import re
import torch
import torchaudio
from einops import rearrange
from vocos import Vocos
from model import CFM, UNetT, DiT, MMDiT
from model.utils import (
load_checkpoint,
get_tokenizer,
convert_char_to_pinyin,
save_spectrogram,
)
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
# --------------------- Dataset Settings -------------------- #
target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
target_rms = 0.1
tokenizer = "pinyin"
dataset_name = "Emilia_ZH_EN"
# ---------------------- infer setting ---------------------- #
seed = None # int | None
exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
ckpt_step = 1200000
nfe_step = 32 # 16, 32
cfg_strength = 2.
ode_method = 'euler' # euler | midpoint
sway_sampling_coef = -1.
speed = 1.
fix_duration = 27 # None (will linear estimate. if code-switched, consider fix) | float (total in seconds, include ref audio)
if exp_name == "F5TTS_Base":
model_cls = DiT
model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
elif exp_name == "E2TTS_Base":
model_cls = UNetT
model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
output_dir = "tests"
ref_audio = "tests/ref_audio/test_en_1_ref_short.wav"
ref_text = "Some call me nature, others call me mother nature."
gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
# ref_audio = "tests/ref_audio/test_zh_1_ref_short.wav"
# ref_text = "对,这就是我,万人敬仰的太乙真人。"
# gen_text = "突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道:\"我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?\""
# -------------------------------------------------#
use_ema = True
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Vocoder model
local = False
if local:
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
vocos.load_state_dict(state_dict)
vocos.eval()
else:
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
# Tokenizer
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
# Model
model = CFM(
transformer = model_cls(
**model_cfg,
text_num_embeds = vocab_size,
mel_dim = n_mel_channels
),
mel_spec_kwargs = dict(
target_sample_rate = target_sample_rate,
n_mel_channels = n_mel_channels,
hop_length = hop_length,
),
odeint_kwargs = dict(
method = ode_method,
),
vocab_char_map = vocab_char_map,
).to(device)
model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
# Audio
audio, sr = torchaudio.load(ref_audio)
if audio.shape[0] > 1:
audio = torch.mean(audio, dim=0, keepdim=True)
rms = torch.sqrt(torch.mean(torch.square(audio)))
if rms < target_rms:
audio = audio * target_rms / rms
if sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
audio = resampler(audio)
audio = audio.to(device)
# Text
if len(ref_text[-1].encode('utf-8')) == 1:
ref_text = ref_text + " "
text_list = [ref_text + gen_text]
if tokenizer == "pinyin":
final_text_list = convert_char_to_pinyin(text_list)
else:
final_text_list = [text_list]
print(f"text : {text_list}")
print(f"pinyin: {final_text_list}")
# Duration
ref_audio_len = audio.shape[-1] // hop_length
if fix_duration is not None:
duration = int(fix_duration * target_sample_rate / hop_length)
else: # simple linear scale calcul
zh_pause_punc = r"。,、;:?!"
ref_text_len = len(ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_text))
gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
# Inference
with torch.inference_mode():
generated, trajectory = model.sample(
cond = audio,
text = final_text_list,
duration = duration,
steps = nfe_step,
cfg_strength = cfg_strength,
sway_sampling_coef = sway_sampling_coef,
seed = seed,
)
print(f"Generated mel: {generated.shape}")
# Final result
generated = generated[:, ref_audio_len:, :]
generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
generated_wave = vocos.decode(generated_mel_spec.cpu())
if rms < target_rms:
generated_wave = generated_wave * rms / target_rms
save_spectrogram(generated_mel_spec[0].cpu().numpy(), f"{output_dir}/test_single.png")
torchaudio.save(f"{output_dir}/test_single.wav", generated_wave, target_sample_rate)
print(f"Generated wav: {generated_wave.shape}")