24
24
from bcipy .config import DEFAULT_FIXATION_PATH , DEFAULT_TEXT_FIXATION , SESSION_LOG_FILENAME
25
25
from bcipy .exceptions import BciPyCoreException
26
26
from bcipy .helpers .list import grouper
27
+ from bcipy .helpers .symbols import alphabet
27
28
28
29
# Prevents pillow from filling the console with debug info
29
30
logging .getLogger ('PIL' ).setLevel (logging .WARNING )
@@ -198,7 +199,7 @@ def extract_trials(
198
199
prestimulus_samples : int = 0 ) -> np .ndarray :
199
200
"""Extract Trials.
200
201
201
- After using the InquiryReshaper, it may be necessary to further trial the data for processing.
202
+ After using the InquiryReshaper, it may be necessary to further extract the trials for processing.
202
203
Using the number of samples and inquiry timing, the data is reshaped from Channels, Inquiry, Samples to
203
204
Channels, Trials, Samples. These should match with the trials extracted from the TrialReshaper given the same
204
205
slicing parameters.
@@ -245,38 +246,46 @@ def __call__(self,
245
246
target_symbols : List [str ],
246
247
gaze_data : np .ndarray ,
247
248
sample_rate : int ,
248
- symbol_set : List [str ],
249
+ stimulus_duration : float ,
250
+ num_stimuli_per_inquiry : int ,
251
+ symbol_set : List [str ] = alphabet (),
249
252
channel_map : Optional [List [int ]] = None ,
250
- ) -> dict :
251
- """Extract inquiry data and labels. Different from the EEG inquiry, the gaze inquiry window starts with
252
- the first flicker and ends with the last flicker in the inquiry. Each inquiry has a length of ~3 seconds.
253
- The labels are provided in the target_symbols list. It returns a Dict, where keys are the target symbols and
254
- the values are inquiries (appended in order of appearance) where the corresponding target symbol is prompted.
253
+ ) -> Tuple [dict , list , List [str ]]:
254
+ """Extract gaze trajectory data and labels.
255
+
256
+ Different from the EEG, gaze inquiry windows start with the first highlighted symbol and end with the
257
+ last highlighted symbol in the inquiry. Each inquiry has a length of (trial duration x num of trials)
258
+ seconds. Labels are provided in 'target_symbols'. It returns a Dict, where keys are the target symbols
259
+ and the values are inquiries (appended in order of appearance) where the corresponding target symbol is
260
+ prompted.
261
+
255
262
Optional outputs:
256
- reshape_data is the list of data reshaped into (Inquiries, Channels, Samples), where inquirires are appended
257
- in chronological order. labels returns the list of target symbols in each inquiry.
263
+ reshape_data is the list of data reshaped into (Inquiries, Channels, Samples), where inquirires are
264
+ appended in chronological order.
265
+ labels returns the list of target symbols in each inquiry.
258
266
259
- Args:
267
+ Parameters
268
+ ----------
260
269
inq_start_times (List[float]): Timestamp of each event in seconds
261
270
target_symbols (List[str]): Prompted symbol in each inquiry
262
271
gaze_data (np.ndarray): shape (channels, samples) eye tracking data
263
- sample_rate (int): sample rate of data provided in eeg_data
272
+ sample_rate (int): sample rate of eye tracker data
273
+ stimulus_duration (float): duration of flash time (in seconds) for each trial
274
+ num_stimuli_per_inquiry (int): number of stimuli in each inquiry (default: 10)
275
+ symbol_set (List[str]): list of all symbols for the task
264
276
channel_map (List[int], optional): Describes which channels to include or discard.
265
277
Defaults to None; all channels will be used.
266
278
267
- Returns:
268
- data_by_targets (dict): Dictionary where keys are the symbol set and values are the appended inquiries
269
- for each symbol. dict[Key] = (np.ndarray) of shape (Channels, Samples)
279
+ Returns
280
+ -------
281
+ data_by_targets (dict): Dictionary where keys consist of the symbol set, and values
282
+ the appended inquiries for each symbol. dict[Key] = (np.ndarray) of shape (Channels, Samples)
270
283
271
284
reshaped_data (List[float]) [optional]: inquiry data of shape (Inquiries, Channels, Samples)
272
285
labels (List[str]) [optional] : Target symbol in each inquiry.
273
286
"""
274
- if channel_map :
275
- # Remove the channels that we are not interested in
276
- channels_to_remove = [idx for idx , value in enumerate (channel_map ) if value == 0 ]
277
- gaze_data = np .delete (gaze_data , channels_to_remove , axis = 0 )
278
-
279
- # Find the value closest to (& greater than) inq_start_times
287
+ # Find the timestamp value closest to (& greater than) inq_start_times.
288
+ # Lsl timestamps are the last row in the gaze_data
280
289
gaze_data_timing = gaze_data [- 1 , :].tolist ()
281
290
282
291
start_times = []
@@ -294,54 +303,45 @@ def __call__(self,
294
303
295
304
# Create a dictionary with symbols as keys and data as values
296
305
# 'A': [], 'B': [] ...
297
- data_by_targets : Dict [str , list ] = {}
306
+ data_by_targets_dict : Dict [str , list ] = {}
298
307
for symbol in symbol_set :
299
- data_by_targets [symbol ] = []
308
+ data_by_targets_dict [symbol ] = []
300
309
301
- window_length = 3 # seconds, total length of flickering after prompt for each inquiry
310
+ buffer = stimulus_duration / 5 # seconds, buffer for each inquiry
311
+ # NOTE: This buffer is used to account for the screen downtime between each stimulus.
312
+ # There is a "duty cycle" of 80% for the stimuli, so we add a buffer of 20% of the stimulus length
313
+ window_length = (stimulus_duration + buffer ) * num_stimuli_per_inquiry # in seconds
302
314
303
315
reshaped_data = []
304
316
# Merge the inquiries if they have the same target letter:
305
317
for i , inquiry_index in enumerate (triggers ):
306
318
start = inquiry_index
307
- stop = int (inquiry_index + (sample_rate * window_length )) # (60 samples * 3 seconds)
319
+ stop = int (inquiry_index + (sample_rate * window_length ))
308
320
# Check if the data exists for the inquiry:
309
321
if stop > len (gaze_data [0 , :]):
310
322
continue
311
-
312
- reshaped_data .append (gaze_data [:, start :stop ])
313
323
# (Optional) extracted data (Inquiries x Channels x Samples)
324
+ reshaped_data .append (gaze_data [:, start :stop ])
314
325
315
- # Populate the dict by appending the inquiry to the correct key:
316
- data_by_targets [labels [i ]].append (gaze_data [:, start :stop ])
317
-
318
- # After populating, flatten the arrays in the dictionary to (Channels x Samples):
319
- for symbol in symbol_set :
320
- if len (data_by_targets [symbol ]) > 0 :
321
- data_by_targets [symbol ] = np .transpose (np .array (data_by_targets [symbol ]), (1 , 0 , 2 ))
322
- data_by_targets [symbol ] = np .reshape (data_by_targets [symbol ], (len (data_by_targets [symbol ]), - 1 ))
323
-
324
- # Note that this is a workaround to the issue of having different number of targetness in
325
- # each symbol. If a target symbol is prompted more than once, the data is appended to the dict as a list.
326
- # Which is why we need to convert it to a (np.ndarray) and flatten the dimensions.
327
- # This is not ideal, but it works for now.
326
+ # Populate the dict by appending the inquiry to the corresponding key:
327
+ data_by_targets_dict [labels [i ]].append (gaze_data [:, start :stop ])
328
328
329
- # return np.stack(reshaped_data, 0), labels
330
- return data_by_targets
329
+ return data_by_targets_dict , reshaped_data , labels
331
330
332
- @staticmethod
333
- def centralize_all_data (data , symbol_pos ):
331
+ def centralize_all_data (self , data : np .ndarray , symbol_pos : np .ndarray ) -> np .ndarray :
334
332
""" Using the symbol locations in matrix, centralize all data (in Tobii units).
335
333
This data will only be used in certain model types.
336
334
Args:
337
- data (np.ndarray): Data in shape of num_channels x num_samples
335
+ data (np.ndarray): Data in shape of num_samples x num_dimensions
338
336
symbol_pos (np.ndarray(float)): Array of the current symbol posiiton in Tobii units
339
337
Returns:
340
- data (np.ndarray): Centralized data in shape of num_channels x num_samples
338
+ new_data (np.ndarray): Centralized data in shape of num_samples x num_dimensions
341
339
"""
340
+ new_data = np .copy (data )
342
341
for i in range (len (data )):
343
- data [i ] = data [i ] - symbol_pos
344
- return data
342
+ new_data [i ] = data [i ] - symbol_pos
343
+
344
+ return new_data
345
345
346
346
347
347
class TrialReshaper (Reshaper ):
0 commit comments