diff --git a/notebooks/Example_Detectability_Model_Walkthrough_prediction_colab.ipynb b/notebooks/Example_Detectability_Model_Walkthrough_prediction_colab.ipynb new file mode 100644 index 00000000..4fb3ed82 --- /dev/null +++ b/notebooks/Example_Detectability_Model_Walkthrough_prediction_colab.ipynb @@ -0,0 +1,768 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "7YWkUVjVr7qJ" + }, + "source": [ + "# Peptide Detectability Prediction \n", + "\n", + "This notebook is prepared to be run in Google [Colaboratory](https://colab.research.google.com/)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "S3DlTOq3r7qM" + }, + "source": [ + "One of the example datasets used in this notebook is deposited in the ProteomeXchange Consortium via the MAssIVE partner repository with the identifier PXD024364. The other dataset is deposited to the ProteomeXchange Consortium via the PRIDE partner repository with identifier PXD010154. \n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Installing the DLOmix Package\n", + "\n", + "If you have not installed the DLOmix package yet, you need to do so before running the code. \n", + "\n", + "You can install the DLOmix package using pip." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "aO-69zbKsGey", + "outputId": "c2064411-9f80-47e6-ca5b-312d547e0f6a", + "scrolled": true + }, + "outputs": [], + "source": [ + "# uncomment the following line to install the DLOmix package in the current environment using pip\n", + "\n", + "#!python -m pip install dlomix>0.1.3" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Mo7H9qzWr7qN" + }, + "source": [ + "#### Importing Required Libraries\n", + "\n", + "Before running the code, ensure you import all the necessary libraries. These imports are essential for accessing the functionalities needed for data processing, model training, and evaluation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "l0CS0tFur7qN", + "outputId": "664e0978-980a-4254-90d1-61e9f1603234" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import tensorflow as tf\n", + "import dlomix\n", + "import sys\n", + "import os" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "oWeVi0iar7qT" + }, + "source": [ + "## Model\n", + "\n", + "We can now create the model. The model architecture is an encoder-decoder with an attention mechanism, that is based on Bidirectional Recurrent Neural Network (BRNN) with Gated Recurrent Units (GRU). Both the Encoder and Decoder consists of a single layer, with the Decoder also including a Dense layer. The model has the default working arguments." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Q8SGTvfRr7qT" + }, + "outputs": [], + "source": [ + "from dlomix.models import DetectabilityModel\n", + "from dlomix.constants import CLASSES_LABELS, alphabet, aa_to_int_dict" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "CLASSES_LABELS, len(alphabet), aa_to_int_dict" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ZqrsF6APr7qU" + }, + "outputs": [], + "source": [ + "total_num_classes = len(CLASSES_LABELS)\n", + "input_dimension = len(alphabet)\n", + "num_cells = 64\n", + "\n", + "model = DetectabilityModel(num_units = num_cells, num_clases = total_num_classes)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Model Weights Configuration\n", + "\n", + "In the following section, you need to specify the path to the model weights you wish to use. The default path provided is set to the weights for the **Pfly** model, which is the fine-tuned model mentioned in the publication associated with this notebook.\n", + "\n", + "- **Using the Default Pfly Model**: If you are utilizing the fine-tuned Pfly model as described in the publication, you can keep the default path unchanged. This will load the model weights for Pfly.\n", + "\n", + "- **Using the Base Model or Different Weights**: If you intend to use the base model or have different weights (e.g., for a custom model), you should update the path to reflect the location of these weights." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "## Loading model weights \n", + "\n", + "model_save_path = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability'\n", + "\n", + "model.load_weights(model_save_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Workflow Overview\n", + "\n", + "This notebook supports two different workflows depending on your dataset:\n", + "\n", + "- **Labeled Data**: Use this pipeline when your dataset includes ground truth labels. This setup not only makes predictions but also allows for detailed evaluation by comparing the true labels with the predicted values, facilitating the generation of a comprehensive evaluation report.\n", + "\n", + "- **Unlabeled Data**: Use this pipeline when your dataset does not include labels. Here, the focus is on making predictions only, without generating a detailed performance report, as there are no labels to compare against.\n", + "\n", + "### Notebook Structure\n", + "\n", + "Subtitles throughout the notebook indicate the sections for each type of data:\n", + "\n", + "- **Labeled Data Section**: Follow these when your dataset includes labels to receive predictions and a comprehensive evaluation report.\n", + "\n", + "- **Unlabeled Data Section**: Use these when your dataset lacks labels, focusing solely on generating predictions.\n", + "\n", + "Make sure to select the appropriate pipeline based on your dataset." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Labeled Data" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "41qXroyKr7qP" + }, + "source": [ + "## 1. Load Data \n", + "\n", + "You can import the `DetectabilityDataset` class and create an instance to manage data for training, validation, and testing. This instance handles TensorFlow dataset objects and simplifies configuring and controlling how your data is preprocessed and split.\n", + "\n", + "For the paramters of the dataset class, please refer to the DLOmix documentation: https://dlomix.readthedocs.io/en/main/dlomix.data.html#\n", + "\n", + "\n", + "**Note**: If class labels are provided, the following encoding scheme should be used:\n", + "- **Non-Flyer**: 0\n", + "- **Weak Flyer**: 1\n", + "- **Intermediate Flyer**: 2\n", + "- **Strong Flyer**: 3" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "RiXz_epEr7qQ" + }, + "outputs": [], + "source": [ + "from dlomix.data import DetectabilityDataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# load the dataset from huggingface for prediction\n", + "\n", + "from datasets import load_dataset, DatasetDict\n", + "\n", + "# pick one of the available datasets on the HuggingFace Hub\n", + "# Collection: https://huggingface.co/collections/Wilhelmlab/detectability-datasets-671e76fb77035878c50a9c1d\n", + "\n", + "hf_data_name = \"Wilhelmlab/detectability-sinitcyn\"\n", + "#hf_data_name = \"Wilhelmlab/detectability-wang\"\n", + "\n", + "hf_dataset_split = load_dataset(hf_data_name, split=\"test\")\n", + "hf_dataset = DatasetDict({\"test\": hf_dataset_split})\n", + "hf_dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "max_pep_length = 40\n", + "BATCH_SIZE = 128\n", + "\n", + "detectability_data = DetectabilityDataset(data_source=hf_dataset,\n", + " data_format='hf',\n", + " max_seq_len=max_pep_length,\n", + " label_column=\"Classes\",\n", + " sequence_column=\"Sequences\",\n", + " dataset_columns_to_keep=['Proteins'],\n", + " batch_size=BATCH_SIZE,\n", + " with_termini=False,\n", + " alphabet=aa_to_int_dict)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "lzNXJ-s6r7qQ" + }, + "outputs": [], + "source": [ + "# This is the dataset with the test split \n", + "# You can see the column names under each split (the columns starting with _ are internal, but can also be used to look up original sequences for example \"_parsed_sequence\")\n", + "detectability_data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Accessing elements in the dataset is done by specificing the split name and then the column name\n", + "# Example here for one sequence after encoding & padding comapred to the original sequence\n", + "\n", + "detectability_data[\"test\"][\"Sequences\"][0], \"\".join(detectability_data[\"test\"][\"_parsed_sequence\"][0])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "oukZ4AyMr7qV" + }, + "source": [ + "## 2. Testing and Reporting\n", + "\n", + "We use the test dataset to assess our model's performance, which is only applicable if labels are available. The `DetectabilityReport` class allows us to compute various metrics, generate reports, and create plots for a comprehensive evaluation of the model.\n", + "\n", + "Note: The reporting module is currently under development, so some features may be unstable or subject to change." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### Generate Predictions on Test Data Using `model.predict`\n", + "\n", + "To obtain predictions for your test data, use the Keras `model.predict` method. Simply pass your test dataset to this method, and it will return the model's predictions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "RrvR8Cl3r7qV" + }, + "outputs": [], + "source": [ + "predictions = model.predict(detectability_data.tensor_test_data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "predictions.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To generate reports and calculate evaluation metrics against predictions, we obtain the targets and the data for the specific dataset split. This can be achieved using the `DetectabilityDataset` class directly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "wKk7MD7Wr7qW" + }, + "outputs": [], + "source": [ + "# access val dataset and get the Classes column\n", + "test_targets = detectability_data[\"test\"][\"Classes\"]\n", + "\n", + "\n", + "# if needed, the decoded version of the classes can be retrieved by looking up the class names\n", + "test_targets_decoded = [CLASSES_LABELS[x] for x in test_targets]\n", + "\n", + "\n", + "test_targets[0:5], test_targets_decoded[0:5]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# The dataframe needed for the report\n", + "\n", + "test_data_df = pd.DataFrame(\n", + " {\n", + " \"Sequences\": detectability_data[\"test\"][\"_parsed_sequence\"], # get the raw parsed sequences\n", + " \"Classes\": test_targets, # get the test targets from above\n", + " \"Proteins\": detectability_data[\"test\"][\"Proteins\"] # get the Proteins column from the dataset object\n", + " }\n", + ")\n", + "\n", + "test_data_df.Sequences = test_data_df.Sequences.apply(lambda x: \"\".join(x)) # join the sequences since they are a list of string amino acids.\n", + "test_data_df.head(5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4kzCh0gwr7qX" + }, + "outputs": [], + "source": [ + "from dlomix.reports.DetectabilityReport import DetectabilityReport, predictions_report\n", + "WANDB_REPORT_API_DISABLE_MESSAGE=True" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Generate a Report Using the `DetectabilityReport` Class\n", + "\n", + "The `DetectabilityReport` class provides a comprehensive way to evaluate your model by generating detailed reports and visualizations. The outputs include:\n", + "\n", + "1. **A PDF Report**: This includes evaluation metrics and plots.\n", + "2. **A CSV File**: Contains the model’s predictions.\n", + "3. **Independent Image Files**: Visualizations are saved as separate image files.\n", + "\n", + "To generate a report, provide the following parameters to the `DetectabilityReport` class:\n", + "\n", + "- **targets**: The true labels for the dataset, which are used to assess the model’s performance.\n", + "- **predictions**: The model’s output predictions for the dataset, which will be compared against the true labels.\n", + "- **input_data_df**: The DataFrame containing the input data used for generating predictions.\n", + "- **output_path**: The directory path where the generated reports, images, and CSV file will be saved.\n", + "- **history**: The training history object (e.g., containing metrics from training) if available. Set this to `None` if not applicable, such as when the report is generated for predictions without training.\n", + "- **rank_by_prot**: A boolean indicating whether to rank peptides based on their associated proteins (`True` or `False`). Defaults to `False`.\n", + "- **threshold**: The classification threshold used to adjust the decision boundary for predictions. By default, this is set to `None`, meaning no specific threshold is applied.\n", + "- **name_of_dataset**: The name of the dataset used for generating predictions, which will be included in the report to provide context.\n", + "- **name_of_model**: The name of the model used to generate the predictions, which will be specified in the report for reference." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Since the detectabiliy report expects the true labels in one-hot encoded format, we expand them here.\n", + "\n", + "num_classes = np.max(test_targets) + 1\n", + "test_targets_one_hot = np.eye(num_classes)[test_targets]\n", + "test_targets_one_hot.shape, len(test_targets)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_7LJZ3TLr7qX" + }, + "outputs": [], + "source": [ + "report = DetectabilityReport(targets = test_targets_one_hot, \n", + " predictions = predictions, \n", + " input_data_df = test_data_df, \n", + " output_path = \"./output/report_on_Sinitcyn_2000_proteins_test_set_labeled\", \n", + " history = None, \n", + " rank_by_prot = True,\n", + " threshold = None,\n", + " name_of_dataset = 'Sinitcyn 2000 proteins test set',\n", + " name_of_model = 'Fine-tuned model (Original)')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Predictions report" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "results_df = report.detectability_report_table\n", + "results_df" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Generating Evaluation Plots with `DetectabilityReport`\n", + "\n", + "The `DetectabilityReport` class enables you to generate a range of plots to visualize and evaluate model performance. It offers a comprehensive suite of visualizations to help you interpret the results of your model's predictions. Here’s how to use it:" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### ROC curve (Binary)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "report.plot_roc_curve_binary()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Confusion matrix (Binary)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "report.plot_confusion_matrix_binary()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### ROC curve (Multi-class)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "report.plot_roc_curve()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Confusion matrix (Multi-class)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "report.plot_confusion_matrix_multiclass()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Heatmap of Average Error Between Actual and Predicted Classes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "report.plot_heatmap_prediction_prob_error()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also produce a complete evaluation report with all the relevant plots in one PDF file by calling the `generate_report` function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "report.generate_report()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Unlabeled Data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Load data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For predicting on unlabeled data, follow the same workflow as described earlier (refer to the \"Load Data\" section for labeled data). Specifically, create an instance of the `DetectabilityDataset` class using your unlabeled data.The configuration below ensures that the entire dataset is treated as test data without generating additional splits (i.e., training and validation sets)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# load the dataset from huggingface for prediction\n", + "\n", + "from datasets import load_dataset, DatasetDict\n", + "\n", + "# pick one of the available datasets on the HuggingFace Hub\n", + "# Collection: https://huggingface.co/collections/Wilhelmlab/detectability-datasets-671e76fb77035878c50a9c1d\n", + "\n", + "hf_data_name = \"Wilhelmlab/detectability-sinitcyn\"\n", + "#hf_data_name = \"Wilhelmlab/detectability-wang\"\n", + "\n", + "hf_dataset_split = load_dataset(hf_data_name, split=\"test\")\n", + "hf_dataset = DatasetDict({\"test\": hf_dataset_split})\n", + "hf_dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# simulate that the class labels are not there (insert None), but we keep the column since it is needed for the dataset class\n", + "\n", + "hf_dataset = hf_dataset.map(lambda example: {**example, 'Classes': None})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "max_pep_length = 40\n", + "BATCH_SIZE = 128\n", + " \n", + "test_data_unlabeled = DetectabilityDataset(data_source=hf_dataset,\n", + " data_format='hf',\n", + " max_seq_len=max_pep_length,\n", + " label_column='Classes',\n", + " sequence_column=\"Sequences\",\n", + " dataset_columns_to_keep=['Proteins'],\n", + " batch_size=BATCH_SIZE,\n", + " with_termini=False,\n", + " alphabet=aa_to_int_dict)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "test_data_unlabeled[\"test\"][\"Classes\"][0:5]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "O1uYK1ZWr7qZ" + }, + "source": [ + "## 2. Predicting and reporting\n", + "\n", + "We use the previously loaded model to generate predictions on the dataset. If labels are not available, you can utilize the `predictions_report` function to produce a clear and organized report based on these predictions. Note that the `predictions_report` function is specifically designed for scenarios where labels are not present." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### Generate Predictions on Test Data Using `model.predict`\n", + "\n", + "To obtain predictions for your test data, use the Keras `model.predict` method. Simply pass your test dataset to this method, and it will return the model's predictions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "predictions_unlabeled = model.predict(test_data_unlabeled.tensor_test_data)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To generate reports we obtain the data for the specific dataset split. This can be achieved using the `DetectabilityDataset` class directly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# The dataframe needed for the report\n", + "\n", + "test_data_unlabeled_df = pd.DataFrame(\n", + " {\n", + " \"Sequences\": test_data_unlabeled[\"test\"][\"_parsed_sequence\"], # get the raw parsed sequences\n", + " \"Proteins\": test_data_unlabeled[\"test\"][\"Proteins\"] # get the Proteins column from the dataset object\n", + " }\n", + ")\n", + "\n", + "test_data_unlabeled_df.Sequences = test_data_unlabeled_df.Sequences.apply(lambda x: \"\".join(x)) # join the sequences since they are a list of string amino acids.\n", + "test_data_unlabeled_df.head(5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Generate a report using the `predictions_report` class by providing the following parameters:\n", + "\n", + "- **predictions**: The model's output predictions for the dataset.\n", + "- **input_data_df**: The DataFrame containing the input data used for generating the predictions.\n", + "- **output_path**: The path where the generated report (in CSV format) will be saved.\n", + "- **rank_by_prot**: A boolean indicating whether to rank peptides based on their associated proteins (`True` or `False`). Defaults to `False`.\n", + "- **threshold**: The classification threshold used to adjust the decision boundary for predictions. By default, this is set to `None`, meaning no specific threshold is applied.\n", + "\n", + "The `predictions_report` class processes the model’s predictions and generates a comprehensive CSV report with the results, including any specified settings, which facilitates evaluation and interpretation of the predictions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "new_predictions_report = predictions_report(predictions = predictions_unlabeled, \n", + " input_data_df = test_data_unlabeled_df, \n", + " output_path = \"./output/report_on_Sinitcyn_2000_proteins_test_set_unlabeled\", \n", + " rank_by_prot = True,\n", + " threshold = None)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "results_unlabeled_df = new_predictions_report.predictions_report\n", + "results_unlabeled_df" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "name": "Example_RTModel_Walkthrough.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebooks/Example_Detectability_Model_Walkthrough_training_and_fine_tuning.ipynb b/notebooks/Example_Detectability_Model_Walkthrough_training_and_fine_tuning.ipynb new file mode 100644 index 00000000..af8ffafc --- /dev/null +++ b/notebooks/Example_Detectability_Model_Walkthrough_training_and_fine_tuning.ipynb @@ -0,0 +1,1146 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "7YWkUVjVr7qJ" + }, + "source": [ + "# Peptide Detectability (Training and Fine-tuning) \n", + "\n", + "This notebook is prepared to be run in Google [Colaboratory](https://colab.research.google.com/). In order to train the model faster, please change the runtime of Colab to use Hardware Accelerator, either GPU or TPU." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "S3DlTOq3r7qM" + }, + "source": [ + "This notebook provides a concise walkthrough of the process for reading a dataset, training, and fine-tuning a model for peptide detectability prediction. \n", + "\n", + "The dataset used in this example is derived from:\n", + "\n", + "- **ProteomTools Dataset**: Includes data from the PRIDE repository with the following identifiers: `PXD004732`, `PXD010595`, and `PXD021013`.\n", + "- **MAssIVE Dataset**: Deposited in the ProteomeXchange Consortium via the MAssIVE partner repository with the identifier `PXD024364`.\n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Installing the DLOmix Package\n", + "\n", + "If you have not installed the DLOmix package yet, you need to do so before running the code. \n", + "\n", + "You can install the DLOmix package using pip." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "aO-69zbKsGey", + "outputId": "c2064411-9f80-47e6-ca5b-312d547e0f6a", + "scrolled": true + }, + "outputs": [], + "source": [ + "# uncomment the following line to install the DLOmix package in the current environment using pip\n", + "\n", + "#!python -m pip install dlomix>0.1.3" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Mo7H9qzWr7qN" + }, + "source": [ + "#### Importing Required Libraries\n", + "\n", + "Before running the code, ensure you import all the necessary libraries. These imports are essential for accessing the functionalities needed for data processing, model training, and evaluation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "l0CS0tFur7qN", + "outputId": "664e0978-980a-4254-90d1-61e9f1603234" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import tensorflow as tf\n", + "import dlomix\n", + "import sys\n", + "import os" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dlomix.__version__" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "41qXroyKr7qP" + }, + "source": [ + "## 1. Load Data for Training\n", + "\n", + "You can import the `DetectabilityDataset` class and create an instance to manage data for training, validation, and testing. This instance handles TensorFlow dataset objects and simplifies configuring and controlling how your data is preprocessed and split.\n", + "\n", + "For the paramters of the dataset class, please refer to the DLOmix documentation: https://dlomix.readthedocs.io/en/main/dlomix.data.html#\n", + "\n", + "\n", + "**Note**: If class labels are provided, the following encoding scheme should be used:\n", + "- **Non-Flyer**: 0\n", + "- **Weak Flyer**: 1\n", + "- **Intermediate Flyer**: 2\n", + "- **Strong Flyer**: 3" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "RiXz_epEr7qQ" + }, + "outputs": [], + "source": [ + "from dlomix.data import DetectabilityDataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from dlomix.constants import CLASSES_LABELS, alphabet, aa_to_int_dict" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "CLASSES_LABELS, len(alphabet), aa_to_int_dict" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "max_pep_length = 40\n", + "BATCH_SIZE = 128 \n", + " \n", + "# The Class handles all the inner details, we have to provide the column names and the alphabet for encoding\n", + "# If the data is already split with a specific logic (which is generally recommended) -> val_data_source and test_data_source are available as well\n", + "\n", + "hf_data = \"Wilhelmlab/detectability-proteometools\"\n", + "detectability_data = DetectabilityDataset(data_source=hf_data,\n", + " data_format='hub',\n", + " max_seq_len=max_pep_length,\n", + " label_column=\"Classes\",\n", + " sequence_column=\"Sequences\",\n", + " dataset_columns_to_keep=None,\n", + " batch_size=BATCH_SIZE,\n", + " with_termini=False,\n", + " alphabet=aa_to_int_dict)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# This is the dataset with train, val, and test splits \n", + "# You can see the column names under each split (the columns starting with _ are internal, but can also be used to look up original sequences for example \"_parsed_sequence\")\n", + "detectability_data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Accessing elements in the dataset is done by specificing the split name and then the column name\n", + "# Example here for one sequence after encoding & padding comapred to the original sequence\n", + "\n", + "detectability_data[\"train\"][\"Sequences\"][0], \"\".join(detectability_data[\"train\"][\"_parsed_sequence\"][0])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "oWeVi0iar7qT" + }, + "source": [ + "## 2. Model\n", + "\n", + "We can now create the model. The model architecture is an encoder-decoder with an attention mechanism, that is based on Bidirectional Recurrent Neural Network (BRNN) with Gated Recurrent Units (GRU). Both the Encoder and Decoder consists of a single layer, with the Decoder also including a Dense layer. The model has the default working arguments." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Q8SGTvfRr7qT" + }, + "outputs": [], + "source": [ + "from dlomix.models import DetectabilityModel" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ZqrsF6APr7qU" + }, + "outputs": [], + "source": [ + "total_num_classes = len(CLASSES_LABELS)\n", + "input_dimension = len(alphabet)\n", + "num_cells = 64\n", + "\n", + "model = DetectabilityModel(num_units = num_cells, num_clases = total_num_classes)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "adD60VwQr7qU" + }, + "source": [ + "## 3. Training and saving the model\n", + "\n", + "You can train the model using the standard Keras approach. The training parameters provided here are those initially configured for the detectability model. However, you have the flexibility to modify these parameters to suit your specific needs." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Compile the Model\n", + "\n", + "Compile the model with the selected settings. You can use built-in TensorFlow options or define and pass custom settings for the optimizer, loss function, and metrics. The default configurations match those used in the original study, but you can modify these settings according to your preferences.\n", + "\n", + "Early stopping is also configured with the original settings, but the parameters can be adjusted based on user preferences. Early stopping monitors a performance metric (e.g., validation loss) and halts training when no improvement is observed for a specified number of epochs. This feature helps prevent overfitting and ensures efficient training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "xLy32wk7r7qU", + "outputId": "34f9961e-1abc-4f8f-904c-7aac4a404241" + }, + "outputs": [], + "source": [ + "callback = tf.keras.callbacks.EarlyStopping(monitor = 'val_loss', \n", + " mode = 'min', \n", + " verbose = 1, \n", + " patience = 5)\n", + "\n", + "\n", + "model_save_path = 'output/weights/new_base_model/base_model_weights_detectability'\n", + "\n", + "model_checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath=model_save_path,\n", + " monitor='val_sparse_categorical_accuracy',\n", + " mode='max',\n", + " verbose=1,\n", + " save_best_only=True, \n", + " save_weights_only=True)\n", + "\n", + "model.compile(optimizer='adam',\n", + " loss='SparseCategoricalCrossentropy', \n", + " metrics='sparse_categorical_accuracy')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wtEUn_vdr7qV" + }, + "source": [ + "We save the results of the training process to enable a detailed examination of the metrics and losses at a later stage. We define the number of epochs for training and supply the training and validation data previously generated. This approach allows us to effectively monitor the model’s performance and make any necessary adjustments." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "E14EcoYTr7qV", + "outputId": "9c88b2d5-e1cb-46b4-e263-73468e222554", + "scrolled": true + }, + "outputs": [], + "source": [ + "# Access to the tensorflow datasets is done by referencing the tensor_train_data or tensor_val_data\n", + "\n", + "history = model.fit(detectability_data.tensor_train_data,\n", + " validation_data = detectability_data.tensor_val_data,\n", + " epochs = 50, \n", + " callbacks=[callback, model_checkpoint])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "oukZ4AyMr7qV" + }, + "source": [ + "## 4. Testing and Reporting\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We use the test dataset to assess our model's performance, which is only applicable if labels are available. The `DetectabilityReport` class allows us to compute various metrics, generate reports, and create plots for a comprehensive evaluation of the model.\n", + "\n", + "Note: The reporting module is currently under development, so some features may be unstable or subject to change.\n", + "\n", + "In the next cell, set the path to the model weights. By default, it points to the newly trained base model. If using different weights, update the path accordingly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# edit the path to save the trained model\n", + "model_save_path = 'output/weights/new_base_model/base_model_weights_detectability'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "## Loading best model weights \n", + "\n", + "model.load_weights(model_save_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Generate Predictions on Test Data Using `model.predict`\n", + "\n", + "To obtain predictions for your test data, use the Keras `model.predict` method. Simply pass your test dataset to this method, and it will return the model's predictions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "RrvR8Cl3r7qV" + }, + "outputs": [], + "source": [ + "predictions = model.predict(detectability_data.tensor_test_data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "predictions.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To generate reports and calculate evaluation metrics against predictions, we obtain the targets and the data for the specific dataset split. This can be achieved using the `DetectabilityDataset` class directly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# access val dataset and get the Classes column\n", + "test_targets = detectability_data[\"test\"][\"Classes\"]\n", + "\n", + "\n", + "# if needed, the decoded version of the classes can be retrieved by looking up the class names\n", + "test_targets_decoded = [CLASSES_LABELS[x] for x in test_targets]\n", + "\n", + "\n", + "test_targets[0:5], test_targets_decoded[0:5]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# The dataframe needed for the report\n", + "\n", + "test_data_df = pd.DataFrame(\n", + " {\n", + " \"Sequences\": detectability_data[\"test\"][\"_parsed_sequence\"], # get the raw parsed sequences\n", + " \"Classes\": test_targets, # get the test targets from above\n", + "# \"Proteins\": detectability_data[\"test\"][\"Proteins\"] # get the Proteins column from the dataset object (if the dataset has \"Proteins\" column)\n", + " }\n", + ")\n", + "\n", + "test_data_df.Sequences = test_data_df.Sequences.apply(lambda x: \"\".join(x)) # join the sequences since they are a list of string amino acids.\n", + "test_data_df.head(5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4kzCh0gwr7qX" + }, + "outputs": [], + "source": [ + "from dlomix.reports.DetectabilityReport import DetectabilityReport, predictions_report\n", + "WANDB_REPORT_API_DISABLE_MESSAGE=True" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Generate a Report Using the `DetectabilityReport` Class\n", + "\n", + "The `DetectabilityReport` class provides a comprehensive way to evaluate your model by generating detailed reports and visualizations. The outputs include:\n", + "\n", + "1. **A PDF Report**: This includes evaluation metrics and plots.\n", + "2. **A CSV File**: Contains the model’s predictions.\n", + "3. **Independent Image Files**: Visualizations are saved as separate image files.\n", + "\n", + "To generate a report, provide the following parameters to the `DetectabilityReport` class:\n", + "\n", + "- **targets**: The true labels for the dataset, which are used to assess the model’s performance.\n", + "- **predictions**: The model’s output predictions for the dataset, which will be compared against the true labels.\n", + "- **input_data_df**: The DataFrame containing the input data used for generating predictions.\n", + "- **output_path**: The directory path where the generated reports, images, and CSV file will be saved.\n", + "- **history**: The training history object (e.g., containing metrics from training) if available. Set this to `None` if not applicable, such as when the report is generated for predictions without training.\n", + "- **rank_by_prot**: A boolean indicating whether to rank peptides based on their associated proteins (`True` or `False`). Defaults to `False`.\n", + "- **threshold**: The classification threshold used to adjust the decision boundary for predictions. By default, this is set to `None`, meaning no specific threshold is applied.\n", + "- **name_of_dataset**: The name of the dataset used for generating predictions, which will be included in the report to provide context.\n", + "- **name_of_model**: The name of the model used to generate the predictions, which will be specified in the report for reference.\n", + "\n", + "Note: The reporting module is currently under development, so some features may be unstable or subject to change." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Since the detectabiliy report expects the true labels in one-hot encoded format, we expand them here.\n", + "\n", + "num_classes = np.max(test_targets) + 1\n", + "test_targets_one_hot = np.eye(num_classes)[test_targets]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_7LJZ3TLr7qX" + }, + "outputs": [], + "source": [ + "report = DetectabilityReport(targets = test_targets_one_hot, \n", + " predictions = predictions, \n", + " input_data_df = test_data_df,\n", + " output_path = \"./output/report_on_ProteomeTools\",\n", + " history = history, \n", + " rank_by_prot = False,\n", + " threshold = None,\n", + " name_of_dataset = 'ProteomeTools',\n", + " name_of_model = 'Base model (new)')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Predictions report" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "results_df = report.detectability_report_table\n", + "results_df.head(5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Generating Evaluation Plots with `DetectabilityReport`\n", + "\n", + "The `DetectabilityReport` class enables you to generate a range of plots to visualize and evaluate model performance. It offers a comprehensive suite of visualizations to help you interpret the results of your model's predictions. Here’s how to use it:" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### Training and Validation Metrics\n", + "\n", + "These plots show the training and validation metrics over epochs. The first plot displays the loss, and the second shows the categorical accuracy. Both plots are generated from the `history` object recorded during the model training process." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 295 + }, + "id": "1iI-_Nufr7qX", + "outputId": "25baa9f5-1d5b-47ed-d75a-def6a55e43bc" + }, + "outputs": [], + "source": [ + "report.plot_keras_metric(\"loss\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "report.plot_keras_metric(\"sparse_categorical_accuracy\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### ROC curve (Binary)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "report.plot_roc_curve_binary()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### Confusion matrix (Binary)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "report.plot_confusion_matrix_binary()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### ROC curve (Multi-class)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "report.plot_roc_curve()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### Confusion matrix (Multi-class)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "report.plot_confusion_matrix_multiclass()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Heatmap of Average Error Between Actual and Predicted Classes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "report.plot_heatmap_prediction_prob_error()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also produce a complete report with all the relevant plots in one PDF file by calling the `generate_report` function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "report.generate_report()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Example: Defining a Classification Threshold\n", + "\n", + "In the following example, a specific classification threshold is defined to adjust the decision boundary for the model's predictions. By setting a threshold, you can control the sensitivity of the model, influencing how it categorizes the output into different classes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "report_using_threshold = DetectabilityReport(test_targets_one_hot, \n", + " predictions, \n", + " test_data_df, \n", + " output_path = \"./output/report_on_ProteomeTools_with_threshold\", \n", + " history = history, \n", + " rank_by_prot = False,\n", + " threshold = 0.5, \n", + " name_of_dataset = 'ProteomeTools',\n", + " name_of_model = 'Base model (new) with threshold')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Predictions report " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "report_using_threshold.detectability_report_table.head(5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Generating a complete PDF report using the `generate_report` function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "report_using_threshold.generate_report()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Load data for fine tuning" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For fine-tuning, the process mirrors the steps used during training. Simply create a `DetectabilityDataset` object with the fine-tuning data (refer to **Section 1: Load Data for Training**)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "max_pep_length = 40\n", + "BATCH_SIZE = 128 \n", + " \n", + "# The Class handles all the inner details, we have to provide the column names and the alphabet for encoding\n", + "# If the data is already split with a specific logic (which is generally recommended) -> val_data_source and test_data_source are available as well\n", + "\n", + "hf_data = \"Wilhelmlab/detectability-sinitcyn\"\n", + "fine_tune_data = DetectabilityDataset(data_source=hf_data,\n", + " data_format='hub',\n", + " max_seq_len=max_pep_length,\n", + " label_column=\"Classes\",\n", + " sequence_column=\"Sequences\",\n", + " dataset_columns_to_keep=['Proteins'],\n", + " batch_size=BATCH_SIZE,\n", + " with_termini=False,\n", + " alphabet=aa_to_int_dict)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fine_tune_data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Fine tuning the model\n", + "\n", + "In the next cell, we create the model and load its weights for fine-tuning. By default, the path is set to the weights of the most recently trained base model. To use different weights, update the path to point to your desired model's weights." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# define again if not in environment from training run\n", + "load_model_path = model_save_path #'output/weights/new_base_model/base_model_weights_detectability'\n", + "\n", + "fine_tuned_model = DetectabilityModel(num_units = num_cells, \n", + " num_clases = total_num_classes)\n", + "\n", + "fine_tuned_model.load_weights(load_model_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Compile the Model\n", + "\n", + "Compile the model with the selected settings. You can use built-in TensorFlow options or define and pass custom settings for the optimizer, loss function, and metrics. The default configurations match those used in the original study, but you can modify these settings according to your preferences.Early stopping is also configured with the original settings, but the parameters can be adjusted based on user preferences." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# compile the model with the optimizer and the metrics we want to use.\n", + "callback_FT = tf.keras.callbacks.EarlyStopping(monitor = 'val_loss', \n", + " mode = 'min', \n", + " verbose = 1, \n", + " patience = 5)\n", + "\n", + "\n", + "model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability'\n", + "\n", + "model_checkpoint_FT = tf.keras.callbacks.ModelCheckpoint(filepath=model_save_path_FT,\n", + " monitor='val_sparse_categorical_accuracy', \n", + " mode='max',\n", + " verbose=1,\n", + " save_best_only=True, \n", + " save_weights_only=True)\n", + "\n", + "fine_tuned_model.compile(optimizer='adam',\n", + " loss='SparseCategoricalCrossentropy', \n", + " metrics='sparse_categorical_accuracy') " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We store the result of training so that we can explore the metrics and the losses later. We specify the number of epochs for training and pass the training and validation data as previously described." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "history_fine_tuned = fine_tuned_model.fit(fine_tune_data.tensor_train_data,\n", + " validation_data=fine_tune_data.tensor_val_data,\n", + " epochs=50, \n", + " callbacks=[callback_FT, model_checkpoint_FT])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Testing and Reporting (Fine-Tuned Model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the following cell, we load the best model weights obtained from fine-tuning. By default, the path points to the most recently fine-tuned model from the previous cell. Update the path if you wish to load different weights." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "## Loading best model weights \n", + "\n", + "model_save_path_FT = 'output/weights/new_fine_tuned_model/fine_tuned_model_weights_detectability'\n", + "\n", + "fine_tuned_model.load_weights(model_save_path_FT)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Generating predictions on the test data using the fine-tuned model with `model.predict`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "predictions_FT = fine_tuned_model.predict(fine_tune_data.tensor_test_data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "predictions_FT.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To generate reports and calculate evaluation metrics against predictions, we obtain the targets and the data for the specific dataset split. This can be achieved using the DetectabilityDataset class directly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# access val dataset and get the Classes column\n", + "test_targets_FT = fine_tune_data[\"test\"][\"Classes\"]\n", + "\n", + "\n", + "# if needed, the decoded version of the classes can be retrieved by looking up the class names\n", + "test_targets_decoded_FT = [CLASSES_LABELS[x] for x in test_targets_FT]\n", + "\n", + "\n", + "test_targets_FT[0:5], test_targets_decoded_FT[0:5]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# The dataframe needed for the report\n", + "\n", + "test_data_df_FT = pd.DataFrame(\n", + " {\n", + " \"Sequences\": fine_tune_data[\"test\"][\"_parsed_sequence\"], # get the raw parsed sequences\n", + " \"Classes\": test_targets_FT, # get the test targets from above\n", + " \"Proteins\": fine_tune_data[\"test\"][\"Proteins\"] # get the Proteins column from the dataset object\n", + " }\n", + ")\n", + "\n", + "test_data_df_FT.Sequences = test_data_df_FT.Sequences.apply(lambda x: \"\".join(x)) # join the sequences since they are a list of string amino acids.\n", + "test_data_df_FT.head(5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Creating a report object with the test targets, predictions, and history to generate metrics and plots for the fine-tuned model. For more details, refer to Section 4: Testing and Reporting, which provides a detailed description of the same process for the initial or base model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Since the detectabiliy report expects the true labels in one-hot encoded format, we expand them here. \n", + "\n", + "num_classes = np.max(test_targets_FT) + 1\n", + "test_targets_FT_one_hot = np.eye(num_classes)[test_targets_FT]\n", + "test_targets_FT_one_hot.shape, len(test_targets_FT)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "report_FT = DetectabilityReport(test_targets_FT_one_hot, \n", + " predictions_FT, \n", + " test_data_df_FT, \n", + " output_path = './output/report_on_Sinitcyn (Fine-tuned model)', \n", + " history = history_fine_tuned, \n", + " rank_by_prot = True,\n", + " threshold = None, \n", + " name_of_dataset = 'Sinitcyn test dataset',\n", + " name_of_model = 'Fine tuned model (new)')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Predictions report (Fine-tuned model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "results_df_FT = report_FT.detectability_report_table\n", + "results_df_FT" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Generating a complete PDF report using the `generate_report` function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "report_FT.generate_report()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Generating the Evaluation Plots for the Fine-Tuned Model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### Training and Validation Metrics\n", + "\n", + "These plots show the training and validation metrics over epochs. The first plot displays the loss, and the second shows the categorical accuracy." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "report_FT.plot_keras_metric(\"loss\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "report_FT.plot_keras_metric(\"sparse_categorical_accuracy\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### ROC curve (Binary)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "report_FT.plot_roc_curve_binary()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### Confusion matrix (Binary)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "report_FT.plot_confusion_matrix_binary()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### ROC curve (Multi-class)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "report_FT.plot_roc_curve()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### Confusion matrix (Multi-class)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "report_FT.plot_confusion_matrix_multiclass()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Heatmap of Average Error Between Actual and Predicted Classes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "report_FT.plot_heatmap_prediction_prob_error()" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "name": "Example_RTModel_Walkthrough.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/pretrained_models/original_detectability_base_model/checkpoint b/pretrained_models/original_detectability_base_model/checkpoint new file mode 100644 index 00000000..98188d7d --- /dev/null +++ b/pretrained_models/original_detectability_base_model/checkpoint @@ -0,0 +1,2 @@ +model_checkpoint_path: "base_attention_model_es_final" +all_model_checkpoint_paths: "base_attention_model_es_final" diff --git a/pretrained_models/original_detectability_base_model/original_detectability_base_model.data-00000-of-00001 b/pretrained_models/original_detectability_base_model/original_detectability_base_model.data-00000-of-00001 new file mode 100644 index 00000000..920e403e Binary files /dev/null and b/pretrained_models/original_detectability_base_model/original_detectability_base_model.data-00000-of-00001 differ diff --git a/pretrained_models/original_detectability_base_model/original_detectability_base_model.index b/pretrained_models/original_detectability_base_model/original_detectability_base_model.index new file mode 100644 index 00000000..13e8c1bd Binary files /dev/null and b/pretrained_models/original_detectability_base_model/original_detectability_base_model.index differ diff --git a/pretrained_models/original_detectability_fine_tuned_model_FINAL/checkpoint b/pretrained_models/original_detectability_fine_tuned_model_FINAL/checkpoint new file mode 100644 index 00000000..71c50630 --- /dev/null +++ b/pretrained_models/original_detectability_fine_tuned_model_FINAL/checkpoint @@ -0,0 +1,2 @@ +model_checkpoint_path: "fine_tuned_weights_attention_model_FINAL_NON" +all_model_checkpoint_paths: "fine_tuned_weights_attention_model_FINAL_NON" diff --git a/pretrained_models/original_detectability_fine_tuned_model_FINAL/original_detectability_fine_tuned_model_FINAL.data-00000-of-00001 b/pretrained_models/original_detectability_fine_tuned_model_FINAL/original_detectability_fine_tuned_model_FINAL.data-00000-of-00001 new file mode 100644 index 00000000..c3298ea3 Binary files /dev/null and b/pretrained_models/original_detectability_fine_tuned_model_FINAL/original_detectability_fine_tuned_model_FINAL.data-00000-of-00001 differ diff --git a/pretrained_models/original_detectability_fine_tuned_model_FINAL/original_detectability_fine_tuned_model_FINAL.index b/pretrained_models/original_detectability_fine_tuned_model_FINAL/original_detectability_fine_tuned_model_FINAL.index new file mode 100644 index 00000000..2057cec2 Binary files /dev/null and b/pretrained_models/original_detectability_fine_tuned_model_FINAL/original_detectability_fine_tuned_model_FINAL.index differ diff --git a/src/dlomix/__init__.py b/src/dlomix/__init__.py index 4154a85e..5da66068 100644 --- a/src/dlomix/__init__.py +++ b/src/dlomix/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.1.3" +__version__ = "0.1.3dev" META_DATA = { "author": "Omar Shouman", diff --git a/src/dlomix/constants.py b/src/dlomix/constants.py index 2f3bfc0e..71d087be 100644 --- a/src/dlomix/constants.py +++ b/src/dlomix/constants.py @@ -88,3 +88,40 @@ "P[UNIMOD:35]": 53, "Y[UNIMOD:354]": 54, } + + +# ---- detectability_model_constants.py ---- +CLASSES_LABELS = ["Non-Flyer", "Weak Flyer", "Intermediate Flyer", "Strong Flyer"] + +alphabet = [ + "0", + "A", + "C", + "D", + "E", + "F", + "G", + "H", + "I", + "K", + "L", + "M", + "N", + "P", + "Q", + "R", + "S", + "T", + "V", + "W", + "Y", +] + +aa_to_int_dict = dict((aa, i) for i, aa in enumerate(alphabet)) + +int_to_aa_dict = dict((i, aa) for i, aa in enumerate(alphabet)) + +import numpy as np + +padding_char = np.zeros(len(alphabet)) +padding_char[0] = 1 diff --git a/src/dlomix/data/__init__.py b/src/dlomix/data/__init__.py index 048fa35f..f3d02e00 100644 --- a/src/dlomix/data/__init__.py +++ b/src/dlomix/data/__init__.py @@ -1,5 +1,6 @@ from .charge_state import ChargeStateDataset from .dataset import PeptideDataset, load_processed_dataset +from .detectability import DetectabilityDataset from .fragment_ion_intensity import FragmentIonIntensityDataset from .retention_time import RetentionTimeDataset @@ -9,4 +10,5 @@ "ChargeStateDataset", "PeptideDataset", "load_processed_dataset", + "DetectabilityDataset", ] diff --git a/src/dlomix/data/dataset.py b/src/dlomix/data/dataset.py index 42aa8614..dfa2ce10 100644 --- a/src/dlomix/data/dataset.py +++ b/src/dlomix/data/dataset.py @@ -322,8 +322,7 @@ def _decide_on_splitting(self): # two or more data sources provided -> no splitting in all cases if count_loaded_data_sources >= 2: - if self.val_data_source is not None: - self._is_predefined_split = True + self._is_predefined_split = True if self._is_predefined_split: warnings.warn( diff --git a/src/dlomix/data/detectability.py b/src/dlomix/data/detectability.py new file mode 100644 index 00000000..3bbfa42e --- /dev/null +++ b/src/dlomix/data/detectability.py @@ -0,0 +1,65 @@ +from typing import Callable, Dict, List, Optional, Union + +from ..constants import ALPHABET_UNMOD +from .dataset import PeptideDataset +from .dataset_config import DatasetConfig +from .dataset_utils import EncodingScheme + + +class DetectabilityDataset(PeptideDataset): + """ + A dataset class for handling Detectability prediction data. + + Args: + data_source (Optional[Union[str, List]]): The path or list of paths to the data source file(s). + val_data_source (Optional[Union[str, List]]): The path or list of paths to the validation data source file(s). + test_data_source (Optional[Union[str, List]]): The path or list of paths to the test data source file(s). + data_format (str): The format of the data source file(s). Default is "parquet". + sequence_column (str): The name of the column containing the peptide sequences. Default is "Sequences". + label_column (str): The name of the column containing the class labels. Default is "Classes". + val_ratio (float): The ratio of validation data to split from the training data. Default is 0.2. + max_seq_len (Union[int, str]): The maximum length of the peptide sequences. Default is 30. + dataset_type (str): The type of dataset to use. Default is "tf". + batch_size (int): The batch size for training and evaluation. Default is 256. + model_features (Optional[List[str]]): The list of features to use for the model. Default is None. + dataset_columns_to_keep (Optional[List[str]]): The list of columns to keep in the dataset. Default is ["Proteins"]. + features_to_extract (Optional[List[Union[Callable, str]]]): The list of features to extract from the dataset. Default is None. + pad (bool): Whether to pad the sequences to the maximum length. Default is True. + padding_value (int): The value to use for padding. Default is 0. + alphabet (Dict): The mapping of characters to integers for encoding the sequences. Default is ALPHABET_UNMOD. + with_termini (bool): Whether to add the N- and C-termini in the sequence column, even if they do not exist. Defaults to True. + encoding_scheme (Union[str, EncodingScheme]): The encoding scheme to use for encoding the sequences. Default is EncodingScheme.UNMOD. + processed (bool): Whether the data has been preprocessed. Default is False. + enable_tf_dataset_cache (bool): Flag to indicate whether to enable TensorFlow Dataset caching (call `.cahce()` on the generate TF Datasets). + disable_cache (bool): Whether to disable Hugging Face datasets caching. Default is False. + """ + + def __init__( + self, + data_source: Optional[Union[str, List]] = None, + val_data_source: Optional[Union[str, List]] = None, + test_data_source: Optional[Union[str, List]] = None, + data_format: str = "csv", + sequence_column: str = "Sequences", + label_column: str = "Classes", + val_ratio: float = 0.2, + max_seq_len: Union[int, str] = 40, + dataset_type: str = "tf", + batch_size: int = 256, + model_features: Optional[List[str]] = None, + dataset_columns_to_keep: Optional[List[str]] = ["Proteins"], + features_to_extract: Optional[List[Union[Callable, str]]] = None, + pad: bool = True, + padding_value: int = 0, + alphabet: Dict = ALPHABET_UNMOD, + with_termini: bool = True, + encoding_scheme: Union[str, EncodingScheme] = EncodingScheme.UNMOD, + processed: bool = False, + enable_tf_dataset_cache: bool = False, + disable_cache: bool = False, + auto_cleanup_cache: bool = True, + num_proc: Optional[int] = None, + batch_processing_size: int = 1000, + ): + kwargs = {k: v for k, v in locals().items() if k not in ["self", "__class__"]} + super().__init__(DatasetConfig(**kwargs)) diff --git a/src/dlomix/detectability_model_constants.py b/src/dlomix/detectability_model_constants.py new file mode 100644 index 00000000..807aef84 --- /dev/null +++ b/src/dlomix/detectability_model_constants.py @@ -0,0 +1,12 @@ +import numpy as np + +CLASSES_LABELS = ['Non-Flyer', 'Weak Flyer', 'Intermediate Flyer', 'Strong Flyer'] + +alphabet = ['0', 'A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y'] + +aa_to_int_dict = dict((aa, i) for i, aa in enumerate(alphabet)) + +int_to_aa_dict = dict((i, aa) for i, aa in enumerate(alphabet)) + +padding_char = np.zeros(len(alphabet)) +padding_char[0] = 1 \ No newline at end of file diff --git a/src/dlomix/models/__init__.py b/src/dlomix/models/__init__.py index 5f6c4df6..8bf7ec47 100644 --- a/src/dlomix/models/__init__.py +++ b/src/dlomix/models/__init__.py @@ -1,5 +1,6 @@ from .base import * from .deepLC import * +from .detectability import * from .prosit import * __all__ = [ @@ -7,4 +8,5 @@ "PrositRetentionTimePredictor", "DeepLCRetentionTimePredictor", "PrositIntensityPredictor", + "DetectabilityModel", ] diff --git a/src/dlomix/models/detectability.py b/src/dlomix/models/detectability.py new file mode 100644 index 00000000..18fe7186 --- /dev/null +++ b/src/dlomix/models/detectability.py @@ -0,0 +1,140 @@ +import tensorflow as tf + +from ..constants import CLASSES_LABELS, padding_char + + +class DetectabilityModel(tf.keras.Model): + def __init__( + self, + num_units, + num_clases=len(CLASSES_LABELS), + name="autoencoder", + padding_char=padding_char, + **kwargs + ): + super(DetectabilityModel, self).__init__(name=name, **kwargs) + + self.num_units = num_units + self.num_clases = num_clases + self.padding_char = padding_char + self.alphabet_size = len(padding_char) + self.one_hot_encoder = tf.keras.layers.Lambda( + lambda x: tf.one_hot(tf.cast(x, "int32"), depth=self.alphabet_size) + ) + self.encoder = Encoder(self.num_units) + self.decoder = Decoder(self.num_units, self.num_clases) + + def call(self, inputs): + onehot_inputs = self.one_hot_encoder(inputs) + enc_outputs, enc_state_f, enc_state_b = self.encoder(onehot_inputs) + + dec_outputs = tf.concat([enc_state_f, enc_state_b], axis=-1) + + decoder_inputs = { + "decoder_outputs": dec_outputs, + "state_f": enc_state_f, + "state_b": enc_state_b, + "encoder_outputs": enc_outputs, + } + + decoder_output = self.decoder(decoder_inputs) + + return decoder_output + + +class Encoder(tf.keras.layers.Layer): + def __init__(self, units, name="encoder", **kwargs): + super(Encoder, self).__init__(name=name, **kwargs) + + self.units = units + + self.mask_enco = tf.keras.layers.Masking(mask_value=padding_char) + + self.encoder_gru = tf.keras.layers.GRU( + self.units, + return_sequences=True, + return_state=True, + recurrent_initializer="glorot_uniform", + ) + + self.encoder_bi = tf.keras.layers.Bidirectional(self.encoder_gru) + + def call(self, inputs): + mask_ = self.mask_enco.compute_mask(inputs) + + mask_bi = self.encoder_bi.compute_mask(inputs, mask_) + + encoder_outputs, encoder_state_f, encoder_state_b = self.encoder_bi( + inputs, initial_state=None, mask=mask_bi + ) + + return encoder_outputs, encoder_state_f, encoder_state_b + + +class BahdanauAttention(tf.keras.layers.Layer): + def __init__(self, units, name="attention_layer", **kwargs): + super(BahdanauAttention, self).__init__(name=name, **kwargs) + self.W1 = tf.keras.layers.Dense(units) + self.W2 = tf.keras.layers.Dense(units) + self.V = tf.keras.layers.Dense(1) + + def call(self, inputs): + query = inputs["query"] + values = inputs["values"] + + query_with_time_axis = tf.expand_dims(query, axis=1) + + scores = self.V(tf.nn.tanh(self.W1(query_with_time_axis) + self.W2(values))) + + attention_weights = tf.nn.softmax(scores, axis=1) + + context_vector = attention_weights * values + + context_vector = tf.reduce_sum(context_vector, axis=1) + + return context_vector + + +class Decoder(tf.keras.layers.Layer): + def __init__(self, units, num_classes, name="decoder", **kwargs): + super(Decoder, self).__init__(name=name, **kwargs) + self.units = units + self.num_classes = num_classes + + self.decoder_gru = tf.keras.layers.GRU( + self.units, return_state=True, recurrent_initializer="glorot_uniform" + ) + + self.attention = BahdanauAttention(self.units) + + self.decoder_bi = tf.keras.layers.Bidirectional(self.decoder_gru) + + self.decoder_dense = tf.keras.layers.Dense( + self.num_classes, activation=tf.nn.softmax + ) + + def call(self, inputs): + decoder_outputs = inputs["decoder_outputs"] + state_f = inputs["state_f"] + state_b = inputs["state_b"] + encoder_outputs = inputs["encoder_outputs"] + + states = [state_f, state_b] + + attention_inputs = {"query": decoder_outputs, "values": encoder_outputs} + + context_vector = self.attention(attention_inputs) + + context_vector = tf.expand_dims(context_vector, axis=1) + + x = context_vector + + ( + decoder_outputs, + decoder_state_forward, + decoder_state_backward, + ) = self.decoder_bi(x, initial_state=states) + + x = self.decoder_dense(decoder_outputs) + # x = tf.expand_dims(x, axis = 1) + return x diff --git a/src/dlomix/reports/DetectabilityReport.py b/src/dlomix/reports/DetectabilityReport.py new file mode 100644 index 00000000..af55d7d1 --- /dev/null +++ b/src/dlomix/reports/DetectabilityReport.py @@ -0,0 +1,818 @@ +# -*- coding: utf-8 -*- + +import os +from itertools import cycle +from os.path import join + +import numpy as np +import pandas as pd +import seaborn as sns +from matplotlib import pyplot as plt +from sklearn.metrics import ConfusionMatrixDisplay, auc, confusion_matrix, roc_curve + +from ..constants import CLASSES_LABELS +from .Report import PDFFile, Report + + +class DetectabilityReport(Report): + """Report generation for Detectability Prediction tasks.""" + + def __init__( + self, + targets, + predictions, + input_data_df, + output_path, + history, + rank_by_prot=False, + threshold=None, + figures_ext="png", + name_of_dataset="unspecified", + name_of_model="unspecified", + ): + super(DetectabilityReport, self).__init__(output_path, history, figures_ext) + + self.pdf_file = PDFFile("DLOmix - Detectability Report") + + self.predictions = predictions + self.test_size = self.predictions.shape[ + 0 + ] # Removing the last part of the test data which don't fit the batch size + self.targets = targets[: self.test_size] + self.input_data_df = input_data_df.loc[: self.test_size - 1] + self.output_path = output_path + self.rank_by_prot = rank_by_prot + self.threshold = threshold + self.name_of_dataset = name_of_dataset + self.name_of_model = name_of_model + self.results_metrics_dict = None + self.results_report_df = None + self.detectability_report_table = None + + if not os.path.exists(self.output_path): + os.makedirs(self.output_path) + + self.results_dict_and_df() + + def generate_report(self, **kwargs): + self._init_report_resources() + + self._add_report_resource( + "name of dataset", + "Dataset", + f"The dataset used is {self.name_of_dataset}\n", + self.name_of_dataset, + ) + + self._add_report_resource( + "name of model", + "Model", + f"The model used to make the prediction is {self.name_of_model}\n", + self.name_of_model, + ) + self._add_report_resource( + "binary_accuracy", + "Binary accuracy", + f"The Binary Accuracy value for the predictions is {round(self.results_metrics_dict['binary_accuracy'], 4)}\n", + self.results_metrics_dict["binary_accuracy"], + ) + + self._add_report_resource( + "categorical_accuracy", + "Categorical Accuracy", + f"The Categorical Accuracy value for the predictions is {round(self.results_metrics_dict['categorical_accuracy'], 4)}\n", + self.results_metrics_dict["categorical_accuracy"], + ) + + self._add_report_resource( + "true_positive_rate", + "True Positive Rate (Recall)", + f"The True Positive Rate (Recall) value for the predictions is {round(self.results_metrics_dict['true_positive_rate'], 4)}\n", + self.results_metrics_dict["true_positive_rate"], + ) + + self._add_report_resource( + "false_positive_rate", + "False Positive Rate (Specificity)", + f"The False Positive Rate (Specificity) value for the predictions is {round(self.results_metrics_dict['false_positive_rate'], 4)}\n", + self.results_metrics_dict["false_positive_rate"], + ) + + self._add_report_resource( + "precision", + "Precision", + f"The Presicion value for the predictions is {round(self.results_metrics_dict['precision'], 4)}\n", + self.results_metrics_dict["precision"], + ) + + self._add_report_resource( + "f1_score", + "F1 Score", + f"The F1 Score value for the predictions is {round(self.results_metrics_dict['f1_score'], 4)}\n", + self.results_metrics_dict["f1_score"], + ) + + self._add_report_resource( + "MCC", + "Matthews Correlation Coefficient (MCC)", + f"The Matthews Correlation Coefficient (MCC) value for the predictions is {round(self.results_metrics_dict['MCC'], 4)}\n", + self.results_metrics_dict["MCC"], + ) + + self.plot_all_metrics() + self.plot_roc_curve_binary() + self.plot_confusion_matrix_binary() + self.plot_roc_curve() + self.plot_confusion_matrix_multiclass() + self.plot_heatmap_prediction_prob_error() + self._compile_report_resources_add_pdf_pages() + self.pdf_file.output( + join(self._output_path, "Detectability_evaluation_report.pdf"), "F" + ) + + def results_dict_and_df(self): + eval_result = evaluation_results( + self.predictions, self.targets, threshold=self.threshold + ) + self.results_metrics_dict = eval_result.eval_results + + target_labels = { + 0: "Non-Flyer", + 1: "Weak Flyer", + 2: "Intermediate Flyer", + 3: "Strong Flyer", + } + binary_labels = {0: "Non-Flyer", 1: "Flyer"} + + df_data_results = self.input_data_df.copy().reset_index(drop=True) + df_data_results = pd.concat( + [ + df_data_results, + pd.DataFrame( + self.results_metrics_dict["predictions"], columns=["Predictions"] + ), + ], + axis=1, + ) + + for i, label in enumerate(CLASSES_LABELS): + df_data_results[label] = np.round_( + np.array(self.results_metrics_dict["probabilities"])[:, i], decimals=3 + ) + + df_data_results["Flyer"] = ( + df_data_results["Weak Flyer"] + + df_data_results["Intermediate Flyer"] + + df_data_results["Strong Flyer"] + ) + # df_data_results['Flyer'] = round(df_data_results['Flyer'], ndigits = 3) + df_data_results["Binary Predictions"] = np.where( + df_data_results["Predictions"] == 0, 0, 1 + ) + df_data_results["Binary Classes"] = np.where( + df_data_results["Classes"] == 0, 0, 1 + ) + + sorted_columns = [ + "Sequences", + "Proteins", + "Weak Flyer", + "Intermediate Flyer", + "Strong Flyer", + "Non-Flyer", + "Flyer", + "Classes", + "Predictions", + "Binary Classes", + "Binary Predictions", + ] + + all_columns = [x for x in sorted_columns if x in df_data_results.columns] + + df_data_results = df_data_results[all_columns] + df_final_results = df_data_results.copy() + + if "Proteins" in df_final_results.columns and self.rank_by_prot: + df_final_results = df_final_results.sort_values( + by=[ + "Proteins", + "Flyer", + "Predictions", + "Strong Flyer", + "Intermediate Flyer", + "Weak Flyer", + ], + ascending=[True, False, False, False, False, False], + ).reset_index(drop=True) + + df_final_results["Rank"] = ( + df_final_results.groupby("Proteins")["Flyer"] + .rank(ascending=False, method="first") + .astype(int) + ) + # df_final_results['Rank_2'] = df_final_results.groupby('Proteins')['Flyer'].rank(ascending = False, method = 'dense').astype(int) + + else: + df_final_results = df_final_results.sort_values( + by=[ + "Flyer", + "Predictions", + "Strong Flyer", + "Intermediate Flyer", + "Weak Flyer", + ], + ascending=[False, False, False, False, False], + ).reset_index(drop=True) + + df_final_results["Rank"] = ( + df_final_results["Flyer"] + .rank(ascending=False, method="first") + .astype(int) + ) + # df_final_results['Rank_2'] = df_final_results['Flyer'].rank(ascending = False, method = 'dense').astype(int) + + df_final_results["Classes"] = df_final_results["Classes"].map(target_labels) + df_final_results["Binary Classes"] = df_final_results["Binary Classes"].map( + binary_labels + ) + df_final_results["Predictions"] = df_final_results["Predictions"].map( + target_labels + ) + df_final_results["Binary Predictions"] = df_final_results[ + "Binary Predictions" + ].map(binary_labels) + + self.results_report_df = df_data_results + self.detectability_report_table = df_final_results + save_path_ = join(self.output_path, "Dectetability_prediction_report.csv") + self.detectability_report_table.to_csv(save_path_, index=False) + + def plot_roc_curve_binary(self): + """Plot ROC curve (Binary classification) + + Arguments + ---------- + binary_targets: Array with binary target values + binary_predictions_prob: Array with binary prediction probability values + """ + + fpr, tpr, thresholds = roc_curve( + np.array(self.results_report_df["Binary Classes"]), + np.array(self.results_report_df["Flyer"]), + ) + AUC_score = auc(fpr, tpr) + + # create ROC curve + + plt.plot(fpr, tpr) + plt.title( + "Receiver operating characteristic curve (Binary classification)", + y=1.04, + fontsize=10, + ) + plt.ylabel("True Positive Rate") + plt.xlabel("False Positive Rate") + save_path = join( + self._output_path, "ROC_curve_binary_classification" + self._figures_ext + ) + + plt.savefig(save_path, bbox_inches="tight", dpi=90) + plt.show() + plt.close() + + self._add_report_resource( + "roc_curve_binary", + "ROC curve (Binary classification)", + "The following plot shows the Receiver operating characteristic (ROC) curve for the binary classification.", + save_path, + ) + + self._add_report_resource( + "AUC_binary_score", + "AUC Binary Score", + f"The AUC score value for the binary classification is {round(AUC_score, 4)}", + AUC_score, + ) + + def plot_confusion_matrix_binary(self): + """Plot confusion matrix (Binary classification) + + Arguments + ---------- + binary_targets: Array-like of shape (n_samples,) with binary target values + binary_predictions_prob: Array-like of shape (n_samples,) with binary prediction classes (not probabilities) values + """ + conf_matrix = confusion_matrix( + self.results_report_df["Binary Classes"], + self.results_report_df["Binary Predictions"], + ) + + conf_matrix + conf_matrix_disp = ConfusionMatrixDisplay( + confusion_matrix=conf_matrix, display_labels=["Non-Flyer", "Flyer"] + ) + fig, ax = plt.subplots() + conf_matrix_disp.plot(xticks_rotation=45, ax=ax) + plt.title("Confusion Matrix (Binary Classification)", y=1.04, fontsize=11) + save_path = join( + self._output_path, "confusion_matrix_binary" + self._figures_ext + ) + plt.savefig(save_path, bbox_inches="tight", dpi=80) + plt.show() + plt.close() + + self._add_report_resource( + "confusion_matrix_binary", + "Confusion Matrix (Binary Classification)", + "The following plot shows the Confusion Matrix (Binary Classification).", + save_path, + ) + + def plot_roc_curve(self): + """Plot ROC curve (Multiclass classification) + + Arguments + ---------- + multiclass_targets: Array with multiclass targets values + multiclass_predictions_prob: Array with multiclass prediction probability values + """ + # Compute ROC curve and ROC area for each class + fpr = {} + tpr = {} + roc_auc = {} + + for i in range(len(CLASSES_LABELS)): + fpr[i], tpr[i], _ = roc_curve( + np.squeeze(self.targets)[:, i], np.squeeze(self.predictions)[:, i] + ) + roc_auc[i] = auc(fpr[i], tpr[i]) + + lw = 2 + colors = cycle(["blue", "red", "green", "gold"]) + for i, color in zip(range(len(CLASSES_LABELS)), colors): + plt.plot( + fpr[i], + tpr[i], + color=color, + lw=2, + label="ROC curve of class {0} (area = {1:0.2f})" + "".format(CLASSES_LABELS[i], roc_auc[i]), + ) + + plt.plot([0, 1], [0, 1], "k--", lw=lw) + plt.xlim([-0.05, 1.0]) + plt.ylim([0.0, 1.05]) + plt.xlabel("False Positive Rate") + plt.ylabel("True Positive Rate") + plt.title("ROC curve for multi-class data", y=1.04, fontsize=10) + plt.legend(loc="best", fontsize="small") + save_path = join( + self._output_path, "ROC_curve_multiclass_classification" + self._figures_ext + ) + plt.savefig(save_path, bbox_inches="tight", dpi=90) + plt.show() + plt.close() + + self._add_report_resource( + "roc_curve_multiclass", + "ROC curve (Multiclass classification)", + "The following plot shows the Receiver operating characteristic (ROC) curve for the multiclass classification.", + save_path, + ) + + self._add_report_resource( + "AUC_multiclass_score", + "AUC Multiclass Score", + f"The AUC score value for the multiclass classification is: {CLASSES_LABELS[0]}: {round(roc_auc[0], 4)}, {CLASSES_LABELS[1]}: {round(roc_auc[1], 4)},\ + {CLASSES_LABELS[2]}: {round(roc_auc[2], 4)}, {CLASSES_LABELS[3]}: {round(roc_auc[3], 4)}.", + roc_auc, + ) + + def plot_confusion_matrix_multiclass(self): + """Plot confusion matrix (Multiclass classification) + + Arguments + ---------- + multiclass_targets: Array-like of shape (n_samples,) with multiclass target values + multiclass_predictions: Array-like of shape (n_samples,) with multiclass prediction classes (not probabilities) values + """ + + multi_conf_matrix = confusion_matrix( + self.results_report_df["Classes"], self.results_report_df["Predictions"] + ) + + conf_matrix_disp = ConfusionMatrixDisplay( + confusion_matrix=multi_conf_matrix, display_labels=CLASSES_LABELS + ) # + fig, ax = plt.subplots() + conf_matrix_disp.plot(xticks_rotation=45, ax=ax) + plt.title( + "Confusion Matrix (Multiclass Classification)", y=1.04, fontsize=11 + ) # , y=1.12 + save_path = join( + self._output_path, "confusion_matrix_multiclass" + self._figures_ext + ) + plt.savefig(save_path, bbox_inches="tight", dpi=80) + plt.show() + plt.close() + + self._add_report_resource( + "confusion_matrix_multiclass", + "Confusion Matrix (Multiclass Classification)", + "The following plot shows the Confusion Matrix (Multiclass Classification).", + save_path, + ) + + def plot_heatmap_prediction_prob_error(self): + """Plot Heatmap of average error between probabilities of real classes vs predicted + + Arguments + ---------- + dict_of_prob_difference: Dictionary containing the average difference between the predicted probabilities of + the predicted classes and the predicted probabilities of real classes + + """ + probability_var = {} + # probability_var_with_std = {} + + for k, v in self.results_metrics_dict["delta_prob_pred"].items(): + probability_var[k] = {} + # probability_var_with_std[k] = {} + + for m, n in self.results_metrics_dict["delta_prob_pred"][k].items(): + # probability_var_with_std[k][m] = {} + + probability_var[k][m] = round( + np.mean(self.results_metrics_dict["delta_prob_pred"][k][m]), + ndigits=3, + ) + + # probability_var_with_std[k][m]['mean'] = round(np.mean(self.results_metrics_dict['delta_prob_pred'][k][m]), ndigits = 3) + # probability_var_with_std[k][m]['std'] = round(np.std(self.results_metrics_dict['delta_prob_pred'][k][m]), ndigits = 3) + + df_probability_var = pd.DataFrame(probability_var) # \ + df_probability_var.columns = CLASSES_LABELS + df_probability_var.index = CLASSES_LABELS + sns.heatmap(df_probability_var, cmap="viridis", linewidths=0.05, annot=True) + plt.yticks(rotation=0) + plt.xticks(rotation=45) + plt.title( + "Heatmap of average error between probabilities of real classes vs predicted", + y=1.04, + fontsize=11, + ) # , y=1.12 + save_path = join( + self._output_path, "heatmap_prediction_prob_error" + self._figures_ext + ) + plt.savefig(save_path, bbox_inches="tight", dpi=80) + plt.show() + plt.close() + + self._add_report_resource( + "heatmap_prediction_prob_error", + "Average error between probabilities of real classes vs predicted", + "The following plot shows the average error between probabilities of real classes vs predicted.", + save_path, + ) + + +class evaluation_results: + def __init__( + self, + predictions, + targets, + num_clases=len(CLASSES_LABELS), + threshold=None, + print_results=True, + ): + super(evaluation_results, self).__init__() + + self.predictions = predictions + self.targets = targets + self.num_clases = num_clases + self.threshold = threshold + self.print_results = print_results + self.all_pred = None + + self.correct_p = 0 + self.incorrect_p = 0 + + self.TP = 0 + self.TN = 0 + self.FP = 0 + self.FN = 0 + + self.results_dict = {} + self.diff_prob_dict = {} + self.eval_results = {} + self.evaluate() + + def evaluate(self): + predicted = np.empty(len(self.predictions)) + + for i in range(self.num_clases): + I = str(i) + + self.results_dict[I] = {} + self.diff_prob_dict[I] = {} + + for j in range(self.num_clases): + J = str(j) + + self.results_dict[I][J] = 0 + self.diff_prob_dict[I][J] = [] + + if self.threshold: + thresh_pred = np.squeeze(self.predictions) + + index_1 = thresh_pred[:, 0] >= self.threshold + index_2 = thresh_pred[:, 0] < self.threshold + + predicted[index_1] = 0 + predicted[index_2] = np.argmax(thresh_pred[:, 1:][index_2], axis=-1) + 1 + + else: + predicted = np.argmax(self.predictions, axis=-1) + + probabilities = np.squeeze(self.predictions) + + expected = self.targets + + expected_str = np.squeeze([str(x) for x in np.argmax(expected, axis=-1)]) + expected_int = np.squeeze(np.argmax(expected, axis=-1)) + predicted_str = np.array([str(int(x)) for x in predicted]) + predicted = np.array([int(x) for x in predicted]) + + self.all_pred = predicted + + for i in range(len(predicted)): + self.results_dict[expected_str[i]][predicted_str[i]] += 1 + + diff_prob = np.absolute( + probabilities[i, expected_int[i]] - probabilities[i, predicted[i]] + ) + + self.diff_prob_dict[expected_str[i]][predicted_str[i]].append(diff_prob) + + correct_p = sum(expected_int == predicted) + + incorrect_p = sum(expected_int != predicted) + + non_flyer_index = np.array(expected_int == 0) + + self.TN = int(sum(expected_int[non_flyer_index] == predicted[non_flyer_index])) + + self.FP = int(sum(expected_int[non_flyer_index] != predicted[non_flyer_index])) + + flyer_index = np.array(expected_int != 0) + + self.TP = int(sum(predicted[flyer_index] != 0)) + + self.FN = int(sum(predicted[flyer_index] == 0)) + + if (self.TP + self.TN + self.FP + self.FN) > 0: + binary_accuracy = (self.TP + self.TN) / ( + self.TP + self.TN + self.FP + self.FN + ) + else: + binary_accuracy = None + + if (correct_p + incorrect_p) > 0: + overall_accuracy = correct_p / (correct_p + incorrect_p) + else: + overall_accuracy = None + + if (self.TP + self.FN) > 0: + true_positive_rate = self.TP / (self.TP + self.FN) + else: + true_positive_rate = None + + if (self.TN + self.FP) > 0: + false_positive_rate = self.TN / (self.TN + self.FP) + else: + false_positive_rate = None + + if (self.TP + self.FP) > 0: + precision = self.TP / (self.TP + self.FP) + else: + precision = None + + if precision > 0 and true_positive_rate > 0: + f_score = 2 / ((1 / precision) + (1 / true_positive_rate)) + else: + f_score = None + + if ( + (self.TP + self.FP) + * (self.TP + self.FN) + * (self.TN + self.FP) + * (self.TN + self.FN) + ) > 0: + MCC = ((self.TP * self.TN) - (self.FP * self.FN)) / np.sqrt( + float( + (self.TP + self.FP) + * (self.TP + self.FN) + * (self.TN + self.FP) + * (self.TN + self.FN) + ) + ) + else: + MCC = None + + conf_matrix = {"TP": self.TP, "TN": self.TN, "FP": self.FP, "FN": self.FN} + + self.eval_results = { + "predictions": predicted, + "probabilities": probabilities, + "categorical_accuracy": overall_accuracy, + "binary_accuracy": binary_accuracy, + "true_positive_rate": true_positive_rate, + "false_positive_rate": false_positive_rate, + "precision": precision, + "f1_score": f_score, + "MCC": MCC, + "conf_matrix": conf_matrix, + "results_dict": self.results_dict, + "delta_prob_pred": self.diff_prob_dict, + } + + if self.print_results: + print( + f'Binary Accuracy: {round(self.eval_results["binary_accuracy"], ndigits = 2)}' + ) + + print( + f'\nCategorical Accuracy: {round(self.eval_results["categorical_accuracy"], ndigits = 2)}' + ) + + print( + f'\nMatthews Correlation Coefficient (MCC): {round(self.eval_results["MCC"], ndigits = 2)}' + ) + + print( + f'\nTrue Positive Rate (Recall): {round(self.eval_results["true_positive_rate"], ndigits = 2)}' + ) + + print( + f'\nFalse Positive Rate (Specificity): {round(self.eval_results["false_positive_rate"], ndigits = 2)}' + ) + + print(f'\nPrecision: {round(self.eval_results["precision"], ndigits = 2)}') + + print(f'\nF1 Score: {round(self.eval_results["f1_score"], ndigits = 2)}') + + +class predictions_report: + def __init__( + self, + predictions, + input_data_df, + output_path, + num_clases=len(CLASSES_LABELS), + rank_by_prot=False, + threshold=None, + ): + super(predictions_report, self).__init__() + + self.predictions = np.squeeze(predictions) + self.test_size = self.predictions.shape[ + 0 + ] # Removing the last part of the test data which don't fit the batch size + self.input_data_df = input_data_df.loc[: self.test_size - 1] + self.output_path = output_path + self.num_clases = num_clases + self.rank_by_prot = rank_by_prot + self.threshold = threshold + + if not os.path.exists(self.output_path): + os.makedirs(self.output_path) + + self.all_pred = None + self.predictions_report = None + self.evaluate() + + def evaluate(self): + predicted = np.empty(len(self.predictions)) + + target_labels = { + 0: "Non-Flyer", + 1: "Weak Flyer", + 2: "Intermediate Flyer", + 3: "Strong Flyer", + } + binary_labels = {0: "Non-Flyer", 1: "Flyer"} + + if self.threshold: + thresh_pred = np.squeeze(self.predictions) + + index_1 = thresh_pred[:, 0] >= self.threshold + index_2 = thresh_pred[:, 0] < self.threshold + + predicted[index_1] = 0 + predicted[index_2] = np.argmax(thresh_pred[:, 1:][index_2], axis=-1) + 1 + + else: + predicted = np.argmax(self.predictions, axis=-1) + predicted = [x for x in predicted] + + self.all_pred = predicted + + df_data_results = self.input_data_df.copy().reset_index(drop=True) + df_data_results = pd.concat( + [df_data_results, pd.DataFrame(self.all_pred, columns=["Predictions"])], + axis=1, + ) + + for i, label in enumerate(CLASSES_LABELS): + df_data_results[label] = np.round_( + np.array(self.predictions)[:, i], decimals=3 + ) + + df_data_results["Flyer"] = ( + df_data_results["Weak Flyer"] + + df_data_results["Intermediate Flyer"] + + df_data_results["Strong Flyer"] + ) + # df_data_results['Flyer'] = round(df_data_results['Flyer'], ndigits = 3) + df_data_results["Flyer"] = np.round_(df_data_results["Flyer"], decimals=3) + df_data_results["Binary Predictions"] = np.where( + df_data_results["Predictions"] == 0, 0, 1 + ) + + if "Classes" in df_data_results.columns: + df_data_results["Binary Classes"] = np.where( + df_data_results["Classes"] == 0, 0, 1 + ) + df_data_results["Classes"] = df_data_results["Classes"].map(target_labels) + df_data_results["Binary Classes"] = df_data_results["Binary Classes"].map( + binary_labels + ) + + sorted_columns = [ + "Sequences", + "Proteins", + "Weak Flyer", + "Intermediate Flyer", + "Strong Flyer", + "Non-Flyer", + "Flyer", + "Classes", + "Predictions", + "Binary Classes", + "Binary Predictions", + ] + + all_columns = [x for x in sorted_columns if x in df_data_results.columns] + + df_data_results = df_data_results[all_columns] + + if "Proteins" in df_data_results.columns and self.rank_by_prot: + df_data_results = df_data_results.sort_values( + by=[ + "Proteins", + "Flyer", + "Predictions", + "Strong Flyer", + "Intermediate Flyer", + "Weak Flyer", + ], + ascending=[True, False, False, False, False, False], + ).reset_index(drop=True) + + df_data_results["Rank"] = ( + df_data_results.groupby("Proteins")["Flyer"] + .rank(ascending=False, method="first") + .astype(int) + ) + # df_data_results['Rank_2'] = df_data_results.groupby('Proteins')['Flyer'].rank(ascending = False, method = 'dense').astype(int) + + else: + df_data_results = df_data_results.sort_values( + by=[ + "Flyer", + "Predictions", + "Strong Flyer", + "Intermediate Flyer", + "Weak Flyer", + ], + ascending=[False, False, False, False, False], + ).reset_index(drop=True) + + df_data_results["Rank"] = ( + df_data_results["Flyer"] + .rank(ascending=False, method="first") + .astype(int) + ) + # df_data_results['Rank_2'] = df_data_results['Flyer'].rank(ascending = False, method = 'dense').astype(int) + + df_data_results["Predictions"] = df_data_results["Predictions"].map( + target_labels + ) + df_data_results["Binary Predictions"] = df_data_results[ + "Binary Predictions" + ].map(binary_labels) + + self.predictions_report = df_data_results + + save_path = join(self.output_path, "Dectetability_prediction_report.csv") + self.predictions_report.to_csv(save_path, index=False) diff --git a/src/dlomix/reports/Report.py b/src/dlomix/reports/Report.py index c9705b0d..05d84967 100644 --- a/src/dlomix/reports/Report.py +++ b/src/dlomix/reports/Report.py @@ -1,3 +1,5 @@ +# -*- coding: utf-8 -*- + import abc import glob import warnings @@ -10,33 +12,13 @@ class Report(abc.ABC): - """Base class for reports. - - Child classes should implement the abstract method `generate_report`. + """Base class for reports, child classes should implement the abstract method generate_report. Parameters ---------- - output_path : str - Path to save output files and figures. - history : tf.keras.callbacks.History or dict - Reference to a Keras History object or its history dict attribute (History.history). - figures_ext : str - File extension and format for saving figures. - - Attributes - ---------- - VALID_FIGURE_FORMATS : list - List of valid figure formats. - - Methods - ------- - plot_keras_metric(metric_name, save_plot=True) - Plot a Keras metric given its name and the history object. - plot_all_metrics() - Plot all available Keras metrics in the History object. - generate_report(targets, predictions, **kwargs) - Abstract method to generate a complete report. - + output_path: path to save output files and figures. + history : reference to a Keras History object or its history dict attribute (History.history). + figures_ext: File extension and format for saving figures. """ VALID_FIGURE_FORMATS = ["pdf", "jpeg", "jpg", "png"] @@ -114,14 +96,12 @@ def _compile_report_resources_add_pdf_pages(self): ) def plot_keras_metric(self, metric_name, save_plot=True): - """Plot a Keras metric given its name and the history object. + """Plot a keras metric given its name and the history object returned by model.fit() Arguments --------- - metric_name : str - The name of the metric. - save_plot : bool, optional - Whether to save the plot to disk or not. Defaults to True. + metric_name: String with the name of the metric. + save_plot (bool, optional): whether to save plot to disk or not. Defaults to True. """ if metric_name.lower() not in self._history_dict.keys(): @@ -137,13 +117,15 @@ def plot_keras_metric(self, metric_name, save_plot=True): ) plt.plot(self._history_dict[metric_name]) plt.plot(self._history_dict["val_" + metric_name]) - plt.title(metric_name) + plt.title(metric_name, fontsize=10) # Modified Original plt.title(metric_name) plt.ylabel(metric_name) plt.xlabel("epoch") - plt.legend(["train", "val"], loc="upper left") + plt.legend(["train", "val"], loc="best") if save_plot: save_path = join(self._output_path, metric_name + self._figures_ext) - plt.savefig(save_path) + plt.savefig( + save_path, bbox_inches="tight", dpi=90 + ) # Modification Original plt.savefig(save_path) plt.show() plt.close() metric_name_spaced = metric_name.replace("_", " ") @@ -163,16 +145,12 @@ def plot_all_metrics(self): @abc.abstractmethod def generate_report(self, targets, predictions, **kwargs): - """Abstract method to generate a complete report. - - Child classes need to implement this method. + """Abstract method to generate a complete report. Child classes need to implement this method. Arguments --------- - targets : array-like - Array with target values. - predictions : array-like - Array with prediction values. + targets: Array with target values. + predictions: Array with prediction values. """ pass @@ -226,7 +204,7 @@ def _add_section_content(self, section_title, section_body): if section_body != "": self.set_font(*PDFFile.SECTION_PARAGRAPH_FONT) self.multi_cell(w=0, h=PDFFile.LINE_HEIGHT, txt=section_body) - self.ln(PDFFile.LINE_HEIGHT) + self.ln(2 * PDFFile.LINE_HEIGHT) def _create_first_page_if_document_empty(self): if self.document_empty: @@ -244,122 +222,17 @@ def add_content_text_page(self, section_title, section_body): self._create_first_page_if_document_empty() self._add_section_content(section_title, section_body) - class PDFFile(FPDF): - """PDF file template class. - - Parameters - ---------- - title : str - Title for the pdf file. - - Attributes - ---------- - PAGE_WIDTH : int - Width of the PDF page. - PAGE_HEIGHT : int - Height of the PDF page. - SECTION_PARAGRAPH_FONT : list - Font settings for section paragraphs. - SECTION_TITLE_FONT : list - Font settings for section titles. - LINE_HEIGHT : int - Height of a line in the PDF. - - Methods - ------- - header() - Override method to add a header to the PDF. - footer() - Override method to add a footer to the PDF. - _add_plot(plot_filepath) - Add a plot to the PDF. - _add_section_content(section_title, section_body) - Add a section title and body to the PDF. - _create_first_page_if_document_empty() - Create the first page of the PDF if it is empty. - add_content_text_page(section_title, section_body) - Add a section title, paragraph, and plot to the PDF. - add_content_plot_page(plot_filepath, section_title="", section_body="") - Add a new page with a section title, paragraph, and plot to the PDF. + def add_content_plot_page(self, plot_filepath, section_title="", section_body=""): + """Add a new page with a section title, a paragraph, and a plot. At least a plot has to be provided. + Arguments + --------- + plot_filepath (str): filepath of the plot to be inserted in the new page. + section_title (str, optional): title for the section. Defaults to "". + section_body (str, optional): paragraph text to add. Defaults to "". """ - PAGE_WIDTH = 210 - PAGE_HEIGHT = 297 - SECTION_PARAGRAPH_FONT = ["Arial", "", 11] - SECTION_TITLE_FONT = ["Arial", "B", 13] - LINE_HEIGHT = 5 - - def __init__(self, title): - super().__init__() - self.title = title - self.width = PDFFile.PAGE_WIDTH - self.height = PDFFile.PAGE_HEIGHT - - self.set_auto_page_break(True) - self.document_empty = True - - def header(self): - self.set_font("Arial", "B", 11) - self.cell(self.width - 80) - self.cell(60, 1, self.title, 0, 0, "R") - self.ln(20) - - def footer(self): - # Page numbers in the footer - self.set_y(-15) - self.set_font("Arial", "I", 8) - self.set_text_color(128) - self.cell(0, 10, "Page " + str(self.page_no()), 0, 0, "C") - - def _add_plot(self, plot_filepath): - self.image(plot_filepath) - self.ln(3 * PDFFile.LINE_HEIGHT) - - def _add_section_content(self, section_title, section_body): - if section_title != "": - self.set_font(*PDFFile.SECTION_TITLE_FONT) - self.cell(w=0, txt=section_title) - self.ln(PDFFile.LINE_HEIGHT) - if section_body != "": - self.set_font(*PDFFile.SECTION_PARAGRAPH_FONT) - self.multi_cell(w=0, h=PDFFile.LINE_HEIGHT, txt=section_body) - self.ln(PDFFile.LINE_HEIGHT) - - def _create_first_page_if_document_empty(self): - if self.document_empty: - self.add_page() - self.document_empty = False - - def add_content_text_page(self, section_title, section_body): - """Add a section title and a paragraph. - - Parameters - ---------- - section_title : str - Title for the section. - section_body : str - Paragraph text to add. - - """ - self._create_first_page_if_document_empty() - self._add_section_content(section_title, section_body) - - def add_content_plot_page( - self, plot_filepath, section_title="", section_body="" - ): - """Add a new page with a section title, a paragraph, and a plot. At least a plot has to be provided. - - Parameters - ---------- - plot_filepath : str - Filepath of the plot to be inserted in the new page. - section_title : str, optional - Title for the section. Defaults to "". - section_body : str, optional - Paragraph text to add. Defaults to "". - - """ - self._create_first_page_if_document_empty() - self._add_section_content(section_title, section_body) - self._add_plot(plot_filepath) + self._create_first_page_if_document_empty() + self.add_page() + self._add_section_content(section_title, section_body) + self._add_plot(plot_filepath) diff --git a/src/dlomix/reports/__init__.py b/src/dlomix/reports/__init__.py index 2e5c6a59..66e595a2 100644 --- a/src/dlomix/reports/__init__.py +++ b/src/dlomix/reports/__init__.py @@ -1,7 +1,9 @@ +from .DetectabilityReport import DetectabilityReport from .IntensityReport import IntensityReport from .RetentionTimeReport import RetentionTimeReport __all__ = [ "RetentionTimeReport", "IntensityReport", + "DetectabilityReport", ]