From b180961782bf87f07cb4b99562a76b344b356025 Mon Sep 17 00:00:00 2001 From: ZhikangNiu Date: Fri, 1 Nov 2024 11:02:39 +0800 Subject: [PATCH] refactor: more details about bigvgan, clear function definition --- README.md | 8 +++----- src/f5_tts/api.py | 13 +++++++----- src/f5_tts/eval/eval_infer_batch.py | 21 +++++++++++-------- src/f5_tts/eval/utils_eval.py | 4 ++-- src/f5_tts/infer/README.md | 4 ++++ src/f5_tts/infer/infer_cli.py | 12 +++++------ src/f5_tts/infer/speech_edit.py | 21 +++++++++++-------- src/f5_tts/infer/utils_infer.py | 31 +++++++++++++++++++---------- src/f5_tts/model/cfm.py | 3 --- src/f5_tts/model/dataset.py | 10 +++++----- src/f5_tts/model/modules.py | 10 ++++------ src/f5_tts/model/trainer.py | 4 ++-- src/f5_tts/train/train.py | 6 +++--- 13 files changed, 82 insertions(+), 65 deletions(-) diff --git a/README.md b/README.md index 34d6750..cd703ef 100644 --- a/README.md +++ b/README.md @@ -46,11 +46,13 @@ cd F5-TTS pip install -e . # Init submodule(optional, if you want to change the vocoder from vocos to bigvgan) -git submodule update --init --recursive +# git submodule update --init --recursive +# pip install -e . ``` After init submodule, you need to change the `src/third_party/BigVGAN/bigvgan.py` by adding the following code at the beginning of the file. ```python +import os import sys sys.path.append(os.path.dirname(os.path.abspath(__file__))) ``` @@ -104,10 +106,6 @@ f5-tts_infer-cli -c custom.toml # Multi voice. See src/f5_tts/infer/README.md f5-tts_infer-cli -c src/f5_tts/infer/examples/multi/story.toml - -# Choose Vocoder -f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file -f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file ``` ### 3. More instructions diff --git a/src/f5_tts/api.py b/src/f5_tts/api.py index 2fb5f40..a4196f2 100644 --- a/src/f5_tts/api.py +++ b/src/f5_tts/api.py @@ -38,7 +38,7 @@ class F5TTS: self.target_sample_rate = target_sample_rate self.hop_length = hop_length self.seed = -1 - self.extract_backend = vocoder_name + self.mel_spec_type = vocoder_name # Set device self.device = device or ( @@ -52,10 +52,13 @@ class F5TTS: def load_vocoder_model(self, vocoder_name, local_path): self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device) - def load_ema_model(self, model_type, ckpt_file, extract_backend, vocab_file, ode_method, use_ema): + def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema): if model_type == "F5-TTS": if not ckpt_file: - ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors")) + if mel_spec_type == "vocos": + ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors")) + elif mel_spec_type == "bigvgan": + ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt")) model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) model_cls = DiT elif model_type == "E2-TTS": @@ -67,7 +70,7 @@ class F5TTS: raise ValueError(f"Unknown model type: {model_type}") self.ema_model = load_model( - model_cls, model_cfg, ckpt_file, extract_backend, vocab_file, ode_method, use_ema, self.device + model_cls, model_cfg, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, self.device ) def export_wav(self, wav, file_wave, remove_silence=False): @@ -111,7 +114,7 @@ class F5TTS: gen_text, self.ema_model, self.vocoder, - self.extract_backend, + self.mel_spec_type, show_info=show_info, progress=progress, target_rms=target_rms, diff --git a/src/f5_tts/eval/eval_infer_batch.py b/src/f5_tts/eval/eval_infer_batch.py index f45604d..94c2875 100644 --- a/src/f5_tts/eval/eval_infer_batch.py +++ b/src/f5_tts/eval/eval_infer_batch.py @@ -32,7 +32,7 @@ n_mel_channels = 100 hop_length = 256 win_length = 1024 n_fft = 1024 -extract_backend = "bigvgan" # 'vocos' or 'bigvgan' +mel_spec_type = "bigvgan" # 'vocos' or 'bigvgan' target_rms = 0.1 @@ -126,11 +126,11 @@ def main(): # Vocoder model local = False - if extract_backend == "vocos": + if mel_spec_type == "vocos": vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz" - elif extract_backend == "bigvgan": + elif mel_spec_type == "bigvgan": vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x" - vocoder = load_vocoder(vocoder_name=extract_backend, is_local=local, local_path=vocoder_local_path) + vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path) # Tokenizer vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer) @@ -144,7 +144,7 @@ def main(): win_length=win_length, n_mel_channels=n_mel_channels, target_sample_rate=target_sample_rate, - extract_backend=extract_backend, + mel_spec_type=mel_spec_type, ), odeint_kwargs=dict( method=ode_method, @@ -152,7 +152,12 @@ def main(): vocab_char_map=vocab_char_map, ).to(device) - dtype = torch.float16 if extract_backend == "vocos" else torch.float32 + supports_fp16 = device == "cuda" and torch.cuda.get_device_properties(device).major >= 6 + if supports_fp16 and mel_spec_type == "vocos": + dtype = torch.float16 + else: + dtype = torch.float32 + model = load_checkpoint(model, ckpt_path, device, dtype, use_ema=use_ema) if not os.path.exists(output_dir) and accelerator.is_main_process: @@ -186,9 +191,9 @@ def main(): for i, gen in enumerate(generated): gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0) gen_mel_spec = gen.permute(0, 2, 1) - if extract_backend == "vocos": + if mel_spec_type == "vocos": generated_wave = vocoder.decode(gen_mel_spec) - elif extract_backend == "bigvgan": + elif mel_spec_type == "bigvgan": generated_wave = vocoder(gen_mel_spec) if ref_rms_list[i] < target_rms: diff --git a/src/f5_tts/eval/utils_eval.py b/src/f5_tts/eval/utils_eval.py index 3b79268..a03d262 100644 --- a/src/f5_tts/eval/utils_eval.py +++ b/src/f5_tts/eval/utils_eval.py @@ -78,7 +78,7 @@ def get_inference_prompt( win_length=1024, n_mel_channels=100, hop_length=256, - extract_backend="bigvgan", + mel_spec_type="bigvgan", target_rms=0.1, use_truth_duration=False, infer_batch_size=1, @@ -102,7 +102,7 @@ def get_inference_prompt( win_length=win_length, n_mel_channels=n_mel_channels, target_sample_rate=target_sample_rate, - extract_backend=extract_backend, + mel_spec_type=mel_spec_type, ) for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."): diff --git a/src/f5_tts/infer/README.md b/src/f5_tts/infer/README.md index e193def..0e84484 100644 --- a/src/f5_tts/infer/README.md +++ b/src/f5_tts/infer/README.md @@ -56,6 +56,10 @@ f5-tts_infer-cli \ --ref_audio "ref_audio.wav" \ --ref_text "The content, subtitle or transcription of reference audio." \ --gen_text "Some text you want TTS model generate for you." + +# Choose Vocoder +f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file +f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file ``` And a `.toml` file would help with more flexible usage. diff --git a/src/f5_tts/infer/infer_cli.py b/src/f5_tts/infer/infer_cli.py index 49e96ff..6138aed 100644 --- a/src/f5_tts/infer/infer_cli.py +++ b/src/f5_tts/infer/infer_cli.py @@ -115,11 +115,9 @@ if args.vocoder_name == "vocos": vocoder_local_path = "../checkpoints/vocos-mel-24khz" elif args.vocoder_name == "bigvgan": vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x" -extract_backend = args.vocoder_name +mel_spec_type = args.vocoder_name -vocoder = load_vocoder( - vocoder_name=extract_backend, is_local=args.load_vocoder_from_local, local_path=vocoder_local_path -) +vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=args.load_vocoder_from_local, local_path=vocoder_local_path) # load models @@ -159,7 +157,7 @@ print(f"Using {model}...") ema_model = load_model(model_cls, model_cfg, ckpt_file, args.vocoder_name, vocab_file) -def main_process(ref_audio, ref_text, text_gen, model_obj, extract_backend, remove_silence, speed): +def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove_silence, speed): main_voice = {"ref_audio": ref_audio, "ref_text": ref_text} if "voices" not in config: voices = {"main": main_voice} @@ -194,7 +192,7 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, extract_backend, remo ref_text = voices[voice]["ref_text"] print(f"Voice: {voice}") audio, final_sample_rate, spectragram = infer_process( - ref_audio, ref_text, gen_text, model_obj, vocoder, extract_backend, speed=speed + ref_audio, ref_text, gen_text, model_obj, vocoder, mel_spec_type, speed=speed ) generated_audio_segments.append(audio) @@ -213,7 +211,7 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, extract_backend, remo def main(): - main_process(ref_audio, ref_text, gen_text, ema_model, extract_backend, remove_silence, speed) + main_process(ref_audio, ref_text, gen_text, ema_model, mel_spec_type, remove_silence, speed) if __name__ == "__main__": diff --git a/src/f5_tts/infer/speech_edit.py b/src/f5_tts/infer/speech_edit.py index e417723..b808792 100644 --- a/src/f5_tts/infer/speech_edit.py +++ b/src/f5_tts/infer/speech_edit.py @@ -18,7 +18,7 @@ n_mel_channels = 100 hop_length = 256 win_length = 1024 n_fft = 1024 -extract_backend = "bigvgan" # 'vocos' or 'bigvgan' +mel_spec_type = "bigvgan" # 'vocos' or 'bigvgan' target_rms = 0.1 tokenizer = "pinyin" @@ -85,11 +85,11 @@ if not os.path.exists(output_dir): # Vocoder model local = False -if extract_backend == "vocos": +if mel_spec_type == "vocos": vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz" -elif extract_backend == "bigvgan": +elif mel_spec_type == "bigvgan": vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x" -vocoder = load_vocoder(vocoder_name=extract_backend, is_local=local, local_path=vocoder_local_path) +vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path) # Tokenizer vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer) @@ -103,7 +103,7 @@ model = CFM( win_length=win_length, n_mel_channels=n_mel_channels, target_sample_rate=target_sample_rate, - extract_backend=extract_backend, + mel_spec_type=mel_spec_type, ), odeint_kwargs=dict( method=ode_method, @@ -111,7 +111,12 @@ model = CFM( vocab_char_map=vocab_char_map, ).to(device) -dtype = torch.float16 if extract_backend == "vocos" else torch.float32 +supports_fp16 = device == "cuda" and torch.cuda.get_device_properties(device).major >= 6 +if supports_fp16 and mel_spec_type == "vocos": + dtype = torch.float16 +else: + dtype = torch.float32 + model = load_checkpoint(model, ckpt_path, device, dtype, use_ema=use_ema) # Audio @@ -178,9 +183,9 @@ with torch.inference_mode(): generated = generated.to(torch.float32) generated = generated[:, ref_audio_len:, :] gen_mel_spec = generated.permute(0, 2, 1) - if extract_backend == "vocos": + if mel_spec_type == "vocos": generated_wave = vocoder.decode(gen_mel_spec) - elif extract_backend == "bigvgan": + elif mel_spec_type == "bigvgan": generated_wave = vocoder(gen_mel_spec) if rms < target_rms: diff --git a/src/f5_tts/infer/utils_infer.py b/src/f5_tts/infer/utils_infer.py index 71f4491..7cb2a4d 100644 --- a/src/f5_tts/infer/utils_infer.py +++ b/src/f5_tts/infer/utils_infer.py @@ -4,7 +4,7 @@ import os import sys sys.path.append(f"../../{os.path.dirname(os.path.abspath(__file__))}/third_party/BigVGAN/") -from third_party.BigVGAN import bigvgan + import hashlib import re import tempfile @@ -40,7 +40,7 @@ n_mel_channels = 100 hop_length = 256 win_length = 1024 n_fft = 1024 -extract_backend = "bigvgan" # 'vocos' or 'bigvgan' +mel_spec_type = "bigvgan" # 'vocos' or 'bigvgan' target_rms = 0.1 cross_fade_duration = 0.15 ode_method = "euler" @@ -97,8 +97,12 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev vocoder = vocoder.eval().to(device) else: print("Download Vocos from huggingface charactr/vocos-mel-24khz") - vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz") + vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device) elif vocoder_name == "bigvgan": + try: + from third_party.BigVGAN import bigvgan + except ImportError: + print("You need to follow the README to init submodule and change the BigVGAN source code.") if is_local: """download from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main""" vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False) @@ -165,7 +169,7 @@ def load_checkpoint(model, ckpt_path, device, dtype, use_ema=True): def load_model( - model_cls, model_cfg, ckpt_path, extract_backend, vocab_file="", ode_method=ode_method, use_ema=True, device=device + model_cls, model_cfg, ckpt_path, mel_spec_type, vocab_file="", ode_method=ode_method, use_ema=True, device=device ): if vocab_file == "": vocab_file = str(files("f5_tts").joinpath("infer/examples/vocab.txt")) @@ -184,7 +188,7 @@ def load_model( win_length=win_length, n_mel_channels=n_mel_channels, target_sample_rate=target_sample_rate, - extract_backend=extract_backend, + mel_spec_type=mel_spec_type, ), odeint_kwargs=dict( method=ode_method, @@ -192,7 +196,12 @@ def load_model( vocab_char_map=vocab_char_map, ).to(device) - dtype = torch.float16 if extract_backend == "vocos" else torch.float32 + supports_fp16 = device == "cuda" and torch.cuda.get_device_properties(device).major >= 6 + if supports_fp16 and mel_spec_type == "vocos": + dtype = torch.float16 + else: + dtype = torch.float32 + model = load_checkpoint(model, ckpt_path, device, dtype, use_ema=use_ema) return model @@ -288,7 +297,7 @@ def infer_process( gen_text, model_obj, vocoder, - extract_backend, + mel_spec_type, show_info=print, progress=tqdm, target_rms=target_rms, @@ -314,7 +323,7 @@ def infer_process( gen_text_batches, model_obj, vocoder, - extract_backend, + mel_spec_type, progress=progress, target_rms=target_rms, cross_fade_duration=cross_fade_duration, @@ -336,7 +345,7 @@ def infer_batch_process( gen_text_batches, model_obj, vocoder, - extract_backend, + mel_spec_type, progress=tqdm, target_rms=0.1, cross_fade_duration=0.15, @@ -392,9 +401,9 @@ def infer_batch_process( generated = generated.to(torch.float32) generated = generated[:, ref_audio_len:, :] generated_mel_spec = generated.permute(0, 2, 1) - if extract_backend == "vocos": + if mel_spec_type == "vocos": generated_wave = vocoder.decode(generated_mel_spec) - elif extract_backend == "bigvgan": + elif mel_spec_type == "bigvgan": generated_wave = vocoder(generated_mel_spec) if rms < target_rms: generated_wave = generated_wave * rms / target_rms diff --git a/src/f5_tts/model/cfm.py b/src/f5_tts/model/cfm.py index c011980..2c88de3 100644 --- a/src/f5_tts/model/cfm.py +++ b/src/f5_tts/model/cfm.py @@ -105,9 +105,6 @@ class CFM(nn.Module): cond = cond.permute(0, 2, 1) assert cond.shape[-1] == self.num_channels - assert next(self.parameters()).dtype == torch.float32 or next(self.parameters()).dtype == torch.float16, print( - "Only support fp16 and fp32 inference currently" - ) cond = cond.to(next(self.parameters()).dtype) batch, cond_seq_len, device = *cond.shape[:2], cond.device diff --git a/src/f5_tts/model/dataset.py b/src/f5_tts/model/dataset.py index 93ddbe0..937836d 100644 --- a/src/f5_tts/model/dataset.py +++ b/src/f5_tts/model/dataset.py @@ -24,7 +24,7 @@ class HFDataset(Dataset): hop_length=256, n_fft=1024, win_length=1024, - extract_backend="vocos", + mel_spec_type="vocos", ): self.data = hf_dataset self.target_sample_rate = target_sample_rate @@ -36,7 +36,7 @@ class HFDataset(Dataset): win_length=win_length, n_mel_channels=n_mel_channels, target_sample_rate=target_sample_rate, - extract_backend=extract_backend, + mel_spec_type=mel_spec_type, ) def get_frame_len(self, index): @@ -90,7 +90,7 @@ class CustomDataset(Dataset): n_mel_channels=100, n_fft=1024, win_length=1024, - extract_backend="vocos", + mel_spec_type="vocos", preprocessed_mel=False, mel_spec_module: nn.Module | None = None, ): @@ -100,7 +100,7 @@ class CustomDataset(Dataset): self.hop_length = hop_length self.n_fft = n_fft self.win_length = win_length - self.extract_backend = extract_backend + self.mel_spec_type = mel_spec_type self.preprocessed_mel = preprocessed_mel if not preprocessed_mel: @@ -112,7 +112,7 @@ class CustomDataset(Dataset): win_length=win_length, n_mel_channels=n_mel_channels, target_sample_rate=target_sample_rate, - extract_backend=extract_backend, + mel_spec_type=mel_spec_type, ), ) diff --git a/src/f5_tts/model/modules.py b/src/f5_tts/model/modules.py index 061a5fd..d3da679 100644 --- a/src/f5_tts/model/modules.py +++ b/src/f5_tts/model/modules.py @@ -142,12 +142,10 @@ class MelSpec(nn.Module): win_length=1024, n_mel_channels=100, target_sample_rate=24_000, - extract_backend="vocos", + mel_spec_type="vocos", ): super().__init__() - assert extract_backend in ["vocos", "bigvgan"], print( - "We only support two extract mel backend: vocos or bigvgan" - ) + assert mel_spec_type in ["vocos", "bigvgan"], print("We only support two extract mel backend: vocos or bigvgan") self.n_fft = n_fft self.hop_length = hop_length @@ -155,9 +153,9 @@ class MelSpec(nn.Module): self.n_mel_channels = n_mel_channels self.target_sample_rate = target_sample_rate - if extract_backend == "vocos": + if mel_spec_type == "vocos": self.extractor = get_vocos_mel_spectrogram - elif extract_backend == "bigvgan": + elif mel_spec_type == "bigvgan": self.extractor = get_bigvgan_mel_spectrogram self.register_buffer("dummy", torch.tensor(0), persistent=False) diff --git a/src/f5_tts/model/trainer.py b/src/f5_tts/model/trainer.py index 51472f9..85c1cb2 100644 --- a/src/f5_tts/model/trainer.py +++ b/src/f5_tts/model/trainer.py @@ -46,7 +46,7 @@ class Trainer: accelerate_kwargs: dict = dict(), ema_kwargs: dict = dict(), bnb_optimizer: bool = False, - extract_backend: str = "vocos", # "vocos" | "bigvgan" + mel_spec_type: str = "vocos", # "vocos" | "bigvgan" ): ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) @@ -108,7 +108,7 @@ class Trainer: self.max_samples = max_samples self.grad_accumulation_steps = grad_accumulation_steps self.max_grad_norm = max_grad_norm - self.vocoder_name = extract_backend + self.vocoder_name = mel_spec_type self.noise_scheduler = noise_scheduler diff --git a/src/f5_tts/train/train.py b/src/f5_tts/train/train.py index 44e8cb4..9ef7db4 100644 --- a/src/f5_tts/train/train.py +++ b/src/f5_tts/train/train.py @@ -13,7 +13,7 @@ n_mel_channels = 100 hop_length = 256 win_length = 1024 n_fft = 1024 -extract_backend = "bigvgan" # 'vocos' or 'bigvgan' +mel_spec_type = "bigvgan" # 'vocos' or 'bigvgan' tokenizer = "pinyin" # 'pinyin', 'char', or 'custom' tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) @@ -63,7 +63,7 @@ def main(): win_length=win_length, n_mel_channels=n_mel_channels, target_sample_rate=target_sample_rate, - extract_backend=extract_backend, + mel_spec_type=mel_spec_type, ) model = CFM( @@ -89,7 +89,7 @@ def main(): wandb_resume_id=wandb_resume_id, last_per_steps=last_per_steps, log_samples=True, - extract_backend=extract_backend, + mel_spec_type=mel_spec_type, ) train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)