Skip to content

Commit dfff4a0

Browse files
chronitiskaxil
authored andcommitted
[AIRFLOW-2997] Support cluster fields in bigquery (#3838)
This adds a cluster_fields argument to the bigquery hook, GCS to bigquery operator and bigquery query operators. This field requests that bigquery store the result of the query/load operation sorted according to the specified fields (the order of fields given is significant).
1 parent 4ec0f0c commit dfff4a0

File tree

4 files changed

+125
-5
lines changed

4 files changed

+125
-5
lines changed

airflow/contrib/hooks/bigquery_hook.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,8 @@ def run_query(self,
496496
schema_update_options=(),
497497
priority='INTERACTIVE',
498498
time_partitioning=None,
499-
api_resource_configs=None):
499+
api_resource_configs=None,
500+
cluster_fields=None):
500501
"""
501502
Executes a BigQuery SQL query. Optionally persists results in a BigQuery
502503
table. See here:
@@ -563,8 +564,12 @@ def run_query(self,
563564
:param time_partitioning: configure optional time partitioning fields i.e.
564565
partition by field, type and expiration as per API specifications.
565566
:type time_partitioning: dict
566-
567+
:param cluster_fields: Request that the result of this query be stored sorted
568+
by one or more columns. This is only available in combination with
569+
time_partitioning. The order of columns given determines the sort order.
570+
:type cluster_fields: list of str
567571
"""
572+
568573
if not api_resource_configs:
569574
api_resource_configs = self.api_resource_configs
570575
else:
@@ -629,6 +634,9 @@ def run_query(self,
629634
'tableId': destination_table,
630635
}
631636

637+
if cluster_fields:
638+
cluster_fields = {'fields': cluster_fields}
639+
632640
query_param_list = [
633641
(sql, 'query', None, str),
634642
(priority, 'priority', 'INTERACTIVE', str),
@@ -639,7 +647,8 @@ def run_query(self,
639647
(maximum_bytes_billed, 'maximumBytesBilled', None, float),
640648
(time_partitioning, 'timePartitioning', {}, dict),
641649
(schema_update_options, 'schemaUpdateOptions', None, tuple),
642-
(destination_dataset_table, 'destinationTable', None, dict)
650+
(destination_dataset_table, 'destinationTable', None, dict),
651+
(cluster_fields, 'clustering', None, dict),
643652
]
644653

645654
for param_tuple in query_param_list:
@@ -854,7 +863,8 @@ def run_load(self,
854863
allow_jagged_rows=False,
855864
schema_update_options=(),
856865
src_fmt_configs=None,
857-
time_partitioning=None):
866+
time_partitioning=None,
867+
cluster_fields=None):
858868
"""
859869
Executes a BigQuery load command to load data from Google Cloud Storage
860870
to BigQuery. See here:
@@ -916,6 +926,10 @@ def run_load(self,
916926
:param time_partitioning: configure optional time partitioning fields i.e.
917927
partition by field, type and expiration as per API specifications.
918928
:type time_partitioning: dict
929+
:param cluster_fields: Request that the result of this load be stored sorted
930+
by one or more columns. This is only available in combination with
931+
time_partitioning. The order of columns given determines the sort order.
932+
:type cluster_fields: list of str
919933
"""
920934

921935
# bigquery only allows certain source formats
@@ -979,6 +993,9 @@ def run_load(self,
979993
'timePartitioning': time_partitioning
980994
})
981995

996+
if cluster_fields:
997+
configuration['load'].update({'clustering': {'fields': cluster_fields}})
998+
982999
if schema_fields:
9831000
configuration['load']['schema'] = {'fields': schema_fields}
9841001

airflow/contrib/operators/bigquery_operator.py

+7
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ class BigQueryOperator(BaseOperator):
9898
:param time_partitioning: configure optional time partitioning fields i.e.
9999
partition by field, type and expiration as per API specifications.
100100
:type time_partitioning: dict
101+
:param cluster_fields: Request that the result of this query be stored sorted
102+
by one or more columns. This is only available in conjunction with
103+
time_partitioning. The order of columns given determines the sort order.
104+
:type cluster_fields: list of str
101105
"""
102106

103107
template_fields = ('bql', 'sql', 'destination_dataset_table', 'labels')
@@ -125,6 +129,7 @@ def __init__(self,
125129
priority='INTERACTIVE',
126130
time_partitioning=None,
127131
api_resource_configs=None,
132+
cluster_fields=None,
128133
*args,
129134
**kwargs):
130135
super(BigQueryOperator, self).__init__(*args, **kwargs)
@@ -150,6 +155,7 @@ def __init__(self,
150155
self.time_partitioning = {}
151156
if api_resource_configs is None:
152157
self.api_resource_configs = {}
158+
self.cluster_fields = cluster_fields
153159

154160
# TODO remove `bql` in Airflow 2.0
155161
if self.bql:
@@ -190,6 +196,7 @@ def execute(self, context):
190196
priority=self.priority,
191197
time_partitioning=self.time_partitioning,
192198
api_resource_configs=self.api_resource_configs,
199+
cluster_fields=self.cluster_fields,
193200
)
194201

195202
def on_kill(self):

airflow/contrib/operators/gcs_to_bq.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,11 @@ class GoogleCloudStorageToBigQueryOperator(BaseOperator):
114114
Note that 'field' is not available in concurrency with
115115
dataset.table$partition.
116116
:type time_partitioning: dict
117+
:param cluster_fields: Request that the result of this load be stored sorted
118+
by one or more columns. This is only available in conjunction with
119+
time_partitioning. The order of columns given determines the sort order.
120+
Not applicable for external tables.
121+
:type cluster_fields: list of str
117122
"""
118123
template_fields = ('bucket', 'source_objects',
119124
'schema_object', 'destination_project_dataset_table')
@@ -146,6 +151,7 @@ def __init__(self,
146151
src_fmt_configs=None,
147152
external_table=False,
148153
time_partitioning=None,
154+
cluster_fields=None,
149155
*args, **kwargs):
150156

151157
super(GoogleCloudStorageToBigQueryOperator, self).__init__(*args, **kwargs)
@@ -183,6 +189,7 @@ def __init__(self,
183189
self.schema_update_options = schema_update_options
184190
self.src_fmt_configs = src_fmt_configs
185191
self.time_partitioning = time_partitioning
192+
self.cluster_fields = cluster_fields
186193

187194
def execute(self, context):
188195
bq_hook = BigQueryHook(bigquery_conn_id=self.bigquery_conn_id,
@@ -238,7 +245,8 @@ def execute(self, context):
238245
allow_jagged_rows=self.allow_jagged_rows,
239246
schema_update_options=self.schema_update_options,
240247
src_fmt_configs=self.src_fmt_configs,
241-
time_partitioning=self.time_partitioning)
248+
time_partitioning=self.time_partitioning,
249+
cluster_fields=self.cluster_fields)
242250

243251
if self.max_id_key:
244252
cursor.execute('SELECT MAX({}) FROM {}'.format(

tests/contrib/hooks/test_bigquery_hook.py

+88
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,94 @@ def test_extra_time_partitioning_options(self):
448448
self.assertEqual(tp_out, expect)
449449

450450

451+
class TestClusteringInRunJob(unittest.TestCase):
452+
453+
@mock.patch("airflow.contrib.hooks.bigquery_hook.LoggingMixin")
454+
@mock.patch("airflow.contrib.hooks.bigquery_hook.time")
455+
@mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
456+
def test_run_load_default(self, mocked_rwc, mocked_time, mocked_logging):
457+
project_id = 12345
458+
459+
def run_with_config(config):
460+
self.assertIsNone(config['load'].get('clustering'))
461+
mocked_rwc.side_effect = run_with_config
462+
463+
bq_hook = hook.BigQueryBaseCursor(mock.Mock(), project_id)
464+
bq_hook.run_load(
465+
destination_project_dataset_table='my_dataset.my_table',
466+
schema_fields=[],
467+
source_uris=[],
468+
)
469+
470+
mocked_rwc.assert_called_once()
471+
472+
@mock.patch("airflow.contrib.hooks.bigquery_hook.LoggingMixin")
473+
@mock.patch("airflow.contrib.hooks.bigquery_hook.time")
474+
@mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
475+
def test_run_load_with_arg(self, mocked_rwc, mocked_time, mocked_logging):
476+
project_id = 12345
477+
478+
def run_with_config(config):
479+
self.assertEqual(
480+
config['load']['clustering'],
481+
{
482+
'fields': ['field1', 'field2']
483+
}
484+
)
485+
mocked_rwc.side_effect = run_with_config
486+
487+
bq_hook = hook.BigQueryBaseCursor(mock.Mock(), project_id)
488+
bq_hook.run_load(
489+
destination_project_dataset_table='my_dataset.my_table',
490+
schema_fields=[],
491+
source_uris=[],
492+
cluster_fields=['field1', 'field2'],
493+
time_partitioning={'type': 'DAY'}
494+
)
495+
496+
mocked_rwc.assert_called_once()
497+
498+
@mock.patch("airflow.contrib.hooks.bigquery_hook.LoggingMixin")
499+
@mock.patch("airflow.contrib.hooks.bigquery_hook.time")
500+
@mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
501+
def test_run_query_default(self, mocked_rwc, mocked_time, mocked_logging):
502+
project_id = 12345
503+
504+
def run_with_config(config):
505+
self.assertIsNone(config['query'].get('clustering'))
506+
mocked_rwc.side_effect = run_with_config
507+
508+
bq_hook = hook.BigQueryBaseCursor(mock.Mock(), project_id)
509+
bq_hook.run_query(sql='select 1')
510+
511+
mocked_rwc.assert_called_once()
512+
513+
@mock.patch("airflow.contrib.hooks.bigquery_hook.LoggingMixin")
514+
@mock.patch("airflow.contrib.hooks.bigquery_hook.time")
515+
@mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
516+
def test_run_query_with_arg(self, mocked_rwc, mocked_time, mocked_logging):
517+
project_id = 12345
518+
519+
def run_with_config(config):
520+
self.assertEqual(
521+
config['query']['clustering'],
522+
{
523+
'fields': ['field1', 'field2']
524+
}
525+
)
526+
mocked_rwc.side_effect = run_with_config
527+
528+
bq_hook = hook.BigQueryBaseCursor(mock.Mock(), project_id)
529+
bq_hook.run_query(
530+
sql='select 1',
531+
destination_dataset_table='my_dataset.my_table',
532+
cluster_fields=['field1', 'field2'],
533+
time_partitioning={'type': 'DAY'}
534+
)
535+
536+
mocked_rwc.assert_called_once()
537+
538+
451539
class TestBigQueryHookLegacySql(unittest.TestCase):
452540
"""Ensure `use_legacy_sql` param in `BigQueryHook` propagates properly."""
453541

0 commit comments

Comments
 (0)