Merge pull request #765 from hcsolakoglu/dynbatchsampler-epoch-shuffle

Add Per-Epoch Batch Shuffling to DynamicBatchSampler
This commit is contained in:
Yushen CHEN
2025-02-05 15:10:08 +08:00
committed by GitHub
2 changed files with 26 additions and 10 deletions

View File

@@ -1,5 +1,4 @@
import json
import random
from importlib.resources import files
import torch
@@ -170,6 +169,7 @@ class DynamicBatchSampler(Sampler[list[int]]):
in a batch to ensure that the total number of frames are less
than a certain threshold.
2. Make sure the padding efficiency in the batch is high.
3. Shuffle batches each epoch while maintaining reproducibility.
"""
def __init__(
@@ -178,6 +178,8 @@ class DynamicBatchSampler(Sampler[list[int]]):
self.sampler = sampler
self.frames_threshold = frames_threshold
self.max_samples = max_samples
self.random_seed = random_seed
self.epoch = 0
indices, batches = [], []
data_source = self.sampler.data_source
@@ -210,17 +212,23 @@ class DynamicBatchSampler(Sampler[list[int]]):
batches.append(batch)
del indices
# if want to have different batches between epochs, may just set a seed and log it in ckpt
# cuz during multi-gpu training, although the batch on per gpu not change between epochs, the formed general minibatch is different
# e.g. for epoch n, use (random_seed + n)
random.seed(random_seed)
random.shuffle(batches)
self.batches = batches
def set_epoch(self, epoch: int) -> None:
"""Sets the epoch for this sampler."""
self.epoch = epoch
def __iter__(self):
return iter(self.batches)
# Use both random_seed and epoch for deterministic but different shuffling per epoch
if self.random_seed is not None:
g = torch.Generator()
g.manual_seed(self.random_seed + self.epoch)
# Use PyTorch's random permutation for better reproducibility across PyTorch versions
indices = torch.randperm(len(self.batches), generator=g).tolist()
batches = [self.batches[i] for i in indices]
else:
batches = self.batches
return iter(batches)
def __len__(self):
return len(self.batches)

View File

@@ -279,7 +279,11 @@ class Trainer:
self.accelerator.even_batches = False
sampler = SequentialSampler(train_dataset)
batch_sampler = DynamicBatchSampler(
sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False
sampler,
self.batch_size,
max_samples=self.max_samples,
random_seed=resumable_with_seed, # This enables reproducible shuffling
drop_last=False,
)
train_dataloader = DataLoader(
train_dataset,
@@ -329,6 +333,10 @@ class Trainer:
progress_bar_initial = 0
current_dataloader = train_dataloader
# Set epoch for the batch sampler if it exists
if hasattr(train_dataloader, "batch_sampler") and hasattr(train_dataloader.batch_sampler, "set_epoch"):
train_dataloader.batch_sampler.set_epoch(epoch)
progress_bar = tqdm(
range(math.ceil(len(train_dataloader) / self.grad_accumulation_steps)),
desc=f"Epoch {epoch+1}/{self.epochs}",