-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Added pearson correlation to losses * fix pearson correlation loss name * added seaborn to setup.py * post-processing intensity temp utils * added intensity report and postprocessing functions * minor fixes * refactored report * bumped up version to v0.0.4 * minor fix spectral angle function * fixed the decoder param for layer size --------- Co-authored-by: WassimG <wassim.gabriel@gmail.com>
- Loading branch information
Showing
9 changed files
with
205 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
__version__ = "0.0.3" | ||
__version__ = "0.0.4" | ||
|
||
META_DATA = { | ||
"author": "Omar Shouman", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from .intensity import masked_spectral_distance | ||
from .intensity import masked_spectral_distance, masked_pearson_correlation_distance | ||
|
||
__all__ = [masked_spectral_distance] | ||
__all__ = [masked_spectral_distance, masked_pearson_correlation_distance] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
from os.path import join | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import seaborn as sns | ||
from matplotlib import pyplot as plt | ||
from matplotlib.colors import LogNorm | ||
from matplotlib.ticker import LogLocator | ||
|
||
from .postprocessing import normalize_intensity_predictions | ||
from .Report import PDFFile, Report | ||
|
||
|
||
class IntensityReport(Report): | ||
"""Report generation for Fragment Ion Intensity Prediction tasks.""" | ||
|
||
TARGETS_LABEL = "x" | ||
PREDICTIONS_LABEL = "y" | ||
DEFAULT_BATCH_SIZE = 600 | ||
|
||
def __init__(self, output_path, history, figures_ext="png", batch_size=0): | ||
super(IntensityReport, self).__init__(output_path, history, figures_ext) | ||
|
||
self.pdf_file = PDFFile("DLOmix - Fragment Ion Intensity Report") | ||
|
||
if batch_size: | ||
self.batch_size = batch_size | ||
else: | ||
self.batch_size = IntensityReport.DEFAULT_BATCH_SIZE | ||
|
||
def generate_report(self, dataset, predictions): | ||
self._init_report_resources() | ||
|
||
predictions_df = self.generate_intensity_results_df(dataset, predictions) | ||
self.plot_all_metrics() | ||
|
||
# make custom plots | ||
self.plot_spectral_angle(predictions_df) | ||
|
||
self._compile_report_resources_add_pdf_pages() | ||
self.pdf_file.output(join(self._output_path, "intensity_Report.pdf"), "F") | ||
|
||
|
||
def generate_intensity_results_df(self, dataset, predictions): | ||
predictions_df = pd.DataFrame() | ||
|
||
predictions_df['sequences'] = dataset.sequences | ||
predictions_df['intensities_pred'] = predictions.tolist() | ||
predictions_df['precursor_charge_onehot'] = dataset.precursor_charge.tolist() | ||
predictions_df['intensities_raw'] = dataset.intensities.tolist() | ||
|
||
return predictions_df | ||
|
||
def plot_spectral_angle( | ||
self, | ||
predictions_df | ||
): | ||
"""Create spectral plot | ||
Arguments | ||
--------- | ||
predictions_df: dataframe with raw intensities, predictions, sequences, precursor_charges | ||
""" | ||
|
||
predictions_acc = normalize_intensity_predictions(predictions_df, self.batch_size) | ||
violin_plot = sns.violinplot(predictions_acc['spectral_angle']) | ||
|
||
save_path = join(self._output_path, "violin_spectral_angle_plot" + self._figures_ext) | ||
|
||
fig = violin_plot.get_figure() | ||
fig.savefig(save_path) | ||
|
||
self._add_report_resource( | ||
"spectral_angle_plot", | ||
"Spectral angle plot", | ||
"The following figure shows the spectral angle plot for the test data.", | ||
save_path, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
from .IntensityReport import IntensityReport | ||
from .RetentionTimeReport import RetentionTimeReport | ||
|
||
__all__ = [RetentionTimeReport] | ||
__all__ = ["RetentionTimeReport", | ||
"IntensityReport", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import functools | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
|
||
import dlomix.losses as losses | ||
|
||
|
||
def reshape_dims(array): | ||
n, dims = array.shape | ||
assert dims == 174 | ||
nlosses = 1 | ||
return array.reshape( | ||
[array.shape[0], 30 - 1, 2, nlosses, 3] | ||
) | ||
|
||
|
||
def reshape_flat(array): | ||
s = array.shape | ||
flat_dim = [s[0], functools.reduce(lambda x, y: x * y, s[1:], 1)] | ||
return array.reshape(flat_dim) | ||
|
||
|
||
def normalize_base_peak(array): | ||
# flat | ||
maxima = array.max(axis=1) | ||
array = array / maxima[:, np.newaxis] | ||
return array | ||
|
||
|
||
def mask_outofrange(array, lengths, mask=-1.): | ||
# dim | ||
for i in range(array.shape[0]): | ||
array[i, lengths[i] - 1 :, :, :, :] = mask | ||
return array | ||
|
||
|
||
def mask_outofcharge(array, charges, mask=-1.): | ||
# dim | ||
for i in range(array.shape[0]): | ||
if charges[i] < 3: | ||
array[i, :, :, :, charges[i] :] = mask | ||
return array | ||
|
||
|
||
def get_spectral_angle(true, pred, batch_size=600): | ||
|
||
n = true.shape[0] | ||
sa = np.zeros([n]) | ||
|
||
def iterate(): | ||
if n > batch_size: | ||
for i in range(n // batch_size): | ||
true_sample = true[i * batch_size : (i + 1) * batch_size] | ||
pred_sample = pred[i * batch_size : (i + 1) * batch_size] | ||
yield i, true_sample, pred_sample | ||
i = n // batch_size | ||
yield i, true[(i) * batch_size :], pred[(i) * batch_size :] | ||
else: | ||
yield 0, true, pred | ||
|
||
for i, t_b, p_b in iterate(): | ||
tf.compat.v1.reset_default_graph() | ||
with tf.compat.v1.Session() as s: | ||
sa_graph = losses.masked_spectral_distance(t_b, p_b) | ||
sa_b = 1 - s.run(sa_graph) | ||
sa[i * batch_size : i * batch_size + sa_b.shape[0]] = sa_b | ||
sa = np.nan_to_num(sa) | ||
return sa | ||
|
||
|
||
def normalize_intensity_predictions(data, batch_size=600): | ||
assert "sequences" in data, "Key sequences is missing in the data provided for post-processing" | ||
assert "intensities_pred" in data, "Key intensities_pred is missing in the data provided for post-processing" | ||
assert "precursor_charge_onehot" in data, "Key precursor_charge_onehot is missing in the data provided for post-processing" | ||
|
||
sequence_lengths = data["sequences"].apply(lambda x: len(x)) | ||
intensities = np.stack(data["intensities_pred"].to_numpy()).astype(np.float32) | ||
precursor_charge_onehot = np.stack(data["precursor_charge_onehot"].to_numpy()) | ||
charges = list(precursor_charge_onehot.argmax(axis=1) + 1) | ||
|
||
intensities[intensities < 0] = 0 | ||
intensities = reshape_dims(intensities) | ||
intensities = mask_outofrange(intensities, sequence_lengths) | ||
intensities = mask_outofcharge(intensities, charges) | ||
intensities = reshape_flat(intensities) | ||
m_idx = intensities == -1 | ||
intensities = normalize_base_peak(intensities) | ||
intensities[m_idx] = -1 | ||
data["intensities_pred"] = intensities | ||
|
||
if "intensities_raw" in data: | ||
data["spectral_angle"] = get_spectral_angle( | ||
np.stack(data["intensities_raw"].to_numpy()).astype(np.float32), intensities, batch_size=batch_size | ||
) | ||
return data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters