1.0.0 F5-TTS v1 base model with better training and inference performance

This commit is contained in:
SWivid
2025-03-12 17:23:10 +08:00
parent 09b478b7d7
commit ca6e49adaa
40 changed files with 1036 additions and 652 deletions

66
.github/workflows/publish-pypi.yaml vendored Normal file
View File

@@ -0,0 +1,66 @@
# This workflow uses actions that are not certified by GitHub.
# They are provided by a third-party and are governed by
# separate terms of service, privacy policy, and support
# documentation.
# GitHub recommends pinning actions to a commit SHA.
# To get a newer version, you will need to update the SHA.
# You can also reference a tag or branch, but the action may change without warning.
name: Upload Python Package
on:
release:
types: [published]
permissions:
contents: read
jobs:
release-build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.x"
- name: Build release distributions
run: |
# NOTE: put your own distribution build steps here.
python -m pip install build
python -m build
- name: Upload distributions
uses: actions/upload-artifact@v4
with:
name: release-dists
path: dist/
pypi-publish:
runs-on: ubuntu-latest
needs:
- release-build
permissions:
# IMPORTANT: this permission is mandatory for trusted publishing
id-token: write
# Dedicated environments with protections for publishing are strongly recommended.
environment:
name: pypi
# OPTIONAL: uncomment and update to include your PyPI project URL in the deployment status:
# url: https://pypi.org/p/YOURPROJECT
steps:
- name: Retrieve release distributions
uses: actions/download-artifact@v4
with:
name: release-dists
path: dist/
- name: Publish release distributions to PyPI
uses: pypa/gh-action-pypi-publish@6f7e8d9c0b1a2c3d4e5f6a7b8c9d0e1f2a3b4c5d

View File

@@ -18,6 +18,7 @@
### Thanks to all the contributors !
## News
- **2025/03/12**: F5-TTS v1 base model with better training and inference performance.
- **2024/10/08**: F5-TTS & E2 TTS base models on [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS), [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), [🟣 Wisemodel](https://wisemodel.cn/models/SJTU_X-LANCE/F5-TTS_Emilia-ZH-EN).
## Installation
@@ -37,7 +38,7 @@ conda activate f5-tts
> ```bash
> # Install pytorch with your CUDA version, e.g.
> pip install torch==2.3.0+cu118 torchaudio==2.3.0+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
> pip install torch==2.4.0+cu124 torchaudio==2.4.0+cu124 --extra-index-url https://download.pytorch.org/whl/cu124
> ```
</details>
@@ -159,7 +160,7 @@ volumes:
# Run with flags
# Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
f5-tts_infer-cli \
--model "F5-TTS" \
--model "F5-TTS_v1" \
--ref_audio "ref_audio.wav" \
--ref_text "The content, subtitle or transcription of reference audio." \
--gen_text "Some text you want TTS model generate for you."

View File

@@ -3,8 +3,10 @@ Pretrained model ckpts. https://huggingface.co/SWivid/F5-TTS
```
ckpts/
E2TTS_Base/
model_1200000.pt
F5TTS_v1_Base/
model_1250000.safetensors
F5TTS_Base/
model_1200000.pt
model_1200000.safetensors
E2TTS_Base/
model_1200000.safetensors
```

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "f5-tts"
version = "0.6.2"
version = "1.0.0"
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
readme = "README.md"
license = {text = "MIT License"}
@@ -25,7 +25,6 @@ dependencies = [
"jieba",
"librosa",
"matplotlib",
"nltk",
"numpy<=1.26.4",
"pydub",
"pypinyin",

View File

@@ -5,43 +5,43 @@ from importlib.resources import files
import soundfile as sf
import tqdm
from cached_path import cached_path
from omegaconf import OmegaConf
from f5_tts.infer.utils_infer import (
hop_length,
infer_process,
load_model,
load_vocoder,
transcribe,
preprocess_ref_audio_text,
infer_process,
remove_silence_for_generated_wav,
save_spectrogram,
transcribe,
target_sample_rate,
)
from f5_tts.model import DiT, UNetT
from f5_tts.model import DiT, UNetT # noqa: F401. used for config
from f5_tts.model.utils import seed_everything
class F5TTS:
def __init__(
self,
model_type="F5-TTS",
model="F5TTS_v1_Base",
ckpt_file="",
vocab_file="",
ode_method="euler",
use_ema=True,
vocoder_name="vocos",
local_path=None,
vocoder_local_path=None,
device=None,
hf_cache_dir=None,
):
# Initialize parameters
self.final_wave = None
self.target_sample_rate = target_sample_rate
self.hop_length = hop_length
self.seed = -1
self.mel_spec_type = vocoder_name
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
model_cls = globals()[model_cfg.model.backbone]
model_arc = model_cfg.model.arch
self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
self.target_sample_rate = model_cfg.model.mel_spec.target_sample_rate
self.ode_method = ode_method
self.use_ema = use_ema
# Set device
if device is not None:
self.device = device
else:
@@ -58,39 +58,31 @@ class F5TTS:
)
# Load models
self.load_vocoder_model(vocoder_name, local_path=local_path, hf_cache_dir=hf_cache_dir)
self.load_ema_model(
model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema, hf_cache_dir=hf_cache_dir
self.vocoder = load_vocoder(
self.mel_spec_type, vocoder_local_path is not None, vocoder_local_path, self.device, hf_cache_dir
)
def load_vocoder_model(self, vocoder_name, local_path=None, hf_cache_dir=None):
self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device, hf_cache_dir)
repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, hf_cache_dir=None):
if model_type == "F5-TTS":
if not ckpt_file:
if mel_spec_type == "vocos":
ckpt_file = str(
cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir)
)
elif mel_spec_type == "bigvgan":
ckpt_file = str(
cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=hf_cache_dir)
)
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
model_cls = DiT
elif model_type == "E2-TTS":
if not ckpt_file:
ckpt_file = str(
cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir)
)
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
model_cls = UNetT
# override for previous models
if model == "F5TTS_Base":
if self.mel_spec_type == "vocos":
ckpt_step = 1200000
elif self.mel_spec_type == "bigvgan":
model = "F5TTS_Base_bigvgan"
ckpt_type = "pt"
elif model == "E2TTS_Base":
repo_name = "E2-TTS"
ckpt_step = 1200000
else:
raise ValueError(f"Unknown model type: {model_type}")
raise ValueError(f"Unknown model type: {model}")
if not ckpt_file:
ckpt_file = str(
cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}", cache_dir=hf_cache_dir)
)
self.ema_model = load_model(
model_cls, model_cfg, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, self.device
model_cls, model_arc, ckpt_file, self.mel_spec_type, vocab_file, self.ode_method, self.use_ema, self.device
)
def transcribe(self, ref_audio, language=None):
@@ -102,8 +94,8 @@ class F5TTS:
if remove_silence:
remove_silence_for_generated_wav(file_wave)
def export_spectrogram(self, spect, file_spect):
save_spectrogram(spect, file_spect)
def export_spectrogram(self, spec, file_spec):
save_spectrogram(spec, file_spec)
def infer(
self,
@@ -121,17 +113,16 @@ class F5TTS:
fix_duration=None,
remove_silence=False,
file_wave=None,
file_spect=None,
seed=-1,
file_spec=None,
seed=None,
):
if seed == -1:
seed = random.randint(0, sys.maxsize)
seed_everything(seed)
self.seed = seed
if seed is None:
self.seed = random.randint(0, sys.maxsize)
seed_everything(self.seed)
ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device)
wav, sr, spect = infer_process(
wav, sr, spec = infer_process(
ref_file,
ref_text,
gen_text,
@@ -153,22 +144,22 @@ class F5TTS:
if file_wave is not None:
self.export_wav(wav, file_wave, remove_silence)
if file_spect is not None:
self.export_spectrogram(spect, file_spect)
if file_spec is not None:
self.export_spectrogram(spec, file_spec)
return wav, sr, spect
return wav, sr, spec
if __name__ == "__main__":
f5tts = F5TTS()
wav, sr, spect = f5tts.infer(
wav, sr, spec = f5tts.infer(
ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
ref_text="some call me nature, others call me mother nature.",
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
file_spect=str(files("f5_tts").joinpath("../../tests/api_out.png")),
seed=-1, # random seed = -1
file_spec=str(files("f5_tts").joinpath("../../tests/api_out.png")),
seed=None,
)
print("seed :", f5tts.seed)

View File

@@ -1,16 +1,16 @@
hydra:
run:
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
datasets:
name: Emilia_ZH_EN # dataset name
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
batch_size_type: frame # "frame" or "sample"
batch_size_type: frame # frame | sample
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
num_workers: 16
optim:
epochs: 15
epochs: 11
learning_rate: 7.5e-5
num_warmup_updates: 20000 # warmup updates
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
@@ -20,25 +20,29 @@ optim:
model:
name: E2TTS_Base
tokenizer: pinyin
tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
backbone: UNetT
arch:
dim: 1024
depth: 24
heads: 16
ff_mult: 4
text_mask_padding: False
pe_attn_head: 1
mel_spec:
target_sample_rate: 24000
n_mel_channels: 100
hop_length: 256
win_length: 1024
n_fft: 1024
mel_spec_type: vocos # 'vocos' or 'bigvgan'
mel_spec_type: vocos # vocos | bigvgan
vocoder:
is_local: False # use local offline ckpt or not
local_path: None # local vocoder path
local_path: null # local vocoder path
ckpts:
logger: wandb # wandb | tensorboard | None
logger: wandb # wandb | tensorboard | null
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
save_per_updates: 50000 # save checkpoint per updates
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
last_per_updates: 5000 # save last checkpoint per updates

View File

@@ -1,16 +1,16 @@
hydra:
run:
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
datasets:
name: Emilia_ZH_EN
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
batch_size_type: frame # "frame" or "sample"
batch_size_type: frame # frame | sample
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
num_workers: 16
optim:
epochs: 15
epochs: 11
learning_rate: 7.5e-5
num_warmup_updates: 20000 # warmup updates
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
@@ -20,25 +20,29 @@ optim:
model:
name: E2TTS_Small
tokenizer: pinyin
tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
backbone: UNetT
arch:
dim: 768
depth: 20
heads: 12
ff_mult: 4
text_mask_padding: False
pe_attn_head: 1
mel_spec:
target_sample_rate: 24000
n_mel_channels: 100
hop_length: 256
win_length: 1024
n_fft: 1024
mel_spec_type: vocos # 'vocos' or 'bigvgan'
mel_spec_type: vocos # vocos | bigvgan
vocoder:
is_local: False # use local offline ckpt or not
local_path: None # local vocoder path
local_path: null # local vocoder path
ckpts:
logger: wandb # wandb | tensorboard | None
logger: wandb # wandb | tensorboard | null
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
save_per_updates: 50000 # save checkpoint per updates
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
last_per_updates: 5000 # save last checkpoint per updates

View File

@@ -1,16 +1,16 @@
hydra:
run:
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
datasets:
name: Emilia_ZH_EN # dataset name
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
batch_size_type: frame # "frame" or "sample"
batch_size_type: frame # frame | sample
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
num_workers: 16
optim:
epochs: 15
epochs: 11
learning_rate: 7.5e-5
num_warmup_updates: 20000 # warmup updates
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
@@ -20,14 +20,17 @@ optim:
model:
name: F5TTS_Base # model name
tokenizer: pinyin # tokenizer type
tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
backbone: DiT
arch:
dim: 1024
depth: 22
heads: 16
ff_mult: 2
text_dim: 512
text_mask_padding: False
conv_layers: 4
pe_attn_head: 1
checkpoint_activations: False # recompute activations and save memory for extra compute
mel_spec:
target_sample_rate: 24000
@@ -35,13 +38,14 @@ model:
hop_length: 256
win_length: 1024
n_fft: 1024
mel_spec_type: vocos # 'vocos' or 'bigvgan'
mel_spec_type: vocos # vocos | bigvgan
vocoder:
is_local: False # use local offline ckpt or not
local_path: None # local vocoder path
local_path: null # local vocoder path
ckpts:
logger: wandb # wandb | tensorboard | None
logger: wandb # wandb | tensorboard | null
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
save_per_updates: 50000 # save checkpoint per updates
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
last_per_updates: 5000 # save last checkpoint per updates

View File

@@ -1,16 +1,16 @@
hydra:
run:
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
datasets:
name: Emilia_ZH_EN
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
batch_size_type: frame # "frame" or "sample"
batch_size_type: frame # frame | sample
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
num_workers: 16
optim:
epochs: 15
epochs: 11
learning_rate: 7.5e-5
num_warmup_updates: 20000 # warmup updates
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
@@ -20,14 +20,17 @@ optim:
model:
name: F5TTS_Small
tokenizer: pinyin
tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
backbone: DiT
arch:
dim: 768
depth: 18
heads: 12
ff_mult: 2
text_dim: 512
text_mask_padding: False
conv_layers: 4
pe_attn_head: 1
checkpoint_activations: False # recompute activations and save memory for extra compute
mel_spec:
target_sample_rate: 24000
@@ -35,13 +38,14 @@ model:
hop_length: 256
win_length: 1024
n_fft: 1024
mel_spec_type: vocos # 'vocos' or 'bigvgan'
mel_spec_type: vocos # vocos | bigvgan
vocoder:
is_local: False # use local offline ckpt or not
local_path: None # local vocoder path
local_path: null # local vocoder path
ckpts:
logger: wandb # wandb | tensorboard | None
logger: wandb # wandb | tensorboard | null
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
save_per_updates: 50000 # save checkpoint per updates
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
last_per_updates: 5000 # save last checkpoint per updates

View File

@@ -0,0 +1,53 @@
hydra:
run:
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
datasets:
name: Emilia_ZH_EN # dataset name
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
batch_size_type: frame # frame | sample
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
num_workers: 16
optim:
epochs: 11
learning_rate: 7.5e-5
num_warmup_updates: 20000 # warmup updates
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
max_grad_norm: 1.0 # gradient clipping
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
model:
name: F5TTS_v1_Base # model name
tokenizer: pinyin # tokenizer type
tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
backbone: DiT
arch:
dim: 1024
depth: 22
heads: 16
ff_mult: 2
text_dim: 512
text_mask_padding: True
qk_norm: null # null | rms_norm
conv_layers: 4
pe_attn_head: null
checkpoint_activations: False # recompute activations and save memory for extra compute
mel_spec:
target_sample_rate: 24000
n_mel_channels: 100
hop_length: 256
win_length: 1024
n_fft: 1024
mel_spec_type: vocos # vocos | bigvgan
vocoder:
is_local: False # use local offline ckpt or not
local_path: null # local vocoder path
ckpts:
logger: wandb # wandb | tensorboard | null
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
save_per_updates: 50000 # save checkpoint per updates
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
last_per_updates: 5000 # save last checkpoint per updates
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}

View File

@@ -10,6 +10,7 @@ from importlib.resources import files
import torch
import torchaudio
from accelerate import Accelerator
from omegaconf import OmegaConf
from tqdm import tqdm
from f5_tts.eval.utils_eval import (
@@ -18,36 +19,26 @@ from f5_tts.eval.utils_eval import (
get_seedtts_testset_metainfo,
)
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
from f5_tts.model import CFM, DiT, UNetT
from f5_tts.model import CFM, DiT, UNetT # noqa: F401. used for config
from f5_tts.model.utils import get_tokenizer
accelerator = Accelerator()
device = f"cuda:{accelerator.process_index}"
# --------------------- Dataset Settings -------------------- #
target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
win_length = 1024
n_fft = 1024
use_ema = True
target_rms = 0.1
rel_path = str(files("f5_tts").joinpath("../../"))
def main():
# ---------------------- infer setting ---------------------- #
parser = argparse.ArgumentParser(description="batch inference")
parser.add_argument("-s", "--seed", default=None, type=int)
parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
parser.add_argument("-n", "--expname", required=True)
parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
parser.add_argument("-m", "--mel_spec_type", default="vocos", type=str, choices=["bigvgan", "vocos"])
parser.add_argument("-to", "--tokenizer", default="pinyin", type=str, choices=["pinyin", "char"])
parser.add_argument("-c", "--ckptstep", default=1250000, type=int)
parser.add_argument("-nfe", "--nfestep", default=32, type=int)
parser.add_argument("-o", "--odemethod", default="euler")
@@ -58,12 +49,8 @@ def main():
args = parser.parse_args()
seed = args.seed
dataset_name = args.dataset
exp_name = args.expname
ckpt_step = args.ckptstep
ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
mel_spec_type = args.mel_spec_type
tokenizer = args.tokenizer
nfe_step = args.nfestep
ode_method = args.odemethod
@@ -77,13 +64,19 @@ def main():
use_truth_duration = False
no_ref_audio = False
if exp_name == "F5TTS_Base":
model_cls = DiT
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
model_cls = globals()[model_cfg.model.backbone]
model_arc = model_cfg.model.arch
elif exp_name == "E2TTS_Base":
model_cls = UNetT
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
dataset_name = model_cfg.datasets.name
tokenizer = model_cfg.model.tokenizer
mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
target_sample_rate = model_cfg.model.mel_spec.target_sample_rate
n_mel_channels = model_cfg.model.mel_spec.n_mel_channels
hop_length = model_cfg.model.mel_spec.hop_length
win_length = model_cfg.model.mel_spec.win_length
n_fft = model_cfg.model.mel_spec.n_fft
if testset == "ls_pc_test_clean":
metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
@@ -111,8 +104,6 @@ def main():
# -------------------------------------------------#
use_ema = True
prompts_all = get_inference_prompt(
metainfo,
speed=speed,
@@ -139,7 +130,7 @@ def main():
# Model
model = CFM(
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
mel_spec_kwargs=dict(
n_fft=n_fft,
hop_length=hop_length,
@@ -154,6 +145,10 @@ def main():
vocab_char_map=vocab_char_map,
).to(device)
ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
if not os.path.exists(ckpt_path):
print("Loading from self-organized training checkpoints rather than released pretrained.")
ckpt_path = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}.pt"
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)

View File

@@ -1,13 +1,18 @@
#!/bin/bash
# e.g. F5-TTS, 16 NFE
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_zh" -nfe 16
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_en" -nfe 16
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "ls_pc_test_clean" -nfe 16
# e.g. Vanilla E2 TTS, 32 NFE
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_zh" -o "midpoint" -ss 0
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_en" -o "midpoint" -ss 0
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "ls_pc_test_clean" -o "midpoint" -ss 0
# e.g. evaluate F5-TTS 16 NFE result on Seed-TTS test-zh
python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
python src/f5_tts/eval/eval_utmos.py --audio_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0
# etc.

View File

@@ -53,43 +53,37 @@ def main():
asr_ckpt_dir = "" # auto download to cache dir
wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
# --------------------------- WER ---------------------------
# --------------------------------------------------------------------------
full_results = []
metrics = []
if eval_task == "wer":
wer_results = []
wers = []
with mp.Pool(processes=len(gpus)) as pool:
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
results = pool.map(run_asr_wer, args)
for r in results:
wer_results.extend(r)
wer_result_path = f"{gen_wav_dir}/{lang}_wer_results.jsonl"
with open(wer_result_path, "w") as f:
for line in wer_results:
wers.append(line["wer"])
json_line = json.dumps(line, ensure_ascii=False)
f.write(json_line + "\n")
wer = round(np.mean(wers) * 100, 3)
print(f"\nTotal {len(wers)} samples")
print(f"WER : {wer}%")
print(f"Results have been saved to {wer_result_path}")
# --------------------------- SIM ---------------------------
if eval_task == "sim":
sims = []
full_results.extend(r)
elif eval_task == "sim":
with mp.Pool(processes=len(gpus)) as pool:
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
results = pool.map(run_sim, args)
for r in results:
sims.extend(r)
full_results.extend(r)
else:
raise ValueError(f"Unknown metric type: {eval_task}")
sim = round(sum(sims) / len(sims), 3)
print(f"\nTotal {len(sims)} samples")
print(f"SIM : {sim}")
result_path = f"{gen_wav_dir}/_{eval_task}_results.jsonl"
with open(result_path, "w") as f:
for line in full_results:
metrics.append(line[eval_task])
f.write(json.dumps(line, ensure_ascii=False) + "\n")
metric = round(np.mean(metrics), 5)
f.write(f"\n{eval_task.upper()}: {metric}\n")
print(f"\nTotal {len(metrics)} samples")
print(f"{eval_task.upper()}: {metric}")
print(f"{eval_task.upper()} results saved to {result_path}")
if __name__ == "__main__":

View File

@@ -52,43 +52,37 @@ def main():
asr_ckpt_dir = "" # auto download to cache dir
wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
# --------------------------- WER ---------------------------
# --------------------------------------------------------------------------
full_results = []
metrics = []
if eval_task == "wer":
wer_results = []
wers = []
with mp.Pool(processes=len(gpus)) as pool:
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
results = pool.map(run_asr_wer, args)
for r in results:
wer_results.extend(r)
wer_result_path = f"{gen_wav_dir}/{lang}_wer_results.jsonl"
with open(wer_result_path, "w") as f:
for line in wer_results:
wers.append(line["wer"])
json_line = json.dumps(line, ensure_ascii=False)
f.write(json_line + "\n")
wer = round(np.mean(wers) * 100, 3)
print(f"\nTotal {len(wers)} samples")
print(f"WER : {wer}%")
print(f"Results have been saved to {wer_result_path}")
# --------------------------- SIM ---------------------------
if eval_task == "sim":
sims = []
full_results.extend(r)
elif eval_task == "sim":
with mp.Pool(processes=len(gpus)) as pool:
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
results = pool.map(run_sim, args)
for r in results:
sims.extend(r)
full_results.extend(r)
else:
raise ValueError(f"Unknown metric type: {eval_task}")
sim = round(sum(sims) / len(sims), 3)
print(f"\nTotal {len(sims)} samples")
print(f"SIM : {sim}")
result_path = f"{gen_wav_dir}/_{eval_task}_results.jsonl"
with open(result_path, "w") as f:
for line in full_results:
metrics.append(line[eval_task])
f.write(json.dumps(line, ensure_ascii=False) + "\n")
metric = round(np.mean(metrics), 5)
f.write(f"\n{eval_task.upper()}: {metric}\n")
print(f"\nTotal {len(metrics)} samples")
print(f"{eval_task.upper()}: {metric}")
print(f"{eval_task.upper()} results saved to {result_path}")
if __name__ == "__main__":

View File

@@ -19,25 +19,23 @@ def main():
predictor = predictor.to(device)
audio_paths = list(Path(args.audio_dir).rglob(f"*.{args.ext}"))
utmos_results = {}
utmos_score = 0
for audio_path in tqdm(audio_paths, desc="Processing"):
wav_name = audio_path.stem
wav, sr = librosa.load(audio_path, sr=None, mono=True)
wav_tensor = torch.from_numpy(wav).to(device).unsqueeze(0)
score = predictor(wav_tensor, sr)
utmos_results[str(wav_name)] = score.item()
utmos_score += score.item()
avg_score = utmos_score / len(audio_paths) if len(audio_paths) > 0 else 0
print(f"UTMOS: {avg_score}")
utmos_result_path = Path(args.audio_dir) / "utmos_results.json"
utmos_result_path = Path(args.audio_dir) / "_utmos_results.jsonl"
with open(utmos_result_path, "w", encoding="utf-8") as f:
json.dump(utmos_results, f, ensure_ascii=False, indent=4)
for audio_path in tqdm(audio_paths, desc="Processing"):
wav, sr = librosa.load(audio_path, sr=None, mono=True)
wav_tensor = torch.from_numpy(wav).to(device).unsqueeze(0)
score = predictor(wav_tensor, sr)
line = {}
line["wav"], line["utmos"] = str(audio_path.stem), score.item()
utmos_score += score.item()
f.write(json.dumps(line, ensure_ascii=False) + "\n")
avg_score = utmos_score / len(audio_paths) if len(audio_paths) > 0 else 0
f.write(f"\nUTMOS: {avg_score:.4f}\n")
print(f"Results have been saved to {utmos_result_path}")
print(f"UTMOS: {avg_score:.4f}")
print(f"UTMOS results saved to {utmos_result_path}")
if __name__ == "__main__":

View File

@@ -389,10 +389,10 @@ def run_sim(args):
model = model.cuda(device)
model.eval()
sims = []
for wav1, wav2, truth in tqdm(test_set):
wav1, sr1 = torchaudio.load(wav1)
wav2, sr2 = torchaudio.load(wav2)
sim_results = []
for gen_wav, prompt_wav, truth in tqdm(test_set):
wav1, sr1 = torchaudio.load(gen_wav)
wav2, sr2 = torchaudio.load(prompt_wav)
resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
@@ -408,6 +408,11 @@ def run_sim(args):
sim = F.cosine_similarity(emb1, emb2)[0].item()
# print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
sims.append(sim)
sim_results.append(
{
"wav": Path(gen_wav).stem,
"sim": sim,
}
)
return sims
return sim_results

View File

@@ -68,14 +68,16 @@ Basically you can inference with flags:
```bash
# Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
f5-tts_infer-cli \
--model "F5-TTS" \
--model F5TTS_v1_Base \
--ref_audio "ref_audio.wav" \
--ref_text "The content, subtitle or transcription of reference audio." \
--gen_text "Some text you want TTS model generate for you."
# Choose Vocoder
f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base_bigvgan/model_1250000.pt>
f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base/model_1200000.safetensors>
# Use BigVGAN as vocoder. Currently only support F5TTS_Base.
f5-tts_infer-cli --model F5TTS_Base --vocoder_name bigvgan --load_vocoder_from_local
# Use custom path checkpoint, e.g.
f5-tts_infer-cli --ckpt_file ckpts/F5TTS_Base/model_1200000.safetensors
# More instructions
f5-tts_infer-cli --help
@@ -90,8 +92,8 @@ f5-tts_infer-cli -c custom.toml
For example, you can use `.toml` to pass in variables, refer to `src/f5_tts/infer/examples/basic/basic.toml`:
```toml
# F5-TTS | E2-TTS
model = "F5-TTS"
# F5TTS_v1_Base | E2TTS_Base
model = "F5TTS_v1_Base"
ref_audio = "infer/examples/basic/basic_ref_en.wav"
# If an empty "", transcribes the reference audio automatically.
ref_text = "Some call me nature, others call me mother nature."
@@ -105,8 +107,8 @@ output_dir = "tests"
You can also leverage `.toml` file to do multi-style generation, refer to `src/f5_tts/infer/examples/multi/story.toml`.
```toml
# F5-TTS | E2-TTS
model = "F5-TTS"
# F5TTS_v1_Base | E2TTS_Base
model = "F5TTS_v1_Base"
ref_audio = "infer/examples/multi/main.flac"
# If an empty "", transcribes the reference audio automatically.
ref_text = ""
@@ -126,6 +128,22 @@ ref_text = ""
```
You should mark the voice with `[main]` `[town]` `[country]` whenever you want to change voice, refer to `src/f5_tts/infer/examples/multi/story.txt`.
## Socket Real-time Service
Real-time voice output with chunk stream:
```bash
# Start socket server
python src/f5_tts/socket_server.py
# If PyAudio not installed
sudo apt-get install portaudio19-dev
pip install pyaudio
# Communicate with socket client
python src/f5_tts/socket_client.py
```
## Speech Editing
To test speech editing capabilities, use the following command:
@@ -134,86 +152,3 @@ To test speech editing capabilities, use the following command:
python src/f5_tts/infer/speech_edit.py
```
## Socket Realtime Client
To communicate with socket server you need to run
```bash
python src/f5_tts/socket_server.py
```
<details>
<summary>Then create client to communicate</summary>
```bash
# If PyAudio not installed
sudo apt-get install portaudio19-dev
pip install pyaudio
```
``` python
# Create the socket_client.py
import socket
import asyncio
import pyaudio
import numpy as np
import logging
import time
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
async def listen_to_F5TTS(text, server_ip="localhost", server_port=9998):
client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
await asyncio.get_event_loop().run_in_executor(None, client_socket.connect, (server_ip, int(server_port)))
start_time = time.time()
first_chunk_time = None
async def play_audio_stream():
nonlocal first_chunk_time
p = pyaudio.PyAudio()
stream = p.open(format=pyaudio.paFloat32, channels=1, rate=24000, output=True, frames_per_buffer=2048)
try:
while True:
data = await asyncio.get_event_loop().run_in_executor(None, client_socket.recv, 8192)
if not data:
break
if data == b"END":
logger.info("End of audio received.")
break
audio_array = np.frombuffer(data, dtype=np.float32)
stream.write(audio_array.tobytes())
if first_chunk_time is None:
first_chunk_time = time.time()
finally:
stream.stop_stream()
stream.close()
p.terminate()
logger.info(f"Total time taken: {time.time() - start_time:.4f} seconds")
try:
data_to_send = f"{text}".encode("utf-8")
await asyncio.get_event_loop().run_in_executor(None, client_socket.sendall, data_to_send)
await play_audio_stream()
except Exception as e:
logger.error(f"Error in listen_to_F5TTS: {e}")
finally:
client_socket.close()
if __name__ == "__main__":
text_to_send = "As a Reader assistant, I'm familiar with new technology. which are key to its improved performance in terms of both training speed and inference efficiency. Let's break down the components"
asyncio.run(listen_to_F5TTS(text_to_send))
```
</details>

View File

@@ -16,7 +16,7 @@
<!-- omit in toc -->
### Supported Languages
- [Multilingual](#multilingual)
- [F5-TTS Base @ zh \& en @ F5-TTS](#f5-tts-base--zh--en--f5-tts)
- [F5-TTS v1 v0 Base @ zh \& en @ F5-TTS](#f5-tts-v1-v0-base--zh--en--f5-tts)
- [English](#english)
- [Finnish](#finnish)
- [F5-TTS Base @ fi @ AsmoKoskinen](#f5-tts-base--fi--asmokoskinen)
@@ -37,7 +37,17 @@
## Multilingual
#### F5-TTS Base @ zh & en @ F5-TTS
#### F5-TTS v1 v0 Base @ zh & en @ F5-TTS
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS v1 Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_v1_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
```bash
Model: hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors
Vocab: hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
```
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
@@ -45,7 +55,7 @@
```bash
Model: hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors
Vocab: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```
*Other infos, e.g. Author info, Github repo, Link to some sampled results, Usage instruction, Tutorial (Blog, Video, etc.) ...*
@@ -64,7 +74,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
```bash
Model: hf://AsmoKoskinen/F5-TTS_Finnish_Model/model_common_voice_fi_vox_populi_fi_20241206.safetensors
Vocab: hf://AsmoKoskinen/F5-TTS_Finnish_Model/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```
@@ -78,7 +88,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
```bash
Model: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_last_reduced.pt
Vocab: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```
- [Online Inference with Hugging Face Space](https://huggingface.co/spaces/RASPIAUDIO/f5-tts_french).
@@ -96,7 +106,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
```bash
Model: hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors
Vocab: hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt
Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```
- Authors: SPRING Lab, Indian Institute of Technology, Madras
@@ -113,7 +123,7 @@ Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "c
```bash
Model: hf://alien79/F5-TTS-italian/model_159600.safetensors
Vocab: hf://alien79/F5-TTS-italian/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```
- Trained by [Mithril Man](https://github.com/MithrilMan)
@@ -131,7 +141,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
```bash
Model: hf://Jmica/F5TTS/JA_25498980/model_25498980.pt
Vocab: hf://Jmica/F5TTS/JA_25498980/vocab_updated.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```
@@ -148,7 +158,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
```bash
Model: hf://hotstone228/F5-TTS-Russian/model_last.safetensors
Vocab: hf://hotstone228/F5-TTS-Russian/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```
- Finetuned by [HotDro4illa](https://github.com/HotDro4illa)
- Any improvements are welcome

View File

@@ -1,5 +1,5 @@
# F5-TTS | E2-TTS
model = "F5-TTS"
# F5TTS_v1_Base | E2TTS_Base
model = "F5TTS_v1_Base"
ref_audio = "infer/examples/basic/basic_ref_en.wav"
# If an empty "", transcribes the reference audio automatically.
ref_text = "Some call me nature, others call me mother nature."

View File

@@ -1,5 +1,5 @@
# F5-TTS | E2-TTS
model = "F5-TTS"
# F5TTS_v1_Base | E2TTS_Base
model = "F5TTS_v1_Base"
ref_audio = "infer/examples/multi/main.flac"
# If an empty "", transcribes the reference audio automatically.
ref_text = ""

View File

@@ -27,7 +27,7 @@ from f5_tts.infer.utils_infer import (
preprocess_ref_audio_text,
remove_silence_for_generated_wav,
)
from f5_tts.model import DiT, UNetT
from f5_tts.model import DiT, UNetT # noqa: F401. used for config
parser = argparse.ArgumentParser(
@@ -50,7 +50,7 @@ parser.add_argument(
"-m",
"--model",
type=str,
help="The model name: F5-TTS | E2-TTS",
help="The model name: F5TTS_v1_Base | F5TTS_Base | E2TTS_Base | etc.",
)
parser.add_argument(
"-mc",
@@ -172,8 +172,7 @@ config = tomli.load(open(args.config, "rb"))
# command-line interface parameters
model = args.model or config.get("model", "F5-TTS")
model_cfg = args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath("configs/F5TTS_Base_train.yaml")))
model = args.model or config.get("model", "F5TTS_v1_Base")
ckpt_file = args.ckpt_file or config.get("ckpt_file", "")
vocab_file = args.vocab_file or config.get("vocab_file", "")
@@ -245,36 +244,32 @@ vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=load_vocoder_from_loc
# load TTS model
if model == "F5-TTS":
model_cls = DiT
model_cfg = OmegaConf.load(model_cfg).model.arch
if not ckpt_file: # path not specified, download from repo
if vocoder_name == "vocos":
repo_name = "F5-TTS"
exp_name = "F5TTS_Base"
ckpt_step = 1200000
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
# ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
elif vocoder_name == "bigvgan":
repo_name = "F5-TTS"
exp_name = "F5TTS_Base_bigvgan"
ckpt_step = 1250000
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
model_cfg = OmegaConf.load(
args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
).model
model_cls = globals()[model_cfg.backbone]
elif model == "E2-TTS":
assert args.model_cfg is None, "E2-TTS does not support custom model_cfg yet"
assert vocoder_name == "vocos", "E2-TTS only supports vocoder vocos yet"
model_cls = UNetT
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
if not ckpt_file: # path not specified, download from repo
repo_name = "E2-TTS"
exp_name = "E2TTS_Base"
repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
if model != "F5TTS_Base":
assert vocoder_name == model_cfg.mel_spec.mel_spec_type
# override for previous models
if model == "F5TTS_Base":
if vocoder_name == "vocos":
ckpt_step = 1200000
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
# ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
elif vocoder_name == "bigvgan":
model = "F5TTS_Base_bigvgan"
ckpt_type = "pt"
elif model == "E2TTS_Base":
repo_name = "E2-TTS"
ckpt_step = 1200000
if not ckpt_file:
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}"))
print(f"Using {model}...")
ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file)
ema_model = load_model(model_cls, model_cfg.arch, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file)
# inference process

View File

@@ -41,12 +41,12 @@ from f5_tts.infer.utils_infer import (
)
DEFAULT_TTS_MODEL = "F5-TTS"
DEFAULT_TTS_MODEL = "F5-TTS_v1"
tts_model_choice = DEFAULT_TTS_MODEL
DEFAULT_TTS_MODEL_CFG = [
"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors",
"hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt",
"hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors",
"hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt",
json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)),
]
@@ -56,13 +56,15 @@ DEFAULT_TTS_MODEL_CFG = [
vocoder = load_vocoder()
def load_f5tts(ckpt_path=str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))):
F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
def load_f5tts():
ckpt_path = str(cached_path(DEFAULT_TTS_MODEL_CFG[0]))
F5TTS_model_cfg = json.loads(DEFAULT_TTS_MODEL_CFG[2])
return load_model(DiT, F5TTS_model_cfg, ckpt_path)
def load_e2tts(ckpt_path=str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))):
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
def load_e2tts():
ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4, text_mask_padding=False, pe_attn_head=1)
return load_model(UNetT, E2TTS_model_cfg, ckpt_path)
@@ -73,7 +75,7 @@ def load_custom(ckpt_path: str, vocab_path="", model_cfg=None):
if vocab_path.startswith("hf://"):
vocab_path = str(cached_path(vocab_path))
if model_cfg is None:
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
model_cfg = json.loads(DEFAULT_TTS_MODEL_CFG[2])
return load_model(DiT, model_cfg, ckpt_path, vocab_file=vocab_path)
@@ -130,7 +132,7 @@ def infer(
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
if model == "F5-TTS":
if model == DEFAULT_TTS_MODEL:
ema_model = F5TTS_ema_model
elif model == "E2-TTS":
global E2TTS_ema_model
@@ -762,7 +764,7 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
"""
)
last_used_custom = files("f5_tts").joinpath("infer/.cache/last_used_custom_model_info.txt")
last_used_custom = files("f5_tts").joinpath("infer/.cache/last_used_custom_model_info_v1.txt")
def load_last_used_custom():
try:
@@ -821,7 +823,30 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
custom_model_cfg = gr.Dropdown(
choices=[
DEFAULT_TTS_MODEL_CFG[2],
json.dumps(dict(dim=768, depth=18, heads=12, ff_mult=2, text_dim=512, conv_layers=4)),
json.dumps(
dict(
dim=1024,
depth=22,
heads=16,
ff_mult=2,
text_dim=512,
text_mask_padding=False,
conv_layers=4,
pe_attn_head=1,
)
),
json.dumps(
dict(
dim=768,
depth=18,
heads=12,
ff_mult=2,
text_dim=512,
text_mask_padding=False,
conv_layers=4,
pe_attn_head=1,
)
),
],
value=load_last_used_custom()[2],
allow_custom_value=True,

View File

@@ -2,12 +2,15 @@ import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
from importlib.resources import files
import torch
import torch.nn.functional as F
import torchaudio
from omegaconf import OmegaConf
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram
from f5_tts.model import CFM, DiT, UNetT
from f5_tts.model import CFM, DiT, UNetT # noqa: F401. used for config
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
device = (
@@ -21,44 +24,40 @@ device = (
)
# --------------------- Dataset Settings -------------------- #
target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
win_length = 1024
n_fft = 1024
mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
target_rms = 0.1
tokenizer = "pinyin"
dataset_name = "Emilia_ZH_EN"
# ---------------------- infer setting ---------------------- #
seed = None # int | None
exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
ckpt_step = 1200000
exp_name = "F5TTS_v1_Base" # F5TTS_v1_Base | E2TTS_Base
ckpt_step = 1250000
nfe_step = 32 # 16, 32
cfg_strength = 2.0
ode_method = "euler" # euler | midpoint
sway_sampling_coef = -1.0
speed = 1.0
target_rms = 0.1
if exp_name == "F5TTS_Base":
model_cls = DiT
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
elif exp_name == "E2TTS_Base":
model_cls = UNetT
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
model_cls = globals()[model_cfg.model.backbone]
model_arc = model_cfg.model.arch
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.safetensors"
dataset_name = model_cfg.datasets.name
tokenizer = model_cfg.model.tokenizer
mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
target_sample_rate = model_cfg.model.mel_spec.target_sample_rate
n_mel_channels = model_cfg.model.mel_spec.n_mel_channels
hop_length = model_cfg.model.mel_spec.hop_length
win_length = model_cfg.model.mel_spec.win_length
n_fft = model_cfg.model.mel_spec.n_fft
ckpt_path = str(files("f5_tts").joinpath("../../")) + f"ckpts/{exp_name}/model_{ckpt_step}.safetensors"
output_dir = "tests"
# [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
# pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git
# [write the origin_text into a file, e.g. tests/test_edit.txt]
@@ -67,7 +66,7 @@ output_dir = "tests"
# [--language "zho" for Chinese, "eng" for English]
# [if local ckpt, set --alignment_model "../checkpoints/mms-300m-1130-forced-aligner"]
audio_to_edit = "src/f5_tts/infer/examples/basic/basic_ref_en.wav"
audio_to_edit = str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav"))
origin_text = "Some call me nature, others call me mother nature."
target_text = "Some call me optimist, others call me realist."
parts_to_edit = [
@@ -106,7 +105,7 @@ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
# Model
model = CFM(
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
mel_spec_kwargs=dict(
n_fft=n_fft,
hop_length=hop_length,

View File

@@ -301,19 +301,19 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_in
)
non_silent_wave = AudioSegment.silent(duration=0)
for non_silent_seg in non_silent_segs:
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000:
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
show_info("Audio is over 15s, clipping short. (1)")
break
non_silent_wave += non_silent_seg
# 2. try to find short silence for clipping if 1. failed
if len(non_silent_wave) > 15000:
if len(non_silent_wave) > 12000:
non_silent_segs = silence.split_on_silence(
aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
)
non_silent_wave = AudioSegment.silent(duration=0)
for non_silent_seg in non_silent_segs:
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000:
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
show_info("Audio is over 15s, clipping short. (2)")
break
non_silent_wave += non_silent_seg
@@ -321,8 +321,8 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_in
aseg = non_silent_wave
# 3. if no proper silence found for clipping
if len(aseg) > 15000:
aseg = aseg[:15000]
if len(aseg) > 12000:
aseg = aseg[:12000]
show_info("Audio is over 15s, clipping short. (3)")
aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
@@ -383,7 +383,7 @@ def infer_process(
):
# Split the input text into batches
audio, sr = torchaudio.load(ref_audio)
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr))
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
for i, gen_text in enumerate(gen_text_batches):
print(f"gen_text {i}", gen_text)

View File

@@ -4,7 +4,7 @@
### unett.py
- flat unet transformer
- structure same as in e2-tts & voicebox paper except using rotary pos emb
- update: allow possible abs pos emb & convnextv2 blocks for embedded text before concat
- possible abs pos emb & convnextv2 blocks for embedded text before concat
### dit.py
- adaln-zero dit
@@ -14,7 +14,7 @@
- possible long skip connection (first layer to last layer)
### mmdit.py
- sd3 structure
- stable diffusion 3 block structure
- timestep as condition
- left stream: text embedded and applied a abs pos emb
- right stream: masked_cond & noised_input concatted and with same conv pos emb as unett

View File

@@ -20,7 +20,7 @@ from f5_tts.model.modules import (
ConvNeXtV2Block,
ConvPositionEmbedding,
DiTBlock,
AdaLayerNormZero_Final,
AdaLayerNorm_Final,
precompute_freqs_cis,
get_pos_embed_indices,
)
@@ -30,10 +30,12 @@ from f5_tts.model.modules import (
class TextEmbedding(nn.Module):
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2):
super().__init__()
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
self.mask_padding = mask_padding # mask filler and batch padding tokens or not
if conv_layers > 0:
self.extra_modeling = True
self.precompute_max_pos = 4096 # ~44s of 24khz audio
@@ -49,6 +51,8 @@ class TextEmbedding(nn.Module):
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
batch, text_len = text.shape[0], text.shape[1]
text = F.pad(text, (0, seq_len - text_len), value=0)
if self.mask_padding:
text_mask = text == 0
if drop_text: # cfg for text
text = torch.zeros_like(text)
@@ -64,7 +68,13 @@ class TextEmbedding(nn.Module):
text = text + text_pos_embed
# convnextv2 blocks
text = self.text_blocks(text)
if self.mask_padding:
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
for block in self.text_blocks:
text = block(text)
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
else:
text = self.text_blocks(text)
return text
@@ -103,7 +113,10 @@ class DiT(nn.Module):
mel_dim=100,
text_num_embeds=256,
text_dim=None,
text_mask_padding=True,
qk_norm=None,
conv_layers=0,
pe_attn_head=None,
long_skip_connection=False,
checkpoint_activations=False,
):
@@ -112,7 +125,10 @@ class DiT(nn.Module):
self.time_embed = TimestepEmbedding(dim)
if text_dim is None:
text_dim = mel_dim
self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
self.text_embed = TextEmbedding(
text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers
)
self.text_cond, self.text_uncond = None, None # text cache
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
self.rotary_embed = RotaryEmbedding(dim_head)
@@ -121,15 +137,40 @@ class DiT(nn.Module):
self.depth = depth
self.transformer_blocks = nn.ModuleList(
[DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
[
DiTBlock(
dim=dim,
heads=heads,
dim_head=dim_head,
ff_mult=ff_mult,
dropout=dropout,
qk_norm=qk_norm,
pe_attn_head=pe_attn_head,
)
for _ in range(depth)
]
)
self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
self.norm_out = AdaLayerNorm_Final(dim) # final modulation
self.proj_out = nn.Linear(dim, mel_dim)
self.checkpoint_activations = checkpoint_activations
self.initialize_weights()
def initialize_weights(self):
# Zero-out AdaLN layers in DiT blocks:
for block in self.transformer_blocks:
nn.init.constant_(block.attn_norm.linear.weight, 0)
nn.init.constant_(block.attn_norm.linear.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.norm_out.linear.weight, 0)
nn.init.constant_(self.norm_out.linear.bias, 0)
nn.init.constant_(self.proj_out.weight, 0)
nn.init.constant_(self.proj_out.bias, 0)
def ckpt_wrapper(self, module):
# https://github.com/chuanyangjin/fast-DiT/blob/main/models.py
def ckpt_forward(*inputs):
@@ -138,6 +179,9 @@ class DiT(nn.Module):
return ckpt_forward
def clear_cache(self):
self.text_cond, self.text_uncond = None, None
def forward(
self,
x: float["b n d"], # nosied input audio # noqa: F722
@@ -147,14 +191,25 @@ class DiT(nn.Module):
drop_audio_cond, # cfg for cond audio
drop_text, # cfg for text
mask: bool["b n"] | None = None, # noqa: F722
cache=False,
):
batch, seq_len = x.shape[0], x.shape[1]
if time.ndim == 0:
time = time.repeat(batch)
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
# t: conditioning time, text: text, x: noised audio + cond audio + text
t = self.time_embed(time)
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
if cache:
if drop_text:
if self.text_uncond is None:
self.text_uncond = self.text_embed(text, seq_len, drop_text=True)
text_embed = self.text_uncond
else:
if self.text_cond is None:
self.text_cond = self.text_embed(text, seq_len, drop_text=False)
text_embed = self.text_cond
else:
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
rope = self.rotary_embed.forward_from_seq_len(seq_len)

View File

@@ -18,7 +18,7 @@ from f5_tts.model.modules import (
TimestepEmbedding,
ConvPositionEmbedding,
MMDiTBlock,
AdaLayerNormZero_Final,
AdaLayerNorm_Final,
precompute_freqs_cis,
get_pos_embed_indices,
)
@@ -28,18 +28,24 @@ from f5_tts.model.modules import (
class TextEmbedding(nn.Module):
def __init__(self, out_dim, text_num_embeds):
def __init__(self, out_dim, text_num_embeds, mask_padding=True):
super().__init__()
self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token
self.mask_padding = mask_padding # mask filler and batch padding tokens or not
self.precompute_max_pos = 1024
self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722
text = text + 1
if drop_text:
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
if self.mask_padding:
text_mask = text == 0
if drop_text: # cfg for text
text = torch.zeros_like(text)
text = self.text_embed(text)
text = self.text_embed(text) # b nt -> b nt d
# sinus pos emb
batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
@@ -49,6 +55,9 @@ class TextEmbedding(nn.Module):
text = text + text_pos_embed
if self.mask_padding:
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
return text
@@ -83,13 +92,16 @@ class MMDiT(nn.Module):
dim_head=64,
dropout=0.1,
ff_mult=4,
text_num_embeds=256,
mel_dim=100,
text_num_embeds=256,
text_mask_padding=True,
qk_norm=None,
):
super().__init__()
self.time_embed = TimestepEmbedding(dim)
self.text_embed = TextEmbedding(dim, text_num_embeds)
self.text_embed = TextEmbedding(dim, text_num_embeds, mask_padding=text_mask_padding)
self.text_cond, self.text_uncond = None, None # text cache
self.audio_embed = AudioEmbedding(mel_dim, dim)
self.rotary_embed = RotaryEmbedding(dim_head)
@@ -106,13 +118,33 @@ class MMDiT(nn.Module):
dropout=dropout,
ff_mult=ff_mult,
context_pre_only=i == depth - 1,
qk_norm=qk_norm,
)
for i in range(depth)
]
)
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
self.norm_out = AdaLayerNorm_Final(dim) # final modulation
self.proj_out = nn.Linear(dim, mel_dim)
self.initialize_weights()
def initialize_weights(self):
# Zero-out AdaLN layers in MMDiT blocks:
for block in self.transformer_blocks:
nn.init.constant_(block.attn_norm_x.linear.weight, 0)
nn.init.constant_(block.attn_norm_x.linear.bias, 0)
nn.init.constant_(block.attn_norm_c.linear.weight, 0)
nn.init.constant_(block.attn_norm_c.linear.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.norm_out.linear.weight, 0)
nn.init.constant_(self.norm_out.linear.bias, 0)
nn.init.constant_(self.proj_out.weight, 0)
nn.init.constant_(self.proj_out.bias, 0)
def clear_cache(self):
self.text_cond, self.text_uncond = None, None
def forward(
self,
x: float["b n d"], # nosied input audio # noqa: F722
@@ -122,6 +154,7 @@ class MMDiT(nn.Module):
drop_audio_cond, # cfg for cond audio
drop_text, # cfg for text
mask: bool["b n"] | None = None, # noqa: F722
cache=False,
):
batch = x.shape[0]
if time.ndim == 0:
@@ -129,7 +162,17 @@ class MMDiT(nn.Module):
# t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(time)
c = self.text_embed(text, drop_text=drop_text)
if cache:
if drop_text:
if self.text_uncond is None:
self.text_uncond = self.text_embed(text, drop_text=True)
c = self.text_uncond
else:
if self.text_cond is None:
self.text_cond = self.text_embed(text, drop_text=False)
c = self.text_cond
else:
c = self.text_embed(text, drop_text=drop_text)
x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
seq_len = x.shape[1]

View File

@@ -33,10 +33,12 @@ from f5_tts.model.modules import (
class TextEmbedding(nn.Module):
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2):
super().__init__()
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
self.mask_padding = mask_padding # mask filler and batch padding tokens or not
if conv_layers > 0:
self.extra_modeling = True
self.precompute_max_pos = 4096 # ~44s of 24khz audio
@@ -52,6 +54,8 @@ class TextEmbedding(nn.Module):
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
batch, text_len = text.shape[0], text.shape[1]
text = F.pad(text, (0, seq_len - text_len), value=0)
if self.mask_padding:
text_mask = text == 0
if drop_text: # cfg for text
text = torch.zeros_like(text)
@@ -67,7 +71,13 @@ class TextEmbedding(nn.Module):
text = text + text_pos_embed
# convnextv2 blocks
text = self.text_blocks(text)
if self.mask_padding:
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
for block in self.text_blocks:
text = block(text)
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
else:
text = self.text_blocks(text)
return text
@@ -106,7 +116,10 @@ class UNetT(nn.Module):
mel_dim=100,
text_num_embeds=256,
text_dim=None,
text_mask_padding=True,
qk_norm=None,
conv_layers=0,
pe_attn_head=None,
skip_connect_type: Literal["add", "concat", "none"] = "concat",
):
super().__init__()
@@ -115,7 +128,10 @@ class UNetT(nn.Module):
self.time_embed = TimestepEmbedding(dim)
if text_dim is None:
text_dim = mel_dim
self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
self.text_embed = TextEmbedding(
text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers
)
self.text_cond, self.text_uncond = None, None # text cache
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
self.rotary_embed = RotaryEmbedding(dim_head)
@@ -134,11 +150,12 @@ class UNetT(nn.Module):
attn_norm = RMSNorm(dim)
attn = Attention(
processor=AttnProcessor(),
processor=AttnProcessor(pe_attn_head=pe_attn_head),
dim=dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
qk_norm=qk_norm,
)
ff_norm = RMSNorm(dim)
@@ -161,6 +178,9 @@ class UNetT(nn.Module):
self.norm_out = RMSNorm(dim)
self.proj_out = nn.Linear(dim, mel_dim)
def clear_cache(self):
self.text_cond, self.text_uncond = None, None
def forward(
self,
x: float["b n d"], # nosied input audio # noqa: F722
@@ -170,6 +190,7 @@ class UNetT(nn.Module):
drop_audio_cond, # cfg for cond audio
drop_text, # cfg for text
mask: bool["b n"] | None = None, # noqa: F722
cache=False,
):
batch, seq_len = x.shape[0], x.shape[1]
if time.ndim == 0:
@@ -177,7 +198,17 @@ class UNetT(nn.Module):
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(time)
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
if cache:
if drop_text:
if self.text_uncond is None:
self.text_uncond = self.text_embed(text, seq_len, drop_text=True)
text_embed = self.text_uncond
else:
if self.text_cond is None:
self.text_cond = self.text_embed(text, seq_len, drop_text=False)
text_embed = self.text_cond
else:
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
# postfix time t to input x, [b n d] -> [b n+1 d]

View File

@@ -162,13 +162,13 @@ class CFM(nn.Module):
# predict flow
pred = self.transformer(
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False, cache=True
)
if cfg_strength < 1e-5:
return pred
null_pred = self.transformer(
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True, cache=True
)
return pred + (pred - null_pred) * cfg_strength
@@ -195,6 +195,7 @@ class CFM(nn.Module):
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
self.transformer.clear_cache()
sampled = trajectory[-1]
out = sampled

View File

@@ -173,7 +173,7 @@ class DynamicBatchSampler(Sampler[list[int]]):
"""
def __init__(
self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False
self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_residual: bool = False
):
self.sampler = sampler
self.frames_threshold = frames_threshold
@@ -208,12 +208,15 @@ class DynamicBatchSampler(Sampler[list[int]]):
batch = []
batch_frames = 0
if not drop_last and len(batch) > 0:
if not drop_residual and len(batch) > 0:
batches.append(batch)
del indices
self.batches = batches
# Ensure even batches with accelerate BatchSamplerShard cls under frame_per_batch setting
self.drop_last = True
def set_epoch(self, epoch: int) -> None:
"""Sets the epoch for this sampler."""
self.epoch = epoch

View File

@@ -269,11 +269,36 @@ class ConvNeXtV2Block(nn.Module):
return residual + x
# AdaLayerNormZero
# RMSNorm
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
self.native_rms_norm = float(torch.__version__[:3]) >= 2.4
def forward(self, x):
if self.native_rms_norm:
if self.weight.dtype in [torch.float16, torch.bfloat16]:
x = x.to(self.weight.dtype)
x = F.rms_norm(x, normalized_shape=(x.shape[-1],), weight=self.weight, eps=self.eps)
else:
variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.eps)
if self.weight.dtype in [torch.float16, torch.bfloat16]:
x = x.to(self.weight.dtype)
x = x * self.weight
return x
# AdaLayerNorm
# return with modulated x for attn input, and params for later mlp modulation
class AdaLayerNormZero(nn.Module):
class AdaLayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
@@ -290,11 +315,11 @@ class AdaLayerNormZero(nn.Module):
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
# AdaLayerNormZero for final layer
# AdaLayerNorm for final layer
# return only with modulated x for attn input, cuz no more mlp modulation
class AdaLayerNormZero_Final(nn.Module):
class AdaLayerNorm_Final(nn.Module):
def __init__(self, dim):
super().__init__()
@@ -341,7 +366,8 @@ class Attention(nn.Module):
dim_head: int = 64,
dropout: float = 0.0,
context_dim: Optional[int] = None, # if not None -> joint attention
context_pre_only=None,
context_pre_only: bool = False,
qk_norm: Optional[str] = None,
):
super().__init__()
@@ -362,18 +388,32 @@ class Attention(nn.Module):
self.to_k = nn.Linear(dim, self.inner_dim)
self.to_v = nn.Linear(dim, self.inner_dim)
if qk_norm is None:
self.q_norm = None
self.k_norm = None
elif qk_norm == "rms_norm":
self.q_norm = RMSNorm(dim_head, eps=1e-6)
self.k_norm = RMSNorm(dim_head, eps=1e-6)
else:
raise ValueError(f"Unimplemented qk_norm: {qk_norm}")
if self.context_dim is not None:
self.to_q_c = nn.Linear(context_dim, self.inner_dim)
self.to_k_c = nn.Linear(context_dim, self.inner_dim)
self.to_v_c = nn.Linear(context_dim, self.inner_dim)
if self.context_pre_only is not None:
self.to_q_c = nn.Linear(context_dim, self.inner_dim)
if qk_norm is None:
self.c_q_norm = None
self.c_k_norm = None
elif qk_norm == "rms_norm":
self.c_q_norm = RMSNorm(dim_head, eps=1e-6)
self.c_k_norm = RMSNorm(dim_head, eps=1e-6)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, dim))
self.to_out.append(nn.Dropout(dropout))
if self.context_pre_only is not None and not self.context_pre_only:
self.to_out_c = nn.Linear(self.inner_dim, dim)
if self.context_dim is not None and not self.context_pre_only:
self.to_out_c = nn.Linear(self.inner_dim, context_dim)
def forward(
self,
@@ -393,8 +433,11 @@ class Attention(nn.Module):
class AttnProcessor:
def __init__(self):
pass
def __init__(
self,
pe_attn_head: int | None = None, # number of attention head to apply rope, None for all
):
self.pe_attn_head = pe_attn_head
def __call__(
self,
@@ -405,19 +448,11 @@ class AttnProcessor:
) -> torch.FloatTensor:
batch_size = x.shape[0]
# `sample` projections.
# `sample` projections
query = attn.to_q(x)
key = attn.to_k(x)
value = attn.to_v(x)
# apply rotary position embedding
if rope is not None:
freqs, xpos_scale = rope
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
# attention
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
@@ -425,6 +460,25 @@ class AttnProcessor:
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# qk norm
if attn.q_norm is not None:
query = attn.q_norm(query)
if attn.k_norm is not None:
key = attn.k_norm(key)
# apply rotary position embedding
if rope is not None:
freqs, xpos_scale = rope
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
if self.pe_attn_head is not None:
pn = self.pe_attn_head
query[:, :pn, :, :] = apply_rotary_pos_emb(query[:, :pn, :, :], freqs, q_xpos_scale)
key[:, :pn, :, :] = apply_rotary_pos_emb(key[:, :pn, :, :], freqs, k_xpos_scale)
else:
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
# mask. e.g. inference got a batch with different target durations, mask out the padding
if mask is not None:
attn_mask = mask
@@ -470,16 +524,36 @@ class JointAttnProcessor:
batch_size = c.shape[0]
# `sample` projections.
# `sample` projections
query = attn.to_q(x)
key = attn.to_k(x)
value = attn.to_v(x)
# `context` projections.
# `context` projections
c_query = attn.to_q_c(c)
c_key = attn.to_k_c(c)
c_value = attn.to_v_c(c)
# attention
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
c_query = c_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
c_key = c_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
c_value = c_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# qk norm
if attn.q_norm is not None:
query = attn.q_norm(query)
if attn.k_norm is not None:
key = attn.k_norm(key)
if attn.c_q_norm is not None:
c_query = attn.c_q_norm(c_query)
if attn.c_k_norm is not None:
c_key = attn.c_k_norm(c_key)
# apply rope for context and noised input independently
if rope is not None:
freqs, xpos_scale = rope
@@ -492,16 +566,10 @@ class JointAttnProcessor:
c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
# attention
query = torch.cat([query, c_query], dim=1)
key = torch.cat([key, c_key], dim=1)
value = torch.cat([value, c_value], dim=1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# joint attention
query = torch.cat([query, c_query], dim=2)
key = torch.cat([key, c_key], dim=2)
value = torch.cat([value, c_value], dim=2)
# mask. e.g. inference got a batch with different target durations, mask out the padding
if mask is not None:
@@ -540,16 +608,17 @@ class JointAttnProcessor:
class DiTBlock(nn.Module):
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, qk_norm=None, pe_attn_head=None):
super().__init__()
self.attn_norm = AdaLayerNormZero(dim)
self.attn_norm = AdaLayerNorm(dim)
self.attn = Attention(
processor=AttnProcessor(),
processor=AttnProcessor(pe_attn_head=pe_attn_head),
dim=dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
qk_norm=qk_norm,
)
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
@@ -585,26 +654,30 @@ class MMDiTBlock(nn.Module):
context_pre_only: last layer only do prenorm + modulation cuz no more ffn
"""
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
def __init__(
self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_dim=None, context_pre_only=False, qk_norm=None
):
super().__init__()
if context_dim is None:
context_dim = dim
self.context_pre_only = context_pre_only
self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
self.attn_norm_x = AdaLayerNormZero(dim)
self.attn_norm_c = AdaLayerNorm_Final(context_dim) if context_pre_only else AdaLayerNorm(context_dim)
self.attn_norm_x = AdaLayerNorm(dim)
self.attn = Attention(
processor=JointAttnProcessor(),
dim=dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
context_dim=dim,
context_dim=context_dim,
context_pre_only=context_pre_only,
qk_norm=qk_norm,
)
if not context_pre_only:
self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
self.ff_norm_c = nn.LayerNorm(context_dim, elementwise_affine=False, eps=1e-6)
self.ff_c = FeedForward(dim=context_dim, mult=ff_mult, dropout=dropout, approximate="tanh")
else:
self.ff_norm_c = None
self.ff_c = None

View File

@@ -32,7 +32,7 @@ class Trainer:
save_per_updates=1000,
keep_last_n_checkpoints: int = -1, # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
checkpoint_path=None,
batch_size=32,
batch_size_per_gpu=32,
batch_size_type: str = "sample",
max_samples=32,
grad_accumulation_steps=1,
@@ -40,7 +40,7 @@ class Trainer:
noise_scheduler: str | None = None,
duration_predictor: torch.nn.Module | None = None,
logger: str | None = "wandb", # "wandb" | "tensorboard" | None
wandb_project="test_e2-tts",
wandb_project="test_f5-tts",
wandb_run_name="test_run",
wandb_resume_id: str = None,
log_samples: bool = False,
@@ -51,6 +51,7 @@ class Trainer:
mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
is_local_vocoder: bool = False, # use local path vocoder
local_vocoder_path: str = "", # local vocoder path
cfg_dict: dict = dict(), # training config
):
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
@@ -72,21 +73,23 @@ class Trainer:
else:
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
self.accelerator.init_trackers(
project_name=wandb_project,
init_kwargs=init_kwargs,
config={
if not cfg_dict:
cfg_dict = {
"epochs": epochs,
"learning_rate": learning_rate,
"num_warmup_updates": num_warmup_updates,
"batch_size": batch_size,
"batch_size_per_gpu": batch_size_per_gpu,
"batch_size_type": batch_size_type,
"max_samples": max_samples,
"grad_accumulation_steps": grad_accumulation_steps,
"max_grad_norm": max_grad_norm,
"gpus": self.accelerator.num_processes,
"noise_scheduler": noise_scheduler,
},
}
cfg_dict["gpus"] = self.accelerator.num_processes
self.accelerator.init_trackers(
project_name=wandb_project,
init_kwargs=init_kwargs,
config=cfg_dict,
)
elif self.logger == "tensorboard":
@@ -111,9 +114,9 @@ class Trainer:
self.save_per_updates = save_per_updates
self.keep_last_n_checkpoints = keep_last_n_checkpoints
self.last_per_updates = default(last_per_updates, save_per_updates)
self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
self.checkpoint_path = default(checkpoint_path, "ckpts/test_f5-tts")
self.batch_size = batch_size
self.batch_size_per_gpu = batch_size_per_gpu
self.batch_size_type = batch_size_type
self.max_samples = max_samples
self.grad_accumulation_steps = grad_accumulation_steps
@@ -179,7 +182,7 @@ class Trainer:
if (
not exists(self.checkpoint_path)
or not os.path.exists(self.checkpoint_path)
or not any(filename.endswith(".pt") for filename in os.listdir(self.checkpoint_path))
or not any(filename.endswith((".pt", ".safetensors")) for filename in os.listdir(self.checkpoint_path))
):
return 0
@@ -191,7 +194,7 @@ class Trainer:
all_checkpoints = [
f
for f in os.listdir(self.checkpoint_path)
if (f.startswith("model_") or f.startswith("pretrained_")) and f.endswith(".pt")
if (f.startswith("model_") or f.startswith("pretrained_")) and f.endswith((".pt", ".safetensors"))
]
# First try to find regular training checkpoints
@@ -205,8 +208,16 @@ class Trainer:
# If no training checkpoints, use pretrained model
latest_checkpoint = next(f for f in all_checkpoints if f.startswith("pretrained_"))
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
if latest_checkpoint.endswith(".safetensors"): # always a pretrained checkpoint
from safetensors.torch import load_file
checkpoint = load_file(f"{self.checkpoint_path}/{latest_checkpoint}", device="cpu")
checkpoint = {"ema_model_state_dict": checkpoint}
elif latest_checkpoint.endswith(".pt"):
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
checkpoint = torch.load(
f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu"
)
# patch for backward compatibility, 305e3ea
for key in ["ema_model.mel_spec.mel_stft.mel_scale.fb", "ema_model.mel_spec.mel_stft.spectrogram.window"]:
@@ -271,7 +282,7 @@ class Trainer:
num_workers=num_workers,
pin_memory=True,
persistent_workers=True,
batch_size=self.batch_size,
batch_size=self.batch_size_per_gpu,
shuffle=True,
generator=generator,
)
@@ -280,10 +291,10 @@ class Trainer:
sampler = SequentialSampler(train_dataset)
batch_sampler = DynamicBatchSampler(
sampler,
self.batch_size,
self.batch_size_per_gpu,
max_samples=self.max_samples,
random_seed=resumable_with_seed, # This enables reproducible shuffling
drop_last=False,
drop_residual=False,
)
train_dataloader = DataLoader(
train_dataset,

View File

@@ -133,11 +133,12 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
# convert char to pinyin
jieba.initialize()
print("Word segmentation module jieba initialized.\n")
def convert_char_to_pinyin(text_list, polyphone=True):
if jieba.dt.initialized is False:
jieba.default_logger.setLevel(50) # CRITICAL
jieba.initialize()
final_text_list = []
custom_trans = str.maketrans(
{";": ",", "": '"', "": '"', "": "'", "": "'"}

View File

@@ -9,7 +9,7 @@ mel_hop_length = 256
mel_sampling_rate = 24000
# target
wanted_max_updates = 1000000
wanted_max_updates = 1200000
# train params
gpus = 8

View File

@@ -0,0 +1,61 @@
import socket
import asyncio
import pyaudio
import numpy as np
import logging
import time
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
async def listen_to_F5TTS(text, server_ip="localhost", server_port=9998):
client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
await asyncio.get_event_loop().run_in_executor(None, client_socket.connect, (server_ip, int(server_port)))
start_time = time.time()
first_chunk_time = None
async def play_audio_stream():
nonlocal first_chunk_time
p = pyaudio.PyAudio()
stream = p.open(format=pyaudio.paFloat32, channels=1, rate=24000, output=True, frames_per_buffer=2048)
try:
while True:
data = await asyncio.get_event_loop().run_in_executor(None, client_socket.recv, 8192)
if not data:
break
if data == b"END":
logger.info("End of audio received.")
break
audio_array = np.frombuffer(data, dtype=np.float32)
stream.write(audio_array.tobytes())
if first_chunk_time is None:
first_chunk_time = time.time()
finally:
stream.stop_stream()
stream.close()
p.terminate()
logger.info(f"Total time taken: {time.time() - start_time:.4f} seconds")
try:
data_to_send = f"{text}".encode("utf-8")
await asyncio.get_event_loop().run_in_executor(None, client_socket.sendall, data_to_send)
await play_audio_stream()
except Exception as e:
logger.error(f"Error in listen_to_F5TTS: {e}")
finally:
client_socket.close()
if __name__ == "__main__":
text_to_send = "As a Reader assistant, I'm familiar with new technology. which are key to its improved performance in terms of both training speed and inference efficiency. Let's break down the components"
asyncio.run(listen_to_F5TTS(text_to_send))

View File

@@ -13,8 +13,9 @@ from importlib.resources import files
import torch
import torchaudio
from huggingface_hub import hf_hub_download
from omegaconf import OmegaConf
from f5_tts.model.backbones.dit import DiT
from f5_tts.model.backbones.dit import DiT # noqa: F401. used for config
from f5_tts.infer.utils_infer import (
chunk_text,
preprocess_ref_audio_text,
@@ -68,7 +69,7 @@ class AudioFileWriterThread(threading.Thread):
class TTSStreamingProcessor:
def __init__(self, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32):
def __init__(self, model, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32):
self.device = device or (
"cuda"
if torch.cuda.is_available()
@@ -78,21 +79,24 @@ class TTSStreamingProcessor:
if torch.backends.mps.is_available()
else "cpu"
)
self.mel_spec_type = "vocos"
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
self.model_cls = globals()[model_cfg.model.backbone]
self.model_arc = model_cfg.model.arch
self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
self.sampling_rate = model_cfg.model.mel_spec.target_sample_rate
self.model = self.load_ema_model(ckpt_file, vocab_file, dtype)
self.vocoder = self.load_vocoder_model()
self.sampling_rate = 24000
self.update_reference(ref_audio, ref_text)
self._warm_up()
self.file_writer_thread = None
self.first_package = True
def load_ema_model(self, ckpt_file, vocab_file, dtype):
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
model_cls = DiT
return load_model(
model_cls=model_cls,
model_cfg=model_cfg,
self.model_cls,
self.model_arc,
ckpt_path=ckpt_file,
mel_spec_type=self.mel_spec_type,
vocab_file=vocab_file,
@@ -212,9 +216,14 @@ if __name__ == "__main__":
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", default=9998)
parser.add_argument(
"--model",
default="F5TTS_v1_Base",
help="The model name, e.g. F5TTS_v1_Base",
)
parser.add_argument(
"--ckpt_file",
default=str(hf_hub_download(repo_id="SWivid/F5-TTS", filename="F5TTS_Base/model_1200000.safetensors")),
default=str(hf_hub_download(repo_id="SWivid/F5-TTS", filename="F5TTS_v1_Base/model_1250000.safetensors")),
help="Path to the model checkpoint file",
)
parser.add_argument(
@@ -242,6 +251,7 @@ if __name__ == "__main__":
try:
# Initialize the processor with the model and vocoder
processor = TTSStreamingProcessor(
model=args.model,
ckpt_file=args.ckpt_file,
vocab_file=args.vocab_file,
ref_audio=args.ref_audio,

View File

@@ -40,10 +40,10 @@ Once your datasets are prepared, you can start the training process.
accelerate config
# .yaml files are under src/f5_tts/configs directory
accelerate launch src/f5_tts/train/train.py --config-name F5TTS_Base_train.yaml
accelerate launch src/f5_tts/train/train.py --config-name F5TTS_v1_Base_train.yaml
# possible to overwrite accelerate and hydra config
accelerate launch --mixed_precision=fp16 src/f5_tts/train/train.py --config-name F5TTS_Small_train.yaml ++datasets.batch_size_per_gpu=19200
accelerate launch --mixed_precision=fp16 src/f5_tts/train/train.py --config-name F5TTS_v1_Base_train.yaml ++datasets.batch_size_per_gpu=19200
```
### 2. Finetuning practice
@@ -53,7 +53,7 @@ Gradio UI training/finetuning with `src/f5_tts/train/finetune_gradio.py` see [#1
The `use_ema = True` is harmful for early-stage finetuned checkpoints (which goes just few updates, thus ema weights still dominated by pretrained ones), try turn it off and see if provide better results.
### 3. Wandb Logging
### 3. W&B Logging
The `wandb/` dir will be created under path you run training/finetuning scripts.
@@ -62,7 +62,7 @@ By default, the training script does NOT use logging (assuming you didn't manual
To turn on wandb logging, you can either:
1. Manually login with `wandb login`: Learn more [here](https://docs.wandb.ai/ref/cli/wandb-login)
2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/site/ and set the environment variable as follows:
2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/authorize and set the environment variable as follows:
On Mac & Linux:
@@ -75,7 +75,7 @@ On Windows:
```
set WANDB_API_KEY=<YOUR WANDB API KEY>
```
Moreover, if you couldn't access Wandb and want to log metrics offline, you can the environment variable as follows:
Moreover, if you couldn't access W&B and want to log metrics offline, you can set the environment variable as follows:
```
export WANDB_MODE=offline

View File

@@ -1,12 +1,13 @@
import argparse
import os
import shutil
from importlib.resources import files
from cached_path import cached_path
from f5_tts.model import CFM, UNetT, DiT, Trainer
from f5_tts.model.utils import get_tokenizer
from f5_tts.model.dataset import load_dataset
from importlib.resources import files
# -------------------------- Dataset Settings --------------------------- #
@@ -20,19 +21,14 @@ mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
# -------------------------- Argument Parsing --------------------------- #
def parse_args():
# batch_size_per_gpu = 1000 settting for gpu 8GB
# batch_size_per_gpu = 1600 settting for gpu 12GB
# batch_size_per_gpu = 2000 settting for gpu 16GB
# batch_size_per_gpu = 3200 settting for gpu 24GB
# num_warmup_updates = 300 for 5000 sample about 10 hours
# change save_per_updates , last_per_updates change this value what you need ,
parser = argparse.ArgumentParser(description="Train CFM Model")
parser.add_argument(
"--exp_name", type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"], help="Experiment name"
"--exp_name",
type=str,
default="F5TTS_v1_Base",
choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"],
help="Experiment name",
)
parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use")
parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate for training")
@@ -88,19 +84,54 @@ def main():
checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}"))
# Model parameters based on experiment name
if args.exp_name == "F5TTS_Base":
if args.exp_name == "F5TTS_v1_Base":
wandb_resume_id = None
model_cls = DiT
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
model_cfg = dict(
dim=1024,
depth=22,
heads=16,
ff_mult=2,
text_dim=512,
conv_layers=4,
)
if args.finetune:
if args.pretrain is None:
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors"))
else:
ckpt_path = args.pretrain
elif args.exp_name == "F5TTS_Base":
wandb_resume_id = None
model_cls = DiT
model_cfg = dict(
dim=1024,
depth=22,
heads=16,
ff_mult=2,
text_dim=512,
text_mask_padding=False,
conv_layers=4,
pe_attn_head=1,
)
if args.finetune:
if args.pretrain is None:
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
else:
ckpt_path = args.pretrain
elif args.exp_name == "E2TTS_Base":
wandb_resume_id = None
model_cls = UNetT
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
model_cfg = dict(
dim=1024,
depth=24,
heads=16,
ff_mult=4,
text_mask_padding=False,
pe_attn_head=1,
)
if args.finetune:
if args.pretrain is None:
ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
@@ -120,6 +151,7 @@ def main():
print("copy checkpoint for finetune")
# Use the tokenizer and tokenizer_path provided in the command line arguments
tokenizer = args.tokenizer
if tokenizer == "custom":
if not args.tokenizer_path:
@@ -156,7 +188,7 @@ def main():
save_per_updates=args.save_per_updates,
keep_last_n_checkpoints=args.keep_last_n_checkpoints,
checkpoint_path=checkpoint_path,
batch_size=args.batch_size_per_gpu,
batch_size_per_gpu=args.batch_size_per_gpu,
batch_size_type=args.batch_size_type,
max_samples=args.max_samples,
grad_accumulation_steps=args.grad_accumulation_steps,

View File

@@ -1,36 +1,36 @@
import threading
import queue
import re
import gc
import json
import numpy as np
import os
import platform
import psutil
import queue
import random
import re
import signal
import shutil
import subprocess
import sys
import tempfile
import threading
import time
from glob import glob
from importlib.resources import files
from scipy.io import wavfile
import click
import gradio as gr
import librosa
import numpy as np
import torch
import torchaudio
from cached_path import cached_path
from datasets import Dataset as Dataset_
from datasets.arrow_writer import ArrowWriter
from safetensors.torch import save_file
from scipy.io import wavfile
from cached_path import cached_path
from safetensors.torch import load_file, save_file
from f5_tts.api import F5TTS
from f5_tts.model.utils import convert_char_to_pinyin
from f5_tts.infer.utils_infer import transcribe
from importlib.resources import files
training_process = None
@@ -118,16 +118,16 @@ def load_settings(project_name):
# Default settings
default_settings = {
"exp_name": "F5TTS_Base",
"learning_rate": 1e-05,
"batch_size_per_gpu": 1000,
"batch_size_type": "frame",
"exp_name": "F5TTS_v1_Base",
"learning_rate": 1e-5,
"batch_size_per_gpu": 1,
"batch_size_type": "sample",
"max_samples": 64,
"grad_accumulation_steps": 1,
"grad_accumulation_steps": 4,
"max_grad_norm": 1,
"epochs": 100,
"num_warmup_updates": 2,
"save_per_updates": 300,
"num_warmup_updates": 100,
"save_per_updates": 500,
"keep_last_n_checkpoints": -1,
"last_per_updates": 100,
"finetune": True,
@@ -362,18 +362,18 @@ def terminate_process(pid):
def start_training(
dataset_name="",
exp_name="F5TTS_Base",
learning_rate=1e-4,
batch_size_per_gpu=400,
batch_size_type="frame",
exp_name="F5TTS_v1_Base",
learning_rate=1e-5,
batch_size_per_gpu=1,
batch_size_type="sample",
max_samples=64,
grad_accumulation_steps=1,
grad_accumulation_steps=4,
max_grad_norm=1.0,
epochs=11,
num_warmup_updates=200,
save_per_updates=400,
epochs=100,
num_warmup_updates=100,
save_per_updates=500,
keep_last_n_checkpoints=-1,
last_per_updates=800,
last_per_updates=100,
finetune=True,
file_checkpoint_train="",
tokenizer_type="pinyin",
@@ -797,14 +797,14 @@ def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()):
print(f"Error processing {file_audio}: {e}")
continue
if duration < 1 or duration > 25:
if duration > 25:
error_files.append([file_audio, "duration > 25 sec"])
if duration < 1 or duration > 30:
if duration > 30:
error_files.append([file_audio, "duration > 30 sec"])
if duration < 1:
error_files.append([file_audio, "duration < 1 sec "])
continue
if len(text) < 3:
error_files.append([file_audio, "very small text len 3"])
error_files.append([file_audio, "very short text length 3"])
continue
text = clear_text(text)
@@ -871,40 +871,37 @@ def check_user(value):
def calculate_train(
name_project,
epochs,
learning_rate,
batch_size_per_gpu,
batch_size_type,
max_samples,
learning_rate,
num_warmup_updates,
save_per_updates,
last_per_updates,
finetune,
):
path_project = os.path.join(path_data, name_project)
file_duraction = os.path.join(path_project, "duration.json")
file_duration = os.path.join(path_project, "duration.json")
if not os.path.isfile(file_duraction):
hop_length = 256
sampling_rate = 24000
if not os.path.isfile(file_duration):
return (
1000,
epochs,
learning_rate,
batch_size_per_gpu,
max_samples,
num_warmup_updates,
save_per_updates,
last_per_updates,
"project not found !",
learning_rate,
)
with open(file_duraction, "r") as file:
with open(file_duration, "r") as file:
data = json.load(file)
duration_list = data["duration"]
samples = len(duration_list)
hours = sum(duration_list) / 3600
# if torch.cuda.is_available():
# gpu_properties = torch.cuda.get_device_properties(0)
# total_memory = gpu_properties.total_memory / (1024**3)
# elif torch.backends.mps.is_available():
# total_memory = psutil.virtual_memory().available / (1024**3)
max_sample_length = max(duration_list) * sampling_rate / hop_length
total_samples = len(duration_list)
total_duration = sum(duration_list)
if torch.cuda.is_available():
gpu_count = torch.cuda.device_count()
@@ -912,64 +909,39 @@ def calculate_train(
for i in range(gpu_count):
gpu_properties = torch.cuda.get_device_properties(i)
total_memory += gpu_properties.total_memory / (1024**3) # in GB
elif torch.xpu.is_available():
gpu_count = torch.xpu.device_count()
total_memory = 0
for i in range(gpu_count):
gpu_properties = torch.xpu.get_device_properties(i)
total_memory += gpu_properties.total_memory / (1024**3)
elif torch.backends.mps.is_available():
gpu_count = 1
total_memory = psutil.virtual_memory().available / (1024**3)
avg_gpu_memory = total_memory / gpu_count
# rough estimate of batch size
if batch_size_type == "frame":
batch = int(total_memory * 0.5)
batch = (lambda num: num + 1 if num % 2 != 0 else num)(batch)
batch_size_per_gpu = int(38400 / batch)
else:
batch_size_per_gpu = int(total_memory / 8)
batch_size_per_gpu = (lambda num: num + 1 if num % 2 != 0 else num)(batch_size_per_gpu)
batch = batch_size_per_gpu
batch_size_per_gpu = max(int(38400 * (avg_gpu_memory - 5) / 75), int(max_sample_length))
elif batch_size_type == "sample":
batch_size_per_gpu = int(200 / (total_duration / total_samples))
if batch_size_per_gpu <= 0:
batch_size_per_gpu = 1
if total_samples < 64:
max_samples = int(total_samples * 0.25)
if samples < 64:
max_samples = int(samples * 0.25)
else:
max_samples = 64
num_warmup_updates = max(num_warmup_updates, int(total_samples * 0.05))
num_warmup_updates = int(samples * 0.05)
save_per_updates = int(samples * 0.10)
last_per_updates = int(save_per_updates * 0.25)
# take 1.2M updates as the maximum
max_updates = 1200000
max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples)
num_warmup_updates = (lambda num: num + 1 if num % 2 != 0 else num)(num_warmup_updates)
save_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(save_per_updates)
last_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_updates)
if last_per_updates <= 0:
last_per_updates = 2
if batch_size_type == "frame":
mini_batch_duration = batch_size_per_gpu * gpu_count * hop_length / sampling_rate
updates_per_epoch = total_duration / mini_batch_duration
elif batch_size_type == "sample":
updates_per_epoch = total_samples / batch_size_per_gpu / gpu_count
total_hours = hours
mel_hop_length = 256
mel_sampling_rate = 24000
# target
wanted_max_updates = 1000000
# train params
gpus = gpu_count
frames_per_gpu = batch_size_per_gpu # 8 * 38400 = 307200
grad_accum = 1
# intermediate
mini_batch_frames = frames_per_gpu * grad_accum * gpus
mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
updates_per_epoch = total_hours / mini_batch_hours
# steps_per_epoch = updates_per_epoch * grad_accum
epochs = wanted_max_updates / updates_per_epoch
epochs = int(max_updates / updates_per_epoch)
if finetune:
learning_rate = 1e-5
@@ -977,14 +949,12 @@ def calculate_train(
learning_rate = 7.5e-5
return (
epochs,
learning_rate,
batch_size_per_gpu,
max_samples,
num_warmup_updates,
save_per_updates,
last_per_updates,
samples,
learning_rate,
int(epochs),
total_samples,
)
@@ -1021,7 +991,11 @@ def expand_model_embeddings(ckpt_path, new_ckpt_path, num_new_tokens=42):
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
ckpt = torch.load(ckpt_path, map_location="cpu")
if ckpt_path.endswith(".safetensors"):
ckpt = load_file(ckpt_path, device="cpu")
ckpt = {"ema_model_state_dict": ckpt}
elif ckpt_path.endswith(".pt"):
ckpt = torch.load(ckpt_path, map_location="cpu")
ema_sd = ckpt.get("ema_model_state_dict", {})
embed_key_ema = "ema_model.transformer.text_embed.text_embed.weight"
@@ -1089,9 +1063,11 @@ def vocab_extend(project_name, symbols, model_type):
with open(file_vocab_project, "w", encoding="utf-8") as f:
f.write("\n".join(vocab))
if model_type == "F5-TTS":
if model_type == "F5TTS_v1_Base":
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors"))
elif model_type == "F5TTS_Base":
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
else:
elif model_type == "E2TTS_Base":
ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
vocab_size_new = len(miss_symbols)
@@ -1101,7 +1077,7 @@ def vocab_extend(project_name, symbols, model_type):
os.makedirs(new_ckpt_path, exist_ok=True)
# Add pretrained_ prefix to model when copying for consistency with finetune_cli.py
new_ckpt_file = os.path.join(new_ckpt_path, "pretrained_model_1200000.pt")
new_ckpt_file = os.path.join(new_ckpt_path, "pretrained_" + os.path.basename(ckpt_path))
size = expand_model_embeddings(ckpt_path, new_ckpt_file, num_new_tokens=vocab_size_new)
@@ -1231,21 +1207,21 @@ def infer(
vocab_file = os.path.join(path_data, project, "vocab.txt")
tts_api = F5TTS(
model_type=exp_name, ckpt_file=file_checkpoint, vocab_file=vocab_file, device=device_test, use_ema=use_ema
model=exp_name, ckpt_file=file_checkpoint, vocab_file=vocab_file, device=device_test, use_ema=use_ema
)
print("update >> ", device_test, file_checkpoint, use_ema)
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
tts_api.infer(
gen_text=gen_text.lower().strip(),
ref_text=ref_text.lower().strip(),
ref_file=ref_audio,
ref_text=ref_text.lower().strip(),
gen_text=gen_text.lower().strip(),
nfe_step=nfe_step,
file_wave=f.name,
speed=speed,
seed=seed,
remove_silence=remove_silence,
file_wave=f.name,
seed=seed,
)
return f.name, tts_api.device, str(tts_api.seed)
@@ -1404,14 +1380,14 @@ def get_audio_select(file_sample):
with gr.Blocks() as app:
gr.Markdown(
"""
# E2/F5 TTS Automatic Finetune
# F5 TTS Automatic Finetune
This is a local web UI for F5 TTS with advanced batch processing support. This app supports the following TTS models:
This is a local web UI for F5 TTS finetuning support. This app supports the following TTS models:
* [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
* [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
The checkpoints support English and Chinese.
The pretrained checkpoints support English and Chinese.
For tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussions/143)
"""
@@ -1488,7 +1464,9 @@ Check the vocabulary for fine-tuning Emilia_ZH_EN to ensure all symbols are incl
Using the extended model, you can finetune to a new language that is missing symbols in the vocab. This creates a new model with a new vocabulary size and saves it in your ckpts/project folder.
```""")
exp_name_extend = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
exp_name_extend = gr.Radio(
label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base"
)
with gr.Row():
txt_extend = gr.Textbox(
@@ -1557,9 +1535,9 @@ Skip this step if you have your dataset, raw.arrow, duration.json, and vocab.txt
fn=get_random_sample_prepare, inputs=[cm_project], outputs=[random_text_prepare, random_audio_prepare]
)
with gr.TabItem("Train Data"):
with gr.TabItem("Train Model"):
gr.Markdown("""```plaintext
The auto-setting is still experimental. Please make sure that the epochs, save per updates, and last per updates are set correctly, or change them manually as needed.
The auto-setting is still experimental. Set a large value of epoch if not sure; and keep last N checkpoints if limited disk space.
If you encounter a memory error, try reducing the batch size per GPU to a smaller number.
```""")
with gr.Row():
@@ -1573,11 +1551,13 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
file_checkpoint_train = gr.Textbox(label="Path to the Pretrained Checkpoint", value="")
with gr.Row():
exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
exp_name = gr.Radio(
label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base"
)
learning_rate = gr.Number(label="Learning Rate", value=1e-5, step=1e-5)
with gr.Row():
batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=1000)
batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=3200)
max_samples = gr.Number(label="Max Samples", value=64)
with gr.Row():
@@ -1585,23 +1565,23 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
max_grad_norm = gr.Number(label="Max Gradient Norm", value=1.0)
with gr.Row():
epochs = gr.Number(label="Epochs", value=10)
num_warmup_updates = gr.Number(label="Warmup Updates", value=2)
epochs = gr.Number(label="Epochs", value=100)
num_warmup_updates = gr.Number(label="Warmup Updates", value=100)
with gr.Row():
save_per_updates = gr.Number(label="Save per Updates", value=300)
save_per_updates = gr.Number(label="Save per Updates", value=500)
keep_last_n_checkpoints = gr.Number(
label="Keep Last N Checkpoints",
value=-1,
step=1,
precision=0,
info="-1: Keep all checkpoints, 0: Only save final model_last.pt, N>0: Keep last N checkpoints",
info="-1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints",
)
last_per_updates = gr.Number(label="Last per Updates", value=100)
with gr.Row():
ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer")
mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "bf16"], value="none")
mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "bf16"], value="fp16")
cd_logger = gr.Radio(label="logger", choices=["wandb", "tensorboard"], value="wandb")
start_button = gr.Button("Start Training")
stop_button = gr.Button("Stop Training", interactive=False)
@@ -1718,23 +1698,21 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
fn=calculate_train,
inputs=[
cm_project,
epochs,
learning_rate,
batch_size_per_gpu,
batch_size_type,
max_samples,
learning_rate,
num_warmup_updates,
save_per_updates,
last_per_updates,
ch_finetune,
],
outputs=[
epochs,
learning_rate,
batch_size_per_gpu,
max_samples,
num_warmup_updates,
save_per_updates,
last_per_updates,
lb_samples,
learning_rate,
epochs,
],
)
@@ -1744,25 +1722,25 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
def setup_load_settings():
output_components = [
exp_name, # 1
learning_rate, # 2
batch_size_per_gpu, # 3
batch_size_type, # 4
max_samples, # 5
grad_accumulation_steps, # 6
max_grad_norm, # 7
epochs, # 8
num_warmup_updates, # 9
save_per_updates, # 10
keep_last_n_checkpoints, # 11
last_per_updates, # 12
ch_finetune, # 13
file_checkpoint_train, # 14
tokenizer_type, # 15
tokenizer_file, # 16
mixed_precision, # 17
cd_logger, # 18
ch_8bit_adam, # 19
exp_name,
learning_rate,
batch_size_per_gpu,
batch_size_type,
max_samples,
grad_accumulation_steps,
max_grad_norm,
epochs,
num_warmup_updates,
save_per_updates,
keep_last_n_checkpoints,
last_per_updates,
ch_finetune,
file_checkpoint_train,
tokenizer_type,
tokenizer_file,
mixed_precision,
cd_logger,
ch_8bit_adam,
]
return output_components
@@ -1784,7 +1762,9 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
gr.Markdown("""```plaintext
SOS: Check the use_ema setting (True or False) for your model to see what works best for you. use seed -1 from random
```""")
exp_name = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
exp_name = gr.Radio(
label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base"
)
list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False)
with gr.Row():
@@ -1838,9 +1818,9 @@ SOS: Check the use_ema setting (True or False) for your model to see what works
bt_checkpoint_refresh.click(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
cm_project.change(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
with gr.TabItem("Reduce Checkpoint"):
with gr.TabItem("Prune Checkpoint"):
gr.Markdown("""```plaintext
Reduce the model size from 5GB to 1.3GB. The new checkpoint can be used for inference or fine-tuning afterward, but it cannot be used to continue training.
Reduce the Base model size from 5GB to 1.3GB. The new checkpoint file prunes out optimizer and etc., can be used for inference or finetuning afterward, but not able to resume pretraining.
```""")
txt_path_checkpoint = gr.Text(label="Path to Checkpoint:")
txt_path_checkpoint_small = gr.Text(label="Path to Output:")

View File

@@ -4,8 +4,9 @@ import os
from importlib.resources import files
import hydra
from omegaconf import OmegaConf
from f5_tts.model import CFM, DiT, Trainer, UNetT
from f5_tts.model import CFM, DiT, UNetT, Trainer # noqa: F401. used for config
from f5_tts.model.dataset import load_dataset
from f5_tts.model.utils import get_tokenizer
@@ -14,9 +15,13 @@ os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to
@hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None)
def main(cfg):
model_cls = globals()[cfg.model.backbone]
model_arc = cfg.model.arch
tokenizer = cfg.model.tokenizer
mel_spec_type = cfg.model.mel_spec.mel_spec_type
exp_name = f"{cfg.model.name}_{mel_spec_type}_{cfg.model.tokenizer}_{cfg.datasets.name}"
wandb_resume_id = None
# set text tokenizer
if tokenizer != "custom":
@@ -26,14 +31,8 @@ def main(cfg):
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
# set model
if "F5TTS" in cfg.model.name:
model_cls = DiT
elif "E2TTS" in cfg.model.name:
model_cls = UNetT
wandb_resume_id = None
model = CFM(
transformer=model_cls(**cfg.model.arch, text_num_embeds=vocab_size, mel_dim=cfg.model.mel_spec.n_mel_channels),
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=cfg.model.mel_spec.n_mel_channels),
mel_spec_kwargs=cfg.model.mel_spec,
vocab_char_map=vocab_char_map,
)
@@ -45,9 +44,9 @@ def main(cfg):
learning_rate=cfg.optim.learning_rate,
num_warmup_updates=cfg.optim.num_warmup_updates,
save_per_updates=cfg.ckpts.save_per_updates,
keep_last_n_checkpoints=getattr(cfg.ckpts, "keep_last_n_checkpoints", -1),
keep_last_n_checkpoints=cfg.ckpts.keep_last_n_checkpoints,
checkpoint_path=str(files("f5_tts").joinpath(f"../../{cfg.ckpts.save_dir}")),
batch_size=cfg.datasets.batch_size_per_gpu,
batch_size_per_gpu=cfg.datasets.batch_size_per_gpu,
batch_size_type=cfg.datasets.batch_size_type,
max_samples=cfg.datasets.max_samples,
grad_accumulation_steps=cfg.optim.grad_accumulation_steps,
@@ -57,11 +56,12 @@ def main(cfg):
wandb_run_name=exp_name,
wandb_resume_id=wandb_resume_id,
last_per_updates=cfg.ckpts.last_per_updates,
log_samples=True,
log_samples=cfg.ckpts.log_samples,
bnb_optimizer=cfg.optim.bnb_optimizer,
mel_spec_type=mel_spec_type,
is_local_vocoder=cfg.model.vocoder.is_local,
local_vocoder_path=cfg.model.vocoder.local_path,
cfg_dict=OmegaConf.to_container(cfg, resolve=True),
)
train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)