Refactor imports and improve code formatting in dataset and trainer modules

This commit is contained in:
Can
2025-02-04 22:20:42 +03:00
parent 93ae7d3fc8
commit 33e865120c
2 changed files with 5 additions and 6 deletions

View File

@@ -1,5 +1,4 @@
import json
import random
from importlib.resources import files
import torch

View File

@@ -279,11 +279,11 @@ class Trainer:
self.accelerator.even_batches = False
sampler = SequentialSampler(train_dataset)
batch_sampler = DynamicBatchSampler(
sampler,
self.batch_size,
max_samples=self.max_samples,
sampler,
self.batch_size,
max_samples=self.max_samples,
random_seed=resumable_with_seed, # This enables reproducible shuffling
drop_last=False
drop_last=False,
)
train_dataloader = DataLoader(
train_dataset,
@@ -334,7 +334,7 @@ class Trainer:
current_dataloader = train_dataloader
# Set epoch for the batch sampler if it exists
if hasattr(train_dataloader, 'batch_sampler') and hasattr(train_dataloader.batch_sampler, 'set_epoch'):
if hasattr(train_dataloader, "batch_sampler") and hasattr(train_dataloader.batch_sampler, "set_epoch"):
train_dataloader.batch_sampler.set_epoch(epoch)
progress_bar = tqdm(