Skip to content

Commit 31782b7

Browse files
committed
reverted in-place optimizations for generators and discriminators, simplified checkpointing code, refinegan reworked
1 parent 0575c8a commit 31782b7

6 files changed

Lines changed: 106 additions & 211 deletions

File tree

rvc/lib/algorithm/discriminators.py

Lines changed: 13 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -21,36 +21,20 @@ class MultiPeriodDiscriminator(torch.nn.Module):
2121
"""
2222

2323
def __init__(self, use_spectral_norm: bool = False, checkpointing: bool = False):
24-
super(MultiPeriodDiscriminator, self).__init__()
24+
super().__init__()
2525
periods = [2, 3, 5, 7, 11, 17, 23, 37]
2626
self.checkpointing = checkpointing
2727
self.discriminators = torch.nn.ModuleList(
28-
[
29-
DiscriminatorS(
30-
use_spectral_norm=use_spectral_norm, checkpointing=checkpointing
31-
)
32-
]
33-
+ [
34-
DiscriminatorP(
35-
p, use_spectral_norm=use_spectral_norm, checkpointing=checkpointing
36-
)
37-
for p in periods
38-
]
28+
[DiscriminatorS(use_spectral_norm=use_spectral_norm)]
29+
+ [DiscriminatorP(p, use_spectral_norm=use_spectral_norm) for p in periods]
3930
)
4031

4132
def forward(self, y, y_hat):
4233
y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
4334
for d in self.discriminators:
4435
if self.training and self.checkpointing:
45-
46-
def forward_discriminator(d, y, y_hat):
47-
y_d_r, fmap_r = d(y)
48-
y_d_g, fmap_g = d(y_hat)
49-
return y_d_r, fmap_r, y_d_g, fmap_g
50-
51-
y_d_r, fmap_r, y_d_g, fmap_g = checkpoint(
52-
forward_discriminator, d, y, y_hat, use_reentrant=False
53-
)
36+
y_d_r, fmap_r = checkpoint(d, y, use_reentrant=False)
37+
y_d_g, fmap_g = checkpoint(d, y_hat, use_reentrant=False)
5438
else:
5539
y_d_r, fmap_r = d(y)
5640
y_d_g, fmap_g = d(y_hat)
@@ -71,9 +55,9 @@ class DiscriminatorS(torch.nn.Module):
7155
convolutional layers that are applied to the input signal.
7256
"""
7357

74-
def __init__(self, use_spectral_norm: bool = False, checkpointing: bool = False):
75-
super(DiscriminatorS, self).__init__()
76-
self.checkpointing = checkpointing
58+
def __init__(self, use_spectral_norm: bool = False):
59+
super().__init__()
60+
7761
norm_f = spectral_norm if use_spectral_norm else weight_norm
7862
self.convs = torch.nn.ModuleList(
7963
[
@@ -86,16 +70,12 @@ def __init__(self, use_spectral_norm: bool = False, checkpointing: bool = False)
8670
]
8771
)
8872
self.conv_post = norm_f(torch.nn.Conv1d(1024, 1, 3, 1, padding=1))
89-
self.lrelu = torch.nn.LeakyReLU(LRELU_SLOPE, inplace=True)
73+
self.lrelu = torch.nn.LeakyReLU(LRELU_SLOPE)
9074

9175
def forward(self, x):
9276
fmap = []
9377
for conv in self.convs:
94-
if self.training and self.checkpointing:
95-
x = checkpoint(conv, x, use_reentrant=False)
96-
x = checkpoint(self.lrelu, x, use_reentrant=False)
97-
else:
98-
x = self.lrelu(conv(x))
78+
x = self.lrelu(conv(x))
9979
fmap.append(x)
10080
x = self.conv_post(x)
10181
fmap.append(x)
@@ -125,10 +105,8 @@ def __init__(
125105
kernel_size: int = 5,
126106
stride: int = 3,
127107
use_spectral_norm: bool = False,
128-
checkpointing: bool = False,
129108
):
130-
super(DiscriminatorP, self).__init__()
131-
self.checkpointing = checkpointing
109+
super().__init__()
132110
self.period = period
133111
norm_f = spectral_norm if use_spectral_norm else weight_norm
134112

@@ -151,7 +129,7 @@ def __init__(
151129
)
152130

153131
self.conv_post = norm_f(torch.nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
154-
self.lrelu = torch.nn.LeakyReLU(LRELU_SLOPE, inplace=True)
132+
self.lrelu = torch.nn.LeakyReLU(LRELU_SLOPE)
155133

156134
def forward(self, x):
157135
fmap = []
@@ -162,13 +140,8 @@ def forward(self, x):
162140
x = x.view(b, c, -1, self.period)
163141

164142
for conv in self.convs:
165-
if self.training and self.checkpointing:
166-
x = checkpoint(conv, x, use_reentrant=False)
167-
x = checkpoint(self.lrelu, x, use_reentrant=False)
168-
else:
169-
x = self.lrelu(conv(x))
143+
x = self.lrelu(conv(x))
170144
fmap.append(x)
171-
172145
x = self.conv_post(x)
173146
fmap.append(x)
174147
x = torch.flatten(x, 1, -1)

rvc/lib/algorithm/generators/hifigan.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,10 @@ def forward(self, x: torch.Tensor, g: Optional[torch.Tensor] = None):
7575
x = self.conv_pre(x)
7676

7777
if g is not None:
78-
# in-place call
79-
x += self.cond(g)
78+
x = x + self.cond(g)
8079

8180
for i in range(self.num_upsamples):
82-
# in-place call
83-
x = torch.nn.functional.leaky_relu_(x, LRELU_SLOPE)
81+
x = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
8482
x = self.ups[i](x)
8583
xs = None
8684
for j in range(self.num_kernels):
@@ -90,10 +88,10 @@ def forward(self, x: torch.Tensor, g: Optional[torch.Tensor] = None):
9088
xs += self.resblocks[i * self.num_kernels + j](x)
9189
x = xs / self.num_kernels
9290
# in-place call
93-
x = torch.nn.functional.leaky_relu_(x)
91+
x = torch.nn.functional.leaky_relu(x)
9492
x = self.conv_post(x)
9593
# in-place call
96-
x = torch.tanh_(x)
94+
x = torch.tanh(x)
9795

9896
return x
9997

rvc/lib/algorithm/generators/hifigan_mrf.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,9 @@ def __init__(self, channels, kernel_size, dilation):
4343
)
4444

4545
def forward(self, x: torch.Tensor):
46-
# new tensor
4746
y = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
4847
y = self.conv1(y)
49-
# in-place call
50-
y = torch.nn.functional.leaky_relu_(y, LRELU_SLOPE)
48+
y = torch.nn.functional.leaky_relu(y, LRELU_SLOPE)
5149
y = self.conv2(y)
5250
return x + y
5351

@@ -344,36 +342,31 @@ def forward(
344342
f0 = self.f0_upsample(f0[:, None, :]).transpose(-1, -2)
345343
har_source, _, _ = self.m_source(f0)
346344
har_source = har_source.transpose(-1, -2)
347-
# new tensor
348345
x = self.conv_pre(x)
349346

350347
if g is not None:
351-
# in-place call
352-
x += self.cond(g)
348+
x = x + self.cond(g)
353349

354350
for ups, mrf, noise_conv in zip(self.upsamples, self.mrfs, self.noise_convs):
355-
# in-place call
356-
x = torch.nn.functional.leaky_relu_(x, LRELU_SLOPE)
351+
x = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
357352

358353
if self.training and self.checkpointing:
359354
x = checkpoint(ups, x, use_reentrant=False)
355+
x = x + noise_conv(har_source)
356+
xs = sum([
357+
checkpoint(layer, x, use_reentrant=False)
358+
for layer in mrf])
360359
else:
361360
x = ups(x)
361+
x = x + noise_conv(har_source)
362+
xs = sum([
363+
layer(x)
364+
for layer in mrf])
365+
x = xs / self.num_kernels
362366

363-
x += noise_conv(har_source)
367+
x = torch.nn.functional.leaky_relu(x)
368+
x = torch.tanh(self.conv_post(x))
364369

365-
def mrf_sum(x, layers):
366-
return sum(layer(x) for layer in layers) / self.num_kernels
367-
368-
if self.training and self.checkpointing:
369-
x = checkpoint(mrf_sum, x, mrf, use_reentrant=False)
370-
else:
371-
x = mrf_sum(x, mrf)
372-
# in-place call
373-
x = torch.nn.functional.leaky_relu_(x)
374-
x = self.conv_post(x)
375-
# in-place call
376-
x = torch.tanh_(x)
377370
return x
378371

379372
def remove_weight_norm(self):

rvc/lib/algorithm/generators/hifigan_nsf.py

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -179,37 +179,29 @@ def forward(
179179
x = self.conv_pre(x)
180180

181181
if g is not None:
182-
# in-place call
183-
x += self.cond(g)
182+
x = x + self.cond(g)
184183

185184
for i, (ups, noise_convs) in enumerate(zip(self.ups, self.noise_convs)):
186-
# in-place call
187-
x = torch.nn.functional.leaky_relu_(x, self.lrelu_slope)
188-
185+
x = torch.nn.functional.leaky_relu(x, self.lrelu_slope)
189186
# Apply upsampling layer
190187
if self.training and self.checkpointing:
191188
x = checkpoint(ups, x, use_reentrant=False)
189+
x = x + noise_convs(har_source)
190+
xs = sum([
191+
checkpoint(resblock, x, use_reentrant=False)
192+
for j, resblock in enumerate(self.resblocks)
193+
if j in range(i * self.num_kernels, (i + 1) * self.num_kernels)])
192194
else:
193195
x = ups(x)
194-
195-
# Add noise excitation
196-
x += noise_convs(har_source)
197-
198-
# Apply residual blocks
199-
def resblock_forward(x, blocks):
200-
return sum(block(x) for block in blocks) / len(blocks)
201-
202-
blocks = self.resblocks[i * self.num_kernels : (i + 1) * self.num_kernels]
203-
204-
# Checkpoint or regular computation for ResBlocks
205-
if self.training and self.checkpointing:
206-
x = checkpoint(resblock_forward, x, blocks, use_reentrant=False)
207-
else:
208-
x = resblock_forward(x, blocks)
209-
# in-place call
210-
x = torch.nn.functional.leaky_relu_(x)
211-
# in-place call
212-
x = torch.tanh_(self.conv_post(x))
196+
x = x + noise_convs(har_source)
197+
xs = sum([
198+
resblock(x)
199+
for j, resblock in enumerate(self.resblocks)
200+
if j in range(i * self.num_kernels, (i + 1) * self.num_kernels)])
201+
x = xs / self.num_kernels
202+
203+
x = torch.nn.functional.leaky_relu(x)
204+
x = torch.tanh(self.conv_post(x))
213205

214206
return x
215207

0 commit comments

Comments
 (0)