#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Jul 27 14:23:20 2025
Example methods for a test data type.
2D planar maps with convolution in Fourier.
This class makes all computations in torch.
Characteristics:
- in pytorch
- assume real maps (in real space)
- N0 gives x and y sizes for array shaped (..., Nx, Ny).
- masks are not supported in convolutions
"""
import numpy as np
import torch
###############################################################################
[docs]
def DT1_subsampling_func(array, Fourier, N0, dg, dg_out, mask_MR):
"""
Downsample the data to the specified resolution.
Note: Masks are not supported in this data type.
Parameters
----------
array : torch.Tensor
Input tensor to be downsampled.
Fourier : bool
Indicates whether input array is in Fourier space.
N0 : tuple of int
Initial resolution of the data.
dg : int
Current downsampling factor of the data.
dg_out : int
Desired downsampling factor of the data.
mask_MR : None
Placeholder for mask, not used in this function.
Returns
-------
torch.Tensor
Downsampled data at the desired downgrading factor dg_out.
fourier : bool
Indicates whether output array is in Fourier space.
"""
if mask_MR is not None:
raise Exception("Masks are not supported in DT1")
if dg_out == dg:
return array, Fourier
# Tuning parameter to keep the aspect ratio and a unified resolution
min_x, min_y = 8, 8
if N0[0] > N0[1]:
min_x = int(min_x * N0[0] / N0[1])
elif N0[1] > N0[0]:
min_y = int(min_y * N0[1] / N0[0])
# Identify the new dimensions
dx = int(max(min_x, N0[0] // 2 ** (dg_out + 1)))
dy = int(max(min_y, N0[1] // 2 ** (dg_out + 1)))
# Check expected current dimensions
dx_cur = int(max(min_x, N0[0] // 2 ** (dg + 1)))
dy_cur = int(max(min_y, N0[1] // 2 ** (dg + 1)))
# Perform downsampling if necessary
if dx != dx_cur or dy != dy_cur:
# Fourier transform if in real space
if not Fourier:
array = torch.fft.fft2(array, norm="ortho")
Fourier = True
# Downsampling in Fourier
array_dg = torch.cat(
(
torch.cat((array[..., :dx, :dy], array[..., -dx:, :dy]), -2),
torch.cat((array[..., :dx, -dy:], array[..., -dx:, -dy:]), -2),
),
-1,
) * np.sqrt(dx * dy / dx_cur / dy_cur)
return array_dg, Fourier
else:
return array, Fourier
###############################################################################
[docs]
def DT1_subsampling_func_toMR(array, Fourier, N0, dg_max, mask_MR):
"""
Generate a list of downsampled input array from resolution dg=0 to
dg=dg_max, following list_dg = range(dg_max + 1). Input array is expected
at dg=0 resolution.
Note: Masks are not supported in this data type.
Parameters
----------
array : torch.Tensor
Input tensor to be downsampled.
Fourier : bool
Indicates whether the array is in Fourier space.
N0 : tuple of int
Initial resolution of the data.-
dg_max : int
Maximum downsampling factor
mask_MR : None
Placeholder for mask, not used in this function.
Returns
-------
list of torch.Tensor
List of downsampled tensors for each downgrading factor from dg=0 to
dg=dg_max.
fourier : bool
Indicates whether output array is in Fourier space.
"""
if mask_MR is not None:
raise Exception("Masks are not supported in DT1")
# First Fourier transform if necessary.
if not Fourier:
array = torch.fft.fft2(array, norm="ortho")
Fourier = True
downsampled_arrays = [array]
current_array = array
for dg_out in range(1, dg_max + 1):
current_array, _ = DT1_subsampling_func(
current_array, Fourier, N0, dg_out - 1, dg_out, None
)
downsampled_arrays.append(current_array)
return downsampled_arrays, Fourier
###############################################################################
[docs]
def DT1_mean_func_MR(array, N0, list_dg, square, mask_MR):
"""
Compute the mean of a list of tensors on their last two dimensions.
The other dimensions of the tensors must match.
These means are stacked on the last dimension of the output tensor.
A multi-resolution mask in real space can be given. It should be of unit
mean at each resolution.
Parameters
----------
array : list of torch.Tensor
List of input tensors for which the mean is to be computed.
N0 : tuple of int
Initial resolution of the data (not used in this function).
list_dg : list of int
List of downsampling factors of the data (not used in this function).
square : bool
If True, compute the quadratic mean.
mask_mr : list of torch.Tensor, optional
List of mask tensors at the relevant resolutions.
Last dimensions should match with input array.
They should be of unit mean at each resolution.
Returns
-------
torch.Tensor
Mean of input arrays, stacked on the last dimension.
"""
# Pre-allocate the resulting tensor
shape_except_N = array[0].shape[:-2]
len_list = len(array)
mean = torch.empty(shape_except_N + (len_list,))
# Loop the mean computation over the list
for i, tensor in enumerate(array):
# Define mask
mask = 1 if mask_MR is None else mask_MR[i]
# Compute mean
if square is False:
mean[..., i] = torch.mean(array[i] * mask, dim=(-2, -1))
else:
mean[..., i] = torch.mean((array[i].abs()) ** 2 * mask, dim=(-2, -1))
return mean
###############################################################################
[docs]
def DT1_wavelet_build(N0, J, L, WType):
"""
Generate a set of 2D planar wavelets in Fourier space, both in full
resolution and in a multi-resolution settings, as well as the related
parameters.
Default values for J, L, and Wtype are used if None.
Parameters
----------
- N0 : tuple
initial size of array (can be multiple dimensions)
- J : int
number of scales
- L : int
number of orientations
- WType : str
type of wavelets (e.g., "Morlet" or "Bump-Steerable")
Returns
-------
wavelet_array : torch.Tensor of size (J,L,N0)
Array of wavelets at J*L scales and orientation at N0 resolution.
wavelet_array_MR : list of torch.Tensor of size (L,Nj)
list of arrays of L wavelets at all J scales and at Nj resolution.
dg_max : int
Maximum dg downsampling factor
- j_to_dg : list of int
list of actual dg_j resolutions at each j scale
- Single_Kernel : bool -> False here
If convolution done at all scales with the same L oriented wavelets
- mask_opt : bool -> False here
If it is possible to do use masked during the convolution
"""
# Default values
if J is None:
J = int(np.log2(min(N0))) - 2
if L is None:
L = 4
if WType is None:
WType = "Crappy"
# Wtype-specific construction
if WType == "Crappy":
# Crappy wavelet set for test. A proper one should be implemented.
# Create the full resolution Wavelet set
wavelet_array = gaussian_bank(J, L, N0)
# Find dg_max (with a min size of 16 = 2 * 8)
# To avoid storing tensors at the same effective resolution
dg_max = int(np.log2(min(N0)) - 4)
# Create the MR list of wavelets
wavelet_array_MR = []
j_to_dg = []
for j in range(J):
dg = min(j, dg_max)
wavelet_array_MR.append(
DT1_subsampling_func(wavelet_array[j], True, N0, 0, dg, None)[0]
)
j_to_dg.append(dg)
# Values of Single_Kernel and mask_opt
Single_Kernel = False
mask_opt = False
return (
wavelet_array,
wavelet_array_MR,
dg_max,
j_to_dg,
Single_Kernel,
mask_opt,
J,
L,
WType,
)
###############################################################################
[docs]
def gaussian_2d_rotated(mu, sigma, angle, size):
"""
Generate a rotated 2D Gaussian centered at an offset mu along the rotated
axis from image center.
Parameters
----------
mu : float
Offset along the rotated axis from the image center (in pixels).
sigma : float
Isotropic standard deviation (spread).
angle : float
Rotation angle in radians (0 to pi).
size : tuple of int
Grid size (M, N) = (height, width).
Returns
-------
torch.Tensor
A 2D Gaussian (M, N) with unit L2 norm.
"""
M, N = size
x = torch.linspace(0, M - 1, M)
y = torch.linspace(0, N - 1, N)
X, Y = torch.meshgrid(x, y, indexing="ij")
# Image center
cx = M / 2
cy = N / 2
# Compute offset from center along rotated axis
cos_a = torch.cos(torch.tensor(angle))
sin_a = torch.sin(torch.tensor(angle))
center_x = cx - mu * sin_a
center_y = cy + mu * cos_a
# Gaussian centered at (center_x, center_y)
G = torch.exp(-((X - center_x) ** 2 + (Y - center_y) ** 2) / (2 * sigma**2))
# Threshold
eps = 10**-1
G[G < eps] = 0
return G
###############################################################################
[docs]
def gaussian_bank(J, L, size, base_mu=None, base_sigma=None):
"""
Generate a bank of rotated and scaled 2D Gaussians.
Parameters
----------
J : int
Number of dyadic scales.
L : int
Number of orientations.
base_sigma : float
Smallest sigma (spread).
base_mu : float
Base offset along the rotated axis.
size : tuple of int
Grid size (M, N).
Returns
-------
torch.Tensor
A tensor of shape (J, L, M, N), each entry L2-normalized.
"""
M, N = size
filters_bank = torch.empty((J, L, M, N))
if base_mu is None:
base_mu = min(M, N) / (2 * torch.sqrt(torch.tensor(2.0)))
if base_sigma is None:
base_sigma = base_mu / (2 * torch.sqrt(torch.tensor(2.0)))
for j in range(J):
sigma = base_sigma / (2**j)
mu = base_mu / (2**j)
for l in range(L):
angle = float(l) * torch.pi / L
filters_bank[j, l] = gaussian_2d_rotated(mu, sigma, angle, size)
# Return the zero frequency to (0,0), and put it to zero
filters_bank = torch.fft.fftshift(filters_bank, dim=(-2, -1))
filters_bank[:, :, 0, 0] = 0
return filters_bank
###############################################################################
[docs]
def DT1_wavelet_conv_full(data, wavelet_set, Fourier, mask):
"""
Perform a convolution of data by the entire wavelet set at full resolution.
No mask is allowed in this DT.
Parameters
----------
- data : torch.Tensor of size (..., N0)
Data whose convolution is computed
- wavelet_set : torch.Tensor of size (J, L, N0)
Wavelet set
- Fourier:
Fourier status of the data
- mask : torch.Tensor of size (...,N0) -> None expected
Mask for the convolution
Returns
-------
- conv: torch.Tensor (..., J, L, N0)
Convolution between data and wavelet_set
- Fourier: bool
Fourier status of the convolution (True in this DT)
"""
# Pass data in Fourier if in real space
_data = data if Fourier else torch.fft.fft2(data)
# Compute the convolution
conv = _data[..., None, None, :, :] * wavelet_set
# Fourier status related to the DT
Fourier = True
return conv, Fourier
###############################################################################
[docs]
def DT1_wavelet_conv_full_MR(data, wavelet_set, Fourier, j_to_dg, mask_MR):
"""
Perform a convolution of data by the entire wavelet in a multi-resolution
setting.
A multi-resolution mask can be given.
Parameters
----------
- data : list of torch.Tensor of size (..., Nj)
Multi-resolution data whose convolution is computed.
The associated dg are list_dg = range(dg_max + 1)
- wavelet_set : list of torch.Tensor of size (J, L, Nj)
Multi-resolution wavelet set.
The associated dg are j_to_dg
- Fourier:
Fourier status of the data
- j_to_dg : list of int
list of actual dg_j resolutions at each j scale
- mask_MR : list of torch.Tensor of size (...,Nj) -> None expected
Multi-resolution masks for the convolution
Returns
-------
- conv: list of torch.Tensor (..., L, Nj)
Convolution between data and wavelet_set
- Fourier: bool
Fourier status of the convolution (True in this DT)
"""
# Initialize conv
conv = []
for j in range(len(wavelet_set)):
# Pass data in Fourier if in real space
dg = j_to_dg[j]
_data_j = data[dg] if Fourier else torch.fft.fft2(data[dg])
# Compute the convolution
conv.append(_data_j[..., None, :, :] * wavelet_set[j])
# Fourier status related to the DT
Fourier = True
return conv, Fourier
###############################################################################
[docs]
def DT1_wavelet_conv(data, wavelet_j, Fourier, mask_MR):
"""
Perform a convolution of data by the wavelet at a given scale and L
orientation. Both the data and the wavelet should be at the Nj resolution.
No mask is allowed in this DT.
Parameters
----------
- data : torch.Tensor of size (..., Nj)
Data whose convolution is computed, at resolution Nj
- wavelet_set : torch.Tensor of size (L, Nj)
Wavelet set at scale j
- Fourier:
Fourier status of the data
- mask_MR : list of torch.Tensor of size (...,Nj) -> None expected
Multi-resolution masks for the convolution
Returns
-------
- conv: torch.Tensor (..., L, N0)
Convolution between data and wavelet_set at scale j
- Fourier: bool
Fourier status of the convolution (True in this DT)
"""
# Pass data in Fourier if in real space
_data = data if Fourier else torch.fft.fft2(data)
# Compute the convolution
conv = _data[..., None, :, :] * wavelet_j
# Fourier status related to the DT
Fourier = True
return conv, Fourier
###############################################################################
###############################################################################
###############################################################################
###############################################################################
[docs]
def DT1_subsampling_func_fromMR(param):
pass