Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIRFLOW-2845] Remove asserts from the contrib package #3690

Merged
merged 1 commit into from
Aug 6, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 29 additions & 26 deletions airflow/contrib/hooks/bigquery_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,9 +592,11 @@ def run_query(self,
}

if destination_dataset_table:
assert '.' in destination_dataset_table, (
'Expected destination_dataset_table in the format of '
'<dataset>.<table>. Got: {}').format(destination_dataset_table)
if '.' not in destination_dataset_table:
raise ValueError(
'Expected destination_dataset_table name in the format of '
'<dataset>.<table>. Got: {}'.format(
destination_dataset_table))
destination_project, destination_dataset, destination_table = \
_split_tablename(table_input=destination_dataset_table,
default_project_id=self.project_id)
Expand All @@ -610,7 +612,9 @@ def run_query(self,
}
})
if udf_config:
assert isinstance(udf_config, list)
if not isinstance(udf_config, list):
raise TypeError("udf_config argument must have a type 'list'"
" not {}".format(type(udf_config)))
configuration['query'].update({
'userDefinedFunctionResources': udf_config
})
Expand Down Expand Up @@ -1153,10 +1157,10 @@ def run_table_delete(self, deletion_dataset_table,
:type ignore_if_missing: boolean
:return:
"""

assert '.' in deletion_dataset_table, (
'Expected deletion_dataset_table in the format of '
'<dataset>.<table>. Got: {}').format(deletion_dataset_table)
if '.' not in deletion_dataset_table:
raise ValueError(
'Expected deletion_dataset_table name in the format of '
'<dataset>.<table>. Got: {}'.format(deletion_dataset_table))
deletion_project, deletion_dataset, deletion_table = \
_split_tablename(table_input=deletion_dataset_table,
default_project_id=self.project_id)
Expand Down Expand Up @@ -1284,14 +1288,10 @@ def run_grant_dataset_view_access(self,
# if view is already in access, do nothing.
self.log.info(
'Table %s:%s.%s already has authorized view access to %s:%s dataset.',
view_project, view_dataset, view_table, source_project,
source_dataset)
view_project, view_dataset, view_table, source_project, source_dataset)
return source_dataset_resource

def delete_dataset(self,
project_id,
dataset_id
):
def delete_dataset(self, project_id, dataset_id):
"""
Delete a dataset of Big query in your project.
:param project_id: The name of the project where we have the dataset .
Expand All @@ -1308,9 +1308,8 @@ def delete_dataset(self,
self.service.datasets().delete(
projectId=project_id,
datasetId=dataset_id).execute()

self.log.info('Dataset deleted successfully: In project %s Dataset %s',
project_id, dataset_id)
self.log.info('Dataset deleted successfully: In project %s '
'Dataset %s', project_id, dataset_id)

except HttpError as err:
raise AirflowException(
Expand Down Expand Up @@ -1518,14 +1517,17 @@ def _bq_cast(string_field, bq_type):
elif bq_type == 'FLOAT' or bq_type == 'TIMESTAMP':
return float(string_field)
elif bq_type == 'BOOLEAN':
assert string_field in set(['true', 'false'])
if string_field not in ['true', 'false']:
raise ValueError("{} must have value 'true' or 'false'".format(
string_field))
return string_field == 'true'
else:
return string_field


def _split_tablename(table_input, default_project_id, var_name=None):
assert default_project_id is not None, "INTERNAL: No default project is specified"
if not default_project_id:
raise ValueError("INTERNAL: No default project is specified")

def var_print(var_name):
if var_name is None:
Expand All @@ -1537,7 +1539,6 @@ def var_print(var_name):
raise Exception(('{var}Use either : or . to specify project '
'got {input}').format(
var=var_print(var_name), input=table_input))

cmpt = table_input.rsplit(':', 1)
project_id = None
rest = table_input
Expand All @@ -1555,8 +1556,10 @@ def var_print(var_name):

cmpt = rest.split('.')
if len(cmpt) == 3:
assert project_id is None, ("{var}Use either : or . to specify project"
).format(var=var_print(var_name))
if project_id:
raise ValueError(
"{var}Use either : or . to specify project".format(
var=var_print(var_name)))
project_id = cmpt[0]
dataset_id = cmpt[1]
table_id = cmpt[2]
Expand Down Expand Up @@ -1586,10 +1589,10 @@ def _cleanse_time_partitioning(destination_dataset_table, time_partitioning_in):
# if it is a partitioned table ($ is in the table name) add partition load option
time_partitioning_out = {}
if destination_dataset_table and '$' in destination_dataset_table:
assert not time_partitioning_in.get('field'), (
"Cannot specify field partition and partition name "
"(dataset.table$partition) at the same time"
)
if time_partitioning_in.get('field'):
raise ValueError(
"Cannot specify field partition and partition name"
"(dataset.table$partition) at the same time")
time_partitioning_out['type'] = 'DAY'

time_partitioning_out.update(time_partitioning_in)
Expand Down
3 changes: 2 additions & 1 deletion airflow/contrib/hooks/databricks_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def __init__(
self.databricks_conn_id = databricks_conn_id
self.databricks_conn = self.get_connection(databricks_conn_id)
self.timeout_seconds = timeout_seconds
assert retry_limit >= 1, 'Retry limit must be greater than equal to 1'
if retry_limit < 1:
raise ValueError('Retry limit must be greater than equal to 1')
self.retry_limit = retry_limit

def _parse_host(self, host):
Expand Down
24 changes: 14 additions & 10 deletions airflow/contrib/hooks/gcp_dataflow_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,11 +225,11 @@ def label_formatter(labels_dict):
def _build_dataflow_job_name(task_id, append_job_name=True):
task_id = str(task_id).replace('_', '-')

assert re.match(r"^[a-z]([-a-z0-9]*[a-z0-9])?$", task_id), \
'Invalid job_name ({}); the name must consist of ' \
'only the characters [-a-z0-9], starting with a ' \
'letter and ending with a letter or number '.format(
task_id)
if not re.match(r"^[a-z]([-a-z0-9]*[a-z0-9])?$", task_id):
raise ValueError(
'Invalid job_name ({}); the name must consist of'
'only the characters [-a-z0-9], starting with a '
'letter and ending with a letter or number '.format(task_id))

if append_job_name:
job_name = task_id + "-" + str(uuid.uuid1())[:8]
Expand All @@ -238,7 +238,8 @@ def _build_dataflow_job_name(task_id, append_job_name=True):

return job_name

def _build_cmd(self, task_id, variables, label_formatter):
@staticmethod
def _build_cmd(task_id, variables, label_formatter):
command = ["--runner=DataflowRunner"]
if variables is not None:
for attr, value in variables.items():
Expand All @@ -250,7 +251,8 @@ def _build_cmd(self, task_id, variables, label_formatter):
command.append("--" + attr + "=" + value)
return command

def _start_template_dataflow(self, name, variables, parameters, dataflow_template):
def _start_template_dataflow(self, name, variables, parameters,
dataflow_template):
# Builds RuntimeEnvironment from variables dictionary
# https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment
environment = {}
Expand All @@ -262,9 +264,11 @@ def _start_template_dataflow(self, name, variables, parameters, dataflow_templat
"parameters": parameters,
"environment": environment}
service = self.get_conn()
request = service.projects().templates().launch(projectId=variables['project'],
gcsPath=dataflow_template,
body=body)
request = service.projects().templates().launch(
projectId=variables['project'],
gcsPath=dataflow_template,
body=body
)
response = request.execute()
variables = self._set_variables(variables)
_DataflowJob(self.get_conn(), variables['project'], name, variables['region'],
Expand Down
11 changes: 8 additions & 3 deletions airflow/contrib/hooks/gcp_mlengine_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ def _wait_for_job_done(self, project_id, job_id, interval=30):
apiclient.errors.HttpError: if HTTP error is returned when getting
the job
"""
assert interval > 0
if interval <= 0:
raise ValueError("Interval must be > 0")
while True:
job = self._get_job(project_id, job_id)
if job['state'] in ['SUCCEEDED', 'FAILED', 'CANCELLED']:
Expand Down Expand Up @@ -242,7 +243,9 @@ def create_model(self, project_id, model):
"""
Create a Model. Blocks until finished.
"""
assert model['name'] is not None and model['name'] is not ''
if not model['name']:
raise ValueError("Model name must be provided and "
"could not be an empty string")
project = 'projects/{}'.format(project_id)

request = self._mlengine.projects().models().create(
Expand All @@ -253,7 +256,9 @@ def get_model(self, project_id, model_name):
"""
Gets a Model. Blocks until finished.
"""
assert model_name is not None and model_name is not ''
if not model_name:
raise ValueError("Model name must be provided and "
"it could not be an empty string")
full_model_name = 'projects/{}/models/{}'.format(
project_id, model_name)
request = self._mlengine.projects().models().get(name=full_model_name)
Expand Down
15 changes: 8 additions & 7 deletions airflow/contrib/hooks/gcs_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,15 +477,16 @@ def create_bucket(self,

self.log.info('Creating Bucket: %s; Location: %s; Storage Class: %s',
bucket_name, location, storage_class)
assert storage_class in storage_classes, \
'Invalid value ({}) passed to storage_class. Value should be ' \
'one of {}'.format(storage_class, storage_classes)
if storage_class not in storage_classes:
raise ValueError(
'Invalid value ({}) passed to storage_class. Value should be '
'one of {}'.format(storage_class, storage_classes))

assert re.match('[a-zA-Z0-9]+', bucket_name[0]), \
'Bucket names must start with a number or letter.'
if not re.match('[a-zA-Z0-9]+', bucket_name[0]):
raise ValueError('Bucket names must start with a number or letter.')

assert re.match('[a-zA-Z0-9]+', bucket_name[-1]), \
'Bucket names must end with a number or letter.'
if not re.match('[a-zA-Z0-9]+', bucket_name[-1]):
raise ValueError('Bucket names must end with a number or letter.')

service = self.get_conn()
bucket_resource = {
Expand Down
4 changes: 3 additions & 1 deletion airflow/contrib/operators/mlengine_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,9 @@ def execute(self, context):
gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to)

if self._operation == 'create':
assert self._version is not None
if not self._version:
raise ValueError("version attribute of {} could not "
"be empty".format(self.__class__.__name__))
return hook.create_version(self._project_id, self._model_name,
self._version)
elif self._operation == 'set_default':
Expand Down
2 changes: 1 addition & 1 deletion tests/contrib/hooks/test_bigquery_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def test_extra_time_partitioning_options(self):
self.assertEqual(tp_out, expect)

def test_cant_add_dollar_and_field_name(self):
with self.assertRaises(AssertionError):
with self.assertRaises(ValueError):
_cleanse_time_partitioning(
'test.teast$20170101',
{'type': 'DAY', 'field': 'test_field', 'expirationMs': 1000}
Expand Down
6 changes: 3 additions & 3 deletions tests/contrib/hooks/test_databricks_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down Expand Up @@ -110,7 +110,7 @@ def test_parse_host_with_scheme(self):
self.assertEquals(host, HOST)

def test_init_bad_retry_limit(self):
with self.assertRaises(AssertionError):
with self.assertRaises(ValueError):
DatabricksHook(retry_limit = 0)

@mock.patch('airflow.contrib.hooks.databricks_hook.requests')
Expand Down
8 changes: 4 additions & 4 deletions tests/contrib/hooks/test_gcp_dataflow_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def test_invalid_dataflow_job_name(self):
fixed_name = invalid_job_name.replace(
'_', '-')

with self.assertRaises(AssertionError) as e:
with self.assertRaises(ValueError) as e:
self.dataflow_hook._build_dataflow_job_name(
task_id=invalid_job_name, append_job_name=False
)
Expand All @@ -222,19 +222,19 @@ def test_dataflow_job_regex_check(self):
), 'dfjob1')

self.assertRaises(
AssertionError,
ValueError,
self.dataflow_hook._build_dataflow_job_name,
task_id='1dfjob', append_job_name=False
)

self.assertRaises(
AssertionError,
ValueError,
self.dataflow_hook._build_dataflow_job_name,
task_id='dfjob@', append_job_name=False
)

self.assertRaises(
AssertionError,
ValueError,
self.dataflow_hook._build_dataflow_job_name,
task_id='df^jo', append_job_name=False
)
Expand Down
4 changes: 2 additions & 2 deletions tests/contrib/hooks/test_gcs_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,14 @@ class TestGCSBucket(unittest.TestCase):
def test_bucket_name_value(self):

bad_start_bucket_name = '/testing123'
with self.assertRaises(AssertionError):
with self.assertRaises(ValueError):

gcs_hook.GoogleCloudStorageHook().create_bucket(
bucket_name=bad_start_bucket_name
)

bad_end_bucket_name = 'testing123/'
with self.assertRaises(AssertionError):
with self.assertRaises(ValueError):
gcs_hook.GoogleCloudStorageHook().create_bucket(
bucket_name=bad_end_bucket_name
)
Expand Down