Source code for STL_main.STL_utils

import torch


[docs] class Gaussianise: """ Rank-based Gaussianisation with invertible mapping via interpolation. Method: - Fit on a reference tensor x_ref. - Build empirical CDF F_X from sorted(x_ref). - Forward: for any x, approximate u = F_X(x) by linear interpolation, then z = Phi^{-1}(u) (standard normal). - Inverse: for any z, u = Phi(z), then approximate x = F_X^{-1}(u) by linear interpolation between sorted samples. This is designed to be differentiable w.r.t. z in the inverse, and w.r.t. x in the forward, except for the non-differentiable dependence of the *sorting* (piecewise constant). """ def __init__(self, x_ref: torch.Tensor, eps: float = 1e-6): """ Parameters ---------- x_ref : torch.Tensor Reference data used to define the mapping. Arbitrary shape. eps : float Small value used to clamp probabilities away from 0 and 1. """ if not torch.is_tensor(x_ref): x_ref = torch.as_tensor(x_ref) if not x_ref.is_floating_point(): x_ref = x_ref.float() self.device = x_ref.device self.dtype = x_ref.dtype self.eps = eps # Flatten reference data x_flat = x_ref.reshape(-1) # Sort reference values x_sorted, _ = torch.sort(x_flat) # (N,) N = x_sorted.numel() # Empirical CDF positions: (r+0.5)/N for ranks r=0..N-1 ranks = torch.arange(N, device=self.device, dtype=self.dtype) u_sorted = (ranks + 0.5) / float(N) # (N,) # Store mapping: x_sorted <-> u_sorted self.x_sorted = x_sorted self.u_sorted = u_sorted # ----- basic normal CDF and inverse CDF using erf/erfinv ----- def _phi_inv(self, u: torch.Tensor) -> torch.Tensor: """ Inverse CDF of standard normal using erfinv. Φ^{-1}(u) = sqrt(2) * erfinv(2u - 1) """ u = u.clamp(self.eps, 1.0 - self.eps) return torch.sqrt( torch.tensor(2.0, device=u.device, dtype=u.dtype) ) * torch.erfinv(2.0 * u - 1.0) def _phi(self, z: torch.Tensor) -> torch.Tensor: """ CDF of standard normal using erf. Φ(z) = 0.5 * [1 + erf(z / sqrt(2))] """ return 0.5 * ( 1.0 + torch.erf( z / torch.sqrt(torch.tensor(2.0, device=z.device, dtype=z.dtype)) ) ) # ----- helper: 1D monotone interpolation ----- def _interp_1d( self, xq: torch.Tensor, x: torch.Tensor, y: torch.Tensor ) -> torch.Tensor: """ 1D linear interpolation: given sorted x, corresponding y, and query xq, approximate yq = f(xq). - x: (N,) sorted in ascending order - y: (N,) - xq: arbitrary shape Returns yq with same shape as xq. This uses torch.searchsorted to find interval indices, and then linear interpolation between neighbors. Outside the range, it clamps to endpoints. """ # Flatten queries xq_flat = xq.reshape(-1) # (M,) # Indices of the right neighbor for each query idx_hi = torch.searchsorted(x, xq_flat, right=True) # (M,) idx_hi = idx_hi.clamp(0, x.numel() - 1) idx_lo = (idx_hi - 1).clamp(0, x.numel() - 1) x_lo = x[idx_lo] x_hi = x[idx_hi] y_lo = y[idx_lo] y_hi = y[idx_hi] # Avoid division by zero when x_hi == x_lo (duplicate x) denom = x_hi - x_lo denom = torch.where(denom.abs() < 1e-12, torch.ones_like(denom), denom) t = (xq_flat - x_lo) / denom yq_flat = y_lo + t * (y_hi - y_lo) return yq_flat.reshape(xq.shape) # ----- forward: data -> Gaussian -----
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Gaussianise input tensor x based on the reference distribution. Steps: - For each value, approximate u = F_X(x) via interpolation in (x_sorted, u_sorted). - Then z = Φ^{-1}(u). Parameters ---------- x : torch.Tensor Data tensor (any shape). Returns ------- z : torch.Tensor Gaussianised tensor with same shape as x. """ if not torch.is_tensor(x): x = torch.as_tensor(x, device=self.device, dtype=self.dtype) else: x = x.to(device=self.device, dtype=self.dtype) # Approximate empirical CDF of x using reference x_sorted -> u_sorted u = self._interp_1d(x, self.x_sorted, self.u_sorted) # Map to Gaussian z = self._phi_inv(u) return z
__call__ = forward # so you can do g(x) directly # ----- inverse: Gaussian -> data -----
[docs] def invert(self, z: torch.Tensor) -> torch.Tensor: """ Map Gaussianised tensor z back to the original data space using the inverse empirical CDF. Steps: - u = Φ(z) - x = F_X^{-1}(u) via interpolation in (u_sorted, x_sorted). Parameters ---------- z : torch.Tensor Gaussian-space tensor (any shape). Returns ------- x_rec : torch.Tensor Reconstructed data tensor in the original amplitude space. """ if not torch.is_tensor(z): z = torch.as_tensor(z, device=self.device, dtype=self.dtype) else: z = z.to(device=self.device, dtype=self.dtype) # Map to uniform using normal CDF u = self._phi(z) # Approximate inverse empirical CDF using u_sorted -> x_sorted x_rec = self._interp_1d(u, self.u_sorted, self.x_sorted) return x_rec