{ "cells": [ { "cell_type": "markdown", "id": "04879e6b-3718-4d23-90fa-35e5c8956861", "metadata": { "id": "04879e6b-3718-4d23-90fa-35e5c8956861" }, "source": [ "# Testing ESMB for Protein Binding Residue Prediction\n", "\n", "This notebook is meant to test out ESM-2 LoRA models on the datasets found [here](https://github.com/hamzagamouh/pt-lm-gnn/tree/main/datasets/yu_merged) for the paper [Hybrid protein-ligand binding residue prediction with protein\n", "language models: Does the structure matter?](https://www.biorxiv.org/content/10.1101/2023.08.11.553028v1). The models referenced in the paper are GCN, GAT, and ensemble structural models trained on PDB sequences to predict binding residues. They are the best performing models that could be found as of 17/09/23. You will need to download the datasets you want to test out from the github above and provide the file path in the code below." ] }, { "cell_type": "markdown", "source": [ "## Mount Your Google Drive if Necessary" ], "metadata": { "id": "Fhq4pEpf--n1" }, "id": "Fhq4pEpf--n1" }, { "cell_type": "code", "source": [ "from google.colab import drive\n", "drive.mount('/content/drive')" ], "metadata": { "id": "jkhgCCq0TjLk", "outputId": "7954b983-0e1a-4ff9-ab74-923b3fa675ba", "colab": { "base_uri": "https://localhost:8080/" } }, "id": "jkhgCCq0TjLk", "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Mounted at /content/drive\n" ] } ] }, { "cell_type": "code", "execution_count": null, "id": "b8df2453-f478-4ef5-b69f-633aff114438", "metadata": { "id": "b8df2453-f478-4ef5-b69f-633aff114438", "outputId": "f97f2168-7871-4e28-a46d-8381b865c4aa", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.6/7.6 MB\u001b[0m \u001b[31m10.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m294.8/294.8 kB\u001b[0m \u001b[31m9.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m15.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m19.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m258.1/258.1 kB\u001b[0m \u001b[31m5.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m85.6/85.6 kB\u001b[0m \u001b[31m2.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m519.6/519.6 kB\u001b[0m \u001b[31m5.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.3/115.3 kB\u001b[0m \u001b[31m7.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m9.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m8.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h" ] } ], "source": [ "!pip install transformers -q\n", "!pip install accelerate -q\n", "!pip install peft -q\n", "!pip install datasets -q" ] }, { "cell_type": "code", "source": [ "import pandas as pd\n", "\n", "# Load the dataset\n", "data_df = pd.read_csv(\"/content/drive/MyDrive/esmb_testing/CA_Training.txt\", delimiter=';')\n", "\n", "# Display the first few rows of the dataframe to understand its structure\n", "data_df.head()\n", "\n" ], "metadata": { "id": "VMkshrhNAJ6z", "outputId": "8ac47820-9cb9-42e0-a859-064e3d2450f0", "colab": { "base_uri": "https://localhost:8080/", "height": 261 } }, "id": "VMkshrhNAJ6z", "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " pdb_id chain_id binding_residues \\\n", "0 1CB8 A E380 D382 K383 D391 Y392 \n", "1 3ALS A E112 N114 N115 D135 \n", "2 2X7Q A N52 D197 G71 E73 \n", "3 3BBY A D75 E77 \n", "4 1B2L A D2 T4 \n", "\n", " sequence \n", "0 GTAELIMKRVMLDLKKPLRNMDKVAEKNLNTLQPDGSWKDVPYKDD... \n", "1 LTSCPPLWTGFNGKCFRLFHNHLNFDNAENACRQFGLASCSGDELA... \n", "2 LPTLKVAYIPEHFSTPLFFAQQQGYYKAHDLSIEFVKVPEGSGRLI... \n", "3 KPAITLWSDAHFFSPYVLSAWVALQEKGLSFHIKTIDRVPLLQIDD... \n", "4 MDLTNKNVIFVAALGGIGLDTSRELVKRNLKNFVILDRVENPTALA... " ], "text/html": [ "\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
pdb_idchain_idbinding_residuessequence
01CB8AE380 D382 K383 D391 Y392GTAELIMKRVMLDLKKPLRNMDKVAEKNLNTLQPDGSWKDVPYKDD...
13ALSAE112 N114 N115 D135LTSCPPLWTGFNGKCFRLFHNHLNFDNAENACRQFGLASCSGDELA...
22X7QAN52 D197 G71 E73LPTLKVAYIPEHFSTPLFFAQQQGYYKAHDLSIEFVKVPEGSGRLI...
33BBYAD75 E77KPAITLWSDAHFFSPYVLSAWVALQEKGLSFHIKTIDRVPLLQIDD...
41B2LAD2 T4MDLTNKNVIFVAALGGIGLDTSRELVKRNLKNFVILDRVENPTALA...
\n", "
\n", "
\n", "\n", "
\n", " \n", "\n", " \n", "\n", " \n", "
\n", "\n", "\n", "
\n", " \n", "\n", "\n", "\n", " \n", "
\n", "
\n", "
\n" ] }, "metadata": {}, "execution_count": 14 } ] }, { "cell_type": "code", "source": [ "# Define a function to convert binding residues to binary labels\n", "def binding_residues_to_labels(row):\n", " sequence = row['sequence']\n", " binding_residues = row['binding_residues']\n", "\n", " # Initialize a list with zeros\n", " labels = [0] * len(sequence)\n", "\n", " # If binding_residues is not NaN, mark the binding residues in the labels list with 1\n", " if isinstance(binding_residues, str):\n", " # Get the indices of the binding residues\n", " binding_residues_indices = [int(residue[1:]) - 1 for residue in binding_residues.split()]\n", "\n", " # Mark the binding residues in the labels list with 1\n", " for idx in binding_residues_indices:\n", " if idx < len(labels):\n", " labels[idx] = 1\n", "\n", " return labels\n", "\n", "# Apply the function to each row in the DataFrame to get the binary labels\n", "data_df['binding_labels'] = data_df.apply(binding_residues_to_labels, axis=1)\n", "\n", "# Display the first few rows of the DataFrame\n", "data_df.head()\n", "\n" ], "metadata": { "id": "lQEGt8EwAJyH", "outputId": "53cf653c-222e-481c-df65-bacb8696db95", "colab": { "base_uri": "https://localhost:8080/", "height": 400 } }, "id": "lQEGt8EwAJyH", "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " pdb_id chain_id binding_residues \\\n", "0 1CB8 A E380 D382 K383 D391 Y392 \n", "1 3ALS A E112 N114 N115 D135 \n", "2 2X7Q A N52 D197 G71 E73 \n", "3 3BBY A D75 E77 \n", "4 1B2L A D2 T4 \n", "\n", " sequence \\\n", "0 GTAELIMKRVMLDLKKPLRNMDKVAEKNLNTLQPDGSWKDVPYKDD... \n", "1 LTSCPPLWTGFNGKCFRLFHNHLNFDNAENACRQFGLASCSGDELA... \n", "2 LPTLKVAYIPEHFSTPLFFAQQQGYYKAHDLSIEFVKVPEGSGRLI... \n", "3 KPAITLWSDAHFFSPYVLSAWVALQEKGLSFHIKTIDRVPLLQIDD... \n", "4 MDLTNKNVIFVAALGGIGLDTSRELVKRNLKNFVILDRVENPTALA... \n", "\n", " binding_labels \n", "0 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", "1 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", "2 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", "3 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", "4 [0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... " ], "text/html": [ "\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
pdb_idchain_idbinding_residuessequencebinding_labels
01CB8AE380 D382 K383 D391 Y392GTAELIMKRVMLDLKKPLRNMDKVAEKNLNTLQPDGSWKDVPYKDD...[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
13ALSAE112 N114 N115 D135LTSCPPLWTGFNGKCFRLFHNHLNFDNAENACRQFGLASCSGDELA...[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
22X7QAN52 D197 G71 E73LPTLKVAYIPEHFSTPLFFAQQQGYYKAHDLSIEFVKVPEGSGRLI...[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
33BBYAD75 E77KPAITLWSDAHFFSPYVLSAWVALQEKGLSFHIKTIDRVPLLQIDD...[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
41B2LAD2 T4MDLTNKNVIFVAALGGIGLDTSRELVKRNLKNFVILDRVENPTALA...[0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
\n", "
\n", "
\n", "\n", "
\n", " \n", "\n", " \n", "\n", " \n", "
\n", "\n", "\n", "
\n", " \n", "\n", "\n", "\n", " \n", "
\n", "
\n", "
\n" ] }, "metadata": {}, "execution_count": 15 } ] }, { "cell_type": "code", "source": [ "# Define the maximum chunk size\n", "MAX_CHUNK_SIZE = 900\n", "\n", "# Function to segment sequences and labels into chunks of size <= 1022\n", "def segment_into_chunks(row):\n", " sequence = row['sequence']\n", " labels = row['binding_labels']\n", "\n", " # Segment the sequence and labels into chunks of size <= 1022\n", " sequence_chunks = [sequence[i:i+MAX_CHUNK_SIZE] for i in range(0, len(sequence), MAX_CHUNK_SIZE)]\n", " label_chunks = [labels[i:i+MAX_CHUNK_SIZE] for i in range(0, len(labels), MAX_CHUNK_SIZE)]\n", "\n", " return sequence_chunks, label_chunks\n", "\n", "# Apply the function to each row in the DataFrame to get the segmented sequences and labels\n", "data_df['sequence_chunks'] = None\n", "data_df['label_chunks'] = None\n", "for idx, row in data_df.iterrows():\n", " data_df.at[idx, 'sequence_chunks'], data_df.at[idx, 'label_chunks'] = segment_into_chunks(row)\n", "\n", "# Display the first few rows of the DataFrame\n", "data_df[['pdb_id', 'chain_id', 'sequence_chunks', 'label_chunks']].head()\n", "\n" ], "metadata": { "id": "DvgnGWnDAJjV", "outputId": "5e5bbadf-50ef-49a3-b8f9-56b8db241055", "colab": { "base_uri": "https://localhost:8080/", "height": 400 } }, "id": "DvgnGWnDAJjV", "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " pdb_id chain_id sequence_chunks \\\n", "0 1CB8 A [GTAELIMKRVMLDLKKPLRNMDKVAEKNLNTLQPDGSWKDVPYKD... \n", "1 3ALS A [LTSCPPLWTGFNGKCFRLFHNHLNFDNAENACRQFGLASCSGDEL... \n", "2 2X7Q A [LPTLKVAYIPEHFSTPLFFAQQQGYYKAHDLSIEFVKVPEGSGRL... \n", "3 3BBY A [KPAITLWSDAHFFSPYVLSAWVALQEKGLSFHIKTIDRVPLLQID... \n", "4 1B2L A [MDLTNKNVIFVAALGGIGLDTSRELVKRNLKNFVILDRVENPTAL... \n", "\n", " label_chunks \n", "0 [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... \n", "1 [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... \n", "2 [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... \n", "3 [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... \n", "4 [[0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... " ], "text/html": [ "\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
pdb_idchain_idsequence_chunkslabel_chunks
01CB8A[GTAELIMKRVMLDLKKPLRNMDKVAEKNLNTLQPDGSWKDVPYKD...[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...
13ALSA[LTSCPPLWTGFNGKCFRLFHNHLNFDNAENACRQFGLASCSGDEL...[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...
22X7QA[LPTLKVAYIPEHFSTPLFFAQQQGYYKAHDLSIEFVKVPEGSGRL...[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...
33BBYA[KPAITLWSDAHFFSPYVLSAWVALQEKGLSFHIKTIDRVPLLQID...[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...
41B2LA[MDLTNKNVIFVAALGGIGLDTSRELVKRNLKNFVILDRVENPTAL...[[0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...
\n", "
\n", "
\n", "\n", "
\n", " \n", "\n", " \n", "\n", " \n", "
\n", "\n", "\n", "
\n", " \n", "\n", "\n", "\n", " \n", "
\n", "
\n", "
\n" ] }, "metadata": {}, "execution_count": 16 } ] }, { "cell_type": "code", "source": [ "from transformers import AutoModelForTokenClassification, AutoTokenizer\n", "from peft import PeftModel\n", "import torch\n", "\n", "def get_predictions(protein_sequence):\n", " # Path to the saved LoRA model\n", " model_path = \"AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3\"\n", " # ESM2 base model\n", " base_model_path = \"facebook/esm2_t12_35M_UR50D\"\n", "\n", " # Load the model\n", " base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)\n", " loaded_model = PeftModel.from_pretrained(base_model, model_path)\n", "\n", " # Ensure the model is in evaluation mode\n", " loaded_model.eval()\n", "\n", " # Load the tokenizer\n", " loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path)\n", "\n", " # Tokenize the sequence\n", " inputs = loaded_tokenizer(protein_sequence, return_tensors=\"pt\", truncation=True, max_length=1024, padding='max_length')\n", "\n", " # Run the model\n", " with torch.no_grad():\n", " logits = loaded_model(**inputs).logits\n", "\n", " # Get predictions\n", " tokens = loaded_tokenizer.convert_ids_to_tokens(inputs[\"input_ids\"][0]) # Convert input ids back to tokens\n", " predictions = torch.argmax(logits, dim=2)[0].numpy()\n", "\n", " # Define labels\n", " id2label = {\n", " 0: \"No binding site\",\n", " 1: \"Binding site\"\n", " }\n", "\n", " # Convert predictions to binary labels (1 for binding site, 0 otherwise)\n", " special_tokens = ['', '', '', '', '.', '-', '', '']\n", " binary_predictions = [1 if id2label[pred] == \"Binding site\" else 0 for token, pred in zip(tokens, predictions) if token not in special_tokens]\n", "\n", " return binary_predictions\n", "\n", "# Use the function to get predictions for a test sequence\n", "test_sequence = \"MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT\"\n", "print(get_predictions(test_sequence))\n", "\n" ], "metadata": { "id": "gfbCVd_VAtoF", "outputId": "054ef295-57f3-44ee-9d27-b65d7d96326e", "colab": { "base_uri": "https://localhost:8080/", "height": 211, "referenced_widgets": [ "36441384bfdc46eb9bcbd5c4faa3b4ea", "fc195c8505c6467b8e399ea17566a6bb", "c68e986ed38b4b3386d2d6b05812e10f", "293b58e6082f4e5b9e8f127c878ec48f", "b8df700ff90346d0a21eed74b0fe9a83", "723113d4cbbd4f2bace792606d811abf", "70abccffc23d4f2fba49bba2334f25cc", "a38ab65e26b94c1493958f7ef0b127a5", "063115df07914fba94f2b7974ea686e1", "a3d1887fca2545639019d0773051e3eb", "4c7e49659eb54d3ba5d4b967e622988c", "8e229cd9f36e497c8a0b7bbe5754911d", "067236d291914321b66b8fb2208f083f", "addcd3ea0cc140b0bd90e0337b847883", "dcf7d9ba3b2c40fa86a1ecf295cff662", "db42b3baf214409dbd3af6b99d6df4d7", "ddee2a1d14874114bae4eec328f18081", "7f27da02414e471d98ca74ee2e2b0156", "50e062dd562b456db4c9f76bb38a3a05", "13cc8293591b45d8906a85c8b2db6af7", "6d54c0c9d22d45f59929e03b68a446e3", "aa1d53df052b4965ac2ae0ecb3d5b9b4" ] } }, "id": "gfbCVd_VAtoF", "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "Some weights of EsmForTokenClassification were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "Downloading (…)/adapter_config.json: 0%| | 0.00/456 [00:00', '', '', '', '.', '-', '', '']\n", " binary_predictions = [1 if id2label[pred] == \"Binding site\" else 0 for token, pred in zip(tokens, predictions) if token not in special_tokens]\n", "\n", " return binary_predictions\n", "\n", "# Load the model and tokenizer\n", "base_model_path = \"facebook/esm2_t12_35M_UR50D\"\n", "model_path = \"AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3\"\n", "loaded_model = PeftModel.from_pretrained(AutoModelForTokenClassification.from_pretrained(base_model_path), model_path)\n", "loaded_model.eval()\n", "loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path)\n", "\n", "# Step 2: Create a function to get predictions for each chunk and store them in a new column\n", "def get_chunk_predictions(row):\n", " global loaded_model, loaded_tokenizer\n", " sequence_chunks = row['sequence_chunks']\n", " predictions = [get_predictions(chunk, loaded_model, loaded_tokenizer) for chunk in sequence_chunks]\n", " return predictions\n", "\n", "data_df['predictions_chunks'] = data_df.apply(get_chunk_predictions, axis=1)\n", "\n", "# Step 3: Flatten the predictions and true labels columns to calculate metrics\n", "from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, matthews_corrcoef\n", "\n", "# Flatten the lists of labels and predictions to calculate metrics\n", "true_labels_flat = [label for sublist in data_df['label_chunks'].tolist() for subsublist in sublist for label in subsublist]\n", "predictions_flat = [label for sublist in data_df['predictions_chunks'].tolist() for subsublist in sublist for label in subsublist]\n", "\n", "# Calculate the metrics\n", "accuracy = accuracy_score(true_labels_flat, predictions_flat)\n", "precision = precision_score(true_labels_flat, predictions_flat)\n", "recall = recall_score(true_labels_flat, predictions_flat)\n", "f1 = f1_score(true_labels_flat, predictions_flat)\n", "auc = roc_auc_score(true_labels_flat, predictions_flat)\n", "mcc = matthews_corrcoef(true_labels_flat, predictions_flat)\n", "\n", "# Print the metrics\n", "print(f'Accuracy: {accuracy:.4f}')\n", "print(f'Precision: {precision:.4f}')\n", "print(f'Recall: {recall:.4f}')\n", "print(f'F1 Score: {f1:.4f}')\n", "print(f'AUC: {auc:.4f}')\n", "print(f'MCC: {mcc:.4f}')\n" ], "metadata": { "id": "b7Fe7TroNz-C", "outputId": "93d2f0f5-d87c-4e53-ceee-48c0a5cc6ad3", "colab": { "base_uri": "https://localhost:8080/" } }, "id": "b7Fe7TroNz-C", "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "Some weights of EsmForTokenClassification were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "Accuracy: 0.8673\n", "Precision: 0.0408\n", "Recall: 0.3071\n", "F1 Score: 0.0721\n", "AUC: 0.5920\n", "MCC: 0.0712\n" ] } ] }, { "cell_type": "markdown", "id": "38cebb36-6758-4e7d-b82d-e43cabfbe798", "metadata": { "id": "38cebb36-6758-4e7d-b82d-e43cabfbe798" }, "source": [ "## Train/Test Metrics\n", "\n", "Here you can get the train and test metrics the model was originally trained on. Perhaps you can figure out why they are so different from the metrics on the datasets above?!\n", "\n", "### Loading and Tokenizing the Datasets\n", "\n", "To use this notebook to run the model on the train/test split and get the various metrics (accuracy, precision, recall, F1 score, AUC, and MCC) you will need to download the pickle files [found on Hugging Face here](https://huggingface.co/datasets/AmelieSchreiber/binding_sites_random_split_by_family_550K). Navigate to the \"Files and versions\" and download the four pickle files (you can ignore the TSV files unless you want to preprocess the data in a different way yourself). Once you have downloaded the pickle files, change the four file pickle paths in the cell below to match the local paths of the pickle files on your machine, then run the cell." ] }, { "cell_type": "code", "execution_count": null, "id": "763eba61-fd1e-45d5-a427-0075e46c6293", "metadata": { "id": "763eba61-fd1e-45d5-a427-0075e46c6293", "outputId": "db820a77-894d-469d-aaf0-5218ec0e9320" }, "outputs": [ { "data": { "text/plain": [ "(Dataset({\n", " features: ['input_ids', 'attention_mask', 'labels'],\n", " num_rows: 450330\n", " }),\n", " Dataset({\n", " features: ['input_ids', 'attention_mask', 'labels'],\n", " num_rows: 113475\n", " }))" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from datasets import Dataset\n", "from transformers import AutoTokenizer\n", "import pickle\n", "\n", "# Load tokenizer\n", "tokenizer = AutoTokenizer.from_pretrained(\"facebook/esm2_t12_35M_UR50D\")\n", "\n", "# Function to truncate labels\n", "def truncate_labels(labels, max_length):\n", " \"\"\"Truncate labels to the specified max_length.\"\"\"\n", " return [label[:max_length] for label in labels]\n", "\n", "# Set the maximum sequence length\n", "max_sequence_length = 1000\n", "\n", "# Load the data from pickle files (change to match your local paths)\n", "with open(\"train_sequences_chunked_by_family.pkl\", \"rb\") as f:\n", " train_sequences = pickle.load(f)\n", "with open(\"test_sequences_chunked_by_family.pkl\", \"rb\") as f:\n", " test_sequences = pickle.load(f)\n", "with open(\"train_labels_chunked_by_family.pkl\", \"rb\") as f:\n", " train_labels = pickle.load(f)\n", "with open(\"test_labels_chunked_by_family.pkl\", \"rb\") as f:\n", " test_labels = pickle.load(f)\n", "\n", "# Tokenize the sequences\n", "train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors=\"pt\", is_split_into_words=False)\n", "test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors=\"pt\", is_split_into_words=False)\n", "\n", "# Truncate the labels to match the tokenized sequence lengths\n", "train_labels = truncate_labels(train_labels, max_sequence_length)\n", "test_labels = truncate_labels(test_labels, max_sequence_length)\n", "\n", "# Create train and test datasets\n", "train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column(\"labels\", train_labels)\n", "test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column(\"labels\", test_labels)\n", "\n", "train_dataset, test_dataset\n" ] }, { "cell_type": "markdown", "id": "c56556a3-93a5-45c6-935d-dc959b18c608", "metadata": { "id": "c56556a3-93a5-45c6-935d-dc959b18c608" }, "source": [ "### Getting the Train/Test Metrics\n", "\n", "Next, run the following cell. Depending on your hardware, this may take a while. There are ~549K protein sequences to process in total. The train dataset will obviously take much longer than the test dataset. Be patient and let both of them complete to see both the train and test metrics." ] }, { "cell_type": "code", "execution_count": null, "id": "65dd11e8-f502-44cd-b439-a593bf4d5019", "metadata": { "id": "65dd11e8-f502-44cd-b439-a593bf4d5019", "outputId": "a795325a-dfa5-4fcf-85c4-734dc8be42e1", "colab": { "referenced_widgets": [ "f110a2bca7314f278e1b97a37f4ab033", "2bd08fb8fcb644d080746c42dc4d77d1" ] } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of EsmForTokenClassification were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f110a2bca7314f278e1b97a37f4ab033", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading (…)/adapter_config.json: 0%| | 0.00/457 [00:00\n", " \n", " \n", " [ 200/56292 01:32 < 7:13:37, 2.16 it/s]\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from sklearn.metrics import(\n", " matthews_corrcoef,\n", " accuracy_score,\n", " precision_recall_fscore_support,\n", " roc_auc_score\n", ")\n", "from peft import PeftModel\n", "from transformers import DataCollatorForTokenClassification, AutoModelForTokenClassification\n", "from transformers import Trainer\n", "from accelerate import Accelerator\n", "\n", "# Instantiate the accelerator\n", "accelerator = Accelerator()\n", "\n", "# Define paths to the LoRA and base models\n", "base_model_path = \"facebook/esm2_t12_35M_UR50D\"\n", "lora_model_path = \"AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3\" # \"path/to/your/lora/model\" # Replace with the correct path to your LoRA model\n", "\n", "# Load the base model\n", "base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)\n", "\n", "# Load the LoRA model\n", "model = PeftModel.from_pretrained(base_model, lora_model_path)\n", "model = accelerator.prepare(model) # Prepare the model using the accelerator\n", "\n", "# Define label mappings\n", "id2label = {0: \"No binding site\", 1: \"Binding site\"}\n", "label2id = {v: k for k, v in id2label.items()}\n", "\n", "# Create a data collator\n", "data_collator = DataCollatorForTokenClassification(tokenizer)\n", "\n", "# Define a function to compute the metrics\n", "def compute_metrics(dataset):\n", " # Get the predictions using the trained model\n", " trainer = Trainer(model=model, data_collator=data_collator)\n", " predictions, labels, _ = trainer.predict(test_dataset=dataset)\n", "\n", " # Remove padding and special tokens\n", " mask = labels != -100\n", " true_labels = labels[mask].flatten()\n", " flat_predictions = np.argmax(predictions, axis=2)[mask].flatten().tolist()\n", "\n", " # Compute the metrics\n", " accuracy = accuracy_score(true_labels, flat_predictions)\n", " precision, recall, f1, _ = precision_recall_fscore_support(true_labels, flat_predictions, average='binary')\n", " auc = roc_auc_score(true_labels, flat_predictions)\n", " mcc = matthews_corrcoef(true_labels, flat_predictions) # Compute the MCC\n", "\n", " return {\"accuracy\": accuracy, \"precision\": precision, \"recall\": recall, \"f1\": f1, \"auc\": auc, \"mcc\": mcc} # Include the MCC in the returned dictionary\n", "\n", "# Get the metrics for the training and test datasets\n", "train_metrics = compute_metrics(train_dataset)\n", "test_metrics = compute_metrics(test_dataset)\n", "\n", "train_metrics, test_metrics" ] }, { "cell_type": "code", "execution_count": null, "id": "d8cc0058-1f81-466d-9fed-4a7ef55ba11f", "metadata": { "id": "d8cc0058-1f81-466d-9fed-4a7ef55ba11f" }, "outputs": [], "source": [] } ], "metadata": { "language_info": { "name": "python" }, "colab": { "provenance": [], "gpuType": "T4" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "accelerator": "GPU", "widgets": { "application/vnd.jupyter.widget-state+json": { "36441384bfdc46eb9bcbd5c4faa3b4ea": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_fc195c8505c6467b8e399ea17566a6bb", "IPY_MODEL_c68e986ed38b4b3386d2d6b05812e10f", "IPY_MODEL_293b58e6082f4e5b9e8f127c878ec48f" ], "layout": "IPY_MODEL_b8df700ff90346d0a21eed74b0fe9a83" } }, "fc195c8505c6467b8e399ea17566a6bb": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_723113d4cbbd4f2bace792606d811abf", "placeholder": "​", "style": "IPY_MODEL_70abccffc23d4f2fba49bba2334f25cc", "value": "Downloading (…)/adapter_config.json: 100%" } }, "c68e986ed38b4b3386d2d6b05812e10f": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_a38ab65e26b94c1493958f7ef0b127a5", "max": 456, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_063115df07914fba94f2b7974ea686e1", "value": 456 } }, "293b58e6082f4e5b9e8f127c878ec48f": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_a3d1887fca2545639019d0773051e3eb", "placeholder": "​", "style": "IPY_MODEL_4c7e49659eb54d3ba5d4b967e622988c", "value": " 456/456 [00:00<00:00, 26.4kB/s]" } }, "b8df700ff90346d0a21eed74b0fe9a83": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "723113d4cbbd4f2bace792606d811abf": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "70abccffc23d4f2fba49bba2334f25cc": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "a38ab65e26b94c1493958f7ef0b127a5": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "063115df07914fba94f2b7974ea686e1": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "a3d1887fca2545639019d0773051e3eb": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "4c7e49659eb54d3ba5d4b967e622988c": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "8e229cd9f36e497c8a0b7bbe5754911d": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_067236d291914321b66b8fb2208f083f", "IPY_MODEL_addcd3ea0cc140b0bd90e0337b847883", "IPY_MODEL_dcf7d9ba3b2c40fa86a1ecf295cff662" ], "layout": "IPY_MODEL_db42b3baf214409dbd3af6b99d6df4d7" } }, "067236d291914321b66b8fb2208f083f": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_ddee2a1d14874114bae4eec328f18081", "placeholder": "​", "style": "IPY_MODEL_7f27da02414e471d98ca74ee2e2b0156", "value": "Downloading adapter_model.bin: 100%" } }, "addcd3ea0cc140b0bd90e0337b847883": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_50e062dd562b456db4c9f76bb38a3a05", "max": 307151, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_13cc8293591b45d8906a85c8b2db6af7", "value": 307151 } }, "dcf7d9ba3b2c40fa86a1ecf295cff662": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_6d54c0c9d22d45f59929e03b68a446e3", "placeholder": "​", "style": "IPY_MODEL_aa1d53df052b4965ac2ae0ecb3d5b9b4", "value": " 307k/307k [00:00<00:00, 1.61MB/s]" } }, "db42b3baf214409dbd3af6b99d6df4d7": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "ddee2a1d14874114bae4eec328f18081": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "7f27da02414e471d98ca74ee2e2b0156": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "50e062dd562b456db4c9f76bb38a3a05": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "13cc8293591b45d8906a85c8b2db6af7": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "6d54c0c9d22d45f59929e03b68a446e3": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "aa1d53df052b4965ac2ae0ecb3d5b9b4": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } } } } }, "nbformat": 4, "nbformat_minor": 5 }