Skip to content

Commit e09b387

Browse files
wmorris75Fokko
authored andcommitted
[AIRFLOW-2974] Extended Databricks hook with clusters operation (#3817)
Add hooks for: - cluster start, - restart, - terminate. Add unit tests for the added hooks. Add hooks for cluster start, restart and terminate. Add unit tests for the added hooks. Add cluster_id variable for performing cluster operation tests.
1 parent f6a1a14 commit e09b387

File tree

3 files changed

+82
-1
lines changed

3 files changed

+82
-1
lines changed

airflow/contrib/hooks/databricks_hook.py

+12
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@
3333
except ImportError:
3434
import urlparse
3535

36+
RESTART_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/restart")
37+
START_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/start")
38+
TERMINATE_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/delete")
3639

3740
SUBMIT_RUN_ENDPOINT = ('POST', 'api/2.0/jobs/runs/submit')
3841
GET_RUN_ENDPOINT = ('GET', 'api/2.0/jobs/runs/get')
@@ -189,6 +192,15 @@ def cancel_run(self, run_id):
189192
json = {'run_id': run_id}
190193
self._do_api_call(CANCEL_RUN_ENDPOINT, json)
191194

195+
def restart_cluster(self, json):
196+
self._do_api_call(RESTART_CLUSTER_ENDPOINT, json)
197+
198+
def start_cluster(self, json):
199+
self._do_api_call(START_CLUSTER_ENDPOINT, json)
200+
201+
def terminate_cluster(self, json):
202+
self._do_api_call(TERMINATE_CLUSTER_ENDPOINT, json)
203+
192204

193205
def _retryable_error(exception):
194206
return isinstance(exception, requests_exceptions.ConnectionError) \

setup.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
1718
[metadata]
1819
name = Airflow
1920
summary = Airflow is a system to programmatically author, schedule and monitor data pipelines.
@@ -34,4 +35,3 @@ all_files = 1
3435
upload-dir = docs/_build/html
3536

3637
[easy_install]
37-

tests/contrib/hooks/test_databricks_hook.py

+69
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
'node_type_id': 'r3.xlarge',
5353
'num_workers': 1
5454
}
55+
CLUSTER_ID = 'cluster_id'
5556
RUN_ID = 1
5657
HOST = 'xx.cloud.databricks.com'
5758
HOST_WITH_SCHEME = 'https://xx.cloud.databricks.com'
@@ -93,6 +94,26 @@ def cancel_run_endpoint(host):
9394
return 'https://{}/api/2.0/jobs/runs/cancel'.format(host)
9495

9596

97+
def start_cluster_endpoint(host):
98+
"""
99+
Utility function to generate the get run endpoint given the host.
100+
"""
101+
return 'https://{}/api/2.0/clusters/start'.format(host)
102+
103+
104+
def restart_cluster_endpoint(host):
105+
"""
106+
Utility function to generate the get run endpoint given the host.
107+
"""
108+
return 'https://{}/api/2.0/clusters/restart'.format(host)
109+
110+
111+
def terminate_cluster_endpoint(host):
112+
"""
113+
Utility function to generate the get run endpoint given the host.
114+
"""
115+
return 'https://{}/api/2.0/clusters/delete'.format(host)
116+
96117
def create_valid_response_mock(content):
97118
response = mock.MagicMock()
98119
response.json.return_value = content
@@ -293,6 +314,54 @@ def test_cancel_run(self, mock_requests):
293314
headers=USER_AGENT_HEADER,
294315
timeout=self.hook.timeout_seconds)
295316

317+
@mock.patch('airflow.contrib.hooks.databricks_hook.requests')
318+
def test_start_cluster(self, mock_requests):
319+
mock_requests.codes.ok = 200
320+
mock_requests.post.return_value.json.return_value = {}
321+
status_code_mock = mock.PropertyMock(return_value=200)
322+
type(mock_requests.post.return_value).status_code = status_code_mock
323+
324+
self.hook.start_cluster({"cluster_id": CLUSTER_ID})
325+
326+
mock_requests.post.assert_called_once_with(
327+
start_cluster_endpoint(HOST),
328+
json={'cluster_id': CLUSTER_ID},
329+
auth=(LOGIN, PASSWORD),
330+
headers=USER_AGENT_HEADER,
331+
timeout=self.hook.timeout_seconds)
332+
333+
@mock.patch('airflow.contrib.hooks.databricks_hook.requests')
334+
def test_restart_cluster(self, mock_requests):
335+
mock_requests.codes.ok = 200
336+
mock_requests.post.return_value.json.return_value = {}
337+
status_code_mock = mock.PropertyMock(return_value=200)
338+
type(mock_requests.post.return_value).status_code = status_code_mock
339+
340+
self.hook.restart_cluster({"cluster_id": CLUSTER_ID})
341+
342+
mock_requests.post.assert_called_once_with(
343+
restart_cluster_endpoint(HOST),
344+
json={'cluster_id': CLUSTER_ID},
345+
auth=(LOGIN, PASSWORD),
346+
headers=USER_AGENT_HEADER,
347+
timeout=self.hook.timeout_seconds)
348+
349+
@mock.patch('airflow.contrib.hooks.databricks_hook.requests')
350+
def test_terminate_cluster(self, mock_requests):
351+
mock_requests.codes.ok = 200
352+
mock_requests.post.return_value.json.return_value = {}
353+
status_code_mock = mock.PropertyMock(return_value=200)
354+
type(mock_requests.post.return_value).status_code = status_code_mock
355+
356+
self.hook.terminate_cluster({"cluster_id": CLUSTER_ID})
357+
358+
mock_requests.post.assert_called_once_with(
359+
terminate_cluster_endpoint(HOST),
360+
json={'cluster_id': CLUSTER_ID},
361+
auth=(LOGIN, PASSWORD),
362+
headers=USER_AGENT_HEADER,
363+
timeout=self.hook.timeout_seconds)
364+
296365

297366
class DatabricksHookTokenTest(unittest.TestCase):
298367
"""

0 commit comments

Comments
 (0)