"""load_classifier
Modified version of [1], rewritten to load tissue classifier models trained by
the MIND team.
[1] https://github.com/msk-mind/data-processing/blob/refactor-cli-fix-tests/data_processing.pathology/models/tissuenet.py
"""
from typing import Callable
import torch
import torchvision
import torchvision.models as models
[docs]def get_classifier(
checkpoint_path:str="/gpfs/mskmindhdp_emc/user/shared_data_folder/kohlia/tile_classifier/ckpts/1.ckpt",
n_classes:int=5,
):
"""get_classifier
loads a model from a checkpoint and unpacks the network
Args:
checkpoint_path (str): path to model checkpoint
n_classes: number of classes used in training, used to set final layer in
PyTorch model
Returns:
nn.Module: PyTorch module with loaded weights
"""
model = models.resnet18(num_classes=n_classes)
model.load_state_dict(
torch.load(checkpoint_path, map_location="cpu")["model_states"]["net"]
)
return model