From a99f2ae0c9474ce63276de22317fe2628256cd87 Mon Sep 17 00:00:00 2001 From: Kevin Eloff Date: Fri, 11 Oct 2024 11:03:48 +0200 Subject: [PATCH] feat: update notebook, add charge check, fix sdpa (#61) --- README.md | 49 ++++- instanovo/configs/inference/default.yaml | 1 + instanovo/transformer/model.py | 18 +- instanovo/transformer/predict.py | 20 ++ instanovo/transformer/train.py | 79 +++++--- .../getting_started_with_instanovo.ipynb | 183 ++++++++++++------ 6 files changed, 252 insertions(+), 98 deletions(-) diff --git a/README.md b/README.md index a7b4ee1..e16b92b 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,7 @@ conda env create -f environment.yml conda activate instanovo ``` -Note: InstaNovo is built for Python >= 3.8, <3.12 and tested on Linux and Windows. +Note: InstaNovo is built for Python >= 3.10, <3.12 and tested on Linux. ### Training @@ -115,6 +115,46 @@ The configuration file for inference may be found under Note: the `denovo=True/False` flag controls whether metrics will be calculated. +### Models + +InstaNovo 1.0.0 includes a new model `instanovo_extended.ckpt` trained on a larger dataset with more +PTMs + +**Training Datasets** + +- [ProteomeTools](https://www.proteometools.org/) Part + [I (PXD004732)](https://www.ebi.ac.uk/pride/archive/projects/PXD004732), + [II (PXD010595)](https://www.ebi.ac.uk/pride/archive/projects/PXD010595), and + [III (PXD021013)](https://www.ebi.ac.uk/pride/archive/projects/PXD021013) \ + (referred to as the all-confidence ProteomeTools `AC-PT` dataset in our paper) +- Additional PRIDE dataset with more modifications: \ + ([PXD000666](https://www.ebi.ac.uk/pride/archive/projects/PXD000666), [PXD000867](https://www.ebi.ac.uk/pride/archive/projects/PXD000867), + [PXD001839](https://www.ebi.ac.uk/pride/archive/projects/PXD001839), [PXD003155](https://www.ebi.ac.uk/pride/archive/projects/PXD003155), + [PXD004364](https://www.ebi.ac.uk/pride/archive/projects/PXD004364), [PXD004612](https://www.ebi.ac.uk/pride/archive/projects/PXD004612), + [PXD005230](https://www.ebi.ac.uk/pride/archive/projects/PXD005230), [PXD006692](https://www.ebi.ac.uk/pride/archive/projects/PXD006692), + [PXD011360](https://www.ebi.ac.uk/pride/archive/projects/PXD011360), [PXD011536](https://www.ebi.ac.uk/pride/archive/projects/PXD011536), + [PXD013543](https://www.ebi.ac.uk/pride/archive/projects/PXD013543), [PXD015928](https://www.ebi.ac.uk/pride/archive/projects/PXD015928), + [PXD016793](https://www.ebi.ac.uk/pride/archive/projects/PXD016793), [PXD017671](https://www.ebi.ac.uk/pride/archive/projects/PXD017671), + [PXD019431](https://www.ebi.ac.uk/pride/archive/projects/PXD019431), [PXD019852](https://www.ebi.ac.uk/pride/archive/projects/PXD019852), + [PXD026910](https://www.ebi.ac.uk/pride/archive/projects/PXD026910), [PXD027772](https://www.ebi.ac.uk/pride/archive/projects/PXD027772)) +- Additional phosphorylation dataset \ + (not yet publicly released) + +**Natively Supported Modifications** + +- Oxidation of methionine +- Cysteine alkylation / Carboxyamidomethylation +- Asparagine and glutamine deamidation +- Serine, Threonine, and Tyrosine phosphorylation +- N-terminal ammonia loss +- N-terminal carbamylation +- N-terminal acetylation + +See residue configuration under +[instanovo/configs/residues/extended.yaml](./instanovo/configs/residues/extended.yaml) + +## Additional features + ### Spectrum Data Class InstaNovo introduces a Spectrum Data Class: [SpectrumDataFrame](./instanovo/utils/data_handler.py). @@ -196,7 +236,7 @@ lazy_df = sdf.to_polars(return_lazy=True) # Returns a pl.LazyFrame sdf.write_mgf("path/to/output.mgf") ``` -**Additional Features:** +**SpectrumDataFrame Features:** - The SpectrumDataFrame supports lazy loading with asynchronous prefetching, mitigating wait times between files. @@ -291,3 +331,8 @@ The model checkpoints are licensed under Creative Commons Non-Commercial journal = {bioRxiv} } ``` + +## Acknowledgements + +Big thanks to Pathmanaban Ramasamy, Tine Claeys, and Lennart Martens for providing us with +additional phosphorylation training data. diff --git a/instanovo/configs/inference/default.yaml b/instanovo/configs/inference/default.yaml index b43b64f..93f85bc 100644 --- a/instanovo/configs/inference/default.yaml +++ b/instanovo/configs/inference/default.yaml @@ -11,6 +11,7 @@ data_type: # .csv, .mgf, .mzml, .mzxml denovo: False num_beams: 1 # 1 defaults to greedy search with basic filtering max_length: 40 +max_charge: 10 # Must be <= model max charge isotope_error_range: [0, 1] use_knapsack: False save_beams: False diff --git a/instanovo/transformer/model.py b/instanovo/transformer/model.py index dd9316e..d62de50 100644 --- a/instanovo/transformer/model.py +++ b/instanovo/transformer/model.py @@ -10,8 +10,6 @@ from omegaconf import DictConfig from torch import nn from torch import Tensor -from torch.nn.attention import sdpa_kernel -from torch.nn.attention import SDPBackend from instanovo.constants import MAX_SEQUENCE_LENGTH from instanovo.transformer.layers import ConvPeakEmbedding @@ -354,6 +352,14 @@ def _flash_encoder( latent_spectra = self.latent_spectrum.expand(x.shape[0], -1, -1) x = torch.cat([latent_spectra, x], dim=1).contiguous() + try: + from torch.nn.attention import sdpa_kernel + from torch.nn.attention import SDPBackend + except ImportError: + raise ImportError( + "Training InstaNovo with Flash attention enabled requires at least pytorch v2.3. Please upgrade your pytorch version" + ) + with sdpa_kernel(SDPBackend.FLASH_ATTENTION): x = self.encoder(x) @@ -391,6 +397,14 @@ def _flash_decoder( c_mask = self._get_causal_mask(y.shape[1]).to(y.device) + try: + from torch.nn.attention import sdpa_kernel + from torch.nn.attention import SDPBackend + except ImportError: + raise ImportError( + "Training InstaNovo with Flash attention enabled requires at least pytorch v2.3. Please upgrade your pytorch version" + ) + with sdpa_kernel(SDPBackend.FLASH_ATTENTION): y_hat = self.decoder(y, x, tgt_mask=c_mask) diff --git a/instanovo/transformer/predict.py b/instanovo/transformer/predict.py index d243482..68b45d3 100644 --- a/instanovo/transformer/predict.py +++ b/instanovo/transformer/predict.py @@ -84,6 +84,26 @@ def get_preds( else: raise + # Check max charge values: + original_size = len(sdf) + max_charge = config.get("max_charge", 10) + model_max_charge = model_config.get("max_charge", 10) + if max_charge > model_max_charge: + logger.warning( + f"Inference has been configured with max_charge={max_charge}, but model has max_charge={model_max_charge}." + ) + logger.warning(f"Overwriting max_charge to Model value: {max_charge}.") + max_charge = model_max_charge + + sdf.filter_rows( + lambda row: (row["precursor_charge"] <= max_charge) + and (row["precursor_charge"] > 0) + ) + if len(sdf) < original_size: + logger.warning( + f"Found {original_size - len(sdf)} rows with charge > {max_charge}. These rows will be skipped." + ) + sdf.sample_subset(fraction=config.get("subset", 1.0), seed=42) logger.info( f"Data loaded, evaluating {config.get('subset', 1.0)*100:.1f}%, {len(sdf):,} samples in total." diff --git a/instanovo/transformer/train.py b/instanovo/transformer/train.py index ffdc976..a2d948a 100644 --- a/instanovo/transformer/train.py +++ b/instanovo/transformer/train.py @@ -420,7 +420,39 @@ def train( else: raise - # TODO: Add automatic splitting if no validation set is specified. + if config.get("valid_path", None) is None: + logger.info("Validation path not specified, generating from training set.") + sequences = list(train_sdf.get_unique_sequences()) + sequences = sorted(list(set([remove_modifications(x) for x in sequences]))) + train_unique, valid_unique = train_test_split( + sequences, + test_size=config.get("valid_subset_of_train"), + random_state=42, + ) + train_unique = set(train_unique) + valid_unique = set(valid_unique) + + train_sdf.filter_rows( + lambda row: remove_modifications(row["sequence"]) in train_unique + ) + valid_sdf.filter_rows( + lambda row: remove_modifications(row["sequence"]) in valid_unique + ) + # Save splits + # TODO: Optionally load the data splits + # TODO: Allow loading of data splits in `predict.py` + # TODO: Upload to Aichor + split_path = os.path.join( + config.get("model_save_folder_path", "./checkpoints"), "splits.csv" + ) + os.makedirs(os.path.dirname(split_path), exist_ok=True) + pd.DataFrame( + { + "modified_sequence": list(train_unique) + list(valid_unique), + "split": ["train"] * len(train_unique) + ["valid"] * len(valid_unique), + } + ).to_csv(str(split_path), index=False) + logger.info(f"Data splits saved to {split_path}") # Check residues if config.get("perform_data_checks", True): @@ -463,40 +495,25 @@ def train( f"{original_size[1]-new_size[1]:,d} ({(original_size[1]-new_size[1])/original_size[1]*100:.2f}%) validation rows dropped." ) - # TODO Modify this code to work in the new SpectrumDataFrame - if config.get("valid_path", None) is None: - logger.info("Validation path not specified, generating from training set.") - sequences = list(train_sdf.get_unique_sequences()) - sequences = sorted(list(set([remove_modifications(x) for x in sequences]))) - train_unique, valid_unique = train_test_split( - sequences, - test_size=config.get("valid_subset_of_train"), - random_state=42, - ) - train_unique = set(train_unique) - valid_unique = set(valid_unique) - + # Check charge values: + original_size = (len(train_sdf), len(valid_sdf)) train_sdf.filter_rows( - lambda row: remove_modifications(row["sequence"]) in train_unique + lambda row: (row["precursor_charge"] <= config.get("max_charge", 10)) + and (row["precursor_charge"] > 0) ) + if len(train_sdf) < original_size[0]: + logger.warning( + f"Found {original_size[0] - len(train_sdf)} rows in training set with charge > {config.get('max_charge', 10)} or <= 0. These rows will be skipped." + ) + valid_sdf.filter_rows( - lambda row: remove_modifications(row["sequence"]) in valid_unique + lambda row: (row["precursor_charge"] <= config.get("max_charge", 10)) + and (row["precursor_charge"] > 0) ) - # Save splits - # TODO: Optionally load the data splits - # TODO: Allow loading of data splits in `predict.py` - # TODO: Upload to Aichor - split_path = os.path.join( - config.get("model_save_folder_path", "./checkpoints"), "splits.csv" - ) - os.makedirs(os.path.dirname(split_path), exist_ok=True) - pd.DataFrame( - { - "modified_sequence": list(train_unique) + list(valid_unique), - "split": ["train"] * len(train_unique) + ["valid"] * len(valid_unique), - } - ).to_csv(str(split_path), index=False) - logger.info(f"Data splits saved to {split_path}") + if len(valid_sdf) < original_size[1]: + logger.warning( + f"Found {original_size[1] - len(valid_sdf)} rows in training set with charge > {config.get('max_charge', 10)}. These rows will be skipped." + ) train_sdf.sample_subset(fraction=config.get("train_subset", 1.0), seed=42) valid_sdf.sample_subset(fraction=config.get("valid_subset", 1.0), seed=42) diff --git a/notebooks/getting_started_with_instanovo.ipynb b/notebooks/getting_started_with_instanovo.ipynb index 9d8d7d0..e63f66c 100644 --- a/notebooks/getting_started_with_instanovo.ipynb +++ b/notebooks/getting_started_with_instanovo.ipynb @@ -28,7 +28,7 @@ "id": "v4Rk9kj1NiMU" }, "source": [ - "![](https://raw.githubusercontent.com/instadeepai/InstaNovo/main/graphical_abstract.jpeg)" + "![](https://raw.githubusercontent.com/instadeepai/InstaNovo/main/docs/assets/graphical_abstract.jpeg)" ] }, { @@ -43,7 +43,7 @@ "\n", "- **De novo peptide sequencing with InstaNovo: Accurate, database-free peptide identification for large scale proteomics experiments** \\\n", " Kevin Eloff, Konstantinos Kalogeropoulos, Oliver Morell, Amandla Mabona, Jakob Berg Jespersen, Wesley Williams, Sam van Beljouw, Marcin Skwark, Andreas Hougaard Laustsen, Stan J. J. Brouns, Anne Ljungars, Erwin M. Schoof, Jeroen Van Goey, Ulrich auf dem Keller, Karim Beguir, Nicolas Lopez Carranza, Timothy P. Jenkins \\\n", - " [bioRxiv](https://www.biorxiv.org/content/10.1101/2023.08.30.555055v1), [GitHub](https://github.com/instadeepai/InstaNovo)\n", + " [bioRxiv](https://www.biorxiv.org/content/10.1101/2023.08.30.555055v3), [GitHub](https://github.com/instadeepai/InstaNovo)\n", "\n", "**Important:**\n", "\n", @@ -62,9 +62,7 @@ "source": [ "## Loading the InstaNovo model\n", "\n", - "We first install the latest instanovo from PyPi\n", - "\n", - "_Note: this currently installs directly from GitHub, this will be updated in the next release._" + "We first install the latest instanovo from PyPi" ] }, { @@ -123,7 +121,9 @@ "id": "7QcyM4jKA9qL" }, "source": [ - "We can download the model checkpoint directly from the [InstaNovo releases](https://github.com/instadeepai/InstaNovo/releases)." + "We can download the model checkpoint directly from the [InstaNovo releases](https://github.com/instadeepai/InstaNovo/releases).\n", + "\n", + "_Note: this file is 1.1GB and may take time to download_" ] }, { @@ -142,8 +142,8 @@ "source": [ "# Download checkpoint locally\n", "os.makedirs(\"checkpoints\", exist_ok=True)\n", - "url = \"https://github.com/instadeepai/InstaNovo/releases/download/0.1.4/instanovo_yeast.pt\"\n", - "file_path = os.path.join(\"checkpoints\", \"instanovo_yeast.pt\")\n", + "url = \"https://github.com/instadeepai/InstaNovo/releases/download/1.0.0/instanovo_extended.ckpt\"\n", + "file_path = os.path.join(\"checkpoints\", \"instanovo_extended.ckpt\")\n", "if not os.path.exists(file_path):\n", " urllib.request.urlretrieve(url, file_path)" ] @@ -185,8 +185,7 @@ "metadata": {}, "outputs": [], "source": [ - "# model_path = \"./checkpoints/instanovo_yeast.pt\"\n", - "model_path = \"../checkpoints/extended_147b2a84/epoch=3-step=800000.ckpt\"" + "model_path = \"./checkpoints/instanovo_extended.ckpt\"" ] }, { @@ -230,7 +229,7 @@ " \"InstaDeepAI/ms_ninespecies_benchmark\",\n", " is_annotated=True,\n", " shuffle=False,\n", - " split=\"test[:10%]\",\n", + " split=\"test[:10%]\", # Let's only use a subset of the test data for faster inference\n", ")" ] }, @@ -324,15 +323,57 @@ "id": "8a16c311-6802-49f8-af8e-857f43510c37" }, "source": [ - "## Knapsack beam-search decoder\n", + "## Decoding\n", "\n", - "Setup knapsack beam search decoder. This may take a few minutes." + "We have three options for decoding:\n", + "- Greedy Search\n", + "- Beam Search\n", + "- Knapsack Beam Search\n", + "\n", + "For the best results and highest peptide recall, use **Knapsack Beam Search**. \n", + "For fastest results (over 10x speedup), use **Greedy Search**.\n", + "\n", + "We generally use a beam size of 5 for Beam Search and Knapsack Beam Search, a higher beam size should increase recall at the cost of performance and vice versa.\n", + "\n", + "_Note: in our findings, greedy search has similar performance at 5% FDR. I.e. if you plan to filter at 5% FDR anyway, use greedy search for optimal performance._\n", + "\n", + "### Greedy Search and Beam Search\n", + "\n", + "Greedy search is used when `num_beams=1`, and beam search is used when `num_beams>1`" ] }, { "cell_type": "code", "execution_count": null, "id": "21", + "metadata": {}, + "outputs": [], + "source": [ + "from instanovo.inference import GreedyDecoder\n", + "from instanovo.inference import BeamSearchDecoder\n", + "\n", + "num_beams = 1 # Change this, defaults are 1 or 5\n", + "\n", + "if num_beams > 1:\n", + " decoder = BeamSearchDecoder(model=model)\n", + "else:\n", + " decoder = GreedyDecoder(model=model)" + ] + }, + { + "cell_type": "markdown", + "id": "22", + "metadata": {}, + "source": [ + "### Knapsack Beam Search\n", + "\n", + "Setup knapsack beam search decoder. This may take a few minutes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "23", "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -385,32 +426,6 @@ " decoder = KnapsackBeamSearchDecoder.from_file(model=model, path=knapsack_path)" ] }, - { - "cell_type": "markdown", - "id": "22", - "metadata": {}, - "source": [ - "Or use greedy search (fastest) or plain BeamSearch" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "23", - "metadata": {}, - "outputs": [], - "source": [ - "from instanovo.inference import GreedyDecoder\n", - "from instanovo.inference import BeamSearchDecoder\n", - "\n", - "num_beams = 1\n", - "\n", - "if num_beams > 1:\n", - " decoder = BeamSearchDecoder(model=model)\n", - "else:\n", - " decoder = GreedyDecoder(model=model)" - ] - }, { "attachments": {}, "cell_type": "markdown", @@ -545,10 +560,20 @@ " targs += list(peptides)" ] }, + { + "cell_type": "markdown", + "id": "31", + "metadata": {}, + "source": [ + "### Evaluation metrics\n", + "\n", + "Model performance without filtering:" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "31", + "id": "32", "metadata": {}, "outputs": [], "source": [ @@ -566,10 +591,42 @@ "print(f\"area under the PR curve: {auc:.5f}\")" ] }, + { + "cell_type": "markdown", + "id": "33", + "metadata": {}, + "source": [ + "### We can find a threshold to ensure a desired FDR:\n", + "\n", + "Model performance at 5% FDR:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "34", + "metadata": {}, + "outputs": [], + "source": [ + "fdr = 5 / 100 # Desired FDR\n", + "\n", + "_, threshold = metrics.find_recall_at_fdr(targs, preds, np.exp(probs), fdr=fdr)\n", + "aa_precision, aa_recall, peptide_recall, peptide_precision = (\n", + " metrics.compute_precision_recall(targs, preds, np.exp(probs), threshold=threshold)\n", + ")\n", + "print(f\"Performance at {fdr*100:.1f}% FDR:\\n\")\n", + "print(f\"amino acid precision: {aa_precision:.5f}\")\n", + "print(f\"amino acid recall: {aa_recall:.5f}\")\n", + "print(f\"peptide precision: {peptide_precision:.5f}\")\n", + "print(f\"peptide recall: {peptide_recall:.5f}\")\n", + "print(f\"area under the PR curve: {auc:.5f}\")\n", + "print(f\"confidence threshold: {threshold:.5f} <-- Use this as a confidence cutoff\")" + ] + }, { "attachments": {}, "cell_type": "markdown", - "id": "32", + "id": "35", "metadata": { "id": "IcstKaUGB8Bo" }, @@ -580,7 +637,7 @@ { "attachments": {}, "cell_type": "markdown", - "id": "33", + "id": "36", "metadata": { "id": "ychXR1M3CbKf" }, @@ -591,7 +648,7 @@ { "cell_type": "code", "execution_count": null, - "id": "34", + "id": "37", "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -615,18 +672,18 @@ { "cell_type": "code", "execution_count": null, - "id": "35", + "id": "38", "metadata": { "id": "HJgMPD0YCWIm" }, "outputs": [], "source": [ - "pred_df.to_csv(\"predictions.csv\", index=False)" + "pred_df.to_csv(\"predictions_kbs.csv\", index=False)" ] }, { "cell_type": "markdown", - "id": "36", + "id": "39", "metadata": {}, "source": [ "## InstaNovo+: Iterative Refinement with a Diffusion Model [OUTDATED]\n", @@ -640,17 +697,17 @@ { "cell_type": "code", "execution_count": null, - "id": "37", + "id": "40", "metadata": {}, "outputs": [], "source": [ - "!pip uninstall instanovo && pip install instanovo==0.1.7 " + "!pip uninstall -y instanovo && pip install instanovo==0.1.7" ] }, { "cell_type": "code", "execution_count": null, - "id": "38", + "id": "41", "metadata": {}, "outputs": [], "source": [ @@ -663,7 +720,7 @@ { "cell_type": "code", "execution_count": null, - "id": "39", + "id": "42", "metadata": {}, "outputs": [], "source": [ @@ -687,7 +744,7 @@ }, { "cell_type": "markdown", - "id": "40", + "id": "43", "metadata": {}, "source": [ "Next, we load the checkpoint and create a decoder object." @@ -696,7 +753,7 @@ { "cell_type": "code", "execution_count": null, - "id": "41", + "id": "44", "metadata": {}, "outputs": [], "source": [ @@ -710,7 +767,7 @@ }, { "cell_type": "markdown", - "id": "42", + "id": "45", "metadata": {}, "source": [ "Then we prepare the inference data loader using predictions from the InstaNovo transformer model." @@ -719,7 +776,7 @@ { "cell_type": "code", "execution_count": null, - "id": "43", + "id": "46", "metadata": {}, "outputs": [], "source": [ @@ -747,7 +804,7 @@ }, { "cell_type": "markdown", - "id": "44", + "id": "47", "metadata": {}, "source": [ "Finally, we predict sequences by iterating over the spectra and refining the InstaNovo predictions." @@ -756,7 +813,7 @@ { "cell_type": "code", "execution_count": null, - "id": "45", + "id": "48", "metadata": {}, "outputs": [], "source": [ @@ -784,7 +841,7 @@ }, { "cell_type": "markdown", - "id": "46", + "id": "49", "metadata": {}, "source": [ "Iterative refinement improves performance on this sample of the Nine Species dataset. (To replicate the performance reported in the paper, you would need to evaluate on the entire dataset.) " @@ -793,7 +850,7 @@ { "cell_type": "code", "execution_count": null, - "id": "47", + "id": "50", "metadata": {}, "outputs": [], "source": [ @@ -817,7 +874,7 @@ { "cell_type": "code", "execution_count": null, - "id": "48", + "id": "51", "metadata": {}, "outputs": [], "source": [ @@ -836,7 +893,7 @@ { "cell_type": "code", "execution_count": null, - "id": "49", + "id": "52", "metadata": {}, "outputs": [], "source": [ @@ -853,7 +910,7 @@ { "cell_type": "code", "execution_count": null, - "id": "50", + "id": "53", "metadata": {}, "outputs": [], "source": [ @@ -868,7 +925,7 @@ "provenance": [] }, "kernelspec": { - "display_name": "instanovo", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -882,7 +939,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.9" + "version": "3.10.15" } }, "nbformat": 4,