{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Model Training Tutorial\n", "\n", "Welcome to the model training tutorial! In this tutorial, we will train a neural network to classify tiles from our toy data set and visualize its efficacy. For now, the classifier code in full can be found at: https://github.com/msk-mind/sandbox/blob/master/classifier/crc_tile_classifier/tile_classifier.py. Our model is essentially a wrapper around PyTorch's ResNet 18 deep residual network; the LUNA team modified it to suit their work with tiling the slides. Here are the steps we will review in this notebook:\n", "\n", "- Initialize project space for model\n", "- Import libraries and set constants\n", "- Define parameters and set output directory\n", "- Load and prepare data\n", "- Define model\n", "- Train and save model\n", "- Visualize TensorBoard logging" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Initialize project space for model\n", "\n", "First, you must initialize a directory for the model and its outputs. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%bash\n", "\n", "mkdir PRO_12-123/model/\n", "mkdir PRO_12-123/model/outputs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Import libraries and set constants\n", "\n", "Since we are not using CLIs with built-in import statements for this notebook, it is necessary to include this on our own. " ] }, { "cell_type": "code", "execution_count": 274, "metadata": {}, "outputs": [], "source": [ "# todo: remove libs\n", "import datetime\n", "import sys\n", "import os\n", "\n", "from collections import Counter\n", "from typing import List, Optional, Tuple\n", "\n", "sys.path.append(os.pardir)\n", "\n", "import matplotlib.pyplot as plt\n", "import pickle\n", "import numpy as np\n", "import pandas as pd\n", "import pyarrow.dataset as ds\n", "import pyarrow.parquet as pq\n", "\n", "import torch\n", "\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "import torchvision.models as models\n", "\n", "from PIL import Image\n", "from pyarrow import fs\n", "from tensorboard import notebook\n", "from torchvision import transforms\n", "from torch.utils.data import Dataset, DataLoader\n", "from torch.utils.data.sampler import SubsetRandomSampler\n", "\n", "from classifier.model import Classifier\n", "from classifier.data import (\n", " TileDataset,\n", " get_stratified_sampler,\n", " get_group_stratified_sampler,\n", ")\n", "from classifier.utils import set_seed, seed_workers\n", "\n", "# constants\n", "torch.set_default_tensor_type(torch.FloatTensor)\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define parameters and set output directories\n", "\n", "Next, you will define parameters that are crucial to your model. You will also Here are some notes to explain the parameters:\n", "\n", "- `batch_size` refers to the number of training examples used in one iteration (number of times the algorithm's parameters are updated); this model runs with mini-batch mode, where the batch size is less than the total dataset size but greater than one. \n", "- `validation_split` refers to the percentage of data that will be used to validate the efficacy of the model as opposed to data that will be used to train the model.\n", "- `shuffle_dataset` is set to true in order for the model to randomize the order of samples within batches between epochs \n", "- `random_seed` is the starting point for generating random numbers in a number sequence. Randomness will be used in many different aspects of the machine learning algorithm, such as shuffling the datset between epochs and randomly initializing weights for the parameters. \n", "- `lr` refers to the learning rate of the gradient descent optimizer.\n", "- `epochs` refers to the number of times the algorithm will loop through the entire dataset during training. Setting this variable to 10 specifies that the dataset will be run through 10 times before the model is finished training." ] }, { "cell_type": "code", "execution_count": 275, "metadata": {}, "outputs": [], "source": [ "# define params\n", "batch_size = 64\n", "validation_split = 0.50\n", "shuffle_dataset = False\n", "random_seed = 42 \n", "lr = 1e-3\n", "epochs = 10\n", "n_workers = 20\n", "\n", "# set random seed\n", "set_seed(random_seed)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You will also set your output directory. This directory will include the checkpoints from the model as well as outputs that will enable the TensorBoard extension to generate figures documenting the model's progression as it trains on the data." ] }, { "cell_type": "code", "execution_count": 276, "metadata": {}, "outputs": [], "source": [ "# set output directory\n", "output_dir = 'PRO_12-123/model/outputs/'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load and prepare data\n", "\n", "First, we must specify attributes of our data: the label set and path. The regional annotations referenced many properties of each image, and we will train our model on five of the most prevalent properties. Note that the `tumor` class must correspond to the first label in the label set; otherwise, our inference will be incorrect further down the pipeline." ] }, { "cell_type": "code", "execution_count": 277, "metadata": {}, "outputs": [], "source": [ "# load data\n", "label_set = {\n", " \"tumor\":0,\n", " \"vessels\":1,\n", " \"stroma\":2\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We must also specify the path to our data and load it into a pandas dataframe. We may see the labels present in our data and the number of instances in each data by taking a closer look at the dataframe." ] }, { "cell_type": "code", "execution_count": 278, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "regional_label\n", "adipocytes 45\n", "arteries 3921\n", "lympho_poor_tumor 3806\n", "lympho_rich_stroma 991\n", "lympho_rich_tumor 5739\n", "veins 194\n", "dtype: int64\n" ] } ], "source": [ "path_to_tile_manifest = 'PRO_12-123/tiles/ov_tileset'\n", "df = pq.ParquetDataset(path_to_tile_manifest).read().to_pandas()\n", "\n", "print(df.groupby(\"regional_label\").size())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To make classification simpler and more effective, we will do additional processing of the data. Specifically, we will group together `lympho_poor_tumor` and `lympho_rich tumor` under a generic `tumor` class, link `arteries` and `veins` under a `vessels` class, as well as remove classes with relatively few samples to get rid of noise. After doing so, we may check the regional labels again. Note that the tumor classes grew in size and that there are fewer labels in total. Our regional labels now match the `label_set` that we defined earlier." ] }, { "cell_type": "code", "execution_count": 290, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "regional_label\n", "stroma 991\n", "tumor 9545\n", "vessels 4115\n", "dtype: int64\n" ] } ], "source": [ "df['regional_label'] = df['regional_label'].str.replace('arteries', 'vessels')\n", "df['regional_label'] = df['regional_label'].str.replace('veins', 'vessels')\n", "\n", "df['regional_label'] = df['regional_label'].str.replace('lympho_poor_tumor', 'tumor')\n", "df['regional_label'] = df['regional_label'].str.replace('lympho_rich_tumor', 'tumor')\n", "\n", "df['regional_label'] = df['regional_label'].str.replace('lympho_rich_stroma', 'stroma')\n", "\n", "df = df[df['regional_label'] != 'adipocytes']\n", "\n", "print(df.groupby(\"regional_label\").size())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "TileDataset is a class that inherits from the PyTorch Dataset class. It transforms our dataset into a Tensor, specifies the labelset, and defines functions to calculate the length of the dataset and to get an item from the dataset. The main difference between this class and the PyTorch class is the latter function; to retrieve an item from the dataset, TileDataset uses image offset to retrieve a tile from an image." ] }, { "cell_type": "code", "execution_count": 291, "metadata": {}, "outputs": [], "source": [ "dataset_local = TileDataset(df, label_set)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With just a little more processing, we can reset the indices in our dataframe, convert patient IDs to numerical values if need be (in this case, our patient IDs are already numeric, so this line is commented out), and convert regional labels to integers." ] }, { "cell_type": "code", "execution_count": 292, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| \n", " | patient_id | \n", "id_slide_container | \n", "address | \n", "coordinates | \n", "otsu_score | \n", "purple_score | \n", "regional_label | \n", "tile_image_offset | \n", "tile_image_length | \n", "tile_image_size_xy | \n", "tile_image_mode | \n", "data_path | \n", "
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", "4 | \n", "2551028 | \n", "x128_y72_z20 | \n", "(128, 72) | \n", "1.0 | \n", "1.0 | \n", "1 | \n", "4.446781e+08 | \n", "49152.0 | \n", "128.0 | \n", "RGB | \n", "/gpfs/mskmindhdp_emc/user/shared_data_folder/p... | \n", "
| 1 | \n", "4 | \n", "2551028 | \n", "x128_y73_z20 | \n", "(128, 73) | \n", "1.0 | \n", "1.0 | \n", "1 | \n", "4.447273e+08 | \n", "49152.0 | \n", "128.0 | \n", "RGB | \n", "/gpfs/mskmindhdp_emc/user/shared_data_folder/p... | \n", "
| 2 | \n", "4 | \n", "2551028 | \n", "x128_y74_z20 | \n", "(128, 74) | \n", "1.0 | \n", "1.0 | \n", "1 | \n", "4.447764e+08 | \n", "49152.0 | \n", "128.0 | \n", "RGB | \n", "/gpfs/mskmindhdp_emc/user/shared_data_folder/p... | \n", "
| 3 | \n", "4 | \n", "2551028 | \n", "x128_y75_z20 | \n", "(128, 75) | \n", "1.0 | \n", "1.0 | \n", "1 | \n", "4.448256e+08 | \n", "49152.0 | \n", "128.0 | \n", "RGB | \n", "/gpfs/mskmindhdp_emc/user/shared_data_folder/p... | \n", "
| 4 | \n", "4 | \n", "2551028 | \n", "x128_y76_z20 | \n", "(128, 76) | \n", "1.0 | \n", "1.0 | \n", "1 | \n", "4.448748e+08 | \n", "49152.0 | \n", "128.0 | \n", "RGB | \n", "/gpfs/mskmindhdp_emc/user/shared_data_folder/p... | \n", "
| ... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
| 14646 | \n", "1 | \n", "2551571 | \n", "x453_y148_z20 | \n", "(453, 148) | \n", "1.0 | \n", "1.0 | \n", "0 | \n", "3.706257e+09 | \n", "49152.0 | \n", "128.0 | \n", "RGB | \n", "/gpfs/mskmindhdp_emc/user/shared_data_folder/p... | \n", "
| 14647 | \n", "1 | \n", "2551571 | \n", "x454_y146_z20 | \n", "(454, 146) | \n", "1.0 | \n", "1.0 | \n", "0 | \n", "3.713237e+09 | \n", "49152.0 | \n", "128.0 | \n", "RGB | \n", "/gpfs/mskmindhdp_emc/user/shared_data_folder/p... | \n", "
| 14648 | \n", "1 | \n", "2551571 | \n", "x454_y147_z20 | \n", "(454, 147) | \n", "1.0 | \n", "1.0 | \n", "0 | \n", "3.713286e+09 | \n", "49152.0 | \n", "128.0 | \n", "RGB | \n", "/gpfs/mskmindhdp_emc/user/shared_data_folder/p... | \n", "
| 14649 | \n", "1 | \n", "2551571 | \n", "x454_y148_z20 | \n", "(454, 148) | \n", "1.0 | \n", "1.0 | \n", "0 | \n", "3.713335e+09 | \n", "49152.0 | \n", "128.0 | \n", "RGB | \n", "/gpfs/mskmindhdp_emc/user/shared_data_folder/p... | \n", "
| 14650 | \n", "1 | \n", "2551571 | \n", "x455_y143_z20 | \n", "(455, 143) | \n", "1.0 | \n", "1.0 | \n", "2 | \n", "3.719725e+09 | \n", "49152.0 | \n", "128.0 | \n", "RGB | \n", "/gpfs/mskmindhdp_emc/user/shared_data_folder/p... | \n", "
14651 rows × 12 columns
\n", "