diff --git a/src/f5_tts/infer/speech_edit.py b/src/f5_tts/infer/speech_edit.py index fdeda9f..1b32656 100644 --- a/src/f5_tts/infer/speech_edit.py +++ b/src/f5_tts/infer/speech_edit.py @@ -89,6 +89,12 @@ fix_duration = [ # parts_to_edit = [[0.84, 1.4], [1.92, 2.4], [4.26, 6.26], ] # fix_duration = None # use origin text duration +# audio_to_edit = "src/f5_tts/infer/examples/basic/basic_ref_zh.wav" +# origin_text = "对,这就是我,万人敬仰的太乙真人。" +# target_text = "对,这就是你,万人敬仰的李白金星。" +# parts_to_edit = [[1.500, 2.784], [4.083, 6.760]] +# fix_duration = [1.284, 2.677] + # -------------------------------------------------# @@ -138,28 +144,55 @@ if rms < target_rms: if sr != target_sample_rate: resampler = torchaudio.transforms.Resample(sr, target_sample_rate) audio = resampler(audio) -offset = 0 -audio_ = torch.zeros(1, 0) -edit_mask = torch.zeros(1, 0, dtype=torch.bool) + +# Convert to mel spectrogram FIRST (on clean original audio) +# This avoids boundary artifacts from mel windows straddling zeros and real audio +audio = audio.to(device) +with torch.inference_mode(): + original_mel = model.mel_spec(audio) # (batch, n_mel, n_frames) + original_mel = original_mel.permute(0, 2, 1) # (batch, n_frames, n_mel) + +# Build mel_cond and edit_mask at FRAME level +# Insert zero frames in mel domain instead of zero samples in wav domain +offset_frame = 0 +mel_cond = torch.zeros(1, 0, n_mel_channels, device=device) +edit_mask = torch.zeros(1, 0, dtype=torch.bool, device=device) +fix_dur_list = fix_duration.copy() if fix_duration is not None else None + for part in parts_to_edit: start, end = part - part_dur = end - start if fix_duration is None else fix_duration.pop(0) - part_dur = part_dur * target_sample_rate - start = start * target_sample_rate - audio_ = torch.cat((audio_, audio[:, round(offset) : round(start)], torch.zeros(1, round(part_dur))), dim=-1) + part_dur_sec = end - start if fix_dur_list is None else fix_dur_list.pop(0) + + # Convert to frames (this is the authoritative unit) + start_frame = round(start * target_sample_rate / hop_length) + end_frame = round(end * target_sample_rate / hop_length) + part_dur_frames = round(part_dur_sec * target_sample_rate / hop_length) + + # Number of frames for the kept (non-edited) region + keep_frames = start_frame - offset_frame + + # Build mel_cond: original mel frames + zero frames for edit region + mel_cond = torch.cat( + ( + mel_cond, + original_mel[:, offset_frame:start_frame, :], + torch.zeros(1, part_dur_frames, n_mel_channels, device=device), + ), + dim=1, + ) edit_mask = torch.cat( ( edit_mask, - torch.ones(1, round((start - offset) / hop_length), dtype=torch.bool), - torch.zeros(1, round(part_dur / hop_length), dtype=torch.bool), + torch.ones(1, keep_frames, dtype=torch.bool, device=device), + torch.zeros(1, part_dur_frames, dtype=torch.bool, device=device), ), dim=-1, ) - offset = end * target_sample_rate -audio = torch.cat((audio_, audio[:, round(offset) :]), dim=-1) -edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value=True) -audio = audio.to(device) -edit_mask = edit_mask.to(device) + offset_frame = end_frame + +# Append remaining mel frames after last edit +mel_cond = torch.cat((mel_cond, original_mel[:, offset_frame:, :]), dim=1) +edit_mask = F.pad(edit_mask, (0, mel_cond.shape[1] - edit_mask.shape[-1]), value=True) # Text text_list = [target_text] @@ -170,14 +203,13 @@ else: print(f"text : {text_list}") print(f"pinyin: {final_text_list}") -# Duration -ref_audio_len = 0 -duration = audio.shape[-1] // hop_length +# Duration - use mel_cond length (not raw audio length) +duration = mel_cond.shape[1] -# Inference +# Inference - pass mel_cond directly (not wav) with torch.inference_mode(): generated, trajectory = model.sample( - cond=audio, + cond=mel_cond, # Now passing mel directly, not wav text=final_text_list, duration=duration, steps=nfe_step, @@ -190,7 +222,6 @@ with torch.inference_mode(): # Final result generated = generated.to(torch.float32) - generated = generated[:, ref_audio_len:, :] gen_mel_spec = generated.permute(0, 2, 1) if mel_spec_type == "vocos": generated_wave = vocoder.decode(gen_mel_spec).cpu()