support iddpm inference with mask

This commit is contained in:
Zangwei Zheng 2024-03-23 20:46:27 +08:00
parent 98e62a7c57
commit 403d21b978
4 changed files with 28 additions and 6 deletions

View file

@ -22,8 +22,8 @@ text_encoder = dict(
model_max_length=120,
)
scheduler = dict(
type="iddpm",
# type="dpm-solver",
# type="iddpm",
type="dpm-solver",
num_sampling_steps=100,
cfg_scale=7.0,
)
@ -39,9 +39,9 @@ loop = 10
condition_frame_length = 4
reference_path = ["assets/images/condition/wave.png"]
mask_strategy = ["0,0,0,1,0"] # valid when reference_path is not None
# loop id, ref id, ref start, length, target start
# (loop id, ref id, ref start, length, target start)
# Others
batch_size = 2
seed = 42
save_dir = "./samples/"
save_dir = "./outputs/samples/"

View file

@ -21,6 +21,7 @@ class DMP_SOLVER:
prompts,
device,
additional_args=None,
mask=None,
):
n = len(prompts)
model_args = text_encoder.encode(prompts)

View file

@ -58,6 +58,7 @@ class IDDPM(SpacedDiffusion):
prompts,
device,
additional_args=None,
mask=None,
):
n = len(prompts)
z = torch.cat([z, z], 0)
@ -76,6 +77,7 @@ class IDDPM(SpacedDiffusion):
model_kwargs=model_args,
progress=True,
device=device,
mask=mask,
)
samples, _ = samples.chunk(2, dim=0)
return samples

View file

@ -15,15 +15,24 @@ import math
import numpy as np
import torch as th
from einops import rearrange
from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
def mean_flat(tensor):
def mean_flat(tensor, mask=None):
"""
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
if mask is None:
return tensor.mean(dim=list(range(1, len(tensor.shape))))
else:
assert tensor.dim() == 5
assert tensor.shape[2] == mask.shape[1]
tensor = rearrange(tensor, "b c t h w -> b t (c h w)")
denom = mask.sum(dim=1) * tensor.shape[-1]
loss = (tensor * mask.unsqueeze(2)).sum(dim=1).sum(dim=1) / denom
return loss
class ModelMeanType(enum.Enum):
@ -368,6 +377,7 @@ class GaussianDiffusion:
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
mask=None,
):
"""
Sample x_{t-1} from the model at the given timestep.
@ -398,6 +408,11 @@ class GaussianDiffusion:
if cond_fn is not None:
out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
if mask is not None:
if mask.shape[0] != x.shape[0]:
mask = mask.repeat(2, 1) # HACK
sample = th.where(mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1), sample, x)
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
def p_sample_loop(
@ -411,6 +426,7 @@ class GaussianDiffusion:
model_kwargs=None,
device=None,
progress=False,
mask=None,
):
"""
Generate samples from the model.
@ -441,6 +457,7 @@ class GaussianDiffusion:
model_kwargs=model_kwargs,
device=device,
progress=progress,
mask=mask,
):
final = sample
return final["sample"]
@ -456,6 +473,7 @@ class GaussianDiffusion:
model_kwargs=None,
device=None,
progress=False,
mask=None,
):
"""
Generate samples from the model and yield intermediate samples from
@ -490,6 +508,7 @@ class GaussianDiffusion:
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
mask=mask,
)
yield out
img = out["sample"]