1+ import os
2+ import math
3+ import torch
4+ import torch .nn as nn
5+ import traceback
6+ from glob import glob
7+
8+ import time
9+ import numpy as np
10+
11+ import tqdm
12+
13+ import argparse
14+
15+ from utils .generic_utils import load_config , load_config_from_str
16+ from utils .generic_utils import set_init_dict
17+
18+ from utils .tensorboard import TensorboardWriter
19+
20+ from utils .dataset import test_dataloader
21+
22+ from utils .generic_utils import validation , PowerLaw_Compressed_Loss , SiSNR_With_Pit , test_fast_with_si_srn
23+
24+ from models .voicefilter .model import VoiceFilter
25+ from models .voicesplit .model import VoiceSplit
26+ from utils .audio_processor import WrapperAudioProcessor as AudioProcessor
27+
28+ from shutil import copyfile
29+ import yaml
30+
31+ def test (args , log_dir , checkpoint_path , testloader , tensorboard , c , model_name , ap , cuda = True ):
32+ if (model_name == 'voicefilter' ):
33+ model = VoiceFilter (c )
34+ elif (model_name == 'voicesplit' ):
35+ model = VoiceSplit (c )
36+ else :
37+ raise Exception (" The model '" + model_name + "' is not suported" )
38+
39+ if c .train_config ['optimizer' ] == 'adam' :
40+ optimizer = torch .optim .Adam (model .parameters (),
41+ lr = c .train_config ['learning_rate' ])
42+ else :
43+ raise Exception ("The %s not is a optimizer supported" % c .train ['optimizer' ])
44+
45+ step = 0
46+ if checkpoint_path is not None :
47+ try :
48+ checkpoint = torch .load (checkpoint_path , map_location = 'cpu' )
49+ model .load_state_dict (checkpoint ['model' ])
50+ if cuda :
51+ model = model .cuda ()
52+ except :
53+ raise Exception ("Fail in load checkpoint, you need use this configs: %s" % checkpoint ['config_str' ])
54+
55+ try :
56+ optimizer .load_state_dict (checkpoint ['optimizer' ])
57+ except :
58+ print (" > Optimizer state is not loaded from checkpoint path, you see this mybe you change the optimizer" )
59+
60+ step = checkpoint ['step' ]
61+ else :
62+ raise Exception ("You need specific a checkpoint for test" )
63+ # convert model from cuda
64+ if cuda :
65+ model = model .cuda ()
66+
67+ # definitions for power-law compressed loss
68+ power = c .loss ['power' ]
69+ complex_ratio = c .loss ['complex_loss_ratio' ]
70+
71+ if c .loss ['loss_name' ] == 'power_law_compression' :
72+ criterion = PowerLaw_Compressed_Loss (power , complex_ratio )
73+ elif c .loss ['loss_name' ] == 'si_snr' :
74+ criterion = SiSNR_With_Pit ()
75+ else :
76+ raise Exception (" The loss '" + c .loss ['loss_name' ]+ "' is not suported" )
77+ return test_fast_with_si_srn (criterion , ap , model , testloader , tensorboard , step , cuda = cuda , loss_name = c .loss ['loss_name' ], test = True )
78+
79+
80+ if __name__ == '__main__' :
81+ parser = argparse .ArgumentParser ()
82+
83+ parser .add_argument ('-d' , '--dataset_dir' , type = str , default = './' ,
84+ help = "Root directory of run." )
85+ parser .add_argument ('-c' , '--config_path' , type = str , required = False , default = None ,
86+ help = "json file with configurations" )
87+ parser .add_argument ('--checkpoints_path' , type = str , required = True ,
88+ help = "path of checkpoint pt file, for continue training" )
89+ args = parser .parse_args ()
90+
91+ all_checkpoints = sorted (glob (os .path .join (args .checkpoints_path , '*.pt' )))
92+ #print(all_checkpoints, os.listdir(args.checkpoints_path))
93+ if args .config_path :
94+ c = load_config (args .config_path )
95+ else : #load config in checkpoint
96+ checkpoint = torch .load (all_checkpoints [0 ], map_location = 'cpu' )
97+ c = load_config_from_str (checkpoint ['config_str' ])
98+
99+ ap = AudioProcessor (c .audio )
100+
101+ log_path = os .path .join (c .train_config ['logs_path' ], c .model_name )
102+ audio_config = c .audio [c .audio ['backend' ]]
103+ tensorboard = TensorboardWriter (log_path , audio_config )
104+ # set test dataset dir
105+ c .dataset ['test_dir' ] = args .dataset_dir
106+ # set batchsize = 32
107+ c .test_config ['batch_size' ] = 5
108+ test_dataloader = test_dataloader (c , ap )
109+ best_loss = 999999999
110+ best_loss_checkpoint = ''
111+ sdrs_checkpoint = []
112+ for i in tqdm .tqdm (range (len (all_checkpoints ))):
113+ checkpoint = all_checkpoints [i ]
114+ mean_loss = test (args , log_path , checkpoint , test_dataloader , tensorboard , c , c .model_name , ap , cuda = True )
115+ sdrs_checkpoint .append ([mean_loss , checkpoint ])
116+ if mean_loss < best_loss :
117+ best_loss = mean_loss
118+ best_loss_checkpoint = checkpoint
119+ print ("Best Loss checkpoint is: " , best_loss_checkpoint , "Best Loss:" , best_loss )
120+ copyfile (best_sdr_checkpoint , os .path .join (args .checkpoints_path ,'fast_best_checkpoint.pt' ))
121+ np .save (os .path .join (args .checkpoints_path ,"Loss_validation_with_VCTK_best_SI-SNR_is_" + str (best_sdr )+ ".np" ), np .array (sdrs_checkpoint ))
0 commit comments