mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-12 07:40:43 -08:00
refactor: more details about bigvgan, clear function definition
This commit is contained in:
@@ -46,11 +46,13 @@ cd F5-TTS
|
||||
pip install -e .
|
||||
|
||||
# Init submodule(optional, if you want to change the vocoder from vocos to bigvgan)
|
||||
git submodule update --init --recursive
|
||||
# git submodule update --init --recursive
|
||||
# pip install -e .
|
||||
```
|
||||
|
||||
After init submodule, you need to change the `src/third_party/BigVGAN/bigvgan.py` by adding the following code at the beginning of the file.
|
||||
```python
|
||||
import os
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
```
|
||||
@@ -104,10 +106,6 @@ f5-tts_infer-cli -c custom.toml
|
||||
|
||||
# Multi voice. See src/f5_tts/infer/README.md
|
||||
f5-tts_infer-cli -c src/f5_tts/infer/examples/multi/story.toml
|
||||
|
||||
# Choose Vocoder
|
||||
f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/model_1250000.pt >
|
||||
f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base/model_1200000.safetensors >
|
||||
```
|
||||
|
||||
### 3. More instructions
|
||||
|
||||
@@ -38,7 +38,7 @@ class F5TTS:
|
||||
self.target_sample_rate = target_sample_rate
|
||||
self.hop_length = hop_length
|
||||
self.seed = -1
|
||||
self.extract_backend = vocoder_name
|
||||
self.mel_spec_type = vocoder_name
|
||||
|
||||
# Set device
|
||||
self.device = device or (
|
||||
@@ -52,10 +52,13 @@ class F5TTS:
|
||||
def load_vocoder_model(self, vocoder_name, local_path):
|
||||
self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device)
|
||||
|
||||
def load_ema_model(self, model_type, ckpt_file, extract_backend, vocab_file, ode_method, use_ema):
|
||||
def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema):
|
||||
if model_type == "F5-TTS":
|
||||
if not ckpt_file:
|
||||
ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))
|
||||
if mel_spec_type == "vocos":
|
||||
ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))
|
||||
elif mel_spec_type == "bigvgan":
|
||||
ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt"))
|
||||
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
||||
model_cls = DiT
|
||||
elif model_type == "E2-TTS":
|
||||
@@ -67,7 +70,7 @@ class F5TTS:
|
||||
raise ValueError(f"Unknown model type: {model_type}")
|
||||
|
||||
self.ema_model = load_model(
|
||||
model_cls, model_cfg, ckpt_file, extract_backend, vocab_file, ode_method, use_ema, self.device
|
||||
model_cls, model_cfg, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, self.device
|
||||
)
|
||||
|
||||
def export_wav(self, wav, file_wave, remove_silence=False):
|
||||
@@ -111,7 +114,7 @@ class F5TTS:
|
||||
gen_text,
|
||||
self.ema_model,
|
||||
self.vocoder,
|
||||
self.extract_backend,
|
||||
self.mel_spec_type,
|
||||
show_info=show_info,
|
||||
progress=progress,
|
||||
target_rms=target_rms,
|
||||
|
||||
@@ -32,7 +32,7 @@ n_mel_channels = 100
|
||||
hop_length = 256
|
||||
win_length = 1024
|
||||
n_fft = 1024
|
||||
extract_backend = "bigvgan" # 'vocos' or 'bigvgan'
|
||||
mel_spec_type = "bigvgan" # 'vocos' or 'bigvgan'
|
||||
target_rms = 0.1
|
||||
|
||||
|
||||
@@ -126,11 +126,11 @@ def main():
|
||||
|
||||
# Vocoder model
|
||||
local = False
|
||||
if extract_backend == "vocos":
|
||||
if mel_spec_type == "vocos":
|
||||
vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
||||
elif extract_backend == "bigvgan":
|
||||
elif mel_spec_type == "bigvgan":
|
||||
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
||||
vocoder = load_vocoder(vocoder_name=extract_backend, is_local=local, local_path=vocoder_local_path)
|
||||
vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path)
|
||||
|
||||
# Tokenizer
|
||||
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
|
||||
@@ -144,7 +144,7 @@ def main():
|
||||
win_length=win_length,
|
||||
n_mel_channels=n_mel_channels,
|
||||
target_sample_rate=target_sample_rate,
|
||||
extract_backend=extract_backend,
|
||||
mel_spec_type=mel_spec_type,
|
||||
),
|
||||
odeint_kwargs=dict(
|
||||
method=ode_method,
|
||||
@@ -152,7 +152,12 @@ def main():
|
||||
vocab_char_map=vocab_char_map,
|
||||
).to(device)
|
||||
|
||||
dtype = torch.float16 if extract_backend == "vocos" else torch.float32
|
||||
supports_fp16 = device == "cuda" and torch.cuda.get_device_properties(device).major >= 6
|
||||
if supports_fp16 and mel_spec_type == "vocos":
|
||||
dtype = torch.float16
|
||||
else:
|
||||
dtype = torch.float32
|
||||
|
||||
model = load_checkpoint(model, ckpt_path, device, dtype, use_ema=use_ema)
|
||||
|
||||
if not os.path.exists(output_dir) and accelerator.is_main_process:
|
||||
@@ -186,9 +191,9 @@ def main():
|
||||
for i, gen in enumerate(generated):
|
||||
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
|
||||
gen_mel_spec = gen.permute(0, 2, 1)
|
||||
if extract_backend == "vocos":
|
||||
if mel_spec_type == "vocos":
|
||||
generated_wave = vocoder.decode(gen_mel_spec)
|
||||
elif extract_backend == "bigvgan":
|
||||
elif mel_spec_type == "bigvgan":
|
||||
generated_wave = vocoder(gen_mel_spec)
|
||||
|
||||
if ref_rms_list[i] < target_rms:
|
||||
|
||||
@@ -78,7 +78,7 @@ def get_inference_prompt(
|
||||
win_length=1024,
|
||||
n_mel_channels=100,
|
||||
hop_length=256,
|
||||
extract_backend="bigvgan",
|
||||
mel_spec_type="bigvgan",
|
||||
target_rms=0.1,
|
||||
use_truth_duration=False,
|
||||
infer_batch_size=1,
|
||||
@@ -102,7 +102,7 @@ def get_inference_prompt(
|
||||
win_length=win_length,
|
||||
n_mel_channels=n_mel_channels,
|
||||
target_sample_rate=target_sample_rate,
|
||||
extract_backend=extract_backend,
|
||||
mel_spec_type=mel_spec_type,
|
||||
)
|
||||
|
||||
for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
|
||||
|
||||
@@ -56,6 +56,10 @@ f5-tts_infer-cli \
|
||||
--ref_audio "ref_audio.wav" \
|
||||
--ref_text "The content, subtitle or transcription of reference audio." \
|
||||
--gen_text "Some text you want TTS model generate for you."
|
||||
|
||||
# Choose Vocoder
|
||||
f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/model_1250000.pt >
|
||||
f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base/model_1200000.safetensors >
|
||||
```
|
||||
|
||||
And a `.toml` file would help with more flexible usage.
|
||||
|
||||
@@ -115,11 +115,9 @@ if args.vocoder_name == "vocos":
|
||||
vocoder_local_path = "../checkpoints/vocos-mel-24khz"
|
||||
elif args.vocoder_name == "bigvgan":
|
||||
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
||||
extract_backend = args.vocoder_name
|
||||
mel_spec_type = args.vocoder_name
|
||||
|
||||
vocoder = load_vocoder(
|
||||
vocoder_name=extract_backend, is_local=args.load_vocoder_from_local, local_path=vocoder_local_path
|
||||
)
|
||||
vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=args.load_vocoder_from_local, local_path=vocoder_local_path)
|
||||
|
||||
|
||||
# load models
|
||||
@@ -159,7 +157,7 @@ print(f"Using {model}...")
|
||||
ema_model = load_model(model_cls, model_cfg, ckpt_file, args.vocoder_name, vocab_file)
|
||||
|
||||
|
||||
def main_process(ref_audio, ref_text, text_gen, model_obj, extract_backend, remove_silence, speed):
|
||||
def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove_silence, speed):
|
||||
main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
|
||||
if "voices" not in config:
|
||||
voices = {"main": main_voice}
|
||||
@@ -194,7 +192,7 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, extract_backend, remo
|
||||
ref_text = voices[voice]["ref_text"]
|
||||
print(f"Voice: {voice}")
|
||||
audio, final_sample_rate, spectragram = infer_process(
|
||||
ref_audio, ref_text, gen_text, model_obj, vocoder, extract_backend, speed=speed
|
||||
ref_audio, ref_text, gen_text, model_obj, vocoder, mel_spec_type, speed=speed
|
||||
)
|
||||
generated_audio_segments.append(audio)
|
||||
|
||||
@@ -213,7 +211,7 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, extract_backend, remo
|
||||
|
||||
|
||||
def main():
|
||||
main_process(ref_audio, ref_text, gen_text, ema_model, extract_backend, remove_silence, speed)
|
||||
main_process(ref_audio, ref_text, gen_text, ema_model, mel_spec_type, remove_silence, speed)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -18,7 +18,7 @@ n_mel_channels = 100
|
||||
hop_length = 256
|
||||
win_length = 1024
|
||||
n_fft = 1024
|
||||
extract_backend = "bigvgan" # 'vocos' or 'bigvgan'
|
||||
mel_spec_type = "bigvgan" # 'vocos' or 'bigvgan'
|
||||
target_rms = 0.1
|
||||
|
||||
tokenizer = "pinyin"
|
||||
@@ -85,11 +85,11 @@ if not os.path.exists(output_dir):
|
||||
|
||||
# Vocoder model
|
||||
local = False
|
||||
if extract_backend == "vocos":
|
||||
if mel_spec_type == "vocos":
|
||||
vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
||||
elif extract_backend == "bigvgan":
|
||||
elif mel_spec_type == "bigvgan":
|
||||
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
||||
vocoder = load_vocoder(vocoder_name=extract_backend, is_local=local, local_path=vocoder_local_path)
|
||||
vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path)
|
||||
|
||||
# Tokenizer
|
||||
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
|
||||
@@ -103,7 +103,7 @@ model = CFM(
|
||||
win_length=win_length,
|
||||
n_mel_channels=n_mel_channels,
|
||||
target_sample_rate=target_sample_rate,
|
||||
extract_backend=extract_backend,
|
||||
mel_spec_type=mel_spec_type,
|
||||
),
|
||||
odeint_kwargs=dict(
|
||||
method=ode_method,
|
||||
@@ -111,7 +111,12 @@ model = CFM(
|
||||
vocab_char_map=vocab_char_map,
|
||||
).to(device)
|
||||
|
||||
dtype = torch.float16 if extract_backend == "vocos" else torch.float32
|
||||
supports_fp16 = device == "cuda" and torch.cuda.get_device_properties(device).major >= 6
|
||||
if supports_fp16 and mel_spec_type == "vocos":
|
||||
dtype = torch.float16
|
||||
else:
|
||||
dtype = torch.float32
|
||||
|
||||
model = load_checkpoint(model, ckpt_path, device, dtype, use_ema=use_ema)
|
||||
|
||||
# Audio
|
||||
@@ -178,9 +183,9 @@ with torch.inference_mode():
|
||||
generated = generated.to(torch.float32)
|
||||
generated = generated[:, ref_audio_len:, :]
|
||||
gen_mel_spec = generated.permute(0, 2, 1)
|
||||
if extract_backend == "vocos":
|
||||
if mel_spec_type == "vocos":
|
||||
generated_wave = vocoder.decode(gen_mel_spec)
|
||||
elif extract_backend == "bigvgan":
|
||||
elif mel_spec_type == "bigvgan":
|
||||
generated_wave = vocoder(gen_mel_spec)
|
||||
|
||||
if rms < target_rms:
|
||||
|
||||
@@ -4,7 +4,7 @@ import os
|
||||
import sys
|
||||
|
||||
sys.path.append(f"../../{os.path.dirname(os.path.abspath(__file__))}/third_party/BigVGAN/")
|
||||
from third_party.BigVGAN import bigvgan
|
||||
|
||||
import hashlib
|
||||
import re
|
||||
import tempfile
|
||||
@@ -40,7 +40,7 @@ n_mel_channels = 100
|
||||
hop_length = 256
|
||||
win_length = 1024
|
||||
n_fft = 1024
|
||||
extract_backend = "bigvgan" # 'vocos' or 'bigvgan'
|
||||
mel_spec_type = "bigvgan" # 'vocos' or 'bigvgan'
|
||||
target_rms = 0.1
|
||||
cross_fade_duration = 0.15
|
||||
ode_method = "euler"
|
||||
@@ -97,8 +97,12 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev
|
||||
vocoder = vocoder.eval().to(device)
|
||||
else:
|
||||
print("Download Vocos from huggingface charactr/vocos-mel-24khz")
|
||||
vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
||||
vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
|
||||
elif vocoder_name == "bigvgan":
|
||||
try:
|
||||
from third_party.BigVGAN import bigvgan
|
||||
except ImportError:
|
||||
print("You need to follow the README to init submodule and change the BigVGAN source code.")
|
||||
if is_local:
|
||||
"""download from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main"""
|
||||
vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
|
||||
@@ -165,7 +169,7 @@ def load_checkpoint(model, ckpt_path, device, dtype, use_ema=True):
|
||||
|
||||
|
||||
def load_model(
|
||||
model_cls, model_cfg, ckpt_path, extract_backend, vocab_file="", ode_method=ode_method, use_ema=True, device=device
|
||||
model_cls, model_cfg, ckpt_path, mel_spec_type, vocab_file="", ode_method=ode_method, use_ema=True, device=device
|
||||
):
|
||||
if vocab_file == "":
|
||||
vocab_file = str(files("f5_tts").joinpath("infer/examples/vocab.txt"))
|
||||
@@ -184,7 +188,7 @@ def load_model(
|
||||
win_length=win_length,
|
||||
n_mel_channels=n_mel_channels,
|
||||
target_sample_rate=target_sample_rate,
|
||||
extract_backend=extract_backend,
|
||||
mel_spec_type=mel_spec_type,
|
||||
),
|
||||
odeint_kwargs=dict(
|
||||
method=ode_method,
|
||||
@@ -192,7 +196,12 @@ def load_model(
|
||||
vocab_char_map=vocab_char_map,
|
||||
).to(device)
|
||||
|
||||
dtype = torch.float16 if extract_backend == "vocos" else torch.float32
|
||||
supports_fp16 = device == "cuda" and torch.cuda.get_device_properties(device).major >= 6
|
||||
if supports_fp16 and mel_spec_type == "vocos":
|
||||
dtype = torch.float16
|
||||
else:
|
||||
dtype = torch.float32
|
||||
|
||||
model = load_checkpoint(model, ckpt_path, device, dtype, use_ema=use_ema)
|
||||
|
||||
return model
|
||||
@@ -288,7 +297,7 @@ def infer_process(
|
||||
gen_text,
|
||||
model_obj,
|
||||
vocoder,
|
||||
extract_backend,
|
||||
mel_spec_type,
|
||||
show_info=print,
|
||||
progress=tqdm,
|
||||
target_rms=target_rms,
|
||||
@@ -314,7 +323,7 @@ def infer_process(
|
||||
gen_text_batches,
|
||||
model_obj,
|
||||
vocoder,
|
||||
extract_backend,
|
||||
mel_spec_type,
|
||||
progress=progress,
|
||||
target_rms=target_rms,
|
||||
cross_fade_duration=cross_fade_duration,
|
||||
@@ -336,7 +345,7 @@ def infer_batch_process(
|
||||
gen_text_batches,
|
||||
model_obj,
|
||||
vocoder,
|
||||
extract_backend,
|
||||
mel_spec_type,
|
||||
progress=tqdm,
|
||||
target_rms=0.1,
|
||||
cross_fade_duration=0.15,
|
||||
@@ -392,9 +401,9 @@ def infer_batch_process(
|
||||
generated = generated.to(torch.float32)
|
||||
generated = generated[:, ref_audio_len:, :]
|
||||
generated_mel_spec = generated.permute(0, 2, 1)
|
||||
if extract_backend == "vocos":
|
||||
if mel_spec_type == "vocos":
|
||||
generated_wave = vocoder.decode(generated_mel_spec)
|
||||
elif extract_backend == "bigvgan":
|
||||
elif mel_spec_type == "bigvgan":
|
||||
generated_wave = vocoder(generated_mel_spec)
|
||||
if rms < target_rms:
|
||||
generated_wave = generated_wave * rms / target_rms
|
||||
|
||||
@@ -105,9 +105,6 @@ class CFM(nn.Module):
|
||||
cond = cond.permute(0, 2, 1)
|
||||
assert cond.shape[-1] == self.num_channels
|
||||
|
||||
assert next(self.parameters()).dtype == torch.float32 or next(self.parameters()).dtype == torch.float16, print(
|
||||
"Only support fp16 and fp32 inference currently"
|
||||
)
|
||||
cond = cond.to(next(self.parameters()).dtype)
|
||||
|
||||
batch, cond_seq_len, device = *cond.shape[:2], cond.device
|
||||
|
||||
@@ -24,7 +24,7 @@ class HFDataset(Dataset):
|
||||
hop_length=256,
|
||||
n_fft=1024,
|
||||
win_length=1024,
|
||||
extract_backend="vocos",
|
||||
mel_spec_type="vocos",
|
||||
):
|
||||
self.data = hf_dataset
|
||||
self.target_sample_rate = target_sample_rate
|
||||
@@ -36,7 +36,7 @@ class HFDataset(Dataset):
|
||||
win_length=win_length,
|
||||
n_mel_channels=n_mel_channels,
|
||||
target_sample_rate=target_sample_rate,
|
||||
extract_backend=extract_backend,
|
||||
mel_spec_type=mel_spec_type,
|
||||
)
|
||||
|
||||
def get_frame_len(self, index):
|
||||
@@ -90,7 +90,7 @@ class CustomDataset(Dataset):
|
||||
n_mel_channels=100,
|
||||
n_fft=1024,
|
||||
win_length=1024,
|
||||
extract_backend="vocos",
|
||||
mel_spec_type="vocos",
|
||||
preprocessed_mel=False,
|
||||
mel_spec_module: nn.Module | None = None,
|
||||
):
|
||||
@@ -100,7 +100,7 @@ class CustomDataset(Dataset):
|
||||
self.hop_length = hop_length
|
||||
self.n_fft = n_fft
|
||||
self.win_length = win_length
|
||||
self.extract_backend = extract_backend
|
||||
self.mel_spec_type = mel_spec_type
|
||||
self.preprocessed_mel = preprocessed_mel
|
||||
|
||||
if not preprocessed_mel:
|
||||
@@ -112,7 +112,7 @@ class CustomDataset(Dataset):
|
||||
win_length=win_length,
|
||||
n_mel_channels=n_mel_channels,
|
||||
target_sample_rate=target_sample_rate,
|
||||
extract_backend=extract_backend,
|
||||
mel_spec_type=mel_spec_type,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -142,12 +142,10 @@ class MelSpec(nn.Module):
|
||||
win_length=1024,
|
||||
n_mel_channels=100,
|
||||
target_sample_rate=24_000,
|
||||
extract_backend="vocos",
|
||||
mel_spec_type="vocos",
|
||||
):
|
||||
super().__init__()
|
||||
assert extract_backend in ["vocos", "bigvgan"], print(
|
||||
"We only support two extract mel backend: vocos or bigvgan"
|
||||
)
|
||||
assert mel_spec_type in ["vocos", "bigvgan"], print("We only support two extract mel backend: vocos or bigvgan")
|
||||
|
||||
self.n_fft = n_fft
|
||||
self.hop_length = hop_length
|
||||
@@ -155,9 +153,9 @@ class MelSpec(nn.Module):
|
||||
self.n_mel_channels = n_mel_channels
|
||||
self.target_sample_rate = target_sample_rate
|
||||
|
||||
if extract_backend == "vocos":
|
||||
if mel_spec_type == "vocos":
|
||||
self.extractor = get_vocos_mel_spectrogram
|
||||
elif extract_backend == "bigvgan":
|
||||
elif mel_spec_type == "bigvgan":
|
||||
self.extractor = get_bigvgan_mel_spectrogram
|
||||
|
||||
self.register_buffer("dummy", torch.tensor(0), persistent=False)
|
||||
|
||||
@@ -46,7 +46,7 @@ class Trainer:
|
||||
accelerate_kwargs: dict = dict(),
|
||||
ema_kwargs: dict = dict(),
|
||||
bnb_optimizer: bool = False,
|
||||
extract_backend: str = "vocos", # "vocos" | "bigvgan"
|
||||
mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
|
||||
):
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
|
||||
@@ -108,7 +108,7 @@ class Trainer:
|
||||
self.max_samples = max_samples
|
||||
self.grad_accumulation_steps = grad_accumulation_steps
|
||||
self.max_grad_norm = max_grad_norm
|
||||
self.vocoder_name = extract_backend
|
||||
self.vocoder_name = mel_spec_type
|
||||
|
||||
self.noise_scheduler = noise_scheduler
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ n_mel_channels = 100
|
||||
hop_length = 256
|
||||
win_length = 1024
|
||||
n_fft = 1024
|
||||
extract_backend = "bigvgan" # 'vocos' or 'bigvgan'
|
||||
mel_spec_type = "bigvgan" # 'vocos' or 'bigvgan'
|
||||
|
||||
tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
|
||||
tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
||||
@@ -63,7 +63,7 @@ def main():
|
||||
win_length=win_length,
|
||||
n_mel_channels=n_mel_channels,
|
||||
target_sample_rate=target_sample_rate,
|
||||
extract_backend=extract_backend,
|
||||
mel_spec_type=mel_spec_type,
|
||||
)
|
||||
|
||||
model = CFM(
|
||||
@@ -89,7 +89,7 @@ def main():
|
||||
wandb_resume_id=wandb_resume_id,
|
||||
last_per_steps=last_per_steps,
|
||||
log_samples=True,
|
||||
extract_backend=extract_backend,
|
||||
mel_spec_type=mel_spec_type,
|
||||
)
|
||||
|
||||
train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
|
||||
|
||||
Reference in New Issue
Block a user