fix(infra): keep model.to(device) on unsharded post-shard load#2146
Open
fix(infra): keep model.to(device) on unsharded post-shard load#2146
Conversation
Persistent buffers initialized via torch.tensor()/torch.ones() inside init_empty_weights() (e.g. Gemma4's Gemma4ClippableLinear input_min/max, Gemma4TextDecoderLayer layer_scalar) stay on CPU because the context only patches register_parameter, not register_buffer. The post-shard load path then unconditionally skipped model.to(device), leaving these buffers stranded and tripping torch.clamp on cuda:0 vs cpu. The skip exists for FSDP's reset_sharded_param issue with tied params under TP>1 (pytorch/pytorch#151085). Narrow it to its actual precondition — any DTensor in the model — so single-GPU, DDP, and other unsharded configs still run model.to(device). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Contributor
Author
|
/claude review |
Comment on lines
+580
to
+581
| has_sharded_params = any(isinstance(p, DTensor) for p in model.parameters()) | ||
| if not (should_load_checkpoint and has_sharded_params): |
Contributor
There was a problem hiding this comment.
Bug: test_skips_model_to_device_when_checkpoint_loaded in tests/unit_tests/_transformers/test_infrastructure.py:193 will fail with this change. _DummyModel uses plain torch.nn.Linear (not DTensor), so has_sharded_params is False, not (True and False) evaluates to True, and model.to() gets called — but the test asserts mock_to.assert_not_called().
The test needs to be updated:
- Rename/update the existing test to assert
model.to()is called (unsharded params + checkpoint loaded = the new Gemma4 fix path). - Add a new test where at least one param is a
DTensormock/instance, to verify the skip still works for the FSDP tied-params case.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
torch.clamp(hidden_states, self.input_min, self.input_max)because persistent buffers (Gemma4ClippableLinear.input_min/max/output_min/max,Gemma4TextDecoderLayer.layer_scalar) were stranded on CPU while params were on cuda:0.init_empty_weights()only patchesregister_parameter, notregister_buffer. So buffers created viatorch.tensor(±inf)/torch.ones(1)stay on CPU. The post-shard load path then unconditionally skippedmodel.to(device), leaving them stranded.reset_sharded_paramissue with tied params under TP>1 (DISABLED test_pp_fsdp_dp_type_FSDP_MP_ScheduleClass2 (__main__.ComposabilityTest) pytorch/pytorch#151085). This PR narrows the skip to its actual precondition — the model has anyDTensorparam — so single-GPU, DDP, and other unsharded configs still getmodel.to(device).Why this location, not a model-level or
init_empty_weightsfixregister_bufferinsideinit_empty_weightsis universal but breaks training: many models (Gemma4 included) register non-persistent buffers likesoftcap,inv_timescalesthat aren't repopulated by_init_weights. They came back as uninitialized GPU memory afterto_empty(device)and produced NaN losses from step 1.input_min/max,output_min/max,layer_scalar,std_bias/scale,Gemma4TextRouter.scale/per_expert_scale) and stay in sync with HF's source.Verification
examples/vlm_finetune/gemma4/gemma4_2b_peft.yaml) ongoogle/gemma-4-E2B-it: training runs end-to-end. Loss 4.2456 → ~1.7 over 250 steps.Test plan
model.to(device)(verified byhas_sharded_params=Truecheck)🤖 Generated with Claude Code