convert to pkg, reorganize repo (#228)

* group files in f5_tts directory

* add setup.py

* use global imports

* simplify demo

* add install directions for library mode

* fix old huggingface_hub version constraint

* move finetune to package

* change imports to f5_tts.model

* bump version

* fix bad merge

* Update inference-cli.py

* fix HF space

* reformat

* fix utils.py vocab.txt import

* fix format

* adapt README for f5_tts package structure

* simplify app.py

* add gradio.Dockerfile and workflow

* refactored for pyproject.toml

* refactored for pyproject.toml

* added in reference to packaged files

* use fork for testing docker image

* added in reference to packaged files

* minor tweaks

* fixed inference-cli.toml path

* fixed inference-cli.toml path

* fixed inference-cli.toml path

* fixed inference-cli.toml path

* refactor eval_infer_batch.py

* fix typo

* added eval_infer_batch to scripts

---------

Co-authored-by: Roberts Slisans <rsxdalv@gmail.com>
Co-authored-by: Adam Kessel <adam@rosi-kessel.org>
Co-authored-by: Roberts Slisans <roberts.slisans@gmail.com>
This commit is contained in:
Yushen CHEN
2024-10-23 21:07:59 +08:00
committed by GitHub
parent 32c3ee7701
commit c4eee0f96b
38 changed files with 451 additions and 259 deletions

132
src/f5_tts/api.py Normal file
View File

@@ -0,0 +1,132 @@
import soundfile as sf
import torch
import tqdm
from cached_path import cached_path
from f5_tts.model import DiT, UNetT
from f5_tts.model.utils import save_spectrogram
from f5_tts.model.utils_infer import load_vocoder, load_model, infer_process, remove_silence_for_generated_wav
from f5_tts.model.utils import seed_everything
import random
import sys
class F5TTS:
def __init__(
self,
model_type="F5-TTS",
ckpt_file="",
vocab_file="",
ode_method="euler",
use_ema=True,
local_path=None,
device=None,
):
# Initialize parameters
self.final_wave = None
self.target_sample_rate = 24000
self.n_mel_channels = 100
self.hop_length = 256
self.target_rms = 0.1
self.seed = -1
# Set device
self.device = device or (
"cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
)
# Load models
self.load_vocoder_model(local_path)
self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema)
def load_vocoder_model(self, local_path):
self.vocos = load_vocoder(local_path is not None, local_path, self.device)
def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema):
if model_type == "F5-TTS":
if not ckpt_file:
ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
model_cls = DiT
elif model_type == "E2-TTS":
if not ckpt_file:
ckpt_file = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
model_cls = UNetT
else:
raise ValueError(f"Unknown model type: {model_type}")
self.ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file, ode_method, use_ema, self.device)
def export_wav(self, wav, file_wave, remove_silence=False):
sf.write(file_wave, wav, self.target_sample_rate)
if remove_silence:
remove_silence_for_generated_wav(file_wave)
def export_spectrogram(self, spect, file_spect):
save_spectrogram(spect, file_spect)
def infer(
self,
ref_file,
ref_text,
gen_text,
show_info=print,
progress=tqdm,
target_rms=0.1,
cross_fade_duration=0.15,
sway_sampling_coef=-1,
cfg_strength=2,
nfe_step=32,
speed=1.0,
fix_duration=None,
remove_silence=False,
file_wave=None,
file_spect=None,
seed=-1,
):
if seed == -1:
seed = random.randint(0, sys.maxsize)
seed_everything(seed)
self.seed = seed
wav, sr, spect = infer_process(
ref_file,
ref_text,
gen_text,
self.ema_model,
show_info=show_info,
progress=progress,
target_rms=target_rms,
cross_fade_duration=cross_fade_duration,
nfe_step=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
speed=speed,
fix_duration=fix_duration,
device=self.device,
)
if file_wave is not None:
self.export_wav(wav, file_wave, remove_silence)
if file_spect is not None:
self.export_spectrogram(spect, file_spect)
return wav, sr, spect
if __name__ == "__main__":
f5tts = F5TTS()
wav, sr, spect = f5tts.infer(
ref_file="tests/ref_audio/test_en_1_ref_short.wav",
ref_text="some call me nature, others call me mother nature.",
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
file_wave="tests/out.wav",
file_spect="tests/out.png",
seed=-1, # random seed = -1
)
print("seed :", f5tts.seed)