Source code for STL_main.ST_Statistics

# -*- coding: utf-8 -*-
"""
Main structure of STL

Tentative proposal by EA
"""

import matplotlib.pyplot as plt
import numpy as np
import torch as bk  # mean, zeros

###############################################################################
###############################################################################


[docs] class ST_Statistics: """ Class whose instances correspond to an set of scattering statistics The set of statistics is built by the ST_operator method, which use the __init__ method. This class is DT-independent. This class contains methods that allow to deal with ST statistics in an unified manner. Most of these methods can be applied directly through the ST_operator implementation. When used in loss, a 1D array can be return using the for_loss method. It works for any type of ST_statistics. It can use a mask on the ST coefficients, which is option-dependent Parameters ---------- # Data type and Wavelet Transform - DT : str Type of data (1d, 2d planar, HealPix, 3d) - N0 : tuple initial size of array (can be multiple dimensions) - J : int number of scales - L : int number of orientations - WType : str type of wavelets # Scattering Transform - SC : str type of ST coefficients ("ScatCov", "WPH") # Data array parameters - Nb : int size of batch - Nc : int number of channel Attributes ---------- - parent parameters (DT,N0,J,L,WType,SC,Nb,Nc) # Additional transform/compression - norm : str type of norm (“self”, “from_ref”) - S2_ref_sqrt_chan_diag : array array of reference S2 coefficients (normalized by sqrt of diagonal over channels) - iso : bool keep only isotropic coefficients - angular_ft : bool perform angular fourier transform on the ST statistics - scale_ft : bool perform scale cosine transform on the ST statistics - flatten : bool only return a 1D-array and not a ST_Statistics instance - mask_st : list of position mask to be applied when flatten ST statistics # ST statistics - S1, S2, S2p, S3, S4 : array of relevant size to store the ST statistics # Power Spectrum - PS : bool whether power spectrum coefficients are computed """ ######################################## def __init__( self, DataClass, N0, ######################################## not used? Nb, Nc, wavelet_op, SC, has_fewer_convolutions, compute_cross_matrix, compute_PS, n_bins, standardized, mean_pre_std, std_pre_std, ): """ Constructor, see details above. """ # Main parameters self.DataClass = DataClass self.N0 = N0 ######################################## not used? self.Nb = Nb self.Nc = Nc # Wavelet transform related parameters self.wavelet_op = wavelet_op self.J = self.wavelet_op.J self.L = self.wavelet_op.L self.WType = self.wavelet_op.WType self.mask_full_res = self.wavelet_op.mask_full_res self.device = self.wavelet_op.device self.dtype = self.wavelet_op.dtype # Scattering transform related parameters self.SC = SC self.S2_ref_sqrt_chan_diag = None self.has_fewer_convolutions = has_fewer_convolutions self.compute_cross_matrix = compute_cross_matrix # Power spectrum computation self.compute_PS = compute_PS self.n_bins = n_bins self.PS_ref_sqrt_chan_diag = None # Mean and variance of data if standardization is applied self.standardized = standardized self.mean_pre_std = mean_pre_std self.std_pre_std = std_pre_std # Additional transform/compression related parameters. While put to # False/None for the initialization, their value are modified if these # methods are called by the scattering operator, or independently. self.norm = False self.iso = False self.angular_ft = False self.scale_ft = False self.flatten = False self.mask_st = None # Not used in flatten method for now @staticmethod def _get_sqrt_chan_diag(stat_ref, norm_batch_mean): """ Prepare S2_ref that has shape [Nb,Nc,Nc,J,L] or PS_ref that has shape [Nb,Nc,Nc,n_bins] by keeping its diagonal, averaging over the batch dimension if norm_batch_mean is True, and applying sqrt. """ stat_ref_chan_diag = stat_ref.diagonal(dim1=1, dim2=2).movedim( -1, 1 ) # [Nb,Nc,J,L] or [Nb,Nc,n_bins] for S2 or PS stat_ref_sqrt_chan_diag = bk.sqrt( stat_ref_chan_diag.mean(dim=0, keepdim=True) if norm_batch_mean else stat_ref_chan_diag ) # [Nb,Nc,J,L] or [Nb,Nc,n_bins] return stat_ref_sqrt_chan_diag def _normalize_scatcov(self): """ Normalize the ScatCov statistics S1,S2,S3,S4 using self.S2_ref_sqrt_chan_diag """ # self.S1 = self.S1 / self.S2_ref_sqrt_chan_diag # [Nb,Nc,J1,L1] self.S1 = self.S1 / bk.sqrt( self.S2_ref_sqrt_chan_diag[:, :, None] * self.S2_ref_sqrt_chan_diag[:, None, :] ) # [Nb,Nc,Nc,J1,L1] self.S2 = self.S2 / ( self.S2_ref_sqrt_chan_diag[:, :, None] * self.S2_ref_sqrt_chan_diag[:, None, :] ) # [Nb,Nc,Nc,J1,L1] self.S3 = self.S3 / ( self.S2_ref_sqrt_chan_diag[:, :, None, :, None, :, None] * self.S2_ref_sqrt_chan_diag[:, None, :, None, :, None, :] ) # [Nb,Nc,Nc,J1,J2,L1,L2] self.S4 = self.S4 / ( self.S2_ref_sqrt_chan_diag[:, :, None, :, None, None, :, None, None] * self.S2_ref_sqrt_chan_diag[:, None, :, None, :, None, None, :, None] ) # [Nb,Nc,Nc,J1,J2,J3,L1,L2,L3] ########################################
[docs] def to_norm( self, norm_type=None, S2_ref_sqrt_chan_diag=None, PS_ref_sqrt_chan_diag=None, var_ref=None, norm_batch_mean=True, ): """ Normalize the ST statistics. Parameters ---------- - norm_type : str type of norm (“self”, “from_ref”) - S2_ref_sqrt_chan_diag : array if self.SC = "ScatCov" array of reference S2 coefficients if "from_ref" (normalized by sqrt of diagonal over channels) - PS_ref_sqrt_chan_diag : array if self.PS = True array of reference Power Spectrum coefficients if "from_ref" - var_ref : array array of reference variance if "from_ref" - norm_batch_mean : bool, default True Used with the "self" normalization type. If True, the reference coefficients are averaged over the batch dimension (dim=0). """ # Check the proper ordering if self.iso: raise Exception("Normalization can only be done before isotropization") if self.angular_ft: raise Exception("Normalization can only be done before angular ft") if self.scale_ft: raise Exception("Normalization can only be done before scale_ft") # Leave the function if no normalization is required if norm_type is None: raise Exception("No normalization type specified") # Verifications if self.norm: raise Exception("ST statistics are already normalized") # Store_ref normalization elif norm_type == "self": var_ref = self.var * 1.0 self.var_ref = ( var_ref if not norm_batch_mean else var_ref.mean(dim=0, keepdim=True) ) if self.SC == "ScatCov": if self.S2_ref_sqrt_chan_diag is None: # prepare self.S2 that has shape [Nb,Nc,Nc,J,L] by keeping its diagonal and applying sqrt # and store as reference self.S2_ref_sqrt_chan_diag = self._get_sqrt_chan_diag( self.S2, norm_batch_mean=norm_batch_mean ) if self.compute_PS: if self.PS_ref_sqrt_chan_diag is None: # prepare self.PS that has shape [Nb,Nc,Nc,J,L] by keeping its diagonal, averaging over the batch dimension if norm_batch_mean is True, and applying sqrt # and store as reference self.PS_ref_sqrt_chan_diag = self._get_sqrt_chan_diag( self.PS, norm_batch_mean=norm_batch_mean ) # Load_ref normalization elif norm_type == "from_ref": self.var_ref = var_ref if self.SC == "ScatCov": # store as reference self.S2_ref_sqrt_chan_diag = S2_ref_sqrt_chan_diag if self.compute_PS: self.PS_ref_sqrt_chan_diag = PS_ref_sqrt_chan_diag # Perform normalization and store reference self.var = self.var / self.var_ref if self.SC == "ScatCov": self._normalize_scatcov() if self.compute_PS: self.PS = self.PS / ( self.PS_ref_sqrt_chan_diag[:, :, None, :] * self.PS_ref_sqrt_chan_diag[:, None, :, :] ) # [Nb, Nc, Nc, n_bins] # Store normalization parameters self.norm = True return self
[docs] def to_iso(self): """ Isotropize the set of ST statistics Note: S2_ref_sqrt_chan_diag is not isotropized since it is used before this step. Note: if self.PS = True, PS coefficients are already isotropized in PS_operator. EA: could probably be better vectorized, to be done. EA: to be done properly with the backend. EA: Sihao used .real for S3 and S4, to consider. """ if self.angular_ft: raise Exception("Isotropization can only be done before angular ft") if self.scale_ft: raise Exception("Isotropization can only be done before scate_ft") Nb, Nc = self.Nb, self.Nc J, L = self.J, self.L if self.SC == "ScatCov": # S1 and S2 # self.S1 = bk.mean(self.S1.mean, -1) # (Nb,Nc,J,L) -> (Nb,Nc,J) # self.S1 = bk.mean(self.S2.mean, -1) # (Nb,Nc,J,L) -> (Nb,Nc,J) # self.S1 = bk.mean(self.S1, -1) # (Nb,Nc,J,L) -> (Nb,Nc,J) self.S1 = bk.mean(self.S1, -1) # (Nb,Nc,Nc,J,L) -> (Nb,Nc,Nc,J) self.S2 = bk.mean(self.S2, -1) # (Nb,Nc,Nc,J,L) -> (Nb,Nc,Nc,J) # S3 and S4 S3iso = bk.zeros( (Nb, Nc, Nc, J, J, L), device=self.S3.device, dtype=self.S3.dtype ) S4iso = bk.zeros( (Nb, Nc, Nc, J, J, J, L, L), device=self.S4.device, dtype=self.S3.dtype ) for l1 in range(L): for l2 in range(L): # (Nb,Nc,Nc,J,J,L,L) -> (Nb,Nc,Nc,J,J,L) S3iso[..., (l2 - l1) % L] += self.S3[..., l1, l2] for l3 in range(L): # (Nb,Nc,Nc,J,J,J,L,L,L) -> (Nb,Nc,Nc,J,J,J,L,L) S4iso[..., (l2 - l1) % L, (l3 - l1) % L] += self.S4[ ..., l1, l2, l3 ] self.S3 = S3iso / L self.S4 = S4iso / L # store isotropy parameter self.iso = True return self
########################################
[docs] def to_angular_ft(self, harmonics_angle=None): """ Angular harmonic transform on the ST statistcs """ Nb, Nc = self.Nb, self.Nc J, L = self.J, self.L if harmonics_angle is None: harmonics_angle = self.L if self.scale_ft: raise Exception("Angular_tf can only be done before scale_ft") # perform angular transform, to be done if self.SC == "ScatCov": if self.iso: S3f = bk.fft.fftn( self.S3, norm="ortho", dim=(-1) ) # (Nb,Nc,Nc,J,J,L) -> (Nb,Nc,Nc,J,J,L) S4f = bk.fft.fftn( self.S4, norm="ortho", dim=(-1, -2) ) # (Nb,Nc,Nc,J,J,L,L) -> (Nb,Nc,Nc,J,J,L,L) S3f = S3f[ ..., :harmonics_angle ] # keep zeroth, first and second harmonic S4f = S4f[..., :harmonics_angle, :harmonics_angle] else: S3f = bk.fft.fftn( self.S3, norm="ortho", dim=(-1, -2) ) # (Nb,Nc,Nc,J,J,L,L) -> (Nb,Nc,Nc,J,J,L,L) S4f = bk.fft.fftn( self.S4, norm="ortho", dim=(-1, -2, -3) ) # (Nb,Nc,Nc,J,J,L,L,L) -> (Nb,Nc,Nc,J,J,L,L,L) S3f = S3f[..., :harmonics_angle, :harmonics_angle] S4f = S4f[..., :harmonics_angle, :harmonics_angle, :harmonics_angle] self.S3 = S3f self.S4 = S4f # store angular_ft parameter self.angular_ft = True return self
########################################
[docs] def to_scale_ft(self, harmonics_scale=None, dj=None, harmonics_angle=None): """ Angular scale transform on the ST statistcs """ Nb, Nc = self.Nb, self.Nc if harmonics_scale is None: harmonics_scale = self.J if dj is None: dj = self.J + 1 if harmonics_angle is None: harmonics_angle = self.L J, L = self.J, harmonics_angle def cosinus1(N, device, dtype): """ The cosine basis cos(kpi (. + 0.5) / N) for 0 <= k < N. :param N: :return: """ if N == 0: return bk.zeros((0, 0), device=device, dtype=dtype) ts = bk.linspace(0, bk.pi * (N - 1) / N, N, device=device, dtype=dtype) + ( 0.5 * bk.pi / N ) indices = bk.stack([k * ts for k in range(N)], dim=0) F = bk.cos(indices) F[1:, :] *= bk.sqrt(bk.tensor(2 / N, device=device, dtype=dtype)) F[0, :] *= bk.sqrt(bk.tensor(1 / N, device=device, dtype=dtype)) return F if self.SC == "ScatCov": if self.iso: S3_reparam = bk.zeros( (Nb, Nc, Nc, J, J, L), device=self.S3.device, dtype=self.S3.dtype ) S4_reparam = bk.zeros( (Nb, Nc, Nc, J, J, J, L, L), device=self.S4.device, dtype=self.S3.dtype, ) nan_complex = bk.tensor( complex(float("nan"), float("nan")), dtype=self.S3.dtype, device=self.S3.device, ) for j1 in range(J): for j2 in range(J): dj2 = (j2 - j1) % J # Set to NaN if dj2 > 3 if dj2 >= dj: S3_reparam[..., j1, dj2, :] = nan_complex else: S3_reparam[..., j1, dj2, :] = self.S3[..., j1, j2, :] for j3 in range(J): dj3 = (j3 - j1) % J dj32 = (j3 - j2) % J # Set to NaN if dj2 > 3 or dj3 > 3 or dj32 > 3 if dj2 >= dj or dj3 >= dj or dj32 >= dj: S4_reparam[..., j1, dj2, dj3, :, :] = nan_complex else: S4_reparam[..., j1, dj2, dj3, :, :] = self.S4[ ..., j1, j2, j3, :, : ] # # Apply DCT on the first J dimension F = cosinus1(J, device=self.S3.device, dtype=self.S3.real.dtype) S3_real = S3_reparam.real S3_imag = S3_reparam.imag # Create mask for NaN values nan_mask = bk.isnan(S3_real) | bk.isnan(S3_imag) # Replace NaN with 0 for computation S3_real_clean = bk.nan_to_num(S3_real, nan=0.0) S3_imag_clean = bk.nan_to_num(S3_imag, nan=0.0) S3_real_reshaped = S3_real_clean.reshape(-1, J, L) S3_imag_reshaped = S3_imag_clean.reshape(-1, J, L) S3_cos_real = bk.matmul(F, S3_real_reshaped) S3_cos_imag = bk.matmul(F, S3_imag_reshaped) S3_cos = bk.complex(S3_cos_real, S3_cos_imag).reshape(S3_reparam.shape) # Restore NaN where they were present S3_cos[nan_mask] = complex(float("nan"), float("nan")) S4_real = S4_reparam.real S4_imag = S4_reparam.imag # Create mask for NaN values nan_mask = bk.isnan(S4_real) | bk.isnan(S4_imag) # Replace NaN with 0 for computation S4_real_clean = bk.nan_to_num(S4_real, nan=0.0) S4_imag_clean = bk.nan_to_num(S4_imag, nan=0.0) S4_real_reshaped = S4_real_clean.reshape(-1, J, J, L, L) S4_imag_reshaped = S4_imag_clean.reshape(-1, J, J, L, L) S4_cos_real = bk.einsum("ij,bjklm->biklm", F, S4_real_reshaped) S4_cos_imag = bk.einsum("ij,bjklm->biklm", F, S4_imag_reshaped) S4_cos = bk.complex(S4_cos_real, S4_cos_imag).reshape(S4_reparam.shape) # Restore NaN where they were present S4_cos[nan_mask] = complex(float("nan"), float("nan")) else: # TODO pass else: pass self.S3 = S3_cos[:, :, :, 0:harmonics_scale] self.S4 = S4_cos[:, :, :, 0:harmonics_scale] # store scale_ft parameter self.scale_ft = True return self
########################################
[docs] def to_flatten( self, keep_batch_dim=False, mask_st=None, mean_along_batch=False, keepnans=False, flatten_complex=False, ): """ Produce either a 1d array that can be used for loss constructions or a 2d array, keeping the batch dimension, if keep_batch_dim is True. A mask can be used to select the coefficients from the initial 1d array. Parameters ---------- - keep_batch_dim : bool, default False if True, the output will have shape [Nb, n_coeff] instead of [Nb * n_coeff] - mask_st : binary 1d array mask for st coefficients after initial flattening - flatten_complex : bool, default False if True, complex coefficients will be flattened into two separate real numbers (real and imaginary parts). Since S1, S2 and PS are real, their null imaginary part won't be included in the flattened statistics tensor. Output ---------- - st_flatten : 1d array if keep_batch_dim is False, else 2d array with shape [Nb, n_coeff] """ # Collect all statistics into a list stats = [self.mean, self.var] # Always include mean and variance stats_names = ["mean", "var"] if self.SC == "ScatCov": stats += [self.S1, self.S2, self.S3, self.S4] stats_names += ["S1", "S2", "S3", "S4"] if self.compute_PS: stats += [self.PS] stats_names += ["PS"] if mean_along_batch: stats = [bk.mean(s, dim=0, keepdim=True) for s in stats] # Flatten each, remove NaNs, concat flattened_list = [] for S, S_name in zip(stats, stats_names): if flatten_complex and bk.is_complex(S): S = bk.view_as_real( S ) # [..., 2] with last dimension for real and imag parts if S_name in ["S1", "S2", "PS"]: S = S[..., 0] # Keep only real part S_flat = S.reshape(-1) if not keep_batch_dim else S.reshape(S.shape[0], -1) # Remove NaNs if specified if not keepnans: nan_mask = bk.isnan(S_flat) if not keep_batch_dim: flattened_list.append(S_flat[~nan_mask]) else: # Check that all batch elements have NaNs in the same positions assert bk.all(nan_mask == nan_mask[0]), ( "NaNs must be at the same indices across all batch elements " "when keep_batch_dim=True and keepnans=False" ) flattened_list.append(S_flat[:, ~nan_mask[0]]) else: flattened_list.append(S_flat) # Concatenate all statistics into a single 1D vector (or 2D if keep_batch_dim is True) st_flatten = ( bk.cat(flattened_list, dim=0) if not keep_batch_dim else bk.cat(flattened_list, dim=1) ) # Optional mask after nan-removal if mask_st is not None: mask_st = bk.as_tensor(mask_st, dtype=bk.bool, device=st_flatten.device) if not keep_batch_dim: if mask_st.numel() != st_flatten.numel(): raise ValueError( f"mask_st length ({mask_st.numel()}) does not match " f"flattened statistic length ({st_flatten.numel()})." ) st_flatten = st_flatten[mask_st] else: assert bk.all(mask_st == mask_st[0]), ( "mask_st must be identical across all batch elements " "when keep_batch_dim=True" ) st_flatten = st_flatten[:, mask_st[0]] self.st_flatten = st_flatten return st_flatten
########################################
[docs] def select(self, param): """ Select and give tensor in output Parameters ---------- - Output ---------- - """ output = 1 return output
######################################## def _to_np(self, x): if isinstance(x, bk.Tensor): return x.detach().cpu().numpy() return np.asarray(x) ########################################
[docs] def plot_coeff(self, b: int = 0, c: int = 0, new_figure: bool = True): """ Reproduce the classical S1/S2/S3/S4 scattering plot for a given (batch, channel). Parameters ---------- b : int Batch index (0 <= b < Nb). c : int Channel index (0 <= c < Nc). new_figure : bool If True, create a new figure. If False, plot into the existing one (useful to overlay multiple ST_statistic objects on the same panels). """ # ---- extract S1..S4 for one (b,c) and convert to numpy ---- def to_np(x): if isinstance(x, bk.Tensor): return x.detach().cpu().numpy() return np.asarray(x) if self.S1 is None: raise ValueError("S1 is None; nothing to plot.") S1_bc = to_np(self.S1[b, c]) # (J, L) S2_bc = to_np(self.S2[b, c]) # (J, L) S3_bc = to_np(self.S3[b, c]) # (J, J, L, L) S4_bc = to_np(self.S4[b, c]) # (J, J, J, L, L, L) # Put a fake 'image' dimension of size 1 to match the old plotting code S1 = S1_bc[None, ...] # (1, J, L) S2 = S2_bc[None, ...] # (1, J, L) # ---- build the compact index arrays for S3 and S4 as in your old code ---- J = S1.shape[1] N_orient = S1.shape[2] # count combinations for S3 and S4 n_s3 = 0 n_s4 = 0 for j3 in range(J): for j2 in range(j3 + 1): n_s3 += 1 for j1 in range(j2 + 1): n_s4 += 1 j1_s3 = np.zeros(n_s3, dtype=int) j2_s3 = np.zeros(n_s3, dtype=int) j1_s4 = np.zeros(n_s4, dtype=int) j2_s4 = np.zeros(n_s4, dtype=int) j3_s4 = np.zeros(n_s4, dtype=int) n_s3 = 0 n_s4 = 0 for j3 in range(J): for j2 in range(0, j3 + 1): j1_s3[n_s3] = j2 j2_s3[n_s3] = j3 n_s3 += 1 for j1 in range(0, j2 + 1): j1_s4[n_s4] = j1 j2_s4[n_s4] = j2 j3_s4[n_s4] = j3 n_s4 += 1 # Now we build compact S3 and S4 arrays with shape # S3: (1, n_s3, L, L) # S4: (1, n_s4, L, L, L) S3 = np.zeros((1, len(j1_s3), N_orient, N_orient), dtype=S3_bc.dtype) S4 = np.zeros((1, len(j1_s4), N_orient, N_orient, N_orient), dtype=S4_bc.dtype) for idx in range(len(j1_s3)): j1 = j1_s3[idx] j2 = j2_s3[idx] S3[0, idx, :, :] = S3_bc[j1, j2, :, :] for idx in range(len(j1_s4)): j1 = j1_s4[idx] j2 = j2_s4[idx] j3 = j3_s4[idx] S4[0, idx, :, :, :] = S4_bc[j1, j2, j3, :, :, :] # ---- now we reproduce your original plot_scat(S1,S2,S3,S4) ---- color = ["b", "r", "orange", "pink"] symbol = ["", ":", "-", "."] if new_figure: plt.figure(figsize=(16, 12)) # ----- S1 ----- plt.subplot(2, 2, 1) for k in range(min(4, N_orient)): plt.plot(S1[0, :, k], color=color[k % len(color)], label=rf"$\Theta = {k}$") plt.legend(frameon=False, ncol=2) plt.xlabel(r"$J_1$") plt.ylabel(r"$S_1$") plt.yscale("log") # ----- S2 ----- plt.subplot(2, 2, 2) for k in range(min(4, N_orient)): plt.plot(S2[0, :, k], color=color[k % len(color)], label=rf"$\Theta = {k}$") plt.xlabel(r"$J_1$") plt.ylabel(r"$S_2$") plt.yscale("log") # ----- S3 ----- plt.subplot(2, 2, 3) # nidx to separate groups of constant j1 nidx = np.concatenate( [np.zeros([1], dtype=int), np.cumsum(np.bincount(j1_s3, minlength=J))], axis=0, ) l_pos = [] l_name = [] for i in np.unique(j1_s3): idx = np.where(j1_s3 == i)[0] for k in range(min(4, N_orient)): for l in range(min(4, N_orient)): if i == 0: plt.plot( j2_s3[idx] + nidx[i], S3[0, idx, k, l], symbol[l % len(symbol)], color=color[k % len(color)], label=rf"$\Theta = {k},{l}$", ) else: plt.plot( j2_s3[idx] + nidx[i], S3[0, idx, k, l], symbol[l % len(symbol)], color=color[k % len(color)], ) l_pos += list(j2_s3[idx] + nidx[i]) l_name += [f"{j1_s3[m]},{j2_s3[m]}" for m in idx] plt.legend(frameon=False, ncol=2) plt.xticks(l_pos, l_name, fontsize=6) plt.xlabel(r"$j_{1},j_{2}$", fontsize=9) plt.ylabel(r"$S_{3}$", fontsize=9) # ----- S4 ----- plt.subplot(2, 2, 4) nidx = 0 l_pos = [] l_name = [] for i in np.unique(j1_s4): for j in np.unique(j2_s4): idx = np.where((j1_s4 == i) & (j2_s4 == j))[0] for k in range(min(4, N_orient)): for l in range(min(4, N_orient)): for m in range(min(4, N_orient)): if i == 0 and j == 0 and m == 0: plt.plot( j2_s4[idx] + j3_s4[idx] + nidx, S4[0, idx, k, l, m], symbol[l % len(symbol)], color=color[k % len(color)], label=rf"$\Theta = {k},{l},{m}$", ) else: plt.plot( j2_s4[idx] + j3_s4[idx] + nidx, S4[0, idx, k, l, m], symbol[l % len(symbol)], color=color[k % len(color)], ) l_pos += list(j2_s4[idx] + j3_s4[idx] + nidx) l_name += [f"{j1_s4[m]},{j2_s4[m]},{j3_s4[m]}" for m in idx] # increment nidx to separate groups of constant j1 sel = j1_s4 == i if np.any(sel): span = j2_s4[sel] + j3_s4[sel] nidx += int(np.max(span) - np.min(span) + 1) plt.legend(frameon=False, ncol=2) plt.xticks(l_pos, l_name, fontsize=6, rotation=90) plt.xlabel(r"$j_{1},j_{2},j_{3}$", fontsize=9) plt.ylabel(r"$S_{4}$", fontsize=9) plt.tight_layout() plt.show()