Skip to content

Commit f99a027

Browse files
xnuinsidekaxil
authored andcommitted
[AIRFLOW-461] Support autodetected schemas in BigQuery run_load (#3880)
1 parent 07fd6e1 commit f99a027

File tree

4 files changed

+39
-14
lines changed

4 files changed

+39
-14
lines changed

airflow/contrib/hooks/bigquery_hook.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -849,8 +849,8 @@ def run_copy(self,
849849

850850
def run_load(self,
851851
destination_project_dataset_table,
852-
schema_fields,
853852
source_uris,
853+
schema_fields=None,
854854
source_format='CSV',
855855
create_disposition='CREATE_IF_NEEDED',
856856
skip_leading_rows=0,
@@ -864,7 +864,8 @@ def run_load(self,
864864
schema_update_options=(),
865865
src_fmt_configs=None,
866866
time_partitioning=None,
867-
cluster_fields=None):
867+
cluster_fields=None,
868+
autodetect=False):
868869
"""
869870
Executes a BigQuery load command to load data from Google Cloud Storage
870871
to BigQuery. See here:
@@ -882,7 +883,11 @@ def run_load(self,
882883
:type destination_project_dataset_table: string
883884
:param schema_fields: The schema field list as defined here:
884885
https://cloud.google.com/bigquery/docs/reference/v2/jobs#configuration.load
886+
Required if autodetect=False; optional if autodetect=True.
885887
:type schema_fields: list
888+
:param autodetect: Attempt to autodetect the schema for CSV and JSON
889+
source files.
890+
:type autodetect: bool
886891
:param source_uris: The source Google Cloud
887892
Storage URI (e.g. gs://some-bucket/some-file.txt). A single wild
888893
per-object name can be used.
@@ -937,6 +942,11 @@ def run_load(self,
937942
# if it's not, we raise a ValueError
938943
# Refer to this link for more details:
939944
# https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.query.tableDefinitions.(key).sourceFormat
945+
946+
if schema_fields is None and not autodetect:
947+
raise ValueError(
948+
'You must either pass a schema or autodetect=True.')
949+
940950
if src_fmt_configs is None:
941951
src_fmt_configs = {}
942952

@@ -971,6 +981,7 @@ def run_load(self,
971981

972982
configuration = {
973983
'load': {
984+
'autodetect': autodetect,
974985
'createDisposition': create_disposition,
975986
'destinationTable': {
976987
'projectId': destination_project,
@@ -1734,7 +1745,7 @@ def _split_tablename(table_input, default_project_id, var_name=None):
17341745

17351746
if '.' not in table_input:
17361747
raise ValueError(
1737-
'Expected deletion_dataset_table name in the format of '
1748+
'Expected target table name in the format of '
17381749
'<dataset>.<table>. Got: {}'.format(table_input))
17391750

17401751
if not default_project_id:

airflow/contrib/operators/bigquery_operator.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def __init__(self,
306306
project_id=None,
307307
schema_fields=None,
308308
gcs_schema_object=None,
309-
time_partitioning={},
309+
time_partitioning=None,
310310
bigquery_conn_id='bigquery_default',
311311
google_cloud_storage_conn_id='google_cloud_default',
312312
delegate_to=None,
@@ -323,7 +323,7 @@ def __init__(self,
323323
self.bigquery_conn_id = bigquery_conn_id
324324
self.google_cloud_storage_conn_id = google_cloud_storage_conn_id
325325
self.delegate_to = delegate_to
326-
self.time_partitioning = time_partitioning
326+
self.time_partitioning = {} if time_partitioning is None else time_partitioning
327327
self.labels = labels
328328

329329
def execute(self, context):

airflow/contrib/operators/gcs_to_bq.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def __init__(self,
152152
external_table=False,
153153
time_partitioning=None,
154154
cluster_fields=None,
155+
autodetect=False,
155156
*args, **kwargs):
156157

157158
super(GoogleCloudStorageToBigQueryOperator, self).__init__(*args, **kwargs)
@@ -190,20 +191,24 @@ def __init__(self,
190191
self.src_fmt_configs = src_fmt_configs
191192
self.time_partitioning = time_partitioning
192193
self.cluster_fields = cluster_fields
194+
self.autodetect = autodetect
193195

194196
def execute(self, context):
195197
bq_hook = BigQueryHook(bigquery_conn_id=self.bigquery_conn_id,
196198
delegate_to=self.delegate_to)
197199

198-
if not self.schema_fields and \
199-
self.schema_object and \
200-
self.source_format != 'DATASTORE_BACKUP':
201-
gcs_hook = GoogleCloudStorageHook(
202-
google_cloud_storage_conn_id=self.google_cloud_storage_conn_id,
203-
delegate_to=self.delegate_to)
204-
schema_fields = json.loads(gcs_hook.download(
205-
self.bucket,
206-
self.schema_object).decode("utf-8"))
200+
if not self.schema_fields:
201+
if self.schema_object and self.source_format != 'DATASTORE_BACKUP':
202+
gcs_hook = GoogleCloudStorageHook(
203+
google_cloud_storage_conn_id=self.google_cloud_storage_conn_id,
204+
delegate_to=self.delegate_to)
205+
schema_fields = json.loads(gcs_hook.download(
206+
self.bucket,
207+
self.schema_object).decode("utf-8"))
208+
elif self.schema_object is None and self.autodetect is False:
209+
raise ValueError('At least one of `schema_fields`, `schema_object`, '
210+
'or `autodetect` must be passed.')
211+
207212
else:
208213
schema_fields = self.schema_fields
209214

@@ -234,6 +239,7 @@ def execute(self, context):
234239
schema_fields=schema_fields,
235240
source_uris=source_uris,
236241
source_format=self.source_format,
242+
autodetect=self.autodetect,
237243
create_disposition=self.create_disposition,
238244
skip_leading_rows=self.skip_leading_rows,
239245
write_disposition=self.write_disposition,

tests/contrib/hooks/test_bigquery_hook.py

+8
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,14 @@ def run_with_config(config):
443443

444444
mocked_rwc.assert_called_once()
445445

446+
@mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
447+
def test_run_with_auto_detect(self, run_with_config):
448+
destination_project_dataset_table = "autodetect.table"
449+
cursor = hook.BigQueryBaseCursor(mock.Mock(), "project_id")
450+
cursor.run_load(destination_project_dataset_table, [], [], autodetect=True)
451+
args, kwargs = run_with_config.call_args
452+
self.assertIs(args[0]['load']['autodetect'], True)
453+
446454
@mock.patch("airflow.contrib.hooks.bigquery_hook.LoggingMixin")
447455
@mock.patch("airflow.contrib.hooks.bigquery_hook.time")
448456
@mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')

0 commit comments

Comments
 (0)