sun_diffusion.utils module

Miscellaneous utilities.

sun_diffusion.utils.grab(x)

Detaches a torch.Tensor from the computational graph and loads it into the CPU memory as a numpy NDArray. Otherwise returns the input as is.

Parameters:

x (Tensor) – PyTorch tensor to be detached

Return type:

ndarray[tuple[int, ...], dtype[TypeVar(_ScalarType_co, bound= generic, covariant=True)]] | Any

Returns:

(NDArray) Detached NumPy array

sun_diffusion.utils.wrap(theta)

Wraps a non-compact input variable into the compact interval \([-\pi, \pi]\).

Parameters:

theta (Tensor, NDArray) – Non-compact, real-valued input

Return type:

Tensor | ndarray[tuple[int, ...], dtype[TypeVar(_ScalarType_co, bound= generic, covariant=True)]]

Returns:

Input folded into \([-\pi, \pi]\)

sun_diffusion.utils.roll(x, shifts, dims)

Bi-compatible wrapper for the roll function in NumPy and PyTorch.

Rolls a NumPy array or PyTorch tensor around a given dimension or set of dimensions given in dims by corresponding amount specified in shifts.

Parameters:
  • x (Tensor, NDArray) – Array or tensor to roll

  • shifts (int, tuple) – Shift amount(s), where negative values shift right

  • dims (int, tuple) – Axes or dims along which to shift

Return type:

ndarray[tuple[int, ...], dtype[TypeVar(_ScalarType_co, bound= generic, covariant=True)]] | Tensor

Returns:

Rolled array / tensor

sun_diffusion.utils.logsumexp(x, axis=0)

Bi-compatible wrapper for logsumexp in NumPy and PyTorch.

sun_diffusion.utils.logsumexp_signed(x, signs, axis)

Computes the log-sum-exp of x, accounting for element signs.

Parameters:
  • x (Tensor) – Logarithms of absolute values

  • signs (Tensor) – Signs for each element of x (+1 / -1)

  • axis (int) – Axis along which to compute the log-sum-exp

Returns:

Logarithm of the signed sum out_signs (Tensor): Sign of the sum (+1 or -1)

Return type:

out_logs (Tensor)

sun_diffusion.utils.compute_kl_div(logp, logq)

Computes the reverse KL divergence, assuming model samples x~q.

sun_diffusion.utils.compute_ess(logp, logq)

Computes the effective sample size from target and model likelihoods.

The ESS is given by

\[{\rm ESS} = \frac{\mathbb{E}[w_i]^2}{\mathbb{E}[w_i^2]} \in [0,1]\]

where \(w_i = \exp(\log{p}_i - \log{q}_i)\).