mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-12 15:50:07 -08:00
Refactor imports and improve code formatting in dataset and trainer modules
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import json
|
||||
import random
|
||||
from importlib.resources import files
|
||||
|
||||
import torch
|
||||
|
||||
@@ -279,11 +279,11 @@ class Trainer:
|
||||
self.accelerator.even_batches = False
|
||||
sampler = SequentialSampler(train_dataset)
|
||||
batch_sampler = DynamicBatchSampler(
|
||||
sampler,
|
||||
self.batch_size,
|
||||
max_samples=self.max_samples,
|
||||
sampler,
|
||||
self.batch_size,
|
||||
max_samples=self.max_samples,
|
||||
random_seed=resumable_with_seed, # This enables reproducible shuffling
|
||||
drop_last=False
|
||||
drop_last=False,
|
||||
)
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
@@ -334,7 +334,7 @@ class Trainer:
|
||||
current_dataloader = train_dataloader
|
||||
|
||||
# Set epoch for the batch sampler if it exists
|
||||
if hasattr(train_dataloader, 'batch_sampler') and hasattr(train_dataloader.batch_sampler, 'set_epoch'):
|
||||
if hasattr(train_dataloader, "batch_sampler") and hasattr(train_dataloader.batch_sampler, "set_epoch"):
|
||||
train_dataloader.batch_sampler.set_epoch(epoch)
|
||||
|
||||
progress_bar = tqdm(
|
||||
|
||||
Reference in New Issue
Block a user