Skip to content

Commit b30e628

Browse files
racoiawsclaude
andauthored
Reduce code duplication in audio collection + some small fixes (#15587)
* Simplify SchroedingerBridge _step to return scalar loss Move component loss logging (train_loss_encoded, train_loss_time) into _step itself, so it returns a plain scalar like all other models. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Roman Korostik <rkorostik@nvidia.com> * Extract _parse_batch helper into AudioToAudioModel base class Replace duplicated batch parsing and 2D-to-3D reshape logic across all 6 audio model subclasses with a single _parse_batch method on the base class. FlowMatchingAudioToAudioModel overrides it to allow missing target_signal for SSL pretraining. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Roman Korostik <rkorostik@nvidia.com> * Move training_step into AudioToAudioModel base class Add abstract _compute_train_loss method that each subclass implements with its model-specific loss computation. The base class training_step handles batch parsing, logging, and return. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Roman Korostik <rkorostik@nvidia.com> * Guard SB component loss logging with self.training check _step is called from both training and evaluation. The train_loss_encoded and train_loss_time logs should only fire during training. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Roman Korostik <rkorostik@nvidia.com> * Move sample_rate and setup_optimization_flags to base AudioToAudioModel.__init__ Both are set identically by all 6 subclasses. setup_optimization_flags only reads self._cfg, so it is safe to call before subclass-specific module initialization. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Roman Korostik <rkorostik@nvidia.com> * Remove redundant world_size init from EncMaskDec and BNR2 ModelPT.__init__ calls set_trainer → set_world_size before any data loader setup, so the pre-super assignment is always overwritten before it can be read. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Roman Korostik <rkorostik@nvidia.com> * Use self.from_config_dict in EncMaskDecAudioToAudioModel Consistent with all other audio model subclasses which use self.from_config_dict rather than the concrete class name. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Roman Korostik <rkorostik@nvidia.com> * Update setup_optimization_flags docstring Now called from base __init__, no longer requires explicit subclass call. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Roman Korostik <rkorostik@nvidia.com> * Extract _normalize/_denormalize helpers into base class Replace repeated normalize/denormalize boilerplate across 4 forward() and 3 _step() methods with calls to shared helpers on AudioToAudioModel. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Roman Korostik <rkorostik@nvidia.com> * Remove misleading -> tuple annotation from _normalize Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Roman Korostik <rkorostik@nvidia.com> * Fix CodeQL warning: use pass instead of ... in abstract method Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Roman Korostik <rkorostik@nvidia.com> * Fix SB test that calls _step outside Lightning training loop The test calls _step directly, which now logs component losses via self.log. Disable logging in this test since there is no active Lightning loop context. Also update to use _parse_batch and the scalar return from _step. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Roman Korostik <rkorostik@nvidia.com> * Fix _denormalize to be proper inverse of _normalize _normalize divides by (norm_scale + eps), so _denormalize should multiply by (norm_scale + eps) to recover the original signal. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Roman Korostik <rkorostik@nvidia.com> --------- Signed-off-by: Roman Korostik <rkorostik@nvidia.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent c66a379 commit b30e628

4 files changed

Lines changed: 125 additions & 342 deletions

File tree

nemo/collections/audio/models/audio_to_audio.py

Lines changed: 72 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from abc import ABC, abstractmethod
1919
from typing import Dict, List, Optional, Union
2020

21+
import einops
2122
import hydra
2223
import librosa
2324
import soundfile as sf
@@ -50,7 +51,9 @@ class AudioToAudioModel(ModelPT, ABC):
5051
def __init__(self, cfg: DictConfig, trainer: Trainer = None):
5152
super().__init__(cfg=cfg, trainer=trainer)
5253

54+
self.sample_rate = self._cfg.sample_rate
5355
self._setup_loss()
56+
self.setup_optimization_flags()
5457

5558
def _setup_loss(self):
5659
"""Setup loss for this model."""
@@ -130,6 +133,55 @@ def _setup_metrics(self, tag: str = 'val'):
130133
'Setup metrics for %s, dataloader %d: %s', tag, dataloader_idx, ', '.join(metrics_dataloader_idx)
131134
)
132135

136+
def _parse_batch(self, batch):
137+
"""Parse a batch into input signal, target signal, and input length.
138+
139+
Handles both dict-style (lhotse) and tuple-style (AudioToTargetDataset)
140+
batches, and ensures signals are in multi-channel format (B, C, T).
141+
142+
Returns:
143+
Tuple of (input_signal, target_signal, input_length).
144+
"""
145+
if isinstance(batch, dict):
146+
# Lhotse dataloaders produce dict batches
147+
input_signal = batch['input_signal']
148+
input_length = batch['input_length']
149+
target_signal = batch['target_signal']
150+
else:
151+
# Standard audio datasets produce tuple batches
152+
input_signal, input_length, target_signal, _ = batch
153+
154+
if input_signal.ndim == 2:
155+
input_signal = einops.rearrange(input_signal, 'B T -> B 1 T')
156+
if target_signal.ndim == 2:
157+
target_signal = einops.rearrange(target_signal, 'B T -> B 1 T')
158+
159+
return input_signal, target_signal, input_length
160+
161+
@abstractmethod
162+
def _compute_train_loss(self, input_signal, target_signal, input_length):
163+
"""Compute training loss from parsed batch signals.
164+
165+
Args:
166+
input_signal: input audio tensor (B, C, T)
167+
target_signal: target audio tensor (B, C, T)
168+
input_length: length of each example in the batch (B,)
169+
170+
Returns:
171+
Scalar loss tensor.
172+
"""
173+
pass
174+
175+
def training_step(self, batch, batch_idx):
176+
input_signal, target_signal, input_length = self._parse_batch(batch)
177+
loss = self._compute_train_loss(input_signal, target_signal, input_length)
178+
179+
self.log('train_loss', loss)
180+
self.log('learning_rate', self._optimizer.param_groups[0]['lr'])
181+
self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32))
182+
183+
return loss
184+
133185
@abstractmethod
134186
def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'):
135187
pass
@@ -313,6 +365,23 @@ def _setup_process_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoade
313365
temporary_dataloader = self._setup_dataloader_from_config(config=DictConfig(dl_config))
314366
return temporary_dataloader
315367

368+
def _normalize(self, signal: torch.Tensor):
369+
"""Normalize signal so its peak amplitude is 1.
370+
371+
Args:
372+
signal: tensor with shape (B, C, T)
373+
374+
Returns:
375+
Tuple of (normalized_signal, norm_scale). Pass norm_scale to
376+
_denormalize to restore the original scale.
377+
"""
378+
norm_scale = torch.amax(signal.abs(), dim=(-1, -2), keepdim=True)
379+
return signal / (norm_scale + self.eps), norm_scale
380+
381+
def _denormalize(self, signal: torch.Tensor, norm_scale: torch.Tensor) -> torch.Tensor:
382+
"""Restore original scale after _normalize."""
383+
return signal * (norm_scale + self.eps)
384+
316385
@staticmethod
317386
def match_batch_length(input: torch.Tensor, batch_length: int) -> torch.Tensor:
318387
"""Trim or pad the output to match the batch length.
@@ -467,11 +536,10 @@ def list_available_models(cls) -> 'List[PretrainedModelInfo]':
467536
return list_of_models
468537

469538
def setup_optimization_flags(self):
470-
"""
471-
Utility method that must be explicitly called by the subclass in order to support optional optimization flags.
472-
This method is the only valid place to access self.cfg prior to DDP training occurs.
539+
"""Setup optional optimization flags from the model config.
473540
474-
The subclass may chose not to support this method, therefore all variables here must be checked via hasattr()
541+
Called automatically during __init__. This is the only valid place
542+
to access self.cfg prior to DDP training.
475543
"""
476544
# Skip update if nan/inf grads appear on any rank.
477545
self._skip_nan_grad = False

0 commit comments

Comments
 (0)