Source code for skoots.lib.vector_to_embedding

from typing import List
from typing import Tuple

import torch
from torch import Tensor

"""
2D and 3D implementations of vector to embedding. 

vector is predicted by a network and is an array of values ranging from -1 to 1. These values can be scaled by
"scale" and are used to project pixels from their location, to a new location in embedding space. 

Functionally, a trained network will push pixels to the center of an object, forming clusers which can be used to infer
object locations. 

"""


[docs] @torch.jit.script def get_vector_mesh(shape: Tuple[int, int, int, int, int], device: str) -> Tensor: """generates a 3d mesh from a vector""" axis_ind: List[Tensor] = [ torch.linspace(0, shape[2] - 1, shape[2], device=device), torch.linspace(0, shape[3] - 1, shape[3], device=device), torch.linspace(0, shape[4] - 1, shape[4], device=device), ] mesh: List[Tensor] = torch.meshgrid(axis_ind, indexing="ij") mesh: List[Tensor] = [m.unsqueeze(0).unsqueeze(0) for m in mesh] mesh: Tensor = torch.cat(mesh, dim=1) return mesh
[docs] @torch.jit.script def _vec2embed3D_graphable(static_scale: Tensor, vector: Tensor, static_mesh) -> Tensor: """ 3D vector to embedding which uses static inputs for cuda graphs Could be a faster way to do this with strides but idk... :param scale: Tensor with shape (3) :param vector: [B, C, X, Y, Z] :return: embedding vector """ return static_mesh + vector.mul(static_scale.view((1, 3, 1, 1, 1)))
[docs] @torch.jit.script def _vec2embed2D(scale: Tensor, vector: Tensor) -> Tensor: """ 2D vector to embedding :param scale: The offest in XY of the vectors. :param vector: [B, C=2, X, Y] the vector matrix predicted by the unet :return: Pixel Spatial Embeddings (i.e. vector + pixel_indicies) """ num: Tensor = torch.clone(scale.float()) newshape: Tuple[int, int, int, int] = (1, 2, 1, 1) axis_ind: List[Tensor] = [ torch.linspace(0, vector.shape[2] - 1, vector.shape[2], device=vector.device), torch.linspace(0, vector.shape[3] - 1, vector.shape[3], device=vector.device), ] mesh = torch.meshgrid(axis_ind, indexing="ij") mesh: List[Tensor] = [m.unsqueeze(0).unsqueeze(0) for m in mesh] mesh: Tensor = torch.cat(mesh, dim=1) vector = vector.mul(num.view(newshape)) return mesh + vector
[docs] @torch.jit.script def _vec2embed3D(scale: Tensor, vector: Tensor, n: int = 1, decay: float = 1.0) -> Tensor: """ 2D or 3D vector to embedding Could be a faster way to do this with strides but idk... :param scale: [N=2/3] :param vector: [B, C, X, Y, Z?] :param n: number of times to apply vectors :return: """ num: Tensor = torch.clone(scale.float()) newshape: Tuple[int, int, int, int, int] = (1, 3, 1, 1, 1) axis_ind: List[Tensor] = [ torch.linspace(0, vector.shape[2] - 1, vector.shape[2], device=vector.device), torch.linspace(0, vector.shape[3] - 1, vector.shape[3], device=vector.device), torch.linspace(0, vector.shape[4] - 1, vector.shape[4], device=vector.device), ] mesh: List[Tensor] = torch.meshgrid(axis_ind, indexing="ij") mesh: List[Tensor] = [m.unsqueeze(0).unsqueeze(0) for m in mesh] mesh: Tensor = torch.cat(mesh, dim=1) scaled_vector = vector.mul(num.view(newshape)) mesh = mesh + scaled_vector scale = 1.0 for _ in range(n - 1): # Only executes if n > 1 # convert to index. scale *= decay scaled_vector = vector.mul(scale * num.view(newshape)) index = mesh.round() b, c, x, y, z = index.shape for i, k in enumerate([x, y, z]): index[:, i, ...] = torch.clamp(index[:, i, ...], 0, k) # 3d index to raveled index = ( (index[:, [0], ...] * y * z) + (index[:, [1], ...] * z) + (index[:, [2], ...]) ) index = index.clamp(0, x * y * z - 1).long() for i in range(c): mesh[:, [i], ...] = mesh[:, [i], ...] + scaled_vector[:, [i], ...].take(index) return mesh
[docs] def vector_to_embedding(scale: Tensor, vector: Tensor, N: int = 1, decay: float = 1.0) -> Tensor: """ Converts a 2D or 3D vector field to a spatial embedding by adding the vector at any position to its own position. vector is a 2D or 3D vector field of shape :math:`(B, 2, X, Y)` for 2D or :math:`(B, 3, X, Y, Z)` for 3D. Each vector ":math:`v`" lies within the range -1 and 1 and is scaled by scale ":math:`s`". The scaled vector is then added to its own position to form a spatial embedding ":math:`\phi`": Formally: .. math:: i,j,k \in \mathbb{Z}_{≥0} \n v_{i,j,k} \in [-1, 1] \n s = [s_i, s_j, s_k] \phi_{i,j,k} = v_{i,j,k} * s + [i, j, k] Shapes: - scale: :math:`(2)` or :math:`(3)` - vector: :math:`(B_{in}, 2, X_{in}, Y_{in})` or :math:`(B_{in}, 3, X_{in}, Y_{in}, Z_{in})` - Returns: :math:`(B_{in}, 2, X_{in}, Y_{in})` or :math:`(B_{in}, 3, X_{in}, Y_{in}, Z_{in})` :param scale: Scaling factors for each vector spatial dimension :param vector: Vector field predicted by a neural network :param N: Number of iterations to apply the vectors. :param decay: vector strength decay after each iteration. Default 1.0 :return: Pixel spatial embeddings """ if vector.ndim == 4: assert decay == 1.0, f'decay parameter only valid for 5D tensor' assert N == 1, f'N must be equal to 1 for 4D tensors.' # assert vector.ndim in [4, 5], f'Vector must be a 4D or 5D tensor, not {vector.ndim}D: {vector.shape=}' return ( _vec2embed3D(scale, vector, N, decay) if vector.ndim == 5 else _vec2embed2D(scale, vector) )
[docs] def vec2embedND(scale, vector): """ Generic N dimmensional vector to embedding Could be a faster way to do this with strides but idk... :param scale: [N=2/3] :param vector: [B, C, X, Y, Z?] :return: """ assert ( scale.shape[0] == vector.shape[1] ), f"Cannot use {scale.shape[0]}D scale with vector shape: {vector.shape}" assert ( scale.shape[0] == vector.ndim - 2 ), f"Cannot use {scale.shape[0]}D scale with {vector.ndim - 2}D vector shape [B, C, ...]: {vector.shape}" num: Tensor = torch.clone(scale.float()) newshape: Tuple[int] = tuple( [1, scale.shape[0]] + [ 1, ] * (vector.ndim - 2) ) axis_ind: List[Tensor] = [] for i in range(vector.ndim - 2): axis_ind.append( torch.linspace( 0, vector.shape[2 + i] - 1, vector.shape[2 + i], device=vector.device ) ) mesh = torch.meshgrid(axis_ind, indexing="ij") mesh = [m.unsqueeze(0).unsqueeze(0) for m in mesh] mesh = torch.cat(mesh, dim=1) vector = vector.mul(num.view(newshape)) return mesh + vector
if __name__ == "__main__": vector = torch.ones((1, 3, 10, 10, 10)).float() vector[:, 0, 5, 5, 5] = -1 vector[:, 1, 5, 5, 5] = -1 vector[:, 2, 5, 5, 5] = -1 vector[:, [0,1,2], 4, 4, 4] = torch.tensor((2, 2, 2)).float() out = vector_to_embedding(torch.tensor((1, 1, 1)), vector, N=2) print(f"{out[0, :, 5, 5, 5]=}") # should equal 6 6 6