Skip to content

Commit b310d92

Browse files
XD-DENGChris Fei
authored and
Chris Fei
committed
Fix DockerOperator & some operator test (apache#4049)
- For argument `image`, no need to explicitly add "latest" if tag is omitted.   "latest" will be used by default if no tag provided. This is handled by `docker` package itself. - Intermediate variable `cpu_shares` is not needed. - Fix wrong usage of `cpu_shares` and `cpu_shares`. Based on https://docker-py.readthedocs.io/en/stable/api.html#docker.api.container.ContainerApiMixin.create_host_config, They should be an arguments of self.cli.create_host_config() rather than APIClient.create_container(). - Change name of the corresponding test script, to ensure it can be discovered. - Fix the test itself. - Some other test scripts are not named properly, which result in failure of test discovery.
1 parent 39cc181 commit b310d92

10 files changed

+69
-89
lines changed

airflow/operators/docker_operator.py

+9-15
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class DockerOperator(BaseOperator):
4343
be provided with the parameter ``docker_conn_id``.
4444
4545
:param image: Docker image from which to create the container.
46+
If image tag is omitted, "latest" will be used.
4647
:type image: str
4748
:param api_version: Remote API version. Set to ``auto`` to automatically
4849
detect the server's version.
@@ -62,7 +63,7 @@ class DockerOperator(BaseOperator):
6263
:type docker_url: str
6364
:param environment: Environment variables to set in the container. (templated)
6465
:type environment: dict
65-
:param force_pull: Pull the docker image on every run. Default is false.
66+
:param force_pull: Pull the docker image on every run. Default is False.
6667
:type force_pull: bool
6768
:param mem_limit: Maximum amount of memory the container can use.
6869
Either a float value, which represents the limit in bytes,
@@ -187,35 +188,28 @@ def execute(self, context):
187188
tls=tls_config
188189
)
189190

190-
if ':' not in self.image:
191-
image = self.image + ':latest'
192-
else:
193-
image = self.image
194-
195-
if self.force_pull or len(self.cli.images(name=image)) == 0:
196-
self.log.info('Pulling docker image %s', image)
197-
for l in self.cli.pull(image, stream=True):
191+
if self.force_pull or len(self.cli.images(name=self.image)) == 0:
192+
self.log.info('Pulling docker image %s', self.image)
193+
for l in self.cli.pull(self.image, stream=True):
198194
output = json.loads(l.decode('utf-8'))
199195
self.log.info("%s", output['status'])
200196

201-
cpu_shares = int(round(self.cpus * 1024))
202-
203197
with TemporaryDirectory(prefix='airflowtmp') as host_tmp_dir:
204198
self.environment['AIRFLOW_TMP_DIR'] = self.tmp_dir
205199
self.volumes.append('{0}:{1}'.format(host_tmp_dir, self.tmp_dir))
206200

207201
self.container = self.cli.create_container(
208202
command=self.get_command(),
209-
cpu_shares=cpu_shares,
210203
environment=self.environment,
211204
host_config=self.cli.create_host_config(
212205
binds=self.volumes,
213206
network_mode=self.network_mode,
214207
shm_size=self.shm_size,
215208
dns=self.dns,
216-
dns_search=self.dns_search),
217-
image=image,
218-
mem_limit=self.mem_limit,
209+
dns_search=self.dns_search,
210+
cpu_shares=int(round(self.cpus * 1024)),
211+
mem_limit=self.mem_limit),
212+
image=self.image,
219213
user=self.user,
220214
working_dir=self.working_dir
221215
)

tests/operators/__init__.py

+2-11
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,12 @@
77
# to you under the Apache License, Version 2.0 (the
88
# "License"); you may not use this file except in compliance
99
# with the License. You may obtain a copy of the License at
10-
#
10+
#
1111
# http://www.apache.org/licenses/LICENSE-2.0
12-
#
12+
#
1313
# Unless required by applicable law or agreed to in writing,
1414
# software distributed under the License is distributed on an
1515
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
1616
# KIND, either express or implied. See the License for the
1717
# specific language governing permissions and limitations
1818
# under the License.
19-
20-
from .docker_operator import *
21-
from .subdag_operator import *
22-
from .operators import *
23-
from .hive_operator import *
24-
from .s3_to_hive_operator import *
25-
from .python_operator import *
26-
from .latest_only_operator import *
27-

tests/operators/docker_operator.py tests/operators/test_docker_operator.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
# to you under the Apache License, Version 2.0 (the
88
# "License"); you may not use this file except in compliance
99
# with the License. You may obtain a copy of the License at
10-
#
10+
#
1111
# http://www.apache.org/licenses/LICENSE-2.0
12-
#
12+
#
1313
# Unless required by applicable law or agreed to in writing,
1414
# software distributed under the License is distributed on an
1515
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -64,20 +64,22 @@ def test_execute(self, client_class_mock, mkdtemp_mock):
6464
client_class_mock.assert_called_with(base_url='unix://var/run/docker.sock', tls=None,
6565
version='1.19')
6666

67-
client_mock.create_container.assert_called_with(command='env', cpu_shares=1024,
67+
client_mock.create_container.assert_called_with(command='env',
6868
environment={
6969
'AIRFLOW_TMP_DIR': '/tmp/airflow',
7070
'UNIT': 'TEST'
7171
},
7272
host_config=host_config,
7373
image='ubuntu:latest',
74-
mem_limit=None, user=None,
74+
user=None,
7575
working_dir='/container/path'
7676
)
7777
client_mock.create_host_config.assert_called_with(binds=['/host/path:/container/path',
7878
'/mkdtemp:/tmp/airflow'],
7979
network_mode='bridge',
8080
shm_size=1000,
81+
cpu_shares=1024,
82+
mem_limit=None,
8183
dns=None,
8284
dns_search=None)
8385
client_mock.images.assert_called_with(name='ubuntu:latest')
@@ -236,5 +238,6 @@ def test_execute_with_docker_conn_id_use_hook(self, operator_client_mock,
236238
'Image was not pulled using operator client'
237239
)
238240

241+
239242
if __name__ == "__main__":
240243
unittest.main()

tests/operators/hive_operator.py tests/operators/test_hive_operator.py

+26-27
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
# to you under the Apache License, Version 2.0 (the
88
# "License"); you may not use this file except in compliance
99
# with the License. You may obtain a copy of the License at
10-
#
10+
#
1111
# http://www.apache.org/licenses/LICENSE-2.0
12-
#
12+
#
1313
# Unless required by applicable law or agreed to in writing,
1414
# software distributed under the License is distributed on an
1515
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -22,11 +22,10 @@
2222
import datetime
2323
import os
2424
import unittest
25-
import mock
2625
import nose
27-
import six
2826

29-
from airflow import DAG, configuration, operators
27+
from airflow import DAG, configuration
28+
import airflow.operators.hive_operator
3029
configuration.load_test_config()
3130

3231

@@ -61,7 +60,7 @@ def setUp(self):
6160
class HiveOperatorConfigTest(HiveEnvironmentTest):
6261

6362
def test_hive_airflow_default_config_queue(self):
64-
t = operators.hive_operator.HiveOperator(
63+
t = airflow.operators.hive_operator.HiveOperator(
6564
task_id='test_default_config_queue',
6665
hql=self.hql,
6766
mapred_queue_priority='HIGH',
@@ -77,7 +76,7 @@ def test_hive_airflow_default_config_queue(self):
7776

7877
def test_hive_airflow_default_config_queue_override(self):
7978
specific_mapred_queue = 'default'
80-
t = operators.hive_operator.HiveOperator(
79+
t = airflow.operators.hive_operator.HiveOperator(
8180
task_id='test_default_config_queue',
8281
hql=self.hql,
8382
mapred_queue=specific_mapred_queue,
@@ -92,15 +91,15 @@ class HiveOperatorTest(HiveEnvironmentTest):
9291

9392
def test_hiveconf_jinja_translate(self):
9493
hql = "SELECT ${num_col} FROM ${hiveconf:table};"
95-
t = operators.hive_operator.HiveOperator(
94+
t = airflow.operators.hive_operator.HiveOperator(
9695
hiveconf_jinja_translate=True,
9796
task_id='dry_run_basic_hql', hql=hql, dag=self.dag)
9897
t.prepare_template()
9998
self.assertEqual(t.hql, "SELECT {{ num_col }} FROM {{ table }};")
10099

101100
def test_hiveconf(self):
102101
hql = "SELECT * FROM ${hiveconf:table} PARTITION (${hiveconf:day});"
103-
t = operators.hive_operator.HiveOperator(
102+
t = airflow.operators.hive_operator.HiveOperator(
104103
hiveconfs={'table': 'static_babynames', 'day': '{{ ds }}'},
105104
task_id='dry_run_basic_hql', hql=hql, dag=self.dag)
106105
t.prepare_template()
@@ -117,13 +116,13 @@ def test_hiveconf(self):
117116
class HivePrestoTest(HiveEnvironmentTest):
118117

119118
def test_hive(self):
120-
t = operators.hive_operator.HiveOperator(
119+
t = airflow.operators.hive_operator.HiveOperator(
121120
task_id='basic_hql', hql=self.hql, dag=self.dag)
122121
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
123122
ignore_ti_state=True)
124123

125124
def test_hive_queues(self):
126-
t = operators.hive_operator.HiveOperator(
125+
t = airflow.operators.hive_operator.HiveOperator(
127126
task_id='test_hive_queues', hql=self.hql,
128127
mapred_queue='default', mapred_queue_priority='HIGH',
129128
mapred_job_name='airflow.test_hive_queues',
@@ -132,12 +131,12 @@ def test_hive_queues(self):
132131
ignore_ti_state=True)
133132

134133
def test_hive_dryrun(self):
135-
t = operators.hive_operator.HiveOperator(
134+
t = airflow.operators.hive_operator.HiveOperator(
136135
task_id='dry_run_basic_hql', hql=self.hql, dag=self.dag)
137136
t.dry_run()
138137

139138
def test_beeline(self):
140-
t = operators.hive_operator.HiveOperator(
139+
t = airflow.operators.hive_operator.HiveOperator(
141140
task_id='beeline_hql', hive_cli_conn_id='beeline_default',
142141
hql=self.hql, dag=self.dag)
143142
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
@@ -147,13 +146,13 @@ def test_presto(self):
147146
sql = """
148147
SELECT count(1) FROM airflow.static_babynames_partitioned;
149148
"""
150-
t = operators.presto_check_operator.PrestoCheckOperator(
149+
t = airflow.operators.presto_check_operator.PrestoCheckOperator(
151150
task_id='presto_check', sql=sql, dag=self.dag)
152151
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
153152
ignore_ti_state=True)
154153

155154
def test_presto_to_mysql(self):
156-
t = operators.presto_to_mysql.PrestoToMySqlTransfer(
155+
t = airflow.operators.presto_to_mysql.PrestoToMySqlTransfer(
157156
task_id='presto_to_mysql_check',
158157
sql="""
159158
SELECT name, count(*) as ccount
@@ -167,15 +166,15 @@ def test_presto_to_mysql(self):
167166
ignore_ti_state=True)
168167

169168
def test_hdfs_sensor(self):
170-
t = operators.sensors.HdfsSensor(
169+
t = airflow.operators.sensors.HdfsSensor(
171170
task_id='hdfs_sensor_check',
172171
filepath='hdfs://user/hive/warehouse/airflow.db/static_babynames',
173172
dag=self.dag)
174173
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
175174
ignore_ti_state=True)
176175

177176
def test_webhdfs_sensor(self):
178-
t = operators.sensors.WebHdfsSensor(
177+
t = airflow.operators.sensors.WebHdfsSensor(
179178
task_id='webhdfs_sensor_check',
180179
filepath='hdfs://user/hive/warehouse/airflow.db/static_babynames',
181180
timeout=120,
@@ -184,7 +183,7 @@ def test_webhdfs_sensor(self):
184183
ignore_ti_state=True)
185184

186185
def test_sql_sensor(self):
187-
t = operators.sensors.SqlSensor(
186+
t = airflow.operators.sensors.SqlSensor(
188187
task_id='hdfs_sensor_check',
189188
conn_id='presto_default',
190189
sql="SELECT 'x' FROM airflow.static_babynames LIMIT 1;",
@@ -193,7 +192,7 @@ def test_sql_sensor(self):
193192
ignore_ti_state=True)
194193

195194
def test_hive_stats(self):
196-
t = operators.hive_stats_operator.HiveStatsCollectionOperator(
195+
t = airflow.operators.hive_stats_operator.HiveStatsCollectionOperator(
197196
task_id='hive_stats_check',
198197
table="airflow.static_babynames_partitioned",
199198
partition={'ds': DEFAULT_DATE_DS},
@@ -202,7 +201,7 @@ def test_hive_stats(self):
202201
ignore_ti_state=True)
203202

204203
def test_named_hive_partition_sensor(self):
205-
t = operators.sensors.NamedHivePartitionSensor(
204+
t = airflow.operators.sensors.NamedHivePartitionSensor(
206205
task_id='hive_partition_check',
207206
partition_names=[
208207
"airflow.static_babynames_partitioned/ds={{ds}}"
@@ -212,7 +211,7 @@ def test_named_hive_partition_sensor(self):
212211
ignore_ti_state=True)
213212

214213
def test_named_hive_partition_sensor_succeeds_on_multiple_partitions(self):
215-
t = operators.sensors.NamedHivePartitionSensor(
214+
t = airflow.operators.sensors.NamedHivePartitionSensor(
216215
task_id='hive_partition_check',
217216
partition_names=[
218217
"airflow.static_babynames_partitioned/ds={{ds}}",
@@ -223,15 +222,15 @@ def test_named_hive_partition_sensor_succeeds_on_multiple_partitions(self):
223222
ignore_ti_state=True)
224223

225224
def test_named_hive_partition_sensor_parses_partitions_with_periods(self):
226-
t = operators.sensors.NamedHivePartitionSensor.parse_partition_name(
225+
t = airflow.operators.sensors.NamedHivePartitionSensor.parse_partition_name(
227226
partition="schema.table/part1=this.can.be.an.issue/part2=ok")
228227
self.assertEqual(t[0], "schema")
229228
self.assertEqual(t[1], "table")
230229
self.assertEqual(t[2], "part1=this.can.be.an.issue/part2=this_should_be_ok")
231230

232231
@nose.tools.raises(airflow.exceptions.AirflowSensorTimeout)
233232
def test_named_hive_partition_sensor_times_out_on_nonexistent_partition(self):
234-
t = operators.sensors.NamedHivePartitionSensor(
233+
t = airflow.operators.sensors.NamedHivePartitionSensor(
235234
task_id='hive_partition_check',
236235
partition_names=[
237236
"airflow.static_babynames_partitioned/ds={{ds}}",
@@ -244,15 +243,15 @@ def test_named_hive_partition_sensor_times_out_on_nonexistent_partition(self):
244243
ignore_ti_state=True)
245244

246245
def test_hive_partition_sensor(self):
247-
t = operators.sensors.HivePartitionSensor(
246+
t = airflow.operators.sensors.HivePartitionSensor(
248247
task_id='hive_partition_check',
249248
table='airflow.static_babynames_partitioned',
250249
dag=self.dag)
251250
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
252251
ignore_ti_state=True)
253252

254253
def test_hive_metastore_sql_sensor(self):
255-
t = operators.sensors.MetastorePartitionSensor(
254+
t = airflow.operators.sensors.MetastorePartitionSensor(
256255
task_id='hive_partition_check',
257256
table='airflow.static_babynames_partitioned',
258257
partition_name='ds={}'.format(DEFAULT_DATE_DS),
@@ -261,7 +260,7 @@ def test_hive_metastore_sql_sensor(self):
261260
ignore_ti_state=True)
262261

263262
def test_hive2samba(self):
264-
t = operators.hive_to_samba_operator.Hive2SambaOperator(
263+
t = airflow.operators.hive_to_samba_operator.Hive2SambaOperator(
265264
task_id='hive2samba_check',
266265
samba_conn_id='tableau_samba',
267266
hql="SELECT * FROM airflow.static_babynames LIMIT 10000",
@@ -271,7 +270,7 @@ def test_hive2samba(self):
271270
ignore_ti_state=True)
272271

273272
def test_hive_to_mysql(self):
274-
t = operators.hive_to_mysql.HiveToMySqlTransfer(
273+
t = airflow.operators.hive_to_mysql.HiveToMySqlTransfer(
275274
mysql_conn_id='airflow_db',
276275
task_id='hive_to_mysql_check',
277276
create=True,

0 commit comments

Comments
 (0)