Skip to content

Commit 586e4d1

Browse files
authored
Artifact Detection and BCI-Fit Artifact Report (#336)
1 parent 72dd333 commit 586e4d1

File tree

14 files changed

+1023
-306
lines changed

14 files changed

+1023
-306
lines changed

bcipy/helpers/demo/demo_report.py

+131
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
from pathlib import Path
2+
from bcipy.helpers.load import load_json_parameters, load_raw_data, load_experimental_data
3+
from bcipy.helpers.triggers import trigger_decoder, TriggerType
4+
from bcipy.config import (
5+
BCIPY_ROOT,
6+
DEFAULT_PARAMETER_FILENAME,
7+
RAW_DATA_FILENAME,
8+
TRIGGER_FILENAME,
9+
DEFAULT_DEVICE_SPEC_FILENAME)
10+
11+
from bcipy.acquisition import devices
12+
from bcipy.helpers.acquisition import analysis_channels
13+
from bcipy.helpers.visualization import visualize_erp
14+
from bcipy.signal.process import get_default_transform
15+
from bcipy.signal.evaluate.artifact import ArtifactDetection
16+
from bcipy.helpers.report import Report, SignalReportSection, SessionReportSection
17+
18+
19+
if __name__ == "__main__":
20+
import argparse
21+
22+
parser = argparse.ArgumentParser()
23+
parser.add_argument(
24+
'-p',
25+
'--path',
26+
help='Path to the directory with >= 1 sessions to be analyzed for artifacts',
27+
required=False)
28+
29+
args = parser.parse_args()
30+
colabel = True
31+
# if no path is provided, prompt for one using a GUI
32+
path = args.path
33+
if not path:
34+
path = load_experimental_data()
35+
36+
trial_window = (0, 1.0)
37+
38+
positions = None
39+
for session in Path(path).iterdir():
40+
# loop through the sessions, pausing after each one to allow for manual stopping
41+
if session.is_dir():
42+
print(f'Processing {session}')
43+
prompt = input('Hit enter to continue or type "skip" to skip processing: ')
44+
if prompt != 'skip':
45+
# load the parameters from the data directory
46+
parameters = load_json_parameters(
47+
f'{session}/{DEFAULT_PARAMETER_FILENAME}', value_cast=True)
48+
49+
# load the raw data from the data directory
50+
raw_data = load_raw_data(Path(session, f'{RAW_DATA_FILENAME}.csv'))
51+
type_amp = raw_data.daq_type
52+
channels = raw_data.channels
53+
sample_rate = raw_data.sample_rate
54+
downsample_rate = parameters.get("down_sampling_rate")
55+
notch_filter = parameters.get("notch_filter_frequency")
56+
filter_high = parameters.get("filter_high")
57+
filter_low = parameters.get("filter_low")
58+
filter_order = parameters.get("filter_order")
59+
static_offset = parameters.get("static_trigger_offset")
60+
61+
default_transform = get_default_transform(
62+
sample_rate_hz=sample_rate,
63+
notch_freq_hz=notch_filter,
64+
bandpass_low=filter_low,
65+
bandpass_high=filter_high,
66+
bandpass_order=filter_order,
67+
downsample_factor=downsample_rate,
68+
)
69+
70+
# load the triggers
71+
if colabel:
72+
trigger_type, trigger_timing, trigger_label = trigger_decoder(
73+
offset=parameters.get('static_trigger_offset'),
74+
trigger_path=f"{session}/{TRIGGER_FILENAME}",
75+
exclusion=[TriggerType.PREVIEW, TriggerType.EVENT, TriggerType.FIXATION],
76+
)
77+
triggers = (trigger_type, trigger_timing, trigger_label)
78+
else:
79+
triggers = None
80+
81+
devices.load(Path(BCIPY_ROOT, DEFAULT_DEVICE_SPEC_FILENAME))
82+
device_spec = devices.preconfigured_device(raw_data.daq_type)
83+
channel_map = analysis_channels(channels, device_spec)
84+
85+
# check the device spec for any frontal channels to use for EOG detection
86+
eye_channels = []
87+
for channel in device_spec.channels:
88+
if 'F' in channel:
89+
eye_channels.append(channel)
90+
if len(eye_channels) == 0:
91+
eye_channels = None
92+
93+
artifact_detector = ArtifactDetection(
94+
raw_data,
95+
parameters,
96+
device_spec,
97+
eye_channels=eye_channels,
98+
session_triggers=triggers)
99+
100+
detected = artifact_detector.detect_artifacts()
101+
figure_handles = visualize_erp(
102+
raw_data,
103+
channel_map,
104+
trigger_timing,
105+
trigger_label,
106+
trial_window,
107+
transform=default_transform,
108+
plot_average=True,
109+
plot_topomaps=True,
110+
)
111+
112+
# Try to find a pkl file in the session folder
113+
pkl_file = None
114+
for file in session.iterdir():
115+
if file.suffix == '.pkl':
116+
pkl_file = file
117+
break
118+
119+
if pkl_file:
120+
auc = pkl_file.stem.split('_')[-1]
121+
else:
122+
auc = 'No Signal Model found in session folder'
123+
124+
sr = SignalReportSection(figure_handles, artifact_detector)
125+
report = Report(session)
126+
session = {'Label': 'Demo Session Report', 'AUC': auc}
127+
session_text = SessionReportSection(session)
128+
report.add(session_text)
129+
report.add(sr)
130+
report.compile()
131+
report.save()

bcipy/helpers/report.py

+68-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
# mypy: disable-error-code="union-attr"
12
import io
23
from abc import ABC
3-
from typing import List, Optional
4+
from typing import List, Optional, Tuple
5+
6+
from matplotlib import pyplot as plt
47

58
from matplotlib.figure import Figure
69
from reportlab.platypus import SimpleDocTemplate, Paragraph, Image
@@ -10,6 +13,7 @@
1013
from reportlab.lib.units import inch
1114

1215
from bcipy.config import BCIPY_FULL_LOGO_PATH
16+
from bcipy.signal.evaluate.artifact import ArtifactDetection
1317

1418

1519
class ReportSection(ABC):
@@ -40,9 +44,15 @@ class SignalReportSection(ReportSection):
4044

4145
def __init__(
4246
self,
43-
figures: List[Figure]) -> None:
47+
figures: List[Figure],
48+
artifact: Optional[ArtifactDetection] = None) -> None:
4449
self.figures = figures
4550
self.report_flowables: List[Flowable] = []
51+
self.artifact = artifact
52+
53+
if self.artifact:
54+
assert self.artifact.analysis_done is not False, (
55+
"If providing artifact for this report, an analysis must be complete to run this report.")
4656
self.style = getSampleStyleSheet()
4757

4858
def compile(self) -> Flowable:
@@ -51,9 +61,65 @@ def compile(self) -> Flowable:
5161
Compiles the Signal Report sections into a flowable that can be used to generate a Report.
5262
"""
5363
self.report_flowables.append(self._create_header())
64+
if self.artifact:
65+
self.report_flowables.append(self._create_artifact_section())
5466
self.report_flowables.extend(self._create_epochs_section())
67+
5568
return KeepTogether(self.report_flowables)
5669

70+
def _create_artifact_section(self) -> Flowable:
71+
"""Create Artifact Section.
72+
73+
Creates a paragraph with the artifact information. This is only included if an artifact detection is provided.
74+
"""
75+
artifact_report = []
76+
artifacts_detected = self.artifact.dropped
77+
artifact_text = '<b>Artifact:</b>'
78+
artifact_section = Paragraph(artifact_text, self.style['BodyText'])
79+
artifact_overview = f'<b>Artifacts Detected:</b> {artifacts_detected}'
80+
artifact_section = Paragraph(artifact_overview, self.style['BodyText'])
81+
artifact_report.append(artifact_section)
82+
83+
if self.artifact.eog_annotations:
84+
eog_artifacts = f'<b>EOG Artifacts:</b> {len(self.artifact.eog_annotations)}'
85+
eog_section = Paragraph(eog_artifacts, self.style['BodyText'])
86+
artifact_report.append(eog_section)
87+
heatmap = self._create_heatmap(
88+
self.artifact.eog_annotations.onset,
89+
(0, self.artifact.total_time),
90+
'EOG')
91+
artifact_report.append(heatmap)
92+
93+
if self.artifact.voltage_annotations:
94+
voltage_artifacts = f'<b>Voltage Artifacts:</b> {len(self.artifact.voltage_annotations)}'
95+
voltage_section = Paragraph(voltage_artifacts, self.style['BodyText'])
96+
artifact_report.append(voltage_section)
97+
98+
# create a heatmap with the onset values of the voltage artifacts
99+
onsets = self.artifact.voltage_annotations.onset
100+
heatmap = self._create_heatmap(
101+
onsets,
102+
(0, self.artifact.total_time),
103+
'Voltage')
104+
artifact_report.append(heatmap)
105+
return KeepTogether(artifact_report)
106+
107+
def _create_heatmap(self, onsets: List[float], range: Tuple[float, float], type: str) -> Image:
108+
"""Create Heatmap.
109+
110+
Creates a heatmap image with the onset values of the voltage artifacts.
111+
"""
112+
# create a heatmap with the onset values
113+
fig, ax = plt.subplots()
114+
fig.set_size_inches(6, 3)
115+
ax.hist(onsets, bins=100, range=range, color='red', alpha=0.7)
116+
ax.set_title(f'{type} Artifact Onsets')
117+
ax.set_xlabel('Time (s)')
118+
# make the label text smaller
119+
ax.set_ylabel('Frequency')
120+
heatmap = self.convert_figure_to_image(fig)
121+
return heatmap
122+
57123
def _create_epochs_section(self) -> List[Image]:
58124
"""Create Epochs Section.
59125

bcipy/helpers/stimuli.py

+25-9
Original file line numberDiff line numberDiff line change
@@ -412,21 +412,37 @@ def update_inquiry_timing(timing: List[List[float]], downsample: int) -> List[Li
412412

413413

414414
def mne_epochs(mne_data: RawArray,
415-
trigger_timing: List[float],
416415
trial_length: float,
417-
trigger_labels: List[int],
418-
baseline: Optional[Tuple[Any, float]] = None) -> Epochs:
416+
trigger_timing: Optional[List[float]] = None,
417+
trigger_labels: Optional[List[int]] = None,
418+
baseline: Optional[Tuple[Any, float]] = None,
419+
reject_by_annotation: bool = False,
420+
preload: bool = False) -> Epochs:
419421
"""MNE Epochs.
420422
421423
Using an MNE RawArray, reshape the data given trigger information. If two labels present [0, 1],
422424
each may be accessed by numbered order. Ex. first_class = epochs['1'], second_class = epochs['2']
423425
"""
424-
annotations = Annotations(trigger_timing, [trial_length] * len(trigger_timing), trigger_labels)
425-
mne_data.set_annotations(annotations)
426-
events_from_annot, _ = mne.events_from_annotations(mne_data)
427-
if not baseline:
428-
baseline = (None, 0.0)
429-
return Epochs(mne_data, events_from_annot, tmax=trial_length, baseline=baseline)
426+
old_annotations = mne_data.annotations
427+
if trigger_timing and trigger_labels:
428+
new_annotations = Annotations(trigger_timing, [trial_length] * len(trigger_timing), trigger_labels)
429+
all_annotations = new_annotations + old_annotations
430+
else:
431+
all_annotations = old_annotations
432+
433+
tmp_data = mne_data.copy()
434+
tmp_data.set_annotations(all_annotations)
435+
436+
events_from_annot, _ = mne.events_from_annotations(tmp_data)
437+
return Epochs(
438+
mne_data,
439+
events_from_annot,
440+
baseline=baseline,
441+
tmax=trial_length,
442+
tmin=-0.05,
443+
proj=False, # apply SSP projection to data. Defaults to True in Epochs.
444+
reject_by_annotation=reject_by_annotation,
445+
preload=preload)
430446

431447

432448
def alphabetize(stimuli: List[str]) -> List[str]:

bcipy/helpers/triggers.py

+1
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ class TriggerType(Enum):
161161
OFFSET = "offset"
162162
EVENT = "event"
163163
PREVIEW = "preview"
164+
ARTIFACT = "artifact"
164165

165166
@classmethod
166167
def list(cls) -> List[str]:

bcipy/helpers/visualization.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def visualize_erp(
9191
baseline = None
9292

9393
mne_data = convert_to_mne(raw_data, channel_map=channel_map, transform=transform)
94-
epochs = mne_epochs(mne_data, trigger_timing, trial_length, trigger_labels, baseline=baseline)
94+
epochs = mne_epochs(mne_data, trial_length, trigger_timing, trigger_labels, baseline=baseline)
9595
# *Note* We assume, as described above, two trigger classes are defined for use in trigger_labels
9696
# (Nontarget=0 and Target=1). This will map into two corresponding MNE epochs whose indexing starts at 1.
9797
# Therefore, epochs['1'] == Nontarget and epochs['2'] == Target.

bcipy/parameters/devices.json

+2-8
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,7 @@
3333
"excluded_from_analysis": [
3434
"TRG",
3535
"X1", "X2", "X3",
36-
"A2",
37-
"T3", "T4",
38-
"Fp1", "Fp2",
39-
"F7", "F8",
40-
"P3", "P4",
41-
"F3", "F4",
42-
"C3", "C4"
36+
"A2"
4337
],
4438
"status": "active",
4539
"static_offset": 0.1
@@ -67,7 +61,7 @@
6761
"name": "DSI-Flex",
6862
"content_type": "EEG",
6963
"channels": [
70-
{ "name": "F3", "label": "Fz", "units": "microvolts", "type": "EEG" },
64+
{ "name": "P4", "label": "Cz", "units": "microvolts", "type": "EEG" },
7165
{ "name": "S2", "label": "Oz", "units": "microvolts", "type": "EEG" },
7266
{ "name": "S3", "label": "P4", "units": "microvolts", "type": "EEG" },
7367
{ "name": "S4", "label": "P3", "units": "microvolts", "type": "EEG" },

0 commit comments

Comments
 (0)