Skip to content

pack_tensor crashes on non-contiguous tensors during Megatron→vLLM refit (Qwen3.5 / Qwen3ForCausalLM converter) #2417

@jlcanta

Description

@jlcanta

Describe the bug

stream_weights_via_ipc_zmq_impl.pack_tensor (in nemo_rl/models/policy/utils.py) calls tensor.data.view(-1) on tensors yielded by megatron_bridge.export_hf_weights. When the bridge yields a non-contiguous tensor — which happens for the Qwen3ForCausalLM converter, since QKVMapping.megatron_to_hf → split_qkv_weights finishes with .reshape(-1, hidden_size) and reshape is allowed to return a non-contiguous view — view(-1) raises:

RuntimeError: view size is not compatible with input tensor's size and stride
(at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

The crash hits the very first refit_policy_generation call, before any training step completes. The same precondition violation exists in nemo_rl/utils/packed_tensor.py:78 (packed_broadcast_producer), which uses .view(torch.uint8).view(-1) on iterator output and would fail equivalently when the broadcast path is exercised.

Steps/Code to reproduce bug

Environment: NeMo-RL v0.6.0, Megatron-Bridge vendored under 3rdparty/Megatron-Bridge-workspace, Ray + vLLM colocated (async engine, IPC weight streaming), 1 node × 2 H200, TP=2.

What I tested:

Model converter_type Result
Qwen/Qwen3-1.7B Qwen2ForCausalLM refit works
Qwen/Qwen3.5-2B Qwen3ForCausalLM crashes in pack_tensor

Same hardware, same backends, same versions. The trigger is the converter / model architecture; the bug itself is in NeMo-RL's tensor handling.

To reproduce, launch GRPO with a Qwen3.5-2B Megatron + vLLM colocated config:

policy:
  model_name: "Qwen/Qwen3.5-2B"
  megatron_cfg:
    enabled: true
    tensor_model_parallel_size: 2
    converter_type: "Qwen3ForCausalLM"
    sequence_parallel: true
  generation:
    backend: "vllm"
    vllm_cfg:
      tensor_parallel_size: 2
    colocated:
      enabled: true

Crash hits at the first refit. Full traceback:

ray.exceptions.RayTaskError(RuntimeError): ray::MegatronPolicyWorker.stream_weights_via_ipc_zmq()
  File "/opt/nemo-rl/nemo_rl/models/policy/workers/megatron_policy_worker.py", line 1089, in stream_weights_via_ipc_zmq
    stream_weights_via_ipc_zmq_impl(...)
  File "/opt/nemo-rl/nemo_rl/models/policy/utils.py", line 336, in stream_weights_via_ipc_zmq_impl
    used_bytes = pack_tensor(current_buffer, tensor, used_bytes)
  File "/opt/nemo-rl/nemo_rl/models/policy/utils.py", line 296, in pack_tensor
    tensor.data.view(-1).view(dtype=torch.uint8), non_blocking=True
RuntimeError: view size is not compatible with input tensor's size and stride
(at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

Minimal Python reproduction of the failing operation (no GPU / no Ray needed):

import torch
t = torch.randn(8, 16).t()        # non-contiguous
t.data.view(-1)                    # RuntimeError as above
t.data.reshape(-1)                 # works

Expected behavior

The refit path should accept any tensor the params iterator yields, regardless of contiguity. The user shouldn't need to pick a specific converter_type or TP/sharding combination just to avoid an unrelated view(-1) precondition.

Proposed fix: replace tensor.data.view(-1) with tensor.data.reshape(-1) in pack_tensor, and add .contiguous() before .view(torch.uint8).view(-1) in packed_broadcast_producer. Both are no-ops when the tensor is already contiguous (the common case) and add a single copy when it isn't — and pack_tensor is about to copy into a buffer anyway. A PR with a regression test (transpose / strided slice / permute roundtrip through stream_weights_via_ipc_zmq_impl) is coming shortly.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions