Skip to content

Commit

Permalink
Add sentence/text chunk count to match context
Browse files Browse the repository at this point in the history
Fix unmatched entity text
  • Loading branch information
synesthesiam committed Jan 3, 2024
1 parent 738b0db commit cf3d92b
Showing 1 changed file with 74 additions and 19 deletions.
93 changes: 74 additions & 19 deletions hassil/recognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,12 @@ class MatchContext:
edit_cost: int = 0
"""Number of edits that were required to match."""

text_chunks_matched: int = 0
"""Number of literal text chunks that were matched."""

intent_sentence: Optional[Sentence] = None
"""Sentence template that is being matched."""

def __post_init__(self):
if self.close_wildcards:
for entity in self.entities:
Expand Down Expand Up @@ -241,6 +247,13 @@ class RecognizeResult:
"""Unmatched entities as a list (duplicates allowed)."""

edit_cost: int = 0
"""Number of edits that were required for the match to succeed."""

text_chunks_matched: int = 0
"""Number of literal text chunks that were successfully matched."""

intent_sentence: Optional[Sentence] = None
"""Sentence template that was matched."""


def recognize(
Expand Down Expand Up @@ -449,6 +462,7 @@ def recognize_all(
match_context = MatchContext(
text=text,
intent_context=intent_context,
intent_sentence=intent_sentence,
)
maybe_match_contexts = match_expression(
local_settings, match_context, intent_sentence
Expand Down Expand Up @@ -628,6 +642,8 @@ def recognize_all(
},
unmatched_entities_list=maybe_match_context.unmatched_entities,
edit_cost=maybe_match_context.edit_cost,
text_chunks_matched=maybe_match_context.text_chunks_matched,
intent_sentence=maybe_match_context.intent_sentence,
)


Expand Down Expand Up @@ -675,6 +691,7 @@ def is_match(
match_context = MatchContext(
text=text,
intent_context=intent_context,
intent_sentence=sentence,
)

for maybe_match_context in match_expression(settings, match_context, sentence):
Expand Down Expand Up @@ -753,6 +770,8 @@ def match_expression(
intent_context=context.intent_context,
unmatched_entities=context.unmatched_entities,
edit_cost=context.edit_cost,
text_chunks_matched=context.text_chunks_matched,
intent_sentence=context.intent_sentence,
)
return

Expand Down Expand Up @@ -789,6 +808,8 @@ def match_expression(
intent_context=context.intent_context,
unmatched_entities=context.unmatched_entities,
edit_cost=context.edit_cost,
text_chunks_matched=context.text_chunks_matched,
intent_sentence=context.intent_sentence,
),
expression,
)
Expand All @@ -805,11 +826,13 @@ def match_expression(
text=context_text,
# must use chunk.text because it hasn't been stripped
is_start_of_word=chunk.text.endswith(" "),
text_chunks_matched=context.text_chunks_matched + 1,
# Copy over
entities=context.entities,
intent_context=context.intent_context,
unmatched_entities=context.unmatched_entities,
edit_cost=context.edit_cost,
intent_sentence=context.intent_sentence,
#
close_wildcards=is_chunk_word,
close_unmatched=is_chunk_word,
Expand Down Expand Up @@ -868,6 +891,8 @@ def match_expression(
is_start_of_word=context.is_start_of_word,
unmatched_entities=context.unmatched_entities,
edit_cost=context.edit_cost,
text_chunks_matched=context.text_chunks_matched,
intent_sentence=context.intent_sentence,
)
elif wildcard is not None:
# Add to wildcard by skipping ahead in the text until we find
Expand All @@ -887,26 +912,38 @@ def match_expression(
is_start_of_word=True,
unmatched_entities=context.unmatched_entities,
edit_cost=context.edit_cost,
text_chunks_matched=context.text_chunks_matched,
intent_sentence=context.intent_sentence,
)
elif settings.allow_unmatched_entities and (
unmatched_entity := context.get_open_entity()
):
# Add to the most recent unmatched entity by skipping ahead in
# the text until we find the current chunk text.
skip_idx = context_text.find(chunk_text)
if skip_idx >= 0:
unmatched_entity.text += context_text[:skip_idx]
re_chunk_text = re.escape(chunk_text.strip())
if settings.ignore_whitespace:
chunk_match = re.search(re_chunk_text, context_text)
else:
# Only skip to a word boundary
chunk_match = re.search(
rf"\s{re_chunk_text}(\s|$)", context_text
)

if chunk_match:
unmatched_entity.text += context_text[: chunk_match.start() + 1]

# Unmatched entities cannot be empty
if unmatched_entity.text:
yield MatchContext(
text=context.text[skip_idx + len(chunk_text) :],
text=context.text[chunk_match.end() :],
# Copy over
entities=context.entities,
intent_context=context.intent_context,
is_start_of_word=True,
unmatched_entities=context.unmatched_entities,
edit_cost=context.edit_cost,
text_chunks_matched=context.text_chunks_matched,
intent_sentence=context.intent_sentence,
)
else:
# Match failed
Expand Down Expand Up @@ -951,6 +988,7 @@ def match_expression(
if context.text:
text_list: TextSlotList = slot_list
# Any value may match
has_matches = False
for slot_value in text_list.values:
value_contexts = match_expression(
settings,
Expand All @@ -962,11 +1000,12 @@ def match_expression(
is_start_of_word=context.is_start_of_word,
unmatched_entities=context.unmatched_entities,
edit_cost=context.edit_cost,
text_chunks_matched=context.text_chunks_matched,
intent_sentence=context.intent_sentence,
),
slot_value.text_in,
)

has_matches = False
for value_context in value_contexts:
has_matches = True
entities = context.entities + [
Expand All @@ -992,6 +1031,8 @@ def match_expression(
is_start_of_word=context.is_start_of_word,
unmatched_entities=context.unmatched_entities,
edit_cost=context.edit_cost,
text_chunks_matched=context.text_chunks_matched,
intent_sentence=context.intent_sentence,
)
else:
yield MatchContext(
Expand All @@ -1002,22 +1043,26 @@ def match_expression(
is_start_of_word=context.is_start_of_word,
unmatched_entities=context.unmatched_entities,
edit_cost=context.edit_cost,
text_chunks_matched=context.text_chunks_matched,
intent_sentence=context.intent_sentence,
)

if (not has_matches) and settings.allow_unmatched_entities:
# Report mismatch
yield MatchContext(
# Copy over
text=context.text,
entities=context.entities,
intent_context=context.intent_context,
is_start_of_word=context.is_start_of_word,
edit_cost=context.edit_cost,
#
unmatched_entities=context.unmatched_entities
+ [UnmatchedTextEntity(name=list_ref.slot_name, text="")],
close_wildcards=True,
)
if (not has_matches) and settings.allow_unmatched_entities:
# Report mismatch
yield MatchContext(
# Copy over
text=context.text,
entities=context.entities,
intent_context=context.intent_context,
is_start_of_word=context.is_start_of_word,
edit_cost=context.edit_cost,
text_chunks_matched=context.text_chunks_matched,
intent_sentence=context.intent_sentence,
#
unmatched_entities=context.unmatched_entities
+ [UnmatchedTextEntity(name=list_ref.slot_name, text="")],
close_wildcards=True,
)

elif isinstance(slot_list, RangeSlotList):
if context.text:
Expand Down Expand Up @@ -1062,6 +1107,8 @@ def match_expression(
is_start_of_word=context.is_start_of_word,
unmatched_entities=context.unmatched_entities,
edit_cost=context.edit_cost,
text_chunks_matched=context.text_chunks_matched,
intent_sentence=context.intent_sentence,
)
elif settings.allow_unmatched_entities:
# Report out of range
Expand All @@ -1072,6 +1119,8 @@ def match_expression(
intent_context=context.intent_context,
is_start_of_word=context.is_start_of_word,
edit_cost=context.edit_cost,
text_chunks_matched=context.text_chunks_matched,
intent_sentence=context.intent_sentence,
#
unmatched_entities=context.unmatched_entities
+ [
Expand Down Expand Up @@ -1121,6 +1170,8 @@ def match_expression(
is_start_of_word=context.is_start_of_word,
unmatched_entities=context.unmatched_entities,
edit_cost=context.edit_cost,
text_chunks_matched=context.text_chunks_matched,
intent_sentence=context.intent_sentence,
),
TextChunk(number_words),
)
Expand All @@ -1143,6 +1194,8 @@ def match_expression(
intent_context=context.intent_context,
is_start_of_word=context.is_start_of_word,
edit_cost=context.edit_cost,
text_chunks_matched=context.text_chunks_matched,
intent_sentence=context.intent_sentence,
#
unmatched_entities=context.unmatched_entities
+ [UnmatchedTextEntity(name=list_ref.slot_name, text="")],
Expand All @@ -1158,6 +1211,8 @@ def match_expression(
is_start_of_word=context.is_start_of_word,
unmatched_entities=context.unmatched_entities,
edit_cost=context.edit_cost,
text_chunks_matched=context.text_chunks_matched,
intent_sentence=context.intent_sentence,
#
entities=context.entities
+ [
Expand Down

0 comments on commit cf3d92b

Please sign in to comment.