diff --git a/invenio_records_resources/services/records/queryparser/__init__.py b/invenio_records_resources/services/records/queryparser/__init__.py index 97fe874c..a4a46f9e 100644 --- a/invenio_records_resources/services/records/queryparser/__init__.py +++ b/invenio_records_resources/services/records/queryparser/__init__.py @@ -10,10 +10,11 @@ from .query import QueryParser from .suggest import SuggestQueryParser -from .transformer import SearchFieldTransformer +from .transformer import FieldValueMapper, SearchFieldTransformer __all__ = ( + "FieldValueMapper", "QueryParser", - "SuggestQueryParser", "SearchFieldTransformer", + "SuggestQueryParser", ) diff --git a/invenio_records_resources/services/records/queryparser/query.py b/invenio_records_resources/services/records/queryparser/query.py index 96c4f4aa..e2e5d66e 100644 --- a/invenio_records_resources/services/records/queryparser/query.py +++ b/invenio_records_resources/services/records/queryparser/query.py @@ -12,6 +12,7 @@ from functools import partial from invenio_search.engine import dsl +from luqum.auto_head_tail import auto_head_tail from luqum.exceptions import ParseError from luqum.parser import parser as luqum_parser from werkzeug.utils import cached_property @@ -54,7 +55,7 @@ class SearchOptions: class SearchOptions: query_parser_cls = QueryParser.factory( fields=["metadata.title^2", "metadata.description"], - tree_transformer_factory=FieldTransformer.factory( + tree_transformer_factory=SearchFieldTransformer.factory( mapping={ "title": "metadata.title", "description": "metadata.description", @@ -125,9 +126,11 @@ def parse(self, query_str): # Perform transformation on the abstract syntax tree (AST) if self.tree_transformer_cls is not None: transformer = self.tree_transformer_cls( - mapping=self.mapping, allow_list=self.allow_list + mapping=self.mapping, + allow_list=self.allow_list, ) new_tree = transformer.visit(tree, context={"identity": self.identity}) + new_tree = auto_head_tail(new_tree) query_str = str(new_tree) return dsl.Q("query_string", query=query_str, **self.extra_params) except (ParseError, QuerystringValidationError): diff --git a/invenio_records_resources/services/records/queryparser/transformer.py b/invenio_records_resources/services/records/queryparser/transformer.py index b22a4bc9..3a68a918 100644 --- a/invenio_records_resources/services/records/queryparser/transformer.py +++ b/invenio_records_resources/services/records/queryparser/transformer.py @@ -13,14 +13,35 @@ how to build your own query tree transformer. """ -from functools import partial - from invenio_i18n import gettext as _ from luqum.visitor import TreeTransformer from invenio_records_resources.services.errors import QuerystringValidationError +class FieldValueMapper: + """Class used to remap values to new terms.""" + + def __init__(self, term_name, word=None, phrase=None): + """Initialize field value mapper.""" + self._term_name = term_name + self._word_fun = word + self._phrase_fun = phrase + + @property + def term_name(self): + """Get the term name.""" + return self._term_name + + def map_word(self, node): + """Modify a word node.""" + return self._word_fun(node) if self._word_fun else node + + def map_phrase(self, node): + """Modify a phrase node.""" + return self._phrase_fun(node) if self._phrase_fun else node + + class SearchFieldTransformer(TreeTransformer): """Transform from user-friendly field names to internal field names.""" @@ -28,12 +49,17 @@ def __init__(self, mapping, allow_list, *args, **kwargs): """Constructor.""" self._mapping = mapping self._allow_list = allow_list - super().__init__(self, *args, **kwargs) + super().__init__(*args, **kwargs) def visit_search_field(self, node, context): """Visit a search field.""" # Use the node name if not mapped for transformation. term_name = self._mapping.get(node.name, node.name) + field_value_mapper = None + + if isinstance(term_name, FieldValueMapper): + field_value_mapper = term_name + term_name = field_value_mapper.term_name # If a allow list exists, the term must be allowed to be queried. if self._allow_list and not term_name in self._allow_list: @@ -41,8 +67,21 @@ def visit_search_field(self, node, context): _("Invalid search field: {field_name}.").format(field_name=node.name) ) + if field_value_mapper: + context["field_value_mapper"] = field_value_mapper + # Returns a copy of the node. new_node = node.clone_item() new_node.name = term_name new_node.children = list(self.clone_children(node, new_node, context)) yield new_node + + def visit_word(self, node, context): + """Visit a word term.""" + mapper = context.get("field_value_mapper") + yield node if mapper is None else mapper.map_word(node) + + def visit_phrase(self, node, context): + """Visit a phrase term.""" + mapper = context.get("field_value_mapper") + yield node if mapper is None else mapper.map_phrase(node) diff --git a/tests/services/test_service_queryparser.py b/tests/services/test_service_queryparser.py index 53080e8d..7c497ff2 100644 --- a/tests/services/test_service_queryparser.py +++ b/tests/services/test_service_queryparser.py @@ -10,8 +10,10 @@ import pytest from invenio_access.permissions import system_identity +from luqum.tree import Phrase from invenio_records_resources.services.records.queryparser import ( + FieldValueMapper, QueryParser, SearchFieldTransformer, ) @@ -151,3 +153,44 @@ def test_parser_fields(allow_list, fields, expected_fields): ) assert not set(p(system_identity).fields).difference(expected_fields) + + +@pytest.mark.parametrize( + "query,transformed_query", + [ + ("doi:10.5281/zenodo.123", 'metadata.doi:"10.5281/zenodo.123"'), + ("doi:(blr OR biosyslit)", 'metadata.doi:("blr" OR "biosyslit")'), + ("doi:(+blr -biosyslit) test", 'metadata.doi:(+"blr" -"biosyslit") test'), + ("lol:test", "lol:lol"), + ( + "doi:(b1 OR b2) lol:(test test1 test2)^2", + 'metadata.doi:("b1" OR "b2") lol:(lol lol lol)^2', + ), + ], +) +def test_querystring_valuemapper(query, transformed_query): + """Invalid syntax falls back to multi match query.""" + + def word_to_phrase(node): + return Phrase( + f'"{node.value}"', + pos=node.pos, + size=node.size + 2, + head=node.head, + tail=node.tail, + ) + + def lol(node): + node.value = "lol" + return node + + p = QueryParser.factory( + mapping={ + "doi": FieldValueMapper("metadata.doi", word=word_to_phrase), + "lol": FieldValueMapper("lol", word=lol), + }, + tree_transformer_cls=SearchFieldTransformer, + ) + assert p(system_identity).parse(query).to_dict() == { + "query_string": {"query": transformed_query} + }