Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LM Toolkit Refactor #381

Open
wants to merge 48 commits into
base: 2.0.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 45 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
737a8e2
add script for simulating lm change with different phrases, small upd…
tab-cmd Jan 3, 2025
a18a35c
add processing script, add NullDAQ
tab-cmd Jan 6, 2025
8d9fb88
Update to custom typing parameters
tab-cmd Jan 6, 2025
c68803c
Add progress bar and WIP update to average phrases across LM
tab-cmd Jan 6, 2025
c284227
update the process script to the new output format
tab-cmd Jan 8, 2025
0bdcfc7
Add final phrases
tab-cmd Jan 8, 2025
ab9244a
add plotting and stats to the processing script
tab-cmd Jan 8, 2025
1b12409
add missing params command
tab-cmd Jan 8, 2025
205d929
add more logging and custom metrics
tab-cmd Jan 9, 2025
ff6e9e2
Integration of causal model, add phrases, update processing scripts
tab-cmd Jan 9, 2025
a58b8cf
matrix processing
tab-cmd Jan 15, 2025
0dc02fe
update figure
tab-cmd Jan 24, 2025
2ac395e
move script to a group demo
tab-cmd Feb 27, 2025
a262d40
Merge remote-tracking branch 'origin/2.0.0' into BANFF_lm
tab-cmd Feb 27, 2025
9b29a51
reset devices default
tab-cmd Feb 27, 2025
32b82f6
remove retry logic from language model init
tab-cmd Feb 27, 2025
8cf25e1
reset static defaults
tab-cmd Feb 27, 2025
a43ba68
update parameters
tab-cmd Feb 27, 2025
0fb843a
remove bad test
tab-cmd Feb 27, 2025
7e5994b
lint
tab-cmd Feb 27, 2025
3ee8b9b
remove integration tests (for now) and add some info to sim README + …
tab-cmd Feb 27, 2025
a9ebcd4
drop support for 3.8
tab-cmd Feb 27, 2025
7b7e560
Added textpredict dependency, removed LM dependencies that are includ…
dcgaines Mar 5, 2025
4f827d9
Refactored main language model classes into adapters that use the new…
dcgaines Mar 5, 2025
b408af6
Renamed ngram model
dcgaines Mar 5, 2025
3461308
Updated imports
dcgaines Mar 5, 2025
5f3016f
More ngram renaming
dcgaines Mar 5, 2025
956ca98
More ngram renaming, adjusted mixture default params
dcgaines Mar 5, 2025
7d96e9a
Updated textpredict version
dcgaines Mar 5, 2025
0c4a167
Converted mixture model to adapter
dcgaines Mar 5, 2025
7401c97
Conveted oracle model into adapter
dcgaines Mar 6, 2025
63f8cd2
Deprecated InvalidLanguageModelException from bcipy in favor of aacte…
dcgaines Mar 6, 2025
c75ad28
Upgraded transformers version to address pytorch deprecation warnings
dcgaines Mar 6, 2025
a937b72
Adjusted max bump to 1, all mass on target
dcgaines Mar 6, 2025
8384e34
Updated test cases to use adapters
dcgaines Mar 6, 2025
911676b
Store bcipy symbol set alongside toolkit model symbol set
dcgaines Mar 6, 2025
6e1c4b5
Updated demos to use new adapter classes
dcgaines Mar 7, 2025
2e2b97e
Updated test class names
dcgaines Mar 7, 2025
204134f
Fixed LM class references
dcgaines Mar 7, 2025
1106682
Updated LM documentation
dcgaines Mar 7, 2025
1f337f0
Fixed misc LM class references
dcgaines Mar 7, 2025
0a421f5
Updated LM class references
dcgaines Mar 7, 2025
3151958
Merge pull request #379 from CAMBI-tech/2.0.0
dcgaines Mar 7, 2025
71bbe11
Merge branch 'lm-toolkit' into BANFF_lm
dcgaines Mar 7, 2025
b073149
Merge pull request #380 from CAMBI-tech/BANFF_lm
dcgaines Mar 7, 2025
fceca14
Merge branch '2.0.0' into lm-toolkit
tab-cmd Mar 20, 2025
f809adb
lint
tab-cmd Mar 20, 2025
de49406
Update textpredict dependency to fix 3.10
dcgaines Mar 24, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: [3.8, 3.9, 3.10.6]
python-version: [3.9, 3.10.6]

steps:
- uses: actions/checkout@v2
Expand Down Expand Up @@ -69,7 +69,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: [3.8, 3.9, 3.10.6]
python-version: [3.9, 3.10.6]

steps:
- uses: actions/checkout@v2
Expand All @@ -93,9 +93,6 @@ jobs:
- name: lint
run: |
make lint
- name: integration-test
run: |
make integration-test
- name: build
run: |
make build
Expand All @@ -106,7 +103,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: [3.8, 3.9, 3.10.6]
python-version: [3.9, 3.10.6]

steps:
- uses: actions/checkout@v4
Expand All @@ -133,9 +130,6 @@ jobs:
- name: lint
run: |
make lint
- name: integration-test
run: |
make integration-test
- name: build
run: |
make build
Expand Down
1 change: 0 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ test-all:
make coverage-report
make type
make lint
make integration-test

unit-test:
pytest --mpl -k "not slow"
Expand Down
2 changes: 1 addition & 1 deletion bcipy/core/tests/resources/mock_session/parameters.json
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@
"recommended": [
"UNIFORM",
"CAUSAL",
"KENLM",
"NGRAM",
"MIXTURE",
"ORACLE"
],
Expand Down
7 changes: 0 additions & 7 deletions bcipy/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,6 @@ class TaskConfigurationException(BciPyCoreException):
...


class InvalidLanguageModelException(BciPyCoreException):
"""Invalid Language Model Exception.

Thrown when attempting to load a language model from an invalid path"""
...


class KenLMInstallationException(BciPyCoreException):
"""KenLM Installation Exception.

Expand Down
4 changes: 2 additions & 2 deletions bcipy/helpers/copy_phrase_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from bcipy.core.symbols import BACKSPACE_CHAR
from bcipy.exceptions import BciPyCoreException
from bcipy.helpers.language_model import histogram, with_min_prob
from bcipy.language.main import LanguageModel
from bcipy.language.main import LanguageModelAdapter
from bcipy.task.control.criteria import (CriteriaEvaluator,
MaxIterationsCriteria,
MinIterationsCriteria,
Expand Down Expand Up @@ -58,7 +58,7 @@ class CopyPhraseWrapper:
def __init__(self,
min_num_inq: int,
max_num_inq: int,
lmodel: LanguageModel,
lmodel: LanguageModelAdapter,
alp: List[str],
evidence_names: List[EvidenceType] = [
EvidenceType.LM, EvidenceType.ERP
Expand Down
32 changes: 17 additions & 15 deletions bcipy/helpers/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,29 @@
import numpy as np

from bcipy.core.symbols import alphabet
from bcipy.language.main import LanguageModel, ResponseType
from bcipy.language.main import LanguageModelAdapter, ResponseType

# pylint: disable=unused-import
# flake8: noqa

"""Only imported models will be included in language_models_by_name"""
# flake8: noqa
from bcipy.exceptions import InvalidLanguageModelException
from bcipy.language.model.causal import CausalLanguageModel
from bcipy.language.model.kenlm import KenLMLanguageModel
from bcipy.language.model.mixture import MixtureLanguageModel
from bcipy.language.model.oracle import OracleLanguageModel
from bcipy.language.model.uniform import UniformLanguageModel
from bcipy.language.model.causal import CausalLanguageModelAdapter
from bcipy.language.model.ngram import NGramLanguageModelAdapter
from bcipy.language.model.mixture import MixtureLanguageModelAdapter
from bcipy.language.model.oracle import OracleLanguageModelAdapter
from bcipy.language.model.uniform import UniformLanguageModelAdapter


def language_models_by_name() -> Dict[str, LanguageModel]:
def language_models_by_name() -> Dict[str, LanguageModelAdapter]:
"""Returns available language models indexed by name."""
return {lm.name(): lm for lm in LanguageModel.__subclasses__()}
return {lm.name(): lm for lm in LanguageModelAdapter.__subclasses__()}


def init_language_model(parameters: dict) -> LanguageModel:
def init_language_model(parameters: dict) -> LanguageModelAdapter:
"""
Init Language Model configured in the parameters.
Init Language Model configured in the parameters. If no language model is
specified, a uniform language model is returned.

Parameters
----------
Expand All @@ -37,7 +37,7 @@ def init_language_model(parameters: dict) -> LanguageModel:

Returns
-------
instance of a LanguageModel
instance of a LanguageModelAdapter
"""

language_models = language_models_by_name()
Expand All @@ -48,9 +48,11 @@ def init_language_model(parameters: dict) -> LanguageModel:

# select the relevant parameters into a dict.
params = {key: parameters[key] for key in args & parameters.keys()}
return model(response_type=ResponseType.SYMBOL,
symbol_set=alphabet(parameters),
**params)

return model(
response_type=ResponseType.SYMBOL,
symbol_set=alphabet(parameters),
**params)


def norm_domain(priors: List[Tuple[str, float]]) -> List[Tuple[str, float]]:
Expand Down
12 changes: 5 additions & 7 deletions bcipy/helpers/tests/test_copy_phrase_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
import unittest

from bcipy.helpers.copy_phrase_wrapper import CopyPhraseWrapper
from bcipy.core.symbols import alphabet
from bcipy.language.model.uniform import UniformLanguageModel
from bcipy.core.symbols import DEFAULT_SYMBOL_SET
from bcipy.language.model.uniform import UniformLanguageModelAdapter
from bcipy.task.data import EvidenceType


class TestCopyPhraseWrapper(unittest.TestCase):
"""Test CopyPhraseWrapper"""

def test_valid_letters(self):
alp = alphabet()
cp = CopyPhraseWrapper(
min_num_inq=1,
max_num_inq=50,
lmodel=None,
alp=alp,
alp=DEFAULT_SYMBOL_SET,
task_list=[("HELLO_WORLD", "HE")],
is_txt_stim=True,
evidence_names=[EvidenceType.LM, EvidenceType.ERP],
Expand Down Expand Up @@ -104,13 +103,12 @@ def test_valid_letters(self):
["nontarget", "nontarget"])

def test_init_series(self):
alp = alphabet()

copy_phrase_task = CopyPhraseWrapper(
min_num_inq=1,
max_num_inq=50,
lmodel=UniformLanguageModel(symbol_set=alp),
alp=alp,
lmodel=UniformLanguageModelAdapter(symbol_set=DEFAULT_SYMBOL_SET),
alp=DEFAULT_SYMBOL_SET,
task_list=[("HELLO_WORLD", "HE")],
is_txt_stim=True,
evidence_names=[EvidenceType.LM, EvidenceType.ERP],
Expand Down
20 changes: 9 additions & 11 deletions bcipy/language/README.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
# Language

BciPy Language module provides an interface for word and character level predictions.
BciPy Language module provides an interface for word and character level predictions. This module primarily relies upon the AAC-TextPredict package (aactextpredict on PyPI) for its probability calculations. More information on this package can be found on our [GitHub repo](https://github.com/kdv123/textpredict)

The core methods of any `LanguageModel` include:
The core methods of any `LanguageModelAdapter` include:

> `predict` - given typing evidence input, return a prediction (character or word).

> `load` - load a pre-trained model given a path (currently BciPy does not support training language models!)

> `update` - update internal state of your model.

You may of course define other methods, however all integrated BciPy experiments using your model will require those to be defined!

The language module has the following structure:
Expand All @@ -18,7 +16,7 @@ The language module has the following structure:

> `lms` - The default location for the model resources.

> `model` - The python classes for each LanguageModel subclass. Detailed descriptions of each can be found below.
> `model` - The python classes for each LanguageModelAdapter subclass. Detailed descriptions of each can be found below.

> `sets` - Different phrase sets that can be used to test the language model classes.

Expand All @@ -28,22 +26,22 @@ The language module has the following structure:

## Uniform Model

The UniformLanguageModel provides equal probabilities for all symbols in the symbol set. This model is useful for evaluating other aspects of the system, such as EEG signal quality, without any influence from a language model.
The UniformLanguageModelAdapter provides equal probabilities for all symbols in the symbol set. This model is useful for evaluating other aspects of the system, such as EEG signal quality, without any influence from a language model.

## KenLM Model
The KenLMLanguageModel utilizes a pretrained n-gram language model to generate probabilities for all symbols in the symbol set. N-gram models use frequencies of different character sequences to generate their predictions. Models trained on AAC-like data can be found [here](https://imagineville.org/software/lm/dec19_char/). For faster load times, it is recommended to use the binary models located at the bottom of the page. The default parameters file utilizes `lm_dec19_char_large_12gram.kenlm`. If you have issues accessing, please reach out to us on GitHub or via email at `cambi_support@googlegroups.com`.
## NGram Model
The NGramLanguageModelAdapter utilizes a pretrained n-gram language model to generate probabilities for all symbols in the symbol set. N-gram models use frequencies of different character sequences to generate their predictions. Models trained on AAC-like data can be found [here](https://imagineville.org/software/lm/dec19_char/). For faster load times, it is recommended to use the binary models located at the bottom of the page. The default parameters file utilizes `lm_dec19_char_large_12gram.kenlm`. If you have issues accessing, please reach out to us on GitHub or via email at `cambi_support@googlegroups.com`.

For models that import the kenlm module, this must be manually installed using `pip install kenlm==0.1 --global-option="max_order=12"`.

## Causal Model
The CausalLanguageModel class can use any causal language model from Huggingface, though it has only been tested with gpt2, facebook/opt, and distilgpt2 families of models. Causal language models predict the next token in a sequence of tokens. For the many of these models, byte-pair encoding (BPE) is used for tokenization. The main idea of BPE is to create a fixed-size vocabulary that contains common English subword units. Then a less common word would be broken down into several subword units in the vocabulary. For example, the tokenization of character sequence `peanut_butter_and_jel` would be:
The CausalLanguageModelAdapter class can use any causal language model from Huggingface, though it has only been tested with gpt2, facebook/opt, and distilgpt2 families of models (including the domain-adapted figmtu/opt-350m-aac). Causal language models predict the next token in a sequence of tokens. For the many of these models, byte-pair encoding (BPE) is used for tokenization. The main idea of BPE is to create a fixed-size vocabulary that contains common English subword units. Then a less common word would be broken down into several subword units in the vocabulary. For example, the tokenization of character sequence `peanut_butter_and_jel` would be:
> *['pe', 'anut', '_butter', '_and', '_j', 'el']*

Therefore, in order to generate a predictive distribution on the next character, we need to examine all the possibilities that could complete the final subword tokens in the input sequences. We must remove at least one token from the end of the context to allow the model the option of extending it, as opposed to only adding a new token. Removing more tokens allows the model more flexibility and may lead to better predictions, but at the cost of a higher prediction time. In this model we remove all of the subword tokens in the current (partially-typed) word to allow it the most flexibility. We then ask the model to estimate the likelihood of the next token and evaluate each token that matches our context. For efficiency, we only track a certain number of hypotheses at a time, known as the beam width, and each hypothesis until it surpasses the context. We can then store the likelihood for each final prediction in a list based on the character that directly follows the context. Once we have no more hypotheses to extend, we can sum the likelihoods stored for each character in our symbol set and normalize so they sum to 1, giving us our final distribution.
Therefore, in order to generate a predictive distribution on the next character, we need to examine all the possibilities that could complete the final subword tokens in the input sequences. We must remove at least one token from the end of the context to allow the model the option of extending it, as opposed to only adding a new token. Removing more tokens allows the model more flexibility and may lead to better predictions, but at the cost of a higher prediction time. In this model we remove all of the subword tokens in the current (partially-typed) word to allow it the most flexibility. We then ask the model to estimate the likelihood of the next token and evaluate each token that matches our context. For efficiency, we only track a certain number of hypotheses at a time, known as the beam width, and each hypothesis until it surpasses the context. We can then store the likelihood for each final prediction in a list based on the character that directly follows the context. Once we have no more hypotheses to extend, we can sum the likelihoods stored for each character in our symbol set and normalize so they sum to 1, giving us our final distribution. More details on this process can be found in our paper, [Adapting Large Language Models for Character-based Augmentative and Alternative Communication](https://arxiv.org/abs/2501.10582).


## Mixture Model
The MixtureLanguageModel class allows for the combination of two or more supported models. The selected models are mixed according to the provided weights, which can be tuned using the Bcipy/scripts/python/mixture_tuning.py script. It is not recommended to use more than one "heavy-weight" model with long prediction times (the CausalLanguageModel) since this model will query each component model and parallelization is not currently supported.
The MixtureLanguageModelAdapter class allows for the combination of two or more supported models. The selected models are mixed according to the provided weights, which can be tuned using the Bcipy/scripts/python/mixture_tuning.py script. It is not recommended to use more than one "heavy-weight" model with long prediction times (the CausalLanguageModel) since this model will query each component model and parallelization is not currently supported.

# Contact Information

Expand Down
4 changes: 2 additions & 2 deletions bcipy/language/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .main import LanguageModel, ResponseType
from .main import LanguageModelAdapter, ResponseType

__all__ = [
"LanguageModel",
"LanguageModelAdapter",
"ResponseType",
]
42 changes: 26 additions & 16 deletions bcipy/language/demo/demo_causal.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,36 @@
from bcipy.language.model.causal import CausalLanguageModel
from bcipy.core.symbols import alphabet
from bcipy.language.model.causal import CausalLanguageModelAdapter
from bcipy.core.symbols import DEFAULT_SYMBOL_SET
from bcipy.language.main import ResponseType


if __name__ == "__main__":
symbol_set = alphabet()
response_type = ResponseType.SYMBOL
lm = CausalLanguageModel(response_type, symbol_set, lang_model_name="gpt2")
lm = CausalLanguageModelAdapter(response_type, DEFAULT_SYMBOL_SET, lang_model_name="figmtu/opt-350m-aac")

next_char_pred = lm.state_update(list("does_it_make_sen"))
print(next_char_pred)
print("Target sentence: does_it_make_sense\n")

next_char_pred = lm.predict(list("does_it_make_sen"))
print(f"Context: does_it_make_sen")
print(f"Predictions: {next_char_pred}")
correct_char_rank = [c[0] for c in next_char_pred].index("S") + 1
print(correct_char_rank)
next_char_pred = lm.state_update(list("does_it_make_sens"))
print(next_char_pred)
print(f"Correct character rank: {correct_char_rank}\n")

next_char_pred = lm.predict(list("does_it_make_sens"))
print(f"Context: does_it_make_sens")
print(f"Predictions: {next_char_pred}")
correct_char_rank = [c[0] for c in next_char_pred].index("E") + 1
print(correct_char_rank)
next_char_pred = lm.state_update(list("does_it_make_sense"))
print(next_char_pred)
print(f"Correct character rank: {correct_char_rank}\n")

next_char_pred = lm.predict(list("does_it_make_sense"))
print(f"Context: does_it_make_sense")
print(f"Predictions: {next_char_pred}")
correct_char_rank = [c[0] for c in next_char_pred].index("_") + 1
print(correct_char_rank)
next_char_pred = lm.state_update(list("i_like_zebra"))
print(next_char_pred)
print(f"Correct character rank: {correct_char_rank}\n")

print("Target sentence: i_like_zebras\n")

next_char_pred = lm.predict(list("i_like_zebra"))
print(f"Context: i_like_zebra")
print(f"Predictions: {next_char_pred}")
correct_char_rank = [c[0] for c in next_char_pred].index("S") + 1
print(correct_char_rank)
print(f"Correct character rank: {correct_char_rank}\n")
41 changes: 28 additions & 13 deletions bcipy/language/demo/demo_mixture.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,37 @@
from bcipy.language.model.mixture import MixtureLanguageModel
from bcipy.core.symbols import alphabet
from bcipy.language.model.mixture import MixtureLanguageModelAdapter
from bcipy.core.symbols import DEFAULT_SYMBOL_SET
from bcipy.language.main import ResponseType


if __name__ == "__main__":
symbol_set = alphabet()
response_type = ResponseType.SYMBOL
lm = MixtureLanguageModel(response_type, symbol_set)
# Load the default mixture model from lm_params.json
lm = MixtureLanguageModelAdapter(response_type, DEFAULT_SYMBOL_SET)

next_char_pred = lm.state_update(list("does_it_make_sen"))
print(next_char_pred)
print("Target sentence: does_it_make_sense\n")

next_char_pred = lm.predict(list("does_it_make_sen"))
print(f"Context: does_it_make_sen")
print(f"Predictions: {next_char_pred}")
correct_char_rank = [c[0] for c in next_char_pred].index("S") + 1
print(correct_char_rank)
next_char_pred = lm.state_update(list("does_it_make_sens"))
print(next_char_pred)
print(f"Correct character rank: {correct_char_rank}\n")

next_char_pred = lm.predict(list("does_it_make_sens"))
print(f"Context: does_it_make_sens")
print(f"Predictions: {next_char_pred}")
correct_char_rank = [c[0] for c in next_char_pred].index("E") + 1
print(correct_char_rank)
next_char_pred = lm.state_update(list("does_it_make_sense"))
print(next_char_pred)
print(f"Correct character rank: {correct_char_rank}\n")

next_char_pred = lm.predict(list("does_it_make_sense"))
print(f"Context: does_it_make_sense")
print(f"Predictions: {next_char_pred}")
correct_char_rank = [c[0] for c in next_char_pred].index("_") + 1
print(correct_char_rank)
print(f"Correct character rank: {correct_char_rank}\n")

print("Target sentence: i_like_zebras\n")

next_char_pred = lm.predict(list("i_like_zebra"))
print(f"Context: i_like_zebra")
print(f"Predictions: {next_char_pred}")
correct_char_rank = [c[0] for c in next_char_pred].index("S") + 1
print(f"Correct character rank: {correct_char_rank}\n")
Loading
Loading