reorganize infer_cli and stuff

This commit is contained in:
SWivid
2024-12-15 22:49:31 +08:00
parent 3c60f99714
commit b7bc6419e7
11 changed files with 272 additions and 167 deletions

View File

@@ -147,11 +147,11 @@ Note: Some model components have linting exceptions for E722 to accommodate tens
## Acknowledgements
- [E2-TTS](https://arxiv.org/abs/2406.18009) brilliant work, simple and effective
- [Emilia](https://arxiv.org/abs/2407.05361), [WenetSpeech4TTS](https://arxiv.org/abs/2406.05763) valuable datasets
- [Emilia](https://arxiv.org/abs/2407.05361), [WenetSpeech4TTS](https://arxiv.org/abs/2406.05763), [LibriTTS](https://arxiv.org/abs/1904.02882), [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) valuable datasets
- [lucidrains](https://github.com/lucidrains) initial CFM structure with also [bfs18](https://github.com/bfs18) for discussion
- [SD3](https://arxiv.org/abs/2403.03206) & [Hugging Face diffusers](https://github.com/huggingface/diffusers) DiT and MMDiT code structure
- [torchdiffeq](https://github.com/rtqichen/torchdiffeq) as ODE solver, [Vocos](https://huggingface.co/charactr/vocos-mel-24khz) as vocoder
- [FunASR](https://github.com/modelscope/FunASR), [faster-whisper](https://github.com/SYSTRAN/faster-whisper), [UniSpeech](https://github.com/microsoft/UniSpeech) for evaluation tools
- [FunASR](https://github.com/modelscope/FunASR), [faster-whisper](https://github.com/SYSTRAN/faster-whisper), [UniSpeech](https://github.com/microsoft/UniSpeech), [SpeechMOS](https://github.com/tarepan/SpeechMOS) for evaluation tools
- [ctc-forced-aligner](https://github.com/MahmoudAshraf97/ctc-forced-aligner) for speech edit test
- [mrfakename](https://x.com/realmrfakename) huggingface space demo ~
- [f5-tts-mlx](https://github.com/lucasnewman/f5-tts-mlx/tree/main) Implementation with MLX framework by [Lucas Newman](https://github.com/lucasnewman)

View File

@@ -39,11 +39,14 @@ Then update in the following scripts with the paths you put evaluation model ckp
### Objective Evaluation
Update the path with your batch-inferenced results, and carry out WER / SIM evaluations:
Update the path with your batch-inferenced results, and carry out WER / SIM / UTMOS evaluations:
```bash
# Evaluation for Seed-TTS test set
python src/f5_tts/eval/eval_seedtts_testset.py --gen_wav_dir <GEN_WAVE_DIR>
# Evaluation [WER] for Seed-TTS test [ZH] set
python src/f5_tts/eval/eval_seedtts_testset.py --eval_task wer --lang zh --gen_wav_dir <GEN_WAV_DIR> --gpu_nums 8
# Evaluation for LibriSpeech-PC test-clean (cross-sentence)
python src/f5_tts/eval/eval_librispeech_test_clean.py --gen_wav_dir <GEN_WAVE_DIR> --librispeech_test_clean_path <TEST_CLEAN_PATH>
```
# Evaluation [SIM] for LibriSpeech-PC test-clean (cross-sentence)
python src/f5_tts/eval/eval_librispeech_test_clean.py --eval_task sim --gen_wav_dir <GEN_WAV_DIR> --librispeech_test_clean_path <TEST_CLEAN_PATH>
# Evaluation [UTMOS]. --ext: Audio extension
python src/f5_tts/eval/eval_utmos.py --audio_dir <WAV_DIR> --ext wav
```

View File

@@ -1,8 +1,9 @@
# Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
import sys
import os
import argparse
import json
import os
import sys
sys.path.append(os.getcwd())
@@ -10,7 +11,6 @@ import multiprocessing as mp
from importlib.resources import files
import numpy as np
import json
from f5_tts.eval.utils_eval import (
get_librispeech_test,
run_asr_wer,
@@ -54,36 +54,41 @@ def main():
wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
# --------------------------- WER ---------------------------
if eval_task == "wer":
wers = []
wer_results = []
wers = []
with mp.Pool(processes=len(gpus)) as pool:
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
results = pool.map(run_asr_wer, args)
for wers_ in results:
wers.extend(wers_)
for r in results:
wer_results.extend(r)
with open(f"{gen_wav_dir}/{lang}_wer_results.jsonl", "w") as f:
for line in wers:
wer_results.append(line["wer"])
wer_result_path = f"{gen_wav_dir}/{lang}_wer_results.jsonl"
with open(wer_result_path, "w") as f:
for line in wer_results:
wers.append(line["wer"])
json_line = json.dumps(line, ensure_ascii=False)
f.write(json_line + "\n")
wer = round(np.mean(wers) * 100, 3)
print(f"\nTotal {len(wers)} samples")
print(f"WER : {wer}%")
print(f"Results have been saved to {wer_result_path}")
# --------------------------- SIM ---------------------------
if eval_task == "sim":
sim_list = []
sims = []
with mp.Pool(processes=len(gpus)) as pool:
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
results = pool.map(run_sim, args)
for sim_ in results:
sim_list.extend(sim_)
for r in results:
sims.extend(r)
sim = round(sum(sim_list) / len(sim_list), 3)
print(f"\nTotal {len(sim_list)} samples")
sim = round(sum(sims) / len(sims), 3)
print(f"\nTotal {len(sims)} samples")
print(f"SIM : {sim}")

View File

@@ -1,8 +1,9 @@
# Evaluate with Seed-TTS testset
import sys
import os
import argparse
import json
import os
import sys
sys.path.append(os.getcwd())
@@ -10,7 +11,6 @@ import multiprocessing as mp
from importlib.resources import files
import numpy as np
import json
from f5_tts.eval.utils_eval import (
get_seed_tts_test,
run_asr_wer,
@@ -55,35 +55,39 @@ def main():
# --------------------------- WER ---------------------------
if eval_task == "wer":
wers = []
wer_results = []
wers = []
with mp.Pool(processes=len(gpus)) as pool:
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
results = pool.map(run_asr_wer, args)
for wers_ in results:
wers.extend(wers_)
for r in results:
wer_results.extend(r)
with open(f"{gen_wav_dir}/{lang}_wer_results.jsonl", "w") as f:
for line in wers:
wer_results.append(line["wer"])
wer_result_path = f"{gen_wav_dir}/{lang}_wer_results.jsonl"
with open(wer_result_path, "w") as f:
for line in wer_results:
wers.append(line["wer"])
json_line = json.dumps(line, ensure_ascii=False)
f.write(json_line + "\n")
wer = round(np.mean(wers) * 100, 3)
print(f"\nTotal {len(wers)} samples")
print(f"WER : {wer}%")
print(f"Results have been saved to {wer_result_path}")
# --------------------------- SIM ---------------------------
if eval_task == "sim":
sim_list = []
sims = []
with mp.Pool(processes=len(gpus)) as pool:
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
results = pool.map(run_sim, args)
for sim_ in results:
sim_list.extend(sim_)
for r in results:
sims.extend(r)
sim = round(sum(sim_list) / len(sim_list), 3)
print(f"\nTotal {len(sim_list)} samples")
sim = round(sum(sims) / len(sims), 3)
print(f"\nTotal {len(sims)} samples")
print(f"SIM : {sim}")

View File

@@ -1,46 +1,43 @@
import torch
import librosa
from pathlib import Path
import json
from tqdm import tqdm
import argparse
import json
from pathlib import Path
import librosa
import torch
from tqdm import tqdm
def main():
parser = argparse.ArgumentParser(description="Evaluate UTMOS scores for audio files.")
parser.add_argument(
"--audio_dir", type=str, required=True, help="Path to the directory containing WAV audio files."
)
parser.add_argument("--ext", type=str, default="wav", help="audio extension.")
parser.add_argument("--device", type=str, default="cuda", help="Device to run inference on (e.g. 'cuda' or 'cpu').")
parser = argparse.ArgumentParser(description="UTMOS Evaluation")
parser.add_argument("--audio_dir", type=str, required=True, help="Audio file path.")
parser.add_argument("--ext", type=str, default="wav", help="Audio extension.")
args = parser.parse_args()
device = "cuda" if args.device and torch.cuda.is_available() else "cpu"
device = "cuda" if torch.cuda.is_available() else "cpu"
predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True)
predictor = predictor.to(device)
lines = list(Path(args.audio_dir).rglob(f"*.{args.ext}"))
results = {}
utmos_result = 0
audio_paths = list(Path(args.audio_dir).rglob(f"*.{args.ext}"))
utmos_results = {}
utmos_score = 0
for line in tqdm(lines, desc="Processing"):
wave_name = line.stem
wave, sr = librosa.load(line, sr=None, mono=True)
wave_tensor = torch.from_numpy(wave).to(device).unsqueeze(0)
score = predictor(wave_tensor, sr)
results[str(wave_name)] = score.item()
utmos_result += score.item()
for audio_path in tqdm(audio_paths, desc="Processing"):
wav_name = audio_path.stem
wav, sr = librosa.load(audio_path, sr=None, mono=True)
wav_tensor = torch.from_numpy(wav).to(device).unsqueeze(0)
score = predictor(wav_tensor, sr)
utmos_results[str(wav_name)] = score.item()
utmos_score += score.item()
avg_score = utmos_result / len(lines) if len(lines) > 0 else 0
avg_score = utmos_score / len(audio_paths) if len(audio_paths) > 0 else 0
print(f"UTMOS: {avg_score}")
output_path = Path(args.audio_dir) / "utmos_results.json"
with open(output_path, "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=4)
utmos_result_path = Path(args.audio_dir) / "utmos_results.json"
with open(utmos_result_path, "w", encoding="utf-8") as f:
json.dump(utmos_results, f, ensure_ascii=False, indent=4)
print(f"Results have been saved to {output_path}")
print(f"Results have been saved to {utmos_result_path}")
if __name__ == "__main__":

View File

@@ -2,12 +2,13 @@ import math
import os
import random
import string
from pathlib import Path
import torch
import torch.nn.functional as F
import torchaudio
from tqdm import tqdm
from pathlib import Path
from f5_tts.eval.ecapa_tdnn import ECAPA_TDNN_SMALL
from f5_tts.model.modules import MelSpec
from f5_tts.model.utils import convert_char_to_pinyin
@@ -320,7 +321,7 @@ def run_asr_wer(args):
from zhon.hanzi import punctuation
punctuation_all = punctuation + string.punctuation
wers = []
wer_results = []
from jiwer import compute_measures
@@ -335,8 +336,8 @@ def run_asr_wer(args):
for segment in segments:
hypo = hypo + " " + segment.text
# raw_truth = truth
# raw_hypo = hypo
raw_truth = truth
raw_hypo = hypo
for x in punctuation_all:
truth = truth.replace(x, "")
@@ -360,16 +361,16 @@ def run_asr_wer(args):
# dele = measures["deletions"] / len(ref_list)
# inse = measures["insertions"] / len(ref_list)
wers.append(
wer_results.append(
{
"wav": Path(gen_wav).stem, # wav name
"truth": truth, # raw_truth
"hypo": hypo, # raw_hypo
"wer": wer, # wer score
"wav": Path(gen_wav).stem,
"truth": raw_truth,
"hypo": raw_hypo,
"wer": wer,
}
)
return wers
return wer_results
# SIM Evaluation
@@ -388,7 +389,7 @@ def run_sim(args):
model = model.cuda(device)
model.eval()
sim_list = []
sims = []
for wav1, wav2, truth in tqdm(test_set):
wav1, sr1 = torchaudio.load(wav1)
wav2, sr2 = torchaudio.load(wav2)
@@ -407,6 +408,6 @@ def run_sim(args):
sim = F.cosine_similarity(emb1, emb2)[0].item()
# print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
sim_list.append(sim)
sims.append(sim)
return sim_list
return sims

View File

@@ -64,6 +64,9 @@ f5-tts_infer-cli \
# Choose Vocoder
f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base_bigvgan/model_1250000.pt>
f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base/model_1200000.safetensors>
# More instructions
f5-tts_infer-cli --help
```
And a `.toml` file would help with more flexible usage.

View File

@@ -22,12 +22,12 @@
- [Finnish Common\_Voice Vox\_Populi @ finetune @ fi](#finnish-common_voice-vox_populi--finetune--fi)
- [French](#french)
- [French LibriVox @ finetune @ fr](#french-librivox--finetune--fr)
- [Hindi](#hindi)
- [F5-TTS Small @ pretrain @ hi](#f5-tts-small--pretrain--hi)
- [Italian](#italian)
- [F5-TTS Italian @ finetune @ it](#f5-tts-italian--finetune--it)
- [Japanese](#japanese)
- [F5-TTS Japanese @ pretrain/finetune @ ja](#f5-tts-japanese--pretrainfinetune--ja)
- [Hindi](#hindi)
- [F5-TTS Small @ pretrain @ hi](#f5-tts-small--pretrain--hi)
- [Mandarin](#mandarin)
- [Spanish](#spanish)
- [F5-TTS Spanish @ pretrain/finetune @ es](#f5-tts-spanish--pretrainfinetune--es)
@@ -81,6 +81,23 @@ VOCAB_FILE: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt
- [Discussion about this training can be found here](https://github.com/SWivid/F5-TTS/issues/434).
## Hindi
#### F5-TTS Small @ pretrain @ hi
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Small|[ckpt & vocab](https://huggingface.co/SPRINGLab/F5-Hindi-24KHz)|[IndicTTS Hi](https://huggingface.co/datasets/SPRINGLab/IndicTTS-Hindi) & [IndicVoices-R Hi](https://huggingface.co/datasets/SPRINGLab/IndicVoices-R_Hindi) |cc-by-4.0|
```bash
MODEL_CKPT: hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors
VOCAB_FILE: hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt
```
Authors: SPRING Lab, Indian Institute of Technology, Madras
<br>
Website: https://asr.iitm.ac.in/
## Italian
#### F5-TTS Italian @ finetune @ it
@@ -110,21 +127,6 @@ MODEL_CKPT: hf://Jmica/F5TTS/JA_8500000/model_8499660.pt
VOCAB_FILE: hf://Jmica/F5TTS/JA_8500000/vocab_updated.txt
```
## Hindi
#### F5-TTS Small @ pretrain @ hi
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Small|[ckpt & vocab](https://huggingface.co/SPRINGLab/F5-Hindi-24KHz)|[IndicTTS Hi](https://huggingface.co/datasets/SPRINGLab/IndicTTS-Hindi) & [IndicVoices-R Hi](https://huggingface.co/datasets/SPRINGLab/IndicVoices-R_Hindi) |cc-by-4.0|
```bash
MODEL_CKPT: hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors
VOCAB_FILE: hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt
```
Authors: SPRING Lab, Indian Institute of Technology, Madras
<br>
Website: https://asr.iitm.ac.in/
## Mandarin

View File

@@ -8,4 +8,4 @@ gen_text = "I don't really care what you call me. I've been a silent spectator,
gen_file = ""
remove_silence = false
output_dir = "tests"
output_file = "infer_cli_out.wav"
output_file = "infer_cli_basic.wav"

View File

@@ -8,6 +8,7 @@ gen_text = ""
gen_file = "infer/examples/multi/story.txt"
remove_silence = true
output_dir = "tests"
output_file = "infer_cli_story.wav"
[voices.town]
ref_audio = "infer/examples/multi/town.flac"

View File

@@ -2,6 +2,7 @@ import argparse
import codecs
import os
import re
from datetime import datetime
from importlib.resources import files
from pathlib import Path
@@ -11,6 +12,14 @@ import tomli
from cached_path import cached_path
from f5_tts.infer.utils_infer import (
mel_spec_type,
target_rms,
cross_fade_duration,
nfe_step,
cfg_strength,
sway_sampling_coef,
speed,
fix_duration,
infer_process,
load_model,
load_vocoder,
@@ -19,6 +28,7 @@ from f5_tts.infer.utils_infer import (
)
from f5_tts.model import DiT, UNetT
parser = argparse.ArgumentParser(
prog="python3 infer-cli.py",
description="Commandline interface for E2/F5 TTS with Advanced Batch Processing.",
@@ -27,86 +37,161 @@ parser = argparse.ArgumentParser(
parser.add_argument(
"-c",
"--config",
help="Configuration file. Default=infer/examples/basic/basic.toml",
type=str,
default=os.path.join(files("f5_tts").joinpath("infer/examples/basic"), "basic.toml"),
help="The configuration file, default see infer/examples/basic/basic.toml",
)
# Note. Not to provide default value here in order to read default from config file
parser.add_argument(
"-m",
"--model",
help="F5-TTS | E2-TTS",
type=str,
help="The model name: F5-TTS | E2-TTS",
)
parser.add_argument(
"-p",
"--ckpt_file",
help="The Checkpoint .pt",
type=str,
help="The path to model checkpoint .pt, leave blank to use default",
)
parser.add_argument(
"-v",
"--vocab_file",
help="The vocab .txt",
type=str,
help="The path to vocab file .txt, leave blank to use default",
)
parser.add_argument(
"-r",
"--ref_audio",
type=str,
help="The reference audio file.",
)
parser.add_argument(
"-s",
"--ref_text",
type=str,
help="The transcript/subtitle for the reference audio",
)
parser.add_argument("-r", "--ref_audio", type=str, help="Reference audio file < 15 seconds.")
parser.add_argument("-s", "--ref_text", type=str, default="666", help="Subtitle for the reference audio.")
parser.add_argument(
"-t",
"--gen_text",
type=str,
help="Text to generate.",
help="The text to make model synthesize a speech",
)
parser.add_argument(
"-f",
"--gen_file",
type=str,
help="File with text to generate. Ignores --gen_text",
help="The file with text to generate, will ignore --gen_text",
)
parser.add_argument(
"-o",
"--output_dir",
type=str,
help="Path to output folder..",
help="The path to output folder",
)
parser.add_argument(
"-w",
"--output_file",
type=str,
help="Filename of output file..",
help="The name of output file",
)
parser.add_argument(
"--save_chunk",
action="store_true",
help="Save chunk audio if your text is too long.",
help="To save each audio chunks during inference",
)
parser.add_argument(
"--remove_silence",
help="Remove silence.",
action="store_true",
help="To remove long silence found in ouput",
)
parser.add_argument("--vocoder_name", type=str, default="vocos", choices=["vocos", "bigvgan"], help="vocoder name")
parser.add_argument(
"--load_vocoder_from_local",
action="store_true",
help="load vocoder from local. Default: ../checkpoints/charactr/vocos-mel-24khz",
help="To load vocoder from local dir, default to ../checkpoints/charactr/vocos-mel-24khz",
)
parser.add_argument(
"--speed",
"--vocoder_name",
type=str,
choices=["vocos", "bigvgan"],
help=f"Used vocoder name: vocos | bigvgan, default {mel_spec_type}",
)
parser.add_argument(
"--target_rms",
type=float,
default=1.0,
help="Adjust the speed of the audio generation (default: 1.0)",
help=f"Target output speech loudness normalization value, default {target_rms}",
)
parser.add_argument(
"--cross_fade_duration",
type=float,
help=f"Duration of cross-fade between audio segments in seconds, default {cross_fade_duration}",
)
parser.add_argument(
"--nfe_step",
type=int,
default=32,
help="Set the number of denoising steps (default: 32)",
help=f"The number of function evaluation (denoising steps), default {nfe_step}",
)
parser.add_argument(
"--cfg_strength",
type=float,
help=f"Classifier-free guidance strength, default {cfg_strength}",
)
parser.add_argument(
"--sway_sampling_coef",
type=float,
help=f"Sway Sampling coefficient, default {sway_sampling_coef}",
)
parser.add_argument(
"--speed",
type=float,
help=f"The speed of the generated audio, default {speed}",
)
parser.add_argument(
"--fix_duration",
type=float,
help=f"Fix the total duration (ref and gen audios) in seconds, default {fix_duration}",
)
args = parser.parse_args()
# config file
config = tomli.load(open(args.config, "rb"))
ref_audio = args.ref_audio if args.ref_audio else config["ref_audio"]
ref_text = args.ref_text if args.ref_text != "666" else config["ref_text"]
gen_text = args.gen_text if args.gen_text else config["gen_text"]
gen_file = args.gen_file if args.gen_file else config["gen_file"]
save_chunk = args.save_chunk if args.save_chunk else False
# command-line interface parameters
model = args.model or config.get("model", "F5-TTS")
ckpt_file = args.ckpt_file or config.get("ckpt_file", "")
vocab_file = args.vocab_file or config.get("vocab_file", "")
ref_audio = args.ref_audio or config.get("ref_audio", "infer/examples/basic/basic_ref_en.wav")
ref_text = args.ref_text or config.get("ref_text", "Some call me nature, others call me mother nature.")
gen_text = args.gen_text or config.get("gen_text", "Here we generate something just for test.")
gen_file = args.gen_file or config.get("gen_file", "")
output_dir = args.output_dir or config.get("output_dir", "tests")
output_file = args.output_file or config.get(
"output_file", f"infer_cli_{datetime.now().strftime(r'%Y%m%d_%H%M%S')}.wav"
)
save_chunk = args.save_chunk
remove_silence = args.remove_silence
load_vocoder_from_local = args.load_vocoder_from_local
vocoder_name = args.vocoder_name or config.get("vocoder_name", mel_spec_type)
target_rms = args.target_rms or config.get("target_rms", target_rms)
cross_fade_duration = args.cross_fade_duration or config.get("cross_fade_duration", cross_fade_duration)
nfe_step = args.nfe_step or config.get("nfe_step", nfe_step)
cfg_strength = args.cfg_strength or config.get("cfg_strength", cfg_strength)
sway_sampling_coef = args.sway_sampling_coef or config.get("sway_sampling_coef", sway_sampling_coef)
speed = args.speed or config.get("speed", speed)
fix_duration = args.fix_duration or config.get("fix_duration", fix_duration)
# patches for pip pkg user
if "infer/examples/" in ref_audio:
@@ -119,35 +204,39 @@ if "voices" in config:
if "infer/examples/" in voice_ref_audio:
config["voices"][voice]["ref_audio"] = str(files("f5_tts").joinpath(f"{voice_ref_audio}"))
# ignore gen_text if gen_file provided
if gen_file:
gen_text = codecs.open(gen_file, "r", "utf-8").read()
output_dir = args.output_dir if args.output_dir else config["output_dir"]
output_file = args.output_file if args.output_file else config["output_file"]
model = args.model if args.model else config["model"]
ckpt_file = args.ckpt_file if args.ckpt_file else ""
vocab_file = args.vocab_file if args.vocab_file else ""
remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
speed = args.speed
nfe_step = args.nfe_step
# output path
wave_path = Path(output_dir) / output_file
# spectrogram_path = Path(output_dir) / "infer_cli_out.png"
if save_chunk:
output_chunk_dir = os.path.join(output_dir, f"{Path(output_file).stem}_chunks")
if not os.path.exists(output_chunk_dir):
os.makedirs(output_chunk_dir)
# load vocoder
vocoder_name = args.vocoder_name
mel_spec_type = args.vocoder_name
if vocoder_name == "vocos":
vocoder_local_path = "../checkpoints/vocos-mel-24khz"
elif vocoder_name == "bigvgan":
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=args.load_vocoder_from_local, local_path=vocoder_local_path)
vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path)
# load models
# load TTS model
if model == "F5-TTS":
model_cls = DiT
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
if ckpt_file == "":
if not ckpt_file: # path not specified, download from repo
if vocoder_name == "vocos":
repo_name = "F5-TTS"
exp_name = "F5TTS_Base"
@@ -164,19 +253,21 @@ elif model == "E2-TTS":
assert vocoder_name == "vocos", "E2-TTS only supports vocoder vocos"
model_cls = UNetT
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
if ckpt_file == "":
if not ckpt_file: # path not specified, download from repo
repo_name = "E2-TTS"
exp_name = "E2TTS_Base"
ckpt_step = 1200000
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
# ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
print(f"Using {model}...")
ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=mel_spec_type, vocab_file=vocab_file)
ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file)
def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove_silence, speed):
# inference process
def main():
main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
if "voices" not in config:
voices = {"main": main_voice}
@@ -184,16 +275,16 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove
voices = config["voices"]
voices["main"] = main_voice
for voice in voices:
print("Voice:", voice)
print("ref_audio ", voices[voice]["ref_audio"])
voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
voices[voice]["ref_audio"], voices[voice]["ref_text"]
)
print("Voice:", voice)
print("Ref_audio:", voices[voice]["ref_audio"])
print("Ref_text:", voices[voice]["ref_text"])
print("ref_audio_", voices[voice]["ref_audio"], "\n\n")
generated_audio_segments = []
reg1 = r"(?=\[\w+\])"
chunks = re.split(reg1, text_gen)
chunks = re.split(reg1, gen_text)
reg2 = r"\[(\w+)\]"
for text in chunks:
if not text.strip():
@@ -208,21 +299,35 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove
print(f"Voice {voice} not found, using main.")
voice = "main"
text = re.sub(reg2, "", text)
gen_text = text.strip()
ref_audio = voices[voice]["ref_audio"]
ref_text = voices[voice]["ref_text"]
ref_audio_ = voices[voice]["ref_audio"]
ref_text_ = voices[voice]["ref_text"]
gen_text_ = text.strip()
print(f"Voice: {voice}")
audio, final_sample_rate, spectragram = infer_process(
ref_audio,
ref_text,
gen_text,
model_obj,
audio_segment, final_sample_rate, spectragram = infer_process(
ref_audio_,
ref_text_,
gen_text_,
ema_model,
vocoder,
mel_spec_type=mel_spec_type,
speed=speed,
mel_spec_type=vocoder_name,
target_rms=target_rms,
cross_fade_duration=cross_fade_duration,
nfe_step=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
speed=speed,
fix_duration=fix_duration,
)
generated_audio_segments.append(audio)
generated_audio_segments.append(audio_segment)
if save_chunk:
if len(gen_text_) > 200:
gen_text_ = gen_text_[:200] + " ... "
sf.write(
os.path.join(output_chunk_dir, f"{len(generated_audio_segments)-1}_{gen_text_}.wav"),
audio_segment,
final_sample_rate,
)
if generated_audio_segments:
final_wave = np.concatenate(generated_audio_segments)
@@ -236,22 +341,6 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove
if remove_silence:
remove_silence_for_generated_wav(f.name)
print(f.name)
# Ensure the gen_text chunk directory exists
if save_chunk:
gen_text_chunk_dir = os.path.join(output_dir, "chunks")
if not os.path.exists(gen_text_chunk_dir): # if Not create directory
os.makedirs(gen_text_chunk_dir)
# Save individual chunks as separate files
for idx, segment in enumerate(generated_audio_segments):
gen_text_chunk_path = os.path.join(output_dir, gen_text_chunk_dir, f"chunk_{idx}.wav")
sf.write(gen_text_chunk_path, segment, final_sample_rate)
print(f"Saved gen_text chunk {idx} at {gen_text_chunk_path}")
def main():
main_process(ref_audio, ref_text, gen_text, ema_model, mel_spec_type, remove_silence, speed)
if __name__ == "__main__":