# optimize_scattering_core
# optimize_from_maps
# optimize_from_stats
import time
import numpy as np
import torch
import torch.nn as nn
from torch import device, nn
from torch.optim import LBFGS
# === Learnable field model ===
[docs]
class ScatteringMatchModel(nn.Module):
def __init__(
self,
st_op,
DataClass,
pbc,
init_shape,
init_map,
device,
dtype,
has_fewer_convolutions,
compute_cross_matrix,
compute_PS,
keep_batch_dim,
mean_field,
prefilter_Nyquist,
adhoc_weights,
):
super().__init__()
# === Field configuration ===
self.st_op = st_op
self.DataClass = DataClass
self.pbc = pbc
self.init_shape = init_shape
self.init_map = init_map
self.device = device
self.dtype = dtype
# === Stats configuration ===
self.has_fewer_convolutions = has_fewer_convolutions
self.compute_cross_matrix = compute_cross_matrix
self.compute_PS = compute_PS
self.keep_batch_dim = keep_batch_dim
self.mean_field = mean_field
self.adhoc_weights = adhoc_weights
# === Initialize learnable field u ===
if self.init_map is None:
self.u = torch.randn(init_shape, device=device, dtype=dtype)
else:
self.u = (
torch.tensor(self.init_map)
.to(device=device, dtype=dtype)
.expand(init_shape)
)
if prefilter_Nyquist:
print("Prefiltering initial map to remove frequencies above Nyquist")
assert (
not self.u.isnan().any()
), "Cannot apply Nyquist filter on intial map with NaNs. Either remove NaNs from the initial map or specify prefilter_Nyquist=False."
self.u = apply_nyquist_filter(self.u)
# === Apply mask constraints ===
self.mask_full_res = st_op.wavelet_op.mask_full_res
if self.mask_full_res is not None:
# for security put large values which should raise abberant values if actually used
self.u[..., self.mask_full_res.array] = 1e10
self.u.requires_grad_()
if self.mask_full_res is not None:
def freeze_hook(grad):
return grad * (~self.mask_full_res.array)
self.u.register_hook(freeze_hook)
print(
"NaN detected in the running synthesis mask, the synthesis takes it into account"
)
[docs]
def forward(self):
# === Build data class ===
DC_u = self.DataClass(self.u, pbc=self.pbc)
# === Compute scattering statistics ===
st_u = self.st_op.apply(
DC_u,
has_fewer_convolutions=self.has_fewer_convolutions,
compute_cross_matrix=self.compute_cross_matrix,
compute_PS=self.compute_PS,
norm="load_ref",
)
# === Re-weight statistics ===
if self.adhoc_weights is not None:
reweight(st_u, self.adhoc_weights)
# === Flatten statistics ===
s_flat_u = st_u.to_flatten(
keep_batch_dim=self.keep_batch_dim,
mean_along_batch=self.mean_field,
keepnans=False,
)
return s_flat_u
[docs]
def reweight(stats, weights):
for coeff_label, weight in weights.items():
if hasattr(stats, coeff_label):
coeff = getattr(stats, coeff_label)
coeff *= weight
else:
raise ValueError(f"Unsupported coefficient label: {coeff_label}")
# === LBFGS optimization (low level)===
[docs]
def optimize_lbfgs(model, loss_fn, lr, max_iter, history_size, verbose, print_iter):
optimizer = LBFGS(
[model.u],
lr=lr,
max_iter=max_iter,
history_size=history_size,
line_search_fn="strong_wolfe",
tolerance_grad=1e-12,
tolerance_change=1e-15,
)
loss_history = []
def closure():
optimizer.zero_grad()
output = model()
loss = loss_fn(output)
loss.backward()
loss_history.append(loss.item())
if verbose and len(loss_history) % print_iter == 0:
print(f"[LBFGS] iter {len(loss_history)}, loss = {loss.item():.6e}")
return loss
start = time.perf_counter()
optimizer.step(closure)
end = time.perf_counter()
torch.cuda.empty_cache() if model.device.type == "cuda" else None
print(f"{len(loss_history)} iterations of synthesis.")
print(f"Execution time: {end - start:.3f} s")
u_opt = model.u.detach()
return u_opt
# === Optimization function for synthesis from target maps (mid level) ===
[docs]
def optimize_from_maps(
target,
st_op_target,
st_op_running,
nbatch=1,
pbc_running=True,
running_shape=None,
init_running=None,
has_fewer_convolutions=False,
compute_cross_matrix=None,
compute_PS=False,
mean_field=True,
lr=1.0,
max_iter=100,
history_size=50,
print_iter=10,
verbose=True,
seed=None,
prefilter_Nyquist=True,
adhoc_weights={"S3": 3.5, "S4": 3.5**2},
):
# Set random seed
torch.manual_seed(seed) if seed is not None else None
# ------- Set homogeneous configuration for device and dtype -------
device = st_op_running.wavelet_op.device
dtype = st_op_running.wavelet_op.dtype
print("Running synthesis on device:", device, "dtype:", dtype)
if target.array.isnan().any():
print("NaN detected in the target, the synthesis takes it into account")
# ------- Determine initial shape for u (from target) -------
input_dim = target.array.ndim
if input_dim == 2:
target_shape = (1, 1, *target.array.shape)
if running_shape is None:
init_shape = (nbatch, 1, *target.array.shape)
else:
assert len(running_shape) == 2, "running_shape should be a tuple of (H,W)"
init_shape = (nbatch, 1, *running_shape)
elif input_dim == 3:
target_shape = (1, *target.array.shape)
if running_shape is None:
init_shape = (nbatch, *target.array.shape)
else:
assert len(running_shape) == 2, "running_shape should be a tuple of (H,W)"
init_shape = (nbatch, target.array.shape[0], *running_shape)
elif input_dim == 4:
target_shape = target.array.shape
if running_shape is None:
init_shape = (nbatch, *target.array.shape[-3:])
else:
assert len(running_shape) == 2, "running_shape should be a tuple of (H,W)"
init_shape = (nbatch, target.array.shape[1], *running_shape)
else:
raise ValueError("target.array must be 2D, 3D or 4D tensor")
print("Initial shape for u:", init_shape)
if not mean_field and target.array.shape[0] != init_shape[0]:
raise ValueError(
"If mean_field is False, target and running batch sizes should match"
)
with torch.no_grad():
# ------- Standardize target -------
l_target = target.copy(empty=False)
l_target.array = l_target.array.reshape(target_shape)
if prefilter_Nyquist:
if l_target.array.isnan().any():
print(
"WARNING: prefiltering target above Nyquist is asked but target has NaNs. Only initial noise will be filtered."
)
else:
print("Prefiltering target to remove frequencies above Nyquist")
l_target.array = apply_nyquist_filter(l_target.array)
l_target, mean_target, std_target = st_op_target.wavelet_op.standardize(
l_target, mean_field=mean_field, inplace=True
) # [Nb, Nc] if mean_field else [1, Nc]
# ------- Compute target stats -------
target_stats = st_op_target.apply(
l_target,
has_fewer_convolutions=has_fewer_convolutions,
compute_cross_matrix=compute_cross_matrix,
compute_PS=compute_PS,
norm="store_ref",
norm_batch_mean=mean_field,
)
if adhoc_weights is not None:
reweight(target_stats, adhoc_weights)
target_stats = target_stats.to_flatten(
mean_along_batch=mean_field, keepnans=False
) # [n_stats] if mean_field else [Nb, n_stats]
target_stats = target_stats.detach()
print("Synthesis on {:} ST coefficients".format(target_stats.nelement()))
# ------- Transfer reference normalization from target to running operator -------
st_op_running.S2_ref_sqrt_chan_diag = st_op_target.S2_ref_sqrt_chan_diag
st_op_running.var_ref = st_op_target.var_ref
if compute_PS:
st_op_running.PS_ref_sqrt_chan_diag = st_op_target.PS_ref_sqrt_chan_diag
# ------- Build model -------
model = ScatteringMatchModel(
st_op=st_op_running,
DataClass=target.__class__,
pbc=pbc_running,
init_shape=init_shape,
init_map=init_running,
has_fewer_convolutions=has_fewer_convolutions,
compute_cross_matrix=compute_cross_matrix,
compute_PS=compute_PS,
keep_batch_dim=False,
mean_field=mean_field,
device=device,
dtype=dtype,
prefilter_Nyquist=prefilter_Nyquist,
adhoc_weights=adhoc_weights,
)
# ------- Launch optimization -------
loss_fn = lambda s_flat_u: ((s_flat_u - target_stats).abs() ** 2).sum()
u_opt = optimize_lbfgs(
model=model,
loss_fn=loss_fn,
lr=lr,
max_iter=max_iter,
history_size=history_size,
verbose=verbose,
print_iter=print_iter,
)
# ------- Post-process optimized u: unstandardize, apply mask constraints, reshape -------
DC_u_opt = target.__class__(array=u_opt, pbc=pbc_running)
st_op_running.wavelet_op.unstandardize(
DC_u_opt, mean=mean_target, std=std_target, inplace=True
)
u_opt = DC_u_opt.array
if st_op_running.wavelet_op.mask_full_res is not None:
u_opt[..., st_op_running.wavelet_op.mask_full_res.array] = torch.nan
if input_dim == 2:
u_opt = u_opt[:, 0, ...] # remove channel dim
if nbatch == 1:
u_opt = u_opt[0] # remove batch dim
return u_opt
# === Optimization function for synthesis from target stats (mid level) ===
[docs]
def optimize_from_stats(
target_stats,
st_op_running,
nbatch,
running_shape=None,
pbc_running=True,
init_running=None,
mean_field=True,
lr=1.0,
max_iter=100,
history_size=50,
print_iter=10,
verbose=True,
seed=None,
prefilter_Nyquist=True,
adhoc_weights={"S3": 3.5, "S4": 3.5**2},
):
"""
Notes:
- Since the loss function is computed on a per-map basis.
- The number of maps (Nb) to be synthesized must match the number of maps in the target statistics,
- mean_field (comparing averaged over batch dimension statistics to handle synthesis with different batch sizes) is not relevant anymore.
- Since statistics are already computed, one can not specify a running shape different from the target shape
"""
# Set random seed
torch.manual_seed(seed) if seed is not None else None
# ------- Set homogeneous configuration for device and dtype -------
device = st_op_running.wavelet_op.device
dtype = st_op_running.wavelet_op.dtype
print("Running synthesis on device:", device, "dtype:", dtype)
# ------- Determine initial shape for u (from target stats) -------
Nb, Nc, N, M = target_stats.Nb, target_stats.Nc, *target_stats.N0
if nbatch != Nb and not mean_field:
raise ValueError(
f"If mean_field is False, target batch size (Nb={Nb}) should match running batch size (nbatch={nbatch})"
)
if running_shape is not None:
init_shape = (nbatch, Nc, *running_shape)
else:
init_shape = (nbatch, Nc, N, M)
print(f"Initial shape for u: {init_shape}")
# ------- Target stats Processing -------
if adhoc_weights is not None:
reweight(target_stats, adhoc_weights)
target_stats_flat = target_stats.to_flatten(
keep_batch_dim=True,
mean_along_batch=mean_field,
keepnans=False,
) # [1, n_stats] if mean_field else [Nb, n_stats]
# Average pre-standardization stats over batch if mean_field is True (for unstandardization)
target_stats.mean_pre_std = (
target_stats.mean_pre_std.mean(dim=0)
if mean_field
else target_stats.mean_pre_std
) # [1, Nc] if mean_field else [Nc]
target_stats.std_pre_std = (
target_stats.std_pre_std.mean(dim=0) if mean_field else target_stats.std_pre_std
) # [1, Nc] if mean_field else [Nc]
# ------- Transfer reference normalization from target stats to running operator -------
# Those reference normalization attributes have been stored during normalisation of the target stats.
assert (
target_stats.S2_ref_sqrt_chan_diag is not None
), "target_stats should be normalized to perform reference normalization attributes transfer"
st_op_running.S2_ref_sqrt_chan_diag = target_stats.S2_ref_sqrt_chan_diag
st_op_running.var_ref = target_stats.var_ref
if st_op_running.compute_PS:
st_op_running.PS_ref_sqrt_chan_diag = target_stats.PS_ref_sqrt_chan_diag
# ------- Build model -------
model = ScatteringMatchModel(
st_op=st_op_running,
DataClass=target_stats.DataClass,
pbc=pbc_running,
init_shape=init_shape,
init_map=init_running,
has_fewer_convolutions=target_stats.has_fewer_convolutions,
compute_cross_matrix=target_stats.compute_cross_matrix,
compute_PS=st_op_running.compute_PS,
keep_batch_dim=True,
mean_field=mean_field,
device=device,
dtype=dtype,
prefilter_Nyquist=prefilter_Nyquist,
adhoc_weights=adhoc_weights,
)
# ------- Launch optimization -------
def loss_fn(s_flat_u):
loss = ((s_flat_u - target_stats_flat).abs() ** 2).sum()
return loss if not mean_field else loss / Nb
u_opt = optimize_lbfgs(
model=model,
loss_fn=loss_fn,
lr=lr,
max_iter=max_iter,
history_size=history_size,
verbose=verbose,
print_iter=print_iter,
)
# ------- Post-process optimized u: unstandardize, apply mask constraints -------
if target_stats.standardized:
DC_u_opt = target_stats.DataClass(u_opt, pbc=pbc_running)
st_op_running.wavelet_op.unstandardize(
DC_u_opt,
mean=target_stats.mean_pre_std,
std=target_stats.std_pre_std,
inplace=True,
)
u_opt = DC_u_opt.array
if st_op_running.wavelet_op.mask_full_res is not None:
u_opt[..., st_op_running.wavelet_op.mask_full_res.array] = torch.nan
if Nc == 1:
u_opt = u_opt[:, 0, ...] # remove channel dim
if nbatch == 1:
u_opt = u_opt[0] # remove batch dim
return u_opt
#######################################################################################
# ------- Pre/Post-processing functions -------
[docs]
def apply_nyquist_filter(tensor, plot=False):
"""
Apply a low-pass filter to an input tensor, keeping only frequencies within the Nyquist radius.
Parameters
----------
tensor : torch.Tensor
Input tensor in real space of shape (..., N, M) where N and M are the spatial dimensions
Returns
-------
torch.Tensor
Filtered tensor in real space of the same shape as input, with high frequencies removed
"""
dim = (-2, -1)
# Compute frequency grids
N, M = tensor.shape[-2:]
fx = N * torch.fft.fftfreq(N, d=1.0, device=tensor.device)
fy = M * torch.fft.fftfreq(M, d=1.0, device=tensor.device)
FX, FY = torch.meshgrid(fx, fy, indexing="ij")
# Create low-pass mask based on Nyquist radius
r2 = FX**2 + FY**2
nyquist_radius = min(N, M) / 2
mask = r2 <= nyquist_radius**2
# Apply mask in Fourier space
tensor_fft = torch.fft.fft2(tensor, dim=dim)
tensor_fft[..., ~mask] = 0.0
# Inverse Fourier transform to get the filtered tensor in real space
tensor_filtered = torch.fft.ifft2(tensor_fft, dim=dim)
if not tensor.is_complex():
tensor_filtered = tensor_filtered.real
if plot:
import matplotlib.pyplot as plt
plt.imshow(
np.abs(torch.fft.fftshift(torch.fft.fft2(tensor[0, 0])).cpu().numpy()),
cmap="plasma",
norm="log",
)
plt.colorbar()
plt.title("Nyquist filtered tensor (Fourier)")
plt.show()
return tensor_filtered
# === User-friendly wrapper for synthesis from target maps (high level) ===
[docs]
def synthesize_from_maps(
data_target,
nbatch,
pbc_running,
running_shape=None,
init_running=None,
running_mask=None,
has_fewer_convolutions=False,
compute_cross_matrix=None,
mean_field=True,
**optim_kwargs,
):
"""
User-friendly wrapper to synthesize field maps from target maps.
Parameters
----------
running_shape : tuple of int, optional
Default is None. If None, the running field has the same shape as the target.
running_mask : torch.BoolTensor, optional
Default is None. If None, the same mask as the target is used.
If a new mask is provided, it must be a boolean tensor with shape matching the running field.
mean_field : bool, optional
Default is True. Default value allows one to perform synthesis between N target samples and M running samples (with N different from or
equal to M) while matching statistics computed from the batch-averaged field.
Notes
-----
- The Power Spectrum is optimized by default whenever possible (i.e., when no NaN values are present in either the target or the running data).
- Within this user-level wrapper, an ST operator is created for both the target and the running map, and then passed to the mid-level wrapper.
This is particularly useful for syntheses involving NaN values, as it allows the use of two distinct masks: one for the target and one for the running map.
For other types of syntheses, the same ST operator is used for both.
"""
if running_mask is None:
# Same mask for running and target
array = (
np.zeros(running_shape) if running_shape is not None else data_target.array
)
data_running = data_target.__class__(array=array, pbc=pbc_running)
else:
if running_shape is None and data_target.array.shape[-2:] != running_mask.shape:
raise ValueError("running_mask shape should match target array shape")
elif running_shape is not None and running_shape != running_mask.shape:
raise ValueError("running_mask shape should match running_shape")
data_running = data_target.__class__(array=running_mask, pbc=pbc_running)
# Select J used for synthesis
J_target = data_target.get_wavelet_op().J - (not data_target.pbc)
J_running = data_running.get_wavelet_op().J - (not data_running.pbc)
J = min(J_target, J_running)
n_bins_target = data_target.get_CS_op().n_bins
n_bins_running = data_running.get_CS_op().n_bins
n_bins = min(n_bins_target, n_bins_running)
if J_target != J_running:
print(
f"Warning: target.J = {J_target}, running.J = {J_running}. Synthesis will use J = {J}."
)
# Get scattering operators for target and running data with selected J
st_op_target = data_target.get_ST_op(
J=J, n_bins=n_bins, has_fewer_convolutions=has_fewer_convolutions
)
st_op_running = data_running.get_ST_op(
J=J,
n_bins=n_bins,
has_fewer_convolutions=has_fewer_convolutions,
replace_nan_value=None,
)
# Disable power spectrum optimization if NaN values are present in target and/or running data
target_has_nan = data_target.array.isnan().any()
running_has_nan = data_running.array.isnan().any()
if target_has_nan or running_has_nan:
print(
"⚠️ Warning: NaN detected in target and/or running data.\n"
"Power spectrum optimization is disabled because its computation is not yet implemented for NaN values in any dataclass. \n"
)
compute_PS = not (target_has_nan or running_has_nan)
# Set default optimization parameters and update with user-provided values
optim_params = dict(
max_iter=100,
lr=1.0,
history_size=50,
print_iter=10,
verbose=True,
seed=26,
prefilter_Nyquist=(
True if init_running is None else not init_running.isnan.any()
),
adhoc_weights={"S3": 3.5, "S4": 3.5**2},
)
optim_params.update(optim_kwargs)
# Run optimization
u_opt = optimize_from_maps(
target=data_target,
st_op_target=st_op_target,
st_op_running=st_op_running,
nbatch=nbatch,
pbc_running=pbc_running,
running_shape=running_shape,
init_running=init_running,
mean_field=mean_field,
has_fewer_convolutions=has_fewer_convolutions,
compute_cross_matrix=compute_cross_matrix,
compute_PS=compute_PS,
**optim_params,
)
return u_opt
# === User-friendly wrapper for synthesis from target statistics (high level) ===
[docs]
def synthesize_from_stats(
target_stats,
nbatch,
pbc_running,
running_shape=None,
init_running=None,
running_mask=None,
mean_field=True,
**optim_kwargs,
):
"""
notes
- Parameters such as `compute_cross_matrix` and `has_fewer_convolutions` are not specified as arguments of this wrapper,
but rather during the computation of the target statistics.
"""
if running_mask is None:
if running_shape is not None:
array = torch.zeros(running_shape)
else:
if target_stats.mask_full_res is None:
array = torch.zeros(target_stats.N0)
else:
array = torch.where(target_stats.mask_full_res.array, torch.nan, 0.0)
array = array.to(device=target_stats.device, dtype=target_stats.dtype)
data_running = target_stats.DataClass(array=array, pbc=pbc_running)
else:
if running_mask.shape != target_stats.N0:
raise ValueError("running_mask shape should match target_stats N0")
data_running = target_stats.DataClass(array=running_mask, pbc=pbc_running)
running_has_nan = data_running.array.isnan().any()
if not target_stats.compute_PS or running_has_nan:
print(
"⚠️ Warning: Power spectrum optimization is disabled because it is not implemented for NaN values in any dataclass"
" and/or because Power spectrum computation has not been included in target_stats.\n"
)
# Remove power spectrum from target_stats if not optimizable on the running side
if target_stats.compute_PS:
target_stats.compute_PS = False
compute_PS = target_stats.compute_PS and not running_has_nan
# Get scattering operator for running data
st_op_running = data_running.get_ST_op(
J=target_stats.J,
n_bins=target_stats.n_bins,
has_fewer_convolutions=target_stats.has_fewer_convolutions,
compute_PS=compute_PS,
replace_nan_value=None,
)
# Set default optimization parameters and update with user-provided values
optim_params = dict(
max_iter=100,
lr=1.0,
history_size=50,
print_iter=10,
verbose=True,
seed=26,
prefilter_Nyquist=(
True if init_running is None else not init_running.isnan.any()
),
adhoc_weights={"S3": 3.5, "S4": 3.5**2},
)
optim_params.update(optim_kwargs)
# Run optimization
u_opt = optimize_from_stats(
target_stats=target_stats,
st_op_running=st_op_running,
nbatch=nbatch,
running_shape=running_shape,
pbc_running=pbc_running,
init_running=init_running,
mean_field=mean_field,
**optim_params,
)
return u_opt