@@ -129,7 +129,7 @@ def __call__(self,
129
129
Returns:
130
130
reshaped_data (np.ndarray): inquiry data of shape (Channels, Inquiries, Samples)
131
131
labels (np.ndarray): integer label for each inquiry. With `trials_per_inquiry=K`,
132
- a label of [0, K-1] indicates the position of `target_label`, or label of K indicates
132
+ a label of [0, K-1] indicates the position of `target_label`, or label of [0 ... 0] indicates
133
133
`target_label` was not present.
134
134
reshaped_trigger_timing (List[List[int]]): For each inquiry, a list of the sample index where each trial
135
135
begins, accounting for the prestim buffer that may have been added to the front of each inquiry.
@@ -229,6 +229,111 @@ def extract_trials(
229
229
return np .stack (new_trials , 1 ) # C x T x S
230
230
231
231
232
+ class GazeReshaper :
233
+ def __call__ (self ,
234
+ inq_start_times : List [float ],
235
+ target_symbols : List [str ],
236
+ gaze_data : np .ndarray ,
237
+ sample_rate : int ,
238
+ symbol_set : List [str ],
239
+ channel_map : List [int ] = None ,
240
+ ) -> dict :
241
+ """Extract inquiry data and labels. Different from the EEG inquiry, the gaze inquiry window starts with
242
+ the first flicker and ends with the last flicker in the inquiry. Each inquiry has a length of ~3 seconds.
243
+ The labels are provided in the target_symbols list. It returns a Dict, where keys are the target symbols and
244
+ the values are inquiries (appended in order of appearance) where the corresponding target symbol is prompted.
245
+ Optional outputs:
246
+ reshape_data is the list of data reshaped into (Inquiries, Channels, Samples), where inquirires are appended
247
+ in chronological order. labels returns the list of target symbols in each inquiry.
248
+
249
+ Args:
250
+ inq_start_times (List[float]): Timestamp of each event in seconds
251
+ target_symbols (List[str]): Prompted symbol in each inquiry
252
+ gaze_data (np.ndarray): shape (channels, samples) eye tracking data
253
+ sample_rate (int): sample rate of data provided in eeg_data
254
+ channel_map (List[int], optional): Describes which channels to include or discard.
255
+ Defaults to None; all channels will be used.
256
+
257
+ Returns:
258
+ data_by_targets (dict): Dictionary where keys are the symbol set and values are the appended inquiries
259
+ for each symbol. dict[Key] = (np.ndarray) of shape (Channels, Samples)
260
+
261
+ reshaped_data (List[float]) [optional]: inquiry data of shape (Inquiries, Channels, Samples)
262
+ labels (List[str]) [optional] : Target symbol in each inquiry.
263
+ """
264
+ if channel_map :
265
+ # Remove the channels that we are not interested in
266
+ channels_to_remove = [idx for idx , value in enumerate (channel_map ) if value == 0 ]
267
+ gaze_data = np .delete (gaze_data , channels_to_remove , axis = 0 )
268
+
269
+ # Find the value closest to (& greater than) inq_start_times
270
+ gaze_data_timing = gaze_data [- 1 , :].tolist ()
271
+
272
+ start_times = []
273
+ for times in inq_start_times :
274
+ temp = list (filter (lambda x : x > times , gaze_data_timing ))
275
+ if len (temp ) > 0 :
276
+ start_times .append (temp [0 ])
277
+
278
+ triggers = []
279
+ for val in start_times :
280
+ triggers .append (gaze_data_timing .index (val ))
281
+
282
+ # Label for every inquiry
283
+ labels = target_symbols
284
+
285
+ # Create a dictionary with symbols as keys and data as values
286
+ # 'A': [], 'B': [] ...
287
+ data_by_targets = {}
288
+ for symbol in symbol_set :
289
+ data_by_targets [symbol ] = []
290
+
291
+ window_length = 3 # seconds, total length of flickering after prompt for each inquiry
292
+
293
+ reshaped_data = []
294
+ # Merge the inquiries if they have the same target letter:
295
+ for i , inquiry_index in enumerate (triggers ):
296
+ start = inquiry_index
297
+ stop = int (inquiry_index + (sample_rate * window_length )) # (60 samples * 3 seconds)
298
+ # Check if the data exists for the inquiry:
299
+ if stop > len (gaze_data [0 , :]):
300
+ continue
301
+
302
+ reshaped_data .append (gaze_data [:, start :stop ])
303
+ # (Optional) extracted data (Inquiries x Channels x Samples)
304
+
305
+ # Populate the dict by appending the inquiry to the correct key:
306
+ data_by_targets [labels [i ]].append (gaze_data [:, start :stop ])
307
+
308
+ # After populating, flatten the arrays in the dictionary to (Channels x Samples):
309
+ for symbol in symbol_set :
310
+ if len (data_by_targets [symbol ]) > 0 :
311
+ data_by_targets [symbol ] = np .transpose (np .array (data_by_targets [symbol ]), (1 , 0 , 2 ))
312
+ data_by_targets [symbol ] = np .reshape (data_by_targets [symbol ], (len (data_by_targets [symbol ]), - 1 ))
313
+
314
+ # Note that this is a workaround to the issue of having different number of targetness in
315
+ # each symbol. If a target symbol is prompted more than once, the data is appended to the dict as a list.
316
+ # Which is why we need to convert it to a (np.ndarray) and flatten the dimensions.
317
+ # This is not ideal, but it works for now.
318
+
319
+ # return np.stack(reshaped_data, 0), labels
320
+ return data_by_targets
321
+
322
+ @staticmethod
323
+ def centralize_all_data (data , symbol_pos ):
324
+ """ Using the symbol locations in matrix, centralize all data (in Tobii units).
325
+ This data will only be used in certain model types.
326
+ Args:
327
+ data (np.ndarray): Data in shape of num_channels x num_samples
328
+ symbol_pos (np.ndarray(float)): Array of the current symbol posiiton in Tobii units
329
+ Returns:
330
+ data (np.ndarray): Centralized data in shape of num_channels x num_samples
331
+ """
332
+ for i in range (len (data )):
333
+ data [i ] = data [i ] - symbol_pos
334
+ return data
335
+
336
+
232
337
class TrialReshaper (Reshaper ):
233
338
def __call__ (self ,
234
339
trial_targetness_label : list ,
0 commit comments