sun_diffusion.utils module¶
Miscellaneous utilities.
- sun_diffusion.utils.grab(x)¶
Detaches a torch.Tensor from the computational graph and loads it into the CPU memory as a numpy NDArray. Otherwise returns the input as is.
- Parameters:
x (Tensor) – PyTorch tensor to be detached
- Return type:
ndarray[tuple[int,...],dtype[TypeVar(_ScalarType_co, bound=generic, covariant=True)]] |Any- Returns:
(NDArray) Detached NumPy array
- sun_diffusion.utils.wrap(theta)¶
Wraps a non-compact input variable into the compact interval \([-\pi, \pi]\).
- Parameters:
theta (Tensor, NDArray) – Non-compact, real-valued input
- Return type:
Tensor|ndarray[tuple[int,...],dtype[TypeVar(_ScalarType_co, bound=generic, covariant=True)]]- Returns:
Input folded into \([-\pi, \pi]\)
- sun_diffusion.utils.roll(x, shifts, dims)¶
Bi-compatible wrapper for the roll function in NumPy and PyTorch.
Rolls a NumPy array or PyTorch tensor around a given dimension or set of dimensions given in dims by corresponding amount specified in shifts.
- Parameters:
x (Tensor, NDArray) – Array or tensor to roll
shifts (int, tuple) – Shift amount(s), where negative values shift right
dims (int, tuple) – Axes or dims along which to shift
- Return type:
ndarray[tuple[int,...],dtype[TypeVar(_ScalarType_co, bound=generic, covariant=True)]] |Tensor- Returns:
Rolled array / tensor
- sun_diffusion.utils.logsumexp(x, axis=0)¶
Bi-compatible wrapper for logsumexp in NumPy and PyTorch.
- sun_diffusion.utils.logsumexp_signed(x, signs, axis)¶
Computes the log-sum-exp of x, accounting for element signs.
- Parameters:
x (Tensor) – Logarithms of absolute values
signs (Tensor) – Signs for each element of x (+1 / -1)
axis (int) – Axis along which to compute the log-sum-exp
- Returns:
Logarithm of the signed sum out_signs (Tensor): Sign of the sum (+1 or -1)
- Return type:
out_logs (Tensor)
- sun_diffusion.utils.compute_kl_div(logp, logq)¶
Computes the reverse KL divergence, assuming model samples x~q.
- sun_diffusion.utils.compute_ess(logp, logq)¶
Computes the effective sample size from target and model likelihoods.
The ESS is given by
\[{\rm ESS} = \frac{\mathbb{E}[w_i]^2}{\mathbb{E}[w_i^2]} \in [0,1]\]where \(w_i = \exp(\log{p}_i - \log{q}_i)\).