STL_main package

Submodules

STL_main.DataType1 module

Created on Sun Jul 27 14:23:20 2025

Example methods for a test data type.

2D planar maps with convolution in Fourier.

This class makes all computations in torch.

Characteristics:
  • in pytorch

  • assume real maps (in real space)

  • N0 gives x and y sizes for array shaped (…, Nx, Ny).

  • masks are not supported in convolutions

STL_main.DataType1.DT1_mean_func_MR(array, N0, list_dg, square, mask_MR)[source]

Compute the mean of a list of tensors on their last two dimensions. The other dimensions of the tensors must match.

These means are stacked on the last dimension of the output tensor.

A multi-resolution mask in real space can be given. It should be of unit mean at each resolution.

Parameters:
  • array (list of torch.Tensor) – List of input tensors for which the mean is to be computed.

  • N0 (tuple of int) – Initial resolution of the data (not used in this function).

  • list_dg (list of int) – List of downsampling factors of the data (not used in this function).

  • square (bool) – If True, compute the quadratic mean.

  • mask_mr (list of torch.Tensor, optional) – List of mask tensors at the relevant resolutions. Last dimensions should match with input array. They should be of unit mean at each resolution.

Returns:

Mean of input arrays, stacked on the last dimension.

Return type:

torch.Tensor

STL_main.DataType1.DT1_subsampling_func(array, Fourier, N0, dg, dg_out, mask_MR)[source]

Downsample the data to the specified resolution.

Note: Masks are not supported in this data type.

Parameters:
  • array (torch.Tensor) – Input tensor to be downsampled.

  • Fourier (bool) – Indicates whether input array is in Fourier space.

  • N0 (tuple of int) – Initial resolution of the data.

  • dg (int) – Current downsampling factor of the data.

  • dg_out (int) – Desired downsampling factor of the data.

  • mask_MR (None) – Placeholder for mask, not used in this function.

Returns:

  • torch.Tensor – Downsampled data at the desired downgrading factor dg_out.

  • fourier (bool) – Indicates whether output array is in Fourier space.

STL_main.DataType1.DT1_subsampling_func_fromMR(param)[source]
STL_main.DataType1.DT1_subsampling_func_toMR(array, Fourier, N0, dg_max, mask_MR)[source]

Generate a list of downsampled input array from resolution dg=0 to dg=dg_max, following list_dg = range(dg_max + 1). Input array is expected at dg=0 resolution.

Note: Masks are not supported in this data type.

Parameters:
  • array (torch.Tensor) – Input tensor to be downsampled.

  • Fourier (bool) – Indicates whether the array is in Fourier space.

  • N0 (tuple of int) – Initial resolution of the data.-

  • dg_max (int) – Maximum downsampling factor

  • mask_MR (None) – Placeholder for mask, not used in this function.

Returns:

  • list of torch.Tensor – List of downsampled tensors for each downgrading factor from dg=0 to dg=dg_max.

  • fourier (bool) – Indicates whether output array is in Fourier space.

STL_main.DataType1.DT1_wavelet_build(N0, J, L, WType)[source]

Generate a set of 2D planar wavelets in Fourier space, both in full resolution and in a multi-resolution settings, as well as the related parameters.

Default values for J, L, and Wtype are used if None.

Parameters:
  • N0 (-) – initial size of array (can be multiple dimensions)

  • J (-) – number of scales

  • L (-) – number of orientations

  • WType (-) – type of wavelets (e.g., “Morlet” or “Bump-Steerable”)

Returns:

  • wavelet_array (torch.Tensor of size (J,L,N0)) – Array of wavelets at J*L scales and orientation at N0 resolution.

  • wavelet_array_MR (list of torch.Tensor of size (L,Nj)) – list of arrays of L wavelets at all J scales and at Nj resolution.

  • dg_max (int) – Maximum dg downsampling factor

  • - j_to_dg (list of int) – list of actual dg_j resolutions at each j scale

  • - Single_Kernel (bool -> False here) – If convolution done at all scales with the same L oriented wavelets

  • - mask_opt (bool -> False here) – If it is possible to do use masked during the convolution

STL_main.DataType1.DT1_wavelet_conv(data, wavelet_j, Fourier, mask_MR)[source]

Perform a convolution of data by the wavelet at a given scale and L orientation. Both the data and the wavelet should be at the Nj resolution.

No mask is allowed in this DT.

Parameters:
  • data (-) – Data whose convolution is computed, at resolution Nj

  • wavelet_set (-) – Wavelet set at scale j

  • Fourier (-) –

    Fourier status of the data - mask_MR : list of torch.Tensor of size (…,Nj) -> None expected

    Multi-resolution masks for the convolution

Returns:

  • - conv (torch.Tensor (…, L, N0)) – Convolution between data and wavelet_set at scale j

  • - Fourier (bool) – Fourier status of the convolution (True in this DT)

STL_main.DataType1.DT1_wavelet_conv_full(data, wavelet_set, Fourier, mask)[source]

Perform a convolution of data by the entire wavelet set at full resolution.

No mask is allowed in this DT.

Parameters:
  • data (-) – Data whose convolution is computed

  • wavelet_set (-) – Wavelet set

  • Fourier (-) – Fourier status of the data

  • mask (-) – Mask for the convolution

Returns:

  • - conv (torch.Tensor (…, J, L, N0)) – Convolution between data and wavelet_set

  • - Fourier (bool) – Fourier status of the convolution (True in this DT)

STL_main.DataType1.DT1_wavelet_conv_full_MR(data, wavelet_set, Fourier, j_to_dg, mask_MR)[source]

Perform a convolution of data by the entire wavelet in a multi-resolution setting.

A multi-resolution mask can be given.

Parameters:
  • data (-) – Multi-resolution data whose convolution is computed. The associated dg are list_dg = range(dg_max + 1)

  • wavelet_set (-) – Multi-resolution wavelet set. The associated dg are j_to_dg

  • Fourier (-) – Fourier status of the data

  • j_to_dg (-) –

    list of actual dg_j resolutions at each j scale - mask_MR : list of torch.Tensor of size (…,Nj) -> None expected

    Multi-resolution masks for the convolution

Returns:

  • - conv (list of torch.Tensor (…, L, Nj)) – Convolution between data and wavelet_set

  • - Fourier (bool) – Fourier status of the convolution (True in this DT)

STL_main.DataType1.gaussian_2d_rotated(mu, sigma, angle, size)[source]

Generate a rotated 2D Gaussian centered at an offset mu along the rotated axis from image center.

Parameters:
  • mu (float) – Offset along the rotated axis from the image center (in pixels).

  • sigma (float) – Isotropic standard deviation (spread).

  • angle (float) – Rotation angle in radians (0 to pi).

  • size (tuple of int) – Grid size (M, N) = (height, width).

Returns:

A 2D Gaussian (M, N) with unit L2 norm.

Return type:

torch.Tensor

STL_main.DataType1.gaussian_bank(J, L, size, base_mu=None, base_sigma=None)[source]

Generate a bank of rotated and scaled 2D Gaussians.

Parameters:
  • J (int) – Number of dyadic scales.

  • L (int) – Number of orientations.

  • base_sigma (float) – Smallest sigma (spread).

  • base_mu (float) – Base offset along the rotated axis.

  • size (tuple of int) – Grid size (M, N).

Returns:

A tensor of shape (J, L, M, N), each entry L2-normalized.

Return type:

torch.Tensor

STL_main.DataType2 module

Created on Sun Jul 27 14:23:20 2025

Example methods for a test data type.

STL_main.DataType2.DT2_Mask_toMR(mask, N0, dg_max)[source]
STL_main.DataType2.DT2_copy(array)[source]
STL_main.DataType2.DT2_cov_func(param)[source]
STL_main.DataType2.DT2_findN(array)[source]
STL_main.DataType2.DT2_fourier(array)[source]
STL_main.DataType2.DT2_ifourier(array)[source]
STL_main.DataType2.DT2_mean_func(param)[source]
STL_main.DataType2.DT2_mean_func_MR(param)[source]
STL_main.DataType2.DT2_modulus(array)[source]
STL_main.DataType2.DT2_subsampling_func(array, Fourier, N0, dg, dg_out, mask_MR)[source]
STL_main.DataType2.DT2_subsampling_func_fromMR(param)[source]
STL_main.DataType2.DT2_subsampling_func_toMR(array, Fourier, N0, dg, dg_max, mask_MR)[source]
STL_main.DataType2.DT2_to_array(array)[source]
STL_main.DataType2.DT2_wavelet_build(param)[source]
STL_main.DataType2.DT2_wavelet_conv(param)[source]
STL_main.DataType2.DT2_wavelet_conv_full(param)[source]
STL_main.DataType2.DT2_wavelet_conv_full_MR(param)[source]

STL_main.LossOptim module

Main structure of STL

Tentative proposal by EA

Still WIP, I try in particular to identify the minimum parameters which are necessary.

I ork only with single channel maps here, need to be extended to multi-channel ones.

STL_main.LossOptim.synth_from_map(data_target, mode='1to1', N_new=None, J=None, L=None, WType=None, SC='ScatCov')[source]

Perform a synthesis from a target map, or an ensemble of target map.

The data_target needs to be a stl_data object, of shape (Nb,Nc,N).

The synthes

STL_main.LossOptim.synth_from_stats(st_target, mode='1to1', N_new=None)[source]

STL_main.STL_2D_FFT_Torch module

Created on Wed Nov 14:07 2018

class STL_main.STL_2D_FFT_Torch.CS_operator_2D_FFT_torch(shape, n_bins=None, J=None, device=device(type='cpu'), dtype=torch.float64, get_crop_border_size_method='flexible_crop')[source]

Bases: object

Class whose instances correspond to a cross spectrum operator for 2D FFT data. The operator is applied through apply method and is DT-dependent.

apply(data, compute_cross_spectrum_matrix=None, get_crop_border_size_method=None)[source]

Compute the power spectrum of the input data array attribute.

Parameters:
  • data (-) – Input data whose array attribute’s power spectrum is to be computed.

  • compute_cross_spectrum_matrix (-) – Boolean matrix indicating which cross-spectra to compute. If None, only auto-spectra are computed.

  • get_crop_border_size_method (-) – Method to determine crop border size for non-PBC data. If None, uses the default method specified in the operator initialization.

Returns:

Cross spectrum values of shape […, Nc, Nc, n_bins]

Return type:

torch.Tensor

build_mask_crop(array, border)[source]

Crops an array by removing ‘border’ pixels from each side along the last two dimensions. Pads with zeros for each cropped side (border may be different for each bin) to keep the same output shape.

Parameters:
  • array (torch.Tensor) – Input array to be cropped.

  • border (torch.Tensor) – Number of pixels to remove from each side. Shape [n_bins].

Returns:

Cropped array. Shape [Nb, Nc, n_bins, N, M].

Return type:

torch.Tensor

estimate_crop_borders()[source]
classmethod k_lambda(t, lam)[source]
classmethod kappa_lambda(t, lam)[source]
plot_cross_spectrum(cs_tensor, b=0, c1=0, c2=0, label=None, color='b')[source]

Plot the power spectrum. :param b: Batch index (0<=b<Nb) :type b: int :param c1: Channel indices (0<=c1,c2<Nc) :type c1: int :param c2: Channel indices (0<=c1,c2<Nc) :type c2: int :param cs_tensor: Cross spectrum values to plot :type cs_tensor: torch.Tensor of shape [Nb, Nc, Nc, n_bins]

Return type:

None

static s(t)[source]
classmethod s_lambda(t, lam)[source]
class STL_main.STL_2D_FFT_Torch.STL_2D_FFT_Torch(array: Tensor, pbc: bool | None = None, dg: int | None = None, N0: tuple[int, int] | None=None, conv_history: list[int] = <factory>, fourier_status: bool = False)[source]

Bases: Base_DataClass

STL_2D_FFT_torch child class for 2D planar STL FFT using PyTorch

Inherits Base_DataClass.

See Base_DataClass for parameter descriptions.

Additional parameters

fourier_statusbool

Indicates if the data is in Fourier space (True) or real space (False).

DT: ClassVar[str] = 'Planar2D_FFT_torch'
divide(data2, epsilon=1e-08, pow=1.0, inplace=False)[source]

Divide self.array by data2.array raised to a power, with a small epsilon added to the denominator for numerical stability.

  • If data2.array is in real space:

    self.array is converted to real space (if needed), and the division is performed in real space.

  • If data2.array is in Fourier space:

    self.array is converted to Fourier space (if needed), and the division is performed in Fourier space (i.e., deconvolution in real space).

Parameters:
  • data2 (STL_2D_FFT_Torch) – Another instance whose array is used as the denominator. Its Fourier status determines the computation domain.

  • epsilon (float, optional) – Small constant added to the denominator for numerical stability (default is 1e-8).

  • power (float, optional) – Exponent applied to the denominator (default is 1).

  • inplace (bool) – If True, performs the operation in-place and returns self. If False, returns a new instance.

Returns:

Result of the division in the appropriate domain.

Return type:

STL_2D_FFT_Torch

fourier(inplace=False)[source]

Compute the Fourier Transform on the last two dimensions of the input tensor.

Parameters:

inplace (-) – If True, acts in-place and returns self. If False, returns a new STL_2D_FFT_Torch instance.

Returns:

STL_2D_FFT_Torch instance whose array attribute is Fourier domain

Return type:

STL_2D_FFT_Torch

fourier_status: bool = False
get_CS_op(*args, **kwargs)[source]
get_ST_op(*args, **kwargs)[source]
get_wavelet_op(*args, **kwargs)[source]

Abstract method. Must be implemented by the child class to return the specific WaveletOperator class for that child.

ifourier(inplace=False)[source]

Compute the inverse Fourier Transform on the last two dimensions of the input tensor.

Parameters:

inplace (-) – If True, acts in-place and returns self. If False, returns a new STL_2D_FFT_Torch instance.

Returns:

STL_2D_FFT_Torch instance whose array attribute is the Fourier

Return type:

STL_2D_FFT_Torch

modulus(inplace=False)[source]

Compute the modulus (absolute value) of the array attribute of data.

Parameters:

inplace (-) – If True, acts in-place and returns self. If False, returns a new STL_2D_FFT_Torch instance.

Returns:

STL_2D_FFT_Torch instance whose array attribute is the modulus

Return type:

STL_2D_FFT_Torch

set_fourier_status(target_fourier_status, inplace=False)[source]

Put the in the desired Fourier status (target_fourier_status).

Parameters:
  • target_fourier_status (-) – Desired Fourier status: True = Fourier space, False = real space.

  • inplace (-) – If True, acts in-place and returns self. If False, returns a new STL_2D_FFT_Torch instance.

Returns:

STL_2D_FFT_Torch instance in the desired Fourier status.

Return type:

STL_2D_FFT_Torch

class STL_main.STL_2D_FFT_Torch.WaveletOperator2D_FFT_torch(N0, J=None, L=None, WType='Bump-Steerable', DT='Planar2D_FFT_torch', device=device(type='cpu'), dtype=torch.float64, get_crop_border_size_method=None)[source]

Bases: object

Class whose instances correspond to a wavelet transform operator. The wavelet set and the operator is built during the initilization. The operator is applied through apply method. This method is DT-dependent, and actually calls independent iterations, but with common method and attribute structure.

The multi-resolution is dealt with several parameters:

dg_max, which indicates the maximum dg resolution j_to_dg, which indicate the dg_j resolution associated to each j scale.

For example, if you work with J=6 wavelets and N0=256, you can have:
  • dg_max=4, with associated dg factors (0, 1, 2, 3, 4)

  • true downsampling factors (1, 2, 4, 8, 16)

  • actual resolutions (256, 128, 64, 32, 16)

  • a j_to_dg list (0, 0, 1, 1, 2, 3) associated to j in range(J=6)

The wavelet convolution is DT-dependent, an is performed either in real or Fourier space.

They are two main types of wavelet arrays.
  • Single_Kernel==True: In this case, a set of L oriented wavelets is

defined at a single pixellation. Convolutions at different scales are then down by subsequent susampling and convolution in pixel space with this single set of L oriented wavelets, and convolutions at all scales can not be done at the initial N0 resolution. - SR_Kernel==False: In this case, a set of J*L dilated and rotated wavelets is defined at the initial N0 resolution, and a convolution at all scales at this initial resolution can be performed. Convolution in a multi-resolution scheme can also be done, where convolution at each scale is done at the proper downsampling factor.

This mean that different wavelets need to be stored:
  • Single_Kernel==True: a single set of L wavelets, stored in the

wavelet_array attribute. - SR_Kernel==False: a set of J*L wavelets, store both at the initial N0 resolution in wavelet_array, and in a multi-resolution framework in wavelet_array_MR.

=> The wavelet_j method then allows to call for the correct quantity when

a convolution is performed.

Rq: when Single_Kernel==True, j_to_dg = range(J). This is not necessarily

the case when False.

See apply for more details.

Parameters:
  • DT (-) – Type of data (1d, 2d planar, HealPix, 3d)

  • N0 (-) – initial size of array (can be multiple dimensions)

  • J (-) – number of scales

  • L (-) – number of orientations

  • WType (-) – type of wavelets (e.g., “Morlet” or “Bump-Steerable”)

- parent parameters (DT,N0,J,L,WType)
- dg_max

maximum dg resolution of the Wavelet Transform (DT- and WType-dependent)

Type:

int (DT- and WType-dependent)

- j_to_dg

list of actual dg_j resolutions at each j scale

Type:

list of int

- wavelet_array
  • array of wavelets at L orientation if Single_Kernel==True.

  • array of wavelets at J*L scales and orientation at N0 resolution

if Single_Kernel==False

Type:

torch tensor

- wavelet_array_MR

list of arrays of L wavelets at all J scales and at Nj resolution Only if if Single_Kernel==False.

Type:

list (len J) of arrays

- Single_Kernel

if convolution done at all scales with the same L oriented wavelets

Type:

bool

Questions and to do
----------
- Is Single_Kernel sufficient in itself? We could separate the fact to do
the convolution with a kernel in real space and to be able to do all
convolution at the initial N0 resolution, using two different attribute.
- We could similarly attach the fact to use mask to the fact to have
Single_Kernel==True.
- Do we add a low-pass filter per default for j=J ?
- Do we impose j_to_dg = range(J) for simplicity and efficiency?
- I propose to anyway only have dyadic wavelets for the "main set".

-> the inclusion of P00’ power spectrum terms can be done differently.

- for __init__, we could ask either for DT and N, or for a stl_array

instance, from which DT and N are obtained. Could be added if useful.

- A proper "by the book" set of Wavelets should be implemented, with proper
Littelwood-Paley and co conditions.
apply(data, j=None, target_fourier_status=None, **kwargs)[source]

Compute the Wavelet Transform (WT) of data. This method is DT dependent, and calls independent iterations with common method and attribute structure.

Data should be a MR==False StlData instance. The wavelet transform can either be computed at all J*L scales and angles (fullJ), or at a given j scale (single_j) for all L orientations?

For code efficiency, this method requires a MR=True StlData instance for the masks at all resolution, with list_dg = range(dg_max + 1).

The different modes are: - fullJ (j=None and MR=False): convolution at J*L scales and angles,

without a MR framework, only if Single_Kernel==False. Input data should be a MR=False StlData instance at dg=0 resolution. No mask are a priori allowed in this case.

  • fullJ_MR (j=None and MR=True): convolution at J*L scales and angles,

    within a MR framework, always possible. Input data should be a MR=True StlData instance at all resolution between dg=0 and dg_max [ lids_dg = range(dg_max) ]

  • single_j: at L angles at a given scale j, within a MR framework.

    Input data should be a MR=False StlData instance wtih dg = dg_j.

Rq: If j = None, the defaut value for MR if j=None is False if Single_Kernel==False, and True else. For a single_j convolution, MR can only be true. Rq: mask_MR is allowed only if mask_opt==True.

Parameters:
  • data (-) – Input data of same DT/N0, can be batched on several dimension. -> dg=0 if fullJ -> dg=dg_j if single_j

  • j (-) – Scale at which the convolution is done. Done at all scales if None.

  • target_fourier_status (-) – Desired Fourier status of output. If None, DT-dependent default is used.

  • Output

  • ----------

  • WT (-) – -> […, J, L, N0] if j is None -> […, L, Nj] if j == int

  • do (Questions and to)

  • ----------

  • here (could be necessary)

  • but

  • functions (only in the mean and cov)

  • computations. (at the end of the)

  • (j (- We could think at the possibility to compute WT at fixed)

  • l)

  • values

  • batchs. (if it helps distributing the computations for large)

  • mask_MR (- To decide if we impose a condition on)

  • it (like the fact that)

  • mean. (is on unit)

  • transform (- I'm a bit skeptical by the fact that an internal Fourier)

  • here

  • could (since it means that the same transform)

  • method. (have to be on multiple call of this)

  • are (- For the convolution at a fixed scale. Should we accept data that)

  • usage (not at Nj resolution and downsample them? It need to be see with)

static bump_steerable_2d(omega_grid, L, xi0, width_factor=2.5, eps=1e-12)[source]

Generate a 2D bump steerable wavelet in Fourier space.

Parameters:
  • omega_grid (torch.Tensor) – Grid of frequencies in Fourier space, of shape [Nx, Ny, 2], where the last dimension corresponds to (omega_x, omega_y).

  • L (int) – Number of orientations (steerability order).

  • xi0 (float) – Center frequency xi = (xi0, 0) in Fourier space.

  • width_factor (float) – Optimized constant to follow at best Littlewood-Paley condition (optimized for L=4).

  • eps (float) – Small constant to avoid division by zero in the bump window.

Returns:

A 2D bump steerable wavelet in Fourier space, of shape [Nx, Ny].

Return type:

torch.Tensor

classmethod bump_steerable_bank(J, L, size)[source]

Generate a bank of 2D bump steerable wavelets in Fourier space.

Parameters:
  • J (int) – Number of dyadic scales.

  • L (int) – Number of orientations (steerability order).

  • size (tuple of int) – Grid size (Nx, Ny).

Returns:

A tensor of shape [J, L, Nx, Ny]

Return type:

torch.Tensor

cov(data1, data2, remove_mean=False, dim=(-2, -1), **kwargs)[source]
static downsample(data, dg_out, normalize=True, inplace=True, target_fourier_status=True, **kwargs)[source]

Downgrade the data array to dg_out resolution by cropping in Fourier space.

Parameters:
  • data (STL_2D_FFT_Torch) – Data object to be downgraded (currently at dg_in resolution).

  • dg_out (int) – Target resolution after downgrading.

  • normalize (bool) – If True, normalize the output in fourier space to keep the same real mean as input.

  • inplace (bool) – If True, modifies data in-place. If False, returns a new instance.

  • target_fourier_status (bool) – If True, output is in Fourier space. If False, output is in real space and normalized to keep the same real mean as input.

Notes

  • To remain consistent, if you downgrade an image that was already downgraded,

it is recommended to keep the output in the same domain (Fourier or real) as the previous data. Otherwise, normalization issues may appear.

Returns:

Data downgraded to dg_out resolution.

Return type:

STL_2D_FFT_Torch

estimate_crop_borders()[source]
static gaussian_2d_rotated(mu, sigma, angle, size)[source]

Generate a rotated 2D Gaussian centered at an offset mu along the rotated axis from image center.

Parameters:
  • mu (float) – Offset along the rotated axis from the image center (in pixels).

  • sigma (float) – Isotropic standard deviation (spread).

  • angle (float) – Rotation angle in radians (0 to pi).

  • size (tuple of int)

Returns:

A 2D Gaussian of shape [Nx, Ny].

Return type:

torch.Tensor

classmethod gaussian_bank(J, L, size, base_mu=None, base_sigma=None)[source]

Generate a bank of rotated and scaled 2D Gaussians.

Parameters:
  • J (int) – Number of dyadic scales.

  • L (int) – Number of orientations.

  • base_sigma (float) – Smallest sigma (spread).

  • base_mu (float) – Base offset along the rotated axis.

  • size (tuple of int) – Grid size (M, N).

Returns:

A tensor of shape [J, L, Nx, Ny], each entry L2-normalized.

Return type:

torch.Tensor

mean(data, dim=(-2, -1), **kwargs)[source]

Compute the mean on the last two dimensions (Nx, Ny).

Parameters:
  • data (-) – Input data. Array should in real space in ST_op workflow

  • dim (-) – Dimensions on which the mean is computed.

square_mean(data, dim=(-2, -1), **kwargs)[source]
standardize(data, mean_field, inplace=False, dim=None)[source]

Standardize the data by removing the mean and scaling to unit variance on the last two dimensions (Nx, Ny) in real space.

Parameters:
  • data (-) – Input data whose array attribute has to be standardized.

  • mean_field (-) – If True, compute mean/std averaged over the batch dimension.

  • inplace (-) – If True, perform the operation in-place on the input data.

  • dim (-) – Dimensions over which to compute the mean and standard deviation.

Returns:

  • - STL_2D_Kernel_Torch – Standardized data.

  • - torch.Tensor – Mean used for standardization.

  • - torch.Tensor – Standard deviation used for standardization.

unstandardize(data, mean, std, inplace=False, dim=None)[source]

Unstandardize the data by scaling back using the provided mean and std.

Parameters:
  • data (-) – Input data whose array attribute has to be unstandardized.

  • mean (-) – Mean used for standardization.

  • std (-) – Standard deviation used for standardization.

Returns:

Unstandardized data.

Return type:

  • STL_2D_FFT_Torch

static wavelet_conv(data, wavelet_set_MR, j)[source]

Perform convolutions of data with a set of L wavelets fixed at a given scale and covering all orientations. Both the data and the wavelet should be at the Nj resolution. WARNING: Sets the data in Fourier space in place if data is in real space.

Parameters:
  • data (-) – Data to be convolved with the wavelet_set, at resolution Nj

  • wavelet_set_MR (-)

  • j (-) – Scale index to select the wavelet set at resolution Nj

Returns:

  • array: torch.Tensor […, L, Njx, Njy]

    Convolution in Fourier space between data and wavelet_set at scale j

  • fourier_status: bool

    True

Return type:

  • STL_2D_FFT_Torch instance with

static wavelet_conv_full(data, wavelet_set)[source]

Perform convolutions of data with the entire wavelet set at full resolution. WARNING: Sets the data in Fourier space in place if data is in real space.

Parameters:
  • data (-) – Data to be convolved with the wavelt_set

  • wavelet_set (-) – Wavelet set in Fourier space at all J scales and L orientations

Returns:

  • array: torch.Tensor […, J, L, Nx, Ny]

    Convolution in Fourier space between data and wavelet_set

  • fourier_status: bool

    True

Return type:

  • STL_2D_FFT_Torch instance with

STL_main.STL_2D_Kernel_Torch module

Created on Tuesday Nov 26 2025

Example methods for a test data type.

2D planar maps with convolution using kernel.

This class makes all computations in torch.

Characteristics:
  • in pytorch

  • assume real maps

  • N0 gives x and y sizes for array shaped (…, Nx, Ny).

  • masks are supported in convolutions

class STL_main.STL_2D_Kernel_Torch.CS_operator_2D_Kernel_Torch(shape, n_bins=None, J=None, device=device(type='cpu'), dtype=torch.float64, get_crop_border_size_method='flexible_crop', cross_spectrum_method='fft')[source]

Bases: object

Class whose instances correspond to a cross spectrum operator for 2D Kernel data. The operator is applied through apply method and is DT-dependent.

apply(data, compute_cross_spectrum_matrix=None, get_crop_border_size_method=None, cross_spectrum_method=None, **kwargs)[source]
apply_fft(data, compute_cross_spectrum_matrix=None, get_crop_border_size_method=None)[source]

Compute the power spectrum of the input data array attribute.

Parameters:
  • data (-) – Input data whose array attribute’s power spectrum is to be computed.

  • compute_cross_spectrum_matrix (-) – Boolean matrix indicating which cross-spectra to compute. If None, only auto-spectra are computed.

  • get_crop_border_size_method (-) – Method to determine crop border size for non-PBC data. If None, uses the default method specified in the operator initialization.

Returns:

Cross spectrum values of shape […, Nc, Nc, n_bins]

Return type:

torch.Tensor

build_mask_crop(array, border)[source]

Crops an array by removing ‘border’ pixels from each side along the last two dimensions. Pads with zeros for each cropped side (border may be different for each bin) to keep the same output shape.

Parameters:
  • array (torch.Tensor) – Input array to be cropped.

  • border (torch.Tensor) – Number of pixels to remove from each side. Shape [n_bins].

Returns:

Cropped array. Shape [Nb, Nc, n_bins, N, M].

Return type:

torch.Tensor

estimate_crop_borders()[source]
classmethod k_lambda(t, lam)[source]
classmethod kappa_lambda(t, lam)[source]
plot_cross_spectrum(cs_tensor, b=0, c1=0, c2=0, label=None, color='b')[source]

Plot the power spectrum. :param b: Batch index (0<=b<Nb) :type b: int :param c1: Channel indices (0<=c1,c2<Nc) :type c1: int :param c2: Channel indices (0<=c1,c2<Nc) :type c2: int :param cs_tensor: Cross spectrum values to plot :type cs_tensor: torch.Tensor of shape [Nb, Nc, Nc, n_bins]

Return type:

None

static s(t)[source]
classmethod s_lambda(t, lam)[source]
class STL_main.STL_2D_Kernel_Torch.STL_2D_Kernel_Torch(array: Tensor, pbc: bool | None = None, dg: int | None = None, N0: tuple[int, int] | None=None, conv_history: list[int] = <factory>)[source]

Bases: Base_DataClass

STL_2D_Kernel_Torch child class for 2D planar STL Kernel using PyTorch

Inherits Base_DataClass.

See Base_DataClass for parameter descriptions.

Additional comments

The initial resolution N0 is fixed, but the maps can be downgraded. The downgrading factor is the power of 2 that is used. A map of initial resolution N0=256 and with dg = 3 is thus at resolution 256/2^3 = 32. The downgraded resolutions are called N0, N1, N2, …

Can store array at a given downgradind dg:
  • attribute MR is False

  • attribute N0 gives the initial resolution

  • attribute dg gives the downgrading level

  • array is an array of size (…, N) with N = N0 // 2^dg

Or at multi-resolution (MR):
  • attribute MR is True

  • attribute N0 gives the initial resolution

  • attribute dg is None

  • array is a list of array of sizes (…, N1), (…, N2), etc.,

with the same dimensions excepts N.

Method usages if MR=True.
  • mean, cov give a single vector or last dim len(list_N)

  • downsample gives an output of size (…, len(list_N), Nout). Only possible if all resolution are downsampled this way.

The class initialization is the frontend one, which can work from DT and data only. It enforces MR=False and dg=0. Two backend init functions for MR=False and MR=True also exist.

- DT

Type of data (1d, 2d planar, HealPix, 3d)

Type:

str

- N0

Initial size of array (can be multiple dimensions)

Type:

tuple of int

- dg

2^dg is the downgrading level w.r.t. N0.

Type:

int

- array

array(s) to store

Type:

array (…, N)

DT: ClassVar[str] = 'Planar2D_kernel_torch'
divide(data2, epsilon=1e-08, pow=1.0, inplace=False)[source]

Divide self.array by data2.array raised to a power, with a small epsilon added to the denominator for numerical stability.

Division is still performed in real space

Parameters:
  • data2 (STL_2D_FFT_Torch) – Another instance whose array is used as the denominator. Its Fourier status determines the computation domain.

  • epsilon (float, optional) – Small constant added to the denominator for numerical stability (default is 1e-8).

  • power (float, optional) – Exponent applied to the denominator (default is 1).

  • inplace (bool) – If True, performs the operation in-place and returns self. If False, returns a new instance.

Returns:

Result of the division in the appropriate domain.

Return type:

STL_2D_FFT_Torch

get_CS_op(*args, **kwargs)[source]
get_ST_op(*args, **kwargs)[source]
get_wavelet_op(J=None, mask_full_res=None, *args, **kwargs)[source]

Abstract method. Must be implemented by the child class to return the specific WaveletOperator class for that child.

modulus(inplace=False)[source]

Compute the modulus (absolute value) of the array attribute of data.

Parameters:

inplace (-) – If True, acts in-place and returns self. If False, returns a new STL_2D_Kernel_Torch instance.

Returns:

STL_2D_Kernel_Torch instance whose array attribute is the modulus

Return type:

STL_2D_Kernel_Torch

class STL_main.STL_2D_Kernel_Torch.WaveletOperator2Dkernel_torch(J, L=None, kernel_size=None, WType='Bump-Steerable', DT='Planar2D_kernel_torch', device=device(type='cpu'), dtype=torch.float64, mask_full_res=None, sigma_smooth=1.0, downsample_nan_weight_threshold=0.33, get_crop_border_size_method=None)[source]

Bases: object

apply(data, j)[source]

Apply the convolution kernel to data.array […, Nx, Ny] and return cdata […, L, Nx, Ny].

Parameters:

data (object) – Object with an attribute array storing the data as a tensor or numpy array with shape […, Nx, Ny].

Returns:

Convolved data with shape […, L, Nx, Ny].

Return type:

torch.Tensor

cov(data1, data2, remove_mean=None, dim=None, specific_channel_pair=None)[source]

Compute the covariance between data1=self and data2 on the last two dimensions (Nx, Ny).

downsample(data, dg_out, inplace=True, replace_nan_value=nan)[source]

Downsample the data to the dg_out resolution. Downsampling is done in real space along the last two dimensions using (successive iterations of, if dg_out - dg > 1) torch.conv2d with stride=2. If a mask is provided at full resolution, the downsampling is nan-aware, and sufficiently isolated NaNs can be removed through local averaging.

mean(data, dim=None)[source]

Compute the mean on the last two dimensions (Nx, Ny).

square_mean(data, dim=(-2, -1), **kwargs)[source]
standardize(data, mean_field, inplace=False, dim=None)[source]

Standardize the data by removing the mean and scaling to unit variance on the last two dimensions (Nx, Ny) in real space.

Parameters:
  • data (-) – Input data whose array attribute has to be standardized.

  • mean_field (-) – If True, compute mean/std averaged over the batch dimension.

  • inplace (-) – If True, perform the operation in-place on the input data.

  • dim (-) – Dimensions over which to compute the mean and standard deviation.

Returns:

  • - STL_2D_Kernel_Torch – Standardized data.

  • - torch.Tensor – Mean used for standardization.

  • - torch.Tensor – Standard deviation used for standardization.

unstandardize(data, mean, std, inplace=False)[source]

Unstandardize the data by scaling back using the provided mean and std.

Parameters:
  • data (-) – Input data whose array attribute has to be unstandardized.

  • mean (-) – Mean used for standardization.

  • std (-) – Standard deviation used for standardization.

Returns:

Unstandardized data.

Return type:

  • STL_2D_Kernel_Torch

STL_main.STL_Healpix_Kernel_Torch module

STL_main.STL_utils module

class STL_main.STL_utils.Gaussianise(x_ref: Tensor, eps: float = 1e-06)[source]

Bases: object

Rank-based Gaussianisation with invertible mapping via interpolation.

Method:
  • Fit on a reference tensor x_ref.

  • Build empirical CDF F_X from sorted(x_ref).

  • Forward: for any x, approximate u = F_X(x) by linear interpolation, then z = Phi^{-1}(u) (standard normal).

  • Inverse: for any z, u = Phi(z), then approximate x = F_X^{-1}(u) by linear interpolation between sorted samples.

This is designed to be differentiable w.r.t. z in the inverse, and w.r.t. x in the forward, except for the non-differentiable dependence of the sorting (piecewise constant).

forward(x: Tensor) Tensor[source]

Gaussianise input tensor x based on the reference distribution.

Steps:
  • For each value, approximate u = F_X(x) via interpolation in (x_sorted, u_sorted).

  • Then z = Φ^{-1}(u).

Parameters:

x (torch.Tensor) – Data tensor (any shape).

Returns:

z – Gaussianised tensor with same shape as x.

Return type:

torch.Tensor

invert(z: Tensor) Tensor[source]

Map Gaussianised tensor z back to the original data space using the inverse empirical CDF.

Steps:
  • u = Φ(z)

  • x = F_X^{-1}(u) via interpolation in (u_sorted, x_sorted).

Parameters:

z (torch.Tensor) – Gaussian-space tensor (any shape).

Returns:

x_rec – Reconstructed data tensor in the original amplitude space.

Return type:

torch.Tensor

STL_main.ST_Operator module

Main structure of STL

Tentative proposal by EA

class STL_main.ST_Operator.ST_Operator(data_example, J=None, L=None, SC='ScatCov', has_fewer_convolutions=False, replace_nan_value=nan, mask_full_res=None, norm='store_ref', S2_ref_sqrt_chan_diag=None, iso=False, angular_ft=False, scale_ft=False, flatten=False, mask_st=None, dj=None, harmonics_angle=None, harmonics_scale=None, compute_PS=True, PS_ref_sqrt_chan_diag=None, var_ref=None, WType=None, downsample_nan_weight_threshold=None, get_crop_border_size_method=None, n_bins=None)[source]

Bases: object

Class whose instances correspond to scattering transforms operators. The operator is built through __init__ method. The operator is applied through apply method. This operator is DT-independent, and call sub-functions with common I/O structure, which in turn rely on DT-dependent backend.

When the ST operator is applied to some data, it creates an instance of the ST statistics where all necessary parameters are passed such that the ST operator that was used in the computation can be reconstructed from it if necessary.

To allow that, a default setting for all parameters used for the apply method can be stored in the ST operator.

A prescription is also given on the order of which the different normalizations/compression can be done:

norm -> iso -> angular_ft -> scale_ft -> flatten (mask_st)

Not every transform can be used, but the ordering should be respected. For instance:

vanilla -> norm -> angular_ft -> flatten (mask_st) vanilla -> iso -> scale_ft

This allow the operators to be defined in a unique way from these parameters.

Mask can be stored in the operator.

Parameters:
  • Transform (# Scattering)

  • data (-) – Data (1d, 2d planar, HealPix, 3d) ##################################################

  • J (-) – number of scales

  • L (-) – number of orientations

  • WType (-) – type of wavelets

  • Transform

  • SC (-) – type of ST coefficients (“ScatCov”, “WPH”)

  • has_fewer_convolutions (-) – For “ScatCov” type, whether the S3 and S4 coefficients are computed with one convolution less (Sihao version)

  • transform/compression (# Additional)

  • norm (-) – type of norm (“self”, “from_ref”)

  • S2_ref_sqrt_chan_diag (-) – array of reference S2 coefficients (square root of the diagonal over channels)

  • iso (-) – keep only isotropic coefficients

  • angular_ft (-) – perform angular fourier transform on the ST statistics

  • scale_ft (-) – perform scale cosine transform on the ST statistics

  • flatten (-) – only return a 1D-array and not a ST_Statistics instance

  • mask_st (-) – mask to be applied when flatten ST statistics

  • computation (# Power spectrum)

  • PS (-) – whether to compute power spectrum coefficients in addition to ST statistics

  • PS_ref (-) – array of reference PS coefficients

- parent parameters (see above)
- wavelet_op

Wavelet Transform operator

Type:

Wavelet_Transform class

apply(data, standardize=False, SC=None, has_fewer_convolutions=None, norm=None, S2_ref_sqrt_chan_diag=None, norm_batch_mean=False, iso=None, angular_ft=None, scale_ft=None, flatten=None, mask_st=None, compute_PS=None, PS_ref_sqrt_chan_diag=None, var_ref=None, compute_cross_matrix=None)[source]

Compute the Scattering Transform (ST) of data, which are either stored in an instance of the ST statistics class, or returned as a flatten array.

This DT-independent methods calls sub-functions which have a common I/O structure, and in turn rely on DT-dependent backend.

It outputs an instance of the Scattering Statistics class, whose additional methods can be called directly to get the desired output.

Uses ST operator parameters unless explicitly overridden in apply.

!!! Attention: I give an example in torch here, but we should consider how to include different backend !!!

!!! Attention: I give here the version with standard scat cov !!!

Parameters:
  • Data (#)

  • data (-) – data, Nc number of channel, Nb batch size. Should have dg=0.

  • Transform (# Scattering)

  • SC (-) – type of ST coefficients (“ScatCov”, “WPH”)

  • has_fewer_convolutions (-) – For “ScatCov” type, whether the S3 and S4 coefficients are computed with one convolution less (Sihao version)

  • pass_mask (-) – Pass mask to ST statistics object if True

  • transform/compression (# Additional)

  • norm (-) – type of norm (“self”, “from_ref”)

  • S2_ref_sqrt_chan_diag (-) – array of reference S2 coefficients (square root of the diagonal over channels)

  • iso (-) – keep only isotropic coefficients

  • angular_ft (-) – perform angular fourier transform on the ST statistics

  • scale_ft (-) – perform scale cosine transform on the ST statistics

  • flatten (-) – only return a 1D-array and not a ST_Statistics instance

  • mask_st (-) – mask to be applied when flatten ST statistics

  • computation (# Cross statistics)

  • compute_PS (-) – whether to compute power spectrum coefficients in addition to ST statistics

  • PS_ref_sqrt_chan_diag (-) – array of reference PS coefficients

  • computation

  • compute_cross_matrix (-) –

    Upper triangular matrix with shape (Nc,Nc), which determines pairs of channels for which to compute cross-statistics. More precisely:

    • computes S1(c1), S2(c1,c1), S3(c1,c1) and S4(c1,c1) if and only if compute_cross_matrix[c1,c1] == True

    • for c1 < c2, computes S2(c1,c2), S3(c1,c2), S3(c2,c1), S4(c1,c2) and S4(c2,c1) if and only if compute_cross_matrix[c1,c2] == True

    • for c1 > c2, compute_cross_matrix[c1,c2] is ignored and should not be specified

    If None, it is replaced by a boolean matrix full of True, so that all cross-statistics are computed.

Output

  • data_stST_Statistics instance, or 1D array

    ST statistics of I, as a flatten array if flatten=True

classmethod from_ST_Statistics(st_stat)[source]

Alternative constructor, which generates the ST operator used to compute a given set of ST statistics.

Parameters:
  • st_stat (-) – st_stat instance whose parameters have to be reproduced

  • do (Remark and to)

  • ----------

  • fact (- In)

  • flatten (a ST_Statistics instance cannot transmit the)

  • parameter

  • clear (since it would have return a 1D array. This is not)

  • point. (for me how to deal with this)

STL_main.ST_Statistics module

Main structure of STL

Tentative proposal by EA

class STL_main.ST_Statistics.ST_Statistics(DataClass, N0, Nb, Nc, wavelet_op, SC, has_fewer_convolutions, compute_cross_matrix, compute_PS, n_bins, standardized, mean_pre_std, std_pre_std)[source]

Bases: object

Class whose instances correspond to an set of scattering statistics The set of statistics is built by the ST_operator method, which use the __init__ method. This class is DT-independent.

This class contains methods that allow to deal with ST statistics in an unified manner. Most of these methods can be applied directly through the ST_operator implementation.

When used in loss, a 1D array can be return using the for_loss method. It works for any type of ST_statistics. It can use a mask on the ST coefficients, which is option-dependent

Parameters:
  • Transform (# Scattering)

  • DT (-) – Type of data (1d, 2d planar, HealPix, 3d)

  • N0 (-) – initial size of array (can be multiple dimensions)

  • J (-) – number of scales

  • L (-) – number of orientations

  • WType (-) – type of wavelets

  • Transform

  • SC (-) – type of ST coefficients (“ScatCov”, “WPH”)

  • parameters (# Data array)

  • Nb (-) – size of batch

  • Nc (-) – number of channel

- parent parameters (DT,N0,J,L,WType,SC,Nb,Nc)
# Additional transform/compression
- norm

type of norm (“self”, “from_ref”)

Type:

str

- S2_ref_sqrt_chan_diag

array of reference S2 coefficients (normalized by sqrt of diagonal over channels)

Type:

array

- iso

keep only isotropic coefficients

Type:

bool

- angular_ft

perform angular fourier transform on the ST statistics

Type:

bool

- scale_ft

perform scale cosine transform on the ST statistics

Type:

bool

- flatten

only return a 1D-array and not a ST_Statistics instance

Type:

bool

- mask_st

mask to be applied when flatten ST statistics

Type:

list of position

# ST statistics
- S1, S2, S2p, S3, S4
Type:

array of relevant size to store the ST statistics

# Power Spectrum
- PS

whether power spectrum coefficients are computed

Type:

bool

plot_coeff(b: int = 0, c: int = 0, new_figure: bool = True)[source]

Reproduce the classical S1/S2/S3/S4 scattering plot for a given (batch, channel).

Parameters:
  • b (int) – Batch index (0 <= b < Nb).

  • c (int) – Channel index (0 <= c < Nc).

  • new_figure (bool) – If True, create a new figure. If False, plot into the existing one (useful to overlay multiple ST_statistic objects on the same panels).

select(param)[source]

Select and give tensor in output

Parameters:
  • -

  • Output

  • ----------

  • -

to_angular_ft(harmonics_angle=None)[source]

Angular harmonic transform on the ST statistcs

to_flatten(keep_batch_dim=False, mask_st=None, mean_along_batch=False, keepnans=False, flatten_complex=False)[source]

Produce either a 1d array that can be used for loss constructions or a 2d array, keeping the batch dimension, if keep_batch_dim is True.

A mask can be used to select the coefficients from the initial 1d array.

Parameters:
  • keep_batch_dim (-) – if True, the output will have shape [Nb, n_coeff] instead of [Nb * n_coeff]

  • mask_st (-) – mask for st coefficients after initial flattening

  • flatten_complex (-) – if True, complex coefficients will be flattened into two separate real numbers (real and imaginary parts). Since S1, S2 and PS are real, their null imaginary part won’t be included in the flattened statistics tensor.

  • Output

  • ----------

  • st_flatten (-)

to_iso()[source]

Isotropize the set of ST statistics

Note: S2_ref_sqrt_chan_diag is not isotropized since it is used before this step. Note: if self.PS = True, PS coefficients are already isotropized in PS_operator.

EA: could probably be better vectorized, to be done. EA: to be done properly with the backend. EA: Sihao used .real for S3 and S4, to consider.

to_norm(norm_type=None, S2_ref_sqrt_chan_diag=None, PS_ref_sqrt_chan_diag=None, var_ref=None, norm_batch_mean=True)[source]

Normalize the ST statistics. :param - norm_type: type of norm (“self”, “from_ref”) :type - norm_type: str :param - S2_ref_sqrt_chan_diag: if self.SC = “ScatCov”

array of reference S2 coefficients if “from_ref” (normalized by sqrt of diagonal over channels)

Parameters:
  • PS_ref_sqrt_chan_diag (-) – if self.PS = True array of reference Power Spectrum coefficients if “from_ref”

  • var_ref (-) – array of reference variance if “from_ref”

  • norm_batch_mean (-) – Used with the “self” normalization type. If True, the reference coefficients are averaged over the batch dimension (dim=0).

to_scale_ft(harmonics_scale=None, dj=None, harmonics_angle=None)[source]

Angular scale transform on the ST statistcs

STL_main.SphericalStencil module

STL_main.StlData module

STL_main.Synthesis module

class STL_main.Synthesis.ScatteringMatchModel(st_op, DataClass, pbc, init_shape, init_map, device, dtype, has_fewer_convolutions, compute_cross_matrix, compute_PS, keep_batch_dim, mean_field, prefilter_Nyquist, adhoc_weights)[source]

Bases: Module

forward()[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

STL_main.Synthesis.apply_nyquist_filter(tensor, plot=False)[source]

Apply a low-pass filter to an input tensor, keeping only frequencies within the Nyquist radius.

Parameters:

tensor (torch.Tensor) – Input tensor in real space of shape (…, N, M) where N and M are the spatial dimensions

Returns:

Filtered tensor in real space of the same shape as input, with high frequencies removed

Return type:

torch.Tensor

STL_main.Synthesis.optimize_from_maps(target, st_op_target, st_op_running, nbatch=1, pbc_running=True, running_shape=None, init_running=None, has_fewer_convolutions=False, compute_cross_matrix=None, compute_PS=False, mean_field=True, lr=1.0, max_iter=100, history_size=50, print_iter=10, verbose=True, seed=None, prefilter_Nyquist=True, adhoc_weights={'S3': 3.5, 'S4': 12.25})[source]
STL_main.Synthesis.optimize_from_stats(target_stats, st_op_running, nbatch, running_shape=None, pbc_running=True, init_running=None, mean_field=True, lr=1.0, max_iter=100, history_size=50, print_iter=10, verbose=True, seed=None, prefilter_Nyquist=True, adhoc_weights={'S3': 3.5, 'S4': 12.25})[source]

Notes: - Since the loss function is computed on a per-map basis.

  • The number of maps (Nb) to be synthesized must match the number of maps in the target statistics,

  • mean_field (comparing averaged over batch dimension statistics to handle synthesis with different batch sizes) is not relevant anymore.

  • Since statistics are already computed, one can not specify a running shape different from the target shape

STL_main.Synthesis.optimize_lbfgs(model, loss_fn, lr, max_iter, history_size, verbose, print_iter)[source]
STL_main.Synthesis.reweight(stats, weights)[source]
STL_main.Synthesis.synthesize_from_maps(data_target, nbatch, pbc_running, running_shape=None, init_running=None, running_mask=None, has_fewer_convolutions=False, compute_cross_matrix=None, mean_field=True, **optim_kwargs)[source]

User-friendly wrapper to synthesize field maps from target maps.

Parameters:
  • running_shape (tuple of int, optional) – Default is None. If None, the running field has the same shape as the target.

  • running_mask (torch.BoolTensor, optional) – Default is None. If None, the same mask as the target is used. If a new mask is provided, it must be a boolean tensor with shape matching the running field.

  • mean_field (bool, optional) – Default is True. Default value allows one to perform synthesis between N target samples and M running samples (with N different from or equal to M) while matching statistics computed from the batch-averaged field.

Notes

  • The Power Spectrum is optimized by default whenever possible (i.e., when no NaN values are present in either the target or the running data).

  • Within this user-level wrapper, an ST operator is created for both the target and the running map, and then passed to the mid-level wrapper. This is particularly useful for syntheses involving NaN values, as it allows the use of two distinct masks: one for the target and one for the running map. For other types of syntheses, the same ST operator is used for both.

STL_main.Synthesis.synthesize_from_stats(target_stats, nbatch, pbc_running, running_shape=None, init_running=None, running_mask=None, mean_field=True, **optim_kwargs)[source]
notes
  • Parameters such as compute_cross_matrix and has_fewer_convolutions are not specified as arguments of this wrapper,

but rather during the computation of the target statistics.

STL_main.WaveletTransform module

STL_main.torch_backend module

Torch backend for STL.

Provides a minimal API used across the codebase, with optional device selection (CPU / GPU) via a device argument.

Exposed functions

  • from_numpy(x, device=None, dtype=None)

  • zeros(shape, device=None, dtype=torch.float32)

  • mean(x, dim)

  • dim(x)

  • shape(x, axis=None)

  • nan

STL_main.torch_backend.dim(x) int[source]

Return the number of dimensions of a tensor-like object.

STL_main.torch_backend.eye(n, device=None, dtype=None)[source]

Return a tensor of ones on CPU or GPU depending on device.

Parameters:
  • shape (tuple or list)

  • device (None, str, or torch.device)

  • dtype (torch.dtype)

STL_main.torch_backend.maskmean(x, dim=(-2, -1), mask=None)[source]

Compute the mean of x along given dims, optionally masked. If mask is given, assumes mask.shape == x.shape[-mask.ndim:], flattens the masked dimensions and makes sure these were included in the input: dim.

Parameters:
  • x – input tensor

  • mask – boolean tensor, same shape as the last dimensions of x

  • dim – int or tuple of ints along which to compute the mean.

Returns:

Tensor with the mean

STL_main.torch_backend.ones(shape, device=None, dtype=None)[source]

Return a tensor of ones on CPU or GPU depending on device.

Parameters:
  • shape (tuple or list)

  • device (None, str, or torch.device)

  • dtype (torch.dtype)

STL_main.torch_backend.shape(x, axis=None)[source]

Return the shape of x, or the size along a given axis.

Parameters:
  • x (torch.Tensor or np.ndarray)

  • axis (int or None) – If None, return the full shape tuple. Otherwise, return the size along the given axis.

STL_main.torch_backend.to_torch_tensor(array)[source]

Transform input array (NumPy or PyTorch) into a PyTorch tensor.

Parameters:

array (np.ndarray or torch.Tensor) – Input array to be converted.

Returns:

Converted PyTorch tensor.

Return type:

torch.Tensor

STL_main.torch_backend.zeros(shape, device=None, dtype=None)[source]

Return a tensor of zeros on CPU or GPU depending on device.

Parameters:
  • shape (tuple or list)

  • device (None, str, or torch.device)

  • dtype (torch.dtype)

Module contents