minor fix

This commit is contained in:
SWivid
2024-11-01 15:11:48 +08:00
parent 305e3eab35
commit 315230210d
8 changed files with 36 additions and 86 deletions

View File

@@ -43,14 +43,10 @@ pip install git+https://github.com/SWivid/F5-TTS.git
```bash ```bash
git clone https://github.com/SWivid/F5-TTS.git git clone https://github.com/SWivid/F5-TTS.git
cd F5-TTS cd F5-TTS
# git submodule update --init --recursive # (optional, if need bigvgan)
pip install -e . pip install -e .
# Init submodule (optional, if you want to change the vocoder from vocos to bigvgan)
# git submodule update --init --recursive
# pip install -e .
``` ```
If initialize submodule, you should add the following code at the beginning of `src/third_party/BigVGAN/bigvgan.py`.
After init submodule, you need to change the `src/third_party/BigVGAN/bigvgan.py` by adding the following code at the beginning of the file.
```python ```python
import os import os
import sys import sys

View File

@@ -120,6 +120,7 @@ def main():
target_sample_rate=target_sample_rate, target_sample_rate=target_sample_rate,
n_mel_channels=n_mel_channels, n_mel_channels=n_mel_channels,
hop_length=hop_length, hop_length=hop_length,
mel_spec_type=mel_spec_type,
target_rms=target_rms, target_rms=target_rms,
use_truth_duration=use_truth_duration, use_truth_duration=use_truth_duration,
infer_batch_size=infer_batch_size, infer_batch_size=infer_batch_size,
@@ -153,12 +154,7 @@ def main():
vocab_char_map=vocab_char_map, vocab_char_map=vocab_char_map,
).to(device) ).to(device)
supports_fp16 = device == "cuda" and torch.cuda.get_device_properties(device).major >= 6 dtype = torch.float32 if mel_spec_type == "bigvgan" else None
if supports_fp16 and mel_spec_type == "vocos":
dtype = torch.float16
elif mel_spec_type == "bigvgan":
dtype = torch.float32
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema) model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
if not os.path.exists(output_dir) and accelerator.is_main_process: if not os.path.exists(output_dir) and accelerator.is_main_process:

View File

@@ -78,7 +78,7 @@ def get_inference_prompt(
win_length=1024, win_length=1024,
n_mel_channels=100, n_mel_channels=100,
hop_length=256, hop_length=256,
mel_spec_type="bigvgan", mel_spec_type="vocos",
target_rms=0.1, target_rms=0.1,
use_truth_duration=False, use_truth_duration=False,
infer_batch_size=1, infer_batch_size=1,

View File

@@ -58,8 +58,8 @@ f5-tts_infer-cli \
--gen_text "Some text you want TTS model generate for you." --gen_text "Some text you want TTS model generate for you."
# Choose Vocoder # Choose Vocoder
f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/model_1250000.pt > f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base_bigvgan/model_1250000.pt>
f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base/model_1200000.safetensors > f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base/model_1200000.safetensors>
``` ```
And a `.toml` file would help with more flexible usage. And a `.toml` file would help with more flexible usage.

View File

@@ -111,12 +111,7 @@ model = CFM(
vocab_char_map=vocab_char_map, vocab_char_map=vocab_char_map,
).to(device) ).to(device)
supports_fp16 = device == "cuda" and torch.cuda.get_device_properties(device).major >= 6 dtype = torch.float32 if mel_spec_type == "bigvgan" else None
if supports_fp16 and mel_spec_type == "vocos":
dtype = torch.float16
elif mel_spec_type == "bigvgan":
dtype = torch.float32
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema) model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
# Audio # Audio

View File

@@ -40,6 +40,7 @@ n_mel_channels = 100
hop_length = 256 hop_length = 256
win_length = 1024 win_length = 1024
n_fft = 1024 n_fft = 1024
mel_spec_type = "vocos"
target_rms = 0.1 target_rms = 0.1
cross_fade_duration = 0.15 cross_fade_duration = 0.15
ode_method = "euler" ode_method = "euler"
@@ -131,7 +132,7 @@ def initialize_asr_pipeline(device=device):
# load model checkpoint for inference # load model checkpoint for inference
def load_checkpoint(model, ckpt_path, device, dtype, use_ema=True): def load_checkpoint(model, ckpt_path, device, dtype=None, use_ema=True):
if dtype is None: if dtype is None:
dtype = ( dtype = (
torch.float16 if device == "cuda" and torch.cuda.get_device_properties(device).major >= 6 else torch.float32 torch.float16 if device == "cuda" and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
@@ -175,7 +176,7 @@ def load_model(
model_cls, model_cls,
model_cfg, model_cfg,
ckpt_path, ckpt_path,
mel_spec_type="vocos", mel_spec_type=mel_spec_type,
vocab_file="", vocab_file="",
ode_method=ode_method, ode_method=ode_method,
use_ema=True, use_ema=True,
@@ -206,12 +207,7 @@ def load_model(
vocab_char_map=vocab_char_map, vocab_char_map=vocab_char_map,
).to(device) ).to(device)
supports_fp16 = device == "cuda" and torch.cuda.get_device_properties(device).major >= 6 dtype = torch.float32 if mel_spec_type == "bigvgan" else None
if supports_fp16 and mel_spec_type == "vocos":
dtype = torch.float16
elif mel_spec_type == "bigvgan":
dtype = torch.float32
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema) model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
return model return model
@@ -307,7 +303,7 @@ def infer_process(
gen_text, gen_text,
model_obj, model_obj,
vocoder, vocoder,
mel_spec_type="vocos", mel_spec_type=mel_spec_type,
show_info=print, show_info=print,
progress=tqdm, progress=tqdm,
target_rms=target_rms, target_rms=target_rms,

View File

@@ -19,57 +19,44 @@ from librosa.filters import mel as librosa_mel_fn
from torch import nn from torch import nn
from x_transformers.x_transformers import apply_rotary_pos_emb from x_transformers.x_transformers import apply_rotary_pos_emb
# raw wav to mel spec # raw wav to mel spec
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
return torch.log(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression_torch(x, C=1):
return torch.exp(x) / C
def spectral_normalize_torch(magnitudes):
return dynamic_range_compression_torch(magnitudes)
mel_basis_cache = {} mel_basis_cache = {}
hann_window_cache = {} hann_window_cache = {}
# BigVGAN extract mel spectrogram def get_bigvgan_mel_spectrogram(
def mel_spectrogram( waveform,
y: torch.Tensor, n_fft=1024,
n_fft: int, n_mel_channels=100,
num_mels: int, target_sample_rate=24000,
sampling_rate: int, hop_length=256,
hop_size: int, win_length=1024,
win_size: int, fmin=0,
fmin: int, fmax=None,
fmax: int = None, center=False,
center: bool = False, ): # Copy from https://github.com/NVIDIA/BigVGAN/tree/main
) -> torch.Tensor: device = waveform.device
"""Copy from https://github.com/NVIDIA/BigVGAN/tree/main""" key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}"
device = y.device
key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"
if key not in mel_basis_cache: if key not in mel_basis_cache:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) mel = librosa_mel_fn(sr=target_sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=fmin, fmax=fmax)
mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) # TODO: why they need .float()? mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) # TODO: why they need .float()?
hann_window_cache[key] = torch.hann_window(win_size).to(device) hann_window_cache[key] = torch.hann_window(win_length).to(device)
mel_basis = mel_basis_cache[key] mel_basis = mel_basis_cache[key]
hann_window = hann_window_cache[key] hann_window = hann_window_cache[key]
padding = (n_fft - hop_size) // 2 padding = (n_fft - hop_length) // 2
y = torch.nn.functional.pad(y.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1) waveform = torch.nn.functional.pad(waveform.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1)
spec = torch.stft( spec = torch.stft(
y, waveform,
n_fft, n_fft,
hop_length=hop_size, hop_length=hop_length,
win_length=win_size, win_length=win_length,
window=hann_window, window=hann_window,
center=center, center=center,
pad_mode="reflect", pad_mode="reflect",
@@ -80,31 +67,11 @@ def mel_spectrogram(
spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9) spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
mel_spec = torch.matmul(mel_basis, spec) mel_spec = torch.matmul(mel_basis, spec)
mel_spec = spectral_normalize_torch(mel_spec) mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5))
return mel_spec return mel_spec
def get_bigvgan_mel_spectrogram(
waveform,
n_fft=1024,
n_mel_channels=100,
target_sample_rate=24000,
hop_length=256,
win_length=1024,
):
return mel_spectrogram(
waveform,
n_fft, # 1024
n_mel_channels, # 100
target_sample_rate, # 24000
hop_length, # 256
win_length, # 1024
fmin=0, # 0
fmax=None, # null
)
def get_vocos_mel_spectrogram( def get_vocos_mel_spectrogram(
waveform, waveform,
n_fft=1024, n_fft=1024,

View File

@@ -13,7 +13,7 @@ n_mel_channels = 100
hop_length = 256 hop_length = 256
win_length = 1024 win_length = 1024
n_fft = 1024 n_fft = 1024
mel_spec_type = "bigvgan" # 'vocos' or 'bigvgan' mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
tokenizer = "pinyin" # 'pinyin', 'char', or 'custom' tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)