Skip to content

Commit 71bbe11

Browse files
authored
Merge branch 'lm-toolkit' into BANFF_lm
2 parents 3151958 + 0a421f5 commit 71bbe11

27 files changed

+342
-1000
lines changed

bcipy/core/tests/resources/mock_session/parameters.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,7 @@
680680
"recommended": [
681681
"UNIFORM",
682682
"CAUSAL",
683-
"KENLM",
683+
"NGRAM",
684684
"MIXTURE",
685685
"ORACLE"
686686
],

bcipy/exceptions.py

-7
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,6 @@ class TaskConfigurationException(BciPyCoreException):
9090
...
9191

9292

93-
class InvalidLanguageModelException(BciPyCoreException):
94-
"""Invalid Language Model Exception.
95-
96-
Thrown when attempting to load a language model from an invalid path"""
97-
...
98-
99-
10093
class KenLMInstallationException(BciPyCoreException):
10194
"""KenLM Installation Exception.
10295

bcipy/helpers/copy_phrase_wrapper.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from bcipy.core.symbols import BACKSPACE_CHAR
1111
from bcipy.exceptions import BciPyCoreException
1212
from bcipy.helpers.language_model import histogram, with_min_prob
13-
from bcipy.language.main import LanguageModel
13+
from bcipy.language.main import LanguageModelAdapter
1414
from bcipy.task.control.criteria import (CriteriaEvaluator,
1515
MaxIterationsCriteria,
1616
MinIterationsCriteria,
@@ -58,7 +58,7 @@ class CopyPhraseWrapper:
5858
def __init__(self,
5959
min_num_inq: int,
6060
max_num_inq: int,
61-
lmodel: LanguageModel,
61+
lmodel: LanguageModelAdapter,
6262
alp: List[str],
6363
evidence_names: List[EvidenceType] = [
6464
EvidenceType.LM, EvidenceType.ERP

bcipy/helpers/language_model.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,26 @@
66
import numpy as np
77

88
from bcipy.core.symbols import alphabet
9-
from bcipy.language.main import LanguageModel, ResponseType
9+
from bcipy.language.main import LanguageModelAdapter, ResponseType
1010

1111
# pylint: disable=unused-import
1212
# flake8: noqa
1313

1414
"""Only imported models will be included in language_models_by_name"""
1515
# flake8: noqa
16-
from bcipy.exceptions import InvalidLanguageModelException
17-
from bcipy.language.model.causal import CausalLanguageModel
18-
from bcipy.language.model.kenlm import KenLMLanguageModel
19-
from bcipy.language.model.mixture import MixtureLanguageModel
20-
from bcipy.language.model.oracle import OracleLanguageModel
21-
from bcipy.language.model.uniform import UniformLanguageModel
16+
from bcipy.language.model.causal import CausalLanguageModelAdapter
17+
from bcipy.language.model.ngram import NGramLanguageModelAdapter
18+
from bcipy.language.model.mixture import MixtureLanguageModelAdapter
19+
from bcipy.language.model.oracle import OracleLanguageModelAdapter
20+
from bcipy.language.model.uniform import UniformLanguageModelAdapter
2221

2322

24-
def language_models_by_name() -> Dict[str, LanguageModel]:
23+
def language_models_by_name() -> Dict[str, LanguageModelAdapter]:
2524
"""Returns available language models indexed by name."""
26-
return {lm.name(): lm for lm in LanguageModel.__subclasses__()}
25+
return {lm.name(): lm for lm in LanguageModelAdapter.__subclasses__()}
2726

2827

29-
def init_language_model(parameters: dict) -> LanguageModel:
28+
def init_language_model(parameters: dict) -> LanguageModelAdapter:
3029
"""
3130
Init Language Model configured in the parameters. If no language model is
3231
specified, a uniform language model is returned.
@@ -38,7 +37,7 @@ def init_language_model(parameters: dict) -> LanguageModel:
3837
3938
Returns
4039
-------
41-
instance of a LanguageModel
40+
instance of a LanguageModelAdapter
4241
"""
4342

4443
language_models = language_models_by_name()

bcipy/helpers/tests/test_copy_phrase_wrapper.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,20 @@
11
import unittest
22

33
from bcipy.helpers.copy_phrase_wrapper import CopyPhraseWrapper
4-
from bcipy.core.symbols import alphabet
5-
from bcipy.language.model.uniform import UniformLanguageModel
4+
from bcipy.core.symbols import DEFAULT_SYMBOL_SET
5+
from bcipy.language.model.uniform import UniformLanguageModelAdapter
66
from bcipy.task.data import EvidenceType
77

88

99
class TestCopyPhraseWrapper(unittest.TestCase):
1010
"""Test CopyPhraseWrapper"""
1111

1212
def test_valid_letters(self):
13-
alp = alphabet()
1413
cp = CopyPhraseWrapper(
1514
min_num_inq=1,
1615
max_num_inq=50,
1716
lmodel=None,
18-
alp=alp,
17+
alp=DEFAULT_SYMBOL_SET,
1918
task_list=[("HELLO_WORLD", "HE")],
2019
is_txt_stim=True,
2120
evidence_names=[EvidenceType.LM, EvidenceType.ERP],
@@ -104,13 +103,12 @@ def test_valid_letters(self):
104103
["nontarget", "nontarget"])
105104

106105
def test_init_series(self):
107-
alp = alphabet()
108106

109107
copy_phrase_task = CopyPhraseWrapper(
110108
min_num_inq=1,
111109
max_num_inq=50,
112-
lmodel=UniformLanguageModel(symbol_set=alp),
113-
alp=alp,
110+
lmodel=UniformLanguageModelAdapter(symbol_set=DEFAULT_SYMBOL_SET),
111+
alp=DEFAULT_SYMBOL_SET,
114112
task_list=[("HELLO_WORLD", "HE")],
115113
is_txt_stim=True,
116114
evidence_names=[EvidenceType.LM, EvidenceType.ERP],

bcipy/language/README.md

+9-11
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
# Language
22

3-
BciPy Language module provides an interface for word and character level predictions.
3+
BciPy Language module provides an interface for word and character level predictions. This module primarily relies upon the AAC-TextPredict package (aactextpredict on PyPI) for its probability calculations. More information on this package can be found on our [GitHub repo](https://github.com/kdv123/textpredict)
44

5-
The core methods of any `LanguageModel` include:
5+
The core methods of any `LanguageModelAdapter` include:
66

77
> `predict` - given typing evidence input, return a prediction (character or word).
88
99
> `load` - load a pre-trained model given a path (currently BciPy does not support training language models!)
1010
11-
> `update` - update internal state of your model.
12-
1311
You may of course define other methods, however all integrated BciPy experiments using your model will require those to be defined!
1412

1513
The language module has the following structure:
@@ -18,7 +16,7 @@ The language module has the following structure:
1816
1917
> `lms` - The default location for the model resources.
2018
21-
> `model` - The python classes for each LanguageModel subclass. Detailed descriptions of each can be found below.
19+
> `model` - The python classes for each LanguageModelAdapter subclass. Detailed descriptions of each can be found below.
2220
2321
> `sets` - Different phrase sets that can be used to test the language model classes.
2422
@@ -28,22 +26,22 @@ The language module has the following structure:
2826

2927
## Uniform Model
3028

31-
The UniformLanguageModel provides equal probabilities for all symbols in the symbol set. This model is useful for evaluating other aspects of the system, such as EEG signal quality, without any influence from a language model.
29+
The UniformLanguageModelAdapter provides equal probabilities for all symbols in the symbol set. This model is useful for evaluating other aspects of the system, such as EEG signal quality, without any influence from a language model.
3230

33-
## KenLM Model
34-
The KenLMLanguageModel utilizes a pretrained n-gram language model to generate probabilities for all symbols in the symbol set. N-gram models use frequencies of different character sequences to generate their predictions. Models trained on AAC-like data can be found [here](https://imagineville.org/software/lm/dec19_char/). For faster load times, it is recommended to use the binary models located at the bottom of the page. The default parameters file utilizes `lm_dec19_char_large_12gram.kenlm`. If you have issues accessing, please reach out to us on GitHub or via email at `cambi_support@googlegroups.com`.
31+
## NGram Model
32+
The NGramLanguageModelAdapter utilizes a pretrained n-gram language model to generate probabilities for all symbols in the symbol set. N-gram models use frequencies of different character sequences to generate their predictions. Models trained on AAC-like data can be found [here](https://imagineville.org/software/lm/dec19_char/). For faster load times, it is recommended to use the binary models located at the bottom of the page. The default parameters file utilizes `lm_dec19_char_large_12gram.kenlm`. If you have issues accessing, please reach out to us on GitHub or via email at `cambi_support@googlegroups.com`.
3533

3634
For models that import the kenlm module, this must be manually installed using `pip install kenlm==0.1 --global-option="max_order=12"`.
3735

3836
## Causal Model
39-
The CausalLanguageModel class can use any causal language model from Huggingface, though it has only been tested with gpt2, facebook/opt, and distilgpt2 families of models. Causal language models predict the next token in a sequence of tokens. For the many of these models, byte-pair encoding (BPE) is used for tokenization. The main idea of BPE is to create a fixed-size vocabulary that contains common English subword units. Then a less common word would be broken down into several subword units in the vocabulary. For example, the tokenization of character sequence `peanut_butter_and_jel` would be:
37+
The CausalLanguageModelAdapter class can use any causal language model from Huggingface, though it has only been tested with gpt2, facebook/opt, and distilgpt2 families of models (including the domain-adapted figmtu/opt-350m-aac). Causal language models predict the next token in a sequence of tokens. For the many of these models, byte-pair encoding (BPE) is used for tokenization. The main idea of BPE is to create a fixed-size vocabulary that contains common English subword units. Then a less common word would be broken down into several subword units in the vocabulary. For example, the tokenization of character sequence `peanut_butter_and_jel` would be:
4038
> *['pe', 'anut', '_butter', '_and', '_j', 'el']*
4139
42-
Therefore, in order to generate a predictive distribution on the next character, we need to examine all the possibilities that could complete the final subword tokens in the input sequences. We must remove at least one token from the end of the context to allow the model the option of extending it, as opposed to only adding a new token. Removing more tokens allows the model more flexibility and may lead to better predictions, but at the cost of a higher prediction time. In this model we remove all of the subword tokens in the current (partially-typed) word to allow it the most flexibility. We then ask the model to estimate the likelihood of the next token and evaluate each token that matches our context. For efficiency, we only track a certain number of hypotheses at a time, known as the beam width, and each hypothesis until it surpasses the context. We can then store the likelihood for each final prediction in a list based on the character that directly follows the context. Once we have no more hypotheses to extend, we can sum the likelihoods stored for each character in our symbol set and normalize so they sum to 1, giving us our final distribution.
40+
Therefore, in order to generate a predictive distribution on the next character, we need to examine all the possibilities that could complete the final subword tokens in the input sequences. We must remove at least one token from the end of the context to allow the model the option of extending it, as opposed to only adding a new token. Removing more tokens allows the model more flexibility and may lead to better predictions, but at the cost of a higher prediction time. In this model we remove all of the subword tokens in the current (partially-typed) word to allow it the most flexibility. We then ask the model to estimate the likelihood of the next token and evaluate each token that matches our context. For efficiency, we only track a certain number of hypotheses at a time, known as the beam width, and each hypothesis until it surpasses the context. We can then store the likelihood for each final prediction in a list based on the character that directly follows the context. Once we have no more hypotheses to extend, we can sum the likelihoods stored for each character in our symbol set and normalize so they sum to 1, giving us our final distribution. More details on this process can be found in our paper, [Adapting Large Language Models for Character-based Augmentative and Alternative Communication](https://arxiv.org/abs/2501.10582).
4341

4442

4543
## Mixture Model
46-
The MixtureLanguageModel class allows for the combination of two or more supported models. The selected models are mixed according to the provided weights, which can be tuned using the Bcipy/scripts/python/mixture_tuning.py script. It is not recommended to use more than one "heavy-weight" model with long prediction times (the CausalLanguageModel) since this model will query each component model and parallelization is not currently supported.
44+
The MixtureLanguageModelAdapter class allows for the combination of two or more supported models. The selected models are mixed according to the provided weights, which can be tuned using the Bcipy/scripts/python/mixture_tuning.py script. It is not recommended to use more than one "heavy-weight" model with long prediction times (the CausalLanguageModel) since this model will query each component model and parallelization is not currently supported.
4745

4846
# Contact Information
4947

bcipy/language/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from .main import LanguageModel, ResponseType
1+
from .main import LanguageModelAdapter, ResponseType
22

33
__all__ = [
4-
"LanguageModel",
4+
"LanguageModelAdapter",
55
"ResponseType",
66
]

bcipy/language/demo/demo_causal.py

+26-16
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,36 @@
1-
from bcipy.language.model.causal import CausalLanguageModel
2-
from bcipy.core.symbols import alphabet
1+
from bcipy.language.model.causal import CausalLanguageModelAdapter
2+
from bcipy.core.symbols import DEFAULT_SYMBOL_SET
33
from bcipy.language.main import ResponseType
44

55

66
if __name__ == "__main__":
7-
symbol_set = alphabet()
87
response_type = ResponseType.SYMBOL
9-
lm = CausalLanguageModel(response_type, symbol_set, lang_model_name="gpt2")
8+
lm = CausalLanguageModelAdapter(response_type, DEFAULT_SYMBOL_SET, lang_model_name="figmtu/opt-350m-aac")
109

11-
next_char_pred = lm.state_update(list("does_it_make_sen"))
12-
print(next_char_pred)
10+
print("Target sentence: does_it_make_sense\n")
11+
12+
next_char_pred = lm.predict(list("does_it_make_sen"))
13+
print(f"Context: does_it_make_sen")
14+
print(f"Predictions: {next_char_pred}")
1315
correct_char_rank = [c[0] for c in next_char_pred].index("S") + 1
14-
print(correct_char_rank)
15-
next_char_pred = lm.state_update(list("does_it_make_sens"))
16-
print(next_char_pred)
16+
print(f"Correct character rank: {correct_char_rank}\n")
17+
18+
next_char_pred = lm.predict(list("does_it_make_sens"))
19+
print(f"Context: does_it_make_sens")
20+
print(f"Predictions: {next_char_pred}")
1721
correct_char_rank = [c[0] for c in next_char_pred].index("E") + 1
18-
print(correct_char_rank)
19-
next_char_pred = lm.state_update(list("does_it_make_sense"))
20-
print(next_char_pred)
22+
print(f"Correct character rank: {correct_char_rank}\n")
23+
24+
next_char_pred = lm.predict(list("does_it_make_sense"))
25+
print(f"Context: does_it_make_sense")
26+
print(f"Predictions: {next_char_pred}")
2127
correct_char_rank = [c[0] for c in next_char_pred].index("_") + 1
22-
print(correct_char_rank)
23-
next_char_pred = lm.state_update(list("i_like_zebra"))
24-
print(next_char_pred)
28+
print(f"Correct character rank: {correct_char_rank}\n")
29+
30+
print("Target sentence: i_like_zebras\n")
31+
32+
next_char_pred = lm.predict(list("i_like_zebra"))
33+
print(f"Context: i_like_zebra")
34+
print(f"Predictions: {next_char_pred}")
2535
correct_char_rank = [c[0] for c in next_char_pred].index("S") + 1
26-
print(correct_char_rank)
36+
print(f"Correct character rank: {correct_char_rank}\n")

bcipy/language/demo/demo_mixture.py

+28-13
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,37 @@
1-
from bcipy.language.model.mixture import MixtureLanguageModel
2-
from bcipy.core.symbols import alphabet
1+
from bcipy.language.model.mixture import MixtureLanguageModelAdapter
2+
from bcipy.core.symbols import DEFAULT_SYMBOL_SET
33
from bcipy.language.main import ResponseType
44

55

66
if __name__ == "__main__":
7-
symbol_set = alphabet()
87
response_type = ResponseType.SYMBOL
9-
lm = MixtureLanguageModel(response_type, symbol_set)
8+
# Load the default mixture model from lm_params.json
9+
lm = MixtureLanguageModelAdapter(response_type, DEFAULT_SYMBOL_SET)
1010

11-
next_char_pred = lm.state_update(list("does_it_make_sen"))
12-
print(next_char_pred)
11+
print("Target sentence: does_it_make_sense\n")
12+
13+
next_char_pred = lm.predict(list("does_it_make_sen"))
14+
print(f"Context: does_it_make_sen")
15+
print(f"Predictions: {next_char_pred}")
1316
correct_char_rank = [c[0] for c in next_char_pred].index("S") + 1
14-
print(correct_char_rank)
15-
next_char_pred = lm.state_update(list("does_it_make_sens"))
16-
print(next_char_pred)
17+
print(f"Correct character rank: {correct_char_rank}\n")
18+
19+
next_char_pred = lm.predict(list("does_it_make_sens"))
20+
print(f"Context: does_it_make_sens")
21+
print(f"Predictions: {next_char_pred}")
1722
correct_char_rank = [c[0] for c in next_char_pred].index("E") + 1
18-
print(correct_char_rank)
19-
next_char_pred = lm.state_update(list("does_it_make_sense"))
20-
print(next_char_pred)
23+
print(f"Correct character rank: {correct_char_rank}\n")
24+
25+
next_char_pred = lm.predict(list("does_it_make_sense"))
26+
print(f"Context: does_it_make_sense")
27+
print(f"Predictions: {next_char_pred}")
2128
correct_char_rank = [c[0] for c in next_char_pred].index("_") + 1
22-
print(correct_char_rank)
29+
print(f"Correct character rank: {correct_char_rank}\n")
30+
31+
print("Target sentence: i_like_zebras\n")
32+
33+
next_char_pred = lm.predict(list("i_like_zebra"))
34+
print(f"Context: i_like_zebra")
35+
print(f"Predictions: {next_char_pred}")
36+
correct_char_rank = [c[0] for c in next_char_pred].index("S") + 1
37+
print(f"Correct character rank: {correct_char_rank}\n")

0 commit comments

Comments
 (0)