From 605fa13b42b40e860961bac8ce30fe49f02dfa0d Mon Sep 17 00:00:00 2001 From: Jim <121858197+jneuendorf-i4h@users.noreply.github.com> Date: Tue, 22 Jul 2025 13:38:44 +0200 Subject: [PATCH] Fix raw.arrow missing rows (#1145) * fix raw.arrow missing rows --------- Co-authored-by: SWivid --- src/f5_tts/train/datasets/prepare_csv_wavs.py | 4 ++-- src/f5_tts/train/datasets/prepare_emilia.py | 1 + src/f5_tts/train/datasets/prepare_emilia_v2.py | 1 + src/f5_tts/train/datasets/prepare_libritts.py | 1 + src/f5_tts/train/datasets/prepare_ljspeech.py | 1 + src/f5_tts/train/finetune_gradio.py | 3 ++- 6 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/f5_tts/train/datasets/prepare_csv_wavs.py b/src/f5_tts/train/datasets/prepare_csv_wavs.py index 26ad6f8..4717b8b 100644 --- a/src/f5_tts/train/datasets/prepare_csv_wavs.py +++ b/src/f5_tts/train/datasets/prepare_csv_wavs.py @@ -208,11 +208,11 @@ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_fine out_dir.mkdir(exist_ok=True, parents=True) print(f"\nSaving to {out_dir} ...") - # Save dataset with improved batch size for better I/O performance raw_arrow_path = out_dir / "raw.arrow" - with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=100) as writer: + with ArrowWriter(path=raw_arrow_path.as_posix()) as writer: for line in tqdm(result, desc="Writing to raw.arrow ..."): writer.write(line) + writer.finalize() # Save durations to JSON dur_json_path = out_dir / "duration.json" diff --git a/src/f5_tts/train/datasets/prepare_emilia.py b/src/f5_tts/train/datasets/prepare_emilia.py index 4c4a771..7c6b805 100644 --- a/src/f5_tts/train/datasets/prepare_emilia.py +++ b/src/f5_tts/train/datasets/prepare_emilia.py @@ -181,6 +181,7 @@ def main(): with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer: for line in tqdm(result, desc="Writing to raw.arrow ..."): writer.write(line) + writer.finalize() # dup a json separately saving duration in case for DynamicBatchSampler ease with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f: diff --git a/src/f5_tts/train/datasets/prepare_emilia_v2.py b/src/f5_tts/train/datasets/prepare_emilia_v2.py index 50322c0..e839412 100644 --- a/src/f5_tts/train/datasets/prepare_emilia_v2.py +++ b/src/f5_tts/train/datasets/prepare_emilia_v2.py @@ -68,6 +68,7 @@ def main(): with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer: for line in tqdm(result, desc="Writing to raw.arrow ..."): writer.write(line) + writer.finalize() with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f: json.dump({"duration": duration_list}, f, ensure_ascii=False) diff --git a/src/f5_tts/train/datasets/prepare_libritts.py b/src/f5_tts/train/datasets/prepare_libritts.py index a892dd6..0a11eb0 100644 --- a/src/f5_tts/train/datasets/prepare_libritts.py +++ b/src/f5_tts/train/datasets/prepare_libritts.py @@ -62,6 +62,7 @@ def main(): with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer: for line in tqdm(result, desc="Writing to raw.arrow ..."): writer.write(line) + writer.finalize() # dup a json separately saving duration in case for DynamicBatchSampler ease with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f: diff --git a/src/f5_tts/train/datasets/prepare_ljspeech.py b/src/f5_tts/train/datasets/prepare_ljspeech.py index 9f64b0a..4a60e39 100644 --- a/src/f5_tts/train/datasets/prepare_ljspeech.py +++ b/src/f5_tts/train/datasets/prepare_ljspeech.py @@ -39,6 +39,7 @@ def main(): with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer: for line in tqdm(result, desc="Writing to raw.arrow ..."): writer.write(line) + writer.finalize() # dup a json separately saving duration in case for DynamicBatchSampler ease with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f: diff --git a/src/f5_tts/train/finetune_gradio.py b/src/f5_tts/train/finetune_gradio.py index eee2a3f..692c95c 100644 --- a/src/f5_tts/train/finetune_gradio.py +++ b/src/f5_tts/train/finetune_gradio.py @@ -796,9 +796,10 @@ def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()): min_second = round(min(duration_list), 2) max_second = round(max(duration_list), 2) - with ArrowWriter(path=file_raw, writer_batch_size=1) as writer: + with ArrowWriter(path=file_raw) as writer: for line in progress.tqdm(result, total=len(result), desc="prepare data"): writer.write(line) + writer.finalize() with open(file_duration, "w") as f: json.dump({"duration": duration_list}, f, ensure_ascii=False)