diff --git a/deepspeed/launcher/multinode_runner.py b/deepspeed/launcher/multinode_runner.py index fe2fa1b476be6..3e0c9d8a1652a 100644 --- a/deepspeed/launcher/multinode_runner.py +++ b/deepspeed/launcher/multinode_runner.py @@ -134,12 +134,31 @@ def name(self): def validate_args(self): super().validate_args() + + # Validate and set MPI environment variables + self._setup_mpi_environment() + #TODO: Allow for include/exclude at node-level but not gpu-level if self.args.include != "" or self.args.exclude != "": raise ValueError(f"{self.name} backend does not support worker include/exclusion") if self.args.num_nodes != -1 or self.args.num_gpus != -1: raise ValueError(f"{self.name} backend does not support limiting num nodes/gpus") + def _setup_mpi_environment(self): + """Sets up MPI-related environment variables or raises an error if they're missing.""" + + required_vars = ['OMPI_COMM_WORLD_LOCAL_RANK', 'OMPI_COMM_WORLD_RANK', 'OMPI_COMM_WORLD_SIZE'] + + # Check if all these are present + if not all(var in os.environ for var in required_vars): + raise EnvironmentError("MPI environment variables are not set. " + "Ensure you are running the script with an MPI-compatible launcher.") + + # Now safe to read all + os.environ['LOCAL_RANK'] = os.environ['OMPI_COMM_WORLD_LOCAL_RANK'] + os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK'] + os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE'] + def get_cmd(self, environment, active_resources): total_process_count = sum(self.resource_pool.values()) diff --git a/tests/unit/launcher/test_multinode_runner.py b/tests/unit/launcher/test_multinode_runner.py index a3b50a4c90ab2..801d4223afcec 100644 --- a/tests/unit/launcher/test_multinode_runner.py +++ b/tests/unit/launcher/test_multinode_runner.py @@ -19,6 +19,14 @@ def runner_info(): return env, hosts, world_info, args +@pytest.fixture +def mock_mpi_env(monkeypatch): + # Provide the 3 required MPI variables: + monkeypatch.setenv('OMPI_COMM_WORLD_LOCAL_RANK', '0') + monkeypatch.setenv('OMPI_COMM_WORLD_RANK', '0') + monkeypatch.setenv('OMPI_COMM_WORLD_SIZE', '1') + + def test_pdsh_runner(runner_info): env, resource_pool, world_info, args = runner_info runner = mnrunner.PDSHRunner(args, world_info) @@ -27,7 +35,7 @@ def test_pdsh_runner(runner_info): assert env['PDSH_RCMD_TYPE'] == 'ssh' -def test_openmpi_runner(runner_info): +def test_openmpi_runner(runner_info, mock_mpi_env): env, resource_pool, world_info, args = runner_info runner = mnrunner.OpenMPIRunner(args, world_info, resource_pool) cmd = runner.get_cmd(env, resource_pool) @@ -35,26 +43,79 @@ def test_openmpi_runner(runner_info): assert 'eth0' in cmd -def test_btl_nic_openmpi_runner(runner_info): +def test_btl_nic_openmpi_runner(runner_info, mock_mpi_env): env, resource_pool, world_info, _ = runner_info args = parse_args(['--launcher_arg', '-mca btl_tcp_if_include eth1', 'test_launcher.py']) - runner = mnrunner.OpenMPIRunner(args, world_info, resource_pool) cmd = runner.get_cmd(env, resource_pool) assert 'eth0' not in cmd assert 'eth1' in cmd -def test_btl_nic_two_dashes_openmpi_runner(runner_info): +def test_btl_nic_two_dashes_openmpi_runner(runner_info, mock_mpi_env): env, resource_pool, world_info, _ = runner_info args = parse_args(['--launcher_arg', '--mca btl_tcp_if_include eth1', 'test_launcher.py']) - runner = mnrunner.OpenMPIRunner(args, world_info, resource_pool) cmd = runner.get_cmd(env, resource_pool) assert 'eth0' not in cmd assert 'eth1' in cmd +def test_setup_mpi_environment_success(): + """Test that _setup_mpi_environment correctly sets environment variables when MPI variables exist.""" + os.environ['OMPI_COMM_WORLD_LOCAL_RANK'] = '0' + os.environ['OMPI_COMM_WORLD_RANK'] = '1' + os.environ['OMPI_COMM_WORLD_SIZE'] = '2' + + args = parse_args(['--launcher_arg', '--mca btl_tcp_if_include eth1', 'test_launcher.py']) + + runner = mnrunner.OpenMPIRunner(args, None, None) + # Set up the MPI environment + runner._setup_mpi_environment() + + assert os.environ['LOCAL_RANK'] == '0' + assert os.environ['RANK'] == '1' + assert os.environ['WORLD_SIZE'] == '2' + + # Clean up environment + del os.environ['OMPI_COMM_WORLD_LOCAL_RANK'] + del os.environ['OMPI_COMM_WORLD_RANK'] + del os.environ['OMPI_COMM_WORLD_SIZE'] + del os.environ['LOCAL_RANK'] + del os.environ['RANK'] + del os.environ['WORLD_SIZE'] + + +def test_setup_mpi_environment_missing_variables(): + """Test that _setup_mpi_environment raises an EnvironmentError when MPI variables are missing.""" + + # Clear relevant environment variables + os.environ.pop('OMPI_COMM_WORLD_LOCAL_RANK', None) + os.environ.pop('OMPI_COMM_WORLD_RANK', None) + os.environ.pop('OMPI_COMM_WORLD_SIZE', None) + + args = parse_args(['--launcher_arg', '--mca btl_tcp_if_include eth1', 'test_launcher.py']) + + with pytest.raises(EnvironmentError, match="MPI environment variables are not set"): + mnrunner.OpenMPIRunner(args, None, None) + + +def test_setup_mpi_environment_fail(): + """Test that _setup_mpi_environment fails if only partial MPI variables are provided.""" + os.environ['OMPI_COMM_WORLD_LOCAL_RANK'] = '0' + os.environ.pop('OMPI_COMM_WORLD_RANK', None) # missing variable + os.environ['OMPI_COMM_WORLD_SIZE'] = '2' + + args = parse_args(['--launcher_arg', '--mca btl_tcp_if_include eth1', 'test_launcher.py']) + + with pytest.raises(EnvironmentError, match="MPI environment variables are not set"): + runner = mnrunner.OpenMPIRunner(args, None, None) + + # Clean up environment + del os.environ['OMPI_COMM_WORLD_LOCAL_RANK'] + del os.environ['OMPI_COMM_WORLD_SIZE'] + + def test_mpich_runner(runner_info): env, resource_pool, world_info, args = runner_info runner = mnrunner.MPICHRunner(args, world_info, resource_pool)