Source code for STL_main.STL_2D_FFT_Torch

"""
Created on Wed Nov 14:07 2018
"""

import math
from dataclasses import dataclass, field

import matplotlib.pyplot as plt
import numpy as np
import torch
from scipy.integrate import quad

import STL_main.torch_backend as bk
from STL_main.Base_DataClass import Base_DataClass
from STL_main.ST_Operator import ST_Operator
from STL_main.torch_backend import (
    _DEFAULT_DEVICE,
    _DEFAULT_DTYPE,
    _get_device,
    _get_dtype,
    maskmean,
    to_torch_tensor,
)


###############################################################################
###############################################################################
[docs] @dataclass class STL_2D_FFT_Torch(Base_DataClass): """ STL_2D_FFT_torch child class for 2D planar STL FFT using PyTorch Inherits Base_DataClass. See Base_DataClass for parameter descriptions. Additional parameters --------------------- fourier_status : bool Indicates if the data is in Fourier space (True) or real space (False). """ # child class constant DT = "Planar2D_FFT_torch" # child instance attributes fourier_status: bool = False def __post_init__(self): super().__post_init__() ###########################################################################
[docs] def modulus(self, inplace=False): """ Compute the modulus (absolute value) of the array attribute of data. Parameters ---------- - inplace : bool If True, acts in-place and returns self. If False, returns a new STL_2D_FFT_Torch instance. Returns ------- STL_2D_FFT_Torch STL_2D_FFT_Torch instance whose array attribute is the modulus """ data = self.copy(empty=False) if not inplace else self data = data.set_fourier_status(target_fourier_status=False, inplace=True) data.array = torch.abs(data.array) data.dtype = data.array.dtype return data
###########################################################################
[docs] def fourier(self, inplace=False): """ Compute the Fourier Transform on the last two dimensions of the input tensor. Parameters ---------- - inplace : bool If True, acts in-place and returns self. If False, returns a new STL_2D_FFT_Torch instance. Returns ------- STL_2D_FFT_Torch STL_2D_FFT_Torch instance whose array attribute is Fourier domain """ data = self.copy(empty=False) if not inplace else self if data.fourier_status: return data else: data.array = torch.fft.fft2(data.array, norm="ortho") data.fourier_status = True return data
###########################################################################
[docs] def ifourier(self, inplace=False): """ Compute the inverse Fourier Transform on the last two dimensions of the input tensor. Parameters ---------- - inplace : bool If True, acts in-place and returns self. If False, returns a new STL_2D_FFT_Torch instance. Returns ------- STL_2D_FFT_Torch STL_2D_FFT_Torch instance whose array attribute is the Fourier """ data = self.copy(empty=False) if not inplace else self if not data.fourier_status: return data else: data.array = torch.fft.ifft2(data.array, norm="ortho") data.fourier_status = False return data
###########################################################################
[docs] def set_fourier_status(self, target_fourier_status, inplace=False): """ Put the in the desired Fourier status (target_fourier_status). Parameters ---------- - target_fourier_status : bool Desired Fourier status: True = Fourier space, False = real space. - inplace : bool If True, acts in-place and returns self. If False, returns a new STL_2D_FFT_Torch instance. Returns ------- STL_2D_FFT_Torch STL_2D_FFT_Torch instance in the desired Fourier status. """ data = self.copy(empty=False) if not inplace else self # If current status differs from desired if data.fourier_status != target_fourier_status: if target_fourier_status: data.fourier(inplace=True) else: data.ifourier(inplace=True) return data
###########################################################################
[docs] def divide(self, data2, epsilon=1e-8, pow=1.0, inplace=False): """ Divide self.array by data2.array raised to a power, with a small epsilon added to the denominator for numerical stability. - If `data2.array` is in real space: `self.array` is converted to real space (if needed), and the division is performed in real space. - If `data2.array` is in Fourier space: `self.array` is converted to Fourier space (if needed), and the division is performed in Fourier space (i.e., deconvolution in real space). Parameters ---------- data2 : STL_2D_FFT_Torch Another instance whose array is used as the denominator. Its Fourier status determines the computation domain. epsilon : float, optional Small constant added to the denominator for numerical stability (default is 1e-8). power : float, optional Exponent applied to the denominator (default is 1). inplace : bool If True, performs the operation in-place and returns self. If False, returns a new instance. Returns ------- STL_2D_FFT_Torch Result of the division in the appropriate domain. """ # convert self to the Fourier status of data2 data1 = self.set_fourier_status( target_fourier_status=data2.fourier_status, inplace=inplace ) # perform the division in the appropriate domain data1.array = data1.array / (data2.array + epsilon) ** pow data1.dtype = data1.array.dtype return data1
###########################################################################
[docs] def get_wavelet_op(self, *args, **kwargs): return WaveletOperator2D_FFT_torch( N0=self.N0, DT=self.DT, device=self.device, dtype=self.dtype, *args, **kwargs, )
###############################################################################
[docs] def get_ST_op(self, *args, **kwargs): return ST_Operator(data_example=self, *args, **kwargs)
###############################################################################
[docs] def get_CS_op(self, *args, **kwargs): return CS_operator_2D_FFT_torch( shape=self.N0, device=self.device, dtype=self.dtype, *args, **kwargs )
[docs] class WaveletOperator2D_FFT_torch: """ Class whose instances correspond to a wavelet transform operator. The wavelet set and the operator is built during the initilization. The operator is applied through apply method. This method is DT-dependent, and actually calls independent iterations, but with common method and attribute structure. The multi-resolution is dealt with several parameters: dg_max, which indicates the maximum dg resolution j_to_dg, which indicate the dg_j resolution associated to each j scale. For example, if you work with J=6 wavelets and N0=256, you can have: - dg_max=4, with associated dg factors (0, 1, 2, 3, 4) - true downsampling factors (1, 2, 4, 8, 16) - actual resolutions (256, 128, 64, 32, 16) - a j_to_dg list (0, 0, 1, 1, 2, 3) associated to j in range(J=6) The wavelet convolution is DT-dependent, an is performed either in real or Fourier space. They are two main types of wavelet arrays. - Single_Kernel==True: In this case, a set of L oriented wavelets is defined at a single pixellation. Convolutions at different scales are then down by subsequent susampling and convolution in pixel space with this single set of L oriented wavelets, and convolutions at all scales can not be done at the initial N0 resolution. - SR_Kernel==False: In this case, a set of J*L dilated and rotated wavelets is defined at the initial N0 resolution, and a convolution at all scales at this initial resolution can be performed. Convolution in a multi-resolution scheme can also be done, where convolution at each scale is done at the proper downsampling factor. This mean that different wavelets need to be stored: - Single_Kernel==True: a single set of L wavelets, stored in the wavelet_array attribute. - SR_Kernel==False: a set of J*L wavelets, store both at the initial N0 resolution in wavelet_array, and in a multi-resolution framework in wavelet_array_MR. => The wavelet_j method then allows to call for the correct quantity when a convolution is performed. Rq: when Single_Kernel==True, j_to_dg = range(J). This is not necessarily the case when False. See apply for more details. Parameters ---------- - DT : str Type of data (1d, 2d planar, HealPix, 3d) - N0 : tuple initial size of array (can be multiple dimensions) - J : int number of scales - L : int number of orientations - WType : str type of wavelets (e.g., "Morlet" or "Bump-Steerable") Attributes ---------- - parent parameters (DT,N0,J,L,WType) - dg_max: int (DT- and WType-dependent) maximum dg resolution of the Wavelet Transform (DT- and WType-dependent) - j_to_dg : list of int list of actual dg_j resolutions at each j scale - wavelet_array : torch tensor * array of wavelets at L orientation if Single_Kernel==True. * array of wavelets at J*L scales and orientation at N0 resolution if Single_Kernel==False - wavelet_array_MR : list (len J) of arrays list of arrays of L wavelets at all J scales and at Nj resolution Only if if Single_Kernel==False. - Single_Kernel : bool if convolution done at all scales with the same L oriented wavelets Questions and to do ---------- - Is Single_Kernel sufficient in itself? We could separate the fact to do the convolution with a kernel in real space and to be able to do all convolution at the initial N0 resolution, using two different attribute. - We could similarly attach the fact to use mask to the fact to have Single_Kernel==True. - Do we add a low-pass filter per default for j=J ? - Do we impose j_to_dg = range(J) for simplicity and efficiency? - I propose to anyway only have dyadic wavelets for the "main set". -> the inclusion of P00' power spectrum terms can be done differently. - for __init__, we could ask either for DT and N, or for a stl_array instance, from which DT and N are obtained. Could be added if useful. - A proper "by the book" set of Wavelets should be implemented, with proper Littelwood-Paley and co conditions. """
[docs] @staticmethod def gaussian_2d_rotated(mu, sigma, angle, size): """ Generate a rotated 2D Gaussian centered at an offset mu along the rotated axis from image center. Parameters ---------- mu : float Offset along the rotated axis from the image center (in pixels). sigma : float Isotropic standard deviation (spread). angle : float Rotation angle in radians (0 to pi). size : tuple of int Returns ------- torch.Tensor A 2D Gaussian of shape [Nx, Ny]. """ M, N = size x = torch.linspace(0, M - 1, M) y = torch.linspace(0, N - 1, N) X, Y = torch.meshgrid(x, y, indexing="ij") # Image center cx = M / 2 cy = N / 2 # Compute offset from center along rotated axis cos_a = torch.cos(torch.tensor(angle)) sin_a = torch.sin(torch.tensor(angle)) center_x = cx - mu * sin_a center_y = cy + mu * cos_a # Gaussian centered at (center_x, center_y) G = torch.exp(-((X - center_x) ** 2 + (Y - center_y) ** 2) / (2 * sigma**2)) # Threshold eps = 10**-1 G[G < eps] = 0 return G
[docs] @classmethod def gaussian_bank(cls, J, L, size, base_mu=None, base_sigma=None): """ Generate a bank of rotated and scaled 2D Gaussians. Parameters ---------- J : int Number of dyadic scales. L : int Number of orientations. base_sigma : float Smallest sigma (spread). base_mu : float Base offset along the rotated axis. size : tuple of int Grid size (M, N). Returns ------- torch.Tensor A tensor of shape [J, L, Nx, Ny], each entry L2-normalized. """ Nx, Ny = size filters_bank = torch.empty((J, L, Nx, Ny)) if base_mu is None: base_mu = min(Nx, Ny) / (2 * torch.sqrt(torch.tensor(2.0))) if base_sigma is None: base_sigma = base_mu / (2 * torch.sqrt(torch.tensor(2.0))) for j in range(J): sigma = base_sigma / (2**j) mu = base_mu / (2**j) for l in range(L): angle = float(l) * torch.pi / L filters_bank[j, l] = cls.gaussian_2d_rotated(mu, sigma, angle, size) # Return the zero frequency to (0,0), and put it to zero filters_bank = torch.fft.fftshift(filters_bank, dim=(-2, -1)) filters_bank[:, :, 0, 0] = 0 # ad hoc normalisation filters_bank /= 0.8 return filters_bank
[docs] @staticmethod def bump_steerable_2d(omega_grid, L, xi0, width_factor=2.5, eps=1e-12): """ Generate a 2D bump steerable wavelet in Fourier space. Parameters ---------- omega_grid : torch.Tensor Grid of frequencies in Fourier space, of shape [Nx, Ny, 2], where the last dimension corresponds to (omega_x, omega_y). L : int Number of orientations (steerability order). xi0 : float Center frequency xi = (xi0, 0) in Fourier space. width_factor : float Optimized constant to follow at best Littlewood-Paley condition (optimized for L=4). eps : float Small constant to avoid division by zero in the bump window. Returns ------- torch.Tensor A 2D bump steerable wavelet in Fourier space, of shape [Nx, Ny]. """ # radial part: bump window centered at xi0 omega_norm = torch.sqrt(omega_grid[..., 0] ** 2 + omega_grid[..., 1] ** 2) r = abs(omega_norm - xi0) / xi0 # apply bump window over r: g(r) = exp(-r^2 / (1 - r^2)) * 1_{0<r<1} r2 = r**2 r2 *= width_factor # optimized parameter (for L=4) to follow at best Littlewood Paley condition, see numerical integration in the notebook "consistency_Kernel_FFT.ipynb" support_r = (r >= 0.0) & (r < width_factor ** (-0.5)) denom = (1.0 - r2).clamp_min(eps) bump = torch.where(support_r, torch.exp(-r2 / denom), torch.zeros_like(r)) # angular part: cos(theta)^(L-1) where theta is the angle of omega in Fourier space theta = torch.atan2(omega_grid[..., 1], omega_grid[..., 0]) support_theta = abs(theta) < torch.pi / 2 angular = torch.where( support_theta, torch.cos(theta).pow(L - 1), torch.zeros_like(theta) ) weights = bump * angular # normalize weights /= weights.abs().max() # Add a constant to follow Littlewood-Paley condition when L!=4 c = ( (2 ** (L - 1)) * math.factorial(L - 1) / math.sqrt(L * math.factorial(2 * (L - 1))) ) c_L4 = ( (2 ** (4 - 1)) * math.factorial(4 - 1) / math.sqrt(4 * math.factorial(2 * (4 - 1))) ) return weights * (c / c_L4)
[docs] @classmethod def bump_steerable_bank(cls, J, L, size): """ Generate a bank of 2D bump steerable wavelets in Fourier space. Parameters ---------- J : int Number of dyadic scales. L : int Number of orientations (steerability order). size : tuple of int Grid size (Nx, Ny). Returns ------- torch.Tensor A tensor of shape [J, L, Nx, Ny] """ Nx, Ny = size filters_bank = torch.empty((J, L, Nx, Ny)) xi0 = min(Nx, Ny) / (2 * torch.sqrt(torch.tensor(2.0))) # Create the frequency grid in Fourier space, with the zero frequency at (0,0) omega_x = torch.fft.fftfreq(Nx) * Nx omega_x = torch.fft.fftshift(omega_x) # Shift zero frequency to center omega_y = torch.fft.fftfreq(Ny) * Ny omega_y = torch.fft.fftshift(omega_y) # Shift zero frequency to center Omega_x, Omega_y = torch.meshgrid(omega_x, omega_y, indexing="ij") omega_grid = torch.stack((Omega_x, Omega_y), dim=-1) for j in range(J): scale_factor = 2**j for l_idx, l in enumerate(range(L)): theta = math.pi * l / L + math.pi / 2 cos_theta = torch.cos(torch.tensor(theta)) sin_theta = torch.sin(torch.tensor(theta)) R = torch.tensor( [[cos_theta, -sin_theta], [sin_theta, cos_theta]], dtype=omega_grid.dtype, device=omega_grid.device, ) q = ( scale_factor * omega_grid @ R ) # rotate and dilate the frequency grid filters_bank[j, l_idx] = torch.fft.fftshift( cls.bump_steerable_2d(q, L=L, xi0=xi0) ) return filters_bank
@staticmethod def _get_crop_border_size_largest_scale_second_layer(data, wavelet_op): if data.pbc: return 0 else: deepest_layer = 2 return math.ceil( deepest_layer * wavelet_op.crop_borders[-1, :] .max() .item() # largest crop at full resolution / (2**data.dg) # adapt to current resolution ) @staticmethod def _get_crop_border_size_largest_scale_layer_flexible(data, wavelet_op): if data.pbc or len(data.conv_history) == 0: return 0 else: return math.ceil( len(data.conv_history) * wavelet_op.crop_borders[-1, :] .max() .item() # largest crop at full resolution / (2**data.dg) # adapt to current resolution ) @staticmethod def _get_crop_border_size_fully_flexible(data, wavelet_op): if data.pbc or len(data.conv_history) == 0: return 0 elif len(data.conv_history) == 1: return math.ceil( wavelet_op.crop_borders[data.conv_history[0], :] .max() .item() # crop at first convolution scale at full resolution / (2**data.dg) # adapt to current resolution ) elif len(data.conv_history) == 2: first_conv_border_downgraded = wavelet_op.crop_borders[ data.conv_history[0], : ].max().item() / ( # crop at first convolution scale at full resolution 2 ** data.conv_history[-1] ) # adapt to second convolution scale return math.ceil( first_conv_border_downgraded / (2 ** (data.dg - data.conv_history[-1])) + wavelet_op.crop_borders[data.conv_history[1], :].max().item() / (2**data.dg) ) else: raise ValueError("Invalid data conv_history.") def __init__( self, N0, J=None, L=None, WType="Bump-Steerable", DT="Planar2D_FFT_torch", device=_DEFAULT_DEVICE, dtype=_DEFAULT_DTYPE, get_crop_border_size_method=None, ): """ Constructor, see details above. Parameters ---------- - WType : str type of wavelets (e.g., "Bump-Steerable" or "Morlet") - L : int number of orientations - J : int number of scales - N0 : tuple of int initial size of fourier domain array (same as data to be processed) - DT : str Type of data (1d, 2d planar, HealPix, 3d) - device : torch.device Device to store the wavelet arrays. - dtype : torch.dtype Data type to store the wavelet arrays. - get_crop_border_size_method : function Method to compute the crop border size. """ self.WType = WType # type of wavelets (e.g., "Bump-Steerable" or "Morlet") # Main parameters self.N0 = N0 self.J = J if J is not None else int(np.log2(min(N0))) - 2 self.L = L if L is not None else 4 self.DT = DT self.device = _get_device(torch.device(device)) self.dtype = _get_dtype(dtype=dtype, device=self.device) self.wavelet_array = None self.wavelet_array_MR = None self.dg_max = None self.j_to_dg = None self._build() # Build all the wavelets-related attributes. if get_crop_border_size_method is not None: self._get_crop_border_size_method = get_crop_border_size_method else: self._get_crop_border_size_method = ( self.__class__._get_crop_border_size_fully_flexible ) # NaNs handling self.mask_full_res = None # Used for NaNs handling in other data types. Must be None for this one which does not handles NaNs. assert ( self.mask_full_res is None ), "mask_full_res must be set to None for this DataType that does not handle NaNs." self.estimate_crop_borders() ########################################################################### def _build(self): """ Build attributes related to the wavelet set and in multi-resolution framework: - wavelet_array - wavelet_array_MR - dg_max - j_to_dg """ # Create the full resolution Wavelet set (in fourier space plus fftshifted) if self.WType == "Morlet": self.wavelet_array = self.__class__.gaussian_bank( self.J, self.L, self.N0 ).to( device=self.device, dtype=self.dtype ) # [J, L, N0x, N0y] elif self.WType == "Bump-Steerable": self.wavelet_array = self.__class__.bump_steerable_bank( self.J, self.L, self.N0 ).to( device=self.device, dtype=self.dtype ) # [J, L, N0x, N0y] else: raise ValueError("Invalid WType.") # Find dg_max (with a min size of 16 = 2 * 8) # To avoid storing tensors at the same effective resolution self.dg_max = int(np.log2(min(self.N0)) - 4) # Create the MR list of wavelets self.wavelet_array_MR = [] self.j_to_dg = [] for j in range(self.J): dg = min(j, self.dg_max) subsampled_wavelet = self.__class__.downsample( data=STL_2D_FFT_Torch(array=self.wavelet_array[j], fourier_status=True), dg_out=dg, normalize=False, inplace=True, target_fourier_status=True, ) # [L, Njx, Njy] assert subsampled_wavelet.fourier_status self.wavelet_array_MR.append(subsampled_wavelet.array) self.j_to_dg.append(dg)
[docs] def estimate_crop_borders(self): N, M = self.N0 # Create impulse at the center of the image impulse = torch.zeros((N, M), device=self.device, dtype=self.dtype) impulse[N // 2, M // 2] = 1.0 # FFT of impulse impulse_ft = torch.fft.fftshift(torch.fft.fft2(impulse, norm="ortho")) # [N, M] # Apply the wavelet filter at j=J and L=0 to the impulse in Fourier Space (look at the impulse response along this direction) psf = torch.fft.ifft2( torch.fft.ifftshift(impulse_ft * self.wavelet_array[-1, 0], dim=(-2, -1)), norm="ortho", dim=(-2, -1), ).abs() # [N, M] # Extract horizontal traces from pixel source traces = psf[N // 2, : M // 2] # [M//2] # Determine border where PSF drops below threshold_percent of the vertical trace at the source pixel (maximum value) threshold_percent = 0.1 threshold = threshold_percent * traces[-1] # [N//2] above_threshold = traces > threshold # [N//2] crop_border = math.ceil(N / 2) - ( above_threshold.float().argmax(dim=0) + 1 ) # scalar # Expanding crop borders to all scales and orientations crop_borders = torch.tensor( [math.ceil(crop_border / (2**j)) for j in range(self.J - 1, -1, -1)] ) crop_borders = crop_borders[:, None].repeat(1, self.L) self.crop_borders = crop_borders
[docs] @staticmethod def wavelet_conv_full(data, wavelet_set): """ Perform convolutions of data with the entire wavelet set at full resolution. WARNING: Sets the data in Fourier space in place if data is in real space. Parameters ---------- - data: STL_2D_FFT_Torch instance whose array attribute is a torch.Tensor of size [..., Nx, Ny] Data to be convolved with the wavelt_set - wavelet_set: torch.Tensor of size [J, L, Nx, Ny] Wavelet set in Fourier space at all J scales and L orientations Returns ------- - STL_2D_FFT_Torch instance with: - array: torch.Tensor [..., J, L, Nx, Ny] Convolution in Fourier space between data and wavelet_set - fourier_status: bool True """ # Set data in Fourier space in place data = data.set_fourier_status(target_fourier_status=True, inplace=True) return STL_2D_FFT_Torch( array=data[..., None, None, :, :].array * wavelet_set, pbc=data.pbc, fourier_status=True, ) # [..., J, L, Nx, Ny]
[docs] @staticmethod def wavelet_conv(data, wavelet_set_MR, j): """ Perform convolutions of data with a set of L wavelets fixed at a given scale and covering all orientations. Both the data and the wavelet should be at the Nj resolution. WARNING: Sets the data in Fourier space in place if data is in real space. Parameters ---------- - data: STL_2D_FFT_Torch instance whose array attribute is a torch.Tensor of size [..., Njx, Njy] Data to be convolved with the wavelet_set, at resolution Nj - wavelet_set_MR: list (len J) of torch.Tensor of size [L, Njx, Njy] - j: int Scale index to select the wavelet set at resolution Nj Returns ------- - STL_2D_FFT_Torch instance with: - array: torch.Tensor [..., L, Njx, Njy] Convolution in Fourier space between data and wavelet_set at scale j - fourier_status: bool True """ # Set data in Fourier space in place data.set_fourier_status(target_fourier_status=True, inplace=True) wavelet_j = wavelet_set_MR[j] # [L, Njx, Njy] return STL_2D_FFT_Torch( array=data[..., None, :, :].array * wavelet_j, dg=data.dg, N0=data.N0, fourier_status=True, pbc=data.pbc, conv_history=data.conv_history + [j], ) # [..., L, Njx, Njy]
########################################################################### def _crop(self, array, border): """ Crops an array by removing 'border' pixels from each side along the last two dimensions. Parameters ---------- array : torch.Tensor Input array to be cropped. border : int Number of pixels to remove from each side. Returns ------- torch.Tensor Cropped array. """ if array is None: return None elif border == 0: return array else: # handling of borders larger than array can be adapted depending on desired behavior if False: # conservative handling of borders larger than array assert array.shape[-2] > 2 * border assert array.shape[-1] > 2 * border elif True: # flexible handling of borders larger than array if min(array.shape[-2:]) <= 2 * border: if not getattr( self, "_border_warning_raised", False ): # warns the user only once per wavelate operator print( "Warning! Data with shape {:} too small to be cropped with border {:}. Using border={:} instead.".format( array.detach().cpu().numpy().shape[-2:], border, (min(array.shape[-2:]) - 1) // 2, ) ) self._border_warning_raised = True border = (min(array.shape[-2:]) - 1) // 2 else: # simple handling of borders larger than array: maskmean will return nan pass return array[..., border:-border, border:-border] ###########################################################################
[docs] def mean(self, data, dim=(-2, -1), **kwargs): """ Compute the mean on the last two dimensions (Nx, Ny). Parameters ---------- - data : STL_2D_FFT_Torch Input data. Array should in real space in ST_op workflow - dim : tuple of int Dimensions on which the mean is computed. """ if data.pbc is None and len(data.conv_history) > 0: raise ValueError("data.pbc should be specified (True or False).") if data.fourier_status: if data.pbc: return data.array[..., 0, 0] / np.sqrt( math.prod(data.array.shape[i] for i in dim) ) else: raise NotImplementedError( "Mean computation in Fourier space for non-periodic data is not implemented." ) else: border = self._get_crop_border_size_method(data=data, wavelet_op=self) cropped_array = self._crop(array=data.array, border=border) # No prefactor needed for mean in real space thanks to downsample function return maskmean(x=cropped_array, dim=dim)
[docs] def square_mean(self, data, dim=(-2, -1), **kwargs): if data.pbc is None and len(data.conv_history) > 0: raise ValueError("data.pbc should be specified (True or False).") if data.fourier_status: if data.pbc: return torch.mean( data.array * data.array.conj() ).real # Parseval identity else: raise NotImplementedError( "Square mean computation in Fourier space for non-periodic data is not implemented." ) else: border = self._get_crop_border_size_method(data=data, wavelet_op=self) cropped_array = self._crop( array=data.array * data.array.conj(), border=border ) # No prefactor needed for mean in real space thanks to downsample function return maskmean(x=cropped_array, dim=dim)
[docs] def cov(self, data1, data2, remove_mean=False, dim=(-2, -1), **kwargs): assert data1.dg == data2.dg, "data1 and data2 must have the same resolution." if remove_mean: raise NotImplementedError("remove_mean is not yet implemented.") border = max( self._get_crop_border_size_method(data=data1, wavelet_op=self), self._get_crop_border_size_method(data=data2, wavelet_op=self), ) if data1.pbc and data1.fourier_status and data2.pbc and data2.fourier_status: # Parseval identity return maskmean( x=data1.array * torch.conj(data2.array), dim=dim, ) else: data1.set_fourier_status(target_fourier_status=False, inplace=True) data2.set_fourier_status(target_fourier_status=False, inplace=True) cropped_array = self._crop( array=data1.array * torch.conj(data2.array), border=border, ) return maskmean(x=cropped_array, dim=dim)
###########################################################################
[docs] def standardize(self, data, mean_field, inplace=False, dim=None): """ Standardize the data by removing the mean and scaling to unit variance on the last two dimensions (Nx, Ny) in real space. Parameters ---------- - data : STL_2D_Kernel_Torch Input data whose array attribute has to be standardized. - mean_field : bool If True, compute mean/std averaged over the batch dimension. - inplace : bool If True, perform the operation in-place on the input data. - dim : tuple Dimensions over which to compute the mean and standard deviation. Returns ------- - STL_2D_Kernel_Torch Standardized data. - torch.Tensor Mean used for standardization. - torch.Tensor Standard deviation used for standardization. """ if dim is None: dim = (-2, -1) l_data = data.copy(empty=False) if not inplace else data mean = self.mean(l_data) # [Nb,Nc] if mean_field: mean = mean.mean(dim=0, keepdim=True) # [1,Nc] l_data.array = ( l_data.array - mean[..., None, None] ) # centering first because no remove_mean in cov var = self.cov(l_data, l_data) if mean_field: var = var.mean(dim=0, keepdim=True) # [1,Nc] std = torch.sqrt(var) l_data.array = l_data.array / std[..., None, None] return l_data, mean, std
###########################################################################
[docs] def unstandardize(self, data, mean, std, inplace=False, dim=None): """ Unstandardize the data by scaling back using the provided mean and std. Parameters ---------- - data : STL_2D_FFT_Torch Input data whose array attribute has to be unstandardized. - mean : torch.Tensor Mean used for standardization. - std : torch.Tensor Standard deviation used for standardization. Returns ------- - STL_2D_FFT_Torch Unstandardized data. """ l_data = data.copy(empty=False) if not inplace else data # unstandardization is done in real space if l_data.fourier_status: l_data = l_data.set_fourier_status( target_fourier_status=False, inplace=True ) l_data.array = l_data.array * std[..., None, None] + mean[..., None, None] return l_data
########################################################################### def _compute_and_store_cross_cov( self, data1, data2, output, compute_cross_matrix, redundant_channels, remove_mean=False, dim=(-2, -1), ): assert ( data1.array.shape[1] == data2.array.shape[1] ), "data1 and data2 arrays must have the same number of channels." assert ( data1.array.ndim == data2.array.ndim ), "data1 and data2 arrays must have the same number of dimensions." assert ( data1.array.shape[1] == output.shape[1] ), "output and data must have the same number of channels." assert ( output.shape[1] == output.shape[2] ), "output must have shape (Nb, Nc, Nc, ...)." Nc = output.shape[1] # number of channels for c1 in range(Nc): for c2 in range(c1, Nc): if compute_cross_matrix[c1, c2]: output[:, c1, c2, ...] = self.cov( data1=data1[:, c1, ...], data2=data2[:, c2, ...], remove_mean=remove_mean, dim=dim, ) if not redundant_channels and c1 != c2: output[:, c2, c1, ...] = self.cov( data1=data1[:, c2, ...], data2=data2[:, c1, ...], remove_mean=remove_mean, dim=dim, ) return ###########################################################################
[docs] def apply(self, data, j=None, target_fourier_status=None, **kwargs): """ Compute the Wavelet Transform (WT) of data. This method is DT dependent, and calls independent iterations with common method and attribute structure. Data should be a MR==False StlData instance. The wavelet transform can either be computed at all J*L scales and angles (fullJ), or at a given j scale (single_j) for all L orientations? For code efficiency, this method requires a MR=True StlData instance for the masks at all resolution, with list_dg = range(dg_max + 1). The different modes are: - fullJ (j=None and MR=False): convolution at J*L scales and angles, without a MR framework, only if Single_Kernel==False. Input data should be a MR=False StlData instance at dg=0 resolution. No mask are a priori allowed in this case. - fullJ_MR (j=None and MR=True): convolution at J*L scales and angles, within a MR framework, always possible. Input data should be a MR=True StlData instance at all resolution between dg=0 and dg_max [ lids_dg = range(dg_max) ] - single_j: at L angles at a given scale j, within a MR framework. Input data should be a MR=False StlData instance wtih dg = dg_j. Rq: If j = None, the defaut value for MR if j=None is False if Single_Kernel==False, and True else. For a single_j convolution, MR can only be true. Rq: mask_MR is allowed only if mask_opt==True. Parameters ---------- - data : STL_2D_FFT_Torch, Input data of same DT/N0, can be batched on several dimension. -> dg=0 if fullJ -> dg=dg_j if single_j - j : int Scale at which the convolution is done. Done at all scales if None. - target_fourier_status : bool or None Desired Fourier status of output. If None, DT-dependent default is used. Output ---------- - WT : StlData: -> [..., J, L, N0] if j is None -> [..., L, Nj] if j == int Questions and to do ---------- - I propose not to deal with the issue of non-periodicity here, but only in the mean and cov functions, at the end of the computations. - We could think at the possibility to compute WT at fixed (j,l) values, if it helps distributing the computations for large batchs. - To decide if we impose a condition on mask_MR, like the fact that it is on unit mean. - I'm a bit skeptical by the fact that an internal Fourier transform could be necessary here, since it means that the same transform could have to be on multiple call of this method. - For the convolution at a fixed scale. Should we accept data that are not at Nj resolution and downsample them? It need to be see with usage """ # Check coherence of input data. if type(data).__name__ != "STL_2D_FFT_Torch": raise Exception( f"Data should be a STL_2D_FFT_Torch instance, got {type(data)}" ) if self.DT != data.DT: raise Exception("Data and wavelet transform should have same DT") if self.N0 != data.N0: raise Exception("Data and wavelet transform should have same N0 attributes") # fullJ convolution (j=None) if j is None: if data.dg != 0: raise Exception("Data should be at dg=0 resolution") # Convolution at all scales at full resolution N0 WT = self.__class__.wavelet_conv_full(data, self.wavelet_array) # single_j convolution elif isinstance(j, int): # Check that dg_j resolutions are compatible if data.dg != self.j_to_dg[j]: raise Exception("Data should be at dg_j resolution") # Convolution at scale j at resolution Nj WT = self.__class__.wavelet_conv(data, self.wavelet_array_MR, j) else: raise Exception("j should be a single int") # Transform to correct Fourier status if necessary if target_fourier_status is not None: WT.set_fourier_status(target_fourier_status, inplace=True) return WT
###########################################################################
[docs] @staticmethod def downsample( data, dg_out, normalize=True, inplace=True, target_fourier_status=True, **kwargs ): """ Downgrade the data array to dg_out resolution by cropping in Fourier space. Parameters ---------- data : STL_2D_FFT_Torch Data object to be downgraded (currently at dg_in resolution). dg_out : int Target resolution after downgrading. normalize : bool If True, normalize the output in fourier space to keep the same real mean as input. inplace : bool If True, modifies data in-place. If False, returns a new instance. target_fourier_status : bool If True, output is in Fourier space. If False, output is in real space and normalized to keep the same real mean as input. Notes ----- - To remain consistent, if you downgrade an image that was already downgraded, it is recommended to keep the output in the same domain (Fourier or real) as the previous data. Otherwise, normalization issues may appear. Returns ------- STL_2D_FFT_Torch Data downgraded to dg_out resolution. """ dg_in = data.dg if dg_out < dg_in: raise ValueError("dg_out should be greater than or equal to dg_in.") # Prepare target data object data = data.copy(empty=False) if not inplace else data # Compute input and output shapes in_shape = data.array.shape factor = 2 ** (dg_out - dg_in) out_shape = (in_shape[-2] // factor, in_shape[-1] // factor) # Ensure data is in Fourier space data.set_fourier_status(target_fourier_status=True, inplace=True) data_fft = torch.fft.fftshift(data.array, dim=(-2, -1)) # Determine minimal crop size and keep original image ratio min_x, min_y = 8, 8 if data.N0[0] > data.N0[1]: min_x = int(min_x * data.N0[0] / data.N0[1]) elif data.N0[1] > data.N0[0]: min_y = int(min_y * data.N0[1] / data.N0[0]) dx = max(min_x, out_shape[0]) dy = max(min_y, out_shape[1]) # Compute crop indices center_x, center_y = in_shape[-2] // 2, in_shape[-1] // 2 half_dx, half_dy = dx // 2, dy // 2 # Crop in Fourier space cropped_fft = data_fft[ ..., center_x - half_dx : center_x + half_dx, center_y - half_dy : center_y + half_dy, ] # Assign cropped array back, inverse shift data.array = torch.fft.ifftshift(cropped_fft, dim=(-2, -1)) if normalize: data.array *= 1 / factor data.dg = dg_out # Optionally convert back to real space with normalization if not target_fourier_status: data.set_fourier_status(target_fourier_status=False, inplace=True) return data
[docs] class CS_operator_2D_FFT_torch: """ Class whose instances correspond to a cross spectrum operator for 2D FFT data. The operator is applied through apply method and is DT-dependent. """ # Useful functions for the bin mask wavelet bank construction
[docs] @staticmethod def s(t): if -1 < t < 1: return np.exp(-1.0 / (1.0 - t**2)) return 0.0
[docs] @classmethod def s_lambda(cls, t, lam): return cls.s((2.0 * lam / (lam - 1.0)) * (t - 1.0 / lam) - 1.0)
[docs] @classmethod def k_lambda(cls, t, lam): if t <= 1.0 / lam: return 1.0 if t >= 1.0: return 0.0 # Integrals num, _ = quad(lambda tp: (cls.s_lambda(tp, lam) ** 2) / tp, t, 1.0) den, _ = quad(lambda tp: (cls.s_lambda(tp, lam) ** 2) / tp, 1.0 / lam, 1.0) return num / den
[docs] @classmethod def kappa_lambda(cls, t, lam): val = cls.k_lambda(t / lam, lam) - cls.k_lambda(t, lam) return np.sqrt(max(val, 0))
########################################################################### def __init__( self, shape, n_bins=None, J=None, device=_DEFAULT_DEVICE, dtype=_DEFAULT_DTYPE, get_crop_border_size_method="flexible_crop", ): """ Initialize a frequency binning object. Args: N0 (tuple): Image size (N, M) n_bins (int): Number of radial frequency bins device: torch device dtype: torch dtype get_crop_border_size_method : str ("flexible_crop" or "largest_crop") """ self.shape = shape self.n_bins = ( int(2 ** (np.log2(min(shape)) - 4)) if n_bins is None else n_bins ) # adaptive number of bins self.J = int(np.log2(min(shape))) - 2 if J is None else J self.device = _get_device(torch.device(device)) self.dtype = _get_dtype(dtype=dtype, device=self.device) self.get_crop_border_size_method = get_crop_border_size_method # --- Build frequency bin masks --- self._build_bin_masks() # --- Estimate crop borders for each bin (for non-PBC data apply) --- self.estimate_crop_borders() ########################################################################### def _build_bin_masks(self): N, M = self.shape self.min_freq = 1 / (2.0**self.J) self.max_freq = 0.5 # Nyquist frequency # get radial profil at high resolution lam = (self.max_freq / self.min_freq) ** (1 / (self.n_bins + 1)) k_vals = torch.linspace(self.min_freq, self.max_freq, 1000) scales_j = torch.arange(1, self.n_bins + 1) psi_kernels = [] for j in scales_j: psi_j = np.array( [self.kappa_lambda(k / (self.min_freq * lam**j), lam) for k in k_vals] ) psi_kernels.append(psi_j) # go from 1D radial profile to 2D bin masks freq_y = torch.fft.fftfreq(N) freq_x = torch.fft.fftfreq(M) FY, FX = torch.meshgrid(freq_y, freq_x, indexing="ij") radial_freq = torch.fft.fftshift(torch.sqrt(FX**2 + FY**2)) diff = torch.abs(radial_freq.unsqueeze(-1) - k_vals) idx = torch.argmin(diff, dim=-1) psi_kernels = torch.from_numpy(np.array(psi_kernels)) # shape [n_bins, 1000] self.bin_masks = torch.zeros( (self.n_bins, N, M), device=self.device, dtype=self.dtype ) for j in range(self.n_bins): self.bin_masks[j] = psi_kernels[j][idx] self.bin_centers = self.min_freq * lam**scales_j self.lam = lam
[docs] def estimate_crop_borders(self): N, M = self.shape # Create impulse at right border, centered vertically impulse = torch.zeros((N, M), device=self.device, dtype=self.dtype) impulse[N // 2, M // 2] = 1.0 # FFT of impulse impulse_ft = torch.fft.fftshift(torch.fft.fft2(impulse, norm="ortho")) # [N, M] # Apply all masks in batch impulse_ft = impulse_ft.unsqueeze(0) # [1, N, M] for broadcasting psfs = torch.fft.ifft2( torch.fft.ifftshift(impulse_ft * self.bin_masks, dim=(-2, -1)), norm="ortho", dim=(-2, -1), ).real # [n_bins, N, M] # Extract horizontal traces from pixel source traces = psfs[:, N // 2, : M // 2].abs() # [n_bins, M//2] # Determine border where PSF drops below threshold_percent of the trace at the source pixel (maximum value) threshold_percent = 0.1 threshold = threshold_percent * traces[:, -1].unsqueeze(1) # [n_bins, 1] above_thresh = traces > threshold # [n_bins, M//2] self.crop_borders = math.ceil(M / 2) - ( above_thresh.float().argmax(dim=1) + 1 ) # [n_bins]
###########################################################################
[docs] def build_mask_crop(self, array, border): """ Crops an array by removing 'border' pixels from each side along the last two dimensions. Pads with zeros for each cropped side (border may be different for each bin) to keep the same output shape. Parameters ---------- array : torch.Tensor Input array to be cropped. border : torch.Tensor Number of pixels to remove from each side. Shape [n_bins]. Returns ------- torch.Tensor Cropped array. Shape [Nb, Nc, n_bins, N, M]. """ if array.ndim < 3: raise ValueError( "Input tensor must have at least 3 dimensions to apply per-bin crop." ) N, M = array.shape[-2:] rows = torch.arange(N, device=array.device).view(1, N, 1) cols = torch.arange(M, device=array.device).view(1, 1, M) border_broadcast = border.view(self.n_bins, 1, 1) mask = ( (rows >= border_broadcast) & (rows < (N - border_broadcast)) & (cols >= border_broadcast) & (cols < (M - border_broadcast)) ) # [n_bins, N, M] return mask
###########################################################################
[docs] def apply( self, data, compute_cross_spectrum_matrix=None, get_crop_border_size_method=None ): """ Compute the power spectrum of the input data array attribute. Parameters ---------- - data : STL_2D_FFT_Torch Input data whose array attribute's power spectrum is to be computed. - compute_cross_spectrum_matrix : torch.BoolTensor of shape [Nc, Nc] Boolean matrix indicating which cross-spectra to compute. If None, only auto-spectra are computed. - get_crop_border_size_method : str or None Method to determine crop border size for non-PBC data. If None, uses the default method specified in the operator initialization. Returns ------- torch.Tensor Cross spectrum values of shape [..., Nc, Nc, n_bins] """ # consistency check if type(data).__name__ != "STL_2D_FFT_Torch": raise Exception( f"Data should be a STL_2D_FFT_Torch instance, got {type(data)}" ) if self.shape != data.N0: raise Exception("Data shape does not match operator shape") if data.dg != 0: raise Exception("Data dg must be 0 for power spectrum computation") if data.array.isnan().any(): raise ValueError( "Data array contains NaN values, cannot compute power spectrum" ) if self.device != data.device: raise Exception("Data device does not match operator device") get_crop_border_size_method = ( self.get_crop_border_size_method if get_crop_border_size_method is None else get_crop_border_size_method ) # Ensure data is in Fourier space l_data = data.set_fourier_status( target_fourier_status=True, inplace=False ) # copy of data in Fourier space l_data.array = torch.fft.fftshift(l_data.array, dim=(-2, -1)) # [Nb, Nc, N, M] # Put in the expected shape if not already (should be already done in ST_op apply) if l_data.array.ndim == 2: l_data.array = l_data.array[None, None, :, :] # [1, 1, N, M] elif l_data.array.ndim == 3: l_data.array = l_data.array[None, :, :, :] # [1, Nc, N, M] Nb, Nc, N, M = l_data.array.shape n_bins = self.n_bins cross_spectrum = ( bk.zeros((Nb, Nc, Nc, n_bins), dtype=bk._DEFAULT_COMPLEX_DTYPE) + bk.nan ) compute_cross_spectrum_matrix = ( bk.eye(Nc, dtype=bool) if compute_cross_spectrum_matrix is None else compute_cross_spectrum_matrix ) l_data_bin = ( l_data.array[:, :, None, :, :] * self.bin_masks[None, None, :, :, :] ) # [Nb, Nc, Nbin, N, M] cross_product_bin = l_data_bin[:, :, None, :, :, :] * torch.conj( l_data.array[:, None, :, None, :, :] ) # [Nb, Nc, Nc, n_bins, N, M] if l_data.pbc: cross_vals = ( cross_product_bin.sum(dim=(-2, -1)) / self.bin_masks.sum(dim=(-2, -1))[None, None, None, :] ).to(dtype=bk._DEFAULT_COMPLEX_DTYPE) # Symetric part is redundant and then not filled as cross_spectrum(c1, c2) and cross_spectrum(c2, c1) are conjugates cross_spectrum[:, compute_cross_spectrum_matrix, :] = cross_vals[ :, compute_cross_spectrum_matrix, : ] return cross_spectrum # [Nb, Nc, Nc, n_bins] if get_crop_border_size_method == "flexible_crop": border = self.crop_borders # [n_bins] elif get_crop_border_size_method == "largest_crop": # border = torch.zeros(self.n_bins) border = torch.full_like( self.crop_borders, self.crop_borders.max() ) # [n_bins] else: raise ValueError( f"Invalid get_crop_border_size_method: {get_crop_border_size_method}" ) ifft_l_data_bin = torch.fft.ifft2( l_data_bin, norm="ortho", dim=(-2, -1) ) # [Nb, Nc, n_bins, N, M] l_data.set_fourier_status( target_fourier_status=False, inplace=True ) # [Nb, Nc, N, M] mask_crop = self.build_mask_crop(l_data.array, border=border) # [n_bins, N, M] prefactor = (l_data.N0[0] * l_data.N0[1]) / mask_crop.sum( dim=(-2, -1) ) # [n_bins] cross_product_bin_real = ifft_l_data_bin[:, :, None, :, :, :] * torch.conj( l_data.array[:, None, :, None, :, :] ) # [Nb, Nc, Nc, n_bins, N, M] cross_vals = ( prefactor * (cross_product_bin_real * mask_crop[None, None, None, :, :, :]).sum( dim=(-2, -1) ) / self.bin_masks.sum(dim=(-2, -1))[None, None, None, :] ).to(dtype=bk._DEFAULT_COMPLEX_DTYPE) cross_spectrum[:, compute_cross_spectrum_matrix, :] = cross_vals[ :, compute_cross_spectrum_matrix, : ] # return cross_spectrum, cross_product_bin, cross_product_bin_real return cross_spectrum # [Nb, Nc, Nc, n_bins]
###########################################################################
[docs] def plot_cross_spectrum(self, cs_tensor, b=0, c1=0, c2=0, label=None, color="b"): """ Plot the power spectrum. Parameters ---------- b : int Batch index (0<=b<Nb) c1, c2 : int Channel indices (0<=c1,c2<Nc) cs_tensor: torch.Tensor of shape [Nb, Nc, Nc, n_bins] Cross spectrum values to plot Returns ------- None """ cs_values = cs_tensor[b, c1, c2, :].cpu().numpy() freqs = self.bin_centers.cpu().numpy() if cs_values.shape != freqs.shape: raise ValueError( f"ps_values shape: {cs_values.shape} and freqs shape: {freqs.shape} must have the same shape." ) plt.plot(freqs, cs_values, "-", marker="o", label=label, color=color) plt.xscale("log") plt.yscale("log") plt.xlabel("frequency") plt.ylabel("Cross Spectra") plt.title(f"Radial Cross Spectra c{c1+1}-c{c2+1} for image {b+1}") plt.grid(True, which="both", ls="-", alpha=0.5) plt.legend()