mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-26 04:41:36 -08:00
Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
646f34b20f | ||
|
|
2e2acc6ea2 | ||
|
|
6fbe7592f5 | ||
|
|
7e37bc5d9a | ||
|
|
35f130ee85 | ||
|
|
e6469f705f | ||
|
|
31cd818095 | ||
|
|
1d13664b24 | ||
|
|
b27471ea06 | ||
|
|
8fb55f107e | ||
|
|
ccb380b752 | ||
|
|
3027b43953 | ||
|
|
ecd1c3949a | ||
|
|
2968aa184f | ||
|
|
fb26b6d93e | ||
|
|
f7f266cdd9 | ||
|
|
695c735737 | ||
|
|
3e2a07da1d | ||
|
|
c47687487c |
@@ -91,7 +91,7 @@ conda activate f5-tts
|
||||
> ```bash
|
||||
> git clone https://github.com/SWivid/F5-TTS.git
|
||||
> cd F5-TTS
|
||||
> # git submodule update --init --recursive # (optional, if need > bigvgan)
|
||||
> # git submodule update --init --recursive # (optional, if use bigvgan as vocoder)
|
||||
> pip install -e .
|
||||
> ```
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "f5-tts"
|
||||
version = "1.1.4"
|
||||
version = "1.1.5"
|
||||
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
|
||||
readme = "README.md"
|
||||
license = {text = "MIT License"}
|
||||
@@ -14,7 +14,7 @@ classifiers = [
|
||||
"Programming Language :: Python :: 3",
|
||||
]
|
||||
dependencies = [
|
||||
"accelerate>=0.33.0",
|
||||
"accelerate>=0.33.0,!=1.7.0",
|
||||
"bitsandbytes>0.37.0; platform_machine != 'arm64' and platform_system != 'Darwin'",
|
||||
"cached_path",
|
||||
"click",
|
||||
|
||||
@@ -13,7 +13,7 @@ To avoid possible inference failures, make sure you have seen through the follow
|
||||
- Add some spaces (blank: " ") or punctuations (e.g. "," ".") <ins>to explicitly introduce some pauses</ins>.
|
||||
- If English punctuation marks the end of a sentence, make sure there is a space " " after it. Otherwise not regarded as when chunk.
|
||||
- <ins>Preprocess numbers</ins> to Chinese letters if you want to have them read in Chinese, otherwise in English.
|
||||
- If the generation output is blank (pure silence), <ins>check for ffmpeg installation</ins>.
|
||||
- If the generation output is blank (pure silence), <ins>check for FFmpeg installation</ins>.
|
||||
- Try <ins>turn off `use_ema` if using an early-stage</ins> finetuned checkpoint (which goes just few updates).
|
||||
|
||||
|
||||
@@ -129,6 +129,28 @@ ref_text = ""
|
||||
```
|
||||
You should mark the voice with `[main]` `[town]` `[country]` whenever you want to change voice, refer to `src/f5_tts/infer/examples/multi/story.txt`.
|
||||
|
||||
## API Usage
|
||||
|
||||
```python
|
||||
from importlib.resources import files
|
||||
from f5_tts.api import F5TTS
|
||||
|
||||
f5tts = F5TTS()
|
||||
wav, sr, spec = f5tts.infer(
|
||||
ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
|
||||
ref_text="some call me nature, others call me mother nature.",
|
||||
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
|
||||
file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
|
||||
file_spec=str(files("f5_tts").joinpath("../../tests/api_out.png")),
|
||||
seed=None,
|
||||
)
|
||||
```
|
||||
Check [api.py](../api.py) for more details.
|
||||
|
||||
## TensorRT-LLM Deployment
|
||||
|
||||
See [detailed instructions](../runtime/triton_trtllm/README.md) for more information.
|
||||
|
||||
## Socket Real-time Service
|
||||
|
||||
Real-time voice output with chunk stream:
|
||||
|
||||
@@ -22,6 +22,8 @@
|
||||
- [F5-TTS Base @ fi @ AsmoKoskinen](#f5-tts-base--fi--asmokoskinen)
|
||||
- [French](#french)
|
||||
- [F5-TTS Base @ fr @ RASPIAUDIO](#f5-tts-base--fr--raspiaudio)
|
||||
- [German](#german)
|
||||
- [F5-TTS Base @ de @ hvoss-techfak](#f5-tts-base--de--hvoss-techfak)
|
||||
- [Hindi](#hindi)
|
||||
- [F5-TTS Small @ hi @ SPRINGLab](#f5-tts-small--hi--springlab)
|
||||
- [Italian](#italian)
|
||||
@@ -97,6 +99,22 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
|
||||
- [Discussion about this training can be found here](https://github.com/SWivid/F5-TTS/issues/434).
|
||||
|
||||
|
||||
## German
|
||||
|
||||
#### F5-TTS Base @ de @ hvoss-techfak
|
||||
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
||||
|:---:|:------------:|:-----------:|:-------------:|
|
||||
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/hvoss-techfak/F5-TTS-German)|[Mozilla Common Voice 19.0](https://commonvoice.mozilla.org/en/datasets) & 800 hours Crowdsourced |cc-by-nc-4.0|
|
||||
|
||||
```bash
|
||||
Model: hf://hvoss-techfak/F5-TTS-German/model_f5tts_german.pt
|
||||
Vocab: hf://hvoss-techfak/F5-TTS-German/vocab.txt
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
||||
```
|
||||
|
||||
- Finetuned by [@hvoss-techfak](https://github.com/hvoss-techfak)
|
||||
|
||||
|
||||
## Hindi
|
||||
|
||||
#### F5-TTS Small @ hi @ SPRINGLab
|
||||
|
||||
@@ -323,7 +323,7 @@ def main():
|
||||
ref_text_ = voices[voice]["ref_text"]
|
||||
gen_text_ = text.strip()
|
||||
print(f"Voice: {voice}")
|
||||
audio_segment, final_sample_rate, spectragram = infer_process(
|
||||
audio_segment, final_sample_rate, spectrogram = infer_process(
|
||||
ref_audio_,
|
||||
ref_text_,
|
||||
gen_text_,
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from collections import OrderedDict
|
||||
@@ -41,6 +42,7 @@ from f5_tts.infer.utils_infer import (
|
||||
preprocess_ref_audio_text,
|
||||
remove_silence_for_generated_wav,
|
||||
save_spectrogram,
|
||||
tempfile_kwargs,
|
||||
)
|
||||
from f5_tts.model import DiT, UNetT
|
||||
|
||||
@@ -126,7 +128,7 @@ def load_text_from_file(file):
|
||||
return gr.update(value=text)
|
||||
|
||||
|
||||
@lru_cache(maxsize=100) # NOTE. need to ensure params of infer() hashable
|
||||
@lru_cache(maxsize=1000) # NOTE. need to ensure params of infer() hashable
|
||||
@gpu_decorator
|
||||
def infer(
|
||||
ref_audio_orig,
|
||||
@@ -189,28 +191,24 @@ def infer(
|
||||
|
||||
# Remove silence
|
||||
if remove_silence:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
||||
sf.write(f.name, final_wave, final_sample_rate)
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", **tempfile_kwargs) as f:
|
||||
temp_path = f.name
|
||||
try:
|
||||
sf.write(temp_path, final_wave, final_sample_rate)
|
||||
remove_silence_for_generated_wav(f.name)
|
||||
final_wave, _ = torchaudio.load(f.name)
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
final_wave = final_wave.squeeze().cpu().numpy()
|
||||
|
||||
# Save the spectrogram
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", **tempfile_kwargs) as tmp_spectrogram:
|
||||
spectrogram_path = tmp_spectrogram.name
|
||||
save_spectrogram(combined_spectrogram, spectrogram_path)
|
||||
save_spectrogram(combined_spectrogram, spectrogram_path)
|
||||
|
||||
return (final_sample_rate, final_wave), spectrogram_path, ref_text, used_seed
|
||||
|
||||
|
||||
with gr.Blocks() as app_credits:
|
||||
gr.Markdown("""
|
||||
# Credits
|
||||
|
||||
* [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
|
||||
* [RootingInLoad](https://github.com/RootingInLoad) for initial chunk generation and podcast app exploration
|
||||
* [jpgallegoar](https://github.com/jpgallegoar) for multiple speech-type generation & voice chat
|
||||
""")
|
||||
with gr.Blocks() as app_tts:
|
||||
gr.Markdown("# Batched TTS")
|
||||
ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
|
||||
@@ -314,6 +312,12 @@ with gr.Blocks() as app_tts:
|
||||
outputs=[ref_text_input],
|
||||
)
|
||||
|
||||
ref_audio_input.clear(
|
||||
lambda: [None, None],
|
||||
None,
|
||||
[ref_text_input, ref_text_file],
|
||||
)
|
||||
|
||||
generate_btn.click(
|
||||
basic_tts,
|
||||
inputs=[
|
||||
@@ -926,6 +930,16 @@ Have a conversation with an AI using your reference voice!
|
||||
)
|
||||
|
||||
|
||||
with gr.Blocks() as app_credits:
|
||||
gr.Markdown("""
|
||||
# Credits
|
||||
|
||||
* [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
|
||||
* [RootingInLoad](https://github.com/RootingInLoad) for initial chunk generation and podcast app exploration
|
||||
* [jpgallegoar](https://github.com/jpgallegoar) for multiple speech-type generation & voice chat
|
||||
""")
|
||||
|
||||
|
||||
with gr.Blocks() as app:
|
||||
gr.Markdown(
|
||||
f"""
|
||||
|
||||
@@ -33,6 +33,7 @@ from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
|
||||
|
||||
|
||||
_ref_audio_cache = {}
|
||||
_ref_text_cache = {}
|
||||
|
||||
device = (
|
||||
"cuda"
|
||||
@@ -44,6 +45,8 @@ device = (
|
||||
else "cpu"
|
||||
)
|
||||
|
||||
tempfile_kwargs = {"delete_on_close": False} if sys.version_info >= (3, 12) else {"delete": False}
|
||||
|
||||
# -----------------------------------------
|
||||
|
||||
target_sample_rate = 24000
|
||||
@@ -290,62 +293,74 @@ def remove_silence_edges(audio, silence_threshold=-42):
|
||||
# preprocess reference audio and text
|
||||
|
||||
|
||||
def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print):
|
||||
def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print):
|
||||
show_info("Converting audio...")
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
||||
|
||||
# Compute a hash of the reference audio file
|
||||
with open(ref_audio_orig, "rb") as audio_file:
|
||||
audio_data = audio_file.read()
|
||||
audio_hash = hashlib.md5(audio_data).hexdigest()
|
||||
|
||||
global _ref_audio_cache
|
||||
|
||||
if audio_hash in _ref_audio_cache:
|
||||
show_info("Using cached preprocessed reference audio...")
|
||||
ref_audio = _ref_audio_cache[audio_hash]
|
||||
|
||||
else: # first pass, do preprocess
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", **tempfile_kwargs) as f:
|
||||
temp_path = f.name
|
||||
|
||||
aseg = AudioSegment.from_file(ref_audio_orig)
|
||||
|
||||
if clip_short:
|
||||
# 1. try to find long silence for clipping
|
||||
# 1. try to find long silence for clipping
|
||||
non_silent_segs = silence.split_on_silence(
|
||||
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000, seek_step=10
|
||||
)
|
||||
non_silent_wave = AudioSegment.silent(duration=0)
|
||||
for non_silent_seg in non_silent_segs:
|
||||
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
|
||||
show_info("Audio is over 12s, clipping short. (1)")
|
||||
break
|
||||
non_silent_wave += non_silent_seg
|
||||
|
||||
# 2. try to find short silence for clipping if 1. failed
|
||||
if len(non_silent_wave) > 12000:
|
||||
non_silent_segs = silence.split_on_silence(
|
||||
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000, seek_step=10
|
||||
aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
|
||||
)
|
||||
non_silent_wave = AudioSegment.silent(duration=0)
|
||||
for non_silent_seg in non_silent_segs:
|
||||
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
|
||||
show_info("Audio is over 12s, clipping short. (1)")
|
||||
show_info("Audio is over 12s, clipping short. (2)")
|
||||
break
|
||||
non_silent_wave += non_silent_seg
|
||||
|
||||
# 2. try to find short silence for clipping if 1. failed
|
||||
if len(non_silent_wave) > 12000:
|
||||
non_silent_segs = silence.split_on_silence(
|
||||
aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
|
||||
)
|
||||
non_silent_wave = AudioSegment.silent(duration=0)
|
||||
for non_silent_seg in non_silent_segs:
|
||||
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
|
||||
show_info("Audio is over 12s, clipping short. (2)")
|
||||
break
|
||||
non_silent_wave += non_silent_seg
|
||||
aseg = non_silent_wave
|
||||
|
||||
aseg = non_silent_wave
|
||||
|
||||
# 3. if no proper silence found for clipping
|
||||
if len(aseg) > 12000:
|
||||
aseg = aseg[:12000]
|
||||
show_info("Audio is over 12s, clipping short. (3)")
|
||||
# 3. if no proper silence found for clipping
|
||||
if len(aseg) > 12000:
|
||||
aseg = aseg[:12000]
|
||||
show_info("Audio is over 12s, clipping short. (3)")
|
||||
|
||||
aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
|
||||
aseg.export(f.name, format="wav")
|
||||
ref_audio = f.name
|
||||
aseg.export(temp_path, format="wav")
|
||||
ref_audio = temp_path
|
||||
|
||||
# Compute a hash of the reference audio file
|
||||
with open(ref_audio, "rb") as audio_file:
|
||||
audio_data = audio_file.read()
|
||||
audio_hash = hashlib.md5(audio_data).hexdigest()
|
||||
# Cache the processed reference audio
|
||||
_ref_audio_cache[audio_hash] = ref_audio
|
||||
|
||||
if not ref_text.strip():
|
||||
global _ref_audio_cache
|
||||
if audio_hash in _ref_audio_cache:
|
||||
global _ref_text_cache
|
||||
if audio_hash in _ref_text_cache:
|
||||
# Use cached asr transcription
|
||||
show_info("Using cached reference text...")
|
||||
ref_text = _ref_audio_cache[audio_hash]
|
||||
ref_text = _ref_text_cache[audio_hash]
|
||||
else:
|
||||
show_info("No reference text provided, transcribing reference audio...")
|
||||
ref_text = transcribe(ref_audio)
|
||||
# Cache the transcribed text (not caching custom ref_text, enabling users to do manual tweak)
|
||||
_ref_audio_cache[audio_hash] = ref_text
|
||||
_ref_text_cache[audio_hash] = ref_text
|
||||
else:
|
||||
show_info("Using custom reference text...")
|
||||
|
||||
@@ -384,7 +399,7 @@ def infer_process(
|
||||
):
|
||||
# Split the input text into batches
|
||||
audio, sr = torchaudio.load(ref_audio)
|
||||
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr))
|
||||
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr) * speed)
|
||||
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
|
||||
for i, gen_text in enumerate(gen_text_batches):
|
||||
print(f"gen_text {i}", gen_text)
|
||||
|
||||
@@ -22,6 +22,7 @@ from f5_tts.model.modules import MelSpec
|
||||
from f5_tts.model.utils import (
|
||||
default,
|
||||
exists,
|
||||
get_epss_timesteps,
|
||||
lens_to_mask,
|
||||
list_str_to_idx,
|
||||
list_str_to_tensor,
|
||||
@@ -92,6 +93,7 @@ class CFM(nn.Module):
|
||||
seed: int | None = None,
|
||||
max_duration=4096,
|
||||
vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
|
||||
use_epss=True,
|
||||
no_ref_audio=False,
|
||||
duplicate_test=False,
|
||||
t_inter=0.1,
|
||||
@@ -190,7 +192,10 @@ class CFM(nn.Module):
|
||||
y0 = (1 - t_start) * y0 + t_start * test_cond
|
||||
steps = int(steps * (1 - t_start))
|
||||
|
||||
t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype)
|
||||
if t_start == 0 and use_epss: # use Empirically Pruned Step Sampling for low NFE
|
||||
t = get_epss_timesteps(steps, device=self.device, dtype=step_cond.dtype)
|
||||
else:
|
||||
t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype)
|
||||
if sway_sampling_coef is not None:
|
||||
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
|
||||
|
||||
|
||||
@@ -189,3 +189,22 @@ def repetition_found(text, length=2, tolerance=10):
|
||||
if count > tolerance:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# get the empirically pruned step for sampling
|
||||
|
||||
|
||||
def get_epss_timesteps(n, device, dtype):
|
||||
dt = 1 / 32
|
||||
predefined_timesteps = {
|
||||
5: [0, 2, 4, 8, 16, 32],
|
||||
6: [0, 2, 4, 6, 8, 16, 32],
|
||||
7: [0, 2, 4, 6, 8, 16, 24, 32],
|
||||
10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32],
|
||||
12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32],
|
||||
16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32],
|
||||
}
|
||||
t = predefined_timesteps.get(n, [])
|
||||
if not t:
|
||||
return torch.linspace(0, 1, n + 1, device=device, dtype=dtype)
|
||||
return dt * torch.tensor(t, device=device, dtype=dtype)
|
||||
|
||||
@@ -220,8 +220,8 @@ def get_args():
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_audio(wav_path, target_sample_rate=16000):
|
||||
assert target_sample_rate == 16000, "hard coding in server"
|
||||
def load_audio(wav_path, target_sample_rate=24000):
|
||||
assert target_sample_rate == 24000, "hard coding in server"
|
||||
if isinstance(wav_path, dict):
|
||||
waveform = wav_path["array"]
|
||||
sample_rate = wav_path["sampling_rate"]
|
||||
@@ -244,7 +244,7 @@ async def send(
|
||||
model_name: str,
|
||||
padding_duration: int = None,
|
||||
audio_save_dir: str = "./",
|
||||
save_sample_rate: int = 16000,
|
||||
save_sample_rate: int = 24000,
|
||||
):
|
||||
total_duration = 0.0
|
||||
latency_data = []
|
||||
@@ -254,7 +254,7 @@ async def send(
|
||||
for i, item in enumerate(manifest_item_list):
|
||||
if i % log_interval == 0:
|
||||
print(f"{name}: {i}/{len(manifest_item_list)}")
|
||||
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
|
||||
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=24000)
|
||||
duration = len(waveform) / sample_rate
|
||||
lengths = np.array([[len(waveform)]], dtype=np.int32)
|
||||
|
||||
@@ -310,8 +310,9 @@ async def send(
|
||||
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
|
||||
sf.write(audio_save_path, audio, save_sample_rate, "PCM_16")
|
||||
|
||||
latency_data.append((end, estimated_target_duration))
|
||||
total_duration += estimated_target_duration
|
||||
actual_duration = len(audio) / save_sample_rate
|
||||
latency_data.append((end, actual_duration))
|
||||
total_duration += actual_duration
|
||||
|
||||
return total_duration, latency_data
|
||||
|
||||
@@ -416,7 +417,7 @@ async def main():
|
||||
model_name=args.model_name,
|
||||
audio_save_dir=args.log_dir,
|
||||
padding_duration=1,
|
||||
save_sample_rate=24000 if args.model_name == "f5_tts" else 16000,
|
||||
save_sample_rate=24000,
|
||||
)
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
@@ -82,7 +82,7 @@ def prepare_request(
|
||||
samples,
|
||||
reference_text,
|
||||
target_text,
|
||||
sample_rate=16000,
|
||||
sample_rate=24000,
|
||||
audio_save_dir: str = "./",
|
||||
):
|
||||
assert len(samples.shape) == 1, "samples should be 1D"
|
||||
@@ -106,8 +106,8 @@ def prepare_request(
|
||||
return data
|
||||
|
||||
|
||||
def load_audio(wav_path, target_sample_rate=16000):
|
||||
assert target_sample_rate == 16000, "hard coding in server"
|
||||
def load_audio(wav_path, target_sample_rate=24000):
|
||||
assert target_sample_rate == 24000, "hard coding in server"
|
||||
if isinstance(wav_path, dict):
|
||||
samples = wav_path["array"]
|
||||
sample_rate = wav_path["sampling_rate"]
|
||||
@@ -129,7 +129,7 @@ if __name__ == "__main__":
|
||||
|
||||
url = f"{server_url}/v2/models/{args.model_name}/infer"
|
||||
samples, sr = load_audio(args.reference_audio)
|
||||
assert sr == 16000, "sample rate hardcoded in server"
|
||||
assert sr == 24000, "sample rate hardcoded in server"
|
||||
|
||||
samples = np.array(samples, dtype=np.float32)
|
||||
data = prepare_request(samples, args.reference_text, args.target_text)
|
||||
|
||||
@@ -33,7 +33,7 @@ parameters [
|
||||
},
|
||||
{
|
||||
key: "reference_audio_sample_rate",
|
||||
value: {string_value:"16000"}
|
||||
value: {string_value:"24000"}
|
||||
},
|
||||
{
|
||||
key: "vocoder",
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
# Training
|
||||
|
||||
Check your FFmpeg installation:
|
||||
```bash
|
||||
ffmpeg -version
|
||||
```
|
||||
If not found, install it first (or skip assuming you know of other backends available).
|
||||
|
||||
## Prepare Dataset
|
||||
|
||||
Example data processing scripts, and you may tailor your own one along with a Dataset class in `src/f5_tts/model/dataset.py`.
|
||||
|
||||
@@ -434,7 +434,7 @@ def start_training(
|
||||
fp16 = ""
|
||||
|
||||
cmd = (
|
||||
f"accelerate launch {fp16} {file_train} --exp_name {exp_name}"
|
||||
f'accelerate launch {fp16} "{file_train}" --exp_name {exp_name}'
|
||||
f" --learning_rate {learning_rate}"
|
||||
f" --batch_size_per_gpu {batch_size_per_gpu}"
|
||||
f" --batch_size_type {batch_size_type}"
|
||||
@@ -453,7 +453,7 @@ def start_training(
|
||||
cmd += " --finetune"
|
||||
|
||||
if file_checkpoint_train != "":
|
||||
cmd += f" --pretrain {file_checkpoint_train}"
|
||||
cmd += f' --pretrain "{file_checkpoint_train}"'
|
||||
|
||||
if tokenizer_file != "":
|
||||
cmd += f" --tokenizer_path {tokenizer_file}"
|
||||
@@ -1099,7 +1099,7 @@ def vocab_extend(project_name, symbols, model_type):
|
||||
return f"vocab old size : {size_vocab}\nvocab new size : {size}\nvocab add : {vocab_size_new}\nnew symbols :\n{vocab_new}"
|
||||
|
||||
|
||||
def vocab_check(project_name):
|
||||
def vocab_check(project_name, tokenizer_type):
|
||||
name_project = project_name
|
||||
path_project = os.path.join(path_data, name_project)
|
||||
|
||||
@@ -1128,6 +1128,8 @@ def vocab_check(project_name):
|
||||
continue
|
||||
|
||||
text = sp[1].lower().strip()
|
||||
if tokenizer_type == "pinyin":
|
||||
text = convert_char_to_pinyin([text], polyphone=True)[0]
|
||||
|
||||
for t in text:
|
||||
if t not in vocab and t not in miss_symbols_keep:
|
||||
@@ -1498,7 +1500,9 @@ Using the extended model, you can finetune to a new language that is missing sym
|
||||
txt_info_extend = gr.Textbox(label="Info", value="")
|
||||
|
||||
txt_extend.change(vocab_count, inputs=[txt_extend], outputs=[txt_count_symbol])
|
||||
check_button.click(fn=vocab_check, inputs=[cm_project], outputs=[txt_info_check, txt_extend])
|
||||
check_button.click(
|
||||
fn=vocab_check, inputs=[cm_project, tokenizer_type], outputs=[txt_info_check, txt_extend]
|
||||
)
|
||||
extend_button.click(
|
||||
fn=vocab_extend, inputs=[cm_project, txt_extend, exp_name_extend], outputs=[txt_info_extend]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user