diff --git a/airflow/contrib/hooks/sftp_hook.py b/airflow/contrib/hooks/sftp_hook.py index 3cb32ab8ba93b..c23bd3012ba32 100644 --- a/airflow/contrib/hooks/sftp_hook.py +++ b/airflow/contrib/hooks/sftp_hook.py @@ -49,6 +49,9 @@ def __init__(self, ftp_conn_id='sftp_default', *args, **kwargs): self.conn = None self.private_key_pass = None + # Fail for unverified hosts, unless this is explicitly allowed + self.no_host_key_check = False + if self.ssh_conn_id is not None: conn = self.get_connection(self.ssh_conn_id) if conn.extra is not None: @@ -59,9 +62,7 @@ def __init__(self, ftp_conn_id='sftp_default', *args, **kwargs): # For backward compatibility # TODO: remove in Airflow 2.1 import warnings - if 'ignore_hostkey_verification' in extra_options \ - and str(extra_options["ignore_hostkey_verification"])\ - .lower() == 'false': + if 'ignore_hostkey_verification' in extra_options: warnings.warn( 'Extra option `ignore_hostkey_verification` is deprecated.' 'Please use `no_host_key_check` instead.' @@ -69,7 +70,14 @@ def __init__(self, ftp_conn_id='sftp_default', *args, **kwargs): DeprecationWarning, stacklevel=2, ) - self.no_host_key_check = False + self.no_host_key_check = str( + extra_options['ignore_hostkey_verification'] + ).lower() == 'true' + + if 'no_host_key_check' in extra_options: + self.no_host_key_check = str( + extra_options['no_host_key_check']).lower() == 'true' + if 'private_key' in extra_options: warnings.warn( 'Extra option `private_key` is deprecated.' diff --git a/airflow/utils/db.py b/airflow/utils/db.py index 30dd84417a2db..ffc41115a1144 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -183,7 +183,7 @@ def initdb(rbac=False): conn_id='sftp_default', conn_type='sftp', host='localhost', port=22, login='airflow', extra=''' - {"private_key": "~/.ssh/id_rsa", "ignore_hostkey_verification": true} + {"key_file": "~/.ssh/id_rsa", "no_host_key_check": true} ''')) merge_conn( models.Connection( diff --git a/tests/contrib/hooks/test_sftp_hook.py b/tests/contrib/hooks/test_sftp_hook.py index 322b5bbbe13a7..ac4d78e9b1609 100644 --- a/tests/contrib/hooks/test_sftp_hook.py +++ b/tests/contrib/hooks/test_sftp_hook.py @@ -19,12 +19,13 @@ from __future__ import print_function +import mock import unittest import shutil import os import pysftp -from airflow import configuration +from airflow import configuration, models from airflow.contrib.hooks.sftp_hook import SFTPHook TMP_PATH = '/tmp' @@ -105,6 +106,63 @@ def test_get_mod_time(self): TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS)) self.assertEqual(len(output), 14) + @mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection') + def test_no_host_key_check_default(self, get_connection): + connection = models.Connection(login='login', host='host') + get_connection.return_value = connection + hook = SFTPHook() + self.assertEqual(hook.no_host_key_check, False) + + @mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection') + def test_no_host_key_check_enabled(self, get_connection): + connection = models.Connection( + login='login', host='host', + extra='{"no_host_key_check": true}') + + get_connection.return_value = connection + hook = SFTPHook() + self.assertEqual(hook.no_host_key_check, True) + + @mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection') + def test_no_host_key_check_disabled(self, get_connection): + connection = models.Connection( + login='login', host='host', + extra='{"no_host_key_check": false}') + + get_connection.return_value = connection + hook = SFTPHook() + self.assertEqual(hook.no_host_key_check, False) + + @mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection') + def test_no_host_key_check_disabled_for_all_but_true(self, get_connection): + connection = models.Connection( + login='login', host='host', + extra='{"no_host_key_check": "foo"}') + + get_connection.return_value = connection + hook = SFTPHook() + self.assertEqual(hook.no_host_key_check, False) + + @mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection') + def test_no_host_key_check_ignore(self, get_connection): + connection = models.Connection( + login='login', host='host', + extra='{"ignore_hostkey_verification": true}') + + get_connection.return_value = connection + hook = SFTPHook() + self.assertEqual(hook.no_host_key_check, True) + + @mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection') + def test_no_host_key_check_no_ignore(self, get_connection): + connection = models.Connection( + login='login', host='host', + extra='{"ignore_hostkey_verification": false}') + + get_connection.return_value = connection + hook = SFTPHook() + self.assertEqual(hook.no_host_key_check, False) + def tearDown(self): shutil.rmtree(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) os.remove(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS))