Skip to content

Commit ad8520c

Browse files
pzelaskoclaude
andcommitted
refactor(speechlm2): drop dead BSHD+CP code paths in SALMAutomodel
The fit-start validator already rejects BSHD + CP > 1 with a hard error pointing users to model.packed_sequences=true (see validate_parallelism_compatibility in parts/parallel.py), so any code that exists only to support BSHD under CP is unreachable. In SALMAutomodel.prepare_inputs the BSHD branch's ``if cp_size > 1: shard_bshd_for_cp(...)`` and the ``llm_attention_mask = None if cp_size > 1 else attention_mask`` ternary both presupposed BSHD + CP > 1; remove them and inline the TP-truncation into the BSHD path. Drop the unused shard_bshd_for_cp helper from cp_helpers.py and update its module docstring + the cp_helpers test docstring accordingly. No behavior change for any reachable configuration. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent df2b08e commit ad8520c

3 files changed

Lines changed: 17 additions & 89 deletions

File tree

nemo/collections/speechlm2/models/salm_automodel.py

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,9 @@ def prepare_inputs(self, batch: dict):
197197
from nemo.collections.speechlm2.parts.cp_helpers import (
198198
encode_audio_with_cp_distribution,
199199
get_cp_mesh,
200-
shard_bshd_for_cp,
201200
)
202201

203-
cp_mesh, cp_size, _ = get_cp_mesh(getattr(self, "_device_mesh", None))
202+
cp_mesh, _, _ = get_cp_mesh(getattr(self, "_device_mesh", None))
204203

205204
# Source audio encoding (distributed across CP ranks when CP is active).
206205
# Input audio: (B_aud, T_samples) → list of (L_i, H) embeddings.
@@ -243,37 +242,20 @@ def prepare_inputs(self, batch: dict):
243242
attention_mask = attention_mask[:, :-1]
244243
target_ids = target_ids[:, 1:]
245244

246-
# Sequence-length divisibility for sequence/context parallelism.
247-
# CP path: pad to 2*cp_size*tp_size and partition along the seq dim
248-
# (the existing TP truncation is folded into the CP padding). BSHD-only
249-
# path keeps the original TP-truncation behavior.
250-
tp_size = self.device_mesh["tp"].size() if self._use_tp else 1
251-
if cp_size > 1:
252-
sharded = shard_bshd_for_cp(input_embs, attention_mask, target_ids, cp_mesh, tp_size=tp_size)
253-
input_embs = sharded["input_embs"]
254-
attention_mask = sharded["attention_mask"]
255-
target_ids = sharded["target_ids"]
256-
elif self._use_tp:
245+
# BSHD path runs only when CP is inactive (the fit-start validator
246+
# rejects BSHD + CP > 1, see _validate_parallelism_compatibility).
247+
# Truncate the seq dim to be divisible by tp_size so sequence
248+
# parallelism doesn't reshape the input under us.
249+
if self._use_tp:
250+
tp_size = self.device_mesh["tp"].size()
257251
if (remainder := (input_embs.shape[1] - 1) % tp_size) != 0:
258-
# Truncate some tokens from the end to make the sequence length shape divisible by tensor parallelism
259-
# world size. Otherwise, sequence parallelism will change the input shape making leading to mismatches.
260252
input_embs = input_embs[:, :-remainder]
261253
attention_mask = attention_mask[:, :-remainder]
262254
target_ids = target_ids[:, :-remainder]
263255

264-
# TE's fused-attention CP path rejects ``padding_causal``; only ``causal``
265-
# is supported. BSHD batches are left-padded so dropping the padding mask
266-
# lets pad K/V leak into real-token attention — empirically this drives
267-
# the loss to NaN at step 2 (the gradient through the LoRA / projection
268-
# parameters is corrupted by the leak after one optimizer step). BSHD +
269-
# CP is therefore not a supported configuration; set
270-
# ``model.packed_sequences: true`` to use the THD path under CP, which
271-
# uses cu_seqlens-aware attention and has no equivalent issue.
272-
llm_attention_mask = None if cp_size > 1 else attention_mask
273-
274256
return {
275257
"input_embeds": input_embs,
276-
"attention_mask": llm_attention_mask,
258+
"attention_mask": attention_mask,
277259
"target_ids": target_ids,
278260
"llm_kwargs": {},
279261
}

nemo/collections/speechlm2/parts/cp_helpers.py

Lines changed: 5 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,14 @@
1313
# limitations under the License.
1414
"""Context-Parallelism (CP) helpers for SALMAutomodel.
1515
16-
These helpers consolidate the CP-shape work needed to feed both BSHD and THD
17-
batches into a Nemotron-V3 LLM whose attention/Mamba layers were CP-wired by
18-
the Automodel parallelizer (`set_context_parallel_group()` / `mixer.cp =
19-
MambaContextParallel(...)`). Three concerns:
16+
These helpers consolidate the CP-shape work needed to feed THD packed
17+
batches into a Nemotron-V3 LLM whose attention/Mamba layers were CP-wired
18+
by the Automodel parallelizer (``set_context_parallel_group()`` /
19+
``mixer.cp = MambaContextParallel(...)``). Two concerns:
2020
2121
1. ``get_cp_mesh`` — read the CP submesh out of a device mesh, returning
2222
``(None, 1, 0)`` when CP is inactive so callers can short-circuit.
23-
2. ``shard_bshd_for_cp`` — pad and partition a BSHD batch along the seq dim
24-
using TE's DualChunkSwap pattern (matches Automodel's Config 1 reference
25-
test in ``run_hybrid_nemotron_v3_cp.py``).
26-
3. ``encode_audio_with_cp_distribution`` — distribute the audio encoder
23+
2. ``encode_audio_with_cp_distribution`` — distribute the audio encoder
2724
forward across CP ranks so it isn't recomputed cp_size times. Pads to a
2825
multiple of cp_size with dummy zero-audios so every rank participates in
2926
FSDP all-gather; dummies are dropped after the post-encoder all-gather.
@@ -34,7 +31,6 @@
3431

3532
import torch
3633
import torch.distributed as dist
37-
import torch.nn.functional as F
3834
from torch import Tensor
3935

4036
from nemo.collections.speechlm2.parts.encoder_chunking import encode_audio_with_optional_chunking
@@ -52,55 +48,6 @@ def get_cp_mesh(device_mesh) -> tuple[Optional[object], int, int]:
5248
return cp_mesh, cp_mesh.size(), cp_rank
5349

5450

55-
def shard_bshd_for_cp(
56-
input_embs: Tensor,
57-
attention_mask: Tensor,
58-
target_ids: Tensor,
59-
cp_mesh,
60-
tp_size: int = 1,
61-
) -> dict[str, Tensor]:
62-
"""Pre-shard a BSHD batch across CP ranks via TE's DualChunkSwap pattern.
63-
64-
Right-pads the seq dim to a multiple of ``2 * cp_size * tp_size`` (TE-CP
65-
requires ``2 * cp_size``; SP requires per-rank len divisible by ``tp_size``)
66-
and partitions along the seq dim using
67-
``transformer_engine_torch.thd_get_partitioned_indices``.
68-
69-
Args:
70-
input_embs: ``[B, T, H]`` float.
71-
attention_mask: ``[B, T]`` bool/long; pad slots become 0.
72-
target_ids: ``[B, T]`` int64; pad slots become ``-100``.
73-
cp_mesh: the CP submesh of size ``cp_size > 1``.
74-
tp_size: tensor-parallel world size (1 if TP is inactive).
75-
76-
Returns dict with keys ``input_embs``, ``attention_mask``, ``target_ids``,
77-
each shape ``[B, T_padded // cp_size, ...]``.
78-
"""
79-
import transformer_engine_torch as tex
80-
81-
cp_size = cp_mesh.size()
82-
cp_rank = dist.get_rank(group=cp_mesh.get_group())
83-
device = input_embs.device
84-
85-
B, T, H = input_embs.shape
86-
mult = 2 * cp_size * max(1, tp_size)
87-
T_padded = ((T + mult - 1) // mult) * mult
88-
pad_n = T_padded - T
89-
if pad_n > 0:
90-
input_embs = F.pad(input_embs, (0, 0, 0, pad_n), value=0.0)
91-
attention_mask = F.pad(attention_mask.to(torch.long), (0, pad_n), value=0).to(torch.bool)
92-
target_ids = F.pad(target_ids, (0, pad_n), value=-100)
93-
94-
cu_seqlens = torch.tensor([0, T_padded], dtype=torch.int32, device=device)
95-
indices = tex.thd_get_partitioned_indices(cu_seqlens, T_padded, cp_size, cp_rank)
96-
97-
return {
98-
"input_embs": input_embs.index_select(1, indices).contiguous(),
99-
"attention_mask": attention_mask.index_select(1, indices).contiguous(),
100-
"target_ids": target_ids.index_select(1, indices).contiguous(),
101-
}
102-
103-
10451
def encode_audio_with_cp_distribution(
10552
perception,
10653
audios: Tensor,

tests/collections/speechlm2/test_salm_cp_helpers.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,10 @@
1313
# limitations under the License.
1414
"""CPU-only tests for the CP-helper module.
1515
16-
The ``cp_size > 1`` paths in ``shard_bshd_for_cp`` and
17-
``encode_audio_with_cp_distribution`` require ``transformer_engine_torch``
18-
and a real ``torch.distributed`` process group respectively; they're
19-
exercised by the 2-GPU smoke. These tests cover the fallback contracts
20-
that run on every machine (``cp_mesh is None``, ``B_aud == 0``).
16+
The ``cp_size > 1`` path in ``encode_audio_with_cp_distribution`` requires
17+
a real ``torch.distributed`` process group; it's exercised by the 2-GPU
18+
smoke. These tests cover the fallback contracts that run on every machine
19+
(``cp_mesh is None``, ``B_aud == 0``).
2120
"""
2221
import torch
2322

0 commit comments

Comments
 (0)