Skip to content

Commit fc9a5b7

Browse files
NielsZeilemakerashb
authored andcommitted
[AIRFLOW-1762] Implement key_file support in ssh_hook create_tunnel
Switched to using sshtunnel package instead of popen approach Closes #3473 from NielsZeilemaker/ssh_hook
1 parent 23be2a3 commit fc9a5b7

File tree

5 files changed

+354
-259
lines changed

5 files changed

+354
-259
lines changed

airflow/contrib/hooks/ssh_hook.py

+148-142
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@
1919

2020
import getpass
2121
import os
22+
import warnings
2223

2324
import paramiko
2425
from paramiko.config import SSH_PORT
26+
from sshtunnel import SSHTunnelForwarder
2527

26-
from contextlib import contextmanager
2728
from airflow.exceptions import AirflowException
2829
from airflow.hooks.base_hook import BaseHook
2930
from airflow.utils.log.logging_mixin import LoggingMixin
@@ -62,7 +63,7 @@ def __init__(self,
6263
username=None,
6364
password=None,
6465
key_file=None,
65-
port=SSH_PORT,
66+
port=None,
6667
timeout=10,
6768
keepalive_interval=30
6869
):
@@ -72,162 +73,167 @@ def __init__(self,
7273
self.username = username
7374
self.password = password
7475
self.key_file = key_file
76+
self.port = port
7577
self.timeout = timeout
7678
self.keepalive_interval = keepalive_interval
79+
7780
# Default values, overridable from Connection
7881
self.compress = True
7982
self.no_host_key_check = True
83+
self.host_proxy = None
84+
85+
# Placeholder for deprecated __enter__
8086
self.client = None
81-
self.port = port
87+
88+
# Use connection to override defaults
89+
if self.ssh_conn_id is not None:
90+
conn = self.get_connection(self.ssh_conn_id)
91+
if self.username is None:
92+
self.username = conn.login
93+
if self.password is None:
94+
self.password = conn.password
95+
if self.remote_host is None:
96+
self.remote_host = conn.host
97+
if self.port is None:
98+
self.port = conn.port
99+
if conn.extra is not None:
100+
extra_options = conn.extra_dejson
101+
self.key_file = extra_options.get("key_file")
102+
103+
if "timeout" in extra_options:
104+
self.timeout = int(extra_options["timeout"], 10)
105+
106+
if "compress" in extra_options\
107+
and str(extra_options["compress"]).lower() == 'false':
108+
self.compress = False
109+
if "no_host_key_check" in extra_options\
110+
and\
111+
str(extra_options["no_host_key_check"]).lower() == 'false':
112+
self.no_host_key_check = False
113+
114+
if not self.remote_host:
115+
raise AirflowException("Missing required param: remote_host")
116+
117+
# Auto detecting username values from system
118+
if not self.username:
119+
self.log.debug(
120+
"username to ssh to host: %s is not specified for connection id"
121+
" %s. Using system's default provided by getpass.getuser()",
122+
self.remote_host, self.ssh_conn_id
123+
)
124+
self.username = getpass.getuser()
125+
126+
user_ssh_config_filename = os.path.expanduser('~/.ssh/config')
127+
if os.path.isfile(user_ssh_config_filename):
128+
ssh_conf = paramiko.SSHConfig()
129+
ssh_conf.parse(open(user_ssh_config_filename))
130+
host_info = ssh_conf.lookup(self.remote_host)
131+
if host_info and host_info.get('proxycommand'):
132+
self.host_proxy = paramiko.ProxyCommand(host_info.get('proxycommand'))
133+
134+
if not (self.password or self.key_file):
135+
if host_info and host_info.get('identityfile'):
136+
self.key_file = host_info.get('identityfile')[0]
137+
138+
self.port = self.port or SSH_PORT
82139

83140
def get_conn(self):
84-
if not self.client:
85-
self.log.debug('Creating SSH client for conn_id: %s', self.ssh_conn_id)
86-
if self.ssh_conn_id is not None:
87-
conn = self.get_connection(self.ssh_conn_id)
88-
if self.username is None:
89-
self.username = conn.login
90-
if self.password is None:
91-
self.password = conn.password
92-
if self.remote_host is None:
93-
self.remote_host = conn.host
94-
if conn.port is not None:
95-
self.port = conn.port
96-
if conn.extra is not None:
97-
extra_options = conn.extra_dejson
98-
self.key_file = extra_options.get("key_file")
99-
100-
if "timeout" in extra_options:
101-
self.timeout = int(extra_options["timeout"], 10)
102-
103-
if "compress" in extra_options \
104-
and str(extra_options["compress"]).lower() == 'false':
105-
self.compress = False
106-
if "no_host_key_check" in extra_options \
107-
and \
108-
str(extra_options["no_host_key_check"]).lower() == 'false':
109-
self.no_host_key_check = False
110-
111-
if not self.remote_host:
112-
raise AirflowException("Missing required param: remote_host")
113-
114-
# Auto detecting username values from system
115-
if not self.username:
116-
self.log.debug(
117-
"username to ssh to host: %s is not specified for connection id"
118-
" %s. Using system's default provided by getpass.getuser()",
119-
self.remote_host, self.ssh_conn_id
120-
)
121-
self.username = getpass.getuser()
122-
123-
host_proxy = None
124-
user_ssh_config_filename = os.path.expanduser('~/.ssh/config')
125-
if os.path.isfile(user_ssh_config_filename):
126-
ssh_conf = paramiko.SSHConfig()
127-
ssh_conf.parse(open(user_ssh_config_filename))
128-
host_info = ssh_conf.lookup(self.remote_host)
129-
if host_info and host_info.get('proxycommand'):
130-
host_proxy = paramiko.ProxyCommand(host_info.get('proxycommand'))
131-
132-
if not (self.password or self.key_file):
133-
if host_info and host_info.get('identityfile'):
134-
self.key_file = host_info.get('identityfile')[0]
135-
136-
try:
137-
client = paramiko.SSHClient()
138-
client.load_system_host_keys()
139-
if self.no_host_key_check:
140-
# Default is RejectPolicy
141-
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
142-
143-
if self.password and self.password.strip():
144-
client.connect(hostname=self.remote_host,
145-
username=self.username,
146-
password=self.password,
147-
timeout=self.timeout,
148-
compress=self.compress,
149-
port=self.port,
150-
sock=host_proxy)
151-
else:
152-
client.connect(hostname=self.remote_host,
153-
username=self.username,
154-
key_filename=self.key_file,
155-
timeout=self.timeout,
156-
compress=self.compress,
157-
port=self.port,
158-
sock=host_proxy)
159-
160-
if self.keepalive_interval:
161-
client.get_transport().set_keepalive(self.keepalive_interval)
162-
163-
self.client = client
164-
except paramiko.AuthenticationException as auth_error:
165-
self.log.error(
166-
"Auth failed while connecting to host: %s, error: %s",
167-
self.remote_host, auth_error
168-
)
169-
except paramiko.SSHException as ssh_error:
170-
self.log.error(
171-
"Failed connecting to host: %s, error: %s",
172-
self.remote_host, ssh_error
173-
)
174-
except Exception as error:
175-
self.log.error(
176-
"Error connecting to host: %s, error: %s",
177-
self.remote_host, error
178-
)
179-
return self.client
180-
181-
@contextmanager
182-
def create_tunnel(self, local_port, remote_port=None, remote_host="localhost"):
183141
"""
184-
Creates a tunnel between two hosts. Like ssh -L <LOCAL_PORT>:host:<REMOTE_PORT>.
185-
Remember to close() the returned "tunnel" object in order to clean up
186-
after yourself when you are done with the tunnel.
142+
Opens a ssh connection to the remote host.
187143
188-
:param local_port:
189-
:type local_port: int
190-
:param remote_port:
191-
:type remote_port: int
192-
:param remote_host:
193-
:type remote_host: str
194-
:return:
144+
:return paramiko.SSHClient object
195145
"""
196146

197-
import subprocess
198-
# this will ensure the connection to the ssh.remote_host from where the tunnel
199-
# is getting created
200-
self.get_conn()
201-
202-
tunnel_host = "{0}:{1}:{2}".format(local_port, remote_host, remote_port)
203-
204-
ssh_cmd = ["ssh", "{0}@{1}".format(self.username, self.remote_host),
205-
"-o", "ControlMaster=no",
206-
"-o", "UserKnownHostsFile=/dev/null",
207-
"-o", "StrictHostKeyChecking=no"]
208-
209-
ssh_tunnel_cmd = ["-L", tunnel_host,
210-
"echo -n ready && cat"
211-
]
212-
213-
ssh_cmd += ssh_tunnel_cmd
214-
self.log.debug("Creating tunnel with cmd: %s", ssh_cmd)
215-
216-
proc = subprocess.Popen(ssh_cmd,
217-
stdin=subprocess.PIPE,
218-
stdout=subprocess.PIPE,
219-
close_fds=True)
220-
ready = proc.stdout.read(5)
221-
assert ready == b"ready", \
222-
"Did not get 'ready' from remote, got '{0}' instead".format(ready)
223-
yield
224-
proc.communicate()
225-
assert proc.returncode == 0, \
226-
"Tunnel process did unclean exit (returncode {}".format(proc.returncode)
147+
self.log.debug('Creating SSH client for conn_id: %s', self.ssh_conn_id)
148+
client = paramiko.SSHClient()
149+
client.load_system_host_keys()
150+
if self.no_host_key_check:
151+
# Default is RejectPolicy
152+
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
153+
154+
if self.password and self.password.strip():
155+
client.connect(hostname=self.remote_host,
156+
username=self.username,
157+
password=self.password,
158+
key_filename=self.key_file,
159+
timeout=self.timeout,
160+
compress=self.compress,
161+
port=self.port,
162+
sock=self.host_proxy)
163+
else:
164+
client.connect(hostname=self.remote_host,
165+
username=self.username,
166+
key_filename=self.key_file,
167+
timeout=self.timeout,
168+
compress=self.compress,
169+
port=self.port,
170+
sock=self.host_proxy)
171+
172+
if self.keepalive_interval:
173+
client.get_transport().set_keepalive(self.keepalive_interval)
174+
175+
self.client = client
176+
return client
227177

228178
def __enter__(self):
179+
warnings.warn('The contextmanager of SSHHook is deprecated.'
180+
'Please use get_conn() as a contextmanager instead.'
181+
'This method will be removed in Airflow 2.0',
182+
category=DeprecationWarning)
229183
return self
230184

231185
def __exit__(self, exc_type, exc_val, exc_tb):
232186
if self.client is not None:
233187
self.client.close()
188+
self.client = None
189+
190+
def get_tunnel(self, remote_port, remote_host="localhost", local_port=None):
191+
"""
192+
Creates a tunnel between two hosts. Like ssh -L <LOCAL_PORT>:host:<REMOTE_PORT>.
193+
194+
:param remote_port: The remote port to create a tunnel to
195+
:type remote_port: int
196+
:param remote_host: The remote host to create a tunnel to (default localhost)
197+
:type remote_host: str
198+
:param local_port: The local port to attach the tunnel to
199+
:type local_port: int
200+
201+
:return: sshtunnel.SSHTunnelForwarder object
202+
"""
203+
204+
if local_port:
205+
local_bind_address = ('localhost', local_port)
206+
else:
207+
local_bind_address = ('localhost',)
208+
209+
if self.password and self.password.strip():
210+
client = SSHTunnelForwarder(self.remote_host,
211+
ssh_port=self.port,
212+
ssh_username=self.username,
213+
ssh_password=self.password,
214+
ssh_pkey=self.key_file,
215+
ssh_proxy=self.host_proxy,
216+
local_bind_address=local_bind_address,
217+
remote_bind_address=(remote_host, remote_port),
218+
logger=self.log)
219+
else:
220+
client = SSHTunnelForwarder(self.remote_host,
221+
ssh_port=self.port,
222+
ssh_username=self.username,
223+
ssh_pkey=self.key_file,
224+
ssh_proxy=self.host_proxy,
225+
local_bind_address=local_bind_address,
226+
remote_bind_address=(remote_host, remote_port),
227+
host_pkey_directories=[],
228+
logger=self.log)
229+
230+
return client
231+
232+
def create_tunnel(self, local_port, remote_port=None, remote_host="localhost"):
233+
warnings.warn('SSHHook.create_tunnel is deprecated, Please'
234+
'use get_tunnel() instead. But please note that the'
235+
'order of the parameters have changed'
236+
'This method will be removed in Airflow 2.0',
237+
category=DeprecationWarning)
238+
239+
return self.get_tunnel(remote_port, remote_host, local_port)

airflow/contrib/operators/sftp_operator.py

+14-12
Original file line numberDiff line numberDiff line change
@@ -82,18 +82,20 @@ def execute(self, context):
8282
if self.remote_host is not None:
8383
self.ssh_hook.remote_host = self.remote_host
8484

85-
ssh_client = self.ssh_hook.get_conn()
86-
sftp_client = ssh_client.open_sftp()
87-
if self.operation.lower() == SFTPOperation.GET:
88-
file_msg = "from {0} to {1}".format(self.remote_filepath,
89-
self.local_filepath)
90-
self.log.debug("Starting to transfer %s", file_msg)
91-
sftp_client.get(self.remote_filepath, self.local_filepath)
92-
else:
93-
file_msg = "from {0} to {1}".format(self.local_filepath,
94-
self.remote_filepath)
95-
self.log.debug("Starting to transfer file %s", file_msg)
96-
sftp_client.put(self.local_filepath, self.remote_filepath)
85+
with self.ssh_hook.get_conn() as ssh_client:
86+
sftp_client = ssh_client.open_sftp()
87+
if self.operation.lower() == SFTPOperation.GET:
88+
file_msg = "from {0} to {1}".format(self.remote_filepath,
89+
self.local_filepath)
90+
self.log.debug("Starting to transfer %s", file_msg)
91+
sftp_client.get(self.remote_filepath, self.local_filepath)
92+
else:
93+
file_msg = "from {0} to {1}".format(self.local_filepath,
94+
self.remote_filepath)
95+
self.log.debug("Starting to transfer file %s", file_msg)
96+
sftp_client.put(self.local_filepath,
97+
self.remote_filepath,
98+
confirm=self.confirm)
9799

98100
except Exception as e:
99101
raise AirflowException("Error while transferring {0}, error: {1}"

0 commit comments

Comments
 (0)