Skip to content

Commit aae770c

Browse files
xnuinsideAlice Berard
authored and
Alice Berard
committed
[AIRFLOW-491] Add feature to pass extra api configs to BQ Hook (apache#3733)
1 parent 5595c1e commit aae770c

File tree

3 files changed

+208
-99
lines changed

3 files changed

+208
-99
lines changed

airflow/contrib/hooks/bigquery_hook.py

+148-85
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import time
2626
from builtins import range
27+
from copy import deepcopy
2728

2829
from past.builtins import basestring
2930

@@ -195,10 +196,19 @@ class BigQueryBaseCursor(LoggingMixin):
195196
PEP 249 cursor isn't needed.
196197
"""
197198

198-
def __init__(self, service, project_id, use_legacy_sql=True):
199+
def __init__(self,
200+
service,
201+
project_id,
202+
use_legacy_sql=True,
203+
api_resource_configs=None):
204+
199205
self.service = service
200206
self.project_id = project_id
201207
self.use_legacy_sql = use_legacy_sql
208+
if api_resource_configs:
209+
_validate_value("api_resource_configs", api_resource_configs, dict)
210+
self.api_resource_configs = api_resource_configs \
211+
if api_resource_configs else {}
202212
self.running_job_id = None
203213

204214
def create_empty_table(self,
@@ -238,8 +248,7 @@ def create_empty_table(self,
238248
239249
:return:
240250
"""
241-
if time_partitioning is None:
242-
time_partitioning = dict()
251+
243252
project_id = project_id if project_id is not None else self.project_id
244253

245254
table_resource = {
@@ -473,11 +482,11 @@ def create_external_table(self,
473482
def run_query(self,
474483
bql=None,
475484
sql=None,
476-
destination_dataset_table=False,
485+
destination_dataset_table=None,
477486
write_disposition='WRITE_EMPTY',
478487
allow_large_results=False,
479488
flatten_results=None,
480-
udf_config=False,
489+
udf_config=None,
481490
use_legacy_sql=None,
482491
maximum_billing_tier=None,
483492
maximum_bytes_billed=None,
@@ -486,7 +495,8 @@ def run_query(self,
486495
labels=None,
487496
schema_update_options=(),
488497
priority='INTERACTIVE',
489-
time_partitioning=None):
498+
time_partitioning=None,
499+
api_resource_configs=None):
490500
"""
491501
Executes a BigQuery SQL query. Optionally persists results in a BigQuery
492502
table. See here:
@@ -518,6 +528,13 @@ def run_query(self,
518528
:param use_legacy_sql: Whether to use legacy SQL (true) or standard SQL (false).
519529
If `None`, defaults to `self.use_legacy_sql`.
520530
:type use_legacy_sql: boolean
531+
:param api_resource_configs: a dictionary that contain params
532+
'configuration' applied for Google BigQuery Jobs API:
533+
https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs
534+
for example, {'query': {'useQueryCache': False}}. You could use it
535+
if you need to provide some params that are not supported by the
536+
BigQueryHook like args.
537+
:type api_resource_configs: dict
521538
:type udf_config: list
522539
:param maximum_billing_tier: Positive integer that serves as a
523540
multiplier of the basic price.
@@ -550,12 +567,22 @@ def run_query(self,
550567
:type time_partitioning: dict
551568
552569
"""
570+
if not api_resource_configs:
571+
api_resource_configs = self.api_resource_configs
572+
else:
573+
_validate_value('api_resource_configs',
574+
api_resource_configs, dict)
575+
configuration = deepcopy(api_resource_configs)
576+
if 'query' not in configuration:
577+
configuration['query'] = {}
578+
579+
else:
580+
_validate_value("api_resource_configs['query']",
581+
configuration['query'], dict)
553582

554-
# TODO remove `bql` in Airflow 2.0 - Jira: [AIRFLOW-2513]
555-
if time_partitioning is None:
556-
time_partitioning = {}
557583
sql = bql if sql is None else sql
558584

585+
# TODO remove `bql` in Airflow 2.0 - Jira: [AIRFLOW-2513]
559586
if bql:
560587
import warnings
561588
warnings.warn('Deprecated parameter `bql` used in '
@@ -566,95 +593,109 @@ def run_query(self,
566593
'Airflow.',
567594
category=DeprecationWarning)
568595

569-
if sql is None:
570-
raise TypeError('`BigQueryBaseCursor.run_query` missing 1 required '
571-
'positional argument: `sql`')
596+
if sql is None and not configuration['query'].get('query', None):
597+
raise TypeError('`BigQueryBaseCursor.run_query` '
598+
'missing 1 required positional argument: `sql`')
572599

573600
# BigQuery also allows you to define how you want a table's schema to change
574601
# as a side effect of a query job
575602
# for more details:
576603
# https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.query.schemaUpdateOptions
604+
577605
allowed_schema_update_options = [
578606
'ALLOW_FIELD_ADDITION', "ALLOW_FIELD_RELAXATION"
579607
]
580-
if not set(allowed_schema_update_options).issuperset(
581-
set(schema_update_options)):
582-
raise ValueError(
583-
"{0} contains invalid schema update options. "
584-
"Please only use one or more of the following options: {1}"
585-
.format(schema_update_options, allowed_schema_update_options))
586608

587-
if use_legacy_sql is None:
588-
use_legacy_sql = self.use_legacy_sql
609+
if not set(allowed_schema_update_options
610+
).issuperset(set(schema_update_options)):
611+
raise ValueError("{0} contains invalid schema update options. "
612+
"Please only use one or more of the following "
613+
"options: {1}"
614+
.format(schema_update_options,
615+
allowed_schema_update_options))
589616

590-
configuration = {
591-
'query': {
592-
'query': sql,
593-
'useLegacySql': use_legacy_sql,
594-
'maximumBillingTier': maximum_billing_tier,
595-
'maximumBytesBilled': maximum_bytes_billed,
596-
'priority': priority
597-
}
598-
}
617+
if schema_update_options:
618+
if write_disposition not in ["WRITE_APPEND", "WRITE_TRUNCATE"]:
619+
raise ValueError("schema_update_options is only "
620+
"allowed if write_disposition is "
621+
"'WRITE_APPEND' or 'WRITE_TRUNCATE'.")
599622

600623
if destination_dataset_table:
601-
if '.' not in destination_dataset_table:
602-
raise ValueError(
603-
'Expected destination_dataset_table name in the format of '
604-
'<dataset>.<table>. Got: {}'.format(
605-
destination_dataset_table))
606624
destination_project, destination_dataset, destination_table = \
607625
_split_tablename(table_input=destination_dataset_table,
608626
default_project_id=self.project_id)
609-
configuration['query'].update({
610-
'allowLargeResults': allow_large_results,
611-
'flattenResults': flatten_results,
612-
'writeDisposition': write_disposition,
613-
'createDisposition': create_disposition,
614-
'destinationTable': {
615-
'projectId': destination_project,
616-
'datasetId': destination_dataset,
617-
'tableId': destination_table,
618-
}
619-
})
620-
if udf_config:
621-
if not isinstance(udf_config, list):
622-
raise TypeError("udf_config argument must have a type 'list'"
623-
" not {}".format(type(udf_config)))
624-
configuration['query'].update({
625-
'userDefinedFunctionResources': udf_config
626-
})
627627

628-
if query_params:
629-
if self.use_legacy_sql:
630-
raise ValueError("Query parameters are not allowed when using "
631-
"legacy SQL")
632-
else:
633-
configuration['query']['queryParameters'] = query_params
628+
destination_dataset_table = {
629+
'projectId': destination_project,
630+
'datasetId': destination_dataset,
631+
'tableId': destination_table,
632+
}
634633

635-
if labels:
636-
configuration['labels'] = labels
634+
query_param_list = [
635+
(sql, 'query', None, str),
636+
(priority, 'priority', 'INTERACTIVE', str),
637+
(use_legacy_sql, 'useLegacySql', self.use_legacy_sql, bool),
638+
(query_params, 'queryParameters', None, dict),
639+
(udf_config, 'userDefinedFunctionResources', None, list),
640+
(maximum_billing_tier, 'maximumBillingTier', None, int),
641+
(maximum_bytes_billed, 'maximumBytesBilled', None, float),
642+
(time_partitioning, 'timePartitioning', {}, dict),
643+
(schema_update_options, 'schemaUpdateOptions', None, tuple),
644+
(destination_dataset_table, 'destinationTable', None, dict)
645+
]
637646

638-
time_partitioning = _cleanse_time_partitioning(
639-
destination_dataset_table,
640-
time_partitioning
641-
)
642-
if time_partitioning:
643-
configuration['query'].update({
644-
'timePartitioning': time_partitioning
645-
})
647+
for param_tuple in query_param_list:
646648

647-
if schema_update_options:
648-
if write_disposition not in ["WRITE_APPEND", "WRITE_TRUNCATE"]:
649-
raise ValueError("schema_update_options is only "
650-
"allowed if write_disposition is "
651-
"'WRITE_APPEND' or 'WRITE_TRUNCATE'.")
652-
else:
653-
self.log.info(
654-
"Adding experimental "
655-
"'schemaUpdateOptions': {0}".format(schema_update_options))
656-
configuration['query'][
657-
'schemaUpdateOptions'] = schema_update_options
649+
param, param_name, param_default, param_type = param_tuple
650+
651+
if param_name not in configuration['query'] and param in [None, {}, ()]:
652+
if param_name == 'timePartitioning':
653+
param_default = _cleanse_time_partitioning(
654+
destination_dataset_table, time_partitioning)
655+
param = param_default
656+
657+
if param not in [None, {}, ()]:
658+
_api_resource_configs_duplication_check(
659+
param_name, param, configuration['query'])
660+
661+
configuration['query'][param_name] = param
662+
663+
# check valid type of provided param,
664+
# it last step because we can get param from 2 sources,
665+
# and first of all need to find it
666+
667+
_validate_value(param_name, configuration['query'][param_name],
668+
param_type)
669+
670+
if param_name == 'schemaUpdateOptions' and param:
671+
self.log.info("Adding experimental 'schemaUpdateOptions': "
672+
"{0}".format(schema_update_options))
673+
674+
if param_name == 'destinationTable':
675+
for key in ['projectId', 'datasetId', 'tableId']:
676+
if key not in configuration['query']['destinationTable']:
677+
raise ValueError(
678+
"Not correct 'destinationTable' in "
679+
"api_resource_configs. 'destinationTable' "
680+
"must be a dict with {'projectId':'', "
681+
"'datasetId':'', 'tableId':''}")
682+
683+
configuration['query'].update({
684+
'allowLargeResults': allow_large_results,
685+
'flattenResults': flatten_results,
686+
'writeDisposition': write_disposition,
687+
'createDisposition': create_disposition,
688+
})
689+
690+
if 'useLegacySql' in configuration['query'] and \
691+
'queryParameters' in configuration['query']:
692+
raise ValueError("Query parameters are not allowed "
693+
"when using legacy SQL")
694+
695+
if labels:
696+
_api_resource_configs_duplication_check(
697+
'labels', labels, configuration)
698+
configuration['labels'] = labels
658699

659700
return self.run_with_configuration(configuration)
660701

@@ -888,8 +929,7 @@ def run_load(self,
888929
# https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.query.tableDefinitions.(key).sourceFormat
889930
if src_fmt_configs is None:
890931
src_fmt_configs = {}
891-
if time_partitioning is None:
892-
time_partitioning = {}
932+
893933
source_format = source_format.upper()
894934
allowed_formats = [
895935
"CSV", "NEWLINE_DELIMITED_JSON", "AVRO", "GOOGLE_SHEETS",
@@ -1167,10 +1207,6 @@ def run_table_delete(self, deletion_dataset_table,
11671207
:type ignore_if_missing: boolean
11681208
:return:
11691209
"""
1170-
if '.' not in deletion_dataset_table:
1171-
raise ValueError(
1172-
'Expected deletion_dataset_table name in the format of '
1173-
'<dataset>.<table>. Got: {}'.format(deletion_dataset_table))
11741210
deletion_project, deletion_dataset, deletion_table = \
11751211
_split_tablename(table_input=deletion_dataset_table,
11761212
default_project_id=self.project_id)
@@ -1536,6 +1572,12 @@ def _bq_cast(string_field, bq_type):
15361572

15371573

15381574
def _split_tablename(table_input, default_project_id, var_name=None):
1575+
1576+
if '.' not in table_input:
1577+
raise ValueError(
1578+
'Expected deletion_dataset_table name in the format of '
1579+
'<dataset>.<table>. Got: {}'.format(table_input))
1580+
15391581
if not default_project_id:
15401582
raise ValueError("INTERNAL: No default project is specified")
15411583

@@ -1597,6 +1639,10 @@ def var_print(var_name):
15971639

15981640
def _cleanse_time_partitioning(destination_dataset_table, time_partitioning_in):
15991641
# if it is a partitioned table ($ is in the table name) add partition load option
1642+
1643+
if time_partitioning_in is None:
1644+
time_partitioning_in = {}
1645+
16001646
time_partitioning_out = {}
16011647
if destination_dataset_table and '$' in destination_dataset_table:
16021648
if time_partitioning_in.get('field'):
@@ -1607,3 +1653,20 @@ def _cleanse_time_partitioning(destination_dataset_table, time_partitioning_in):
16071653

16081654
time_partitioning_out.update(time_partitioning_in)
16091655
return time_partitioning_out
1656+
1657+
1658+
def _validate_value(key, value, expected_type):
1659+
""" function to check expected type and raise
1660+
error if type is not correct """
1661+
if not isinstance(value, expected_type):
1662+
raise TypeError("{} argument must have a type {} not {}".format(
1663+
key, expected_type, type(value)))
1664+
1665+
1666+
def _api_resource_configs_duplication_check(key, value, config_dict):
1667+
if key in config_dict and value != config_dict[key]:
1668+
raise ValueError("Values of {param_name} param are duplicated. "
1669+
"`api_resource_configs` contained {param_name} param "
1670+
"in `query` config and {param_name} was also provided "
1671+
"with arg to run_query() method. Please remove duplicates."
1672+
.format(param_name=key))

airflow/contrib/operators/bigquery_operator.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,13 @@ class BigQueryOperator(BaseOperator):
7575
(without incurring a charge). If unspecified, this will be
7676
set to your project default.
7777
:type maximum_bytes_billed: float
78+
:param api_resource_configs: a dictionary that contain params
79+
'configuration' applied for Google BigQuery Jobs API:
80+
https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs
81+
for example, {'query': {'useQueryCache': False}}. You could use it
82+
if you need to provide some params that are not supported by BigQueryOperator
83+
like args.
84+
:type api_resource_configs: dict
7885
:param schema_update_options: Allows the schema of the destination
7986
table to be updated as a side effect of the load job.
8087
:type schema_update_options: tuple
@@ -118,7 +125,8 @@ def __init__(self,
118125
query_params=None,
119126
labels=None,
120127
priority='INTERACTIVE',
121-
time_partitioning={},
128+
time_partitioning=None,
129+
api_resource_configs=None,
122130
*args,
123131
**kwargs):
124132
super(BigQueryOperator, self).__init__(*args, **kwargs)
@@ -140,7 +148,10 @@ def __init__(self,
140148
self.labels = labels
141149
self.bq_cursor = None
142150
self.priority = priority
143-
self.time_partitioning = time_partitioning
151+
if time_partitioning is None:
152+
self.time_partitioning = {}
153+
if api_resource_configs is None:
154+
self.api_resource_configs = {}
144155

145156
# TODO remove `bql` in Airflow 2.0
146157
if self.bql:
@@ -179,7 +190,8 @@ def execute(self, context):
179190
labels=self.labels,
180191
schema_update_options=self.schema_update_options,
181192
priority=self.priority,
182-
time_partitioning=self.time_partitioning
193+
time_partitioning=self.time_partitioning,
194+
api_resource_configs=self.api_resource_configs,
183195
)
184196

185197
def on_kill(self):

0 commit comments

Comments
 (0)