From 403d21b978097944b25ce375073a6b1c1a574a8f Mon Sep 17 00:00:00 2001 From: Zangwei Zheng Date: Sat, 23 Mar 2024 20:46:27 +0800 Subject: [PATCH] support iddpm inference with mask --- .../inference_long/16x256x256-long.py | 8 +++---- opensora/schedulers/dpms/__init__.py | 1 + opensora/schedulers/iddpm/__init__.py | 2 ++ .../schedulers/iddpm/gaussian_diffusion.py | 23 +++++++++++++++++-- 4 files changed, 28 insertions(+), 6 deletions(-) diff --git a/configs/opensora/inference_long/16x256x256-long.py b/configs/opensora/inference_long/16x256x256-long.py index 42967e7..de3c771 100644 --- a/configs/opensora/inference_long/16x256x256-long.py +++ b/configs/opensora/inference_long/16x256x256-long.py @@ -22,8 +22,8 @@ text_encoder = dict( model_max_length=120, ) scheduler = dict( - type="iddpm", - # type="dpm-solver", + # type="iddpm", + type="dpm-solver", num_sampling_steps=100, cfg_scale=7.0, ) @@ -39,9 +39,9 @@ loop = 10 condition_frame_length = 4 reference_path = ["assets/images/condition/wave.png"] mask_strategy = ["0,0,0,1,0"] # valid when reference_path is not None -# loop id, ref id, ref start, length, target start +# (loop id, ref id, ref start, length, target start) # Others batch_size = 2 seed = 42 -save_dir = "./samples/" +save_dir = "./outputs/samples/" diff --git a/opensora/schedulers/dpms/__init__.py b/opensora/schedulers/dpms/__init__.py index 084224b..6a1502c 100644 --- a/opensora/schedulers/dpms/__init__.py +++ b/opensora/schedulers/dpms/__init__.py @@ -21,6 +21,7 @@ class DMP_SOLVER: prompts, device, additional_args=None, + mask=None, ): n = len(prompts) model_args = text_encoder.encode(prompts) diff --git a/opensora/schedulers/iddpm/__init__.py b/opensora/schedulers/iddpm/__init__.py index 14d6027..f1fd49e 100644 --- a/opensora/schedulers/iddpm/__init__.py +++ b/opensora/schedulers/iddpm/__init__.py @@ -58,6 +58,7 @@ class IDDPM(SpacedDiffusion): prompts, device, additional_args=None, + mask=None, ): n = len(prompts) z = torch.cat([z, z], 0) @@ -76,6 +77,7 @@ class IDDPM(SpacedDiffusion): model_kwargs=model_args, progress=True, device=device, + mask=mask, ) samples, _ = samples.chunk(2, dim=0) return samples diff --git a/opensora/schedulers/iddpm/gaussian_diffusion.py b/opensora/schedulers/iddpm/gaussian_diffusion.py index 4a74592..173b851 100644 --- a/opensora/schedulers/iddpm/gaussian_diffusion.py +++ b/opensora/schedulers/iddpm/gaussian_diffusion.py @@ -15,15 +15,24 @@ import math import numpy as np import torch as th +from einops import rearrange from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl -def mean_flat(tensor): +def mean_flat(tensor, mask=None): """ Take the mean over all non-batch dimensions. """ - return tensor.mean(dim=list(range(1, len(tensor.shape)))) + if mask is None: + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + else: + assert tensor.dim() == 5 + assert tensor.shape[2] == mask.shape[1] + tensor = rearrange(tensor, "b c t h w -> b t (c h w)") + denom = mask.sum(dim=1) * tensor.shape[-1] + loss = (tensor * mask.unsqueeze(2)).sum(dim=1).sum(dim=1) / denom + return loss class ModelMeanType(enum.Enum): @@ -368,6 +377,7 @@ class GaussianDiffusion: denoised_fn=None, cond_fn=None, model_kwargs=None, + mask=None, ): """ Sample x_{t-1} from the model at the given timestep. @@ -398,6 +408,11 @@ class GaussianDiffusion: if cond_fn is not None: out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + if mask is not None: + if mask.shape[0] != x.shape[0]: + mask = mask.repeat(2, 1) # HACK + sample = th.where(mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1), sample, x) + return {"sample": sample, "pred_xstart": out["pred_xstart"]} def p_sample_loop( @@ -411,6 +426,7 @@ class GaussianDiffusion: model_kwargs=None, device=None, progress=False, + mask=None, ): """ Generate samples from the model. @@ -441,6 +457,7 @@ class GaussianDiffusion: model_kwargs=model_kwargs, device=device, progress=progress, + mask=mask, ): final = sample return final["sample"] @@ -456,6 +473,7 @@ class GaussianDiffusion: model_kwargs=None, device=None, progress=False, + mask=None, ): """ Generate samples from the model and yield intermediate samples from @@ -490,6 +508,7 @@ class GaussianDiffusion: denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, + mask=mask, ) yield out img = out["sample"]