7
7
# to you under the Apache License, Version 2.0 (the
8
8
# "License"); you may not use this file except in compliance
9
9
# with the License. You may obtain a copy of the License at
10
- #
10
+ #
11
11
# http://www.apache.org/licenses/LICENSE-2.0
12
- #
12
+ #
13
13
# Unless required by applicable law or agreed to in writing,
14
14
# software distributed under the License is distributed on an
15
15
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
25
25
import sys
26
26
import unittest
27
27
28
- from airflow .contrib .operators .postgres_to_gcs_operator import PostgresToGoogleCloudStorageOperator
28
+ from airflow .hooks .postgres_hook import PostgresHook
29
+ from airflow .contrib .operators .postgres_to_gcs_operator import \
30
+ PostgresToGoogleCloudStorageOperator
29
31
30
32
try :
31
- from unittest import mock
33
+ from unittest . mock import patch
32
34
except ImportError :
33
35
try :
34
- import mock
36
+ from mock import patch
35
37
except ImportError :
36
38
mock = None
37
39
38
- PY3 = sys . version_info [ 0 ] == 3
40
+ TABLES = { 'postgres_to_gcs_operator' , 'postgres_to_gcs_operator_empty' }
39
41
40
42
TASK_ID = 'test-postgres-to-gcs'
41
- POSTGRES_CONN_ID = 'postgres_conn_test '
42
- SQL = 'select 1 '
43
+ POSTGRES_CONN_ID = 'postgres_default '
44
+ SQL = 'SELECT * FROM postgres_to_gcs_operator '
43
45
BUCKET = 'gs://test'
44
46
FILENAME = 'test_{}.ndjson'
45
- # we expect the psycopg cursor to return encoded strs in py2 and decoded in py3
46
- if PY3 :
47
- ROWS = [('mock_row_content_1' , 42 ), ('mock_row_content_2' , 43 ), ('mock_row_content_3' , 44 )]
48
- CURSOR_DESCRIPTION = (('some_str' , 0 ), ('some_num' , 1005 ))
49
- else :
50
- ROWS = [(b'mock_row_content_1' , 42 ), (b'mock_row_content_2' , 43 ), (b'mock_row_content_3' , 44 )]
51
- CURSOR_DESCRIPTION = ((b'some_str' , 0 ), (b'some_num' , 1005 ))
47
+
52
48
NDJSON_LINES = [
53
49
b'{"some_num": 42, "some_str": "mock_row_content_1"}\n ' ,
54
50
b'{"some_num": 43, "some_str": "mock_row_content_2"}\n ' ,
55
51
b'{"some_num": 44, "some_str": "mock_row_content_3"}\n '
56
52
]
57
53
SCHEMA_FILENAME = 'schema_test.json'
58
- SCHEMA_JSON = b'[{"mode": "NULLABLE", "name": "some_str", "type": "STRING"}, {"mode": "REPEATED", "name": "some_num", "type": "INTEGER"}]'
54
+ SCHEMA_JSON = b'[{"mode": "NULLABLE", "name": "some_str", "type": "STRING"}, ' \
55
+ b'{"mode": "NULLABLE", "name": "some_num", "type": "INTEGER"}]'
59
56
60
57
61
58
class PostgresToGoogleCloudStorageOperatorTest (unittest .TestCase ):
59
+ def setUp (self ):
60
+ postgres = PostgresHook ()
61
+ with postgres .get_conn () as conn :
62
+ with conn .cursor () as cur :
63
+ for table in TABLES :
64
+ cur .execute ("DROP TABLE IF EXISTS {} CASCADE;" .format (table ))
65
+ cur .execute ("CREATE TABLE {}(some_str varchar, some_num integer);"
66
+ .format (table ))
67
+
68
+ cur .execute (
69
+ "INSERT INTO postgres_to_gcs_operator VALUES(%s, %s);" ,
70
+ ('mock_row_content_1' , 42 )
71
+ )
72
+ cur .execute (
73
+ "INSERT INTO postgres_to_gcs_operator VALUES(%s, %s);" ,
74
+ ('mock_row_content_2' , 43 )
75
+ )
76
+ cur .execute (
77
+ "INSERT INTO postgres_to_gcs_operator VALUES(%s, %s);" ,
78
+ ('mock_row_content_3' , 44 )
79
+ )
80
+
81
+ def tearDown (self ):
82
+ postgres = PostgresHook ()
83
+ with postgres .get_conn () as conn :
84
+ with conn .cursor () as cur :
85
+ for table in TABLES :
86
+ cur .execute ("DROP TABLE IF EXISTS {} CASCADE;" .format (table ))
87
+
62
88
def test_init (self ):
63
89
"""Test PostgresToGoogleCloudStorageOperator instance is properly initialized."""
64
90
op = PostgresToGoogleCloudStorageOperator (
@@ -68,9 +94,8 @@ def test_init(self):
68
94
self .assertEqual (op .bucket , BUCKET )
69
95
self .assertEqual (op .filename , FILENAME )
70
96
71
- @mock .patch ('airflow.contrib.operators.postgres_to_gcs_operator.PostgresHook' )
72
- @mock .patch ('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook' )
73
- def test_exec_success (self , gcs_hook_mock_class , pg_hook_mock_class ):
97
+ @patch ('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook' )
98
+ def test_exec_success (self , gcs_hook_mock_class ):
74
99
"""Test the execute function in case where the run is successful."""
75
100
op = PostgresToGoogleCloudStorageOperator (
76
101
task_id = TASK_ID ,
@@ -79,10 +104,6 @@ def test_exec_success(self, gcs_hook_mock_class, pg_hook_mock_class):
79
104
bucket = BUCKET ,
80
105
filename = FILENAME )
81
106
82
- pg_hook_mock = pg_hook_mock_class .return_value
83
- pg_hook_mock .get_conn ().cursor ().__iter__ .return_value = iter (ROWS )
84
- pg_hook_mock .get_conn ().cursor ().description = CURSOR_DESCRIPTION
85
-
86
107
gcs_hook_mock = gcs_hook_mock_class .return_value
87
108
88
109
def _assert_upload (bucket , obj , tmp_filename , content_type ):
@@ -96,16 +117,9 @@ def _assert_upload(bucket, obj, tmp_filename, content_type):
96
117
97
118
op .execute (None )
98
119
99
- pg_hook_mock_class .assert_called_once_with (postgres_conn_id = POSTGRES_CONN_ID )
100
- pg_hook_mock .get_conn ().cursor ().execute .assert_called_once_with (SQL , None )
101
-
102
- @mock .patch ('airflow.contrib.operators.postgres_to_gcs_operator.PostgresHook' )
103
- @mock .patch ('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook' )
104
- def test_file_splitting (self , gcs_hook_mock_class , pg_hook_mock_class ):
120
+ @patch ('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook' )
121
+ def test_file_splitting (self , gcs_hook_mock_class ):
105
122
"""Test that ndjson is split by approx_max_file_size_bytes param."""
106
- pg_hook_mock = pg_hook_mock_class .return_value
107
- pg_hook_mock .get_conn ().cursor ().__iter__ .return_value = iter (ROWS )
108
- pg_hook_mock .get_conn ().cursor ().description = CURSOR_DESCRIPTION
109
123
110
124
gcs_hook_mock = gcs_hook_mock_class .return_value
111
125
expected_upload = {
@@ -129,13 +143,23 @@ def _assert_upload(bucket, obj, tmp_filename, content_type):
129
143
approx_max_file_size_bytes = len (expected_upload [FILENAME .format (0 )]))
130
144
op .execute (None )
131
145
132
- @mock .patch ('airflow.contrib.operators.postgres_to_gcs_operator.PostgresHook' )
133
- @mock .patch ('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook' )
134
- def test_schema_file (self , gcs_hook_mock_class , pg_hook_mock_class ):
146
+ @patch ('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook' )
147
+ def test_empty_query (self , gcs_hook_mock_class ):
148
+ """If the sql returns no rows, we should not upload any files"""
149
+ gcs_hook_mock = gcs_hook_mock_class .return_value
150
+
151
+ op = PostgresToGoogleCloudStorageOperator (
152
+ task_id = TASK_ID ,
153
+ sql = 'SELECT * FROM postgres_to_gcs_operator_empty' ,
154
+ bucket = BUCKET ,
155
+ filename = FILENAME )
156
+ op .execute (None )
157
+
158
+ assert not gcs_hook_mock .upload .called , 'No data means no files in the bucket'
159
+
160
+ @patch ('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook' )
161
+ def test_schema_file (self , gcs_hook_mock_class ):
135
162
"""Test writing schema files."""
136
- pg_hook_mock = pg_hook_mock_class .return_value
137
- pg_hook_mock .get_conn ().cursor ().__iter__ .return_value = iter (ROWS )
138
- pg_hook_mock .get_conn ().cursor ().description = CURSOR_DESCRIPTION
139
163
140
164
gcs_hook_mock = gcs_hook_mock_class .return_value
141
165
0 commit comments