From 4f1c85a306f90c5f7e096191f805fc0201a80940 Mon Sep 17 00:00:00 2001 From: jlowin Date: Wed, 3 Aug 2016 17:23:22 -0400 Subject: [PATCH] [AIRFLOW-393] Add progress callback for FTP download --- airflow/contrib/hooks/ftp_hook.py | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/airflow/contrib/hooks/ftp_hook.py b/airflow/contrib/hooks/ftp_hook.py index 2f2e4c287fc42..7864db8bb283c 100644 --- a/airflow/contrib/hooks/ftp_hook.py +++ b/airflow/contrib/hooks/ftp_hook.py @@ -141,7 +141,11 @@ def delete_directory(self, path): conn = self.get_conn() conn.rmd(path) - def retrieve_file(self, remote_full_path, local_full_path_or_buffer): + def retrieve_file( + self, + remote_full_path, + local_full_path_or_buffer, + progress_callback=None): """ Transfers the remote file to a local location. @@ -154,6 +158,11 @@ def retrieve_file(self, remote_full_path, local_full_path_or_buffer): :param local_full_path_or_buffer: full path to the local file or a file-like buffer :type local_full_path: str or file-like buffer + :param progress_callback: a function that is called prior to processing + each block of data. It is passed the number of bytes about to be + processed. If the file size is known, this can be used to track + progress. + :type progress_callback: callable """ conn = self.get_conn() @@ -164,10 +173,17 @@ def retrieve_file(self, remote_full_path, local_full_path_or_buffer): else: output_handle = local_full_path_or_buffer + if progress_callback is not None: + def callback(data): + callback(len(data)) + output_handle.write(data) + else: + callback = output_handle.write + remote_path, remote_file_name = os.path.split(remote_full_path) conn.cwd(remote_path) logging.info('Retrieving file from FTP: {}'.format(remote_full_path)) - conn.retrbinary('RETR %s' % remote_file_name, output_handle.write) + conn.retrbinary('RETR %s' % remote_file_name, callback) logging.info('Finished retrieving file from FTP: {}'.format( remote_full_path)) @@ -228,6 +244,16 @@ def get_mod_time(self, path): ftp_mdtm = conn.sendcmd('MDTM ' + path) return datetime.datetime.strptime(ftp_mdtm[4:], '%Y%m%d%H%M%S') + def size(self, path): + """ + Returns the size of a file (in bytes) + + :param path: remote file path + :type path: string + """ + conn = self.get_conn() + return conn.size(path) + class FTPSHook(FTPHook):