mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-12 07:40:43 -08:00
convert to pkg, reorganize repo (#228)
* group files in f5_tts directory * add setup.py * use global imports * simplify demo * add install directions for library mode * fix old huggingface_hub version constraint * move finetune to package * change imports to f5_tts.model * bump version * fix bad merge * Update inference-cli.py * fix HF space * reformat * fix utils.py vocab.txt import * fix format * adapt README for f5_tts package structure * simplify app.py * add gradio.Dockerfile and workflow * refactored for pyproject.toml * refactored for pyproject.toml * added in reference to packaged files * use fork for testing docker image * added in reference to packaged files * minor tweaks * fixed inference-cli.toml path * fixed inference-cli.toml path * fixed inference-cli.toml path * fixed inference-cli.toml path * refactor eval_infer_batch.py * fix typo * added eval_infer_batch to scripts --------- Co-authored-by: Roberts Slisans <rsxdalv@gmail.com> Co-authored-by: Adam Kessel <adam@rosi-kessel.org> Co-authored-by: Roberts Slisans <roberts.slisans@gmail.com>
This commit is contained in:
132
src/f5_tts/api.py
Normal file
132
src/f5_tts/api.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import soundfile as sf
|
||||
import torch
|
||||
import tqdm
|
||||
from cached_path import cached_path
|
||||
|
||||
from f5_tts.model import DiT, UNetT
|
||||
from f5_tts.model.utils import save_spectrogram
|
||||
|
||||
from f5_tts.model.utils_infer import load_vocoder, load_model, infer_process, remove_silence_for_generated_wav
|
||||
from f5_tts.model.utils import seed_everything
|
||||
import random
|
||||
import sys
|
||||
|
||||
|
||||
class F5TTS:
|
||||
def __init__(
|
||||
self,
|
||||
model_type="F5-TTS",
|
||||
ckpt_file="",
|
||||
vocab_file="",
|
||||
ode_method="euler",
|
||||
use_ema=True,
|
||||
local_path=None,
|
||||
device=None,
|
||||
):
|
||||
# Initialize parameters
|
||||
self.final_wave = None
|
||||
self.target_sample_rate = 24000
|
||||
self.n_mel_channels = 100
|
||||
self.hop_length = 256
|
||||
self.target_rms = 0.1
|
||||
self.seed = -1
|
||||
|
||||
# Set device
|
||||
self.device = device or (
|
||||
"cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
)
|
||||
|
||||
# Load models
|
||||
self.load_vocoder_model(local_path)
|
||||
self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema)
|
||||
|
||||
def load_vocoder_model(self, local_path):
|
||||
self.vocos = load_vocoder(local_path is not None, local_path, self.device)
|
||||
|
||||
def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema):
|
||||
if model_type == "F5-TTS":
|
||||
if not ckpt_file:
|
||||
ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))
|
||||
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
||||
model_cls = DiT
|
||||
elif model_type == "E2-TTS":
|
||||
if not ckpt_file:
|
||||
ckpt_file = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
|
||||
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
||||
model_cls = UNetT
|
||||
else:
|
||||
raise ValueError(f"Unknown model type: {model_type}")
|
||||
|
||||
self.ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file, ode_method, use_ema, self.device)
|
||||
|
||||
def export_wav(self, wav, file_wave, remove_silence=False):
|
||||
sf.write(file_wave, wav, self.target_sample_rate)
|
||||
|
||||
if remove_silence:
|
||||
remove_silence_for_generated_wav(file_wave)
|
||||
|
||||
def export_spectrogram(self, spect, file_spect):
|
||||
save_spectrogram(spect, file_spect)
|
||||
|
||||
def infer(
|
||||
self,
|
||||
ref_file,
|
||||
ref_text,
|
||||
gen_text,
|
||||
show_info=print,
|
||||
progress=tqdm,
|
||||
target_rms=0.1,
|
||||
cross_fade_duration=0.15,
|
||||
sway_sampling_coef=-1,
|
||||
cfg_strength=2,
|
||||
nfe_step=32,
|
||||
speed=1.0,
|
||||
fix_duration=None,
|
||||
remove_silence=False,
|
||||
file_wave=None,
|
||||
file_spect=None,
|
||||
seed=-1,
|
||||
):
|
||||
if seed == -1:
|
||||
seed = random.randint(0, sys.maxsize)
|
||||
seed_everything(seed)
|
||||
self.seed = seed
|
||||
wav, sr, spect = infer_process(
|
||||
ref_file,
|
||||
ref_text,
|
||||
gen_text,
|
||||
self.ema_model,
|
||||
show_info=show_info,
|
||||
progress=progress,
|
||||
target_rms=target_rms,
|
||||
cross_fade_duration=cross_fade_duration,
|
||||
nfe_step=nfe_step,
|
||||
cfg_strength=cfg_strength,
|
||||
sway_sampling_coef=sway_sampling_coef,
|
||||
speed=speed,
|
||||
fix_duration=fix_duration,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if file_wave is not None:
|
||||
self.export_wav(wav, file_wave, remove_silence)
|
||||
|
||||
if file_spect is not None:
|
||||
self.export_spectrogram(spect, file_spect)
|
||||
|
||||
return wav, sr, spect
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
f5tts = F5TTS()
|
||||
|
||||
wav, sr, spect = f5tts.infer(
|
||||
ref_file="tests/ref_audio/test_en_1_ref_short.wav",
|
||||
ref_text="some call me nature, others call me mother nature.",
|
||||
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
|
||||
file_wave="tests/out.wav",
|
||||
file_spect="tests/out.png",
|
||||
seed=-1, # random seed = -1
|
||||
)
|
||||
|
||||
print("seed :", f5tts.seed)
|
||||
Reference in New Issue
Block a user