-
Notifications
You must be signed in to change notification settings - Fork 38
Expand file tree
/
Copy pathutils.py
More file actions
111 lines (76 loc) · 3.22 KB
/
utils.py
File metadata and controls
111 lines (76 loc) · 3.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import copy
import torch
import numpy as np
from scipy import signal
from librosa.filters import mel
from scipy.signal import get_window
import torch
import torch.nn as nn
import torch.nn.functional as F
def butter_highpass(cutoff, fs, order=5):
nyq = 0.5 * fs
normal_cutoff = cutoff / nyq
b, a = signal.butter(order, normal_cutoff, btype='high', analog=False)
return b, a
def pySTFT(x, fft_length=1024, hop_length=256):
x = np.pad(x, int(fft_length//2), mode='reflect')
noverlap = fft_length - hop_length
shape = x.shape[:-1]+((x.shape[-1]-noverlap)//hop_length, fft_length)
strides = x.strides[:-1]+(hop_length*x.strides[-1], x.strides[-1])
result = np.lib.stride_tricks.as_strided(x, shape=shape,
strides=strides)
fft_window = get_window('hann', fft_length, fftbins=True)
result = np.fft.rfft(fft_window * result, n=fft_length).T
return np.abs(result)
class LinearNorm(torch.nn.Module):
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
super(LinearNorm, self).__init__()
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
torch.nn.init.xavier_uniform_(
self.linear_layer.weight,
gain=torch.nn.init.calculate_gain(w_init_gain))
def forward(self, x):
return self.linear_layer(x)
class ConvNorm(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
padding=None, dilation=1, bias=True, w_init_gain='linear'):
super(ConvNorm, self).__init__()
if padding is None:
assert(kernel_size % 2 == 1)
padding = int(dilation * (kernel_size - 1) / 2)
self.conv = torch.nn.Conv1d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation,
bias=bias)
torch.nn.init.xavier_uniform_(
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
def forward(self, signal):
conv_signal = self.conv(signal)
return conv_signal
def filter_bank_mean(num_rep, codes_mask, max_len_long):
'''
num_rep (B, L)
codes_mask (B, L)
output: filterbank (B, L, max_len_fake)
zero pad in codes must be real zero
'''
num_rep = num_rep.unsqueeze(-1) # (B, L, 1)
codes_mask = codes_mask.unsqueeze(-1) # (B, L, 1)
num_rep = num_rep * codes_mask
right_edge = num_rep.cumsum(dim=1)
left_edge = torch.zeros_like(right_edge)
left_edge[:, 1:, :] = right_edge[:, :-1, :]
right_edge = right_edge.ceil()
left_edge = left_edge.floor()
index = torch.arange(1, max_len_long+1, device=num_rep.device).view(1, 1, -1)
lower = index - left_edge
right_edge_flip = max_len_long - right_edge
upper = (index - right_edge_flip).flip(dims=(2,))
# triangular pooling
fb = F.relu(torch.min(lower, upper)).float()
# mean pooling
fb = (fb > 0).float()
norm = fb.sum(dim=-1, keepdim=True)
norm[norm==0] = 1.0
fb = fb / norm
return fb * codes_mask