Skip to content

Commit e768eb2

Browse files
dlebechwayne.morris
authored and
wayne.morris
committed
[AIRFLOW-3049] Add extra operations for Mongo hook (apache#3890)
This commit adds update, replace and delete operations for the Mongo hook.
1 parent 79ae2ec commit e768eb2

File tree

2 files changed

+308
-1
lines changed

2 files changed

+308
-1
lines changed

airflow/contrib/hooks/mongo_hook.py

+156-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from ssl import CERT_NONE
1515

1616
from airflow.hooks.base_hook import BaseHook
17-
from pymongo import MongoClient
17+
from pymongo import MongoClient, ReplaceOne
1818

1919

2020
class MongoHook(BaseHook):
@@ -130,3 +130,158 @@ def insert_many(self, mongo_collection, docs, mongo_db=None, **kwargs):
130130
collection = self.get_collection(mongo_collection, mongo_db=mongo_db)
131131

132132
return collection.insert_many(docs, **kwargs)
133+
134+
def update_one(self, mongo_collection, filter_doc, update_doc,
135+
mongo_db=None, **kwargs):
136+
"""
137+
Updates a single document in a mongo collection.
138+
https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.update_one
139+
140+
:param mongo_collection: The name of the collection to update.
141+
:type mongo_collection: str
142+
:param filter_doc: A query that matches the documents to update.
143+
:type filter_doc: dict
144+
:param update_doc: The modifications to apply.
145+
:type update_doc: dict
146+
:param mongo_db: The name of the database to use.
147+
Can be omitted; then the database from the connection string is used.
148+
:type mongo_db: str
149+
150+
"""
151+
collection = self.get_collection(mongo_collection, mongo_db=mongo_db)
152+
153+
return collection.update_one(filter_doc, update_doc, **kwargs)
154+
155+
def update_many(self, mongo_collection, filter_doc, update_doc,
156+
mongo_db=None, **kwargs):
157+
"""
158+
Updates one or more documents in a mongo collection.
159+
https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.update_many
160+
161+
:param mongo_collection: The name of the collection to update.
162+
:type mongo_collection: str
163+
:param filter_doc: A query that matches the documents to update.
164+
:type filter_doc: dict
165+
:param update_doc: The modifications to apply.
166+
:type update_doc: dict
167+
:param mongo_db: The name of the database to use.
168+
Can be omitted; then the database from the connection string is used.
169+
:type mongo_db: str
170+
171+
"""
172+
collection = self.get_collection(mongo_collection, mongo_db=mongo_db)
173+
174+
return collection.update_many(filter_doc, update_doc, **kwargs)
175+
176+
def replace_one(self, mongo_collection, doc, filter_doc=None,
177+
mongo_db=None, **kwargs):
178+
"""
179+
Replaces a single document in a mongo collection.
180+
https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.replace_one
181+
182+
.. note::
183+
If no ``filter_doc`` is given, it is assumed that the replacement
184+
document contain the ``_id`` field which is then used as filters.
185+
186+
:param mongo_collection: The name of the collection to update.
187+
:type mongo_collection: str
188+
:param doc: The new document.
189+
:type doc: dict
190+
:param filter_doc: A query that matches the documents to replace.
191+
Can be omitted; then the _id field from doc will be used.
192+
:type filter_doc: dict
193+
:param mongo_db: The name of the database to use.
194+
Can be omitted; then the database from the connection string is used.
195+
:type mongo_db: str
196+
"""
197+
collection = self.get_collection(mongo_collection, mongo_db=mongo_db)
198+
199+
if not filter_doc:
200+
filter_doc = {'_id': doc['_id']}
201+
202+
return collection.replace_one(filter_doc, doc, **kwargs)
203+
204+
def replace_many(self, mongo_collection, docs,
205+
filter_docs=None, mongo_db=None, upsert=False, collation=None,
206+
**kwargs):
207+
"""
208+
Replaces many documents in a mongo collection.
209+
210+
Uses bulk_write with multiple ReplaceOne operations
211+
https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.bulk_write
212+
213+
.. note::
214+
If no ``filter_docs``are given, it is assumed that all
215+
replacement documents contain the ``_id`` field which are then
216+
used as filters.
217+
218+
:param mongo_collection: The name of the collection to update.
219+
:type mongo_collection: str
220+
:param docs: The new documents.
221+
:type docs: list(dict)
222+
:param filter_docs: A list of queries that match the documents to replace.
223+
Can be omitted; then the _id fields from docs will be used.
224+
:type filter_docs: list(dict)
225+
:param mongo_db: The name of the database to use.
226+
Can be omitted; then the database from the connection string is used.
227+
:type mongo_db: str
228+
:param upsert: If ``True``, perform an insert if no documents
229+
match the filters for the replace operation.
230+
:type upsert: bool
231+
:param collation: An instance of
232+
:class:`~pymongo.collation.Collation`. This option is only
233+
supported on MongoDB 3.4 and above.
234+
:type collation: :class:`~pymongo.collation.Collation`
235+
236+
"""
237+
collection = self.get_collection(mongo_collection, mongo_db=mongo_db)
238+
239+
if not filter_docs:
240+
filter_docs = [{'_id': doc['_id']} for doc in docs]
241+
242+
requests = [
243+
ReplaceOne(
244+
filter_docs[i],
245+
docs[i],
246+
upsert=upsert,
247+
collation=collation)
248+
for i in range(len(docs))
249+
]
250+
251+
return collection.bulk_write(requests, **kwargs)
252+
253+
def delete_one(self, mongo_collection, filter_doc, mongo_db=None, **kwargs):
254+
"""
255+
Deletes a single document in a mongo collection.
256+
https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.delete_one
257+
258+
:param mongo_collection: The name of the collection to delete from.
259+
:type mongo_collection: str
260+
:param filter_doc: A query that matches the document to delete.
261+
:type filter_doc: dict
262+
:param mongo_db: The name of the database to use.
263+
Can be omitted; then the database from the connection string is used.
264+
:type mongo_db: str
265+
266+
"""
267+
collection = self.get_collection(mongo_collection, mongo_db=mongo_db)
268+
269+
return collection.delete_one(filter_doc, **kwargs)
270+
271+
def delete_many(self, mongo_collection, filter_doc, mongo_db=None, **kwargs):
272+
"""
273+
Deletes one or more documents in a mongo collection.
274+
https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.delete_many
275+
276+
:param mongo_collection: The name of the collection to delete from.
277+
:type mongo_collection: str
278+
:param filter_doc: A query that matches the documents to delete.
279+
:type filter_doc: dict
280+
:param mongo_db: The name of the database to use.
281+
Can be omitted; then the database from the connection string is used.
282+
:type mongo_db: str
283+
284+
"""
285+
collection = self.get_collection(mongo_collection, mongo_db=mongo_db)
286+
287+
return collection.delete_many(filter_doc, **kwargs)

tests/contrib/hooks/test_mongo_hook.py

+152
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,158 @@ def test_insert_many(self):
6969
result_objs = [result for result in result_objs]
7070
self.assertEqual(len(result_objs), 2)
7171

72+
@unittest.skipIf(mongomock is None, 'mongomock package not present')
73+
def test_update_one(self):
74+
collection = mongomock.MongoClient().db.collection
75+
obj = {'_id': '1', 'field': 0}
76+
collection.insert_one(obj)
77+
78+
filter_doc = obj
79+
update_doc = {'$inc': {'field': 123}}
80+
81+
self.hook.update_one(collection, filter_doc, update_doc)
82+
83+
result_obj = collection.find_one(filter='1')
84+
self.assertEqual(123, result_obj['field'])
85+
86+
@unittest.skipIf(mongomock is None, 'mongomock package not present')
87+
def test_update_one_with_upsert(self):
88+
collection = mongomock.MongoClient().db.collection
89+
90+
filter_doc = {'_id': '1', 'field': 0}
91+
update_doc = {'$inc': {'field': 123}}
92+
93+
self.hook.update_one(collection, filter_doc, update_doc, upsert=True)
94+
95+
result_obj = collection.find_one(filter='1')
96+
self.assertEqual(123, result_obj['field'])
97+
98+
@unittest.skipIf(mongomock is None, 'mongomock package not present')
99+
def test_update_many(self):
100+
collection = mongomock.MongoClient().db.collection
101+
obj1 = {'_id': '1', 'field': 0}
102+
obj2 = {'_id': '2', 'field': 0}
103+
collection.insert_many([obj1, obj2])
104+
105+
filter_doc = {'field': 0}
106+
update_doc = {'$inc': {'field': 123}}
107+
108+
self.hook.update_many(collection, filter_doc, update_doc)
109+
110+
result_obj = collection.find_one(filter='1')
111+
self.assertEqual(123, result_obj['field'])
112+
113+
result_obj = collection.find_one(filter='2')
114+
self.assertEqual(123, result_obj['field'])
115+
116+
@unittest.skipIf(mongomock is None, 'mongomock package not present')
117+
def test_update_many_with_upsert(self):
118+
collection = mongomock.MongoClient().db.collection
119+
120+
filter_doc = {'_id': '1', 'field': 0}
121+
update_doc = {'$inc': {'field': 123}}
122+
123+
self.hook.update_many(collection, filter_doc, update_doc, upsert=True)
124+
125+
result_obj = collection.find_one(filter='1')
126+
self.assertEqual(123, result_obj['field'])
127+
128+
@unittest.skipIf(mongomock is None, 'mongomock package not present')
129+
def test_replace_one(self):
130+
collection = mongomock.MongoClient().db.collection
131+
obj1 = {'_id': '1', 'field': 'test_value_1'}
132+
obj2 = {'_id': '2', 'field': 'test_value_2'}
133+
collection.insert_many([obj1, obj2])
134+
135+
obj1['field'] = 'test_value_1_updated'
136+
self.hook.replace_one(collection, obj1)
137+
138+
result_obj = collection.find_one(filter='1')
139+
self.assertEqual('test_value_1_updated', result_obj['field'])
140+
141+
# Other document should stay intact
142+
result_obj = collection.find_one(filter='2')
143+
self.assertEqual('test_value_2', result_obj['field'])
144+
145+
@unittest.skipIf(mongomock is None, 'mongomock package not present')
146+
def test_replace_one_with_filter(self):
147+
collection = mongomock.MongoClient().db.collection
148+
obj1 = {'_id': '1', 'field': 'test_value_1'}
149+
obj2 = {'_id': '2', 'field': 'test_value_2'}
150+
collection.insert_many([obj1, obj2])
151+
152+
obj1['field'] = 'test_value_1_updated'
153+
self.hook.replace_one(collection, obj1, {'field': 'test_value_1'})
154+
155+
result_obj = collection.find_one(filter='1')
156+
self.assertEqual('test_value_1_updated', result_obj['field'])
157+
158+
# Other document should stay intact
159+
result_obj = collection.find_one(filter='2')
160+
self.assertEqual('test_value_2', result_obj['field'])
161+
162+
@unittest.skipIf(mongomock is None, 'mongomock package not present')
163+
def test_replace_one_with_upsert(self):
164+
collection = mongomock.MongoClient().db.collection
165+
166+
obj = {'_id': '1', 'field': 'test_value_1'}
167+
self.hook.replace_one(collection, obj, upsert=True)
168+
169+
result_obj = collection.find_one(filter='1')
170+
self.assertEqual('test_value_1', result_obj['field'])
171+
172+
@unittest.skipIf(mongomock is None, 'mongomock package not present')
173+
def test_replace_many(self):
174+
collection = mongomock.MongoClient().db.collection
175+
obj1 = {'_id': '1', 'field': 'test_value_1'}
176+
obj2 = {'_id': '2', 'field': 'test_value_2'}
177+
collection.insert_many([obj1, obj2])
178+
179+
obj1['field'] = 'test_value_1_updated'
180+
obj2['field'] = 'test_value_2_updated'
181+
self.hook.replace_many(collection, [obj1, obj2])
182+
183+
result_obj = collection.find_one(filter='1')
184+
self.assertEqual('test_value_1_updated', result_obj['field'])
185+
186+
result_obj = collection.find_one(filter='2')
187+
self.assertEqual('test_value_2_updated', result_obj['field'])
188+
189+
@unittest.skipIf(mongomock is None, 'mongomock package not present')
190+
def test_replace_many_with_upsert(self):
191+
collection = mongomock.MongoClient().db.collection
192+
obj1 = {'_id': '1', 'field': 'test_value_1'}
193+
obj2 = {'_id': '2', 'field': 'test_value_2'}
194+
195+
self.hook.replace_many(collection, [obj1, obj2], upsert=True)
196+
197+
result_obj = collection.find_one(filter='1')
198+
self.assertEqual('test_value_1', result_obj['field'])
199+
200+
result_obj = collection.find_one(filter='2')
201+
self.assertEqual('test_value_2', result_obj['field'])
202+
203+
@unittest.skipIf(mongomock is None, 'mongomock package not present')
204+
def test_delete_one(self):
205+
collection = mongomock.MongoClient().db.collection
206+
obj = {'_id': '1'}
207+
collection.insert_one(obj)
208+
209+
self.hook.delete_one(collection, {'_id': '1'})
210+
211+
self.assertEqual(0, collection.count())
212+
213+
@unittest.skipIf(mongomock is None, 'mongomock package not present')
214+
def test_delete_many(self):
215+
collection = mongomock.MongoClient().db.collection
216+
obj1 = {'_id': '1', 'field': 'value'}
217+
obj2 = {'_id': '2', 'field': 'value'}
218+
collection.insert_many([obj1, obj2])
219+
220+
self.hook.delete_many(collection, {'field': 'value'})
221+
222+
self.assertEqual(0, collection.count())
223+
72224
@unittest.skipIf(mongomock is None, 'mongomock package not present')
73225
def test_find_one(self):
74226
collection = mongomock.MongoClient().db.collection

0 commit comments

Comments
 (0)