Skip to content

Commit c6bce92

Browse files
authored
Merge pull request #302 from CAMBI-tech/reshaper_testing
Adds tests to rehsapers
2 parents 3a97da2 + b89c4fe commit c6bce92

File tree

2 files changed

+47
-1
lines changed

2 files changed

+47
-1
lines changed

bcipy/helpers/stimuli.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def extract_trials(
221221

222222
try:
223223
new_trials.append(inquiries[:, inquiry_idx, start:end])
224-
except IndexError:
224+
except IndexError: # pragma: no cover
225225
raise BciPyCoreException(
226226
f'InquiryReshaper.extract_trials: index out of bounds. \n'
227227
f'Inquiry: [{inquiry_idx}] from {start}:{end}. init_time: {time}, '

bcipy/helpers/tests/test_stimuli.py

+46
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,22 @@ def test_best_case_inquiry_gen(self):
713713
self.assertEqual([1] + ([0.2] * n), times[0])
714714
self.assertEqual(['red'] + (['white'] * n), colors[0])
715715

716+
def test_best_case_inquiry_gen_invalid_alp(self):
717+
"""Test best_case_rsvp_inq_gen throws error when passed invalid alp shape"""
718+
alp = ['a', 'b', 'c', 'd']
719+
session_stimuli = [0.1, 0.1, 0.1, 0.2, 0.2, 0.1, 0.2]
720+
stim_length = 5
721+
with self.assertRaises(BciPyCoreException, msg='Missing information about the alphabet.'):
722+
best_case_rsvp_inq_gen(
723+
alp=alp,
724+
session_stimuli=session_stimuli,
725+
timing=[1, 0.2],
726+
color=['red', 'white'],
727+
stim_number=1,
728+
stim_length=stim_length,
729+
is_txt=True
730+
)
731+
716732
def test_best_case_inquiry_gen_with_inq_constants(self):
717733
"""Test best_case_rsvp_inq_gen with inquiry constants"""
718734

@@ -854,6 +870,22 @@ def test_trial_reshaper(self):
854870
self.assertTrue(np.all(labels == [1, 0, 0]))
855871
self.assertTrue(reshaped_trials.shape == expected_shape)
856872

873+
def test_trial_reshaper_with_no_channel_map(self):
874+
sample_rate = 256
875+
trial_length_s = 0.5
876+
reshaped_trials, labels = TrialReshaper()(
877+
trial_targetness_label=self.target_info,
878+
timing_info=self.timing_info,
879+
eeg_data=self.eeg,
880+
sample_rate=sample_rate,
881+
channel_map=None,
882+
poststimulus_length=trial_length_s
883+
)
884+
trial_length_samples = int(sample_rate * trial_length_s)
885+
expected_shape = (self.channel_number, len(self.target_info), trial_length_samples)
886+
self.assertTrue(np.all(labels == [1, 0, 0]))
887+
self.assertTrue(reshaped_trials.shape == expected_shape)
888+
857889

858890
class TestInquiryReshaper(unittest.TestCase):
859891

@@ -918,6 +950,20 @@ def test_inquiry_reshaper(self):
918950
self.assertTrue(reshaped_data.shape == expected_shape)
919951
self.assertTrue(np.all(labels == self.true_labels))
920952

953+
def test_inquiry_reshaper_with_no_channel_map(self):
954+
reshaped_data, labels, _ = InquiryReshaper()(
955+
trial_targetness_label=self.target_info,
956+
timing_info=self.timing_info,
957+
eeg_data=self.eeg,
958+
sample_rate=self.sample_rate,
959+
trials_per_inquiry=self.trials_per_inquiry,
960+
channel_map=None,
961+
poststimulus_length=self.trial_length
962+
)
963+
expected_shape = (self.n_channel, self.n_inquiry, self.samples_per_inquiry)
964+
self.assertTrue(reshaped_data.shape == expected_shape)
965+
self.assertTrue(np.all(labels == self.true_labels))
966+
921967
def test_inquiry_reshaper_trial_extraction(self):
922968
timing = [[1, 3, 4], [1, 4, 5], [1, 2, 3], [4, 5, 6]]
923969
# make a fake eeg data array (n_channels, n_inquiry, n_samples)

0 commit comments

Comments
 (0)