-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Description & Motivation
When training LoRA adapters for large, pretrained model(s), we freeze the original model weights during training. For a checkpoint, we might be interested in saving only the adapter weights, which can be relatively small compared to the base model.
There are a few ways to implement this, but if one is using DeepSpeed in tandem with lightning, a simple keyword argument can take care of this with ease. This keyword argument is exclude_frozen_parameters
in the call to DeepSpeed's deepspeed.DeepSpeedEngine.save_checkpoint
function.
This can dramatically save disk space; in my use case, I've found making the changes to enable this dropped the .ckpt
folder size from ~50gb to ~2gb.
This feature should allow a user who is using DeepSpeed in tandem with lightning to pass this parameter down into the deep speed engine so that exclude_frozen_parameters
may be adjusted to the user's needs.
The impact of this feature is ultimately
- ability to configure a high-level boolean flag in DeepSpeed
- potentially dramatic savings in disk space with very little code changes
Pitch
Small Proof of Concept
I have implemented this through a hack, which verifiably works for my environment. The below is NOT how I am proposing to implement the feature, but a very small proof of concept that works at the moment using Lightning v2.5.2:
# this save_checkpoint function is from: https://lightning.ai/docs/pytorch/stable/_modules/lightning/pytorch/strategies/deepspeed.html#DeepSpeedStrategy.save_checkpoint
@override
def save_checkpoint(self, checkpoint: dict, filepath, storage_options=None) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.
Args:
checkpoint: The checkpoint state dictionary
filepath: write-target file's path
storage_options: not used for ``DeepSpeedStrategy`` as ``CheckpointIO`` is not used
Raises:
TypeError:
If ``storage_options`` arg is passed in
"""
# broadcast the filepath from rank 0 to ensure all the states are saved in a common filepath
filepath = self.broadcast(filepath)
if storage_options is not None:
raise TypeError(
"`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg"
f" is not supported for `{self.__class__.__name__}` as `CheckpointIO` is not used."
)
if self.zero_stage_3 and self._multi_device and self.is_global_zero:
warning_cache.warn(
"When saving the DeepSpeed Stage 3 checkpoint, "
"each worker will save a shard of the checkpoint within a directory. "
"If a single file is required after training, "
"see https://lightning.ai/docs/pytorch/stable/advanced/model_parallel.html#"
"deepspeed-zero-stage-3-single-file for instructions."
)
# Use deepspeed's internal checkpointing function to handle partitioned weights across processes
# dump states as a checkpoint dictionary object
_exclude_keys = ["state_dict", "optimizer_states"]
checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys}
self.deepspeed_engine.save_checkpoint(
filepath,
client_state=checkpoint,
tag="checkpoint",
# --- 🟢 THE ONLY MODIFICATION 🟢 ---
exclude_frozen_parameters=True,
# ----------------------------
)
from functools import partial
def patch_strategy(trainer: pl.Trainer) -> None:
# mutates the trainer's strategy if it is deepspeed
strategy = trainer.strategy
if not isinstance(strategy, DeepSpeedStrategy):
# nothing to do
return
strategy.save_checkpoint = partial(save_checkpoint, strategy)
Implementation Suggestion
To implement this in a more reasonable way, we can add a new argument to the constructor of lightning.pytorch.strategies.DeepSpeedStrategy
:
def __init__(
self,
accelerator: Optional["pl.accelerators.Accelerator"] = None,
zero_optimization: bool = True,
stage: int = 2,
remote_device: Optional[str] = None,
offload_optimizer: bool = False,
offload_parameters: bool = False,
offload_params_device: str = "cpu",
nvme_path: str = "/local_nvme",
params_buffer_count: int = 5,
params_buffer_size: int = 100_000_000,
max_in_cpu: int = 1_000_000_000,
offload_optimizer_device: str = "cpu",
optimizer_buffer_count: int = 4,
block_size: int = 1048576,
queue_depth: int = 8,
single_submit: bool = False,
overlap_events: bool = True,
thread_count: int = 1,
pin_memory: bool = False,
sub_group_size: int = 1_000_000_000_000,
contiguous_gradients: bool = True,
overlap_comm: bool = True,
allgather_partitions: bool = True,
reduce_scatter: bool = True,
allgather_bucket_size: int = 200_000_000,
reduce_bucket_size: int = 200_000_000,
zero_allow_untested_optimizer: bool = True,
logging_batch_size_per_gpu: Union[str, int] = "auto",
config: Optional[Union[_PATH, dict[str, Any]]] = None,
logging_level: int = logging.WARN,
parallel_devices: Optional[list[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
loss_scale: float = 0,
initial_scale_power: int = 16,
loss_scale_window: int = 1000,
hysteresis: int = 2,
min_loss_scale: int = 1,
partition_activations: bool = False,
cpu_checkpointing: bool = False,
contiguous_memory_optimization: bool = False,
synchronize_checkpoint_boundary: bool = False,
load_full_weights: bool = False,
precision_plugin: Optional[Precision] = None,
process_group_backend: Optional[str] = None,
timeout: Optional[timedelta] = default_pg_timeout,
# ---- here ----
on_checkpoint_save_exclude_frozen_parameters: bool = False,
) -> None:
...
# after default config is setup:
self.on_checkpoint_save_exclude_frozen_parameters = on_checkpoint_save_exclude_frozen_parameters
We can then modify the save_checkpoint_function
to reference this parameter:
def save_checkpoint(self, checkpoint: dict, filepath: _PATH, storage_options: Optional[Any] = None) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.
Args:
checkpoint: The checkpoint state dictionary
filepath: write-target file's path
storage_options: not used for ``DeepSpeedStrategy`` as ``CheckpointIO`` is not used
Raises:
TypeError:
If ``storage_options`` arg is passed in
"""
# keep everything else the same as before
self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint, tag="checkpoint", exclude_frozen_parameters=self.on_checkpoint_save_exclude_frozen_parameters)
There is likely a bit more work to be done (which hopefully may be caught by unit and integration testing), but the above describes the general idea.
Alternatives
Through Callbacks
One can create a callback that performs pruning of the checkpoint after it is saved to disk. This can be problematic, however, if RAM (CPU RAM even) is limited because it requires loading the checkpoint in order to edit it. If the .pt
file is larger than 10gb, this might cause OOM. Furthermore, that approach does not save any wasted work, rather it adds more work for the computer to do.
Through Lifecycle Function Override of Modules
An experienced lightning user may achieve this by overriding the state_dict
, on_save_checkpoint
, load_state_dict
functions in their modules. I have tried to explicitly save only adapter weights and was ultimately unsuccessful in cutting down checkpoint size. I am not an experienced lightning user, but I think newcomers being able to use this DeepSpeed flag without too much work aligns with the project philosophy.
Additional context
I am using the preconfigured strategy "deepspeed_stage_2"
. This works for me within a SLURM cluster with multiple GPUs. I am unfamiliar with how lightning works and am inexperienced with the framework, so my understanding of this code change may be incorrect or have wider implications than I realize.