1414import yaml
1515import torch
1616import torch .nn as nn
17+ import matplotlib .pyplot as plt
1718from torch import optim
1819from torch .utils .data import DataLoader
1920from . import NISQA_lib as NL
@@ -50,7 +51,81 @@ def evaluate(self, mapping='first_order', do_print=True, do_plot=False):
5051 self ._evaluate_dim (mapping = mapping , do_print = do_print , do_plot = do_plot )
5152 else :
5253 self ._evaluate_mos (mapping = mapping , do_print = do_print , do_plot = do_plot )
53-
54+
55+ def _draw_barchart (self ):
56+ # settings of barchart
57+ bar_width = 0.15
58+ index = np .arange (len (self .ds_val .df ['deg' ]))
59+ colors = ['r' , 'g' , 'b' , 'c' , 'm' ]
60+ fig , ax = plt .subplots (figsize = (12 , 8 ))
61+ for i , col in enumerate (['mos_pred' , 'noi_pred' , 'dis_pred' , 'col_pred' , 'loud_pred' ]):
62+ ax .bar (index + (i * bar_width ), self .ds_val .df [col ], bar_width , color = colors [i ], label = col )
63+ for j , val in enumerate (self .ds_val .df [col ]):
64+ ax .annotate (str (f'{ val :.2f} ' ), xy = (index [j ] + (i * bar_width ), val ), xytext = (0 , 3 ),
65+ textcoords = "offset points" , ha = 'center' , va = 'bottom' )
66+ # set labels and title
67+ ax .set_xlabel ('WavName' )
68+ ax .set_ylabel ('Scores' )
69+ title = 'BarChart for ' + str (self .args ['deg' ]) if self .args ['mode' ] == 'predict_file' else 'Barchart for Wavs under: ' + str (self .args ['data_dir' ])
70+ ax .set_title (title )
71+ ax .set_xticks (index + (2 * bar_width ))
72+ ax .set_xticklabels (self .ds_val .df ['deg' ], rotation = 90 )
73+ ax .set_ylim (0 , 5 )
74+ # add legend
75+ ax .legend ()
76+ # save plot
77+ if 'plot_name' in self .args and self .args ['plot_name' ] != 'None' :
78+ save_path = self .args ['output_dir' ] if self .args ['output_dir' ][- 1 ] == '/' else self .args ['output_dir' ]+ '/'
79+ plt .savefig (save_path + 'BarChart_' + self .args ['plot_name' ])
80+ # display the plot
81+ plt .tight_layout ()
82+ plt .show ()
83+
84+
85+ def _draw_linechart (self ):
86+ # settings of line chart
87+ fig , ax = plt .subplots (figsize = (12 , 8 ))
88+ index = np .arange (len (self .ds_val .df ['deg' ]))
89+ colors = ['r' , 'g' , 'b' , 'c' , 'm' ]
90+ for i , col in enumerate (['mos_pred' , 'noi_pred' , 'dis_pred' , 'col_pred' , 'loud_pred' ]):
91+ ax .plot (index , self .ds_val .df [col ], marker = 'o' , linestyle = '-' , color = colors [i ], label = col )
92+ for x , y in zip (index , self .ds_val .df [col ]):
93+ ax .text (x , y , f'{ y :.2f} ' , ha = 'left' , va = 'top' )
94+ # set labels and title
95+ ax .set_xlabel ('WavName' )
96+ ax .set_ylabel ('Scores' )
97+ title = 'LineChart for ' + str (self .args ['deg' ]) if self .args ['mode' ] == 'predict_file' else 'Line Chart for Wavs under: ' + str (self .args ['data_dir' ])
98+ ax .set_title (title )
99+ ax .set_xticks (index )
100+ ax .set_xticklabels (self .ds_val .df ['deg' ], rotation = 90 )
101+ ax .set_ylim (0 , 5 )
102+ # add legend
103+ ax .legend ()
104+ # save plot
105+ if 'plot_name' in self .args and self .args ['plot_name' ] != 'None' :
106+ save_path = self .args ['output_dir' ] if self .args ['output_dir' ][- 1 ] == '/' else self .args ['output_dir' ] + '/'
107+ plt .savefig (save_path + 'LineChart_' + self .args ['plot_name' ])
108+ # display the plot
109+ plt .tight_layout ()
110+ plt .show ()
111+
112+ # calculate mean and standard deviation of the results and concatenate with the initial results.
113+ def _compute_mean_stdDev (self ):
114+ mean = self .ds_val .df .drop (['deg' , 'model' ], axis = 1 ).mean ()
115+ std = self .ds_val .df .drop (['deg' , 'model' ], axis = 1 ).std ()
116+ # alter NAN to 0
117+ std .fillna (0 , inplace = True )
118+
119+ stat = pd .DataFrame ({'deg' : ['**Mean**' , '**Standard Deviation**' ],
120+ 'mos_pred' : [mean ['mos_pred' ], std ['mos_pred' ]],
121+ 'noi_pred' : [mean ['noi_pred' ], std ['noi_pred' ]],
122+ 'dis_pred' : [mean ['dis_pred' ], std ['dis_pred' ]],
123+ 'col_pred' : [mean ['col_pred' ], std ['col_pred' ]],
124+ 'loud_pred' : [mean ['loud_pred' ], std ['loud_pred' ]],
125+ 'model' : self .ds_val .df ['model' ][0 ]})
126+ stat = pd .concat ([self .ds_val .df , stat ])
127+ return stat
128+
54129 def predict (self ):
55130 print ('---> Predicting ...' )
56131 if self .args ['tr_parallel' ]:
@@ -69,15 +144,38 @@ def predict(self):
69144 self .ds_val ,
70145 self .args ['tr_bs_val' ],
71146 self .dev ,
72- num_workers = self .args ['tr_num_workers' ])
73-
147+ num_workers = self .args ['tr_num_workers' ])
148+
74149 if self .args ['output_dir' ]:
75150 self .ds_val .df ['model' ] = self .args ['name' ]
76- self .ds_val .df .to_csv (
77- os .path .join (self .args ['output_dir' ], 'NISQA_results.csv' ),
78- index = False )
79-
80- print (self .ds_val .df .to_string (index = False ))
151+ csv_name = 'NISQA_results.csv'
152+ if 'output_name' in self .args :
153+ csv_name = str (self .args ['output_name' ]) + '.csv'
154+ # whether print mean and standard deviation or not
155+ if 'compute_stats' in self .args and self .args ['compute_stats' ]:
156+ # generate results with mean and deviation to self.statistics.
157+ self .statistics = self ._compute_mean_stdDev ()
158+ self .statistics .to_csv (
159+ os .path .join (self .args ['output_dir' ], csv_name ),
160+ index = False )
161+ else :
162+ self .ds_val .df .to_csv (
163+ os .path .join (self .args ['output_dir' ], csv_name ),
164+ index = False )
165+
166+ # print either statistics or ds_val based on 'compute_stats' parameter
167+ if 'compute_stats' in self .args and self .args ['compute_stats' ]:
168+ print (self .statistics .to_string (index = False ))
169+ else :
170+ print (self .ds_val .df .to_string (index = False ))
171+
172+ # Visualization of the results
173+ if 'plot_type' in self .args and self .args ['plot_type' ] == 'barchart' :
174+ self ._draw_barchart ()
175+ elif 'plot_type' in self .args and self .args ['plot_type' ] == 'linechart' :
176+ self ._draw_linechart ()
177+
178+ # returned DataFrame is not changed.
81179 return self .ds_val .df
82180
83181 def _train_mos (self ):
0 commit comments