|
52 | 52 | 'node_type_id': 'r3.xlarge',
|
53 | 53 | 'num_workers': 1
|
54 | 54 | }
|
| 55 | +CLUSTER_ID = 'cluster_id' |
55 | 56 | RUN_ID = 1
|
56 | 57 | HOST = 'xx.cloud.databricks.com'
|
57 | 58 | HOST_WITH_SCHEME = 'https://xx.cloud.databricks.com'
|
@@ -93,6 +94,26 @@ def cancel_run_endpoint(host):
|
93 | 94 | return 'https://{}/api/2.0/jobs/runs/cancel'.format(host)
|
94 | 95 |
|
95 | 96 |
|
| 97 | +def start_cluster_endpoint(host): |
| 98 | + """ |
| 99 | + Utility function to generate the get run endpoint given the host. |
| 100 | + """ |
| 101 | + return 'https://{}/api/2.0/clusters/start'.format(host) |
| 102 | + |
| 103 | + |
| 104 | +def restart_cluster_endpoint(host): |
| 105 | + """ |
| 106 | + Utility function to generate the get run endpoint given the host. |
| 107 | + """ |
| 108 | + return 'https://{}/api/2.0/clusters/restart'.format(host) |
| 109 | + |
| 110 | + |
| 111 | +def terminate_cluster_endpoint(host): |
| 112 | + """ |
| 113 | + Utility function to generate the get run endpoint given the host. |
| 114 | + """ |
| 115 | + return 'https://{}/api/2.0/clusters/delete'.format(host) |
| 116 | + |
96 | 117 | def create_valid_response_mock(content):
|
97 | 118 | response = mock.MagicMock()
|
98 | 119 | response.json.return_value = content
|
@@ -293,6 +314,54 @@ def test_cancel_run(self, mock_requests):
|
293 | 314 | headers=USER_AGENT_HEADER,
|
294 | 315 | timeout=self.hook.timeout_seconds)
|
295 | 316 |
|
| 317 | + @mock.patch('airflow.contrib.hooks.databricks_hook.requests') |
| 318 | + def test_start_cluster(self, mock_requests): |
| 319 | + mock_requests.codes.ok = 200 |
| 320 | + mock_requests.post.return_value.json.return_value = {} |
| 321 | + status_code_mock = mock.PropertyMock(return_value=200) |
| 322 | + type(mock_requests.post.return_value).status_code = status_code_mock |
| 323 | + |
| 324 | + self.hook.start_cluster({"cluster_id": CLUSTER_ID}) |
| 325 | + |
| 326 | + mock_requests.post.assert_called_once_with( |
| 327 | + start_cluster_endpoint(HOST), |
| 328 | + json={'cluster_id': CLUSTER_ID}, |
| 329 | + auth=(LOGIN, PASSWORD), |
| 330 | + headers=USER_AGENT_HEADER, |
| 331 | + timeout=self.hook.timeout_seconds) |
| 332 | + |
| 333 | + @mock.patch('airflow.contrib.hooks.databricks_hook.requests') |
| 334 | + def test_restart_cluster(self, mock_requests): |
| 335 | + mock_requests.codes.ok = 200 |
| 336 | + mock_requests.post.return_value.json.return_value = {} |
| 337 | + status_code_mock = mock.PropertyMock(return_value=200) |
| 338 | + type(mock_requests.post.return_value).status_code = status_code_mock |
| 339 | + |
| 340 | + self.hook.restart_cluster({"cluster_id": CLUSTER_ID}) |
| 341 | + |
| 342 | + mock_requests.post.assert_called_once_with( |
| 343 | + restart_cluster_endpoint(HOST), |
| 344 | + json={'cluster_id': CLUSTER_ID}, |
| 345 | + auth=(LOGIN, PASSWORD), |
| 346 | + headers=USER_AGENT_HEADER, |
| 347 | + timeout=self.hook.timeout_seconds) |
| 348 | + |
| 349 | + @mock.patch('airflow.contrib.hooks.databricks_hook.requests') |
| 350 | + def test_terminate_cluster(self, mock_requests): |
| 351 | + mock_requests.codes.ok = 200 |
| 352 | + mock_requests.post.return_value.json.return_value = {} |
| 353 | + status_code_mock = mock.PropertyMock(return_value=200) |
| 354 | + type(mock_requests.post.return_value).status_code = status_code_mock |
| 355 | + |
| 356 | + self.hook.terminate_cluster({"cluster_id": CLUSTER_ID}) |
| 357 | + |
| 358 | + mock_requests.post.assert_called_once_with( |
| 359 | + terminate_cluster_endpoint(HOST), |
| 360 | + json={'cluster_id': CLUSTER_ID}, |
| 361 | + auth=(LOGIN, PASSWORD), |
| 362 | + headers=USER_AGENT_HEADER, |
| 363 | + timeout=self.hook.timeout_seconds) |
| 364 | + |
296 | 365 |
|
297 | 366 | class DatabricksHookTokenTest(unittest.TestCase):
|
298 | 367 | """
|
|
0 commit comments