Source code for skoots.train.erosion

from typing import Dict, Tuple, Union, List

import torch
import torch.nn.functional as F
from torch import Tensor


[docs] @torch.jit.script def binary_erosion(image: Tensor) -> Tensor: kernel = _get_binary_kernel3d(3, image.device) padding = _compute_zero_padding((3, 3, 3)) b, c, h, w, d = image.shape # map the local window to single vector features: Tensor = F.conv3d( image.reshape(b * c, 1, h, w, d), kernel, padding=padding, stride=1 ) return features.min(dim=1)[0].unsqueeze(0)
[docs] class erosion: def __init__( self, kernel_targets: int = 3, rate: float = 0.5, device: str = "cpu" ) -> None: super(erosion, self).__init__() self.device = device self.rate = rate self.kernel_targets = kernel_targets if kernel_targets % 2 != 1: raise ValueError("Expected Kernel target to be Odd") self.padding: Tuple[int, int, int] = self._compute_zero_padding( (kernel_targets, kernel_targets, kernel_targets) ) self.kernel: torch.Tensor = self._get_binary_kernel3d(kernel_targets, device) def __call__( self, input: Union[Dict[str, torch.Tensor], torch.Tensor] ) -> Union[Dict[str, torch.Tensor], torch.Tensor]: """ Randomly applies a gaussian blur :param data_dict Dict[str, torch.Tensor]: data_dictionary from a dataloader. Has keys: key : val 'image' : torch.Tensor of size [C, X, Y, Z] where C is the number of colors, X,Y,Z are the mask height, width, and depth 'masks' : torch.Tensor of size [I, X, Y, Z] where I is the number of identifiable objects in the mask 'centroids' : torch.Tensor of size [I, 3] where dimension two is the [X, Y, Z] position of the centroid for instance i :return: data_dict Dict[str, torch.Tensor]: dictionary with identical keys as input, but with transformed values """ if isinstance(input, Dict): self.check_inputs(input) return self._is_dict(input) elif isinstance(input, torch.Tensor): return self._is_tensor(input)
[docs] def _is_dict(self, data_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: if data_dict["masks"].dtype != self.kernel.dtype: raise ValueError( f'Expected Image dtype to be {self.kernel.dtype} not {data_dict["masks"].dtype}' ) # We expect the image to have ndim==4 but need a batch size of 1 for median filter device = data_dict["image"].device if self.kernel.device != device: self.kernel = self.kernel.to(device) if torch.rand(1) < self.rate: im = data_dict["masks"].unsqueeze(0) b, c, h, w, d = im.shape # map the local window to single vector features: torch.Tensor = F.conv3d( im.reshape(b * c, 1, h, w, d), self.kernel, padding=self.padding, stride=1, ) data_dict["masks"] = features.min(dim=1)[0] return data_dict
[docs] def _is_tensor(self, input: torch.Tensor) -> torch.Tensor: "Assumes [C, X, Y, Z]" if input.device != self.kernel.device: raise ValueError( f"Expected Image Device to be {self.kernel.device} not {input.device}" ) if input.dtype != self.kernel.dtype: raise ValueError( f"Expected Image dtype to be {self.kernel.dtype} not {input.dtype}" ) if input.ndim != 4: raise ValueError( f"Expected Image ndim to be 4 not {input.ndim}, with shape {input.shape}" ) b, c, h, w, d = input.unsqueeze(0).shape # map the local window to single vector features: torch.Tensor = F.conv3d( input.reshape(b * c, 1, h, w, d), self.kernel, padding=self.padding, stride=1, ) return features.min(dim=1)[0]
[docs] def _compute_zero_padding(kernel_size: Tuple[int, int, int]) -> Tuple[int, int, int]: r"""Utility function that computes zero padding tuple. Adapted from Kornia """ computed: List[int] = [(k - 1) // 2 for k in kernel_size] return computed[0], computed[1], computed[2]
[docs] def _get_binary_kernel3d(window_size: int, device: torch.device) -> torch.Tensor: r"""Creates a symetric binary kernel to extract the patches. If the window size is HxWxD will create a (H*W)xHxW kernel. Adapted from Kornia """ window_range: int = int(window_size**3) kernel: torch.Tensor = torch.zeros( (window_range, window_range, window_range), device=device ) for i in range(window_range): kernel[i, i, i] += 1.0 kernel = kernel.view(-1, 1, window_size, window_size, window_size) # get rid of all zero kernels ind = torch.nonzero(kernel.view(kernel.shape[0], -1).sum(1)) return kernel[ind[:, 0], ...]