Source code for skoots.train.generate_skeletons

import os.path
from typing import Tuple, Dict

import torch
import torch.nn.functional as F
from skimage.morphology import skeletonize
from torch import Tensor
from tqdm import tqdm
import kimimaro
import numpy as np
import skimage.io as io
import glob


[docs] def save_train_test_split( mask: Tensor, skeleton: Dict[int, Tensor], z_split: int, base: str ): """ Splits a volume of binary masks and skeletons. You CANNOT naively just split the mask in two as skeletons of objects on the border might not be properly calculated. Saves pickled Dict[int, Tensor] to base+'_train.skeletons.trch' and base+'_validate.skeletons.trch' :param mask: Instance masks :param skeleton: Dict of skeleton of EVERY object in mask :param z_split: Z index of the train test split :param base: base filepath by which to save. :return: None """ # train _mask = mask[..., 0 : z_split + 1 :] # assert 486 in _mask.unique() _skel = {} for u in _mask.unique(): u = int(u) if u == 486: print("We got em...") if u == 0: continue if u in skeleton: _skel[u] = skeleton[u] # else: # print(f'Not in Train: {u}') torch.save(_skel, base + "_train.skeletons.trch") _mask = mask[..., z_split::] _skel = {} for u in _mask.unique(): u = int(u) if u == 0: continue if u in skeleton: x = skeleton[u] x[:, 2] -= 150 _skel[u] = x torch.save(_skel, base + "_validate.skeletons.trch")
[docs] def calculate_skeletons(mask: Tensor, scale: Tensor) -> Dict[int, Tensor]: """ Calculates the skeleton of each object in mask :param mask: [C, X, Y, Z] :return: Dict[int, Tensor] dict of masks where int is the object id and Tensor is [3, K] skeletons """ # raise DeprecationWarning('This function is slow and should no longer be used...') unique = torch.unique(mask) print(f'found {len(unique)-1} objects to skeletonize.') print(f'object IDs: {unique.tolist()}') x, y, z = mask.shape if scale.sum() != 3: print(scale, x, y, z) large_mask = ( F.interpolate( mask.unsqueeze(0).unsqueeze(0).float(), size=torch.tensor([x, y, z]).mul(scale).float().round().int().tolist(), mode="nearest", ).squeeze() # .cuda() .int() ) else: large_mask = mask.clone() large_mask_unique = large_mask.unique().tolist() unique_list = unique.tolist() for u in unique: assert u in large_mask_unique, "Downscaled too much!" for u in large_mask_unique: assert u in unique_list, "Downscaled too much!" # assert torch.allclose( # unique.cuda(), torch.unique(large_mask) # ), f"{unique=}, {large_mask.unique()=}" # large_mask = mask unique, counts = torch.unique(large_mask, return_counts=True) output = {} for id in tqdm(unique): if id == 0: continue temp = large_mask == id nonzero = torch.nonzero(temp) lower = nonzero.min(0)[0] upper = nonzero.max(0)[0] # print(f'{id.item()=}, {lower=}, {upper=}') upper[upper.sub(lower) == 0] += 1 # Get just the region of a particular instance of the binary image... temp = temp[ lower[0].item() : upper[0].item(), # x lower[1].item() : upper[1].item(), # y lower[2].item() : upper[2].item(), # z ].float() _x = upper[0] - lower[0] _y = upper[1] - lower[1] _z = upper[2] - lower[2] # Calculate the binary skeleton of that image... skeleton = skeletonize(temp.cpu().numpy(), method="lee") skeleton = torch.from_numpy(skeleton).unsqueeze(0).unsqueeze(0) offset = lower.cpu().div(scale) # , rounding_mode='trunc') # offset = offset if torch.nonzero(skeleton).shape[0] != 0: skel = torch.nonzero(skeleton.squeeze(0).squeeze(0)).div(scale).add(offset) # print(skel) output[int(id)] = skel else: _nonzoer = torch.nonzero(temp.cpu()).float() _nonzoer = _nonzoer.unsqueeze(0) if _nonzoer.ndim == 1 else _nonzoer output[int(id)] = _nonzoer.mean(0).div(scale).add(offset).unsqueeze(0) assert ( output[int(id)].shape[0] > 0 and output[int(id)].ndim > 1 ), f"{temp.nonzero().shape=}, {lower=} {id}, {output[int(id)].shape}" return output
[docs] def _calculate_skeletons(mask: Tensor, scale: Tensor) -> Dict[int, Tensor]: """ image of shape [X, Y, Z,] with int masks returns dict[int, Tensor] :param mask: :param scale: :return: """ skels = kimimaro.skeletonize( mask.numpy(), anisotropy=scale, progress=True, dust_threshold=0, fill_holes=False, fix_avocados=False, parallel=0, ) skels = { k: torch.from_numpy(v.vertices).div(torch.tensor(scale)).round().long() for k, v in skels.items() } return skels
[docs] def create_gt_skeletons(base_dir, mask_filter, scale: Tuple[float, float, float]): if os.path.isdir(base_dir): files = glob.glob(os.path.join(base_dir, f"*{mask_filter}.tif")) # files = [b[:-11:] for b in files] print( f"found the following files in dir: {base_dir} with mask_filer: {mask_filter}:\n{files}" ) else: files = [base_dir] print(f'skeletonizing: {base_dir}') scale = torch.tensor(scale) for f in files: print("attempting to skeletonize ", f) mask = io.imread(f) mask = torch.from_numpy(mask.astype(np.int32)) mask = mask.permute((1, 2, 0)) print('loaded image: ', f) output = calculate_skeletons(mask, scale) for u in mask.unique(): if u == 0: continue assert int(u) in output, f"{f}, {u}, {output.keys()=}" torch.save(output, f + ".skeletons.trch") print("SAVED", f + ".skeletons.trch")
if __name__ == "__main__": import skimage.io as io import numpy as np import glob """ The move -> Calculate the skeleton of each instance and save it as a tensor of nonzero indicies For each pixel in the instance mask, we now need to calculate which skeleton point we need to point to To do this, we take the index of each point, find the skeleton point closest to it, andreplace it """ image = io.imread( "/home/chris/Dropbox (Partners HealthCare)/skoots-experiments/data/mitochondria/train/external/kasthuri.labels.tif" ) image = torch.from_numpy(image.transpose((1, 2, 0))) skels = calculate_skeletons(image, (1, 1, 7)) skeletons = torch.zeros_like(image) print(image.shape, skeletons.shape) for id, skel in skels.items(): print(skel.shape) skeletons[ skel[:, 0].long(), skel[:, 1].long(), skel[:, 2].div(7).round().long() ] = 1 io.imsave( "/home/chris/Desktop/skeletons_test.tif", skeletons.permute((2, 0, 1).numpy()) ) # bases = [ # # '/home/chris/Dropbox (Partners HealthCare)/trainMitochondriaSegmentation/data/train/', # "/home/chris/Dropbox (Partners HealthCare)/trainMitochondriaSegmentation/data/toBeSkeletonized/", # ] # if not torch.cuda.is_available(): # raise RuntimeError("NO CUDA") # # # Z_SCALE = 2 # SCALE = torch.tensor([0.2, 0.2, 3]) # # SCALE = torch.tensor([1,1,1]) # # bases = glob.glob(bases[0] + "*.labels.tif") # bases = [b[:-11:] for b in bases] # # for base in bases: # mask = io.imread(base + ".labels.tif") # mask = torch.from_numpy(mask.astype(np.int32)) # # mask = mask.permute((1, 2, 0)) # x, y, z = mask.shape # # output = calculate_skeletons(mask, SCALE) # # torch.save(output, base + ".skeletons.trch") # print("SAVED", base + ".skeletons.trch") # # save_train_test_split(mask, output, 150, base)