Skip to content

fix(infra): keep model.to(device) on unsharded post-shard load#2146

Open
HuiyingLi wants to merge 1 commit intomainfrom
huiyingl/fix/single-gpu-peft-buffer-device
Open

fix(infra): keep model.to(device) on unsharded post-shard load#2146
HuiyingLi wants to merge 1 commit intomainfrom
huiyingl/fix/single-gpu-peft-buffer-device

Conversation

@HuiyingLi
Copy link
Copy Markdown
Contributor

Summary

  • Single-GPU + custom-registry HF model + PEFT crashed in Gemma4's vision tower at torch.clamp(hidden_states, self.input_min, self.input_max) because persistent buffers (Gemma4ClippableLinear.input_min/max/output_min/max, Gemma4TextDecoderLayer.layer_scalar) were stranded on CPU while params were on cuda:0.
  • Root cause: init_empty_weights() only patches register_parameter, not register_buffer. So buffers created via torch.tensor(±inf) / torch.ones(1) stay on CPU. The post-shard load path then unconditionally skipped model.to(device), leaving them stranded.
  • The skip was added for FSDP's reset_sharded_param issue with tied params under TP>1 (DISABLED test_pp_fsdp_dp_type_FSDP_MP_ScheduleClass2 (__main__.ComposabilityTest) pytorch/pytorch#151085). This PR narrows the skip to its actual precondition — the model has any DTensor param — so single-GPU, DDP, and other unsharded configs still get model.to(device).

Why this location, not a model-level or init_empty_weights fix

  • Patching register_buffer inside init_empty_weights is universal but breaks training: many models (Gemma4 included) register non-persistent buffers like softcap, inv_timescales that aren't repopulated by _init_weights. They came back as uninitialized GPU memory after to_empty(device) and produced NaN losses from step 1.
  • Gemma4-only fix in the custom wrapper would have to enumerate every persistent buffer (input_min/max, output_min/max, layer_scalar, std_bias/scale, Gemma4TextRouter.scale/per_expert_scale) and stay in sync with HF's source.
  • The infrastructure-level condition is one line, exactly tracks the FSDP precondition, and catches every model with this pattern.

Verification

  • 1-GPU PEFT (examples/vlm_finetune/gemma4/gemma4_2b_peft.yaml) on google/gemma-4-E2B-it: training runs end-to-end. Loss 4.2456 → ~1.7 over 250 steps.
  • 8-GPU FSDP run on the same yaml: parity with 1-GPU within 0.01–0.08 across the first 20 steps (same seed, same global batch=8).

Test plan

  • CI green
  • Existing single-GPU and multi-GPU PEFT recipes unaffected
  • Multi-GPU TP>1 / FSDP runs still skip model.to(device) (verified by has_sharded_params=True check)

🤖 Generated with Claude Code

Persistent buffers initialized via torch.tensor()/torch.ones() inside
init_empty_weights() (e.g. Gemma4's Gemma4ClippableLinear input_min/max,
Gemma4TextDecoderLayer layer_scalar) stay on CPU because the context
only patches register_parameter, not register_buffer. The post-shard
load path then unconditionally skipped model.to(device), leaving these
buffers stranded and tripping torch.clamp on cuda:0 vs cpu.

The skip exists for FSDP's reset_sharded_param issue with tied params
under TP>1 (pytorch/pytorch#151085). Narrow it to its actual
precondition — any DTensor in the model — so single-GPU, DDP, and other
unsharded configs still run model.to(device).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 6, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@HuiyingLi
Copy link
Copy Markdown
Contributor Author

/claude review

Comment on lines +580 to +581
has_sharded_params = any(isinstance(p, DTensor) for p in model.parameters())
if not (should_load_checkpoint and has_sharded_params):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: test_skips_model_to_device_when_checkpoint_loaded in tests/unit_tests/_transformers/test_infrastructure.py:193 will fail with this change. _DummyModel uses plain torch.nn.Linear (not DTensor), so has_sharded_params is False, not (True and False) evaluates to True, and model.to() gets called — but the test asserts mock_to.assert_not_called().

The test needs to be updated:

  1. Rename/update the existing test to assert model.to() is called (unsharded params + checkpoint loaded = the new Gemma4 fix path).
  2. Add a new test where at least one param is a DTensor mock/instance, to verify the skip still works for the FSDP tied-params case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant