mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
[feature] apply padding mask to loss (#7)
This commit is contained in:
parent
6f30a56d52
commit
603b193d38
|
|
@ -13,10 +13,12 @@ import torch as th
|
|||
from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
|
||||
|
||||
|
||||
def mean_flat(tensor):
|
||||
def mean_flat(tensor, mask=None):
|
||||
"""
|
||||
Take the mean over all non-batch dimensions.
|
||||
"""
|
||||
if mask is not None:
|
||||
tensor = tensor * mask
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
|
||||
|
||||
|
|
@ -45,7 +47,9 @@ class ModelVarType(enum.Enum):
|
|||
|
||||
class LossType(enum.Enum):
|
||||
MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
|
||||
RESCALED_MSE = enum.auto() # use raw MSE loss (with RESCALED_KL when learning variances)
|
||||
RESCALED_MSE = (
|
||||
enum.auto()
|
||||
) # use raw MSE loss (with RESCALED_KL when learning variances)
|
||||
KL = enum.auto() # use the variational lower-bound
|
||||
RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
|
||||
|
||||
|
|
@ -56,7 +60,9 @@ class LossType(enum.Enum):
|
|||
def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
|
||||
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
|
||||
warmup_time = int(num_diffusion_timesteps * warmup_frac)
|
||||
betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
|
||||
betas[:warmup_time] = np.linspace(
|
||||
beta_start, beta_end, warmup_time, dtype=np.float64
|
||||
)
|
||||
return betas
|
||||
|
||||
|
||||
|
|
@ -76,7 +82,9 @@ def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_time
|
|||
** 2
|
||||
)
|
||||
elif beta_schedule == "linear":
|
||||
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
|
||||
betas = np.linspace(
|
||||
beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
|
||||
)
|
||||
elif beta_schedule == "warmup10":
|
||||
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
|
||||
elif beta_schedule == "warmup50":
|
||||
|
|
@ -84,7 +92,9 @@ def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_time
|
|||
elif beta_schedule == "const":
|
||||
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
|
||||
elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
|
||||
betas = 1.0 / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64)
|
||||
betas = 1.0 / np.linspace(
|
||||
num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(beta_schedule)
|
||||
assert betas.shape == (num_diffusion_timesteps,)
|
||||
|
|
@ -173,7 +183,9 @@ class GaussianDiffusion:
|
|||
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
|
||||
|
||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||
self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
||||
self.posterior_variance = (
|
||||
betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
||||
)
|
||||
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
||||
self.posterior_log_variance_clipped = (
|
||||
np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:]))
|
||||
|
|
@ -181,8 +193,14 @@ class GaussianDiffusion:
|
|||
else np.array([])
|
||||
)
|
||||
|
||||
self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
||||
self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
|
||||
self.posterior_mean_coef1 = (
|
||||
betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
||||
)
|
||||
self.posterior_mean_coef2 = (
|
||||
(1.0 - self.alphas_cumprod_prev)
|
||||
* np.sqrt(alphas)
|
||||
/ (1.0 - self.alphas_cumprod)
|
||||
)
|
||||
|
||||
def q_mean_variance(self, x_start, t):
|
||||
"""
|
||||
|
|
@ -191,9 +209,13 @@ class GaussianDiffusion:
|
|||
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
||||
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
|
||||
"""
|
||||
mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
||||
mean = (
|
||||
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
||||
)
|
||||
variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
|
||||
log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
||||
log_variance = _extract_into_tensor(
|
||||
self.log_one_minus_alphas_cumprod, t, x_start.shape
|
||||
)
|
||||
return mean, variance, log_variance
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
|
|
@ -210,7 +232,8 @@ class GaussianDiffusion:
|
|||
assert noise.shape == x_start.shape
|
||||
return (
|
||||
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
||||
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||||
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
|
||||
* noise
|
||||
)
|
||||
|
||||
def q_posterior_mean_variance(self, x_start, x_t, t):
|
||||
|
|
@ -224,7 +247,9 @@ class GaussianDiffusion:
|
|||
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
||||
)
|
||||
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
|
||||
posterior_log_variance_clipped = _extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
|
||||
posterior_log_variance_clipped = _extract_into_tensor(
|
||||
self.posterior_log_variance_clipped, t, x_t.shape
|
||||
)
|
||||
assert (
|
||||
posterior_mean.shape[0]
|
||||
== posterior_variance.shape[0]
|
||||
|
|
@ -233,7 +258,9 @@ class GaussianDiffusion:
|
|||
)
|
||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||
|
||||
def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
|
||||
def p_mean_variance(
|
||||
self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
|
||||
):
|
||||
"""
|
||||
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
|
||||
the initial x, x_0.
|
||||
|
|
@ -267,7 +294,9 @@ class GaussianDiffusion:
|
|||
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
|
||||
assert model_output.shape == (B, C * 2, *x.shape[2:])
|
||||
model_output, model_var_values = th.split(model_output, C, dim=1)
|
||||
min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
|
||||
min_log = _extract_into_tensor(
|
||||
self.posterior_log_variance_clipped, t, x.shape
|
||||
)
|
||||
max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
|
||||
# The model_var_values is [-1, 1] for [min_var, max_var].
|
||||
frac = (model_var_values + 1) / 2
|
||||
|
|
@ -299,10 +328,16 @@ class GaussianDiffusion:
|
|||
if self.model_mean_type == ModelMeanType.START_X:
|
||||
pred_xstart = process_xstart(model_output)
|
||||
else:
|
||||
pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output))
|
||||
model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
|
||||
pred_xstart = process_xstart(
|
||||
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
|
||||
)
|
||||
model_mean, _, _ = self.q_posterior_mean_variance(
|
||||
x_start=pred_xstart, x_t=x, t=t
|
||||
)
|
||||
|
||||
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
|
||||
assert (
|
||||
model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
|
||||
)
|
||||
return {
|
||||
"mean": model_mean,
|
||||
"variance": model_variance,
|
||||
|
|
@ -320,7 +355,8 @@ class GaussianDiffusion:
|
|||
|
||||
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
||||
return (
|
||||
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
|
||||
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
||||
- pred_xstart
|
||||
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
||||
|
||||
def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
||||
|
|
@ -332,7 +368,9 @@ class GaussianDiffusion:
|
|||
This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
|
||||
"""
|
||||
gradient = cond_fn(x, t, **model_kwargs)
|
||||
new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
|
||||
new_mean = (
|
||||
p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
|
||||
)
|
||||
return new_mean
|
||||
|
||||
def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
||||
|
|
@ -350,7 +388,9 @@ class GaussianDiffusion:
|
|||
|
||||
out = p_mean_var.copy()
|
||||
out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
|
||||
out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
|
||||
out["mean"], _, _ = self.q_posterior_mean_variance(
|
||||
x_start=out["pred_xstart"], x_t=x, t=t
|
||||
)
|
||||
return out
|
||||
|
||||
def p_sample(
|
||||
|
|
@ -388,9 +428,13 @@ class GaussianDiffusion:
|
|||
model_kwargs=model_kwargs,
|
||||
)
|
||||
noise = th.randn_like(x)
|
||||
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0
|
||||
nonzero_mask = (
|
||||
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
||||
) # no noise when t == 0
|
||||
if cond_fn is not None:
|
||||
out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
||||
out["mean"] = self.condition_mean(
|
||||
cond_fn, out, x, t, model_kwargs=model_kwargs
|
||||
)
|
||||
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
|
||||
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
||||
|
||||
|
|
@ -520,11 +564,20 @@ class GaussianDiffusion:
|
|||
|
||||
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
||||
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
|
||||
sigma = eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev)
|
||||
sigma = (
|
||||
eta
|
||||
* th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
|
||||
* th.sqrt(1 - alpha_bar / alpha_bar_prev)
|
||||
)
|
||||
# Equation 12.
|
||||
noise = th.randn_like(x)
|
||||
mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
|
||||
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0
|
||||
mean_pred = (
|
||||
out["pred_xstart"] * th.sqrt(alpha_bar_prev)
|
||||
+ th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
|
||||
)
|
||||
nonzero_mask = (
|
||||
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
||||
) # no noise when t == 0
|
||||
sample = mean_pred + nonzero_mask * sigma * noise
|
||||
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
||||
|
||||
|
|
@ -556,12 +609,16 @@ class GaussianDiffusion:
|
|||
# Usually our model outputs epsilon, but we re-derive it
|
||||
# in case we used x_start or x_prev prediction.
|
||||
eps = (
|
||||
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"]
|
||||
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
|
||||
- out["pred_xstart"]
|
||||
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
|
||||
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
|
||||
|
||||
# Equation 12. reversed
|
||||
mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
|
||||
mean_pred = (
|
||||
out["pred_xstart"] * th.sqrt(alpha_bar_next)
|
||||
+ th.sqrt(1 - alpha_bar_next) * eps
|
||||
)
|
||||
|
||||
return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
|
||||
|
||||
|
|
@ -647,7 +704,9 @@ class GaussianDiffusion:
|
|||
yield out
|
||||
img = out["sample"]
|
||||
|
||||
def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None):
|
||||
def _vb_terms_bpd(
|
||||
self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None, mask=None
|
||||
):
|
||||
"""
|
||||
Get a term for the variational lower-bound.
|
||||
The resulting units are bits (rather than nats, as one might expect).
|
||||
|
|
@ -656,23 +715,38 @@ class GaussianDiffusion:
|
|||
- 'output': a shape [N] tensor of NLLs or KLs.
|
||||
- 'pred_xstart': the x_0 predictions.
|
||||
"""
|
||||
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)
|
||||
out = self.p_mean_variance(model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs)
|
||||
kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"])
|
||||
kl = mean_flat(kl) / np.log(2.0)
|
||||
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
|
||||
x_start=x_start, x_t=x_t, t=t
|
||||
)
|
||||
out = self.p_mean_variance(
|
||||
model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
|
||||
)
|
||||
kl = normal_kl(
|
||||
true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
|
||||
)
|
||||
kl = mean_flat(kl, mask=mask) / np.log(2.0)
|
||||
|
||||
decoder_nll = -discretized_gaussian_log_likelihood(
|
||||
x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
|
||||
)
|
||||
assert decoder_nll.shape == x_start.shape
|
||||
decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
|
||||
decoder_nll = mean_flat(decoder_nll, mask=mask) / np.log(2.0)
|
||||
|
||||
# At the first timestep return the decoder NLL,
|
||||
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
|
||||
output = th.where((t == 0), decoder_nll, kl)
|
||||
return {"output": output, "pred_xstart": out["pred_xstart"]}
|
||||
|
||||
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
|
||||
def _expand_mask(self, mask, ndim: int):
|
||||
assert mask.ndim == 2
|
||||
# [B, S] -> [B, 1, S, ...]
|
||||
mask = mask.unsqueeze(1)
|
||||
mask = mask.view(*mask.shape, *([1] * (ndim - mask.ndim)))
|
||||
return mask
|
||||
|
||||
def training_losses(
|
||||
self, model, x_start, t, model_kwargs=None, noise=None, mask=None
|
||||
):
|
||||
"""
|
||||
Compute training losses for a single timestep.
|
||||
:param model: the model to evaluate loss on.
|
||||
|
|
@ -692,6 +766,9 @@ class GaussianDiffusion:
|
|||
|
||||
terms = {}
|
||||
|
||||
if mask is not None:
|
||||
mask = self._expand_mask(mask, x_start.ndim)
|
||||
|
||||
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
|
||||
terms["loss"] = self._vb_terms_bpd(
|
||||
model=model,
|
||||
|
|
@ -700,6 +777,7 @@ class GaussianDiffusion:
|
|||
t=t,
|
||||
clip_denoised=False,
|
||||
model_kwargs=model_kwargs,
|
||||
mask=mask,
|
||||
)["output"]
|
||||
if self.loss_type == LossType.RESCALED_KL:
|
||||
terms["loss"] *= self.num_timesteps
|
||||
|
|
@ -722,6 +800,7 @@ class GaussianDiffusion:
|
|||
x_t=x_t,
|
||||
t=t,
|
||||
clip_denoised=False,
|
||||
mask=mask,
|
||||
)["output"]
|
||||
if self.loss_type == LossType.RESCALED_MSE:
|
||||
# Divide by 1000 for equivalence with initial implementation.
|
||||
|
|
@ -729,12 +808,14 @@ class GaussianDiffusion:
|
|||
terms["vb"] *= self.num_timesteps / 1000.0
|
||||
|
||||
target = {
|
||||
ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0],
|
||||
ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
|
||||
x_start=x_start, x_t=x_t, t=t
|
||||
)[0],
|
||||
ModelMeanType.START_X: x_start,
|
||||
ModelMeanType.EPSILON: noise,
|
||||
}[self.model_mean_type]
|
||||
assert model_output.shape == target.shape == x_start.shape
|
||||
terms["mse"] = mean_flat((target - model_output) ** 2)
|
||||
terms["mse"] = mean_flat((target - model_output) ** 2, mask=mask)
|
||||
if "vb" in terms:
|
||||
terms["loss"] = terms["mse"] + terms["vb"]
|
||||
else:
|
||||
|
|
@ -755,7 +836,9 @@ class GaussianDiffusion:
|
|||
batch_size = x_start.shape[0]
|
||||
t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
|
||||
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
|
||||
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
|
||||
kl_prior = normal_kl(
|
||||
mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
|
||||
)
|
||||
return mean_flat(kl_prior) / np.log(2.0)
|
||||
|
||||
def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
|
||||
|
|
|
|||
5
train.py
5
train.py
|
|
@ -135,13 +135,16 @@ def main(args):
|
|||
for step, batch in enumerate(dataloader):
|
||||
batch = {k: v.to(get_current_device()) for k, v in batch.items()}
|
||||
video_inputs = batch.pop("video_latent_states")
|
||||
mask = batch.pop("video_padding_mask")
|
||||
t = torch.randint(
|
||||
0,
|
||||
diffusion.num_timesteps,
|
||||
(video_inputs.shape[0],),
|
||||
device=video_inputs.device,
|
||||
)
|
||||
loss_dict = diffusion.training_losses(model, video_inputs, t, batch)
|
||||
loss_dict = diffusion.training_losses(
|
||||
model, video_inputs, t, batch, mask=mask
|
||||
)
|
||||
loss = loss_dict["loss"].mean()
|
||||
booster.backward(loss, opt)
|
||||
opt.step()
|
||||
|
|
|
|||
Loading…
Reference in a new issue