import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset
from PIL import Image
[docs]class BaseTorchTileDataset(Dataset):
"""Base class for a tile dataset
Impliments the usual torch dataset methods, and additionally provides a decoding of the binary tile data.
PIL images can be further preprocessed before becoming torch tensors via an abstract preprocess method
Will send the tensors to gpu if available, on the device specified by CUDA_VISIBLE_DEVICES="1"
"""
def __init__(self, tile_manifest=None, tile_path=None, label_cols=[], **kwargs):
"""Initialize BaseTileDataset
Can accept either a tile dataframe or a path to tile data
Args:
tile_manifest (pd.DataFrame): Dataframe of tile data
tile_path (str): Base path of tile data
label_cols (list[str]): (Optional) label columns to return as tensors, e.g. for training
"""
if tile_manifest is not None:
self.tile_manifest = tile_manifest
elif tile_path is not None:
self.tile_manifest = pd.read_csv(tile_path + 'address.slice.csv').set_index("address")
self.tile_manifest['data_path'] = tile_path + 'tiles.slice.pil'
else:
raise RuntimeError("Must specifiy either tile_manifest or tile_path")
self.label_cols = label_cols
self.setup(**kwargs)
def __len__(self):
return len(self.tile_manifest)
def __repr__(self):
return f"TileDataset with {len(self.tile_manifest)} tiles, indexed by {self.tile_manifest.index.names}, returning label columns: {self.label_cols}"
def __getitem__(self, idx: int):
"""Tile accessor
Loads a tile image from the tile manifest. Returns a batch of the indicies of the input dataframe, the tile data always.
If label columns where specified, the 3rd position of the tuple is a tensor of the label data.
Args:
idx (int): Integer index
Returns:
(str, torch.tensor, optional torch.tensor): tuple of the tile index and corresponding tile as a torch tensor, and metadata labels if specified
"""
if not type(idx)==int: raise TypeError(f"BaseTileDataset only accepts interger indicies, got {type(idx)}")
row = self.tile_manifest.iloc[idx]
with open(row.data_path, "rb") as fp:
fp.seek(int(row.tile_image_offset))
img = Image.frombytes(
row.tile_image_mode,
(int(row.tile_image_size_xy), int(row.tile_image_size_xy)),
fp.read(int(row.tile_image_length)),
)
if len(self.label_cols):
return row.name, self.preprocess(img), torch.tensor(row[self.label_cols].to_list())
else:
return row.name, self.preprocess(img)
[docs] def setup(self, **kwargs):
"""Set additional attributes for dataset class
Template/abstract method where a dataset is configured
Args:
kwargs: Keyward arguements passed onto the subclass method
"""
raise NotImplementedError("setup() has not be implimented in the subclass!")
[docs] def preprocess(self, input_tile: Image):
"""Preprocessing method called for each tile patch
Loads a tile image from the tile manifest, must be manually implimented to accept a single PIL image and return a torch tensor.
Args:
input_tile (Image): Integer index
Returns:
torch.tensor: Output tile as preprocessed tensor
"""
raise NotImplementedError("preprocess() has not be implimented in the subclass!")
[docs]class BaseTorchTileClassifier(nn.Module):
def __init__(self, **kwargs):
"""Initialize BaseTorchTileClassifier
Will run on cuda if available, on the device specified by CUDA_VISIBLE_DEVICES="1"
Args:
kwargs: Keyward arguements passed onto the subclass method
"""
super(BaseTorchTileClassifier, self).__init__()
self.cuda_is_available = torch.cuda.is_available()
self.setup(**kwargs)
self.eval()
if self.cuda_is_available: self.cuda()
[docs] def forward(self, index, tile_data):
"""Forward pass for base classifier class
Loads a tile image from the tile manifest
Args:
index (list[str]): Tile address indicies with length B
tile_data (torch.tensor): Input tiles of shape (B, *)
Returns:
pd.DataFrame: Dataframe of output features
"""
if self.cuda_is_available: tile_data = tile_data.cuda()
with torch.no_grad():
return pd.DataFrame(self.predict(tile_data).cpu().numpy(), index=index, )
[docs] def setup(self, **kwargs):
"""Set classifier modules
Template/abstract method where individual modules that make up the forward pass are configured
Args:
kwargs: Keyward arguements passed onto the subclass method
"""
raise NotImplementedError("setup() has not be implimented in the subclass!")
[docs] def predict(self, input_tiles: torch.tensor):
"""predict method
Loads a tile image from the tile manifest, must be manually implimented to pass the input tensor through the modules specified in setup()
Args:
input_tiles (torch.tensor): Input tiles of shape (B, *)
Returns:
torch.tensor: 2D tensor with (B, C) where B is the batch dimension and C are output classes or features
"""
raise NotImplementedError("predict() has not be implimented in the subclass!")