mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-12 15:50:07 -08:00
minor fix
This commit is contained in:
@@ -43,14 +43,10 @@ pip install git+https://github.com/SWivid/F5-TTS.git
|
|||||||
```bash
|
```bash
|
||||||
git clone https://github.com/SWivid/F5-TTS.git
|
git clone https://github.com/SWivid/F5-TTS.git
|
||||||
cd F5-TTS
|
cd F5-TTS
|
||||||
|
# git submodule update --init --recursive # (optional, if need bigvgan)
|
||||||
pip install -e .
|
pip install -e .
|
||||||
|
|
||||||
# Init submodule (optional, if you want to change the vocoder from vocos to bigvgan)
|
|
||||||
# git submodule update --init --recursive
|
|
||||||
# pip install -e .
|
|
||||||
```
|
```
|
||||||
|
If initialize submodule, you should add the following code at the beginning of `src/third_party/BigVGAN/bigvgan.py`.
|
||||||
After init submodule, you need to change the `src/third_party/BigVGAN/bigvgan.py` by adding the following code at the beginning of the file.
|
|
||||||
```python
|
```python
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|||||||
@@ -120,6 +120,7 @@ def main():
|
|||||||
target_sample_rate=target_sample_rate,
|
target_sample_rate=target_sample_rate,
|
||||||
n_mel_channels=n_mel_channels,
|
n_mel_channels=n_mel_channels,
|
||||||
hop_length=hop_length,
|
hop_length=hop_length,
|
||||||
|
mel_spec_type=mel_spec_type,
|
||||||
target_rms=target_rms,
|
target_rms=target_rms,
|
||||||
use_truth_duration=use_truth_duration,
|
use_truth_duration=use_truth_duration,
|
||||||
infer_batch_size=infer_batch_size,
|
infer_batch_size=infer_batch_size,
|
||||||
@@ -153,12 +154,7 @@ def main():
|
|||||||
vocab_char_map=vocab_char_map,
|
vocab_char_map=vocab_char_map,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
supports_fp16 = device == "cuda" and torch.cuda.get_device_properties(device).major >= 6
|
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
|
||||||
if supports_fp16 and mel_spec_type == "vocos":
|
|
||||||
dtype = torch.float16
|
|
||||||
elif mel_spec_type == "bigvgan":
|
|
||||||
dtype = torch.float32
|
|
||||||
|
|
||||||
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
|
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
|
||||||
|
|
||||||
if not os.path.exists(output_dir) and accelerator.is_main_process:
|
if not os.path.exists(output_dir) and accelerator.is_main_process:
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ def get_inference_prompt(
|
|||||||
win_length=1024,
|
win_length=1024,
|
||||||
n_mel_channels=100,
|
n_mel_channels=100,
|
||||||
hop_length=256,
|
hop_length=256,
|
||||||
mel_spec_type="bigvgan",
|
mel_spec_type="vocos",
|
||||||
target_rms=0.1,
|
target_rms=0.1,
|
||||||
use_truth_duration=False,
|
use_truth_duration=False,
|
||||||
infer_batch_size=1,
|
infer_batch_size=1,
|
||||||
|
|||||||
@@ -58,8 +58,8 @@ f5-tts_infer-cli \
|
|||||||
--gen_text "Some text you want TTS model generate for you."
|
--gen_text "Some text you want TTS model generate for you."
|
||||||
|
|
||||||
# Choose Vocoder
|
# Choose Vocoder
|
||||||
f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/model_1250000.pt >
|
f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base_bigvgan/model_1250000.pt>
|
||||||
f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base/model_1200000.safetensors >
|
f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base/model_1200000.safetensors>
|
||||||
```
|
```
|
||||||
|
|
||||||
And a `.toml` file would help with more flexible usage.
|
And a `.toml` file would help with more flexible usage.
|
||||||
|
|||||||
@@ -111,12 +111,7 @@ model = CFM(
|
|||||||
vocab_char_map=vocab_char_map,
|
vocab_char_map=vocab_char_map,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
supports_fp16 = device == "cuda" and torch.cuda.get_device_properties(device).major >= 6
|
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
|
||||||
if supports_fp16 and mel_spec_type == "vocos":
|
|
||||||
dtype = torch.float16
|
|
||||||
elif mel_spec_type == "bigvgan":
|
|
||||||
dtype = torch.float32
|
|
||||||
|
|
||||||
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
|
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
|
||||||
|
|
||||||
# Audio
|
# Audio
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ n_mel_channels = 100
|
|||||||
hop_length = 256
|
hop_length = 256
|
||||||
win_length = 1024
|
win_length = 1024
|
||||||
n_fft = 1024
|
n_fft = 1024
|
||||||
|
mel_spec_type = "vocos"
|
||||||
target_rms = 0.1
|
target_rms = 0.1
|
||||||
cross_fade_duration = 0.15
|
cross_fade_duration = 0.15
|
||||||
ode_method = "euler"
|
ode_method = "euler"
|
||||||
@@ -131,7 +132,7 @@ def initialize_asr_pipeline(device=device):
|
|||||||
# load model checkpoint for inference
|
# load model checkpoint for inference
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(model, ckpt_path, device, dtype, use_ema=True):
|
def load_checkpoint(model, ckpt_path, device, dtype=None, use_ema=True):
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = (
|
dtype = (
|
||||||
torch.float16 if device == "cuda" and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
|
torch.float16 if device == "cuda" and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
|
||||||
@@ -175,7 +176,7 @@ def load_model(
|
|||||||
model_cls,
|
model_cls,
|
||||||
model_cfg,
|
model_cfg,
|
||||||
ckpt_path,
|
ckpt_path,
|
||||||
mel_spec_type="vocos",
|
mel_spec_type=mel_spec_type,
|
||||||
vocab_file="",
|
vocab_file="",
|
||||||
ode_method=ode_method,
|
ode_method=ode_method,
|
||||||
use_ema=True,
|
use_ema=True,
|
||||||
@@ -206,12 +207,7 @@ def load_model(
|
|||||||
vocab_char_map=vocab_char_map,
|
vocab_char_map=vocab_char_map,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
supports_fp16 = device == "cuda" and torch.cuda.get_device_properties(device).major >= 6
|
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
|
||||||
if supports_fp16 and mel_spec_type == "vocos":
|
|
||||||
dtype = torch.float16
|
|
||||||
elif mel_spec_type == "bigvgan":
|
|
||||||
dtype = torch.float32
|
|
||||||
|
|
||||||
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
|
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
@@ -307,7 +303,7 @@ def infer_process(
|
|||||||
gen_text,
|
gen_text,
|
||||||
model_obj,
|
model_obj,
|
||||||
vocoder,
|
vocoder,
|
||||||
mel_spec_type="vocos",
|
mel_spec_type=mel_spec_type,
|
||||||
show_info=print,
|
show_info=print,
|
||||||
progress=tqdm,
|
progress=tqdm,
|
||||||
target_rms=target_rms,
|
target_rms=target_rms,
|
||||||
|
|||||||
@@ -19,57 +19,44 @@ from librosa.filters import mel as librosa_mel_fn
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from x_transformers.x_transformers import apply_rotary_pos_emb
|
from x_transformers.x_transformers import apply_rotary_pos_emb
|
||||||
|
|
||||||
|
|
||||||
# raw wav to mel spec
|
# raw wav to mel spec
|
||||||
|
|
||||||
|
|
||||||
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
|
||||||
return torch.log(torch.clamp(x, min=clip_val) * C)
|
|
||||||
|
|
||||||
|
|
||||||
def dynamic_range_decompression_torch(x, C=1):
|
|
||||||
return torch.exp(x) / C
|
|
||||||
|
|
||||||
|
|
||||||
def spectral_normalize_torch(magnitudes):
|
|
||||||
return dynamic_range_compression_torch(magnitudes)
|
|
||||||
|
|
||||||
|
|
||||||
mel_basis_cache = {}
|
mel_basis_cache = {}
|
||||||
hann_window_cache = {}
|
hann_window_cache = {}
|
||||||
|
|
||||||
|
|
||||||
# BigVGAN extract mel spectrogram
|
def get_bigvgan_mel_spectrogram(
|
||||||
def mel_spectrogram(
|
waveform,
|
||||||
y: torch.Tensor,
|
n_fft=1024,
|
||||||
n_fft: int,
|
n_mel_channels=100,
|
||||||
num_mels: int,
|
target_sample_rate=24000,
|
||||||
sampling_rate: int,
|
hop_length=256,
|
||||||
hop_size: int,
|
win_length=1024,
|
||||||
win_size: int,
|
fmin=0,
|
||||||
fmin: int,
|
fmax=None,
|
||||||
fmax: int = None,
|
center=False,
|
||||||
center: bool = False,
|
): # Copy from https://github.com/NVIDIA/BigVGAN/tree/main
|
||||||
) -> torch.Tensor:
|
device = waveform.device
|
||||||
"""Copy from https://github.com/NVIDIA/BigVGAN/tree/main"""
|
key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}"
|
||||||
device = y.device
|
|
||||||
key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"
|
|
||||||
|
|
||||||
if key not in mel_basis_cache:
|
if key not in mel_basis_cache:
|
||||||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
mel = librosa_mel_fn(sr=target_sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=fmin, fmax=fmax)
|
||||||
mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) # TODO: why they need .float()?
|
mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) # TODO: why they need .float()?
|
||||||
hann_window_cache[key] = torch.hann_window(win_size).to(device)
|
hann_window_cache[key] = torch.hann_window(win_length).to(device)
|
||||||
|
|
||||||
mel_basis = mel_basis_cache[key]
|
mel_basis = mel_basis_cache[key]
|
||||||
hann_window = hann_window_cache[key]
|
hann_window = hann_window_cache[key]
|
||||||
|
|
||||||
padding = (n_fft - hop_size) // 2
|
padding = (n_fft - hop_length) // 2
|
||||||
y = torch.nn.functional.pad(y.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1)
|
waveform = torch.nn.functional.pad(waveform.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1)
|
||||||
|
|
||||||
spec = torch.stft(
|
spec = torch.stft(
|
||||||
y,
|
waveform,
|
||||||
n_fft,
|
n_fft,
|
||||||
hop_length=hop_size,
|
hop_length=hop_length,
|
||||||
win_length=win_size,
|
win_length=win_length,
|
||||||
window=hann_window,
|
window=hann_window,
|
||||||
center=center,
|
center=center,
|
||||||
pad_mode="reflect",
|
pad_mode="reflect",
|
||||||
@@ -80,31 +67,11 @@ def mel_spectrogram(
|
|||||||
spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
|
spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
|
||||||
|
|
||||||
mel_spec = torch.matmul(mel_basis, spec)
|
mel_spec = torch.matmul(mel_basis, spec)
|
||||||
mel_spec = spectral_normalize_torch(mel_spec)
|
mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5))
|
||||||
|
|
||||||
return mel_spec
|
return mel_spec
|
||||||
|
|
||||||
|
|
||||||
def get_bigvgan_mel_spectrogram(
|
|
||||||
waveform,
|
|
||||||
n_fft=1024,
|
|
||||||
n_mel_channels=100,
|
|
||||||
target_sample_rate=24000,
|
|
||||||
hop_length=256,
|
|
||||||
win_length=1024,
|
|
||||||
):
|
|
||||||
return mel_spectrogram(
|
|
||||||
waveform,
|
|
||||||
n_fft, # 1024
|
|
||||||
n_mel_channels, # 100
|
|
||||||
target_sample_rate, # 24000
|
|
||||||
hop_length, # 256
|
|
||||||
win_length, # 1024
|
|
||||||
fmin=0, # 0
|
|
||||||
fmax=None, # null
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_vocos_mel_spectrogram(
|
def get_vocos_mel_spectrogram(
|
||||||
waveform,
|
waveform,
|
||||||
n_fft=1024,
|
n_fft=1024,
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ n_mel_channels = 100
|
|||||||
hop_length = 256
|
hop_length = 256
|
||||||
win_length = 1024
|
win_length = 1024
|
||||||
n_fft = 1024
|
n_fft = 1024
|
||||||
mel_spec_type = "bigvgan" # 'vocos' or 'bigvgan'
|
mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
|
||||||
|
|
||||||
tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
|
tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
|
||||||
tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
||||||
|
|||||||
Reference in New Issue
Block a user