mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-25 20:34:27 -08:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7e4985ca56 | ||
|
|
f05ceda4cb | ||
|
|
2bd39dd813 | ||
|
|
f017815083 | ||
|
|
297755fac3 | ||
|
|
d05075205f | ||
|
|
8722cf0766 | ||
|
|
48d1a9312e |
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "f5-tts"
|
||||
version = "1.0.0"
|
||||
version = "1.0.3"
|
||||
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
|
||||
readme = "README.md"
|
||||
license = {text = "MIT License"}
|
||||
|
||||
@@ -74,8 +74,6 @@ class F5TTS:
|
||||
elif model == "E2TTS_Base":
|
||||
repo_name = "E2-TTS"
|
||||
ckpt_step = 1200000
|
||||
else:
|
||||
raise ValueError(f"Unknown model type: {model}")
|
||||
|
||||
if not ckpt_file:
|
||||
ckpt_file = str(
|
||||
@@ -117,8 +115,9 @@ class F5TTS:
|
||||
seed=None,
|
||||
):
|
||||
if seed is None:
|
||||
self.seed = random.randint(0, sys.maxsize)
|
||||
seed_everything(self.seed)
|
||||
seed = random.randint(0, sys.maxsize)
|
||||
seed_everything(seed)
|
||||
self.seed = seed
|
||||
|
||||
ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device)
|
||||
|
||||
|
||||
@@ -479,14 +479,15 @@ def infer_batch_process(
|
||||
cfg_strength=cfg_strength,
|
||||
sway_sampling_coef=sway_sampling_coef,
|
||||
)
|
||||
del _
|
||||
|
||||
generated = generated.to(torch.float32)
|
||||
generated = generated.to(torch.float32) # generated mel spectrogram
|
||||
generated = generated[:, ref_audio_len:, :]
|
||||
generated_mel_spec = generated.permute(0, 2, 1)
|
||||
generated = generated.permute(0, 2, 1)
|
||||
if mel_spec_type == "vocos":
|
||||
generated_wave = vocoder.decode(generated_mel_spec)
|
||||
generated_wave = vocoder.decode(generated)
|
||||
elif mel_spec_type == "bigvgan":
|
||||
generated_wave = vocoder(generated_mel_spec)
|
||||
generated_wave = vocoder(generated)
|
||||
if rms < target_rms:
|
||||
generated_wave = generated_wave * rms / target_rms
|
||||
|
||||
@@ -497,7 +498,9 @@ def infer_batch_process(
|
||||
for j in range(0, len(generated_wave), chunk_size):
|
||||
yield generated_wave[j : j + chunk_size], target_sample_rate
|
||||
else:
|
||||
yield generated_wave, generated_mel_spec[0].cpu().numpy()
|
||||
generated_cpu = generated[0].cpu().numpy()
|
||||
del generated
|
||||
yield generated_wave, generated_cpu
|
||||
|
||||
if streaming:
|
||||
for gen_text in progress.tqdm(gen_text_batches) if progress is not None else gen_text_batches:
|
||||
|
||||
@@ -219,7 +219,8 @@ class DiT(nn.Module):
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
if self.checkpoint_activations:
|
||||
x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope)
|
||||
# https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.checkpoint
|
||||
x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope, use_reentrant=False)
|
||||
else:
|
||||
x = block(x, t, mask=mask, rope=rope)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user