Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query parser: Map search field values #557

Merged
merged 1 commit into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@

from .query import QueryParser
from .suggest import SuggestQueryParser
from .transformer import SearchFieldTransformer
from .transformer import FieldValueMapper, SearchFieldTransformer

__all__ = (
"FieldValueMapper",
"QueryParser",
"SuggestQueryParser",
"SearchFieldTransformer",
"SuggestQueryParser",
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from functools import partial

from invenio_search.engine import dsl
from luqum.auto_head_tail import auto_head_tail
from luqum.exceptions import ParseError
from luqum.parser import parser as luqum_parser
from werkzeug.utils import cached_property
Expand Down Expand Up @@ -54,7 +55,7 @@ class SearchOptions:
class SearchOptions:
query_parser_cls = QueryParser.factory(
fields=["metadata.title^2", "metadata.description"],
tree_transformer_factory=FieldTransformer.factory(
tree_transformer_factory=SearchFieldTransformer.factory(
mapping={
"title": "metadata.title",
"description": "metadata.description",
Expand Down Expand Up @@ -125,9 +126,11 @@ def parse(self, query_str):
# Perform transformation on the abstract syntax tree (AST)
if self.tree_transformer_cls is not None:
transformer = self.tree_transformer_cls(
mapping=self.mapping, allow_list=self.allow_list
mapping=self.mapping,
allow_list=self.allow_list,
)
new_tree = transformer.visit(tree, context={"identity": self.identity})
new_tree = auto_head_tail(new_tree)
query_str = str(new_tree)
return dsl.Q("query_string", query=query_str, **self.extra_params)
except (ParseError, QuerystringValidationError):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,36 +13,75 @@
how to build your own query tree transformer.
"""

from functools import partial

from invenio_i18n import gettext as _
from luqum.visitor import TreeTransformer

from invenio_records_resources.services.errors import QuerystringValidationError


class FieldValueMapper:
"""Class used to remap values to new terms."""

def __init__(self, term_name, word=None, phrase=None):
"""Initialize field value mapper."""
self._term_name = term_name
self._word_fun = word
self._phrase_fun = phrase

@property
def term_name(self):
"""Get the term name."""
return self._term_name

def map_word(self, node):
"""Modify a word node."""
return self._word_fun(node) if self._word_fun else node

def map_phrase(self, node):
"""Modify a phrase node."""
return self._phrase_fun(node) if self._phrase_fun else node


class SearchFieldTransformer(TreeTransformer):
"""Transform from user-friendly field names to internal field names."""

def __init__(self, mapping, allow_list, *args, **kwargs):
"""Constructor."""
self._mapping = mapping
self._allow_list = allow_list
super().__init__(self, *args, **kwargs)
super().__init__(*args, **kwargs)

def visit_search_field(self, node, context):
"""Visit a search field."""
# Use the node name if not mapped for transformation.
term_name = self._mapping.get(node.name, node.name)
field_value_mapper = None

if isinstance(term_name, FieldValueMapper):
field_value_mapper = term_name
term_name = field_value_mapper.term_name

# If a allow list exists, the term must be allowed to be queried.
if self._allow_list and not term_name in self._allow_list:
raise QuerystringValidationError(
_("Invalid search field: {field_name}.").format(field_name=node.name)
)

if field_value_mapper:
context["field_value_mapper"] = field_value_mapper

# Returns a copy of the node.
new_node = node.clone_item()
new_node.name = term_name
new_node.children = list(self.clone_children(node, new_node, context))
yield new_node

def visit_word(self, node, context):
"""Visit a word term."""
mapper = context.get("field_value_mapper")
yield node if mapper is None else mapper.map_word(node)

def visit_phrase(self, node, context):
"""Visit a phrase term."""
mapper = context.get("field_value_mapper")
yield node if mapper is None else mapper.map_phrase(node)
43 changes: 43 additions & 0 deletions tests/services/test_service_queryparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@

import pytest
from invenio_access.permissions import system_identity
from luqum.tree import Phrase

from invenio_records_resources.services.records.queryparser import (
FieldValueMapper,
QueryParser,
SearchFieldTransformer,
)
Expand Down Expand Up @@ -151,3 +153,44 @@ def test_parser_fields(allow_list, fields, expected_fields):
)

assert not set(p(system_identity).fields).difference(expected_fields)


@pytest.mark.parametrize(
"query,transformed_query",
[
("doi:10.5281/zenodo.123", 'metadata.doi:"10.5281/zenodo.123"'),
("doi:(blr OR biosyslit)", 'metadata.doi:("blr" OR "biosyslit")'),
("doi:(+blr -biosyslit) test", 'metadata.doi:(+"blr" -"biosyslit") test'),
("lol:test", "lol:lol"),
(
"doi:(b1 OR b2) lol:(test test1 test2)^2",
'metadata.doi:("b1" OR "b2") lol:(lol lol lol)^2',
),
],
)
def test_querystring_valuemapper(query, transformed_query):
"""Invalid syntax falls back to multi match query."""

def word_to_phrase(node):
return Phrase(
f'"{node.value}"',
pos=node.pos,
size=node.size + 2,
head=node.head,
tail=node.tail,
)

def lol(node):
node.value = "lol"
return node

p = QueryParser.factory(
mapping={
"doi": FieldValueMapper("metadata.doi", word=word_to_phrase),
"lol": FieldValueMapper("lol", word=lol),
},
tree_transformer_cls=SearchFieldTransformer,
)
assert p(system_identity).parse(query).to_dict() == {
"query_string": {"query": transformed_query}
}
Loading