Source code for skoots.train.sigma

from typing import Dict, List

import torch
from yacs.config import CfgNode


# DOCUMENTED


[docs] class Sigma: """an adjustable sigma parameter for training.""" def __init__( self, adjustments: List[Dict[str, float]], initial_sigma: List[float] = [0.1, 0.1, 0.8], device="cpu", ): """ Creates an object which inputs an epoch, and returns a torch.tensor of [sigma_x, sigma_y, sigma_z] Sigma should reflect the Error of embedding vectors at each spatial dim. :param adjustments: lost of adjustments and when to apply them :param initial_sigma: initial values of sigma at epoch=0 :param device: device to load sigma on ('cpu' or 'cuda') """ self.adjutments = adjustments self.device = device self.initial_sigma = ( torch.tensor(initial_sigma, device=device) if isinstance(initial_sigma, list) else initial_sigma ) values = [1] # Initial values set so sigma isnt zero at 3 zero epochs = [-1] for d in adjustments: values.append(d["multiplier"]) epochs.append(d["epoch"]) self.values = torch.tensor(values, device=self.device) self.epochs = torch.tensor(epochs, device=self.device) def __call__(self, e: int): """ Returns sigma at epoch "e" :param e: epoch :return: sigma [x,y,z] """ multiplier = self.values[self.epochs < e].prod() return self.initial_sigma * multiplier
[docs] def init_sigma(cfg: CfgNode, device: torch.device) -> Sigma: start = torch.tensor(cfg.TRAIN.INITIAL_SIGMA, device=device) multipliers = [{"multiplier": a, "epoch": b} for a, b in cfg.TRAIN.SIGMA_DECAY] return Sigma(initial_sigma=start, adjustments=multipliers)
if __name__ == "__main__": from skoots.config import get_cfg_defaults cfg = get_cfg_defaults() _ = init_sigma(cfg, "cpu")