Source code for skoots.lib.morphology

from typing import List
from typing import Tuple

import torch
import torch.nn.functional as F
from torch import Tensor
from functools import cache


[docs] @cache def _compute_zero_padding(kernel_size: List[int]) -> Tuple[int, int, int]: r"""Utility function that computes zero padding tuple. Adapted from Kornia """ computed: List[int] = [(k - 1) // 2 for k in kernel_size] return computed[0], computed[1], computed[2]
[docs] @cache def _get_binary_kernel3d(window_size: int, device: str) -> Tensor: r"""Creates a symmetric binary kernel to extract the patches. If the window size is HxWxD will create a (H*W)xHxW kernel. Adapted from a 2D Kornia implementation """ window_range: int = int(window_size**3) kernel: Tensor = torch.zeros( (window_range, window_range, window_range), device=device ) for i in range(window_range): kernel[i, i, i] += 1.0 kernel = kernel.view(-1, 1, window_size, window_size, window_size) # get rid of all zero kernels ind = torch.nonzero(kernel.view(kernel.shape[0], -1).sum(1)) return kernel[ind[:, 0], ...]
[docs] @cache def _get_binary_kernel2d(window_size: int, device: str) -> Tensor: r"""Creates a symmetric binary kernel to extract the patches. If the window size is HxWxD will create a (H*W)xHxW kernel. Adapted from a 2D Kornia implementation """ window_range: int = int(window_size**3) kernel: Tensor = torch.zeros((window_range, window_range, 1), device=device) for i in range(window_range): kernel[i, i, 0] += 1.0 kernel = kernel.view(-1, 1, window_size, window_size, 1) # get rid of all zero kernels ind = torch.nonzero(kernel.view(kernel.shape[0], -1).sum(1)) return kernel[ind[:, 0], ...]
# re-implemented from torchvision.tensor.functional
[docs] @cache def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor: ksize_half = (kernel_size - 1) * 0.5 x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) pdf = torch.exp(-0.5 * (x / sigma).pow(2)) kernel1d = pdf / pdf.sum() return kernel1d
# re-implemented from torchvision.tensor.functional
[docs] @cache def _get_gaussian_kernel2d( kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device ) -> Tensor: kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to( device, dtype=dtype ) kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to( device, dtype=dtype ) kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :]) return kernel2d
# expanded to 3D
[docs] @cache def _get_gaussian_kernel3d( kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device ) -> Tensor: kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to( device, dtype=dtype ) kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to( device, dtype=dtype ) kernel1d_z = _get_gaussian_kernel1d(kernel_size[2], sigma[2]).to( device, dtype=dtype ) kernel3d = (kernel1d_x[:, None] @ kernel1d_y[None, :]).unsqueeze(-1) @ kernel1d_z[ None, : ] return kernel3d
[docs] @torch.jit.script def gauss_filter(input: Tensor, kernel: List[int], sigma: List[float]) -> Tensor: """ gaussian filter of a 3D tensor :param input: (B, C, X, Y, Z) :param kernel: [int, int, int] :param sigma: [float, float, float] :return: blured image """ padding: Tuple[int, int, int] = _compute_zero_padding(kernel) kernel: Tensor = _get_gaussian_kernel3d(kernel, sigma, input.dtype, input.device) kernel = kernel.expand( input.shape[1], 1, kernel.shape[0], kernel.shape[1], kernel.shape[2] ) features: Tensor = F.conv3d( input, kernel, padding=padding, stride=(1, 1, 1), groups=input.shape[1] ) return features
[docs] @torch.jit.script def binary_erosion(image: Tensor) -> Tensor: """ Performs binary erosion on a 5D Tensor. Shapes: - input: :math:`(B, C, X, Y, Z)` - output: :math:`(B, C, X, Y, Z)` :param image: binary image :return: eroded image """ device = str(image.device) kernel = _get_binary_kernel3d(3, device) padding = _compute_zero_padding((3, 3, 3)) b, c, h, w, d = image.shape # map the local window to single vector features: Tensor = F.conv3d( image.reshape(b * c, 1, h, w, d), kernel, padding=padding, stride=1 ) return features.min(dim=1)[0].unsqueeze(0)
[docs] @torch.jit.script def binary_dilation(image: Tensor) -> Tensor: """ Performs binary dilation on a 5D Tensor. Shapes: - input: :math:`(B, C, X, Y, Z)` - output: :math:`(C, C, X, Y, Z)` :param image: binary image :return: dilated image """ padding: Tuple[int, int, int] = _compute_zero_padding((3, 3, 3)) kernel: Tensor = _get_binary_kernel3d(3, str(image.device)) b, c, h, w, d = image.shape # map the local window to single vector features = F.conv3d( image.reshape(b * c, 1, h, w, d), kernel, padding=padding, stride=1 ) return torch.max(features.view(b, c, -1, h, w, d), dim=2)[0]
[docs] @torch.jit.ignore def binary_dilation_2d(image: Tensor) -> Tensor: """ Performs binary dilation on a 5D Tensor. Shapes: - input: :math:`(B, C, X, Y, Z)` - output: :math:`(C, C, X, Y, Z)` :param image: binary image :return: dilated image """ padding: Tuple[int, int, int] = _compute_zero_padding((3, 3, 1)) kernel: Tensor = _get_binary_kernel2d(3, str(image.device)) b, c, h, w, d = image.shape # map the local window to single vector features = F.conv3d( image.reshape(b * c, 1, h, w, d), kernel, padding=padding, stride=1 ) return torch.max(features.view(b, c, -1, h, w, d), dim=2)[0]
[docs] def median_filter(input: Tensor) -> Tensor: padding: Tuple[int, int, int] = _compute_zero_padding((3, 3, 3)) kernel: Tensor = _get_binary_kernel3d(3, input.dtype, input.device) b, c, h, w, d = input.shape # map the local window to single vector features: Tensor = F.conv3d( input.reshape(b * c, 1, h, w, d), kernel, padding=padding, stride=1 ) return torch.median(features.view(b, c, -1, h, w, d), dim=2)[0]
[docs] def mean_filter(input: Tensor) -> Tensor: padding: Tuple[int, int, int] = _compute_zero_padding((3, 3, 3)) kernel: Tensor = _get_binary_kernel3d(3, input.dtype, input.device) b, c, h, w, d = input.shape # map the local window to single vector features: Tensor = F.conv3d( input.reshape(b * c, 1, h, w, d), kernel, padding=padding, stride=1 ) return torch.mean(features.view(b, c, -1, h, w, d), dim=2)[0]
[docs] def dilate(input: Tensor) -> Tensor: padding: Tuple[int, int, int] = _compute_zero_padding((3, 3, 3)) kernel: Tensor = _get_binary_kernel3d(3, input.dtype, input.device) b, c, h, w, d = input.shape # map the local window to single vector features = F.conv3d( input.reshape(b * c, 1, h, w, d), kernel, padding=padding, stride=1 ) return torch.max(features.view(b, c, -1, h, w, d), dim=2)[0]