Training API Flow#

Presented here is the rough flow of data from necessary to train a SKOOTS segmentation model. This is markedly different from evaluation, which is presented elsewhere.

_images/skoots_train_api_flow.pdf _images/skoots_train_api_flow_inverted.pdf
skoots.lib.embedding_to_prob.baked_embed_to_prob(embedding, baked_skeletons, sigma, eps=1e-16)[source]

N Dimensional embedding to probability with a baked skeleton array

Calculates a probability \(\phi\) based on a euclidean distance between a spatial embedding \(E_i\) and a baked skeleton pixel \(S_i\).

\[\phi(E_i, S_i) =exp\left(\sum_{k \in [x,y,z]} \frac{(E_{ki} - S_{ki})^2}{-2\sigma^2_k} \right)\]

In three spatial dimmensions, this expands to

\[\phi(E_i, S_i) =exp\left(\frac{(E_{xi} - S_{xi})^2}{-2\sigma^2_x} + \frac{(E_{yi} - S_{yi})^2}{-2\sigma_y^2} + \frac{(E_{zi} - S_{zi})^2}{-2\sigma^2_z}\right)\]
Shapes:
  • embedding: \((B_{in}, 2, X_{in}, Y_{in})\) or \((B_{in}, 3, X_{in}, Y_{in}, Z_{in})\)

  • baked_skeletons: \((B_{in}, 2, X_in, Y_{in})\) or \((B_{in}, 3, X_{in}, Y_in, Z_{in})\)

  • sigma: \((2)\) or \((3)\)

  • returns: \((B_{in}, 1, X_{in}, Y_{in})\) or \((B_{in}, 1, X_{in}, Y_{in}, Z_{in})\)

Parameters:
  • embedding (Tensor) – embedding tensor

  • baked_skeletons (Tensor) – a baked skeleton tensor

  • sigma (Tensor) – Standard deviation of the gaussian. Larger values give higher probability further away.

  • eps (float) – small value for numerical stability

Return type:

Tensor

Returns:

Probability matrix

skoots.lib.skeleton.bake_skeleton(masks, skeletons, anisotropy=(1.0, 1.0, 1.0), average=True, device='cpu', return_distance=False)[source]

For each pixel \(p_ik\) of object \(k\) at index \(i\in[x,y,z]\) in masks, returns a baked skeleton where the value at each index is the closest skeleton point \(s_{jk}\) of any instance \(k\).

This should reflect the ACTUAL spatial distance of your dataset for best results…These models tend to like XY embedding vectors more than Z. For anisotropic datasets, you should roughly provide the anisotropic correction factor of each voxel. For instance anisotropy of (1.0, 1.0, 5.0) means that the Z dimension is 5x larger than XY.

Formally, the value at each position \(i\in[x,y,z]\) of the baked skeleton tensor \(S\) is the minimum of the euclidean distance function \(f(a, b)\) and the skeleton point of any instance:

\[S_{i} = min \left( f(i, s_{k})\right)\ for\ k \in [1, 2, ..., N]\]
Shapes:
  • masks: \((1, X_{in}, Y_{in}, Z_{in})\)

  • skeletons: \((3, N_i)\)

  • anisotropy: \((3)\)

  • returns: \((3, X_{in}, Y_{in}, Z_{in})\)

Parameters:
  • masks (Tensor) – Ground Truth instance mask of shape [1, X, Y, Z] of objects where each pixel is an integer id value.

  • skeletons (Dict[int, Tensor]) – Dict of skeleton indicies where each key is a unique instance of an object in mask. - Each skeleton has a shape [3, N] where N is the number of pixels constituting the skeleton

  • anisotropy (Tuple[float, float, float]) – Anisotropic correction factor for min distance calculation

  • average (bool) – Average the skeletons such that there is a smooth transition form one area to the next

  • device (str) – torch.Device by which to run calculations

  • return_distance (bool) – if true and bake_skeletons is dispatching the triton kernel, returns the distance to each closest skeleton

Return type:

Union[Tensor, Tuple[Tensor, Tensor]]

Returns:

Baked skeleton

skoots.lib.flood_fill.efficient_flood_fill(skeleton)[source]

Efficiently floods a binary skeleton mask in place by first flood filling small regions, then merging connected components later. Avoids memory copies when possible. Returns a skeleton mask where each connected component has a unique label, however these labels may not be sequential. I.e. unique(skeleton) -> [4, 16, 23, 24, 96]

Parameters:

skeleton (Tensor) – binary skeleton mask to flood fill

Return type:

Tensor

Returns:

Flood filled tensor

skoots.lib.vector_to_embedding.vector_to_embedding(scale, vector, N=1, decay=1.0)[source]

Converts a 2D or 3D vector field to a spatial embedding by adding the vector at any position to its own position.

vector is a 2D or 3D vector field of shape \((B, 2, X, Y)\) for 2D or \((B, 3, X, Y, Z)\) for 3D. Each vector “\(v\)” lies within the range -1 and 1 and is scaled by scale “\(s\)”. The scaled vector is then added to its own position to form a spatial embedding “\(\phi\)”:

Formally:
\[ \begin{align}\begin{aligned}i,j,k \in \mathbb{Z}_{≥0} \\v_{i,j,k} \in [-1, 1] \\s = [s_i, s_j, s_k]\\\phi_{i,j,k} = v_{i,j,k} * s + [i, j, k]\end{aligned}\end{align} \]
Shapes:
  • scale: \((2)\) or \((3)\)

  • vector: \((B_{in}, 2, X_{in}, Y_{in})\) or \((B_{in}, 3, X_{in}, Y_{in}, Z_{in})\)

  • Returns: \((B_{in}, 2, X_{in}, Y_{in})\) or \((B_{in}, 3, X_{in}, Y_{in}, Z_{in})\)

Parameters:
  • scale (Tensor) – Scaling factors for each vector spatial dimension

  • vector (Tensor) – Vector field predicted by a neural network

  • N (int) – Number of iterations to apply the vectors.

  • decay (float) – vector strength decay after each iteration. Default 1.0

Return type:

Tensor

Returns:

Pixel spatial embeddings

class skoots.train.loss.tversky(alpha, beta, eps)[source]

Returns dice index of two torch.Tensors

Parameters:
  • alpha (float) – float - Value which penalizes False Positive Values

  • beta (float) – float - Value which penalizes False Negatives

  • eps (float) – float - Numerical stability term

__init__(alpha, beta, eps)[source]

Returns dice index of two torch.Tensors

Parameters:
  • alpha (float) – float - Value which penalizes False Positive Values

  • beta (float) – float - Value which penalizes False Negatives

  • eps (float) – float - Numerical stability term

static _tversky(pred, gt, alpha, beta, eps=1e-08)[source]

tversky loss on per image basis.

Args:

pred: [N, X, Y, Z] Tensor of predicted segmentation masks (N instances) gt: [N, X, Y, Z] Tensor of ground truth segmentation masks (N instances) alpha: Penalty to false positives beta: Penalty to false negatives eps: stability parameter

Returns:

forward(predicted, ground_truth)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class skoots.train.dataloader.dataset(path, transforms=<function dataset.<lambda>>, pad_size=100, device='cpu', sample_per_image=1)[source]

Custom dataset for loading and accessing skoots training data. This class loads data based on filenames and specific extensions: ‘.tif’ (raw image), ‘.labels.tif’ (instance masks), ‘.skeletons.tif’ (precomputed skeletons). An example training data folder might contain the following:

data\
 └  train\
      │ train_data.tif
      │ train_data.labels.tif
      └ train_data.skeletons.tif
Parameters:
  • path (Union[List[str], str]) – Path to training data

  • transforms (Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]]) – A function which applies dataset augmentation on a data_dict

  • pad_size (Optional[int]) – padding to add to every image in the dataset

  • device (Optional[str]) – torch.device which to output all data on

  • sample_per_image (Optional[int]) – number of times each image/mask pair is sampled per iteration over a dataset

class skoots.train.dataloader.MultiDataset(*args)[source]

A utility class for joining multiple datasets into one accessible class. Sometimes, you may subdivide your training data based on some criteria. The most common is size: data from folder data/train/train_alot must be sampled 100 times per epoch, while data from folder data/train/train_notsomuch might only want to be sampled 1 times per epoch.

You could construct a two skoots.train.dataloader.dataset objects for each and access both in a single MultiDataset class…

>>> from skoots.train.dataloader import dataset
>>>
>>> # has one image sampled 100 times
>>> data0 = dataset('data/train/train_alot', sample_per_image=100)
>>> print(len(data0))  # 100
>>>
>>> # has one image sampled once
>>> data1 = dataset('data/train/train_notsomuch', sample_per_image=1)
>>> print(len(data1))  # 1
>>>
>>> merged_data = MultiDataset(data0, data1)
>>> print(len(merged_data))  # 101, they've been merged!
Parameters:

args

skoots.train.dataloader.skeleton_colate(data_dict)[source]

Colate function with defines how we batch training data. Unpacks a data_dict with keys: ‘image’, ‘masks’, ‘skele_masks’, ‘baked_skeleton’, ‘skeleton’ and puts them each into a Tensor. This should not be called outright, rather passed to a torch.DataLoader for automatic batching.

Parameters:

data_dict (List[Dict[str, Tensor]]) – Dictonary of augmented training data

Return type:

Tuple[Tensor, Tensor, List[Dict[str, Tensor]], Tensor, Tensor]

Returns:

Tuple of batched data