Skip to content

Commit 3f05548

Browse files
committed
added save and load of scaler's settings for FP16 training to avoid a default initialization on resume
1 parent 932a5dd commit 3f05548

2 files changed

Lines changed: 10 additions & 3 deletions

File tree

rvc/train/train.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -473,12 +473,13 @@ def run(
473473
print("Using Float16 for training.")
474474

475475
# Load checkpoint if available
476+
scaler_dict = {}
476477
try:
477478
print("Starting training...")
478-
_, _, _, epoch_str = load_checkpoint(
479+
_, _, _, epoch_str, scaler_dict = load_checkpoint(
479480
latest_checkpoint_path(experiment_dir, "D_*.pth"), net_d, optim_d
480481
)
481-
_, _, _, epoch_str = load_checkpoint(
482+
_, _, _, epoch_str, _ = load_checkpoint(
482483
latest_checkpoint_path(experiment_dir, "G_*.pth"), net_g, optim_g
483484
)
484485
epoch_str += 1
@@ -536,6 +537,8 @@ def run(
536537

537538
use_scaler = device.type == "cuda" and train_dtype == torch.float16
538539
scaler = torch.amp.GradScaler(enabled=use_scaler)
540+
if len(scaler_dict) > 0:
541+
scaler.load_state_dict(scaler_dict)
539542

540543
cache = []
541544
# collect the reference audio for tensorboard evaluation
@@ -1007,13 +1010,15 @@ def train_and_evaluate(
10071010
config.train.learning_rate,
10081011
epoch,
10091012
os.path.join(experiment_dir, "G_" + checkpoint_suffix),
1013+
scaler,
10101014
)
10111015
save_checkpoint(
10121016
net_d,
10131017
optim_d,
10141018
config.train.learning_rate,
10151019
epoch,
10161020
os.path.join(experiment_dir, "D_" + checkpoint_suffix),
1021+
scaler,
10171022
)
10181023
if custom_save_every_weights:
10191024
model_add.append(

rvc/train/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,11 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
7979
optimizer,
8080
checkpoint_dict.get("learning_rate", 0),
8181
checkpoint_dict["iteration"],
82+
checkpoint_dict.get("scaler", {})
8283
)
8384

8485

85-
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
86+
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path, scaler):
8687
"""
8788
Save the model and optimizer state to a checkpoint file.
8889
@@ -101,6 +102,7 @@ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path)
101102
"iteration": iteration,
102103
"optimizer": optimizer.state_dict(),
103104
"learning_rate": learning_rate,
105+
"scaler": scaler.state_dict(),
104106
}
105107

106108
# Create a backwards-compatible checkpoint

0 commit comments

Comments
 (0)