mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-02-22 21:43:19 +01:00
* upload v2.0 * update docs * [hotfix] fit latest fa3 (#802) * update readme * update readme * update readme * update train readme * update readme * update readme: motion score * cleaning video dc ae WIP * update config * add dependency functions * undo cleaning * use latest dcae * complete high compression training * update hcae config * cleaned up vae * update ae.md * further cleanup * update vae & ae paths * align naming of ae * [hotfix] fix ring attn bwd for fa3 (#803) * train ae default without wandb * update config * update evaluation results * added hcae report * update readme * update readme demo * update readme demo * update readme gif * display demo directly in readme * update paper * delete files --------- Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: Shen-Chenhui <shen_chenhui@u.nus.edu> Co-authored-by: wuxiwen <wuxiwen.simon@gmail.com>
727 lines
22 KiB
Python
727 lines
22 KiB
Python
import math
|
|
import os
|
|
import random
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass, replace
|
|
|
|
import torch
|
|
from einops import rearrange, repeat
|
|
from mmengine.config import Config
|
|
from peft import PeftModel
|
|
from torch import Tensor, nn
|
|
|
|
from opensora.datasets.aspect import get_image_size
|
|
from opensora.models.mmdit.model import MMDiTModel
|
|
from opensora.models.text.conditioner import HFEmbedder
|
|
from opensora.registry import MODELS, build_module
|
|
from opensora.utils.inference import (
|
|
SamplingMethod,
|
|
collect_references_batch,
|
|
prepare_inference_condition,
|
|
)
|
|
|
|
# ======================================================
|
|
# Sampling Options
|
|
# ======================================================
|
|
|
|
|
|
@dataclass
|
|
class SamplingOption:
|
|
# The width of the image/video.
|
|
width: int | None = None
|
|
|
|
# The height of the image/video.
|
|
height: int | None = None
|
|
|
|
# The resolution of the image/video. If provided, it will override the height and width.
|
|
resolution: str | None = None
|
|
|
|
# The aspect ratio of the image/video. If provided, it will override the height and width.
|
|
aspect_ratio: str | None = None
|
|
|
|
# The number of frames.
|
|
num_frames: int = 1
|
|
|
|
# The number of sampling steps.
|
|
num_steps: int = 50
|
|
|
|
# The classifier-free guidance (text).
|
|
guidance: float = 4.0
|
|
|
|
# use oscillation for text guidance
|
|
text_osci: bool = False
|
|
|
|
# The classifier-free guidance (image), or for the guidance on condition for i2v and v2v
|
|
guidance_img: float | None = None
|
|
|
|
# use oscillation for image guidance
|
|
image_osci: bool = False
|
|
|
|
# use temporal scaling for image guidance
|
|
scale_temporal_osci: bool = False
|
|
|
|
# The seed for the random number generator.
|
|
seed: int | None = None
|
|
|
|
# Whether to shift the schedule.
|
|
shift: bool = True
|
|
|
|
# The sampling method.
|
|
method: str | SamplingMethod = SamplingMethod.I2V
|
|
|
|
# Temporal reduction
|
|
temporal_reduction: int = 1
|
|
|
|
# is causal vae
|
|
is_causal_vae: bool = False
|
|
|
|
# flow shift
|
|
flow_shift: float | None = None
|
|
|
|
|
|
def sanitize_sampling_option(sampling_option: SamplingOption) -> SamplingOption:
|
|
"""
|
|
Sanitize the sampling options.
|
|
|
|
Args:
|
|
sampling_option (SamplingOption): The sampling options.
|
|
|
|
Returns:
|
|
SamplingOption: The sanitized sampling options.
|
|
"""
|
|
if (
|
|
sampling_option.resolution is not None
|
|
or sampling_option.aspect_ratio is not None
|
|
):
|
|
assert (
|
|
sampling_option.resolution is not None
|
|
and sampling_option.aspect_ratio is not None
|
|
), "Both resolution and aspect ratio must be provided"
|
|
resolution = sampling_option.resolution
|
|
aspect_ratio = sampling_option.aspect_ratio
|
|
height, width = get_image_size(resolution, aspect_ratio, training=False)
|
|
else:
|
|
assert (
|
|
sampling_option.height is not None and sampling_option.width is not None
|
|
), "Both height and width must be provided"
|
|
height, width = sampling_option.height, sampling_option.width
|
|
|
|
height = (height // 16 + (1 if height % 16 else 0)) * 16
|
|
width = (width // 16 + (1 if width % 16 else 0)) * 16
|
|
replace_dict = dict(height=height, width=width)
|
|
|
|
if isinstance(sampling_option.method, str):
|
|
method = SamplingMethod(sampling_option.method)
|
|
replace_dict["method"] = method
|
|
|
|
return replace(sampling_option, **replace_dict)
|
|
|
|
|
|
def get_oscillation_gs(guidance_scale: float, i: int, force_num=10):
|
|
"""
|
|
get oscillation guidance for cfg.
|
|
|
|
Args:
|
|
guidance_scale: original guidance value
|
|
i: denoising step
|
|
force_num: before which don't apply oscillation
|
|
"""
|
|
if i < force_num or (i >= force_num and i % 2 == 0):
|
|
gs = guidance_scale
|
|
else:
|
|
gs = 1.0
|
|
return gs
|
|
|
|
|
|
# ======================================================
|
|
# Denoising
|
|
# ======================================================
|
|
|
|
|
|
class Denoiser(ABC):
|
|
@abstractmethod
|
|
def denoise(self, model: MMDiTModel, **kwargs) -> Tensor:
|
|
"""Denoise the input."""
|
|
|
|
@abstractmethod
|
|
def prepare_guidance(
|
|
self,
|
|
text: list[str],
|
|
optional_models: dict[str, nn.Module],
|
|
device: torch.device,
|
|
dtype: torch.dtype,
|
|
**kwargs,
|
|
) -> dict[str, Tensor]:
|
|
"""Prepare the guidance for the model. This method will alter text."""
|
|
|
|
|
|
class I2VDenoiser(Denoiser):
|
|
def denoise(self, model: MMDiTModel, **kwargs) -> Tensor:
|
|
img = kwargs.pop("img")
|
|
timesteps = kwargs.pop("timesteps")
|
|
guidance = kwargs.pop("guidance")
|
|
guidance_img = kwargs.pop("guidance_img")
|
|
|
|
# cond ref arguments
|
|
masks = kwargs.pop("masks")
|
|
masked_ref = kwargs.pop("masked_ref")
|
|
kwargs.pop("sigma_min")
|
|
|
|
# oscillation guidance
|
|
text_osci = kwargs.pop("text_osci", False)
|
|
image_osci = kwargs.pop("image_osci", False)
|
|
scale_temporal_osci = kwargs.pop("scale_temporal_osci", False)
|
|
|
|
# patch size
|
|
patch_size = kwargs.pop("patch_size", 2)
|
|
|
|
guidance_vec = torch.full(
|
|
(img.shape[0],), guidance, device=img.device, dtype=img.dtype
|
|
)
|
|
for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
|
|
# timesteps
|
|
t_vec = torch.full(
|
|
(img.shape[0],), t_curr, dtype=img.dtype, device=img.device
|
|
)
|
|
b, c, t, w, h = masked_ref.size()
|
|
cond = torch.cat((masks, masked_ref), dim=1)
|
|
cond = pack(cond, patch_size=patch_size)
|
|
kwargs["cond"] = torch.cat([cond, cond, torch.zeros_like(cond)], dim=0)
|
|
|
|
# forward preparation
|
|
cond_x = img[: len(img) // 3]
|
|
|
|
img = torch.cat([cond_x, cond_x, cond_x], dim=0)
|
|
# forward
|
|
pred = model(
|
|
img=img,
|
|
**kwargs,
|
|
timesteps=t_vec,
|
|
guidance=guidance_vec,
|
|
)
|
|
|
|
# prepare guidance
|
|
text_gs = get_oscillation_gs(guidance, i) if text_osci else guidance
|
|
image_gs = (
|
|
get_oscillation_gs(guidance_img, i) if image_osci else guidance_img
|
|
)
|
|
cond, uncond, uncond_2 = pred.chunk(3, dim=0)
|
|
if image_gs > 1.0 and scale_temporal_osci:
|
|
# image_gs decrease with each denoising step
|
|
step_upper_image_gs = torch.linspace(image_gs, 1.0, len(timesteps))[i]
|
|
# image_gs increase along the temporal axis of the latent video
|
|
image_gs = torch.linspace(1.0, step_upper_image_gs, t)[
|
|
None, None, :, None, None
|
|
].repeat(b, c, 1, h, w)
|
|
image_gs = pack(image_gs, patch_size=patch_size).to(cond.device, cond.dtype)
|
|
|
|
# update
|
|
pred = uncond_2 + image_gs * (uncond - uncond_2) + text_gs * (cond - uncond)
|
|
pred = torch.cat([pred, pred, pred], dim=0)
|
|
|
|
img = img + (t_prev - t_curr) * pred
|
|
|
|
img = img[: len(img) // 3]
|
|
|
|
return img
|
|
|
|
def prepare_guidance(
|
|
self,
|
|
text: list[str],
|
|
optional_models: dict[str, nn.Module],
|
|
device: torch.device,
|
|
dtype: torch.dtype,
|
|
**kwargs,
|
|
) -> tuple[list[str], dict[str, Tensor]]:
|
|
ret = {}
|
|
|
|
neg = kwargs.get("neg", None)
|
|
ret["guidance_img"] = kwargs.pop("guidance_img")
|
|
|
|
# text
|
|
if neg is None:
|
|
neg = [""] * len(text)
|
|
text = text + neg + neg
|
|
return text, ret
|
|
|
|
|
|
class DistilledDenoiser(Denoiser):
|
|
def denoise(self, model: MMDiTModel, **kwargs) -> Tensor:
|
|
img = kwargs.pop("img")
|
|
timesteps = kwargs.pop("timesteps")
|
|
guidance = kwargs.pop("guidance")
|
|
|
|
guidance_vec = torch.full(
|
|
(img.shape[0],), guidance, device=img.device, dtype=img.dtype
|
|
)
|
|
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
|
|
# timesteps
|
|
t_vec = torch.full(
|
|
(img.shape[0],), t_curr, dtype=img.dtype, device=img.device
|
|
)
|
|
# forward
|
|
pred = model(
|
|
img=img,
|
|
**kwargs,
|
|
timesteps=t_vec,
|
|
guidance=guidance_vec,
|
|
)
|
|
# update
|
|
img = img + (t_prev - t_curr) * pred
|
|
return img
|
|
|
|
def prepare_guidance(
|
|
self,
|
|
text: list[str],
|
|
optional_models: dict[str, nn.Module],
|
|
device: torch.device,
|
|
dtype: torch.dtype,
|
|
**kwargs,
|
|
) -> tuple[list[str], dict[str, Tensor]]:
|
|
return text, {}
|
|
|
|
|
|
SamplingMethodDict = {
|
|
SamplingMethod.I2V: I2VDenoiser(),
|
|
SamplingMethod.DISTILLED: DistilledDenoiser(),
|
|
}
|
|
|
|
|
|
# ======================================================
|
|
# Timesteps
|
|
# ======================================================
|
|
|
|
|
|
def time_shift(alpha: float, t: Tensor) -> Tensor:
|
|
return alpha * t / (1 + (alpha - 1) * t)
|
|
|
|
|
|
def get_res_lin_function(
|
|
x1: float = 256, y1: float = 1, x2: float = 4096, y2: float = 3
|
|
) -> callable:
|
|
m = (y2 - y1) / (x2 - x1)
|
|
b = y1 - m * x1
|
|
return lambda x: m * x + b
|
|
|
|
|
|
def get_schedule(
|
|
num_steps: int,
|
|
image_seq_len: int,
|
|
num_frames: int,
|
|
shift_alpha: float | None = None,
|
|
base_shift: float = 1,
|
|
max_shift: float = 3,
|
|
shift: bool = True,
|
|
) -> list[float]:
|
|
# extra step for zero
|
|
timesteps = torch.linspace(1, 0, num_steps + 1)
|
|
|
|
# shifting the schedule to favor high timesteps for higher signal images
|
|
if shift:
|
|
if shift_alpha is None:
|
|
# estimate mu based on linear estimation between two points
|
|
# spatial scale
|
|
shift_alpha = get_res_lin_function(y1=base_shift, y2=max_shift)(
|
|
image_seq_len
|
|
)
|
|
# temporal scale
|
|
shift_alpha *= math.sqrt(num_frames)
|
|
# calculate shifted timesteps
|
|
timesteps = time_shift(shift_alpha, timesteps)
|
|
|
|
return timesteps.tolist()
|
|
|
|
|
|
def get_noise(
|
|
num_samples: int,
|
|
height: int,
|
|
width: int,
|
|
num_frames: int,
|
|
device: torch.device,
|
|
dtype: torch.dtype,
|
|
seed: int,
|
|
patch_size: int = 2,
|
|
channel: int = 16,
|
|
) -> Tensor:
|
|
"""
|
|
Generate a noise tensor.
|
|
|
|
Args:
|
|
num_samples (int): Number of samples.
|
|
height (int): Height of the noise tensor.
|
|
width (int): Width of the noise tensor.
|
|
num_frames (int): Number of frames.
|
|
device (torch.device): Device to put the noise tensor on.
|
|
dtype (torch.dtype): Data type of the noise tensor.
|
|
seed (int): Seed for the random number generator.
|
|
|
|
Returns:
|
|
Tensor: The noise tensor.
|
|
"""
|
|
D = int(os.environ.get("AE_SPATIAL_COMPRESSION", 16))
|
|
return torch.randn(
|
|
num_samples,
|
|
channel,
|
|
num_frames,
|
|
# allow for packing
|
|
patch_size * math.ceil(height / D),
|
|
patch_size * math.ceil(width / D),
|
|
device=device,
|
|
dtype=dtype,
|
|
generator=torch.Generator(device=device).manual_seed(seed),
|
|
)
|
|
|
|
|
|
def pack(x: Tensor, patch_size: int = 2) -> Tensor:
|
|
return rearrange(
|
|
x, "b c t (h ph) (w pw) -> b (t h w) (c ph pw)", ph=patch_size, pw=patch_size
|
|
)
|
|
|
|
|
|
def unpack(
|
|
x: Tensor, height: int, width: int, num_frames: int, patch_size: int = 2
|
|
) -> Tensor:
|
|
D = int(os.environ.get("AE_SPATIAL_COMPRESSION", 16))
|
|
return rearrange(
|
|
x,
|
|
"b (t h w) (c ph pw) -> b c t (h ph) (w pw)",
|
|
h=math.ceil(height / D),
|
|
w=math.ceil(width / D),
|
|
t=num_frames,
|
|
ph=patch_size,
|
|
pw=patch_size,
|
|
)
|
|
|
|
|
|
# ======================================================
|
|
# Prepare
|
|
# ======================================================
|
|
|
|
|
|
def prepare(
|
|
t5,
|
|
clip: HFEmbedder,
|
|
img: Tensor,
|
|
prompt: str | list[str],
|
|
seq_align: int = 1,
|
|
patch_size: int = 2,
|
|
) -> dict[str, Tensor]:
|
|
"""
|
|
Prepare the input for the model.
|
|
|
|
Args:
|
|
t5 (HFEmbedder): The T5 model.
|
|
clip (HFEmbedder): The CLIP model.
|
|
img (Tensor): The image tensor.
|
|
prompt (str | list[str]): The prompt(s).
|
|
|
|
Returns:
|
|
dict[str, Tensor]: The input dictionary.
|
|
|
|
img_ids: used for positional embedding in T,H,W dimensions later
|
|
text_ids: for positional embedding, but set to 0 for now since our text encoder already encodes positional information
|
|
"""
|
|
bs, c, t, h, w = img.shape
|
|
device, dtype = img.device, img.dtype
|
|
if isinstance(prompt, str):
|
|
prompt = [prompt]
|
|
if bs != len(prompt):
|
|
bs = len(prompt)
|
|
|
|
img = rearrange(
|
|
img, "b c t (h ph) (w pw) -> b (t h w) (c ph pw)", ph=patch_size, pw=patch_size
|
|
)
|
|
if img.shape[0] != bs:
|
|
img = repeat(img, "b ... -> (repeat b) ...", repeat=bs // img.shape[0])
|
|
|
|
img_ids = torch.zeros(t, h // patch_size, w // patch_size, 3)
|
|
img_ids[..., 0] = img_ids[..., 0] + torch.arange(t)[:, None, None]
|
|
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // patch_size)[None, :, None]
|
|
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // patch_size)[None, None, :]
|
|
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
|
|
|
# Encode the tokenized prompts
|
|
txt = t5(prompt, added_tokens=img_ids.shape[1], seq_align=seq_align)
|
|
if txt.shape[0] == 1 and bs > 1:
|
|
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
|
|
txt_ids = torch.zeros(bs, txt.shape[1], 3)
|
|
|
|
vec = clip(prompt)
|
|
if vec.shape[0] == 1 and bs > 1:
|
|
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
|
|
|
|
return {
|
|
"img": img,
|
|
"img_ids": img_ids.to(device, dtype),
|
|
"txt": txt.to(device, dtype),
|
|
"txt_ids": txt_ids.to(device, dtype),
|
|
"y_vec": vec.to(device, dtype),
|
|
}
|
|
|
|
|
|
def prepare_ids(
|
|
img: Tensor,
|
|
t5_embedding: Tensor,
|
|
clip_embedding: Tensor,
|
|
) -> dict[str, Tensor]:
|
|
"""
|
|
Prepare the input for the model.
|
|
|
|
Args:
|
|
img (Tensor): The image tensor.
|
|
t5_embedding (Tensor): The T5 embedding.
|
|
clip_embedding (Tensor): The CLIP embedding.
|
|
|
|
Returns:
|
|
dict[str, Tensor]: The input dictionary.
|
|
|
|
img_ids: used for positional embedding in T,H,W dimensions later
|
|
text_ids: for positional embedding, but set to 0 for now since our text encoder already encodes positional information
|
|
"""
|
|
bs, c, t, h, w = img.shape
|
|
device, dtype = img.device, img.dtype
|
|
|
|
img = rearrange(img, "b c t (h ph) (w pw) -> b (t h w) (c ph pw)", ph=2, pw=2)
|
|
if img.shape[0] != bs:
|
|
img = repeat(img, "b ... -> (repeat b) ...", repeat=bs // img.shape[0])
|
|
|
|
img_ids = torch.zeros(t, h // 2, w // 2, 3)
|
|
img_ids[..., 0] = img_ids[..., 0] + torch.arange(t)[:, None, None]
|
|
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[None, :, None]
|
|
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, None, :]
|
|
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
|
|
|
# Encode the tokenized prompts
|
|
if t5_embedding.shape[0] == 1 and bs > 1:
|
|
t5_embedding = repeat(t5_embedding, "1 ... -> bs ...", bs=bs)
|
|
txt_ids = torch.zeros(bs, t5_embedding.shape[1], 3)
|
|
|
|
if clip_embedding.shape[0] == 1 and bs > 1:
|
|
clip_embedding = repeat(clip_embedding, "1 ... -> bs ...", bs=bs)
|
|
|
|
return {
|
|
"img": img,
|
|
"img_ids": img_ids.to(device, dtype),
|
|
"txt": t5_embedding.to(device, dtype),
|
|
"txt_ids": txt_ids.to(device, dtype),
|
|
"y_vec": clip_embedding.to(device, dtype),
|
|
}
|
|
|
|
|
|
def prepare_models(
|
|
cfg: Config,
|
|
device: torch.device,
|
|
dtype: torch.dtype,
|
|
offload_model: bool = False,
|
|
) -> tuple[nn.Module, nn.Module, nn.Module, nn.Module, dict[str, nn.Module]]:
|
|
"""
|
|
Prepare models for inference.
|
|
|
|
Args:
|
|
cfg (Config): The configuration object.
|
|
device (torch.device): The device to use.
|
|
dtype (torch.dtype): The data type to use.
|
|
|
|
Returns:
|
|
tuple[nn.Module, nn.Module, nn.Module, nn.Module, dict[str, nn.Module]]: The models. They are the diffusion model, the autoencoder model, the T5 model, the CLIP model, and the optional models.
|
|
"""
|
|
model_device = (
|
|
"cpu" if offload_model and cfg.get("img_flux", None) is not None else device
|
|
)
|
|
|
|
model = build_module(
|
|
cfg.model, MODELS, device_map=model_device, torch_dtype=dtype
|
|
).eval()
|
|
model_ae = build_module(
|
|
cfg.ae, MODELS, device_map=model_device, torch_dtype=dtype
|
|
).eval()
|
|
model_t5 = build_module(cfg.t5, MODELS, device_map=device, torch_dtype=dtype).eval()
|
|
model_clip = build_module(
|
|
cfg.clip, MODELS, device_map=device, torch_dtype=dtype
|
|
).eval()
|
|
if cfg.get("pretrained_lora_path", None) is not None:
|
|
model = PeftModel.from_pretrained(
|
|
model, cfg.pretrained_lora_path, is_trainable=False
|
|
)
|
|
|
|
# optional models
|
|
optional_models = {}
|
|
if cfg.get("img_flux", None) is not None:
|
|
model_img_flux = build_module(
|
|
cfg.img_flux, MODELS, device_map=device, torch_dtype=dtype
|
|
).eval()
|
|
model_ae_img_flux = build_module(
|
|
cfg.img_flux_ae, MODELS, device_map=device, torch_dtype=dtype
|
|
).eval()
|
|
optional_models["img_flux"] = model_img_flux
|
|
optional_models["img_flux_ae"] = model_ae_img_flux
|
|
|
|
return model, model_ae, model_t5, model_clip, optional_models
|
|
|
|
|
|
def prepare_api(
|
|
model: nn.Module,
|
|
model_ae: nn.Module,
|
|
model_t5: nn.Module,
|
|
model_clip: nn.Module,
|
|
optional_models: dict[str, nn.Module],
|
|
) -> callable:
|
|
"""
|
|
Prepare the API function for inference.
|
|
|
|
Args:
|
|
model (nn.Module): The diffusion model.
|
|
model_ae (nn.Module): The autoencoder model.
|
|
model_t5 (nn.Module): The T5 model.
|
|
model_clip (nn.Module): The CLIP model.
|
|
|
|
Returns:
|
|
callable: The API function for inference.
|
|
"""
|
|
|
|
@torch.inference_mode()
|
|
def api_fn(
|
|
opt: SamplingOption,
|
|
cond_type: str = "t2v",
|
|
seed: int = None,
|
|
sigma_min: float = 1e-5,
|
|
text: list[str] = None,
|
|
neg: list[str] = None,
|
|
patch_size: int = 2,
|
|
channel: int = 16,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
The API function for inference.
|
|
|
|
Args:
|
|
opt (SamplingOption): The sampling options.
|
|
text (list[str], optional): The text prompts. Defaults to None.
|
|
neg (list[str], optional): The negative text prompts. Defaults to None.
|
|
|
|
Returns:
|
|
torch.Tensor: The generated images.
|
|
"""
|
|
device = next(model.parameters()).device
|
|
dtype = next(model.parameters()).dtype
|
|
|
|
# passing seed will overwrite opt seed
|
|
if seed is None:
|
|
# random seed if not provided
|
|
seed = opt.seed if opt.seed is not None else random.randint(0, 2**32 - 1)
|
|
if opt.is_causal_vae:
|
|
num_frames = (
|
|
1
|
|
if opt.num_frames == 1
|
|
else (opt.num_frames - 1) // opt.temporal_reduction + 1
|
|
)
|
|
else:
|
|
num_frames = (
|
|
1 if opt.num_frames == 1 else opt.num_frames // opt.temporal_reduction
|
|
)
|
|
|
|
z = get_noise(
|
|
len(text),
|
|
opt.height,
|
|
opt.width,
|
|
num_frames,
|
|
device,
|
|
dtype,
|
|
seed,
|
|
patch_size=patch_size,
|
|
channel=channel // (patch_size**2),
|
|
)
|
|
denoiser = SamplingMethodDict[opt.method]
|
|
|
|
# i2v reference conditions
|
|
references = [None] * len(text)
|
|
if cond_type != "t2v" and "ref" in kwargs:
|
|
reference_path_list = kwargs.pop("ref")
|
|
references = collect_references_batch(
|
|
reference_path_list,
|
|
cond_type,
|
|
model_ae,
|
|
(opt.height, opt.width),
|
|
is_causal=opt.is_causal_vae,
|
|
)
|
|
elif cond_type != "t2v":
|
|
print(
|
|
"your csv file doesn't have a ref column or is not processed properly. will default to cond_type t2v!"
|
|
)
|
|
cond_type = "t2v"
|
|
|
|
# timestep editing
|
|
timesteps = get_schedule(
|
|
opt.num_steps,
|
|
(z.shape[-1] * z.shape[-2]) // patch_size**2,
|
|
num_frames,
|
|
shift=opt.shift,
|
|
shift_alpha=opt.flow_shift,
|
|
)
|
|
|
|
# prepare classifier-free guidance data (method specific)
|
|
text, additional_inp = denoiser.prepare_guidance(
|
|
text=text,
|
|
optional_models=optional_models,
|
|
device=device,
|
|
dtype=dtype,
|
|
neg=neg,
|
|
guidance_img=opt.guidance_img,
|
|
)
|
|
|
|
inp = prepare(model_t5, model_clip, z, prompt=text, patch_size=patch_size)
|
|
inp.update(additional_inp)
|
|
|
|
if opt.method in [SamplingMethod.I2V]:
|
|
# prepare references
|
|
masks, masked_ref = prepare_inference_condition(
|
|
z, cond_type, ref_list=references, causal=opt.is_causal_vae
|
|
)
|
|
inp["masks"] = masks
|
|
inp["masked_ref"] = masked_ref
|
|
inp["sigma_min"] = sigma_min
|
|
|
|
x = denoiser.denoise(
|
|
model,
|
|
**inp,
|
|
timesteps=timesteps,
|
|
guidance=opt.guidance,
|
|
text_osci=opt.text_osci,
|
|
image_osci=opt.image_osci,
|
|
scale_temporal_osci=(
|
|
opt.scale_temporal_osci and "i2v" in cond_type
|
|
), # don't use temporal osci for v2v or t2v
|
|
flow_shift=opt.flow_shift,
|
|
patch_size=patch_size,
|
|
)
|
|
|
|
x = unpack(x, opt.height, opt.width, num_frames, patch_size=patch_size)
|
|
|
|
# replace for image condition
|
|
if cond_type == "i2v_head":
|
|
x[0, :, :1] = references[0][0]
|
|
elif cond_type == "i2v_tail":
|
|
x[0, :, -1:] = references[0][0]
|
|
elif cond_type == "i2v_loop":
|
|
x[0, :, :1] = references[0][0]
|
|
x[0, :, -1:] = references[0][1]
|
|
|
|
x = model_ae.decode(x)
|
|
x = x[:, :, : opt.num_frames] # image
|
|
|
|
# remove the duplicate frames
|
|
if not opt.is_causal_vae:
|
|
if cond_type == "i2v_head":
|
|
pad_len = model_ae.compression[0] - 1
|
|
x = x[:, :, pad_len:]
|
|
elif cond_type == "i2v_tail":
|
|
pad_len = model_ae.compression[0] - 1
|
|
x = x[:, :, :-pad_len]
|
|
elif cond_type == "i2v_loop":
|
|
pad_len = model_ae.compression[0] - 1
|
|
x = x[:, :, pad_len:-pad_len]
|
|
|
|
return x
|
|
|
|
return api_fn
|