diff --git a/baseline_model/Prosit_baseline_model.keras b/baseline_model/Prosit_baseline_model.keras new file mode 100644 index 00000000..9b4a1b2d Binary files /dev/null and b/baseline_model/Prosit_baseline_model.keras differ diff --git a/notebooks/Example_automatic_refinement_transfer_learning.ipynb b/notebooks/Example_automatic_refinement_transfer_learning.ipynb new file mode 100644 index 00000000..c37f09e7 --- /dev/null +++ b/notebooks/Example_automatic_refinement_transfer_learning.ipynb @@ -0,0 +1,106 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = '-1'\n", + "os.environ['HF_HOME'] = '/cmnfs/proj/bmpc_dlomix/datasets'\n", + "os.environ['HF_DATASETS_CACHE'] = '/cmnfs/proj/bmpc_dlomix/datasets/hf_cache'\n", + "\n", + "num_proc = 16\n", + "os.environ[\"OMP_NUM_THREADS\"] = f\"{num_proc}\"\n", + "os.environ[\"TF_NUM_INTRAOP_THREADS\"] = f\"{num_proc}\"\n", + "os.environ[\"TF_NUM_INTEROP_THREADS\"] = f\"{num_proc}\"\n", + "\n", + "import tensorflow as tf\n", + "tf.config.threading.set_inter_op_parallelism_threads(num_proc)\n", + "tf.config.threading.set_intra_op_parallelism_threads(num_proc)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from dlomix.data import load_processed_dataset\n", + "\n", + "dataset = load_processed_dataset('/cmnfs/proj/bmpc_dlomix/datasets/processed/ptm_baseline_small_cleaned_bs1024')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from dlomix.models import PrositIntensityPredictor\n", + "from dlomix.losses import masked_spectral_distance, masked_pearson_correlation_distance\n", + "\n", + "model = tf.keras.models.load_model('/cmnfs/proj/bmpc_dlomix/models/baseline_models/noptm_baseline_full_bs1024_unmod_extended/7ef3360f-2349-46c0-a905-01187d4899e2.keras')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from dlomix.refinement_transfer_learning.automatic_rl_tl import AutomaticRlTlTraining, AutomaticRlTlTrainingConfig\n", + "\n", + "config = AutomaticRlTlTrainingConfig(\n", + " dataset=dataset,\n", + " baseline_model=model,\n", + " use_wandb=True\n", + ")\n", + "\n", + "trainer = AutomaticRlTlTraining(config)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "new_model = trainer.train()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/setup.py b/setup.py index 3ba35a21..71719ed1 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ packages=setuptools.find_packages(where="src"), package_dir={"": "src"}, include_package_data=True, - package_data={"": ["data/processing/pickled_feature_dicts/*"]}, + package_data={"": ["data/processing/pickled_feature_dicts/*", "prosit_baseline_model.txt", "refinement_transfer_learning/user_report.ipynb"]}, install_requires=[ "datasets", "fpdf", @@ -45,6 +45,10 @@ "wandb": [ "wandb >= 0.15", ], + "rltl-report": [ + "nbconvert", + "ipykernel" + ] }, classifiers=[ "Programming Language :: Python :: 3", diff --git a/src/dlomix/data/charge_state.py b/src/dlomix/data/charge_state.py index 3042d898..f779449e 100644 --- a/src/dlomix/data/charge_state.py +++ b/src/dlomix/data/charge_state.py @@ -42,6 +42,8 @@ def __init__( sequence_column: str = "modified_sequence", label_column: str = "most_abundant_charge_by_count", val_ratio: float = 0.2, + test_ratio: float = 0.2, + advanced_splitting: bool = False, max_seq_len: Union[int, str] = 30, dataset_type: str = "tf", batch_size: int = 256, @@ -59,6 +61,8 @@ def __init__( auto_cleanup_cache: bool = True, num_proc: Optional[int] = None, batch_processing_size: int = 1000, + inference_only: bool = False, + ion_types: Optional[List[str]] = None, ): super().__init__( data_source, @@ -68,6 +72,8 @@ def __init__( sequence_column, label_column, val_ratio, + test_ratio, + advanced_splitting, max_seq_len, dataset_type, batch_size, @@ -85,4 +91,6 @@ def __init__( auto_cleanup_cache, num_proc, batch_processing_size, + inference_only, + ion_types, ) diff --git a/src/dlomix/data/dataset.py b/src/dlomix/data/dataset.py index db5d8900..ef34f550 100644 --- a/src/dlomix/data/dataset.py +++ b/src/dlomix/data/dataset.py @@ -6,6 +6,9 @@ import os import warnings from typing import Callable, Dict, List, Optional, Union +import random +import collections +import numpy as np from datasets import Dataset, DatasetDict, Sequence, Value, load_dataset, load_from_disk @@ -26,6 +29,7 @@ ) logger = logging.getLogger(__name__) +logging.captureWarnings(True) class PeptideDataset: @@ -50,6 +54,10 @@ class PeptideDataset: Name of the column in the data source file that contains the labels. val_ratio : float Ratio of the validation data to the training data. The value should be between 0 and 1. + test_ratio : float + Ratio of the test data to the validation data. The value should be between 0 and 1. (Splits the validation data again, not the train data.) + advanced_splitting: bool + Flag to indicate whether to use advanced train/val/test splitting if providing a single data file or not max_seq_len : int Maximum sequence length to pad the sequences to. If set to 0, the sequences will not be padded. dataset_type : str @@ -102,7 +110,7 @@ class PeptideDataset: Create a PeptideDataset object from a DatasetConfig object. """ - DEFAULT_SPLIT_NAMES = ["train", "val", "test"] + DEFAULT_SPLIT_NAMES = ["train", "val", "test", "inference"] CONFIG_JSON_NAME = "dlomix_peptide_dataset_config.json" def __init__( @@ -114,6 +122,8 @@ def __init__( sequence_column: str, label_column: str, val_ratio: float, + test_ratio: float, + advanced_splitting: bool, max_seq_len: int, dataset_type: str, batch_size: int, @@ -131,6 +141,8 @@ def __init__( auto_cleanup_cache: bool = True, num_proc: Optional[int] = None, batch_processing_size: Optional[int] = 1000, + inference_only: bool = False, + ion_types: Optional[List[str]] = None, ): super(PeptideDataset, self).__init__() self.data_source = data_source @@ -143,10 +155,17 @@ def __init__( self.label_column = label_column self.val_ratio = val_ratio + self.test_ratio = test_ratio + self.advanced_splitting = advanced_splitting self.max_seq_len = max_seq_len self.dataset_type = dataset_type self.batch_size = batch_size self.model_features = model_features + self.inference_only = inference_only + if isinstance(inference_only, bool) and inference_only: + self.label_column = None + # y and b ions are the standard ions which are used for Prosit + self.ion_types = ['y', 'b'] if ion_types is None else ion_types # to be kept in the hf dataset, but not returned in the tensor dataset if dataset_columns_to_keep is None: @@ -174,6 +193,7 @@ def __init__( self.hf_dataset: Optional[Union[Dataset, DatasetDict]] = None self._empty_dataset_mode = False self._is_predefined_split = False + self._is_predefined_test_split = False self._test_set_only = False self._num_proc = num_proc self._set_num_proc() @@ -224,6 +244,7 @@ def _refresh_config(self): sequence_column=self.sequence_column, label_column=self.label_column, val_ratio=self.val_ratio, + test_ratio=self.test_ratio, max_seq_len=self.max_seq_len, dataset_type=self.dataset_type, batch_size=self.batch_size, @@ -235,6 +256,7 @@ def _refresh_config(self): alphabet=self.alphabet, encoding_scheme=self.encoding_scheme, processed=self.processed, + inference_only=self.inference_only ) self._config._additional_data.update( @@ -331,11 +353,15 @@ def _decide_on_splitting(self): self._test_set_only = True if self.val_data_source is not None: self._is_predefined_split = True + if self.test_ratio == 0.0 or self.test_ratio is None: + self._is_predefined_test_split = True # two or more data sources provided -> no splitting in all cases if count_loaded_data_sources >= 2: if self.val_data_source is not None: self._is_predefined_split = True + if self.test_data_source is not None: + self._is_predefined_test_split = True if self._is_predefined_split: warnings.warn( @@ -346,7 +372,17 @@ def _decide_on_splitting(self): ) def _remove_unnecessary_columns(self): - self._relevant_columns = [self.sequence_column, self.label_column] + # if inference only dataset, no sequence column necessary + if self.inference_only: + warnings.warn( + """ + This is a inference only dataset! You can only make predictions with this dataset! Attempting to + train a model with this dataset will result in an error! + """ + ) + self._relevant_columns = [self.sequence_column] + else: + self._relevant_columns = [self.sequence_column, self.label_column] if self.model_features is not None: self._relevant_columns.extend(self.model_features) @@ -361,16 +397,47 @@ def _remove_unnecessary_columns(self): def _split_dataset(self): if self._is_predefined_split or self._test_set_only: return + + if self.inference_only: + # if inference only dataset -> only keep the inference dataset and remove the train set + self.hf_dataset["inference"] = self.hf_dataset[PeptideDataset.DEFAULT_SPLIT_NAMES[0]] + self.hf_dataset.pop(PeptideDataset.DEFAULT_SPLIT_NAMES[0]) + return # only a train dataset or a train and a test but no val -> split train into train/val - - splitted_dataset = self.hf_dataset[ - PeptideDataset.DEFAULT_SPLIT_NAMES[0] - ].train_test_split(test_size=self.val_ratio) + splitted_dataset = dict() + if not self.advanced_splitting: + splitted_dataset = self.hf_dataset[ + PeptideDataset.DEFAULT_SPLIT_NAMES[0] + ].train_test_split(test_size=self.val_ratio) + else: + splitted_dataset["train"], splitted_dataset["test"] = _split_dataset_advanced( + self.hf_dataset[PeptideDataset.DEFAULT_SPLIT_NAMES[0]], + ratio=self.val_ratio + ) self.hf_dataset["train"] = splitted_dataset["train"] self.hf_dataset["val"] = splitted_dataset["test"] + # if test set is not specified -> split train set into test set and remaining train set + if self._is_predefined_test_split: + del splitted_dataset["train"] + del splitted_dataset["test"] + del splitted_dataset + return + + self.test_ratio = self.test_ratio / (1 - self.val_ratio) + if not self.advanced_splitting: + splitted_dataset = self.hf_dataset["train"].train_test_split(test_size=self.test_ratio) + else: + splitted_dataset = _split_dataset_advanced( + self.hf_dataset['train'], + ratio=self.test_ratio + ) + + self.hf_dataset["train"] = splitted_dataset["train"] + self.hf_dataset["test"] = splitted_dataset["test"] + del splitted_dataset["train"] del splitted_dataset["test"] del splitted_dataset @@ -507,7 +574,7 @@ def _apply_processing_pipeline(self): if isinstance(processor, SequencePaddingProcessor): for split in self.hf_dataset.keys(): - if split != "test": + if split != "test" and split != "inference": logger.info( f"Removing truncated sequences in the {split} split ..." ) @@ -601,6 +668,8 @@ def from_dataset_config(cls, config: DatasetConfig): sequence_column=config.sequence_column, label_column=config.label_column, val_ratio=config.val_ratio, + test_ratio=config.test_ratio, + advanced_splitting=config.advanced_splitting, max_seq_len=config.max_seq_len, dataset_type=config.dataset_type, batch_size=config.batch_size, @@ -612,6 +681,8 @@ def from_dataset_config(cls, config: DatasetConfig): alphabet=config.alphabet, encoding_scheme=config.encoding_scheme, processed=config.processed, + inference_only=config.inference_only, + ion_types=config.ion_types, ) for k, v in config._additional_data.items(): @@ -646,7 +717,10 @@ def _get_input_tensor_column_names(self): input_tensor_columns = self._relevant_columns.copy() # remove the label column from the input tensor columns since the to_tf_dataset method has a separate label_cols argument - input_tensor_columns.remove(self.label_column) + try: + input_tensor_columns.remove(self.label_column) + except ValueError as e: + pass # remove the columns that are not needed in the tensor dataset input_tensor_columns = list( @@ -687,6 +761,16 @@ def tensor_test_data(self): tf_dataset = tf_dataset.cache() return tf_dataset + + @property + def tensor_inference_data(self): + """TensorFlow Dataset object for the inference data""" + tf_dataset = self._get_split_tf_dataset(PeptideDataset.DEFAULT_SPLIT_NAMES[3]) + + if self.enable_tf_dataset_cache: + tf_dataset = tf_dataset.cache() + + return tf_dataset def _get_split_tf_dataset(self, split_name: str): existing_splits = list(self.hf_dataset.keys()) @@ -729,3 +813,66 @@ def load_processed_dataset(path: str): dataset = class_obj.from_dataset_config(config) dataset.hf_dataset = hf_dataset return dataset + + +def _argshuffle_splits(idx, traininsplit): + """Splits the input list of indices according to the given trainsplit and shuffles the indices + inside both the train and validation split, so that all peptides are shuffeled + + Args: + idx (list): All available indices + trainsplit (float): the ratio for the smaller dataset (e.g. 0.3 for 30% validation and 70% train split) + Returns: + tuple[list]: first entry are the shuffled train indices, second entry are the shuffled val/test indices + """ + train_idx = list(np.random.permutation(idx[:traininsplit])) + test_idx = list(np.random.permutation(idx[traininsplit:])) + return train_idx, test_idx + +def _peptide_argsort(sequence_integer, seed=42): + """Groups the peptides by sequence and shuffles both the peptides as groups and the peptides within each group + + Args: + sequence_integer: list of all peptide sequences present in the dataset + seed (int): Seed for reproducible results (defaults to 42) + + Returns: + list: list of shuffled indices, which are ordered by unique peptides + """ + random.seed(seed) + peptide_groups = collections.defaultdict(list) + for index, row in enumerate(sequence_integer): + peptide_groups[tuple(row)].append(index) + + # shuffle within peptides + for indeces in peptide_groups.values(): + random.shuffle(indeces) + + # shuffle peptides + peptides = list(peptide_groups.keys()) + random.shuffle(peptides) + + # join indeces + indeces = [] + for peptide in peptides: + indeces.extend(peptide_groups[peptide]) + return indeces + +def _split_dataset_advanced(data_to_split: Dataset, ratio: float) -> tuple[Dataset]: + """Function to split a pyarrow dataset with advanced splitting logic. This function takes into account the peptide sequences and splits the dataset + according to the given ratio in such a way, that at most one peptide overlaps in the resulting splits. + + Args: + data_to_split: (Dataset): pyarrow dataset, which is to be used for splitting + ratio: (float): the ratio for the smaller dataset (e.g. 0.3 for 30% validation and 70% train split) + + Returns: + tuple[Dataset]: first entry of the tuple is the train split, second entry is the validation/test split + + """ + sequences = data_to_split['modified_sequence'] + n = len(sequences) + n_split = int(n * float(1 - ratio)) + idx = _peptide_argsort(sequences, seed=42) + idx_1, idx_2 = _argshuffle_splits(idx, n_split) + return data_to_split.select(idx_1), data_to_split.select(idx_2) diff --git a/src/dlomix/data/dataset_config.py b/src/dlomix/data/dataset_config.py index 2ef10764..a3baabaa 100644 --- a/src/dlomix/data/dataset_config.py +++ b/src/dlomix/data/dataset_config.py @@ -31,6 +31,10 @@ class DatasetConfig: encoding_scheme: Union[str, EncodingScheme] processed: bool _additional_data: dict = field(default_factory=dict) + test_ratio: Optional[float] = 0 + inference_only: bool = False + ion_types: Optional[List[str]] = None + advanced_splitting: bool = False def save_config_json(self, path: str): """ diff --git a/src/dlomix/data/fragment_ion_intensity.py b/src/dlomix/data/fragment_ion_intensity.py index 8fe42e44..51647472 100644 --- a/src/dlomix/data/fragment_ion_intensity.py +++ b/src/dlomix/data/fragment_ion_intensity.py @@ -17,6 +17,7 @@ class FragmentIonIntensityDataset(PeptideDataset): sequence_column (str): The name of the column containing the peptide sequences. label_column (str): The name of the column containing the intensity labels. val_ratio (float): The ratio of validation data to split from the training data. + test_ratio (float): The ratio of the test data to split from the training data. max_seq_len (Union[int, str]): The maximum length of the peptide sequences. dataset_type (str): The type of dataset to use (e.g., "tf" for TensorFlow dataset). batch_size (int): The batch size for training and evaluation. @@ -43,6 +44,8 @@ def __init__( sequence_column: str = "modified_sequence", label_column: str = "intensities_raw", val_ratio: float = 0.2, + test_ratio: float = 0.0, + advanced_splitting: bool = False, max_seq_len: Union[int, str] = 30, dataset_type: str = "tf", batch_size: int = 64, @@ -60,6 +63,8 @@ def __init__( auto_cleanup_cache: bool = True, num_proc: Optional[int] = None, batch_processing_size: int = 1000, + inference_only: bool = False, + ion_types: Optional[List[str]] = None, ): super().__init__( data_source, @@ -69,6 +74,8 @@ def __init__( sequence_column, label_column, val_ratio, + test_ratio, + advanced_splitting, max_seq_len, dataset_type, batch_size, @@ -86,4 +93,6 @@ def __init__( auto_cleanup_cache, num_proc, batch_processing_size, + inference_only, + ion_types, ) diff --git a/src/dlomix/data/processing/__init__.py b/src/dlomix/data/processing/__init__.py index 62b3748e..9f63ccc7 100644 --- a/src/dlomix/data/processing/__init__.py +++ b/src/dlomix/data/processing/__init__.py @@ -1,4 +1,5 @@ import json +import logging from .feature_extractors import ( AVAILABLE_FEATURE_EXTRACTORS, @@ -8,6 +9,8 @@ ) from .processors import FunctionProcessor, SequenceParsingProcessor +logger = logging.getLogger(__name__) + __all__ = [ "AVAILABLE_FEATURE_EXTRACTORS", "LookupFeatureExtractor", @@ -25,7 +28,7 @@ ) ) -print( +logger.debug( f""" Avaliable feature extractors are (use the key of the following dict and pass it to features_to_extract in the Dataset Class): {json.dumps(d, indent=3, sort_keys=True)}. diff --git a/src/dlomix/data/processing/processors.py b/src/dlomix/data/processing/processors.py index f14154bf..5650d711 100644 --- a/src/dlomix/data/processing/processors.py +++ b/src/dlomix/data/processing/processors.py @@ -110,7 +110,12 @@ def _parse_proforma_sequence(self, sequence_string): splitted = sequence_string.split("-") if len(splitted) == 1: - n_term, seq, c_term = "[]-", splitted[0], "-[]" + if splitted[0].startswith('[UNIMOD:'): + n_term = splitted[0][:splitted[0].find(']') + 1] + '-' + seq = splitted[0][splitted[0].find(']') + 1:] + c_term = '-[]' + else: + n_term, seq, c_term = '[]-', splitted[0], '-[]' elif len(splitted) == 2: if splitted[0].startswith("[UNIMOD:"): n_term, seq, c_term = splitted[0] + "-", splitted[1], "-[]" @@ -120,8 +125,9 @@ def _parse_proforma_sequence(self, sequence_string): n_term, seq, c_term = splitted n_term += "-" c_term = "-" + c_term - - seq = re.findall(r"[A-Za-z](?:\[UNIMOD:\d+\])*|[^\[\]]", seq) + # last option of the regex is to find unimod modifications at the beginning of the sequence when sequence does not start with []- + # e.g. seq = "[UNIMOD:1]ADEFGLMN" + seq = re.findall(r"[A-Za-z](?:\[UNIMOD:\d+\])*|[^\[\]]|\[UNIMOD:\d+\]", seq) return n_term, seq, c_term def __update_sequence_column_with_termini(self, n_terms, seq, c_terms): @@ -268,7 +274,7 @@ def single_process(self, input_data, **kwargs): } def _encode(self, sequence): - encoded = [self.alphabet.get(amino_acid) for amino_acid in sequence] + encoded = [self.alphabet[amino_acid] for amino_acid in sequence] return encoded diff --git a/src/dlomix/data/retention_time.py b/src/dlomix/data/retention_time.py index e1f3465c..002b8dc0 100644 --- a/src/dlomix/data/retention_time.py +++ b/src/dlomix/data/retention_time.py @@ -42,6 +42,8 @@ def __init__( sequence_column: str = "modified_sequence", label_column: str = "indexed_retention_time", val_ratio: float = 0.2, + test_ratio: float = 0.2, + advanced_splitting: bool = False, max_seq_len: Union[int, str] = 30, dataset_type: str = "tf", batch_size: int = 256, @@ -59,6 +61,8 @@ def __init__( auto_cleanup_cache: bool = True, num_proc: Optional[int] = None, batch_processing_size: int = 1000, + inference_only: bool = False, + ion_types: Optional[List[str]] = None, ): super().__init__( data_source, @@ -68,6 +72,8 @@ def __init__( sequence_column, label_column, val_ratio, + test_ratio, + advanced_splitting, max_seq_len, dataset_type, batch_size, @@ -85,4 +91,6 @@ def __init__( auto_cleanup_cache, num_proc, batch_processing_size, + inference_only, + ion_types, ) diff --git a/src/dlomix/interface/__init__.py b/src/dlomix/interface/__init__.py new file mode 100644 index 00000000..36ae0f5c --- /dev/null +++ b/src/dlomix/interface/__init__.py @@ -0,0 +1,9 @@ +from .oktoberfest_interface import load_keras_model, save_keras_model, process_dataset, download_model_from_github, get_model_url + +__all__ = { + 'download_model_from_github', + 'get_model_url', + 'load_keras_model', + 'save_keras_model', + 'process_dataset' +} \ No newline at end of file diff --git a/src/dlomix/interface/oktoberfest_interface.py b/src/dlomix/interface/oktoberfest_interface.py new file mode 100644 index 00000000..a3a0c639 --- /dev/null +++ b/src/dlomix/interface/oktoberfest_interface.py @@ -0,0 +1,230 @@ +import logging +from pathlib import Path +import importlib.resources as pkg_resources +from copy import deepcopy +import requests +from tensorflow.keras.models import load_model +import pyarrow.parquet as pq + + +import dlomix +from dlomix.losses import masked_spectral_distance +from dlomix.data.fragment_ion_intensity import FragmentIonIntensityDataset +from dlomix.models.prosit import PrositIntensityPredictor + +logger = logging.getLogger(__name__) + +MODEL_FILENAME = 'prosit_baseline_model.keras' +MODEL_DIR = Path.home() / '.dlomix' / 'models' + + +def get_model_url(): + with pkg_resources.open_text(dlomix, 'prosit_baseline_model.txt') as url_file: + return url_file.read().strip() + + +def download_model_from_github() -> str: + MODEL_DIR.mkdir(parents=True, exist_ok=True) + model_path = MODEL_DIR / MODEL_FILENAME + + if model_path.exists(): + logger.info(f'Using cached model: {str(model_path)}') + return model_path + + logger.info('Start downloading model from GitHub...') + model_url = get_model_url() + response = requests.get(model_url) + response.raise_for_status() + + with open(model_path, 'wb') as f: + f.write(response.content) + + logger.info(f'Model downloaded successfully under {str(model_path)}') + return str(model_path) + + +def load_keras_model(model_file_path: str = 'baseline') -> PrositIntensityPredictor: + """Load a PrositIntensityPredictor model given a model file path. + + Args: + model_file_path (str): Path to a saved PrositIntensityPredictor model (.keras format). + If no path is given, automatically downloads the baseline model. Defaults to 'baseline' + + Raises: + ValueError: If the model_file_path does not end with the .keras extension + FileNotFoundError: If the given model_file_path does not exist + + Returns: + PrositIntensityPredictor: A loaded PrositIntensityPredictor model, that can be used for predictions, refinement or transfer learning purposes. + """ + + # download the model file from github if the baseline model should be used, otherwise a model path can be specified + if model_file_path == 'baseline': + model_file_path = download_model_from_github() + return load_model(model_file_path, compile=False) + + if not str(model_file_path).endswith('.keras'): + raise ValueError('The given model file is not saved with the .keras format! Please specify a path with the .keras extension.') + if not Path(model_file_path).exists(): + raise FileNotFoundError('Given model file was not found. Please specify an existing saved model file.') + return load_model(model_file_path, compile=False) + + +def save_keras_model(model: PrositIntensityPredictor, path_to_model: str) -> None: + """Saves a given keras model to the path_to_model path. + Automatically adds the .keras extension, if the given path does not end in it. This is important, so that + the model is saved correctly to be loaded again. + + Args: + model (PrositIntensityPredictor): The model object which should be saved + path_to_model (str): Path to the model where the model should be saved + + Raises: + FileExistsError: If the model file already exists -> Raise Error + """ + if Path(path_to_model).exists(): + raise FileExistsError('This model file already exists. Specify a file, which does not exist yet.') + if not path_to_model.endswith('.keras'): + path_to_model += '.keras' + model.save(path_to_model) + + +def process_dataset( + parquet_file_path: str, + model: PrositIntensityPredictor = None, + modifications: list = None, + ion_types: list = None, + label_column: str = 'intensities_raw', + batch_size: int = 1024, + val_ratio: float = 0.2, + test_ratio: float = 0.0, + additional_columns: list[str] = None + ) -> FragmentIonIntensityDataset: + """Interface function for Oktoberfest package to correcly process a dataset and load a baseline model + + Processes the parquet file to a FragmentIonIntensityDataset, which is ready to be used for prediction and/or refinement/transfer learning + The data splits can be investigated with dataset.hf_dataset.keys(). + If the label column is not present in the given parquet_file_path, the dataset can only be used for prediction. + If the format of the given parquet file does not match the model format -> the user is warned that there may be refinement/transfer learning steps + necessary. + + Args: + parquet_file_path (str): Path to the .parquet file which has the necessary data stored. + Necessary columns are: ['modified_sequence', 'precursor_charge_onehot', 'collision_energy_aligned_normed', 'method_nbr'] + Optional columns are: ['intensities_raw'] + model (PrositIntensityPredictor, optional): Specify a loaded model of the PrositIntensityPredictor class. If None is given, + the baseline model will be automatically downloaded from GitHub and loaded. Defaults to 'None' + modifications (list, optional): A list of all modifications which are present in the dataset. Defaults to None. + ion_types (list, optional): A list of the ion types which are present in the dataset. Defaults to ['y', 'b']. + label_column (str, optional): The column identifier for where the intensity labels are, if there are any. Defaults to 'intensities_raw'. + val_ratio (float, optional): A validation split ratio. Defaults to 0.2. + test_ratio (float, optional): A test split ratio. Defaults to 0.0. + additional_columns (list[str], optional): List of additional columns to keep in dataset for downstream analysis (will not be returned as tensors). + + Raises: + ValueError: If the parquet_file_path does not have the .parquet extension + FileNotFoundError: If the parquet_file_path does not exist + + Returns: + FragmentIonIntensityDataset: + FragmentIonIntensityDataset: The fully processed dataset, which is ready to be used for prediction or transfer/refinement learning + """ + + + modifications = [] if modifications is None else modifications + ion_types = ['y', 'b'] if ion_types is None else ion_types + + # load the baseline model if None is given + if model is None: + model = load_keras_model('baseline') + + val_data_source, test_data_source = None, None + if not parquet_file_path.endswith('.parquet'): + # check if dataset is already split + train_path = parquet_file_path + '_train.parquet' + val_data_path = parquet_file_path + '_val.parquet' + test_data_path = parquet_file_path + '_test.parquet' + + # check if the train split exists, if not -> raise ValueError (val and test split are not necessary) + if not Path(train_path).exists(): + raise ValueError('The specified file is not a parquet file! Please specify a path with the .parquet extension.') + else: + parquet_file_path = train_path + + # check if validation split exists + if Path(val_data_path).exists(): + val_data_source = val_data_path + # check if test split exists + if Path(test_data_path).exists(): + test_data_source = test_data_path + + if not Path(parquet_file_path).exists(): + raise FileNotFoundError('Specified parquet file was not found. Please specify a valid parquet file.') + + # check if intensities_raw column is in the parquet file + inference_only = True + col_names = pq.read_schema(parquet_file_path).names + if label_column in col_names: + inference_only = False + + # get the differences between the model and the datset tokens + difference = set(modifications) - set(model.alphabet.keys()) + if not difference: + new_alphabet = model.alphabet + else: + logger.warning( + """ + There are new tokens in the dataset, which are not supported by the loaded model. + Either load a different model or transfer learning needs to be done. + """) + new_alphabet = deepcopy(model.alphabet) + new_alphabet.update({k: i for i, k in enumerate(difference, start=len(new_alphabet) + 1)}) + + # check for new ion types + if any([ion_type in ['c', 'z', 'a', 'x'] for ion_type in ion_types]): + if len(ion_types) == 2: + logger.warning( + """ + Number of ions is the same as the loaded model supports, but the ion types are different. + The model probably needs to be refined to achieve a better performance on these new ion types. + """) + if len(ion_types) > 2: + if 'y' in ion_types and 'b' in ion_types: + logger.warning( + """ + New Ion types in addition to y and b ions detected. + A new output layer is necessary, but it can keep trained weights for y and b ions. + """) + else: + logger.warning( + """ + Only new ion types are detected. A totally new output layer is necessary. + """ + ) + + # put additional columns in lower case TODO: remove if CAPS issue is fixed on Oktoberfest side + if additional_columns is not None: + additional_columns = [c.lower() for c in additional_columns] + + logger.info('Start processing the dataset...') + dataset = FragmentIonIntensityDataset( + data_source=parquet_file_path, + val_data_source=val_data_source, + test_data_source=test_data_source, + data_format='parquet', + label_column=label_column, + inference_only=inference_only, + val_ratio=val_ratio, + advanced_splitting=True, + batch_size=batch_size, + test_ratio=test_ratio, + alphabet=new_alphabet, + encoding_scheme='naive-mods', + model_features=['precursor_charge_onehot', 'collision_energy_aligned_normed', 'method_nbr'], + ion_types=ion_types, + dataset_columns_to_keep=additional_columns + ) + + logger.info(f'The available data splits are: {", ".join(list(dataset.hf_dataset.keys()))}') + + return dataset diff --git a/src/dlomix/losses/intensity.py b/src/dlomix/losses/intensity.py index 2b713c3b..3b2223b4 100644 --- a/src/dlomix/losses/intensity.py +++ b/src/dlomix/losses/intensity.py @@ -1,8 +1,10 @@ import numpy as np import tensorflow as tf import tensorflow.keras.backend as K +import keras +@keras.saving.register_keras_serializable() def masked_spectral_distance(y_true, y_pred): """ Calculates the masked spectral distance between true and predicted intensity vectors. @@ -49,6 +51,7 @@ def masked_spectral_distance(y_true, y_pred): return 2 * arccos / np.pi +@keras.saving.register_keras_serializable() def masked_pearson_correlation_distance(y_true, y_pred): """ Calculates the masked Pearson correlation distance between true and predicted intensity vectors. diff --git a/src/dlomix/models/prosit.py b/src/dlomix/models/prosit.py index c709db47..596a597d 100644 --- a/src/dlomix/models/prosit.py +++ b/src/dlomix/models/prosit.py @@ -1,3 +1,4 @@ +import logging import warnings import tensorflow as tf @@ -6,6 +7,9 @@ from ..data.processing.feature_extractors import FEATURE_EXTRACTORS_PARAMETERS from ..layers.attention import AttentionLayer, DecoderAttentionLayer +logger = logging.getLogger(__name__) +logging.captureWarnings(True) + class PrositRetentionTimePredictor(tf.keras.Model): """ @@ -123,6 +127,8 @@ class PrositIntensityPredictor(tf.keras.Model): List of string values corresponding to fixed keys in the inputs dict that are considered meta data. Defaults to None, which corresponds then to the default meta data keys `META_DATA_KEYS`. with_termini : boolean, optional Whether to consider the termini in the sequence. Defaults to True. + ion_types : list[str], optional + Save the ion types with the model to later check compatibility with a given dataset. Attributes ---------- @@ -167,6 +173,7 @@ def __init__( input_keys=None, meta_data_keys=None, with_termini=True, + ion_types=None ): super(PrositIntensityPredictor, self).__init__() @@ -180,6 +187,7 @@ def __init__( self.use_prosit_ptm_features = use_prosit_ptm_features self.input_keys = input_keys self.meta_data_keys = meta_data_keys + self.ion_types = ion_types # maximum number of fragment ions self.max_ion = self.seq_length - 1 @@ -319,7 +327,7 @@ def call(self, inputs, **kwargs): encoded_ptm = self.ptm_input_encoder(ptm_ac_features) elif self.use_prosit_ptm_features: warnings.warn( - f"PTM features enabled and following PTM features are expected in the model for Prosit Intesity: {PrositIntensityPredictor.PTM_INPUT_KEYS}. The actual input passed to the model contains the following keys: {list(inputs.keys())}. Falling back to no PTM features." + f"PTM features enabled and following PTM features are expected in the model for Prosit Intensity: {PrositIntensityPredictor.PTM_INPUT_KEYS}. The actual input passed to the model contains the following keys: {list(inputs.keys())}. Falling back to no PTM features." ) x = self.embedding(peptides_in) diff --git a/src/dlomix/prosit_baseline_model.txt b/src/dlomix/prosit_baseline_model.txt new file mode 100644 index 00000000..dd2f04bf --- /dev/null +++ b/src/dlomix/prosit_baseline_model.txt @@ -0,0 +1 @@ +https://github.com/wilhelm-lab/dlomix/raw/feature/bmpc/baseline_model/Prosit_baseline_model.keras \ No newline at end of file diff --git a/src/dlomix/refinement_transfer_learning/__init__.py b/src/dlomix/refinement_transfer_learning/__init__.py new file mode 100644 index 00000000..a7f7d885 --- /dev/null +++ b/src/dlomix/refinement_transfer_learning/__init__.py @@ -0,0 +1,8 @@ +from .automatic_rl_tl import AutomaticRlTlTraining, AutomaticRlTlTrainingConfig, AutomaticRlTlTrainingInstance, TrainingInstanceConfig + +__all__ = { + 'AutomaticRlTlTraining', + 'AutomaticRlTlTrainingConfig', + 'AutomaticTlTlTrainingInstance', + 'TrainingInstanceConfig' +} \ No newline at end of file diff --git a/src/dlomix/refinement_transfer_learning/automatic_rl_tl.py b/src/dlomix/refinement_transfer_learning/automatic_rl_tl.py new file mode 100644 index 00000000..3c9b31fd --- /dev/null +++ b/src/dlomix/refinement_transfer_learning/automatic_rl_tl.py @@ -0,0 +1,826 @@ +import os +import sys +import logging + +import tensorflow as tf +from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, LearningRateScheduler, CSVLogger + +from .custom_callbacks import CustomCSVLogger, BatchEvaluationCallback, InflectionPointEarlyStopping, LearningRateWarmupPerStep, InflectionPointLRReducer, OverfittingEarlyStopping + +from dlomix.constants import PTMS_ALPHABET, ALPHABET_NAIVE_MODS, ALPHABET_UNMOD +from dlomix.data import load_processed_dataset, FragmentIonIntensityDataset +from dlomix.models import PrositIntensityPredictor +from dlomix.losses import masked_spectral_distance, masked_pearson_correlation_distance +from dlomix.refinement_transfer_learning import change_layers, freezing + +from dataclasses import dataclass, asdict, field +from typing import Optional +import math +import json +import numpy as np +from pathlib import Path +import importlib.resources as importlib_resources +import shutil + +from nbconvert import HTMLExporter +import nbformat +from nbconvert.preprocessors import ExecutePreprocessor + +logger = logging.getLogger(__name__) + + +@dataclass +class AutomaticRlTlTrainingConfig: + """Configuration for an automatic refinement/transfer learning run. + + Attributes: + dataset (FragmentIonIntensityDataset): Dataset that should be used for training. The datasets needs a train and validation split and must not be an inference-only dataset. + baseline_model (Optional[PrositIntensityPredictor]): If a model is provided, this model is used as baseline for training. If no model is specified, a new model is trained from scratch. + min_warmup_sequences_new_weights (int): Determines, the length the learning rate warmup phase in phase 1 of the automatic training pipeline (training of newly added weights). Default: 4000000 + min_warmup_sequences_whole_model (int): Determines, the length the learning rate warmup phase in phase 2 of the automatic training pipeline (training of all weights in the model). Default: 4000000 + improve_further (bool): Determines whether a third training phase is performed which has more restrictive early stopping criterions and learning rate scheduling. Default: True + use_wandb (bool): Determines whether to use wandb to log the training run. Wandb needs to be installed as dependency if this is set to True. Default: False + wandb_project (str): Selects the wandb project that the run should correspond to. This is ignored if use_wandb is set to False. Default: "DLOmix_auto_RL_TL" + wandb_tags (list[str]): List of wandb tags to add to the run. This is ignored if use_wandb is set to False. Default: [] + """ + + # dataset/model parameters + dataset : FragmentIonIntensityDataset + baseline_model : Optional[PrositIntensityPredictor] + + # training parameters + min_warmup_sequences_new_weights : int = 4000000 + min_warmup_sequences_whole_model : int = 4000000 + improve_further : bool = True + + # wandb parameters + use_wandb : bool = False + wandb_project : str = 'DLOmix_auto_RL_TL' + wandb_tags : list[str] = field(default_factory=list) + + # csv logger parameters + results_log : str = 'results_log' + + + def to_dict(self): + """Converts configuration to a python dict object. Only attributes are included which can be easily represented as text. + + Returns: + dict: Configuration options as dictionary + """ + return { + 'min_warmup_sequences_new_weights': self.min_warmup_sequences_new_weights, + 'min_warmup_sequences_whole_model': self.min_warmup_sequences_whole_model, + 'improve_further': self.improve_further + } + + +@dataclass +class TrainingInstanceConfig: + learning_rate : float + num_epochs : int + + freeze_inner_layers : bool = False + freeze_whole_embedding_layer : bool = False + freeze_whole_regressor_layer : bool = False + freeze_old_embedding_weights : bool = False + freeze_old_regressor_weights : bool = False + + plateau_early_stopping : bool = False + plateau_early_stopping_patience : int = 0 + plateau_early_stopping_min_delta : float = 0 + + inflection_early_stopping : bool = False + inflection_early_stopping_min_improvement : float = 0 + inflection_early_stopping_patience : int = 0 + inflection_early_stopping_ignore_first_n : int = 0 + + inflection_lr_reducer : bool = False + inflection_lr_reducer_factor : float = 0 + inflection_lr_reducer_min_improvement : float = 0 + inflection_lr_reducer_patience : int = 0 + + lr_warmup : bool = False + lr_warmup_num_steps : int = 0 + lr_warmup_start_lr : float = 0 + + +class AutomaticRlTlTraining: + config : AutomaticRlTlTrainingConfig + results_data_path : Path + results_notebook_path : Path + + model : PrositIntensityPredictor + is_new_model : bool + + requires_new_embedding_layer : bool + can_reuse_old_embedding_weights : bool + requires_new_regressor_layer : bool + can_reuse_old_regressor_weights : bool + + current_epoch_offset : int = 0 + callbacks : list = [] + training_schedule : list = [] + validation_steps : Optional[int] + + initial_loss : float = None + + def __init__(self, config : AutomaticRlTlTrainingConfig): + """Automatic refinement/transfer learning given a dataset and optionally an existing model. The training process consists of the following phases: + + Phase 1: + This phase is only performed, if new weights were added to the model in the embedding or regressor layer (extended embedding or additional ions). Only the new weights are trained while all other weights are frozen. The training process starts with a learning rate warmup. The phase automatically stops as soon as no major improvements are detected anymore. + + Phase 2: + This phase resembles the main training process. All weights are trained and no freezing is applied. The phase starts with a learning rate warmup and automatically stops as soon as no major improvements are detected anymore. + + Phase 3: + Optional finetuning phase that is only performed if config.improve_further is set to True. This phase starts with a slightly lower learning rate than the one used in phase 2 and reduces the learning rate when as no significant improvement can be detected anymore. The phase stops automatically as soon as no improvements are detected over a longer period. + + Args: + config (AutomaticRlTlTrainingConfig): Contains all relevant configuration parameters for performing the automatic refinement/transfer learning process. Please refer to the documentation of AutomaticRlTlTrainingConfig for further documentation. + """ + self.config = config + + self._init_wandb() + self._init_logging() + self._init_model() + self._update_model_inputs() + self._update_model_outputs() + self._init_training() + self._construct_training_schedule() + self._explore_data() + + def _init_wandb(self): + """ Initializes Weights & Biases Logging if the user requested that in the config. + """ + if self.config.use_wandb: + global wandb + global WandbCallback + import wandb + from wandb.integration.keras import WandbCallback + + wandb.init( + project=self.config.wandb_project, + config=self.config.to_dict(), + tags=self.config.wandb_tags + ) + + def _init_logging(self): + """ Initializes Weights & Biases Logging and CSV Logging if the user requested that in the config. + """ + + self.results_data_path = Path(self.config.results_log) / 'log_data/' + + if not os.path.exists(self.results_data_path): + os.makedirs(self.results_data_path) + + notebook_ref = importlib_resources.files('dlomix') / 'refinement_transfer_learning' / 'user_report.ipynb' + self.results_notebook_path = Path(self.config.results_log) / 'report.ipynb' + with importlib_resources.as_file(notebook_ref) as path: + shutil.copyfile(path, self.results_notebook_path) + + self.csv_logger = CustomCSVLogger(f'{self.results_data_path}/training_log.csv', separator=',', append=True) + + def _init_model(self): + """Configures the given baseline model or creates a new model if no baseline model is provided in the config. + """ + if self.config.baseline_model is not None: + self.model = self.config.baseline_model + self.is_new_model = False + else: + # initialize new model + input_mapping = { + "SEQUENCE_KEY": "modified_sequence", + "COLLISION_ENERGY_KEY": "collision_energy_aligned_normed", + "PRECURSOR_CHARGE_KEY": "precursor_charge_onehot", + "FRAGMENTATION_TYPE_KEY": "method_nbr", + } + + meta_data_keys=["collision_energy_aligned_normed", "precursor_charge_onehot", "method_nbr"] + + self.model = PrositIntensityPredictor( + seq_length=self.config.dataset.max_seq_len, + alphabet=self.config.dataset.alphabet, + use_prosit_ptm_features=False, + with_termini=False, + input_keys=input_mapping, + meta_data_keys=meta_data_keys + ) + self.is_new_model = True + + optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4) + self.model.compile( + optimizer=optimizer, + loss=masked_spectral_distance, + metrics=[masked_pearson_correlation_distance] + ) + + + def _calculate_spectral_angles(self, stage): + """Calculates and saves the spectral angle distributions before and after training.""" + + def calculate_spectral_distance(dataset, model, max_batches=1000): + spectral_dists = [] + for i, (batch, y_true) in enumerate(dataset): + if i >= max_batches: + break + y_pred = model.predict(batch, verbose=0) + spectral_dists.extend(masked_spectral_distance(y_true=y_true, y_pred=y_pred).numpy()) + return spectral_dists + + def calculate_and_save_spectral_angle_distribution(data, model, results_log, stage, datasets=['train', 'val', 'test']): + """ + Predict the intensities, calculate spectral distances, and save the spectral angle distribution for the specified datasets. + + Args: + data: A dataset containing tensor_train_data, tensor_val_data, and tensor_test_data. + model: A trained model used for making predictions. + results_log: Directory to save the JSON files. + stage: A string indicating the stage ('before' or 'after'). + datasets: A list of strings indicating which datasets to use ('train', 'val', 'test'). + + Returns: + None (saves JSON files) + """ + def save_json(data, filename): + with open(os.path.join(results_log, filename), 'w') as f: + json.dump(data, f) + + for dataset in datasets: + if dataset not in ['train', 'val', 'test']: + raise ValueError("Invalid dataset type. Choose 'train', 'val', or 'test'.") + + try: + if dataset == 'train': + dataset_data = data.tensor_train_data + elif dataset == 'val': + dataset_data = data.tensor_val_data + elif dataset == 'test': + dataset_data = data.tensor_test_data + except ValueError: + continue + + + spectral_dists = calculate_spectral_distance(dataset_data, model) + sa_data = [1 - sd for sd in spectral_dists] + avg_sa = np.mean(sa_data) + + data_to_save = { + 'spectral_angles': sa_data, + 'average_spectral_angle': avg_sa + } + + # Load existing data if present + filename = f'spectral_angle_distribution_{dataset}.json' + file_path = os.path.join(results_log, filename) + if os.path.exists(file_path): + with open(file_path, 'r') as f: + existing_data = json.load(f) + else: + existing_data = {} + + existing_data[stage] = data_to_save + + save_json(existing_data, filename) + + calculate_and_save_spectral_angle_distribution( + data=self.config.dataset, + model=self.model, + results_log=self.results_data_path, + stage=stage, + datasets=['train', 'val', 'test'] + ) + + def _update_model_inputs(self): + """Modifies the model's embedding layer to fit the provided dataset. All decisions here are made automatically based on the provided model and dataset. + """ + model_alphabet = self.model.alphabet + dataset_alphabet = self.config.dataset.alphabet + + if self.is_new_model: + logger.info('[embedding layer] created new model with fresh embedding layer') + self.requires_new_embedding_layer = False + self.can_reuse_old_embedding_weights = False + return + + if model_alphabet == dataset_alphabet: + logger.info('[embedding layer] model and dataset modifications match') + self.requires_new_embedding_layer = False + self.can_reuse_old_embedding_weights = False + else: + logger.info('[embedding layer] model and dataset modifications do not match') + self.requires_new_embedding_layer = True + # check if the existing embedding can be reused + including_entries = [model_val == dataset_alphabet[key] for key, model_val in model_alphabet.items()] + if all(including_entries): + logger.info('[embedding layer] can reuse old embedding weights') + self.can_reuse_old_embedding_weights = True + else: + logger.info('[embedding layer] old embedding weights cannot be reused (mismatch in the mapping)') + self.can_reuse_old_embedding_weights = False + + change_layers.change_input_layer( + self.model, + list(dataset_alphabet.keys()), + freeze_old_embeds=self.can_reuse_old_embedding_weights + ) + self.model.alphabet = dataset_alphabet + + def _update_model_outputs(self): + """Modifies the model's regressor layer to fit the provided dataset. All decisions here are made automatically based on the provided model and dataset. + + Raises: + RuntimeError: Error is raised if the model and the dataset have a different sequence length. A mismatch in the sequence length is not supported. + """ + # check that sequence length matches + if self.model.seq_length != self.config.dataset.max_seq_len: + raise RuntimeError(f"Max. sequence length does not match between dataset and model (dataset: {self.config.dataset.max_seq_len}, model: {self.model.seq_length})") + + if self.is_new_model: + logger.info('[regressor layer] created new model with fresh regressor layer') + self.requires_new_regressor_layer = False + self.can_reuse_old_regressor_weights = False + return + + # check whether number of ions matches + model_ions = ['y', 'b'] + if hasattr(self.model, 'ion_types') and self.model.ion_types is not None: + model_ions = self.model.ion_types + + dataset_ions = ['y', 'b'] + if hasattr(self.config.dataset, 'ion_types') and self.config.dataset.ion_types is not None: + dataset_ions = self.config.dataset.ion_types + + if model_ions == dataset_ions: + logger.info('[regressor layer] matching ion types') + self.requires_new_regressor_layer = False + self.can_reuse_old_regressor_weights = False + else: + logger.info('[regressor layer] ion types not matching') + self.requires_new_regressor_layer = True + + if len(model_ions) <= len(dataset_ions) and all([m == d for m, d in zip(model_ions, dataset_ions)]): + logger.info('[regressor layer] can reuse existing regressor weights') + self.can_reuse_old_regressor_weights = True + else: + logger.info('[regressor layer] old regressor weights cannot be reused (mismatch in the ion ordering / num. ions)') + self.can_reuse_old_regressor_weights = False + + change_layers.change_output_layer( + self.model, + len(dataset_ions), + freeze_old_output=self.can_reuse_old_regressor_weights + ) + self.model.ion_types = dataset_ions + + def _init_training(self): + """Configures relevant training settings that are used across all phases of the training. + """ + self.callbacks = [] + if self.config.use_wandb: + class LearningRateReporter(tf.keras.callbacks.Callback): + def on_train_batch_end(self, batch, *args): + wandb.log({'learning_rate': self.model.optimizer.lr.read_value()}) + + class RealEpochReporter(tf.keras.callbacks.Callback): + def on_epoch_begin(self_inner, epoch, *args): + wandb.log({'epoch_total': epoch + self.current_epoch_offset}) + + self.callbacks = [WandbCallback(save_model=False, log_batch_frequency=True, verbose=1), LearningRateReporter(), RealEpochReporter(), self.csv_logger] + else: + self.callbacks = [ + self.csv_logger + ] + + self.progress_reporter_min_loss = None + class LossProgressReporter(tf.keras.callbacks.Callback): + counter : int = 0 + def on_train_batch_end(self_inner, batch, logs): + loss = logs['loss'] + + if self.progress_reporter_min_loss is None: + self.progress_reporter_min_loss = loss + + loss = min(self.progress_reporter_min_loss, loss) + + if self_inner.counter % 1000 == 0: + approx_progress = min(0.9999, max(0, (self.initial_loss - loss) / (self.initial_loss - 0.1))) + logger.info(f'[training] masked spectral distance: {loss}, approx. progress: {approx_progress * 100:.2f}%') + + self_inner.counter += 1 + + self.callbacks.append(LossProgressReporter()) + + num_train_batches = self.config.dataset.tensor_train_data.cardinality().numpy() + batch_size = self.config.dataset.batch_size + num_train_sequences = batch_size * num_train_batches + self.callbacks.append(OverfittingEarlyStopping( + max_validation_train_difference=0.1, + patience=max(2, math.ceil(2000000 / num_train_sequences)), + wandb_log=self.config.use_wandb + )) + + num_val_batches = self.config.dataset.tensor_val_data.cardinality().numpy() + self.validation_steps = 1000 if num_val_batches > 1000 else None + + + + def _evaluate_model(self): + """Runs an evaluation over max. 1000 batches of the validation set and logs the validation performance. + """ + loss, metric = self.model.evaluate( + self.config.dataset.tensor_val_data, + steps=self.validation_steps, + verbose=0 + ) + + if self.initial_loss is None: + self.initial_loss = loss + + logger.info(f'[validation] loss: {loss}, pearson distance: {metric}') + if self.config.use_wandb: + wandb.log({'val_loss': loss, 'val_masked_pearson_correlation_distance': metric}) + + self.csv_logger.set_validation_metrics(val_loss=loss, val_masked_pearson_correlation_distance=metric) + return {'val_loss': loss, 'val_masked_pearson_correlation_distance': metric} + + + def _construct_training_schedule(self): + """Configures the phases of the training process based on the given config and the provided dataset and model. + """ + self.training_schedule = [] + + num_train_batches = self.config.dataset.tensor_train_data.cardinality().numpy() + batch_size = self.config.dataset.batch_size + num_train_sequences = batch_size * num_train_batches + + is_transfer_learning = self.requires_new_embedding_layer or self.requires_new_regressor_layer + + # step 1: + # warm up new weights in embedding/regressor layer + if is_transfer_learning: + warmup_sequences = self.config.min_warmup_sequences_new_weights + warmup_epochs = math.ceil(warmup_sequences / num_train_sequences) + warmup_batches = math.ceil(warmup_sequences / batch_size) + training_epochs = 10000 + self.training_schedule.append(TrainingInstanceConfig( + num_epochs=warmup_epochs + training_epochs, + learning_rate=1e-4, + lr_warmup=True, + lr_warmup_num_steps=warmup_batches, + lr_warmup_start_lr=1e-8, + inflection_early_stopping=True, + inflection_early_stopping_min_improvement=1e-4, + inflection_early_stopping_ignore_first_n=warmup_batches, + inflection_early_stopping_patience=1000, + freeze_inner_layers=True, + freeze_whole_embedding_layer=not self.requires_new_embedding_layer, + freeze_whole_regressor_layer=not self.requires_new_regressor_layer, + freeze_old_embedding_weights=self.requires_new_embedding_layer and self.can_reuse_old_embedding_weights, + freeze_old_regressor_weights=self.requires_new_regressor_layer and self.can_reuse_old_regressor_weights + )) + else: + self.csv_logger.reset_phase() + + # step 2: + # warmup whole model and do main fitting process + warmup_sequences = self.config.min_warmup_sequences_whole_model + warmup_epochs = math.ceil(warmup_sequences / num_train_sequences) + warmup_batches = math.ceil(warmup_sequences / batch_size) + training_epochs = 10000 + self.training_schedule.append(TrainingInstanceConfig( + num_epochs=warmup_epochs + training_epochs, + learning_rate=1e-4, + lr_warmup=True, + lr_warmup_num_steps=warmup_batches, + lr_warmup_start_lr=1e-8, + inflection_early_stopping=True, + inflection_early_stopping_min_improvement=1e-5, + inflection_early_stopping_ignore_first_n=warmup_batches, + inflection_early_stopping_patience=2000 + )) + + # step 3: + # optional: refine the model further to get a really good model + if self.config.improve_further: + training_epochs = 10000 + self.training_schedule.append(TrainingInstanceConfig( + num_epochs=training_epochs, + learning_rate=1e-4, + inflection_early_stopping=True, + inflection_early_stopping_min_improvement=1e-7, + inflection_early_stopping_ignore_first_n=0, + inflection_early_stopping_patience=100000, + inflection_lr_reducer=True, + inflection_lr_reducer_factor=0.7, + inflection_lr_reducer_min_improvement=1e-7, + inflection_lr_reducer_patience=5000 + )) + + def _explore_data(self): + """Generates and saves exploratory data plots in the results_log folder.""" + def save_json(data, filename): + with open(os.path.join(self.results_data_path, filename), 'w') as f: + json.dump(data, f) + + def plot_amino_acid_distribution(dataset, alphabet, dataset_name): + """Plots the frequency of each amino acid in the sequences for a given dataset split.""" + def count_amino_acids(sequences): + aa_counts = {aa: 0 for aa in alphabet} + for seq in sequences: + for aa in alphabet: + aa_counts[aa] += seq.count(aa) + return list(aa_counts.values()) + + sequences = dataset['modified_sequence_raw'] + aa_counts = count_amino_acids(sequences) + alphabet_keys = list(alphabet.keys()) + + data = { + 'alphabet': alphabet_keys, + 'counts': aa_counts + } + save_json(data, f'amino_acid_distribution_{dataset_name}.json') + + def plot_distribution(dataset, feature, dataset_name, transform_func=None, bins=None, xlabel='', ylabel='Frequency', is_sequence=False): + """General function to plot distributions for different features.""" + feature_data = dataset[feature] + if transform_func: + feature_data = transform_func(feature_data) + if is_sequence: + feature_data = [len(seq) for seq in feature_data] + + if is_sequence: + # Define bins to cover the integer range of sequence lengths + actual_bins = np.arange(min(feature_data) - 0.5, max(feature_data) + 1.5, 1) + else: + actual_bins = bins(feature_data) if callable(bins) else bins if bins is not None else 30 + + hist, bin_edges = np.histogram(feature_data, bins=actual_bins) + + data = { + 'hist': hist.tolist(), + 'bin_edges': bin_edges.tolist(), + 'xlabel': xlabel, + 'ylabel': ylabel + } + + if is_sequence: + feature = 'sequence' + + save_json(data, f'{feature}_distribution_{dataset_name}.json') + + eval_datasets = { + 'train': self.config.dataset.hf_dataset['train'], + 'val': self.config.dataset.hf_dataset['val'], + 'test': self.config.dataset.hf_dataset['test'] if 'test' in self.config.dataset.hf_dataset else None + } + + for dataset_name, dataset in eval_datasets.items(): + if dataset: + if 'modified_sequence_raw' in self.config.dataset.dataset_columns_to_keep: + plot_amino_acid_distribution(dataset, self.config.dataset.alphabet, dataset_name) + if 'sequence' in self.config.dataset.dataset_columns_to_keep: + plot_distribution(dataset, 'sequence', dataset_name, is_sequence=True, bins=None, xlabel='Sequence Length') + + plot_distribution(dataset, 'collision_energy_aligned_normed', dataset_name, xlabel='Collision Energy') + # plot_distribution(dataset, 'intensities_raw', dataset_name, lambda x: [i for sub in x for i in sub], xlabel='Intensity') + plot_distribution(dataset, 'precursor_charge_onehot', dataset_name, lambda x: np.argmax(x, axis=1), bins=np.arange(6) - 0.5, xlabel='Precursor Charge') + + def _compile_report(self): + """Creates a visual PDF report from the jupyter notebook in the results folder + """ + with open(self.results_notebook_path, 'r') as notebook_file: + notebook = nbformat.read(notebook_file, as_version=4) + + current_cwd = os.getcwd() + os.chdir(self.config.results_log) + executor = ExecutePreprocessor() + executor.preprocess(notebook) + os.chdir(current_cwd) + + exporter = HTMLExporter() + exporter.exclude_input = True + result, resources = exporter.from_notebook_node(notebook) + + result_path = Path(self.config.results_log) / 'report.html' + with open(result_path, "w") as f: + f.write(result) + + + def train(self): + """Performs the training process and returns the final model. + + Returns: + PrositIntensityPredictor: The refined model that results from the training process. This model can be used for predictions or further training steps. + """ + self._calculate_spectral_angles('before') + self._evaluate_model() + + # Add the batch evaluation callback to the callbacks list + batch_eval_callback = BatchEvaluationCallback(self._evaluate_model, 1000) + self.callbacks.append(batch_eval_callback) + + for instance_config in self.training_schedule: + + self.csv_logger.reset_phase() + + training = AutomaticRlTlTrainingInstance( + instance_config=instance_config, + model=self.model, + dataset=self.config.dataset, + current_epoch_offset=self.current_epoch_offset, + wandb_logging=self.config.use_wandb, + results_log=self.results_data_path, + callbacks=self.callbacks, + validation_steps=self.validation_steps + ) + training.run() + + self.current_epoch_offset = training.current_epoch_offset + self._evaluate_model() + + if self.config.use_wandb: + wandb.finish() + + self._calculate_spectral_angles('after') + self._compile_report() + + return self.model + + + +class AutomaticRlTlTrainingInstance: + + model : PrositIntensityPredictor + dataset : FragmentIonIntensityDataset + instance_config : TrainingInstanceConfig + current_epoch_offset : int + wandb_logging : bool + results_log: str + callbacks : list + validation_steps : Optional[int] + + stopped_early : bool + final_learning_rate : float + inflection_early_stopping : Optional[InflectionPointEarlyStopping] = None + + + def __init__( + self, + instance_config : TrainingInstanceConfig, + model : PrositIntensityPredictor, + current_epoch_offset : int, + dataset : FragmentIonIntensityDataset, + wandb_logging : bool, + results_log: str, + callbacks : list, + validation_steps : Optional[int] + ): + self.instance_config = instance_config + self.model = model + self.dataset = dataset + self.current_epoch_offset = current_epoch_offset + self.wandb_logging = wandb_logging + self.results_log = results_log + self.callbacks = callbacks.copy() + self.validation_steps = validation_steps + + self._configure_training() + + + def _configure_training(self): + + # freezing of old embedding weights + if self.instance_config.freeze_old_embedding_weights: + if self.wandb_logging: + wandb.log({'freeze_old_embedding_weights': 1}) + with open(f'{self.results_log}/freeze_log.csv', 'a') as f: + f.write('freeze_old_embedding_weights,1\n') + change_layers.freeze_old_embeddings(self.model) + else: + if self.wandb_logging: + wandb.log({'freeze_old_embedding_weights': 0}) + with open(f'{self.results_log}/freeze_log.csv', 'a') as f: + f.write('freeze_old_embedding_weights,0\n') + change_layers.release_old_embeddings(self.model) + + # freezing of old regressor weights + if self.instance_config.freeze_old_regressor_weights: + if self.wandb_logging: + wandb.log({'freeze_old_regressor_weights': 1}) + with open(f'{self.results_log}/freeze_log.csv', 'a') as f: + f.write('freeze_old_regressor_weights,1\n') + change_layers.freeze_old_regressor(self.model) + else: + if self.wandb_logging: + wandb.log({'freeze_old_regressor_weights': 0}) + with open(f'{self.results_log}/freeze_log.csv', 'a') as f: + f.write('freeze_old_regressor_weights,0\n') + change_layers.release_old_regressor(self.model) + + + # freezing of inner layers + if self.instance_config.freeze_inner_layers: + freezing.freeze_model( + self.model, + self.instance_config.freeze_whole_embedding_layer, + self.instance_config.freeze_whole_regressor_layer + ) + + if self.wandb_logging: + wandb.log({ + 'freeze_inner_layers': 1, + 'freeze_embedding_layer': 1 if self.instance_config.freeze_whole_embedding_layer else 0, + 'freeze_regressor_layer': 1 if self.instance_config.freeze_whole_regressor_layer else 0 + }) + with open(f'{self.results_log}/freeze_log.csv', 'a') as f: + f.write(f'freeze_inner_layers,1\nfreeze_embedding_layer,{1 if self.instance_config.freeze_whole_embedding_layer else 0}\nfreeze_regressor_layer,{1 if self.instance_config.freeze_whole_regressor_layer else 0}\n\n') + else: + if self.instance_config.freeze_whole_embedding_layer: + raise RuntimeError('Cannot freeze whole embedding layer without freezing inner part of the model.') + if self.instance_config.freeze_whole_regressor_layer: + raise RuntimeError('Cannot freeze whole regressor layer without freezing inner part of the model.') + + freezing.release_model(self.model) + + if self.wandb_logging: + wandb.log({ + 'freeze_inner_layers': 0, + 'freeze_embedding_layer': 0, + 'freeze_regressor_layer': 0 + }) + with open(f'{self.results_log}/freeze_log.csv', 'a') as f: + f.write(f'freeze_inner_layers,0\nfreeze_embedding_layer,0\nfreeze_regressor_layer,0\n\n') + + + if self.instance_config.plateau_early_stopping: + early_stopping = EarlyStopping( + monitor="val_loss", + min_delta=self.instance_config.plateau_early_stopping_min_delta, + patience=self.instance_config.plateau_early_stopping_patience, + restore_best_weights=True) + + self.callbacks.append(early_stopping) + + + if self.instance_config.inflection_early_stopping: + self.inflection_early_stopping = InflectionPointEarlyStopping( + min_improvement=self.instance_config.inflection_early_stopping_min_improvement, + patience=self.instance_config.inflection_early_stopping_patience, + ignore_first_n=self.instance_config.inflection_early_stopping_ignore_first_n, + wandb_log=self.wandb_logging + ) + + self.callbacks.append(self.inflection_early_stopping) + + + if self.instance_config.inflection_lr_reducer: + reduce_lr = InflectionPointLRReducer( + factor=self.instance_config.inflection_lr_reducer_factor, + patience=self.instance_config.inflection_lr_reducer_patience, + min_improvement=self.instance_config.inflection_lr_reducer_min_improvement, + wandb_log=self.wandb_logging + ) + + self.callbacks.append(reduce_lr) + + if self.instance_config.lr_warmup: + lr_warmup_linear = LearningRateWarmupPerStep( + num_steps=self.instance_config.lr_warmup_num_steps, + start_lr=self.instance_config.lr_warmup_start_lr, + end_lr=self.instance_config.learning_rate + ) + self.callbacks.append(lr_warmup_linear) + + + def run(self): + + # perform all training runs + optimizer = tf.keras.optimizers.Adam(learning_rate=self.instance_config.learning_rate) + self.model.compile( + optimizer=optimizer, + loss=masked_spectral_distance, + metrics=[masked_pearson_correlation_distance] + ) + + # train model + history = self.model.fit( + self.dataset.tensor_train_data, + validation_data=self.dataset.tensor_val_data, + validation_steps=self.validation_steps, + epochs=self.instance_config.num_epochs, + callbacks=self.callbacks, + verbose=0 + ) + + inflection_ES_stopped = self.inflection_early_stopping is not None and self.inflection_early_stopping.stopped_early + if len(history.history['loss']) < self.instance_config.num_epochs or inflection_ES_stopped: + self.stopped_early = True + else: + self.stopped_early = False + + self.final_learning_rate = self.model.optimizer._learning_rate.numpy() + self.current_epoch_offset += len(history.history['loss']) + diff --git a/src/dlomix/refinement_transfer_learning/change_layers.py b/src/dlomix/refinement_transfer_learning/change_layers.py new file mode 100644 index 00000000..3367938f --- /dev/null +++ b/src/dlomix/refinement_transfer_learning/change_layers.py @@ -0,0 +1,168 @@ +import tensorflow as tf +from dlomix.models import PrositIntensityPredictor +from tensorflow.keras.constraints import Constraint +import keras.backend as K +import keras +from dlomix.models import PrositIntensityPredictor + + +@keras.saving.register_keras_serializable() +class FixRegressorWeights(Constraint): + def __init__(self, old_weights, old_fions): + self.old_weights = old_weights + self.freeze_weights = True + self.old_fions = old_fions + def __call__(self, w): + if self.freeze_weights: + return K.concatenate([self.old_weights, w[:, self.old_fions:]], axis=1) + return w + + +@keras.saving.register_keras_serializable() +class FixBias(Constraint): + def __init__(self, old_bias, old_fions): + self.old_bias = old_bias + self.freeze_bias = True + self.old_fions = old_fions + def __call__(self, b): + if self.freeze_bias: + return K.concatenate([self.old_bias, b[self.old_fions:]], axis=0) + return b + + +def change_output_layer(model: PrositIntensityPredictor, number_of_ions: int = 2, freeze_old_output: bool = False) -> None: + """ + Change the output layer of a PrositItensityPredictor model + This means changing the number of predicted ions to the number of ion types in the dataset. + The default PrositIntensityPredictor predicts two ion types (y- and b-ions). + If the number of ions is not given, this function will replace the output layer with a randomly initialized layer the same dimensions as before. + + If the number of ions changes to for example 4, the regressor will have an output dimension of: + (batch_size, number_of_ions * charge_states * possible ions) = (batch_size, 4 * 3 * 29) = (batch_size, 348) + After changing the output layer, the models needs to be compiled again before training. + + It is possible to fix the old weights and the old bias of the regressor layer before reinitializing the regressor layer. + To do so, set the freeze_old_output to 'True' + + Args: + model (PrositIntensityPredictor): the model where the output layer changes + number_of_ions (int, optional): Number of ions the model should be able to predict. Defaults to 2. + freeze_old_output (bool, optional): Specify if the pre-trained regressor weight should be kept in place if reinitializing the embedding layer + """ + kernel_constraint = None + bias_constraint = None + if freeze_old_output: + old_weights = model.regressor.get_layer('time_dense').get_weights()[0] + old_bias = model.regressor.get_layer('time_dense').get_weights()[1] + + kernel_constraint = FixRegressorWeights(old_weights, model.len_fion) + bias_constraint = FixBias(old_bias, model.len_fion) + + model.len_fion = 3 * number_of_ions + model.regressor = tf.keras.Sequential( + [ + tf.keras.layers.TimeDistributed( + tf.keras.layers.Dense( + model.len_fion, + kernel_constraint=kernel_constraint, + bias_constraint=bias_constraint + ), name='time_dense' + ), + tf.keras.layers.LeakyReLU(name='activation'), + tf.keras.layers.Flatten(name='out')], + name='regressor' + ) + + +def release_old_regressor(model: PrositIntensityPredictor): + """Function to release the pre-trained regressor of a re-initialized regressor layer of the Prosit model + The freeze_weights parameter changes the constraint, so that the weights do not get overwritten by the old weights + In theory, the regressor weights and bias can be frozen again. + + Args: + model (PrositIntensityPredictor): the model where to release the regressor + """ + if model.regressor.get_layer('time_dense').layer.kernel_constraint is not None: + model.regressor.get_layer('time_dense').layer.kernel_constraint.freeze_weights = False + model.regressor.get_layer('time_dense').layer.bias_constraint.freeze_weights = False + + +def freeze_old_regressor(model: PrositIntensityPredictor): + """Function to freeze the pre-trained regressor of a re-initialized regressor layer of the Prosit model + The freeze_weights parameter changes the constraint, so that the weights do not get overwritten by the old weights. + + Args: + model (PrositIntensityPredictor): the model where to freeze the regressor + """ + if model.regressor.get_layer('time_dense').layer.kernel_constraint is not None: + model.regressor.get_layer('time_dense').layer.kernel_constraint.freeze_weights = False + model.regressor.get_layer('time_dense').layer.bias_constraint.freeze_weights = False + + +@keras.saving.register_keras_serializable() +class FixWeights(Constraint): + def __init__(self, old_weights, max_old_embedding): + self.old_weights = old_weights + self.freeze_weights = True + self.max_embedding_value = max_old_embedding + def __call__(self, w): + if self.freeze_weights: + return K.concatenate([self.old_weights[:self.max_embedding_value + 1], w[self.max_embedding_value + 1:]], axis=0) + return w + + +def change_input_layer(model: PrositIntensityPredictor, modifications: list = None, freeze_old_embeds: bool = False) -> None: + """Change the input layer of a PrositIntensityPredictor model + This means changing the number of embeddings the Embedding layer can produce. This is directly tied to the size of the alphabet of the model. + A list of new modifications the model should support is given and the modifications are added to the alphabet, increasing its size. + If no new modifications are given, the weights for the Embedding layer are re-initialized. + + This function also allows the user to freeze the old embedding weights trained by the loaded model, + meaning it only allows changing the weights for the embeddings of the new modifications. + + After changing the input layer, the models needs to be compiled again before training. + + Args: + model (PrositIntensityPredictor): The model, where the input layers needs to be changed + modifications (list, optional): List of modifications the model should support. Defaults to None. + freeze_old_embeds (bool): If set to True, the old embeddings of the loaded model are not changed during training. Defaults to False. + """ + old_embedding_max = max(model.alphabet.values()) + if modifications: + new_modifications = set(modifications) - set(model.alphabet.keys()) + model.alphabet.update({k: i for i, k in enumerate(new_modifications, start=len(model.alphabet) + 1)}) + + embeddings_constraint = None + if freeze_old_embeds: + # if added names to the model, replace get_layer index with name + trained_embeds_weights = model.layers[0].get_weights()[0] + embeddings_constraint = FixWeights(trained_embeds_weights, max_old_embedding=old_embedding_max) + + model.embedding = tf.keras.layers.Embedding( + input_dim=len(model.alphabet) + 2, + output_dim=model.embedding_output_dim, + input_length=model.seq_length, + embeddings_constraint=embeddings_constraint, + name='embedding' + ) + +def release_old_embeddings(model: PrositIntensityPredictor): + """Function to release the pre-trained embeddings of a re-initialized embedding layer of the Prosit model + The freeze_weights parameter changes the constraint, so that the weights do not get overwritten by the old weights + In theory, the embeddings can be frozen again. + + Args: + model (PrositIntensityPredictor): model with a changed embedding layer named 'embedding' + """ + if model.embedding.embeddings_constraint is not None: + model.embedding.embeddings_constraint.freeze_weights = False + +def freeze_old_embeddings(model: PrositIntensityPredictor): + """Function to freeze the pre-trained embeddings of a re-initialized embedding layer of the Prosit model + The freeze_weights parameter changes the constraint, so that the weights do not get overwritten by the old weights. + + Args: + model (PrositIntensityPredictor): model with a changed embedding layer named 'embedding' + """ + if model.embedding.embeddings_constraint is not None: + model.embedding.embeddings_constraint.freeze_weights = True \ No newline at end of file diff --git a/src/dlomix/refinement_transfer_learning/custom_callbacks.py b/src/dlomix/refinement_transfer_learning/custom_callbacks.py new file mode 100644 index 00000000..4aa7c486 --- /dev/null +++ b/src/dlomix/refinement_transfer_learning/custom_callbacks.py @@ -0,0 +1,256 @@ +import tensorflow as tf +from tensorflow.keras.callbacks import Callback +from typing import Optional +import math + +class CustomCSVLogger(tf.keras.callbacks.Callback): + def __init__(self, filename, separator=',', append=True): + super().__init__() + self.filename = filename + self.separator = separator + self.append = append + self.file_writer = None + self.keys = ['phase', 'epoch', 'batch', 'learning_rate', 'loss', 'masked_pearson_correlation_distance', 'val_loss', 'val_masked_pearson_correlation_distance'] + self.epoch = 0 + self.batch_counter = 0 + self.val_loss = None + self.val_masked_pearson_correlation_distance = None + self.phase = 0 + + def on_train_begin(self, logs=None): + mode = 'a' if self.append else 'w' + self.file_writer = open(self.filename, mode) + # Set up headers if file is empty + if not self.append or self.file_writer.tell() == 0: + header = self.separator.join(self.keys) + self.file_writer.write(header + '\n') + + def on_batch_end(self, batch, logs=None): + logs = logs or {} + self.batch_counter += 1 + logs['phase'] = self.phase + logs['epoch'] = self.epoch + logs['batch'] = self.batch_counter + logs['learning_rate'] = float(tf.keras.backend.get_value(self.model.optimizer.lr)) + + # Ensure all keys are present, even if some values are missing + data_to_log = {key: logs.get(key, '') for key in self.keys} + data_to_log['val_loss'] = self.val_loss + data_to_log['val_masked_pearson_correlation_distance'] = self.val_masked_pearson_correlation_distance + + # Write the log data for the current batch + row = [str(data_to_log.get(key, '')) for key in self.keys] + row_line = self.separator.join(row) + self.file_writer.write(row_line + '\n') + self.file_writer.flush() + + def on_epoch_end(self, epoch, logs=None): + self.epoch += 1 + + def on_train_end(self, logs=None): + if self.file_writer: + self.file_writer.close() + + def reset_phase(self): + """Resets the epoch counter and increments the phase.""" + self.phase += 1 + + def set_validation_metrics(self, val_loss, val_masked_pearson_correlation_distance): + self.val_loss = val_loss + self.val_masked_pearson_correlation_distance = val_masked_pearson_correlation_distance + + + +class BatchEvaluationCallback(tf.keras.callbacks.Callback): + def __init__(self, evaluate_model_func, batch_interval): + super().__init__() + self.evaluate_model_func = evaluate_model_func + self.batch_interval = batch_interval + self.batch_counter = 0 + + def on_batch_end(self, batch, logs=None): + self.batch_counter += 1 + if self.batch_counter % self.batch_interval == 0: + self.evaluate_model_func() + + +class OverfittingEarlyStopping(tf.keras.callbacks.Callback): + max_validation_train_difference : float + patience : int + wandb_log : bool + + patience_counter : int = 0 + + def __init__(self, max_validation_train_difference, patience, wandb_log): + super().__init__() + self.max_validation_train_difference = max_validation_train_difference + self.patience = patience + self.wandb_log = wandb_log + + if self.wandb_log: + global wandb + import wandb + + def on_epoch_end(self, epoch, logs=None): + train_loss = logs['loss'] + val_loss = logs['val_loss'] + + if not math.isfinite(train_loss) or not math.isfinite(val_loss): + self.model.stop_training = True + return + + if self.wandb_log: + wandb.log({ + 'overfitting_early_stopping_loss_diff': val_loss - train_loss, + 'overfitting_early_stopping_current_patience': self.patience_counter / self.patience + }) + + if val_loss - train_loss < self.max_validation_train_difference: + # difference is within allowed margin + self.patience_counter = max(0, self.patience_counter - max(2, math.ceil(0.1 * self.patience))) + return + + # difference is too high + self.patience_counter += 1 + if self.patience_counter > self.patience: + self.model.stop_training = True + + + +class InflectionPointDetector: + min_improvement : float + patience : int + ignore_first_n : int + smoothing_window : int + wandb_log : bool + wandb_log_name : str = 'InflectionPointDetector' + + change_sum : float = 0 + num_steps : int = 0 + previous_loss : float = 0 + global_min : float = float('inf') + initial_loss : float = None + patience_counter : int = 0 + current_changes : list[float] = [] + + def __init__(self, min_improvement : float, patience : int, ignore_first_n : int = 0, wandb_log : bool = False): + self.min_improvement = min_improvement + self.patience = patience + self.ignore_first_n = ignore_first_n + self.smoothing_window = 3 * patience + self.wandb_log = wandb_log + + if self.wandb_log: + global wandb + import wandb + + def reset_detector(self): + # self.change_sum = 0 + # self.num_steps = 0 + self.patience_counter = 0 + # self.current_changes = [] + + + def inflection_reached(self, loss : float): + if loss < self.global_min: + self.global_min = loss + + if self.initial_loss is None: + self.initial_loss = loss + + loss = self.initial_loss - self.global_min + + change = loss - self.previous_loss + self.change_sum += change + self.num_steps += 1 + + self.current_changes.append(change) + if len(self.current_changes) > self.smoothing_window: + self.current_changes.pop(0) + change = sum(self.current_changes) / len(self.current_changes) + + avg_change = self.change_sum / self.num_steps + + if self.wandb_log: + wandb.log({ + f'{self.wandb_log_name}_avg_change': avg_change, + f'{self.wandb_log_name}_current_change': change, + f'{self.wandb_log_name}_current_patience': self.patience_counter / self.patience + }) + + self.previous_loss = loss + + if self.num_steps < self.ignore_first_n: + return False + + if self.num_steps > self.patience: + # enough datapoints to do estimation + if change < self.min_improvement and change < avg_change: + # we are likely after the inflection point and have a low avg change + self.patience_counter += 1 + + if self.patience_counter >= self.patience: + return True + else: + self.patience_counter = 0 + + return False + + +class InflectionPointEarlyStopping(tf.keras.callbacks.Callback, InflectionPointDetector): + stopped_early : bool = False + + def __init__(self, *args, **kwargs): + InflectionPointDetector.__init__(self, *args, **kwargs) + tf.keras.callbacks.Callback.__init__(self) + self.wandb_log_name = 'InflectionPointEarlyStopping' + + def on_train_batch_end(self, batch, logs): + loss = logs['loss'] + + if self.inflection_reached(loss): + self.stopped_early = True + self.model.stop_training = True + +class InflectionPointLRReducer(tf.keras.callbacks.Callback, InflectionPointDetector): + factor : float + + def __init__(self, factor : float, *args, **kwargs): + InflectionPointDetector.__init__(self, *args, **kwargs) + tf.keras.callbacks.Callback.__init__(self) + self.wandb_log_name = 'InflectionPointLRReducer' + self.factor = factor + + def on_train_batch_end(self, batch, logs): + loss = logs['loss'] + + if self.inflection_reached(loss): + lr = self.model.optimizer.lr.read_value() + lr *= self.factor + self.model.optimizer.lr.assign(lr) + self.reset_detector() + + +class LearningRateWarmupPerStep(tf.keras.callbacks.Callback): + num_steps : int + start_lr : float + end_lr : float + + steps_counter : int = 0 + + def __init__(self, num_steps : int, start_lr : float, end_lr : float, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.num_steps = num_steps + self.start_lr = start_lr + self.end_lr = end_lr + + def on_train_batch_begin(self, batch, logs): + lr = self.model.optimizer.lr.read_value() + if self.steps_counter < self.num_steps: + factor = self.steps_counter / self.num_steps + # lr = factor * self.end_lr + (1-factor) * self.start_lr + lr = self.end_lr ** factor * self.start_lr ** (1-factor) + self.model.optimizer.lr.assign(lr) + + self.steps_counter += 1 \ No newline at end of file diff --git a/src/dlomix/refinement_transfer_learning/freezing.py b/src/dlomix/refinement_transfer_learning/freezing.py new file mode 100644 index 00000000..39358d31 --- /dev/null +++ b/src/dlomix/refinement_transfer_learning/freezing.py @@ -0,0 +1,61 @@ +from dlomix.models import PrositIntensityPredictor + + + + +# function to freeze all layers except first and/or last layer +def freeze_model(model:PrositIntensityPredictor, trainable_first_layer:bool = False, trainable_last_layer:bool = False) -> None: + ''' Freezes all layers of a PrositIntensityPredictor and keep first and/or last layer trainable. + + First setting the whole model to trainable because this attribute overshadows the trainable attribute of every sublayer. + Then iterating through all sublayers and sets the trainable attribute of every layer to 'False', model is now frozen. + Next, setting the trainable attribute of either the first embedding layer or the last time density layer to trainable. + + Parameter + --------- + model : dlomix.models.prosit.PrositIntensityPredictor + The model to be frozen. + trainable_first_layer : bool + Whether the first layer should remain trainable. + trainable_last_layer : bool + Whether the last layer should remain trainable + -------- + + ''' + + model.trainable = True + for lay in model.layers: + try: + for sublay in lay.layers: + sublay.trainable = False + except (AttributeError): + lay.trainable = False + + if (trainable_first_layer): + first_layer = model.get_layer(name="embedding") + first_layer.trainable = True + + if (trainable_last_layer): + last_layer = model.regressor.get_layer(name = "time_dense") + last_layer.trainable = True + + +def release_model(model:PrositIntensityPredictor) -> None: + '''Unfreezes all layers of a PrositIntensityPredictor model. + + Sets the trainable attribute of every layer to 'True'. + + Parameter + --------- + model : dlomix.models.prosit.PrositIntensityPredictor + The model to be unfrozen. + -------- + ''' + model.trainable = True + + for lay in model.layers: + try: + for sublay in lay.layers: + sublay.trainable = True + except (AttributeError): + lay.trainable = True \ No newline at end of file diff --git a/src/dlomix/refinement_transfer_learning/user_report.ipynb b/src/dlomix/refinement_transfer_learning/user_report.ipynb new file mode 100644 index 00000000..8c1f1c68 --- /dev/null +++ b/src/dlomix/refinement_transfer_learning/user_report.ipynb @@ -0,0 +1,341 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## User Report for Refinement and Transfer Learning\n", + "\n", + "This notebook provides a detailed analysis of data exploration and training results for an automatic refinement and transfer learning pipeline. It includes visualizations of key dataset features like amino acid distribution, sequence lengths, collision energy, and intensity values for train, validation, and test sets. Spectral angle distributions are calculated and saved before and after training, enabling comparison of model performance. Training results, including learning curves and performance metrics, are documented to highlight improvements through the training stages." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import json\n", + "import seaborn as sns\n", + "import pandas as pd\n", + "from matplotlib import pyplot as plt\n", + "from pathlib import Path\n", + "\n", + "LOGGING_DIR = Path(os.getcwd()) / \"log_data\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Data exploration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def plot_json_data(results_log):\n", + " def load_json(filename):\n", + " with open(filename, 'r') as f:\n", + " return json.load(f)\n", + "\n", + " def plot_amino_acid_distribution(datasets=['train', 'val', 'test']):\n", + " fig, axes = plt.subplots(len(datasets), 1, figsize=(18, 5 * len(datasets)), sharey=True)\n", + " for i, dataset in enumerate(datasets):\n", + " file_path = results_log / f'amino_acid_distribution_{dataset}.json'\n", + " if os.path.exists(file_path):\n", + " data = load_json(file_path)\n", + " alphabet_keys = data['alphabet']\n", + " aa_counts = data['counts']\n", + "\n", + " # Filter out amino acids with a count of 0\n", + " filtered_keys_counts = [(k, c) for k, c in zip(alphabet_keys, aa_counts) if c > 0]\n", + " filtered_keys = [k for k, c in filtered_keys_counts]\n", + " filtered_counts = [c for k, c in filtered_keys_counts]\n", + "\n", + " axes[i].bar(filtered_keys, filtered_counts, edgecolor='black')\n", + " axes[i].set_title(f'{dataset.capitalize()} Set')\n", + " axes[i].set_xlabel('Amino Acid')\n", + " axes[i].set_ylabel('Frequency')\n", + " axes[i].tick_params(axis='x', rotation=45) # Rotate x-axis labels\n", + "\n", + " plt.tight_layout()\n", + " plt.show()\n", + "\n", + " def plot_distribution(feature, xlabel='', ylabel='Frequency', is_sequence=False, transform_func=None):\n", + " datasets = ['train', 'val', 'test']\n", + " fig, axes = plt.subplots(1, len(datasets), figsize=(18, 6), sharey=True)\n", + " \n", + " for i, dataset in enumerate(datasets):\n", + " file_path = results_log / f'{feature}_distribution_{dataset}.json'\n", + " if os.path.exists(file_path):\n", + " data = load_json(file_path)\n", + " feature_data = data['hist']\n", + " bin_edges = data['bin_edges']\n", + " axes[i].hist(bin_edges[:-1], bins=bin_edges, weights=feature_data, edgecolor='black')\n", + " axes[i].set_title(f'{dataset.capitalize()} Set')\n", + " axes[i].set_xlabel(xlabel)\n", + " if i == 0:\n", + " axes[i].set_ylabel(ylabel)\n", + "\n", + " fig.suptitle(f'{xlabel.title()} Distribution')\n", + " plt.tight_layout()\n", + " plt.show()\n", + "\n", + " plot_amino_acid_distribution()\n", + " plot_distribution('collision_energy_aligned_normed', xlabel='Collision Energy')\n", + " # plot_distribution('intensities_raw', xlabel='Intensity')\n", + " plot_distribution('sequence', xlabel='Sequence Length')\n", + " plot_distribution('precursor_charge_onehot', xlabel='Precursor Charge')\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Call the function to plot the data\n", + "plot_json_data(LOGGING_DIR)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training results\n", + "\n", + "##### Spectral Angle Distribution" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def plot_spectral_angle_distributions(results_log, datasets=['train', 'val', 'test']):\n", + " \"\"\"\n", + " Reads the spectral angle distributions from JSON files and plots them before and after training.\n", + "\n", + " Args:\n", + " results_log: Directory where the JSON files are saved.\n", + " datasets: A list of strings indicating which datasets to plot ('train', 'val', 'test').\n", + "\n", + " Returns:\n", + " None (plots the distributions)\n", + " \"\"\"\n", + " def load_json(filename):\n", + " with open(filename, 'r') as f:\n", + " return json.load(f)\n", + " \n", + " fig_before, axes_before = plt.subplots(1, len(datasets), figsize=(18, 6), sharey=True)\n", + " fig_after, axes_after = plt.subplots(1, len(datasets), figsize=(18, 6), sharey=True)\n", + " \n", + " for ax_before, ax_after, dataset in zip(axes_before, axes_after, datasets):\n", + " filename = results_log / f'spectral_angle_distribution_{dataset}.json'\n", + " if os.path.exists(filename):\n", + " data = load_json(filename)\n", + " spectral_angles_before = data['before']['spectral_angles']\n", + " avg_sa_before = data['before']['average_spectral_angle']\n", + " spectral_angles_after = data['after']['spectral_angles']\n", + " avg_sa_after = data['after']['average_spectral_angle']\n", + "\n", + " # Plot before training\n", + " ax_before.hist(spectral_angles_before, bins=30, alpha=0.75, edgecolor='black')\n", + " ax_before.axvline(avg_sa_before, color='r', linestyle='dashed', linewidth=1)\n", + " ax_before.text(ax_before.get_xlim()[1] * 0.3, ax_before.get_ylim()[1] * 0.9, f'Avg. SA = {avg_sa_before:.2f}', color='r')\n", + " ax_before.set_title(f'{dataset.capitalize()} Set Before Training')\n", + " ax_before.set_xlabel('Spectral Angle')\n", + " if dataset == datasets[0]:\n", + " ax_before.set_ylabel('Frequency')\n", + "\n", + " # Plot after training\n", + " ax_after.hist(spectral_angles_after, bins=30, alpha=0.75, edgecolor='black')\n", + " ax_after.axvline(avg_sa_after, color='r', linestyle='dashed', linewidth=1)\n", + " ax_after.text(ax_after.get_xlim()[1] * 0.3, ax_after.get_ylim()[1] * 0.9, f'Avg. SA = {avg_sa_after:.2f}', color='r')\n", + " ax_after.set_title(f'{dataset.capitalize()} Set After Training')\n", + " ax_after.set_xlabel('Spectral Angle')\n", + " if dataset == datasets[0]:\n", + " ax_after.set_ylabel('Frequency')\n", + " else:\n", + " print(f\"No data found for {dataset} set.\")\n", + "\n", + " plt.tight_layout()\n", + " plt.show()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_spectral_angle_distributions(LOGGING_DIR)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The spectral angle distribution is computed using the first 1000 batches for each dataset type (training, validation, and test) to provide a representative overview. Calculating the distribution across the entire dataset would result in excessive runtime; therefore, using 1000 batches serves as a practical and efficient approximation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Training process evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def plot_data(data):\n", + " \"\"\"\n", + " Plots specified columns in the provided DataFrame, indicating the start of new phases and labeling them as they appear in the data.\n", + " \"\"\"\n", + " sns.set_theme(style=\"whitegrid\")\n", + " fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(15, 15))\n", + " fig.subplots_adjust(hspace=0.4, wspace=0.4)\n", + "\n", + " unique_phases = data['phase'].unique()\n", + " colors = ['blue', 'green', 'brown']\n", + " \n", + " # Plot Loss and Validation Loss with new colors\n", + " sns.lineplot(ax=axes[0, 0], data=data, x='batch', y='loss', color='#1f77b4') # Deep Blue for Training\n", + " sns.lineplot(ax=axes[0, 0], data=data, x='batch', y='val_loss', color='#b41f4c') # Soft Red for Validation\n", + " for i, phase in enumerate(unique_phases):\n", + " pt_indices = data.index[data['phase'] == phase].tolist()\n", + " if pt_indices:\n", + " pt = pt_indices[0]\n", + " axes[0, 0].axvline(x=data['batch'].iloc[pt], color=colors[i % len(colors)], linestyle='--', alpha=0.5)\n", + " axes[0, 0].set_title('Loss and Validation Loss')\n", + " axes[0, 0].set_xlabel('Batches')\n", + " axes[0, 0].set_ylabel('Loss')\n", + "\n", + " # Plot Masked Pearson Correlation Distance and Validation Masked Pearson Correlation Distance\n", + " sns.lineplot(ax=axes[0, 1], data=data, x='batch', y='masked_pearson_correlation_distance', color='#1f77b4')\n", + " sns.lineplot(ax=axes[0, 1], data=data, x='batch', y='val_masked_pearson_correlation_distance', color='#b41f4c')\n", + " for i, phase in enumerate(unique_phases):\n", + " pt_indices = data.index[data['phase'] == phase].tolist()\n", + " if pt_indices:\n", + " pt = pt_indices[0]\n", + " axes[0, 1].axvline(x=data['batch'].iloc[pt], color=colors[i % len(colors)], linestyle='--', alpha=0.5)\n", + " axes[0, 1].set_title('Masked Pearson Correlation Distance and Validation')\n", + " axes[0, 1].set_xlabel('Batches')\n", + " axes[0, 1].set_ylabel('Masked Pearson Correlation Distance')\n", + "\n", + " # Plot Learning Rate \n", + " sns.lineplot(ax=axes[1, 0], data=data, x='batch', y='learning_rate', color='#4d4d4d') \n", + " for i, phase in enumerate(unique_phases):\n", + " pt_indices = data.index[data['phase'] == phase].tolist()\n", + " if pt_indices:\n", + " pt = pt_indices[0]\n", + " axes[1, 0].axvline(x=data['batch'].iloc[pt], color=colors[i % len(colors)], linestyle='--', alpha=0.5)\n", + " axes[1, 0].set_title('Learning Rate')\n", + " axes[1, 0].set_xlabel('Batches')\n", + " axes[1, 0].set_ylabel('Learning Rate')\n", + " axes[1, 0].set_yscale('log')\n", + "\n", + " # Remove the unused fourth subplot\n", + " fig.delaxes(axes[1, 1])\n", + "\n", + " # Create a custom legend\n", + " handles = [\n", + " plt.Line2D([0], [0], color='#1f77b4', lw=2, label='Training'),\n", + " plt.Line2D([0], [0], color='#b41f4c', lw=2, label='Validation')\n", + " ]\n", + " # Add phase transitions to the legend\n", + " for i, phase in enumerate(unique_phases):\n", + " handles.append(plt.Line2D([0], [0], color=colors[i % len(colors)], linestyle='--', lw=2, label=f'Phase {phase}'))\n", + "\n", + " # Add overall legend\n", + " fig.legend(handles=handles, loc='upper center', ncol=3)\n", + "\n", + " plt.show()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "logged_training_results = pd.read_csv(LOGGING_DIR / 'training_log.csv')\n", + "plot_data(logged_training_results)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if logged_training_results.iloc[-1]['val_loss'] > 0.2:\n", + " val_loss_value = logged_training_results.iloc[-1]['val_loss']\n", + " print(f\"The model didn't learn enough about the data, leading to a validation loss higher than 0.2. The current validation loss is {val_loss_value:.6f}.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Freezing per phase" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load training log to find out the starting phase\n", + "training_log = pd.read_csv(LOGGING_DIR / 'training_log.csv')\n", + "start_phase = training_log['phase'].min()\n", + "\n", + "# Load freezing log\n", + "logged_freezing = pd.read_csv(LOGGING_DIR / 'freeze_log.csv', header=None, names=[\"freezing\", \"status\"], skip_blank_lines=False)\n", + "\n", + "# Assign phases based on the start phase from training log\n", + "logged_freezing['Phase'] = logged_freezing['freezing'].isna().cumsum() + start_phase\n", + "logged_freezing = logged_freezing.dropna().reset_index(drop=True)\n", + "\n", + "# Adjust phase labeling\n", + "logged_freezing['Phase'] = logged_freezing['Phase'].apply(lambda x: f\"Phase {x}\")\n", + "logged_freezing['status'] = logged_freezing['status'].astype(int)\n", + "\n", + "# Display the processed data\n", + "logged_freezing" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "dlomix", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/dlomix/reports/Report.py b/src/dlomix/reports/Report.py index c9705b0d..52613fab 100644 --- a/src/dlomix/reports/Report.py +++ b/src/dlomix/reports/Report.py @@ -1,5 +1,6 @@ import abc import glob +import logging import warnings from os import makedirs from os.path import join @@ -8,6 +9,9 @@ from fpdf import FPDF from matplotlib import pyplot as plt +logger = logging.getLogger(__name__) +logging.captureWarnings(True) + class Report(abc.ABC): """Base class for reports. diff --git a/src/dlomix/reports/RetentionTimeReport.py b/src/dlomix/reports/RetentionTimeReport.py index 5bfa51bc..c6c08adb 100644 --- a/src/dlomix/reports/RetentionTimeReport.py +++ b/src/dlomix/reports/RetentionTimeReport.py @@ -1,5 +1,5 @@ +import logging from os.path import join -from warnings import warn import numpy as np from matplotlib import pyplot as plt @@ -8,6 +8,9 @@ from ..reports.Report import PDFFile, Report +logger = logging.getLogger(__name__) +logging.captureWarnings(True) + class RetentionTimeReport(Report): """Report generation for Retention Time Prediction tasks."""