Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- Fixed with adding a missing device id for pytorch 2.8 ([#21105](https://github.com/Lightning-AI/pytorch-lightning/pull/21105))


---
Expand All @@ -33,7 +33,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Added support for NVIDIA H200 GPUs in `get_available_flops` ([#20913](https://github.com/Lightning-AI/pytorch-lightning/pull/21119))
- Added support for NVIDIA H200 GPUs in `get_available_flops` ([#21119](https://github.com/Lightning-AI/pytorch-lightning/pull/21119))



Expand Down
6 changes: 5 additions & 1 deletion src/lightning/fabric/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
_sync_ddp_if_available,
)
from lightning.fabric.utilities.distributed import group as _group
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3
from lightning.fabric.utilities.rank_zero import rank_zero_only

_DDP_FORK_ALIASES = (
Expand Down Expand Up @@ -212,7 +213,10 @@ def _setup_distributed(self) -> None:
self._set_world_ranks()
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
kwargs: dict[str, Any] = {"timeout": self._timeout}
if _TORCH_GREATER_EQUAL_2_3:
kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None
_init_dist_connection(self.cluster_environment, self._process_group_backend, **kwargs)

def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)
Expand Down
5 changes: 4 additions & 1 deletion src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,10 @@ def _setup_distributed(self) -> None:
self._set_world_ranks()
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
kwargs: dict[str, Any] = {"timeout": self._timeout}
if _TORCH_GREATER_EQUAL_2_3:
kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None
_init_dist_connection(self.cluster_environment, self._process_group_backend, **kwargs)

def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)
Expand Down
5 changes: 4 additions & 1 deletion src/lightning/fabric/strategies/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,10 @@ def _setup_distributed(self) -> None:
self._set_world_ranks()
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
kwargs: dict[str, Any] = {"timeout": self._timeout}
if _TORCH_GREATER_EQUAL_2_3:
kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None
_init_dist_connection(self.cluster_environment, self._process_group_backend, **kwargs)

def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)
Expand Down
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed callbacks by defer step/time-triggered `ModelCheckpoint` saves until validation metrics are available ([#21106](https://github.com/Lightning-AI/pytorch-lightning/pull/21106))


- Fixed with adding a missing device id for pytorch 2.8 ([#21105](https://github.com/Lightning-AI/pytorch-lightning/pull/21105))


---

Expand Down
7 changes: 5 additions & 2 deletions src/lightning/pytorch/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
_sync_ddp_if_available,
)
from lightning.fabric.utilities.distributed import group as _group
from lightning.fabric.utilities.imports import _IS_WINDOWS
from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_3
from lightning.fabric.utilities.optimizer import _optimizers_to_device
from lightning.fabric.utilities.seed import reset_seed
from lightning.fabric.utilities.types import ReduceOp
Expand Down Expand Up @@ -200,7 +200,10 @@ def setup_distributed(self) -> None:
self.set_world_ranks()
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
kwargs: dict[str, Any] = {"timeout": self._timeout}
if _TORCH_GREATER_EQUAL_2_3:
kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None
_init_dist_connection(self.cluster_environment, self._process_group_backend, **kwargs)

def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)
Expand Down
7 changes: 5 additions & 2 deletions src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
_sync_ddp_if_available,
)
from lightning.fabric.utilities.distributed import group as _group
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2, _TORCH_GREATER_EQUAL_2_3
from lightning.fabric.utilities.init import _has_meta_device_parameters_or_buffers
from lightning.fabric.utilities.load import _lazy_load, _materialize_tensors
from lightning.fabric.utilities.optimizer import _optimizers_to_device
Expand Down Expand Up @@ -260,7 +260,10 @@ def setup_environment(self) -> None:

self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
kwargs: dict[str, Any] = {"timeout": self._timeout}
if _TORCH_GREATER_EQUAL_2_3:
kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None
_init_dist_connection(self.cluster_environment, self._process_group_backend, **kwargs)

# if 'device_mesh' in the `kwargs` is provided as a tuple, update it into the `DeviceMesh` object here
if isinstance(self.kwargs.get("device_mesh"), tuple):
Expand Down
7 changes: 5 additions & 2 deletions src/lightning/pytorch/strategies/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
_sync_ddp_if_available,
)
from lightning.fabric.utilities.distributed import group as _group
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3, _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.init import _materialize_distributed_module
from lightning.fabric.utilities.load import _METADATA_FILENAME
from lightning.fabric.utilities.optimizer import _optimizers_to_device
Expand Down Expand Up @@ -350,7 +350,10 @@ def _setup_distributed(self) -> None:
self.set_world_ranks()
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
kwargs: dict[str, Any] = {"timeout": self._timeout}
if _TORCH_GREATER_EQUAL_2_3:
kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None
_init_dist_connection(self.cluster_environment, self._process_group_backend, **kwargs)

def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)
Expand Down
49 changes: 48 additions & 1 deletion tests/tests_fabric/strategies/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from lightning.fabric.plugins.environments import LightningEnvironment
from lightning.fabric.strategies import DDPStrategy
from lightning.fabric.strategies.ddp import _DDPBackwardSyncControl
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3
from tests_fabric.helpers.runif import RunIf


Expand Down Expand Up @@ -168,6 +169,52 @@ def test_set_timeout(init_process_group_mock):
process_group_backend = strategy._get_process_group_backend()
global_rank = strategy.cluster_environment.global_rank()
world_size = strategy.cluster_environment.world_size()
kwargs = {}
if _TORCH_GREATER_EQUAL_2_3:
kwargs["device_id"] = strategy.root_device if strategy.root_device.type != "cpu" else None
init_process_group_mock.assert_called_with(
process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta
process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta, **kwargs
)


@mock.patch("torch.distributed.init_process_group")
def test_device_id_passed_for_cuda_devices(init_process_group_mock):
"""Test that device_id is passed to init_process_group for CUDA devices but not for CPU."""
# Test with CPU device - device_id should be None
cpu_strategy = DDPStrategy(parallel_devices=[torch.device("cpu")])
cpu_strategy.cluster_environment = LightningEnvironment()
cpu_strategy.accelerator = Mock()
cpu_strategy.setup_environment()

process_group_backend = cpu_strategy._get_process_group_backend()
global_rank = cpu_strategy.cluster_environment.global_rank()
world_size = cpu_strategy.cluster_environment.world_size()
kwargs = {}
if _TORCH_GREATER_EQUAL_2_3:
kwargs["device_id"] = cpu_strategy.root_device if cpu_strategy.root_device.type != "cpu" else None
init_process_group_mock.assert_called_with(
process_group_backend, rank=global_rank, world_size=world_size, timeout=cpu_strategy._timeout, **kwargs
)

init_process_group_mock.reset_mock()

# Test with CUDA device - device_id should be the device
cuda_device = torch.device("cuda", 0)
cuda_strategy = DDPStrategy(parallel_devices=[cuda_device])
cuda_strategy.cluster_environment = LightningEnvironment()
cuda_strategy.accelerator = Mock()
cuda_strategy.setup_environment()

process_group_backend = cuda_strategy._get_process_group_backend()
global_rank = cuda_strategy.cluster_environment.global_rank()
world_size = cuda_strategy.cluster_environment.world_size()
kwargs = {}
if _TORCH_GREATER_EQUAL_2_3:
kwargs["device_id"] = cuda_strategy.root_device if cuda_strategy.root_device.type != "cpu" else None
init_process_group_mock.assert_called_with(
process_group_backend,
rank=global_rank,
world_size=world_size,
timeout=cuda_strategy._timeout,
**kwargs,
)
7 changes: 5 additions & 2 deletions tests/tests_fabric/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
_get_full_state_dict_context,
_is_sharded_checkpoint,
)
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2, _TORCH_GREATER_EQUAL_2_3


def test_custom_mixed_precision():
Expand Down Expand Up @@ -381,8 +381,11 @@ def test_set_timeout(init_process_group_mock):
process_group_backend = strategy._get_process_group_backend()
global_rank = strategy.cluster_environment.global_rank()
world_size = strategy.cluster_environment.world_size()
kwargs = {}
if _TORCH_GREATER_EQUAL_2_3:
kwargs["device_id"] = strategy.root_device if strategy.root_device.type != "cpu" else None
init_process_group_mock.assert_called_with(
process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta
process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta, **kwargs
)


Expand Down
6 changes: 5 additions & 1 deletion tests/tests_fabric/strategies/test_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from lightning.fabric.strategies import ModelParallelStrategy
from lightning.fabric.strategies.fsdp import _is_sharded_checkpoint
from lightning.fabric.strategies.model_parallel import _ParallelBackwardSyncControl
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3
from tests_fabric.helpers.runif import RunIf


Expand Down Expand Up @@ -316,8 +317,11 @@ def test_set_timeout(init_process_group_mock, _):
process_group_backend = strategy._get_process_group_backend()
global_rank = strategy.cluster_environment.global_rank()
world_size = strategy.cluster_environment.world_size()
kwargs = {}
if _TORCH_GREATER_EQUAL_2_3:
kwargs["device_id"] = strategy.root_device if strategy.root_device.type != "cpu" else None
init_process_group_mock.assert_called_with(
process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta
process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta, **kwargs
)


Expand Down
36 changes: 35 additions & 1 deletion tests/tests_pytorch/strategies/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torch.nn.parallel import DistributedDataParallel

from lightning.fabric.plugins.environments import LightningEnvironment
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.plugins import DoublePrecision, HalfPrecision, Precision
Expand Down Expand Up @@ -132,8 +133,41 @@ def test_set_timeout(mock_init_process_group):
process_group_backend = trainer.strategy._get_process_group_backend()
global_rank = trainer.strategy.cluster_environment.global_rank()
world_size = trainer.strategy.cluster_environment.world_size()
kwargs = {}
if _TORCH_GREATER_EQUAL_2_3:
kwargs["device_id"] = trainer.strategy.root_device if trainer.strategy.root_device.type != "cpu" else None
mock_init_process_group.assert_called_with(
process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta
process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta, **kwargs
)


@mock.patch("torch.distributed.init_process_group")
def test_device_id_passed_for_cuda_devices_pytorch(mock_init_process_group):
"""Test that device_id is passed to init_process_group for CUDA devices but not for CPU."""
# Test with CPU device - device_id should be None
model = BoringModel()
ddp_strategy = DDPStrategy()
trainer = Trainer(
max_epochs=1,
accelerator="cpu",
strategy=ddp_strategy,
)
trainer.strategy.connect(model)
trainer.lightning_module.trainer = trainer
trainer.strategy.setup_environment()

process_group_backend = trainer.strategy._get_process_group_backend()
global_rank = trainer.strategy.cluster_environment.global_rank()
world_size = trainer.strategy.cluster_environment.world_size()
kwargs = {}
if _TORCH_GREATER_EQUAL_2_3:
kwargs["device_id"] = trainer.strategy.root_device if trainer.strategy.root_device.type != "cpu" else None
mock_init_process_group.assert_called_with(
process_group_backend,
rank=global_rank,
world_size=world_size,
timeout=trainer.strategy._timeout,
**kwargs,
)


Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/strategies/test_ddp_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_multi_gpu_model_ddp_fit_test(tmp_path):
assert out["test_acc"] > 0.7


@RunIf(skip_windows=True)
@RunIf(skip_windows=True, max_torch="2.7")
@mock.patch("torch.cuda.set_device")
@mock.patch("lightning.pytorch.accelerators.cuda._check_cuda_matmul_precision")
@mock.patch("lightning.pytorch.accelerators.cuda._clear_cuda_memory")
Expand Down
7 changes: 5 additions & 2 deletions tests/tests_pytorch/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from lightning.fabric.plugins.environments import LightningEnvironment
from lightning.fabric.strategies.fsdp import _is_sharded_checkpoint
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2, _TORCH_GREATER_EQUAL_2_3
from lightning.fabric.utilities.load import _load_distributed_checkpoint
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
Expand Down Expand Up @@ -532,8 +532,11 @@ def test_set_timeout(init_process_group_mock):
process_group_backend = strategy._get_process_group_backend()
global_rank = strategy.cluster_environment.global_rank()
world_size = strategy.cluster_environment.world_size()
kwargs = {}
if _TORCH_GREATER_EQUAL_2_3:
kwargs["device_id"] = strategy.root_device if strategy.root_device.type != "cpu" else None
init_process_group_mock.assert_called_with(
process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta
process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta, **kwargs
)


Expand Down
6 changes: 5 additions & 1 deletion tests/tests_pytorch/strategies/test_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torch.nn as nn

from lightning.fabric.strategies.model_parallel import _is_sharded_checkpoint
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3
from lightning.pytorch import LightningModule
from lightning.pytorch.plugins.environments import LightningEnvironment
from lightning.pytorch.strategies import ModelParallelStrategy
Expand Down Expand Up @@ -202,8 +203,11 @@ def test_set_timeout(init_process_group_mock, _):
process_group_backend = strategy._get_process_group_backend()
global_rank = strategy.cluster_environment.global_rank()
world_size = strategy.cluster_environment.world_size()
kwargs = {}
if _TORCH_GREATER_EQUAL_2_3:
kwargs["device_id"] = strategy.root_device if strategy.root_device.type != "cpu" else None
init_process_group_mock.assert_called_with(
process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta
process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta, **kwargs
)


Expand Down
Loading