Skip to content

Commit 21a762c

Browse files
ashbAlice Berard
authored and
Alice Berard
committed
[AIRFLOW-3343] Update DockerOperator for Docker-py 3.0.0 API changes (apache#4187)
The API of `wait()` changed to return a dict, not just a number so this Operator wasn't actually working, but the tests were passing because the return was mocked in-correctly. I also removed `shm_size` from kwargs passed to BaseOperator to avoid the deprecation warning about unknown args.
1 parent 42a9ea1 commit 21a762c

File tree

3 files changed

+16
-12
lines changed

3 files changed

+16
-12
lines changed

airflow/operators/docker_operator.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ class DockerOperator(BaseOperator):
108108
:type xcom_all: bool
109109
:param docker_conn_id: ID of the Airflow connection to use
110110
:type docker_conn_id: str
111+
:param shm_size: Size of ``/dev/shm`` in bytes. The size must be
112+
greater than 0. If omitted uses system default.
113+
:type shm_size: int
111114
"""
112115
template_fields = ('command', 'environment',)
113116
template_ext = ('.sh', '.bash',)
@@ -139,6 +142,7 @@ def __init__(
139142
dns=None,
140143
dns_search=None,
141144
auto_remove=False,
145+
shm_size=None,
142146
*args,
143147
**kwargs):
144148

@@ -167,7 +171,7 @@ def __init__(
167171
self.xcom_push_flag = xcom_push
168172
self.xcom_all = xcom_all
169173
self.docker_conn_id = docker_conn_id
170-
self.shm_size = kwargs.get('shm_size')
174+
self.shm_size = shm_size
171175

172176
self.cli = None
173177
self.container = None
@@ -197,7 +201,7 @@ def execute(self, context):
197201
if self.force_pull or len(self.cli.images(name=self.image)) == 0:
198202
self.log.info('Pulling docker image %s', self.image)
199203
for l in self.cli.pull(self.image, stream=True):
200-
output = json.loads(l.decode('utf-8'))
204+
output = json.loads(l.decode('utf-8').strip())
201205
if 'status' in output:
202206
self.log.info("%s", output['status'])
203207

@@ -230,9 +234,9 @@ def execute(self, context):
230234
line = line.decode('utf-8')
231235
self.log.info(line)
232236

233-
exit_code = self.cli.wait(self.container['Id'])
234-
if exit_code != 0:
235-
raise AirflowException('docker container failed')
237+
result = self.cli.wait(self.container['Id'])
238+
if result['StatusCode'] != 0:
239+
raise AirflowException('docker container failed: ' + repr(result))
236240

237241
if self.xcom_push_flag:
238242
return self.cli.logs(container=self.container['Id']) \

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def write_version(filename=os.path.join(*['airflow',
174174
'sphinx-rtd-theme>=0.1.6',
175175
'Sphinx-PyPI-upload>=0.2.1'
176176
]
177-
docker = ['docker>=3.0.0']
177+
docker = ['docker~=3.0']
178178
druid = ['pydruid>=0.4.1']
179179
elasticsearch = [
180180
'elasticsearch>=5.0.0,<6.0.0',

tests/operators/test_docker_operator.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def test_execute(self, client_class_mock, mkdtemp_mock):
5151
client_mock.images.return_value = []
5252
client_mock.logs.return_value = ['container log']
5353
client_mock.pull.return_value = [b'{"status":"pull log"}']
54-
client_mock.wait.return_value = 0
54+
client_mock.wait.return_value = {"StatusCode": 0}
5555

5656
client_class_mock.return_value = client_mock
5757

@@ -97,7 +97,7 @@ def test_execute_tls(self, client_class_mock, tls_class_mock):
9797
client_mock.images.return_value = []
9898
client_mock.logs.return_value = []
9999
client_mock.pull.return_value = []
100-
client_mock.wait.return_value = 0
100+
client_mock.wait.return_value = {"StatusCode": 0}
101101

102102
client_class_mock.return_value = client_mock
103103
tls_mock = mock.Mock()
@@ -123,7 +123,7 @@ def test_execute_unicode_logs(self, client_class_mock):
123123
client_mock.images.return_value = []
124124
client_mock.logs.return_value = ['unicode container log 😁']
125125
client_mock.pull.return_value = []
126-
client_mock.wait.return_value = 0
126+
client_mock.wait.return_value = {"StatusCode": 0}
127127

128128
client_class_mock.return_value = client_mock
129129

@@ -145,7 +145,7 @@ def test_execute_container_fails(self, client_class_mock):
145145
client_mock.images.return_value = []
146146
client_mock.logs.return_value = []
147147
client_mock.pull.return_value = []
148-
client_mock.wait.return_value = 1
148+
client_mock.wait.return_value = {"StatusCode": 1}
149149

150150
client_class_mock.return_value = client_mock
151151

@@ -174,7 +174,7 @@ def test_execute_no_docker_conn_id_no_hook(self, operator_client_mock):
174174
client_mock.create_container.return_value = {'Id': 'some_id'}
175175
client_mock.logs.return_value = []
176176
client_mock.pull.return_value = []
177-
client_mock.wait.return_value = 0
177+
client_mock.wait.return_value = {"StatusCode": 0}
178178
operator_client_mock.return_value = client_mock
179179

180180
# Create the DockerOperator
@@ -209,7 +209,7 @@ def test_execute_with_docker_conn_id_use_hook(self, operator_client_mock,
209209
client_mock.create_container.return_value = {'Id': 'some_id'}
210210
client_mock.logs.return_value = []
211211
client_mock.pull.return_value = []
212-
client_mock.wait.return_value = 0
212+
client_mock.wait.return_value = {"StatusCode": 0}
213213
operator_client_mock.return_value = client_mock
214214

215215
# Create the DockerOperator

0 commit comments

Comments
 (0)