Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit 4cffac1

Browse files
committed
add script for fast choise the best validation checkpoint
1 parent 8d77e1b commit 4cffac1

5 files changed

Lines changed: 205 additions & 7 deletions

File tree

test_all_checkpoints.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from utils.tensorboard import TensorboardWriter
1919

20-
from utils.dataset import test_dataloader
20+
from utils.dataset import test_dataloader, eval_dataloader
2121

2222
from utils.generic_utils import validation, PowerLaw_Compressed_Loss, SiSNR_With_Pit
2323

@@ -105,8 +105,8 @@ def test(args, log_dir, checkpoint_path, testloader, tensorboard, c, model_name,
105105
c.dataset['test_dir'] = args.dataset_dir
106106
# set batchsize = 1
107107
c.train_config['batch_size'] = 1
108-
test_dataloader = test_dataloader(c, ap)
109-
108+
test_dataloader = eval_dataloader(c, ap)
109+
print(c.dataset['format'])
110110
best_sdr = 0
111111
best_loss = 999999999
112112
best_sdr_checkpoint = ''

test_fast_all_checkpoints.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import os
2+
import math
3+
import torch
4+
import torch.nn as nn
5+
import traceback
6+
from glob import glob
7+
8+
import time
9+
import numpy as np
10+
11+
import tqdm
12+
13+
import argparse
14+
15+
from utils.generic_utils import load_config, load_config_from_str
16+
from utils.generic_utils import set_init_dict
17+
18+
from utils.tensorboard import TensorboardWriter
19+
20+
from utils.dataset import test_dataloader
21+
22+
from utils.generic_utils import validation, PowerLaw_Compressed_Loss, SiSNR_With_Pit, test_fast_with_si_srn
23+
24+
from models.voicefilter.model import VoiceFilter
25+
from models.voicesplit.model import VoiceSplit
26+
from utils.audio_processor import WrapperAudioProcessor as AudioProcessor
27+
28+
from shutil import copyfile
29+
import yaml
30+
31+
def test(args, log_dir, checkpoint_path, testloader, tensorboard, c, model_name, ap, cuda=True):
32+
if(model_name == 'voicefilter'):
33+
model = VoiceFilter(c)
34+
elif(model_name == 'voicesplit'):
35+
model = VoiceSplit(c)
36+
else:
37+
raise Exception(" The model '"+model_name+"' is not suported")
38+
39+
if c.train_config['optimizer'] == 'adam':
40+
optimizer = torch.optim.Adam(model.parameters(),
41+
lr=c.train_config['learning_rate'])
42+
else:
43+
raise Exception("The %s not is a optimizer supported" % c.train['optimizer'])
44+
45+
step = 0
46+
if checkpoint_path is not None:
47+
try:
48+
checkpoint = torch.load(checkpoint_path, map_location='cpu')
49+
model.load_state_dict(checkpoint['model'])
50+
if cuda:
51+
model = model.cuda()
52+
except:
53+
raise Exception("Fail in load checkpoint, you need use this configs: %s" %checkpoint['config_str'])
54+
55+
try:
56+
optimizer.load_state_dict(checkpoint['optimizer'])
57+
except:
58+
print(" > Optimizer state is not loaded from checkpoint path, you see this mybe you change the optimizer")
59+
60+
step = checkpoint['step']
61+
else:
62+
raise Exception("You need specific a checkpoint for test")
63+
# convert model from cuda
64+
if cuda:
65+
model = model.cuda()
66+
67+
# definitions for power-law compressed loss
68+
power = c.loss['power']
69+
complex_ratio = c.loss['complex_loss_ratio']
70+
71+
if c.loss['loss_name'] == 'power_law_compression':
72+
criterion = PowerLaw_Compressed_Loss(power, complex_ratio)
73+
elif c.loss['loss_name'] == 'si_snr':
74+
criterion = SiSNR_With_Pit()
75+
else:
76+
raise Exception(" The loss '"+c.loss['loss_name']+"' is not suported")
77+
return test_fast_with_si_srn(criterion, ap, model, testloader, tensorboard, step, cuda=cuda, loss_name=c.loss['loss_name'], test=True)
78+
79+
80+
if __name__ == '__main__':
81+
parser = argparse.ArgumentParser()
82+
83+
parser.add_argument('-d', '--dataset_dir', type=str, default='./',
84+
help="Root directory of run.")
85+
parser.add_argument('-c', '--config_path', type=str, required=False, default=None,
86+
help="json file with configurations")
87+
parser.add_argument('--checkpoints_path', type=str, required=True,
88+
help="path of checkpoint pt file, for continue training")
89+
args = parser.parse_args()
90+
91+
all_checkpoints = sorted(glob(os.path.join(args.checkpoints_path, '*.pt')))
92+
#print(all_checkpoints, os.listdir(args.checkpoints_path))
93+
if args.config_path:
94+
c = load_config(args.config_path)
95+
else: #load config in checkpoint
96+
checkpoint = torch.load(all_checkpoints[0], map_location='cpu')
97+
c = load_config_from_str(checkpoint['config_str'])
98+
99+
ap = AudioProcessor(c.audio)
100+
101+
log_path = os.path.join(c.train_config['logs_path'], c.model_name)
102+
audio_config = c.audio[c.audio['backend']]
103+
tensorboard = TensorboardWriter(log_path, audio_config)
104+
# set test dataset dir
105+
c.dataset['test_dir'] = args.dataset_dir
106+
# set batchsize = 32
107+
c.test_config['batch_size'] = 5
108+
test_dataloader = test_dataloader(c, ap)
109+
best_loss = 999999999
110+
best_loss_checkpoint = ''
111+
sdrs_checkpoint = []
112+
for i in tqdm.tqdm(range(len(all_checkpoints))):
113+
checkpoint = all_checkpoints[i]
114+
mean_loss= test(args, log_path, checkpoint, test_dataloader, tensorboard, c, c.model_name, ap, cuda=True)
115+
sdrs_checkpoint.append([mean_loss, checkpoint])
116+
if mean_loss < best_loss:
117+
best_loss = mean_loss
118+
best_loss_checkpoint = checkpoint
119+
print("Best Loss checkpoint is: ", best_loss_checkpoint, "Best Loss:", best_loss)
120+
copyfile(best_sdr_checkpoint, os.path.join(args.checkpoints_path,'fast_best_checkpoint.pt'))
121+
np.save(os.path.join(args.checkpoints_path,"Loss_validation_with_VCTK_best_SI-SNR_is_"+str(best_sdr)+".np"), np.array(sdrs_checkpoint))

train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from utils.tensorboard import TensorboardWriter
1616

17-
from utils.dataset import train_dataloader, test_dataloader
17+
from utils.dataset import train_dataloader, eval_dataloader
1818

1919
from utils.generic_utils import validation, PowerLaw_Compressed_Loss, SiSNR_With_Pit
2020

@@ -159,5 +159,5 @@ def train(args, log_dir, checkpoint_path, trainloader, testloader, tensorboard,
159159
raise Exception("Please verify directories of dataset in "+args.config_path)
160160

161161
train_dataloader = train_dataloader(c, ap)
162-
test_dataloader = test_dataloader(c, ap)
162+
test_dataloader = eval_dataloader(c, ap)
163163
train(args, log_path, args.checkpoint_path, train_dataloader, test_dataloader, tensorboard, c, c.model_name, ap, cuda=True)

utils/dataset.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ def __getitem__(self, idx):
4949
target_spec, _ = self.ap.get_spec_from_audio(target_wav, return_phase=True)
5050
target_spec = torch.from_numpy(target_spec)
5151
mixed_spec = torch.from_numpy(mixed_spec)
52+
mixed_phase = torch.from_numpy(mixed_phase)
53+
target_wav = torch.from_numpy(target_wav)
54+
mixed_wav = torch.from_numpy(mixed_wav)
5255
seq_len = torch.from_numpy(np.array([mixed_wav.shape[0]]))
5356
return emb, target_spec, mixed_spec, target_wav, mixed_wav, mixed_phase, seq_len
5457

@@ -69,6 +72,15 @@ def test_dataloader(c, ap):
6972
collate_fn=test_collate_fn, batch_size=c.test_config['batch_size'],
7073
shuffle=False, num_workers=c.test_config['num_workers'])
7174

75+
def eval_dataloader(c, ap):
76+
return DataLoader(dataset=Dataset(c, ap, train=False),
77+
collate_fn=eval_collate_fn, batch_size=c.test_config['batch_size'],
78+
shuffle=False, num_workers=c.test_config['num_workers'])
79+
80+
81+
def eval_collate_fn(batch):
82+
return batch
83+
7284
def train_collate_fn(item):
7385
embs_list = []
7486
target_list = []
@@ -102,4 +114,38 @@ def train_collate_fn(item):
102114
return embs_list, target_list, mixed_list, seq_len_list, target_wav_list, mixed_phase_list
103115

104116
def test_collate_fn(batch):
105-
return batch
117+
embs_list = []
118+
target_list = []
119+
mixed_list = []
120+
seq_len_list = []
121+
mixed_phase_list = []
122+
target_wav_list = []
123+
mixed_wav_list = []
124+
125+
for emb, target, mixed, target_wav, mixed_wav, mixed_phase, seq_len in batch:
126+
#print(emb)
127+
if emb.tolist() == [0]:
128+
#print("ignorado ", emb)
129+
continue
130+
embs_list.append(emb)
131+
target_list.append(target)
132+
mixed_list.append(mixed)
133+
seq_len_list.append(seq_len)
134+
mixed_phase_list.append(mixed_phase)
135+
target_wav_list.append(target_wav)
136+
mixed_wav_list.append(mixed_wav)
137+
138+
# concate tensors in dim 0
139+
target_list = stack(target_list, dim=0)
140+
mixed_list = stack(mixed_list, dim=0)
141+
seq_len_list = stack(seq_len_list, dim=0)
142+
target_wav_list = stack(target_wav_list, dim=0)
143+
mixed_phase_list = stack(mixed_phase_list, dim=0) # np.array(mixed_phase_list)
144+
mixed_wav_list = stack(mixed_wav_list, dim=0)
145+
try:
146+
embs_list = stack(embs_list, dim=0)
147+
except:
148+
#print('erro, stack')
149+
embs_list = embs_list
150+
return embs_list, target_list, mixed_list, target_wav_list, mixed_wav_list, mixed_phase_list, seq_len_list
151+

utils/generic_utils.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,9 +218,11 @@ def validation(criterion, ap, model, testloader, tensorboard, step, cuda=True, l
218218
mixed_spec = mixed_spec[0].cpu().detach().numpy()
219219
clean_spec = clean_spec[0].cpu().detach().numpy()
220220
est_mag = est_mag[0].cpu().detach().numpy()
221+
mixed_phase = mixed_phase[0].cpu().detach().numpy()
221222

222223
est_wav = ap.inv_spectrogram(est_mag, phase=mixed_phase)
223224
est_mask = est_mask[0].cpu().detach().numpy()
225+
224226
if loss_name == 'si_snr':
225227
test_loss = criterion(torch.from_numpy(np.array([[clean_wav]])), torch.from_numpy(np.array([[est_wav]])), seq_len).item()
226228
sdr = bss_eval_sources(clean_wav, est_wav, False)[0][0]
@@ -244,12 +246,41 @@ def validation(criterion, ap, model, testloader, tensorboard, step, cuda=True, l
244246
print("Mean Test Loss:", mean_test_loss)
245247
print("Mean Test SDR:", mean_sdr)
246248
return mean_test_loss, mean_sdr
249+
250+
def test_fast_with_si_srn(criterion, ap, model, testloader, tensorboard, step, cuda=True, loss_name='si_snr', test=False):
251+
losses = []
252+
model.eval()
253+
# set fast and best criterion
254+
criterion = SiSNR_With_Pit()
255+
count = 0
256+
with torch.no_grad():
257+
for emb, clean_spec, mixed_spec, clean_wav, mixed_wav, mixed_phase, seq_len in testloader:
258+
if cuda:
259+
emb = emb.cuda()
260+
clean_spec = clean_spec.cuda()
261+
mixed_spec = mixed_spec.cuda()
262+
mixed_phase = mixed_phase.cuda()
263+
seq_len = seq_len.cuda()
264+
est_mask = model(mixed_spec, emb)
265+
est_mag = est_mask * mixed_spec
266+
# convert spec to wav using phase
267+
output = ap.torch_inv_spectrogram(est_mag, mixed_phase)
268+
target = ap.torch_inv_spectrogram(clean_spec, mixed_phase)
269+
shape = list(target.shape)
270+
target = torch.reshape(target, [shape[0],1]+shape[1:]) # append channel dim
271+
output = torch.reshape(output, [shape[0],1]+shape[1:]) # append channel dim
272+
test_loss = criterion(output, target, seq_len).item()
273+
losses.append(test_loss)
274+
275+
mean_test_loss = np.array(losses).mean()
276+
print("Mean Si-SRN with Pit Loss:", mean_test_loss)
277+
return mean_test_loss
278+
247279
class AttrDict(dict):
248280
def __init__(self, *args, **kwargs):
249281
super(AttrDict, self).__init__(*args, **kwargs)
250282
self.__dict__ = self
251283

252-
253284
def load_config(config_path):
254285
config = AttrDict()
255286
with open(config_path, "r") as f:

0 commit comments

Comments
 (0)