Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIRFLOW-3049] Add extra operations for Mongo hook #3890

Merged
merged 2 commits into from
Oct 29, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 145 additions & 1 deletion airflow/contrib/hooks/mongo_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ssl import CERT_NONE

from airflow.hooks.base_hook import BaseHook
from pymongo import MongoClient
from pymongo import MongoClient, ReplaceOne


class MongoHook(BaseHook):
Expand Down Expand Up @@ -130,3 +130,147 @@ def insert_many(self, mongo_collection, docs, mongo_db=None, **kwargs):
collection = self.get_collection(mongo_collection, mongo_db=mongo_db)

return collection.insert_many(docs, **kwargs)

def update_one(self, mongo_collection, filter_doc, update_doc,
mongo_db=None, **kwargs):
"""
Updates a single document in a mongo collection.
https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.update_one

:param mongo_collection: The name of the collection to update.
:type mongo_collection: str
:param filter_doc: A query that matches the documents to update.
:type filter_doc: dict
:param update_doc: The modifications to apply.
:type update_doc: dict
:param mongo_db: The name of the database to use.
Can be omitted; then the database from the connection string is used.
:type mongo_db: str

"""
collection = self.get_collection(mongo_collection, mongo_db=mongo_db)

return collection.update_one(filter_doc, update_doc, **kwargs)

def update_many(self, mongo_collection, filter_doc, update_doc,
mongo_db=None, **kwargs):
"""
Updates one or more documents in a mongo collection.
https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.update_many

:param mongo_collection: The name of the collection to update.
:type mongo_collection: str
:param filter_doc: A query that matches the documents to update.
:type filter_doc: dict
:param update_doc: The modifications to apply.
:type update_doc: dict
:param mongo_db: The name of the database to use.
Can be omitted; then the database from the connection string is used.
:type mongo_db: str

"""
collection = self.get_collection(mongo_collection, mongo_db=mongo_db)

return collection.update_many(filter_doc, update_doc, **kwargs)

def replace_one(self, mongo_collection, doc, filter_doc=None,
mongo_db=None, **kwargs):
"""
Replaces a single document in a mongo collection.
https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.replace_one

If no filter document is given, it is assumed that the replacement
document contains the _id field which is then used as filter.

:param mongo_collection: The name of the collection to update.
:type mongo_collection: str
:param doc: The new document.
:type doc: dict
:param filter_doc: A query that matches the documents to replace.
Can be omitted; then the _id field from doc will be used.
:type filter_doc: dict
:param mongo_db: The name of the database to use.
Can be omitted; then the database from the connection string is used.
:type mongo_db: str
"""
collection = self.get_collection(mongo_collection, mongo_db=mongo_db)

if not filter_doc:
filter_doc = {'_id': doc['_id']}

return collection.replace_one(filter_doc, doc, **kwargs)

def replace_many(self, mongo_collection, docs,
filter_docs=None, mongo_db=None, upsert=False, collation=None,
**bulk_kwargs):
Copy link
Contributor

@Fokko Fokko Sep 15, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe call this one just kwargs as well, to make it congruent with the other calls.

"""
Replaces many documents in a mongo collection.

Uses bulk_write with multiple ReplaceOne operations
https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.bulk_write

If no filter documents are given, it is assumed that all replacement
documents contain the _id field which are then used as filters.

:param mongo_collection: The name of the collection to update.
:type mongo_collection: str
:param docs: The new documents.
:type docs: list(dict)
:param filter_docs: A list of queries that match the documents to replace.
Can be omitted; then the _id fields from docs will be used.
:type filter_docs: list(dict)
:param mongo_db: The name of the database to use.
Can be omitted; then the database from the connection string is used.
:type mongo_db: str
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The upsert and collation are missing from the docstring.

"""
collection = self.get_collection(mongo_collection, mongo_db=mongo_db)

if not filter_docs:
filter_docs = [{'_id': doc['_id']} for doc in docs]

requests = [
ReplaceOne(
filter_docs[i],
docs[i],
upsert=upsert,
collation=collation)
for i in range(len(docs))
]

return collection.bulk_write(requests, **bulk_kwargs)

def delete_one(self, mongo_collection, filter_doc, mongo_db=None, **kwargs):
"""
Deletes a single document in a mongo collection.
https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.delete_one

:param mongo_collection: The name of the collection to delete from.
:type mongo_collection: str
:param filter_doc: A query that matches the document to delete.
:type filter_doc: dict
:param mongo_db: The name of the database to use.
Can be omitted; then the database from the connection string is used.
:type mongo_db: str

"""
collection = self.get_collection(mongo_collection, mongo_db=mongo_db)

return collection.delete_one(filter_doc, **kwargs)

def delete_many(self, mongo_collection, filter_doc, mongo_db=None, **kwargs):
"""
Deletes one or more documents in a mongo collection.
https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.delete_many

:param mongo_collection: The name of the collection to delete from.
:type mongo_collection: str
:param filter_doc: A query that matches the documents to delete.
:type filter_doc: dict
:param mongo_db: The name of the database to use.
Can be omitted; then the database from the connection string is used.
:type mongo_db: str

"""
collection = self.get_collection(mongo_collection, mongo_db=mongo_db)

return collection.delete_many(filter_doc, **kwargs)
152 changes: 152 additions & 0 deletions tests/contrib/hooks/test_mongo_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,158 @@ def test_insert_many(self):
result_objs = [result for result in result_objs]
self.assertEqual(len(result_objs), 2)

@unittest.skipIf(mongomock is None, 'mongomock package not present')
def test_update_one(self):
collection = mongomock.MongoClient().db.collection
obj = {'_id': '1', 'field': 0}
collection.insert_one(obj)

filter_doc = obj
update_doc = {'$inc': {'field': 123}}

self.hook.update_one(collection, filter_doc, update_doc)

result_obj = collection.find_one(filter='1')
self.assertEqual(123, result_obj['field'])

@unittest.skipIf(mongomock is None, 'mongomock package not present')
def test_update_one_with_upsert(self):
collection = mongomock.MongoClient().db.collection

filter_doc = {'_id': '1', 'field': 0}
update_doc = {'$inc': {'field': 123}}

self.hook.update_one(collection, filter_doc, update_doc, upsert=True)

result_obj = collection.find_one(filter='1')
self.assertEqual(123, result_obj['field'])

@unittest.skipIf(mongomock is None, 'mongomock package not present')
def test_update_many(self):
collection = mongomock.MongoClient().db.collection
obj1 = {'_id': '1', 'field': 0}
obj2 = {'_id': '2', 'field': 0}
collection.insert_many([obj1, obj2])

filter_doc = {'field': 0}
update_doc = {'$inc': {'field': 123}}

self.hook.update_many(collection, filter_doc, update_doc)

result_obj = collection.find_one(filter='1')
self.assertEqual(123, result_obj['field'])

result_obj = collection.find_one(filter='2')
self.assertEqual(123, result_obj['field'])

@unittest.skipIf(mongomock is None, 'mongomock package not present')
def test_update_many_with_upsert(self):
collection = mongomock.MongoClient().db.collection

filter_doc = {'_id': '1', 'field': 0}
update_doc = {'$inc': {'field': 123}}

self.hook.update_many(collection, filter_doc, update_doc, upsert=True)

result_obj = collection.find_one(filter='1')
self.assertEqual(123, result_obj['field'])

@unittest.skipIf(mongomock is None, 'mongomock package not present')
def test_replace_one(self):
collection = mongomock.MongoClient().db.collection
obj1 = {'_id': '1', 'field': 'test_value_1'}
obj2 = {'_id': '2', 'field': 'test_value_2'}
collection.insert_many([obj1, obj2])

obj1['field'] = 'test_value_1_updated'
self.hook.replace_one(collection, obj1)

result_obj = collection.find_one(filter='1')
self.assertEqual('test_value_1_updated', result_obj['field'])

# Other document should stay intact
result_obj = collection.find_one(filter='2')
self.assertEqual('test_value_2', result_obj['field'])

@unittest.skipIf(mongomock is None, 'mongomock package not present')
def test_replace_one_with_filter(self):
collection = mongomock.MongoClient().db.collection
obj1 = {'_id': '1', 'field': 'test_value_1'}
obj2 = {'_id': '2', 'field': 'test_value_2'}
collection.insert_many([obj1, obj2])

obj1['field'] = 'test_value_1_updated'
self.hook.replace_one(collection, obj1, {'field': 'test_value_1'})

result_obj = collection.find_one(filter='1')
self.assertEqual('test_value_1_updated', result_obj['field'])

# Other document should stay intact
result_obj = collection.find_one(filter='2')
self.assertEqual('test_value_2', result_obj['field'])

@unittest.skipIf(mongomock is None, 'mongomock package not present')
def test_replace_one_with_upsert(self):
collection = mongomock.MongoClient().db.collection

obj = {'_id': '1', 'field': 'test_value_1'}
self.hook.replace_one(collection, obj, upsert=True)

result_obj = collection.find_one(filter='1')
self.assertEqual('test_value_1', result_obj['field'])

@unittest.skipIf(mongomock is None, 'mongomock package not present')
def test_replace_many(self):
collection = mongomock.MongoClient().db.collection
obj1 = {'_id': '1', 'field': 'test_value_1'}
obj2 = {'_id': '2', 'field': 'test_value_2'}
collection.insert_many([obj1, obj2])

obj1['field'] = 'test_value_1_updated'
obj2['field'] = 'test_value_2_updated'
self.hook.replace_many(collection, [obj1, obj2])

result_obj = collection.find_one(filter='1')
self.assertEqual('test_value_1_updated', result_obj['field'])

result_obj = collection.find_one(filter='2')
self.assertEqual('test_value_2_updated', result_obj['field'])

@unittest.skipIf(mongomock is None, 'mongomock package not present')
def test_replace_many_with_upsert(self):
collection = mongomock.MongoClient().db.collection
obj1 = {'_id': '1', 'field': 'test_value_1'}
obj2 = {'_id': '2', 'field': 'test_value_2'}

self.hook.replace_many(collection, [obj1, obj2], upsert=True)

result_obj = collection.find_one(filter='1')
self.assertEqual('test_value_1', result_obj['field'])

result_obj = collection.find_one(filter='2')
self.assertEqual('test_value_2', result_obj['field'])

@unittest.skipIf(mongomock is None, 'mongomock package not present')
def test_delete_one(self):
collection = mongomock.MongoClient().db.collection
obj = {'_id': '1'}
collection.insert_one(obj)

self.hook.delete_one(collection, {'_id': '1'})

self.assertEqual(0, collection.count())

@unittest.skipIf(mongomock is None, 'mongomock package not present')
def test_delete_many(self):
collection = mongomock.MongoClient().db.collection
obj1 = {'_id': '1', 'field': 'value'}
obj2 = {'_id': '2', 'field': 'value'}
collection.insert_many([obj1, obj2])

self.hook.delete_many(collection, {'field': 'value'})

self.assertEqual(0, collection.count())

@unittest.skipIf(mongomock is None, 'mongomock package not present')
def test_find_one(self):
collection = mongomock.MongoClient().db.collection
Expand Down