Skip to content

Commit a56e1d8

Browse files
authored
Offline analysis updates for multimodal fusion. (#360)
1 parent 171c706 commit a56e1d8

20 files changed

+762
-421
lines changed

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,4 @@ bcipy/language/out/
3737

3838

3939
bcipy/simulator/tests/resource/
40-
!bcipy/simulator/data
40+
!bcipy/simulator/data

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,14 @@ Our final release candidate before the official 2.0 release!
3434
- session data to VEP calibration #322
3535
- Model
3636
- Offline analysis to support multimodal fusion. Initial release of GazeModel, GazeReshaper, and Gaze Visualization #294
37+
- Updates to ensure seamless offline analysis for both EEG and Gaze data #305
38+
- Offline analysis support for EEG and (multiple) gaze models. Updates to support Eye Tracker Evidence class #360
3739
- Language Model
3840
- Add Oracle model #316
3941
- Random Uniform model #311
4042
- Stimuli
4143
- Updates to ensure stimuli are presented at the same frequency #287
44+
- Output stimuli position, screen capture and monitor information after Matrix tasks #303
4245
- Dynamic Selection Window
4346
- Updated trial_length to trial_window to allow for greater control of window used after stimulus presentations #291
4447
- Report

bcipy/helpers/demo/demo_visualization.py

-3
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,6 @@
9797
trigger_timing,
9898
labels,
9999
trial_window,
100-
transform=default_transform,
101-
plot_average=True,
102-
plot_topomaps=True,
103100
save_path=save_path,
104101
show=args.show
105102
)

bcipy/helpers/stimuli.py

+47-47
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from bcipy.config import DEFAULT_FIXATION_PATH, DEFAULT_TEXT_FIXATION, SESSION_LOG_FILENAME
2525
from bcipy.exceptions import BciPyCoreException
2626
from bcipy.helpers.list import grouper
27+
from bcipy.helpers.symbols import alphabet
2728

2829
# Prevents pillow from filling the console with debug info
2930
logging.getLogger('PIL').setLevel(logging.WARNING)
@@ -198,7 +199,7 @@ def extract_trials(
198199
prestimulus_samples: int = 0) -> np.ndarray:
199200
"""Extract Trials.
200201
201-
After using the InquiryReshaper, it may be necessary to further trial the data for processing.
202+
After using the InquiryReshaper, it may be necessary to further extract the trials for processing.
202203
Using the number of samples and inquiry timing, the data is reshaped from Channels, Inquiry, Samples to
203204
Channels, Trials, Samples. These should match with the trials extracted from the TrialReshaper given the same
204205
slicing parameters.
@@ -245,38 +246,46 @@ def __call__(self,
245246
target_symbols: List[str],
246247
gaze_data: np.ndarray,
247248
sample_rate: int,
248-
symbol_set: List[str],
249+
stimulus_duration: float,
250+
num_stimuli_per_inquiry: int,
251+
symbol_set: List[str] = alphabet(),
249252
channel_map: Optional[List[int]] = None,
250-
) -> dict:
251-
"""Extract inquiry data and labels. Different from the EEG inquiry, the gaze inquiry window starts with
252-
the first flicker and ends with the last flicker in the inquiry. Each inquiry has a length of ~3 seconds.
253-
The labels are provided in the target_symbols list. It returns a Dict, where keys are the target symbols and
254-
the values are inquiries (appended in order of appearance) where the corresponding target symbol is prompted.
253+
) -> Tuple[dict, list, List[str]]:
254+
"""Extract gaze trajectory data and labels.
255+
256+
Different from the EEG, gaze inquiry windows start with the first highlighted symbol and end with the
257+
last highlighted symbol in the inquiry. Each inquiry has a length of (trial duration x num of trials)
258+
seconds. Labels are provided in 'target_symbols'. It returns a Dict, where keys are the target symbols
259+
and the values are inquiries (appended in order of appearance) where the corresponding target symbol is
260+
prompted.
261+
255262
Optional outputs:
256-
reshape_data is the list of data reshaped into (Inquiries, Channels, Samples), where inquirires are appended
257-
in chronological order. labels returns the list of target symbols in each inquiry.
263+
reshape_data is the list of data reshaped into (Inquiries, Channels, Samples), where inquirires are
264+
appended in chronological order.
265+
labels returns the list of target symbols in each inquiry.
258266
259-
Args:
267+
Parameters
268+
----------
260269
inq_start_times (List[float]): Timestamp of each event in seconds
261270
target_symbols (List[str]): Prompted symbol in each inquiry
262271
gaze_data (np.ndarray): shape (channels, samples) eye tracking data
263-
sample_rate (int): sample rate of data provided in eeg_data
272+
sample_rate (int): sample rate of eye tracker data
273+
stimulus_duration (float): duration of flash time (in seconds) for each trial
274+
num_stimuli_per_inquiry (int): number of stimuli in each inquiry (default: 10)
275+
symbol_set (List[str]): list of all symbols for the task
264276
channel_map (List[int], optional): Describes which channels to include or discard.
265277
Defaults to None; all channels will be used.
266278
267-
Returns:
268-
data_by_targets (dict): Dictionary where keys are the symbol set and values are the appended inquiries
269-
for each symbol. dict[Key] = (np.ndarray) of shape (Channels, Samples)
279+
Returns
280+
-------
281+
data_by_targets (dict): Dictionary where keys consist of the symbol set, and values
282+
the appended inquiries for each symbol. dict[Key] = (np.ndarray) of shape (Channels, Samples)
270283
271284
reshaped_data (List[float]) [optional]: inquiry data of shape (Inquiries, Channels, Samples)
272285
labels (List[str]) [optional] : Target symbol in each inquiry.
273286
"""
274-
if channel_map:
275-
# Remove the channels that we are not interested in
276-
channels_to_remove = [idx for idx, value in enumerate(channel_map) if value == 0]
277-
gaze_data = np.delete(gaze_data, channels_to_remove, axis=0)
278-
279-
# Find the value closest to (& greater than) inq_start_times
287+
# Find the timestamp value closest to (& greater than) inq_start_times.
288+
# Lsl timestamps are the last row in the gaze_data
280289
gaze_data_timing = gaze_data[-1, :].tolist()
281290

282291
start_times = []
@@ -294,54 +303,45 @@ def __call__(self,
294303

295304
# Create a dictionary with symbols as keys and data as values
296305
# 'A': [], 'B': [] ...
297-
data_by_targets: Dict[str, list] = {}
306+
data_by_targets_dict: Dict[str, list] = {}
298307
for symbol in symbol_set:
299-
data_by_targets[symbol] = []
308+
data_by_targets_dict[symbol] = []
300309

301-
window_length = 3 # seconds, total length of flickering after prompt for each inquiry
310+
buffer = stimulus_duration / 5 # seconds, buffer for each inquiry
311+
# NOTE: This buffer is used to account for the screen downtime between each stimulus.
312+
# There is a "duty cycle" of 80% for the stimuli, so we add a buffer of 20% of the stimulus length
313+
window_length = (stimulus_duration + buffer) * num_stimuli_per_inquiry # in seconds
302314

303315
reshaped_data = []
304316
# Merge the inquiries if they have the same target letter:
305317
for i, inquiry_index in enumerate(triggers):
306318
start = inquiry_index
307-
stop = int(inquiry_index + (sample_rate * window_length)) # (60 samples * 3 seconds)
319+
stop = int(inquiry_index + (sample_rate * window_length))
308320
# Check if the data exists for the inquiry:
309321
if stop > len(gaze_data[0, :]):
310322
continue
311-
312-
reshaped_data.append(gaze_data[:, start:stop])
313323
# (Optional) extracted data (Inquiries x Channels x Samples)
324+
reshaped_data.append(gaze_data[:, start:stop])
314325

315-
# Populate the dict by appending the inquiry to the correct key:
316-
data_by_targets[labels[i]].append(gaze_data[:, start:stop])
317-
318-
# After populating, flatten the arrays in the dictionary to (Channels x Samples):
319-
for symbol in symbol_set:
320-
if len(data_by_targets[symbol]) > 0:
321-
data_by_targets[symbol] = np.transpose(np.array(data_by_targets[symbol]), (1, 0, 2))
322-
data_by_targets[symbol] = np.reshape(data_by_targets[symbol], (len(data_by_targets[symbol]), -1))
323-
324-
# Note that this is a workaround to the issue of having different number of targetness in
325-
# each symbol. If a target symbol is prompted more than once, the data is appended to the dict as a list.
326-
# Which is why we need to convert it to a (np.ndarray) and flatten the dimensions.
327-
# This is not ideal, but it works for now.
326+
# Populate the dict by appending the inquiry to the corresponding key:
327+
data_by_targets_dict[labels[i]].append(gaze_data[:, start:stop])
328328

329-
# return np.stack(reshaped_data, 0), labels
330-
return data_by_targets
329+
return data_by_targets_dict, reshaped_data, labels
331330

332-
@staticmethod
333-
def centralize_all_data(data, symbol_pos):
331+
def centralize_all_data(self, data: np.ndarray, symbol_pos: np.ndarray) -> np.ndarray:
334332
""" Using the symbol locations in matrix, centralize all data (in Tobii units).
335333
This data will only be used in certain model types.
336334
Args:
337-
data (np.ndarray): Data in shape of num_channels x num_samples
335+
data (np.ndarray): Data in shape of num_samples x num_dimensions
338336
symbol_pos (np.ndarray(float)): Array of the current symbol posiiton in Tobii units
339337
Returns:
340-
data (np.ndarray): Centralized data in shape of num_channels x num_samples
338+
new_data (np.ndarray): Centralized data in shape of num_samples x num_dimensions
341339
"""
340+
new_data = np.copy(data)
342341
for i in range(len(data)):
343-
data[i] = data[i] - symbol_pos
344-
return data
342+
new_data[i] = data[i] - symbol_pos
343+
344+
return new_data
345345

346346

347347
class TrialReshaper(Reshaper):

0 commit comments

Comments
 (0)