Skip to content

Commit 079b0ee

Browse files
committed
[AIRFLOW-3197] EMRHook is missing new parameters of the AWS API (#4044)
Allow passing any params to the CreateJobFlow API, so that we don't have to stay up to date with AWS api changes.
1 parent 80d113e commit 079b0ee

File tree

4 files changed

+42
-24
lines changed

4 files changed

+42
-24
lines changed

UPDATING.md

+10
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,16 @@ then you need to change it like this
7373
def is_active(self):
7474
return self.active
7575

76+
### EMRHook now passes all of connection's extra to CreateJobFlow API
77+
78+
EMRHook.create_job_flow has been changed to pass all keys to the create_job_flow API, rather than
79+
just specific known keys for greater flexibility.
80+
81+
However prior to this release the "emr_default" sample connection that was created had invalid
82+
configuration, so creating EMR clusters might fail until your connection is updated. (Ec2KeyName,
83+
Ec2SubnetId, TerminationProtection and KeepJobFlowAliveWhenNoSteps were all top-level keys when they
84+
should be inside the "Instances" dict)
85+
7686
## Airflow 1.10
7787

7888
Installation and upgrading requires setting `SLUGIFY_USES_TEXT_UNIDECODE=yes` in your environment or

airflow/contrib/hooks/emr_hook.py

+1-14
Original file line numberDiff line numberDiff line change
@@ -51,19 +51,6 @@ def create_job_flow(self, job_flow_overrides):
5151
config = emr_conn.extra_dejson.copy()
5252
config.update(job_flow_overrides)
5353

54-
response = self.get_conn().run_job_flow(
55-
Name=config.get('Name'),
56-
LogUri=config.get('LogUri'),
57-
ReleaseLabel=config.get('ReleaseLabel'),
58-
Instances=config.get('Instances'),
59-
Steps=config.get('Steps', []),
60-
BootstrapActions=config.get('BootstrapActions', []),
61-
Applications=config.get('Applications'),
62-
Configurations=config.get('Configurations', []),
63-
VisibleToAllUsers=config.get('VisibleToAllUsers'),
64-
JobFlowRole=config.get('JobFlowRole'),
65-
ServiceRole=config.get('ServiceRole'),
66-
Tags=config.get('Tags')
67-
)
54+
response = self.get_conn().run_job_flow(**config)
6855

6956
return response

airflow/utils/db.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,8 @@ def initdb(rbac=False):
225225
"LogUri": "s3://my-emr-log-bucket/default_job_flow_location",
226226
"ReleaseLabel": "emr-4.6.0",
227227
"Instances": {
228+
"Ec2KeyName": "mykey",
229+
"Ec2SubnetId": "somesubnet",
228230
"InstanceGroups": [
229231
{
230232
"Name": "Master nodes",
@@ -240,12 +242,10 @@ def initdb(rbac=False):
240242
"InstanceType": "r3.2xlarge",
241243
"InstanceCount": 1
242244
}
243-
]
245+
],
246+
"TerminationProtected": false,
247+
"KeepJobFlowAliveWhenNoSteps": false
244248
},
245-
"Ec2KeyName": "mykey",
246-
"KeepJobFlowAliveWhenNoSteps": false,
247-
"TerminationProtected": false,
248-
"Ec2SubnetId": "somesubnet",
249249
"Applications":[
250250
{ "Name": "Spark" }
251251
],

tests/contrib/hooks/test_emr_hook.py

+26-5
Original file line numberDiff line numberDiff line change
@@ -31,28 +31,49 @@
3131
mock_emr = None
3232

3333

34+
@unittest.skipIf(mock_emr is None, 'moto package not present')
3435
class TestEmrHook(unittest.TestCase):
3536
@mock_emr
3637
def setUp(self):
3738
configuration.load_test_config()
3839

39-
@unittest.skipIf(mock_emr is None, 'mock_emr package not present')
4040
@mock_emr
4141
def test_get_conn_returns_a_boto3_connection(self):
4242
hook = EmrHook(aws_conn_id='aws_default')
4343
self.assertIsNotNone(hook.get_conn().list_clusters())
4444

45-
@unittest.skipIf(mock_emr is None, 'mock_emr package not present')
4645
@mock_emr
4746
def test_create_job_flow_uses_the_emr_config_to_create_a_cluster(self):
4847
client = boto3.client('emr', region_name='us-east-1')
49-
if len(client.list_clusters()['Clusters']):
50-
raise ValueError('AWS not properly mocked')
5148

5249
hook = EmrHook(aws_conn_id='aws_default', emr_conn_id='emr_default')
5350
cluster = hook.create_job_flow({'Name': 'test_cluster'})
5451

55-
self.assertEqual(client.list_clusters()['Clusters'][0]['Id'], cluster['JobFlowId'])
52+
self.assertEqual(client.list_clusters()['Clusters'][0]['Id'],
53+
cluster['JobFlowId'])
54+
55+
@mock_emr
56+
def test_create_job_flow_extra_args(self):
57+
"""
58+
Test that we can add extra arguments to the launch call.
59+
60+
This is useful for when AWS add new options, such as
61+
"SecurityConfiguration" so that we don't have to change our code
62+
"""
63+
client = boto3.client('emr', region_name='us-east-1')
64+
65+
hook = EmrHook(aws_conn_id='aws_default', emr_conn_id='emr_default')
66+
# AmiVersion is really old and almost no one will use it anymore, but
67+
# it's one of the "optional" request params that moto supports - it's
68+
# coverage of EMR isn't 100% it turns out.
69+
cluster = hook.create_job_flow({'Name': 'test_cluster',
70+
'ReleaseLabel': '',
71+
'AmiVersion': '3.2'})
72+
73+
cluster = client.describe_cluster(ClusterId=cluster['JobFlowId'])['Cluster']
74+
75+
# The AmiVersion comes back as {Requested,Running}AmiVersion fields.
76+
self.assertEqual(cluster['RequestedAmiVersion'], '3.2')
5677

5778
if __name__ == '__main__':
5879
unittest.main()

0 commit comments

Comments
 (0)