Detailed Inference Protocol#

Imports#

Inference is handled by the eval() function in skoots.lib.eval.py. We first need to import all requred libraries, notably zarr, fastremap, numpy, and torch.

import logging
import os.path
import tracemalloc
import warnings
import time

import fastremap
import zarr

import numpy as np
import skimage.io as io

import torch
import torch._dynamo
import torch.nn as nn
from torch import Tensor
from torch.cuda.amp import autocast
from tqdm import tqdm
from yacs.config import CfgNode

import skoots.lib.skeleton
from skoots.lib.cropper import crops
from skoots.lib.flood_fill import efficient_flood_fill
from skoots.lib.morphology import binary_dilation, binary_dilation_2d
from skoots.lib.utils import cfg_to_bism_model
from skoots.lib.vector_to_embedding import vector_to_embedding

warnings.filterwarnings("ignore")

Load the image and model#

SKOOTS performs best when evaluated with the same core parameters used for training. Therefore, we package these with the model file, such that they never need to be remembered. This takes the form of a YACS config node. We therefore load this, the model, and the input image.

@torch.inference_mode()  # disables autograd and reference counting for SPEED
def eval(
    image_path: str,
    checkpoint_path: str = "/home/chris/Dropbox (Partners HealthCare)/trainMitochondriaSegmentation/models/Oct21_17-15-08_CHRISUBUNTU.trch",
) -> None:
    tracemalloc.start()
    start = time.time()

    torch._dynamo.config.log_level = logging.ERROR
    logging.info(f"Loading model file: {checkpoint_path}")

    checkpoint = torch.load(checkpoint_path)
    if "cfg" in checkpoint:
        cfg: CfgNode = checkpoint["cfg"]
    else:
        raise RuntimeError("Attempting to evaluate skoots on a legacy model file.")

    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    filename_without_extensions = os.path.splitext(image_path)[0]

    # make sure the image is 5 channels, in [C, X, Y, Z] shape, and normalized between 0 and 1
    logging.info(f"Loading image from file: {image_path}")
    image: np.array = io.imread(image_path)  # [Z, X, Y, C]
    image: np.array = image[..., np.newaxis] if image.ndim == 3 else image
    image: np.array = image.transpose(-1, 1, 2, 0)
    image: np.array = image[[2], ...] if image.shape[0] > 3 else image  # [C=1, X, Y, Z]

    scale: int = 2**16 if image.dtype == np.uint16 else 2**8
    image: Tensor = torch.from_numpy(image).pin_memory()

    vector_scale = torch.tensor(cfg.SKOOTS.VECTOR_SCALING)

    # we use bism for constructing the models.
    logging.info(f"Constructing SKOOTS model")
    base_model: nn.Module = cfg_to_bism_model(cfg)  # This is our skoots torch model
    base_model.load_state_dict(state_dict=checkpoint["model_state_dict"])
    base_model = base_model.to(device).train()

    logging.info(f"Compiling SKOOTS model with torch inductor")
    model = torch.compile(base_model)
    for _ in range(10):  # warmup torchinductor
        _ = model(torch.rand((1, 1, 300, 300, 20), device=device, dtype=torch.float))

Preallocate intermediary arrays#

For inference, SKOOTS needs to keep track of the skeleton and embedding vectors. We can pre-allocate them here.

    c, x, y, z = image.shape
    skeleton = torch.zeros(size=(1, x, y, z), dtype=torch.int16)
    vectors = torch.zeros((3, x, y, z), dtype=torch.half)

Iterative Evaluation#

It is likely the deep neural network model cannot process the entire image at once. Rather, we crop the image using a cropping utility in skoots.lib.cropper.crops. This simply creates a generator which returns a crop, and it’s index. I have been tempted to increase the crop size, it somehow leads to degraded performance. I’d keep it at the default.

    cropsize = [300, 300, 20]  # DEFAULT (300, 300, 20), If you change it, the model might screw up!!!
    overlap = [10, 10, 5]

    total = skoots.lib.cropper.get_total_num_crops(image.shape, cropsize, overlap)
    iterator = tqdm(
        crops(image, cropsize, overlap, device=device), desc="", total=total
    )
    benchmark_start = time.time()

We can now loop over the crop and evaluate the model. To reduce the amount of storage, we only take the vectors and skeletons wich are likely in an object, as defined by the probability map.

    for crop, (x, y, z) in iterator:
        with autocast(enabled=True):  # Saves Memory!
            out = model(crop.div(scale).float().cuda())

        probability_map = out[:, [-1], ...]
        skeleton_map = out[:, [-2], ...].float()
        vec = out[:, 0:3:1, ...]

        vec = vec * probability_map.gt(0.5)
        skeleton_map = skeleton_map * probability_map.gt(0.5)

We have found that performing binary expansion on the skeletons in 2d/3d helps with overall accuraccy.

        for _ in range(
            1
        ):  # expand the skeletons in x/y/z. Only  because they can get too skinny
            skeleton_map = binary_dilation(skeleton_map)
            for _ in range(3):  # expand 2 times just in x/y
                skeleton_map = binary_dilation_2d(skeleton_map)

We now store the crop in the buffer.

        # put the predictions into the preallocated tensors...
        _destination = (
            ...,
            slice(x + overlap[0], x + cropsize[0] - overlap[0]),
            slice(y + overlap[1], y + cropsize[1] - overlap[1]),
            slice(z + overlap[2], z + cropsize[2] - overlap[2]),
        )

        _source = (
            0,
            ...,
            slice(overlap[0], -overlap[0]),
            slice(overlap[1], -overlap[1]),
            slice(overlap[2], -overlap[2]),
        )

        skeleton[_destination] = skeleton_map[_source].gt(0.8).cpu()
        vectors[_destination] = vec[_source].half().cpu()

        iterator.desc = f"Evaluating UNet on slice [x{x}:y{y}:z{z}]"

Assigning ID’s#

Once we have saved the entire skeleton and vectors, we now must assign an id to the skeletons. This is done via flood fill from skoots.lib.floot_fill.py. It’s not that efficient.

    skeleton: Tensor = efficient_flood_fill(skeleton)

Generate Instance Masks#

From the labeled skeletons, we can get instance masks using the embeddings. We do this via crops as well, using similar functions from training.

    cropsize = [500, 500, 50]
    overlap = (50, 50, 5)
    iterator = tqdm(
        crops(vectors, crop_size=cropsize, overlap=overlap), desc="Assigning Instances:"
    )

    instance_mask = torch.zeros_like(skeleton, dtype=torch.int16)
    skeleton = skeleton.unsqueeze(0).unsqueeze(0)

    logging.info(f"Identifying connected components...")
    for _vec, (x, y, z) in iterator:
        _destination = (
            slice(x + overlap[0], x + cropsize[0] - overlap[0]),
            slice(y + overlap[1], y + cropsize[1] - overlap[1]),
            slice(z + overlap[2], z + cropsize[2] - overlap[2]),
        )

        _source = (
            slice(overlap[0], -overlap[0]),
            slice(overlap[1], -overlap[1]),
            slice(overlap[2], -overlap[2]),
        )

        _embed = skoots.lib.vector_to_embedding.vector_to_embedding(
            scale=vector_scale, vector=_vec, N=2
        )
        _embed += torch.tensor((x, y, z)).view(
            1, 3, 1, 1, 1
        )  # We adjust embedding to region of the crop

Unlike training, we dont know where these embeddings are supposed to be. Rather we trust they are pointing into a skeleton. Therefore, to assign an instance label, we let an embedding point to a labeled skeleton, which assigns its label. This is simply an indexing operation, and is therefore fast.

        # This gives the instance mask!
        _inst_maks = skoots.lib.skeleton.index_skeleton_by_embed(
            skeleton=skeleton, embed=_embed
        ).squeeze()

        # Plop it back into the pre-allocated array
        instance_mask[_destination] = (
            _inst_maks[_source] if torch.tensor(overlap).gt(0).all() else _inst_maks
        )