Skip to content

FineGrainedFP8Config(dequantize=true) materializes full BF16 model per rank before sharding #2114

@HuiyingLi

Description

@HuiyingLi

What

The standard FP8-checkpoint load path

```yaml
quantization_config:
target: transformers.FineGrainedFP8Config
dequantize: true
```

dequantizes the entire FP8 checkpoint into BF16 on every rank before any TP / FSDP / PP sharding kicks in. The full BF16 model is the per-rank peak.

This is fine on small models (24B BF16 ≈ 48 GB, fits 80 GB H100) but OOMs at scale.

Why it matters

Model Full BF16 footprint Fits 80 GB H100? Fits 141 GB H200?
Devstral-Small-2-24B (FP8) ~48 GB yes yes
Devstral-2-123B (FP8) ~246 GB no no
Mistral-Medium-3.5-128B (FP8) ~256 GB no no

For 123B / 128B+ FP8 checkpoints there's no per-rank configuration — even TP=8 PP=8 across 64 GPUs — that lets the HF dequant path run, because the OOM happens before the model is sharded. Cranking up node count does nothing.

Today's solution

`nemo_automodel/components/models/mistral3_vlm/state_dict_adapter.py` (see `Mistral3FP8StateDictAdapter.for_vlm_full()`) hooks the DCP load so each rank only reads its own shard from disk: pull the FP8 weight + per-tensor `weight_scale_inv` for the local shard, dequantize to BF16 in place, emit the local DTensor. Per-rank peak ≈ shard size, not full model.

This works but is a one-off — it lives inside the `mistral3_vlm` package and only kicks in when the resolver hook claims the config (i.e. for the dawn-ridge 128B specifically). Any other FP8-native dense / MoE checkpoint at this scale would need to ship its own adapter.

Possible shapes:

  1. Generalize the `Mistral3FP8StateDictAdapter` factory pattern into a `StreamingFP8StateDictAdapter` keyed off `config.quantization_config.quant_method == "fp8"`, registered for any model whose checkpoint declares per-tensor or per-block FP8.
  2. Push the streaming dequant upstream into `transformers/quantizers/quantizer_finegrained_fp8.py` (the loader that processes `FineGrainedFP8Config`) so HF itself reads + dequants per shard rather than per full tensor.
  3. Document the failure mode loudly: today the OOM happens deep in the HF loader with no log line saying "this can't work at this scale, use a streaming adapter". A pre-flight check in `nemo_automodel/_transformers/model_init.py` that estimates BF16 footprint vs available HBM and refuses to start would at least convert the silent OOM into a useful error.

Repro

Any FP8 checkpoint > ~80 GB BF16 with `FineGrainedFP8Config(dequantize=true)`, e.g. `mistralai/Devstral-2-123B-Instruct-2512` on 8 × H100-80GB:

```yaml
model:
target: nemo_automodel.NeMoAutoModelForCausalLM.from_pretrained
pretrained_model_name_or_path: mistralai/Devstral-2-123B-Instruct-2512
force_hf: true
quantization_config:
target: transformers.FineGrainedFP8Config
dequantize: true
```

OOMs at HF dequant before any sharding.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions