diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py b/src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py index 6b1e0754e..e7d1bd863 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py @@ -20,7 +20,13 @@ from ..pipelines.wan.wan_pipeline_2_2 import WanPipeline2_2 from .. import max_logging import orbax.checkpoint as ocp -from maxdiffusion.checkpointing.checkpointing_utils import add_sharding_to_struct, get_cpu_mesh_and_sharding +from flax import nnx +from maxdiffusion.checkpointing.checkpointing_utils import ( + add_sharding_to_struct, + get_cpu_mesh_and_sharding, + create_orbax_checkpoint_manager, + WAN_CHECKPOINT, +) from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer @@ -83,20 +89,100 @@ def load_diffusers_checkpoint(self): pipeline = WanPipeline2_2.from_pretrained(self.config) return pipeline + def _get_pretrained_orbax_dir(self) -> str: + return getattr(self.config, "pretrained_orbax_dir", "") + + def save_pretrained_checkpoint(self, pretrained_dir: str, pipeline: WanPipeline2_2): + """Save pretrained weights (no optimizer state) to orbax for fast subsequent loads.""" + max_logging.log(f"Saving pretrained WAN 2.2 weights to orbax at {pretrained_dir}") + pretrained_mgr = create_orbax_checkpoint_manager( + pretrained_dir, + enable_checkpointing=True, + save_interval_steps=1, + checkpoint_type=WAN_CHECKPOINT, + use_async=False, + ) + _, low_state, _ = nnx.split(pipeline.low_noise_transformer, nnx.Param, ...) + _, high_state, _ = nnx.split(pipeline.high_noise_transformer, nnx.Param, ...) + low_params = low_state.to_pure_dict() + high_params = high_state.to_pure_dict() + wan_config = json.loads(pipeline.low_noise_transformer.to_json_string()) + pretrained_mgr.save( + 0, + args=ocp.args.Composite( + wan_config=ocp.args.JsonSave(wan_config), + low_noise_transformer_state=ocp.args.StandardSave(low_params), + high_noise_transformer_state=ocp.args.StandardSave(high_params), + ), + ) + pretrained_mgr.wait_until_finished() + max_logging.log(f"Pretrained weights saved to {pretrained_dir}") + + def load_pretrained_from_orbax(self, pretrained_dir: str) -> Tuple[Optional[object], Optional[int]]: + """Load pretrained weights from orbax cache if available.""" + try: + pretrained_mgr = create_orbax_checkpoint_manager( + pretrained_dir, + enable_checkpointing=True, + save_interval_steps=1, + checkpoint_type=WAN_CHECKPOINT, + use_async=False, + ) + step = pretrained_mgr.latest_step() + if step is None: + max_logging.log(f"No pretrained orbax checkpoint found in {pretrained_dir}") + return None, None + max_logging.log(f"Found pretrained orbax checkpoint (step {step}) in {pretrained_dir}") + mesh, replicated_sharding = get_cpu_mesh_and_sharding() + metadatas = pretrained_mgr.item_metadata(step) + low_meta = metadatas.low_noise_transformer_state + high_meta = metadatas.high_noise_transformer_state + target_shardings_low = jax.tree_util.tree_map(lambda x: replicated_sharding, low_meta) + target_shardings_high = jax.tree_util.tree_map(lambda x: replicated_sharding, high_meta) + with mesh: + abstract_low = jax.tree_util.tree_map(add_sharding_to_struct, low_meta, target_shardings_low) + abstract_high = jax.tree_util.tree_map(add_sharding_to_struct, high_meta, target_shardings_high) + max_logging.log("Restoring pretrained WAN 2.2 weights from orbax") + restored = pretrained_mgr.restore( + step, + args=ocp.args.Composite( + wan_config=ocp.args.JsonRestore(), + low_noise_transformer_state=ocp.args.StandardRestore(abstract_low), + high_noise_transformer_state=ocp.args.StandardRestore(abstract_high), + ), + ) + return restored, step + except Exception as e: # pylint: disable=broad-except + max_logging.log(f"Failed to load pretrained orbax checkpoint from {pretrained_dir}: {e}") + return None, None + def load_checkpoint(self, step=None) -> Tuple[WanPipeline2_2, Optional[dict], Optional[int]]: + pretrained_dir = self._get_pretrained_orbax_dir() + + # 1. Fast path: load from pretrained orbax cache (skips diffusers entirely). + if pretrained_dir: + restored, loaded_step = self.load_pretrained_from_orbax(pretrained_dir) + if restored is not None: + max_logging.log("Loading WAN 2.2 pipeline from pretrained orbax checkpoint") + pipeline = WanPipeline2_2.from_checkpoint(self.config, restored) + return pipeline, None, loaded_step + + # 2. Try training checkpoint from checkpoint_dir. restored_checkpoint, step = self.load_wan_configs_from_orbax(step) opt_state = None if restored_checkpoint: - max_logging.log("Loading WAN pipeline from checkpoint") + max_logging.log("Loading WAN pipeline from training checkpoint") pipeline = WanPipeline2_2.from_checkpoint(self.config, restored_checkpoint) - # Check for optimizer state in either transformer if "opt_state" in restored_checkpoint.low_noise_transformer_state.keys(): opt_state = restored_checkpoint.low_noise_transformer_state["opt_state"] elif "opt_state" in restored_checkpoint.high_noise_transformer_state.keys(): opt_state = restored_checkpoint.high_noise_transformer_state["opt_state"] else: - max_logging.log("No checkpoint found, loading default pipeline.") + # 3. Slow path: load from diffusers, then cache to orbax for next time. + max_logging.log("No checkpoint found, loading pipeline from diffusers.") pipeline = self.load_diffusers_checkpoint() + if pretrained_dir: + self.save_pretrained_checkpoint(pretrained_dir, pipeline) return pipeline, opt_state, step diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p2.py b/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p2.py index ce3cc7bb1..1845f5e8b 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p2.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p2.py @@ -20,7 +20,13 @@ from ..pipelines.wan.wan_pipeline_i2v_2p2 import WanPipelineI2V_2_2 from .. import max_logging import orbax.checkpoint as ocp -from maxdiffusion.checkpointing.checkpointing_utils import add_sharding_to_struct, get_cpu_mesh_and_sharding +from flax import nnx +from maxdiffusion.checkpointing.checkpointing_utils import ( + add_sharding_to_struct, + get_cpu_mesh_and_sharding, + create_orbax_checkpoint_manager, + WAN_CHECKPOINT, +) from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer @@ -83,20 +89,100 @@ def load_diffusers_checkpoint(self): pipeline = WanPipelineI2V_2_2.from_pretrained(self.config) return pipeline + def _get_pretrained_orbax_dir(self) -> str: + return getattr(self.config, "pretrained_orbax_dir", "") + + def save_pretrained_checkpoint(self, pretrained_dir: str, pipeline: WanPipelineI2V_2_2): + """Save pretrained weights (no optimizer state) to orbax for fast subsequent loads.""" + max_logging.log(f"Saving pretrained WAN 2.2 I2V weights to orbax at {pretrained_dir}") + pretrained_mgr = create_orbax_checkpoint_manager( + pretrained_dir, + enable_checkpointing=True, + save_interval_steps=1, + checkpoint_type=WAN_CHECKPOINT, + use_async=False, + ) + _, low_state, _ = nnx.split(pipeline.low_noise_transformer, nnx.Param, ...) + _, high_state, _ = nnx.split(pipeline.high_noise_transformer, nnx.Param, ...) + low_params = low_state.to_pure_dict() + high_params = high_state.to_pure_dict() + wan_config = json.loads(pipeline.low_noise_transformer.to_json_string()) + pretrained_mgr.save( + 0, + args=ocp.args.Composite( + wan_config=ocp.args.JsonSave(wan_config), + low_noise_transformer_state=ocp.args.StandardSave(low_params), + high_noise_transformer_state=ocp.args.StandardSave(high_params), + ), + ) + pretrained_mgr.wait_until_finished() + max_logging.log(f"Pretrained weights saved to {pretrained_dir}") + + def load_pretrained_from_orbax(self, pretrained_dir: str) -> Tuple[Optional[object], Optional[int]]: + """Load pretrained weights from orbax cache if available.""" + try: + pretrained_mgr = create_orbax_checkpoint_manager( + pretrained_dir, + enable_checkpointing=True, + save_interval_steps=1, + checkpoint_type=WAN_CHECKPOINT, + use_async=False, + ) + step = pretrained_mgr.latest_step() + if step is None: + max_logging.log(f"No pretrained orbax checkpoint found in {pretrained_dir}") + return None, None + max_logging.log(f"Found pretrained orbax checkpoint (step {step}) in {pretrained_dir}") + mesh, replicated_sharding = get_cpu_mesh_and_sharding() + metadatas = pretrained_mgr.item_metadata(step) + low_meta = metadatas.low_noise_transformer_state + high_meta = metadatas.high_noise_transformer_state + target_shardings_low = jax.tree_util.tree_map(lambda x: replicated_sharding, low_meta) + target_shardings_high = jax.tree_util.tree_map(lambda x: replicated_sharding, high_meta) + with mesh: + abstract_low = jax.tree_util.tree_map(add_sharding_to_struct, low_meta, target_shardings_low) + abstract_high = jax.tree_util.tree_map(add_sharding_to_struct, high_meta, target_shardings_high) + max_logging.log("Restoring pretrained WAN 2.2 I2V weights from orbax") + restored = pretrained_mgr.restore( + step, + args=ocp.args.Composite( + wan_config=ocp.args.JsonRestore(), + low_noise_transformer_state=ocp.args.StandardRestore(abstract_low), + high_noise_transformer_state=ocp.args.StandardRestore(abstract_high), + ), + ) + return restored, step + except Exception as e: # pylint: disable=broad-except + max_logging.log(f"Failed to load pretrained orbax checkpoint from {pretrained_dir}: {e}") + return None, None + def load_checkpoint(self, step=None) -> Tuple[WanPipelineI2V_2_2, Optional[dict], Optional[int]]: + pretrained_dir = self._get_pretrained_orbax_dir() + + # 1. Fast path: load from pretrained orbax cache (skips diffusers entirely). + if pretrained_dir: + restored, loaded_step = self.load_pretrained_from_orbax(pretrained_dir) + if restored is not None: + max_logging.log("Loading WAN 2.2 I2V pipeline from pretrained orbax checkpoint") + pipeline = WanPipelineI2V_2_2.from_checkpoint(self.config, restored) + return pipeline, None, loaded_step + + # 2. Try training checkpoint from checkpoint_dir. restored_checkpoint, step = self.load_wan_configs_from_orbax(step) opt_state = None if restored_checkpoint: - max_logging.log("Loading WAN pipeline from checkpoint") + max_logging.log("Loading WAN pipeline from training checkpoint") pipeline = WanPipelineI2V_2_2.from_checkpoint(self.config, restored_checkpoint) - # Check for optimizer state in either transformer if "opt_state" in restored_checkpoint.low_noise_transformer_state.keys(): opt_state = restored_checkpoint.low_noise_transformer_state["opt_state"] elif "opt_state" in restored_checkpoint.high_noise_transformer_state.keys(): opt_state = restored_checkpoint.high_noise_transformer_state["opt_state"] else: - max_logging.log("No checkpoint found, loading default pipeline.") + # 3. Slow path: load from diffusers, then cache to orbax for next time. + max_logging.log("No checkpoint found, loading pipeline from diffusers.") pipeline = self.load_diffusers_checkpoint() + if pretrained_dir: + self.save_pretrained_checkpoint(pretrained_dir, pipeline) return pipeline, opt_state, step diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index bf29fa867..7bbde39db 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -253,6 +253,10 @@ names_which_can_be_offloaded: [] # checkpoint every number of samples, -1 means don't checkpoint. checkpoint_every: -1 checkpoint_dir: "" +# Directory to cache pretrained weights as an orbax checkpoint for fast inference loads. +# On first run (slow, diffusers load), weights are saved here automatically. +# On subsequent runs, weights are loaded from here instead (~10x faster). +pretrained_orbax_dir: "" # enables one replica to read the ckpt then broadcast to the rest enable_single_replica_ckpt_restoring: False diff --git a/src/maxdiffusion/configs/base_wan_i2v_27b.yml b/src/maxdiffusion/configs/base_wan_i2v_27b.yml index 90799524c..fdcd7de5a 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_27b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_27b.yml @@ -252,6 +252,10 @@ names_which_can_be_offloaded: [] # checkpoint every number of samples, -1 means don't checkpoint. checkpoint_every: -1 checkpoint_dir: "" +# Directory to cache pretrained weights as an orbax checkpoint for fast inference loads. +# On first run (slow, diffusers load), weights are saved here automatically. +# On subsequent runs, weights are loaded from here instead (~10x faster). +pretrained_orbax_dir: "" # enables one replica to read the ckpt then broadcast to the rest enable_single_replica_ckpt_restoring: False diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py index 2c294a124..0ffd0a10a 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py @@ -48,12 +48,27 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t common_components = cls._create_common_components(config, vae_only) low_noise_transformer, high_noise_transformer = None, None if not vae_only and load_transformer: + # Restructure the combined checkpoint into per-transformer checkpoints. + # create_sharded_logical_transformer expects {"wan_config": ..., "wan_state": ...}. + if restored_checkpoint is not None: + low_noise_ckpt = { + "wan_config": restored_checkpoint["wan_config"], + "wan_state": restored_checkpoint["low_noise_transformer_state"], + } + high_noise_ckpt = { + "wan_config": restored_checkpoint["wan_config"], + "wan_state": restored_checkpoint["high_noise_transformer_state"], + } + else: + low_noise_ckpt = None + high_noise_ckpt = None + low_noise_transformer = super().load_transformer( devices_array=common_components["devices_array"], mesh=common_components["mesh"], rngs=common_components["rngs"], config=config, - restored_checkpoint=restored_checkpoint, + restored_checkpoint=low_noise_ckpt, subfolder="transformer_2", ) high_noise_transformer = super().load_transformer( @@ -61,7 +76,7 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t mesh=common_components["mesh"], rngs=common_components["rngs"], config=config, - restored_checkpoint=restored_checkpoint, + restored_checkpoint=high_noise_ckpt, subfolder="transformer", ) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py index 1ba54f2eb..9c8044f12 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py @@ -51,12 +51,27 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t low_noise_transformer, high_noise_transformer = None, None if not vae_only: if load_transformer: + # Restructure the combined checkpoint into per-transformer checkpoints. + # create_sharded_logical_transformer expects {"wan_config": ..., "wan_state": ...}. + if restored_checkpoint is not None: + high_noise_ckpt = { + "wan_config": restored_checkpoint["wan_config"], + "wan_state": restored_checkpoint["high_noise_transformer_state"], + } + low_noise_ckpt = { + "wan_config": restored_checkpoint["wan_config"], + "wan_state": restored_checkpoint["low_noise_transformer_state"], + } + else: + high_noise_ckpt = None + low_noise_ckpt = None + high_noise_transformer = super().load_transformer( devices_array=common_components["devices_array"], mesh=common_components["mesh"], rngs=common_components["rngs"], config=config, - restored_checkpoint=restored_checkpoint, + restored_checkpoint=high_noise_ckpt, subfolder="transformer", ) low_noise_transformer = super().load_transformer( @@ -64,7 +79,7 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t mesh=common_components["mesh"], rngs=common_components["rngs"], config=config, - restored_checkpoint=restored_checkpoint, + restored_checkpoint=low_noise_ckpt, subfolder="transformer_2", )