# -*- coding: utf-8 -*-
"""
Torch backend for STL.
Provides a minimal API used across the codebase, with optional device
selection (CPU / GPU) via a `device` argument.
Exposed functions
-----------------
- from_numpy(x, device=None, dtype=None)
- zeros(shape, device=None, dtype=torch.float32)
- mean(x, dim)
- dim(x)
- shape(x, axis=None)
- nan
"""
import numpy as np
import torch
# ---------------------------------------------------------------------
# Device handling
# ---------------------------------------------------------------------
# Global default device: GPU if available, else CPU
_DEFAULT_DEVICE = torch.device(
"cuda"
if torch.cuda.is_available()
else (torch.device("mps") if torch.backends.mps.is_available() else "cpu")
)
_DEFAULT_DTYPE = (
torch.float64 if _DEFAULT_DEVICE != torch.device("mps") else torch.float32
)
_DEFAULT_COMPLEX_DTYPE = (
torch.complex128 if _DEFAULT_DTYPE == torch.float64 else torch.complex64
)
def _get_device(device=None) -> torch.device:
"""
Resolve a device spec into a torch.device.
device:
- None -> _DEFAULT_DEVICE
- "gpu" -> "cuda" if available, else ("mps" if available, else "cpu")
- "cuda", "cuda:0", "cpu", ... -> passed to torch.device
- torch.device -> returned as-is
"""
if device is None:
return _DEFAULT_DEVICE
if isinstance(device, torch.device):
return device
if isinstance(device, str):
d = device.lower()
if d in ("gpu", "cuda", "mps"):
return torch.device(
"cuda"
if torch.cuda.is_available()
else (
torch.device("mps") if torch.backends.mps.is_available() else "cpu"
)
)
if d == "cpu":
return torch.device("cpu")
# Let torch.device handle strings like "cuda:0"
return torch.device(device)
# Fallback: let torch.device figure it out / raise
return torch.device(device)
def _get_dtype(dtype=None, device=None) -> torch.dtype:
"""
Adapts input dtype to device.
Indeed, so far, MPS is not compatible with float64.
So if input dtype is float64 and device is MPS, return float32.
:param dtype: input torch dtype
:param device: input torch device
:return: closest torch dtype compatible with device
:rtype: torch dtype
"""
dtype = _DEFAULT_DTYPE if dtype is None else dtype
if device == torch.device("mps"):
if dtype == torch.float64:
print(
"WARNING: torch.float64 not supported on MPS device. Switching to torch.float32."
)
dtype = torch.float32
else:
pass # will probably need to be handled in the future
return dtype
# ---------------------------------------------------------------------
# Core API
# ---------------------------------------------------------------------
def _from_numpy(x: np.ndarray, device=None, dtype=None):
"""
Convert a NumPy array to a torch.Tensor on the requested device.
Parameters
----------
x : np.ndarray
device : None, str, or torch.device
If "gpu"/"cuda"/"mps" -> CUDA (if available), else MPS (if available), otherwise CPU.
If "cpu" -> CPU.
If None -> default device (_DEFAULT_DEVICE).
dtype : torch.dtype or None
If not None, cast to this dtype.
"""
t = torch.from_numpy(x)
dtype = t.dtype if dtype is None else dtype
device = _get_device(
device
) # returns device if device is not None else finds a suitable device
fix_dtype = _get_dtype(
dtype=dtype, device=device
) # matches the input array dtype to the device
return t.to(device=device, dtype=fix_dtype)
[docs]
def to_torch_tensor(array):
"""
Transform input array (NumPy or PyTorch) into a PyTorch tensor.
Parameters
----------
array : np.ndarray or torch.Tensor
Input array to be converted.
Returns
-------
torch.Tensor
Converted PyTorch tensor.
"""
if isinstance(array, np.ndarray):
return _from_numpy(array)
elif isinstance(array, torch.Tensor):
device = _get_device() # Choose device: use GPU if available, otherwise CPU
fix_dtype = _get_dtype(
dtype=array.dtype, device=device
) # matches the input tensor dtype to the device
return array.to(device=device, dtype=fix_dtype)
else:
raise TypeError(f"Unsupported array type: {type(array)}")
[docs]
def zeros(shape, device=None, dtype=None):
"""
Return a tensor of zeros on CPU or GPU depending on `device`.
Parameters
----------
shape : tuple or list
device : None, str, or torch.device
dtype : torch.dtype
"""
device = _get_device(device)
fix_dtype = _get_dtype(dtype=dtype, device=device)
return torch.zeros(shape, dtype=fix_dtype, device=device)
[docs]
def ones(shape, device=None, dtype=None):
"""
Return a tensor of ones on CPU or GPU depending on `device`.
Parameters
----------
shape : tuple or list
device : None, str, or torch.device
dtype : torch.dtype
"""
device = _get_device(device)
fix_dtype = _get_dtype(dtype=dtype, device=device)
return torch.ones(shape, dtype=fix_dtype, device=device)
[docs]
def eye(n, device=None, dtype=None):
"""
Return a tensor of ones on CPU or GPU depending on `device`.
Parameters
----------
shape : tuple or list
device : None, str, or torch.device
dtype : torch.dtype
"""
device = _get_device(device)
fix_dtype = _get_dtype(dtype=dtype, device=device)
return torch.eye(n, dtype=fix_dtype, device=device)
[docs]
def maskmean(x, dim=(-2, -1), mask=None):
"""
Compute the mean of x along given dims, optionally masked.
If mask is given, assumes mask.shape == x.shape[-mask.ndim:],
flattens the masked dimensions and makes sure these were included in the input: dim.
Args:
x: input tensor
mask: boolean tensor, same shape as the last dimensions of x
dim: int or tuple of ints along which to compute the mean.
Returns:
Tensor with the mean
"""
if mask is None:
return x.mean() if dim is None else x.mean(dim=dim)
else:
assert dim == (-2, -1)
assert (
mask.shape[-len(dim) :] == x.shape[-len(dim) :]
) # check mask shape matches x shape on masked dims
# assert x[..., ~mask].isnan().sum() == 0 ############################################### sanity check to remove in the future
x_masked = torch.where(~mask, x, 0.0)
count = (
(~mask).sum() if dim is None else (~mask).sum(dim=dim)
) ########################################## can be improved in the future by normalizing masks once for all beforehand
# assert count != 0, "mask is full of NaNs, cannot compute mean"
return (x_masked.sum() if dim is None else x_masked.sum(dim=dim)) / count
[docs]
def dim(x) -> int:
"""
Return the number of dimensions of a tensor-like object.
"""
if hasattr(x, "dim"):
return x.dim()
return np.array(x).ndim
[docs]
def shape(x, axis=None):
"""
Return the shape of `x`, or the size along a given axis.
Parameters
----------
x : torch.Tensor or np.ndarray
axis : int or None
If None, return the full shape tuple.
Otherwise, return the size along the given axis.
"""
s = x.shape
return s if axis is None else s[axis]
# ---------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------
# Scalar NaN; e.g. bk.zeros(...)+bk.nan -> NaN-filled tensor
nan = torch.nan