Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DH-335] Make default search type whole string. #268

Merged
merged 2 commits into from
Jun 21, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions datahub/search/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from elasticsearch.helpers import bulk as es_bulk
from elasticsearch_dsl import Search
from elasticsearch_dsl.connections import connections
from elasticsearch_dsl.query import Q
from elasticsearch_dsl.query import Match, MatchPhrase, Q


def configure_connection():
Expand All @@ -33,12 +33,22 @@ def configure_connection():
)


def get_search_term_query(term):
"""Returns search term query."""
return Q('bool', should=[
MatchPhrase(name={'query': term, 'boost': 2}),
MatchPhrase(_all={'query': term, 'boost': 1.5}),
Match(name={'query': term, 'boost': 1.0}),
Match(_all={'query': term, 'boost': 0.5}),
])


def get_basic_search_query(term, entities=('company',), offset=0, limit=100):
"""Performs basic search looking for name and then _all in entity.

Also returns number of results in other entities.
"""
query = Q('multi_match', query=term, fields=['name', '_all'])
query = get_search_term_query(term)
s = Search(index=settings.ES_INDEX).query(query)
s = s.post_filter(
Q('bool', should=[Q('term', _type=entity) for entity in entities])
Expand All @@ -55,7 +65,7 @@ def get_search_by_entity_query(term=None, filters=None, entity=None, ranges=None
"""Perform filtered search for given terms in given entity."""
query = [Q('term', _type=entity)]
if term != '':
query.append(Q('multi_match', query=term, fields=['name', '_all']))
query.append(get_search_term_query(term))

query_filter = []

Expand Down
170 changes: 141 additions & 29 deletions datahub/search/test/test_elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,103 @@
from datahub.search import elasticsearch


def test_get_search_term_query():
"""Tests search term query."""
query = elasticsearch.get_search_term_query('hello')

assert query.to_dict() == {
'bool': {
'should': [
{
'match_phrase': {
'name': {
'query': 'hello',
'boost': 2
}
}
}, {
'match_phrase': {
'_all': {
'query': 'hello',
'boost': 1.5
}
}
}, {
'match': {
'name': {
'query': 'hello',
'boost': 1.0
}
}
}, {
'match': {
'_all': {
'query': 'hello',
'boost': 0.5
}
}
}
]
}
}


def test_get_basic_search_query():
"""Tests basic search query."""
query = elasticsearch.get_basic_search_query('test', entities=('contact',), offset=5, limit=5)

assert query.to_dict() == {
'query': {
'multi_match': {
'query': 'test',
'fields': ['name', '_all']
'bool': {
'should': [
{
'match_phrase': {
'name': {
'query': 'test',
'boost': 2
}
}
}, {
'match_phrase': {
'_all': {
'query': 'test',
'boost': 1.5
}
}
}, {
'match': {
'name': {
'query': 'test',
'boost': 1.0
}
}
}, {
'match': {
'_all': {
'query': 'test',
'boost': 0.5
}
}
}
]
}
},
'post_filter': {
'bool': {
'should': [
{'term': {'_type': 'contact'}}
{
'term': {
'_type': 'contact'
}
}
]
}
},
'aggs': {
'count_by_type': {
'terms': {'field': '_type'}
'terms': {
'field': '_type'
}
}
},
'from': 5,
Expand Down Expand Up @@ -57,37 +133,73 @@ def test_search_by_entity_query():
assert query.to_dict() == {
'query': {
'bool': {
'must': [{
'term': {
'_type': 'company'
}}, {
'multi_match': {
'query': 'test',
'fields': ['name', '_all']
}}]
'must': [
{
'term': {
'_type': 'company'
}
}, {
'bool': {
'should': [
{
'match_phrase': {
'name': {
'query': 'test',
'boost': 2
}
}
}, {
'match_phrase': {
'_all': {
'query': 'test',
'boost': 1.5
}
}
}, {
'match': {
'name': {
'query': 'test',
'boost': 1.0
}
}
}, {
'match': {
'_all': {
'query': 'test',
'boost': 0.5
}
}
}
]
}
}
]
}
},
'post_filter': {
'bool': {
'must': [{
'term': {
'address_town': 'Woodside'
}}, {
'nested': {
'path': 'trading_address_country',
'query': {
'term': {
'trading_address_country.id':
'80756b9a-5d95-e211-a939-e4115bead28a'
'must': [
{
'term': {
'address_town': 'Woodside'
}
}, {
'nested': {
'path': 'trading_address_country',
'query': {
'term': {
'trading_address_country.id': '80756b9a-5d95-e211-a939-e4115bead28a'
}
}
}
}}, {
'range': {
'estimated_land_date': {
'gte': '2017-06-13T09:44:31.062870',
'lte': '2017-06-13T09:44:31.062870'
}, {
'range': {
'estimated_land_date': {
'gte': '2017-06-13T09:44:31.062870',
'lte': '2017-06-13T09:44:31.062870'
}
}
}}
}
]
}
},
Expand Down
27 changes: 26 additions & 1 deletion datahub/search/test/test_views.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import datetime

import pytest
from elasticsearch_dsl.connections import connections
from rest_framework import status
from rest_framework.reverse import reverse

from datahub.company.test.factories import CompanyFactory
from datahub.core import constants
from datahub.core.test_utils import LeelooTestCase

pytestmark = pytest.mark.django_db


@pytest.mark.usefixtures('setup_data')
@pytest.mark.usefixtures('setup_data', 'post_save_handlers')
class SearchTestCase(LeelooTestCase):
"""Tests search views."""

Expand Down Expand Up @@ -190,3 +192,26 @@ def test_search_investment_project_no_filters(self):
response = self.api_client.post(url, {})

assert response.status_code == status.HTTP_400_BAD_REQUEST

def test_search_results_quality(self):
"""Tests quality of results."""
CompanyFactory(name='The Risk Advisory Group').save()
CompanyFactory(name='The Advisory Group').save()
CompanyFactory(name='The Advisory').save()
CompanyFactory(name='The Advisories').save()

connections.get_connection().indices.refresh()

term = 'The Advisory'

url = reverse('api-v3:search:basic')
response = self.api_client.get(url, {
'term': term,
'entity': 'company'
})

assert response.data['count'] == 4
assert ['The Advisory',
'The Advisory Group',
'The Risk Advisory Group',
'The Advisories'] == [company['name'] for company in response.data['companies']]