-
Notifications
You must be signed in to change notification settings - Fork 505
Enable multi rank safetensor consolidation #1625
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This is good. One general question, can the work be split across all ranks on different nodes? Or does this require the assumption that the underlying file system is distributed. |
output_dir=checkpoint_id, | ||
fqn_to_index_mapping=self.sd_adapter.fqn_to_index_mapping, | ||
num_threads=5, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this API take PG as an argument?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, it's just inferring https://github.com/pytorch/pytorch/blob/e20f6d798606f3245686e950c43635bbe526232d/torch/distributed/checkpoint/_consolidate_hf_safetensors.py#L650. Do you think it should?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's better to be the case. I don't know how post training will be and how people are going to split the nodes. But if not all ranks join the checkpoint save and load, dist.get_world()
is not correct. I'm not familiar with the post training use case though. cc., @tianyu-l
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fegin
I don't know either. cc: @allenwang28 if you know.
Does dist.get_rank()
and dist.get_world_size()
rely on NCCL PG? It sounds a bit strange & unnecessary that CPU consolidating relies on GPU info.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tianyu-l Ya your concern makes sense. I can try to think on a better way to do this, let me know if you have suggestions. Since we needed gpus on dcp save, I thought it using torch.dist again to split the work would make sense, as we can assume that users are using gpus, even though they aren't gpu operations. It allows us to use multiple cpus to do the consolidating if they exist, to speed it up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But if not all ranks join the checkpoint save and load, dist.get_world() is not correct.
could you elaborate on when this might be the case? (just so I'm understanding correctly)
In my mental model for Forge/post-training, we would model "training" as its own "distributed world" so e.g. the presence of inference servers won't impact the world size that Titan sees, maybe I'm missing an edge case.
ideally we can use Monarch APIs to gather rank and world size if we wanted to avoid this from the process group, but ofc that assumes we're always running Titan through Monarch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@allenwang28 dist.get_world()
is the default world size, which is the all the ranks that participate in first init_process_group()
. If the inference server is separate from the training world, this shouldn't be a problem.
@ankitageorge That being said, it is better to have an optional argument for PG in case OSS users have a different world setting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a pg argument pytorch/pytorch#161421. If someone could review, would be great.
But this shouldn't be a blocker for this PR it sounds like?
@fegin If multiple nodes are used, then we have to assume that some distributed filesystem is being used. This is true for both the single rank and multi-rank consolidation |
Summary: Based on feedback on pytorch/torchtitan#1625, adding a pg argument to consolidate_safetensors_files_on_every_rank so that we don't infer the pg and users can supply one if needed. Test Plan: ensure existing tests pass Rollback Plan: Differential Revision: D80954339
@@ -387,6 +399,14 @@ def dcp_save( | |||
checkpoint_id=checkpoint_save_id, | |||
) | |||
|
|||
if to_hf and self.sd_adapter.fqn_to_index_mapping: | |||
consolidate_safetensors_files_on_every_rank( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if async_save is used? Does consolidate_safetensors_files_on_every_rank
work with this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we only use to_hf for the final checkpoint save, which isn't async. But ya this wouldn't work with async. We'd need an await and then run this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. I think some check is required to ensure all the configurations combinations is correct. But we can do this in another PR.
Summary: Based on feedback on pytorch/torchtitan#1625, adding a pg argument to consolidate_safetensors_files_on_every_rank so that we don't infer the pg and users can supply one if needed. Test Plan: ensure existing tests pass Rollback Plan: Differential Revision: D80954339 Pull Request resolved: #161421 Approved by: https://github.com/fegin
On saves, we were relying on rank-0 to consolidate the sharded safetensor files, as it was being done in the DCP finish step, which is only done on rank-0. We can instead rely on all ranks available to split this work, speeding up the overall save operation. For the 8B model, the save without consolidation was ~40s on my server with 8 ranks. An extra 20s was for consolidation. This is brought down to 10s with this change. For larger models with more files to be split across more ranks, I would expect larger gains. ``` (titan) [ankitageorge@devvm2888.eag0 /data/users/ankitageorge/torchtitan (multi-rank-safetensor-consolidation)]$ CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh + NGPU=8 + export LOG_RANK=0,1,2,3,4 + LOG_RANK=0,1,2,3,4 + CONFIG_FILE=./torchtitan/models/llama3/train_configs/llama3_8b.toml + TORCHFT_LIGHTHOUSE=http://localhost:29510 + PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + TORCHFT_LIGHTHOUSE=http://localhost:29510 + torchrun --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint=localhost:0 --local-ranks-filter 0,1,2,3,4 --role rank --tee 3 -m torchtitan.train --job.config_file ./torchtitan/models/llama3/train_configs/llama3_8b.toml W0822 07:52:34.351000 2536224 site-packages/torch/distributed/run.py:803] W0822 07:52:34.351000 2536224 site-packages/torch/distributed/run.py:803] ***************************************** W0822 07:52:34.351000 2536224 site-packages/torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. W0822 07:52:34.351000 2536224 site-packages/torch/distributed/run.py:803] ***************************************** [rank3]:[titan] 2025-08-22 07:52:40,492 - root - INFO - Starting job: Llama 3 8B training [rank3]:/home/ankitageorge/.conda/envs/titan/lib/python3.13/site-packages/torch/cuda/__init__.py:410: UserWarning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (Triggered internally at /pytorch/c10/core/AllocatorConfig.cpp:28.) [rank3]: torch._C._cuda_init() [rank2]:[titan] 2025-08-22 07:52:40,768 - root - INFO - Starting job: Llama 3 8B training [rank2]:/home/ankitageorge/.conda/envs/titan/lib/python3.13/site-packages/torch/cuda/__init__.py:410: UserWarning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (Triggered internally at /pytorch/c10/core/AllocatorConfig.cpp:28.) [rank2]: torch._C._cuda_init() [rank0]:[titan] 2025-08-22 07:52:40,796 - root - INFO - Starting job: Llama 3 8B training [rank0]:/home/ankitageorge/.conda/envs/titan/lib/python3.13/site-packages/torch/cuda/__init__.py:410: UserWarning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (Triggered internally at /pytorch/c10/core/AllocatorConfig.cpp:28.) [rank0]: torch._C._cuda_init() [rank1]:[titan] 2025-08-22 07:52:40,788 - root - INFO - Starting job: Llama 3 8B training [rank1]:/home/ankitageorge/.conda/envs/titan/lib/python3.13/site-packages/torch/cuda/__init__.py:410: UserWarning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (Triggered internally at /pytorch/c10/core/AllocatorConfig.cpp:28.) [rank1]: torch._C._cuda_init() [rank4]:[titan] 2025-08-22 07:52:40,728 - root - INFO - Starting job: Llama 3 8B training [rank4]:/home/ankitageorge/.conda/envs/titan/lib/python3.13/site-packages/torch/cuda/__init__.py:410: UserWarning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (Triggered internally at /pytorch/c10/core/AllocatorConfig.cpp:28.) [rank4]: torch._C._cuda_init() [rank3]:[titan] 2025-08-22 07:52:42,690 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config [rank3]:[titan] 2025-08-22 07:52:42,692 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8] [rank3]:[titan] 2025-08-22 07:52:42,698 - root - INFO - [GC] Initial GC collection 0.00 seconds [rank4]:[titan] 2025-08-22 07:52:44,081 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config [rank4]:[titan] 2025-08-22 07:52:44,084 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8] [rank4]:[titan] 2025-08-22 07:52:44,090 - root - INFO - [GC] Initial GC collection 0.00 seconds [rank1]:[titan] 2025-08-22 07:52:44,137 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config [rank1]:[titan] 2025-08-22 07:52:44,139 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8] [rank1]:[titan] 2025-08-22 07:52:44,145 - root - INFO - [GC] Initial GC collection 0.00 seconds [rank0]:[titan] 2025-08-22 07:52:44,215 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config [rank0]:[titan] 2025-08-22 07:52:44,217 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8] [rank0]:NCCL version 2.27.5+cuda12.9 [rank2]:[titan] 2025-08-22 07:52:44,238 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config [rank2]:[titan] 2025-08-22 07:52:44,240 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8] [rank2]:[titan] 2025-08-22 07:52:44,246 - root - INFO - [GC] Initial GC collection 0.00 seconds [rank0]:[titan] 2025-08-22 07:52:44,223 - root - INFO - [GC] Initial GC collection 0.00 seconds [rank1]:[titan] 2025-08-22 07:52:50,420 - root - INFO - Loading tokenizer from tokenizer.json [rank2]:[titan] 2025-08-22 07:52:50,420 - root - INFO - Loading tokenizer from tokenizer.json [rank4]:[titan] 2025-08-22 07:52:50,421 - root - INFO - Loading tokenizer from tokenizer.json [rank0]:[titan] 2025-08-22 07:52:50,420 - root - INFO - Loading tokenizer from tokenizer.json [rank3]:[titan] 2025-08-22 07:52:50,420 - root - INFO - Loading tokenizer from tokenizer.json [rank1]:[titan] 2025-08-22 07:52:50,696 - root - INFO - Preparing c4 dataset from allenai/c4 [rank2]:[titan] 2025-08-22 07:52:50,718 - root - INFO - Preparing c4 dataset from allenai/c4 [rank4]:[titan] 2025-08-22 07:52:50,696 - root - INFO - Preparing c4 dataset from allenai/c4 [rank0]:[titan] 2025-08-22 07:52:50,718 - root - INFO - Preparing c4 dataset from allenai/c4 [rank3]:[titan] 2025-08-22 07:52:50,717 - root - INFO - Preparing c4 dataset from allenai/c4 [rank4]:[titan] 2025-08-22 07:52:56,570 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=0) [rank4]:[titan] 2025-08-22 07:52:56,707 - root - INFO - CUDA capacity: NVIDIA H100 with 94.99GiB memory [rank4]:[titan] 2025-08-22 07:52:56,723 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters [rank4]:[titan] 2025-08-22 07:52:56,724 - root - INFO - Applied selective activation checkpointing to the model [rank4]:[titan] 2025-08-22 07:52:56,802 - root - INFO - Applied FSDP to the model [rank0]:[titan] 2025-08-22 07:52:56,923 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=0) [rank1]:[titan] 2025-08-22 07:52:56,872 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=0) [rank3]:[titan] 2025-08-22 07:52:56,875 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=0) [rank1]:[titan] 2025-08-22 07:52:57,012 - root - INFO - CUDA capacity: NVIDIA H100 with 94.99GiB memory [rank1]:[titan] 2025-08-22 07:52:57,028 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters [rank1]:[titan] 2025-08-22 07:52:57,029 - root - INFO - Applied selective activation checkpointing to the model [rank2]:[titan] 2025-08-22 07:52:57,137 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=0) [rank0]:[titan] 2025-08-22 07:52:57,066 - root - INFO - TensorBoard logging enabled. Logs will be saved at ./outputs/tb/20250822-0752 [rank0]:[titan] 2025-08-22 07:52:57,067 - root - INFO - CUDA capacity: NVIDIA H100 with 94.99GiB memory [rank0]:[titan] 2025-08-22 07:52:57,081 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters [rank0]:[titan] 2025-08-22 07:52:57,082 - root - INFO - Applied selective activation checkpointing to the model [rank1]:[titan] 2025-08-22 07:52:57,104 - root - INFO - Applied FSDP to the model [rank3]:[titan] 2025-08-22 07:52:57,047 - root - INFO - CUDA capacity: NVIDIA H100 with 94.99GiB memory [rank3]:[titan] 2025-08-22 07:52:57,065 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters [rank3]:[titan] 2025-08-22 07:52:57,066 - root - INFO - Applied selective activation checkpointing to the model [rank4]:[titan] 2025-08-22 07:52:57,048 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14 [rank4]:[titan] 2025-08-22 07:52:57,049 - root - INFO - CUDA memory usage for model: 3.77GiB(3.97%) [rank4]:[titan] 2025-08-22 07:52:57,050 - root - WARNING - Warmup steps (200) exceed total training steps (10). Adjusting warmup steps to 10. [rank4]:[titan] 2025-08-22 07:52:57,078 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/hf__checkpoint [rank4]:[titan] 2025-08-22 07:52:57,078 - root - INFO - Mixed precision training is handled by fully_shard [rank4]:[titan] 2025-08-22 07:52:57,078 - root - INFO - Trainer is initialized with local batch size 1, global batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 10 (warmup 200) [rank4]:[titan] 2025-08-22 07:52:57,078 - root - INFO - Training starts at step 1 [rank4]:[titan] 2025-08-22 07:52:57,078 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace [rank0]:[titan] 2025-08-22 07:52:57,159 - root - INFO - Applied FSDP to the model [rank3]:[titan] 2025-08-22 07:52:57,160 - root - INFO - Applied FSDP to the model [rank2]:[titan] 2025-08-22 07:52:57,280 - root - INFO - CUDA capacity: NVIDIA H100 with 94.99GiB memory [rank2]:[titan] 2025-08-22 07:52:57,297 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters [rank2]:[titan] 2025-08-22 07:52:57,298 - root - INFO - Applied selective activation checkpointing to the model [rank2]:[titan] 2025-08-22 07:52:57,374 - root - INFO - Applied FSDP to the model [rank0]:[titan] 2025-08-22 07:52:57,467 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14 [rank0]:[titan] 2025-08-22 07:52:57,467 - root - INFO - CUDA memory usage for model: 3.77GiB(3.97%) [rank0]:[titan] 2025-08-22 07:52:57,469 - root - WARNING - Warmup steps (200) exceed total training steps (10). Adjusting warmup steps to 10. [rank0]:[titan] 2025-08-22 07:52:57,498 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/hf__checkpoint [rank0]:[titan] 2025-08-22 07:52:57,498 - root - INFO - Mixed precision training is handled by fully_shard [rank0]:[titan] 2025-08-22 07:52:57,498 - root - INFO - Trainer is initialized with local batch size 1, global batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 10 (warmup 200) [rank0]:[titan] 2025-08-22 07:52:57,498 - root - INFO - Training starts at step 1 [rank0]:[titan] 2025-08-22 07:52:57,498 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace [rank1]:[titan] 2025-08-22 07:52:57,470 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14 [rank1]:[titan] 2025-08-22 07:52:57,470 - root - INFO - CUDA memory usage for model: 3.77GiB(3.97%) [rank1]:[titan] 2025-08-22 07:52:57,472 - root - WARNING - Warmup steps (200) exceed total training steps (10). Adjusting warmup steps to 10. [rank1]:[titan] 2025-08-22 07:52:57,500 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/hf__checkpoint [rank1]:[titan] 2025-08-22 07:52:57,500 - root - INFO - Mixed precision training is handled by fully_shard [rank1]:[titan] 2025-08-22 07:52:57,500 - root - INFO - Trainer is initialized with local batch size 1, global batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 10 (warmup 200) [rank1]:[titan] 2025-08-22 07:52:57,500 - root - INFO - Training starts at step 1 [rank1]:[titan] 2025-08-22 07:52:57,500 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace [rank3]:[titan] 2025-08-22 07:52:57,459 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14 [rank3]:[titan] 2025-08-22 07:52:57,459 - root - INFO - CUDA memory usage for model: 3.77GiB(3.97%) [rank3]:[titan] 2025-08-22 07:52:57,461 - root - WARNING - Warmup steps (200) exceed total training steps (10). Adjusting warmup steps to 10. [rank3]:[titan] 2025-08-22 07:52:57,489 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/hf__checkpoint [rank3]:[titan] 2025-08-22 07:52:57,489 - root - INFO - Mixed precision training is handled by fully_shard [rank3]:[titan] 2025-08-22 07:52:57,489 - root - INFO - Trainer is initialized with local batch size 1, global batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 10 (warmup 200) [rank3]:[titan] 2025-08-22 07:52:57,489 - root - INFO - Training starts at step 1 [rank3]:[titan] 2025-08-22 07:52:57,489 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace [rank2]:[titan] 2025-08-22 07:52:57,633 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14 [rank2]:[titan] 2025-08-22 07:52:57,633 - root - INFO - CUDA memory usage for model: 3.77GiB(3.97%) [rank2]:[titan] 2025-08-22 07:52:57,635 - root - WARNING - Warmup steps (200) exceed total training steps (10). Adjusting warmup steps to 10. [rank2]:[titan] 2025-08-22 07:52:57,663 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/hf__checkpoint [rank2]:[titan] 2025-08-22 07:52:57,663 - root - INFO - Mixed precision training is handled by fully_shard [rank2]:[titan] 2025-08-22 07:52:57,663 - root - INFO - Trainer is initialized with local batch size 1, global batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 10 (warmup 200) [rank2]:[titan] 2025-08-22 07:52:57,664 - root - INFO - Training starts at step 1 [rank2]:[titan] 2025-08-22 07:52:57,664 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace [rank1]:[titan] 2025-08-22 07:53:02,519 - root - INFO - step: 1 loss: 12.2541 grad_norm: 4.0170 memory: 39.86GiB(41.96%) tps: 1,492 tflops: 86.40 mfu: 8.74% [rank1]:[titan] 2025-08-22 07:53:02,519 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-08-22 07:53:02,519 - root - INFO - step: 1 loss: 12.2541 grad_norm: 4.0170 memory: 39.86GiB(41.96%) tps: 1,506 tflops: 87.24 mfu: 8.82% [rank0]:[titan] 2025-08-22 07:53:02,519 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank2]:[titan] 2025-08-22 07:53:02,519 - root - INFO - step: 1 loss: 12.2541 grad_norm: 4.0170 memory: 39.86GiB(41.96%) tps: 1,569 tflops: 90.84 mfu: 9.19% [rank2]:[titan] 2025-08-22 07:53:02,519 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank3]:[titan] 2025-08-22 07:53:02,519 - root - INFO - step: 1 loss: 12.2541 grad_norm: 4.0170 memory: 39.86GiB(41.96%) tps: 1,502 tflops: 86.97 mfu: 8.79% [rank3]:[titan] 2025-08-22 07:53:02,519 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank4]:[titan] 2025-08-22 07:53:02,519 - root - INFO - step: 1 loss: 12.2541 grad_norm: 4.0170 memory: 39.86GiB(41.96%) tps: 1,413 tflops: 81.85 mfu: 8.28% [rank4]:[titan] 2025-08-22 07:53:02,519 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank3]:[titan] 2025-08-22 07:53:13,006 - root - INFO - step: 10 loss: 9.8309 grad_norm: 4.4717 memory: 47.38GiB(49.87%) tps: 7,031 tflops: 407.19 mfu: 41.17% [rank3]:[titan] 2025-08-22 07:53:13,007 - root - INFO - Saving the checkpoint (or staging if async is enabled). [rank3]:[titan] 2025-08-22 07:53:13,007 - root - INFO - Saving a model only checkpoint in torch.float32 at last step, step 10. [rank0]:[titan] 2025-08-22 07:53:13,006 - root - INFO - step: 10 loss: 9.8309 grad_norm: 4.4717 memory: 47.38GiB(49.87%) tps: 7,031 tflops: 407.21 mfu: 41.17% [rank0]:[titan] 2025-08-22 07:53:13,007 - root - INFO - Saving the checkpoint (or staging if async is enabled). [rank0]:[titan] 2025-08-22 07:53:13,007 - root - INFO - Saving a model only checkpoint in torch.float32 at last step, step 10. [rank1]:[titan] 2025-08-22 07:53:13,006 - root - INFO - step: 10 loss: 9.8309 grad_norm: 4.4717 memory: 47.38GiB(49.87%) tps: 7,031 tflops: 407.19 mfu: 41.17% [rank1]:[titan] 2025-08-22 07:53:13,006 - root - INFO - Saving the checkpoint (or staging if async is enabled). [rank1]:[titan] 2025-08-22 07:53:13,006 - root - INFO - Saving a model only checkpoint in torch.float32 at last step, step 10. [rank2]:[titan] 2025-08-22 07:53:13,006 - root - INFO - step: 10 loss: 9.8309 grad_norm: 4.4717 memory: 47.38GiB(49.87%) tps: 7,031 tflops: 407.19 mfu: 41.17% [rank2]:[titan] 2025-08-22 07:53:13,006 - root - INFO - Saving the checkpoint (or staging if async is enabled). [rank2]:[titan] 2025-08-22 07:53:13,006 - root - INFO - Saving a model only checkpoint in torch.float32 at last step, step 10. [rank4]:[titan] 2025-08-22 07:53:13,006 - root - INFO - step: 10 loss: 9.8309 grad_norm: 4.4717 memory: 47.38GiB(49.87%) tps: 7,031 tflops: 407.19 mfu: 41.17% [rank4]:[titan] 2025-08-22 07:53:13,006 - root - INFO - Saving the checkpoint (or staging if async is enabled). [rank4]:[titan] 2025-08-22 07:53:13,006 - root - INFO - Saving a model only checkpoint in torch.float32 at last step, step 10. [rank0]:Time to save 1: 39.62316728616133 seconds [rank4]:Time to save 1: 39.62515217997134 seconds [rank0]:/home/ankitageorge/.conda/envs/titan/lib/python3.13/site-packages/torch/distributed/distributed_c10d.py:4814: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user. [rank0]: warnings.warn( # warn only once [rank1]:Time to save 1: 39.62614830210805 seconds [rank2]:Time to save 1: 39.62664035195485 seconds [rank3]:Time to save 1: 39.6256582220085 seconds [rank4]:/home/ankitageorge/.conda/envs/titan/lib/python3.13/site-packages/torch/distributed/distributed_c10d.py:4814: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user. [rank4]: warnings.warn( # warn only once [rank2]:/home/ankitageorge/.conda/envs/titan/lib/python3.13/site-packages/torch/distributed/distributed_c10d.py:4814: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user. [rank2]: warnings.warn( # warn only once [rank3]:/home/ankitageorge/.conda/envs/titan/lib/python3.13/site-packages/torch/distributed/distributed_c10d.py:4814: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user. [rank3]: warnings.warn( # warn only once [rank4]:[titan] 2025-08-22 07:54:02,586 - root - INFO - [GC] GC collection invoked by checkpointer. 0.03 seconds [rank4]:[titan] 2025-08-22 07:54:02,587 - root - INFO - Training completed [rank4]:[titan] 2025-08-22 07:54:02,587 - root - INFO - Destroying the purge thread. [rank0]:Time to save 2: 49.51617495715618 seconds [rank4]:Time to save 2: 49.517131249886006 seconds [rank1]:/home/ankitageorge/.conda/envs/titan/lib/python3.13/site-packages/torch/distributed/distributed_c10d.py:4814: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user. [rank1]: warnings.warn( # warn only once [rank1]:[titan] 2025-08-22 07:54:02,574 - root - INFO - [GC] GC collection invoked by checkpointer. 0.02 seconds [rank1]:[titan] 2025-08-22 07:54:02,575 - root - INFO - Training completed [rank1]:[titan] 2025-08-22 07:54:02,575 - root - INFO - Destroying the purge thread. [rank0]:[titan] 2025-08-22 07:54:02,575 - root - INFO - [GC] GC collection invoked by checkpointer. 0.02 seconds [rank0]:[titan] 2025-08-22 07:54:02,576 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank1]:Time to save 2: 49.51811962015927 seconds [rank2]:[titan] 2025-08-22 07:54:02,574 - root - INFO - [GC] GC collection invoked by checkpointer. 0.02 seconds [rank2]:[titan] 2025-08-22 07:54:02,575 - root - INFO - Training completed [rank2]:[titan] 2025-08-22 07:54:02,575 - root - INFO - Destroying the purge thread. [rank2]:Time to save 2: 49.51844248594716 seconds [rank3]:Time to save 2: 49.51776701770723 seconds [rank3]:[titan] 2025-08-22 07:54:02,575 - root - INFO - [GC] GC collection invoked by checkpointer. 0.02 seconds [rank3]:[titan] 2025-08-22 07:54:02,576 - root - INFO - Training completed [rank3]:[titan] 2025-08-22 07:54:02,576 - root - INFO - Destroying the purge thread. [rank0]:[titan] 2025-08-22 07:54:04,576 - root - INFO - Training completed [rank0]:[titan] 2025-08-22 07:54:04,577 - root - INFO - Destroying the purge thread. [rank3]:[titan] 2025-08-22 07:54:05,224 - root - INFO - Process group destroyed [rank4]:[titan] 2025-08-22 07:54:05,272 - root - INFO - Process group destroyed [rank2]:[titan] 2025-08-22 07:54:05,272 - root - INFO - Process group destroyed [rank1]:[titan] 2025-08-22 07:54:05,388 - root - INFO - Process group destroyed [rank0]:[titan] 2025-08-22 07:54:05,924 - root - INFO - Process group destroyed [W822 07:54:09.143827219 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator()) ``` --------- Co-authored-by: ankitageorge <ankitageorge@devvm2888.eag0.facebook.com>
On saves, we were relying on rank-0 to consolidate the sharded safetensor files, as it was being done in the DCP finish step, which is only done on rank-0. We can instead rely on all ranks available to split this work, speeding up the overall save operation. For the 8B model, the save without consolidation was ~40s on my server with 8 ranks. An extra 20s was for consolidation. This is brought down to 10s with this change. For larger models with more files to be split across more ranks, I would expect larger gains.