mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-12 07:40:43 -08:00
add and run pre-commit with ruff
This commit is contained in:
14
.github/workflows/pre-commit.yaml
vendored
Normal file
14
.github/workflows/pre-commit.yaml
vendored
Normal file
@@ -0,0 +1,14 @@
|
||||
name: pre-commit
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v3
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
14
.pre-commit-config.yaml
Normal file
14
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,14 @@
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
rev: v0.7.0
|
||||
hooks:
|
||||
# Run the linter.
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
# Run the formatter.
|
||||
- id: ruff-format
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v2.3.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
20
README.md
20
README.md
@@ -43,6 +43,26 @@ pip install -r requirements.txt
|
||||
docker build -t f5tts:v1 .
|
||||
```
|
||||
|
||||
### Development
|
||||
|
||||
When making a pull request, please use pre-commit to ensure code quality:
|
||||
|
||||
```bash
|
||||
pip install pre-commit
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
This will run linters and formatters automatically before each commit.
|
||||
|
||||
Manually run using:
|
||||
|
||||
```bash
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
Note: Some model components have linting exceptions for E722 to accommodate tensor notation
|
||||
|
||||
|
||||
## Prepare Dataset
|
||||
|
||||
Example data processing scripts for Emilia and Wenetspeech4TTS, and you may tailor your own one along with a Dataset class in `model/dataset.py`.
|
||||
|
||||
101
finetune-cli.py
101
finetune-cli.py
@@ -1,42 +1,57 @@
|
||||
import argparse
|
||||
from model import CFM, UNetT, DiT, MMDiT, Trainer
|
||||
from model import CFM, UNetT, DiT, Trainer
|
||||
from model.utils import get_tokenizer
|
||||
from model.dataset import load_dataset
|
||||
from cached_path import cached_path
|
||||
import shutil,os
|
||||
import shutil
|
||||
import os
|
||||
|
||||
# -------------------------- Dataset Settings --------------------------- #
|
||||
target_sample_rate = 24000
|
||||
n_mel_channels = 100
|
||||
hop_length = 256
|
||||
|
||||
tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
|
||||
tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
||||
|
||||
# -------------------------- Argument Parsing --------------------------- #
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Train CFM Model')
|
||||
parser = argparse.ArgumentParser(description="Train CFM Model")
|
||||
|
||||
parser.add_argument(
|
||||
"--exp_name", type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"], help="Experiment name"
|
||||
)
|
||||
parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use")
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate for training")
|
||||
parser.add_argument("--batch_size_per_gpu", type=int, default=256, help="Batch size per GPU")
|
||||
parser.add_argument(
|
||||
"--batch_size_type", type=str, default="frame", choices=["frame", "sample"], help="Batch size type"
|
||||
)
|
||||
parser.add_argument("--max_samples", type=int, default=16, help="Max sequences per batch")
|
||||
parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
|
||||
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
|
||||
parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs")
|
||||
parser.add_argument("--num_warmup_updates", type=int, default=5, help="Warmup steps")
|
||||
parser.add_argument("--save_per_updates", type=int, default=10, help="Save checkpoint every X steps")
|
||||
parser.add_argument("--last_per_steps", type=int, default=10, help="Save last checkpoint every X steps")
|
||||
parser.add_argument("--finetune", type=bool, default=True, help="Use Finetune")
|
||||
|
||||
parser.add_argument(
|
||||
"--tokenizer", type=str, default="pinyin", choices=["pinyin", "char", "custom"], help="Tokenizer type"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to custom tokenizer vocab file (only used if tokenizer = 'custom')",
|
||||
)
|
||||
|
||||
parser.add_argument('--exp_name', type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"],help='Experiment name')
|
||||
parser.add_argument('--dataset_name', type=str, default="Emilia_ZH_EN", help='Name of the dataset to use')
|
||||
parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate for training')
|
||||
parser.add_argument('--batch_size_per_gpu', type=int, default=256, help='Batch size per GPU')
|
||||
parser.add_argument('--batch_size_type', type=str, default="frame", choices=["frame", "sample"],help='Batch size type')
|
||||
parser.add_argument('--max_samples', type=int, default=16, help='Max sequences per batch')
|
||||
parser.add_argument('--grad_accumulation_steps', type=int, default=1,help='Gradient accumulation steps')
|
||||
parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Max gradient norm for clipping')
|
||||
parser.add_argument('--epochs', type=int, default=10, help='Number of training epochs')
|
||||
parser.add_argument('--num_warmup_updates', type=int, default=5, help='Warmup steps')
|
||||
parser.add_argument('--save_per_updates', type=int, default=10, help='Save checkpoint every X steps')
|
||||
parser.add_argument('--last_per_steps', type=int, default=10, help='Save last checkpoint every X steps')
|
||||
parser.add_argument('--finetune', type=bool, default=True, help='Use Finetune')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
# -------------------------- Training Settings -------------------------- #
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
|
||||
# Model parameters based on experiment name
|
||||
if args.exp_name == "F5TTS_Base":
|
||||
@@ -44,24 +59,31 @@ def main():
|
||||
model_cls = DiT
|
||||
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
||||
if args.finetune:
|
||||
ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
|
||||
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
|
||||
elif args.exp_name == "E2TTS_Base":
|
||||
wandb_resume_id = None
|
||||
model_cls = UNetT
|
||||
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
||||
if args.finetune:
|
||||
ckpt_path = str(cached_path(f"hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
|
||||
|
||||
if args.finetune:
|
||||
path_ckpt = os.path.join("ckpts",args.dataset_name)
|
||||
if os.path.isdir(path_ckpt)==False:
|
||||
os.makedirs(path_ckpt,exist_ok=True)
|
||||
shutil.copy2(ckpt_path,os.path.join(path_ckpt,os.path.basename(ckpt_path)))
|
||||
|
||||
checkpoint_path=os.path.join("ckpts",args.dataset_name)
|
||||
ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
|
||||
|
||||
if args.finetune:
|
||||
path_ckpt = os.path.join("ckpts", args.dataset_name)
|
||||
if not os.path.isdir(path_ckpt):
|
||||
os.makedirs(path_ckpt, exist_ok=True)
|
||||
shutil.copy2(ckpt_path, os.path.join(path_ckpt, os.path.basename(ckpt_path)))
|
||||
|
||||
checkpoint_path = os.path.join("ckpts", args.dataset_name)
|
||||
|
||||
# Use the tokenizer and tokenizer_path provided in the command line arguments
|
||||
tokenizer = args.tokenizer
|
||||
if tokenizer == "custom":
|
||||
if not args.tokenizer_path:
|
||||
raise ValueError("Custom tokenizer selected, but no tokenizer_path provided.")
|
||||
tokenizer_path = args.tokenizer_path
|
||||
else:
|
||||
tokenizer_path = args.dataset_name
|
||||
|
||||
# Use the dataset_name provided in the command line
|
||||
tokenizer_path = args.dataset_name if tokenizer != "custom" else tokenizer_path
|
||||
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
|
||||
|
||||
mel_spec_kwargs = dict(
|
||||
@@ -71,11 +93,7 @@ def main():
|
||||
)
|
||||
|
||||
e2tts = CFM(
|
||||
transformer=model_cls(
|
||||
**model_cfg,
|
||||
text_num_embeds=vocab_size,
|
||||
mel_dim=n_mel_channels
|
||||
),
|
||||
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
|
||||
mel_spec_kwargs=mel_spec_kwargs,
|
||||
vocab_char_map=vocab_char_map,
|
||||
)
|
||||
@@ -99,10 +117,11 @@ def main():
|
||||
)
|
||||
|
||||
train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
|
||||
trainer.train(train_dataset,
|
||||
resumable_with_seed=666 # seed for shuffling dataset
|
||||
)
|
||||
trainer.train(
|
||||
train_dataset,
|
||||
resumable_with_seed=666, # seed for shuffling dataset
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
182
gradio_app.py
182
gradio_app.py
@@ -1,3 +1,6 @@
|
||||
# ruff: noqa: E402
|
||||
# Above allows ruff to ignore E402: module level import not at top of file
|
||||
|
||||
import re
|
||||
import tempfile
|
||||
|
||||
@@ -11,16 +14,19 @@ from pydub import AudioSegment
|
||||
|
||||
try:
|
||||
import spaces
|
||||
|
||||
USING_SPACES = True
|
||||
except ImportError:
|
||||
USING_SPACES = False
|
||||
|
||||
|
||||
def gpu_decorator(func):
|
||||
if USING_SPACES:
|
||||
return spaces.GPU(func)
|
||||
else:
|
||||
return func
|
||||
|
||||
|
||||
from model import DiT, UNetT
|
||||
from model.utils import (
|
||||
save_spectrogram,
|
||||
@@ -38,15 +44,18 @@ vocos = load_vocoder()
|
||||
|
||||
# load models
|
||||
F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
||||
F5TTS_ema_model = load_model(DiT, F5TTS_model_cfg, str(cached_path(f"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors")))
|
||||
F5TTS_ema_model = load_model(
|
||||
DiT, F5TTS_model_cfg, str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))
|
||||
)
|
||||
|
||||
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
||||
E2TTS_ema_model = load_model(UNetT, E2TTS_model_cfg, str(cached_path(f"hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors")))
|
||||
E2TTS_ema_model = load_model(
|
||||
UNetT, E2TTS_model_cfg, str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
|
||||
)
|
||||
|
||||
|
||||
@gpu_decorator
|
||||
def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15, speed=1):
|
||||
|
||||
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=gr.Info)
|
||||
|
||||
if model == "F5-TTS":
|
||||
@@ -54,7 +63,16 @@ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_
|
||||
elif model == "E2-TTS":
|
||||
ema_model = E2TTS_ema_model
|
||||
|
||||
final_wave, final_sample_rate, combined_spectrogram = infer_process(ref_audio, ref_text, gen_text, ema_model, cross_fade_duration=cross_fade_duration, speed=speed, show_info=gr.Info, progress=gr.Progress())
|
||||
final_wave, final_sample_rate, combined_spectrogram = infer_process(
|
||||
ref_audio,
|
||||
ref_text,
|
||||
gen_text,
|
||||
ema_model,
|
||||
cross_fade_duration=cross_fade_duration,
|
||||
speed=speed,
|
||||
show_info=gr.Info,
|
||||
progress=gr.Progress(),
|
||||
)
|
||||
|
||||
# Remove silence
|
||||
if remove_silence:
|
||||
@@ -73,17 +91,19 @@ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_
|
||||
|
||||
|
||||
@gpu_decorator
|
||||
def generate_podcast(script, speaker1_name, ref_audio1, ref_text1, speaker2_name, ref_audio2, ref_text2, model, remove_silence):
|
||||
def generate_podcast(
|
||||
script, speaker1_name, ref_audio1, ref_text1, speaker2_name, ref_audio2, ref_text2, model, remove_silence
|
||||
):
|
||||
# Split the script into speaker blocks
|
||||
speaker_pattern = re.compile(f"^({re.escape(speaker1_name)}|{re.escape(speaker2_name)}):", re.MULTILINE)
|
||||
speaker_blocks = speaker_pattern.split(script)[1:] # Skip the first empty element
|
||||
|
||||
|
||||
generated_audio_segments = []
|
||||
|
||||
|
||||
for i in range(0, len(speaker_blocks), 2):
|
||||
speaker = speaker_blocks[i]
|
||||
text = speaker_blocks[i+1].strip()
|
||||
|
||||
text = speaker_blocks[i + 1].strip()
|
||||
|
||||
# Determine which speaker is talking
|
||||
if speaker == speaker1_name:
|
||||
ref_audio = ref_audio1
|
||||
@@ -93,51 +113,52 @@ def generate_podcast(script, speaker1_name, ref_audio1, ref_text1, speaker2_name
|
||||
ref_text = ref_text2
|
||||
else:
|
||||
continue # Skip if the speaker is neither speaker1 nor speaker2
|
||||
|
||||
|
||||
# Generate audio for this block
|
||||
audio, _ = infer(ref_audio, ref_text, text, model, remove_silence)
|
||||
|
||||
|
||||
# Convert the generated audio to a numpy array
|
||||
sr, audio_data = audio
|
||||
|
||||
|
||||
# Save the audio data as a WAV file
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
|
||||
sf.write(temp_file.name, audio_data, sr)
|
||||
audio_segment = AudioSegment.from_wav(temp_file.name)
|
||||
|
||||
|
||||
generated_audio_segments.append(audio_segment)
|
||||
|
||||
|
||||
# Add a short pause between speakers
|
||||
pause = AudioSegment.silent(duration=500) # 500ms pause
|
||||
generated_audio_segments.append(pause)
|
||||
|
||||
|
||||
# Concatenate all audio segments
|
||||
final_podcast = sum(generated_audio_segments)
|
||||
|
||||
|
||||
# Export the final podcast
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
|
||||
podcast_path = temp_file.name
|
||||
final_podcast.export(podcast_path, format="wav")
|
||||
|
||||
|
||||
return podcast_path
|
||||
|
||||
|
||||
def parse_speechtypes_text(gen_text):
|
||||
# Pattern to find (Emotion)
|
||||
pattern = r'\((.*?)\)'
|
||||
pattern = r"\((.*?)\)"
|
||||
|
||||
# Split the text by the pattern
|
||||
tokens = re.split(pattern, gen_text)
|
||||
|
||||
segments = []
|
||||
|
||||
current_emotion = 'Regular'
|
||||
current_emotion = "Regular"
|
||||
|
||||
for i in range(len(tokens)):
|
||||
if i % 2 == 0:
|
||||
# This is text
|
||||
text = tokens[i].strip()
|
||||
if text:
|
||||
segments.append({'emotion': current_emotion, 'text': text})
|
||||
segments.append({"emotion": current_emotion, "text": text})
|
||||
else:
|
||||
# This is emotion
|
||||
emotion = tokens[i].strip()
|
||||
@@ -158,9 +179,7 @@ with gr.Blocks() as app_tts:
|
||||
gr.Markdown("# Batched TTS")
|
||||
ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
|
||||
gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
|
||||
model_choice = gr.Radio(
|
||||
choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
|
||||
)
|
||||
model_choice = gr.Radio(choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS")
|
||||
generate_btn = gr.Button("Synthesize", variant="primary")
|
||||
with gr.Accordion("Advanced Settings", open=False):
|
||||
ref_text_input = gr.Textbox(
|
||||
@@ -206,23 +225,24 @@ with gr.Blocks() as app_tts:
|
||||
],
|
||||
outputs=[audio_output, spectrogram_output],
|
||||
)
|
||||
|
||||
|
||||
with gr.Blocks() as app_podcast:
|
||||
gr.Markdown("# Podcast Generation")
|
||||
speaker1_name = gr.Textbox(label="Speaker 1 Name")
|
||||
ref_audio_input1 = gr.Audio(label="Reference Audio (Speaker 1)", type="filepath")
|
||||
ref_text_input1 = gr.Textbox(label="Reference Text (Speaker 1)", lines=2)
|
||||
|
||||
|
||||
speaker2_name = gr.Textbox(label="Speaker 2 Name")
|
||||
ref_audio_input2 = gr.Audio(label="Reference Audio (Speaker 2)", type="filepath")
|
||||
ref_text_input2 = gr.Textbox(label="Reference Text (Speaker 2)", lines=2)
|
||||
|
||||
script_input = gr.Textbox(label="Podcast Script", lines=10,
|
||||
placeholder="Enter the script with speaker names at the start of each block, e.g.:\nSean: How did you start studying...\n\nMeghan: I came to my interest in technology...\nIt was a long journey...\n\nSean: That's fascinating. Can you elaborate...")
|
||||
|
||||
podcast_model_choice = gr.Radio(
|
||||
choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
|
||||
|
||||
script_input = gr.Textbox(
|
||||
label="Podcast Script",
|
||||
lines=10,
|
||||
placeholder="Enter the script with speaker names at the start of each block, e.g.:\nSean: How did you start studying...\n\nMeghan: I came to my interest in technology...\nIt was a long journey...\n\nSean: That's fascinating. Can you elaborate...",
|
||||
)
|
||||
|
||||
podcast_model_choice = gr.Radio(choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS")
|
||||
podcast_remove_silence = gr.Checkbox(
|
||||
label="Remove Silences",
|
||||
value=True,
|
||||
@@ -230,8 +250,12 @@ with gr.Blocks() as app_podcast:
|
||||
generate_podcast_btn = gr.Button("Generate Podcast", variant="primary")
|
||||
podcast_output = gr.Audio(label="Generated Podcast")
|
||||
|
||||
def podcast_generation(script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence):
|
||||
return generate_podcast(script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence)
|
||||
def podcast_generation(
|
||||
script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence
|
||||
):
|
||||
return generate_podcast(
|
||||
script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence
|
||||
)
|
||||
|
||||
generate_podcast_btn.click(
|
||||
podcast_generation,
|
||||
@@ -249,23 +273,24 @@ with gr.Blocks() as app_podcast:
|
||||
outputs=podcast_output,
|
||||
)
|
||||
|
||||
|
||||
def parse_emotional_text(gen_text):
|
||||
# Pattern to find (Emotion)
|
||||
pattern = r'\((.*?)\)'
|
||||
pattern = r"\((.*?)\)"
|
||||
|
||||
# Split the text by the pattern
|
||||
tokens = re.split(pattern, gen_text)
|
||||
|
||||
segments = []
|
||||
|
||||
current_emotion = 'Regular'
|
||||
current_emotion = "Regular"
|
||||
|
||||
for i in range(len(tokens)):
|
||||
if i % 2 == 0:
|
||||
# This is text
|
||||
text = tokens[i].strip()
|
||||
if text:
|
||||
segments.append({'emotion': current_emotion, 'text': text})
|
||||
segments.append({"emotion": current_emotion, "text": text})
|
||||
else:
|
||||
# This is emotion
|
||||
emotion = tokens[i].strip()
|
||||
@@ -273,6 +298,7 @@ def parse_emotional_text(gen_text):
|
||||
|
||||
return segments
|
||||
|
||||
|
||||
with gr.Blocks() as app_emotional:
|
||||
# New section for emotional generation
|
||||
gr.Markdown(
|
||||
@@ -287,13 +313,15 @@ with gr.Blocks() as app_emotional:
|
||||
"""
|
||||
)
|
||||
|
||||
gr.Markdown("Upload different audio clips for each speech type. 'Regular' emotion is mandatory. You can add additional speech types by clicking the 'Add Speech Type' button.")
|
||||
gr.Markdown(
|
||||
"Upload different audio clips for each speech type. 'Regular' emotion is mandatory. You can add additional speech types by clicking the 'Add Speech Type' button."
|
||||
)
|
||||
|
||||
# Regular speech type (mandatory)
|
||||
with gr.Row():
|
||||
regular_name = gr.Textbox(value='Regular', label='Speech Type Name', interactive=False)
|
||||
regular_audio = gr.Audio(label='Regular Reference Audio', type='filepath')
|
||||
regular_ref_text = gr.Textbox(label='Reference Text (Regular)', lines=2)
|
||||
regular_name = gr.Textbox(value="Regular", label="Speech Type Name", interactive=False)
|
||||
regular_audio = gr.Audio(label="Regular Reference Audio", type="filepath")
|
||||
regular_ref_text = gr.Textbox(label="Reference Text (Regular)", lines=2)
|
||||
|
||||
# Additional speech types (up to 99 more)
|
||||
max_speech_types = 100
|
||||
@@ -304,9 +332,9 @@ with gr.Blocks() as app_emotional:
|
||||
|
||||
for i in range(max_speech_types - 1):
|
||||
with gr.Row():
|
||||
name_input = gr.Textbox(label='Speech Type Name', visible=False)
|
||||
audio_input = gr.Audio(label='Reference Audio', type='filepath', visible=False)
|
||||
ref_text_input = gr.Textbox(label='Reference Text', lines=2, visible=False)
|
||||
name_input = gr.Textbox(label="Speech Type Name", visible=False)
|
||||
audio_input = gr.Audio(label="Reference Audio", type="filepath", visible=False)
|
||||
ref_text_input = gr.Textbox(label="Reference Text", lines=2, visible=False)
|
||||
delete_btn = gr.Button("Delete", variant="secondary", visible=False)
|
||||
speech_type_names.append(name_input)
|
||||
speech_type_audios.append(audio_input)
|
||||
@@ -351,7 +379,11 @@ with gr.Blocks() as app_emotional:
|
||||
add_speech_type_btn.click(
|
||||
add_speech_type_fn,
|
||||
inputs=speech_type_count,
|
||||
outputs=[speech_type_count] + speech_type_names + speech_type_audios + speech_type_ref_texts + speech_type_delete_btns
|
||||
outputs=[speech_type_count]
|
||||
+ speech_type_names
|
||||
+ speech_type_audios
|
||||
+ speech_type_ref_texts
|
||||
+ speech_type_delete_btns,
|
||||
)
|
||||
|
||||
# Function to delete a speech type
|
||||
@@ -365,9 +397,9 @@ with gr.Blocks() as app_emotional:
|
||||
|
||||
for i in range(max_speech_types - 1):
|
||||
if i == index:
|
||||
name_updates.append(gr.update(visible=False, value=''))
|
||||
name_updates.append(gr.update(visible=False, value=""))
|
||||
audio_updates.append(gr.update(visible=False, value=None))
|
||||
ref_text_updates.append(gr.update(visible=False, value=''))
|
||||
ref_text_updates.append(gr.update(visible=False, value=""))
|
||||
delete_btn_updates.append(gr.update(visible=False))
|
||||
else:
|
||||
name_updates.append(gr.update())
|
||||
@@ -386,16 +418,18 @@ with gr.Blocks() as app_emotional:
|
||||
delete_btn.click(
|
||||
delete_fn,
|
||||
inputs=speech_type_count,
|
||||
outputs=[speech_type_count] + speech_type_names + speech_type_audios + speech_type_ref_texts + speech_type_delete_btns
|
||||
outputs=[speech_type_count]
|
||||
+ speech_type_names
|
||||
+ speech_type_audios
|
||||
+ speech_type_ref_texts
|
||||
+ speech_type_delete_btns,
|
||||
)
|
||||
|
||||
# Text input for the prompt
|
||||
gen_text_input_emotional = gr.Textbox(label="Text to Generate", lines=10)
|
||||
|
||||
# Model choice
|
||||
model_choice_emotional = gr.Radio(
|
||||
choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
|
||||
)
|
||||
model_choice_emotional = gr.Radio(choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS")
|
||||
|
||||
with gr.Accordion("Advanced Settings", open=False):
|
||||
remove_silence_emotional = gr.Checkbox(
|
||||
@@ -408,6 +442,7 @@ with gr.Blocks() as app_emotional:
|
||||
|
||||
# Output audio
|
||||
audio_output_emotional = gr.Audio(label="Synthesized Audio")
|
||||
|
||||
@gpu_decorator
|
||||
def generate_emotional_speech(
|
||||
regular_audio,
|
||||
@@ -417,37 +452,39 @@ with gr.Blocks() as app_emotional:
|
||||
):
|
||||
num_additional_speech_types = max_speech_types - 1
|
||||
speech_type_names_list = args[:num_additional_speech_types]
|
||||
speech_type_audios_list = args[num_additional_speech_types:2 * num_additional_speech_types]
|
||||
speech_type_ref_texts_list = args[2 * num_additional_speech_types:3 * num_additional_speech_types]
|
||||
speech_type_audios_list = args[num_additional_speech_types : 2 * num_additional_speech_types]
|
||||
speech_type_ref_texts_list = args[2 * num_additional_speech_types : 3 * num_additional_speech_types]
|
||||
model_choice = args[3 * num_additional_speech_types]
|
||||
remove_silence = args[3 * num_additional_speech_types + 1]
|
||||
|
||||
# Collect the speech types and their audios into a dict
|
||||
speech_types = {'Regular': {'audio': regular_audio, 'ref_text': regular_ref_text}}
|
||||
speech_types = {"Regular": {"audio": regular_audio, "ref_text": regular_ref_text}}
|
||||
|
||||
for name_input, audio_input, ref_text_input in zip(speech_type_names_list, speech_type_audios_list, speech_type_ref_texts_list):
|
||||
for name_input, audio_input, ref_text_input in zip(
|
||||
speech_type_names_list, speech_type_audios_list, speech_type_ref_texts_list
|
||||
):
|
||||
if name_input and audio_input:
|
||||
speech_types[name_input] = {'audio': audio_input, 'ref_text': ref_text_input}
|
||||
speech_types[name_input] = {"audio": audio_input, "ref_text": ref_text_input}
|
||||
|
||||
# Parse the gen_text into segments
|
||||
segments = parse_speechtypes_text(gen_text)
|
||||
|
||||
# For each segment, generate speech
|
||||
generated_audio_segments = []
|
||||
current_emotion = 'Regular'
|
||||
current_emotion = "Regular"
|
||||
|
||||
for segment in segments:
|
||||
emotion = segment['emotion']
|
||||
text = segment['text']
|
||||
emotion = segment["emotion"]
|
||||
text = segment["text"]
|
||||
|
||||
if emotion in speech_types:
|
||||
current_emotion = emotion
|
||||
else:
|
||||
# If emotion not available, default to Regular
|
||||
current_emotion = 'Regular'
|
||||
current_emotion = "Regular"
|
||||
|
||||
ref_audio = speech_types[current_emotion]['audio']
|
||||
ref_text = speech_types[current_emotion].get('ref_text', '')
|
||||
ref_audio = speech_types[current_emotion]["audio"]
|
||||
ref_text = speech_types[current_emotion].get("ref_text", "")
|
||||
|
||||
# Generate speech for this segment
|
||||
audio, _ = infer(ref_audio, ref_text, text, model_choice, remove_silence, 0)
|
||||
@@ -469,7 +506,11 @@ with gr.Blocks() as app_emotional:
|
||||
regular_audio,
|
||||
regular_ref_text,
|
||||
gen_text_input_emotional,
|
||||
] + speech_type_names + speech_type_audios + speech_type_ref_texts + [
|
||||
]
|
||||
+ speech_type_names
|
||||
+ speech_type_audios
|
||||
+ speech_type_ref_texts
|
||||
+ [
|
||||
model_choice_emotional,
|
||||
remove_silence_emotional,
|
||||
],
|
||||
@@ -477,11 +518,7 @@ with gr.Blocks() as app_emotional:
|
||||
)
|
||||
|
||||
# Validation function to disable Generate button if speech types are missing
|
||||
def validate_speech_types(
|
||||
gen_text,
|
||||
regular_name,
|
||||
*args
|
||||
):
|
||||
def validate_speech_types(gen_text, regular_name, *args):
|
||||
num_additional_speech_types = max_speech_types - 1
|
||||
speech_type_names_list = args[:num_additional_speech_types]
|
||||
|
||||
@@ -495,7 +532,7 @@ with gr.Blocks() as app_emotional:
|
||||
|
||||
# Parse the gen_text to get the speech types used
|
||||
segments = parse_emotional_text(gen_text)
|
||||
speech_types_in_text = set(segment['emotion'] for segment in segments)
|
||||
speech_types_in_text = set(segment["emotion"] for segment in segments)
|
||||
|
||||
# Check if all speech types in text are available
|
||||
missing_speech_types = speech_types_in_text - speech_types_available
|
||||
@@ -510,7 +547,7 @@ with gr.Blocks() as app_emotional:
|
||||
gen_text_input_emotional.change(
|
||||
validate_speech_types,
|
||||
inputs=[gen_text_input_emotional, regular_name] + speech_type_names,
|
||||
outputs=generate_emotional_btn
|
||||
outputs=generate_emotional_btn,
|
||||
)
|
||||
with gr.Blocks() as app:
|
||||
gr.Markdown(
|
||||
@@ -531,6 +568,7 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
|
||||
)
|
||||
gr.TabbedInterface([app_tts, app_podcast, app_emotional, app_credits], ["TTS", "Podcast", "Multi-Style", "Credits"])
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
|
||||
@click.option("--host", "-H", default=None, help="Host to run the app on")
|
||||
@@ -544,10 +582,8 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
|
||||
@click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
|
||||
def main(port, host, share, api):
|
||||
global app
|
||||
print(f"Starting app...")
|
||||
app.queue(api_open=api).launch(
|
||||
server_name=host, server_port=port, share=share, show_api=api
|
||||
)
|
||||
print("Starting app...")
|
||||
app.queue(api_open=api).launch(server_name=host, server_port=port, share=share, show_api=api)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -44,19 +44,8 @@ parser.add_argument(
|
||||
"--vocab_file",
|
||||
help="The vocab .txt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--ref_audio",
|
||||
type=str,
|
||||
help="Reference audio file < 15 seconds."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--ref_text",
|
||||
type=str,
|
||||
default="666",
|
||||
help="Subtitle for the reference audio."
|
||||
)
|
||||
parser.add_argument("-r", "--ref_audio", type=str, help="Reference audio file < 15 seconds.")
|
||||
parser.add_argument("-s", "--ref_text", type=str, default="666", help="Subtitle for the reference audio.")
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--gen_text",
|
||||
@@ -99,8 +88,8 @@ model = args.model if args.model else config["model"]
|
||||
ckpt_file = args.ckpt_file if args.ckpt_file else ""
|
||||
vocab_file = args.vocab_file if args.vocab_file else ""
|
||||
remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
|
||||
wave_path = Path(output_dir)/"out.wav"
|
||||
spectrogram_path = Path(output_dir)/"out.png"
|
||||
wave_path = Path(output_dir) / "out.wav"
|
||||
spectrogram_path = Path(output_dir) / "out.png"
|
||||
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
||||
|
||||
vocos = load_vocoder(is_local=args.load_vocoder_from_local, local_path=vocos_local_path)
|
||||
@@ -110,44 +99,46 @@ vocos = load_vocoder(is_local=args.load_vocoder_from_local, local_path=vocos_loc
|
||||
if model == "F5-TTS":
|
||||
model_cls = DiT
|
||||
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
||||
if ckpt_file == "":
|
||||
repo_name= "F5-TTS"
|
||||
if ckpt_file == "":
|
||||
repo_name = "F5-TTS"
|
||||
exp_name = "F5TTS_Base"
|
||||
ckpt_step= 1200000
|
||||
ckpt_step = 1200000
|
||||
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
||||
# ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
|
||||
|
||||
elif model == "E2-TTS":
|
||||
model_cls = UNetT
|
||||
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
||||
if ckpt_file == "":
|
||||
repo_name= "E2-TTS"
|
||||
if ckpt_file == "":
|
||||
repo_name = "E2-TTS"
|
||||
exp_name = "E2TTS_Base"
|
||||
ckpt_step= 1200000
|
||||
ckpt_step = 1200000
|
||||
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
||||
# ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
|
||||
|
||||
print(f"Using {model}...")
|
||||
ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file)
|
||||
|
||||
|
||||
|
||||
def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence):
|
||||
main_voice = {"ref_audio":ref_audio, "ref_text":ref_text}
|
||||
main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
|
||||
if "voices" not in config:
|
||||
voices = {"main": main_voice}
|
||||
else:
|
||||
voices = config["voices"]
|
||||
voices["main"] = main_voice
|
||||
for voice in voices:
|
||||
voices[voice]['ref_audio'], voices[voice]['ref_text'] = preprocess_ref_audio_text(voices[voice]['ref_audio'], voices[voice]['ref_text'])
|
||||
voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
|
||||
voices[voice]["ref_audio"], voices[voice]["ref_text"]
|
||||
)
|
||||
print("Voice:", voice)
|
||||
print("Ref_audio:", voices[voice]['ref_audio'])
|
||||
print("Ref_text:", voices[voice]['ref_text'])
|
||||
print("Ref_audio:", voices[voice]["ref_audio"])
|
||||
print("Ref_text:", voices[voice]["ref_text"])
|
||||
|
||||
generated_audio_segments = []
|
||||
reg1 = r'(?=\[\w+\])'
|
||||
reg1 = r"(?=\[\w+\])"
|
||||
chunks = re.split(reg1, text_gen)
|
||||
reg2 = r'\[(\w+)\]'
|
||||
reg2 = r"\[(\w+)\]"
|
||||
for text in chunks:
|
||||
match = re.match(reg2, text)
|
||||
if match:
|
||||
@@ -160,8 +151,8 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence):
|
||||
voice = "main"
|
||||
text = re.sub(reg2, "", text)
|
||||
gen_text = text.strip()
|
||||
ref_audio = voices[voice]['ref_audio']
|
||||
ref_text = voices[voice]['ref_text']
|
||||
ref_audio = voices[voice]["ref_audio"]
|
||||
ref_text = voices[voice]["ref_text"]
|
||||
print(f"Voice: {voice}")
|
||||
audio, final_sample_rate, spectragram = infer_process(ref_audio, ref_text, gen_text, model_obj)
|
||||
generated_audio_segments.append(audio)
|
||||
|
||||
@@ -5,3 +5,6 @@ from model.backbones.dit import DiT
|
||||
from model.backbones.mmdit import MMDiT
|
||||
|
||||
from model.trainer import Trainer
|
||||
|
||||
|
||||
__all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"]
|
||||
|
||||
@@ -21,14 +21,16 @@ from model.modules import (
|
||||
ConvPositionEmbedding,
|
||||
DiTBlock,
|
||||
AdaLayerNormZero_Final,
|
||||
precompute_freqs_cis, get_pos_embed_indices,
|
||||
precompute_freqs_cis,
|
||||
get_pos_embed_indices,
|
||||
)
|
||||
|
||||
|
||||
# Text embedding
|
||||
|
||||
|
||||
class TextEmbedding(nn.Module):
|
||||
def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2):
|
||||
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
|
||||
super().__init__()
|
||||
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
|
||||
|
||||
@@ -36,20 +38,22 @@ class TextEmbedding(nn.Module):
|
||||
self.extra_modeling = True
|
||||
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
||||
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
|
||||
self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
|
||||
self.text_blocks = nn.Sequential(
|
||||
*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
|
||||
)
|
||||
else:
|
||||
self.extra_modeling = False
|
||||
|
||||
def forward(self, text: int['b nt'], seq_len, drop_text = False):
|
||||
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
|
||||
batch, text_len = text.shape[0], text.shape[1]
|
||||
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
||||
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
||||
text = F.pad(text, (0, seq_len - text_len), value = 0)
|
||||
text = F.pad(text, (0, seq_len - text_len), value=0)
|
||||
|
||||
if drop_text: # cfg for text
|
||||
text = torch.zeros_like(text)
|
||||
|
||||
text = self.text_embed(text) # b n -> b n d
|
||||
text = self.text_embed(text) # b n -> b n d
|
||||
|
||||
# possible extra modeling
|
||||
if self.extra_modeling:
|
||||
@@ -67,88 +71,91 @@ class TextEmbedding(nn.Module):
|
||||
|
||||
# noised input audio and context mixing embedding
|
||||
|
||||
|
||||
class InputEmbedding(nn.Module):
|
||||
def __init__(self, mel_dim, text_dim, out_dim):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
|
||||
self.conv_pos_embed = ConvPositionEmbedding(dim = out_dim)
|
||||
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
|
||||
|
||||
def forward(self, x: float['b n d'], cond: float['b n d'], text_embed: float['b n d'], drop_audio_cond = False):
|
||||
def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
|
||||
if drop_audio_cond: # cfg for cond audio
|
||||
cond = torch.zeros_like(cond)
|
||||
|
||||
x = self.proj(torch.cat((x, cond, text_embed), dim = -1))
|
||||
x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
|
||||
x = self.conv_pos_embed(x) + x
|
||||
return x
|
||||
|
||||
|
||||
|
||||
# Transformer backbone using DiT blocks
|
||||
|
||||
|
||||
class DiT(nn.Module):
|
||||
def __init__(self, *,
|
||||
dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
|
||||
mel_dim = 100, text_num_embeds = 256, text_dim = None, conv_layers = 0,
|
||||
long_skip_connection = False,
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
depth=8,
|
||||
heads=8,
|
||||
dim_head=64,
|
||||
dropout=0.1,
|
||||
ff_mult=4,
|
||||
mel_dim=100,
|
||||
text_num_embeds=256,
|
||||
text_dim=None,
|
||||
conv_layers=0,
|
||||
long_skip_connection=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.time_embed = TimestepEmbedding(dim)
|
||||
if text_dim is None:
|
||||
text_dim = mel_dim
|
||||
self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers = conv_layers)
|
||||
self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
|
||||
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
||||
|
||||
self.rotary_embed = RotaryEmbedding(dim_head)
|
||||
|
||||
self.dim = dim
|
||||
self.depth = depth
|
||||
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
DiTBlock(
|
||||
dim = dim,
|
||||
heads = heads,
|
||||
dim_head = dim_head,
|
||||
ff_mult = ff_mult,
|
||||
dropout = dropout
|
||||
)
|
||||
for _ in range(depth)
|
||||
]
|
||||
[DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
|
||||
)
|
||||
self.long_skip_connection = nn.Linear(dim * 2, dim, bias = False) if long_skip_connection else None
|
||||
|
||||
self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
|
||||
|
||||
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
||||
self.proj_out = nn.Linear(dim, mel_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: float['b n d'], # nosied input audio
|
||||
cond: float['b n d'], # masked cond audio
|
||||
text: int['b nt'], # text
|
||||
time: float['b'] | float[''], # time step
|
||||
x: float["b n d"], # nosied input audio # noqa: F722
|
||||
cond: float["b n d"], # masked cond audio # noqa: F722
|
||||
text: int["b nt"], # text # noqa: F722
|
||||
time: float["b"] | float[""], # time step # noqa: F821 F722
|
||||
drop_audio_cond, # cfg for cond audio
|
||||
drop_text, # cfg for text
|
||||
mask: bool['b n'] | None = None,
|
||||
drop_text, # cfg for text
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
):
|
||||
batch, seq_len = x.shape[0], x.shape[1]
|
||||
if time.ndim == 0:
|
||||
time = time.repeat(batch)
|
||||
|
||||
|
||||
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
||||
t = self.time_embed(time)
|
||||
text_embed = self.text_embed(text, seq_len, drop_text = drop_text)
|
||||
x = self.input_embed(x, cond, text_embed, drop_audio_cond = drop_audio_cond)
|
||||
|
||||
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
|
||||
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
||||
|
||||
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
||||
|
||||
if self.long_skip_connection is not None:
|
||||
residual = x
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, t, mask = mask, rope = rope)
|
||||
x = block(x, t, mask=mask, rope=rope)
|
||||
|
||||
if self.long_skip_connection is not None:
|
||||
x = self.long_skip_connection(torch.cat((x, residual), dim = -1))
|
||||
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
|
||||
|
||||
x = self.norm_out(x, t)
|
||||
output = self.proj_out(x)
|
||||
|
||||
@@ -19,12 +19,14 @@ from model.modules import (
|
||||
ConvPositionEmbedding,
|
||||
MMDiTBlock,
|
||||
AdaLayerNormZero_Final,
|
||||
precompute_freqs_cis, get_pos_embed_indices,
|
||||
precompute_freqs_cis,
|
||||
get_pos_embed_indices,
|
||||
)
|
||||
|
||||
|
||||
# text embedding
|
||||
|
||||
|
||||
class TextEmbedding(nn.Module):
|
||||
def __init__(self, out_dim, text_num_embeds):
|
||||
super().__init__()
|
||||
@@ -33,7 +35,7 @@ class TextEmbedding(nn.Module):
|
||||
self.precompute_max_pos = 1024
|
||||
self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
|
||||
|
||||
def forward(self, text: int['b nt'], drop_text = False) -> int['b nt d']:
|
||||
def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722
|
||||
text = text + 1
|
||||
if drop_text:
|
||||
text = torch.zeros_like(text)
|
||||
@@ -52,27 +54,37 @@ class TextEmbedding(nn.Module):
|
||||
|
||||
# noised input & masked cond audio embedding
|
||||
|
||||
|
||||
class AudioEmbedding(nn.Module):
|
||||
def __init__(self, in_dim, out_dim):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(2 * in_dim, out_dim)
|
||||
self.conv_pos_embed = ConvPositionEmbedding(out_dim)
|
||||
|
||||
def forward(self, x: float['b n d'], cond: float['b n d'], drop_audio_cond = False):
|
||||
def forward(self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False): # noqa: F722
|
||||
if drop_audio_cond:
|
||||
cond = torch.zeros_like(cond)
|
||||
x = torch.cat((x, cond), dim = -1)
|
||||
x = torch.cat((x, cond), dim=-1)
|
||||
x = self.linear(x)
|
||||
x = self.conv_pos_embed(x) + x
|
||||
return x
|
||||
|
||||
|
||||
|
||||
# Transformer backbone using MM-DiT blocks
|
||||
|
||||
|
||||
class MMDiT(nn.Module):
|
||||
def __init__(self, *,
|
||||
dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
|
||||
text_num_embeds = 256, mel_dim = 100,
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
depth=8,
|
||||
heads=8,
|
||||
dim_head=64,
|
||||
dropout=0.1,
|
||||
ff_mult=4,
|
||||
text_num_embeds=256,
|
||||
mel_dim=100,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -84,16 +96,16 @@ class MMDiT(nn.Module):
|
||||
|
||||
self.dim = dim
|
||||
self.depth = depth
|
||||
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
MMDiTBlock(
|
||||
dim = dim,
|
||||
heads = heads,
|
||||
dim_head = dim_head,
|
||||
dropout = dropout,
|
||||
ff_mult = ff_mult,
|
||||
context_pre_only = i == depth - 1,
|
||||
dim=dim,
|
||||
heads=heads,
|
||||
dim_head=dim_head,
|
||||
dropout=dropout,
|
||||
ff_mult=ff_mult,
|
||||
context_pre_only=i == depth - 1,
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
@@ -103,13 +115,13 @@ class MMDiT(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: float['b n d'], # nosied input audio
|
||||
cond: float['b n d'], # masked cond audio
|
||||
text: int['b nt'], # text
|
||||
time: float['b'] | float[''], # time step
|
||||
x: float["b n d"], # nosied input audio # noqa: F722
|
||||
cond: float["b n d"], # masked cond audio # noqa: F722
|
||||
text: int["b nt"], # text # noqa: F722
|
||||
time: float["b"] | float[""], # time step # noqa: F821 F722
|
||||
drop_audio_cond, # cfg for cond audio
|
||||
drop_text, # cfg for text
|
||||
mask: bool['b n'] | None = None,
|
||||
drop_text, # cfg for text
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
):
|
||||
batch = x.shape[0]
|
||||
if time.ndim == 0:
|
||||
@@ -117,16 +129,16 @@ class MMDiT(nn.Module):
|
||||
|
||||
# t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
|
||||
t = self.time_embed(time)
|
||||
c = self.text_embed(text, drop_text = drop_text)
|
||||
x = self.audio_embed(x, cond, drop_audio_cond = drop_audio_cond)
|
||||
c = self.text_embed(text, drop_text=drop_text)
|
||||
x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
|
||||
|
||||
seq_len = x.shape[1]
|
||||
text_len = text.shape[1]
|
||||
rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
|
||||
rope_text = self.rotary_embed.forward_from_seq_len(text_len)
|
||||
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
c, x = block(x, c, t, mask = mask, rope = rope_audio, c_rope = rope_text)
|
||||
c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text)
|
||||
|
||||
x = self.norm_out(x, t)
|
||||
output = self.proj_out(x)
|
||||
|
||||
@@ -24,14 +24,16 @@ from model.modules import (
|
||||
Attention,
|
||||
AttnProcessor,
|
||||
FeedForward,
|
||||
precompute_freqs_cis, get_pos_embed_indices,
|
||||
precompute_freqs_cis,
|
||||
get_pos_embed_indices,
|
||||
)
|
||||
|
||||
|
||||
# Text embedding
|
||||
|
||||
|
||||
class TextEmbedding(nn.Module):
|
||||
def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2):
|
||||
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
|
||||
super().__init__()
|
||||
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
|
||||
|
||||
@@ -39,20 +41,22 @@ class TextEmbedding(nn.Module):
|
||||
self.extra_modeling = True
|
||||
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
||||
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
|
||||
self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
|
||||
self.text_blocks = nn.Sequential(
|
||||
*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
|
||||
)
|
||||
else:
|
||||
self.extra_modeling = False
|
||||
|
||||
def forward(self, text: int['b nt'], seq_len, drop_text = False):
|
||||
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
|
||||
batch, text_len = text.shape[0], text.shape[1]
|
||||
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
||||
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
||||
text = F.pad(text, (0, seq_len - text_len), value = 0)
|
||||
text = F.pad(text, (0, seq_len - text_len), value=0)
|
||||
|
||||
if drop_text: # cfg for text
|
||||
text = torch.zeros_like(text)
|
||||
|
||||
text = self.text_embed(text) # b n -> b n d
|
||||
text = self.text_embed(text) # b n -> b n d
|
||||
|
||||
# possible extra modeling
|
||||
if self.extra_modeling:
|
||||
@@ -70,28 +74,40 @@ class TextEmbedding(nn.Module):
|
||||
|
||||
# noised input audio and context mixing embedding
|
||||
|
||||
|
||||
class InputEmbedding(nn.Module):
|
||||
def __init__(self, mel_dim, text_dim, out_dim):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
|
||||
self.conv_pos_embed = ConvPositionEmbedding(dim = out_dim)
|
||||
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
|
||||
|
||||
def forward(self, x: float['b n d'], cond: float['b n d'], text_embed: float['b n d'], drop_audio_cond = False):
|
||||
def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
|
||||
if drop_audio_cond: # cfg for cond audio
|
||||
cond = torch.zeros_like(cond)
|
||||
|
||||
x = self.proj(torch.cat((x, cond, text_embed), dim = -1))
|
||||
x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
|
||||
x = self.conv_pos_embed(x) + x
|
||||
return x
|
||||
|
||||
|
||||
# Flat UNet Transformer backbone
|
||||
|
||||
|
||||
class UNetT(nn.Module):
|
||||
def __init__(self, *,
|
||||
dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
|
||||
mel_dim = 100, text_num_embeds = 256, text_dim = None, conv_layers = 0,
|
||||
skip_connect_type: Literal['add', 'concat', 'none'] = 'concat',
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
depth=8,
|
||||
heads=8,
|
||||
dim_head=64,
|
||||
dropout=0.1,
|
||||
ff_mult=4,
|
||||
mel_dim=100,
|
||||
text_num_embeds=256,
|
||||
text_dim=None,
|
||||
conv_layers=0,
|
||||
skip_connect_type: Literal["add", "concat", "none"] = "concat",
|
||||
):
|
||||
super().__init__()
|
||||
assert depth % 2 == 0, "UNet-Transformer's depth should be even."
|
||||
@@ -99,7 +115,7 @@ class UNetT(nn.Module):
|
||||
self.time_embed = TimestepEmbedding(dim)
|
||||
if text_dim is None:
|
||||
text_dim = mel_dim
|
||||
self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers = conv_layers)
|
||||
self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
|
||||
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
||||
|
||||
self.rotary_embed = RotaryEmbedding(dim_head)
|
||||
@@ -108,7 +124,7 @@ class UNetT(nn.Module):
|
||||
|
||||
self.dim = dim
|
||||
self.skip_connect_type = skip_connect_type
|
||||
needs_skip_proj = skip_connect_type == 'concat'
|
||||
needs_skip_proj = skip_connect_type == "concat"
|
||||
|
||||
self.depth = depth
|
||||
self.layers = nn.ModuleList([])
|
||||
@@ -118,53 +134,57 @@ class UNetT(nn.Module):
|
||||
|
||||
attn_norm = RMSNorm(dim)
|
||||
attn = Attention(
|
||||
processor = AttnProcessor(),
|
||||
dim = dim,
|
||||
heads = heads,
|
||||
dim_head = dim_head,
|
||||
dropout = dropout,
|
||||
)
|
||||
processor=AttnProcessor(),
|
||||
dim=dim,
|
||||
heads=heads,
|
||||
dim_head=dim_head,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
ff_norm = RMSNorm(dim)
|
||||
ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
|
||||
ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
||||
|
||||
skip_proj = nn.Linear(dim * 2, dim, bias = False) if needs_skip_proj and is_later_half else None
|
||||
skip_proj = nn.Linear(dim * 2, dim, bias=False) if needs_skip_proj and is_later_half else None
|
||||
|
||||
self.layers.append(nn.ModuleList([
|
||||
skip_proj,
|
||||
attn_norm,
|
||||
attn,
|
||||
ff_norm,
|
||||
ff,
|
||||
]))
|
||||
self.layers.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
skip_proj,
|
||||
attn_norm,
|
||||
attn,
|
||||
ff_norm,
|
||||
ff,
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
self.norm_out = RMSNorm(dim)
|
||||
self.proj_out = nn.Linear(dim, mel_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: float['b n d'], # nosied input audio
|
||||
cond: float['b n d'], # masked cond audio
|
||||
text: int['b nt'], # text
|
||||
time: float['b'] | float[''], # time step
|
||||
x: float["b n d"], # nosied input audio # noqa: F722
|
||||
cond: float["b n d"], # masked cond audio # noqa: F722
|
||||
text: int["b nt"], # text # noqa: F722
|
||||
time: float["b"] | float[""], # time step # noqa: F821 F722
|
||||
drop_audio_cond, # cfg for cond audio
|
||||
drop_text, # cfg for text
|
||||
mask: bool['b n'] | None = None,
|
||||
drop_text, # cfg for text
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
):
|
||||
batch, seq_len = x.shape[0], x.shape[1]
|
||||
if time.ndim == 0:
|
||||
time = time.repeat(batch)
|
||||
|
||||
|
||||
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
||||
t = self.time_embed(time)
|
||||
text_embed = self.text_embed(text, seq_len, drop_text = drop_text)
|
||||
x = self.input_embed(x, cond, text_embed, drop_audio_cond = drop_audio_cond)
|
||||
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
|
||||
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
||||
|
||||
# postfix time t to input x, [b n d] -> [b n+1 d]
|
||||
x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x
|
||||
if mask is not None:
|
||||
mask = F.pad(mask, (1, 0), value=1)
|
||||
|
||||
|
||||
rope = self.rotary_embed.forward_from_seq_len(seq_len + 1)
|
||||
|
||||
# flat unet transformer
|
||||
@@ -182,14 +202,14 @@ class UNetT(nn.Module):
|
||||
|
||||
if is_later_half:
|
||||
skip = skips.pop()
|
||||
if skip_connect_type == 'concat':
|
||||
x = torch.cat((x, skip), dim = -1)
|
||||
if skip_connect_type == "concat":
|
||||
x = torch.cat((x, skip), dim=-1)
|
||||
x = maybe_skip_proj(x)
|
||||
elif skip_connect_type == 'add':
|
||||
elif skip_connect_type == "add":
|
||||
x = x + skip
|
||||
|
||||
# attention and feedforward blocks
|
||||
x = attn(attn_norm(x), rope = rope, mask = mask) + x
|
||||
x = attn(attn_norm(x), rope=rope, mask=mask) + x
|
||||
x = ff(ff_norm(x)) + x
|
||||
|
||||
assert len(skips) == 0
|
||||
|
||||
126
model/cfm.py
126
model/cfm.py
@@ -20,29 +20,32 @@ from torchdiffeq import odeint
|
||||
|
||||
from model.modules import MelSpec
|
||||
from model.utils import (
|
||||
default, exists,
|
||||
list_str_to_idx, list_str_to_tensor,
|
||||
lens_to_mask, mask_from_frac_lengths,
|
||||
)
|
||||
default,
|
||||
exists,
|
||||
list_str_to_idx,
|
||||
list_str_to_tensor,
|
||||
lens_to_mask,
|
||||
mask_from_frac_lengths,
|
||||
)
|
||||
|
||||
|
||||
class CFM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
transformer: nn.Module,
|
||||
sigma = 0.,
|
||||
sigma=0.0,
|
||||
odeint_kwargs: dict = dict(
|
||||
# atol = 1e-5,
|
||||
# rtol = 1e-5,
|
||||
method = 'euler' # 'midpoint'
|
||||
method="euler" # 'midpoint'
|
||||
),
|
||||
audio_drop_prob = 0.3,
|
||||
cond_drop_prob = 0.2,
|
||||
num_channels = None,
|
||||
audio_drop_prob=0.3,
|
||||
cond_drop_prob=0.2,
|
||||
num_channels=None,
|
||||
mel_spec_module: nn.Module | None = None,
|
||||
mel_spec_kwargs: dict = dict(),
|
||||
frac_lengths_mask: tuple[float, float] = (0.7, 1.),
|
||||
vocab_char_map: dict[str: int] | None = None
|
||||
frac_lengths_mask: tuple[float, float] = (0.7, 1.0),
|
||||
vocab_char_map: dict[str:int] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -78,21 +81,21 @@ class CFM(nn.Module):
|
||||
@torch.no_grad()
|
||||
def sample(
|
||||
self,
|
||||
cond: float['b n d'] | float['b nw'],
|
||||
text: int['b nt'] | list[str],
|
||||
duration: int | int['b'],
|
||||
cond: float["b n d"] | float["b nw"], # noqa: F722
|
||||
text: int["b nt"] | list[str], # noqa: F722
|
||||
duration: int | int["b"], # noqa: F821
|
||||
*,
|
||||
lens: int['b'] | None = None,
|
||||
steps = 32,
|
||||
cfg_strength = 1.,
|
||||
sway_sampling_coef = None,
|
||||
lens: int["b"] | None = None, # noqa: F821
|
||||
steps=32,
|
||||
cfg_strength=1.0,
|
||||
sway_sampling_coef=None,
|
||||
seed: int | None = None,
|
||||
max_duration = 4096,
|
||||
vocoder: Callable[[float['b d n']], float['b nw']] | None = None,
|
||||
no_ref_audio = False,
|
||||
duplicate_test = False,
|
||||
t_inter = 0.1,
|
||||
edit_mask = None,
|
||||
max_duration=4096,
|
||||
vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
|
||||
no_ref_audio=False,
|
||||
duplicate_test=False,
|
||||
t_inter=0.1,
|
||||
edit_mask=None,
|
||||
):
|
||||
self.eval()
|
||||
|
||||
@@ -108,7 +111,7 @@ class CFM(nn.Module):
|
||||
|
||||
batch, cond_seq_len, device = *cond.shape[:2], cond.device
|
||||
if not exists(lens):
|
||||
lens = torch.full((batch,), cond_seq_len, device = device, dtype = torch.long)
|
||||
lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long)
|
||||
|
||||
# text
|
||||
|
||||
@@ -120,8 +123,8 @@ class CFM(nn.Module):
|
||||
assert text.shape[0] == batch
|
||||
|
||||
if exists(text):
|
||||
text_lens = (text != -1).sum(dim = -1)
|
||||
lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
|
||||
text_lens = (text != -1).sum(dim=-1)
|
||||
lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
|
||||
|
||||
# duration
|
||||
|
||||
@@ -130,20 +133,22 @@ class CFM(nn.Module):
|
||||
cond_mask = cond_mask & edit_mask
|
||||
|
||||
if isinstance(duration, int):
|
||||
duration = torch.full((batch,), duration, device = device, dtype = torch.long)
|
||||
duration = torch.full((batch,), duration, device=device, dtype=torch.long)
|
||||
|
||||
duration = torch.maximum(lens + 1, duration) # just add one token so something is generated
|
||||
duration = duration.clamp(max = max_duration)
|
||||
duration = torch.maximum(lens + 1, duration) # just add one token so something is generated
|
||||
duration = duration.clamp(max=max_duration)
|
||||
max_duration = duration.amax()
|
||||
|
||||
|
||||
# duplicate test corner for inner time step oberservation
|
||||
if duplicate_test:
|
||||
test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2*cond_seq_len), value = 0.)
|
||||
|
||||
cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value = 0.)
|
||||
cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value = False)
|
||||
test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0)
|
||||
|
||||
cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
|
||||
cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False)
|
||||
cond_mask = cond_mask.unsqueeze(-1)
|
||||
step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) # allow direct control (cut cond audio) with lens passed in
|
||||
step_cond = torch.where(
|
||||
cond_mask, cond, torch.zeros_like(cond)
|
||||
) # allow direct control (cut cond audio) with lens passed in
|
||||
|
||||
if batch > 1:
|
||||
mask = lens_to_mask(duration)
|
||||
@@ -161,11 +166,15 @@ class CFM(nn.Module):
|
||||
# step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
|
||||
|
||||
# predict flow
|
||||
pred = self.transformer(x = x, cond = step_cond, text = text, time = t, mask = mask, drop_audio_cond = False, drop_text = False)
|
||||
pred = self.transformer(
|
||||
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False
|
||||
)
|
||||
if cfg_strength < 1e-5:
|
||||
return pred
|
||||
|
||||
null_pred = self.transformer(x = x, cond = step_cond, text = text, time = t, mask = mask, drop_audio_cond = True, drop_text = True)
|
||||
|
||||
null_pred = self.transformer(
|
||||
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True
|
||||
)
|
||||
return pred + (pred - null_pred) * cfg_strength
|
||||
|
||||
# noise input
|
||||
@@ -175,8 +184,8 @@ class CFM(nn.Module):
|
||||
for dur in duration:
|
||||
if exists(seed):
|
||||
torch.manual_seed(seed)
|
||||
y0.append(torch.randn(dur, self.num_channels, device = self.device, dtype=step_cond.dtype))
|
||||
y0 = pad_sequence(y0, padding_value = 0, batch_first = True)
|
||||
y0.append(torch.randn(dur, self.num_channels, device=self.device, dtype=step_cond.dtype))
|
||||
y0 = pad_sequence(y0, padding_value=0, batch_first=True)
|
||||
|
||||
t_start = 0
|
||||
|
||||
@@ -186,12 +195,12 @@ class CFM(nn.Module):
|
||||
y0 = (1 - t_start) * y0 + t_start * test_cond
|
||||
steps = int(steps * (1 - t_start))
|
||||
|
||||
t = torch.linspace(t_start, 1, steps, device = self.device, dtype=step_cond.dtype)
|
||||
t = torch.linspace(t_start, 1, steps, device=self.device, dtype=step_cond.dtype)
|
||||
if sway_sampling_coef is not None:
|
||||
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
|
||||
|
||||
trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
|
||||
|
||||
|
||||
sampled = trajectory[-1]
|
||||
out = sampled
|
||||
out = torch.where(cond_mask, cond, out)
|
||||
@@ -204,10 +213,10 @@ class CFM(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inp: float['b n d'] | float['b nw'], # mel or raw wave
|
||||
text: int['b nt'] | list[str],
|
||||
inp: float["b n d"] | float["b nw"], # mel or raw wave # noqa: F722
|
||||
text: int["b nt"] | list[str], # noqa: F722
|
||||
*,
|
||||
lens: int['b'] | None = None,
|
||||
lens: int["b"] | None = None, # noqa: F821
|
||||
noise_scheduler: str | None = None,
|
||||
):
|
||||
# handle raw wave
|
||||
@@ -216,7 +225,7 @@ class CFM(nn.Module):
|
||||
inp = inp.permute(0, 2, 1)
|
||||
assert inp.shape[-1] == self.num_channels
|
||||
|
||||
batch, seq_len, dtype, device, σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma
|
||||
batch, seq_len, dtype, device, _σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma
|
||||
|
||||
# handle text as string
|
||||
if isinstance(text, list):
|
||||
@@ -228,12 +237,12 @@ class CFM(nn.Module):
|
||||
|
||||
# lens and mask
|
||||
if not exists(lens):
|
||||
lens = torch.full((batch,), seq_len, device = device)
|
||||
|
||||
mask = lens_to_mask(lens, length = seq_len) # useless here, as collate_fn will pad to max length in batch
|
||||
lens = torch.full((batch,), seq_len, device=device)
|
||||
|
||||
mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch
|
||||
|
||||
# get a random span to mask out for training conditionally
|
||||
frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(*self.frac_lengths_mask)
|
||||
frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask)
|
||||
rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
|
||||
|
||||
if exists(mask):
|
||||
@@ -246,7 +255,7 @@ class CFM(nn.Module):
|
||||
x0 = torch.randn_like(x1)
|
||||
|
||||
# time step
|
||||
time = torch.rand((batch,), dtype = dtype, device = self.device)
|
||||
time = torch.rand((batch,), dtype=dtype, device=self.device)
|
||||
# TODO. noise_scheduler
|
||||
|
||||
# sample xt (φ_t(x) in the paper)
|
||||
@@ -255,10 +264,7 @@ class CFM(nn.Module):
|
||||
flow = x1 - x0
|
||||
|
||||
# only predict what is within the random mask span for infilling
|
||||
cond = torch.where(
|
||||
rand_span_mask[..., None],
|
||||
torch.zeros_like(x1), x1
|
||||
)
|
||||
cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1)
|
||||
|
||||
# transformer and cfg training with a drop rate
|
||||
drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper
|
||||
@@ -267,13 +273,15 @@ class CFM(nn.Module):
|
||||
drop_text = True
|
||||
else:
|
||||
drop_text = False
|
||||
|
||||
|
||||
# if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
|
||||
# adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
|
||||
pred = self.transformer(x = φ, cond = cond, text = text, time = time, drop_audio_cond = drop_audio_cond, drop_text = drop_text)
|
||||
pred = self.transformer(
|
||||
x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text
|
||||
)
|
||||
|
||||
# flow matching loss
|
||||
loss = F.mse_loss(pred, flow, reduction = 'none')
|
||||
loss = F.mse_loss(pred, flow, reduction="none")
|
||||
loss = loss[rand_span_mask]
|
||||
|
||||
return loss.mean(), cond, pred
|
||||
|
||||
173
model/dataset.py
173
model/dataset.py
@@ -6,7 +6,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import Dataset, Sampler
|
||||
import torchaudio
|
||||
from datasets import load_dataset, load_from_disk
|
||||
from datasets import load_from_disk
|
||||
from datasets import Dataset as Dataset_
|
||||
|
||||
from model.modules import MelSpec
|
||||
@@ -16,53 +16,55 @@ class HFDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
hf_dataset: Dataset,
|
||||
target_sample_rate = 24_000,
|
||||
n_mel_channels = 100,
|
||||
hop_length = 256,
|
||||
target_sample_rate=24_000,
|
||||
n_mel_channels=100,
|
||||
hop_length=256,
|
||||
):
|
||||
self.data = hf_dataset
|
||||
self.target_sample_rate = target_sample_rate
|
||||
self.hop_length = hop_length
|
||||
self.mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length)
|
||||
|
||||
self.mel_spectrogram = MelSpec(
|
||||
target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length
|
||||
)
|
||||
|
||||
def get_frame_len(self, index):
|
||||
row = self.data[index]
|
||||
audio = row['audio']['array']
|
||||
sample_rate = row['audio']['sampling_rate']
|
||||
audio = row["audio"]["array"]
|
||||
sample_rate = row["audio"]["sampling_rate"]
|
||||
return audio.shape[-1] / sample_rate * self.target_sample_rate / self.hop_length
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
row = self.data[index]
|
||||
audio = row['audio']['array']
|
||||
audio = row["audio"]["array"]
|
||||
|
||||
# logger.info(f"Audio shape: {audio.shape}")
|
||||
|
||||
sample_rate = row['audio']['sampling_rate']
|
||||
sample_rate = row["audio"]["sampling_rate"]
|
||||
duration = audio.shape[-1] / sample_rate
|
||||
|
||||
if duration > 30 or duration < 0.3:
|
||||
return self.__getitem__((index + 1) % len(self.data))
|
||||
|
||||
|
||||
audio_tensor = torch.from_numpy(audio).float()
|
||||
|
||||
|
||||
if sample_rate != self.target_sample_rate:
|
||||
resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
|
||||
audio_tensor = resampler(audio_tensor)
|
||||
|
||||
|
||||
audio_tensor = audio_tensor.unsqueeze(0) # 't -> 1 t')
|
||||
|
||||
|
||||
mel_spec = self.mel_spectrogram(audio_tensor)
|
||||
|
||||
|
||||
mel_spec = mel_spec.squeeze(0) # '1 d t -> d t'
|
||||
|
||||
text = row['text']
|
||||
|
||||
|
||||
text = row["text"]
|
||||
|
||||
return dict(
|
||||
mel_spec = mel_spec,
|
||||
text = text,
|
||||
mel_spec=mel_spec,
|
||||
text=text,
|
||||
)
|
||||
|
||||
|
||||
@@ -70,11 +72,11 @@ class CustomDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
custom_dataset: Dataset,
|
||||
durations = None,
|
||||
target_sample_rate = 24_000,
|
||||
hop_length = 256,
|
||||
n_mel_channels = 100,
|
||||
preprocessed_mel = False,
|
||||
durations=None,
|
||||
target_sample_rate=24_000,
|
||||
hop_length=256,
|
||||
n_mel_channels=100,
|
||||
preprocessed_mel=False,
|
||||
):
|
||||
self.data = custom_dataset
|
||||
self.durations = durations
|
||||
@@ -82,16 +84,20 @@ class CustomDataset(Dataset):
|
||||
self.hop_length = hop_length
|
||||
self.preprocessed_mel = preprocessed_mel
|
||||
if not preprocessed_mel:
|
||||
self.mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, hop_length=hop_length, n_mel_channels=n_mel_channels)
|
||||
self.mel_spectrogram = MelSpec(
|
||||
target_sample_rate=target_sample_rate, hop_length=hop_length, n_mel_channels=n_mel_channels
|
||||
)
|
||||
|
||||
def get_frame_len(self, index):
|
||||
if self.durations is not None: # Please make sure the separately provided durations are correct, otherwise 99.99% OOM
|
||||
if (
|
||||
self.durations is not None
|
||||
): # Please make sure the separately provided durations are correct, otherwise 99.99% OOM
|
||||
return self.durations[index] * self.target_sample_rate / self.hop_length
|
||||
return self.data[index]["duration"] * self.target_sample_rate / self.hop_length
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
row = self.data[index]
|
||||
audio_path = row["audio_path"]
|
||||
@@ -108,45 +114,52 @@ class CustomDataset(Dataset):
|
||||
|
||||
if duration > 30 or duration < 0.3:
|
||||
return self.__getitem__((index + 1) % len(self.data))
|
||||
|
||||
|
||||
if source_sample_rate != self.target_sample_rate:
|
||||
resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
|
||||
audio = resampler(audio)
|
||||
|
||||
|
||||
mel_spec = self.mel_spectrogram(audio)
|
||||
mel_spec = mel_spec.squeeze(0) # '1 d t -> d t')
|
||||
|
||||
|
||||
return dict(
|
||||
mel_spec = mel_spec,
|
||||
text = text,
|
||||
mel_spec=mel_spec,
|
||||
text=text,
|
||||
)
|
||||
|
||||
|
||||
|
||||
# Dynamic Batch Sampler
|
||||
|
||||
|
||||
class DynamicBatchSampler(Sampler[list[int]]):
|
||||
""" Extension of Sampler that will do the following:
|
||||
1. Change the batch size (essentially number of sequences)
|
||||
in a batch to ensure that the total number of frames are less
|
||||
than a certain threshold.
|
||||
2. Make sure the padding efficiency in the batch is high.
|
||||
"""Extension of Sampler that will do the following:
|
||||
1. Change the batch size (essentially number of sequences)
|
||||
in a batch to ensure that the total number of frames are less
|
||||
than a certain threshold.
|
||||
2. Make sure the padding efficiency in the batch is high.
|
||||
"""
|
||||
|
||||
def __init__(self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False):
|
||||
def __init__(
|
||||
self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False
|
||||
):
|
||||
self.sampler = sampler
|
||||
self.frames_threshold = frames_threshold
|
||||
self.max_samples = max_samples
|
||||
|
||||
indices, batches = [], []
|
||||
data_source = self.sampler.data_source
|
||||
|
||||
for idx in tqdm(self.sampler, desc=f"Sorting with sampler... if slow, check whether dataset is provided with duration"):
|
||||
|
||||
for idx in tqdm(
|
||||
self.sampler, desc="Sorting with sampler... if slow, check whether dataset is provided with duration"
|
||||
):
|
||||
indices.append((idx, data_source.get_frame_len(idx)))
|
||||
indices.sort(key=lambda elem : elem[1])
|
||||
indices.sort(key=lambda elem: elem[1])
|
||||
|
||||
batch = []
|
||||
batch_frames = 0
|
||||
for idx, frame_len in tqdm(indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"):
|
||||
for idx, frame_len in tqdm(
|
||||
indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"
|
||||
):
|
||||
if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples):
|
||||
batch.append(idx)
|
||||
batch_frames += frame_len
|
||||
@@ -182,76 +195,86 @@ class DynamicBatchSampler(Sampler[list[int]]):
|
||||
|
||||
# Load dataset
|
||||
|
||||
|
||||
def load_dataset(
|
||||
dataset_name: str,
|
||||
tokenizer: str = "pinyin",
|
||||
dataset_type: str = "CustomDataset",
|
||||
audio_type: str = "raw",
|
||||
mel_spec_kwargs: dict = dict()
|
||||
) -> CustomDataset | HFDataset:
|
||||
'''
|
||||
dataset_name: str,
|
||||
tokenizer: str = "pinyin",
|
||||
dataset_type: str = "CustomDataset",
|
||||
audio_type: str = "raw",
|
||||
mel_spec_kwargs: dict = dict(),
|
||||
) -> CustomDataset | HFDataset:
|
||||
"""
|
||||
dataset_type - "CustomDataset" if you want to use tokenizer name and default data path to load for train_dataset
|
||||
- "CustomDatasetPath" if you just want to pass the full path to a preprocessed dataset without relying on tokenizer
|
||||
'''
|
||||
|
||||
"""
|
||||
|
||||
print("Loading dataset ...")
|
||||
|
||||
if dataset_type == "CustomDataset":
|
||||
if audio_type == "raw":
|
||||
try:
|
||||
train_dataset = load_from_disk(f"data/{dataset_name}_{tokenizer}/raw")
|
||||
except:
|
||||
except: # noqa: E722
|
||||
train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/raw.arrow")
|
||||
preprocessed_mel = False
|
||||
elif audio_type == "mel":
|
||||
train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/mel.arrow")
|
||||
preprocessed_mel = True
|
||||
with open(f"data/{dataset_name}_{tokenizer}/duration.json", 'r', encoding='utf-8') as f:
|
||||
with open(f"data/{dataset_name}_{tokenizer}/duration.json", "r", encoding="utf-8") as f:
|
||||
data_dict = json.load(f)
|
||||
durations = data_dict["duration"]
|
||||
train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs)
|
||||
|
||||
train_dataset = CustomDataset(
|
||||
train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs
|
||||
)
|
||||
|
||||
elif dataset_type == "CustomDatasetPath":
|
||||
try:
|
||||
train_dataset = load_from_disk(f"{dataset_name}/raw")
|
||||
except:
|
||||
except: # noqa: E722
|
||||
train_dataset = Dataset_.from_file(f"{dataset_name}/raw.arrow")
|
||||
|
||||
with open(f"{dataset_name}/duration.json", 'r', encoding='utf-8') as f:
|
||||
|
||||
with open(f"{dataset_name}/duration.json", "r", encoding="utf-8") as f:
|
||||
data_dict = json.load(f)
|
||||
durations = data_dict["duration"]
|
||||
train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs)
|
||||
|
||||
train_dataset = CustomDataset(
|
||||
train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs
|
||||
)
|
||||
|
||||
elif dataset_type == "HFDataset":
|
||||
print("Should manually modify the path of huggingface dataset to your need.\n" +
|
||||
"May also the corresponding script cuz different dataset may have different format.")
|
||||
print(
|
||||
"Should manually modify the path of huggingface dataset to your need.\n"
|
||||
+ "May also the corresponding script cuz different dataset may have different format."
|
||||
)
|
||||
pre, post = dataset_name.split("_")
|
||||
train_dataset = HFDataset(load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir="./data"),)
|
||||
train_dataset = HFDataset(
|
||||
load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir="./data"),
|
||||
)
|
||||
|
||||
return train_dataset
|
||||
|
||||
|
||||
# collation
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
mel_specs = [item['mel_spec'].squeeze(0) for item in batch]
|
||||
mel_specs = [item["mel_spec"].squeeze(0) for item in batch]
|
||||
mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
|
||||
max_mel_length = mel_lengths.amax()
|
||||
|
||||
padded_mel_specs = []
|
||||
for spec in mel_specs: # TODO. maybe records mask for attention here
|
||||
padding = (0, max_mel_length - spec.size(-1))
|
||||
padded_spec = F.pad(spec, padding, value = 0)
|
||||
padded_spec = F.pad(spec, padding, value=0)
|
||||
padded_mel_specs.append(padded_spec)
|
||||
|
||||
|
||||
mel_specs = torch.stack(padded_mel_specs)
|
||||
|
||||
text = [item['text'] for item in batch]
|
||||
text = [item["text"] for item in batch]
|
||||
text_lengths = torch.LongTensor([len(item) for item in text])
|
||||
|
||||
return dict(
|
||||
mel = mel_specs,
|
||||
mel_lengths = mel_lengths,
|
||||
text = text,
|
||||
text_lengths = text_lengths,
|
||||
mel=mel_specs,
|
||||
mel_lengths=mel_lengths,
|
||||
text=text,
|
||||
text_lengths=text_lengths,
|
||||
)
|
||||
|
||||
@@ -9,13 +9,14 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
''' Res2Conv1d + BatchNorm1d + ReLU
|
||||
'''
|
||||
""" Res2Conv1d + BatchNorm1d + ReLU
|
||||
"""
|
||||
|
||||
|
||||
class Res2Conv1dReluBn(nn.Module):
|
||||
'''
|
||||
"""
|
||||
in_channels == out_channels == channels
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
|
||||
super().__init__()
|
||||
@@ -51,8 +52,9 @@ class Res2Conv1dReluBn(nn.Module):
|
||||
return out
|
||||
|
||||
|
||||
''' Conv1d + BatchNorm1d + ReLU
|
||||
'''
|
||||
""" Conv1d + BatchNorm1d + ReLU
|
||||
"""
|
||||
|
||||
|
||||
class Conv1dReluBn(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
|
||||
@@ -64,8 +66,9 @@ class Conv1dReluBn(nn.Module):
|
||||
return self.bn(F.relu(self.conv(x)))
|
||||
|
||||
|
||||
''' The SE connection of 1D case.
|
||||
'''
|
||||
""" The SE connection of 1D case.
|
||||
"""
|
||||
|
||||
|
||||
class SE_Connect(nn.Module):
|
||||
def __init__(self, channels, se_bottleneck_dim=128):
|
||||
@@ -82,8 +85,8 @@ class SE_Connect(nn.Module):
|
||||
return out
|
||||
|
||||
|
||||
''' SE-Res2Block of the ECAPA-TDNN architecture.
|
||||
'''
|
||||
""" SE-Res2Block of the ECAPA-TDNN architecture.
|
||||
"""
|
||||
|
||||
# def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
|
||||
# return nn.Sequential(
|
||||
@@ -93,6 +96,7 @@ class SE_Connect(nn.Module):
|
||||
# SE_Connect(channels)
|
||||
# )
|
||||
|
||||
|
||||
class SE_Res2Block(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
|
||||
super().__init__()
|
||||
@@ -122,8 +126,9 @@ class SE_Res2Block(nn.Module):
|
||||
return x + residual
|
||||
|
||||
|
||||
''' Attentive weighted mean and standard deviation pooling.
|
||||
'''
|
||||
""" Attentive weighted mean and standard deviation pooling.
|
||||
"""
|
||||
|
||||
|
||||
class AttentiveStatsPool(nn.Module):
|
||||
def __init__(self, in_dim, attention_channels=128, global_context_att=False):
|
||||
@@ -138,7 +143,6 @@ class AttentiveStatsPool(nn.Module):
|
||||
self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
if self.global_context_att:
|
||||
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
|
||||
context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
|
||||
@@ -151,38 +155,52 @@ class AttentiveStatsPool(nn.Module):
|
||||
# alpha = F.relu(self.linear1(x_in))
|
||||
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
||||
mean = torch.sum(alpha * x, dim=2)
|
||||
residuals = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2
|
||||
residuals = torch.sum(alpha * (x**2), dim=2) - mean**2
|
||||
std = torch.sqrt(residuals.clamp(min=1e-9))
|
||||
return torch.cat([mean, std], dim=1)
|
||||
|
||||
|
||||
class ECAPA_TDNN(nn.Module):
|
||||
def __init__(self, feat_dim=80, channels=512, emb_dim=192, global_context_att=False,
|
||||
feat_type='wavlm_large', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
|
||||
def __init__(
|
||||
self,
|
||||
feat_dim=80,
|
||||
channels=512,
|
||||
emb_dim=192,
|
||||
global_context_att=False,
|
||||
feat_type="wavlm_large",
|
||||
sr=16000,
|
||||
feature_selection="hidden_states",
|
||||
update_extract=False,
|
||||
config_path=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.feat_type = feat_type
|
||||
self.feature_selection = feature_selection
|
||||
self.update_extract = update_extract
|
||||
self.sr = sr
|
||||
|
||||
torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
|
||||
|
||||
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
|
||||
try:
|
||||
local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main")
|
||||
self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source='local', config_path=config_path)
|
||||
except:
|
||||
self.feature_extract = torch.hub.load('s3prl/s3prl', feat_type)
|
||||
self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source="local", config_path=config_path)
|
||||
except: # noqa: E722
|
||||
self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type)
|
||||
|
||||
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"):
|
||||
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
|
||||
self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"
|
||||
):
|
||||
self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
|
||||
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"):
|
||||
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
|
||||
self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"
|
||||
):
|
||||
self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
|
||||
|
||||
self.feat_num = self.get_feat_num()
|
||||
self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
|
||||
|
||||
if feat_type != 'fbank' and feat_type != 'mfcc':
|
||||
freeze_list = ['final_proj', 'label_embs_concat', 'mask_emb', 'project_q', 'quantizer']
|
||||
if feat_type != "fbank" and feat_type != "mfcc":
|
||||
freeze_list = ["final_proj", "label_embs_concat", "mask_emb", "project_q", "quantizer"]
|
||||
for name, param in self.feature_extract.named_parameters():
|
||||
for freeze_val in freeze_list:
|
||||
if freeze_val in name:
|
||||
@@ -198,18 +216,46 @@ class ECAPA_TDNN(nn.Module):
|
||||
self.channels = [channels] * 4 + [1536]
|
||||
|
||||
self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
|
||||
self.layer2 = SE_Res2Block(self.channels[0], self.channels[1], kernel_size=3, stride=1, padding=2, dilation=2, scale=8, se_bottleneck_dim=128)
|
||||
self.layer3 = SE_Res2Block(self.channels[1], self.channels[2], kernel_size=3, stride=1, padding=3, dilation=3, scale=8, se_bottleneck_dim=128)
|
||||
self.layer4 = SE_Res2Block(self.channels[2], self.channels[3], kernel_size=3, stride=1, padding=4, dilation=4, scale=8, se_bottleneck_dim=128)
|
||||
self.layer2 = SE_Res2Block(
|
||||
self.channels[0],
|
||||
self.channels[1],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=2,
|
||||
dilation=2,
|
||||
scale=8,
|
||||
se_bottleneck_dim=128,
|
||||
)
|
||||
self.layer3 = SE_Res2Block(
|
||||
self.channels[1],
|
||||
self.channels[2],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=3,
|
||||
dilation=3,
|
||||
scale=8,
|
||||
se_bottleneck_dim=128,
|
||||
)
|
||||
self.layer4 = SE_Res2Block(
|
||||
self.channels[2],
|
||||
self.channels[3],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=4,
|
||||
dilation=4,
|
||||
scale=8,
|
||||
se_bottleneck_dim=128,
|
||||
)
|
||||
|
||||
# self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
|
||||
cat_channels = channels * 3
|
||||
self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
|
||||
self.pooling = AttentiveStatsPool(self.channels[-1], attention_channels=128, global_context_att=global_context_att)
|
||||
self.pooling = AttentiveStatsPool(
|
||||
self.channels[-1], attention_channels=128, global_context_att=global_context_att
|
||||
)
|
||||
self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
|
||||
self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
|
||||
|
||||
|
||||
def get_feat_num(self):
|
||||
self.feature_extract.eval()
|
||||
wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
|
||||
@@ -226,12 +272,12 @@ class ECAPA_TDNN(nn.Module):
|
||||
x = self.feature_extract([sample for sample in x])
|
||||
else:
|
||||
with torch.no_grad():
|
||||
if self.feat_type == 'fbank' or self.feat_type == 'mfcc':
|
||||
if self.feat_type == "fbank" or self.feat_type == "mfcc":
|
||||
x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len
|
||||
else:
|
||||
x = self.feature_extract([sample for sample in x])
|
||||
|
||||
if self.feat_type == 'fbank':
|
||||
if self.feat_type == "fbank":
|
||||
x = x.log()
|
||||
|
||||
if self.feat_type != "fbank" and self.feat_type != "mfcc":
|
||||
@@ -263,6 +309,22 @@ class ECAPA_TDNN(nn.Module):
|
||||
return out
|
||||
|
||||
|
||||
def ECAPA_TDNN_SMALL(feat_dim, emb_dim=256, feat_type='wavlm_large', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
|
||||
return ECAPA_TDNN(feat_dim=feat_dim, channels=512, emb_dim=emb_dim,
|
||||
feat_type=feat_type, sr=sr, feature_selection=feature_selection, update_extract=update_extract, config_path=config_path)
|
||||
def ECAPA_TDNN_SMALL(
|
||||
feat_dim,
|
||||
emb_dim=256,
|
||||
feat_type="wavlm_large",
|
||||
sr=16000,
|
||||
feature_selection="hidden_states",
|
||||
update_extract=False,
|
||||
config_path=None,
|
||||
):
|
||||
return ECAPA_TDNN(
|
||||
feat_dim=feat_dim,
|
||||
channels=512,
|
||||
emb_dim=emb_dim,
|
||||
feat_type=feat_type,
|
||||
sr=sr,
|
||||
feature_selection=feature_selection,
|
||||
update_extract=update_extract,
|
||||
config_path=config_path,
|
||||
)
|
||||
|
||||
222
model/modules.py
222
model/modules.py
@@ -21,39 +21,40 @@ from x_transformers.x_transformers import apply_rotary_pos_emb
|
||||
|
||||
# raw wav to mel spec
|
||||
|
||||
|
||||
class MelSpec(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
filter_length = 1024,
|
||||
hop_length = 256,
|
||||
win_length = 1024,
|
||||
n_mel_channels = 100,
|
||||
target_sample_rate = 24_000,
|
||||
normalize = False,
|
||||
power = 1,
|
||||
norm = None,
|
||||
center = True,
|
||||
filter_length=1024,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
n_mel_channels=100,
|
||||
target_sample_rate=24_000,
|
||||
normalize=False,
|
||||
power=1,
|
||||
norm=None,
|
||||
center=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_mel_channels = n_mel_channels
|
||||
|
||||
self.mel_stft = torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate = target_sample_rate,
|
||||
n_fft = filter_length,
|
||||
win_length = win_length,
|
||||
hop_length = hop_length,
|
||||
n_mels = n_mel_channels,
|
||||
power = power,
|
||||
center = center,
|
||||
normalized = normalize,
|
||||
norm = norm,
|
||||
sample_rate=target_sample_rate,
|
||||
n_fft=filter_length,
|
||||
win_length=win_length,
|
||||
hop_length=hop_length,
|
||||
n_mels=n_mel_channels,
|
||||
power=power,
|
||||
center=center,
|
||||
normalized=normalize,
|
||||
norm=norm,
|
||||
)
|
||||
|
||||
self.register_buffer('dummy', torch.tensor(0), persistent = False)
|
||||
self.register_buffer("dummy", torch.tensor(0), persistent=False)
|
||||
|
||||
def forward(self, inp):
|
||||
if len(inp.shape) == 3:
|
||||
inp = inp.squeeze(1) # 'b 1 nw -> b nw'
|
||||
inp = inp.squeeze(1) # 'b 1 nw -> b nw'
|
||||
|
||||
assert len(inp.shape) == 2
|
||||
|
||||
@@ -61,12 +62,13 @@ class MelSpec(nn.Module):
|
||||
self.to(inp.device)
|
||||
|
||||
mel = self.mel_stft(inp)
|
||||
mel = mel.clamp(min = 1e-5).log()
|
||||
mel = mel.clamp(min=1e-5).log()
|
||||
return mel
|
||||
|
||||
|
||||
|
||||
# sinusoidal position embedding
|
||||
|
||||
|
||||
class SinusPositionEmbedding(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
@@ -84,35 +86,37 @@ class SinusPositionEmbedding(nn.Module):
|
||||
|
||||
# convolutional position embedding
|
||||
|
||||
|
||||
class ConvPositionEmbedding(nn.Module):
|
||||
def __init__(self, dim, kernel_size = 31, groups = 16):
|
||||
def __init__(self, dim, kernel_size=31, groups=16):
|
||||
super().__init__()
|
||||
assert kernel_size % 2 != 0
|
||||
self.conv1d = nn.Sequential(
|
||||
nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
|
||||
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
|
||||
nn.Mish(),
|
||||
nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
|
||||
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
|
||||
nn.Mish(),
|
||||
)
|
||||
|
||||
def forward(self, x: float['b n d'], mask: bool['b n'] | None = None):
|
||||
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
|
||||
if mask is not None:
|
||||
mask = mask[..., None]
|
||||
x = x.masked_fill(~mask, 0.)
|
||||
x = x.masked_fill(~mask, 0.0)
|
||||
|
||||
x = x.permute(0, 2, 1)
|
||||
x = self.conv1d(x)
|
||||
out = x.permute(0, 2, 1)
|
||||
|
||||
if mask is not None:
|
||||
out = out.masked_fill(~mask, 0.)
|
||||
out = out.masked_fill(~mask, 0.0)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
# rotary positional embedding related
|
||||
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.):
|
||||
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
|
||||
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
||||
# has some connection to NTK literature
|
||||
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
||||
@@ -125,12 +129,14 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_resca
|
||||
freqs_sin = torch.sin(freqs) # imaginary part
|
||||
return torch.cat([freqs_cos, freqs_sin], dim=-1)
|
||||
|
||||
def get_pos_embed_indices(start, length, max_pos, scale=1.):
|
||||
|
||||
def get_pos_embed_indices(start, length, max_pos, scale=1.0):
|
||||
# length = length if isinstance(length, int) else length.max()
|
||||
scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
|
||||
pos = start.unsqueeze(1) + (
|
||||
torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) *
|
||||
scale.unsqueeze(1)).long()
|
||||
pos = (
|
||||
start.unsqueeze(1)
|
||||
+ (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
|
||||
)
|
||||
# avoid extra long error.
|
||||
pos = torch.where(pos < max_pos, pos, max_pos - 1)
|
||||
return pos
|
||||
@@ -138,6 +144,7 @@ def get_pos_embed_indices(start, length, max_pos, scale=1.):
|
||||
|
||||
# Global Response Normalization layer (Instance Normalization ?)
|
||||
|
||||
|
||||
class GRN(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
@@ -153,6 +160,7 @@ class GRN(nn.Module):
|
||||
# ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
|
||||
# ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
|
||||
|
||||
|
||||
class ConvNeXtV2Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -162,7 +170,9 @@ class ConvNeXtV2Block(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
padding = (dilation * (7 - 1)) // 2
|
||||
self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation) # depthwise conv
|
||||
self.dwconv = nn.Conv1d(
|
||||
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
|
||||
) # depthwise conv
|
||||
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
||||
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
|
||||
self.act = nn.GELU()
|
||||
@@ -185,6 +195,7 @@ class ConvNeXtV2Block(nn.Module):
|
||||
# AdaLayerNormZero
|
||||
# return with modulated x for attn input, and params for later mlp modulation
|
||||
|
||||
|
||||
class AdaLayerNormZero(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
@@ -194,7 +205,7 @@ class AdaLayerNormZero(nn.Module):
|
||||
|
||||
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
def forward(self, x, emb = None):
|
||||
def forward(self, x, emb=None):
|
||||
emb = self.linear(self.silu(emb))
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
|
||||
|
||||
@@ -205,6 +216,7 @@ class AdaLayerNormZero(nn.Module):
|
||||
# AdaLayerNormZero for final layer
|
||||
# return only with modulated x for attn input, cuz no more mlp modulation
|
||||
|
||||
|
||||
class AdaLayerNormZero_Final(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
@@ -224,22 +236,16 @@ class AdaLayerNormZero_Final(nn.Module):
|
||||
|
||||
# FeedForward
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out = None, mult = 4, dropout = 0., approximate: str = 'none'):
|
||||
def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = dim_out if dim_out is not None else dim
|
||||
|
||||
activation = nn.GELU(approximate=approximate)
|
||||
project_in = nn.Sequential(
|
||||
nn.Linear(dim, inner_dim),
|
||||
activation
|
||||
)
|
||||
self.ff = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim_out)
|
||||
)
|
||||
project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
|
||||
self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
|
||||
|
||||
def forward(self, x):
|
||||
return self.ff(x)
|
||||
@@ -248,6 +254,7 @@ class FeedForward(nn.Module):
|
||||
# Attention with possible joint part
|
||||
# modified from diffusers/src/diffusers/models/attention_processor.py
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -256,8 +263,8 @@ class Attention(nn.Module):
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
context_dim: Optional[int] = None, # if not None -> joint attention
|
||||
context_pre_only = None,
|
||||
context_dim: Optional[int] = None, # if not None -> joint attention
|
||||
context_pre_only=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -293,20 +300,21 @@ class Attention(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: float['b n d'], # noised input x
|
||||
c: float['b n d'] = None, # context c
|
||||
mask: bool['b n'] | None = None,
|
||||
rope = None, # rotary position embedding for x
|
||||
c_rope = None, # rotary position embedding for c
|
||||
x: float["b n d"], # noised input x # noqa: F722
|
||||
c: float["b n d"] = None, # context c # noqa: F722
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
rope=None, # rotary position embedding for x
|
||||
c_rope=None, # rotary position embedding for c
|
||||
) -> torch.Tensor:
|
||||
if c is not None:
|
||||
return self.processor(self, x, c = c, mask = mask, rope = rope, c_rope = c_rope)
|
||||
return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
|
||||
else:
|
||||
return self.processor(self, x, mask = mask, rope = rope)
|
||||
return self.processor(self, x, mask=mask, rope=rope)
|
||||
|
||||
|
||||
# Attention processor
|
||||
|
||||
|
||||
class AttnProcessor:
|
||||
def __init__(self):
|
||||
pass
|
||||
@@ -314,11 +322,10 @@ class AttnProcessor:
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
x: float['b n d'], # noised input x
|
||||
mask: bool['b n'] | None = None,
|
||||
rope = None, # rotary position embedding
|
||||
x: float["b n d"], # noised input x # noqa: F722
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
rope=None, # rotary position embedding
|
||||
) -> torch.FloatTensor:
|
||||
|
||||
batch_size = x.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
@@ -329,7 +336,7 @@ class AttnProcessor:
|
||||
# apply rotary position embedding
|
||||
if rope is not None:
|
||||
freqs, xpos_scale = rope
|
||||
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
|
||||
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
||||
|
||||
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
||||
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
||||
@@ -360,14 +367,15 @@ class AttnProcessor:
|
||||
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(-1)
|
||||
x = x.masked_fill(~mask, 0.)
|
||||
x = x.masked_fill(~mask, 0.0)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
||||
# Joint Attention processor for MM-DiT
|
||||
# modified from diffusers/src/diffusers/models/attention_processor.py
|
||||
|
||||
|
||||
class JointAttnProcessor:
|
||||
def __init__(self):
|
||||
pass
|
||||
@@ -375,11 +383,11 @@ class JointAttnProcessor:
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
x: float['b n d'], # noised input x
|
||||
c: float['b nt d'] = None, # context c, here text
|
||||
mask: bool['b n'] | None = None,
|
||||
rope = None, # rotary position embedding for x
|
||||
c_rope = None, # rotary position embedding for c
|
||||
x: float["b n d"], # noised input x # noqa: F722
|
||||
c: float["b nt d"] = None, # context c, here text # noqa: F722
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
rope=None, # rotary position embedding for x
|
||||
c_rope=None, # rotary position embedding for c
|
||||
) -> torch.FloatTensor:
|
||||
residual = x
|
||||
|
||||
@@ -398,12 +406,12 @@ class JointAttnProcessor:
|
||||
# apply rope for context and noised input independently
|
||||
if rope is not None:
|
||||
freqs, xpos_scale = rope
|
||||
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
|
||||
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
||||
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
||||
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
||||
if c_rope is not None:
|
||||
freqs, xpos_scale = c_rope
|
||||
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
|
||||
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
||||
c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
|
||||
c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
|
||||
|
||||
@@ -420,7 +428,7 @@ class JointAttnProcessor:
|
||||
|
||||
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
||||
if mask is not None:
|
||||
attn_mask = F.pad(mask, (0, c.shape[1]), value = True) # no mask for c (text)
|
||||
attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
|
||||
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
||||
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
|
||||
else:
|
||||
@@ -432,8 +440,8 @@ class JointAttnProcessor:
|
||||
|
||||
# Split the attention outputs.
|
||||
x, c = (
|
||||
x[:, :residual.shape[1]],
|
||||
x[:, residual.shape[1]:],
|
||||
x[:, : residual.shape[1]],
|
||||
x[:, residual.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
@@ -445,7 +453,7 @@ class JointAttnProcessor:
|
||||
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(-1)
|
||||
x = x.masked_fill(~mask, 0.)
|
||||
x = x.masked_fill(~mask, 0.0)
|
||||
# c = c.masked_fill(~mask, 0.) # no mask for c (text)
|
||||
|
||||
return x, c
|
||||
@@ -453,24 +461,24 @@ class JointAttnProcessor:
|
||||
|
||||
# DiT Block
|
||||
|
||||
class DiTBlock(nn.Module):
|
||||
|
||||
def __init__(self, dim, heads, dim_head, ff_mult = 4, dropout = 0.1):
|
||||
class DiTBlock(nn.Module):
|
||||
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
|
||||
super().__init__()
|
||||
|
||||
|
||||
self.attn_norm = AdaLayerNormZero(dim)
|
||||
self.attn = Attention(
|
||||
processor = AttnProcessor(),
|
||||
dim = dim,
|
||||
heads = heads,
|
||||
dim_head = dim_head,
|
||||
dropout = dropout,
|
||||
)
|
||||
|
||||
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
|
||||
processor=AttnProcessor(),
|
||||
dim=dim,
|
||||
heads=heads,
|
||||
dim_head=dim_head,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
def forward(self, x, t, mask = None, rope = None): # x: noised input, t: time embedding
|
||||
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
||||
|
||||
def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
|
||||
# pre-norm & modulation for attention input
|
||||
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
|
||||
|
||||
@@ -479,7 +487,7 @@ class DiTBlock(nn.Module):
|
||||
|
||||
# process attention output for input x
|
||||
x = x + gate_msa.unsqueeze(1) * attn_output
|
||||
|
||||
|
||||
norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
ff_output = self.ff(norm)
|
||||
x = x + gate_mlp.unsqueeze(1) * ff_output
|
||||
@@ -489,8 +497,9 @@ class DiTBlock(nn.Module):
|
||||
|
||||
# MMDiT Block https://arxiv.org/abs/2403.03206
|
||||
|
||||
|
||||
class MMDiTBlock(nn.Module):
|
||||
r"""
|
||||
r"""
|
||||
modified from diffusers/src/diffusers/models/attention.py
|
||||
|
||||
notes.
|
||||
@@ -499,33 +508,33 @@ class MMDiTBlock(nn.Module):
|
||||
context_pre_only: last layer only do prenorm + modulation cuz no more ffn
|
||||
"""
|
||||
|
||||
def __init__(self, dim, heads, dim_head, ff_mult = 4, dropout = 0.1, context_pre_only = False):
|
||||
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
|
||||
super().__init__()
|
||||
|
||||
self.context_pre_only = context_pre_only
|
||||
|
||||
|
||||
self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
|
||||
self.attn_norm_x = AdaLayerNormZero(dim)
|
||||
self.attn = Attention(
|
||||
processor = JointAttnProcessor(),
|
||||
dim = dim,
|
||||
heads = heads,
|
||||
dim_head = dim_head,
|
||||
dropout = dropout,
|
||||
context_dim = dim,
|
||||
context_pre_only = context_pre_only,
|
||||
)
|
||||
processor=JointAttnProcessor(),
|
||||
dim=dim,
|
||||
heads=heads,
|
||||
dim_head=dim_head,
|
||||
dropout=dropout,
|
||||
context_dim=dim,
|
||||
context_pre_only=context_pre_only,
|
||||
)
|
||||
|
||||
if not context_pre_only:
|
||||
self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_c = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
|
||||
self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
||||
else:
|
||||
self.ff_norm_c = None
|
||||
self.ff_c = None
|
||||
self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_x = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
|
||||
self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
||||
|
||||
def forward(self, x, c, t, mask = None, rope = None, c_rope = None): # x: noised input, c: context, t: time embedding
|
||||
def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding
|
||||
# pre-norm & modulation for attention input
|
||||
if self.context_pre_only:
|
||||
norm_c = self.attn_norm_c(c, t)
|
||||
@@ -539,7 +548,7 @@ class MMDiTBlock(nn.Module):
|
||||
# process attention output for context c
|
||||
if self.context_pre_only:
|
||||
c = None
|
||||
else: # if not last layer
|
||||
else: # if not last layer
|
||||
c = c + c_gate_msa.unsqueeze(1) * c_attn_output
|
||||
|
||||
norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
||||
@@ -548,7 +557,7 @@ class MMDiTBlock(nn.Module):
|
||||
|
||||
# process attention output for input x
|
||||
x = x + x_gate_msa.unsqueeze(1) * x_attn_output
|
||||
|
||||
|
||||
norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
|
||||
x_ff_output = self.ff_x(norm_x)
|
||||
x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
|
||||
@@ -558,17 +567,14 @@ class MMDiTBlock(nn.Module):
|
||||
|
||||
# time step conditioning embedding
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(self, dim, freq_embed_dim=256):
|
||||
super().__init__()
|
||||
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
|
||||
self.time_mlp = nn.Sequential(
|
||||
nn.Linear(freq_embed_dim, dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim)
|
||||
)
|
||||
self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
||||
|
||||
def forward(self, timestep: float['b']):
|
||||
def forward(self, timestep: float["b"]): # noqa: F821
|
||||
time_hidden = self.time_embed(timestep)
|
||||
time_hidden = time_hidden.to(timestep.dtype)
|
||||
time = self.time_mlp(time_hidden) # b d
|
||||
|
||||
205
model/trainer.py
205
model/trainer.py
@@ -22,71 +22,69 @@ from model.dataset import DynamicBatchSampler, collate_fn
|
||||
|
||||
# trainer
|
||||
|
||||
|
||||
class Trainer:
|
||||
def __init__(
|
||||
self,
|
||||
model: CFM,
|
||||
epochs,
|
||||
learning_rate,
|
||||
num_warmup_updates = 20000,
|
||||
save_per_updates = 1000,
|
||||
checkpoint_path = None,
|
||||
batch_size = 32,
|
||||
num_warmup_updates=20000,
|
||||
save_per_updates=1000,
|
||||
checkpoint_path=None,
|
||||
batch_size=32,
|
||||
batch_size_type: str = "sample",
|
||||
max_samples = 32,
|
||||
grad_accumulation_steps = 1,
|
||||
max_grad_norm = 1.0,
|
||||
max_samples=32,
|
||||
grad_accumulation_steps=1,
|
||||
max_grad_norm=1.0,
|
||||
noise_scheduler: str | None = None,
|
||||
duration_predictor: torch.nn.Module | None = None,
|
||||
wandb_project = "test_e2-tts",
|
||||
wandb_run_name = "test_run",
|
||||
wandb_project="test_e2-tts",
|
||||
wandb_run_name="test_run",
|
||||
wandb_resume_id: str = None,
|
||||
last_per_steps = None,
|
||||
last_per_steps=None,
|
||||
accelerate_kwargs: dict = dict(),
|
||||
ema_kwargs: dict = dict(),
|
||||
bnb_optimizer: bool = False,
|
||||
):
|
||||
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
|
||||
logger = "wandb" if wandb.api.api_key else None
|
||||
print(f"Using logger: {logger}")
|
||||
|
||||
self.accelerator = Accelerator(
|
||||
log_with = logger,
|
||||
kwargs_handlers = [ddp_kwargs],
|
||||
gradient_accumulation_steps = grad_accumulation_steps,
|
||||
**accelerate_kwargs
|
||||
log_with=logger,
|
||||
kwargs_handlers=[ddp_kwargs],
|
||||
gradient_accumulation_steps=grad_accumulation_steps,
|
||||
**accelerate_kwargs,
|
||||
)
|
||||
|
||||
if logger == "wandb":
|
||||
if exists(wandb_resume_id):
|
||||
init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name, 'id': wandb_resume_id}}
|
||||
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}}
|
||||
else:
|
||||
init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name}}
|
||||
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
|
||||
self.accelerator.init_trackers(
|
||||
project_name = wandb_project,
|
||||
project_name=wandb_project,
|
||||
init_kwargs=init_kwargs,
|
||||
config={"epochs": epochs,
|
||||
"learning_rate": learning_rate,
|
||||
"num_warmup_updates": num_warmup_updates,
|
||||
"batch_size": batch_size,
|
||||
"batch_size_type": batch_size_type,
|
||||
"max_samples": max_samples,
|
||||
"grad_accumulation_steps": grad_accumulation_steps,
|
||||
"max_grad_norm": max_grad_norm,
|
||||
"gpus": self.accelerator.num_processes,
|
||||
"noise_scheduler": noise_scheduler}
|
||||
)
|
||||
config={
|
||||
"epochs": epochs,
|
||||
"learning_rate": learning_rate,
|
||||
"num_warmup_updates": num_warmup_updates,
|
||||
"batch_size": batch_size,
|
||||
"batch_size_type": batch_size_type,
|
||||
"max_samples": max_samples,
|
||||
"grad_accumulation_steps": grad_accumulation_steps,
|
||||
"max_grad_norm": max_grad_norm,
|
||||
"gpus": self.accelerator.num_processes,
|
||||
"noise_scheduler": noise_scheduler,
|
||||
},
|
||||
)
|
||||
|
||||
self.model = model
|
||||
|
||||
if self.is_main:
|
||||
self.ema_model = EMA(
|
||||
model,
|
||||
include_online_model = False,
|
||||
**ema_kwargs
|
||||
)
|
||||
self.ema_model = EMA(model, include_online_model=False, **ema_kwargs)
|
||||
|
||||
self.ema_model.to(self.accelerator.device)
|
||||
|
||||
@@ -94,7 +92,7 @@ class Trainer:
|
||||
self.num_warmup_updates = num_warmup_updates
|
||||
self.save_per_updates = save_per_updates
|
||||
self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
|
||||
self.checkpoint_path = default(checkpoint_path, 'ckpts/test_e2-tts')
|
||||
self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.batch_size_type = batch_size_type
|
||||
@@ -108,12 +106,11 @@ class Trainer:
|
||||
|
||||
if bnb_optimizer:
|
||||
import bitsandbytes as bnb
|
||||
|
||||
self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
|
||||
else:
|
||||
self.optimizer = AdamW(model.parameters(), lr=learning_rate)
|
||||
self.model, self.optimizer = self.accelerator.prepare(
|
||||
self.model, self.optimizer
|
||||
)
|
||||
self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
|
||||
|
||||
@property
|
||||
def is_main(self):
|
||||
@@ -123,81 +120,112 @@ class Trainer:
|
||||
self.accelerator.wait_for_everyone()
|
||||
if self.is_main:
|
||||
checkpoint = dict(
|
||||
model_state_dict = self.accelerator.unwrap_model(self.model).state_dict(),
|
||||
optimizer_state_dict = self.accelerator.unwrap_model(self.optimizer).state_dict(),
|
||||
ema_model_state_dict = self.ema_model.state_dict(),
|
||||
scheduler_state_dict = self.scheduler.state_dict(),
|
||||
step = step
|
||||
model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
|
||||
optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
|
||||
ema_model_state_dict=self.ema_model.state_dict(),
|
||||
scheduler_state_dict=self.scheduler.state_dict(),
|
||||
step=step,
|
||||
)
|
||||
if not os.path.exists(self.checkpoint_path):
|
||||
os.makedirs(self.checkpoint_path)
|
||||
if last == True:
|
||||
if last:
|
||||
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
|
||||
print(f"Saved last checkpoint at step {step}")
|
||||
else:
|
||||
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
|
||||
|
||||
def load_checkpoint(self):
|
||||
if not exists(self.checkpoint_path) or not os.path.exists(self.checkpoint_path) or not os.listdir(self.checkpoint_path):
|
||||
if (
|
||||
not exists(self.checkpoint_path)
|
||||
or not os.path.exists(self.checkpoint_path)
|
||||
or not os.listdir(self.checkpoint_path)
|
||||
):
|
||||
return 0
|
||||
|
||||
|
||||
self.accelerator.wait_for_everyone()
|
||||
if "model_last.pt" in os.listdir(self.checkpoint_path):
|
||||
latest_checkpoint = "model_last.pt"
|
||||
else:
|
||||
latest_checkpoint = sorted([f for f in os.listdir(self.checkpoint_path) if f.endswith('.pt')], key=lambda x: int(''.join(filter(str.isdigit, x))))[-1]
|
||||
latest_checkpoint = sorted(
|
||||
[f for f in os.listdir(self.checkpoint_path) if f.endswith(".pt")],
|
||||
key=lambda x: int("".join(filter(str.isdigit, x))),
|
||||
)[-1]
|
||||
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
|
||||
checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
|
||||
|
||||
if self.is_main:
|
||||
self.ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
|
||||
self.ema_model.load_state_dict(checkpoint["ema_model_state_dict"])
|
||||
|
||||
if 'step' in checkpoint:
|
||||
self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
|
||||
self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
if "step" in checkpoint:
|
||||
self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
|
||||
self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"])
|
||||
if self.scheduler:
|
||||
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
step = checkpoint['step']
|
||||
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
||||
step = checkpoint["step"]
|
||||
else:
|
||||
checkpoint['model_state_dict'] = {k.replace("ema_model.", ""): v for k, v in checkpoint['ema_model_state_dict'].items() if k not in ["initted", "step"]}
|
||||
self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
|
||||
checkpoint["model_state_dict"] = {
|
||||
k.replace("ema_model.", ""): v
|
||||
for k, v in checkpoint["ema_model_state_dict"].items()
|
||||
if k not in ["initted", "step"]
|
||||
}
|
||||
self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
|
||||
step = 0
|
||||
|
||||
del checkpoint; gc.collect()
|
||||
del checkpoint
|
||||
gc.collect()
|
||||
return step
|
||||
|
||||
def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
|
||||
|
||||
if exists(resumable_with_seed):
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(resumable_with_seed)
|
||||
else:
|
||||
else:
|
||||
generator = None
|
||||
|
||||
if self.batch_size_type == "sample":
|
||||
train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, persistent_workers=True,
|
||||
batch_size=self.batch_size, shuffle=True, generator=generator)
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
collate_fn=collate_fn,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
persistent_workers=True,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=True,
|
||||
generator=generator,
|
||||
)
|
||||
elif self.batch_size_type == "frame":
|
||||
self.accelerator.even_batches = False
|
||||
sampler = SequentialSampler(train_dataset)
|
||||
batch_sampler = DynamicBatchSampler(sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False)
|
||||
train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, persistent_workers=True,
|
||||
batch_sampler=batch_sampler)
|
||||
batch_sampler = DynamicBatchSampler(
|
||||
sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False
|
||||
)
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
collate_fn=collate_fn,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
persistent_workers=True,
|
||||
batch_sampler=batch_sampler,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}")
|
||||
|
||||
|
||||
# accelerator.prepare() dispatches batches to devices;
|
||||
# which means the length of dataloader calculated before, should consider the number of devices
|
||||
warmup_steps = self.num_warmup_updates * self.accelerator.num_processes # consider a fixed warmup steps while using accelerate multi-gpu ddp
|
||||
# otherwise by default with split_batches=False, warmup steps change with num_processes
|
||||
warmup_steps = (
|
||||
self.num_warmup_updates * self.accelerator.num_processes
|
||||
) # consider a fixed warmup steps while using accelerate multi-gpu ddp
|
||||
# otherwise by default with split_batches=False, warmup steps change with num_processes
|
||||
total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
|
||||
decay_steps = total_steps - warmup_steps
|
||||
warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
|
||||
decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
|
||||
self.scheduler = SequentialLR(self.optimizer,
|
||||
schedulers=[warmup_scheduler, decay_scheduler],
|
||||
milestones=[warmup_steps])
|
||||
train_dataloader, self.scheduler = self.accelerator.prepare(train_dataloader, self.scheduler) # actual steps = 1 gpu steps / gpus
|
||||
self.scheduler = SequentialLR(
|
||||
self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_steps]
|
||||
)
|
||||
train_dataloader, self.scheduler = self.accelerator.prepare(
|
||||
train_dataloader, self.scheduler
|
||||
) # actual steps = 1 gpu steps / gpus
|
||||
start_step = self.load_checkpoint()
|
||||
global_step = start_step
|
||||
|
||||
@@ -212,23 +240,36 @@ class Trainer:
|
||||
for epoch in range(skipped_epoch, self.epochs):
|
||||
self.model.train()
|
||||
if exists(resumable_with_seed) and epoch == skipped_epoch:
|
||||
progress_bar = tqdm(skipped_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process,
|
||||
initial=skipped_batch, total=orig_epoch_step)
|
||||
progress_bar = tqdm(
|
||||
skipped_dataloader,
|
||||
desc=f"Epoch {epoch+1}/{self.epochs}",
|
||||
unit="step",
|
||||
disable=not self.accelerator.is_local_main_process,
|
||||
initial=skipped_batch,
|
||||
total=orig_epoch_step,
|
||||
)
|
||||
else:
|
||||
progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process)
|
||||
progress_bar = tqdm(
|
||||
train_dataloader,
|
||||
desc=f"Epoch {epoch+1}/{self.epochs}",
|
||||
unit="step",
|
||||
disable=not self.accelerator.is_local_main_process,
|
||||
)
|
||||
|
||||
for batch in progress_bar:
|
||||
with self.accelerator.accumulate(self.model):
|
||||
text_inputs = batch['text']
|
||||
mel_spec = batch['mel'].permute(0, 2, 1)
|
||||
text_inputs = batch["text"]
|
||||
mel_spec = batch["mel"].permute(0, 2, 1)
|
||||
mel_lengths = batch["mel_lengths"]
|
||||
|
||||
# TODO. add duration predictor training
|
||||
if self.duration_predictor is not None and self.accelerator.is_local_main_process:
|
||||
dur_loss = self.duration_predictor(mel_spec, lens=batch.get('durations'))
|
||||
dur_loss = self.duration_predictor(mel_spec, lens=batch.get("durations"))
|
||||
self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step)
|
||||
|
||||
loss, cond, pred = self.model(mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler)
|
||||
loss, cond, pred = self.model(
|
||||
mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler
|
||||
)
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
|
||||
@@ -245,13 +286,13 @@ class Trainer:
|
||||
|
||||
if self.accelerator.is_local_main_process:
|
||||
self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
|
||||
|
||||
|
||||
progress_bar.set_postfix(step=str(global_step), loss=loss.item())
|
||||
|
||||
|
||||
if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
|
||||
self.save_checkpoint(global_step)
|
||||
|
||||
|
||||
if global_step % self.last_per_steps == 0:
|
||||
self.save_checkpoint(global_step, last=True)
|
||||
|
||||
|
||||
self.accelerator.end_training()
|
||||
|
||||
314
model/utils.py
314
model/utils.py
@@ -8,6 +8,7 @@ from tqdm import tqdm
|
||||
from collections import defaultdict
|
||||
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pylab as plt
|
||||
|
||||
@@ -25,109 +26,102 @@ from model.modules import MelSpec
|
||||
|
||||
# seed everything
|
||||
|
||||
def seed_everything(seed = 0):
|
||||
|
||||
def seed_everything(seed=0):
|
||||
random.seed(seed)
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
# helpers
|
||||
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
|
||||
def default(v, d):
|
||||
return v if exists(v) else d
|
||||
|
||||
|
||||
# tensor helpers
|
||||
|
||||
def lens_to_mask(
|
||||
t: int['b'],
|
||||
length: int | None = None
|
||||
) -> bool['b n']:
|
||||
|
||||
def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821
|
||||
if not exists(length):
|
||||
length = t.amax()
|
||||
|
||||
seq = torch.arange(length, device = t.device)
|
||||
seq = torch.arange(length, device=t.device)
|
||||
return seq[None, :] < t[:, None]
|
||||
|
||||
def mask_from_start_end_indices(
|
||||
seq_len: int['b'],
|
||||
start: int['b'],
|
||||
end: int['b']
|
||||
):
|
||||
max_seq_len = seq_len.max().item()
|
||||
seq = torch.arange(max_seq_len, device = start.device).long()
|
||||
|
||||
def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]): # noqa: F722 F821
|
||||
max_seq_len = seq_len.max().item()
|
||||
seq = torch.arange(max_seq_len, device=start.device).long()
|
||||
start_mask = seq[None, :] >= start[:, None]
|
||||
end_mask = seq[None, :] < end[:, None]
|
||||
return start_mask & end_mask
|
||||
|
||||
def mask_from_frac_lengths(
|
||||
seq_len: int['b'],
|
||||
frac_lengths: float['b']
|
||||
):
|
||||
|
||||
def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa: F722 F821
|
||||
lengths = (frac_lengths * seq_len).long()
|
||||
max_start = seq_len - lengths
|
||||
|
||||
rand = torch.rand_like(frac_lengths)
|
||||
start = (max_start * rand).long().clamp(min = 0)
|
||||
start = (max_start * rand).long().clamp(min=0)
|
||||
end = start + lengths
|
||||
|
||||
return mask_from_start_end_indices(seq_len, start, end)
|
||||
|
||||
def maybe_masked_mean(
|
||||
t: float['b n d'],
|
||||
mask: bool['b n'] = None
|
||||
) -> float['b d']:
|
||||
|
||||
def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]: # noqa: F722
|
||||
if not exists(mask):
|
||||
return t.mean(dim = 1)
|
||||
return t.mean(dim=1)
|
||||
|
||||
t = torch.where(mask[:, :, None], t, torch.tensor(0., device=t.device))
|
||||
t = torch.where(mask[:, :, None], t, torch.tensor(0.0, device=t.device))
|
||||
num = t.sum(dim=1)
|
||||
den = mask.float().sum(dim=1)
|
||||
|
||||
return num / den.clamp(min=1.)
|
||||
return num / den.clamp(min=1.0)
|
||||
|
||||
|
||||
# simple utf-8 tokenizer, since paper went character based
|
||||
def list_str_to_tensor(
|
||||
text: list[str],
|
||||
padding_value = -1
|
||||
) -> int['b nt']:
|
||||
list_tensors = [torch.tensor([*bytes(t, 'UTF-8')]) for t in text] # ByT5 style
|
||||
text = pad_sequence(list_tensors, padding_value = padding_value, batch_first = True)
|
||||
def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722
|
||||
list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style
|
||||
text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True)
|
||||
return text
|
||||
|
||||
|
||||
# char tokenizer, based on custom dataset's extracted .txt file
|
||||
def list_str_to_idx(
|
||||
text: list[str] | list[list[str]],
|
||||
vocab_char_map: dict[str, int], # {char: idx}
|
||||
padding_value = -1
|
||||
) -> int['b nt']:
|
||||
padding_value=-1,
|
||||
) -> int["b nt"]: # noqa: F722
|
||||
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
|
||||
text = pad_sequence(list_idx_tensors, padding_value = padding_value, batch_first = True)
|
||||
text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
|
||||
return text
|
||||
|
||||
|
||||
# Get tokenizer
|
||||
|
||||
|
||||
def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
|
||||
'''
|
||||
"""
|
||||
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
|
||||
- "char" for char-wise tokenizer, need .txt vocab_file
|
||||
- "byte" for utf-8 tokenizer
|
||||
- "custom" if you're directly passing in a path to the vocab.txt you want to use
|
||||
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
|
||||
- if use "char", derived from unfiltered character & symbol counts of custom dataset
|
||||
- if use "byte", set to 256 (unicode byte range)
|
||||
'''
|
||||
- if use "byte", set to 256 (unicode byte range)
|
||||
"""
|
||||
if tokenizer in ["pinyin", "char"]:
|
||||
with open (f"data/{dataset_name}_{tokenizer}/vocab.txt", "r", encoding="utf-8") as f:
|
||||
with open(f"data/{dataset_name}_{tokenizer}/vocab.txt", "r", encoding="utf-8") as f:
|
||||
vocab_char_map = {}
|
||||
for i, char in enumerate(f):
|
||||
vocab_char_map[char[:-1]] = i
|
||||
@@ -138,7 +132,7 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
|
||||
vocab_char_map = None
|
||||
vocab_size = 256
|
||||
elif tokenizer == "custom":
|
||||
with open (dataset_name, "r", encoding="utf-8") as f:
|
||||
with open(dataset_name, "r", encoding="utf-8") as f:
|
||||
vocab_char_map = {}
|
||||
for i, char in enumerate(f):
|
||||
vocab_char_map[char[:-1]] = i
|
||||
@@ -149,16 +143,19 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
|
||||
|
||||
# convert char to pinyin
|
||||
|
||||
def convert_char_to_pinyin(text_list, polyphone = True):
|
||||
|
||||
def convert_char_to_pinyin(text_list, polyphone=True):
|
||||
final_text_list = []
|
||||
god_knows_why_en_testset_contains_zh_quote = str.maketrans({'“': '"', '”': '"', '‘': "'", '’': "'"}) # in case librispeech (orig no-pc) test-clean
|
||||
custom_trans = str.maketrans({';': ','}) # add custom trans here, to address oov
|
||||
god_knows_why_en_testset_contains_zh_quote = str.maketrans(
|
||||
{"“": '"', "”": '"', "‘": "'", "’": "'"}
|
||||
) # in case librispeech (orig no-pc) test-clean
|
||||
custom_trans = str.maketrans({";": ","}) # add custom trans here, to address oov
|
||||
for text in text_list:
|
||||
char_list = []
|
||||
text = text.translate(god_knows_why_en_testset_contains_zh_quote)
|
||||
text = text.translate(custom_trans)
|
||||
for seg in jieba.cut(text):
|
||||
seg_byte_len = len(bytes(seg, 'UTF-8'))
|
||||
seg_byte_len = len(bytes(seg, "UTF-8"))
|
||||
if seg_byte_len == len(seg): # if pure alphabets and symbols
|
||||
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
|
||||
char_list.append(" ")
|
||||
@@ -187,7 +184,7 @@ def convert_char_to_pinyin(text_list, polyphone = True):
|
||||
# save spectrogram
|
||||
def save_spectrogram(spectrogram, path):
|
||||
plt.figure(figsize=(12, 4))
|
||||
plt.imshow(spectrogram, origin='lower', aspect='auto')
|
||||
plt.imshow(spectrogram, origin="lower", aspect="auto")
|
||||
plt.colorbar()
|
||||
plt.savefig(path)
|
||||
plt.close()
|
||||
@@ -195,13 +192,15 @@ def save_spectrogram(spectrogram, path):
|
||||
|
||||
# seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
|
||||
def get_seedtts_testset_metainfo(metalst):
|
||||
f = open(metalst); lines = f.readlines(); f.close()
|
||||
f = open(metalst)
|
||||
lines = f.readlines()
|
||||
f.close()
|
||||
metainfo = []
|
||||
for line in lines:
|
||||
if len(line.strip().split('|')) == 5:
|
||||
utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|')
|
||||
elif len(line.strip().split('|')) == 4:
|
||||
utt, prompt_text, prompt_wav, gt_text = line.strip().split('|')
|
||||
if len(line.strip().split("|")) == 5:
|
||||
utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
|
||||
elif len(line.strip().split("|")) == 4:
|
||||
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
|
||||
gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
|
||||
if not os.path.isabs(prompt_wav):
|
||||
prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
|
||||
@@ -211,18 +210,20 @@ def get_seedtts_testset_metainfo(metalst):
|
||||
|
||||
# librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav
|
||||
def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path):
|
||||
f = open(metalst); lines = f.readlines(); f.close()
|
||||
f = open(metalst)
|
||||
lines = f.readlines()
|
||||
f.close()
|
||||
metainfo = []
|
||||
for line in lines:
|
||||
ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t')
|
||||
ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t")
|
||||
|
||||
# ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
|
||||
ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-')
|
||||
ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac')
|
||||
ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
|
||||
ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac")
|
||||
|
||||
# gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
|
||||
gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-')
|
||||
gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac')
|
||||
gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
|
||||
gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac")
|
||||
|
||||
metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))
|
||||
|
||||
@@ -234,7 +235,7 @@ def padded_mel_batch(ref_mels):
|
||||
max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
|
||||
padded_ref_mels = []
|
||||
for mel in ref_mels:
|
||||
padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value = 0)
|
||||
padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0)
|
||||
padded_ref_mels.append(padded_ref_mel)
|
||||
padded_ref_mels = torch.stack(padded_ref_mels)
|
||||
padded_ref_mels = padded_ref_mels.permute(0, 2, 1)
|
||||
@@ -243,12 +244,21 @@ def padded_mel_batch(ref_mels):
|
||||
|
||||
# get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav
|
||||
|
||||
|
||||
def get_inference_prompt(
|
||||
metainfo,
|
||||
speed = 1., tokenizer = "pinyin", polyphone = True,
|
||||
target_sample_rate = 24000, n_mel_channels = 100, hop_length = 256, target_rms = 0.1,
|
||||
use_truth_duration = False,
|
||||
infer_batch_size = 1, num_buckets = 200, min_secs = 3, max_secs = 40,
|
||||
metainfo,
|
||||
speed=1.0,
|
||||
tokenizer="pinyin",
|
||||
polyphone=True,
|
||||
target_sample_rate=24000,
|
||||
n_mel_channels=100,
|
||||
hop_length=256,
|
||||
target_rms=0.1,
|
||||
use_truth_duration=False,
|
||||
infer_batch_size=1,
|
||||
num_buckets=200,
|
||||
min_secs=3,
|
||||
max_secs=40,
|
||||
):
|
||||
prompts_all = []
|
||||
|
||||
@@ -256,13 +266,15 @@ def get_inference_prompt(
|
||||
max_tokens = max_secs * target_sample_rate // hop_length
|
||||
|
||||
batch_accum = [0] * num_buckets
|
||||
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = \
|
||||
([[] for _ in range(num_buckets)] for _ in range(6))
|
||||
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = (
|
||||
[[] for _ in range(num_buckets)] for _ in range(6)
|
||||
)
|
||||
|
||||
mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length)
|
||||
mel_spectrogram = MelSpec(
|
||||
target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length
|
||||
)
|
||||
|
||||
for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
|
||||
|
||||
# Audio
|
||||
ref_audio, ref_sr = torchaudio.load(prompt_wav)
|
||||
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
|
||||
@@ -274,11 +286,11 @@ def get_inference_prompt(
|
||||
ref_audio = resampler(ref_audio)
|
||||
|
||||
# Text
|
||||
if len(prompt_text[-1].encode('utf-8')) == 1:
|
||||
if len(prompt_text[-1].encode("utf-8")) == 1:
|
||||
prompt_text = prompt_text + " "
|
||||
text = [prompt_text + gt_text]
|
||||
if tokenizer == "pinyin":
|
||||
text_list = convert_char_to_pinyin(text, polyphone = polyphone)
|
||||
text_list = convert_char_to_pinyin(text, polyphone=polyphone)
|
||||
else:
|
||||
text_list = text
|
||||
|
||||
@@ -294,8 +306,8 @@ def get_inference_prompt(
|
||||
# # test vocoder resynthesis
|
||||
# ref_audio = gt_audio
|
||||
else:
|
||||
ref_text_len = len(prompt_text.encode('utf-8'))
|
||||
gen_text_len = len(gt_text.encode('utf-8'))
|
||||
ref_text_len = len(prompt_text.encode("utf-8"))
|
||||
gen_text_len = len(gt_text.encode("utf-8"))
|
||||
total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
|
||||
|
||||
# to mel spectrogram
|
||||
@@ -304,8 +316,9 @@ def get_inference_prompt(
|
||||
|
||||
# deal with batch
|
||||
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
|
||||
assert min_tokens <= total_mel_len <= max_tokens, \
|
||||
f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
|
||||
assert (
|
||||
min_tokens <= total_mel_len <= max_tokens
|
||||
), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
|
||||
bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
|
||||
|
||||
utts[bucket_i].append(utt)
|
||||
@@ -319,28 +332,39 @@ def get_inference_prompt(
|
||||
|
||||
if batch_accum[bucket_i] >= infer_batch_size:
|
||||
# print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
|
||||
prompts_all.append((
|
||||
utts[bucket_i],
|
||||
ref_rms_list[bucket_i],
|
||||
padded_mel_batch(ref_mels[bucket_i]),
|
||||
ref_mel_lens[bucket_i],
|
||||
total_mel_lens[bucket_i],
|
||||
final_text_list[bucket_i]
|
||||
))
|
||||
prompts_all.append(
|
||||
(
|
||||
utts[bucket_i],
|
||||
ref_rms_list[bucket_i],
|
||||
padded_mel_batch(ref_mels[bucket_i]),
|
||||
ref_mel_lens[bucket_i],
|
||||
total_mel_lens[bucket_i],
|
||||
final_text_list[bucket_i],
|
||||
)
|
||||
)
|
||||
batch_accum[bucket_i] = 0
|
||||
utts[bucket_i], ref_rms_list[bucket_i], ref_mels[bucket_i], ref_mel_lens[bucket_i], total_mel_lens[bucket_i], final_text_list[bucket_i] = [], [], [], [], [], []
|
||||
(
|
||||
utts[bucket_i],
|
||||
ref_rms_list[bucket_i],
|
||||
ref_mels[bucket_i],
|
||||
ref_mel_lens[bucket_i],
|
||||
total_mel_lens[bucket_i],
|
||||
final_text_list[bucket_i],
|
||||
) = [], [], [], [], [], []
|
||||
|
||||
# add residual
|
||||
for bucket_i, bucket_frames in enumerate(batch_accum):
|
||||
if bucket_frames > 0:
|
||||
prompts_all.append((
|
||||
utts[bucket_i],
|
||||
ref_rms_list[bucket_i],
|
||||
padded_mel_batch(ref_mels[bucket_i]),
|
||||
ref_mel_lens[bucket_i],
|
||||
total_mel_lens[bucket_i],
|
||||
final_text_list[bucket_i]
|
||||
))
|
||||
prompts_all.append(
|
||||
(
|
||||
utts[bucket_i],
|
||||
ref_rms_list[bucket_i],
|
||||
padded_mel_batch(ref_mels[bucket_i]),
|
||||
ref_mel_lens[bucket_i],
|
||||
total_mel_lens[bucket_i],
|
||||
final_text_list[bucket_i],
|
||||
)
|
||||
)
|
||||
# not only leave easy work for last workers
|
||||
random.seed(666)
|
||||
random.shuffle(prompts_all)
|
||||
@@ -351,6 +375,7 @@ def get_inference_prompt(
|
||||
# get wav_res_ref_text of seed-tts test metalst
|
||||
# https://github.com/BytedanceSpeech/seed-tts-eval
|
||||
|
||||
|
||||
def get_seed_tts_test(metalst, gen_wav_dir, gpus):
|
||||
f = open(metalst)
|
||||
lines = f.readlines()
|
||||
@@ -358,14 +383,14 @@ def get_seed_tts_test(metalst, gen_wav_dir, gpus):
|
||||
|
||||
test_set_ = []
|
||||
for line in tqdm(lines):
|
||||
if len(line.strip().split('|')) == 5:
|
||||
utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|')
|
||||
elif len(line.strip().split('|')) == 4:
|
||||
utt, prompt_text, prompt_wav, gt_text = line.strip().split('|')
|
||||
if len(line.strip().split("|")) == 5:
|
||||
utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
|
||||
elif len(line.strip().split("|")) == 4:
|
||||
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
|
||||
|
||||
if not os.path.exists(os.path.join(gen_wav_dir, utt + '.wav')):
|
||||
if not os.path.exists(os.path.join(gen_wav_dir, utt + ".wav")):
|
||||
continue
|
||||
gen_wav = os.path.join(gen_wav_dir, utt + '.wav')
|
||||
gen_wav = os.path.join(gen_wav_dir, utt + ".wav")
|
||||
if not os.path.isabs(prompt_wav):
|
||||
prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
|
||||
|
||||
@@ -374,65 +399,69 @@ def get_seed_tts_test(metalst, gen_wav_dir, gpus):
|
||||
num_jobs = len(gpus)
|
||||
if num_jobs == 1:
|
||||
return [(gpus[0], test_set_)]
|
||||
|
||||
|
||||
wav_per_job = len(test_set_) // num_jobs + 1
|
||||
test_set = []
|
||||
for i in range(num_jobs):
|
||||
test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
|
||||
test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job]))
|
||||
|
||||
return test_set
|
||||
|
||||
|
||||
# get librispeech test-clean cross sentence test
|
||||
|
||||
def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = False):
|
||||
|
||||
def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth=False):
|
||||
f = open(metalst)
|
||||
lines = f.readlines()
|
||||
f.close()
|
||||
|
||||
test_set_ = []
|
||||
for line in tqdm(lines):
|
||||
ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t')
|
||||
ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t")
|
||||
|
||||
if eval_ground_truth:
|
||||
gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-')
|
||||
gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac')
|
||||
gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
|
||||
gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac")
|
||||
else:
|
||||
if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + '.wav')):
|
||||
if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + ".wav")):
|
||||
raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
|
||||
gen_wav = os.path.join(gen_wav_dir, gen_utt + '.wav')
|
||||
gen_wav = os.path.join(gen_wav_dir, gen_utt + ".wav")
|
||||
|
||||
ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-')
|
||||
ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac')
|
||||
ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
|
||||
ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac")
|
||||
|
||||
test_set_.append((gen_wav, ref_wav, gen_txt))
|
||||
|
||||
num_jobs = len(gpus)
|
||||
if num_jobs == 1:
|
||||
return [(gpus[0], test_set_)]
|
||||
|
||||
|
||||
wav_per_job = len(test_set_) // num_jobs + 1
|
||||
test_set = []
|
||||
for i in range(num_jobs):
|
||||
test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
|
||||
test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job]))
|
||||
|
||||
return test_set
|
||||
|
||||
|
||||
# load asr model
|
||||
|
||||
def load_asr_model(lang, ckpt_dir = ""):
|
||||
|
||||
def load_asr_model(lang, ckpt_dir=""):
|
||||
if lang == "zh":
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(
|
||||
model = os.path.join(ckpt_dir, "paraformer-zh"),
|
||||
# vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
|
||||
model=os.path.join(ckpt_dir, "paraformer-zh"),
|
||||
# vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
|
||||
# punc_model = os.path.join(ckpt_dir, "ct-punc"),
|
||||
# spk_model = os.path.join(ckpt_dir, "cam++"),
|
||||
# spk_model = os.path.join(ckpt_dir, "cam++"),
|
||||
disable_update=True,
|
||||
) # following seed-tts setting
|
||||
) # following seed-tts setting
|
||||
elif lang == "en":
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
|
||||
model = WhisperModel(model_size, device="cuda", compute_type="float16")
|
||||
return model
|
||||
@@ -440,44 +469,50 @@ def load_asr_model(lang, ckpt_dir = ""):
|
||||
|
||||
# WER Evaluation, the way Seed-TTS does
|
||||
|
||||
|
||||
def run_asr_wer(args):
|
||||
rank, lang, test_set, ckpt_dir = args
|
||||
|
||||
if lang == "zh":
|
||||
import zhconv
|
||||
|
||||
torch.cuda.set_device(rank)
|
||||
elif lang == "en":
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
|
||||
else:
|
||||
raise NotImplementedError("lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now.")
|
||||
raise NotImplementedError(
|
||||
"lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now."
|
||||
)
|
||||
|
||||
asr_model = load_asr_model(lang, ckpt_dir=ckpt_dir)
|
||||
|
||||
asr_model = load_asr_model(lang, ckpt_dir = ckpt_dir)
|
||||
|
||||
from zhon.hanzi import punctuation
|
||||
|
||||
punctuation_all = punctuation + string.punctuation
|
||||
wers = []
|
||||
|
||||
from jiwer import compute_measures
|
||||
|
||||
for gen_wav, prompt_wav, truth in tqdm(test_set):
|
||||
if lang == "zh":
|
||||
res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
|
||||
hypo = res[0]["text"]
|
||||
hypo = zhconv.convert(hypo, 'zh-cn')
|
||||
hypo = zhconv.convert(hypo, "zh-cn")
|
||||
elif lang == "en":
|
||||
segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en")
|
||||
hypo = ''
|
||||
hypo = ""
|
||||
for segment in segments:
|
||||
hypo = hypo + ' ' + segment.text
|
||||
hypo = hypo + " " + segment.text
|
||||
|
||||
# raw_truth = truth
|
||||
# raw_hypo = hypo
|
||||
|
||||
for x in punctuation_all:
|
||||
truth = truth.replace(x, '')
|
||||
hypo = hypo.replace(x, '')
|
||||
truth = truth.replace(x, "")
|
||||
hypo = hypo.replace(x, "")
|
||||
|
||||
truth = truth.replace(' ', ' ')
|
||||
hypo = hypo.replace(' ', ' ')
|
||||
truth = truth.replace(" ", " ")
|
||||
hypo = hypo.replace(" ", " ")
|
||||
|
||||
if lang == "zh":
|
||||
truth = " ".join([x for x in truth])
|
||||
@@ -501,22 +536,22 @@ def run_asr_wer(args):
|
||||
|
||||
# SIM Evaluation
|
||||
|
||||
|
||||
def run_sim(args):
|
||||
rank, test_set, ckpt_dir = args
|
||||
device = f"cuda:{rank}"
|
||||
|
||||
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=None)
|
||||
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type="wavlm_large", config_path=None)
|
||||
state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage)
|
||||
model.load_state_dict(state_dict['model'], strict=False)
|
||||
model.load_state_dict(state_dict["model"], strict=False)
|
||||
|
||||
use_gpu=True if torch.cuda.is_available() else False
|
||||
use_gpu = True if torch.cuda.is_available() else False
|
||||
if use_gpu:
|
||||
model = model.cuda(device)
|
||||
model.eval()
|
||||
|
||||
sim_list = []
|
||||
for wav1, wav2, truth in tqdm(test_set):
|
||||
|
||||
wav1, sr1 = torchaudio.load(wav1)
|
||||
wav2, sr2 = torchaudio.load(wav2)
|
||||
|
||||
@@ -531,20 +566,21 @@ def run_sim(args):
|
||||
with torch.no_grad():
|
||||
emb1 = model(wav1)
|
||||
emb2 = model(wav2)
|
||||
|
||||
|
||||
sim = F.cosine_similarity(emb1, emb2)[0].item()
|
||||
# print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
|
||||
sim_list.append(sim)
|
||||
|
||||
|
||||
return sim_list
|
||||
|
||||
|
||||
# filter func for dirty data with many repetitions
|
||||
|
||||
def repetition_found(text, length = 2, tolerance = 10):
|
||||
|
||||
def repetition_found(text, length=2, tolerance=10):
|
||||
pattern_count = defaultdict(int)
|
||||
for i in range(len(text) - length + 1):
|
||||
pattern = text[i:i + length]
|
||||
pattern = text[i : i + length]
|
||||
pattern_count[pattern] += 1
|
||||
for pattern, count in pattern_count.items():
|
||||
if count > tolerance:
|
||||
@@ -554,25 +590,31 @@ def repetition_found(text, length = 2, tolerance = 10):
|
||||
|
||||
# load model checkpoint for inference
|
||||
|
||||
def load_checkpoint(model, ckpt_path, device, use_ema = True):
|
||||
|
||||
def load_checkpoint(model, ckpt_path, device, use_ema=True):
|
||||
if device == "cuda":
|
||||
model = model.half()
|
||||
|
||||
ckpt_type = ckpt_path.split(".")[-1]
|
||||
if ckpt_type == "safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
checkpoint = load_file(ckpt_path)
|
||||
else:
|
||||
checkpoint = torch.load(ckpt_path, weights_only=True)
|
||||
|
||||
if use_ema:
|
||||
if ckpt_type == "safetensors":
|
||||
checkpoint = {'ema_model_state_dict': checkpoint}
|
||||
checkpoint['model_state_dict'] = {k.replace("ema_model.", ""): v for k, v in checkpoint['ema_model_state_dict'].items() if k not in ["initted", "step"]}
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
checkpoint = {"ema_model_state_dict": checkpoint}
|
||||
checkpoint["model_state_dict"] = {
|
||||
k.replace("ema_model.", ""): v
|
||||
for k, v in checkpoint["ema_model_state_dict"].items()
|
||||
if k not in ["initted", "step"]
|
||||
}
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
else:
|
||||
if ckpt_type == "safetensors":
|
||||
checkpoint = {'model_state_dict': checkpoint}
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
checkpoint = {"model_state_dict": checkpoint}
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
|
||||
return model.to(device)
|
||||
|
||||
@@ -19,11 +19,7 @@ from model.utils import (
|
||||
convert_char_to_pinyin,
|
||||
)
|
||||
|
||||
device = (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
)
|
||||
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
print(f"Using {device} device")
|
||||
|
||||
asr_pipe = pipeline(
|
||||
@@ -54,6 +50,7 @@ fix_duration = None
|
||||
|
||||
# chunk text into smaller pieces
|
||||
|
||||
|
||||
def chunk_text(text, max_chars=135):
|
||||
"""
|
||||
Splits the input text into chunks, each with a maximum number of characters.
|
||||
@@ -68,15 +65,15 @@ def chunk_text(text, max_chars=135):
|
||||
chunks = []
|
||||
current_chunk = ""
|
||||
# Split the text into sentences based on punctuation followed by whitespace
|
||||
sentences = re.split(r'(?<=[;:,.!?])\s+|(?<=[;:,。!?])', text)
|
||||
sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text)
|
||||
|
||||
for sentence in sentences:
|
||||
if len(current_chunk.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
|
||||
current_chunk += sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
|
||||
if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
|
||||
current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
|
||||
else:
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk.strip())
|
||||
current_chunk = sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
|
||||
current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
|
||||
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk.strip())
|
||||
@@ -86,6 +83,7 @@ def chunk_text(text, max_chars=135):
|
||||
|
||||
# load vocoder
|
||||
|
||||
|
||||
def load_vocoder(is_local=False, local_path=""):
|
||||
if is_local:
|
||||
print(f"Load vocos from local path {local_path}")
|
||||
@@ -101,23 +99,21 @@ def load_vocoder(is_local=False, local_path=""):
|
||||
|
||||
# load model for inference
|
||||
|
||||
|
||||
def load_model(model_cls, model_cfg, ckpt_path, vocab_file=""):
|
||||
|
||||
if vocab_file == "":
|
||||
vocab_file = "Emilia_ZH_EN"
|
||||
tokenizer = "pinyin"
|
||||
else:
|
||||
tokenizer = "custom"
|
||||
|
||||
print("\nvocab : ", vocab_file, tokenizer)
|
||||
print("tokenizer : ", tokenizer)
|
||||
print("model : ", ckpt_path,"\n")
|
||||
print("\nvocab : ", vocab_file, tokenizer)
|
||||
print("tokenizer : ", tokenizer)
|
||||
print("model : ", ckpt_path, "\n")
|
||||
|
||||
vocab_char_map, vocab_size = get_tokenizer(vocab_file, tokenizer)
|
||||
model = CFM(
|
||||
transformer=model_cls(
|
||||
**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
|
||||
),
|
||||
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
|
||||
mel_spec_kwargs=dict(
|
||||
target_sample_rate=target_sample_rate,
|
||||
n_mel_channels=n_mel_channels,
|
||||
@@ -129,21 +125,20 @@ def load_model(model_cls, model_cfg, ckpt_path, vocab_file=""):
|
||||
vocab_char_map=vocab_char_map,
|
||||
).to(device)
|
||||
|
||||
model = load_checkpoint(model, ckpt_path, device, use_ema = True)
|
||||
model = load_checkpoint(model, ckpt_path, device, use_ema=True)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# preprocess reference audio and text
|
||||
|
||||
|
||||
def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print):
|
||||
show_info("Converting audio...")
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
||||
aseg = AudioSegment.from_file(ref_audio_orig)
|
||||
|
||||
non_silent_segs = silence.split_on_silence(
|
||||
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000
|
||||
)
|
||||
non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000)
|
||||
non_silent_wave = AudioSegment.silent(duration=0)
|
||||
for non_silent_seg in non_silent_segs:
|
||||
non_silent_wave += non_silent_seg
|
||||
@@ -181,22 +176,27 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print):
|
||||
|
||||
# infer process: chunk text -> infer batches [i.e. infer_batch_process()]
|
||||
|
||||
def infer_process(ref_audio, ref_text, gen_text, model_obj, cross_fade_duration=0.15, speed=speed, show_info=print, progress=tqdm):
|
||||
|
||||
def infer_process(
|
||||
ref_audio, ref_text, gen_text, model_obj, cross_fade_duration=0.15, speed=speed, show_info=print, progress=tqdm
|
||||
):
|
||||
# Split the input text into batches
|
||||
audio, sr = torchaudio.load(ref_audio)
|
||||
max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
|
||||
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
|
||||
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
|
||||
for i, gen_text in enumerate(gen_text_batches):
|
||||
print(f'gen_text {i}', gen_text)
|
||||
|
||||
print(f"gen_text {i}", gen_text)
|
||||
|
||||
show_info(f"Generating audio in {len(gen_text_batches)} batches...")
|
||||
return infer_batch_process((audio, sr), ref_text, gen_text_batches, model_obj, cross_fade_duration, speed, progress)
|
||||
|
||||
|
||||
# infer batches
|
||||
|
||||
def infer_batch_process(ref_audio, ref_text, gen_text_batches, model_obj, cross_fade_duration=0.15, speed=1, progress=tqdm):
|
||||
|
||||
def infer_batch_process(
|
||||
ref_audio, ref_text, gen_text_batches, model_obj, cross_fade_duration=0.15, speed=1, progress=tqdm
|
||||
):
|
||||
audio, sr = ref_audio
|
||||
if audio.shape[0] > 1:
|
||||
audio = torch.mean(audio, dim=0, keepdim=True)
|
||||
@@ -212,7 +212,7 @@ def infer_batch_process(ref_audio, ref_text, gen_text_batches, model_obj, cross_
|
||||
generated_waves = []
|
||||
spectrograms = []
|
||||
|
||||
if len(ref_text[-1].encode('utf-8')) == 1:
|
||||
if len(ref_text[-1].encode("utf-8")) == 1:
|
||||
ref_text = ref_text + " "
|
||||
for i, gen_text in enumerate(progress.tqdm(gen_text_batches)):
|
||||
# Prepare the text
|
||||
@@ -221,8 +221,8 @@ def infer_batch_process(ref_audio, ref_text, gen_text_batches, model_obj, cross_
|
||||
|
||||
# Calculate duration
|
||||
ref_audio_len = audio.shape[-1] // hop_length
|
||||
ref_text_len = len(ref_text.encode('utf-8'))
|
||||
gen_text_len = len(gen_text.encode('utf-8'))
|
||||
ref_text_len = len(ref_text.encode("utf-8"))
|
||||
gen_text_len = len(gen_text.encode("utf-8"))
|
||||
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
|
||||
|
||||
# inference
|
||||
@@ -245,7 +245,7 @@ def infer_batch_process(ref_audio, ref_text, gen_text_batches, model_obj, cross_
|
||||
|
||||
# wav -> numpy
|
||||
generated_wave = generated_wave.squeeze().cpu().numpy()
|
||||
|
||||
|
||||
generated_waves.append(generated_wave)
|
||||
spectrograms.append(generated_mel_spec[0].cpu().numpy())
|
||||
|
||||
@@ -280,11 +280,9 @@ def infer_batch_process(ref_audio, ref_text, gen_text_batches, model_obj, cross_
|
||||
cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
|
||||
|
||||
# Combine
|
||||
new_wave = np.concatenate([
|
||||
prev_wave[:-cross_fade_samples],
|
||||
cross_faded_overlap,
|
||||
next_wave[cross_fade_samples:]
|
||||
])
|
||||
new_wave = np.concatenate(
|
||||
[prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]]
|
||||
)
|
||||
|
||||
final_wave = new_wave
|
||||
|
||||
@@ -296,6 +294,7 @@ def infer_batch_process(ref_audio, ref_text, gen_text_batches, model_obj, cross_
|
||||
|
||||
# remove silence from generated wav
|
||||
|
||||
|
||||
def remove_silence_for_generated_wav(filename):
|
||||
aseg = AudioSegment.from_file(filename)
|
||||
non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
|
||||
|
||||
10
ruff.toml
Normal file
10
ruff.toml
Normal file
@@ -0,0 +1,10 @@
|
||||
line-length = 120
|
||||
target-version = "py310"
|
||||
|
||||
[lint]
|
||||
# Only ignore variables with names starting with "_".
|
||||
dummy-variable-rgx = "^_.*$"
|
||||
|
||||
[lint.isort]
|
||||
force-single-line = true
|
||||
lines-after-imports = 2
|
||||
@@ -1,6 +1,7 @@
|
||||
'''ADAPTIVE BATCH SIZE'''
|
||||
print('Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in')
|
||||
print(' -> least padding, gather wavs with accumulated frames in a batch\n')
|
||||
"""ADAPTIVE BATCH SIZE"""
|
||||
|
||||
print("Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in")
|
||||
print(" -> least padding, gather wavs with accumulated frames in a batch\n")
|
||||
|
||||
# data
|
||||
total_hours = 95282
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
import sys, os
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
from model import M2_TTS, UNetT, DiT, MMDiT
|
||||
from model import M2_TTS, DiT
|
||||
|
||||
import torch
|
||||
import thop
|
||||
|
||||
|
||||
''' ~155M '''
|
||||
""" ~155M """
|
||||
# transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4)
|
||||
# transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4)
|
||||
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2)
|
||||
@@ -15,11 +17,11 @@ import thop
|
||||
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4, long_skip_connection = True)
|
||||
# transformer = MMDiT(dim = 512, depth = 16, heads = 16, ff_mult = 2)
|
||||
|
||||
''' ~335M '''
|
||||
""" ~335M """
|
||||
# FLOPs: 622.1 G, Params: 333.2 M
|
||||
# transformer = UNetT(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
|
||||
# FLOPs: 363.4 G, Params: 335.8 M
|
||||
transformer = DiT(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
|
||||
transformer = DiT(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
||||
|
||||
|
||||
model = M2_TTS(transformer=transformer)
|
||||
@@ -30,6 +32,8 @@ duration = 20
|
||||
frame_length = int(duration * target_sample_rate / hop_length)
|
||||
text_length = 150
|
||||
|
||||
flops, params = thop.profile(model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long)))
|
||||
flops, params = thop.profile(
|
||||
model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long))
|
||||
)
|
||||
print(f"FLOPs: {flops / 1e9} G")
|
||||
print(f"Params: {params / 1e6} M")
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import sys, os
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import time
|
||||
@@ -14,9 +16,9 @@ from vocos import Vocos
|
||||
from model import CFM, UNetT, DiT
|
||||
from model.utils import (
|
||||
load_checkpoint,
|
||||
get_tokenizer,
|
||||
get_seedtts_testset_metainfo,
|
||||
get_librispeech_test_clean_metainfo,
|
||||
get_tokenizer,
|
||||
get_seedtts_testset_metainfo,
|
||||
get_librispeech_test_clean_metainfo,
|
||||
get_inference_prompt,
|
||||
)
|
||||
|
||||
@@ -38,16 +40,16 @@ tokenizer = "pinyin"
|
||||
|
||||
parser = argparse.ArgumentParser(description="batch inference")
|
||||
|
||||
parser.add_argument('-s', '--seed', default=None, type=int)
|
||||
parser.add_argument('-d', '--dataset', default="Emilia_ZH_EN")
|
||||
parser.add_argument('-n', '--expname', required=True)
|
||||
parser.add_argument('-c', '--ckptstep', default=1200000, type=int)
|
||||
parser.add_argument("-s", "--seed", default=None, type=int)
|
||||
parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
|
||||
parser.add_argument("-n", "--expname", required=True)
|
||||
parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
|
||||
|
||||
parser.add_argument('-nfe', '--nfestep', default=32, type=int)
|
||||
parser.add_argument('-o', '--odemethod', default="euler")
|
||||
parser.add_argument('-ss', '--swaysampling', default=-1, type=float)
|
||||
parser.add_argument("-nfe", "--nfestep", default=32, type=int)
|
||||
parser.add_argument("-o", "--odemethod", default="euler")
|
||||
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
|
||||
|
||||
parser.add_argument('-t', '--testset', required=True)
|
||||
parser.add_argument("-t", "--testset", required=True)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -66,26 +68,26 @@ testset = args.testset
|
||||
|
||||
|
||||
infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
|
||||
cfg_strength = 2.
|
||||
speed = 1.
|
||||
cfg_strength = 2.0
|
||||
speed = 1.0
|
||||
use_truth_duration = False
|
||||
no_ref_audio = False
|
||||
|
||||
|
||||
if exp_name == "F5TTS_Base":
|
||||
model_cls = DiT
|
||||
model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
|
||||
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
||||
|
||||
elif exp_name == "E2TTS_Base":
|
||||
model_cls = UNetT
|
||||
model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
|
||||
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
||||
|
||||
|
||||
if testset == "ls_pc_test_clean":
|
||||
metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
|
||||
librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
|
||||
metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
|
||||
|
||||
|
||||
elif testset == "seedtts_test_zh":
|
||||
metalst = "data/seedtts_testset/zh/meta.lst"
|
||||
metainfo = get_seedtts_testset_metainfo(metalst)
|
||||
@@ -96,13 +98,16 @@ elif testset == "seedtts_test_en":
|
||||
|
||||
|
||||
# path to save genereted wavs
|
||||
if seed is None: seed = random.randint(-10000, 10000)
|
||||
output_dir = f"results/{exp_name}_{ckpt_step}/{testset}/" \
|
||||
f"seed{seed}_{ode_method}_nfe{nfe_step}" \
|
||||
f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}" \
|
||||
f"_cfg{cfg_strength}_speed{speed}" \
|
||||
f"{'_gt-dur' if use_truth_duration else ''}" \
|
||||
if seed is None:
|
||||
seed = random.randint(-10000, 10000)
|
||||
output_dir = (
|
||||
f"results/{exp_name}_{ckpt_step}/{testset}/"
|
||||
f"seed{seed}_{ode_method}_nfe{nfe_step}"
|
||||
f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
|
||||
f"_cfg{cfg_strength}_speed{speed}"
|
||||
f"{'_gt-dur' if use_truth_duration else ''}"
|
||||
f"{'_no-ref-audio' if no_ref_audio else ''}"
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------#
|
||||
@@ -110,15 +115,15 @@ output_dir = f"results/{exp_name}_{ckpt_step}/{testset}/" \
|
||||
use_ema = True
|
||||
|
||||
prompts_all = get_inference_prompt(
|
||||
metainfo,
|
||||
speed = speed,
|
||||
tokenizer = tokenizer,
|
||||
target_sample_rate = target_sample_rate,
|
||||
n_mel_channels = n_mel_channels,
|
||||
hop_length = hop_length,
|
||||
target_rms = target_rms,
|
||||
use_truth_duration = use_truth_duration,
|
||||
infer_batch_size = infer_batch_size,
|
||||
metainfo,
|
||||
speed=speed,
|
||||
tokenizer=tokenizer,
|
||||
target_sample_rate=target_sample_rate,
|
||||
n_mel_channels=n_mel_channels,
|
||||
hop_length=hop_length,
|
||||
target_rms=target_rms,
|
||||
use_truth_duration=use_truth_duration,
|
||||
infer_batch_size=infer_batch_size,
|
||||
)
|
||||
|
||||
# Vocoder model
|
||||
@@ -137,23 +142,19 @@ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
|
||||
|
||||
# Model
|
||||
model = CFM(
|
||||
transformer = model_cls(
|
||||
**model_cfg,
|
||||
text_num_embeds = vocab_size,
|
||||
mel_dim = n_mel_channels
|
||||
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
|
||||
mel_spec_kwargs=dict(
|
||||
target_sample_rate=target_sample_rate,
|
||||
n_mel_channels=n_mel_channels,
|
||||
hop_length=hop_length,
|
||||
),
|
||||
mel_spec_kwargs = dict(
|
||||
target_sample_rate = target_sample_rate,
|
||||
n_mel_channels = n_mel_channels,
|
||||
hop_length = hop_length,
|
||||
odeint_kwargs=dict(
|
||||
method=ode_method,
|
||||
),
|
||||
odeint_kwargs = dict(
|
||||
method = ode_method,
|
||||
),
|
||||
vocab_char_map = vocab_char_map,
|
||||
vocab_char_map=vocab_char_map,
|
||||
).to(device)
|
||||
|
||||
model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
|
||||
model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema)
|
||||
|
||||
if not os.path.exists(output_dir) and accelerator.is_main_process:
|
||||
os.makedirs(output_dir)
|
||||
@@ -163,29 +164,28 @@ accelerator.wait_for_everyone()
|
||||
start = time.time()
|
||||
|
||||
with accelerator.split_between_processes(prompts_all) as prompts:
|
||||
|
||||
for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
|
||||
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
|
||||
ref_mels = ref_mels.to(device)
|
||||
ref_mel_lens = torch.tensor(ref_mel_lens, dtype = torch.long).to(device)
|
||||
total_mel_lens = torch.tensor(total_mel_lens, dtype = torch.long).to(device)
|
||||
|
||||
ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
|
||||
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
|
||||
|
||||
# Inference
|
||||
with torch.inference_mode():
|
||||
generated, _ = model.sample(
|
||||
cond = ref_mels,
|
||||
text = final_text_list,
|
||||
duration = total_mel_lens,
|
||||
lens = ref_mel_lens,
|
||||
steps = nfe_step,
|
||||
cfg_strength = cfg_strength,
|
||||
sway_sampling_coef = sway_sampling_coef,
|
||||
no_ref_audio = no_ref_audio,
|
||||
seed = seed,
|
||||
cond=ref_mels,
|
||||
text=final_text_list,
|
||||
duration=total_mel_lens,
|
||||
lens=ref_mel_lens,
|
||||
steps=nfe_step,
|
||||
cfg_strength=cfg_strength,
|
||||
sway_sampling_coef=sway_sampling_coef,
|
||||
no_ref_audio=no_ref_audio,
|
||||
seed=seed,
|
||||
)
|
||||
# Final result
|
||||
for i, gen in enumerate(generated):
|
||||
gen = gen[ref_mel_lens[i]:total_mel_lens[i], :].unsqueeze(0)
|
||||
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
|
||||
gen_mel_spec = gen.permute(0, 2, 1)
|
||||
generated_wave = vocos.decode(gen_mel_spec.cpu())
|
||||
if ref_rms_list[i] < target_rms:
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
|
||||
|
||||
import sys, os
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import multiprocessing as mp
|
||||
@@ -19,7 +21,7 @@ metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
|
||||
librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
|
||||
gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
|
||||
|
||||
gpus = [0,1,2,3,4,5,6,7]
|
||||
gpus = [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
|
||||
|
||||
## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
|
||||
@@ -46,7 +48,7 @@ if eval_task == "wer":
|
||||
for wers_ in results:
|
||||
wers.extend(wers_)
|
||||
|
||||
wer = round(np.mean(wers)*100, 3)
|
||||
wer = round(np.mean(wers) * 100, 3)
|
||||
print(f"\nTotal {len(wers)} samples")
|
||||
print(f"WER : {wer}%")
|
||||
|
||||
@@ -62,6 +64,6 @@ if eval_task == "sim":
|
||||
for sim_ in results:
|
||||
sim_list.extend(sim_)
|
||||
|
||||
sim = round(sum(sim_list)/len(sim_list), 3)
|
||||
sim = round(sum(sim_list) / len(sim_list), 3)
|
||||
print(f"\nTotal {len(sim_list)} samples")
|
||||
print(f"SIM : {sim}")
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# Evaluate with Seed-TTS testset
|
||||
|
||||
import sys, os
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import multiprocessing as mp
|
||||
@@ -14,21 +16,21 @@ from model.utils import (
|
||||
|
||||
|
||||
eval_task = "wer" # sim | wer
|
||||
lang = "zh" # zh | en
|
||||
lang = "zh" # zh | en
|
||||
metalst = f"data/seedtts_testset/{lang}/meta.lst" # seed-tts testset
|
||||
# gen_wav_dir = f"data/seedtts_testset/{lang}/wavs" # ground truth wavs
|
||||
gen_wav_dir = f"PATH_TO_GENERATED" # generated wavs
|
||||
gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
|
||||
|
||||
|
||||
# NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
|
||||
# zh 1.254 seems a result of 4 workers wer_seed_tts
|
||||
gpus = [0,1,2,3,4,5,6,7]
|
||||
# zh 1.254 seems a result of 4 workers wer_seed_tts
|
||||
gpus = [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
|
||||
|
||||
local = False
|
||||
if local: # use local custom checkpoint dir
|
||||
if lang == "zh":
|
||||
asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr
|
||||
asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr
|
||||
elif lang == "en":
|
||||
asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
|
||||
else:
|
||||
@@ -48,7 +50,7 @@ if eval_task == "wer":
|
||||
for wers_ in results:
|
||||
wers.extend(wers_)
|
||||
|
||||
wer = round(np.mean(wers)*100, 3)
|
||||
wer = round(np.mean(wers) * 100, 3)
|
||||
print(f"\nTotal {len(wers)} samples")
|
||||
print(f"WER : {wer}%")
|
||||
|
||||
@@ -64,6 +66,6 @@ if eval_task == "sim":
|
||||
for sim_ in results:
|
||||
sim_list.extend(sim_)
|
||||
|
||||
sim = round(sum(sim_list)/len(sim_list), 3)
|
||||
sim = round(sum(sim_list) / len(sim_list), 3)
|
||||
print(f"\nTotal {len(sim_list)} samples")
|
||||
print(f"SIM : {sim}")
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import sys, os
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
from pathlib import Path
|
||||
@@ -17,10 +19,11 @@ from model.utils import (
|
||||
|
||||
PRETRAINED_VOCAB_PATH = Path(__file__).parent.parent / "data/Emilia_ZH_EN_pinyin/vocab.txt"
|
||||
|
||||
|
||||
def is_csv_wavs_format(input_dataset_dir):
|
||||
fpath = Path(input_dataset_dir)
|
||||
metadata = fpath / "metadata.csv"
|
||||
wavs = fpath / 'wavs'
|
||||
wavs = fpath / "wavs"
|
||||
return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir()
|
||||
|
||||
|
||||
@@ -46,22 +49,24 @@ def prepare_csv_wavs_dir(input_dir):
|
||||
|
||||
return sub_result, durations, vocab_set
|
||||
|
||||
|
||||
def get_audio_duration(audio_path):
|
||||
audio, sample_rate = torchaudio.load(audio_path)
|
||||
num_channels = audio.shape[0]
|
||||
return audio.shape[1] / (sample_rate * num_channels)
|
||||
|
||||
|
||||
def read_audio_text_pairs(csv_file_path):
|
||||
audio_text_pairs = []
|
||||
|
||||
parent = Path(csv_file_path).parent
|
||||
with open(csv_file_path, mode='r', newline='', encoding='utf-8') as csvfile:
|
||||
reader = csv.reader(csvfile, delimiter='|')
|
||||
with open(csv_file_path, mode="r", newline="", encoding="utf-8") as csvfile:
|
||||
reader = csv.reader(csvfile, delimiter="|")
|
||||
next(reader) # Skip the header row
|
||||
for row in reader:
|
||||
if len(row) >= 2:
|
||||
audio_file = row[0].strip() # First column: audio file path
|
||||
text = row[1].strip() # Second column: text
|
||||
text = row[1].strip() # Second column: text
|
||||
audio_file_path = parent / audio_file
|
||||
audio_text_pairs.append((audio_file_path.as_posix(), text))
|
||||
|
||||
@@ -78,12 +83,12 @@ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_fine
|
||||
# dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
|
||||
raw_arrow_path = out_dir / "raw.arrow"
|
||||
with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=1) as writer:
|
||||
for line in tqdm(result, desc=f"Writing to raw.arrow ..."):
|
||||
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
||||
writer.write(line)
|
||||
|
||||
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
||||
dur_json_path = out_dir / "duration.json"
|
||||
with open(dur_json_path.as_posix(), 'w', encoding='utf-8') as f:
|
||||
with open(dur_json_path.as_posix(), "w", encoding="utf-8") as f:
|
||||
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
||||
|
||||
# vocab map, i.e. tokenizer
|
||||
@@ -120,13 +125,14 @@ def cli():
|
||||
# finetune: python scripts/prepare_csv_wavs.py /path/to/input_dir /path/to/output_dir_pinyin
|
||||
# pretrain: python scripts/prepare_csv_wavs.py /path/to/output_dir_pinyin --pretrain
|
||||
parser = argparse.ArgumentParser(description="Prepare and save dataset.")
|
||||
parser.add_argument('inp_dir', type=str, help="Input directory containing the data.")
|
||||
parser.add_argument('out_dir', type=str, help="Output directory to save the prepared data.")
|
||||
parser.add_argument('--pretrain', action='store_true', help="Enable for new pretrain, otherwise is a fine-tune")
|
||||
parser.add_argument("inp_dir", type=str, help="Input directory containing the data.")
|
||||
parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.")
|
||||
parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
|
||||
@@ -4,7 +4,9 @@
|
||||
# generate audio text map for Emilia ZH & EN
|
||||
# evaluate for vocab size
|
||||
|
||||
import sys, os
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
from pathlib import Path
|
||||
@@ -12,7 +14,6 @@ import json
|
||||
from tqdm import tqdm
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
|
||||
from datasets import Dataset
|
||||
from datasets.arrow_writer import ArrowWriter
|
||||
|
||||
from model.utils import (
|
||||
@@ -21,13 +22,89 @@ from model.utils import (
|
||||
)
|
||||
|
||||
|
||||
out_zh = {"ZH_B00041_S06226", "ZH_B00042_S09204", "ZH_B00065_S09430", "ZH_B00065_S09431", "ZH_B00066_S09327", "ZH_B00066_S09328"}
|
||||
out_zh = {
|
||||
"ZH_B00041_S06226",
|
||||
"ZH_B00042_S09204",
|
||||
"ZH_B00065_S09430",
|
||||
"ZH_B00065_S09431",
|
||||
"ZH_B00066_S09327",
|
||||
"ZH_B00066_S09328",
|
||||
}
|
||||
zh_filters = ["い", "て"]
|
||||
# seems synthesized audios, or heavily code-switched
|
||||
out_en = {
|
||||
"EN_B00013_S00913", "EN_B00042_S00120", "EN_B00055_S04111", "EN_B00061_S00693", "EN_B00061_S01494", "EN_B00061_S03375",
|
||||
|
||||
"EN_B00059_S00092", "EN_B00111_S04300", "EN_B00100_S03759", "EN_B00087_S03811", "EN_B00059_S00950", "EN_B00089_S00946", "EN_B00078_S05127", "EN_B00070_S04089", "EN_B00074_S09659", "EN_B00061_S06983", "EN_B00061_S07060", "EN_B00059_S08397", "EN_B00082_S06192", "EN_B00091_S01238", "EN_B00089_S07349", "EN_B00070_S04343", "EN_B00061_S02400", "EN_B00076_S01262", "EN_B00068_S06467", "EN_B00076_S02943", "EN_B00064_S05954", "EN_B00061_S05386", "EN_B00066_S06544", "EN_B00076_S06944", "EN_B00072_S08620", "EN_B00076_S07135", "EN_B00076_S09127", "EN_B00065_S00497", "EN_B00059_S06227", "EN_B00063_S02859", "EN_B00075_S01547", "EN_B00061_S08286", "EN_B00079_S02901", "EN_B00092_S03643", "EN_B00096_S08653", "EN_B00063_S04297", "EN_B00063_S04614", "EN_B00079_S04698", "EN_B00104_S01666", "EN_B00061_S09504", "EN_B00061_S09694", "EN_B00065_S05444", "EN_B00063_S06860", "EN_B00065_S05725", "EN_B00069_S07628", "EN_B00083_S03875", "EN_B00071_S07665", "EN_B00071_S07665", "EN_B00062_S04187", "EN_B00065_S09873", "EN_B00065_S09922", "EN_B00084_S02463", "EN_B00067_S05066", "EN_B00106_S08060", "EN_B00073_S06399", "EN_B00073_S09236", "EN_B00087_S00432", "EN_B00085_S05618", "EN_B00064_S01262", "EN_B00072_S01739", "EN_B00059_S03913", "EN_B00069_S04036", "EN_B00067_S05623", "EN_B00060_S05389", "EN_B00060_S07290", "EN_B00062_S08995",
|
||||
"EN_B00013_S00913",
|
||||
"EN_B00042_S00120",
|
||||
"EN_B00055_S04111",
|
||||
"EN_B00061_S00693",
|
||||
"EN_B00061_S01494",
|
||||
"EN_B00061_S03375",
|
||||
"EN_B00059_S00092",
|
||||
"EN_B00111_S04300",
|
||||
"EN_B00100_S03759",
|
||||
"EN_B00087_S03811",
|
||||
"EN_B00059_S00950",
|
||||
"EN_B00089_S00946",
|
||||
"EN_B00078_S05127",
|
||||
"EN_B00070_S04089",
|
||||
"EN_B00074_S09659",
|
||||
"EN_B00061_S06983",
|
||||
"EN_B00061_S07060",
|
||||
"EN_B00059_S08397",
|
||||
"EN_B00082_S06192",
|
||||
"EN_B00091_S01238",
|
||||
"EN_B00089_S07349",
|
||||
"EN_B00070_S04343",
|
||||
"EN_B00061_S02400",
|
||||
"EN_B00076_S01262",
|
||||
"EN_B00068_S06467",
|
||||
"EN_B00076_S02943",
|
||||
"EN_B00064_S05954",
|
||||
"EN_B00061_S05386",
|
||||
"EN_B00066_S06544",
|
||||
"EN_B00076_S06944",
|
||||
"EN_B00072_S08620",
|
||||
"EN_B00076_S07135",
|
||||
"EN_B00076_S09127",
|
||||
"EN_B00065_S00497",
|
||||
"EN_B00059_S06227",
|
||||
"EN_B00063_S02859",
|
||||
"EN_B00075_S01547",
|
||||
"EN_B00061_S08286",
|
||||
"EN_B00079_S02901",
|
||||
"EN_B00092_S03643",
|
||||
"EN_B00096_S08653",
|
||||
"EN_B00063_S04297",
|
||||
"EN_B00063_S04614",
|
||||
"EN_B00079_S04698",
|
||||
"EN_B00104_S01666",
|
||||
"EN_B00061_S09504",
|
||||
"EN_B00061_S09694",
|
||||
"EN_B00065_S05444",
|
||||
"EN_B00063_S06860",
|
||||
"EN_B00065_S05725",
|
||||
"EN_B00069_S07628",
|
||||
"EN_B00083_S03875",
|
||||
"EN_B00071_S07665",
|
||||
"EN_B00071_S07665",
|
||||
"EN_B00062_S04187",
|
||||
"EN_B00065_S09873",
|
||||
"EN_B00065_S09922",
|
||||
"EN_B00084_S02463",
|
||||
"EN_B00067_S05066",
|
||||
"EN_B00106_S08060",
|
||||
"EN_B00073_S06399",
|
||||
"EN_B00073_S09236",
|
||||
"EN_B00087_S00432",
|
||||
"EN_B00085_S05618",
|
||||
"EN_B00064_S01262",
|
||||
"EN_B00072_S01739",
|
||||
"EN_B00059_S03913",
|
||||
"EN_B00069_S04036",
|
||||
"EN_B00067_S05623",
|
||||
"EN_B00060_S05389",
|
||||
"EN_B00060_S07290",
|
||||
"EN_B00062_S08995",
|
||||
}
|
||||
en_filters = ["ا", "い", "て"]
|
||||
|
||||
@@ -43,18 +120,24 @@ def deal_with_audio_dir(audio_dir):
|
||||
for line in tqdm(lines, desc=f"{audio_jsonl.stem}"):
|
||||
obj = json.loads(line)
|
||||
text = obj["text"]
|
||||
if obj['language'] == "zh":
|
||||
if obj["language"] == "zh":
|
||||
if obj["wav"].split("/")[1] in out_zh or any(f in text for f in zh_filters) or repetition_found(text):
|
||||
bad_case_zh += 1
|
||||
continue
|
||||
else:
|
||||
text = text.translate(str.maketrans({',': ',', '!': '!', '?': '?'})) # not "。" cuz much code-switched
|
||||
if obj['language'] == "en":
|
||||
if obj["wav"].split("/")[1] in out_en or any(f in text for f in en_filters) or repetition_found(text, length=4):
|
||||
text = text.translate(
|
||||
str.maketrans({",": ",", "!": "!", "?": "?"})
|
||||
) # not "。" cuz much code-switched
|
||||
if obj["language"] == "en":
|
||||
if (
|
||||
obj["wav"].split("/")[1] in out_en
|
||||
or any(f in text for f in en_filters)
|
||||
or repetition_found(text, length=4)
|
||||
):
|
||||
bad_case_en += 1
|
||||
continue
|
||||
if tokenizer == "pinyin":
|
||||
text = convert_char_to_pinyin([text], polyphone = polyphone)[0]
|
||||
text = convert_char_to_pinyin([text], polyphone=polyphone)[0]
|
||||
duration = obj["duration"]
|
||||
sub_result.append({"audio_path": str(audio_dir.parent / obj["wav"]), "text": text, "duration": duration})
|
||||
durations.append(duration)
|
||||
@@ -96,11 +179,11 @@ def main():
|
||||
# dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
|
||||
# dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
|
||||
with ArrowWriter(path=f"data/{dataset_name}/raw.arrow") as writer:
|
||||
for line in tqdm(result, desc=f"Writing to raw.arrow ..."):
|
||||
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
||||
writer.write(line)
|
||||
|
||||
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
||||
with open(f"data/{dataset_name}/duration.json", 'w', encoding='utf-8') as f:
|
||||
with open(f"data/{dataset_name}/duration.json", "w", encoding="utf-8") as f:
|
||||
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
||||
|
||||
# vocab map, i.e. tokenizer
|
||||
@@ -114,12 +197,13 @@ def main():
|
||||
print(f"\nFor {dataset_name}, sample count: {len(result)}")
|
||||
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
||||
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
|
||||
if "ZH" in langs: print(f"Bad zh transcription case: {total_bad_case_zh}")
|
||||
if "EN" in langs: print(f"Bad en transcription case: {total_bad_case_en}\n")
|
||||
if "ZH" in langs:
|
||||
print(f"Bad zh transcription case: {total_bad_case_zh}")
|
||||
if "EN" in langs:
|
||||
print(f"Bad en transcription case: {total_bad_case_en}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
max_workers = 32
|
||||
|
||||
tokenizer = "pinyin" # "pinyin" | "char"
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
# generate audio text map for WenetSpeech4TTS
|
||||
# evaluate for vocab size
|
||||
|
||||
import sys, os
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import json
|
||||
@@ -23,7 +25,7 @@ def deal_with_sub_path_files(dataset_path, sub_path):
|
||||
|
||||
audio_paths, texts, durations = [], [], []
|
||||
for text_file in tqdm(text_files):
|
||||
with open(os.path.join(text_dir, text_file), 'r', encoding='utf-8') as file:
|
||||
with open(os.path.join(text_dir, text_file), "r", encoding="utf-8") as file:
|
||||
first_line = file.readline().split("\t")
|
||||
audio_nm = first_line[0]
|
||||
audio_path = os.path.join(audio_dir, audio_nm + ".wav")
|
||||
@@ -32,7 +34,7 @@ def deal_with_sub_path_files(dataset_path, sub_path):
|
||||
audio_paths.append(audio_path)
|
||||
|
||||
if tokenizer == "pinyin":
|
||||
texts.extend(convert_char_to_pinyin([text], polyphone = polyphone))
|
||||
texts.extend(convert_char_to_pinyin([text], polyphone=polyphone))
|
||||
elif tokenizer == "char":
|
||||
texts.append(text)
|
||||
|
||||
@@ -46,7 +48,7 @@ def main():
|
||||
assert tokenizer in ["pinyin", "char"]
|
||||
|
||||
audio_path_list, text_list, duration_list = [], [], []
|
||||
|
||||
|
||||
executor = ProcessPoolExecutor(max_workers=max_workers)
|
||||
futures = []
|
||||
for dataset_path in dataset_paths:
|
||||
@@ -68,8 +70,10 @@ def main():
|
||||
dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list})
|
||||
dataset.save_to_disk(f"data/{dataset_name}_{tokenizer}/raw", max_shard_size="2GB") # arrow format
|
||||
|
||||
with open(f"data/{dataset_name}_{tokenizer}/duration.json", 'w', encoding='utf-8') as f:
|
||||
json.dump({"duration": duration_list}, f, ensure_ascii=False) # dup a json separately saving duration in case for DynamicBatchSampler ease
|
||||
with open(f"data/{dataset_name}_{tokenizer}/duration.json", "w", encoding="utf-8") as f:
|
||||
json.dump(
|
||||
{"duration": duration_list}, f, ensure_ascii=False
|
||||
) # dup a json separately saving duration in case for DynamicBatchSampler ease
|
||||
|
||||
print("\nEvaluating vocab size (all characters and symbols / all phonemes) ...")
|
||||
text_vocab_set = set()
|
||||
@@ -85,22 +89,21 @@ def main():
|
||||
f.write(vocab + "\n")
|
||||
print(f"\nFor {dataset_name}, sample count: {len(text_list)}")
|
||||
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}\n")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
max_workers = 32
|
||||
|
||||
tokenizer = "pinyin" # "pinyin" | "char"
|
||||
polyphone = True
|
||||
dataset_choice = 1 # 1: Premium, 2: Standard, 3: Basic
|
||||
|
||||
dataset_name = ["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice-1]
|
||||
dataset_name = ["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice - 1]
|
||||
dataset_paths = [
|
||||
"<SOME_PATH>/WenetSpeech4TTS/Basic",
|
||||
"<SOME_PATH>/WenetSpeech4TTS/Standard",
|
||||
"<SOME_PATH>/WenetSpeech4TTS/Premium",
|
||||
][-dataset_choice:]
|
||||
][-dataset_choice:]
|
||||
print(f"\nChoose Dataset: {dataset_name}\n")
|
||||
|
||||
main()
|
||||
@@ -109,8 +112,8 @@ if __name__ == "__main__":
|
||||
# WenetSpeech4TTS Basic Standard Premium
|
||||
# samples count 3932473 1941220 407494
|
||||
# pinyin vocab size 1349 1348 1344 (no polyphone)
|
||||
# - - 1459 (polyphone)
|
||||
# - - 1459 (polyphone)
|
||||
# char vocab size 5264 5219 5042
|
||||
|
||||
|
||||
# vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
|
||||
# please be careful if using pretrained model, make sure the vocab.txt is same
|
||||
|
||||
@@ -5,11 +5,11 @@ import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from vocos import Vocos
|
||||
|
||||
from model import CFM, UNetT, DiT, MMDiT
|
||||
from model import CFM, UNetT, DiT
|
||||
from model.utils import (
|
||||
load_checkpoint,
|
||||
get_tokenizer,
|
||||
convert_char_to_pinyin,
|
||||
get_tokenizer,
|
||||
convert_char_to_pinyin,
|
||||
save_spectrogram,
|
||||
)
|
||||
|
||||
@@ -35,18 +35,18 @@ exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
|
||||
ckpt_step = 1200000
|
||||
|
||||
nfe_step = 32 # 16, 32
|
||||
cfg_strength = 2.
|
||||
ode_method = 'euler' # euler | midpoint
|
||||
sway_sampling_coef = -1.
|
||||
speed = 1.
|
||||
cfg_strength = 2.0
|
||||
ode_method = "euler" # euler | midpoint
|
||||
sway_sampling_coef = -1.0
|
||||
speed = 1.0
|
||||
|
||||
if exp_name == "F5TTS_Base":
|
||||
model_cls = DiT
|
||||
model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
|
||||
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
||||
|
||||
elif exp_name == "E2TTS_Base":
|
||||
model_cls = UNetT
|
||||
model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
|
||||
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
||||
|
||||
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.safetensors"
|
||||
output_dir = "tests"
|
||||
@@ -62,8 +62,14 @@ output_dir = "tests"
|
||||
audio_to_edit = "tests/ref_audio/test_en_1_ref_short.wav"
|
||||
origin_text = "Some call me nature, others call me mother nature."
|
||||
target_text = "Some call me optimist, others call me realist."
|
||||
parts_to_edit = [[1.42, 2.44], [4.04, 4.9], ] # stard_ends of "nature" & "mother nature", in seconds
|
||||
fix_duration = [1.2, 1, ] # fix duration for "optimist" & "realist", in seconds
|
||||
parts_to_edit = [
|
||||
[1.42, 2.44],
|
||||
[4.04, 4.9],
|
||||
] # stard_ends of "nature" & "mother nature", in seconds
|
||||
fix_duration = [
|
||||
1.2,
|
||||
1,
|
||||
] # fix duration for "optimist" & "realist", in seconds
|
||||
|
||||
# audio_to_edit = "tests/ref_audio/test_zh_1_ref_short.wav"
|
||||
# origin_text = "对,这就是我,万人敬仰的太乙真人。"
|
||||
@@ -86,7 +92,7 @@ if local:
|
||||
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
|
||||
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device)
|
||||
vocos.load_state_dict(state_dict)
|
||||
|
||||
|
||||
vocos.eval()
|
||||
else:
|
||||
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
||||
@@ -96,23 +102,19 @@ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
|
||||
|
||||
# Model
|
||||
model = CFM(
|
||||
transformer = model_cls(
|
||||
**model_cfg,
|
||||
text_num_embeds = vocab_size,
|
||||
mel_dim = n_mel_channels
|
||||
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
|
||||
mel_spec_kwargs=dict(
|
||||
target_sample_rate=target_sample_rate,
|
||||
n_mel_channels=n_mel_channels,
|
||||
hop_length=hop_length,
|
||||
),
|
||||
mel_spec_kwargs = dict(
|
||||
target_sample_rate = target_sample_rate,
|
||||
n_mel_channels = n_mel_channels,
|
||||
hop_length = hop_length,
|
||||
odeint_kwargs=dict(
|
||||
method=ode_method,
|
||||
),
|
||||
odeint_kwargs = dict(
|
||||
method = ode_method,
|
||||
),
|
||||
vocab_char_map = vocab_char_map,
|
||||
vocab_char_map=vocab_char_map,
|
||||
).to(device)
|
||||
|
||||
model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
|
||||
model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema)
|
||||
|
||||
# Audio
|
||||
audio, sr = torchaudio.load(audio_to_edit)
|
||||
@@ -132,14 +134,18 @@ for part in parts_to_edit:
|
||||
part_dur = end - start if fix_duration is None else fix_duration.pop(0)
|
||||
part_dur = part_dur * target_sample_rate
|
||||
start = start * target_sample_rate
|
||||
audio_ = torch.cat((audio_, audio[:, round(offset):round(start)], torch.zeros(1, round(part_dur))), dim = -1)
|
||||
edit_mask = torch.cat((edit_mask,
|
||||
torch.ones(1, round((start - offset) / hop_length), dtype = torch.bool),
|
||||
torch.zeros(1, round(part_dur / hop_length), dtype = torch.bool)
|
||||
), dim = -1)
|
||||
audio_ = torch.cat((audio_, audio[:, round(offset) : round(start)], torch.zeros(1, round(part_dur))), dim=-1)
|
||||
edit_mask = torch.cat(
|
||||
(
|
||||
edit_mask,
|
||||
torch.ones(1, round((start - offset) / hop_length), dtype=torch.bool),
|
||||
torch.zeros(1, round(part_dur / hop_length), dtype=torch.bool),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
offset = end * target_sample_rate
|
||||
# audio = torch.cat((audio_, audio[:, round(offset):]), dim = -1)
|
||||
edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value = True)
|
||||
edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value=True)
|
||||
audio = audio.to(device)
|
||||
edit_mask = edit_mask.to(device)
|
||||
|
||||
@@ -159,14 +165,14 @@ duration = audio.shape[-1] // hop_length
|
||||
# Inference
|
||||
with torch.inference_mode():
|
||||
generated, trajectory = model.sample(
|
||||
cond = audio,
|
||||
text = final_text_list,
|
||||
duration = duration,
|
||||
steps = nfe_step,
|
||||
cfg_strength = cfg_strength,
|
||||
sway_sampling_coef = sway_sampling_coef,
|
||||
seed = seed,
|
||||
edit_mask = edit_mask,
|
||||
cond=audio,
|
||||
text=final_text_list,
|
||||
duration=duration,
|
||||
steps=nfe_step,
|
||||
cfg_strength=cfg_strength,
|
||||
sway_sampling_coef=sway_sampling_coef,
|
||||
seed=seed,
|
||||
edit_mask=edit_mask,
|
||||
)
|
||||
print(f"Generated mel: {generated.shape}")
|
||||
|
||||
|
||||
68
train.py
68
train.py
@@ -1,4 +1,4 @@
|
||||
from model import CFM, UNetT, DiT, MMDiT, Trainer
|
||||
from model import CFM, UNetT, DiT, Trainer
|
||||
from model.utils import get_tokenizer
|
||||
from model.dataset import load_dataset
|
||||
|
||||
@@ -9,8 +9,8 @@ target_sample_rate = 24000
|
||||
n_mel_channels = 100
|
||||
hop_length = 256
|
||||
|
||||
tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
|
||||
tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
||||
tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
|
||||
tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
||||
dataset_name = "Emilia_ZH_EN"
|
||||
|
||||
# -------------------------- Training Settings -------------------------- #
|
||||
@@ -23,7 +23,7 @@ batch_size_per_gpu = 38400 # 8 GPUs, 8 * 38400 = 307200
|
||||
batch_size_type = "frame" # "frame" or "sample"
|
||||
max_samples = 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
||||
grad_accumulation_steps = 1 # note: updates = steps / grad_accumulation_steps
|
||||
max_grad_norm = 1.
|
||||
max_grad_norm = 1.0
|
||||
|
||||
epochs = 11 # use linear decay, thus epochs control the slope
|
||||
num_warmup_updates = 20000 # warmup steps
|
||||
@@ -34,15 +34,16 @@ last_per_steps = 5000 # save last checkpoint per steps
|
||||
if exp_name == "F5TTS_Base":
|
||||
wandb_resume_id = None
|
||||
model_cls = DiT
|
||||
model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
|
||||
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
||||
elif exp_name == "E2TTS_Base":
|
||||
wandb_resume_id = None
|
||||
model_cls = UNetT
|
||||
model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
|
||||
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
||||
|
||||
def main():
|
||||
if tokenizer == "custom":
|
||||
tokenizer_path = tokenizer_path
|
||||
@@ -51,44 +52,41 @@ def main():
|
||||
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
|
||||
|
||||
mel_spec_kwargs = dict(
|
||||
target_sample_rate = target_sample_rate,
|
||||
n_mel_channels = n_mel_channels,
|
||||
hop_length = hop_length,
|
||||
)
|
||||
|
||||
target_sample_rate=target_sample_rate,
|
||||
n_mel_channels=n_mel_channels,
|
||||
hop_length=hop_length,
|
||||
)
|
||||
|
||||
model = CFM(
|
||||
transformer = model_cls(
|
||||
**model_cfg,
|
||||
text_num_embeds = vocab_size,
|
||||
mel_dim = n_mel_channels
|
||||
),
|
||||
mel_spec_kwargs = mel_spec_kwargs,
|
||||
vocab_char_map = vocab_char_map,
|
||||
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
|
||||
mel_spec_kwargs=mel_spec_kwargs,
|
||||
vocab_char_map=vocab_char_map,
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model,
|
||||
epochs,
|
||||
epochs,
|
||||
learning_rate,
|
||||
num_warmup_updates = num_warmup_updates,
|
||||
save_per_updates = save_per_updates,
|
||||
checkpoint_path = f'ckpts/{exp_name}',
|
||||
batch_size = batch_size_per_gpu,
|
||||
batch_size_type = batch_size_type,
|
||||
max_samples = max_samples,
|
||||
grad_accumulation_steps = grad_accumulation_steps,
|
||||
max_grad_norm = max_grad_norm,
|
||||
wandb_project = "CFM-TTS",
|
||||
wandb_run_name = exp_name,
|
||||
wandb_resume_id = wandb_resume_id,
|
||||
last_per_steps = last_per_steps,
|
||||
num_warmup_updates=num_warmup_updates,
|
||||
save_per_updates=save_per_updates,
|
||||
checkpoint_path=f"ckpts/{exp_name}",
|
||||
batch_size=batch_size_per_gpu,
|
||||
batch_size_type=batch_size_type,
|
||||
max_samples=max_samples,
|
||||
grad_accumulation_steps=grad_accumulation_steps,
|
||||
max_grad_norm=max_grad_norm,
|
||||
wandb_project="CFM-TTS",
|
||||
wandb_run_name=exp_name,
|
||||
wandb_resume_id=wandb_resume_id,
|
||||
last_per_steps=last_per_steps,
|
||||
)
|
||||
|
||||
train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
|
||||
trainer.train(train_dataset,
|
||||
resumable_with_seed = 666 # seed for shuffling dataset
|
||||
)
|
||||
trainer.train(
|
||||
train_dataset,
|
||||
resumable_with_seed=666, # seed for shuffling dataset
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user