mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-15 03:15:20 +02:00
support iddpm inference with mask
This commit is contained in:
parent
98e62a7c57
commit
403d21b978
|
|
@ -22,8 +22,8 @@ text_encoder = dict(
|
|||
model_max_length=120,
|
||||
)
|
||||
scheduler = dict(
|
||||
type="iddpm",
|
||||
# type="dpm-solver",
|
||||
# type="iddpm",
|
||||
type="dpm-solver",
|
||||
num_sampling_steps=100,
|
||||
cfg_scale=7.0,
|
||||
)
|
||||
|
|
@ -39,9 +39,9 @@ loop = 10
|
|||
condition_frame_length = 4
|
||||
reference_path = ["assets/images/condition/wave.png"]
|
||||
mask_strategy = ["0,0,0,1,0"] # valid when reference_path is not None
|
||||
# loop id, ref id, ref start, length, target start
|
||||
# (loop id, ref id, ref start, length, target start)
|
||||
|
||||
# Others
|
||||
batch_size = 2
|
||||
seed = 42
|
||||
save_dir = "./samples/"
|
||||
save_dir = "./outputs/samples/"
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ class DMP_SOLVER:
|
|||
prompts,
|
||||
device,
|
||||
additional_args=None,
|
||||
mask=None,
|
||||
):
|
||||
n = len(prompts)
|
||||
model_args = text_encoder.encode(prompts)
|
||||
|
|
|
|||
|
|
@ -58,6 +58,7 @@ class IDDPM(SpacedDiffusion):
|
|||
prompts,
|
||||
device,
|
||||
additional_args=None,
|
||||
mask=None,
|
||||
):
|
||||
n = len(prompts)
|
||||
z = torch.cat([z, z], 0)
|
||||
|
|
@ -76,6 +77,7 @@ class IDDPM(SpacedDiffusion):
|
|||
model_kwargs=model_args,
|
||||
progress=True,
|
||||
device=device,
|
||||
mask=mask,
|
||||
)
|
||||
samples, _ = samples.chunk(2, dim=0)
|
||||
return samples
|
||||
|
|
|
|||
|
|
@ -15,15 +15,24 @@ import math
|
|||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from einops import rearrange
|
||||
|
||||
from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
|
||||
|
||||
|
||||
def mean_flat(tensor):
|
||||
def mean_flat(tensor, mask=None):
|
||||
"""
|
||||
Take the mean over all non-batch dimensions.
|
||||
"""
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
if mask is None:
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
else:
|
||||
assert tensor.dim() == 5
|
||||
assert tensor.shape[2] == mask.shape[1]
|
||||
tensor = rearrange(tensor, "b c t h w -> b t (c h w)")
|
||||
denom = mask.sum(dim=1) * tensor.shape[-1]
|
||||
loss = (tensor * mask.unsqueeze(2)).sum(dim=1).sum(dim=1) / denom
|
||||
return loss
|
||||
|
||||
|
||||
class ModelMeanType(enum.Enum):
|
||||
|
|
@ -368,6 +377,7 @@ class GaussianDiffusion:
|
|||
denoised_fn=None,
|
||||
cond_fn=None,
|
||||
model_kwargs=None,
|
||||
mask=None,
|
||||
):
|
||||
"""
|
||||
Sample x_{t-1} from the model at the given timestep.
|
||||
|
|
@ -398,6 +408,11 @@ class GaussianDiffusion:
|
|||
if cond_fn is not None:
|
||||
out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
||||
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
|
||||
if mask is not None:
|
||||
if mask.shape[0] != x.shape[0]:
|
||||
mask = mask.repeat(2, 1) # HACK
|
||||
sample = th.where(mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1), sample, x)
|
||||
|
||||
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
||||
|
||||
def p_sample_loop(
|
||||
|
|
@ -411,6 +426,7 @@ class GaussianDiffusion:
|
|||
model_kwargs=None,
|
||||
device=None,
|
||||
progress=False,
|
||||
mask=None,
|
||||
):
|
||||
"""
|
||||
Generate samples from the model.
|
||||
|
|
@ -441,6 +457,7 @@ class GaussianDiffusion:
|
|||
model_kwargs=model_kwargs,
|
||||
device=device,
|
||||
progress=progress,
|
||||
mask=mask,
|
||||
):
|
||||
final = sample
|
||||
return final["sample"]
|
||||
|
|
@ -456,6 +473,7 @@ class GaussianDiffusion:
|
|||
model_kwargs=None,
|
||||
device=None,
|
||||
progress=False,
|
||||
mask=None,
|
||||
):
|
||||
"""
|
||||
Generate samples from the model and yield intermediate samples from
|
||||
|
|
@ -490,6 +508,7 @@ class GaussianDiffusion:
|
|||
denoised_fn=denoised_fn,
|
||||
cond_fn=cond_fn,
|
||||
model_kwargs=model_kwargs,
|
||||
mask=mask,
|
||||
)
|
||||
yield out
|
||||
img = out["sample"]
|
||||
|
|
|
|||
Loading…
Reference in a new issue