Skip to content

Commit 667fd6e

Browse files
author
Fokko Driesprong
committed
[AIRFLOW-3059] Log how many rows are read from Postgres
To know how many data is being read from Postgres, it is nice to log this to the Airflow log. Previously when there was no data, it would still create a single file. This is not something that we want, and therefore we've changed this behaviour. Refactored the tests to make use of Postgres itself since we have it running. This makes the tests more realistic, instead of mocking everything.
1 parent 7d60d26 commit 667fd6e

File tree

2 files changed

+94
-60
lines changed

2 files changed

+94
-60
lines changed

airflow/contrib/operators/postgres_to_gcs_operator.py

+32-22
Original file line numberDiff line numberDiff line change
@@ -133,28 +133,38 @@ def _write_local_data_files(self, cursor):
133133
contain the data for the GCS objects.
134134
"""
135135
schema = list(map(lambda schema_tuple: schema_tuple[0], cursor.description))
136-
file_no = 0
137-
tmp_file_handle = NamedTemporaryFile(delete=True)
138-
tmp_file_handles = {self.filename.format(file_no): tmp_file_handle}
139-
140-
for row in cursor:
141-
# Convert datetime objects to utc seconds, and decimals to floats
142-
row = map(self.convert_types, row)
143-
row_dict = dict(zip(schema, row))
144-
145-
s = json.dumps(row_dict, sort_keys=True)
146-
if PY3:
147-
s = s.encode('utf-8')
148-
tmp_file_handle.write(s)
149-
150-
# Append newline to make dumps BigQuery compatible.
151-
tmp_file_handle.write(b'\n')
152-
153-
# Stop if the file exceeds the file size limit.
154-
if tmp_file_handle.tell() >= self.approx_max_file_size_bytes:
155-
file_no += 1
156-
tmp_file_handle = NamedTemporaryFile(delete=True)
157-
tmp_file_handles[self.filename.format(file_no)] = tmp_file_handle
136+
tmp_file_handles = {}
137+
row_no = 0
138+
139+
def _create_new_file():
140+
handle = NamedTemporaryFile(delete=True)
141+
filename = self.filename.format(len(tmp_file_handles))
142+
tmp_file_handles[filename] = handle
143+
return handle
144+
145+
# Don't create a file if there is nothing to write
146+
if cursor.rowcount > 0:
147+
tmp_file_handle = _create_new_file()
148+
149+
for row in cursor:
150+
# Convert datetime objects to utc seconds, and decimals to floats
151+
row = map(self.convert_types, row)
152+
row_dict = dict(zip(schema, row))
153+
154+
s = json.dumps(row_dict, sort_keys=True)
155+
if PY3:
156+
s = s.encode('utf-8')
157+
tmp_file_handle.write(s)
158+
159+
# Append newline to make dumps BigQuery compatible.
160+
tmp_file_handle.write(b'\n')
161+
162+
# Stop if the file exceeds the file size limit.
163+
if tmp_file_handle.tell() >= self.approx_max_file_size_bytes:
164+
tmp_file_handle = _create_new_file()
165+
row_no += 1
166+
167+
self.log.info('Received %s rows over %s files', row_no, len(tmp_file_handles))
158168

159169
return tmp_file_handles
160170

tests/contrib/operators/test_postgres_to_gcs_operator.py

+62-38
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
# to you under the Apache License, Version 2.0 (the
88
# "License"); you may not use this file except in compliance
99
# with the License. You may obtain a copy of the License at
10-
#
10+
#
1111
# http://www.apache.org/licenses/LICENSE-2.0
12-
#
12+
#
1313
# Unless required by applicable law or agreed to in writing,
1414
# software distributed under the License is distributed on an
1515
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -25,40 +25,66 @@
2525
import sys
2626
import unittest
2727

28-
from airflow.contrib.operators.postgres_to_gcs_operator import PostgresToGoogleCloudStorageOperator
28+
from airflow.hooks.postgres_hook import PostgresHook
29+
from airflow.contrib.operators.postgres_to_gcs_operator import \
30+
PostgresToGoogleCloudStorageOperator
2931

3032
try:
31-
from unittest import mock
33+
from unittest.mock import patch
3234
except ImportError:
3335
try:
34-
import mock
36+
from mock import patch
3537
except ImportError:
3638
mock = None
3739

38-
PY3 = sys.version_info[0] == 3
40+
TABLES = {'postgres_to_gcs_operator', 'postgres_to_gcs_operator_empty'}
3941

4042
TASK_ID = 'test-postgres-to-gcs'
41-
POSTGRES_CONN_ID = 'postgres_conn_test'
42-
SQL = 'select 1'
43+
POSTGRES_CONN_ID = 'postgres_default'
44+
SQL = 'SELECT * FROM postgres_to_gcs_operator'
4345
BUCKET = 'gs://test'
4446
FILENAME = 'test_{}.ndjson'
45-
# we expect the psycopg cursor to return encoded strs in py2 and decoded in py3
46-
if PY3:
47-
ROWS = [('mock_row_content_1', 42), ('mock_row_content_2', 43), ('mock_row_content_3', 44)]
48-
CURSOR_DESCRIPTION = (('some_str', 0), ('some_num', 1005))
49-
else:
50-
ROWS = [(b'mock_row_content_1', 42), (b'mock_row_content_2', 43), (b'mock_row_content_3', 44)]
51-
CURSOR_DESCRIPTION = ((b'some_str', 0), (b'some_num', 1005))
47+
5248
NDJSON_LINES = [
5349
b'{"some_num": 42, "some_str": "mock_row_content_1"}\n',
5450
b'{"some_num": 43, "some_str": "mock_row_content_2"}\n',
5551
b'{"some_num": 44, "some_str": "mock_row_content_3"}\n'
5652
]
5753
SCHEMA_FILENAME = 'schema_test.json'
58-
SCHEMA_JSON = b'[{"mode": "NULLABLE", "name": "some_str", "type": "STRING"}, {"mode": "REPEATED", "name": "some_num", "type": "INTEGER"}]'
54+
SCHEMA_JSON = b'[{"mode": "NULLABLE", "name": "some_str", "type": "STRING"}, ' \
55+
b'{"mode": "NULLABLE", "name": "some_num", "type": "INTEGER"}]'
5956

6057

6158
class PostgresToGoogleCloudStorageOperatorTest(unittest.TestCase):
59+
def setUp(self):
60+
postgres = PostgresHook()
61+
with postgres.get_conn() as conn:
62+
with conn.cursor() as cur:
63+
for table in TABLES:
64+
cur.execute("DROP TABLE IF EXISTS {} CASCADE;".format(table))
65+
cur.execute("CREATE TABLE {}(some_str varchar, some_num integer);"
66+
.format(table))
67+
68+
cur.execute(
69+
"INSERT INTO postgres_to_gcs_operator VALUES(%s, %s);",
70+
('mock_row_content_1', 42)
71+
)
72+
cur.execute(
73+
"INSERT INTO postgres_to_gcs_operator VALUES(%s, %s);",
74+
('mock_row_content_2', 43)
75+
)
76+
cur.execute(
77+
"INSERT INTO postgres_to_gcs_operator VALUES(%s, %s);",
78+
('mock_row_content_3', 44)
79+
)
80+
81+
def tearDown(self):
82+
postgres = PostgresHook()
83+
with postgres.get_conn() as conn:
84+
with conn.cursor() as cur:
85+
for table in TABLES:
86+
cur.execute("DROP TABLE IF EXISTS {} CASCADE;".format(table))
87+
6288
def test_init(self):
6389
"""Test PostgresToGoogleCloudStorageOperator instance is properly initialized."""
6490
op = PostgresToGoogleCloudStorageOperator(
@@ -68,9 +94,8 @@ def test_init(self):
6894
self.assertEqual(op.bucket, BUCKET)
6995
self.assertEqual(op.filename, FILENAME)
7096

71-
@mock.patch('airflow.contrib.operators.postgres_to_gcs_operator.PostgresHook')
72-
@mock.patch('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook')
73-
def test_exec_success(self, gcs_hook_mock_class, pg_hook_mock_class):
97+
@patch('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook')
98+
def test_exec_success(self, gcs_hook_mock_class):
7499
"""Test the execute function in case where the run is successful."""
75100
op = PostgresToGoogleCloudStorageOperator(
76101
task_id=TASK_ID,
@@ -79,10 +104,6 @@ def test_exec_success(self, gcs_hook_mock_class, pg_hook_mock_class):
79104
bucket=BUCKET,
80105
filename=FILENAME)
81106

82-
pg_hook_mock = pg_hook_mock_class.return_value
83-
pg_hook_mock.get_conn().cursor().__iter__.return_value = iter(ROWS)
84-
pg_hook_mock.get_conn().cursor().description = CURSOR_DESCRIPTION
85-
86107
gcs_hook_mock = gcs_hook_mock_class.return_value
87108

88109
def _assert_upload(bucket, obj, tmp_filename, content_type):
@@ -96,16 +117,9 @@ def _assert_upload(bucket, obj, tmp_filename, content_type):
96117

97118
op.execute(None)
98119

99-
pg_hook_mock_class.assert_called_once_with(postgres_conn_id=POSTGRES_CONN_ID)
100-
pg_hook_mock.get_conn().cursor().execute.assert_called_once_with(SQL, None)
101-
102-
@mock.patch('airflow.contrib.operators.postgres_to_gcs_operator.PostgresHook')
103-
@mock.patch('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook')
104-
def test_file_splitting(self, gcs_hook_mock_class, pg_hook_mock_class):
120+
@patch('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook')
121+
def test_file_splitting(self, gcs_hook_mock_class):
105122
"""Test that ndjson is split by approx_max_file_size_bytes param."""
106-
pg_hook_mock = pg_hook_mock_class.return_value
107-
pg_hook_mock.get_conn().cursor().__iter__.return_value = iter(ROWS)
108-
pg_hook_mock.get_conn().cursor().description = CURSOR_DESCRIPTION
109123

110124
gcs_hook_mock = gcs_hook_mock_class.return_value
111125
expected_upload = {
@@ -129,13 +143,23 @@ def _assert_upload(bucket, obj, tmp_filename, content_type):
129143
approx_max_file_size_bytes=len(expected_upload[FILENAME.format(0)]))
130144
op.execute(None)
131145

132-
@mock.patch('airflow.contrib.operators.postgres_to_gcs_operator.PostgresHook')
133-
@mock.patch('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook')
134-
def test_schema_file(self, gcs_hook_mock_class, pg_hook_mock_class):
146+
@patch('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook')
147+
def test_empty_query(self, gcs_hook_mock_class):
148+
"""If the sql returns no rows, we should not upload any files"""
149+
gcs_hook_mock = gcs_hook_mock_class.return_value
150+
151+
op = PostgresToGoogleCloudStorageOperator(
152+
task_id=TASK_ID,
153+
sql='SELECT * FROM postgres_to_gcs_operator_empty',
154+
bucket=BUCKET,
155+
filename=FILENAME)
156+
op.execute(None)
157+
158+
assert not gcs_hook_mock.upload.called, 'No data means no files in the bucket'
159+
160+
@patch('airflow.contrib.operators.postgres_to_gcs_operator.GoogleCloudStorageHook')
161+
def test_schema_file(self, gcs_hook_mock_class):
135162
"""Test writing schema files."""
136-
pg_hook_mock = pg_hook_mock_class.return_value
137-
pg_hook_mock.get_conn().cursor().__iter__.return_value = iter(ROWS)
138-
pg_hook_mock.get_conn().cursor().description = CURSOR_DESCRIPTION
139163

140164
gcs_hook_mock = gcs_hook_mock_class.return_value
141165

0 commit comments

Comments
 (0)