Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Offline analysis updates for multimodal fusion. #360

Merged
merged 26 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
a7a02a8
Offline analysis updates, bug fixes
celikbasak Dec 1, 2023
2bc15db
Changelog updated
celikbasak Dec 1, 2023
bda7072
plotting updates v2
celikbasak Mar 11, 2024
03a07b2
Changes to offline analysis and gaze model, GazeEvaluator added
celikbasak May 23, 2024
88951a4
add to the ignore
tab-cmd May 23, 2024
c7d20c4
Merge branch '2.0.0rc4' into matrix_copyphrase
tab-cmd May 23, 2024
c07590b
Update to Inquiry Reshaper, comply with signal base model
tab-cmd May 23, 2024
99ddd36
Updates on gaze models - will be refactored soon
celikbasak Jul 8, 2024
a6d1a31
GP model added
celikbasak Aug 11, 2024
b04a523
changes in offline analysis
celikbasak Aug 16, 2024
040f552
updates WIP
celikbasak Aug 22, 2024
cc3933a
more updates to the model
celikbasak Aug 27, 2024
14ae79c
updates on GP
celikbasak Sep 3, 2024
477d930
updates WIP
celikbasak Sep 11, 2024
0d8dc14
changes Oct 28
celikbasak Oct 28, 2024
a497d96
Merge remote-tracking branch 'origin/2.0.0rc4' into aaai_submission_t…
tab-cmd Oct 28, 2024
77e3f07
Offline analysis updated to train and save each model separately. Pat…
celikbasak Nov 4, 2024
b0d11b1
All csv files containing results are removed
celikbasak Nov 4, 2024
3e10534
Commented out codes and TODOs removed.
celikbasak Nov 4, 2024
b693180
PR reviews resolved. Deleted unused scripts and comments
celikbasak Nov 5, 2024
18047a9
Setup.py restored, Changelog updated
celikbasak Nov 5, 2024
1f0157f
Merge branch '2.0.0rc4' into aaai_submission_tbd_later
celikbasak Nov 5, 2024
6fb6eb0
Update tests, typing
tab-cmd Nov 5, 2024
2a116a0
lint, fix remaining tests
tab-cmd Nov 5, 2024
4b457e9
remove used fusion/eval module
tab-cmd Nov 5, 2024
76cf4e7
Add some documentation and cleanup unused variable assignment
tab-cmd Nov 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,6 @@ bcipy/language/out/

bcipy/simulator/tests/resource/
!bcipy/simulator/data


*.csv
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Our final release candidate before the official 2.0 release!
- New task protocol for orchestrating tasks in a session. This refactors several Task and Cli functionality #339
- Model
- Offline analysis to support multimodal fusion. Initial release of GazeModel, GazeReshaper, and Gaze Visualization #294
- Updates to ensure seamless offline analysis for both EEG and Gaze data #305
- Stimuli
- Updates to ensure stimuli are presented at the same frequency #287 Output stimuli position, screen capture and monitor information after Matrix tasks #303
- Dynamic Selection Window
Expand Down
7 changes: 2 additions & 5 deletions bcipy/helpers/demo/demo_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from bcipy.helpers.load import (load_experimental_data, load_json_parameters,
load_raw_data)
from bcipy.helpers.triggers import TriggerType, trigger_decoder
from bcipy.helpers.visualization import visualize_erp
from bcipy.helpers.visualization import visualize_gaze
from bcipy.signal.process import get_default_transform

if __name__ == '__main__':
Expand Down Expand Up @@ -91,15 +91,12 @@

save_path = None if not args.save else path

figure_handles = visualize_erp(
figure_handles = visualize_gaze(
raw_data,
channel_map,
trigger_timing,
labels,
trial_window,
transform=default_transform,
plot_average=True,
plot_topomaps=True,
save_path=save_path,
show=args.show
)
95 changes: 47 additions & 48 deletions bcipy/helpers/stimuli.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def extract_trials(
prestimulus_samples: int = 0) -> np.ndarray:
"""Extract Trials.

After using the InquiryReshaper, it may be necessary to further trial the data for processing.
After using the InquiryReshaper, it may be necessary to further extract the trials for processing.
Using the number of samples and inquiry timing, the data is reshaped from Channels, Inquiry, Samples to
Channels, Trials, Samples. These should match with the trials extracted from the TrialReshaper given the same
slicing parameters.
Expand Down Expand Up @@ -245,38 +245,46 @@ def __call__(self,
target_symbols: List[str],
gaze_data: np.ndarray,
sample_rate: int,
stimulus_duration: float,
num_stimuli_per_inquiry: int,
symbol_set: List[str],
channel_map: Optional[List[int]] = None,
) -> dict:
"""Extract inquiry data and labels. Different from the EEG inquiry, the gaze inquiry window starts with
the first flicker and ends with the last flicker in the inquiry. Each inquiry has a length of ~3 seconds.
The labels are provided in the target_symbols list. It returns a Dict, where keys are the target symbols and
the values are inquiries (appended in order of appearance) where the corresponding target symbol is prompted.
) -> Tuple[dict, List[float], List[str]]:
"""Extract gaze trajectory data and labels.

Different from the EEG, gaze inquiry windows start with the first highlighted symbol and end with the
last highlighted symbol in the inquiry. Each inquiry has a length of (trial duration x num of trials)
seconds. Labels are provided in 'target_symbols'. It returns a Dict, where keys are the target symbols
and the values are inquiries (appended in order of appearance) where the corresponding target symbol is
prompted.

Optional outputs:
reshape_data is the list of data reshaped into (Inquiries, Channels, Samples), where inquirires are appended
in chronological order. labels returns the list of target symbols in each inquiry.
reshape_data is the list of data reshaped into (Inquiries, Channels, Samples), where inquirires are
appended in chronological order.
labels returns the list of target symbols in each inquiry.

Args:
Parameters
----------
inq_start_times (List[float]): Timestamp of each event in seconds
target_symbols (List[str]): Prompted symbol in each inquiry
gaze_data (np.ndarray): shape (channels, samples) eye tracking data
sample_rate (int): sample rate of data provided in eeg_data
sample_rate (int): sample rate of eye tracker data
stimulus_duration (float): duration of flash time (in seconds) for each trial
num_stimuli_per_inquiry (int): number of stimuli in each inquiry (default: 10)
symbol_set (List[str]): list of all symbols for the task
channel_map (List[int], optional): Describes which channels to include or discard.
Defaults to None; all channels will be used.

Returns:
data_by_targets (dict): Dictionary where keys are the symbol set and values are the appended inquiries
for each symbol. dict[Key] = (np.ndarray) of shape (Channels, Samples)

Returns
-------
data_by_targets (dict): Dictionary where keys consist of the symbol set, and values
the appended inquiries for each symbol. dict[Key] = (np.ndarray) of shape (Channels, Samples)

reshaped_data (List[float]) [optional]: inquiry data of shape (Inquiries, Channels, Samples)
labels (List[str]) [optional] : Target symbol in each inquiry.
"""
if channel_map:
# Remove the channels that we are not interested in
channels_to_remove = [idx for idx, value in enumerate(channel_map) if value == 0]
gaze_data = np.delete(gaze_data, channels_to_remove, axis=0)

# Find the value closest to (& greater than) inq_start_times
# Find the timestamp value closest to (& greater than) inq_start_times.
# Lsl timestamps are the last row in the gaze_data
gaze_data_timing = gaze_data[-1, :].tolist()

start_times = []
Expand All @@ -294,54 +302,45 @@ def __call__(self,

# Create a dictionary with symbols as keys and data as values
# 'A': [], 'B': [] ...
data_by_targets: Dict[str, list] = {}
data_by_targets_dict: Dict[str, list] = {}
for symbol in symbol_set:
data_by_targets[symbol] = []
data_by_targets_dict[symbol] = []

window_length = 3 # seconds, total length of flickering after prompt for each inquiry
buffer = stimulus_duration / 5 # seconds, buffer for each inquiry
# NOTE: This buffer is used to account for the screen downtime between each stimulus.
# There is a "duty cycle" of 80% for the stimuli, so we add a buffer of 20% of the stimulus length
window_length = (stimulus_duration + buffer) * num_stimuli_per_inquiry # in seconds

reshaped_data = []
# Merge the inquiries if they have the same target letter:
for i, inquiry_index in enumerate(triggers):
start = inquiry_index
stop = int(inquiry_index + (sample_rate * window_length)) # (60 samples * 3 seconds)
stop = int(inquiry_index + (sample_rate * window_length))
# Check if the data exists for the inquiry:
if stop > len(gaze_data[0, :]):
continue

reshaped_data.append(gaze_data[:, start:stop])
# (Optional) extracted data (Inquiries x Channels x Samples)
reshaped_data.append(gaze_data[:, start:stop])

# Populate the dict by appending the inquiry to the correct key:
data_by_targets[labels[i]].append(gaze_data[:, start:stop])

# After populating, flatten the arrays in the dictionary to (Channels x Samples):
for symbol in symbol_set:
if len(data_by_targets[symbol]) > 0:
data_by_targets[symbol] = np.transpose(np.array(data_by_targets[symbol]), (1, 0, 2))
data_by_targets[symbol] = np.reshape(data_by_targets[symbol], (len(data_by_targets[symbol]), -1))

# Note that this is a workaround to the issue of having different number of targetness in
# each symbol. If a target symbol is prompted more than once, the data is appended to the dict as a list.
# Which is why we need to convert it to a (np.ndarray) and flatten the dimensions.
# This is not ideal, but it works for now.

# return np.stack(reshaped_data, 0), labels
return data_by_targets
# Populate the dict by appending the inquiry to the corresponding key:
data_by_targets_dict[labels[i]].append(gaze_data[:, start:stop])

return data_by_targets_dict, reshaped_data, labels

@staticmethod
def centralize_all_data(data, symbol_pos):
def centralize_all_data(data: np.ndarray, symbol_pos: np.ndarray) -> np.ndarray:
""" Using the symbol locations in matrix, centralize all data (in Tobii units).
This data will only be used in certain model types.
Args:
data (np.ndarray): Data in shape of num_channels x num_samples
data (np.ndarray): Data in shape of num_samples x num_dimensions
symbol_pos (np.ndarray(float)): Array of the current symbol posiiton in Tobii units
Returns:
data (np.ndarray): Centralized data in shape of num_channels x num_samples
new_data (np.ndarray): Centralized data in shape of num_samples x num_dimensions
"""
new_data = np.copy(data)
for i in range(len(data)):
data[i] = data[i] - symbol_pos
return data
new_data[i] = data[i] - symbol_pos

return new_data


class TrialReshaper(Reshaper):
Expand Down
Loading