Skip to content

Commit c67a396

Browse files
wmorris75ashb
authored andcommitted
[AIRFLOW-2993] s3_to_sftp and sftp_to_s3 operators (#3828)
Add operators for transferring files between s3 and sftp.
1 parent 8b33948 commit c67a396

File tree

4 files changed

+484
-0
lines changed

4 files changed

+484
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# -*- coding: utf-8 -*-
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
20+
from airflow.models import BaseOperator
21+
from airflow.hooks.S3_hook import S3Hook
22+
from airflow.contrib.hooks.ssh_hook import SSHHook
23+
from tempfile import NamedTemporaryFile
24+
from urllib.parse import urlparse
25+
from airflow.utils.decorators import apply_defaults
26+
27+
28+
class S3ToSFTPOperator(BaseOperator):
29+
"""
30+
This operator enables the transferring of files from S3 to a SFTP server
31+
:param sftp_conn_id: The sftp connection id. The name or
32+
identifier for establishing a connection to the SFTP server.
33+
:type sftp_conn_id: string
34+
:param sftp_path: The sftp remote path. This is the specified
35+
file path for uploading file to the SFTP server.
36+
:type sftp_path: string
37+
:param s3_conn_id: The s3 connnection id. The name or identifier for establishing
38+
a connection to S3
39+
:type s3_conn_id: string
40+
:param s3_bucket: The targeted s3 bucket. This is the S3 bucket
41+
from where the file is downloaded.
42+
:type s3_bucket: string
43+
:param s3_key: The targeted s3 key. This is the specified file path
44+
for downloading the file from S3.
45+
:type s3_key: string
46+
"""
47+
48+
template_fields = ('s3_key', 'sftp_path')
49+
50+
@apply_defaults
51+
def __init__(self,
52+
s3_bucket,
53+
s3_key,
54+
sftp_path,
55+
sftp_conn_id='ssh_default',
56+
s3_conn_id='aws_default',
57+
*args,
58+
**kwargs):
59+
super(S3ToSFTPOperator, self).__init__(*args, **kwargs)
60+
self.sftp_conn_id = sftp_conn_id
61+
self.sftp_path = sftp_path
62+
self.s3_bucket = s3_bucket
63+
self.s3_key = s3_key
64+
self.s3_conn_id = s3_conn_id
65+
66+
@staticmethod
67+
def get_s3_key(s3_key):
68+
"""This parses the correct format for S3 keys
69+
regardless of how the S3 url is passed."""
70+
71+
parsed_s3_key = urlparse(s3_key)
72+
return parsed_s3_key.path.lstrip('/')
73+
74+
def execute(self, context):
75+
self.s3_key = self.get_s3_key(self.s3_key)
76+
ssh_hook = SSHHook(ssh_conn_id=self.sftp_conn_id)
77+
s3_hook = S3Hook(self.s3_conn_id)
78+
79+
s3_client = s3_hook.get_conn()
80+
sftp_client = ssh_hook.get_conn().open_sftp()
81+
82+
with NamedTemporaryFile("w") as f:
83+
s3_client.download_file(self.s3_bucket, self.s3_key, f.name)
84+
sftp_client.put(f.name, self.sftp_path)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# -*- coding: utf-8 -*-
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
20+
from airflow.models import BaseOperator
21+
from airflow.hooks.S3_hook import S3Hook
22+
from airflow.contrib.hooks.ssh_hook import SSHHook
23+
from tempfile import NamedTemporaryFile
24+
from urllib.parse import urlparse
25+
from airflow.utils.decorators import apply_defaults
26+
27+
28+
class SFTPToS3Operator(BaseOperator):
29+
"""
30+
This operator enables the transferring of files from a SFTP server to Amazon S3
31+
:param sftp_conn_id: The sftp connection id. The name or identifier for
32+
establishing a connection to the SFTP server.
33+
:type sftp_conn_id: string
34+
:param sftp_path: The sftp remote path. This is the specified file
35+
path for downloading the file from the SFTP server.
36+
:type sftp_path: string
37+
:param s3_conn_id: The s3 connnection id. The name or identifier for
38+
establishing a connection to S3
39+
:type s3_conn_id: string
40+
:param s3_bucket: The targeted s3 bucket. This is the S3 bucket
41+
to where the file is uploaded.
42+
:type s3_bucket: string
43+
:param s3_key: The targeted s3 key. This is the specified path
44+
for uploading the file to S3.
45+
:type s3_key: string
46+
"""
47+
48+
template_fields = ('s3_key', 'sftp_path')
49+
50+
@apply_defaults
51+
def __init__(self,
52+
s3_bucket,
53+
s3_key,
54+
sftp_path,
55+
sftp_conn_id='ssh_default',
56+
s3_conn_id='aws_default',
57+
*args,
58+
**kwargs):
59+
super(SFTPToS3Operator, self).__init__(*args, **kwargs)
60+
self.sftp_conn_id = sftp_conn_id
61+
self.sftp_path = sftp_path
62+
self.s3_bucket = s3_bucket
63+
self.s3_key = s3_key
64+
self.s3_conn_id = s3_conn_id
65+
66+
@staticmethod
67+
def get_s3_key(s3_key):
68+
"""This parses the correct format for S3 keys
69+
regardless of how the S3 url is passed."""
70+
71+
parsed_s3_key = urlparse(s3_key)
72+
return parsed_s3_key.path.lstrip('/')
73+
74+
def execute(self, context):
75+
self.s3_key = self.get_s3_key(self.s3_key)
76+
ssh_hook = SSHHook(ssh_conn_id=self.sftp_conn_id)
77+
s3_hook = S3Hook(self.s3_conn_id)
78+
79+
sftp_client = ssh_hook.get_conn().open_sftp()
80+
81+
with NamedTemporaryFile("w") as f:
82+
sftp_client.get(self.sftp_path, f.name)
83+
84+
s3_hook.load_file(
85+
filename=f.name,
86+
key=self.s3_key,
87+
bucket_name=self.s3_bucket,
88+
replace=True
89+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
# -*- coding: utf-8 -*-
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
20+
import unittest
21+
22+
from airflow import configuration
23+
from airflow import models
24+
from airflow.contrib.operators.s3_to_sftp_operator import S3ToSFTPOperator
25+
from airflow.contrib.operators.ssh_operator import SSHOperator
26+
from airflow.models import DAG, TaskInstance
27+
from airflow.settings import Session
28+
from airflow.utils import timezone
29+
from airflow.utils.timezone import datetime
30+
import boto3
31+
from moto import mock_s3
32+
33+
34+
TASK_ID = 'test_s3_to_sftp'
35+
BUCKET = 'test-s3-bucket'
36+
S3_KEY = 'test/test_1_file.csv'
37+
SFTP_PATH = '/tmp/remote_path.txt'
38+
SFTP_CONN_ID = 'ssh_default'
39+
S3_CONN_ID = 'aws_default'
40+
LOCAL_FILE_PATH = '/tmp/test_s3_upload'
41+
42+
SFTP_MOCK_FILE = 'test_sftp_file.csv'
43+
S3_MOCK_FILES = 'test_1_file.csv'
44+
45+
TEST_DAG_ID = 'unit_tests'
46+
DEFAULT_DATE = datetime(2018, 1, 1)
47+
48+
49+
def reset(dag_id=TEST_DAG_ID):
50+
session = Session()
51+
tis = session.query(models.TaskInstance).filter_by(dag_id=dag_id)
52+
tis.delete()
53+
session.commit()
54+
session.close()
55+
56+
57+
reset()
58+
59+
60+
class S3ToSFTPOperatorTest(unittest.TestCase):
61+
@mock_s3
62+
def setUp(self):
63+
configuration.load_test_config()
64+
from airflow.contrib.hooks.ssh_hook import SSHHook
65+
from airflow.hooks.S3_hook import S3Hook
66+
67+
hook = SSHHook(ssh_conn_id='ssh_default')
68+
s3_hook = S3Hook('aws_default')
69+
hook.no_host_key_check = True
70+
args = {
71+
'owner': 'airflow',
72+
'start_date': DEFAULT_DATE,
73+
'provide_context': True
74+
}
75+
dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args)
76+
dag.schedule_interval = '@once'
77+
78+
self.hook = hook
79+
self.s3_hook = s3_hook
80+
81+
self.ssh_client = self.hook.get_conn()
82+
self.sftp_client = self.ssh_client.open_sftp()
83+
84+
self.dag = dag
85+
self.s3_bucket = BUCKET
86+
self.sftp_path = SFTP_PATH
87+
self.s3_key = S3_KEY
88+
89+
@mock_s3
90+
def test_s3_to_sftp_operation(self):
91+
# Setting
92+
configuration.conf.set("core", "enable_xcom_pickling", "True")
93+
test_remote_file_content = \
94+
"This is remote file content \n which is also multiline " \
95+
"another line here \n this is last line. EOF"
96+
97+
# Test for creation of s3 bucket
98+
conn = boto3.client('s3')
99+
conn.create_bucket(Bucket=self.s3_bucket)
100+
self.assertTrue((self.s3_hook.check_for_bucket(self.s3_bucket)))
101+
102+
with open(LOCAL_FILE_PATH, 'w') as f:
103+
f.write(test_remote_file_content)
104+
self.s3_hook.load_file(LOCAL_FILE_PATH, self.s3_key, bucket_name=BUCKET)
105+
106+
# Check if object was created in s3
107+
objects_in_dest_bucket = conn.list_objects(Bucket=self.s3_bucket,
108+
Prefix=self.s3_key)
109+
# there should be object found, and there should only be one object found
110+
self.assertEqual(len(objects_in_dest_bucket['Contents']), 1)
111+
112+
# the object found should be consistent with dest_key specified earlier
113+
self.assertEqual(objects_in_dest_bucket['Contents'][0]['Key'], self.s3_key)
114+
115+
# get remote file to local
116+
run_task = S3ToSFTPOperator(
117+
s3_bucket=BUCKET,
118+
s3_key=S3_KEY,
119+
sftp_path=SFTP_PATH,
120+
sftp_conn_id=SFTP_CONN_ID,
121+
s3_conn_id=S3_CONN_ID,
122+
task_id=TASK_ID,
123+
dag=self.dag
124+
)
125+
self.assertIsNotNone(run_task)
126+
127+
run_task.execute(None)
128+
129+
# Check that the file is created remotely
130+
check_file_task = SSHOperator(
131+
task_id="test_check_file",
132+
ssh_hook=self.hook,
133+
command="cat {0}".format(self.sftp_path),
134+
do_xcom_push=True,
135+
dag=self.dag
136+
)
137+
self.assertIsNotNone(check_file_task)
138+
ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow())
139+
ti3.run()
140+
self.assertEqual(
141+
ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(),
142+
test_remote_file_content.encode('utf-8'))
143+
144+
# Clean up after finishing with test
145+
conn.delete_object(Bucket=self.s3_bucket, Key=self.s3_key)
146+
conn.delete_bucket(Bucket=self.s3_bucket)
147+
self.assertFalse((self.s3_hook.check_for_bucket(self.s3_bucket)))
148+
149+
def delete_remote_resource(self):
150+
# check the remote file content
151+
remove_file_task = SSHOperator(
152+
task_id="test_check_file",
153+
ssh_hook=self.hook,
154+
command="rm {0}".format(self.sftp_path),
155+
do_xcom_push=True,
156+
dag=self.dag
157+
)
158+
self.assertIsNotNone(remove_file_task)
159+
ti3 = TaskInstance(task=remove_file_task, execution_date=timezone.utcnow())
160+
ti3.run()
161+
162+
def tearDown(self):
163+
self.delete_remote_resource()
164+
165+
166+
if __name__ == '__main__':
167+
unittest.main()

0 commit comments

Comments
 (0)