Skip to content

Commit 82200bc

Browse files
exployjeffkpayne
authored andcommitted
[AIRFLOW-2797] Create Google Dataproc cluster with custom image (apache#3871)
1 parent eb163f7 commit 82200bc

File tree

2 files changed

+49
-1
lines changed

2 files changed

+49
-1
lines changed

airflow/contrib/operators/dataproc_operator.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,10 @@ class DataprocClusterCreateOperator(BaseOperator):
6767
to add to all instances
6868
:type metadata: dict
6969
:param image_version: the version of software inside the Dataproc cluster
70-
:type image_version: string
70+
:type image_version: str
71+
:param custom_image: custom Dataproc image for more info see
72+
https://cloud.google.com/dataproc/docs/guides/dataproc-images
73+
:type: custom_image: str
7174
:param properties: dict of properties to set on
7275
config files (e.g. spark-defaults.conf), see
7376
https://cloud.google.com/dataproc/docs/reference/rest/v1/ \
@@ -138,6 +141,7 @@ def __init__(self,
138141
init_actions_uris=None,
139142
init_action_timeout="10m",
140143
metadata=None,
144+
custom_image=None,
141145
image_version=None,
142146
properties=None,
143147
master_machine_type='n1-standard-4',
@@ -168,6 +172,7 @@ def __init__(self,
168172
self.init_actions_uris = init_actions_uris
169173
self.init_action_timeout = init_action_timeout
170174
self.metadata = metadata
175+
self.custom_image = custom_image
171176
self.image_version = image_version
172177
self.properties = properties
173178
self.master_machine_type = master_machine_type
@@ -187,6 +192,9 @@ def __init__(self,
187192
self.auto_delete_time = auto_delete_time
188193
self.auto_delete_ttl = auto_delete_ttl
189194

195+
assert not (self.custom_image and self.image_version), \
196+
"custom_image and image_version can't be both set"
197+
190198
def _get_cluster_list_for_project(self, service):
191199
result = service.projects().regions().clusters().list(
192200
projectId=self.project_id,
@@ -321,6 +329,12 @@ def _build_cluster_data(self):
321329
cluster_data['config']['gceClusterConfig']['tags'] = self.tags
322330
if self.image_version:
323331
cluster_data['config']['softwareConfig']['imageVersion'] = self.image_version
332+
elif self.custom_image:
333+
custom_image_url = 'https://www.googleapis.com/compute/beta/projects/' \
334+
'{}/global/images/{}'.format(self.project_id,
335+
self.custom_image)
336+
cluster_data['config']['masterConfig']['imageUri'] = custom_image_url
337+
cluster_data['config']['workerConfig']['imageUri'] = custom_image_url
324338
if self.properties:
325339
cluster_data['config']['softwareConfig']['properties'] = self.properties
326340
if self.idle_delete_ttl:

tests/contrib/operators/test_dataproc_operator.py

+34
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
TAGS = ['tag1', 'tag2']
6060
STORAGE_BUCKET = 'gs://airflow-test-bucket/'
6161
IMAGE_VERSION = '1.1'
62+
CUSTOM_IMAGE = 'test-custom-image'
6263
MASTER_MACHINE_TYPE = 'n1-standard-2'
6364
MASTER_DISK_SIZE = 100
6465
WORKER_MACHINE_TYPE = 'n1-standard-2'
@@ -258,6 +259,39 @@ def test_build_cluster_data_with_autoDeleteTime_and_autoDeleteTtl(self):
258259
self.assertEqual(cluster_data['config']['lifecycleConfig']['autoDeleteTime'],
259260
"2017-06-07T00:00:00.000000Z")
260261

262+
def test_init_with_image_version_and_custom_image_both_set(self):
263+
with self.assertRaises(AssertionError):
264+
DataprocClusterCreateOperator(
265+
task_id=TASK_ID,
266+
cluster_name=CLUSTER_NAME,
267+
project_id=PROJECT_ID,
268+
num_workers=NUM_WORKERS,
269+
zone=ZONE,
270+
dag=self.dag,
271+
image_version=IMAGE_VERSION,
272+
custom_image=CUSTOM_IMAGE
273+
)
274+
275+
def test_init_with_custom_image(self):
276+
dataproc_operator = DataprocClusterCreateOperator(
277+
task_id=TASK_ID,
278+
cluster_name=CLUSTER_NAME,
279+
project_id=PROJECT_ID,
280+
num_workers=NUM_WORKERS,
281+
zone=ZONE,
282+
dag=self.dag,
283+
custom_image=CUSTOM_IMAGE
284+
)
285+
286+
cluster_data = dataproc_operator._build_cluster_data()
287+
expected_custom_image_url = \
288+
'https://www.googleapis.com/compute/beta/projects/' \
289+
'{}/global/images/{}'.format(PROJECT_ID, CUSTOM_IMAGE)
290+
self.assertEqual(cluster_data['config']['masterConfig']['imageUri'],
291+
expected_custom_image_url)
292+
self.assertEqual(cluster_data['config']['workerConfig']['imageUri'],
293+
expected_custom_image_url)
294+
261295
def test_cluster_name_log_no_sub(self):
262296
with patch('airflow.contrib.operators.dataproc_operator.DataProcHook') \
263297
as mock_hook:

0 commit comments

Comments
 (0)