mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-12 15:50:07 -08:00
Merge pull request #765 from hcsolakoglu/dynbatchsampler-epoch-shuffle
Add Per-Epoch Batch Shuffling to DynamicBatchSampler
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import json
|
||||
import random
|
||||
from importlib.resources import files
|
||||
|
||||
import torch
|
||||
@@ -170,6 +169,7 @@ class DynamicBatchSampler(Sampler[list[int]]):
|
||||
in a batch to ensure that the total number of frames are less
|
||||
than a certain threshold.
|
||||
2. Make sure the padding efficiency in the batch is high.
|
||||
3. Shuffle batches each epoch while maintaining reproducibility.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -178,6 +178,8 @@ class DynamicBatchSampler(Sampler[list[int]]):
|
||||
self.sampler = sampler
|
||||
self.frames_threshold = frames_threshold
|
||||
self.max_samples = max_samples
|
||||
self.random_seed = random_seed
|
||||
self.epoch = 0
|
||||
|
||||
indices, batches = [], []
|
||||
data_source = self.sampler.data_source
|
||||
@@ -210,17 +212,23 @@ class DynamicBatchSampler(Sampler[list[int]]):
|
||||
batches.append(batch)
|
||||
|
||||
del indices
|
||||
|
||||
# if want to have different batches between epochs, may just set a seed and log it in ckpt
|
||||
# cuz during multi-gpu training, although the batch on per gpu not change between epochs, the formed general minibatch is different
|
||||
# e.g. for epoch n, use (random_seed + n)
|
||||
random.seed(random_seed)
|
||||
random.shuffle(batches)
|
||||
|
||||
self.batches = batches
|
||||
|
||||
def set_epoch(self, epoch: int) -> None:
|
||||
"""Sets the epoch for this sampler."""
|
||||
self.epoch = epoch
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.batches)
|
||||
# Use both random_seed and epoch for deterministic but different shuffling per epoch
|
||||
if self.random_seed is not None:
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.random_seed + self.epoch)
|
||||
# Use PyTorch's random permutation for better reproducibility across PyTorch versions
|
||||
indices = torch.randperm(len(self.batches), generator=g).tolist()
|
||||
batches = [self.batches[i] for i in indices]
|
||||
else:
|
||||
batches = self.batches
|
||||
return iter(batches)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.batches)
|
||||
|
||||
@@ -279,7 +279,11 @@ class Trainer:
|
||||
self.accelerator.even_batches = False
|
||||
sampler = SequentialSampler(train_dataset)
|
||||
batch_sampler = DynamicBatchSampler(
|
||||
sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False
|
||||
sampler,
|
||||
self.batch_size,
|
||||
max_samples=self.max_samples,
|
||||
random_seed=resumable_with_seed, # This enables reproducible shuffling
|
||||
drop_last=False,
|
||||
)
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
@@ -329,6 +333,10 @@ class Trainer:
|
||||
progress_bar_initial = 0
|
||||
current_dataloader = train_dataloader
|
||||
|
||||
# Set epoch for the batch sampler if it exists
|
||||
if hasattr(train_dataloader, "batch_sampler") and hasattr(train_dataloader.batch_sampler, "set_epoch"):
|
||||
train_dataloader.batch_sampler.set_epoch(epoch)
|
||||
|
||||
progress_bar = tqdm(
|
||||
range(math.ceil(len(train_dataloader) / self.grad_accumulation_steps)),
|
||||
desc=f"Epoch {epoch+1}/{self.epochs}",
|
||||
|
||||
Reference in New Issue
Block a user