Open-Sora/opensora/utils/sampling.py
2025-03-19 17:33:32 +08:00

1092 lines
38 KiB
Python

import math
import os
import random
from abc import ABC, abstractmethod
from dataclasses import dataclass, replace
import json
import subprocess
from collections import defaultdict
import sys
import torch
import torchvision
from einops import rearrange, repeat
from mmengine.config import Config
from peft import PeftModel
from torch import Tensor, nn
import torch.distributed as dist
from opensora.datasets.aspect import get_image_size
from opensora.models.mmdit.model import MMDiTModel
from opensora.models.text.conditioner import HFEmbedder
from opensora.registry import MODELS, build_module
from opensora.utils.inference import (
SamplingMethod,
collect_references_batch,
prepare_inference_condition,
)
# ======================================================
# Sampling Options
# ======================================================
@dataclass
class SamplingOption:
# The width of the image/video.
width: int | None = None
# The height of the image/video.
height: int | None = None
# The resolution of the image/video. If provided, it will override the height and width.
resolution: str | None = None
# The aspect ratio of the image/video. If provided, it will override the height and width.
aspect_ratio: str | None = None
# The number of frames.
num_frames: int = 1
# The number of sampling steps.
num_steps: int = 50
# The classifier-free guidance (text).
guidance: float = 4.0
# use oscillation for text guidance
text_osci: bool = False
# The classifier-free guidance (image), or for the guidance on condition for i2v and v2v
guidance_img: float | None = None
# use oscillation for image guidance
image_osci: bool = False
# use temporal scaling for image guidance
scale_temporal_osci: bool = False
# The seed for the random number generator.
seed: int | None = None
# Whether to shift the schedule.
shift: bool = True
# The sampling method.
method: str | SamplingMethod = SamplingMethod.I2V
# Temporal reduction
temporal_reduction: int = 1
# is causal vae
is_causal_vae: bool = False
# flow shift
flow_shift: float | None = None
# do_inference_scaling
do_inference_scaling: bool | None = False
# num_subtree
num_subtree: int = 3
# backward_scale
backward_scale: float = 1.0
# forward_scale
forward_scale: float = 1.0
# scaling_steps
scaling_steps: list = None
# vbench_gpus
vbench_gpus: list = None
# vbench_dimension_list
vbench_dimension_list: list = None
NORMALIZE_DIC = {
"subject consistency": {"Min": 0.1462, "Max": 1.0},
"background consistency": {"Min": 0.2615, "Max": 1.0},
"motion smoothness": {"Min": 0.706, "Max": 0.9975},
"dynamic degree": {"Min": 0.0, "Max": 1.0},
"aesthetic quality": {"Min": 0.0, "Max": 1.0},
"imaging quality": {"Min": 0.0, "Max": 1.0},
}
DIM_WEIGHT = {
"subject consistency":1,
"background consistency":1,
"motion smoothness":1,
"aesthetic quality":1,
"imaging quality":1,
"dynamic degree":0.5,
}
def find_highest_score_video(data):
video_scores = defaultdict(dict)
for metric_name, metric_data in data.items():
if not isinstance(metric_data, list) or len(metric_data) < 2:
continue
if metric_name not in NORMALIZE_DIC:
continue
min_val = NORMALIZE_DIC[metric_name]["Min"]
max_val = NORMALIZE_DIC[metric_name]["Max"]
dim_weight = DIM_WEIGHT[metric_name]
for entry in metric_data[1]:
try:
path_parts = entry["video_path"].split("/")
filename = path_parts[-1]
video_index = int(filename.split(".")[0])
if "video_results" in entry:
raw_score = entry["video_results"]
elif "cor_num_per_video" in entry:
raw_score = entry["cor_num_per_video"]
else:
continue
norm_score = (raw_score - min_val) / (max_val - min_val) * dim_weight
video_scores[video_index][metric_name] = norm_score
except (KeyError, ValueError, IndexError):
continue
final_scores = {}
for vid, scores in video_scores.items():
if len(scores) > 0:
final_scores[vid] = sum(scores.values()) / sum(DIM_WEIGHT[key] for key in scores.keys())
if not final_scores:
return -1
max_score = max(final_scores.values())
candidates = [vid for vid, score in final_scores.items() if score == max_score]
return min(candidates) if candidates else -1
def sanitize_sampling_option(sampling_option: SamplingOption) -> SamplingOption:
"""
Sanitize the sampling options.
Args:
sampling_option (SamplingOption): The sampling options.
Returns:
SamplingOption: The sanitized sampling options.
"""
if (
sampling_option.resolution is not None
or sampling_option.aspect_ratio is not None
):
assert (
sampling_option.resolution is not None
and sampling_option.aspect_ratio is not None
), "Both resolution and aspect ratio must be provided"
resolution = sampling_option.resolution
aspect_ratio = sampling_option.aspect_ratio
height, width = get_image_size(resolution, aspect_ratio, training=False)
else:
assert (
sampling_option.height is not None and sampling_option.width is not None
), "Both height and width must be provided"
height, width = sampling_option.height, sampling_option.width
height = (height // 16 + (1 if height % 16 else 0)) * 16
width = (width // 16 + (1 if width % 16 else 0)) * 16
replace_dict = dict(height=height, width=width)
if isinstance(sampling_option.method, str):
method = SamplingMethod(sampling_option.method)
replace_dict["method"] = method
return replace(sampling_option, **replace_dict)
def get_oscillation_gs(guidance_scale: float, i: int, force_num=10):
"""
get oscillation guidance for cfg.
Args:
guidance_scale: original guidance value
i: denoising step
force_num: before which don't apply oscillation
"""
if i < force_num or (i >= force_num and i % 2 == 0):
gs = guidance_scale
else:
gs = 1.0
return gs
# ======================================================
# Denoising
# ======================================================
class Denoiser(ABC):
@abstractmethod
def denoise(self, model: MMDiTModel, **kwargs) -> Tensor:
"""Denoise the input."""
@abstractmethod
def prepare_guidance(
self,
text: list[str],
optional_models: dict[str, nn.Module],
device: torch.device,
dtype: torch.dtype,
**kwargs,
) -> dict[str, Tensor]:
"""Prepare the guidance for the model. This method will alter text."""
class I2VDenoiser(Denoiser):
def denoise(self, model: MMDiTModel, **kwargs) -> Tensor:
img = kwargs.pop("img")
timesteps = kwargs.pop("timesteps")
guidance = kwargs.pop("guidance")
guidance_img = kwargs.pop("guidance_img")
# cond ref arguments
masks = kwargs.pop("masks")
masked_ref = kwargs.pop("masked_ref")
kwargs.pop("sigma_min")
# oscillation guidance
text_osci = kwargs.pop("text_osci", False)
image_osci = kwargs.pop("image_osci", False)
scale_temporal_osci = kwargs.pop("scale_temporal_osci", False)
# patch size
patch_size = kwargs.pop("patch_size", 2)
guidance_vec = torch.full(
(img.shape[0],), guidance, device=img.device, dtype=img.dtype
)
for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
# timesteps
t_vec = torch.full(
(img.shape[0],), t_curr, dtype=img.dtype, device=img.device
)
b, c, t, w, h = masked_ref.size()
cond = torch.cat((masks, masked_ref), dim=1)
cond = pack(cond, patch_size=patch_size)
kwargs["cond"] = torch.cat([cond, cond, torch.zeros_like(cond)], dim=0)
# forward preparation
cond_x = img[: len(img) // 3]
img = torch.cat([cond_x, cond_x, cond_x], dim=0)
# forward
pred = model(
img=img,
**kwargs,
timesteps=t_vec,
guidance=guidance_vec,
)
# prepare guidance
text_gs = get_oscillation_gs(guidance, i) if text_osci else guidance
image_gs = (
get_oscillation_gs(guidance_img, i) if image_osci else guidance_img
)
cond, uncond, uncond_2 = pred.chunk(3, dim=0)
if image_gs > 1.0 and scale_temporal_osci:
# image_gs decrease with each denoising step
step_upper_image_gs = torch.linspace(image_gs, 1.0, len(timesteps))[i]
# image_gs increase along the temporal axis of the latent video
image_gs = torch.linspace(1.0, step_upper_image_gs, t)[
None, None, :, None, None
].repeat(b, c, 1, h, w)
image_gs = pack(image_gs, patch_size=patch_size).to(cond.device, cond.dtype)
# update
pred = uncond_2 + image_gs * (uncond - uncond_2) + text_gs * (cond - uncond)
pred = torch.cat([pred, pred, pred], dim=0)
img = img + (t_prev - t_curr) * pred
img = img[: len(img) // 3]
return img
def prepare_guidance(
self,
text: list[str],
optional_models: dict[str, nn.Module],
device: torch.device,
dtype: torch.dtype,
**kwargs,
) -> tuple[list[str], dict[str, Tensor]]:
ret = {}
neg = kwargs.get("neg", None)
ret["guidance_img"] = kwargs.pop("guidance_img")
# text
if neg is None:
neg = [""] * len(text)
text = text + neg + neg
return text, ret
class I2VScalingDenoiser(Denoiser):
def denoise(self, model: MMDiTModel, **kwargs) -> Tensor:
img = kwargs.pop("img")
timesteps = kwargs.pop("timesteps")
guidance = kwargs.pop("guidance")
guidance_img = kwargs.pop("guidance_img")
vbench = kwargs.pop("vbench", None)
do_inference_scaling = kwargs.pop("do_inference_scaling", True)
num_subtree = kwargs.pop("num_subtree", 3)
model_ae = kwargs.pop("model_ae", None)
height = kwargs.pop("height", None)
width = kwargs.pop("width", None)
num_frames = kwargs.pop("num_frames", None)
patch_size = kwargs.pop("patch_size", None)
prompt = kwargs.pop("prompt", None)
backward_scale = kwargs.pop("backward_scale", 1)
forward_scale = kwargs.pop("forward_scale", 1)
scaling_steps = kwargs.pop("scaling_steps", range(len(timesteps))[1:])
vbench_dimension_list = kwargs.pop("vbench_dimension_list", ["overall_consistency"])
vbench_gpus = kwargs.pop("vbench_gpus", [4,5,6,7])
vbench_gpus = vbench_gpus[:num_subtree]
# cond ref arguments
masks = kwargs.pop("masks")
masked_ref = kwargs.pop("masked_ref")
kwargs.pop("sigma_min")
# oscillation guidance
text_osci = kwargs.pop("text_osci", False)
image_osci = kwargs.pop("image_osci", False)
scale_temporal_osci = kwargs.pop("scale_temporal_osci", False)
prompt = [prompt[0]]
prompt_short = sanitize_filename(prompt[0])
save_dir = f'temp/{prompt_short}'
os.makedirs(save_dir, exist_ok=True)
def save_video(tensor, filename):
# tensor shape: [B, C, T, H, W]
# Normalize to [0, 255]
video = ((tensor + 1) * 127.5).clamp(0, 255).to(torch.uint8)
# Ensure output directory exists
os.makedirs(os.path.dirname(filename), exist_ok=True)
if video.shape[2] == 1: # Single frame - save as image
# Save middle frame for evaluation
frame = video[0, :, 0] # [H, W, C]
image_path = filename.replace('.mp4', '.png')
torchvision.io.write_png(frame.cpu(), image_path)
return image_path
else:
# Also save full video
video = video[0].permute(1, 2, 3, 0)[20:].contiguous() # [B, T, H, W, C]
torchvision.io.write_video(filename, video.cpu(), fps=8)
return filename
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
# timesteps
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
b, c, t, w, h = masked_ref.size()
cond = torch.cat((masks, masked_ref), dim=1)
cond = pack(cond)
kwargs["cond"] = torch.cat([cond, cond, torch.zeros_like(cond)], dim=0)
# forward preparation
cond_x = img[: len(img) // 3]
noise_shape = cond_x.shape
img = torch.cat([cond_x, cond_x, cond_x], dim=0)
if i not in scaling_steps or not do_inference_scaling:
# forward
pred = model(
img=img,
**kwargs,
timesteps=t_vec,
guidance=guidance_vec,
)
# prepare guidance
text_gs = get_oscillation_gs(guidance, i) if text_osci else guidance
image_gs = get_oscillation_gs(guidance_img, i) if image_osci else guidance_img
cond, uncond, uncond_2 = pred.chunk(3, dim=0)
if image_gs > 1.0 and scale_temporal_osci:
# image_gs decrease with each denoising step
step_upper_image_gs = torch.linspace(image_gs, 1.0, len(timesteps))[i]
# image_gs increase along the temporal axis of the latent video
image_gs = torch.linspace(1.0, step_upper_image_gs, t)[None, None, :, None, None].repeat(b, c, 1, h, w)
image_gs = pack(image_gs).to(cond.device, cond.dtype)
# update
pred = uncond_2 + image_gs * (uncond - uncond_2) + text_gs * (cond - uncond)
pred = torch.cat([pred, pred, pred], dim=0)
img = img + (t_prev - t_curr) * pred
else:
subtrees = []
subtree_onesteps = []
# Generate multiple subtrees
for subtree_idx in range(num_subtree):
noise = torch.randn(noise_shape, device=cond_x.device, dtype=cond_x.dtype)
zeros = torch.zeros(noise_shape, device=cond_x.device, dtype=cond_x.dtype)
noise = torch.cat([noise, zeros, zeros], dim=0)
subtree = img - (timesteps[i+1] - t_curr) * forward_scale * noise
t_subtree = t_curr
t_subtree_prev = t_subtree + (timesteps[i+1] - t_curr) * backward_scale
t_subtree_vec = torch.full((img.shape[0],), t_subtree, dtype=cond.dtype, device=cond.device)
subtree_noise_pred = model(
img=subtree,
**kwargs,
timesteps=t_subtree_vec,
guidance=guidance_vec,
)
# prepare guidance
text_gs = get_oscillation_gs(guidance, i) if text_osci else guidance
image_gs = get_oscillation_gs(guidance_img, i) if image_osci else guidance_img
cond, uncond, uncond_2 = subtree_noise_pred.chunk(3, dim=0)
if image_gs > 1.0 and scale_temporal_osci:
# image_gs decrease with each denoising step
step_upper_image_gs = torch.linspace(image_gs, 1.0, len(timesteps))[i]
# image_gs increase along the temporal axis of the latent video
image_gs = torch.linspace(1.0, step_upper_image_gs, t)[None, None, :, None, None].repeat(b, c, 1, h, w)
image_gs = pack(image_gs).to(cond.device, cond.dtype)
# update
subtree_noise_pred = uncond_2 + image_gs * (uncond - uncond_2) + text_gs * (cond - uncond)
subtree_noise_pred = torch.cat([subtree_noise_pred, subtree_noise_pred, subtree_noise_pred], dim=0)
subtree = subtree + (t_subtree_prev - t_subtree) * subtree_noise_pred
# Next timestep prediction
t_subtree_prev_vec = torch.full((img.shape[0],), t_subtree_prev, dtype=cond.dtype, device=cond.device)
subtree_onestep_noise_pred = model(
img=subtree,
**kwargs,
timesteps=t_subtree_prev_vec,
guidance=guidance_vec,
)
# prepare guidance
text_gs = get_oscillation_gs(guidance, i) if text_osci else guidance
image_gs = get_oscillation_gs(guidance_img, i) if image_osci else guidance_img
cond, uncond, uncond_2 = subtree_noise_pred.chunk(3, dim=0)
if image_gs > 1.0 and scale_temporal_osci:
# image_gs decrease with each denoising step
step_upper_image_gs = torch.linspace(image_gs, 1.0, len(timesteps))[i]
# image_gs increase along the temporal axis of the latent video
image_gs = torch.linspace(1.0, step_upper_image_gs, t)[None, None, :, None, None].repeat(b, c, 1, h, w)
image_gs = pack(image_gs).to(cond.device, cond.dtype)
# update
subtree_noise_pred = uncond_2 + image_gs * (uncond - uncond_2) + text_gs * (cond - uncond)
subtree_noise_pred = torch.cat([subtree_noise_pred, subtree_noise_pred, subtree_noise_pred], dim=0)
# img = img + (t_prev - t_curr) * pred
subtree_onestep = subtree + (timesteps[-1] - t_subtree_prev) * subtree_onestep_noise_pred
# Save subtree results
if model_ae is not None:
unpacked_subtree_onestep = unpack(subtree_onestep, height, width, num_frames, patch_size)[: len(img) // 3]
decoded_onestep = model_ae.decode(unpacked_subtree_onestep)
# Now returns path to middle frame image for evaluation
os.makedirs(f"{save_dir}/{i}_subtree", exist_ok=True)
decoded_onestep_path = save_video(decoded_onestep, f"{save_dir}/{i}_subtree/{subtree_idx}.mp4")
# Store results
subtrees.append(subtree)
subtree_onesteps.append(decoded_onestep_path)
dist.barrier()
if dist.get_rank() == 0:
videos_path = f"{save_dir}/{i}_subtree"
output_path = f"{save_dir}/{i}_subtree"
prompt_file = "temp/prompt.json" # hard coded for now
with open(prompt_file, "w") as fp:
prompt_json = {f"{index}.mp4": prompt[0] for index in range(num_subtree)}
json.dump(prompt_json, fp)
python_path = os.path.dirname(sys.executable)
minimal_env = {
"PATH": f"{python_path}:/usr/local/bin:/usr/bin:/bin",
"CUDA_VISIBLE_DEVICES": ",".join([str(item) for item in vbench_gpus])
}
cmd_args = [
'vbench',
'evaluate',
'--dimension',
' '.join(vbench_dimension_list),
'--videos_path',
videos_path,
'--mode',
'custom_input',
'--prompt_file',
prompt_file,
'--output_path',
output_path,
'--ngpus',
str(len(vbench_gpus))
]
subprocess.run(cmd_args, check=True, env=minimal_env)
dist.barrier()
# Find the latest evaluation results file
subtree_dir = f"{save_dir}/{i}_subtree"
eval_files = [f for f in os.listdir(subtree_dir) if f.startswith('results_') and f.endswith('_eval_results.json')]
if not eval_files:
raise FileNotFoundError(f"No evaluation results found in {subtree_dir}")
# Sort by timestamp in filename to get the latest one
latest_eval_file = sorted(eval_files)[-1]
eval_results_path = os.path.join(subtree_dir, latest_eval_file)
with open(eval_results_path, 'r') as f:
eval_results = json.load(f)
best_idx = find_highest_score_video(eval_results)
img = subtrees[best_idx]
img = img[: len(img) // 3]
return img
def prepare_guidance(
self, text: list[str], optional_models: dict[str, nn.Module], device: torch.device, dtype: torch.dtype, **kwargs
) -> tuple[list[str], dict[str, Tensor]]:
ret = {}
neg = kwargs.get("neg", None)
ret["guidance_img"] = kwargs.pop("guidance_img")
# text
if neg is None:
neg = [""] * len(text)
text = text + neg + neg
return text, ret
class DistilledDenoiser(Denoiser):
def denoise(self, model: MMDiTModel, **kwargs) -> Tensor:
img = kwargs.pop("img")
timesteps = kwargs.pop("timesteps")
guidance = kwargs.pop("guidance")
guidance_vec = torch.full(
(img.shape[0],), guidance, device=img.device, dtype=img.dtype
)
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
# timesteps
t_vec = torch.full(
(img.shape[0],), t_curr, dtype=img.dtype, device=img.device
)
# forward
pred = model(
img=img,
**kwargs,
timesteps=t_vec,
guidance=guidance_vec,
)
# update
img = img + (t_prev - t_curr) * pred
return img
def prepare_guidance(
self,
text: list[str],
optional_models: dict[str, nn.Module],
device: torch.device,
dtype: torch.dtype,
**kwargs,
) -> tuple[list[str], dict[str, Tensor]]:
return text, {}
SamplingMethodDict = {
SamplingMethod.I2V: I2VDenoiser(),
SamplingMethod.I2VINFERENCESCALING: I2VScalingDenoiser(),
SamplingMethod.DISTILLED: DistilledDenoiser(),
}
# ======================================================
# Timesteps
# ======================================================
def time_shift(alpha: float, t: Tensor) -> Tensor:
return alpha * t / (1 + (alpha - 1) * t)
def get_res_lin_function(
x1: float = 256, y1: float = 1, x2: float = 4096, y2: float = 3
) -> callable:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
def get_schedule(
num_steps: int,
image_seq_len: int,
num_frames: int,
shift_alpha: float | None = None,
base_shift: float = 1,
max_shift: float = 3,
shift: bool = True,
) -> list[float]:
# extra step for zero
timesteps = torch.linspace(1, 0, num_steps + 1)
# shifting the schedule to favor high timesteps for higher signal images
if shift:
if shift_alpha is None:
# estimate mu based on linear estimation between two points
# spatial scale
shift_alpha = get_res_lin_function(y1=base_shift, y2=max_shift)(
image_seq_len
)
# temporal scale
shift_alpha *= math.sqrt(num_frames)
# calculate shifted timesteps
timesteps = time_shift(shift_alpha, timesteps)
return timesteps.tolist()
def get_noise(
num_samples: int,
height: int,
width: int,
num_frames: int,
device: torch.device,
dtype: torch.dtype,
seed: int,
patch_size: int = 2,
channel: int = 16,
) -> Tensor:
"""
Generate a noise tensor.
Args:
num_samples (int): Number of samples.
height (int): Height of the noise tensor.
width (int): Width of the noise tensor.
num_frames (int): Number of frames.
device (torch.device): Device to put the noise tensor on.
dtype (torch.dtype): Data type of the noise tensor.
seed (int): Seed for the random number generator.
Returns:
Tensor: The noise tensor.
"""
D = int(os.environ.get("AE_SPATIAL_COMPRESSION", 16))
return torch.randn(
num_samples,
channel,
num_frames,
# allow for packing
patch_size * math.ceil(height / D),
patch_size * math.ceil(width / D),
device=device,
dtype=dtype,
generator=torch.Generator(device=device).manual_seed(seed),
)
def pack(x: Tensor, patch_size: int = 2) -> Tensor:
return rearrange(
x, "b c t (h ph) (w pw) -> b (t h w) (c ph pw)", ph=patch_size, pw=patch_size
)
def unpack(
x: Tensor, height: int, width: int, num_frames: int, patch_size: int = 2
) -> Tensor:
D = int(os.environ.get("AE_SPATIAL_COMPRESSION", 16))
return rearrange(
x,
"b (t h w) (c ph pw) -> b c t (h ph) (w pw)",
h=math.ceil(height / D),
w=math.ceil(width / D),
t=num_frames,
ph=patch_size,
pw=patch_size,
)
# ======================================================
# Prepare
# ======================================================
def prepare(
t5,
clip: HFEmbedder,
img: Tensor,
prompt: str | list[str],
seq_align: int = 1,
patch_size: int = 2,
) -> dict[str, Tensor]:
"""
Prepare the input for the model.
Args:
t5 (HFEmbedder): The T5 model.
clip (HFEmbedder): The CLIP model.
img (Tensor): The image tensor.
prompt (str | list[str]): The prompt(s).
Returns:
dict[str, Tensor]: The input dictionary.
img_ids: used for positional embedding in T,H,W dimensions later
text_ids: for positional embedding, but set to 0 for now since our text encoder already encodes positional information
"""
bs, c, t, h, w = img.shape
device, dtype = img.device, img.dtype
if isinstance(prompt, str):
prompt = [prompt]
if bs != len(prompt):
bs = len(prompt)
img = rearrange(
img, "b c t (h ph) (w pw) -> b (t h w) (c ph pw)", ph=patch_size, pw=patch_size
)
if img.shape[0] != bs:
img = repeat(img, "b ... -> (repeat b) ...", repeat=bs // img.shape[0])
img_ids = torch.zeros(t, h // patch_size, w // patch_size, 3)
img_ids[..., 0] = img_ids[..., 0] + torch.arange(t)[:, None, None]
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // patch_size)[None, :, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // patch_size)[None, None, :]
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
# Encode the tokenized prompts
txt = t5(prompt, added_tokens=img_ids.shape[1], seq_align=seq_align)
if txt.shape[0] == 1 and bs > 1:
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
txt_ids = torch.zeros(bs, txt.shape[1], 3)
vec = clip(prompt)
if vec.shape[0] == 1 and bs > 1:
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
return {
"img": img,
"img_ids": img_ids.to(device, dtype),
"txt": txt.to(device, dtype),
"txt_ids": txt_ids.to(device, dtype),
"y_vec": vec.to(device, dtype),
}
def prepare_ids(
img: Tensor,
t5_embedding: Tensor,
clip_embedding: Tensor,
) -> dict[str, Tensor]:
"""
Prepare the input for the model.
Args:
img (Tensor): The image tensor.
t5_embedding (Tensor): The T5 embedding.
clip_embedding (Tensor): The CLIP embedding.
Returns:
dict[str, Tensor]: The input dictionary.
img_ids: used for positional embedding in T,H,W dimensions later
text_ids: for positional embedding, but set to 0 for now since our text encoder already encodes positional information
"""
bs, c, t, h, w = img.shape
device, dtype = img.device, img.dtype
img = rearrange(img, "b c t (h ph) (w pw) -> b (t h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] != bs:
img = repeat(img, "b ... -> (repeat b) ...", repeat=bs // img.shape[0])
img_ids = torch.zeros(t, h // 2, w // 2, 3)
img_ids[..., 0] = img_ids[..., 0] + torch.arange(t)[:, None, None]
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[None, :, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, None, :]
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
# Encode the tokenized prompts
if t5_embedding.shape[0] == 1 and bs > 1:
t5_embedding = repeat(t5_embedding, "1 ... -> bs ...", bs=bs)
txt_ids = torch.zeros(bs, t5_embedding.shape[1], 3)
if clip_embedding.shape[0] == 1 and bs > 1:
clip_embedding = repeat(clip_embedding, "1 ... -> bs ...", bs=bs)
return {
"img": img,
"img_ids": img_ids.to(device, dtype),
"txt": t5_embedding.to(device, dtype),
"txt_ids": txt_ids.to(device, dtype),
"y_vec": clip_embedding.to(device, dtype),
}
def prepare_models(
cfg: Config,
device: torch.device,
dtype: torch.dtype,
offload_model: bool = False,
) -> tuple[nn.Module, nn.Module, nn.Module, nn.Module, dict[str, nn.Module]]:
"""
Prepare models for inference.
Args:
cfg (Config): The configuration object.
device (torch.device): The device to use.
dtype (torch.dtype): The data type to use.
Returns:
tuple[nn.Module, nn.Module, nn.Module, nn.Module, dict[str, nn.Module]]: The models. They are the diffusion model, the autoencoder model, the T5 model, the CLIP model, and the optional models.
"""
model_device = (
"cpu" if offload_model and cfg.get("img_flux", None) is not None else device
)
model = build_module(
cfg.model, MODELS, device_map=model_device, torch_dtype=dtype
).eval()
model_ae = build_module(
cfg.ae, MODELS, device_map=model_device, torch_dtype=dtype
).eval()
model_t5 = build_module(cfg.t5, MODELS, device_map=device, torch_dtype=dtype).eval()
model_clip = build_module(
cfg.clip, MODELS, device_map=device, torch_dtype=dtype
).eval()
if cfg.get("pretrained_lora_path", None) is not None:
model = PeftModel.from_pretrained(
model, cfg.pretrained_lora_path, is_trainable=False
)
# optional models
optional_models = {}
if cfg.get("img_flux", None) is not None:
model_img_flux = build_module(
cfg.img_flux, MODELS, device_map=device, torch_dtype=dtype
).eval()
model_ae_img_flux = build_module(
cfg.img_flux_ae, MODELS, device_map=device, torch_dtype=dtype
).eval()
optional_models["img_flux"] = model_img_flux
optional_models["img_flux_ae"] = model_ae_img_flux
return model, model_ae, model_t5, model_clip, optional_models
def prepare_api(
model: nn.Module,
model_ae: nn.Module,
model_t5: nn.Module,
model_clip: nn.Module,
optional_models: dict[str, nn.Module],
) -> callable:
"""
Prepare the API function for inference.
Args:
model (nn.Module): The diffusion model.
model_ae (nn.Module): The autoencoder model.
model_t5 (nn.Module): The T5 model.
model_clip (nn.Module): The CLIP model.
Returns:
callable: The API function for inference.
"""
@torch.inference_mode()
def api_fn(
opt: SamplingOption,
cond_type: str = "t2v",
seed: int = None,
sigma_min: float = 1e-5,
text: list[str] = None,
neg: list[str] = None,
patch_size: int = 2,
channel: int = 16,
**kwargs,
):
"""
The API function for inference.
Args:
opt (SamplingOption): The sampling options.
text (list[str], optional): The text prompts. Defaults to None.
neg (list[str], optional): The negative text prompts. Defaults to None.
Returns:
torch.Tensor: The generated images.
"""
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
# passing seed will overwrite opt seed
if seed is None:
# random seed if not provided
seed = opt.seed if opt.seed is not None else random.randint(0, 2**32 - 1)
if opt.is_causal_vae:
num_frames = (
1
if opt.num_frames == 1
else (opt.num_frames - 1) // opt.temporal_reduction + 1
)
else:
num_frames = (
1 if opt.num_frames == 1 else opt.num_frames // opt.temporal_reduction
)
z = get_noise(
len(text),
opt.height,
opt.width,
num_frames,
device,
dtype,
seed,
patch_size=patch_size,
channel=channel // (patch_size**2),
)
denoiser = SamplingMethodDict[opt.method]
# i2v reference conditions
references = [None] * len(text)
if cond_type != "t2v" and "ref" in kwargs:
reference_path_list = kwargs.pop("ref")
references = collect_references_batch(
reference_path_list,
cond_type,
model_ae,
(opt.height, opt.width),
is_causal=opt.is_causal_vae,
)
elif cond_type != "t2v":
print(
"your csv file doesn't have a ref column or is not processed properly. will default to cond_type t2v!"
)
cond_type = "t2v"
# timestep editing
timesteps = get_schedule(
opt.num_steps,
(z.shape[-1] * z.shape[-2]) // patch_size**2,
num_frames,
shift=opt.shift,
shift_alpha=opt.flow_shift,
)
# prepare classifier-free guidance data (method specific)
text, additional_inp = denoiser.prepare_guidance(
text=text,
optional_models=optional_models,
device=device,
dtype=dtype,
neg=neg,
guidance_img=opt.guidance_img,
)
inp = prepare(model_t5, model_clip, z, prompt=text, patch_size=patch_size)
inp.update(additional_inp)
if opt.method in [SamplingMethod.I2V, SamplingMethod.I2VINFERENCESCALING]:
# prepare references
masks, masked_ref = prepare_inference_condition(
z, cond_type, ref_list=references, causal=opt.is_causal_vae
)
inp["masks"] = masks
inp["masked_ref"] = masked_ref
inp["sigma_min"] = sigma_min
x = denoiser.denoise(
model,
**inp,
timesteps=timesteps,
guidance=opt.guidance,
text_osci=opt.text_osci,
image_osci=opt.image_osci,
scale_temporal_osci=(
opt.scale_temporal_osci and "i2v" in cond_type
), # don't use temporal osci for v2v or t2v
flow_shift=opt.flow_shift,
patch_size=patch_size,
# inference scaling parameters
do_inference_scaling=opt.do_inference_scaling,
num_subtree=opt.num_subtree,
model_ae=model_ae,
height=opt.height,
width=opt.width,
backward_scale = opt.backward_scale,
forward_scale = opt.forward_scale,
scaling_steps = opt.scaling_steps,
num_frames=num_frames,
prompt=text,
vbench_dimension_list=opt.vbench_dimension_list,
vbench_gpus=opt.vbench_gpus
)
x = unpack(x, opt.height, opt.width, num_frames, patch_size=patch_size)
# replace for image condition
if cond_type == "i2v_head":
x[0, :, :1] = references[0][0]
elif cond_type == "i2v_tail":
x[0, :, -1:] = references[0][0]
elif cond_type == "i2v_loop":
x[0, :, :1] = references[0][0]
x[0, :, -1:] = references[0][1]
x = model_ae.decode(x)
x = x[:, :, : opt.num_frames] # image
# remove the duplicate frames
if not opt.is_causal_vae:
if cond_type == "i2v_head":
pad_len = model_ae.compression[0] - 1
x = x[:, :, pad_len:]
elif cond_type == "i2v_tail":
pad_len = model_ae.compression[0] - 1
x = x[:, :, :-pad_len]
elif cond_type == "i2v_loop":
pad_len = model_ae.compression[0] - 1
x = x[:, :, pad_len:-pad_len]
return x
return api_fn
def sanitize_filename(prompt):
"""Sanitize the prompt to create a valid filename."""
# Remove or replace special characters
invalid_chars = '<>:"/\\|?*\n\r\t'
filename = prompt.strip()
for char in invalid_chars:
filename = filename.replace(char, '_')
# Replace multiple spaces/underscores with single underscore
filename = '_'.join(filter(None, filename.split()))
# Limit length and ensure it's not empty
filename = filename[:30] if filename else "default"
# Remove leading/trailing special characters
filename = filename.strip('._-')
return filename or "default"