Skip to content

Add Support for DeepSpeed's exclude_frozen_parameters argument in DeepSpeedStrategy #20949

@tempoxylophone

Description

@tempoxylophone

Description & Motivation

When training LoRA adapters for large, pretrained model(s), we freeze the original model weights during training. For a checkpoint, we might be interested in saving only the adapter weights, which can be relatively small compared to the base model.

There are a few ways to implement this, but if one is using DeepSpeed in tandem with lightning, a simple keyword argument can take care of this with ease. This keyword argument is exclude_frozen_parameters in the call to DeepSpeed's deepspeed.DeepSpeedEngine.save_checkpoint function.

This can dramatically save disk space; in my use case, I've found making the changes to enable this dropped the .ckpt folder size from ~50gb to ~2gb.

This feature should allow a user who is using DeepSpeed in tandem with lightning to pass this parameter down into the deep speed engine so that exclude_frozen_parameters may be adjusted to the user's needs.

The impact of this feature is ultimately

  • ability to configure a high-level boolean flag in DeepSpeed
  • potentially dramatic savings in disk space with very little code changes

Pitch

Small Proof of Concept

I have implemented this through a hack, which verifiably works for my environment. The below is NOT how I am proposing to implement the feature, but a very small proof of concept that works at the moment using Lightning v2.5.2:

# this save_checkpoint function is from: https://lightning.ai/docs/pytorch/stable/_modules/lightning/pytorch/strategies/deepspeed.html#DeepSpeedStrategy.save_checkpoint
@override
def save_checkpoint(self, checkpoint: dict, filepath, storage_options=None) -> None:
    """Save model/training states as a checkpoint file through state-dump and file-write.

    Args:
        checkpoint: The checkpoint state dictionary
        filepath: write-target file's path
        storage_options: not used for ``DeepSpeedStrategy`` as ``CheckpointIO`` is not used

    Raises:
        TypeError:
            If ``storage_options`` arg is passed in

    """
    # broadcast the filepath from rank 0 to ensure all the states are saved in a common filepath
    filepath = self.broadcast(filepath)

    if storage_options is not None:
        raise TypeError(
            "`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg"
            f" is not supported for `{self.__class__.__name__}` as `CheckpointIO` is not used."
        )

    if self.zero_stage_3 and self._multi_device and self.is_global_zero:
        warning_cache.warn(
             "When saving the DeepSpeed Stage 3 checkpoint, "
             "each worker will save a shard of the checkpoint within a directory. "
             "If a single file is required after training, "
             "see https://lightning.ai/docs/pytorch/stable/advanced/model_parallel.html#"
             "deepspeed-zero-stage-3-single-file for instructions."
        )

    # Use deepspeed's internal checkpointing function to handle partitioned weights across processes
    # dump states as a checkpoint dictionary object
    _exclude_keys = ["state_dict", "optimizer_states"]
    checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys}
    self.deepspeed_engine.save_checkpoint(
        filepath,
        client_state=checkpoint,
        tag="checkpoint",
        # --- 🟢 THE ONLY MODIFICATION 🟢 ---
        exclude_frozen_parameters=True,
        # ----------------------------
    )


from functools import partial
def patch_strategy(trainer: pl.Trainer) -> None:
    # mutates the trainer's strategy if it is deepspeed
    strategy = trainer.strategy
    if not isinstance(strategy, DeepSpeedStrategy):
        # nothing to do
        return
    strategy.save_checkpoint = partial(save_checkpoint, strategy)

Implementation Suggestion

To implement this in a more reasonable way, we can add a new argument to the constructor of lightning.pytorch.strategies.DeepSpeedStrategy:

def __init__(
        self,
        accelerator: Optional["pl.accelerators.Accelerator"] = None,
        zero_optimization: bool = True,
        stage: int = 2,
        remote_device: Optional[str] = None,
        offload_optimizer: bool = False,
        offload_parameters: bool = False,
        offload_params_device: str = "cpu",
        nvme_path: str = "/local_nvme",
        params_buffer_count: int = 5,
        params_buffer_size: int = 100_000_000,
        max_in_cpu: int = 1_000_000_000,
        offload_optimizer_device: str = "cpu",
        optimizer_buffer_count: int = 4,
        block_size: int = 1048576,
        queue_depth: int = 8,
        single_submit: bool = False,
        overlap_events: bool = True,
        thread_count: int = 1,
        pin_memory: bool = False,
        sub_group_size: int = 1_000_000_000_000,
        contiguous_gradients: bool = True,
        overlap_comm: bool = True,
        allgather_partitions: bool = True,
        reduce_scatter: bool = True,
        allgather_bucket_size: int = 200_000_000,
        reduce_bucket_size: int = 200_000_000,
        zero_allow_untested_optimizer: bool = True,
        logging_batch_size_per_gpu: Union[str, int] = "auto",
        config: Optional[Union[_PATH, dict[str, Any]]] = None,
        logging_level: int = logging.WARN,
        parallel_devices: Optional[list[torch.device]] = None,
        cluster_environment: Optional[ClusterEnvironment] = None,
        loss_scale: float = 0,
        initial_scale_power: int = 16,
        loss_scale_window: int = 1000,
        hysteresis: int = 2,
        min_loss_scale: int = 1,
        partition_activations: bool = False,
        cpu_checkpointing: bool = False,
        contiguous_memory_optimization: bool = False,
        synchronize_checkpoint_boundary: bool = False,
        load_full_weights: bool = False,
        precision_plugin: Optional[Precision] = None,
        process_group_backend: Optional[str] = None,
        timeout: Optional[timedelta] = default_pg_timeout,
        # ---- here ----
        on_checkpoint_save_exclude_frozen_parameters: bool = False,
    ) -> None:
    ...
    # after default config is setup:
    self.on_checkpoint_save_exclude_frozen_parameters = on_checkpoint_save_exclude_frozen_parameters

We can then modify the save_checkpoint_function to reference this parameter:

def save_checkpoint(self, checkpoint: dict, filepath: _PATH, storage_options: Optional[Any] = None) -> None:
    """Save model/training states as a checkpoint file through state-dump and file-write.

    Args:
        checkpoint: The checkpoint state dictionary
        filepath: write-target file's path
        storage_options: not used for ``DeepSpeedStrategy`` as ``CheckpointIO`` is not used

    Raises:
        TypeError:
            If ``storage_options`` arg is passed in

    """
    # keep everything else the same as before
    self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint, tag="checkpoint", exclude_frozen_parameters=self.on_checkpoint_save_exclude_frozen_parameters)

There is likely a bit more work to be done (which hopefully may be caught by unit and integration testing), but the above describes the general idea.

Alternatives

Through Callbacks

One can create a callback that performs pruning of the checkpoint after it is saved to disk. This can be problematic, however, if RAM (CPU RAM even) is limited because it requires loading the checkpoint in order to edit it. If the .pt file is larger than 10gb, this might cause OOM. Furthermore, that approach does not save any wasted work, rather it adds more work for the computer to do.

Through Lifecycle Function Override of Modules

An experienced lightning user may achieve this by overriding the state_dict, on_save_checkpoint, load_state_dict functions in their modules. I have tried to explicitly save only adapter weights and was ultimately unsuccessful in cutting down checkpoint size. I am not an experienced lightning user, but I think newcomers being able to use this DeepSpeed flag without too much work aligns with the project philosophy.

Additional context

I am using the preconfigured strategy "deepspeed_stage_2". This works for me within a SLURM cluster with multiple GPUs. I am unfamiliar with how lightning works and am inexperienced with the framework, so my understanding of this code change may be incorrect or have wider implications than I realize.

cc @lantiga @Borda

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions