Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid security issues of subprocess shell #6498

Merged
merged 10 commits into from
Sep 11, 2024
5 changes: 4 additions & 1 deletion bin/ds_bench
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ import sys
required_env = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
if not all(map(lambda v: v in os.environ, required_env)):
import subprocess
subprocess.run("deepspeed $(which ds_bench) " + " ".join(sys.argv[1:]), shell=True)
r = subprocess.check_output(["which", "ds_bench"])
ds_bench_bin = r.decode('utf-8').strip()
safe_cmd = ["deepspeed", ds_bench_bin] + sys.argv[1:]
subprocess.run(safe_cmd)
else:
args = benchmark_parser().parse_args()
rank = args.local_rank
Expand Down
5 changes: 3 additions & 2 deletions csrc/aio/py_test/ds_aio_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Functionality of swapping tensors to/from (NVMe) storage devices.
"""
import subprocess
import shlex


class Job(object):
Expand Down Expand Up @@ -39,10 +40,10 @@ def close_output_file(self):


def run_job(job):
args = ' '.join(job.cmd())
args = shlex.split(' '.join(job.cmd()))
print(f'args = {args}')
job.open_output_file()
proc = subprocess.run(args=args, shell=True, stdout=job.get_stdout(), stderr=job.get_stderr(), cwd=job.get_cwd())
proc = subprocess.run(args=args, stdout=job.get_stdout(), stderr=job.get_stderr(), cwd=job.get_cwd())
job.close_output_file()
assert proc.returncode == 0, \
f"This command failed: {job.cmd()}"
5 changes: 3 additions & 2 deletions deepspeed/comm/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,8 +697,9 @@ def mpi_discovery(distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT, verbose=True)

master_addr = None
if rank == 0:
hostname_cmd = ["hostname -I"]
result = subprocess.check_output(hostname_cmd, shell=True)
import shlex
hostname_cmd = shlex.split("hostname -I")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may not need shlex.split when the command has no placeholder for injection. The same goes for some other fixes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For hostname -I, using socket.gethostname and socket.gethostbyname_ex can be safer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tohtana, thanks for the feedback. I will remove the shlex.split() here. However, I am not getting the same output if I use socket.gethostbyname_ex(socket.gethostname())

Copy link
Contributor Author

@tjruwase tjruwase Sep 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I think the shlex.split() is still need to pass args as a list since we are removing shell=True. Alternatively, I could manually construct the list. However, I think I will keep the shlex.split() to future-proof for arg changes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about this code? This is a general proposal to make the code robust against malicious modifications of system commands. But I don't think this is crucial because it won't be a typical attack that can harm our users. We can just keep hostname -I if this doesn't work.

>>> import socket
>>> hostname = socket.gethostname()
>>> ip_addresses = socket.gethostbyname_ex(hostname)[2]
>>> ip_addresses
['172.17.0.2']

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that as well, it looks like we would need to modify the command to be:

>>> socket.gethostbyname_ex(socket.gethostname()+".local")[2][0]

Thought it is slower to run on my system with that change.

I wonder if hosts or DNS in your system (Windows?) has [HOSTNAME].local but it doesn't work on my env.

Copy link
Contributor Author

@tjruwase tjruwase Sep 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

>>> socket.gethostbyname_ex(socket.gethostname()+".local")[2][0]

This works and gives correct/expected results on my Linux lambda box.

@tohtana, do you mean this does not work for you?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't work on my wsl.

$ python -c 'import socket; socket.gethostbyname_ex(socket.gethostname()+".local")[2][0]'
Traceback (most recent call last):
  File "<string>", line 1, in <module>
socket.gaierror: [Errno -5] No address associated with hostname

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, I'd say lets leave it as hostname -I for now, and we can make another PR to update where we can more strenuously test Windows and other OSs?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree to address this in another PR. This one is urgent and focuses on security.
But I'm not sure it is a good idea to get the IP from the first entry from hostname -I. It is not simple to control it even for the administrator. It is easier to tell users to configure /etc/hosts properly.
After a quick look on the usage, probably it can also be a hostname, not an IP.

result = subprocess.check_output(hostname_cmd)
master_addr = result.decode('utf-8').split()[0]
master_addr = comm.bcast(master_addr, root=0)

Expand Down
4 changes: 3 additions & 1 deletion deepspeed/elasticity/elastic_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def _set_master_addr_port(store: Store,

if master_addr is None:
# master_addr = _get_fq_hostname()
result = subprocess.check_output("hostname -I", shell=True)
import shlex
safe_cmd = shlex.split("hostname -I")
result = subprocess.check_output(safe_cmd)
master_addr = result.decode('utf-8').split()[0]

store.set("MASTER_ADDR", master_addr.encode(encoding="UTF-8"))
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/launcher/multinode_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def backend_exists(self):
if not mpiname_exists:
warnings.warn("mpiname does not exist, mvapich is not installed properly")
else:
results = subprocess.check_output('mpiname', shell=True)
results = subprocess.check_output(['mpiname'])
mpiname_results = results.decode('utf-8').strip()
if "MVAPICH2-GDR" in mpiname_results:
exists = True
Expand Down
8 changes: 5 additions & 3 deletions deepspeed/launcher/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from copy import deepcopy
import signal
import time
import shlex

from .multinode_runner import PDSHRunner, OpenMPIRunner, MVAPICHRunner, SlurmRunner, MPICHRunner, IMPIRunner
from .constants import PDSH_LAUNCHER, OPENMPI_LAUNCHER, MVAPICH_LAUNCHER, SLURM_LAUNCHER, MPICH_LAUNCHER, IMPI_LAUNCHER
Expand Down Expand Up @@ -445,7 +446,8 @@ def main(args=None):
if args.ssh_port is not None:
ssh_check_cmd += f"-p {args.ssh_port} "
ssh_check_cmd += f"{first_host} hostname"
subprocess.check_call(ssh_check_cmd, stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL, shell=True)
safe_ssh_cmd = shlex.split(ssh_check_cmd)
subprocess.check_call(safe_ssh_cmd, stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL)
except subprocess.CalledProcessError:
raise RuntimeError(
f"Using hostfile at {args.hostfile} but host={first_host} was not reachable via ssh. If you are running with a single node please remove {args.hostfile} or setup passwordless ssh."
Expand All @@ -458,9 +460,9 @@ def main(args=None):
if args.ssh_port is not None:
ssh_check_cmd += f" -p {args.ssh_port}"
ssh_check_cmd += f" {first_host} hostname -I"
hostname_cmd = [ssh_check_cmd]
hostname_cmd = shlex.split(ssh_check_cmd)
try:
result = subprocess.check_output(hostname_cmd, shell=True)
result = subprocess.check_output(hostname_cmd)
except subprocess.CalledProcessError as err:
logger.error(
"Unable to detect suitable master address via `hostname -I`, please manually specify one via --master_addr"
Expand Down
8 changes: 5 additions & 3 deletions op_builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,8 @@ def get_rocm_gpu_arch():
rocm_info = Path("rocminfo")
rocm_gpu_arch_cmd = str(rocm_info) + " | grep -o -m 1 'gfx.*'"
try:
result = subprocess.check_output(rocm_gpu_arch_cmd, shell=True)
safe_cmd = shlex.split(rocm_gpu_arch_cmd)
result = subprocess.check_output(safe_cmd)
rocm_gpu_arch = result.decode('utf-8').strip()
except subprocess.CalledProcessError:
rocm_gpu_arch = ""
Expand All @@ -271,7 +272,8 @@ def get_rocm_wavefront_size():
rocm_wavefront_size_cmd = str(
rocm_info) + " | grep -Eo -m1 'Wavefront Size:[[:space:]]+[0-9]+' | grep -Eo '[0-9]+'"
try:
result = subprocess.check_output(rocm_wavefront_size_cmd, shell=True)
safe_cmd = shlex.split(rocm_wavefront_size_cmd)
result = subprocess.check_output(rocm_wavefront_size_cmd)
rocm_wavefront_size = result.decode('utf-8').strip()
except subprocess.CalledProcessError:
rocm_wavefront_size = "32"
Expand Down Expand Up @@ -432,7 +434,7 @@ def _backup_cpuinfo(self):
"to detect the CPU architecture. 'lscpu' does not appear to exist on "
"your system, will fall back to use -march=native and non-vectorized execution.")
return None
result = subprocess.check_output('lscpu', shell=True)
result = subprocess.check_output(['lscpu'])
result = result.decode('utf-8').strip().lower()

cpu_info = {}
Expand Down
14 changes: 8 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from setuptools.command import egg_info
import time
import typing
import shlex

torch_available = True
try:
Expand Down Expand Up @@ -157,10 +158,11 @@ def get_env_if_set(key, default: typing.Any = ""):

def command_exists(cmd):
if sys.platform == "win32":
result = subprocess.Popen(f'{cmd}', stdout=subprocess.PIPE, shell=True)
safe_cmd = shlex.split(f'{cmd}')
result = subprocess.Popen(safe_cmd, stdout=subprocess.PIPE)
return result.wait() == 1
else:
safe_cmd = ["bash", "-c", f"type {cmd}"]
safe_cmd = shlex.split(f"bash -c type {cmd}")
result = subprocess.Popen(safe_cmd, stdout=subprocess.PIPE)
return result.wait() == 0

Expand Down Expand Up @@ -200,13 +202,13 @@ def op_enabled(op_name):
print(f'Install Ops={install_ops}')

# Write out version/git info.
git_hash_cmd = "git rev-parse --short HEAD"
git_branch_cmd = "git rev-parse --abbrev-ref HEAD"
git_hash_cmd = shlex.split("bash -c git rev-parse --short HEAD")
git_branch_cmd = shlex.split("bash -c git rev-parse --abbrev-ref HEAD")
if command_exists('git') and not is_env_set('DS_BUILD_STRING'):
try:
result = subprocess.check_output(git_hash_cmd, shell=True)
result = subprocess.check_output(git_hash_cmd)
git_hash = result.decode('utf-8').strip()
result = subprocess.check_output(git_branch_cmd, shell=True)
result = subprocess.check_output(git_branch_cmd)
git_branch = result.decode('utf-8').strip()
except subprocess.CalledProcessError:
git_hash = "unknown"
Expand Down
9 changes: 5 additions & 4 deletions tests/model/BingBertSquad/BingBertSquad_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import subprocess
import os
import time
import shlex


class BaseTestCase(unittest.TestCase):
Expand Down Expand Up @@ -40,18 +41,18 @@ def ensure_directory_exists(self, filename):
os.makedirs(dirname)

def clean_test_env(self):
cmd = "dlts_ssh pkill -9 -f /usr/bin/python"
cmd = shlex.split("dlts_ssh pkill -9 -f /usr/bin/python")
print(cmd)
subprocess.run(cmd, shell=True, check=False, executable='/bin/bash')
subprocess.run(cmd, check=False, executable='/bin/bash')
time.sleep(20)

def run_BingBertSquad_test(self, test_config, output):
ds_flag = " -d --deepspeed_config " + test_config["json"] if test_config["deepspeed"] else " "
other_args = " " + test_config["other_args"] if "other_args" in test_config else " "

cmd = "./run_BingBertSquad_sanity.sh -e 1 -g {0} {1} {2}".format(test_config["gpus"], other_args, ds_flag)

cmd = shlex.split(cmd)
self.ensure_directory_exists(output)
with open(output, "w") as f:
print(cmd)
subprocess.run(cmd, shell=True, check=False, executable='/bin/bash', stdout=f, stderr=f)
subprocess.run(cmd, check=False, executable='/bin/bash', stdout=f, stderr=f)
21 changes: 11 additions & 10 deletions tests/model/Megatron_GPT2/run_checkpoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import subprocess
import os
import re
import shlex
from .test_common import BaseTestCase

LAYERS = 2
Expand All @@ -18,9 +19,9 @@


def remove_file(test_id, filename):
cmd = f"if [ -f {filename} ] ; then rm -v {filename}; fi"
cmd = shlex.split(f"if [ -f {filename} ] ; then rm -v {filename}; fi")
print(f"{test_id} cmd: {cmd}")
subprocess.run(cmd, shell=True, check=False, executable='/bin/bash')
subprocess.run(cmd, check=False, executable='/bin/bash')


def grep_loss_from_file(file_name):
Expand Down Expand Up @@ -451,9 +452,9 @@ def run_test(self, test_config, r_tol):
checkpoint_name = test_config["checkpoint_name"]
#---------------remove old checkpoint---------------#
try:
cmd = f"rm -rf {checkpoint_name}"
cmd = shlex.split(f"rm -rf {checkpoint_name}")
print(f"{self.id()} cmd: {cmd}")
subprocess.run(cmd, shell=True, check=False, executable='/bin/bash')
subprocess.run(cmd, check=False, executable='/bin/bash')
except:
print("No old checkpoint")

Expand All @@ -474,8 +475,8 @@ def run_test(self, test_config, r_tol):

# remove previous test log
try:
cmd = f"rm {base_file}"
subprocess.run(cmd, shell=True, check=False, executable='/bin/bash')
cmd = shlex.split(f"rm {base_file}")
subprocess.run(cmd, check=False, executable='/bin/bash')
except:
print(f"{self.id()} No old logs")

Expand All @@ -489,9 +490,9 @@ def run_test(self, test_config, r_tol):

# set checkpoint load iteration
try:
cmd = f"echo {checkpoint_interval} > {checkpoint_name}/latest_checkpointed_iteration.txt"
cmd = shlex.split(f"echo {checkpoint_interval} > {checkpoint_name}/latest_checkpointed_iteration.txt")
print(f"{self.id()} running cmd: {cmd}")
subprocess.run(cmd, shell=True, check=False, executable='/bin/bash')
subprocess.run(cmd, check=False, executable='/bin/bash')
except:
print(f"{self.id()} Failed to update the checkpoint iteration file")
return False
Expand All @@ -506,8 +507,8 @@ def run_test(self, test_config, r_tol):

# remove previous test log
try:
cmd = f"rm {test_file}"
subprocess.run(cmd, shell=True, check=False, executable='/bin/bash')
cmd = shlex.split(f"rm {test_file}")
subprocess.run(cmd, check=False, executable='/bin/bash')
except:
print(f"{self.id()} no previous logs for")
self.run_gpt2_test(test_config, test_file)
Expand Down
9 changes: 5 additions & 4 deletions tests/model/Megatron_GPT2/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import subprocess
import os
import time
import shlex


class BaseTestCase(unittest.TestCase):
Expand Down Expand Up @@ -46,9 +47,9 @@ def ensure_directory_exists(self, filename):
os.makedirs(dirname)

def clean_test_env(self):
cmd = "dlts_ssh pkill -9 -f /usr/bin/python"
cmd = shlex.split("dlts_ssh pkill -9 -f /usr/bin/python")
print(cmd)
subprocess.run(cmd, shell=True, check=False, executable='/bin/bash')
subprocess.run(cmd, check=False, executable='/bin/bash')
time.sleep(20)

def run_gpt2_test(self, test_config, output):
Expand All @@ -60,8 +61,8 @@ def run_gpt2_test(self, test_config, output):
test_config["mp"], test_config["gpus"], test_config["nodes"], test_config["bs"], test_config["steps"],
test_config["layers"], test_config["hidden_size"], test_config["seq_length"], test_config["heads"],
ckpt_num, other_args, ds_flag)

cmd = shlex.split(cmd)
self.ensure_directory_exists(output)
with open(output, "w") as f:
print(cmd)
subprocess.run(cmd, shell=True, check=False, executable='/bin/bash', stdout=f, stderr=f)
subprocess.run(cmd, check=False, executable='/bin/bash', stdout=f, stderr=f)
18 changes: 15 additions & 3 deletions tests/unit/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,20 @@ def get_master_port(base_port=29500, port_range_size=1000):
raise IOError('no free ports')


def _get_cpu_socket_count():
import shlex
p1 = subprocess.Popen(shlex.split("cat /proc/cpuinfo"), stdout=subprocess.PIPE)
p2 = subprocess.Popen(["grep", "physical id"], stdin=p1.stdout, stdout=subprocess.PIPE)
p1.stdout.close()
p3 = subprocess.Popen(shlex.split("sort -u"), stdin=p2.stdout, stdout=subprocess.PIPE)
p2.stdout.close()
p4 = subprocess.Popen(shlex.split("wc -l"), stdin=p3.stdout, stdout=subprocess.PIPE)
p3.stdout.close()
r = int(p4.communicate()[0])
p4.stdout.close()
return r


def set_accelerator_visible():
cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
xdist_worker_id = get_xdist_worker_id()
Expand Down Expand Up @@ -95,9 +109,7 @@ def set_accelerator_visible():
num_accelerators = int(npu_smi.decode('utf-8').strip().split('\n')[0].split(':')[1].strip())
else:
assert get_accelerator().device_name() == 'cpu'
cpu_sockets = int(
subprocess.check_output('cat /proc/cpuinfo | grep "physical id" | sort -u | wc -l', shell=True))
num_accelerators = cpu_sockets
num_accelerators = _get_cpu_socket_count()

if isinstance(num_accelerators, list):
cuda_visible = ",".join(num_accelerators)
Expand Down
Loading