Skip to content

Commit f8d3e93

Browse files
xnuinsideAlice Berard
authored and
Alice Berard
committed
[AIRFLOW-2845] Asserts in contrib package code are changed on raise ValueError and TypeError (apache#3690)
1 parent 6926b19 commit f8d3e93

10 files changed

+74
-58
lines changed

airflow/contrib/hooks/bigquery_hook.py

+29-26
Original file line numberDiff line numberDiff line change
@@ -592,9 +592,11 @@ def run_query(self,
592592
}
593593

594594
if destination_dataset_table:
595-
assert '.' in destination_dataset_table, (
596-
'Expected destination_dataset_table in the format of '
597-
'<dataset>.<table>. Got: {}').format(destination_dataset_table)
595+
if '.' not in destination_dataset_table:
596+
raise ValueError(
597+
'Expected destination_dataset_table name in the format of '
598+
'<dataset>.<table>. Got: {}'.format(
599+
destination_dataset_table))
598600
destination_project, destination_dataset, destination_table = \
599601
_split_tablename(table_input=destination_dataset_table,
600602
default_project_id=self.project_id)
@@ -610,7 +612,9 @@ def run_query(self,
610612
}
611613
})
612614
if udf_config:
613-
assert isinstance(udf_config, list)
615+
if not isinstance(udf_config, list):
616+
raise TypeError("udf_config argument must have a type 'list'"
617+
" not {}".format(type(udf_config)))
614618
configuration['query'].update({
615619
'userDefinedFunctionResources': udf_config
616620
})
@@ -1153,10 +1157,10 @@ def run_table_delete(self, deletion_dataset_table,
11531157
:type ignore_if_missing: boolean
11541158
:return:
11551159
"""
1156-
1157-
assert '.' in deletion_dataset_table, (
1158-
'Expected deletion_dataset_table in the format of '
1159-
'<dataset>.<table>. Got: {}').format(deletion_dataset_table)
1160+
if '.' not in deletion_dataset_table:
1161+
raise ValueError(
1162+
'Expected deletion_dataset_table name in the format of '
1163+
'<dataset>.<table>. Got: {}'.format(deletion_dataset_table))
11601164
deletion_project, deletion_dataset, deletion_table = \
11611165
_split_tablename(table_input=deletion_dataset_table,
11621166
default_project_id=self.project_id)
@@ -1284,14 +1288,10 @@ def run_grant_dataset_view_access(self,
12841288
# if view is already in access, do nothing.
12851289
self.log.info(
12861290
'Table %s:%s.%s already has authorized view access to %s:%s dataset.',
1287-
view_project, view_dataset, view_table, source_project,
1288-
source_dataset)
1291+
view_project, view_dataset, view_table, source_project, source_dataset)
12891292
return source_dataset_resource
12901293

1291-
def delete_dataset(self,
1292-
project_id,
1293-
dataset_id
1294-
):
1294+
def delete_dataset(self, project_id, dataset_id):
12951295
"""
12961296
Delete a dataset of Big query in your project.
12971297
:param project_id: The name of the project where we have the dataset .
@@ -1308,9 +1308,8 @@ def delete_dataset(self,
13081308
self.service.datasets().delete(
13091309
projectId=project_id,
13101310
datasetId=dataset_id).execute()
1311-
1312-
self.log.info('Dataset deleted successfully: In project %s Dataset %s',
1313-
project_id, dataset_id)
1311+
self.log.info('Dataset deleted successfully: In project %s '
1312+
'Dataset %s', project_id, dataset_id)
13141313

13151314
except HttpError as err:
13161315
raise AirflowException(
@@ -1518,14 +1517,17 @@ def _bq_cast(string_field, bq_type):
15181517
elif bq_type == 'FLOAT' or bq_type == 'TIMESTAMP':
15191518
return float(string_field)
15201519
elif bq_type == 'BOOLEAN':
1521-
assert string_field in set(['true', 'false'])
1520+
if string_field not in ['true', 'false']:
1521+
raise ValueError("{} must have value 'true' or 'false'".format(
1522+
string_field))
15221523
return string_field == 'true'
15231524
else:
15241525
return string_field
15251526

15261527

15271528
def _split_tablename(table_input, default_project_id, var_name=None):
1528-
assert default_project_id is not None, "INTERNAL: No default project is specified"
1529+
if not default_project_id:
1530+
raise ValueError("INTERNAL: No default project is specified")
15291531

15301532
def var_print(var_name):
15311533
if var_name is None:
@@ -1537,7 +1539,6 @@ def var_print(var_name):
15371539
raise Exception(('{var}Use either : or . to specify project '
15381540
'got {input}').format(
15391541
var=var_print(var_name), input=table_input))
1540-
15411542
cmpt = table_input.rsplit(':', 1)
15421543
project_id = None
15431544
rest = table_input
@@ -1555,8 +1556,10 @@ def var_print(var_name):
15551556

15561557
cmpt = rest.split('.')
15571558
if len(cmpt) == 3:
1558-
assert project_id is None, ("{var}Use either : or . to specify project"
1559-
).format(var=var_print(var_name))
1559+
if project_id:
1560+
raise ValueError(
1561+
"{var}Use either : or . to specify project".format(
1562+
var=var_print(var_name)))
15601563
project_id = cmpt[0]
15611564
dataset_id = cmpt[1]
15621565
table_id = cmpt[2]
@@ -1586,10 +1589,10 @@ def _cleanse_time_partitioning(destination_dataset_table, time_partitioning_in):
15861589
# if it is a partitioned table ($ is in the table name) add partition load option
15871590
time_partitioning_out = {}
15881591
if destination_dataset_table and '$' in destination_dataset_table:
1589-
assert not time_partitioning_in.get('field'), (
1590-
"Cannot specify field partition and partition name "
1591-
"(dataset.table$partition) at the same time"
1592-
)
1592+
if time_partitioning_in.get('field'):
1593+
raise ValueError(
1594+
"Cannot specify field partition and partition name"
1595+
"(dataset.table$partition) at the same time")
15931596
time_partitioning_out['type'] = 'DAY'
15941597

15951598
time_partitioning_out.update(time_partitioning_in)

airflow/contrib/hooks/databricks_hook.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ def __init__(
6161
self.databricks_conn_id = databricks_conn_id
6262
self.databricks_conn = self.get_connection(databricks_conn_id)
6363
self.timeout_seconds = timeout_seconds
64-
assert retry_limit >= 1, 'Retry limit must be greater than equal to 1'
64+
if retry_limit < 1:
65+
raise ValueError('Retry limit must be greater than equal to 1')
6566
self.retry_limit = retry_limit
6667

6768
def _parse_host(self, host):

airflow/contrib/hooks/gcp_dataflow_hook.py

+14-10
Original file line numberDiff line numberDiff line change
@@ -225,11 +225,11 @@ def label_formatter(labels_dict):
225225
def _build_dataflow_job_name(task_id, append_job_name=True):
226226
task_id = str(task_id).replace('_', '-')
227227

228-
assert re.match(r"^[a-z]([-a-z0-9]*[a-z0-9])?$", task_id), \
229-
'Invalid job_name ({}); the name must consist of ' \
230-
'only the characters [-a-z0-9], starting with a ' \
231-
'letter and ending with a letter or number '.format(
232-
task_id)
228+
if not re.match(r"^[a-z]([-a-z0-9]*[a-z0-9])?$", task_id):
229+
raise ValueError(
230+
'Invalid job_name ({}); the name must consist of'
231+
'only the characters [-a-z0-9], starting with a '
232+
'letter and ending with a letter or number '.format(task_id))
233233

234234
if append_job_name:
235235
job_name = task_id + "-" + str(uuid.uuid1())[:8]
@@ -238,7 +238,8 @@ def _build_dataflow_job_name(task_id, append_job_name=True):
238238

239239
return job_name
240240

241-
def _build_cmd(self, task_id, variables, label_formatter):
241+
@staticmethod
242+
def _build_cmd(task_id, variables, label_formatter):
242243
command = ["--runner=DataflowRunner"]
243244
if variables is not None:
244245
for attr, value in variables.items():
@@ -250,7 +251,8 @@ def _build_cmd(self, task_id, variables, label_formatter):
250251
command.append("--" + attr + "=" + value)
251252
return command
252253

253-
def _start_template_dataflow(self, name, variables, parameters, dataflow_template):
254+
def _start_template_dataflow(self, name, variables, parameters,
255+
dataflow_template):
254256
# Builds RuntimeEnvironment from variables dictionary
255257
# https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment
256258
environment = {}
@@ -262,9 +264,11 @@ def _start_template_dataflow(self, name, variables, parameters, dataflow_templat
262264
"parameters": parameters,
263265
"environment": environment}
264266
service = self.get_conn()
265-
request = service.projects().templates().launch(projectId=variables['project'],
266-
gcsPath=dataflow_template,
267-
body=body)
267+
request = service.projects().templates().launch(
268+
projectId=variables['project'],
269+
gcsPath=dataflow_template,
270+
body=body
271+
)
268272
response = request.execute()
269273
variables = self._set_variables(variables)
270274
_DataflowJob(self.get_conn(), variables['project'], name, variables['region'],

airflow/contrib/hooks/gcp_mlengine_hook.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,8 @@ def _wait_for_job_done(self, project_id, job_id, interval=30):
152152
apiclient.errors.HttpError: if HTTP error is returned when getting
153153
the job
154154
"""
155-
assert interval > 0
155+
if interval <= 0:
156+
raise ValueError("Interval must be > 0")
156157
while True:
157158
job = self._get_job(project_id, job_id)
158159
if job['state'] in ['SUCCEEDED', 'FAILED', 'CANCELLED']:
@@ -242,7 +243,9 @@ def create_model(self, project_id, model):
242243
"""
243244
Create a Model. Blocks until finished.
244245
"""
245-
assert model['name'] is not None and model['name'] is not ''
246+
if not model['name']:
247+
raise ValueError("Model name must be provided and "
248+
"could not be an empty string")
246249
project = 'projects/{}'.format(project_id)
247250

248251
request = self._mlengine.projects().models().create(
@@ -253,7 +256,9 @@ def get_model(self, project_id, model_name):
253256
"""
254257
Gets a Model. Blocks until finished.
255258
"""
256-
assert model_name is not None and model_name is not ''
259+
if not model_name:
260+
raise ValueError("Model name must be provided and "
261+
"it could not be an empty string")
257262
full_model_name = 'projects/{}/models/{}'.format(
258263
project_id, model_name)
259264
request = self._mlengine.projects().models().get(name=full_model_name)

airflow/contrib/hooks/gcs_hook.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -477,15 +477,16 @@ def create_bucket(self,
477477

478478
self.log.info('Creating Bucket: %s; Location: %s; Storage Class: %s',
479479
bucket_name, location, storage_class)
480-
assert storage_class in storage_classes, \
481-
'Invalid value ({}) passed to storage_class. Value should be ' \
482-
'one of {}'.format(storage_class, storage_classes)
480+
if storage_class not in storage_classes:
481+
raise ValueError(
482+
'Invalid value ({}) passed to storage_class. Value should be '
483+
'one of {}'.format(storage_class, storage_classes))
483484

484-
assert re.match('[a-zA-Z0-9]+', bucket_name[0]), \
485-
'Bucket names must start with a number or letter.'
485+
if not re.match('[a-zA-Z0-9]+', bucket_name[0]):
486+
raise ValueError('Bucket names must start with a number or letter.')
486487

487-
assert re.match('[a-zA-Z0-9]+', bucket_name[-1]), \
488-
'Bucket names must end with a number or letter.'
488+
if not re.match('[a-zA-Z0-9]+', bucket_name[-1]):
489+
raise ValueError('Bucket names must end with a number or letter.')
489490

490491
service = self.get_conn()
491492
bucket_resource = {

airflow/contrib/operators/mlengine_operator.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,9 @@ def execute(self, context):
427427
gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to)
428428

429429
if self._operation == 'create':
430-
assert self._version is not None
430+
if not self._version:
431+
raise ValueError("version attribute of {} could not "
432+
"be empty".format(self.__class__.__name__))
431433
return hook.create_version(self._project_id, self._model_name,
432434
self._version)
433435
elif self._operation == 'set_default':

tests/contrib/hooks/test_bigquery_hook.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ def test_extra_time_partitioning_options(self):
414414
self.assertEqual(tp_out, expect)
415415

416416
def test_cant_add_dollar_and_field_name(self):
417-
with self.assertRaises(AssertionError):
417+
with self.assertRaises(ValueError):
418418
_cleanse_time_partitioning(
419419
'test.teast$20170101',
420420
{'type': 'DAY', 'field': 'test_field', 'expirationMs': 1000}

tests/contrib/hooks/test_databricks_hook.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
# to you under the Apache License, Version 2.0 (the
88
# "License"); you may not use this file except in compliance
99
# with the License. You may obtain a copy of the License at
10-
#
10+
#
1111
# http://www.apache.org/licenses/LICENSE-2.0
12-
#
12+
#
1313
# Unless required by applicable law or agreed to in writing,
1414
# software distributed under the License is distributed on an
1515
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -110,7 +110,7 @@ def test_parse_host_with_scheme(self):
110110
self.assertEquals(host, HOST)
111111

112112
def test_init_bad_retry_limit(self):
113-
with self.assertRaises(AssertionError):
113+
with self.assertRaises(ValueError):
114114
DatabricksHook(retry_limit = 0)
115115

116116
@mock.patch('airflow.contrib.hooks.databricks_hook.requests')

tests/contrib/hooks/test_gcp_dataflow_hook.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def test_invalid_dataflow_job_name(self):
195195
fixed_name = invalid_job_name.replace(
196196
'_', '-')
197197

198-
with self.assertRaises(AssertionError) as e:
198+
with self.assertRaises(ValueError) as e:
199199
self.dataflow_hook._build_dataflow_job_name(
200200
task_id=invalid_job_name, append_job_name=False
201201
)
@@ -222,19 +222,19 @@ def test_dataflow_job_regex_check(self):
222222
), 'dfjob1')
223223

224224
self.assertRaises(
225-
AssertionError,
225+
ValueError,
226226
self.dataflow_hook._build_dataflow_job_name,
227227
task_id='1dfjob', append_job_name=False
228228
)
229229

230230
self.assertRaises(
231-
AssertionError,
231+
ValueError,
232232
self.dataflow_hook._build_dataflow_job_name,
233233
task_id='dfjob@', append_job_name=False
234234
)
235235

236236
self.assertRaises(
237-
AssertionError,
237+
ValueError,
238238
self.dataflow_hook._build_dataflow_job_name,
239239
task_id='df^jo', append_job_name=False
240240
)

tests/contrib/hooks/test_gcs_hook.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,14 @@ class TestGCSBucket(unittest.TestCase):
6666
def test_bucket_name_value(self):
6767

6868
bad_start_bucket_name = '/testing123'
69-
with self.assertRaises(AssertionError):
69+
with self.assertRaises(ValueError):
7070

7171
gcs_hook.GoogleCloudStorageHook().create_bucket(
7272
bucket_name=bad_start_bucket_name
7373
)
7474

7575
bad_end_bucket_name = 'testing123/'
76-
with self.assertRaises(AssertionError):
76+
with self.assertRaises(ValueError):
7777
gcs_hook.GoogleCloudStorageHook().create_bucket(
7878
bucket_name=bad_end_bucket_name
7979
)

0 commit comments

Comments
 (0)