Skip to content

Commit de6d19f

Browse files
xnuinsideChris Fei
authored and
Chris Fei
committedJan 23, 2019
Add feature to pass extra api configs to BQ Hook (apache#3733)
1 parent a5dcbde commit de6d19f

File tree

3 files changed

+208
-99
lines changed

3 files changed

+208
-99
lines changed
 

‎airflow/contrib/hooks/bigquery_hook.py

+148-85
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import time
2626
from builtins import range
27+
from copy import deepcopy
2728

2829
from past.builtins import basestring
2930

@@ -195,10 +196,19 @@ class BigQueryBaseCursor(LoggingMixin):
195196
PEP 249 cursor isn't needed.
196197
"""
197198

198-
def __init__(self, service, project_id, use_legacy_sql=True):
199+
def __init__(self,
200+
service,
201+
project_id,
202+
use_legacy_sql=True,
203+
api_resource_configs=None):
204+
199205
self.service = service
200206
self.project_id = project_id
201207
self.use_legacy_sql = use_legacy_sql
208+
if api_resource_configs:
209+
_validate_value("api_resource_configs", api_resource_configs, dict)
210+
self.api_resource_configs = api_resource_configs \
211+
if api_resource_configs else {}
202212
self.running_job_id = None
203213

204214
def create_empty_table(self,
@@ -238,8 +248,7 @@ def create_empty_table(self,
238248
239249
:return:
240250
"""
241-
if time_partitioning is None:
242-
time_partitioning = dict()
251+
243252
project_id = project_id if project_id is not None else self.project_id
244253

245254
table_resource = {
@@ -473,11 +482,11 @@ def create_external_table(self,
473482
def run_query(self,
474483
bql=None,
475484
sql=None,
476-
destination_dataset_table=False,
485+
destination_dataset_table=None,
477486
write_disposition='WRITE_EMPTY',
478487
allow_large_results=False,
479488
flatten_results=None,
480-
udf_config=False,
489+
udf_config=None,
481490
use_legacy_sql=None,
482491
maximum_billing_tier=None,
483492
maximum_bytes_billed=None,
@@ -486,7 +495,8 @@ def run_query(self,
486495
labels=None,
487496
schema_update_options=(),
488497
priority='INTERACTIVE',
489-
time_partitioning=None):
498+
time_partitioning=None,
499+
api_resource_configs=None):
490500
"""
491501
Executes a BigQuery SQL query. Optionally persists results in a BigQuery
492502
table. See here:
@@ -518,6 +528,13 @@ def run_query(self,
518528
:param use_legacy_sql: Whether to use legacy SQL (true) or standard SQL (false).
519529
If `None`, defaults to `self.use_legacy_sql`.
520530
:type use_legacy_sql: boolean
531+
:param api_resource_configs: a dictionary that contain params
532+
'configuration' applied for Google BigQuery Jobs API:
533+
https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs
534+
for example, {'query': {'useQueryCache': False}}. You could use it
535+
if you need to provide some params that are not supported by the
536+
BigQueryHook like args.
537+
:type api_resource_configs: dict
521538
:type udf_config: list
522539
:param maximum_billing_tier: Positive integer that serves as a
523540
multiplier of the basic price.
@@ -548,12 +565,22 @@ def run_query(self,
548565
:type time_partitioning: dict
549566
550567
"""
568+
if not api_resource_configs:
569+
api_resource_configs = self.api_resource_configs
570+
else:
571+
_validate_value('api_resource_configs',
572+
api_resource_configs, dict)
573+
configuration = deepcopy(api_resource_configs)
574+
if 'query' not in configuration:
575+
configuration['query'] = {}
576+
577+
else:
578+
_validate_value("api_resource_configs['query']",
579+
configuration['query'], dict)
551580

552-
# TODO remove `bql` in Airflow 2.0 - Jira: [AIRFLOW-2513]
553-
if time_partitioning is None:
554-
time_partitioning = {}
555581
sql = bql if sql is None else sql
556582

583+
# TODO remove `bql` in Airflow 2.0 - Jira: [AIRFLOW-2513]
557584
if bql:
558585
import warnings
559586
warnings.warn('Deprecated parameter `bql` used in '
@@ -564,95 +591,109 @@ def run_query(self,
564591
'Airflow.',
565592
category=DeprecationWarning)
566593

567-
if sql is None:
568-
raise TypeError('`BigQueryBaseCursor.run_query` missing 1 required '
569-
'positional argument: `sql`')
594+
if sql is None and not configuration['query'].get('query', None):
595+
raise TypeError('`BigQueryBaseCursor.run_query` '
596+
'missing 1 required positional argument: `sql`')
570597

571598
# BigQuery also allows you to define how you want a table's schema to change
572599
# as a side effect of a query job
573600
# for more details:
574601
# https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.query.schemaUpdateOptions
602+
575603
allowed_schema_update_options = [
576604
'ALLOW_FIELD_ADDITION', "ALLOW_FIELD_RELAXATION"
577605
]
578-
if not set(allowed_schema_update_options).issuperset(
579-
set(schema_update_options)):
580-
raise ValueError(
581-
"{0} contains invalid schema update options. "
582-
"Please only use one or more of the following options: {1}"
583-
.format(schema_update_options, allowed_schema_update_options))
584606

585-
if use_legacy_sql is None:
586-
use_legacy_sql = self.use_legacy_sql
607+
if not set(allowed_schema_update_options
608+
).issuperset(set(schema_update_options)):
609+
raise ValueError("{0} contains invalid schema update options. "
610+
"Please only use one or more of the following "
611+
"options: {1}"
612+
.format(schema_update_options,
613+
allowed_schema_update_options))
587614

588-
configuration = {
589-
'query': {
590-
'query': sql,
591-
'useLegacySql': use_legacy_sql,
592-
'maximumBillingTier': maximum_billing_tier,
593-
'maximumBytesBilled': maximum_bytes_billed,
594-
'priority': priority
595-
}
596-
}
615+
if schema_update_options:
616+
if write_disposition not in ["WRITE_APPEND", "WRITE_TRUNCATE"]:
617+
raise ValueError("schema_update_options is only "
618+
"allowed if write_disposition is "
619+
"'WRITE_APPEND' or 'WRITE_TRUNCATE'.")
597620

598621
if destination_dataset_table:
599-
if '.' not in destination_dataset_table:
600-
raise ValueError(
601-
'Expected destination_dataset_table name in the format of '
602-
'<dataset>.<table>. Got: {}'.format(
603-
destination_dataset_table))
604622
destination_project, destination_dataset, destination_table = \
605623
_split_tablename(table_input=destination_dataset_table,
606624
default_project_id=self.project_id)
607-
configuration['query'].update({
608-
'allowLargeResults': allow_large_results,
609-
'flattenResults': flatten_results,
610-
'writeDisposition': write_disposition,
611-
'createDisposition': create_disposition,
612-
'destinationTable': {
613-
'projectId': destination_project,
614-
'datasetId': destination_dataset,
615-
'tableId': destination_table,
616-
}
617-
})
618-
if udf_config:
619-
if not isinstance(udf_config, list):
620-
raise TypeError("udf_config argument must have a type 'list'"
621-
" not {}".format(type(udf_config)))
622-
configuration['query'].update({
623-
'userDefinedFunctionResources': udf_config
624-
})
625625

626-
if query_params:
627-
if self.use_legacy_sql:
628-
raise ValueError("Query parameters are not allowed when using "
629-
"legacy SQL")
630-
else:
631-
configuration['query']['queryParameters'] = query_params
626+
destination_dataset_table = {
627+
'projectId': destination_project,
628+
'datasetId': destination_dataset,
629+
'tableId': destination_table,
630+
}
632631

633-
if labels:
634-
configuration['labels'] = labels
632+
query_param_list = [
633+
(sql, 'query', None, str),
634+
(priority, 'priority', 'INTERACTIVE', str),
635+
(use_legacy_sql, 'useLegacySql', self.use_legacy_sql, bool),
636+
(query_params, 'queryParameters', None, dict),
637+
(udf_config, 'userDefinedFunctionResources', None, list),
638+
(maximum_billing_tier, 'maximumBillingTier', None, int),
639+
(maximum_bytes_billed, 'maximumBytesBilled', None, float),
640+
(time_partitioning, 'timePartitioning', {}, dict),
641+
(schema_update_options, 'schemaUpdateOptions', None, tuple),
642+
(destination_dataset_table, 'destinationTable', None, dict)
643+
]
635644

636-
time_partitioning = _cleanse_time_partitioning(
637-
destination_dataset_table,
638-
time_partitioning
639-
)
640-
if time_partitioning:
641-
configuration['query'].update({
642-
'timePartitioning': time_partitioning
643-
})
645+
for param_tuple in query_param_list:
644646

645-
if schema_update_options:
646-
if write_disposition not in ["WRITE_APPEND", "WRITE_TRUNCATE"]:
647-
raise ValueError("schema_update_options is only "
648-
"allowed if write_disposition is "
649-
"'WRITE_APPEND' or 'WRITE_TRUNCATE'.")
650-
else:
651-
self.log.info(
652-
"Adding experimental "
653-
"'schemaUpdateOptions': {0}".format(schema_update_options))
654-
configuration['query'][
655-
'schemaUpdateOptions'] = schema_update_options
647+
param, param_name, param_default, param_type = param_tuple
648+
649+
if param_name not in configuration['query'] and param in [None, {}, ()]:
650+
if param_name == 'timePartitioning':
651+
param_default = _cleanse_time_partitioning(
652+
destination_dataset_table, time_partitioning)
653+
param = param_default
654+
655+
if param not in [None, {}, ()]:
656+
_api_resource_configs_duplication_check(
657+
param_name, param, configuration['query'])
658+
659+
configuration['query'][param_name] = param
660+
661+
# check valid type of provided param,
662+
# it last step because we can get param from 2 sources,
663+
# and first of all need to find it
664+
665+
_validate_value(param_name, configuration['query'][param_name],
666+
param_type)
667+
668+
if param_name == 'schemaUpdateOptions' and param:
669+
self.log.info("Adding experimental 'schemaUpdateOptions': "
670+
"{0}".format(schema_update_options))
671+
672+
if param_name == 'destinationTable':
673+
for key in ['projectId', 'datasetId', 'tableId']:
674+
if key not in configuration['query']['destinationTable']:
675+
raise ValueError(
676+
"Not correct 'destinationTable' in "
677+
"api_resource_configs. 'destinationTable' "
678+
"must be a dict with {'projectId':'', "
679+
"'datasetId':'', 'tableId':''}")
680+
681+
configuration['query'].update({
682+
'allowLargeResults': allow_large_results,
683+
'flattenResults': flatten_results,
684+
'writeDisposition': write_disposition,
685+
'createDisposition': create_disposition,
686+
})
687+
688+
if 'useLegacySql' in configuration['query'] and \
689+
'queryParameters' in configuration['query']:
690+
raise ValueError("Query parameters are not allowed "
691+
"when using legacy SQL")
692+
693+
if labels:
694+
_api_resource_configs_duplication_check(
695+
'labels', labels, configuration)
696+
configuration['labels'] = labels
656697

657698
return self.run_with_configuration(configuration)
658699

@@ -884,8 +925,7 @@ def run_load(self,
884925
# https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.query.tableDefinitions.(key).sourceFormat
885926
if src_fmt_configs is None:
886927
src_fmt_configs = {}
887-
if time_partitioning is None:
888-
time_partitioning = {}
928+
889929
source_format = source_format.upper()
890930
allowed_formats = [
891931
"CSV", "NEWLINE_DELIMITED_JSON", "AVRO", "GOOGLE_SHEETS",
@@ -1163,10 +1203,6 @@ def run_table_delete(self, deletion_dataset_table,
11631203
:type ignore_if_missing: boolean
11641204
:return:
11651205
"""
1166-
if '.' not in deletion_dataset_table:
1167-
raise ValueError(
1168-
'Expected deletion_dataset_table name in the format of '
1169-
'<dataset>.<table>. Got: {}'.format(deletion_dataset_table))
11701206
deletion_project, deletion_dataset, deletion_table = \
11711207
_split_tablename(table_input=deletion_dataset_table,
11721208
default_project_id=self.project_id)
@@ -1532,6 +1568,12 @@ def _bq_cast(string_field, bq_type):
15321568

15331569

15341570
def _split_tablename(table_input, default_project_id, var_name=None):
1571+
1572+
if '.' not in table_input:
1573+
raise ValueError(
1574+
'Expected deletion_dataset_table name in the format of '
1575+
'<dataset>.<table>. Got: {}'.format(table_input))
1576+
15351577
if not default_project_id:
15361578
raise ValueError("INTERNAL: No default project is specified")
15371579

@@ -1593,8 +1635,29 @@ def var_print(var_name):
15931635

15941636
def _cleanse_time_partitioning(destination_dataset_table, time_partitioning_in):
15951637
# if it is a partitioned table ($ is in the table name) add partition load option
1638+
1639+
if time_partitioning_in is None:
1640+
time_partitioning_in = {}
1641+
15961642
time_partitioning_out = {}
15971643
if destination_dataset_table and '$' in destination_dataset_table:
15981644
time_partitioning_out['type'] = 'DAY'
15991645
time_partitioning_out.update(time_partitioning_in)
16001646
return time_partitioning_out
1647+
1648+
1649+
def _validate_value(key, value, expected_type):
1650+
""" function to check expected type and raise
1651+
error if type is not correct """
1652+
if not isinstance(value, expected_type):
1653+
raise TypeError("{} argument must have a type {} not {}".format(
1654+
key, expected_type, type(value)))
1655+
1656+
1657+
def _api_resource_configs_duplication_check(key, value, config_dict):
1658+
if key in config_dict and value != config_dict[key]:
1659+
raise ValueError("Values of {param_name} param are duplicated. "
1660+
"`api_resource_configs` contained {param_name} param "
1661+
"in `query` config and {param_name} was also provided "
1662+
"with arg to run_query() method. Please remove duplicates."
1663+
.format(param_name=key))

‎airflow/contrib/operators/bigquery_operator.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,13 @@ class BigQueryOperator(BaseOperator):
7575
(without incurring a charge). If unspecified, this will be
7676
set to your project default.
7777
:type maximum_bytes_billed: float
78+
:param api_resource_configs: a dictionary that contain params
79+
'configuration' applied for Google BigQuery Jobs API:
80+
https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs
81+
for example, {'query': {'useQueryCache': False}}. You could use it
82+
if you need to provide some params that are not supported by BigQueryOperator
83+
like args.
84+
:type api_resource_configs: dict
7885
:param schema_update_options: Allows the schema of the destination
7986
table to be updated as a side effect of the load job.
8087
:type schema_update_options: tuple
@@ -116,7 +123,8 @@ def __init__(self,
116123
query_params=None,
117124
labels=None,
118125
priority='INTERACTIVE',
119-
time_partitioning={},
126+
time_partitioning=None,
127+
api_resource_configs=None,
120128
*args,
121129
**kwargs):
122130
super(BigQueryOperator, self).__init__(*args, **kwargs)
@@ -138,7 +146,10 @@ def __init__(self,
138146
self.labels = labels
139147
self.bq_cursor = None
140148
self.priority = priority
141-
self.time_partitioning = time_partitioning
149+
if time_partitioning is None:
150+
self.time_partitioning = {}
151+
if api_resource_configs is None:
152+
self.api_resource_configs = {}
142153

143154
# TODO remove `bql` in Airflow 2.0
144155
if self.bql:
@@ -177,7 +188,8 @@ def execute(self, context):
177188
labels=self.labels,
178189
schema_update_options=self.schema_update_options,
179190
priority=self.priority,
180-
time_partitioning=self.time_partitioning
191+
time_partitioning=self.time_partitioning,
192+
api_resource_configs=self.api_resource_configs,
181193
)
182194

183195
def on_kill(self):

0 commit comments

Comments
 (0)