fix typo in trainer.py with 4ae5347282 formatting #909

This commit is contained in:
SWivid
2025-03-25 16:17:03 +08:00
parent b9156c0ad5
commit 6b7f6eefdc

View File

@@ -51,7 +51,7 @@ class Trainer:
mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
is_local_vocoder: bool = False, # use local path vocoder
local_vocoder_path: str = "", # local vocoder path
cfg_dict: dict = dict(), # training config
model_cfg_dict: dict = dict(), # training config
):
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
@@ -73,8 +73,8 @@ class Trainer:
else:
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
if not cfg_dict:
cfg_dict = {
if not model_cfg_dict:
model_cfg_dict = {
"epochs": epochs,
"learning_rate": learning_rate,
"num_warmup_updates": num_warmup_updates,
@@ -85,11 +85,11 @@ class Trainer:
"max_grad_norm": max_grad_norm,
"noise_scheduler": noise_scheduler,
}
cfg_dict["gpus"] = self.accelerator.num_processes
model_cfg_dict["gpus"] = self.accelerator.num_processes
self.accelerator.init_trackers(
project_name=wandb_project,
init_kwargs=init_kwargs,
config=cfg_dict,
config=model_cfg_dict,
)
elif self.logger == "tensorboard":