mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-14 18:25:35 +02:00
90 lines
3.4 KiB
Python
90 lines
3.4 KiB
Python
# Adapted from DiT
|
|
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
# --------------------------------------------------------
|
|
# References:
|
|
# DiT: https://github.com/facebookresearch/DiT/tree/main
|
|
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
|
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
|
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
|
# --------------------------------------------------------
|
|
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
|
|
def normal_kl(mean1, logvar1, mean2, logvar2):
|
|
"""
|
|
Compute the KL divergence between two gaussians.
|
|
Shapes are automatically broadcasted, so batches can be compared to
|
|
scalars, among other use cases.
|
|
"""
|
|
tensor = None
|
|
for obj in (mean1, logvar1, mean2, logvar2):
|
|
if isinstance(obj, torch.Tensor):
|
|
tensor = obj
|
|
break
|
|
assert tensor is not None, "at least one argument must be a Tensor"
|
|
|
|
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
|
# Tensors, but it does not work for torch.exp().
|
|
logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)]
|
|
|
|
return 0.5 * (
|
|
-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
|
|
)
|
|
|
|
|
|
def approx_standard_normal_cdf(x):
|
|
"""
|
|
A fast approximation of the cumulative distribution function of the
|
|
standard normal.
|
|
"""
|
|
return 0.5 * (1.0 + torch.tanh(np.sqrt(2.0 / torch.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
|
|
|
|
|
def continuous_gaussian_log_likelihood(x, *, means, log_scales):
|
|
"""
|
|
Compute the log-likelihood of a continuous Gaussian distribution.
|
|
:param x: the targets
|
|
:param means: the Gaussian mean Tensor.
|
|
:param log_scales: the Gaussian log stddev Tensor.
|
|
:return: a tensor like x of log probabilities (in nats).
|
|
"""
|
|
centered_x = x - means
|
|
inv_stdv = torch.exp(-log_scales)
|
|
normalized_x = centered_x * inv_stdv
|
|
log_probs = torch.distributions.Normal(torch.zeros_like(x), torch.ones_like(x)).log_prob(normalized_x)
|
|
return log_probs
|
|
|
|
|
|
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
|
|
"""
|
|
Compute the log-likelihood of a Gaussian distribution discretizing to a
|
|
given image.
|
|
:param x: the target images. It is assumed that this was uint8 values,
|
|
rescaled to the range [-1, 1].
|
|
:param means: the Gaussian mean Tensor.
|
|
:param log_scales: the Gaussian log stddev Tensor.
|
|
:return: a tensor like x of log probabilities (in nats).
|
|
"""
|
|
assert x.shape == means.shape == log_scales.shape
|
|
centered_x = x - means
|
|
inv_stdv = torch.exp(-log_scales)
|
|
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
|
|
cdf_plus = approx_standard_normal_cdf(plus_in)
|
|
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
|
|
cdf_min = approx_standard_normal_cdf(min_in)
|
|
log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12))
|
|
log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12))
|
|
cdf_delta = cdf_plus - cdf_min
|
|
log_probs = torch.where(
|
|
x < -0.999,
|
|
log_cdf_plus,
|
|
torch.where(x > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))),
|
|
)
|
|
assert log_probs.shape == x.shape
|
|
return log_probs
|