Detailed Training Protocol#
The training process is typically invoked via the command line interface via the skoots-train command. This calls into the main function in file skoots.train.__main__.py. This function parses all command line arguments, loads the config file and model, initializes pytorch DataDistributedParallel, and finally calls the train() function from skoots.train.engine.py. To understand how we train SKOOTS, we will go that function in detail. Throughout the training script, you will see references to a variable cfg which stores the users configuration data.
Imports#
We must first import each package necessary for training. SKOOTS tries to take a functional approach at training. It not exactly in line with functional programing best practices, but avoids you from going into a hell of inheritance.
import os
import os.path
from functools import partial
from statistics import mean
from typing import Callable, Union, Dict
import torch
import torch.nn as nn
import torch.optim.lr_scheduler
import torch.optim.swa_utils
from torch import Tensor
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from tqdm import trange
from yacs.config import CfgNode
import skoots.train.loss
from skoots.lib.embedding_to_prob import baked_embed_to_prob
from skoots.lib.vector_to_embedding import vector_to_embedding
from skoots.train.dataloader import dataset, MultiDataset, skeleton_colate
from skoots.train.merged_transform import (
transform_from_cfg,
background_transform_from_cfg,
)
from skoots.train.setup import setup_process
from skoots.train.sigma import Sigma, init_sigma
from skoots.train.utils import write_progress
Setup DataDistributedParallel#
We need to define 3 mandatory inputs: rank, port, and world_size. Starting in reverse, world_size is the total number of devices to run distributed training on. If you have two GPU’s in one machine, then your world size would be 2. port is the port of a local web server by which to run distributed training. rank is the process number. So for a world_size of 2, we would get two process, one where rank=0 and one with rank=1. World size is handled by the configuration file with cfg.SYSTEM.NUM_GPUS. This function should be called through pytorch multiprocessing. See skoots.train.__main__.py.
# Invoked from skoots.train.__main__.py
def train(rank: str,
port: str,
world_size: int,
base_model: nn.Module,
cfg: CfgNode
) -> None:
pass
From here we set up required processes for torch DistributedDataParallel as well as compile the model using torch inductor (if available). This lets us use multiple GPU’s for training, as well as just-in-time compiled Cuda kernels for accelerated training.
setup_process(rank, world_size, port, backend="nccl")
device = f"cuda:{rank}"
base_model = base_model.to(device)
base_model = torch.nn.parallel.DistributedDataParallel(base_model)
if int(rank) == 0:
print(cfg)
if int(torch.__version__[0]) >= 2:
print("Comiled with Inductor")
model = torch.compile(base_model)
else:
model = torch.jit.script(base_model)
Data Loading and Augmentation#
Data augmentation parameters are set by the configuration file and executed as a single function from skoots.train.merged_transform.py. This is to reduce the overhead of chaining multiple augmentation classes together, which some augmentation libraries like to do. There is a seperate set of transformations for background data, as this does not need to process masks or skeletons.
augmentations: Callable[[Dict[str, Tensor]], Dict[str, Tensor]] = partial(
transform_from_cfg, cfg=cfg, device=device
)
background_agumentations: Callable[
[Dict[str, Tensor]], Dict[str, Tensor]
] = partial(background_transform_from_cfg, cfg=cfg, device=device)
This function is takes in a data_dict, which is simply a python dictionary which contains the image, masks, and skeletons. Next, we load our data using the dataset class from skoots.train.dataloader.py. This dataset class looks for multiple sets of three files in a single of folder with a common prefix and the extensions: *.tif(the image), *.label.tif (the masks), and *.skeletons.trch (the precomputed skeletons). Training data often consists of one, really large file, too large to fit in a neural network. Therefore, the notion of an epoch doesn’t make sense. Instead, SKOOTS defines an epoch as a set number of samples from each image in a dataset. This might change for different images, (you dont want to sample a small image 30 times), and therefore SKOOTS enables the user to split their datasets up in multiple folders, and define a sample rate for each.
This is set in the config by specifying a list of potential data locations: _C.TRAIN.TRAIN_DATA_DIR = [data_loc_1, data_loc_2, ...]. For each data location, we let the user define the number of samples which defines an epoch. This is reflected in code here:
_datasets = [] # store multiple datasets
for path, N in zip(cfg.TRAIN.TRAIN_DATA_DIR, cfg.TRAIN.TRAIN_SAMPLE_PER_IMAGE):
_device = device if cfg.TRAIN.STORE_DATA_ON_GPU else "cpu"
_datasets.append(
dataset(
path=path, # where is our data
transforms=augmentations, # augmentation function
sample_per_image=N, # how many times do we sample each image?
device=device, # what devive (cpu or gpu) should the data go to
pad_size=10, # zero padding added to each image
)
.pin_memory() # pins the memory in ram for faster access
.to(_device) # if your dataset is small, or GPU is LARGE, all of the data can live on the GPU for faster access
)
merged_train = MultiDataset(*_datasets) # helper class which lets us access all datasets in one object
train_sampler = torch.utils.data.distributed.DistributedSampler(merged_train)
_n_workers = 0 # if _device != 'cpu' else 2
# put this in a pytorch dataloader for automatic batching and sampling
dataloader = DataLoader(
merged_train,
num_workers=_n_workers,
batch_size=cfg.TRAIN.TRAIN_BATCH_SIZE,
sampler=train_sampler,
collate_fn=skeleton_colate,
)
We do the same for validation and background datasets.
for path, N in zip(
cfg.TRAIN.BACKGROUND_DATA_DIR, cfg.TRAIN.BACKGROUND_SAMPLE_PER_IMAGE
):
_device = device if cfg.TRAIN.STORE_DATA_ON_GPU else "cpu"
_datasets.append(
dataset(
path=path,
transforms=background_agumentations,
sample_per_image=N,
device=device,
pad_size=100,
)
.pin_memory()
.to(_device)
)
merged_train = MultiDataset(*_datasets)
train_sampler = torch.utils.data.distributed.DistributedSampler(merged_train)
_n_workers = 0 # if _device != 'cpu' else 2
dataloader = DataLoader(
merged_train,
num_workers=_n_workers,
batch_size=cfg.TRAIN.TRAIN_BATCH_SIZE,
sampler=train_sampler,
collate_fn=skeleton_colate,
)
# Validation Dataset
_datasets = []
for path, N in zip(
cfg.TRAIN.VALIDATION_DATA_DIR, cfg.TRAIN.VALIDATION_SAMPLE_PER_IMAGE
):
_device = device if cfg.TRAIN.STORE_DATA_ON_GPU else "cpu"
_datasets.append(
dataset(
path=path,
transforms=augmentations,
sample_per_image=N,
device=device,
pad_size=10,
)
.pin_memory()
.to(_device)
)
merged_validation = MultiDataset(*_datasets)
test_sampler = torch.utils.data.distributed.DistributedSampler(merged_validation)
if _datasets or cfg.TRAIN.VALIDATION_BATCH_SIZE >= 1:
_n_workers = 0 # if _device != 'cpu' else 2
valdiation_dataloader = DataLoader(
merged_validation,
num_workers=_n_workers,
batch_size=cfg.TRAIN.VALIDATION_BATCH_SIZE,
sampler=test_sampler,
collate_fn=skeleton_colate,
)
else: # we might not want to run validation...
valdiation_dataloader = None
Optimizers, Schedulers, Loss#
We set optimizers, learning rate schedulers, and loss functions through the config file. The constructors for each come from a list of dictonaries at the top of skoots.train.engine.py:
_valid_optimizers = {
"adamw": torch.optim.AdamW,
"adam": torch.optim.Adam,
"sgd": torch.optim.SGD,
"adamax": torch.optim.Adamax,
}
_valid_loss_functions = {
"soft_cldice": skoots.train.loss.soft_dice_cldice,
"tversky": skoots.train.loss.tversky,
}
_valid_lr_schedulers = {
"cosine_annealing_warm_restarts": torch.optim.lr_scheduler.CosineAnnealingWarmRestarts
}
Within the training script, we get the constructor for each from these valid options, and call into it with other arguments set by the config file. We can set keyword arguments and values for the loss functions via the configuration as well. This is helpful when using tversky loss with different pentalties for foreground and background.
optimizer = _valid_optimizers[cfg.TRAIN.OPTIMIZER](
model.parameters(),
lr=cfg.TRAIN.LEARNING_RATE,
weight_decay=cfg.TRAIN.WEIGHT_DECAY,
)
scheduler = _valid_lr_schedulers[cfg.TRAIN.SCHEDULER](
optimizer, T_0=cfg.TRAIN.SCHEDULER_T0
)
scaler = GradScaler(enabled=cfg.TRAIN.MIXED_PRECISION)
swa_model = torch.optim.swa_utils.AveragedModel(model)
swa_start = 100
swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, swa_lr=0.05)
_kwarg = {
k: v for k, v in zip(cfg.TRAIN.LOSS_EMBED_KEYWORDS, cfg.TRAIN.LOSS_EMBED_VALUES)
}
loss_embed: Callable = _valid_loss_functions[cfg.TRAIN.LOSS_EMBED](**_kwarg)
_kwarg = {
k: v
for k, v in zip(
cfg.TRAIN.LOSS_PROBABILITY_KEYWORDS, cfg.TRAIN.LOSS_PROBABILITY_VALUES
)
}
loss_prob: Callable = _valid_loss_functions[cfg.TRAIN.LOSS_PROBABILITY](**_kwarg)
_kwarg = {
k: v
for k, v in zip(
cfg.TRAIN.LOSS_SKELETON_KEYWORDS, cfg.TRAIN.LOSS_SKELETON_VALUES
)
}
loss_skele: Callable = _valid_loss_functions[cfg.TRAIN.LOSS_SKELETON](**_kwarg)
Sigma#
To evaluate embedding accuracy, SKOOTS defines a distance penalty variable called sigma. This is implemented in its own class: skoots.train.sigma.py. The parameters for this are set in the config file, and the class is constructed with the helper function skoots.train.sigma.init_sigma()This penalty decays over multiple epochs and is called like a function:
sigma: Sigma = init_sigma(cfg, device)
_ = sigma(100) # whats the sigma at epoch 100?
Vector Scaling#
Our model will ultimately output a set of vectors from -1 to 1. This must be scaled to fit the maximum radius of any object you wish to segment. That is set here.
vector_scale = torch.tensor(cfg.SKOOTS.VECTOR_SCALING, device=device)
Before final training we also set/initalize a couple of other things
# these disable some torch checks but can accelerate training speed
torch.backends.cudnn.benchmark = cfg.TRAIN.CUDNN_BENCHMARK
torch.autograd.profiler.profile = cfg.TRAIN.AUTOGRAD_PROFILE
torch.autograd.profiler.emit_nvtx(enabled=cfg.TRAIN.AUTOGRAD_EMIT_NVTX)
torch.autograd.set_detect_anomaly(cfg.TRAIN.AUTOGRAD_DETECT_ANOMALY)
# we use tensorboard for logging
writer = SummaryWriter() if rank == 0 else None
if writer:
print("SUMMARY WRITER LOG DIR: ", writer.get_logdir())
# Save each loss value in a list... we disregard the first one... ;)
avg_epoch_loss = [9999999999.9999999999]
avg_epoch_embed_loss = [9999999999.9999999999]
avg_epoch_prob_loss = [9999999999.9999999999]
avg_epoch_skele_loss = [9999999999.9999999999]
avg_val_loss = [9999999999.9999999999]
avg_val_embed_loss = [9999999999.9999999999]
avg_val_prob_loss = [9999999999.9999999999]
avg_val_skele_loss = [9999999999.9999999999]
Calling the DataLoader and a Simple Training Iteration#
The DataLoader acts like an iterable which returns 5 pieces of information: the image, the labeled mask, the skeleton dictonary, the skeleton masks, and the “baked” skeleton. For more reference on what these are, see the Training section. We use each of these to perform a training step. First the image is passed through the model
# assume current epoch is set here:
current_epoch = 0
for images, masks, skeleton, skele_masks, baked in dataloader:
out: Tensor = model(images)
The out tensor is a 5 channel tensor which contains the semantic probability map, the embedding vectors, and the skeleton map. We can separate these here:
probability_map: Tensor = out[:, [-1], ...]
vector: Tensor = out[:, 0:3:1, ...]
predicted_skeleton: Tensor = out[:, [-2], ...]
To calculate a loss, we need a skeleton embedding. To calculate the skeleton embedding we need the vectors, vector sale, and the function vector_to_embedding from skoots.lib.vector_to_embedding.py
embedding: Tensor = vector_to_embedding(vector_scale, vector)
Once we have an embedding, we need a way to calculate a loss value. We do this by generating a probability score for each pixel based on how close the embedding is from it’s “true” destination. This true destination is its closest skeleton, and contained in the baked skeleton tensor. To calculate this probability we call the function baked_embed_to_prob from skoots.lib.embedding_to_prob.py.
out: Tensor = baked_embed_to_prob(embedding, baked, sigma(current_epoch))
This probability map is just a tensor from 0-1. It’s esentially a semantic map, and therefore we can use the tversky loss with the semantic map to generate a single loss value.
_loss_embed = loss_embed(out, masks.gt(0).float())
The predicted skeletons and probability map have targets generated by the dataloader, and therefore we simply generate a loss using a similar method.
_loss_prob = loss_prob(probability_map, masks.gt(0).float())
_loss_skeleton = loss_skele(
predicted_skeleton, skele_masks.gt(0).float()
) # + skel_crossover_loss(predicted_skeleton, skele_masks.gt(0).float())
Finally, we let the user define the relative weight each loss value has on the overall training and the epoch at which we should first consider. This is defined in the configuration file and represented in code here.
loss = (
(
cfg.TRAIN.LOSS_EMBED_RELATIVE_WEIGHT
* (1 if current_epoch > cfg.TRAIN.LOSS_EMBED_START_EPOCH else 0)
* _loss_embed
)
+ (
cfg.TRAIN.LOSS_PROBABILITY_RELATIVE_WEIGHT
* (1 if current_epoch > cfg.TRAIN.LOSS_PROBABILITY_START_EPOCH else 0)
* _loss_prob
)
+ (
cfg.TRAIN.LOSS_SKELETON_RELATIVE_WEIGHT
* (1 if current_epoch > cfg.TRAIN.LOSS_SKELETON_START_EPOCH else 0)
* _loss_skeleton
)
)
Now we scale the loss (if using stochastic weight averaging) and run backpropagation.
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
Warmup#
We found that over training a randomly initialized model, helps that model learn the task on new data down the line. We can do all the steps above, but just in one dataset
# Warmup... Get the first from train_data
for images, masks, skeleton, skele_masks, baked in dataloader:
pass
assert images is not None, len(dataloader)
warmup_range = trange(cfg.TRAIN.N_WARMUP, desc="Warmup: {}")
for w in warmup_range:
optimizer.zero_grad(set_to_none=True)
with autocast(enabled=cfg.TRAIN.MIXED_PRECISION): # Saves Memory!
out: Tensor = model(images)
probability_map: Tensor = out[:, [-1], ...]
vector: Tensor = out[:, 0:3:1, ...]
predicted_skeleton: Tensor = out[:, [-2], ...]
embedding: Tensor = vector_to_embedding(vector_scale, vector)
out: Tensor = baked_embed_to_prob(embedding, baked, sigma(0))
_loss_embed = loss_embed(
out, masks.gt(0).float()
) # out = [B, 2/3, X, Y, Z?]
_loss_prob = loss_prob(probability_map, masks.gt(0).float())
_loss_skeleton = loss_skele(
predicted_skeleton, skele_masks.gt(0).float()
) # + skel_crossover_loss(predicted_skeleton, skele_masks.gt(0).float())
loss = (
(cfg.TRAIN.LOSS_EMBED_RELATIVE_WEIGHT * _loss_embed)
+ (cfg.TRAIN.LOSS_PROBABILITY_RELATIVE_WEIGHT * _loss_prob)
+ (cfg.TRAIN.LOSS_SKELETON_RELATIVE_WEIGHT * _loss_skeleton)
)
warmup_range.desc = f"{loss.item()}"
if torch.isnan(loss):
print(
f"Found NaN value in loss.\n\tLoss Embed: {_loss_embed}\n\tLoss Probability: {_loss_prob}"
)
print(f"\t{torch.any(torch.isnan(vector))}")
print(f"\t{torch.any(torch.isnan(embedding))}")
continue
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
Main Training Loop#
We can now train our entire model. This simply takes the previous method, but applies it over multiple images in our dataset, multiple times. The only difference here is we do some logging to tensorboard.
# Train Step...
epoch_range = (
trange(cfg.TRAIN.NUM_EPOCHS, desc=f"Loss = {1.0000000}") if rank == 0 else range(cfg.TRAIN.NUM_EPOCHS)
)
for e in epoch_range:
_loss, _embed, _prob, _skele = [], [], [], []
if cfg.TRAIN.DISTRIBUTED:
train_sampler.set_epoch(e)
for images, masks, skeleton, skele_masks, baked in dataloader:
optimizer.zero_grad(set_to_none=True)
with autocast(enabled=cfg.TRAIN.MIXED_PRECISION): # Saves Memory!
out: Tensor = model(images)
probability_map: Tensor = out[:, [-1], ...]
vector: Tensor = out[:, 0:3:1, ...]
predicted_skeleton: Tensor = out[:, [-2], ...]
embedding: Tensor = vector_to_embedding(vector_scale, vector)
out: Tensor = baked_embed_to_prob(embedding, baked, sigma(e))
_loss_embed = loss_embed(
out, masks.gt(0).float()
) # out = [B, 2/3, X, Y, :w
# Z?]
_loss_prob = loss_prob(probability_map, masks.gt(0).float())
_loss_skeleton = loss_skele(
predicted_skeleton, skele_masks.gt(0).float()
) # + skel_crossover_loss(predicted_skeleton, skele_masks.gt(0).float())
# fuck this small amount of code.
loss = (
(
cfg.TRAIN.LOSS_EMBED_RELATIVE_WEIGHT
* (1 if e > cfg.TRAIN.LOSS_EMBED_START_EPOCH else 0)
* _loss_embed
)
+ (
cfg.TRAIN.LOSS_PROBABILITY_RELATIVE_WEIGHT
* (1 if e > cfg.TRAIN.LOSS_PROBABILITY_START_EPOCH else 0)
* _loss_prob
)
+ (
cfg.TRAIN.LOSS_SKELETON_RELATIVE_WEIGHT
* (1 if e > cfg.TRAIN.LOSS_SKELETON_START_EPOCH else 0)
* _loss_skeleton
)
)
if torch.isnan(loss):
print(
f"Found NaN value in loss.\n\tLoss Embed: {_loss_embed}\n\tLoss Probability: {_loss_prob}"
)
print(f"\t{torch.any(torch.isnan(vector))}")
print(f"\t{torch.any(torch.isnan(embedding))}")
continue
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
if e > swa_start:
swa_model.update_parameters(model)
_loss.append(loss.item())
_embed.append(_loss_embed.item())
_prob.append(_loss_prob.item())
_skele.append(_loss_skeleton.item())
avg_epoch_loss.append(mean(_loss))
avg_epoch_embed_loss.append(mean(_embed))
avg_epoch_prob_loss.append(mean(_prob))
avg_epoch_skele_loss.append(mean(_skele))
scheduler.step()
if writer and (rank == 0):
writer.add_scalar("lr", scheduler.get_last_lr()[-1], e)
writer.add_scalar("Loss/train", avg_epoch_loss[-1], e)
writer.add_scalar("Loss/embed", avg_epoch_embed_loss[-1], e)
writer.add_scalar("Loss/prob", avg_epoch_prob_loss[-1], e)
writer.add_scalar("Loss/skele-mask", avg_epoch_skele_loss[-1], e)
write_progress(
writer=writer,
tag="Train",
epoch=e,
images=images,
masks=masks,
probability_map=probability_map,
vector=vector,
out=out,
skeleton=skeleton,
predicted_skeleton=predicted_skeleton,
gt_skeleton=skele_masks,
)
# # Validation Step
if e % 10 == 0 and valdiation_dataloader:
_loss, _embed, _prob, _skele = [], [], [], []
for images, masks, skeleton, skele_masks, baked in valdiation_dataloader:
with autocast(enabled=cfg.TRAIN.MIXED_PRECISION): # Saves Memory!
with torch.no_grad():
out: Tensor = model(images)
probability_map: Tensor = out[:, [-1], ...]
predicted_skeleton: Tensor = out[:, [-2], ...]
vector: Tensor = out[:, 0:3:1, ...]
embedding: Tensor = vector_to_embedding(vector_scale, vector)
out: Tensor = baked_embed_to_prob(embedding, baked, sigma(e))
_loss_embed = loss_embed(out, masks.gt(0).float())
_loss_prob = loss_prob(probability_map, masks.gt(0).float())
_loss_skeleton = loss_prob(
predicted_skeleton, skele_masks.gt(0).float()
)
loss = (2 * _loss_embed) + (2 * _loss_prob) + _loss_skeleton
if torch.isnan(loss):
print(
f"Found NaN value in loss.\n\tLoss Embed: {_loss_embed}\n\tLoss Probability: {_loss_prob}"
)
print(f"\t{torch.any(torch.isnan(vector))}")
print(f"\t{torch.any(torch.isnan(embedding))}")
continue
scaler.scale(loss)
_loss.append(loss.item())
_embed.append(_loss_embed.item())
_prob.append(_loss_prob.item())
_skele.append(_loss_skeleton.item())
avg_val_loss.append(mean(_loss))
avg_val_embed_loss.append(mean(_embed))
avg_val_prob_loss.append(mean(_prob))
avg_val_skele_loss.append(mean(_skele))
if writer and (rank == 0):
writer.add_scalar("Validation/train", avg_val_loss[-1], e)
writer.add_scalar("Validation/embed", avg_val_embed_loss[-1], e)
writer.add_scalar("Validation/prob", avg_val_prob_loss[-1], e)
write_progress(
writer=writer,
tag="Validation",
epoch=e,
images=images,
masks=masks,
probability_map=probability_map,
vector=vector,
out=out,
skeleton=skeleton,
predicted_skeleton=predicted_skeleton,
gt_skeleton=skele_masks,
)
if rank == 0:
epoch_range.desc = (
f"lr={scheduler.get_last_lr()[-1]:.3e}, Loss (train | val): "
+ f"{avg_epoch_loss[-1]:.5f} | {avg_val_loss[-1]:.5f}"
)
state_dict = (
model.module.state_dict()
if hasattr(model, "module")
else model.state_dict()
)
if e % 100 == 0:
torch.save(state_dict, cfg.TRAIN.SAVE_PATH + f"/test_{e}.trch")
Save the model#
Finally, each model trained by this script is saved as a dictionary with the configuration file cfg, model_state_dict, and the optimizer_state_dict. It is saved to the same name as the SummaryWriter object for tensorboard, linking the two.
if rank == 0:
state_dict = (
model.module.state_dict()
if hasattr(model, "module")
else model.state_dict()
)
constants = {
"cfg": cfg,
"model_state_dict": state_dict,
"optimizer_state_dict": optimizer.state_dict(),
"avg_epoch_loss": avg_epoch_loss,
"avg_epoch_embed_loss": avg_epoch_embed_loss,
"avg_epoch_prob_loss": avg_epoch_prob_loss,
"avg_epoch_skele_loss": avg_epoch_skele_loss,
"avg_val_loss": avg_epoch_loss,
"avg_val_embed_loss": avg_epoch_embed_loss,
"avg_val_prob_loss": avg_epoch_prob_loss,
"avg_val_skele_loss": avg_epoch_skele_loss,
}
try:
torch.save(
constants,
f"{cfg.TRAIN.SAVE_PATH}/{os.path.split(writer.log_dir)[-1]}.trch",
)
except:
print(
f"Could not save at: {cfg.TRAIN.SAVE_PATH}/{os.path.split(writer.log_dir)[-1]}.trch"
f"Saving at {os.getcwd()}/{os.path.split(writer.log_dir)[-1]}.trch instead"
)
torch.save(
constants,
f"{os.getcwd()}/{os.path.split(writer.log_dir)[-1]}.trch",
)