diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 16cd42adc90d9..cb37fcc6a937f 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -24,7 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- +- Fixed with adding a missing device id for pytorch 2.8 ([#21105](https://github.com/Lightning-AI/pytorch-lightning/pull/21105)) --- @@ -33,7 +33,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed -- Added support for NVIDIA H200 GPUs in `get_available_flops` ([#20913](https://github.com/Lightning-AI/pytorch-lightning/pull/21119)) +- Added support for NVIDIA H200 GPUs in `get_available_flops` ([#21119](https://github.com/Lightning-AI/pytorch-lightning/pull/21119)) diff --git a/src/lightning/fabric/strategies/ddp.py b/src/lightning/fabric/strategies/ddp.py index ce47e4e403c34..e826b910c16d3 100644 --- a/src/lightning/fabric/strategies/ddp.py +++ b/src/lightning/fabric/strategies/ddp.py @@ -41,6 +41,7 @@ _sync_ddp_if_available, ) from lightning.fabric.utilities.distributed import group as _group +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3 from lightning.fabric.utilities.rank_zero import rank_zero_only _DDP_FORK_ALIASES = ( @@ -212,7 +213,10 @@ def _setup_distributed(self) -> None: self._set_world_ranks() self._process_group_backend = self._get_process_group_backend() assert self.cluster_environment is not None - _init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout) + kwargs: dict[str, Any] = {"timeout": self._timeout} + if _TORCH_GREATER_EQUAL_2_3: + kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None + _init_dist_connection(self.cluster_environment, self._process_group_backend, **kwargs) def _get_process_group_backend(self) -> str: return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device) diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index 972cb0a2cd840..baaee74af0ec9 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -663,7 +663,10 @@ def _setup_distributed(self) -> None: self._set_world_ranks() self._process_group_backend = self._get_process_group_backend() assert self.cluster_environment is not None - _init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout) + kwargs: dict[str, Any] = {"timeout": self._timeout} + if _TORCH_GREATER_EQUAL_2_3: + kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None + _init_dist_connection(self.cluster_environment, self._process_group_backend, **kwargs) def _get_process_group_backend(self) -> str: return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device) diff --git a/src/lightning/fabric/strategies/model_parallel.py b/src/lightning/fabric/strategies/model_parallel.py index ace23a9c7a2c5..0d49ddf91a0bc 100644 --- a/src/lightning/fabric/strategies/model_parallel.py +++ b/src/lightning/fabric/strategies/model_parallel.py @@ -302,7 +302,10 @@ def _setup_distributed(self) -> None: self._set_world_ranks() self._process_group_backend = self._get_process_group_backend() assert self.cluster_environment is not None - _init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout) + kwargs: dict[str, Any] = {"timeout": self._timeout} + if _TORCH_GREATER_EQUAL_2_3: + kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None + _init_dist_connection(self.cluster_environment, self._process_group_backend, **kwargs) def _get_process_group_backend(self) -> str: return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 176e34273d776..03664c8e2d1ad 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -31,6 +31,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed callbacks by defer step/time-triggered `ModelCheckpoint` saves until validation metrics are available ([#21106](https://github.com/Lightning-AI/pytorch-lightning/pull/21106)) +- Fixed with adding a missing device id for pytorch 2.8 ([#21105](https://github.com/Lightning-AI/pytorch-lightning/pull/21105)) + --- diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index fd3f66ef42471..92206e1accc31 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -36,7 +36,7 @@ _sync_ddp_if_available, ) from lightning.fabric.utilities.distributed import group as _group -from lightning.fabric.utilities.imports import _IS_WINDOWS +from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_3 from lightning.fabric.utilities.optimizer import _optimizers_to_device from lightning.fabric.utilities.seed import reset_seed from lightning.fabric.utilities.types import ReduceOp @@ -200,7 +200,10 @@ def setup_distributed(self) -> None: self.set_world_ranks() self._process_group_backend = self._get_process_group_backend() assert self.cluster_environment is not None - _init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout) + kwargs: dict[str, Any] = {"timeout": self._timeout} + if _TORCH_GREATER_EQUAL_2_3: + kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None + _init_dist_connection(self.cluster_environment, self._process_group_backend, **kwargs) def _get_process_group_backend(self) -> str: return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device) diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index 55ea354a5cb60..3fbd0f9cd5f0a 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -61,7 +61,7 @@ _sync_ddp_if_available, ) from lightning.fabric.utilities.distributed import group as _group -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2, _TORCH_GREATER_EQUAL_2_3 from lightning.fabric.utilities.init import _has_meta_device_parameters_or_buffers from lightning.fabric.utilities.load import _lazy_load, _materialize_tensors from lightning.fabric.utilities.optimizer import _optimizers_to_device @@ -260,7 +260,10 @@ def setup_environment(self) -> None: self._process_group_backend = self._get_process_group_backend() assert self.cluster_environment is not None - _init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout) + kwargs: dict[str, Any] = {"timeout": self._timeout} + if _TORCH_GREATER_EQUAL_2_3: + kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None + _init_dist_connection(self.cluster_environment, self._process_group_backend, **kwargs) # if 'device_mesh' in the `kwargs` is provided as a tuple, update it into the `DeviceMesh` object here if isinstance(self.kwargs.get("device_mesh"), tuple): diff --git a/src/lightning/pytorch/strategies/model_parallel.py b/src/lightning/pytorch/strategies/model_parallel.py index 82fec205af731..e0286dbe2e0e6 100644 --- a/src/lightning/pytorch/strategies/model_parallel.py +++ b/src/lightning/pytorch/strategies/model_parallel.py @@ -39,7 +39,7 @@ _sync_ddp_if_available, ) from lightning.fabric.utilities.distributed import group as _group -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3, _TORCH_GREATER_EQUAL_2_4 from lightning.fabric.utilities.init import _materialize_distributed_module from lightning.fabric.utilities.load import _METADATA_FILENAME from lightning.fabric.utilities.optimizer import _optimizers_to_device @@ -350,7 +350,10 @@ def _setup_distributed(self) -> None: self.set_world_ranks() self._process_group_backend = self._get_process_group_backend() assert self.cluster_environment is not None - _init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout) + kwargs: dict[str, Any] = {"timeout": self._timeout} + if _TORCH_GREATER_EQUAL_2_3: + kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None + _init_dist_connection(self.cluster_environment, self._process_group_backend, **kwargs) def _get_process_group_backend(self) -> str: return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device) diff --git a/tests/tests_fabric/strategies/test_ddp.py b/tests/tests_fabric/strategies/test_ddp.py index fa5c975228a5e..f302da5d1bc4f 100644 --- a/tests/tests_fabric/strategies/test_ddp.py +++ b/tests/tests_fabric/strategies/test_ddp.py @@ -25,6 +25,7 @@ from lightning.fabric.plugins.environments import LightningEnvironment from lightning.fabric.strategies import DDPStrategy from lightning.fabric.strategies.ddp import _DDPBackwardSyncControl +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3 from tests_fabric.helpers.runif import RunIf @@ -168,6 +169,52 @@ def test_set_timeout(init_process_group_mock): process_group_backend = strategy._get_process_group_backend() global_rank = strategy.cluster_environment.global_rank() world_size = strategy.cluster_environment.world_size() + kwargs = {} + if _TORCH_GREATER_EQUAL_2_3: + kwargs["device_id"] = strategy.root_device if strategy.root_device.type != "cpu" else None init_process_group_mock.assert_called_with( - process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta + process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta, **kwargs + ) + + +@mock.patch("torch.distributed.init_process_group") +def test_device_id_passed_for_cuda_devices(init_process_group_mock): + """Test that device_id is passed to init_process_group for CUDA devices but not for CPU.""" + # Test with CPU device - device_id should be None + cpu_strategy = DDPStrategy(parallel_devices=[torch.device("cpu")]) + cpu_strategy.cluster_environment = LightningEnvironment() + cpu_strategy.accelerator = Mock() + cpu_strategy.setup_environment() + + process_group_backend = cpu_strategy._get_process_group_backend() + global_rank = cpu_strategy.cluster_environment.global_rank() + world_size = cpu_strategy.cluster_environment.world_size() + kwargs = {} + if _TORCH_GREATER_EQUAL_2_3: + kwargs["device_id"] = cpu_strategy.root_device if cpu_strategy.root_device.type != "cpu" else None + init_process_group_mock.assert_called_with( + process_group_backend, rank=global_rank, world_size=world_size, timeout=cpu_strategy._timeout, **kwargs + ) + + init_process_group_mock.reset_mock() + + # Test with CUDA device - device_id should be the device + cuda_device = torch.device("cuda", 0) + cuda_strategy = DDPStrategy(parallel_devices=[cuda_device]) + cuda_strategy.cluster_environment = LightningEnvironment() + cuda_strategy.accelerator = Mock() + cuda_strategy.setup_environment() + + process_group_backend = cuda_strategy._get_process_group_backend() + global_rank = cuda_strategy.cluster_environment.global_rank() + world_size = cuda_strategy.cluster_environment.world_size() + kwargs = {} + if _TORCH_GREATER_EQUAL_2_3: + kwargs["device_id"] = cuda_strategy.root_device if cuda_strategy.root_device.type != "cpu" else None + init_process_group_mock.assert_called_with( + process_group_backend, + rank=global_rank, + world_size=world_size, + timeout=cuda_strategy._timeout, + **kwargs, ) diff --git a/tests/tests_fabric/strategies/test_fsdp.py b/tests/tests_fabric/strategies/test_fsdp.py index d5f82752a9176..6be379d36582c 100644 --- a/tests/tests_fabric/strategies/test_fsdp.py +++ b/tests/tests_fabric/strategies/test_fsdp.py @@ -31,7 +31,7 @@ _get_full_state_dict_context, _is_sharded_checkpoint, ) -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2, _TORCH_GREATER_EQUAL_2_3 def test_custom_mixed_precision(): @@ -381,8 +381,11 @@ def test_set_timeout(init_process_group_mock): process_group_backend = strategy._get_process_group_backend() global_rank = strategy.cluster_environment.global_rank() world_size = strategy.cluster_environment.world_size() + kwargs = {} + if _TORCH_GREATER_EQUAL_2_3: + kwargs["device_id"] = strategy.root_device if strategy.root_device.type != "cpu" else None init_process_group_mock.assert_called_with( - process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta + process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta, **kwargs ) diff --git a/tests/tests_fabric/strategies/test_model_parallel.py b/tests/tests_fabric/strategies/test_model_parallel.py index d044626bf8389..0e38f6e7777d1 100644 --- a/tests/tests_fabric/strategies/test_model_parallel.py +++ b/tests/tests_fabric/strategies/test_model_parallel.py @@ -25,6 +25,7 @@ from lightning.fabric.strategies import ModelParallelStrategy from lightning.fabric.strategies.fsdp import _is_sharded_checkpoint from lightning.fabric.strategies.model_parallel import _ParallelBackwardSyncControl +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3 from tests_fabric.helpers.runif import RunIf @@ -316,8 +317,11 @@ def test_set_timeout(init_process_group_mock, _): process_group_backend = strategy._get_process_group_backend() global_rank = strategy.cluster_environment.global_rank() world_size = strategy.cluster_environment.world_size() + kwargs = {} + if _TORCH_GREATER_EQUAL_2_3: + kwargs["device_id"] = strategy.root_device if strategy.root_device.type != "cpu" else None init_process_group_mock.assert_called_with( - process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta + process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta, **kwargs ) diff --git a/tests/tests_pytorch/strategies/test_ddp.py b/tests/tests_pytorch/strategies/test_ddp.py index 915e57440b40f..823d77d0d5848 100644 --- a/tests/tests_pytorch/strategies/test_ddp.py +++ b/tests/tests_pytorch/strategies/test_ddp.py @@ -20,6 +20,7 @@ from torch.nn.parallel import DistributedDataParallel from lightning.fabric.plugins.environments import LightningEnvironment +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3 from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.plugins import DoublePrecision, HalfPrecision, Precision @@ -132,8 +133,41 @@ def test_set_timeout(mock_init_process_group): process_group_backend = trainer.strategy._get_process_group_backend() global_rank = trainer.strategy.cluster_environment.global_rank() world_size = trainer.strategy.cluster_environment.world_size() + kwargs = {} + if _TORCH_GREATER_EQUAL_2_3: + kwargs["device_id"] = trainer.strategy.root_device if trainer.strategy.root_device.type != "cpu" else None mock_init_process_group.assert_called_with( - process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta + process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta, **kwargs + ) + + +@mock.patch("torch.distributed.init_process_group") +def test_device_id_passed_for_cuda_devices_pytorch(mock_init_process_group): + """Test that device_id is passed to init_process_group for CUDA devices but not for CPU.""" + # Test with CPU device - device_id should be None + model = BoringModel() + ddp_strategy = DDPStrategy() + trainer = Trainer( + max_epochs=1, + accelerator="cpu", + strategy=ddp_strategy, + ) + trainer.strategy.connect(model) + trainer.lightning_module.trainer = trainer + trainer.strategy.setup_environment() + + process_group_backend = trainer.strategy._get_process_group_backend() + global_rank = trainer.strategy.cluster_environment.global_rank() + world_size = trainer.strategy.cluster_environment.world_size() + kwargs = {} + if _TORCH_GREATER_EQUAL_2_3: + kwargs["device_id"] = trainer.strategy.root_device if trainer.strategy.root_device.type != "cpu" else None + mock_init_process_group.assert_called_with( + process_group_backend, + rank=global_rank, + world_size=world_size, + timeout=trainer.strategy._timeout, + **kwargs, ) diff --git a/tests/tests_pytorch/strategies/test_ddp_integration.py b/tests/tests_pytorch/strategies/test_ddp_integration.py index 048403366ebc7..fc3a8cfebbac0 100644 --- a/tests/tests_pytorch/strategies/test_ddp_integration.py +++ b/tests/tests_pytorch/strategies/test_ddp_integration.py @@ -66,7 +66,7 @@ def test_multi_gpu_model_ddp_fit_test(tmp_path): assert out["test_acc"] > 0.7 -@RunIf(skip_windows=True) +@RunIf(skip_windows=True, max_torch="2.7") @mock.patch("torch.cuda.set_device") @mock.patch("lightning.pytorch.accelerators.cuda._check_cuda_matmul_precision") @mock.patch("lightning.pytorch.accelerators.cuda._clear_cuda_memory") diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index 560ab19f823ca..f7c15b5930be8 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -18,7 +18,7 @@ from lightning.fabric.plugins.environments import LightningEnvironment from lightning.fabric.strategies.fsdp import _is_sharded_checkpoint -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2, _TORCH_GREATER_EQUAL_2_3 from lightning.fabric.utilities.load import _load_distributed_checkpoint from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelCheckpoint @@ -532,8 +532,11 @@ def test_set_timeout(init_process_group_mock): process_group_backend = strategy._get_process_group_backend() global_rank = strategy.cluster_environment.global_rank() world_size = strategy.cluster_environment.world_size() + kwargs = {} + if _TORCH_GREATER_EQUAL_2_3: + kwargs["device_id"] = strategy.root_device if strategy.root_device.type != "cpu" else None init_process_group_mock.assert_called_with( - process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta + process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta, **kwargs ) diff --git a/tests/tests_pytorch/strategies/test_model_parallel.py b/tests/tests_pytorch/strategies/test_model_parallel.py index 86a95944ac20d..c803c10afa4b4 100644 --- a/tests/tests_pytorch/strategies/test_model_parallel.py +++ b/tests/tests_pytorch/strategies/test_model_parallel.py @@ -22,6 +22,7 @@ import torch.nn as nn from lightning.fabric.strategies.model_parallel import _is_sharded_checkpoint +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3 from lightning.pytorch import LightningModule from lightning.pytorch.plugins.environments import LightningEnvironment from lightning.pytorch.strategies import ModelParallelStrategy @@ -202,8 +203,11 @@ def test_set_timeout(init_process_group_mock, _): process_group_backend = strategy._get_process_group_backend() global_rank = strategy.cluster_environment.global_rank() world_size = strategy.cluster_environment.world_size() + kwargs = {} + if _TORCH_GREATER_EQUAL_2_3: + kwargs["device_id"] = strategy.root_device if strategy.root_device.type != "cpu" else None init_process_group_mock.assert_called_with( - process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta + process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta, **kwargs )