add and run pre-commit with ruff

This commit is contained in:
Tom Hunn
2024-10-21 14:46:45 +10:00
parent 77e00db01b
commit a4ca14b5f6
29 changed files with 1827 additions and 1328 deletions

14
.github/workflows/pre-commit.yaml vendored Normal file
View File

@@ -0,0 +1,14 @@
name: pre-commit
on:
pull_request:
push:
branches: [main]
jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
- uses: pre-commit/action@v3.0.1

14
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,14 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.7.0
hooks:
# Run the linter.
- id: ruff
args: [--fix]
# Run the formatter.
- id: ruff-format
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
hooks:
- id: check-yaml

View File

@@ -43,6 +43,26 @@ pip install -r requirements.txt
docker build -t f5tts:v1 .
```
### Development
When making a pull request, please use pre-commit to ensure code quality:
```bash
pip install pre-commit
pre-commit install
```
This will run linters and formatters automatically before each commit.
Manually run using:
```bash
pre-commit run --all-files
```
Note: Some model components have linting exceptions for E722 to accommodate tensor notation
## Prepare Dataset
Example data processing scripts for Emilia and Wenetspeech4TTS, and you may tailor your own one along with a Dataset class in `model/dataset.py`.

View File

@@ -1,42 +1,57 @@
import argparse
from model import CFM, UNetT, DiT, MMDiT, Trainer
from model import CFM, UNetT, DiT, Trainer
from model.utils import get_tokenizer
from model.dataset import load_dataset
from cached_path import cached_path
import shutil,os
import shutil
import os
# -------------------------- Dataset Settings --------------------------- #
target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
# -------------------------- Argument Parsing --------------------------- #
def parse_args():
parser = argparse.ArgumentParser(description='Train CFM Model')
parser = argparse.ArgumentParser(description="Train CFM Model")
parser.add_argument(
"--exp_name", type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"], help="Experiment name"
)
parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use")
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate for training")
parser.add_argument("--batch_size_per_gpu", type=int, default=256, help="Batch size per GPU")
parser.add_argument(
"--batch_size_type", type=str, default="frame", choices=["frame", "sample"], help="Batch size type"
)
parser.add_argument("--max_samples", type=int, default=16, help="Max sequences per batch")
parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs")
parser.add_argument("--num_warmup_updates", type=int, default=5, help="Warmup steps")
parser.add_argument("--save_per_updates", type=int, default=10, help="Save checkpoint every X steps")
parser.add_argument("--last_per_steps", type=int, default=10, help="Save last checkpoint every X steps")
parser.add_argument("--finetune", type=bool, default=True, help="Use Finetune")
parser.add_argument(
"--tokenizer", type=str, default="pinyin", choices=["pinyin", "char", "custom"], help="Tokenizer type"
)
parser.add_argument(
"--tokenizer_path",
type=str,
default=None,
help="Path to custom tokenizer vocab file (only used if tokenizer = 'custom')",
)
parser.add_argument('--exp_name', type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"],help='Experiment name')
parser.add_argument('--dataset_name', type=str, default="Emilia_ZH_EN", help='Name of the dataset to use')
parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate for training')
parser.add_argument('--batch_size_per_gpu', type=int, default=256, help='Batch size per GPU')
parser.add_argument('--batch_size_type', type=str, default="frame", choices=["frame", "sample"],help='Batch size type')
parser.add_argument('--max_samples', type=int, default=16, help='Max sequences per batch')
parser.add_argument('--grad_accumulation_steps', type=int, default=1,help='Gradient accumulation steps')
parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Max gradient norm for clipping')
parser.add_argument('--epochs', type=int, default=10, help='Number of training epochs')
parser.add_argument('--num_warmup_updates', type=int, default=5, help='Warmup steps')
parser.add_argument('--save_per_updates', type=int, default=10, help='Save checkpoint every X steps')
parser.add_argument('--last_per_steps', type=int, default=10, help='Save last checkpoint every X steps')
parser.add_argument('--finetune', type=bool, default=True, help='Use Finetune')
return parser.parse_args()
# -------------------------- Training Settings -------------------------- #
def main():
args = parse_args()
# Model parameters based on experiment name
if args.exp_name == "F5TTS_Base":
@@ -44,24 +59,31 @@ def main():
model_cls = DiT
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
if args.finetune:
ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
elif args.exp_name == "E2TTS_Base":
wandb_resume_id = None
model_cls = UNetT
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
if args.finetune:
ckpt_path = str(cached_path(f"hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
if args.finetune:
path_ckpt = os.path.join("ckpts",args.dataset_name)
if os.path.isdir(path_ckpt)==False:
os.makedirs(path_ckpt,exist_ok=True)
shutil.copy2(ckpt_path,os.path.join(path_ckpt,os.path.basename(ckpt_path)))
checkpoint_path=os.path.join("ckpts",args.dataset_name)
ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
if args.finetune:
path_ckpt = os.path.join("ckpts", args.dataset_name)
if not os.path.isdir(path_ckpt):
os.makedirs(path_ckpt, exist_ok=True)
shutil.copy2(ckpt_path, os.path.join(path_ckpt, os.path.basename(ckpt_path)))
checkpoint_path = os.path.join("ckpts", args.dataset_name)
# Use the tokenizer and tokenizer_path provided in the command line arguments
tokenizer = args.tokenizer
if tokenizer == "custom":
if not args.tokenizer_path:
raise ValueError("Custom tokenizer selected, but no tokenizer_path provided.")
tokenizer_path = args.tokenizer_path
else:
tokenizer_path = args.dataset_name
# Use the dataset_name provided in the command line
tokenizer_path = args.dataset_name if tokenizer != "custom" else tokenizer_path
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
mel_spec_kwargs = dict(
@@ -71,11 +93,7 @@ def main():
)
e2tts = CFM(
transformer=model_cls(
**model_cfg,
text_num_embeds=vocab_size,
mel_dim=n_mel_channels
),
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
mel_spec_kwargs=mel_spec_kwargs,
vocab_char_map=vocab_char_map,
)
@@ -99,10 +117,11 @@ def main():
)
train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
trainer.train(train_dataset,
resumable_with_seed=666 # seed for shuffling dataset
)
trainer.train(
train_dataset,
resumable_with_seed=666, # seed for shuffling dataset
)
if __name__ == '__main__':
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

View File

@@ -1,3 +1,6 @@
# ruff: noqa: E402
# Above allows ruff to ignore E402: module level import not at top of file
import re
import tempfile
@@ -11,16 +14,19 @@ from pydub import AudioSegment
try:
import spaces
USING_SPACES = True
except ImportError:
USING_SPACES = False
def gpu_decorator(func):
if USING_SPACES:
return spaces.GPU(func)
else:
return func
from model import DiT, UNetT
from model.utils import (
save_spectrogram,
@@ -38,15 +44,18 @@ vocos = load_vocoder()
# load models
F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
F5TTS_ema_model = load_model(DiT, F5TTS_model_cfg, str(cached_path(f"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors")))
F5TTS_ema_model = load_model(
DiT, F5TTS_model_cfg, str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))
)
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
E2TTS_ema_model = load_model(UNetT, E2TTS_model_cfg, str(cached_path(f"hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors")))
E2TTS_ema_model = load_model(
UNetT, E2TTS_model_cfg, str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
)
@gpu_decorator
def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15, speed=1):
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=gr.Info)
if model == "F5-TTS":
@@ -54,7 +63,16 @@ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_
elif model == "E2-TTS":
ema_model = E2TTS_ema_model
final_wave, final_sample_rate, combined_spectrogram = infer_process(ref_audio, ref_text, gen_text, ema_model, cross_fade_duration=cross_fade_duration, speed=speed, show_info=gr.Info, progress=gr.Progress())
final_wave, final_sample_rate, combined_spectrogram = infer_process(
ref_audio,
ref_text,
gen_text,
ema_model,
cross_fade_duration=cross_fade_duration,
speed=speed,
show_info=gr.Info,
progress=gr.Progress(),
)
# Remove silence
if remove_silence:
@@ -73,17 +91,19 @@ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_
@gpu_decorator
def generate_podcast(script, speaker1_name, ref_audio1, ref_text1, speaker2_name, ref_audio2, ref_text2, model, remove_silence):
def generate_podcast(
script, speaker1_name, ref_audio1, ref_text1, speaker2_name, ref_audio2, ref_text2, model, remove_silence
):
# Split the script into speaker blocks
speaker_pattern = re.compile(f"^({re.escape(speaker1_name)}|{re.escape(speaker2_name)}):", re.MULTILINE)
speaker_blocks = speaker_pattern.split(script)[1:] # Skip the first empty element
generated_audio_segments = []
for i in range(0, len(speaker_blocks), 2):
speaker = speaker_blocks[i]
text = speaker_blocks[i+1].strip()
text = speaker_blocks[i + 1].strip()
# Determine which speaker is talking
if speaker == speaker1_name:
ref_audio = ref_audio1
@@ -93,51 +113,52 @@ def generate_podcast(script, speaker1_name, ref_audio1, ref_text1, speaker2_name
ref_text = ref_text2
else:
continue # Skip if the speaker is neither speaker1 nor speaker2
# Generate audio for this block
audio, _ = infer(ref_audio, ref_text, text, model, remove_silence)
# Convert the generated audio to a numpy array
sr, audio_data = audio
# Save the audio data as a WAV file
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
sf.write(temp_file.name, audio_data, sr)
audio_segment = AudioSegment.from_wav(temp_file.name)
generated_audio_segments.append(audio_segment)
# Add a short pause between speakers
pause = AudioSegment.silent(duration=500) # 500ms pause
generated_audio_segments.append(pause)
# Concatenate all audio segments
final_podcast = sum(generated_audio_segments)
# Export the final podcast
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
podcast_path = temp_file.name
final_podcast.export(podcast_path, format="wav")
return podcast_path
def parse_speechtypes_text(gen_text):
# Pattern to find (Emotion)
pattern = r'\((.*?)\)'
pattern = r"\((.*?)\)"
# Split the text by the pattern
tokens = re.split(pattern, gen_text)
segments = []
current_emotion = 'Regular'
current_emotion = "Regular"
for i in range(len(tokens)):
if i % 2 == 0:
# This is text
text = tokens[i].strip()
if text:
segments.append({'emotion': current_emotion, 'text': text})
segments.append({"emotion": current_emotion, "text": text})
else:
# This is emotion
emotion = tokens[i].strip()
@@ -158,9 +179,7 @@ with gr.Blocks() as app_tts:
gr.Markdown("# Batched TTS")
ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
model_choice = gr.Radio(
choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
)
model_choice = gr.Radio(choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS")
generate_btn = gr.Button("Synthesize", variant="primary")
with gr.Accordion("Advanced Settings", open=False):
ref_text_input = gr.Textbox(
@@ -206,23 +225,24 @@ with gr.Blocks() as app_tts:
],
outputs=[audio_output, spectrogram_output],
)
with gr.Blocks() as app_podcast:
gr.Markdown("# Podcast Generation")
speaker1_name = gr.Textbox(label="Speaker 1 Name")
ref_audio_input1 = gr.Audio(label="Reference Audio (Speaker 1)", type="filepath")
ref_text_input1 = gr.Textbox(label="Reference Text (Speaker 1)", lines=2)
speaker2_name = gr.Textbox(label="Speaker 2 Name")
ref_audio_input2 = gr.Audio(label="Reference Audio (Speaker 2)", type="filepath")
ref_text_input2 = gr.Textbox(label="Reference Text (Speaker 2)", lines=2)
script_input = gr.Textbox(label="Podcast Script", lines=10,
placeholder="Enter the script with speaker names at the start of each block, e.g.:\nSean: How did you start studying...\n\nMeghan: I came to my interest in technology...\nIt was a long journey...\n\nSean: That's fascinating. Can you elaborate...")
podcast_model_choice = gr.Radio(
choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
script_input = gr.Textbox(
label="Podcast Script",
lines=10,
placeholder="Enter the script with speaker names at the start of each block, e.g.:\nSean: How did you start studying...\n\nMeghan: I came to my interest in technology...\nIt was a long journey...\n\nSean: That's fascinating. Can you elaborate...",
)
podcast_model_choice = gr.Radio(choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS")
podcast_remove_silence = gr.Checkbox(
label="Remove Silences",
value=True,
@@ -230,8 +250,12 @@ with gr.Blocks() as app_podcast:
generate_podcast_btn = gr.Button("Generate Podcast", variant="primary")
podcast_output = gr.Audio(label="Generated Podcast")
def podcast_generation(script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence):
return generate_podcast(script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence)
def podcast_generation(
script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence
):
return generate_podcast(
script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence
)
generate_podcast_btn.click(
podcast_generation,
@@ -249,23 +273,24 @@ with gr.Blocks() as app_podcast:
outputs=podcast_output,
)
def parse_emotional_text(gen_text):
# Pattern to find (Emotion)
pattern = r'\((.*?)\)'
pattern = r"\((.*?)\)"
# Split the text by the pattern
tokens = re.split(pattern, gen_text)
segments = []
current_emotion = 'Regular'
current_emotion = "Regular"
for i in range(len(tokens)):
if i % 2 == 0:
# This is text
text = tokens[i].strip()
if text:
segments.append({'emotion': current_emotion, 'text': text})
segments.append({"emotion": current_emotion, "text": text})
else:
# This is emotion
emotion = tokens[i].strip()
@@ -273,6 +298,7 @@ def parse_emotional_text(gen_text):
return segments
with gr.Blocks() as app_emotional:
# New section for emotional generation
gr.Markdown(
@@ -287,13 +313,15 @@ with gr.Blocks() as app_emotional:
"""
)
gr.Markdown("Upload different audio clips for each speech type. 'Regular' emotion is mandatory. You can add additional speech types by clicking the 'Add Speech Type' button.")
gr.Markdown(
"Upload different audio clips for each speech type. 'Regular' emotion is mandatory. You can add additional speech types by clicking the 'Add Speech Type' button."
)
# Regular speech type (mandatory)
with gr.Row():
regular_name = gr.Textbox(value='Regular', label='Speech Type Name', interactive=False)
regular_audio = gr.Audio(label='Regular Reference Audio', type='filepath')
regular_ref_text = gr.Textbox(label='Reference Text (Regular)', lines=2)
regular_name = gr.Textbox(value="Regular", label="Speech Type Name", interactive=False)
regular_audio = gr.Audio(label="Regular Reference Audio", type="filepath")
regular_ref_text = gr.Textbox(label="Reference Text (Regular)", lines=2)
# Additional speech types (up to 99 more)
max_speech_types = 100
@@ -304,9 +332,9 @@ with gr.Blocks() as app_emotional:
for i in range(max_speech_types - 1):
with gr.Row():
name_input = gr.Textbox(label='Speech Type Name', visible=False)
audio_input = gr.Audio(label='Reference Audio', type='filepath', visible=False)
ref_text_input = gr.Textbox(label='Reference Text', lines=2, visible=False)
name_input = gr.Textbox(label="Speech Type Name", visible=False)
audio_input = gr.Audio(label="Reference Audio", type="filepath", visible=False)
ref_text_input = gr.Textbox(label="Reference Text", lines=2, visible=False)
delete_btn = gr.Button("Delete", variant="secondary", visible=False)
speech_type_names.append(name_input)
speech_type_audios.append(audio_input)
@@ -351,7 +379,11 @@ with gr.Blocks() as app_emotional:
add_speech_type_btn.click(
add_speech_type_fn,
inputs=speech_type_count,
outputs=[speech_type_count] + speech_type_names + speech_type_audios + speech_type_ref_texts + speech_type_delete_btns
outputs=[speech_type_count]
+ speech_type_names
+ speech_type_audios
+ speech_type_ref_texts
+ speech_type_delete_btns,
)
# Function to delete a speech type
@@ -365,9 +397,9 @@ with gr.Blocks() as app_emotional:
for i in range(max_speech_types - 1):
if i == index:
name_updates.append(gr.update(visible=False, value=''))
name_updates.append(gr.update(visible=False, value=""))
audio_updates.append(gr.update(visible=False, value=None))
ref_text_updates.append(gr.update(visible=False, value=''))
ref_text_updates.append(gr.update(visible=False, value=""))
delete_btn_updates.append(gr.update(visible=False))
else:
name_updates.append(gr.update())
@@ -386,16 +418,18 @@ with gr.Blocks() as app_emotional:
delete_btn.click(
delete_fn,
inputs=speech_type_count,
outputs=[speech_type_count] + speech_type_names + speech_type_audios + speech_type_ref_texts + speech_type_delete_btns
outputs=[speech_type_count]
+ speech_type_names
+ speech_type_audios
+ speech_type_ref_texts
+ speech_type_delete_btns,
)
# Text input for the prompt
gen_text_input_emotional = gr.Textbox(label="Text to Generate", lines=10)
# Model choice
model_choice_emotional = gr.Radio(
choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
)
model_choice_emotional = gr.Radio(choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS")
with gr.Accordion("Advanced Settings", open=False):
remove_silence_emotional = gr.Checkbox(
@@ -408,6 +442,7 @@ with gr.Blocks() as app_emotional:
# Output audio
audio_output_emotional = gr.Audio(label="Synthesized Audio")
@gpu_decorator
def generate_emotional_speech(
regular_audio,
@@ -417,37 +452,39 @@ with gr.Blocks() as app_emotional:
):
num_additional_speech_types = max_speech_types - 1
speech_type_names_list = args[:num_additional_speech_types]
speech_type_audios_list = args[num_additional_speech_types:2 * num_additional_speech_types]
speech_type_ref_texts_list = args[2 * num_additional_speech_types:3 * num_additional_speech_types]
speech_type_audios_list = args[num_additional_speech_types : 2 * num_additional_speech_types]
speech_type_ref_texts_list = args[2 * num_additional_speech_types : 3 * num_additional_speech_types]
model_choice = args[3 * num_additional_speech_types]
remove_silence = args[3 * num_additional_speech_types + 1]
# Collect the speech types and their audios into a dict
speech_types = {'Regular': {'audio': regular_audio, 'ref_text': regular_ref_text}}
speech_types = {"Regular": {"audio": regular_audio, "ref_text": regular_ref_text}}
for name_input, audio_input, ref_text_input in zip(speech_type_names_list, speech_type_audios_list, speech_type_ref_texts_list):
for name_input, audio_input, ref_text_input in zip(
speech_type_names_list, speech_type_audios_list, speech_type_ref_texts_list
):
if name_input and audio_input:
speech_types[name_input] = {'audio': audio_input, 'ref_text': ref_text_input}
speech_types[name_input] = {"audio": audio_input, "ref_text": ref_text_input}
# Parse the gen_text into segments
segments = parse_speechtypes_text(gen_text)
# For each segment, generate speech
generated_audio_segments = []
current_emotion = 'Regular'
current_emotion = "Regular"
for segment in segments:
emotion = segment['emotion']
text = segment['text']
emotion = segment["emotion"]
text = segment["text"]
if emotion in speech_types:
current_emotion = emotion
else:
# If emotion not available, default to Regular
current_emotion = 'Regular'
current_emotion = "Regular"
ref_audio = speech_types[current_emotion]['audio']
ref_text = speech_types[current_emotion].get('ref_text', '')
ref_audio = speech_types[current_emotion]["audio"]
ref_text = speech_types[current_emotion].get("ref_text", "")
# Generate speech for this segment
audio, _ = infer(ref_audio, ref_text, text, model_choice, remove_silence, 0)
@@ -469,7 +506,11 @@ with gr.Blocks() as app_emotional:
regular_audio,
regular_ref_text,
gen_text_input_emotional,
] + speech_type_names + speech_type_audios + speech_type_ref_texts + [
]
+ speech_type_names
+ speech_type_audios
+ speech_type_ref_texts
+ [
model_choice_emotional,
remove_silence_emotional,
],
@@ -477,11 +518,7 @@ with gr.Blocks() as app_emotional:
)
# Validation function to disable Generate button if speech types are missing
def validate_speech_types(
gen_text,
regular_name,
*args
):
def validate_speech_types(gen_text, regular_name, *args):
num_additional_speech_types = max_speech_types - 1
speech_type_names_list = args[:num_additional_speech_types]
@@ -495,7 +532,7 @@ with gr.Blocks() as app_emotional:
# Parse the gen_text to get the speech types used
segments = parse_emotional_text(gen_text)
speech_types_in_text = set(segment['emotion'] for segment in segments)
speech_types_in_text = set(segment["emotion"] for segment in segments)
# Check if all speech types in text are available
missing_speech_types = speech_types_in_text - speech_types_available
@@ -510,7 +547,7 @@ with gr.Blocks() as app_emotional:
gen_text_input_emotional.change(
validate_speech_types,
inputs=[gen_text_input_emotional, regular_name] + speech_type_names,
outputs=generate_emotional_btn
outputs=generate_emotional_btn,
)
with gr.Blocks() as app:
gr.Markdown(
@@ -531,6 +568,7 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
)
gr.TabbedInterface([app_tts, app_podcast, app_emotional, app_credits], ["TTS", "Podcast", "Multi-Style", "Credits"])
@click.command()
@click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
@click.option("--host", "-H", default=None, help="Host to run the app on")
@@ -544,10 +582,8 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
@click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
def main(port, host, share, api):
global app
print(f"Starting app...")
app.queue(api_open=api).launch(
server_name=host, server_port=port, share=share, show_api=api
)
print("Starting app...")
app.queue(api_open=api).launch(server_name=host, server_port=port, share=share, show_api=api)
if __name__ == "__main__":

View File

@@ -44,19 +44,8 @@ parser.add_argument(
"--vocab_file",
help="The vocab .txt",
)
parser.add_argument(
"-r",
"--ref_audio",
type=str,
help="Reference audio file < 15 seconds."
)
parser.add_argument(
"-s",
"--ref_text",
type=str,
default="666",
help="Subtitle for the reference audio."
)
parser.add_argument("-r", "--ref_audio", type=str, help="Reference audio file < 15 seconds.")
parser.add_argument("-s", "--ref_text", type=str, default="666", help="Subtitle for the reference audio.")
parser.add_argument(
"-t",
"--gen_text",
@@ -99,8 +88,8 @@ model = args.model if args.model else config["model"]
ckpt_file = args.ckpt_file if args.ckpt_file else ""
vocab_file = args.vocab_file if args.vocab_file else ""
remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
wave_path = Path(output_dir)/"out.wav"
spectrogram_path = Path(output_dir)/"out.png"
wave_path = Path(output_dir) / "out.wav"
spectrogram_path = Path(output_dir) / "out.png"
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
vocos = load_vocoder(is_local=args.load_vocoder_from_local, local_path=vocos_local_path)
@@ -110,44 +99,46 @@ vocos = load_vocoder(is_local=args.load_vocoder_from_local, local_path=vocos_loc
if model == "F5-TTS":
model_cls = DiT
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
if ckpt_file == "":
repo_name= "F5-TTS"
if ckpt_file == "":
repo_name = "F5-TTS"
exp_name = "F5TTS_Base"
ckpt_step= 1200000
ckpt_step = 1200000
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
# ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
elif model == "E2-TTS":
model_cls = UNetT
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
if ckpt_file == "":
repo_name= "E2-TTS"
if ckpt_file == "":
repo_name = "E2-TTS"
exp_name = "E2TTS_Base"
ckpt_step= 1200000
ckpt_step = 1200000
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
# ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
print(f"Using {model}...")
ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file)
def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence):
main_voice = {"ref_audio":ref_audio, "ref_text":ref_text}
main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
if "voices" not in config:
voices = {"main": main_voice}
else:
voices = config["voices"]
voices["main"] = main_voice
for voice in voices:
voices[voice]['ref_audio'], voices[voice]['ref_text'] = preprocess_ref_audio_text(voices[voice]['ref_audio'], voices[voice]['ref_text'])
voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
voices[voice]["ref_audio"], voices[voice]["ref_text"]
)
print("Voice:", voice)
print("Ref_audio:", voices[voice]['ref_audio'])
print("Ref_text:", voices[voice]['ref_text'])
print("Ref_audio:", voices[voice]["ref_audio"])
print("Ref_text:", voices[voice]["ref_text"])
generated_audio_segments = []
reg1 = r'(?=\[\w+\])'
reg1 = r"(?=\[\w+\])"
chunks = re.split(reg1, text_gen)
reg2 = r'\[(\w+)\]'
reg2 = r"\[(\w+)\]"
for text in chunks:
match = re.match(reg2, text)
if match:
@@ -160,8 +151,8 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence):
voice = "main"
text = re.sub(reg2, "", text)
gen_text = text.strip()
ref_audio = voices[voice]['ref_audio']
ref_text = voices[voice]['ref_text']
ref_audio = voices[voice]["ref_audio"]
ref_text = voices[voice]["ref_text"]
print(f"Voice: {voice}")
audio, final_sample_rate, spectragram = infer_process(ref_audio, ref_text, gen_text, model_obj)
generated_audio_segments.append(audio)

View File

@@ -5,3 +5,6 @@ from model.backbones.dit import DiT
from model.backbones.mmdit import MMDiT
from model.trainer import Trainer
__all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"]

View File

@@ -21,14 +21,16 @@ from model.modules import (
ConvPositionEmbedding,
DiTBlock,
AdaLayerNormZero_Final,
precompute_freqs_cis, get_pos_embed_indices,
precompute_freqs_cis,
get_pos_embed_indices,
)
# Text embedding
class TextEmbedding(nn.Module):
def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2):
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
super().__init__()
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
@@ -36,20 +38,22 @@ class TextEmbedding(nn.Module):
self.extra_modeling = True
self.precompute_max_pos = 4096 # ~44s of 24khz audio
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
self.text_blocks = nn.Sequential(
*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
)
else:
self.extra_modeling = False
def forward(self, text: int['b nt'], seq_len, drop_text = False):
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
batch, text_len = text.shape[0], text.shape[1]
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
text = F.pad(text, (0, seq_len - text_len), value = 0)
text = F.pad(text, (0, seq_len - text_len), value=0)
if drop_text: # cfg for text
text = torch.zeros_like(text)
text = self.text_embed(text) # b n -> b n d
text = self.text_embed(text) # b n -> b n d
# possible extra modeling
if self.extra_modeling:
@@ -67,88 +71,91 @@ class TextEmbedding(nn.Module):
# noised input audio and context mixing embedding
class InputEmbedding(nn.Module):
def __init__(self, mel_dim, text_dim, out_dim):
super().__init__()
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
self.conv_pos_embed = ConvPositionEmbedding(dim = out_dim)
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
def forward(self, x: float['b n d'], cond: float['b n d'], text_embed: float['b n d'], drop_audio_cond = False):
def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
if drop_audio_cond: # cfg for cond audio
cond = torch.zeros_like(cond)
x = self.proj(torch.cat((x, cond, text_embed), dim = -1))
x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
x = self.conv_pos_embed(x) + x
return x
# Transformer backbone using DiT blocks
class DiT(nn.Module):
def __init__(self, *,
dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
mel_dim = 100, text_num_embeds = 256, text_dim = None, conv_layers = 0,
long_skip_connection = False,
def __init__(
self,
*,
dim,
depth=8,
heads=8,
dim_head=64,
dropout=0.1,
ff_mult=4,
mel_dim=100,
text_num_embeds=256,
text_dim=None,
conv_layers=0,
long_skip_connection=False,
):
super().__init__()
self.time_embed = TimestepEmbedding(dim)
if text_dim is None:
text_dim = mel_dim
self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers = conv_layers)
self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
self.rotary_embed = RotaryEmbedding(dim_head)
self.dim = dim
self.depth = depth
self.transformer_blocks = nn.ModuleList(
[
DiTBlock(
dim = dim,
heads = heads,
dim_head = dim_head,
ff_mult = ff_mult,
dropout = dropout
)
for _ in range(depth)
]
[DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
)
self.long_skip_connection = nn.Linear(dim * 2, dim, bias = False) if long_skip_connection else None
self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
self.proj_out = nn.Linear(dim, mel_dim)
def forward(
self,
x: float['b n d'], # nosied input audio
cond: float['b n d'], # masked cond audio
text: int['b nt'], # text
time: float['b'] | float[''], # time step
x: float["b n d"], # nosied input audio # noqa: F722
cond: float["b n d"], # masked cond audio # noqa: F722
text: int["b nt"], # text # noqa: F722
time: float["b"] | float[""], # time step # noqa: F821 F722
drop_audio_cond, # cfg for cond audio
drop_text, # cfg for text
mask: bool['b n'] | None = None,
drop_text, # cfg for text
mask: bool["b n"] | None = None, # noqa: F722
):
batch, seq_len = x.shape[0], x.shape[1]
if time.ndim == 0:
time = time.repeat(batch)
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(time)
text_embed = self.text_embed(text, seq_len, drop_text = drop_text)
x = self.input_embed(x, cond, text_embed, drop_audio_cond = drop_audio_cond)
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
rope = self.rotary_embed.forward_from_seq_len(seq_len)
if self.long_skip_connection is not None:
residual = x
for block in self.transformer_blocks:
x = block(x, t, mask = mask, rope = rope)
x = block(x, t, mask=mask, rope=rope)
if self.long_skip_connection is not None:
x = self.long_skip_connection(torch.cat((x, residual), dim = -1))
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
x = self.norm_out(x, t)
output = self.proj_out(x)

View File

@@ -19,12 +19,14 @@ from model.modules import (
ConvPositionEmbedding,
MMDiTBlock,
AdaLayerNormZero_Final,
precompute_freqs_cis, get_pos_embed_indices,
precompute_freqs_cis,
get_pos_embed_indices,
)
# text embedding
class TextEmbedding(nn.Module):
def __init__(self, out_dim, text_num_embeds):
super().__init__()
@@ -33,7 +35,7 @@ class TextEmbedding(nn.Module):
self.precompute_max_pos = 1024
self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
def forward(self, text: int['b nt'], drop_text = False) -> int['b nt d']:
def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722
text = text + 1
if drop_text:
text = torch.zeros_like(text)
@@ -52,27 +54,37 @@ class TextEmbedding(nn.Module):
# noised input & masked cond audio embedding
class AudioEmbedding(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.linear = nn.Linear(2 * in_dim, out_dim)
self.conv_pos_embed = ConvPositionEmbedding(out_dim)
def forward(self, x: float['b n d'], cond: float['b n d'], drop_audio_cond = False):
def forward(self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False): # noqa: F722
if drop_audio_cond:
cond = torch.zeros_like(cond)
x = torch.cat((x, cond), dim = -1)
x = torch.cat((x, cond), dim=-1)
x = self.linear(x)
x = self.conv_pos_embed(x) + x
return x
# Transformer backbone using MM-DiT blocks
class MMDiT(nn.Module):
def __init__(self, *,
dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
text_num_embeds = 256, mel_dim = 100,
def __init__(
self,
*,
dim,
depth=8,
heads=8,
dim_head=64,
dropout=0.1,
ff_mult=4,
text_num_embeds=256,
mel_dim=100,
):
super().__init__()
@@ -84,16 +96,16 @@ class MMDiT(nn.Module):
self.dim = dim
self.depth = depth
self.transformer_blocks = nn.ModuleList(
[
MMDiTBlock(
dim = dim,
heads = heads,
dim_head = dim_head,
dropout = dropout,
ff_mult = ff_mult,
context_pre_only = i == depth - 1,
dim=dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
ff_mult=ff_mult,
context_pre_only=i == depth - 1,
)
for i in range(depth)
]
@@ -103,13 +115,13 @@ class MMDiT(nn.Module):
def forward(
self,
x: float['b n d'], # nosied input audio
cond: float['b n d'], # masked cond audio
text: int['b nt'], # text
time: float['b'] | float[''], # time step
x: float["b n d"], # nosied input audio # noqa: F722
cond: float["b n d"], # masked cond audio # noqa: F722
text: int["b nt"], # text # noqa: F722
time: float["b"] | float[""], # time step # noqa: F821 F722
drop_audio_cond, # cfg for cond audio
drop_text, # cfg for text
mask: bool['b n'] | None = None,
drop_text, # cfg for text
mask: bool["b n"] | None = None, # noqa: F722
):
batch = x.shape[0]
if time.ndim == 0:
@@ -117,16 +129,16 @@ class MMDiT(nn.Module):
# t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(time)
c = self.text_embed(text, drop_text = drop_text)
x = self.audio_embed(x, cond, drop_audio_cond = drop_audio_cond)
c = self.text_embed(text, drop_text=drop_text)
x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
seq_len = x.shape[1]
text_len = text.shape[1]
rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
rope_text = self.rotary_embed.forward_from_seq_len(text_len)
for block in self.transformer_blocks:
c, x = block(x, c, t, mask = mask, rope = rope_audio, c_rope = rope_text)
c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text)
x = self.norm_out(x, t)
output = self.proj_out(x)

View File

@@ -24,14 +24,16 @@ from model.modules import (
Attention,
AttnProcessor,
FeedForward,
precompute_freqs_cis, get_pos_embed_indices,
precompute_freqs_cis,
get_pos_embed_indices,
)
# Text embedding
class TextEmbedding(nn.Module):
def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2):
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
super().__init__()
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
@@ -39,20 +41,22 @@ class TextEmbedding(nn.Module):
self.extra_modeling = True
self.precompute_max_pos = 4096 # ~44s of 24khz audio
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
self.text_blocks = nn.Sequential(
*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
)
else:
self.extra_modeling = False
def forward(self, text: int['b nt'], seq_len, drop_text = False):
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
batch, text_len = text.shape[0], text.shape[1]
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
text = F.pad(text, (0, seq_len - text_len), value = 0)
text = F.pad(text, (0, seq_len - text_len), value=0)
if drop_text: # cfg for text
text = torch.zeros_like(text)
text = self.text_embed(text) # b n -> b n d
text = self.text_embed(text) # b n -> b n d
# possible extra modeling
if self.extra_modeling:
@@ -70,28 +74,40 @@ class TextEmbedding(nn.Module):
# noised input audio and context mixing embedding
class InputEmbedding(nn.Module):
def __init__(self, mel_dim, text_dim, out_dim):
super().__init__()
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
self.conv_pos_embed = ConvPositionEmbedding(dim = out_dim)
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
def forward(self, x: float['b n d'], cond: float['b n d'], text_embed: float['b n d'], drop_audio_cond = False):
def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
if drop_audio_cond: # cfg for cond audio
cond = torch.zeros_like(cond)
x = self.proj(torch.cat((x, cond, text_embed), dim = -1))
x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
x = self.conv_pos_embed(x) + x
return x
# Flat UNet Transformer backbone
class UNetT(nn.Module):
def __init__(self, *,
dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
mel_dim = 100, text_num_embeds = 256, text_dim = None, conv_layers = 0,
skip_connect_type: Literal['add', 'concat', 'none'] = 'concat',
def __init__(
self,
*,
dim,
depth=8,
heads=8,
dim_head=64,
dropout=0.1,
ff_mult=4,
mel_dim=100,
text_num_embeds=256,
text_dim=None,
conv_layers=0,
skip_connect_type: Literal["add", "concat", "none"] = "concat",
):
super().__init__()
assert depth % 2 == 0, "UNet-Transformer's depth should be even."
@@ -99,7 +115,7 @@ class UNetT(nn.Module):
self.time_embed = TimestepEmbedding(dim)
if text_dim is None:
text_dim = mel_dim
self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers = conv_layers)
self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
self.rotary_embed = RotaryEmbedding(dim_head)
@@ -108,7 +124,7 @@ class UNetT(nn.Module):
self.dim = dim
self.skip_connect_type = skip_connect_type
needs_skip_proj = skip_connect_type == 'concat'
needs_skip_proj = skip_connect_type == "concat"
self.depth = depth
self.layers = nn.ModuleList([])
@@ -118,53 +134,57 @@ class UNetT(nn.Module):
attn_norm = RMSNorm(dim)
attn = Attention(
processor = AttnProcessor(),
dim = dim,
heads = heads,
dim_head = dim_head,
dropout = dropout,
)
processor=AttnProcessor(),
dim=dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
)
ff_norm = RMSNorm(dim)
ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
skip_proj = nn.Linear(dim * 2, dim, bias = False) if needs_skip_proj and is_later_half else None
skip_proj = nn.Linear(dim * 2, dim, bias=False) if needs_skip_proj and is_later_half else None
self.layers.append(nn.ModuleList([
skip_proj,
attn_norm,
attn,
ff_norm,
ff,
]))
self.layers.append(
nn.ModuleList(
[
skip_proj,
attn_norm,
attn,
ff_norm,
ff,
]
)
)
self.norm_out = RMSNorm(dim)
self.proj_out = nn.Linear(dim, mel_dim)
def forward(
self,
x: float['b n d'], # nosied input audio
cond: float['b n d'], # masked cond audio
text: int['b nt'], # text
time: float['b'] | float[''], # time step
x: float["b n d"], # nosied input audio # noqa: F722
cond: float["b n d"], # masked cond audio # noqa: F722
text: int["b nt"], # text # noqa: F722
time: float["b"] | float[""], # time step # noqa: F821 F722
drop_audio_cond, # cfg for cond audio
drop_text, # cfg for text
mask: bool['b n'] | None = None,
drop_text, # cfg for text
mask: bool["b n"] | None = None, # noqa: F722
):
batch, seq_len = x.shape[0], x.shape[1]
if time.ndim == 0:
time = time.repeat(batch)
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(time)
text_embed = self.text_embed(text, seq_len, drop_text = drop_text)
x = self.input_embed(x, cond, text_embed, drop_audio_cond = drop_audio_cond)
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
# postfix time t to input x, [b n d] -> [b n+1 d]
x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x
if mask is not None:
mask = F.pad(mask, (1, 0), value=1)
rope = self.rotary_embed.forward_from_seq_len(seq_len + 1)
# flat unet transformer
@@ -182,14 +202,14 @@ class UNetT(nn.Module):
if is_later_half:
skip = skips.pop()
if skip_connect_type == 'concat':
x = torch.cat((x, skip), dim = -1)
if skip_connect_type == "concat":
x = torch.cat((x, skip), dim=-1)
x = maybe_skip_proj(x)
elif skip_connect_type == 'add':
elif skip_connect_type == "add":
x = x + skip
# attention and feedforward blocks
x = attn(attn_norm(x), rope = rope, mask = mask) + x
x = attn(attn_norm(x), rope=rope, mask=mask) + x
x = ff(ff_norm(x)) + x
assert len(skips) == 0

View File

@@ -20,29 +20,32 @@ from torchdiffeq import odeint
from model.modules import MelSpec
from model.utils import (
default, exists,
list_str_to_idx, list_str_to_tensor,
lens_to_mask, mask_from_frac_lengths,
)
default,
exists,
list_str_to_idx,
list_str_to_tensor,
lens_to_mask,
mask_from_frac_lengths,
)
class CFM(nn.Module):
def __init__(
self,
transformer: nn.Module,
sigma = 0.,
sigma=0.0,
odeint_kwargs: dict = dict(
# atol = 1e-5,
# rtol = 1e-5,
method = 'euler' # 'midpoint'
method="euler" # 'midpoint'
),
audio_drop_prob = 0.3,
cond_drop_prob = 0.2,
num_channels = None,
audio_drop_prob=0.3,
cond_drop_prob=0.2,
num_channels=None,
mel_spec_module: nn.Module | None = None,
mel_spec_kwargs: dict = dict(),
frac_lengths_mask: tuple[float, float] = (0.7, 1.),
vocab_char_map: dict[str: int] | None = None
frac_lengths_mask: tuple[float, float] = (0.7, 1.0),
vocab_char_map: dict[str:int] | None = None,
):
super().__init__()
@@ -78,21 +81,21 @@ class CFM(nn.Module):
@torch.no_grad()
def sample(
self,
cond: float['b n d'] | float['b nw'],
text: int['b nt'] | list[str],
duration: int | int['b'],
cond: float["b n d"] | float["b nw"], # noqa: F722
text: int["b nt"] | list[str], # noqa: F722
duration: int | int["b"], # noqa: F821
*,
lens: int['b'] | None = None,
steps = 32,
cfg_strength = 1.,
sway_sampling_coef = None,
lens: int["b"] | None = None, # noqa: F821
steps=32,
cfg_strength=1.0,
sway_sampling_coef=None,
seed: int | None = None,
max_duration = 4096,
vocoder: Callable[[float['b d n']], float['b nw']] | None = None,
no_ref_audio = False,
duplicate_test = False,
t_inter = 0.1,
edit_mask = None,
max_duration=4096,
vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
no_ref_audio=False,
duplicate_test=False,
t_inter=0.1,
edit_mask=None,
):
self.eval()
@@ -108,7 +111,7 @@ class CFM(nn.Module):
batch, cond_seq_len, device = *cond.shape[:2], cond.device
if not exists(lens):
lens = torch.full((batch,), cond_seq_len, device = device, dtype = torch.long)
lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long)
# text
@@ -120,8 +123,8 @@ class CFM(nn.Module):
assert text.shape[0] == batch
if exists(text):
text_lens = (text != -1).sum(dim = -1)
lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
text_lens = (text != -1).sum(dim=-1)
lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
# duration
@@ -130,20 +133,22 @@ class CFM(nn.Module):
cond_mask = cond_mask & edit_mask
if isinstance(duration, int):
duration = torch.full((batch,), duration, device = device, dtype = torch.long)
duration = torch.full((batch,), duration, device=device, dtype=torch.long)
duration = torch.maximum(lens + 1, duration) # just add one token so something is generated
duration = duration.clamp(max = max_duration)
duration = torch.maximum(lens + 1, duration) # just add one token so something is generated
duration = duration.clamp(max=max_duration)
max_duration = duration.amax()
# duplicate test corner for inner time step oberservation
if duplicate_test:
test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2*cond_seq_len), value = 0.)
cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value = 0.)
cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value = False)
test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0)
cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False)
cond_mask = cond_mask.unsqueeze(-1)
step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) # allow direct control (cut cond audio) with lens passed in
step_cond = torch.where(
cond_mask, cond, torch.zeros_like(cond)
) # allow direct control (cut cond audio) with lens passed in
if batch > 1:
mask = lens_to_mask(duration)
@@ -161,11 +166,15 @@ class CFM(nn.Module):
# step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
# predict flow
pred = self.transformer(x = x, cond = step_cond, text = text, time = t, mask = mask, drop_audio_cond = False, drop_text = False)
pred = self.transformer(
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False
)
if cfg_strength < 1e-5:
return pred
null_pred = self.transformer(x = x, cond = step_cond, text = text, time = t, mask = mask, drop_audio_cond = True, drop_text = True)
null_pred = self.transformer(
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True
)
return pred + (pred - null_pred) * cfg_strength
# noise input
@@ -175,8 +184,8 @@ class CFM(nn.Module):
for dur in duration:
if exists(seed):
torch.manual_seed(seed)
y0.append(torch.randn(dur, self.num_channels, device = self.device, dtype=step_cond.dtype))
y0 = pad_sequence(y0, padding_value = 0, batch_first = True)
y0.append(torch.randn(dur, self.num_channels, device=self.device, dtype=step_cond.dtype))
y0 = pad_sequence(y0, padding_value=0, batch_first=True)
t_start = 0
@@ -186,12 +195,12 @@ class CFM(nn.Module):
y0 = (1 - t_start) * y0 + t_start * test_cond
steps = int(steps * (1 - t_start))
t = torch.linspace(t_start, 1, steps, device = self.device, dtype=step_cond.dtype)
t = torch.linspace(t_start, 1, steps, device=self.device, dtype=step_cond.dtype)
if sway_sampling_coef is not None:
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
sampled = trajectory[-1]
out = sampled
out = torch.where(cond_mask, cond, out)
@@ -204,10 +213,10 @@ class CFM(nn.Module):
def forward(
self,
inp: float['b n d'] | float['b nw'], # mel or raw wave
text: int['b nt'] | list[str],
inp: float["b n d"] | float["b nw"], # mel or raw wave # noqa: F722
text: int["b nt"] | list[str], # noqa: F722
*,
lens: int['b'] | None = None,
lens: int["b"] | None = None, # noqa: F821
noise_scheduler: str | None = None,
):
# handle raw wave
@@ -216,7 +225,7 @@ class CFM(nn.Module):
inp = inp.permute(0, 2, 1)
assert inp.shape[-1] == self.num_channels
batch, seq_len, dtype, device, σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma
batch, seq_len, dtype, device, _σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma
# handle text as string
if isinstance(text, list):
@@ -228,12 +237,12 @@ class CFM(nn.Module):
# lens and mask
if not exists(lens):
lens = torch.full((batch,), seq_len, device = device)
mask = lens_to_mask(lens, length = seq_len) # useless here, as collate_fn will pad to max length in batch
lens = torch.full((batch,), seq_len, device=device)
mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch
# get a random span to mask out for training conditionally
frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(*self.frac_lengths_mask)
frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask)
rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
if exists(mask):
@@ -246,7 +255,7 @@ class CFM(nn.Module):
x0 = torch.randn_like(x1)
# time step
time = torch.rand((batch,), dtype = dtype, device = self.device)
time = torch.rand((batch,), dtype=dtype, device=self.device)
# TODO. noise_scheduler
# sample xt (φ_t(x) in the paper)
@@ -255,10 +264,7 @@ class CFM(nn.Module):
flow = x1 - x0
# only predict what is within the random mask span for infilling
cond = torch.where(
rand_span_mask[..., None],
torch.zeros_like(x1), x1
)
cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1)
# transformer and cfg training with a drop rate
drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper
@@ -267,13 +273,15 @@ class CFM(nn.Module):
drop_text = True
else:
drop_text = False
# if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
# adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
pred = self.transformer(x = φ, cond = cond, text = text, time = time, drop_audio_cond = drop_audio_cond, drop_text = drop_text)
pred = self.transformer(
x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text
)
# flow matching loss
loss = F.mse_loss(pred, flow, reduction = 'none')
loss = F.mse_loss(pred, flow, reduction="none")
loss = loss[rand_span_mask]
return loss.mean(), cond, pred

View File

@@ -6,7 +6,7 @@ import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, Sampler
import torchaudio
from datasets import load_dataset, load_from_disk
from datasets import load_from_disk
from datasets import Dataset as Dataset_
from model.modules import MelSpec
@@ -16,53 +16,55 @@ class HFDataset(Dataset):
def __init__(
self,
hf_dataset: Dataset,
target_sample_rate = 24_000,
n_mel_channels = 100,
hop_length = 256,
target_sample_rate=24_000,
n_mel_channels=100,
hop_length=256,
):
self.data = hf_dataset
self.target_sample_rate = target_sample_rate
self.hop_length = hop_length
self.mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length)
self.mel_spectrogram = MelSpec(
target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length
)
def get_frame_len(self, index):
row = self.data[index]
audio = row['audio']['array']
sample_rate = row['audio']['sampling_rate']
audio = row["audio"]["array"]
sample_rate = row["audio"]["sampling_rate"]
return audio.shape[-1] / sample_rate * self.target_sample_rate / self.hop_length
def __len__(self):
return len(self.data)
def __getitem__(self, index):
row = self.data[index]
audio = row['audio']['array']
audio = row["audio"]["array"]
# logger.info(f"Audio shape: {audio.shape}")
sample_rate = row['audio']['sampling_rate']
sample_rate = row["audio"]["sampling_rate"]
duration = audio.shape[-1] / sample_rate
if duration > 30 or duration < 0.3:
return self.__getitem__((index + 1) % len(self.data))
audio_tensor = torch.from_numpy(audio).float()
if sample_rate != self.target_sample_rate:
resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
audio_tensor = resampler(audio_tensor)
audio_tensor = audio_tensor.unsqueeze(0) # 't -> 1 t')
mel_spec = self.mel_spectrogram(audio_tensor)
mel_spec = mel_spec.squeeze(0) # '1 d t -> d t'
text = row['text']
text = row["text"]
return dict(
mel_spec = mel_spec,
text = text,
mel_spec=mel_spec,
text=text,
)
@@ -70,11 +72,11 @@ class CustomDataset(Dataset):
def __init__(
self,
custom_dataset: Dataset,
durations = None,
target_sample_rate = 24_000,
hop_length = 256,
n_mel_channels = 100,
preprocessed_mel = False,
durations=None,
target_sample_rate=24_000,
hop_length=256,
n_mel_channels=100,
preprocessed_mel=False,
):
self.data = custom_dataset
self.durations = durations
@@ -82,16 +84,20 @@ class CustomDataset(Dataset):
self.hop_length = hop_length
self.preprocessed_mel = preprocessed_mel
if not preprocessed_mel:
self.mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, hop_length=hop_length, n_mel_channels=n_mel_channels)
self.mel_spectrogram = MelSpec(
target_sample_rate=target_sample_rate, hop_length=hop_length, n_mel_channels=n_mel_channels
)
def get_frame_len(self, index):
if self.durations is not None: # Please make sure the separately provided durations are correct, otherwise 99.99% OOM
if (
self.durations is not None
): # Please make sure the separately provided durations are correct, otherwise 99.99% OOM
return self.durations[index] * self.target_sample_rate / self.hop_length
return self.data[index]["duration"] * self.target_sample_rate / self.hop_length
def __len__(self):
return len(self.data)
def __getitem__(self, index):
row = self.data[index]
audio_path = row["audio_path"]
@@ -108,45 +114,52 @@ class CustomDataset(Dataset):
if duration > 30 or duration < 0.3:
return self.__getitem__((index + 1) % len(self.data))
if source_sample_rate != self.target_sample_rate:
resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
audio = resampler(audio)
mel_spec = self.mel_spectrogram(audio)
mel_spec = mel_spec.squeeze(0) # '1 d t -> d t')
return dict(
mel_spec = mel_spec,
text = text,
mel_spec=mel_spec,
text=text,
)
# Dynamic Batch Sampler
class DynamicBatchSampler(Sampler[list[int]]):
""" Extension of Sampler that will do the following:
1. Change the batch size (essentially number of sequences)
in a batch to ensure that the total number of frames are less
than a certain threshold.
2. Make sure the padding efficiency in the batch is high.
"""Extension of Sampler that will do the following:
1. Change the batch size (essentially number of sequences)
in a batch to ensure that the total number of frames are less
than a certain threshold.
2. Make sure the padding efficiency in the batch is high.
"""
def __init__(self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False):
def __init__(
self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False
):
self.sampler = sampler
self.frames_threshold = frames_threshold
self.max_samples = max_samples
indices, batches = [], []
data_source = self.sampler.data_source
for idx in tqdm(self.sampler, desc=f"Sorting with sampler... if slow, check whether dataset is provided with duration"):
for idx in tqdm(
self.sampler, desc="Sorting with sampler... if slow, check whether dataset is provided with duration"
):
indices.append((idx, data_source.get_frame_len(idx)))
indices.sort(key=lambda elem : elem[1])
indices.sort(key=lambda elem: elem[1])
batch = []
batch_frames = 0
for idx, frame_len in tqdm(indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"):
for idx, frame_len in tqdm(
indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"
):
if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples):
batch.append(idx)
batch_frames += frame_len
@@ -182,76 +195,86 @@ class DynamicBatchSampler(Sampler[list[int]]):
# Load dataset
def load_dataset(
dataset_name: str,
tokenizer: str = "pinyin",
dataset_type: str = "CustomDataset",
audio_type: str = "raw",
mel_spec_kwargs: dict = dict()
) -> CustomDataset | HFDataset:
'''
dataset_name: str,
tokenizer: str = "pinyin",
dataset_type: str = "CustomDataset",
audio_type: str = "raw",
mel_spec_kwargs: dict = dict(),
) -> CustomDataset | HFDataset:
"""
dataset_type - "CustomDataset" if you want to use tokenizer name and default data path to load for train_dataset
- "CustomDatasetPath" if you just want to pass the full path to a preprocessed dataset without relying on tokenizer
'''
"""
print("Loading dataset ...")
if dataset_type == "CustomDataset":
if audio_type == "raw":
try:
train_dataset = load_from_disk(f"data/{dataset_name}_{tokenizer}/raw")
except:
except: # noqa: E722
train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/raw.arrow")
preprocessed_mel = False
elif audio_type == "mel":
train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/mel.arrow")
preprocessed_mel = True
with open(f"data/{dataset_name}_{tokenizer}/duration.json", 'r', encoding='utf-8') as f:
with open(f"data/{dataset_name}_{tokenizer}/duration.json", "r", encoding="utf-8") as f:
data_dict = json.load(f)
durations = data_dict["duration"]
train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs)
train_dataset = CustomDataset(
train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs
)
elif dataset_type == "CustomDatasetPath":
try:
train_dataset = load_from_disk(f"{dataset_name}/raw")
except:
except: # noqa: E722
train_dataset = Dataset_.from_file(f"{dataset_name}/raw.arrow")
with open(f"{dataset_name}/duration.json", 'r', encoding='utf-8') as f:
with open(f"{dataset_name}/duration.json", "r", encoding="utf-8") as f:
data_dict = json.load(f)
durations = data_dict["duration"]
train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs)
train_dataset = CustomDataset(
train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs
)
elif dataset_type == "HFDataset":
print("Should manually modify the path of huggingface dataset to your need.\n" +
"May also the corresponding script cuz different dataset may have different format.")
print(
"Should manually modify the path of huggingface dataset to your need.\n"
+ "May also the corresponding script cuz different dataset may have different format."
)
pre, post = dataset_name.split("_")
train_dataset = HFDataset(load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir="./data"),)
train_dataset = HFDataset(
load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir="./data"),
)
return train_dataset
# collation
def collate_fn(batch):
mel_specs = [item['mel_spec'].squeeze(0) for item in batch]
mel_specs = [item["mel_spec"].squeeze(0) for item in batch]
mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
max_mel_length = mel_lengths.amax()
padded_mel_specs = []
for spec in mel_specs: # TODO. maybe records mask for attention here
padding = (0, max_mel_length - spec.size(-1))
padded_spec = F.pad(spec, padding, value = 0)
padded_spec = F.pad(spec, padding, value=0)
padded_mel_specs.append(padded_spec)
mel_specs = torch.stack(padded_mel_specs)
text = [item['text'] for item in batch]
text = [item["text"] for item in batch]
text_lengths = torch.LongTensor([len(item) for item in text])
return dict(
mel = mel_specs,
mel_lengths = mel_lengths,
text = text,
text_lengths = text_lengths,
mel=mel_specs,
mel_lengths=mel_lengths,
text=text,
text_lengths=text_lengths,
)

View File

@@ -9,13 +9,14 @@ import torch.nn as nn
import torch.nn.functional as F
''' Res2Conv1d + BatchNorm1d + ReLU
'''
""" Res2Conv1d + BatchNorm1d + ReLU
"""
class Res2Conv1dReluBn(nn.Module):
'''
"""
in_channels == out_channels == channels
'''
"""
def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
super().__init__()
@@ -51,8 +52,9 @@ class Res2Conv1dReluBn(nn.Module):
return out
''' Conv1d + BatchNorm1d + ReLU
'''
""" Conv1d + BatchNorm1d + ReLU
"""
class Conv1dReluBn(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
@@ -64,8 +66,9 @@ class Conv1dReluBn(nn.Module):
return self.bn(F.relu(self.conv(x)))
''' The SE connection of 1D case.
'''
""" The SE connection of 1D case.
"""
class SE_Connect(nn.Module):
def __init__(self, channels, se_bottleneck_dim=128):
@@ -82,8 +85,8 @@ class SE_Connect(nn.Module):
return out
''' SE-Res2Block of the ECAPA-TDNN architecture.
'''
""" SE-Res2Block of the ECAPA-TDNN architecture.
"""
# def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
# return nn.Sequential(
@@ -93,6 +96,7 @@ class SE_Connect(nn.Module):
# SE_Connect(channels)
# )
class SE_Res2Block(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
super().__init__()
@@ -122,8 +126,9 @@ class SE_Res2Block(nn.Module):
return x + residual
''' Attentive weighted mean and standard deviation pooling.
'''
""" Attentive weighted mean and standard deviation pooling.
"""
class AttentiveStatsPool(nn.Module):
def __init__(self, in_dim, attention_channels=128, global_context_att=False):
@@ -138,7 +143,6 @@ class AttentiveStatsPool(nn.Module):
self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
def forward(self, x):
if self.global_context_att:
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
@@ -151,38 +155,52 @@ class AttentiveStatsPool(nn.Module):
# alpha = F.relu(self.linear1(x_in))
alpha = torch.softmax(self.linear2(alpha), dim=2)
mean = torch.sum(alpha * x, dim=2)
residuals = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2
residuals = torch.sum(alpha * (x**2), dim=2) - mean**2
std = torch.sqrt(residuals.clamp(min=1e-9))
return torch.cat([mean, std], dim=1)
class ECAPA_TDNN(nn.Module):
def __init__(self, feat_dim=80, channels=512, emb_dim=192, global_context_att=False,
feat_type='wavlm_large', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
def __init__(
self,
feat_dim=80,
channels=512,
emb_dim=192,
global_context_att=False,
feat_type="wavlm_large",
sr=16000,
feature_selection="hidden_states",
update_extract=False,
config_path=None,
):
super().__init__()
self.feat_type = feat_type
self.feature_selection = feature_selection
self.update_extract = update_extract
self.sr = sr
torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
try:
local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main")
self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source='local', config_path=config_path)
except:
self.feature_extract = torch.hub.load('s3prl/s3prl', feat_type)
self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source="local", config_path=config_path)
except: # noqa: E722
self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type)
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"):
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"
):
self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"):
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"
):
self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
self.feat_num = self.get_feat_num()
self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
if feat_type != 'fbank' and feat_type != 'mfcc':
freeze_list = ['final_proj', 'label_embs_concat', 'mask_emb', 'project_q', 'quantizer']
if feat_type != "fbank" and feat_type != "mfcc":
freeze_list = ["final_proj", "label_embs_concat", "mask_emb", "project_q", "quantizer"]
for name, param in self.feature_extract.named_parameters():
for freeze_val in freeze_list:
if freeze_val in name:
@@ -198,18 +216,46 @@ class ECAPA_TDNN(nn.Module):
self.channels = [channels] * 4 + [1536]
self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
self.layer2 = SE_Res2Block(self.channels[0], self.channels[1], kernel_size=3, stride=1, padding=2, dilation=2, scale=8, se_bottleneck_dim=128)
self.layer3 = SE_Res2Block(self.channels[1], self.channels[2], kernel_size=3, stride=1, padding=3, dilation=3, scale=8, se_bottleneck_dim=128)
self.layer4 = SE_Res2Block(self.channels[2], self.channels[3], kernel_size=3, stride=1, padding=4, dilation=4, scale=8, se_bottleneck_dim=128)
self.layer2 = SE_Res2Block(
self.channels[0],
self.channels[1],
kernel_size=3,
stride=1,
padding=2,
dilation=2,
scale=8,
se_bottleneck_dim=128,
)
self.layer3 = SE_Res2Block(
self.channels[1],
self.channels[2],
kernel_size=3,
stride=1,
padding=3,
dilation=3,
scale=8,
se_bottleneck_dim=128,
)
self.layer4 = SE_Res2Block(
self.channels[2],
self.channels[3],
kernel_size=3,
stride=1,
padding=4,
dilation=4,
scale=8,
se_bottleneck_dim=128,
)
# self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
cat_channels = channels * 3
self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
self.pooling = AttentiveStatsPool(self.channels[-1], attention_channels=128, global_context_att=global_context_att)
self.pooling = AttentiveStatsPool(
self.channels[-1], attention_channels=128, global_context_att=global_context_att
)
self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
def get_feat_num(self):
self.feature_extract.eval()
wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
@@ -226,12 +272,12 @@ class ECAPA_TDNN(nn.Module):
x = self.feature_extract([sample for sample in x])
else:
with torch.no_grad():
if self.feat_type == 'fbank' or self.feat_type == 'mfcc':
if self.feat_type == "fbank" or self.feat_type == "mfcc":
x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len
else:
x = self.feature_extract([sample for sample in x])
if self.feat_type == 'fbank':
if self.feat_type == "fbank":
x = x.log()
if self.feat_type != "fbank" and self.feat_type != "mfcc":
@@ -263,6 +309,22 @@ class ECAPA_TDNN(nn.Module):
return out
def ECAPA_TDNN_SMALL(feat_dim, emb_dim=256, feat_type='wavlm_large', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
return ECAPA_TDNN(feat_dim=feat_dim, channels=512, emb_dim=emb_dim,
feat_type=feat_type, sr=sr, feature_selection=feature_selection, update_extract=update_extract, config_path=config_path)
def ECAPA_TDNN_SMALL(
feat_dim,
emb_dim=256,
feat_type="wavlm_large",
sr=16000,
feature_selection="hidden_states",
update_extract=False,
config_path=None,
):
return ECAPA_TDNN(
feat_dim=feat_dim,
channels=512,
emb_dim=emb_dim,
feat_type=feat_type,
sr=sr,
feature_selection=feature_selection,
update_extract=update_extract,
config_path=config_path,
)

View File

@@ -21,39 +21,40 @@ from x_transformers.x_transformers import apply_rotary_pos_emb
# raw wav to mel spec
class MelSpec(nn.Module):
def __init__(
self,
filter_length = 1024,
hop_length = 256,
win_length = 1024,
n_mel_channels = 100,
target_sample_rate = 24_000,
normalize = False,
power = 1,
norm = None,
center = True,
filter_length=1024,
hop_length=256,
win_length=1024,
n_mel_channels=100,
target_sample_rate=24_000,
normalize=False,
power=1,
norm=None,
center=True,
):
super().__init__()
self.n_mel_channels = n_mel_channels
self.mel_stft = torchaudio.transforms.MelSpectrogram(
sample_rate = target_sample_rate,
n_fft = filter_length,
win_length = win_length,
hop_length = hop_length,
n_mels = n_mel_channels,
power = power,
center = center,
normalized = normalize,
norm = norm,
sample_rate=target_sample_rate,
n_fft=filter_length,
win_length=win_length,
hop_length=hop_length,
n_mels=n_mel_channels,
power=power,
center=center,
normalized=normalize,
norm=norm,
)
self.register_buffer('dummy', torch.tensor(0), persistent = False)
self.register_buffer("dummy", torch.tensor(0), persistent=False)
def forward(self, inp):
if len(inp.shape) == 3:
inp = inp.squeeze(1) # 'b 1 nw -> b nw'
inp = inp.squeeze(1) # 'b 1 nw -> b nw'
assert len(inp.shape) == 2
@@ -61,12 +62,13 @@ class MelSpec(nn.Module):
self.to(inp.device)
mel = self.mel_stft(inp)
mel = mel.clamp(min = 1e-5).log()
mel = mel.clamp(min=1e-5).log()
return mel
# sinusoidal position embedding
class SinusPositionEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
@@ -84,35 +86,37 @@ class SinusPositionEmbedding(nn.Module):
# convolutional position embedding
class ConvPositionEmbedding(nn.Module):
def __init__(self, dim, kernel_size = 31, groups = 16):
def __init__(self, dim, kernel_size=31, groups=16):
super().__init__()
assert kernel_size % 2 != 0
self.conv1d = nn.Sequential(
nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
nn.Mish(),
nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
nn.Mish(),
)
def forward(self, x: float['b n d'], mask: bool['b n'] | None = None):
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
if mask is not None:
mask = mask[..., None]
x = x.masked_fill(~mask, 0.)
x = x.masked_fill(~mask, 0.0)
x = x.permute(0, 2, 1)
x = self.conv1d(x)
out = x.permute(0, 2, 1)
if mask is not None:
out = out.masked_fill(~mask, 0.)
out = out.masked_fill(~mask, 0.0)
return out
# rotary positional embedding related
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.):
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
@@ -125,12 +129,14 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_resca
freqs_sin = torch.sin(freqs) # imaginary part
return torch.cat([freqs_cos, freqs_sin], dim=-1)
def get_pos_embed_indices(start, length, max_pos, scale=1.):
def get_pos_embed_indices(start, length, max_pos, scale=1.0):
# length = length if isinstance(length, int) else length.max()
scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
pos = start.unsqueeze(1) + (
torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) *
scale.unsqueeze(1)).long()
pos = (
start.unsqueeze(1)
+ (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
)
# avoid extra long error.
pos = torch.where(pos < max_pos, pos, max_pos - 1)
return pos
@@ -138,6 +144,7 @@ def get_pos_embed_indices(start, length, max_pos, scale=1.):
# Global Response Normalization layer (Instance Normalization ?)
class GRN(nn.Module):
def __init__(self, dim):
super().__init__()
@@ -153,6 +160,7 @@ class GRN(nn.Module):
# ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
# ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
class ConvNeXtV2Block(nn.Module):
def __init__(
self,
@@ -162,7 +170,9 @@ class ConvNeXtV2Block(nn.Module):
):
super().__init__()
padding = (dilation * (7 - 1)) // 2
self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation) # depthwise conv
self.dwconv = nn.Conv1d(
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
) # depthwise conv
self.norm = nn.LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
@@ -185,6 +195,7 @@ class ConvNeXtV2Block(nn.Module):
# AdaLayerNormZero
# return with modulated x for attn input, and params for later mlp modulation
class AdaLayerNormZero(nn.Module):
def __init__(self, dim):
super().__init__()
@@ -194,7 +205,7 @@ class AdaLayerNormZero(nn.Module):
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb = None):
def forward(self, x, emb=None):
emb = self.linear(self.silu(emb))
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
@@ -205,6 +216,7 @@ class AdaLayerNormZero(nn.Module):
# AdaLayerNormZero for final layer
# return only with modulated x for attn input, cuz no more mlp modulation
class AdaLayerNormZero_Final(nn.Module):
def __init__(self, dim):
super().__init__()
@@ -224,22 +236,16 @@ class AdaLayerNormZero_Final(nn.Module):
# FeedForward
class FeedForward(nn.Module):
def __init__(self, dim, dim_out = None, mult = 4, dropout = 0., approximate: str = 'none'):
def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
activation = nn.GELU(approximate=approximate)
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
activation
)
self.ff = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
)
project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
def forward(self, x):
return self.ff(x)
@@ -248,6 +254,7 @@ class FeedForward(nn.Module):
# Attention with possible joint part
# modified from diffusers/src/diffusers/models/attention_processor.py
class Attention(nn.Module):
def __init__(
self,
@@ -256,8 +263,8 @@ class Attention(nn.Module):
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
context_dim: Optional[int] = None, # if not None -> joint attention
context_pre_only = None,
context_dim: Optional[int] = None, # if not None -> joint attention
context_pre_only=None,
):
super().__init__()
@@ -293,20 +300,21 @@ class Attention(nn.Module):
def forward(
self,
x: float['b n d'], # noised input x
c: float['b n d'] = None, # context c
mask: bool['b n'] | None = None,
rope = None, # rotary position embedding for x
c_rope = None, # rotary position embedding for c
x: float["b n d"], # noised input x # noqa: F722
c: float["b n d"] = None, # context c # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
rope=None, # rotary position embedding for x
c_rope=None, # rotary position embedding for c
) -> torch.Tensor:
if c is not None:
return self.processor(self, x, c = c, mask = mask, rope = rope, c_rope = c_rope)
return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
else:
return self.processor(self, x, mask = mask, rope = rope)
return self.processor(self, x, mask=mask, rope=rope)
# Attention processor
class AttnProcessor:
def __init__(self):
pass
@@ -314,11 +322,10 @@ class AttnProcessor:
def __call__(
self,
attn: Attention,
x: float['b n d'], # noised input x
mask: bool['b n'] | None = None,
rope = None, # rotary position embedding
x: float["b n d"], # noised input x # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
rope=None, # rotary position embedding
) -> torch.FloatTensor:
batch_size = x.shape[0]
# `sample` projections.
@@ -329,7 +336,7 @@ class AttnProcessor:
# apply rotary position embedding
if rope is not None:
freqs, xpos_scale = rope
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
@@ -360,14 +367,15 @@ class AttnProcessor:
if mask is not None:
mask = mask.unsqueeze(-1)
x = x.masked_fill(~mask, 0.)
x = x.masked_fill(~mask, 0.0)
return x
# Joint Attention processor for MM-DiT
# modified from diffusers/src/diffusers/models/attention_processor.py
class JointAttnProcessor:
def __init__(self):
pass
@@ -375,11 +383,11 @@ class JointAttnProcessor:
def __call__(
self,
attn: Attention,
x: float['b n d'], # noised input x
c: float['b nt d'] = None, # context c, here text
mask: bool['b n'] | None = None,
rope = None, # rotary position embedding for x
c_rope = None, # rotary position embedding for c
x: float["b n d"], # noised input x # noqa: F722
c: float["b nt d"] = None, # context c, here text # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
rope=None, # rotary position embedding for x
c_rope=None, # rotary position embedding for c
) -> torch.FloatTensor:
residual = x
@@ -398,12 +406,12 @@ class JointAttnProcessor:
# apply rope for context and noised input independently
if rope is not None:
freqs, xpos_scale = rope
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
if c_rope is not None:
freqs, xpos_scale = c_rope
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
@@ -420,7 +428,7 @@ class JointAttnProcessor:
# mask. e.g. inference got a batch with different target durations, mask out the padding
if mask is not None:
attn_mask = F.pad(mask, (0, c.shape[1]), value = True) # no mask for c (text)
attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
else:
@@ -432,8 +440,8 @@ class JointAttnProcessor:
# Split the attention outputs.
x, c = (
x[:, :residual.shape[1]],
x[:, residual.shape[1]:],
x[:, : residual.shape[1]],
x[:, residual.shape[1] :],
)
# linear proj
@@ -445,7 +453,7 @@ class JointAttnProcessor:
if mask is not None:
mask = mask.unsqueeze(-1)
x = x.masked_fill(~mask, 0.)
x = x.masked_fill(~mask, 0.0)
# c = c.masked_fill(~mask, 0.) # no mask for c (text)
return x, c
@@ -453,24 +461,24 @@ class JointAttnProcessor:
# DiT Block
class DiTBlock(nn.Module):
def __init__(self, dim, heads, dim_head, ff_mult = 4, dropout = 0.1):
class DiTBlock(nn.Module):
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
super().__init__()
self.attn_norm = AdaLayerNormZero(dim)
self.attn = Attention(
processor = AttnProcessor(),
dim = dim,
heads = heads,
dim_head = dim_head,
dropout = dropout,
)
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
processor=AttnProcessor(),
dim=dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
)
def forward(self, x, t, mask = None, rope = None): # x: noised input, t: time embedding
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
# pre-norm & modulation for attention input
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
@@ -479,7 +487,7 @@ class DiTBlock(nn.Module):
# process attention output for input x
x = x + gate_msa.unsqueeze(1) * attn_output
norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
ff_output = self.ff(norm)
x = x + gate_mlp.unsqueeze(1) * ff_output
@@ -489,8 +497,9 @@ class DiTBlock(nn.Module):
# MMDiT Block https://arxiv.org/abs/2403.03206
class MMDiTBlock(nn.Module):
r"""
r"""
modified from diffusers/src/diffusers/models/attention.py
notes.
@@ -499,33 +508,33 @@ class MMDiTBlock(nn.Module):
context_pre_only: last layer only do prenorm + modulation cuz no more ffn
"""
def __init__(self, dim, heads, dim_head, ff_mult = 4, dropout = 0.1, context_pre_only = False):
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
super().__init__()
self.context_pre_only = context_pre_only
self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
self.attn_norm_x = AdaLayerNormZero(dim)
self.attn = Attention(
processor = JointAttnProcessor(),
dim = dim,
heads = heads,
dim_head = dim_head,
dropout = dropout,
context_dim = dim,
context_pre_only = context_pre_only,
)
processor=JointAttnProcessor(),
dim=dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
context_dim=dim,
context_pre_only=context_pre_only,
)
if not context_pre_only:
self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_c = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
else:
self.ff_norm_c = None
self.ff_c = None
self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_x = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
def forward(self, x, c, t, mask = None, rope = None, c_rope = None): # x: noised input, c: context, t: time embedding
def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding
# pre-norm & modulation for attention input
if self.context_pre_only:
norm_c = self.attn_norm_c(c, t)
@@ -539,7 +548,7 @@ class MMDiTBlock(nn.Module):
# process attention output for context c
if self.context_pre_only:
c = None
else: # if not last layer
else: # if not last layer
c = c + c_gate_msa.unsqueeze(1) * c_attn_output
norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
@@ -548,7 +557,7 @@ class MMDiTBlock(nn.Module):
# process attention output for input x
x = x + x_gate_msa.unsqueeze(1) * x_attn_output
norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
x_ff_output = self.ff_x(norm_x)
x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
@@ -558,17 +567,14 @@ class MMDiTBlock(nn.Module):
# time step conditioning embedding
class TimestepEmbedding(nn.Module):
def __init__(self, dim, freq_embed_dim=256):
super().__init__()
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
self.time_mlp = nn.Sequential(
nn.Linear(freq_embed_dim, dim),
nn.SiLU(),
nn.Linear(dim, dim)
)
self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
def forward(self, timestep: float['b']):
def forward(self, timestep: float["b"]): # noqa: F821
time_hidden = self.time_embed(timestep)
time_hidden = time_hidden.to(timestep.dtype)
time = self.time_mlp(time_hidden) # b d

View File

@@ -22,71 +22,69 @@ from model.dataset import DynamicBatchSampler, collate_fn
# trainer
class Trainer:
def __init__(
self,
model: CFM,
epochs,
learning_rate,
num_warmup_updates = 20000,
save_per_updates = 1000,
checkpoint_path = None,
batch_size = 32,
num_warmup_updates=20000,
save_per_updates=1000,
checkpoint_path=None,
batch_size=32,
batch_size_type: str = "sample",
max_samples = 32,
grad_accumulation_steps = 1,
max_grad_norm = 1.0,
max_samples=32,
grad_accumulation_steps=1,
max_grad_norm=1.0,
noise_scheduler: str | None = None,
duration_predictor: torch.nn.Module | None = None,
wandb_project = "test_e2-tts",
wandb_run_name = "test_run",
wandb_project="test_e2-tts",
wandb_run_name="test_run",
wandb_resume_id: str = None,
last_per_steps = None,
last_per_steps=None,
accelerate_kwargs: dict = dict(),
ema_kwargs: dict = dict(),
bnb_optimizer: bool = False,
):
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
logger = "wandb" if wandb.api.api_key else None
print(f"Using logger: {logger}")
self.accelerator = Accelerator(
log_with = logger,
kwargs_handlers = [ddp_kwargs],
gradient_accumulation_steps = grad_accumulation_steps,
**accelerate_kwargs
log_with=logger,
kwargs_handlers=[ddp_kwargs],
gradient_accumulation_steps=grad_accumulation_steps,
**accelerate_kwargs,
)
if logger == "wandb":
if exists(wandb_resume_id):
init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name, 'id': wandb_resume_id}}
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}}
else:
init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name}}
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
self.accelerator.init_trackers(
project_name = wandb_project,
project_name=wandb_project,
init_kwargs=init_kwargs,
config={"epochs": epochs,
"learning_rate": learning_rate,
"num_warmup_updates": num_warmup_updates,
"batch_size": batch_size,
"batch_size_type": batch_size_type,
"max_samples": max_samples,
"grad_accumulation_steps": grad_accumulation_steps,
"max_grad_norm": max_grad_norm,
"gpus": self.accelerator.num_processes,
"noise_scheduler": noise_scheduler}
)
config={
"epochs": epochs,
"learning_rate": learning_rate,
"num_warmup_updates": num_warmup_updates,
"batch_size": batch_size,
"batch_size_type": batch_size_type,
"max_samples": max_samples,
"grad_accumulation_steps": grad_accumulation_steps,
"max_grad_norm": max_grad_norm,
"gpus": self.accelerator.num_processes,
"noise_scheduler": noise_scheduler,
},
)
self.model = model
if self.is_main:
self.ema_model = EMA(
model,
include_online_model = False,
**ema_kwargs
)
self.ema_model = EMA(model, include_online_model=False, **ema_kwargs)
self.ema_model.to(self.accelerator.device)
@@ -94,7 +92,7 @@ class Trainer:
self.num_warmup_updates = num_warmup_updates
self.save_per_updates = save_per_updates
self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
self.checkpoint_path = default(checkpoint_path, 'ckpts/test_e2-tts')
self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
self.batch_size = batch_size
self.batch_size_type = batch_size_type
@@ -108,12 +106,11 @@ class Trainer:
if bnb_optimizer:
import bitsandbytes as bnb
self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
else:
self.optimizer = AdamW(model.parameters(), lr=learning_rate)
self.model, self.optimizer = self.accelerator.prepare(
self.model, self.optimizer
)
self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
@property
def is_main(self):
@@ -123,81 +120,112 @@ class Trainer:
self.accelerator.wait_for_everyone()
if self.is_main:
checkpoint = dict(
model_state_dict = self.accelerator.unwrap_model(self.model).state_dict(),
optimizer_state_dict = self.accelerator.unwrap_model(self.optimizer).state_dict(),
ema_model_state_dict = self.ema_model.state_dict(),
scheduler_state_dict = self.scheduler.state_dict(),
step = step
model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
ema_model_state_dict=self.ema_model.state_dict(),
scheduler_state_dict=self.scheduler.state_dict(),
step=step,
)
if not os.path.exists(self.checkpoint_path):
os.makedirs(self.checkpoint_path)
if last == True:
if last:
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
print(f"Saved last checkpoint at step {step}")
else:
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
def load_checkpoint(self):
if not exists(self.checkpoint_path) or not os.path.exists(self.checkpoint_path) or not os.listdir(self.checkpoint_path):
if (
not exists(self.checkpoint_path)
or not os.path.exists(self.checkpoint_path)
or not os.listdir(self.checkpoint_path)
):
return 0
self.accelerator.wait_for_everyone()
if "model_last.pt" in os.listdir(self.checkpoint_path):
latest_checkpoint = "model_last.pt"
else:
latest_checkpoint = sorted([f for f in os.listdir(self.checkpoint_path) if f.endswith('.pt')], key=lambda x: int(''.join(filter(str.isdigit, x))))[-1]
latest_checkpoint = sorted(
[f for f in os.listdir(self.checkpoint_path) if f.endswith(".pt")],
key=lambda x: int("".join(filter(str.isdigit, x))),
)[-1]
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
if self.is_main:
self.ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
self.ema_model.load_state_dict(checkpoint["ema_model_state_dict"])
if 'step' in checkpoint:
self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint['optimizer_state_dict'])
if "step" in checkpoint:
self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"])
if self.scheduler:
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
step = checkpoint['step']
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
step = checkpoint["step"]
else:
checkpoint['model_state_dict'] = {k.replace("ema_model.", ""): v for k, v in checkpoint['ema_model_state_dict'].items() if k not in ["initted", "step"]}
self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
checkpoint["model_state_dict"] = {
k.replace("ema_model.", ""): v
for k, v in checkpoint["ema_model_state_dict"].items()
if k not in ["initted", "step"]
}
self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
step = 0
del checkpoint; gc.collect()
del checkpoint
gc.collect()
return step
def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
if exists(resumable_with_seed):
generator = torch.Generator()
generator.manual_seed(resumable_with_seed)
else:
else:
generator = None
if self.batch_size_type == "sample":
train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, persistent_workers=True,
batch_size=self.batch_size, shuffle=True, generator=generator)
train_dataloader = DataLoader(
train_dataset,
collate_fn=collate_fn,
num_workers=num_workers,
pin_memory=True,
persistent_workers=True,
batch_size=self.batch_size,
shuffle=True,
generator=generator,
)
elif self.batch_size_type == "frame":
self.accelerator.even_batches = False
sampler = SequentialSampler(train_dataset)
batch_sampler = DynamicBatchSampler(sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False)
train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, persistent_workers=True,
batch_sampler=batch_sampler)
batch_sampler = DynamicBatchSampler(
sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False
)
train_dataloader = DataLoader(
train_dataset,
collate_fn=collate_fn,
num_workers=num_workers,
pin_memory=True,
persistent_workers=True,
batch_sampler=batch_sampler,
)
else:
raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}")
# accelerator.prepare() dispatches batches to devices;
# which means the length of dataloader calculated before, should consider the number of devices
warmup_steps = self.num_warmup_updates * self.accelerator.num_processes # consider a fixed warmup steps while using accelerate multi-gpu ddp
# otherwise by default with split_batches=False, warmup steps change with num_processes
warmup_steps = (
self.num_warmup_updates * self.accelerator.num_processes
) # consider a fixed warmup steps while using accelerate multi-gpu ddp
# otherwise by default with split_batches=False, warmup steps change with num_processes
total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
decay_steps = total_steps - warmup_steps
warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
self.scheduler = SequentialLR(self.optimizer,
schedulers=[warmup_scheduler, decay_scheduler],
milestones=[warmup_steps])
train_dataloader, self.scheduler = self.accelerator.prepare(train_dataloader, self.scheduler) # actual steps = 1 gpu steps / gpus
self.scheduler = SequentialLR(
self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_steps]
)
train_dataloader, self.scheduler = self.accelerator.prepare(
train_dataloader, self.scheduler
) # actual steps = 1 gpu steps / gpus
start_step = self.load_checkpoint()
global_step = start_step
@@ -212,23 +240,36 @@ class Trainer:
for epoch in range(skipped_epoch, self.epochs):
self.model.train()
if exists(resumable_with_seed) and epoch == skipped_epoch:
progress_bar = tqdm(skipped_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process,
initial=skipped_batch, total=orig_epoch_step)
progress_bar = tqdm(
skipped_dataloader,
desc=f"Epoch {epoch+1}/{self.epochs}",
unit="step",
disable=not self.accelerator.is_local_main_process,
initial=skipped_batch,
total=orig_epoch_step,
)
else:
progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process)
progress_bar = tqdm(
train_dataloader,
desc=f"Epoch {epoch+1}/{self.epochs}",
unit="step",
disable=not self.accelerator.is_local_main_process,
)
for batch in progress_bar:
with self.accelerator.accumulate(self.model):
text_inputs = batch['text']
mel_spec = batch['mel'].permute(0, 2, 1)
text_inputs = batch["text"]
mel_spec = batch["mel"].permute(0, 2, 1)
mel_lengths = batch["mel_lengths"]
# TODO. add duration predictor training
if self.duration_predictor is not None and self.accelerator.is_local_main_process:
dur_loss = self.duration_predictor(mel_spec, lens=batch.get('durations'))
dur_loss = self.duration_predictor(mel_spec, lens=batch.get("durations"))
self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step)
loss, cond, pred = self.model(mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler)
loss, cond, pred = self.model(
mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler
)
self.accelerator.backward(loss)
if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
@@ -245,13 +286,13 @@ class Trainer:
if self.accelerator.is_local_main_process:
self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
progress_bar.set_postfix(step=str(global_step), loss=loss.item())
if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
self.save_checkpoint(global_step)
if global_step % self.last_per_steps == 0:
self.save_checkpoint(global_step, last=True)
self.accelerator.end_training()

View File

@@ -8,6 +8,7 @@ from tqdm import tqdm
from collections import defaultdict
import matplotlib
matplotlib.use("Agg")
import matplotlib.pylab as plt
@@ -25,109 +26,102 @@ from model.modules import MelSpec
# seed everything
def seed_everything(seed = 0):
def seed_everything(seed=0):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# helpers
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
# tensor helpers
def lens_to_mask(
t: int['b'],
length: int | None = None
) -> bool['b n']:
def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821
if not exists(length):
length = t.amax()
seq = torch.arange(length, device = t.device)
seq = torch.arange(length, device=t.device)
return seq[None, :] < t[:, None]
def mask_from_start_end_indices(
seq_len: int['b'],
start: int['b'],
end: int['b']
):
max_seq_len = seq_len.max().item()
seq = torch.arange(max_seq_len, device = start.device).long()
def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]): # noqa: F722 F821
max_seq_len = seq_len.max().item()
seq = torch.arange(max_seq_len, device=start.device).long()
start_mask = seq[None, :] >= start[:, None]
end_mask = seq[None, :] < end[:, None]
return start_mask & end_mask
def mask_from_frac_lengths(
seq_len: int['b'],
frac_lengths: float['b']
):
def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa: F722 F821
lengths = (frac_lengths * seq_len).long()
max_start = seq_len - lengths
rand = torch.rand_like(frac_lengths)
start = (max_start * rand).long().clamp(min = 0)
start = (max_start * rand).long().clamp(min=0)
end = start + lengths
return mask_from_start_end_indices(seq_len, start, end)
def maybe_masked_mean(
t: float['b n d'],
mask: bool['b n'] = None
) -> float['b d']:
def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]: # noqa: F722
if not exists(mask):
return t.mean(dim = 1)
return t.mean(dim=1)
t = torch.where(mask[:, :, None], t, torch.tensor(0., device=t.device))
t = torch.where(mask[:, :, None], t, torch.tensor(0.0, device=t.device))
num = t.sum(dim=1)
den = mask.float().sum(dim=1)
return num / den.clamp(min=1.)
return num / den.clamp(min=1.0)
# simple utf-8 tokenizer, since paper went character based
def list_str_to_tensor(
text: list[str],
padding_value = -1
) -> int['b nt']:
list_tensors = [torch.tensor([*bytes(t, 'UTF-8')]) for t in text] # ByT5 style
text = pad_sequence(list_tensors, padding_value = padding_value, batch_first = True)
def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722
list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style
text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True)
return text
# char tokenizer, based on custom dataset's extracted .txt file
def list_str_to_idx(
text: list[str] | list[list[str]],
vocab_char_map: dict[str, int], # {char: idx}
padding_value = -1
) -> int['b nt']:
padding_value=-1,
) -> int["b nt"]: # noqa: F722
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
text = pad_sequence(list_idx_tensors, padding_value = padding_value, batch_first = True)
text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
return text
# Get tokenizer
def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
'''
"""
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
- "char" for char-wise tokenizer, need .txt vocab_file
- "byte" for utf-8 tokenizer
- "custom" if you're directly passing in a path to the vocab.txt you want to use
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
- if use "char", derived from unfiltered character & symbol counts of custom dataset
- if use "byte", set to 256 (unicode byte range)
'''
- if use "byte", set to 256 (unicode byte range)
"""
if tokenizer in ["pinyin", "char"]:
with open (f"data/{dataset_name}_{tokenizer}/vocab.txt", "r", encoding="utf-8") as f:
with open(f"data/{dataset_name}_{tokenizer}/vocab.txt", "r", encoding="utf-8") as f:
vocab_char_map = {}
for i, char in enumerate(f):
vocab_char_map[char[:-1]] = i
@@ -138,7 +132,7 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
vocab_char_map = None
vocab_size = 256
elif tokenizer == "custom":
with open (dataset_name, "r", encoding="utf-8") as f:
with open(dataset_name, "r", encoding="utf-8") as f:
vocab_char_map = {}
for i, char in enumerate(f):
vocab_char_map[char[:-1]] = i
@@ -149,16 +143,19 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
# convert char to pinyin
def convert_char_to_pinyin(text_list, polyphone = True):
def convert_char_to_pinyin(text_list, polyphone=True):
final_text_list = []
god_knows_why_en_testset_contains_zh_quote = str.maketrans({'': '"', '': '"', '': "'", '': "'"}) # in case librispeech (orig no-pc) test-clean
custom_trans = str.maketrans({';': ','}) # add custom trans here, to address oov
god_knows_why_en_testset_contains_zh_quote = str.maketrans(
{"": '"', "": '"', "": "'", "": "'"}
) # in case librispeech (orig no-pc) test-clean
custom_trans = str.maketrans({";": ","}) # add custom trans here, to address oov
for text in text_list:
char_list = []
text = text.translate(god_knows_why_en_testset_contains_zh_quote)
text = text.translate(custom_trans)
for seg in jieba.cut(text):
seg_byte_len = len(bytes(seg, 'UTF-8'))
seg_byte_len = len(bytes(seg, "UTF-8"))
if seg_byte_len == len(seg): # if pure alphabets and symbols
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
char_list.append(" ")
@@ -187,7 +184,7 @@ def convert_char_to_pinyin(text_list, polyphone = True):
# save spectrogram
def save_spectrogram(spectrogram, path):
plt.figure(figsize=(12, 4))
plt.imshow(spectrogram, origin='lower', aspect='auto')
plt.imshow(spectrogram, origin="lower", aspect="auto")
plt.colorbar()
plt.savefig(path)
plt.close()
@@ -195,13 +192,15 @@ def save_spectrogram(spectrogram, path):
# seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
def get_seedtts_testset_metainfo(metalst):
f = open(metalst); lines = f.readlines(); f.close()
f = open(metalst)
lines = f.readlines()
f.close()
metainfo = []
for line in lines:
if len(line.strip().split('|')) == 5:
utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|')
elif len(line.strip().split('|')) == 4:
utt, prompt_text, prompt_wav, gt_text = line.strip().split('|')
if len(line.strip().split("|")) == 5:
utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
elif len(line.strip().split("|")) == 4:
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
if not os.path.isabs(prompt_wav):
prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
@@ -211,18 +210,20 @@ def get_seedtts_testset_metainfo(metalst):
# librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav
def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path):
f = open(metalst); lines = f.readlines(); f.close()
f = open(metalst)
lines = f.readlines()
f.close()
metainfo = []
for line in lines:
ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t')
ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t")
# ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-')
ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac')
ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac")
# gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-')
gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac')
gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac")
metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))
@@ -234,7 +235,7 @@ def padded_mel_batch(ref_mels):
max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
padded_ref_mels = []
for mel in ref_mels:
padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value = 0)
padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0)
padded_ref_mels.append(padded_ref_mel)
padded_ref_mels = torch.stack(padded_ref_mels)
padded_ref_mels = padded_ref_mels.permute(0, 2, 1)
@@ -243,12 +244,21 @@ def padded_mel_batch(ref_mels):
# get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav
def get_inference_prompt(
metainfo,
speed = 1., tokenizer = "pinyin", polyphone = True,
target_sample_rate = 24000, n_mel_channels = 100, hop_length = 256, target_rms = 0.1,
use_truth_duration = False,
infer_batch_size = 1, num_buckets = 200, min_secs = 3, max_secs = 40,
metainfo,
speed=1.0,
tokenizer="pinyin",
polyphone=True,
target_sample_rate=24000,
n_mel_channels=100,
hop_length=256,
target_rms=0.1,
use_truth_duration=False,
infer_batch_size=1,
num_buckets=200,
min_secs=3,
max_secs=40,
):
prompts_all = []
@@ -256,13 +266,15 @@ def get_inference_prompt(
max_tokens = max_secs * target_sample_rate // hop_length
batch_accum = [0] * num_buckets
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = \
([[] for _ in range(num_buckets)] for _ in range(6))
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = (
[[] for _ in range(num_buckets)] for _ in range(6)
)
mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length)
mel_spectrogram = MelSpec(
target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length
)
for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
# Audio
ref_audio, ref_sr = torchaudio.load(prompt_wav)
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
@@ -274,11 +286,11 @@ def get_inference_prompt(
ref_audio = resampler(ref_audio)
# Text
if len(prompt_text[-1].encode('utf-8')) == 1:
if len(prompt_text[-1].encode("utf-8")) == 1:
prompt_text = prompt_text + " "
text = [prompt_text + gt_text]
if tokenizer == "pinyin":
text_list = convert_char_to_pinyin(text, polyphone = polyphone)
text_list = convert_char_to_pinyin(text, polyphone=polyphone)
else:
text_list = text
@@ -294,8 +306,8 @@ def get_inference_prompt(
# # test vocoder resynthesis
# ref_audio = gt_audio
else:
ref_text_len = len(prompt_text.encode('utf-8'))
gen_text_len = len(gt_text.encode('utf-8'))
ref_text_len = len(prompt_text.encode("utf-8"))
gen_text_len = len(gt_text.encode("utf-8"))
total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
# to mel spectrogram
@@ -304,8 +316,9 @@ def get_inference_prompt(
# deal with batch
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
assert min_tokens <= total_mel_len <= max_tokens, \
f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
assert (
min_tokens <= total_mel_len <= max_tokens
), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
utts[bucket_i].append(utt)
@@ -319,28 +332,39 @@ def get_inference_prompt(
if batch_accum[bucket_i] >= infer_batch_size:
# print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
prompts_all.append((
utts[bucket_i],
ref_rms_list[bucket_i],
padded_mel_batch(ref_mels[bucket_i]),
ref_mel_lens[bucket_i],
total_mel_lens[bucket_i],
final_text_list[bucket_i]
))
prompts_all.append(
(
utts[bucket_i],
ref_rms_list[bucket_i],
padded_mel_batch(ref_mels[bucket_i]),
ref_mel_lens[bucket_i],
total_mel_lens[bucket_i],
final_text_list[bucket_i],
)
)
batch_accum[bucket_i] = 0
utts[bucket_i], ref_rms_list[bucket_i], ref_mels[bucket_i], ref_mel_lens[bucket_i], total_mel_lens[bucket_i], final_text_list[bucket_i] = [], [], [], [], [], []
(
utts[bucket_i],
ref_rms_list[bucket_i],
ref_mels[bucket_i],
ref_mel_lens[bucket_i],
total_mel_lens[bucket_i],
final_text_list[bucket_i],
) = [], [], [], [], [], []
# add residual
for bucket_i, bucket_frames in enumerate(batch_accum):
if bucket_frames > 0:
prompts_all.append((
utts[bucket_i],
ref_rms_list[bucket_i],
padded_mel_batch(ref_mels[bucket_i]),
ref_mel_lens[bucket_i],
total_mel_lens[bucket_i],
final_text_list[bucket_i]
))
prompts_all.append(
(
utts[bucket_i],
ref_rms_list[bucket_i],
padded_mel_batch(ref_mels[bucket_i]),
ref_mel_lens[bucket_i],
total_mel_lens[bucket_i],
final_text_list[bucket_i],
)
)
# not only leave easy work for last workers
random.seed(666)
random.shuffle(prompts_all)
@@ -351,6 +375,7 @@ def get_inference_prompt(
# get wav_res_ref_text of seed-tts test metalst
# https://github.com/BytedanceSpeech/seed-tts-eval
def get_seed_tts_test(metalst, gen_wav_dir, gpus):
f = open(metalst)
lines = f.readlines()
@@ -358,14 +383,14 @@ def get_seed_tts_test(metalst, gen_wav_dir, gpus):
test_set_ = []
for line in tqdm(lines):
if len(line.strip().split('|')) == 5:
utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|')
elif len(line.strip().split('|')) == 4:
utt, prompt_text, prompt_wav, gt_text = line.strip().split('|')
if len(line.strip().split("|")) == 5:
utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
elif len(line.strip().split("|")) == 4:
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
if not os.path.exists(os.path.join(gen_wav_dir, utt + '.wav')):
if not os.path.exists(os.path.join(gen_wav_dir, utt + ".wav")):
continue
gen_wav = os.path.join(gen_wav_dir, utt + '.wav')
gen_wav = os.path.join(gen_wav_dir, utt + ".wav")
if not os.path.isabs(prompt_wav):
prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
@@ -374,65 +399,69 @@ def get_seed_tts_test(metalst, gen_wav_dir, gpus):
num_jobs = len(gpus)
if num_jobs == 1:
return [(gpus[0], test_set_)]
wav_per_job = len(test_set_) // num_jobs + 1
test_set = []
for i in range(num_jobs):
test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job]))
return test_set
# get librispeech test-clean cross sentence test
def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = False):
def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth=False):
f = open(metalst)
lines = f.readlines()
f.close()
test_set_ = []
for line in tqdm(lines):
ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t')
ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t")
if eval_ground_truth:
gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-')
gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac')
gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac")
else:
if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + '.wav')):
if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + ".wav")):
raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
gen_wav = os.path.join(gen_wav_dir, gen_utt + '.wav')
gen_wav = os.path.join(gen_wav_dir, gen_utt + ".wav")
ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-')
ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac')
ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac")
test_set_.append((gen_wav, ref_wav, gen_txt))
num_jobs = len(gpus)
if num_jobs == 1:
return [(gpus[0], test_set_)]
wav_per_job = len(test_set_) // num_jobs + 1
test_set = []
for i in range(num_jobs):
test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job]))
return test_set
# load asr model
def load_asr_model(lang, ckpt_dir = ""):
def load_asr_model(lang, ckpt_dir=""):
if lang == "zh":
from funasr import AutoModel
model = AutoModel(
model = os.path.join(ckpt_dir, "paraformer-zh"),
# vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
model=os.path.join(ckpt_dir, "paraformer-zh"),
# vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
# punc_model = os.path.join(ckpt_dir, "ct-punc"),
# spk_model = os.path.join(ckpt_dir, "cam++"),
# spk_model = os.path.join(ckpt_dir, "cam++"),
disable_update=True,
) # following seed-tts setting
) # following seed-tts setting
elif lang == "en":
from faster_whisper import WhisperModel
model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
model = WhisperModel(model_size, device="cuda", compute_type="float16")
return model
@@ -440,44 +469,50 @@ def load_asr_model(lang, ckpt_dir = ""):
# WER Evaluation, the way Seed-TTS does
def run_asr_wer(args):
rank, lang, test_set, ckpt_dir = args
if lang == "zh":
import zhconv
torch.cuda.set_device(rank)
elif lang == "en":
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
else:
raise NotImplementedError("lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now.")
raise NotImplementedError(
"lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now."
)
asr_model = load_asr_model(lang, ckpt_dir=ckpt_dir)
asr_model = load_asr_model(lang, ckpt_dir = ckpt_dir)
from zhon.hanzi import punctuation
punctuation_all = punctuation + string.punctuation
wers = []
from jiwer import compute_measures
for gen_wav, prompt_wav, truth in tqdm(test_set):
if lang == "zh":
res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
hypo = res[0]["text"]
hypo = zhconv.convert(hypo, 'zh-cn')
hypo = zhconv.convert(hypo, "zh-cn")
elif lang == "en":
segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en")
hypo = ''
hypo = ""
for segment in segments:
hypo = hypo + ' ' + segment.text
hypo = hypo + " " + segment.text
# raw_truth = truth
# raw_hypo = hypo
for x in punctuation_all:
truth = truth.replace(x, '')
hypo = hypo.replace(x, '')
truth = truth.replace(x, "")
hypo = hypo.replace(x, "")
truth = truth.replace(' ', ' ')
hypo = hypo.replace(' ', ' ')
truth = truth.replace(" ", " ")
hypo = hypo.replace(" ", " ")
if lang == "zh":
truth = " ".join([x for x in truth])
@@ -501,22 +536,22 @@ def run_asr_wer(args):
# SIM Evaluation
def run_sim(args):
rank, test_set, ckpt_dir = args
device = f"cuda:{rank}"
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=None)
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type="wavlm_large", config_path=None)
state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage)
model.load_state_dict(state_dict['model'], strict=False)
model.load_state_dict(state_dict["model"], strict=False)
use_gpu=True if torch.cuda.is_available() else False
use_gpu = True if torch.cuda.is_available() else False
if use_gpu:
model = model.cuda(device)
model.eval()
sim_list = []
for wav1, wav2, truth in tqdm(test_set):
wav1, sr1 = torchaudio.load(wav1)
wav2, sr2 = torchaudio.load(wav2)
@@ -531,20 +566,21 @@ def run_sim(args):
with torch.no_grad():
emb1 = model(wav1)
emb2 = model(wav2)
sim = F.cosine_similarity(emb1, emb2)[0].item()
# print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
sim_list.append(sim)
return sim_list
# filter func for dirty data with many repetitions
def repetition_found(text, length = 2, tolerance = 10):
def repetition_found(text, length=2, tolerance=10):
pattern_count = defaultdict(int)
for i in range(len(text) - length + 1):
pattern = text[i:i + length]
pattern = text[i : i + length]
pattern_count[pattern] += 1
for pattern, count in pattern_count.items():
if count > tolerance:
@@ -554,25 +590,31 @@ def repetition_found(text, length = 2, tolerance = 10):
# load model checkpoint for inference
def load_checkpoint(model, ckpt_path, device, use_ema = True):
def load_checkpoint(model, ckpt_path, device, use_ema=True):
if device == "cuda":
model = model.half()
ckpt_type = ckpt_path.split(".")[-1]
if ckpt_type == "safetensors":
from safetensors.torch import load_file
checkpoint = load_file(ckpt_path)
else:
checkpoint = torch.load(ckpt_path, weights_only=True)
if use_ema:
if ckpt_type == "safetensors":
checkpoint = {'ema_model_state_dict': checkpoint}
checkpoint['model_state_dict'] = {k.replace("ema_model.", ""): v for k, v in checkpoint['ema_model_state_dict'].items() if k not in ["initted", "step"]}
model.load_state_dict(checkpoint['model_state_dict'])
checkpoint = {"ema_model_state_dict": checkpoint}
checkpoint["model_state_dict"] = {
k.replace("ema_model.", ""): v
for k, v in checkpoint["ema_model_state_dict"].items()
if k not in ["initted", "step"]
}
model.load_state_dict(checkpoint["model_state_dict"])
else:
if ckpt_type == "safetensors":
checkpoint = {'model_state_dict': checkpoint}
model.load_state_dict(checkpoint['model_state_dict'])
checkpoint = {"model_state_dict": checkpoint}
model.load_state_dict(checkpoint["model_state_dict"])
return model.to(device)

View File

@@ -19,11 +19,7 @@ from model.utils import (
convert_char_to_pinyin,
)
device = (
"cuda"
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
)
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
asr_pipe = pipeline(
@@ -54,6 +50,7 @@ fix_duration = None
# chunk text into smaller pieces
def chunk_text(text, max_chars=135):
"""
Splits the input text into chunks, each with a maximum number of characters.
@@ -68,15 +65,15 @@ def chunk_text(text, max_chars=135):
chunks = []
current_chunk = ""
# Split the text into sentences based on punctuation followed by whitespace
sentences = re.split(r'(?<=[;:,.!?])\s+|(?<=[;:,。!?])', text)
sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text)
for sentence in sentences:
if len(current_chunk.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
current_chunk += sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
else:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
if current_chunk:
chunks.append(current_chunk.strip())
@@ -86,6 +83,7 @@ def chunk_text(text, max_chars=135):
# load vocoder
def load_vocoder(is_local=False, local_path=""):
if is_local:
print(f"Load vocos from local path {local_path}")
@@ -101,23 +99,21 @@ def load_vocoder(is_local=False, local_path=""):
# load model for inference
def load_model(model_cls, model_cfg, ckpt_path, vocab_file=""):
if vocab_file == "":
vocab_file = "Emilia_ZH_EN"
tokenizer = "pinyin"
else:
tokenizer = "custom"
print("\nvocab : ", vocab_file, tokenizer)
print("tokenizer : ", tokenizer)
print("model : ", ckpt_path,"\n")
print("\nvocab : ", vocab_file, tokenizer)
print("tokenizer : ", tokenizer)
print("model : ", ckpt_path, "\n")
vocab_char_map, vocab_size = get_tokenizer(vocab_file, tokenizer)
model = CFM(
transformer=model_cls(
**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
),
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
mel_spec_kwargs=dict(
target_sample_rate=target_sample_rate,
n_mel_channels=n_mel_channels,
@@ -129,21 +125,20 @@ def load_model(model_cls, model_cfg, ckpt_path, vocab_file=""):
vocab_char_map=vocab_char_map,
).to(device)
model = load_checkpoint(model, ckpt_path, device, use_ema = True)
model = load_checkpoint(model, ckpt_path, device, use_ema=True)
return model
# preprocess reference audio and text
def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print):
show_info("Converting audio...")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
aseg = AudioSegment.from_file(ref_audio_orig)
non_silent_segs = silence.split_on_silence(
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000
)
non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000)
non_silent_wave = AudioSegment.silent(duration=0)
for non_silent_seg in non_silent_segs:
non_silent_wave += non_silent_seg
@@ -181,22 +176,27 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print):
# infer process: chunk text -> infer batches [i.e. infer_batch_process()]
def infer_process(ref_audio, ref_text, gen_text, model_obj, cross_fade_duration=0.15, speed=speed, show_info=print, progress=tqdm):
def infer_process(
ref_audio, ref_text, gen_text, model_obj, cross_fade_duration=0.15, speed=speed, show_info=print, progress=tqdm
):
# Split the input text into batches
audio, sr = torchaudio.load(ref_audio)
max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
for i, gen_text in enumerate(gen_text_batches):
print(f'gen_text {i}', gen_text)
print(f"gen_text {i}", gen_text)
show_info(f"Generating audio in {len(gen_text_batches)} batches...")
return infer_batch_process((audio, sr), ref_text, gen_text_batches, model_obj, cross_fade_duration, speed, progress)
# infer batches
def infer_batch_process(ref_audio, ref_text, gen_text_batches, model_obj, cross_fade_duration=0.15, speed=1, progress=tqdm):
def infer_batch_process(
ref_audio, ref_text, gen_text_batches, model_obj, cross_fade_duration=0.15, speed=1, progress=tqdm
):
audio, sr = ref_audio
if audio.shape[0] > 1:
audio = torch.mean(audio, dim=0, keepdim=True)
@@ -212,7 +212,7 @@ def infer_batch_process(ref_audio, ref_text, gen_text_batches, model_obj, cross_
generated_waves = []
spectrograms = []
if len(ref_text[-1].encode('utf-8')) == 1:
if len(ref_text[-1].encode("utf-8")) == 1:
ref_text = ref_text + " "
for i, gen_text in enumerate(progress.tqdm(gen_text_batches)):
# Prepare the text
@@ -221,8 +221,8 @@ def infer_batch_process(ref_audio, ref_text, gen_text_batches, model_obj, cross_
# Calculate duration
ref_audio_len = audio.shape[-1] // hop_length
ref_text_len = len(ref_text.encode('utf-8'))
gen_text_len = len(gen_text.encode('utf-8'))
ref_text_len = len(ref_text.encode("utf-8"))
gen_text_len = len(gen_text.encode("utf-8"))
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
# inference
@@ -245,7 +245,7 @@ def infer_batch_process(ref_audio, ref_text, gen_text_batches, model_obj, cross_
# wav -> numpy
generated_wave = generated_wave.squeeze().cpu().numpy()
generated_waves.append(generated_wave)
spectrograms.append(generated_mel_spec[0].cpu().numpy())
@@ -280,11 +280,9 @@ def infer_batch_process(ref_audio, ref_text, gen_text_batches, model_obj, cross_
cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
# Combine
new_wave = np.concatenate([
prev_wave[:-cross_fade_samples],
cross_faded_overlap,
next_wave[cross_fade_samples:]
])
new_wave = np.concatenate(
[prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]]
)
final_wave = new_wave
@@ -296,6 +294,7 @@ def infer_batch_process(ref_audio, ref_text, gen_text_batches, model_obj, cross_
# remove silence from generated wav
def remove_silence_for_generated_wav(filename):
aseg = AudioSegment.from_file(filename)
non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)

10
ruff.toml Normal file
View File

@@ -0,0 +1,10 @@
line-length = 120
target-version = "py310"
[lint]
# Only ignore variables with names starting with "_".
dummy-variable-rgx = "^_.*$"
[lint.isort]
force-single-line = true
lines-after-imports = 2

View File

@@ -1,6 +1,7 @@
'''ADAPTIVE BATCH SIZE'''
print('Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in')
print(' -> least padding, gather wavs with accumulated frames in a batch\n')
"""ADAPTIVE BATCH SIZE"""
print("Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in")
print(" -> least padding, gather wavs with accumulated frames in a batch\n")
# data
total_hours = 95282

View File

@@ -1,13 +1,15 @@
import sys, os
import sys
import os
sys.path.append(os.getcwd())
from model import M2_TTS, UNetT, DiT, MMDiT
from model import M2_TTS, DiT
import torch
import thop
''' ~155M '''
""" ~155M """
# transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4)
# transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4)
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2)
@@ -15,11 +17,11 @@ import thop
# transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4, long_skip_connection = True)
# transformer = MMDiT(dim = 512, depth = 16, heads = 16, ff_mult = 2)
''' ~335M '''
""" ~335M """
# FLOPs: 622.1 G, Params: 333.2 M
# transformer = UNetT(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
# FLOPs: 363.4 G, Params: 335.8 M
transformer = DiT(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
transformer = DiT(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
model = M2_TTS(transformer=transformer)
@@ -30,6 +32,8 @@ duration = 20
frame_length = int(duration * target_sample_rate / hop_length)
text_length = 150
flops, params = thop.profile(model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long)))
flops, params = thop.profile(
model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long))
)
print(f"FLOPs: {flops / 1e9} G")
print(f"Params: {params / 1e6} M")

View File

@@ -1,4 +1,6 @@
import sys, os
import sys
import os
sys.path.append(os.getcwd())
import time
@@ -14,9 +16,9 @@ from vocos import Vocos
from model import CFM, UNetT, DiT
from model.utils import (
load_checkpoint,
get_tokenizer,
get_seedtts_testset_metainfo,
get_librispeech_test_clean_metainfo,
get_tokenizer,
get_seedtts_testset_metainfo,
get_librispeech_test_clean_metainfo,
get_inference_prompt,
)
@@ -38,16 +40,16 @@ tokenizer = "pinyin"
parser = argparse.ArgumentParser(description="batch inference")
parser.add_argument('-s', '--seed', default=None, type=int)
parser.add_argument('-d', '--dataset', default="Emilia_ZH_EN")
parser.add_argument('-n', '--expname', required=True)
parser.add_argument('-c', '--ckptstep', default=1200000, type=int)
parser.add_argument("-s", "--seed", default=None, type=int)
parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
parser.add_argument("-n", "--expname", required=True)
parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
parser.add_argument('-nfe', '--nfestep', default=32, type=int)
parser.add_argument('-o', '--odemethod', default="euler")
parser.add_argument('-ss', '--swaysampling', default=-1, type=float)
parser.add_argument("-nfe", "--nfestep", default=32, type=int)
parser.add_argument("-o", "--odemethod", default="euler")
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
parser.add_argument('-t', '--testset', required=True)
parser.add_argument("-t", "--testset", required=True)
args = parser.parse_args()
@@ -66,26 +68,26 @@ testset = args.testset
infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
cfg_strength = 2.
speed = 1.
cfg_strength = 2.0
speed = 1.0
use_truth_duration = False
no_ref_audio = False
if exp_name == "F5TTS_Base":
model_cls = DiT
model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
elif exp_name == "E2TTS_Base":
model_cls = UNetT
model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
if testset == "ls_pc_test_clean":
metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
elif testset == "seedtts_test_zh":
metalst = "data/seedtts_testset/zh/meta.lst"
metainfo = get_seedtts_testset_metainfo(metalst)
@@ -96,13 +98,16 @@ elif testset == "seedtts_test_en":
# path to save genereted wavs
if seed is None: seed = random.randint(-10000, 10000)
output_dir = f"results/{exp_name}_{ckpt_step}/{testset}/" \
f"seed{seed}_{ode_method}_nfe{nfe_step}" \
f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}" \
f"_cfg{cfg_strength}_speed{speed}" \
f"{'_gt-dur' if use_truth_duration else ''}" \
if seed is None:
seed = random.randint(-10000, 10000)
output_dir = (
f"results/{exp_name}_{ckpt_step}/{testset}/"
f"seed{seed}_{ode_method}_nfe{nfe_step}"
f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
f"_cfg{cfg_strength}_speed{speed}"
f"{'_gt-dur' if use_truth_duration else ''}"
f"{'_no-ref-audio' if no_ref_audio else ''}"
)
# -------------------------------------------------#
@@ -110,15 +115,15 @@ output_dir = f"results/{exp_name}_{ckpt_step}/{testset}/" \
use_ema = True
prompts_all = get_inference_prompt(
metainfo,
speed = speed,
tokenizer = tokenizer,
target_sample_rate = target_sample_rate,
n_mel_channels = n_mel_channels,
hop_length = hop_length,
target_rms = target_rms,
use_truth_duration = use_truth_duration,
infer_batch_size = infer_batch_size,
metainfo,
speed=speed,
tokenizer=tokenizer,
target_sample_rate=target_sample_rate,
n_mel_channels=n_mel_channels,
hop_length=hop_length,
target_rms=target_rms,
use_truth_duration=use_truth_duration,
infer_batch_size=infer_batch_size,
)
# Vocoder model
@@ -137,23 +142,19 @@ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
# Model
model = CFM(
transformer = model_cls(
**model_cfg,
text_num_embeds = vocab_size,
mel_dim = n_mel_channels
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
mel_spec_kwargs=dict(
target_sample_rate=target_sample_rate,
n_mel_channels=n_mel_channels,
hop_length=hop_length,
),
mel_spec_kwargs = dict(
target_sample_rate = target_sample_rate,
n_mel_channels = n_mel_channels,
hop_length = hop_length,
odeint_kwargs=dict(
method=ode_method,
),
odeint_kwargs = dict(
method = ode_method,
),
vocab_char_map = vocab_char_map,
vocab_char_map=vocab_char_map,
).to(device)
model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema)
if not os.path.exists(output_dir) and accelerator.is_main_process:
os.makedirs(output_dir)
@@ -163,29 +164,28 @@ accelerator.wait_for_everyone()
start = time.time()
with accelerator.split_between_processes(prompts_all) as prompts:
for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
ref_mels = ref_mels.to(device)
ref_mel_lens = torch.tensor(ref_mel_lens, dtype = torch.long).to(device)
total_mel_lens = torch.tensor(total_mel_lens, dtype = torch.long).to(device)
ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
# Inference
with torch.inference_mode():
generated, _ = model.sample(
cond = ref_mels,
text = final_text_list,
duration = total_mel_lens,
lens = ref_mel_lens,
steps = nfe_step,
cfg_strength = cfg_strength,
sway_sampling_coef = sway_sampling_coef,
no_ref_audio = no_ref_audio,
seed = seed,
cond=ref_mels,
text=final_text_list,
duration=total_mel_lens,
lens=ref_mel_lens,
steps=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
no_ref_audio=no_ref_audio,
seed=seed,
)
# Final result
for i, gen in enumerate(generated):
gen = gen[ref_mel_lens[i]:total_mel_lens[i], :].unsqueeze(0)
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
gen_mel_spec = gen.permute(0, 2, 1)
generated_wave = vocos.decode(gen_mel_spec.cpu())
if ref_rms_list[i] < target_rms:

View File

@@ -1,6 +1,8 @@
# Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
import sys, os
import sys
import os
sys.path.append(os.getcwd())
import multiprocessing as mp
@@ -19,7 +21,7 @@ metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
gpus = [0,1,2,3,4,5,6,7]
gpus = [0, 1, 2, 3, 4, 5, 6, 7]
test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
@@ -46,7 +48,7 @@ if eval_task == "wer":
for wers_ in results:
wers.extend(wers_)
wer = round(np.mean(wers)*100, 3)
wer = round(np.mean(wers) * 100, 3)
print(f"\nTotal {len(wers)} samples")
print(f"WER : {wer}%")
@@ -62,6 +64,6 @@ if eval_task == "sim":
for sim_ in results:
sim_list.extend(sim_)
sim = round(sum(sim_list)/len(sim_list), 3)
sim = round(sum(sim_list) / len(sim_list), 3)
print(f"\nTotal {len(sim_list)} samples")
print(f"SIM : {sim}")

View File

@@ -1,6 +1,8 @@
# Evaluate with Seed-TTS testset
import sys, os
import sys
import os
sys.path.append(os.getcwd())
import multiprocessing as mp
@@ -14,21 +16,21 @@ from model.utils import (
eval_task = "wer" # sim | wer
lang = "zh" # zh | en
lang = "zh" # zh | en
metalst = f"data/seedtts_testset/{lang}/meta.lst" # seed-tts testset
# gen_wav_dir = f"data/seedtts_testset/{lang}/wavs" # ground truth wavs
gen_wav_dir = f"PATH_TO_GENERATED" # generated wavs
gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
# NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
# zh 1.254 seems a result of 4 workers wer_seed_tts
gpus = [0,1,2,3,4,5,6,7]
# zh 1.254 seems a result of 4 workers wer_seed_tts
gpus = [0, 1, 2, 3, 4, 5, 6, 7]
test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
local = False
if local: # use local custom checkpoint dir
if lang == "zh":
asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr
asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr
elif lang == "en":
asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
else:
@@ -48,7 +50,7 @@ if eval_task == "wer":
for wers_ in results:
wers.extend(wers_)
wer = round(np.mean(wers)*100, 3)
wer = round(np.mean(wers) * 100, 3)
print(f"\nTotal {len(wers)} samples")
print(f"WER : {wer}%")
@@ -64,6 +66,6 @@ if eval_task == "sim":
for sim_ in results:
sim_list.extend(sim_)
sim = round(sum(sim_list)/len(sim_list), 3)
sim = round(sum(sim_list) / len(sim_list), 3)
print(f"\nTotal {len(sim_list)} samples")
print(f"SIM : {sim}")

View File

@@ -1,4 +1,6 @@
import sys, os
import sys
import os
sys.path.append(os.getcwd())
from pathlib import Path
@@ -17,10 +19,11 @@ from model.utils import (
PRETRAINED_VOCAB_PATH = Path(__file__).parent.parent / "data/Emilia_ZH_EN_pinyin/vocab.txt"
def is_csv_wavs_format(input_dataset_dir):
fpath = Path(input_dataset_dir)
metadata = fpath / "metadata.csv"
wavs = fpath / 'wavs'
wavs = fpath / "wavs"
return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir()
@@ -46,22 +49,24 @@ def prepare_csv_wavs_dir(input_dir):
return sub_result, durations, vocab_set
def get_audio_duration(audio_path):
audio, sample_rate = torchaudio.load(audio_path)
num_channels = audio.shape[0]
return audio.shape[1] / (sample_rate * num_channels)
def read_audio_text_pairs(csv_file_path):
audio_text_pairs = []
parent = Path(csv_file_path).parent
with open(csv_file_path, mode='r', newline='', encoding='utf-8') as csvfile:
reader = csv.reader(csvfile, delimiter='|')
with open(csv_file_path, mode="r", newline="", encoding="utf-8") as csvfile:
reader = csv.reader(csvfile, delimiter="|")
next(reader) # Skip the header row
for row in reader:
if len(row) >= 2:
audio_file = row[0].strip() # First column: audio file path
text = row[1].strip() # Second column: text
text = row[1].strip() # Second column: text
audio_file_path = parent / audio_file
audio_text_pairs.append((audio_file_path.as_posix(), text))
@@ -78,12 +83,12 @@ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_fine
# dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
raw_arrow_path = out_dir / "raw.arrow"
with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=1) as writer:
for line in tqdm(result, desc=f"Writing to raw.arrow ..."):
for line in tqdm(result, desc="Writing to raw.arrow ..."):
writer.write(line)
# dup a json separately saving duration in case for DynamicBatchSampler ease
dur_json_path = out_dir / "duration.json"
with open(dur_json_path.as_posix(), 'w', encoding='utf-8') as f:
with open(dur_json_path.as_posix(), "w", encoding="utf-8") as f:
json.dump({"duration": duration_list}, f, ensure_ascii=False)
# vocab map, i.e. tokenizer
@@ -120,13 +125,14 @@ def cli():
# finetune: python scripts/prepare_csv_wavs.py /path/to/input_dir /path/to/output_dir_pinyin
# pretrain: python scripts/prepare_csv_wavs.py /path/to/output_dir_pinyin --pretrain
parser = argparse.ArgumentParser(description="Prepare and save dataset.")
parser.add_argument('inp_dir', type=str, help="Input directory containing the data.")
parser.add_argument('out_dir', type=str, help="Output directory to save the prepared data.")
parser.add_argument('--pretrain', action='store_true', help="Enable for new pretrain, otherwise is a fine-tune")
parser.add_argument("inp_dir", type=str, help="Input directory containing the data.")
parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.")
parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune")
args = parser.parse_args()
prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain)
if __name__ == "__main__":
cli()

View File

@@ -4,7 +4,9 @@
# generate audio text map for Emilia ZH & EN
# evaluate for vocab size
import sys, os
import sys
import os
sys.path.append(os.getcwd())
from pathlib import Path
@@ -12,7 +14,6 @@ import json
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor
from datasets import Dataset
from datasets.arrow_writer import ArrowWriter
from model.utils import (
@@ -21,13 +22,89 @@ from model.utils import (
)
out_zh = {"ZH_B00041_S06226", "ZH_B00042_S09204", "ZH_B00065_S09430", "ZH_B00065_S09431", "ZH_B00066_S09327", "ZH_B00066_S09328"}
out_zh = {
"ZH_B00041_S06226",
"ZH_B00042_S09204",
"ZH_B00065_S09430",
"ZH_B00065_S09431",
"ZH_B00066_S09327",
"ZH_B00066_S09328",
}
zh_filters = ["", ""]
# seems synthesized audios, or heavily code-switched
out_en = {
"EN_B00013_S00913", "EN_B00042_S00120", "EN_B00055_S04111", "EN_B00061_S00693", "EN_B00061_S01494", "EN_B00061_S03375",
"EN_B00059_S00092", "EN_B00111_S04300", "EN_B00100_S03759", "EN_B00087_S03811", "EN_B00059_S00950", "EN_B00089_S00946", "EN_B00078_S05127", "EN_B00070_S04089", "EN_B00074_S09659", "EN_B00061_S06983", "EN_B00061_S07060", "EN_B00059_S08397", "EN_B00082_S06192", "EN_B00091_S01238", "EN_B00089_S07349", "EN_B00070_S04343", "EN_B00061_S02400", "EN_B00076_S01262", "EN_B00068_S06467", "EN_B00076_S02943", "EN_B00064_S05954", "EN_B00061_S05386", "EN_B00066_S06544", "EN_B00076_S06944", "EN_B00072_S08620", "EN_B00076_S07135", "EN_B00076_S09127", "EN_B00065_S00497", "EN_B00059_S06227", "EN_B00063_S02859", "EN_B00075_S01547", "EN_B00061_S08286", "EN_B00079_S02901", "EN_B00092_S03643", "EN_B00096_S08653", "EN_B00063_S04297", "EN_B00063_S04614", "EN_B00079_S04698", "EN_B00104_S01666", "EN_B00061_S09504", "EN_B00061_S09694", "EN_B00065_S05444", "EN_B00063_S06860", "EN_B00065_S05725", "EN_B00069_S07628", "EN_B00083_S03875", "EN_B00071_S07665", "EN_B00071_S07665", "EN_B00062_S04187", "EN_B00065_S09873", "EN_B00065_S09922", "EN_B00084_S02463", "EN_B00067_S05066", "EN_B00106_S08060", "EN_B00073_S06399", "EN_B00073_S09236", "EN_B00087_S00432", "EN_B00085_S05618", "EN_B00064_S01262", "EN_B00072_S01739", "EN_B00059_S03913", "EN_B00069_S04036", "EN_B00067_S05623", "EN_B00060_S05389", "EN_B00060_S07290", "EN_B00062_S08995",
"EN_B00013_S00913",
"EN_B00042_S00120",
"EN_B00055_S04111",
"EN_B00061_S00693",
"EN_B00061_S01494",
"EN_B00061_S03375",
"EN_B00059_S00092",
"EN_B00111_S04300",
"EN_B00100_S03759",
"EN_B00087_S03811",
"EN_B00059_S00950",
"EN_B00089_S00946",
"EN_B00078_S05127",
"EN_B00070_S04089",
"EN_B00074_S09659",
"EN_B00061_S06983",
"EN_B00061_S07060",
"EN_B00059_S08397",
"EN_B00082_S06192",
"EN_B00091_S01238",
"EN_B00089_S07349",
"EN_B00070_S04343",
"EN_B00061_S02400",
"EN_B00076_S01262",
"EN_B00068_S06467",
"EN_B00076_S02943",
"EN_B00064_S05954",
"EN_B00061_S05386",
"EN_B00066_S06544",
"EN_B00076_S06944",
"EN_B00072_S08620",
"EN_B00076_S07135",
"EN_B00076_S09127",
"EN_B00065_S00497",
"EN_B00059_S06227",
"EN_B00063_S02859",
"EN_B00075_S01547",
"EN_B00061_S08286",
"EN_B00079_S02901",
"EN_B00092_S03643",
"EN_B00096_S08653",
"EN_B00063_S04297",
"EN_B00063_S04614",
"EN_B00079_S04698",
"EN_B00104_S01666",
"EN_B00061_S09504",
"EN_B00061_S09694",
"EN_B00065_S05444",
"EN_B00063_S06860",
"EN_B00065_S05725",
"EN_B00069_S07628",
"EN_B00083_S03875",
"EN_B00071_S07665",
"EN_B00071_S07665",
"EN_B00062_S04187",
"EN_B00065_S09873",
"EN_B00065_S09922",
"EN_B00084_S02463",
"EN_B00067_S05066",
"EN_B00106_S08060",
"EN_B00073_S06399",
"EN_B00073_S09236",
"EN_B00087_S00432",
"EN_B00085_S05618",
"EN_B00064_S01262",
"EN_B00072_S01739",
"EN_B00059_S03913",
"EN_B00069_S04036",
"EN_B00067_S05623",
"EN_B00060_S05389",
"EN_B00060_S07290",
"EN_B00062_S08995",
}
en_filters = ["ا", "", ""]
@@ -43,18 +120,24 @@ def deal_with_audio_dir(audio_dir):
for line in tqdm(lines, desc=f"{audio_jsonl.stem}"):
obj = json.loads(line)
text = obj["text"]
if obj['language'] == "zh":
if obj["language"] == "zh":
if obj["wav"].split("/")[1] in out_zh or any(f in text for f in zh_filters) or repetition_found(text):
bad_case_zh += 1
continue
else:
text = text.translate(str.maketrans({',': '', '!': '', '?': ''})) # not "。" cuz much code-switched
if obj['language'] == "en":
if obj["wav"].split("/")[1] in out_en or any(f in text for f in en_filters) or repetition_found(text, length=4):
text = text.translate(
str.maketrans({",": "", "!": "", "?": ""})
) # not "。" cuz much code-switched
if obj["language"] == "en":
if (
obj["wav"].split("/")[1] in out_en
or any(f in text for f in en_filters)
or repetition_found(text, length=4)
):
bad_case_en += 1
continue
if tokenizer == "pinyin":
text = convert_char_to_pinyin([text], polyphone = polyphone)[0]
text = convert_char_to_pinyin([text], polyphone=polyphone)[0]
duration = obj["duration"]
sub_result.append({"audio_path": str(audio_dir.parent / obj["wav"]), "text": text, "duration": duration})
durations.append(duration)
@@ -96,11 +179,11 @@ def main():
# dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
# dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
with ArrowWriter(path=f"data/{dataset_name}/raw.arrow") as writer:
for line in tqdm(result, desc=f"Writing to raw.arrow ..."):
for line in tqdm(result, desc="Writing to raw.arrow ..."):
writer.write(line)
# dup a json separately saving duration in case for DynamicBatchSampler ease
with open(f"data/{dataset_name}/duration.json", 'w', encoding='utf-8') as f:
with open(f"data/{dataset_name}/duration.json", "w", encoding="utf-8") as f:
json.dump({"duration": duration_list}, f, ensure_ascii=False)
# vocab map, i.e. tokenizer
@@ -114,12 +197,13 @@ def main():
print(f"\nFor {dataset_name}, sample count: {len(result)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
if "ZH" in langs: print(f"Bad zh transcription case: {total_bad_case_zh}")
if "EN" in langs: print(f"Bad en transcription case: {total_bad_case_en}\n")
if "ZH" in langs:
print(f"Bad zh transcription case: {total_bad_case_zh}")
if "EN" in langs:
print(f"Bad en transcription case: {total_bad_case_en}\n")
if __name__ == "__main__":
max_workers = 32
tokenizer = "pinyin" # "pinyin" | "char"

View File

@@ -1,7 +1,9 @@
# generate audio text map for WenetSpeech4TTS
# evaluate for vocab size
import sys, os
import sys
import os
sys.path.append(os.getcwd())
import json
@@ -23,7 +25,7 @@ def deal_with_sub_path_files(dataset_path, sub_path):
audio_paths, texts, durations = [], [], []
for text_file in tqdm(text_files):
with open(os.path.join(text_dir, text_file), 'r', encoding='utf-8') as file:
with open(os.path.join(text_dir, text_file), "r", encoding="utf-8") as file:
first_line = file.readline().split("\t")
audio_nm = first_line[0]
audio_path = os.path.join(audio_dir, audio_nm + ".wav")
@@ -32,7 +34,7 @@ def deal_with_sub_path_files(dataset_path, sub_path):
audio_paths.append(audio_path)
if tokenizer == "pinyin":
texts.extend(convert_char_to_pinyin([text], polyphone = polyphone))
texts.extend(convert_char_to_pinyin([text], polyphone=polyphone))
elif tokenizer == "char":
texts.append(text)
@@ -46,7 +48,7 @@ def main():
assert tokenizer in ["pinyin", "char"]
audio_path_list, text_list, duration_list = [], [], []
executor = ProcessPoolExecutor(max_workers=max_workers)
futures = []
for dataset_path in dataset_paths:
@@ -68,8 +70,10 @@ def main():
dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list})
dataset.save_to_disk(f"data/{dataset_name}_{tokenizer}/raw", max_shard_size="2GB") # arrow format
with open(f"data/{dataset_name}_{tokenizer}/duration.json", 'w', encoding='utf-8') as f:
json.dump({"duration": duration_list}, f, ensure_ascii=False) # dup a json separately saving duration in case for DynamicBatchSampler ease
with open(f"data/{dataset_name}_{tokenizer}/duration.json", "w", encoding="utf-8") as f:
json.dump(
{"duration": duration_list}, f, ensure_ascii=False
) # dup a json separately saving duration in case for DynamicBatchSampler ease
print("\nEvaluating vocab size (all characters and symbols / all phonemes) ...")
text_vocab_set = set()
@@ -85,22 +89,21 @@ def main():
f.write(vocab + "\n")
print(f"\nFor {dataset_name}, sample count: {len(text_list)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}\n")
if __name__ == "__main__":
max_workers = 32
tokenizer = "pinyin" # "pinyin" | "char"
polyphone = True
dataset_choice = 1 # 1: Premium, 2: Standard, 3: Basic
dataset_name = ["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice-1]
dataset_name = ["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice - 1]
dataset_paths = [
"<SOME_PATH>/WenetSpeech4TTS/Basic",
"<SOME_PATH>/WenetSpeech4TTS/Standard",
"<SOME_PATH>/WenetSpeech4TTS/Premium",
][-dataset_choice:]
][-dataset_choice:]
print(f"\nChoose Dataset: {dataset_name}\n")
main()
@@ -109,8 +112,8 @@ if __name__ == "__main__":
# WenetSpeech4TTS Basic Standard Premium
# samples count 3932473 1941220 407494
# pinyin vocab size 1349 1348 1344 (no polyphone)
# - - 1459 (polyphone)
# - - 1459 (polyphone)
# char vocab size 5264 5219 5042
# vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
# please be careful if using pretrained model, make sure the vocab.txt is same

View File

@@ -5,11 +5,11 @@ import torch.nn.functional as F
import torchaudio
from vocos import Vocos
from model import CFM, UNetT, DiT, MMDiT
from model import CFM, UNetT, DiT
from model.utils import (
load_checkpoint,
get_tokenizer,
convert_char_to_pinyin,
get_tokenizer,
convert_char_to_pinyin,
save_spectrogram,
)
@@ -35,18 +35,18 @@ exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
ckpt_step = 1200000
nfe_step = 32 # 16, 32
cfg_strength = 2.
ode_method = 'euler' # euler | midpoint
sway_sampling_coef = -1.
speed = 1.
cfg_strength = 2.0
ode_method = "euler" # euler | midpoint
sway_sampling_coef = -1.0
speed = 1.0
if exp_name == "F5TTS_Base":
model_cls = DiT
model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
elif exp_name == "E2TTS_Base":
model_cls = UNetT
model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.safetensors"
output_dir = "tests"
@@ -62,8 +62,14 @@ output_dir = "tests"
audio_to_edit = "tests/ref_audio/test_en_1_ref_short.wav"
origin_text = "Some call me nature, others call me mother nature."
target_text = "Some call me optimist, others call me realist."
parts_to_edit = [[1.42, 2.44], [4.04, 4.9], ] # stard_ends of "nature" & "mother nature", in seconds
fix_duration = [1.2, 1, ] # fix duration for "optimist" & "realist", in seconds
parts_to_edit = [
[1.42, 2.44],
[4.04, 4.9],
] # stard_ends of "nature" & "mother nature", in seconds
fix_duration = [
1.2,
1,
] # fix duration for "optimist" & "realist", in seconds
# audio_to_edit = "tests/ref_audio/test_zh_1_ref_short.wav"
# origin_text = "对,这就是我,万人敬仰的太乙真人。"
@@ -86,7 +92,7 @@ if local:
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device)
vocos.load_state_dict(state_dict)
vocos.eval()
else:
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
@@ -96,23 +102,19 @@ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
# Model
model = CFM(
transformer = model_cls(
**model_cfg,
text_num_embeds = vocab_size,
mel_dim = n_mel_channels
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
mel_spec_kwargs=dict(
target_sample_rate=target_sample_rate,
n_mel_channels=n_mel_channels,
hop_length=hop_length,
),
mel_spec_kwargs = dict(
target_sample_rate = target_sample_rate,
n_mel_channels = n_mel_channels,
hop_length = hop_length,
odeint_kwargs=dict(
method=ode_method,
),
odeint_kwargs = dict(
method = ode_method,
),
vocab_char_map = vocab_char_map,
vocab_char_map=vocab_char_map,
).to(device)
model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema)
# Audio
audio, sr = torchaudio.load(audio_to_edit)
@@ -132,14 +134,18 @@ for part in parts_to_edit:
part_dur = end - start if fix_duration is None else fix_duration.pop(0)
part_dur = part_dur * target_sample_rate
start = start * target_sample_rate
audio_ = torch.cat((audio_, audio[:, round(offset):round(start)], torch.zeros(1, round(part_dur))), dim = -1)
edit_mask = torch.cat((edit_mask,
torch.ones(1, round((start - offset) / hop_length), dtype = torch.bool),
torch.zeros(1, round(part_dur / hop_length), dtype = torch.bool)
), dim = -1)
audio_ = torch.cat((audio_, audio[:, round(offset) : round(start)], torch.zeros(1, round(part_dur))), dim=-1)
edit_mask = torch.cat(
(
edit_mask,
torch.ones(1, round((start - offset) / hop_length), dtype=torch.bool),
torch.zeros(1, round(part_dur / hop_length), dtype=torch.bool),
),
dim=-1,
)
offset = end * target_sample_rate
# audio = torch.cat((audio_, audio[:, round(offset):]), dim = -1)
edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value = True)
edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value=True)
audio = audio.to(device)
edit_mask = edit_mask.to(device)
@@ -159,14 +165,14 @@ duration = audio.shape[-1] // hop_length
# Inference
with torch.inference_mode():
generated, trajectory = model.sample(
cond = audio,
text = final_text_list,
duration = duration,
steps = nfe_step,
cfg_strength = cfg_strength,
sway_sampling_coef = sway_sampling_coef,
seed = seed,
edit_mask = edit_mask,
cond=audio,
text=final_text_list,
duration=duration,
steps=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
seed=seed,
edit_mask=edit_mask,
)
print(f"Generated mel: {generated.shape}")

View File

@@ -1,4 +1,4 @@
from model import CFM, UNetT, DiT, MMDiT, Trainer
from model import CFM, UNetT, DiT, Trainer
from model.utils import get_tokenizer
from model.dataset import load_dataset
@@ -9,8 +9,8 @@ target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
tokenizer = "pinyin" # 'pinyin', 'char', or 'custom'
tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
dataset_name = "Emilia_ZH_EN"
# -------------------------- Training Settings -------------------------- #
@@ -23,7 +23,7 @@ batch_size_per_gpu = 38400 # 8 GPUs, 8 * 38400 = 307200
batch_size_type = "frame" # "frame" or "sample"
max_samples = 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
grad_accumulation_steps = 1 # note: updates = steps / grad_accumulation_steps
max_grad_norm = 1.
max_grad_norm = 1.0
epochs = 11 # use linear decay, thus epochs control the slope
num_warmup_updates = 20000 # warmup steps
@@ -34,15 +34,16 @@ last_per_steps = 5000 # save last checkpoint per steps
if exp_name == "F5TTS_Base":
wandb_resume_id = None
model_cls = DiT
model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
elif exp_name == "E2TTS_Base":
wandb_resume_id = None
model_cls = UNetT
model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
# ----------------------------------------------------------------------- #
def main():
if tokenizer == "custom":
tokenizer_path = tokenizer_path
@@ -51,44 +52,41 @@ def main():
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
mel_spec_kwargs = dict(
target_sample_rate = target_sample_rate,
n_mel_channels = n_mel_channels,
hop_length = hop_length,
)
target_sample_rate=target_sample_rate,
n_mel_channels=n_mel_channels,
hop_length=hop_length,
)
model = CFM(
transformer = model_cls(
**model_cfg,
text_num_embeds = vocab_size,
mel_dim = n_mel_channels
),
mel_spec_kwargs = mel_spec_kwargs,
vocab_char_map = vocab_char_map,
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
mel_spec_kwargs=mel_spec_kwargs,
vocab_char_map=vocab_char_map,
)
trainer = Trainer(
model,
epochs,
epochs,
learning_rate,
num_warmup_updates = num_warmup_updates,
save_per_updates = save_per_updates,
checkpoint_path = f'ckpts/{exp_name}',
batch_size = batch_size_per_gpu,
batch_size_type = batch_size_type,
max_samples = max_samples,
grad_accumulation_steps = grad_accumulation_steps,
max_grad_norm = max_grad_norm,
wandb_project = "CFM-TTS",
wandb_run_name = exp_name,
wandb_resume_id = wandb_resume_id,
last_per_steps = last_per_steps,
num_warmup_updates=num_warmup_updates,
save_per_updates=save_per_updates,
checkpoint_path=f"ckpts/{exp_name}",
batch_size=batch_size_per_gpu,
batch_size_type=batch_size_type,
max_samples=max_samples,
grad_accumulation_steps=grad_accumulation_steps,
max_grad_norm=max_grad_norm,
wandb_project="CFM-TTS",
wandb_run_name=exp_name,
wandb_resume_id=wandb_resume_id,
last_per_steps=last_per_steps,
)
train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
trainer.train(train_dataset,
resumable_with_seed = 666 # seed for shuffling dataset
)
trainer.train(
train_dataset,
resumable_with_seed=666, # seed for shuffling dataset
)
if __name__ == '__main__':
if __name__ == "__main__":
main()