@@ -713,6 +713,22 @@ def test_best_case_inquiry_gen(self):
713
713
self .assertEqual ([1 ] + ([0.2 ] * n ), times [0 ])
714
714
self .assertEqual (['red' ] + (['white' ] * n ), colors [0 ])
715
715
716
+ def test_best_case_inquiry_gen_invalid_alp (self ):
717
+ """Test best_case_rsvp_inq_gen throws error when passed invalid alp shape"""
718
+ alp = ['a' , 'b' , 'c' , 'd' ]
719
+ session_stimuli = [0.1 , 0.1 , 0.1 , 0.2 , 0.2 , 0.1 , 0.2 ]
720
+ stim_length = 5
721
+ with self .assertRaises (BciPyCoreException , msg = 'Missing information about the alphabet.' ):
722
+ best_case_rsvp_inq_gen (
723
+ alp = alp ,
724
+ session_stimuli = session_stimuli ,
725
+ timing = [1 , 0.2 ],
726
+ color = ['red' , 'white' ],
727
+ stim_number = 1 ,
728
+ stim_length = stim_length ,
729
+ is_txt = True
730
+ )
731
+
716
732
def test_best_case_inquiry_gen_with_inq_constants (self ):
717
733
"""Test best_case_rsvp_inq_gen with inquiry constants"""
718
734
@@ -854,6 +870,22 @@ def test_trial_reshaper(self):
854
870
self .assertTrue (np .all (labels == [1 , 0 , 0 ]))
855
871
self .assertTrue (reshaped_trials .shape == expected_shape )
856
872
873
+ def test_trial_reshaper_with_no_channel_map (self ):
874
+ sample_rate = 256
875
+ trial_length_s = 0.5
876
+ reshaped_trials , labels = TrialReshaper ()(
877
+ trial_targetness_label = self .target_info ,
878
+ timing_info = self .timing_info ,
879
+ eeg_data = self .eeg ,
880
+ sample_rate = sample_rate ,
881
+ channel_map = None ,
882
+ poststimulus_length = trial_length_s
883
+ )
884
+ trial_length_samples = int (sample_rate * trial_length_s )
885
+ expected_shape = (self .channel_number , len (self .target_info ), trial_length_samples )
886
+ self .assertTrue (np .all (labels == [1 , 0 , 0 ]))
887
+ self .assertTrue (reshaped_trials .shape == expected_shape )
888
+
857
889
858
890
class TestInquiryReshaper (unittest .TestCase ):
859
891
@@ -918,6 +950,20 @@ def test_inquiry_reshaper(self):
918
950
self .assertTrue (reshaped_data .shape == expected_shape )
919
951
self .assertTrue (np .all (labels == self .true_labels ))
920
952
953
+ def test_inquiry_reshaper_with_no_channel_map (self ):
954
+ reshaped_data , labels , _ = InquiryReshaper ()(
955
+ trial_targetness_label = self .target_info ,
956
+ timing_info = self .timing_info ,
957
+ eeg_data = self .eeg ,
958
+ sample_rate = self .sample_rate ,
959
+ trials_per_inquiry = self .trials_per_inquiry ,
960
+ channel_map = None ,
961
+ poststimulus_length = self .trial_length
962
+ )
963
+ expected_shape = (self .n_channel , self .n_inquiry , self .samples_per_inquiry )
964
+ self .assertTrue (reshaped_data .shape == expected_shape )
965
+ self .assertTrue (np .all (labels == self .true_labels ))
966
+
921
967
def test_inquiry_reshaper_trial_extraction (self ):
922
968
timing = [[1 , 3 , 4 ], [1 , 4 , 5 ], [1 , 2 , 3 ], [4 , 5 , 6 ]]
923
969
# make a fake eeg data array (n_channels, n_inquiry, n_samples)
0 commit comments