Source code for skoots.validate.lib

import skoots.train.loss
from typing import Dict, Tuple

import torch
from torch import Tensor
from tqdm import tqdm

import skoots.train.loss
from skoots.validate.utils import imread


[docs] def mask_to_bbox(mask: Tensor) -> Tuple[Tensor, Tensor]: """ Calculates the 3D bbox for each instance of an instance segmentation mask. Assumes each positive integer is an instance for class label. Returns a tensor of id labels and a tensor of bboxes bboxes are in format: [x0,y0,z0,x1,y1,z1] Assigns a bbox to each unique lablel! Does not mean each label has a valid bbox!!! Shapes: - mask: :math:`(1, X_{in}, Y_{in}, Z_{in})` - return[0]: Id labels: :math: `(N)` - return[1]: bboxes: :math: `(6, N)` :param mask: Input instance segmentation mask :return: id labels and bboxes """ assert mask.ndim == 4, "Mask ndim != 4" sparse_mask = mask.to_sparse() indices = sparse_mask.indices() values = sparse_mask.values() unique = torch.unique(values) # assured to have no zeros because tensor is sparse bboxes = torch.empty( (6, unique.shape[0]), device=mask.device, dtype=torch.int16 ) # preallocate for speed for i, u in enumerate(unique): ind = indices[1::, values == u] # just x,y,z dim of indicies bboxes[0, i] = ind[0].min() # x0 bboxes[1, i] = ind[1].min() # y0 bboxes[2, i] = ind[2].min() # z0 bboxes[3, i] = ind[0].max() # x1 bboxes[4, i] = ind[1].max() # y1 bboxes[5, i] = ind[2].max() # z1 return unique, bboxes
[docs] def valid_box_inds(boxes): """ returns the inds of all valid boxes :param boxes: [6, N] :return: [N] """ x0, y0, z0, x1, y1, z1 = ( boxes[0, :], boxes[1, :], boxes[2, :], boxes[3, :], boxes[4, :], boxes[5, :], ) inds = torch.logical_and(torch.logical_and((x1 > x0), (y1 > y0)), (z1 > z0)) return inds
[docs] def box_iou(a: Tensor, b: Tensor) -> Tensor: """ Compute the IoU of the cartesian product of two sets of boxes. Each box in each set shall be (x0, y0, z0, x0, y0, z0). Shapes: a: :math:`(6, N)`. b: :math:`(6, M)`. returns: :math: `(N, M)` :param a: box 1 :param b: box 2 :return: iou of each box in 1 against all boxes in 2 """ if not valid_box_inds(a).bool().all(): raise AssertionError("a does not follow (x0, y0, z0, x1, y1, z1) format.") if not valid_box_inds(b).bool().all(): raise AssertionError("b does not follow (x0, y0, z0, x1, y2, z1) format.") a = a.T.float() # we'll have buffer overflow otherwise b = b.T.float() # we'll have buffer overflow otherwise # find intersection lower_bounds = torch.max( a[:, :3].unsqueeze(1), b[:, :3].unsqueeze(0) ).float() # (n, m, 3) upper_bounds = torch.min( a[:, 3:].unsqueeze(1), b[:, 3:].unsqueeze(0) ).float() # (n, m, 3) intersection_dims = torch.clamp(upper_bounds - lower_bounds, min=0) # (n1, n2, 3) intersection = ( intersection_dims[:, :, 0].float() * intersection_dims[:, :, 1].float() * intersection_dims[:, :, 2].float() ) # (n, m) # Find areas of each box in both sets areas_a = (a[:, 3] - a[:, 0]) * (a[:, 4] - a[:, 1]) * (a[:, 5] - a[:, 2]) # (n) areas_b = (b[:, 3] - b[:, 0]) * (b[:, 4] - b[:, 1]) * (b[:, 5] - b[:, 2]) # (m) union = areas_a.unsqueeze(1) + areas_b.unsqueeze(0) - intersection # (n1, n2) return intersection / union # (n1, n2)
[docs] def calculate_accuracies_from_bbox( ground_truth: Dict[str, Tensor], predictions: Dict[str, Tensor], device: str | None = None, threshold=0.1, ): """ Calculates True positive, False Positive, False Negative from data_dict of segmentation 3d bboxes :param ground_truth: :param predictions: :param device: :param threshold: :return: """ device = device if device else ground_truth["boxes"].device _gt = ground_truth["boxes"].to(device) _pred = predictions["boxes"].to(device) iou = box_iou(_gt, _pred) gt_max, gt_indicies = iou.max(dim=1) gt = torch.logical_not(gt_max.gt(threshold)) if iou.shape[1] > 0 else torch.ones(0) pred = ( torch.logical_not(iou.max(dim=0)[0].gt(threshold)) if iou.shape[0] > 0 else torch.ones(0) ) true_positive = torch.sum(torch.logical_not(gt)) false_positive = torch.sum(pred) false_negative = torch.sum(gt) return ( true_positive, false_positive, false_negative, )
[docs] def accuracies_from_iou(iou: Tensor, thr: float = 0.1) -> Tensor: gt_max, gt_indicies = iou.max(dim=1) gt = torch.logical_not(gt_max.gt(thr)) if iou.shape[1] > 0 else torch.ones(0) pred = ( torch.logical_not(iou.max(dim=0)[0].gt(thr)) if iou.shape[0] > 0 else torch.ones(0) ) true_positive = torch.sum(torch.logical_not(gt)) false_positive = torch.sum(pred) false_negative = torch.sum(gt) return ( true_positive.cpu().item(), false_positive.cpu().item(), false_negative.cpu().item(), )
[docs] def mask_iou(gt: Tensor, pred: Tensor): """ Calculates the IoU of each object on a per-mask-basis. :param gt: mask 1 with N instances :param pred: mask 2 with M instances :return: NxM matrix of IoU's """ assert gt.shape == pred.shape, "Input tensors must be the same shape" assert gt.device == pred.device, "Input tensors must be on the same device" a_unique = gt.unique() a_unique = a_unique[a_unique > 0] b_unique = pred.unique() b_unique = b_unique[b_unique > 0] iou = torch.zeros( (a_unique.shape[0], b_unique.shape[0]), dtype=torch.float, device=gt.device ) for i, au in tqdm(enumerate(a_unique), total=len(a_unique)): _a = gt == au touching = pred[ _a ].unique() # we only calculate iou of lables which have "contact with" our mask touching = touching[touching != 0] for j, bu in enumerate(b_unique): if torch.any(touching == bu): _b = pred == bu intersection = torch.logical_and(_a, _b).sum() union = torch.logical_or(_a, _b).sum() iou[i, j] = intersection / union else: iou[i, j] = 0.0 return iou
[docs] def mask_dice(gt: Tensor, pred: Tensor): """ Calculates the Dice Index of each object on a per-mask-basis. :param gt: mask 1 with N instances :param pred: mask 2 with M instances :return: NxM matrix of IoU's """ assert gt.shape == pred.shape, "Input tensors must be the same shape" assert gt.device == pred.device, "Input tensors must be on the same device" a_unique = gt.unique() a_unique = a_unique[a_unique > 0] b_unique = pred.unique() b_unique = b_unique[b_unique > 0] dice = torch.zeros( (a_unique.shape[0], b_unique.shape[0]), dtype=torch.float, device=gt.device ) for i, au in tqdm(enumerate(a_unique), total=len(a_unique)): _a = gt == au touching = pred[ _a ].unique() # we only calculate iou of lables which have "contact with" our mask touching = touching[touching != 0] for j, bu in enumerate(b_unique): if torch.any(touching == bu): _b = pred == bu numerator = torch.logical_and(_a, _b).sum() * 2 denominator = _a.sum() + _b.sum() assert ( numerator < denominator ), f"{numerator=}, {denominator=}, {_a.sum()=}, {_b.sum()=}, {(_a*_b).sum()=}" dice[i, j] = numerator / denominator else: dice[i, j] = 0.0 return dice
[docs] def mask_soft_cldice(gt: Tensor, pred: Tensor): """ Calculates the Dice Index of each object on a per-mask-basis. :param gt: mask 1 with N instances :param pred: mask 2 with M instances :return: NxM matrix of IoU's """ assert gt.shape == pred.shape, "Input tensors must be the same shape" assert gt.device == pred.device, "Input tensors must be on the same device" a_unique = gt.unique() a_unique = a_unique[a_unique > 0] b_unique = pred.unique() b_unique = b_unique[b_unique > 0] criterion = torch.compile(skoots.train.loss.soft_cldice()) cldice = torch.zeros( (a_unique.shape[0], b_unique.shape[0]), dtype=torch.float, device=gt.device ) for i, au in tqdm(enumerate(a_unique), total=len(a_unique)): _a = gt == au touching = pred[ _a ].unique() # we only calculate iou of lables which have "contact with" our mask touching = touching[touching != 0] for j, bu in enumerate(b_unique): if torch.any(touching == bu): _b = pred == bu cldice[i, j] = criterion(_b.float(), _a.float()) else: cldice[i, j] = 0.0 return cldice
[docs] def sparse_mask_iou(a: Tensor, b: Tensor) -> Tensor: """ Calculates the IoU of each object on a per-mask-basis using sparse tensors. :param a: mask 1 with N instances :param b: mask 2 with M instances :return: NxM matrix of IoU's """ raise NotImplementedError("In Development...") assert a.shape == b.shape, "Input tensors must be the same shape" assert a.device == b.device, "Input tensors must be on the same device" shape = a.shape a = a.to_sparse_coo() b = b.to_sparse_coo() a_unique = a.labels().unique() b_unique = b.labels().unique() a_indicies = a.indicies() b_indicies = b.indicies() iou = torch.zeros( (a_unique.shape[0], b_unique.shape[0]), dtype=torch.float, device=a.device ) for i, au in tqdm(enumerate(a_unique)): for j, bu in enumerate(b_unique): _a_ind = a_indicies[a.lables() == au] _b_ind = b_indicies[b.lables() == bu] a_sparse = torch.sparse_coo_tensor( indices=_a_ind, labels=torch.ones_like(a.labels() == au), size=shape ) b_sparse = torch.sparse_coo_tensor( indices=_b_ind, labels=torch.ones_like(b.labels() == bu), size=shape )
[docs] def f1_score(tp, fp, fn): num = 2 * tp dem = 2 * tp + fp + fn return num / dem
[docs] def _iou_instance_dict(a: Tensor, b: Tensor) -> Dict[int, Tensor]: """ Given two instance masks, compares each instance in b against a. Usually assumes A is the ground truth. :param a: Mask A :param b: Mask B :return: Dict of instances and every IOU for each instance """ a_unique = a.unique() a_unique = a_unique[a_unique > 0] b_unique = b.unique() b_unique = b_unique[b_unique > 0] iou = {} for i, au in tqdm(enumerate(a_unique), total=len(a_unique)): _a = a == au touching = b[ _a ].unique() # we only calculate iou of lables which have "contact with" our mask touching = touching[touching != 0] iou[au] = [] for j, bu in enumerate(b_unique): if torch.any(touching == bu): _b = b == bu intersection = torch.logical_and(_a, _b).sum() union = torch.logical_or(_a, _b).sum() iou[au].append((intersection / union).item()) return iou
[docs] def get_segmentation_errors(ground_truth: Tensor, predicted: Tensor) -> float: """ Calculates the IoU of each object on a per-mask-basis. :param ground_truth: mask 1 with N instances :param predicted: mask 2 with M instances :return: NxM matrix of IoU's """ iou = _iou_instance_dict(ground_truth, predicted) for k, v in iou.items(): iou[k] = torch.tensor(v) num_split = 0 for k, v in iou.items(): if v.gt(0.2).int().sum() > 1: num_split += 1 over_segmentation_rate = num_split / len(iou) iou = _iou_instance_dict(predicted, ground_truth) for k, v in iou.items(): iou[k] = torch.tensor(v) num_split = 0 for k, v in iou.items(): if v.gt(0.2).int().sum() > 1: num_split += 1 under_segmentation_rate = num_split / len(iou) return over_segmentation_rate, under_segmentation_rate
if __name__ == "__main__": gt = imread("../../tests/test_data/hide_validate.labels.tif")[..., 2:-2] pred = imread("../../tests/test_data/hide_validate_skeleton_instance_mask.tif")[ ..., 2:-2 ] aff_pred = imread( "../../tests/test_data/hide_validation_affinity_instance_segmentaiton.tif" )[..., 2:-2] device = "cuda" if torch.cuda.is_available() else "cpu" gt = gt.to(device) pred = pred.to(device) aff_pred = aff_pred.to(device) u, c = aff_pred.unique(return_counts=True) for a, b in tqdm(zip(u, c)): if b < 500: aff_pred[aff_pred == a] = 0 skoots_seg_errors = get_segmentation_errors(gt, pred) aff_seg_errors = get_segmentation_errors(gt, aff_pred) # print('SKOOTS IOU') # if not os.path.exists('../../tests/iou_gt_skoots_mask.trch'): # iou_skoots = mask_iou(gt, pred) # torch.save(iou_skoots, '../../tests/iou_gt_skoots_mask.trch') # else: # iou_skoots = torch.load('../../tests/iou_gt_skoots_mask.trch') # # print('AFFINITES IOU') # if not os.path.exists('../../tests/iou_gt_affinites_mask.trch'): # iou_aff = mask_iou(gt, aff_pred) # torch.save(iou_aff, '../../tests/iou_gt_affinites_mask.trch') # else: # iou_aff = torch.load('../../tests/iou_gt_affinites_mask.trch') # # # iou_skoots = torch.load('../../tests/iou_gt_skoots_mask.trch') # # iou_aff = torch.load('../../tests/iou_gt_affinites_mask.trch') # # # # tfp_skoots = [accuracies_from_iou(iou_skoots, thr/100) for thr in range(100)] # tfp_aff = [accuracies_from_iou(iou_aff, thr/100) for thr in range(100)] # # precision_skoots = [(tp /(tp + fp)) for (tp, fp, fn) in tfp_skoots] # recall_skoots = [(tp / (tp + fn)) for (tp, fp, fn) in tfp_skoots] # # precision_aff = [(tp/(tp+fp)) for (tp, fp, fn) in tfp_aff] # recall_aff = [(tp / (tp + fn)) for (tp, fp, fn) in tfp_aff] # # f1_skoots = [f1_score(*a) for a in tfp_skoots] # f1_aff = [f1_score(*a) for a in tfp_aff] # # plt.plot(np.arange(0, 100), precision_skoots) # plt.plot(np.arange(0, 100), precision_aff) # plt.legend(['SKOOTS', 'AFFINITIES']) # plt.ylabel('Precision') # plt.xlabel('IoU Threshold: (%)') # plt.show() # # plt.plot(np.arange(0, 100), recall_skoots) # plt.plot(np.arange(0, 100), recall_aff) # plt.legend(['SKOOTS', 'AFFINITIES']) # plt.ylabel('Recall') # plt.xlabel('IoU Threshold: (%)') # plt.show() # # plt.plot(np.arange(0, 100), f1_skoots) # plt.plot(np.arange(0, 100), f1_aff) # plt.legend(['SKOOTS', 'AFFINITIES']) # plt.ylabel('F1') # plt.xlabel('IoU Threshold: (%)') # plt.show()