19
19
20
20
import getpass
21
21
import os
22
+ import warnings
22
23
23
24
import paramiko
24
25
from paramiko .config import SSH_PORT
26
+ from sshtunnel import SSHTunnelForwarder
25
27
26
- from contextlib import contextmanager
27
28
from airflow .exceptions import AirflowException
28
29
from airflow .hooks .base_hook import BaseHook
29
30
from airflow .utils .log .logging_mixin import LoggingMixin
@@ -62,7 +63,7 @@ def __init__(self,
62
63
username = None ,
63
64
password = None ,
64
65
key_file = None ,
65
- port = SSH_PORT ,
66
+ port = None ,
66
67
timeout = 10 ,
67
68
keepalive_interval = 30
68
69
):
@@ -72,162 +73,167 @@ def __init__(self,
72
73
self .username = username
73
74
self .password = password
74
75
self .key_file = key_file
76
+ self .port = port
75
77
self .timeout = timeout
76
78
self .keepalive_interval = keepalive_interval
79
+
77
80
# Default values, overridable from Connection
78
81
self .compress = True
79
82
self .no_host_key_check = True
83
+ self .host_proxy = None
84
+
85
+ # Placeholder for deprecated __enter__
80
86
self .client = None
81
- self .port = port
87
+
88
+ # Use connection to override defaults
89
+ if self .ssh_conn_id is not None :
90
+ conn = self .get_connection (self .ssh_conn_id )
91
+ if self .username is None :
92
+ self .username = conn .login
93
+ if self .password is None :
94
+ self .password = conn .password
95
+ if self .remote_host is None :
96
+ self .remote_host = conn .host
97
+ if self .port is None :
98
+ self .port = conn .port
99
+ if conn .extra is not None :
100
+ extra_options = conn .extra_dejson
101
+ self .key_file = extra_options .get ("key_file" )
102
+
103
+ if "timeout" in extra_options :
104
+ self .timeout = int (extra_options ["timeout" ], 10 )
105
+
106
+ if "compress" in extra_options \
107
+ and str (extra_options ["compress" ]).lower () == 'false' :
108
+ self .compress = False
109
+ if "no_host_key_check" in extra_options \
110
+ and \
111
+ str (extra_options ["no_host_key_check" ]).lower () == 'false' :
112
+ self .no_host_key_check = False
113
+
114
+ if not self .remote_host :
115
+ raise AirflowException ("Missing required param: remote_host" )
116
+
117
+ # Auto detecting username values from system
118
+ if not self .username :
119
+ self .log .debug (
120
+ "username to ssh to host: %s is not specified for connection id"
121
+ " %s. Using system's default provided by getpass.getuser()" ,
122
+ self .remote_host , self .ssh_conn_id
123
+ )
124
+ self .username = getpass .getuser ()
125
+
126
+ user_ssh_config_filename = os .path .expanduser ('~/.ssh/config' )
127
+ if os .path .isfile (user_ssh_config_filename ):
128
+ ssh_conf = paramiko .SSHConfig ()
129
+ ssh_conf .parse (open (user_ssh_config_filename ))
130
+ host_info = ssh_conf .lookup (self .remote_host )
131
+ if host_info and host_info .get ('proxycommand' ):
132
+ self .host_proxy = paramiko .ProxyCommand (host_info .get ('proxycommand' ))
133
+
134
+ if not (self .password or self .key_file ):
135
+ if host_info and host_info .get ('identityfile' ):
136
+ self .key_file = host_info .get ('identityfile' )[0 ]
137
+
138
+ self .port = self .port or SSH_PORT
82
139
83
140
def get_conn (self ):
84
- if not self .client :
85
- self .log .debug ('Creating SSH client for conn_id: %s' , self .ssh_conn_id )
86
- if self .ssh_conn_id is not None :
87
- conn = self .get_connection (self .ssh_conn_id )
88
- if self .username is None :
89
- self .username = conn .login
90
- if self .password is None :
91
- self .password = conn .password
92
- if self .remote_host is None :
93
- self .remote_host = conn .host
94
- if conn .port is not None :
95
- self .port = conn .port
96
- if conn .extra is not None :
97
- extra_options = conn .extra_dejson
98
- self .key_file = extra_options .get ("key_file" )
99
-
100
- if "timeout" in extra_options :
101
- self .timeout = int (extra_options ["timeout" ], 10 )
102
-
103
- if "compress" in extra_options \
104
- and str (extra_options ["compress" ]).lower () == 'false' :
105
- self .compress = False
106
- if "no_host_key_check" in extra_options \
107
- and \
108
- str (extra_options ["no_host_key_check" ]).lower () == 'false' :
109
- self .no_host_key_check = False
110
-
111
- if not self .remote_host :
112
- raise AirflowException ("Missing required param: remote_host" )
113
-
114
- # Auto detecting username values from system
115
- if not self .username :
116
- self .log .debug (
117
- "username to ssh to host: %s is not specified for connection id"
118
- " %s. Using system's default provided by getpass.getuser()" ,
119
- self .remote_host , self .ssh_conn_id
120
- )
121
- self .username = getpass .getuser ()
122
-
123
- host_proxy = None
124
- user_ssh_config_filename = os .path .expanduser ('~/.ssh/config' )
125
- if os .path .isfile (user_ssh_config_filename ):
126
- ssh_conf = paramiko .SSHConfig ()
127
- ssh_conf .parse (open (user_ssh_config_filename ))
128
- host_info = ssh_conf .lookup (self .remote_host )
129
- if host_info and host_info .get ('proxycommand' ):
130
- host_proxy = paramiko .ProxyCommand (host_info .get ('proxycommand' ))
131
-
132
- if not (self .password or self .key_file ):
133
- if host_info and host_info .get ('identityfile' ):
134
- self .key_file = host_info .get ('identityfile' )[0 ]
135
-
136
- try :
137
- client = paramiko .SSHClient ()
138
- client .load_system_host_keys ()
139
- if self .no_host_key_check :
140
- # Default is RejectPolicy
141
- client .set_missing_host_key_policy (paramiko .AutoAddPolicy ())
142
-
143
- if self .password and self .password .strip ():
144
- client .connect (hostname = self .remote_host ,
145
- username = self .username ,
146
- password = self .password ,
147
- timeout = self .timeout ,
148
- compress = self .compress ,
149
- port = self .port ,
150
- sock = host_proxy )
151
- else :
152
- client .connect (hostname = self .remote_host ,
153
- username = self .username ,
154
- key_filename = self .key_file ,
155
- timeout = self .timeout ,
156
- compress = self .compress ,
157
- port = self .port ,
158
- sock = host_proxy )
159
-
160
- if self .keepalive_interval :
161
- client .get_transport ().set_keepalive (self .keepalive_interval )
162
-
163
- self .client = client
164
- except paramiko .AuthenticationException as auth_error :
165
- self .log .error (
166
- "Auth failed while connecting to host: %s, error: %s" ,
167
- self .remote_host , auth_error
168
- )
169
- except paramiko .SSHException as ssh_error :
170
- self .log .error (
171
- "Failed connecting to host: %s, error: %s" ,
172
- self .remote_host , ssh_error
173
- )
174
- except Exception as error :
175
- self .log .error (
176
- "Error connecting to host: %s, error: %s" ,
177
- self .remote_host , error
178
- )
179
- return self .client
180
-
181
- @contextmanager
182
- def create_tunnel (self , local_port , remote_port = None , remote_host = "localhost" ):
183
141
"""
184
- Creates a tunnel between two hosts. Like ssh -L <LOCAL_PORT>:host:<REMOTE_PORT>.
185
- Remember to close() the returned "tunnel" object in order to clean up
186
- after yourself when you are done with the tunnel.
142
+ Opens a ssh connection to the remote host.
187
143
188
- :param local_port:
189
- :type local_port: int
190
- :param remote_port:
191
- :type remote_port: int
192
- :param remote_host:
193
- :type remote_host: str
194
- :return:
144
+ :return paramiko.SSHClient object
195
145
"""
196
146
197
- import subprocess
198
- # this will ensure the connection to the ssh.remote_host from where the tunnel
199
- # is getting created
200
- self .get_conn ()
201
-
202
- tunnel_host = "{0}:{1}:{2}" . format ( local_port , remote_host , remote_port )
203
-
204
- ssh_cmd = [ "ssh" , "{0}@{1}" . format ( self .username , self .remote_host ),
205
- "-o" , "ControlMaster=no" ,
206
- "-o" , "UserKnownHostsFile=/dev/null" ,
207
- "-o" , "StrictHostKeyChecking=no" ]
208
-
209
- ssh_tunnel_cmd = [ "-L" , tunnel_host ,
210
- "echo -n ready && cat"
211
- ]
212
-
213
- ssh_cmd += ssh_tunnel_cmd
214
- self . log . debug ( "Creating tunnel with cmd: %s" , ssh_cmd )
215
-
216
- proc = subprocess . Popen ( ssh_cmd ,
217
- stdin = subprocess . PIPE ,
218
- stdout = subprocess . PIPE ,
219
- close_fds = True )
220
- ready = proc . stdout . read ( 5 )
221
- assert ready == b"ready" , \
222
- "Did not get 'ready' from remote, got '{0}' instead" . format ( ready )
223
- yield
224
- proc . communicate ()
225
- assert proc . returncode == 0 , \
226
- "Tunnel process did unclean exit (returncode {}" . format ( proc . returncode )
147
+ self . log . debug ( 'Creating SSH client for conn_id: %s' , self . ssh_conn_id )
148
+ client = paramiko . SSHClient ()
149
+ client . load_system_host_keys ()
150
+ if self .no_host_key_check :
151
+ # Default is RejectPolicy
152
+ client . set_missing_host_key_policy ( paramiko . AutoAddPolicy () )
153
+
154
+ if self .password and self .password . strip ():
155
+ client . connect ( hostname = self . remote_host ,
156
+ username = self . username ,
157
+ password = self . password ,
158
+ key_filename = self . key_file ,
159
+ timeout = self . timeout ,
160
+ compress = self . compress ,
161
+ port = self . port ,
162
+ sock = self . host_proxy )
163
+ else :
164
+ client . connect ( hostname = self . remote_host ,
165
+ username = self . username ,
166
+ key_filename = self . key_file ,
167
+ timeout = self . timeout ,
168
+ compress = self . compress ,
169
+ port = self . port ,
170
+ sock = self . host_proxy )
171
+
172
+ if self . keepalive_interval :
173
+ client . get_transport (). set_keepalive ( self . keepalive_interval )
174
+
175
+ self . client = client
176
+ return client
227
177
228
178
def __enter__ (self ):
179
+ warnings .warn ('The contextmanager of SSHHook is deprecated.'
180
+ 'Please use get_conn() as a contextmanager instead.'
181
+ 'This method will be removed in Airflow 2.0' ,
182
+ category = DeprecationWarning )
229
183
return self
230
184
231
185
def __exit__ (self , exc_type , exc_val , exc_tb ):
232
186
if self .client is not None :
233
187
self .client .close ()
188
+ self .client = None
189
+
190
+ def get_tunnel (self , remote_port , remote_host = "localhost" , local_port = None ):
191
+ """
192
+ Creates a tunnel between two hosts. Like ssh -L <LOCAL_PORT>:host:<REMOTE_PORT>.
193
+
194
+ :param remote_port: The remote port to create a tunnel to
195
+ :type remote_port: int
196
+ :param remote_host: The remote host to create a tunnel to (default localhost)
197
+ :type remote_host: str
198
+ :param local_port: The local port to attach the tunnel to
199
+ :type local_port: int
200
+
201
+ :return: sshtunnel.SSHTunnelForwarder object
202
+ """
203
+
204
+ if local_port :
205
+ local_bind_address = ('localhost' , local_port )
206
+ else :
207
+ local_bind_address = ('localhost' ,)
208
+
209
+ if self .password and self .password .strip ():
210
+ client = SSHTunnelForwarder (self .remote_host ,
211
+ ssh_port = self .port ,
212
+ ssh_username = self .username ,
213
+ ssh_password = self .password ,
214
+ ssh_pkey = self .key_file ,
215
+ ssh_proxy = self .host_proxy ,
216
+ local_bind_address = local_bind_address ,
217
+ remote_bind_address = (remote_host , remote_port ),
218
+ logger = self .log )
219
+ else :
220
+ client = SSHTunnelForwarder (self .remote_host ,
221
+ ssh_port = self .port ,
222
+ ssh_username = self .username ,
223
+ ssh_pkey = self .key_file ,
224
+ ssh_proxy = self .host_proxy ,
225
+ local_bind_address = local_bind_address ,
226
+ remote_bind_address = (remote_host , remote_port ),
227
+ host_pkey_directories = [],
228
+ logger = self .log )
229
+
230
+ return client
231
+
232
+ def create_tunnel (self , local_port , remote_port = None , remote_host = "localhost" ):
233
+ warnings .warn ('SSHHook.create_tunnel is deprecated, Please'
234
+ 'use get_tunnel() instead. But please note that the'
235
+ 'order of the parameters have changed'
236
+ 'This method will be removed in Airflow 2.0' ,
237
+ category = DeprecationWarning )
238
+
239
+ return self .get_tunnel (remote_port , remote_host , local_port )
0 commit comments