diff --git a/diffusion/gaussian_diffusion.py b/diffusion/gaussian_diffusion.py index cf734a6..c8e0b25 100644 --- a/diffusion/gaussian_diffusion.py +++ b/diffusion/gaussian_diffusion.py @@ -13,10 +13,12 @@ import torch as th from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl -def mean_flat(tensor): +def mean_flat(tensor, mask=None): """ Take the mean over all non-batch dimensions. """ + if mask is not None: + tensor = tensor * mask return tensor.mean(dim=list(range(1, len(tensor.shape)))) @@ -45,7 +47,9 @@ class ModelVarType(enum.Enum): class LossType(enum.Enum): MSE = enum.auto() # use raw MSE loss (and KL when learning variances) - RESCALED_MSE = enum.auto() # use raw MSE loss (with RESCALED_KL when learning variances) + RESCALED_MSE = ( + enum.auto() + ) # use raw MSE loss (with RESCALED_KL when learning variances) KL = enum.auto() # use the variational lower-bound RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB @@ -56,7 +60,9 @@ class LossType(enum.Enum): def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) warmup_time = int(num_diffusion_timesteps * warmup_frac) - betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) + betas[:warmup_time] = np.linspace( + beta_start, beta_end, warmup_time, dtype=np.float64 + ) return betas @@ -76,7 +82,9 @@ def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_time ** 2 ) elif beta_schedule == "linear": - betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) + betas = np.linspace( + beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 + ) elif beta_schedule == "warmup10": betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) elif beta_schedule == "warmup50": @@ -84,7 +92,9 @@ def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_time elif beta_schedule == "const": betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 - betas = 1.0 / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64) + betas = 1.0 / np.linspace( + num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 + ) else: raise NotImplementedError(beta_schedule) assert betas.shape == (num_diffusion_timesteps,) @@ -173,7 +183,9 @@ class GaussianDiffusion: self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) # calculations for posterior q(x_{t-1} | x_t, x_0) - self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_variance = ( + betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain self.posterior_log_variance_clipped = ( np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:])) @@ -181,8 +193,14 @@ class GaussianDiffusion: else np.array([]) ) - self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) - self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) + self.posterior_mean_coef1 = ( + betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + self.posterior_mean_coef2 = ( + (1.0 - self.alphas_cumprod_prev) + * np.sqrt(alphas) + / (1.0 - self.alphas_cumprod) + ) def q_mean_variance(self, x_start, t): """ @@ -191,9 +209,13 @@ class GaussianDiffusion: :param t: the number of diffusion steps (minus 1). Here, 0 means one step. :return: A tuple (mean, variance, log_variance), all of x_start's shape. """ - mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + mean = ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + ) variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) - log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor( + self.log_one_minus_alphas_cumprod, t, x_start.shape + ) return mean, variance, log_variance def q_sample(self, x_start, t, noise=None): @@ -210,7 +232,8 @@ class GaussianDiffusion: assert noise.shape == x_start.shape return ( _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start - + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise ) def q_posterior_mean_variance(self, x_start, x_t, t): @@ -224,7 +247,9 @@ class GaussianDiffusion: + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) - posterior_log_variance_clipped = _extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) assert ( posterior_mean.shape[0] == posterior_variance.shape[0] @@ -233,7 +258,9 @@ class GaussianDiffusion: ) return posterior_mean, posterior_variance, posterior_log_variance_clipped - def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None): + def p_mean_variance( + self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None + ): """ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of the initial x, x_0. @@ -267,7 +294,9 @@ class GaussianDiffusion: if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: assert model_output.shape == (B, C * 2, *x.shape[2:]) model_output, model_var_values = th.split(model_output, C, dim=1) - min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) + min_log = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x.shape + ) max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) # The model_var_values is [-1, 1] for [min_var, max_var]. frac = (model_var_values + 1) / 2 @@ -299,10 +328,16 @@ class GaussianDiffusion: if self.model_mean_type == ModelMeanType.START_X: pred_xstart = process_xstart(model_output) else: - pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)) - model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) + pred_xstart = process_xstart( + self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) + ) + model_mean, _, _ = self.q_posterior_mean_variance( + x_start=pred_xstart, x_t=x, t=t + ) - assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + assert ( + model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + ) return { "mean": model_mean, "variance": model_variance, @@ -320,7 +355,8 @@ class GaussianDiffusion: def _predict_eps_from_xstart(self, x_t, t, pred_xstart): return ( - _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - pred_xstart ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): @@ -332,7 +368,9 @@ class GaussianDiffusion: This uses the conditioning strategy from Sohl-Dickstein et al. (2015). """ gradient = cond_fn(x, t, **model_kwargs) - new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + new_mean = ( + p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + ) return new_mean def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): @@ -350,7 +388,9 @@ class GaussianDiffusion: out = p_mean_var.copy() out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) - out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) + out["mean"], _, _ = self.q_posterior_mean_variance( + x_start=out["pred_xstart"], x_t=x, t=t + ) return out def p_sample( @@ -388,9 +428,13 @@ class GaussianDiffusion: model_kwargs=model_kwargs, ) noise = th.randn_like(x) - nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0 + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 if cond_fn is not None: - out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) + out["mean"] = self.condition_mean( + cond_fn, out, x, t, model_kwargs=model_kwargs + ) sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise return {"sample": sample, "pred_xstart": out["pred_xstart"]} @@ -520,11 +564,20 @@ class GaussianDiffusion: alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) - sigma = eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev) + sigma = ( + eta + * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + * th.sqrt(1 - alpha_bar / alpha_bar_prev) + ) # Equation 12. noise = th.randn_like(x) - mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps - nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0 + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps + ) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 sample = mean_pred + nonzero_mask * sigma * noise return {"sample": sample, "pred_xstart": out["pred_xstart"]} @@ -556,12 +609,16 @@ class GaussianDiffusion: # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = ( - _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"] + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x + - out["pred_xstart"] ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) # Equation 12. reversed - mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_next) + + th.sqrt(1 - alpha_bar_next) * eps + ) return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} @@ -647,7 +704,9 @@ class GaussianDiffusion: yield out img = out["sample"] - def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None): + def _vb_terms_bpd( + self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None, mask=None + ): """ Get a term for the variational lower-bound. The resulting units are bits (rather than nats, as one might expect). @@ -656,23 +715,38 @@ class GaussianDiffusion: - 'output': a shape [N] tensor of NLLs or KLs. - 'pred_xstart': the x_0 predictions. """ - true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t) - out = self.p_mean_variance(model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs) - kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]) - kl = mean_flat(kl) / np.log(2.0) + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + ) + out = self.p_mean_variance( + model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs + ) + kl = normal_kl( + true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] + ) + kl = mean_flat(kl, mask=mask) / np.log(2.0) decoder_nll = -discretized_gaussian_log_likelihood( x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] ) assert decoder_nll.shape == x_start.shape - decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + decoder_nll = mean_flat(decoder_nll, mask=mask) / np.log(2.0) # At the first timestep return the decoder NLL, # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) output = th.where((t == 0), decoder_nll, kl) return {"output": output, "pred_xstart": out["pred_xstart"]} - def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): + def _expand_mask(self, mask, ndim: int): + assert mask.ndim == 2 + # [B, S] -> [B, 1, S, ...] + mask = mask.unsqueeze(1) + mask = mask.view(*mask.shape, *([1] * (ndim - mask.ndim))) + return mask + + def training_losses( + self, model, x_start, t, model_kwargs=None, noise=None, mask=None + ): """ Compute training losses for a single timestep. :param model: the model to evaluate loss on. @@ -692,6 +766,9 @@ class GaussianDiffusion: terms = {} + if mask is not None: + mask = self._expand_mask(mask, x_start.ndim) + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: terms["loss"] = self._vb_terms_bpd( model=model, @@ -700,6 +777,7 @@ class GaussianDiffusion: t=t, clip_denoised=False, model_kwargs=model_kwargs, + mask=mask, )["output"] if self.loss_type == LossType.RESCALED_KL: terms["loss"] *= self.num_timesteps @@ -722,6 +800,7 @@ class GaussianDiffusion: x_t=x_t, t=t, clip_denoised=False, + mask=mask, )["output"] if self.loss_type == LossType.RESCALED_MSE: # Divide by 1000 for equivalence with initial implementation. @@ -729,12 +808,14 @@ class GaussianDiffusion: terms["vb"] *= self.num_timesteps / 1000.0 target = { - ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0], + ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + )[0], ModelMeanType.START_X: x_start, ModelMeanType.EPSILON: noise, }[self.model_mean_type] assert model_output.shape == target.shape == x_start.shape - terms["mse"] = mean_flat((target - model_output) ** 2) + terms["mse"] = mean_flat((target - model_output) ** 2, mask=mask) if "vb" in terms: terms["loss"] = terms["mse"] + terms["vb"] else: @@ -755,7 +836,9 @@ class GaussianDiffusion: batch_size = x_start.shape[0] t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) - kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) + kl_prior = normal_kl( + mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 + ) return mean_flat(kl_prior) / np.log(2.0) def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): diff --git a/train.py b/train.py index 9fec3ee..5635c99 100644 --- a/train.py +++ b/train.py @@ -135,13 +135,16 @@ def main(args): for step, batch in enumerate(dataloader): batch = {k: v.to(get_current_device()) for k, v in batch.items()} video_inputs = batch.pop("video_latent_states") + mask = batch.pop("video_padding_mask") t = torch.randint( 0, diffusion.num_timesteps, (video_inputs.shape[0],), device=video_inputs.device, ) - loss_dict = diffusion.training_losses(model, video_inputs, t, batch) + loss_dict = diffusion.training_losses( + model, video_inputs, t, batch, mask=mask + ) loss = loss_dict["loss"].mean() booster.backward(loss, opt) opt.step()