Source code for skoots.train.loss

from typing import List, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

from skoots.lib.morphology import binary_erosion
from skoots.lib.utils import crop_to_identical_size


[docs] def tversky_graphable(pred, gt, alpha, beta): true_positive: Tensor = pred.mul(gt).sum() false_positive: Tensor = torch.logical_not(gt).mul(pred).sum().add(1e-10).mul(alpha) false_negative: Tensor = ((1 - pred) * gt).sum() * beta tversky = (true_positive + 1e-10) / ( true_positive + false_positive + false_negative + 1e-10 ) return 1 - tversky
[docs] class jaccard(nn.Module): def __init__(self): super(jaccard, self).__init__()
[docs] def forward( self, predicted: torch.Tensor, ground_truth: torch.Tensor, eps: float = 1e-10 ) -> torch.Tensor: """ Returns jaccard index of two torch.Tensors :param predicted: [B, I, X, Y, Z] torch.Tensor - probabilities calculated from hcat.utils.embedding_to_probability where B: is batch size, I: instances in image :param ground_truth: [B, I, X, Y, Z] torch.Tensor - segmentation mask for each instance (I). :param eps: float - Very small number to ensure numerical stability. Default 1e-10 :return: jaccard_loss: [1] torch.Tensor - Result of Loss Function Calculation """ # Crop both tensors to the same shape predicted, ground_truth = crop_to_identical_size(predicted, ground_truth) intersection = (predicted * ground_truth).sum().add(eps) union = (predicted + ground_truth).sum().sub(intersection).add(eps) return 1.0 - (intersection / union)
[docs] def _dice(pred: Tensor, ground_truth: Tensor, eps: float): intersection = (pred * ground_truth).sum().add(eps) denominator = (pred + ground_truth).sum().add(eps) loss = 2 * intersection / denominator return 1-loss
[docs] class dice(nn.Module): def __init__(self): super(dice, self).__init__()
[docs] def forward( self, predicted: torch.Tensor, ground_truth: torch.Tensor, eps: float = 1e-10 ) -> torch.Tensor: """ Returns dice index of two torch.Tensors :param predicted: [B, I, X, Y, Z] torch.Tensor - probabilities calculated from hcat.utils.embedding_to_probability where B: is batch size, I: instances in image :param ground_truth: [B, I, X, Y, Z] torch.Tensor - segmentation mask for each instance (I). :param smooth: float - Very small number to ensure numerical stability. Default 1e-10 :return: dice_loss: [1] torch.Tensor - Result of Loss Function Calculation """ # # Crop both tensors to the same shape # predicted, ground_truth = crop_to_identical_size(predicted, ground_truth) # # intersection = (predicted * ground_truth).sum().add(smooth) # denominator = (predicted + ground_truth).sum().add(smooth) # loss = 2 * intersection / denominator # # return 1 - loss predicted, ground_truth = crop_to_identical_size(predicted, ground_truth) return _dice(predicted, ground_truth, eps)
[docs] class tversky(nn.Module):
[docs] def __init__(self, alpha: float, beta: float, eps: float): """ Returns dice index of two torch.Tensors :param alpha: float - Value which penalizes False Positive Values :param beta: float - Value which penalizes False Negatives :param eps: float - Numerical stability term """ super(tversky, self).__init__() self.alpha = torch.tensor(float(alpha)) self.beta = torch.tensor(float(beta)) self.eps = torch.tensor(float(eps))
[docs] def forward( self, predicted: Union[Tensor, List[Tensor]], ground_truth: Tensor ) -> Tensor: if self.alpha.device != predicted.device: # silently caches device self.alpha.to(predicted.device) self.beta.to(predicted.device) self.eps.to(predicted.device) futures: List[torch.jit.Future[torch.Tensor]] = [] # List of Tensors if isinstance(predicted, list): for i, pred in enumerate(predicted): futures.append( torch.jit.fork( self._tversky, pred, ground_truth[i, ...], self.alpha, self.beta, self.eps, ) ) # Already Batched Tensor elif isinstance(predicted, Tensor): for i in range(predicted.shape[0]): futures.append( torch.jit.fork( self._tversky, predicted[i, ...], ground_truth[i, ...], self.alpha, self.beta, self.eps, ) ) results: List[Tensor] = [] for future in futures: results.append(torch.jit.wait(future)) return torch.mean(torch.stack(results))
[docs] @staticmethod def _tversky( pred: Tensor, gt: Tensor, alpha: Tensor, beta: Tensor, eps: float = 1e-8 ): """ tversky loss on per image basis. Args: pred: [N, X, Y, Z] Tensor of predicted segmentation masks (N instances) gt: [N, X, Y, Z] Tensor of ground truth segmentation masks (N instances) alpha: Penalty to false positives beta: Penalty to false negatives eps: stability parameter Returns: """ # ------------------- Expand Masks unique = torch.unique(gt) unique = unique[unique != 0] # assert gt.ndim == 4, f'{gt.shape=}' _, x, y, z = gt.shape nd_masks = torch.zeros((unique.shape[0], x, y, z), device=pred.device) for i, id in enumerate(unique): nd_masks[i, ...] = (gt == id).float().squeeze(0) pred, nd_masks = crop_to_identical_size(pred, nd_masks) # assert not torch.any(torch.isnan(pred)), torch.sum(torch.isnan(pred)) true_positive: Tensor = pred.mul(nd_masks).sum() false_positive: Tensor = ( torch.logical_not(nd_masks).mul(pred).sum().add(1e-10).mul(alpha) ) false_negative: Tensor = ((1 - pred) * nd_masks).sum() * beta # assert not torch.any(torch.isnan(true_positive)), torch.sum(torch.isnan(true_positive)) # assert not torch.any(torch.isnan(false_negative)) # assert not torch.any(torch.isnan(false_positive)) # # assert not torch.any(torch.isinf(true_positive)) # assert not torch.any(torch.isinf(false_negative)) # assert not torch.any(torch.isinf(false_positive)) tversky = (true_positive + eps) / ( true_positive + false_positive + false_negative + eps ) return 1 - tversky
def __repr__(self): return f"LossFn[name=tversky, alpha={self.alpha.item()}, beta={self.beta.item()}, eps={self.eps.item()}"
[docs] class split(nn.Module): def __init__(self, n_iter: int = 2, alpha: float = 2.0, device: str = "cpu"): """ The "oh shit my skeletons have split" loss. This basically checks if an edge has crossed the middle of a GT object. If it does, it applies a crazy loss. Approximates the distance function by just eroding and adding a bunch. Approximates the edge function by subtracting the prediction by the eroded For speed, will only check for pixels :math:`n_{iter}` away. So if :math:`n_{iter} = 3' the maximum distance any pixel might be from an edge would be 3. Formally: if :math:`E` is the edge function and :math:`\Phi` is the distance function, we compute the loss :math:`L(s, p)` where :math:`s` is the ground truth skeleton, and :math:`p` is the predicted skeleton where .. math:: L(s, p) = E(p)^{ \alpha \Phi(s)} - 1 :param n_iter: Number of times to perform erosion for distance calculation. :param alpha: Scale factor for exponential loss. Large values penalize breakages more. :param device: a torch.device - 'cuda' or 'cpu' """ super(split, self).__init__() self.n = n_iter self._device = device self.a = torch.tensor(alpha, device=device)
[docs] def forward(self, pred, gt): distance = torch.zeros_like(gt) + gt for _ in range(self.n): gt = binary_erosion(gt) distance = distance + gt # psuedo distance function... distance = distance.div(self.n - 1) pred = pred.sub(binary_erosion(pred)).mul(2) # cheeky edge detection function _split_loss = self._split_loss(edges=pred, distance=distance, a=self.a) return _split_loss.mean()
[docs] @staticmethod # @torch.jit.script def _split_loss(edges: Tensor, distance: Tensor, a: Tensor): return torch.pow(edges, a * distance)
############## clDICE LOSS FROM: https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/soft_skeleton.py
[docs] def soft_erode(img: Tensor) -> Tensor: """approximates morphological operations through max_pooling for 2D and 3D""" if len(img.shape) == 4: p1 = -F.max_pool2d(-img, (3, 1), (1, 1), (1, 0)) p2 = -F.max_pool2d(-img, (1, 3), (1, 1), (0, 1)) return torch.min(p1, p2) elif len(img.shape) == 5: p1 = -F.max_pool3d(-img, (3, 1, 1), (1, 1, 1), (1, 0, 0)) p2 = -F.max_pool3d(-img, (1, 3, 1), (1, 1, 1), (0, 1, 0)) p3 = -F.max_pool3d(-img, (1, 1, 3), (1, 1, 1), (0, 0, 1)) return torch.min(torch.min(p1, p2), p3)
[docs] def soft_dilate(img: Tensor) -> Tensor: """approximates morphological operations through max_pooling for 2D and 3D""" if len(img.shape) == 4: return F.max_pool2d(img, (3, 3), (1, 1), (1, 1)) elif len(img.shape) == 5: return F.max_pool3d(img, (3, 3, 3), (1, 1, 1), (1, 1, 1))
[docs] def soft_open(img: Tensor) -> Tensor: """approximates morphological operations through max_pooling for 2D and 3D""" return soft_dilate(soft_erode(img))
[docs] def soft_skeletonize(img: Tensor, iter_: int) -> Tensor: """ Performs a soft-skeletonization by terativly performing "soft morphological operations" :param img: Image to perform operation on :param iter_: Number of times to perform the operation :return: Soft-skeleton """ img1 = soft_open(img) skel = F.relu(img - img1) for j in range(iter_): img = soft_erode(img) img1 = soft_open(img) delta = F.relu(img - img1) skel = skel + F.relu(delta - skel * delta) return skel
[docs] class soft_cldice(nn.Module): def __init__(self, iter_=3, smooth=1.0): super(soft_cldice, self).__init__() self.iter = iter_ self.smooth = smooth
[docs] def forward(self, predicted: Tensor, ground_truth: Tensor) -> Tensor: """ Calculates the soft-clDice metric on a true and predicted value :param ground_truth: :param predicted: :return: """ skeleton_predicted = soft_skeletonize(predicted, self.iter) skeleton_true = soft_skeletonize(ground_truth, self.iter) tprec = ( torch.sum(torch.multiply(skeleton_predicted, ground_truth)[:, 1:, ...]) + self.smooth ) / (torch.sum(skeleton_predicted[:, 1:, ...]) + self.smooth) tsens = ( torch.sum(torch.multiply(skeleton_true, predicted)[:, 1:, ...]) + self.smooth ) / (torch.sum(skeleton_true[:, 1:, ...]) + self.smooth) cl_dice = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens) return cl_dice
[docs] def soft_dice(predicted: Tensor, ground_truth: Tensor, smooth: int = 1) -> Tensor: """ Computes the soft dice metric :param ground_truth: :param predicted: :param smooth: smoothing factor to prevent division by zero :return: """ intersection = torch.sum((ground_truth * predicted)) coeff = (2.0 * intersection + smooth) / ( torch.sum(ground_truth) + torch.sum(predicted) + smooth ) return 1.0 - coeff
[docs] class soft_dice_cldice(nn.Module): def __init__(self, iter_=3, alpha=0.5, smooth=1.0): super(soft_dice_cldice, self).__init__() self.iter = iter_ self.smooth = smooth self.alpha = alpha
[docs] def forward(self, predicted: Tensor, ground_truth: Tensor) -> Tensor: """ Calculates a singular loss value combining soft-Dice and soft-clDice which can be used to train a neural network :param predicted: Input tensor :param ground_truth: Ground Truth Tensor :return: Single value which to perform a backwards pass """ dice = soft_dice(ground_truth, predicted) skel_pred = soft_skeletonize(predicted, self.iter) skel_true = soft_skeletonize(ground_truth, self.iter) tprec = (torch.sum(skel_pred * ground_truth) + self.smooth) / ( torch.sum(skel_pred) + self.smooth ) tsens = (torch.sum(skel_true * predicted) + self.smooth) / ( torch.sum(skel_true) + self.smooth ) cl_dice = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens) return (1.0 - self.alpha) * dice + self.alpha * cl_dice
if __name__ == "__main__": lossfn = soft_dice_cldice() predicted = torch.rand((1, 1, 20, 20, 10), device="cpu") gt = torch.rand((1, 1, 20, 20, 10), device="cpu").round().float() a = lossfn(predicted, gt) print(a)