diff --git a/src/f5_tts/train/datasets/prepare_csv_wavs.py b/src/f5_tts/train/datasets/prepare_csv_wavs.py index 26ad6f8..4717b8b 100644 --- a/src/f5_tts/train/datasets/prepare_csv_wavs.py +++ b/src/f5_tts/train/datasets/prepare_csv_wavs.py @@ -208,11 +208,11 @@ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_fine out_dir.mkdir(exist_ok=True, parents=True) print(f"\nSaving to {out_dir} ...") - # Save dataset with improved batch size for better I/O performance raw_arrow_path = out_dir / "raw.arrow" - with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=100) as writer: + with ArrowWriter(path=raw_arrow_path.as_posix()) as writer: for line in tqdm(result, desc="Writing to raw.arrow ..."): writer.write(line) + writer.finalize() # Save durations to JSON dur_json_path = out_dir / "duration.json" diff --git a/src/f5_tts/train/datasets/prepare_emilia.py b/src/f5_tts/train/datasets/prepare_emilia.py index 4c4a771..7c6b805 100644 --- a/src/f5_tts/train/datasets/prepare_emilia.py +++ b/src/f5_tts/train/datasets/prepare_emilia.py @@ -181,6 +181,7 @@ def main(): with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer: for line in tqdm(result, desc="Writing to raw.arrow ..."): writer.write(line) + writer.finalize() # dup a json separately saving duration in case for DynamicBatchSampler ease with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f: diff --git a/src/f5_tts/train/datasets/prepare_emilia_v2.py b/src/f5_tts/train/datasets/prepare_emilia_v2.py index 50322c0..e839412 100644 --- a/src/f5_tts/train/datasets/prepare_emilia_v2.py +++ b/src/f5_tts/train/datasets/prepare_emilia_v2.py @@ -68,6 +68,7 @@ def main(): with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer: for line in tqdm(result, desc="Writing to raw.arrow ..."): writer.write(line) + writer.finalize() with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f: json.dump({"duration": duration_list}, f, ensure_ascii=False) diff --git a/src/f5_tts/train/datasets/prepare_libritts.py b/src/f5_tts/train/datasets/prepare_libritts.py index a892dd6..0a11eb0 100644 --- a/src/f5_tts/train/datasets/prepare_libritts.py +++ b/src/f5_tts/train/datasets/prepare_libritts.py @@ -62,6 +62,7 @@ def main(): with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer: for line in tqdm(result, desc="Writing to raw.arrow ..."): writer.write(line) + writer.finalize() # dup a json separately saving duration in case for DynamicBatchSampler ease with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f: diff --git a/src/f5_tts/train/datasets/prepare_ljspeech.py b/src/f5_tts/train/datasets/prepare_ljspeech.py index 9f64b0a..4a60e39 100644 --- a/src/f5_tts/train/datasets/prepare_ljspeech.py +++ b/src/f5_tts/train/datasets/prepare_ljspeech.py @@ -39,6 +39,7 @@ def main(): with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer: for line in tqdm(result, desc="Writing to raw.arrow ..."): writer.write(line) + writer.finalize() # dup a json separately saving duration in case for DynamicBatchSampler ease with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f: diff --git a/src/f5_tts/train/finetune_gradio.py b/src/f5_tts/train/finetune_gradio.py index eee2a3f..692c95c 100644 --- a/src/f5_tts/train/finetune_gradio.py +++ b/src/f5_tts/train/finetune_gradio.py @@ -796,9 +796,10 @@ def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()): min_second = round(min(duration_list), 2) max_second = round(max(duration_list), 2) - with ArrowWriter(path=file_raw, writer_batch_size=1) as writer: + with ArrowWriter(path=file_raw) as writer: for line in progress.tqdm(result, total=len(result), desc="prepare data"): writer.write(line) + writer.finalize() with open(file_duration, "w") as f: json.dump({"duration": duration_list}, f, ensure_ascii=False)