Source code for skoots.train.dataloader

from __future__ import annotations
import glob
import os.path
from typing import Dict
from typing import Tuple, Callable, List, Union, Optional, Any
import math

import numpy as np
import skimage.io as io
import torch
from torch import Tensor
from torch.utils.data import Dataset
import logging

# from skoots.train.merged_transform import get_centroids
from skoots.lib.custom_types import DataDict
import torch.jit
from numba import njit, prange


[docs] @njit(parallel=True, fastmath=True) def _sub_sq_sum(x: np.ndarray, other): """ subtracts other, sqares the result for each val, and sums to one number. :param x: np.ndarray :param other: number :return: number """ total = np.array([0.0], dtype=np.float64) x = x.flatten() for i in prange(x.shape[0]): total += (float(x[i]) - other) ** 2 return total
Transform = Callable[[Dict[str, Tensor]], Dict[str, Tensor]]
[docs] class dataset(Dataset): def __init__( self, path: Union[List[str], str], transforms: Optional[Transform] = lambda x: x, pad_size: Optional[int] = 100, device: Optional[str] = "cpu", sample_per_image: Optional[int] = 1, ): r""" Custom dataset for loading and accessing skoots training data. This class loads data based on filenames and specific extensions: '.tif' (raw image), '.labels.tif' (instance masks), '.skeletons.tif' (precomputed skeletons). An example training data folder might contain the following: :: data\ └ train\ │ train_data.tif │ train_data.labels.tif └ train_data.skeletons.tif :param path: Path to training data :param transforms: A function which applies dataset augmentation on a data_dict :param pad_size: padding to add to every image in the dataset :param device: torch.device which to **output** all data on :param sample_per_image: number of times each image/mask pair is sampled per iteration over a dataset """ super(Dataset, self).__init__() # Reassigning variables self.path = path self.files: List[str] = [] self.image: List[Tensor] = [] self.centroids: List[Tensor] = [] self.masks: List[Tensor] = [] self.skeletons: List[Dict[int, Tensor]] = [] self.baked_skeleton: List[Tensor] = [] self.transforms: Callable[[DataDict], DataDict] = transforms self.device = device self.pad_size: List[int] = [pad_size, pad_size] # cached mean self.__mean: Dict[str, Any] | None = None self.__std: Dict[str, Any] | None = None self.__numel: Dict[str, Any] | None = None self.__sum: Dict[str, Any] | None = None self.sample_per_image: int = sample_per_image # Store all possible directories containing data in a list path: List[str] = [path] if isinstance(path, str) else path logging.debug(f"Attempting to load files from: {path}") for p in path: self.files.extend(glob.glob(f"{p}{os.sep}*.labels.tif")) logging.debug(f"found the following files to load:") for f in self.files: logging.debug(f"----> {f}") for f in self.files: if os.path.exists(f[:-11:] + ".tif"): image_path = f[:-11:] + ".tif" else: raise FileNotFoundError( f"Could not find file: {image_path[:-4:]} with extensions .tif" ) assert os.path.exists(f.replace('.labels.tif', '.skeletons.trch')), f'cannot find skeleton file for: {f}' skeleton = ( torch.load(f[:-11:] + ".skeletons.trch") if os.path.exists(f[:-11:] + ".skeletons.trch") else {-1: torch.tensor([])} ) logging.info(f"Loading Image: {image_path}") image: np.array = io.imread(image_path) masks: np.array = io.imread(f) # [Z, X, Y] assert image.dtype == np.uint8, f"image must be 8bit" image: np.array = image[..., np.newaxis] if image.ndim == 3 else image image: np.array = image.transpose(-1, 1, 2, 0) image: np.array = image[[2], ...] if image.shape[0] > 3 else image masks: np.array = masks.transpose(1, 2, 0) if masks.max() < 256: logging.info(f"saving mask at {f} as dtype: uint8") masks = masks.astype(np.uint8) elif masks.max() < (2**16 // 2) - 1: logging.info(f"saving mask at {f} as dtype: int16") masks = masks.astype(np.int16) else: logging.info(f"saving mask at {f} as dtype: int32") masks = masks.astype(np.int32) # Convert to torch.tensor image: Tensor = torch.from_numpy(image) # .to(self.device) masks: Tensor = torch.from_numpy(masks).unsqueeze(0) # I need the images in a float, but use torch automated mixed precision so can store as half. # This may not be the same for you! self.image.append(image) self.masks.append(masks) for i, (k, v) in enumerate(skeleton.items()): if v.numel() == 0: raise ValueError(f"{f} instance label {k} has {v.numel()=}") self.skeletons.append(skeleton) self.baked_skeleton.append(None) logging.info(f"done loading from source: {path}") def __len__(self) -> int: return len(self.image) * self.sample_per_image def __getitem__(self, item: int) -> DataDict: # We might artificially want to sample more times per image # Usefull when larging super large images with a lot of data. item = item // self.sample_per_image with torch.no_grad(): data_dict: DataDict = { "image": self.image[item], "masks": self.masks[item], "skeletons": self.skeletons[item], "baked_skeleton": self.baked_skeleton[item], } # Transformation pipeline with torch.no_grad(): data_dict: DataDict = self.transforms(data_dict) # Apply transforms for k in data_dict: if isinstance(data_dict[k], torch.Tensor): data_dict[k] = data_dict[k].to(self.device) elif isinstance(data_dict[k], dict): data_dict[k] = { key: value.to(self.device) for (key, value) in data_dict[k].items() } return data_dict
[docs] def to(self, device: str): """ Sends all data stored in the dataloader to a device. :param device: torch device for images, masks, and skeletons :return: self """ self.image = [x.to(device) for x in self.image] self.masks = [x.to(device) for x in self.masks] self.skeletons = [ {k: v.to(device) for (k, v) in x.items()} for x in self.skeletons ] return self
[docs] def cuda(self) -> dataset: """alias for self.to('cuda:0')""" self.to("cuda:0") return self
[docs] def cpu(self) -> dataset: """alias for self.to('cpu')""" self.to("cuda:0") self.to("cpu") return self
[docs] def pin_memory(self) -> dataset: """ Pins underlying memory allowing faster transfer to GPU """ self.image = [x.pin_memory() for x in self.image] self.masks = [x.pin_memory() for x in self.masks] self.skeletons = [ {k: v.pin_memory() for (k, v) in x.items()} for x in self.skeletons ] return self
[docs] def map(self, fn, key: List[str] | str) -> BackgroundDataset: """ applies a fn to an internal datastructure, provided by key. valid keys: ['image', 'background', 'skele_masks', 'skeletons'] """ _valid_keys = ["image", "masks", "skeletons"] key: List[str] = [key] if isinstance(key, str) else key for k in key: assert ( k in _valid_keys ), f"key: {k} is invalid. Valid keys are: {_valid_keys}" if key == "image": self.image = [fn(im) for im in self.image] if key == "masks": self.masks = [fn(im) for im in self.masks] if key == "skeletons": self.skeletons = [fn(im) for im in self.skeletons] return self
[docs] def sum(self, with_invert: bool = False) -> int: if self.__sum is None or self.__sum["with_invert"] != with_invert: logging.info( f"calculating sum of dataset from: {self.path} | {with_invert=}" ) total = 0 for x in self.image: total += x.cpu().sum() if with_invert: total += x.cpu().sub(255).mul(-1).sum() self.__sum = {"sum": total, "with_invert": with_invert} return self.__sum["sum"]
[docs] def numel(self, with_invert: bool = False) -> int: if self.__numel is None or self.__numel["with_invert"] != with_invert: logging.info( f"calculating numel of dataset from: {self.path} | {with_invert=}" ) numel = sum([x.numel() for x in self.image]) if self.image else 0 numel = numel * 2 if with_invert else numel self.__numel = {"numel": numel, "with_invert": with_invert} return self.__numel["numel"]
[docs] def mean(self, with_invert: bool = False) -> float | None: if self.__mean is None or self.__mean["with_invert"] != with_invert: logging.info( f"calculating mean of dataset from: {self.path} | {with_invert=}" ) self.__mean = { "mean": self.sum(with_invert=with_invert) / self.numel(with_invert=with_invert), "with_invert": with_invert, } return self.__mean["mean"]
[docs] def std(self, with_invert: bool = False) -> float | None: if self.__std is None or self.__std["with_invert"] != with_invert: logging.info( f"calculating std of dataset from: {self.path} | {with_invert=}" ) mean = self.mean(with_invert=with_invert) n = self.numel(with_invert=with_invert) numerator = self.subtract_square_sum(mean) ** 2 self.__std = {"std": math.sqrt(numerator / n), "with_invert": with_invert} return self.__std["std"]
[docs] def subtract_square_sum(self, other): """ returns the sum of the entire dataset, each px subtracted by other :param other: :return: """ logging.info( f"performing substract_square_sum on dataset {self.path} | {other=}" ) with torch.no_grad(): total = 0 for x in self.image: total += _sub_sq_sum(x.cpu().numpy(), other) return total
[docs] class BackgroundDataset(Dataset): def __init__( self, path: Union[List[str], str], transforms: Optional[Transform] = lambda x: x, device: Optional[str] = "cpu", sample_per_image: Optional[int] = 1, ): super(Dataset, self).__init__() r""" Custom dataset for loading and accessing skoots background training data. Unlike skoots.train.dataloader.dataset, which looks for masks and skeletons, this dataset meed only images given that the images do not contain any actuall instances of the thing you're trying to segment - i.e. its background. An example training data folder might contain the following: :: data\ └ background\ └ background_image.tif :param path: Path to background data :param transforms: A function which applies background_dataset augmentation on a data_dict :param pad_size: padding to add to every image in the dataset :param device: torch.device which to **output** all data on :param sample_per_image: number of times each image/mask pair is sampled per iteration over a dataset """ # Reassigning variables self.files = [] self.image = [] self.transforms = transforms self.device = device self.sample_per_image = sample_per_image path: List[str] = [path] if isinstance(path, str) else path for p in path: self.files.extend(glob.glob(f"{p}{os.sep}*.labels.tif")) for f in self.files: if os.path.exists(f[:-11:] + ".tif"): image_path = f[:-11:] + ".tif" else: raise FileNotFoundError( f"Could not find file: {image_path[:-4:]} with extensions .tif" ) skeleton = ( torch.load(f[:-11:] + ".skeletons.trch") if os.path.exists(f[:-11:] + ".skeletons.trch") else {-1: torch.tensor([])} ) image: np.array = io.imread(image_path) # [Z, X, Y, C] masks: np.array = io.imread(f) # [Z, X, Y] image: np.array = image[..., np.newaxis] if image.ndim == 3 else image image: np.array = image.transpose(-1, 1, 2, 0) image: np.array = image[[2], ...] if image.shape[0] > 3 else image scale: int = ( 2**16 if image.max() > 256 else 255 ) # Our images might be 16 bit, or 8 bit scale: int = scale if image.max() > 1 else 1 assert image.max() < 256, "16bit images not supported" image: Tensor = torch.from_numpy(image.astype(np.uint8)) # .to(self.device) self.image.append(image) def __len__(self) -> int: return len(self.image) * self.sample_per_image def __getitem__(self, item: int) -> Dict[str, Tensor]: # We might artificially want to sample more times per image # Usefull when larging super large images with a lot of data. item = item // self.sample_per_image with torch.no_grad(): data_dict = { "image": self.image[item], "masks": torch.empty((1)), "skeletons": {-1: torch.empty((1))}, "baked_skeleton": None, } # Transformation pipeline with torch.no_grad(): data_dict = self.transforms(data_dict) # Apply transforms for k in data_dict: if isinstance(data_dict[k], torch.Tensor): data_dict[k] = data_dict[k].to(self.device, non_blocking=True) elif isinstance(data_dict[k], dict): data_dict[k] = { key: value.to(self.device, non_blocking=True) for (key, value) in data_dict[k].items() } return data_dict
[docs] def to(self, device: str): """ Sends all data stored in the dataloader to a device. :param device: torch device for images, masks, and skeletons :return: self """ self.image = [x.to(device) for x in self.image] self.masks = [x.to(device) for x in self.masks] self.skeletons = [ {k: v.to(device) for (k, v) in x.items()} for x in self.skeletons ] return self
[docs] def cuda(self): """alias for self.to('cuda:0')""" self.to("cuda:0") return self
[docs] def cpu(self): """alias for self.to('cpu')""" self.to("cpu") return self
[docs] def map(self, fn, key: List[str] | str) -> BackgroundDataset: """ applies a fn to an internal datastructure, provided by key. valid keys: ['image', 'background', 'skele_masks', 'skeletons'] """ _valid_keys = ["image"] key: List[str] = [key] if isinstance(key, str) else key for k in key: assert ( k in _valid_keys ), f"key: {k} is invalid. Valid keys are: {_valid_keys}" if key == "image": self.image = [fn(im) for im in self.image] return self
[docs] def sum(self): total = 0 for x in self.image: total += x.sum() return total
[docs] def numel(self): numel = sum([x.numel() for x in self.image]) if self.image else 0 return numel
[docs] def mean(self): logging.debug(f"Calculating dataset mean for {self}") if self.image: return self.sum() / self.numel else: return None
[docs] def std(self): logging.debug(f"Calculating dataset mean for {self}") mean = self.mean() n = self.numel() if mean is not None: numerator = self.sum_subtract(mean) ** 2 return math.sqrt(numerator / n) else: return None
[docs] def subtract_square_sum(self, other): """ returns the sum of the entire dataset, each px subtracted by other :param other: :return: """ total = 0 for x in self.image: total += x.to(torch.float64).sub(other).pow(2).sum() return total
[docs] class MultiDataset(Dataset): def __init__(self, *args): r""" A utility class for joining multiple datasets into one accessible class. Sometimes, you may subdivide your training data based on some criteria. The most common is size: data from folder data/train/train_alot must be sampled 100 times per epoch, while data from folder data/train/train_notsomuch might only want to be sampled 1 times per epoch. You could construct a two skoots.train.dataloader.dataset objects for each and access both in a single MultiDataset class... >>> from skoots.train.dataloader import dataset >>> >>> # has one image sampled 100 times >>> data0 = dataset('data/train/train_alot', sample_per_image=100) >>> print(len(data0)) # 100 >>> >>> # has one image sampled once >>> data1 = dataset('data/train/train_notsomuch', sample_per_image=1) >>> print(len(data1)) # 1 >>> >>> merged_data = MultiDataset(data0, data1) >>> print(len(merged_data)) # 101, they've been merged! :param args: :type args: """ self.datasets: List[dataset] = [] for ds in args: if isinstance(ds, Dataset): self.datasets.append(ds) self._dataset_lengths = [len(ds) for ds in self.datasets] self.num_datasets = len(self.datasets) self._mapped_indicies = [] for i, ds in enumerate(self.datasets): # range(len(ds)) necessary to not index whole dataset at start. SLOW!!! self._mapped_indicies.extend([i for _ in range(len(ds))]) def __len__(self) -> int: return len(self._mapped_indicies) def __getitem__(self, item: int) -> DataDict: i = self._mapped_indicies[item] # Get the ind for the dataset _offset = sum(self._dataset_lengths[:i]) # Ind offset try: return self.datasets[i][item - _offset] except Exception as e: print(i, _offset, item - _offset, item, len(self.datasets[i])) raise e
[docs] def to(self, device: str) -> MultiDataset: """ Sends all data stored in the dataloader to a device. Occurs for ALL wrapped datasets. :param device: torch device for images, masks, and skeletons :return: self """ for i in range(self.num_datasets): self.datasets[i].to(device) return self
[docs] def cuda(self) -> MultiDataset: """alias for self.to('cuda:0')""" for i in range(self.num_datasets): self.datasets[i].to("cuda:0") return self
[docs] def cpu(self) -> MultiDataset: """alias for self.to('cpu')""" for i in range(self.num_datasets): self.datasets[i].to("cpu") return self
[docs] def map(self, fn, key) -> MultiDataset: for i in range(self.num_datasets): self.datasets[i].map(fn, key) # ocurs in place return self
[docs] def sum(self, with_invert: bool = False): logging.debug(f"calculating dataset sum for {self} | {with_invert=}") total = 0 for d in self.datasets: _sum = d.sum(with_invert=with_invert) total = total + _sum if _sum is not None else total if total: return total else: return None
[docs] def numel(self, with_invert: bool = False): logging.debug(f"calculating dataset numel for {self} | {with_invert=}") total = 0 for d in self.datasets: _numel = d.numel(with_invert=with_invert) total = total + _numel if _numel is not None else total if total: return total else: return None
[docs] def mean(self, with_invert: bool = False): logging.debug(f"calculating dataset mean for {self} | {with_invert=}") _sum = self.sum(with_invert=with_invert) _numel = self.numel(with_invert=with_invert) if _sum and _numel: return _sum / _numel else: return None
[docs] def std(self, with_invert: bool = False): logging.debug(f"calculating dataset std for {self} | {with_invert=}") mean = float(self.mean(with_invert=with_invert)) n = self.numel(with_invert=with_invert) if mean is not None: numerator = sum([d.subtract_square_sum(mean) for d in self.datasets]) return math.sqrt(numerator / n) else: return None
# Custom batching function!
[docs] def skeleton_colate( data_dict: List[Dict[str, Tensor]] ) -> Tuple[Tensor, Tensor, List[Dict[str, Tensor]], Tensor, Tensor]: """ Colate function with defines how we batch training data. Unpacks a data_dict with keys: 'image', 'masks', 'skele_masks', 'baked_skeleton', 'skeleton' and puts them each into a Tensor. This should not be called outright, rather passed to a torch.DataLoader for automatic batching. :param data_dict: Dictonary of augmented training data :return: Tuple of batched data """ images = torch.stack([dd.pop("image") for dd in data_dict], dim=0) masks = torch.stack([dd.pop("masks") for dd in data_dict], dim=0) skele_masks = torch.stack([dd.pop("skele_masks") for dd in data_dict], dim=0) baked = [dd.pop("baked_skeleton") for dd in data_dict] if baked[0] is not None: baked = torch.stack(baked, dim=0) skeletons = [dd.pop("skeletons") for dd in data_dict] return images, masks, skeletons, skele_masks, baked
if __name__ == "__main__": """ class dataset(Dataset): def __init__( self, path: Union[List[str], str], transforms: Optional[Transform] = lambda x: x, pad_size: Optional[int] = 100, device: Optional[str] = "cpu", sample_per_image: Optional[int] = 1, ): """ from skoots.train.merged_transform import merged_transform_3D data = dataset( path="/home/chris/Dropbox (Partners HealthCare)/skoots-experiments/data/mitochondria/train/hide", transforms=merged_transform_3D, ) for m in data.masks: print(m.max(), m.shape) for i in range(len(data)): print(data[i]["masks"].max(), data[i]["masks"].shape)