Skip to content

Commit 21b7c4f

Browse files
author
Hadley-Zhang
committed
feat: Add descriptive statistics and charts
Added descriptive tools for Speech Quality Prediction: - Mean and Standard Deviation of the results - BarChart and LineChart of the results This change helps provide better assistance in Speech Quality Prediction, especially for repeated averaging of results of the same test sample in practical Audio Testing Scenarios. Also, output name of CSV is added just like gabrielmittag#30.
1 parent ac83137 commit 21b7c4f

2 files changed

Lines changed: 123 additions & 8 deletions

File tree

nisqa/NISQA_model.py

Lines changed: 106 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import yaml
1515
import torch
1616
import torch.nn as nn
17+
import matplotlib.pyplot as plt
1718
from torch import optim
1819
from torch.utils.data import DataLoader
1920
from . import NISQA_lib as NL
@@ -50,7 +51,81 @@ def evaluate(self, mapping='first_order', do_print=True, do_plot=False):
5051
self._evaluate_dim(mapping=mapping, do_print=do_print, do_plot=do_plot)
5152
else:
5253
self._evaluate_mos(mapping=mapping, do_print=do_print, do_plot=do_plot)
53-
54+
55+
def _draw_barchart(self):
56+
# settings of barchart
57+
bar_width = 0.15
58+
index = np.arange(len(self.ds_val.df['deg']))
59+
colors = ['r', 'g', 'b', 'c', 'm']
60+
fig, ax = plt.subplots(figsize=(12, 8))
61+
for i, col in enumerate(['mos_pred', 'noi_pred', 'dis_pred', 'col_pred', 'loud_pred']):
62+
ax.bar(index + (i * bar_width), self.ds_val.df[col], bar_width, color=colors[i], label=col)
63+
for j, val in enumerate(self.ds_val.df[col]):
64+
ax.annotate(str(f'{val:.2f}'), xy=(index[j] + (i * bar_width), val), xytext=(0, 3),
65+
textcoords="offset points", ha='center', va='bottom')
66+
# set labels and title
67+
ax.set_xlabel('WavName')
68+
ax.set_ylabel('Scores')
69+
title = 'BarChart for '+str(self.args['deg']) if self.args['mode'] == 'predict_file' else 'Barchart for Wavs under: '+str(self.args['data_dir'])
70+
ax.set_title(title)
71+
ax.set_xticks(index + (2 * bar_width))
72+
ax.set_xticklabels(self.ds_val.df['deg'], rotation=90)
73+
ax.set_ylim(0, 5)
74+
# add legend
75+
ax.legend()
76+
# save plot
77+
if 'plot_name' in self.args and self.args['plot_name'] != 'None':
78+
save_path = self.args['output_dir'] if self.args['output_dir'][-1] == '/' else self.args['output_dir']+'/'
79+
plt.savefig(save_path+'BarChart_' + self.args['plot_name'])
80+
# display the plot
81+
plt.tight_layout()
82+
plt.show()
83+
84+
85+
def _draw_linechart(self):
86+
# settings of line chart
87+
fig, ax = plt.subplots(figsize=(12, 8))
88+
index = np.arange(len(self.ds_val.df['deg']))
89+
colors = ['r', 'g', 'b', 'c', 'm']
90+
for i, col in enumerate(['mos_pred', 'noi_pred', 'dis_pred', 'col_pred', 'loud_pred']):
91+
ax.plot(index, self.ds_val.df[col], marker='o', linestyle='-', color=colors[i], label=col)
92+
for x, y in zip(index, self.ds_val.df[col]):
93+
ax.text(x, y, f'{y:.2f}', ha='left', va='top')
94+
# set labels and title
95+
ax.set_xlabel('WavName')
96+
ax.set_ylabel('Scores')
97+
title = 'LineChart for ' + str(self.args['deg']) if self.args['mode'] == 'predict_file' else 'Line Chart for Wavs under: ' + str(self.args['data_dir'])
98+
ax.set_title(title)
99+
ax.set_xticks(index)
100+
ax.set_xticklabels(self.ds_val.df['deg'], rotation=90)
101+
ax.set_ylim(0, 5)
102+
# add legend
103+
ax.legend()
104+
# save plot
105+
if 'plot_name' in self.args and self.args['plot_name'] != 'None':
106+
save_path = self.args['output_dir'] if self.args['output_dir'][-1] == '/' else self.args['output_dir'] + '/'
107+
plt.savefig(save_path + 'LineChart_' + self.args['plot_name'])
108+
# display the plot
109+
plt.tight_layout()
110+
plt.show()
111+
112+
# calculate mean and standard deviation of the results and concatenate with the initial results.
113+
def _compute_mean_stdDev(self):
114+
mean = self.ds_val.df.drop(['deg', 'model'], axis=1).mean()
115+
std = self.ds_val.df.drop(['deg', 'model'], axis=1).std()
116+
# alter NAN to 0
117+
std.fillna(0, inplace=True)
118+
119+
stat = pd.DataFrame({'deg': ['**Mean**', '**Standard Deviation**'],
120+
'mos_pred': [mean['mos_pred'], std['mos_pred']],
121+
'noi_pred': [mean['noi_pred'], std['noi_pred']],
122+
'dis_pred': [mean['dis_pred'], std['dis_pred']],
123+
'col_pred': [mean['col_pred'], std['col_pred']],
124+
'loud_pred': [mean['loud_pred'], std['loud_pred']],
125+
'model': self.ds_val.df['model'][0]})
126+
stat = pd.concat([self.ds_val.df, stat])
127+
return stat
128+
54129
def predict(self):
55130
print('---> Predicting ...')
56131
if self.args['tr_parallel']:
@@ -69,15 +144,38 @@ def predict(self):
69144
self.ds_val,
70145
self.args['tr_bs_val'],
71146
self.dev,
72-
num_workers=self.args['tr_num_workers'])
73-
147+
num_workers=self.args['tr_num_workers'])
148+
74149
if self.args['output_dir']:
75150
self.ds_val.df['model'] = self.args['name']
76-
self.ds_val.df.to_csv(
77-
os.path.join(self.args['output_dir'], 'NISQA_results.csv'),
78-
index=False)
79-
80-
print(self.ds_val.df.to_string(index=False))
151+
csv_name = 'NISQA_results.csv'
152+
if 'output_name' in self.args:
153+
csv_name = str(self.args['output_name']) + '.csv'
154+
# whether print mean and standard deviation or not
155+
if 'compute_stats' in self.args and self.args['compute_stats']:
156+
# generate results with mean and deviation to self.statistics.
157+
self.statistics = self._compute_mean_stdDev()
158+
self.statistics.to_csv(
159+
os.path.join(self.args['output_dir'], csv_name),
160+
index=False)
161+
else:
162+
self.ds_val.df.to_csv(
163+
os.path.join(self.args['output_dir'], csv_name),
164+
index=False)
165+
166+
# print either statistics or ds_val based on 'compute_stats' parameter
167+
if 'compute_stats' in self.args and self.args['compute_stats']:
168+
print(self.statistics.to_string(index=False))
169+
else:
170+
print(self.ds_val.df.to_string(index=False))
171+
172+
# Visualization of the results
173+
if 'plot_type' in self.args and self.args['plot_type'] == 'barchart':
174+
self._draw_barchart()
175+
elif 'plot_type' in self.args and self.args['plot_type'] == 'linechart':
176+
self._draw_linechart()
177+
178+
# returned DataFrame is not changed.
81179
return self.ds_val.df
82180

83181
def _train_mos(self):

run_predict.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,17 @@
1616
parser.add_argument('--num_workers', type=int, default=0, help='number of workers for pytorchs dataloader')
1717
parser.add_argument('--bs', type=int, default=1, help='batch size for predicting')
1818
parser.add_argument('--ms_channel', type=int, help='audio channel in case of stereo file')
19+
### Add
20+
# 1) the name of result csv,
21+
# 2) statistics(mean and standard deviation),
22+
# 3) type of visualizations
23+
# 4) whether to save the plot
24+
### to the parameters
25+
parser.add_argument('--output_name', type=str, help='name of the csv result file')
26+
parser.add_argument('--compute_stats', action='store_true', help='whether to calculate the mean and the standard deviation of the results')
27+
parser.add_argument('--plot_type', type=str, default='None', help='Visualization of the results. Either barchart, linechart or None')
28+
parser.add_argument('--plot_name', type=str, default='None', help='name of the plot file if saving plot is needed')
29+
###
1930

2031
args = parser.parse_args()
2132
args = vars(args)
@@ -35,6 +46,12 @@
3546
args['data_dir'] = ''
3647
else:
3748
raise NotImplementedError('--mode given not available')
49+
50+
# `plot_name` can only be set when `plot_type` is not `None``
51+
if 'plot_name' in args and args['plot_name'] != 'None':
52+
if 'plot_type' in args and args['plot_type'] == 'None':
53+
raise ValueError('--plot_name argument can only be set when `plot_type` is either `barchart` or `linechart`')
54+
3855
args['tr_bs_val'] = args['bs']
3956
args['tr_num_workers'] = args['num_workers']
4057

0 commit comments

Comments
 (0)