Skip to content

Commit e3e97b5

Browse files
authored
Merge pull request #268 from uktrade/feature/search
[DH-335] Make default search type whole string.
2 parents 7be9a77 + 376a4e1 commit e3e97b5

File tree

3 files changed

+180
-33
lines changed

3 files changed

+180
-33
lines changed

datahub/search/elasticsearch.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from elasticsearch.helpers import bulk as es_bulk
77
from elasticsearch_dsl import Search
88
from elasticsearch_dsl.connections import connections
9-
from elasticsearch_dsl.query import Q
9+
from elasticsearch_dsl.query import Match, MatchPhrase, Q
1010

1111

1212
def configure_connection():
@@ -33,12 +33,22 @@ def configure_connection():
3333
)
3434

3535

36+
def get_search_term_query(term):
37+
"""Returns search term query."""
38+
return Q('bool', should=[
39+
MatchPhrase(name={'query': term, 'boost': 2}),
40+
MatchPhrase(_all={'query': term, 'boost': 1.5}),
41+
Match(name={'query': term, 'boost': 1.0}),
42+
Match(_all={'query': term, 'boost': 0.5}),
43+
])
44+
45+
3646
def get_basic_search_query(term, entities=('company',), offset=0, limit=100):
3747
"""Performs basic search looking for name and then _all in entity.
3848
3949
Also returns number of results in other entities.
4050
"""
41-
query = Q('multi_match', query=term, fields=['name', '_all'])
51+
query = get_search_term_query(term)
4252
s = Search(index=settings.ES_INDEX).query(query)
4353
s = s.post_filter(
4454
Q('bool', should=[Q('term', _type=entity) for entity in entities])
@@ -55,7 +65,7 @@ def get_search_by_entity_query(term=None, filters=None, entity=None, ranges=None
5565
"""Perform filtered search for given terms in given entity."""
5666
query = [Q('term', _type=entity)]
5767
if term != '':
58-
query.append(Q('multi_match', query=term, fields=['name', '_all']))
68+
query.append(get_search_term_query(term))
5969

6070
query_filter = []
6171

datahub/search/test/test_elasticsearch.py

+141-29
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,103 @@
44
from datahub.search import elasticsearch
55

66

7+
def test_get_search_term_query():
8+
"""Tests search term query."""
9+
query = elasticsearch.get_search_term_query('hello')
10+
11+
assert query.to_dict() == {
12+
'bool': {
13+
'should': [
14+
{
15+
'match_phrase': {
16+
'name': {
17+
'query': 'hello',
18+
'boost': 2
19+
}
20+
}
21+
}, {
22+
'match_phrase': {
23+
'_all': {
24+
'query': 'hello',
25+
'boost': 1.5
26+
}
27+
}
28+
}, {
29+
'match': {
30+
'name': {
31+
'query': 'hello',
32+
'boost': 1.0
33+
}
34+
}
35+
}, {
36+
'match': {
37+
'_all': {
38+
'query': 'hello',
39+
'boost': 0.5
40+
}
41+
}
42+
}
43+
]
44+
}
45+
}
46+
47+
748
def test_get_basic_search_query():
849
"""Tests basic search query."""
950
query = elasticsearch.get_basic_search_query('test', entities=('contact',), offset=5, limit=5)
1051

1152
assert query.to_dict() == {
1253
'query': {
13-
'multi_match': {
14-
'query': 'test',
15-
'fields': ['name', '_all']
54+
'bool': {
55+
'should': [
56+
{
57+
'match_phrase': {
58+
'name': {
59+
'query': 'test',
60+
'boost': 2
61+
}
62+
}
63+
}, {
64+
'match_phrase': {
65+
'_all': {
66+
'query': 'test',
67+
'boost': 1.5
68+
}
69+
}
70+
}, {
71+
'match': {
72+
'name': {
73+
'query': 'test',
74+
'boost': 1.0
75+
}
76+
}
77+
}, {
78+
'match': {
79+
'_all': {
80+
'query': 'test',
81+
'boost': 0.5
82+
}
83+
}
84+
}
85+
]
1686
}
1787
},
1888
'post_filter': {
1989
'bool': {
2090
'should': [
21-
{'term': {'_type': 'contact'}}
91+
{
92+
'term': {
93+
'_type': 'contact'
94+
}
95+
}
2296
]
2397
}
2498
},
2599
'aggs': {
26100
'count_by_type': {
27-
'terms': {'field': '_type'}
101+
'terms': {
102+
'field': '_type'
103+
}
28104
}
29105
},
30106
'from': 5,
@@ -57,37 +133,73 @@ def test_search_by_entity_query():
57133
assert query.to_dict() == {
58134
'query': {
59135
'bool': {
60-
'must': [{
61-
'term': {
62-
'_type': 'company'
63-
}}, {
64-
'multi_match': {
65-
'query': 'test',
66-
'fields': ['name', '_all']
67-
}}]
136+
'must': [
137+
{
138+
'term': {
139+
'_type': 'company'
140+
}
141+
}, {
142+
'bool': {
143+
'should': [
144+
{
145+
'match_phrase': {
146+
'name': {
147+
'query': 'test',
148+
'boost': 2
149+
}
150+
}
151+
}, {
152+
'match_phrase': {
153+
'_all': {
154+
'query': 'test',
155+
'boost': 1.5
156+
}
157+
}
158+
}, {
159+
'match': {
160+
'name': {
161+
'query': 'test',
162+
'boost': 1.0
163+
}
164+
}
165+
}, {
166+
'match': {
167+
'_all': {
168+
'query': 'test',
169+
'boost': 0.5
170+
}
171+
}
172+
}
173+
]
174+
}
175+
}
176+
]
68177
}
69178
},
70179
'post_filter': {
71180
'bool': {
72-
'must': [{
73-
'term': {
74-
'address_town': 'Woodside'
75-
}}, {
76-
'nested': {
77-
'path': 'trading_address_country',
78-
'query': {
79-
'term': {
80-
'trading_address_country.id':
81-
'80756b9a-5d95-e211-a939-e4115bead28a'
181+
'must': [
182+
{
183+
'term': {
184+
'address_town': 'Woodside'
185+
}
186+
}, {
187+
'nested': {
188+
'path': 'trading_address_country',
189+
'query': {
190+
'term': {
191+
'trading_address_country.id': '80756b9a-5d95-e211-a939-e4115bead28a'
192+
}
82193
}
83194
}
84-
}}, {
85-
'range': {
86-
'estimated_land_date': {
87-
'gte': '2017-06-13T09:44:31.062870',
88-
'lte': '2017-06-13T09:44:31.062870'
195+
}, {
196+
'range': {
197+
'estimated_land_date': {
198+
'gte': '2017-06-13T09:44:31.062870',
199+
'lte': '2017-06-13T09:44:31.062870'
200+
}
89201
}
90-
}}
202+
}
91203
]
92204
}
93205
},

datahub/search/test/test_views.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
import datetime
22

33
import pytest
4+
from elasticsearch_dsl.connections import connections
45
from rest_framework import status
56
from rest_framework.reverse import reverse
67

8+
from datahub.company.test.factories import CompanyFactory
79
from datahub.core import constants
810
from datahub.core.test_utils import LeelooTestCase
911

1012
pytestmark = pytest.mark.django_db
1113

1214

13-
@pytest.mark.usefixtures('setup_data')
15+
@pytest.mark.usefixtures('setup_data', 'post_save_handlers')
1416
class SearchTestCase(LeelooTestCase):
1517
"""Tests search views."""
1618

@@ -190,3 +192,26 @@ def test_search_investment_project_no_filters(self):
190192
response = self.api_client.post(url, {})
191193

192194
assert response.status_code == status.HTTP_400_BAD_REQUEST
195+
196+
def test_search_results_quality(self):
197+
"""Tests quality of results."""
198+
CompanyFactory(name='The Risk Advisory Group').save()
199+
CompanyFactory(name='The Advisory Group').save()
200+
CompanyFactory(name='The Advisory').save()
201+
CompanyFactory(name='The Advisories').save()
202+
203+
connections.get_connection().indices.refresh()
204+
205+
term = 'The Advisory'
206+
207+
url = reverse('api-v3:search:basic')
208+
response = self.api_client.get(url, {
209+
'term': term,
210+
'entity': 'company'
211+
})
212+
213+
assert response.data['count'] == 4
214+
assert ['The Advisory',
215+
'The Advisory Group',
216+
'The Risk Advisory Group',
217+
'The Advisories'] == [company['name'] for company in response.data['companies']]

0 commit comments

Comments
 (0)