Skip to content

Commit abb12be

Browse files
chronitisFokko
authored andcommitted
[AIRFLOW-2997] Support cluster fields in bigquery (#3838)
This adds a cluster_fields argument to the bigquery hook, GCS to bigquery operator and bigquery query operators. This field requests that bigquery store the result of the query/load operation sorted according to the specified fields (the order of fields given is significant).
1 parent 288fca4 commit abb12be

File tree

4 files changed

+125
-5
lines changed

4 files changed

+125
-5
lines changed

airflow/contrib/hooks/bigquery_hook.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,8 @@ def run_query(self,
496496
schema_update_options=(),
497497
priority='INTERACTIVE',
498498
time_partitioning=None,
499-
api_resource_configs=None):
499+
api_resource_configs=None,
500+
cluster_fields=None):
500501
"""
501502
Executes a BigQuery SQL query. Optionally persists results in a BigQuery
502503
table. See here:
@@ -565,8 +566,12 @@ def run_query(self,
565566
expiration as per API specifications. Note that 'field' is not available in
566567
conjunction with dataset.table$partition.
567568
:type time_partitioning: dict
568-
569+
:param cluster_fields: Request that the result of this query be stored sorted
570+
by one or more columns. This is only available in combination with
571+
time_partitioning. The order of columns given determines the sort order.
572+
:type cluster_fields: list of str
569573
"""
574+
570575
if not api_resource_configs:
571576
api_resource_configs = self.api_resource_configs
572577
else:
@@ -631,6 +636,9 @@ def run_query(self,
631636
'tableId': destination_table,
632637
}
633638

639+
if cluster_fields:
640+
cluster_fields = {'fields': cluster_fields}
641+
634642
query_param_list = [
635643
(sql, 'query', None, str),
636644
(priority, 'priority', 'INTERACTIVE', str),
@@ -641,7 +649,8 @@ def run_query(self,
641649
(maximum_bytes_billed, 'maximumBytesBilled', None, float),
642650
(time_partitioning, 'timePartitioning', {}, dict),
643651
(schema_update_options, 'schemaUpdateOptions', None, tuple),
644-
(destination_dataset_table, 'destinationTable', None, dict)
652+
(destination_dataset_table, 'destinationTable', None, dict),
653+
(cluster_fields, 'clustering', None, dict),
645654
]
646655

647656
for param_tuple in query_param_list:
@@ -856,7 +865,8 @@ def run_load(self,
856865
allow_jagged_rows=False,
857866
schema_update_options=(),
858867
src_fmt_configs=None,
859-
time_partitioning=None):
868+
time_partitioning=None,
869+
cluster_fields=None):
860870
"""
861871
Executes a BigQuery load command to load data from Google Cloud Storage
862872
to BigQuery. See here:
@@ -920,6 +930,10 @@ def run_load(self,
920930
expiration as per API specifications. Note that 'field' is not available in
921931
conjunction with dataset.table$partition.
922932
:type time_partitioning: dict
933+
:param cluster_fields: Request that the result of this load be stored sorted
934+
by one or more columns. This is only available in combination with
935+
time_partitioning. The order of columns given determines the sort order.
936+
:type cluster_fields: list of str
923937
"""
924938

925939
# bigquery only allows certain source formats
@@ -983,6 +997,9 @@ def run_load(self,
983997
'timePartitioning': time_partitioning
984998
})
985999

1000+
if cluster_fields:
1001+
configuration['load'].update({'clustering': {'fields': cluster_fields}})
1002+
9861003
if schema_fields:
9871004
configuration['load']['schema'] = {'fields': schema_fields}
9881005

airflow/contrib/operators/bigquery_operator.py

+7
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ class BigQueryOperator(BaseOperator):
100100
expiration as per API specifications. Note that 'field' is not available in
101101
conjunction with dataset.table$partition.
102102
:type time_partitioning: dict
103+
:param cluster_fields: Request that the result of this query be stored sorted
104+
by one or more columns. This is only available in conjunction with
105+
time_partitioning. The order of columns given determines the sort order.
106+
:type cluster_fields: list of str
103107
"""
104108

105109
template_fields = ('bql', 'sql', 'destination_dataset_table', 'labels')
@@ -127,6 +131,7 @@ def __init__(self,
127131
priority='INTERACTIVE',
128132
time_partitioning=None,
129133
api_resource_configs=None,
134+
cluster_fields=None,
130135
*args,
131136
**kwargs):
132137
super(BigQueryOperator, self).__init__(*args, **kwargs)
@@ -152,6 +157,7 @@ def __init__(self,
152157
self.time_partitioning = {}
153158
if api_resource_configs is None:
154159
self.api_resource_configs = {}
160+
self.cluster_fields = cluster_fields
155161

156162
# TODO remove `bql` in Airflow 2.0
157163
if self.bql:
@@ -192,6 +198,7 @@ def execute(self, context):
192198
priority=self.priority,
193199
time_partitioning=self.time_partitioning,
194200
api_resource_configs=self.api_resource_configs,
201+
cluster_fields=self.cluster_fields,
195202
)
196203

197204
def on_kill(self):

airflow/contrib/operators/gcs_to_bq.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,11 @@ class GoogleCloudStorageToBigQueryOperator(BaseOperator):
114114
Note that 'field' is not available in concurrency with
115115
dataset.table$partition.
116116
:type time_partitioning: dict
117+
:param cluster_fields: Request that the result of this load be stored sorted
118+
by one or more columns. This is only available in conjunction with
119+
time_partitioning. The order of columns given determines the sort order.
120+
Not applicable for external tables.
121+
:type cluster_fields: list of str
117122
"""
118123
template_fields = ('bucket', 'source_objects',
119124
'schema_object', 'destination_project_dataset_table')
@@ -146,6 +151,7 @@ def __init__(self,
146151
src_fmt_configs=None,
147152
external_table=False,
148153
time_partitioning=None,
154+
cluster_fields=None,
149155
*args, **kwargs):
150156

151157
super(GoogleCloudStorageToBigQueryOperator, self).__init__(*args, **kwargs)
@@ -183,6 +189,7 @@ def __init__(self,
183189
self.schema_update_options = schema_update_options
184190
self.src_fmt_configs = src_fmt_configs
185191
self.time_partitioning = time_partitioning
192+
self.cluster_fields = cluster_fields
186193

187194
def execute(self, context):
188195
bq_hook = BigQueryHook(bigquery_conn_id=self.bigquery_conn_id,
@@ -238,7 +245,8 @@ def execute(self, context):
238245
allow_jagged_rows=self.allow_jagged_rows,
239246
schema_update_options=self.schema_update_options,
240247
src_fmt_configs=self.src_fmt_configs,
241-
time_partitioning=self.time_partitioning)
248+
time_partitioning=self.time_partitioning,
249+
cluster_fields=self.cluster_fields)
242250

243251
if self.max_id_key:
244252
cursor.execute('SELECT MAX({}) FROM {}'.format(

tests/contrib/hooks/test_bigquery_hook.py

+88
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,94 @@ def test_cant_add_dollar_and_field_name(self):
455455
)
456456

457457

458+
class TestClusteringInRunJob(unittest.TestCase):
459+
460+
@mock.patch("airflow.contrib.hooks.bigquery_hook.LoggingMixin")
461+
@mock.patch("airflow.contrib.hooks.bigquery_hook.time")
462+
@mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
463+
def test_run_load_default(self, mocked_rwc, mocked_time, mocked_logging):
464+
project_id = 12345
465+
466+
def run_with_config(config):
467+
self.assertIsNone(config['load'].get('clustering'))
468+
mocked_rwc.side_effect = run_with_config
469+
470+
bq_hook = hook.BigQueryBaseCursor(mock.Mock(), project_id)
471+
bq_hook.run_load(
472+
destination_project_dataset_table='my_dataset.my_table',
473+
schema_fields=[],
474+
source_uris=[],
475+
)
476+
477+
mocked_rwc.assert_called_once()
478+
479+
@mock.patch("airflow.contrib.hooks.bigquery_hook.LoggingMixin")
480+
@mock.patch("airflow.contrib.hooks.bigquery_hook.time")
481+
@mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
482+
def test_run_load_with_arg(self, mocked_rwc, mocked_time, mocked_logging):
483+
project_id = 12345
484+
485+
def run_with_config(config):
486+
self.assertEqual(
487+
config['load']['clustering'],
488+
{
489+
'fields': ['field1', 'field2']
490+
}
491+
)
492+
mocked_rwc.side_effect = run_with_config
493+
494+
bq_hook = hook.BigQueryBaseCursor(mock.Mock(), project_id)
495+
bq_hook.run_load(
496+
destination_project_dataset_table='my_dataset.my_table',
497+
schema_fields=[],
498+
source_uris=[],
499+
cluster_fields=['field1', 'field2'],
500+
time_partitioning={'type': 'DAY'}
501+
)
502+
503+
mocked_rwc.assert_called_once()
504+
505+
@mock.patch("airflow.contrib.hooks.bigquery_hook.LoggingMixin")
506+
@mock.patch("airflow.contrib.hooks.bigquery_hook.time")
507+
@mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
508+
def test_run_query_default(self, mocked_rwc, mocked_time, mocked_logging):
509+
project_id = 12345
510+
511+
def run_with_config(config):
512+
self.assertIsNone(config['query'].get('clustering'))
513+
mocked_rwc.side_effect = run_with_config
514+
515+
bq_hook = hook.BigQueryBaseCursor(mock.Mock(), project_id)
516+
bq_hook.run_query(sql='select 1')
517+
518+
mocked_rwc.assert_called_once()
519+
520+
@mock.patch("airflow.contrib.hooks.bigquery_hook.LoggingMixin")
521+
@mock.patch("airflow.contrib.hooks.bigquery_hook.time")
522+
@mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
523+
def test_run_query_with_arg(self, mocked_rwc, mocked_time, mocked_logging):
524+
project_id = 12345
525+
526+
def run_with_config(config):
527+
self.assertEqual(
528+
config['query']['clustering'],
529+
{
530+
'fields': ['field1', 'field2']
531+
}
532+
)
533+
mocked_rwc.side_effect = run_with_config
534+
535+
bq_hook = hook.BigQueryBaseCursor(mock.Mock(), project_id)
536+
bq_hook.run_query(
537+
sql='select 1',
538+
destination_dataset_table='my_dataset.my_table',
539+
cluster_fields=['field1', 'field2'],
540+
time_partitioning={'type': 'DAY'}
541+
)
542+
543+
mocked_rwc.assert_called_once()
544+
545+
458546
class TestBigQueryHookLegacySql(unittest.TestCase):
459547
"""Ensure `use_legacy_sql` param in `BigQueryHook` propagates properly."""
460548

0 commit comments

Comments
 (0)