Open-Sora/opensora/utils/sampling.py

1063 lines
37 KiB
Python
Raw Normal View History

import math
import os
import random
from abc import ABC, abstractmethod
from dataclasses import dataclass, replace
2025-03-17 08:37:03 +01:00
import json
import subprocess
from collections import defaultdict
import torch
2025-03-17 08:37:03 +01:00
import torchvision
from einops import rearrange, repeat
from mmengine.config import Config
from peft import PeftModel
from torch import Tensor, nn
2025-03-17 08:37:03 +01:00
import torch.distributed as dist
from opensora.datasets.aspect import get_image_size
from opensora.models.mmdit.model import MMDiTModel
from opensora.models.text.conditioner import HFEmbedder
from opensora.registry import MODELS, build_module
from opensora.utils.inference import (
SamplingMethod,
collect_references_batch,
prepare_inference_condition,
)
# ======================================================
# Sampling Options
# ======================================================
@dataclass
class SamplingOption:
# The width of the image/video.
width: int | None = None
# The height of the image/video.
height: int | None = None
# The resolution of the image/video. If provided, it will override the height and width.
resolution: str | None = None
# The aspect ratio of the image/video. If provided, it will override the height and width.
aspect_ratio: str | None = None
# The number of frames.
num_frames: int = 1
# The number of sampling steps.
num_steps: int = 50
# The classifier-free guidance (text).
guidance: float = 4.0
# use oscillation for text guidance
text_osci: bool = False
# The classifier-free guidance (image), or for the guidance on condition for i2v and v2v
guidance_img: float | None = None
# use oscillation for image guidance
image_osci: bool = False
# use temporal scaling for image guidance
scale_temporal_osci: bool = False
# The seed for the random number generator.
seed: int | None = None
# Whether to shift the schedule.
shift: bool = True
# The sampling method.
method: str | SamplingMethod = SamplingMethod.I2V
# Temporal reduction
temporal_reduction: int = 1
# is causal vae
is_causal_vae: bool = False
# flow shift
flow_shift: float | None = None
2025-03-17 08:37:03 +01:00
# do_inference_scaling
do_inference_scaling: bool | None = False
# num_subtree
num_subtree: int = 3
# backward_scale
backward_scale: float = 1.0
# forward_scale
forward_scale: float = 1.0
# scaling_steps
scaling_steps: list = None
# vbench_gpus
vbench_gpus: list = None
# vbench_dimension_list
vbench_dimension_list: list = None
def find_highest_score_video(data):
video_scores = defaultdict(list)
normalization_rules = {
"subject_consistency": lambda e: e["video_results"],
"background_consistency": lambda e: e["video_results"],
"temporal_flickering": lambda e: e["video_results"],
"motion_smoothness": lambda e: e["video_results"],
"dynamic_degree": lambda e: 1.0 if e["video_results"] else 0.0,
"aesthetic_quality": lambda e: e["video_results"],
"imaging_quality": lambda e: e["video_results"] / 100,
"human_action": lambda e: e["cor_num_per_video"],
"temporal_style": lambda e: e["video_results"],
"overall_consistency": lambda e: e["video_results"]
}
for metric_name, metric_data in data.items():
if not isinstance(metric_data, list) or len(metric_data) < 2:
continue
process_rule = normalization_rules.get(metric_name)
if not process_rule:
continue
for entry in metric_data[1]:
try:
path_parts = entry["video_path"].split("/")
filename = path_parts[-1]
video_index = int(filename.split(".")[0])
score = process_rule(entry)
video_scores[video_index].append(score)
except (KeyError, ValueError, IndexError):
continue
avg_scores = {}
for vid, scores in video_scores.items():
if len(scores) == 0:
avg_scores[vid] = 0.0
continue
avg_scores[vid] = sum(scores) / len(scores)
if not avg_scores:
return -1
max_score = max(avg_scores.values())
candidates = sorted(
[vid for vid, score in avg_scores.items() if score == max_score]
)
return candidates[0] if candidates else -1
def sanitize_sampling_option(sampling_option: SamplingOption) -> SamplingOption:
"""
Sanitize the sampling options.
Args:
sampling_option (SamplingOption): The sampling options.
Returns:
SamplingOption: The sanitized sampling options.
"""
if (
sampling_option.resolution is not None
or sampling_option.aspect_ratio is not None
):
assert (
sampling_option.resolution is not None
and sampling_option.aspect_ratio is not None
), "Both resolution and aspect ratio must be provided"
resolution = sampling_option.resolution
aspect_ratio = sampling_option.aspect_ratio
height, width = get_image_size(resolution, aspect_ratio, training=False)
else:
assert (
sampling_option.height is not None and sampling_option.width is not None
), "Both height and width must be provided"
height, width = sampling_option.height, sampling_option.width
height = (height // 16 + (1 if height % 16 else 0)) * 16
width = (width // 16 + (1 if width % 16 else 0)) * 16
replace_dict = dict(height=height, width=width)
if isinstance(sampling_option.method, str):
method = SamplingMethod(sampling_option.method)
replace_dict["method"] = method
return replace(sampling_option, **replace_dict)
def get_oscillation_gs(guidance_scale: float, i: int, force_num=10):
"""
get oscillation guidance for cfg.
Args:
guidance_scale: original guidance value
i: denoising step
force_num: before which don't apply oscillation
"""
if i < force_num or (i >= force_num and i % 2 == 0):
gs = guidance_scale
else:
gs = 1.0
return gs
# ======================================================
# Denoising
# ======================================================
class Denoiser(ABC):
@abstractmethod
def denoise(self, model: MMDiTModel, **kwargs) -> Tensor:
"""Denoise the input."""
@abstractmethod
def prepare_guidance(
self,
text: list[str],
optional_models: dict[str, nn.Module],
device: torch.device,
dtype: torch.dtype,
**kwargs,
) -> dict[str, Tensor]:
"""Prepare the guidance for the model. This method will alter text."""
class I2VDenoiser(Denoiser):
def denoise(self, model: MMDiTModel, **kwargs) -> Tensor:
img = kwargs.pop("img")
timesteps = kwargs.pop("timesteps")
guidance = kwargs.pop("guidance")
guidance_img = kwargs.pop("guidance_img")
# cond ref arguments
masks = kwargs.pop("masks")
masked_ref = kwargs.pop("masked_ref")
kwargs.pop("sigma_min")
# oscillation guidance
text_osci = kwargs.pop("text_osci", False)
image_osci = kwargs.pop("image_osci", False)
scale_temporal_osci = kwargs.pop("scale_temporal_osci", False)
# patch size
patch_size = kwargs.pop("patch_size", 2)
guidance_vec = torch.full(
(img.shape[0],), guidance, device=img.device, dtype=img.dtype
)
for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
# timesteps
t_vec = torch.full(
(img.shape[0],), t_curr, dtype=img.dtype, device=img.device
)
b, c, t, w, h = masked_ref.size()
cond = torch.cat((masks, masked_ref), dim=1)
cond = pack(cond, patch_size=patch_size)
kwargs["cond"] = torch.cat([cond, cond, torch.zeros_like(cond)], dim=0)
# forward preparation
cond_x = img[: len(img) // 3]
img = torch.cat([cond_x, cond_x, cond_x], dim=0)
# forward
pred = model(
img=img,
**kwargs,
timesteps=t_vec,
guidance=guidance_vec,
)
# prepare guidance
text_gs = get_oscillation_gs(guidance, i) if text_osci else guidance
image_gs = (
get_oscillation_gs(guidance_img, i) if image_osci else guidance_img
)
cond, uncond, uncond_2 = pred.chunk(3, dim=0)
if image_gs > 1.0 and scale_temporal_osci:
# image_gs decrease with each denoising step
step_upper_image_gs = torch.linspace(image_gs, 1.0, len(timesteps))[i]
# image_gs increase along the temporal axis of the latent video
image_gs = torch.linspace(1.0, step_upper_image_gs, t)[
None, None, :, None, None
].repeat(b, c, 1, h, w)
image_gs = pack(image_gs, patch_size=patch_size).to(cond.device, cond.dtype)
# update
pred = uncond_2 + image_gs * (uncond - uncond_2) + text_gs * (cond - uncond)
pred = torch.cat([pred, pred, pred], dim=0)
img = img + (t_prev - t_curr) * pred
img = img[: len(img) // 3]
return img
def prepare_guidance(
self,
text: list[str],
optional_models: dict[str, nn.Module],
device: torch.device,
dtype: torch.dtype,
**kwargs,
) -> tuple[list[str], dict[str, Tensor]]:
ret = {}
neg = kwargs.get("neg", None)
ret["guidance_img"] = kwargs.pop("guidance_img")
# text
if neg is None:
neg = [""] * len(text)
text = text + neg + neg
return text, ret
2025-03-17 08:37:03 +01:00
class I2VScalingDenoiser(Denoiser):
def denoise(self, model: MMDiTModel, **kwargs) -> Tensor:
img = kwargs.pop("img")
timesteps = kwargs.pop("timesteps")
guidance = kwargs.pop("guidance")
guidance_img = kwargs.pop("guidance_img")
vbench = kwargs.pop("vbench", None)
do_inference_scaling = kwargs.pop("do_inference_scaling", True)
num_subtree = kwargs.pop("num_subtree", 3)
model_ae = kwargs.pop("model_ae", None)
height = kwargs.pop("height", None)
width = kwargs.pop("width", None)
num_frames = kwargs.pop("num_frames", None)
patch_size = kwargs.pop("patch_size", None)
prompt = kwargs.pop("prompt", None)
backward_scale = kwargs.pop("backward_scale", 1)
forward_scale = kwargs.pop("forward_scale", 1)
scaling_steps = kwargs.pop("scaling_steps", range(len(timesteps))[1:])
vbench_dimension_list = kwargs.pop("vbench_dimension_list", ["overall_consistency"])
vbench_gpus = kwargs.pop("vbench_gpus", [4,5,6,7])
vbench_gpus = vbench_gpus[:num_subtree]
# cond ref arguments
masks = kwargs.pop("masks")
masked_ref = kwargs.pop("masked_ref")
kwargs.pop("sigma_min")
# oscillation guidance
text_osci = kwargs.pop("text_osci", False)
image_osci = kwargs.pop("image_osci", False)
scale_temporal_osci = kwargs.pop("scale_temporal_osci", False)
prompt = [prompt[0]]
prompt_short = prompt[0][:30].replace(" ", "_")
save_dir = f'temp/{prompt_short}'
os.makedirs(save_dir, exist_ok=True)
def save_video(tensor, filename):
# tensor shape: [B, C, T, H, W]
# Normalize to [0, 255]
video = ((tensor + 1) * 127.5).clamp(0, 255).to(torch.uint8)
# Ensure output directory exists
os.makedirs(os.path.dirname(filename), exist_ok=True)
if video.shape[2] == 1: # Single frame - save as image
# Save middle frame for evaluation
frame = video[0, :, 0] # [H, W, C]
image_path = filename.replace('.mp4', '.png')
torchvision.io.write_png(frame.cpu(), image_path)
return image_path
else:
# Also save full video
video = video[0].permute(1, 2, 3, 0)[20:].contiguous() # [B, T, H, W, C]
torchvision.io.write_video(filename, video.cpu(), fps=8)
return filename
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
# timesteps
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
b, c, t, w, h = masked_ref.size()
cond = torch.cat((masks, masked_ref), dim=1)
cond = pack(cond)
kwargs["cond"] = torch.cat([cond, cond, torch.zeros_like(cond)], dim=0)
# forward preparation
cond_x = img[: len(img) // 3]
noise_shape = cond_x.shape
img = torch.cat([cond_x, cond_x, cond_x], dim=0)
if i not in scaling_steps or not do_inference_scaling:
# forward
pred = model(
img=img,
**kwargs,
timesteps=t_vec,
guidance=guidance_vec,
)
# prepare guidance
text_gs = get_oscillation_gs(guidance, i) if text_osci else guidance
image_gs = get_oscillation_gs(guidance_img, i) if image_osci else guidance_img
cond, uncond, uncond_2 = pred.chunk(3, dim=0)
if image_gs > 1.0 and scale_temporal_osci:
# image_gs decrease with each denoising step
step_upper_image_gs = torch.linspace(image_gs, 1.0, len(timesteps))[i]
# image_gs increase along the temporal axis of the latent video
image_gs = torch.linspace(1.0, step_upper_image_gs, t)[None, None, :, None, None].repeat(b, c, 1, h, w)
image_gs = pack(image_gs).to(cond.device, cond.dtype)
# update
pred = uncond_2 + image_gs * (uncond - uncond_2) + text_gs * (cond - uncond)
pred = torch.cat([pred, pred, pred], dim=0)
img = img + (t_prev - t_curr) * pred
else:
subtrees = []
subtree_onesteps = []
# Generate multiple subtrees
for subtree_idx in range(num_subtree):
noise = torch.randn(noise_shape, device=cond_x.device, dtype=cond_x.dtype)
zeros = torch.zeros(noise_shape, device=cond_x.device, dtype=cond_x.dtype)
noise = torch.cat([noise, zeros, zeros], dim=0)
subtree = img - (t_curr - timesteps[i-1]) * forward_scale * noise
t_subtree = t_curr - (t_curr - timesteps[i-1]) * forward_scale
t_subtree_prev = t_subtree + (t_curr - timesteps[i-1]) * backward_scale
t_subtree_vec = torch.full((img.shape[0],), t_subtree, dtype=cond.dtype, device=cond.device)
subtree_noise_pred = model(
img=subtree,
**kwargs,
timesteps=t_subtree_vec,
guidance=guidance_vec,
)
# prepare guidance
text_gs = get_oscillation_gs(guidance, i) if text_osci else guidance
image_gs = get_oscillation_gs(guidance_img, i) if image_osci else guidance_img
cond, uncond, uncond_2 = subtree_noise_pred.chunk(3, dim=0)
if image_gs > 1.0 and scale_temporal_osci:
# image_gs decrease with each denoising step
step_upper_image_gs = torch.linspace(image_gs, 1.0, len(timesteps))[i]
# image_gs increase along the temporal axis of the latent video
image_gs = torch.linspace(1.0, step_upper_image_gs, t)[None, None, :, None, None].repeat(b, c, 1, h, w)
image_gs = pack(image_gs).to(cond.device, cond.dtype)
# update
subtree_noise_pred = uncond_2 + image_gs * (uncond - uncond_2) + text_gs * (cond - uncond)
subtree_noise_pred = torch.cat([subtree_noise_pred, subtree_noise_pred, subtree_noise_pred], dim=0)
subtree = subtree + (t_subtree_prev - t_subtree) * subtree_noise_pred
# Next timestep prediction
t_subtree_prev_vec = torch.full((img.shape[0],), t_subtree_prev, dtype=cond.dtype, device=cond.device)
subtree_onestep_noise_pred = model(
img=subtree,
**kwargs,
timesteps=t_subtree_prev_vec,
guidance=guidance_vec,
)
# prepare guidance
text_gs = get_oscillation_gs(guidance, i) if text_osci else guidance
image_gs = get_oscillation_gs(guidance_img, i) if image_osci else guidance_img
cond, uncond, uncond_2 = subtree_noise_pred.chunk(3, dim=0)
if image_gs > 1.0 and scale_temporal_osci:
# image_gs decrease with each denoising step
step_upper_image_gs = torch.linspace(image_gs, 1.0, len(timesteps))[i]
# image_gs increase along the temporal axis of the latent video
image_gs = torch.linspace(1.0, step_upper_image_gs, t)[None, None, :, None, None].repeat(b, c, 1, h, w)
image_gs = pack(image_gs).to(cond.device, cond.dtype)
# update
subtree_noise_pred = uncond_2 + image_gs * (uncond - uncond_2) + text_gs * (cond - uncond)
subtree_noise_pred = torch.cat([subtree_noise_pred, subtree_noise_pred, subtree_noise_pred], dim=0)
# img = img + (t_prev - t_curr) * pred
subtree_onestep = subtree + (timesteps[-1] - t_subtree_prev) * subtree_onestep_noise_pred
# Save subtree results
if model_ae is not None:
unpacked_subtree_onestep = unpack(subtree_onestep, height, width, num_frames, patch_size)[: len(img) // 3]
decoded_onestep = model_ae.decode(unpacked_subtree_onestep)
# Now returns path to middle frame image for evaluation
os.makedirs(f"{save_dir}/{i}_subtree", exist_ok=True)
decoded_onestep_path = save_video(decoded_onestep, f"{save_dir}/{i}_subtree/{subtree_idx}.mp4")
# Store results
subtrees.append(subtree)
subtree_onesteps.append(decoded_onestep_path)
dist.barrier()
if dist.get_rank() == 0:
videos_path = f"{save_dir}/{i}_subtree"
output_path = f"{save_dir}/{i}_subtree"
prompt_file = "temp/prompt.json"
with open(prompt_file, "w") as fp:
prompt_json = {f"{index}.mp4": prompt[0] for index in range(num_subtree)}
json.dump(prompt_json, fp)
minimal_env = {
"PATH": "/usr/local/bin:/usr/bin:/bin:/mnt/jfs-hdd/home/huangshijie/opensora_vbench/bin",
"CUDA_VISIBLE_DEVICES": ",".join([str(item) for item in vbench_gpus])
}
cmd_args = [
'vbench',
'evaluate',
'--dimension',
','.join(vbench_dimension_list),
'--videos_path',
videos_path,
'--mode',
'custom_input',
'--prompt_file',
prompt_file,
'--output_path',
output_path,
'--ngpus',
str(len(vbench_gpus))
]
subprocess.run(cmd_args, check=True, env=minimal_env)
dist.barrier()
# Find the latest evaluation results file
subtree_dir = f"{save_dir}/{i}_subtree"
eval_files = [f for f in os.listdir(subtree_dir) if f.startswith('results_') and f.endswith('_eval_results.json')]
if not eval_files:
raise FileNotFoundError(f"No evaluation results found in {subtree_dir}")
# Sort by timestamp in filename to get the latest one
latest_eval_file = sorted(eval_files)[-1]
eval_results_path = os.path.join(subtree_dir, latest_eval_file)
with open(eval_results_path, 'r') as f:
eval_results = json.load(f)
best_idx = find_highest_score_video(eval_results)
img = subtrees[best_idx]
img = img[: len(img) // 3]
return img
def prepare_guidance(
self, text: list[str], optional_models: dict[str, nn.Module], device: torch.device, dtype: torch.dtype, **kwargs
) -> tuple[list[str], dict[str, Tensor]]:
ret = {}
neg = kwargs.get("neg", None)
ret["guidance_img"] = kwargs.pop("guidance_img")
# text
if neg is None:
neg = [""] * len(text)
text = text + neg + neg
return text, ret
class DistilledDenoiser(Denoiser):
def denoise(self, model: MMDiTModel, **kwargs) -> Tensor:
img = kwargs.pop("img")
timesteps = kwargs.pop("timesteps")
guidance = kwargs.pop("guidance")
guidance_vec = torch.full(
(img.shape[0],), guidance, device=img.device, dtype=img.dtype
)
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
# timesteps
t_vec = torch.full(
(img.shape[0],), t_curr, dtype=img.dtype, device=img.device
)
# forward
pred = model(
img=img,
**kwargs,
timesteps=t_vec,
guidance=guidance_vec,
)
# update
img = img + (t_prev - t_curr) * pred
return img
def prepare_guidance(
self,
text: list[str],
optional_models: dict[str, nn.Module],
device: torch.device,
dtype: torch.dtype,
**kwargs,
) -> tuple[list[str], dict[str, Tensor]]:
return text, {}
SamplingMethodDict = {
SamplingMethod.I2V: I2VDenoiser(),
2025-03-17 08:37:03 +01:00
SamplingMethod.I2VINFERENCESCALING: I2VScalingDenoiser(),
SamplingMethod.DISTILLED: DistilledDenoiser(),
}
# ======================================================
# Timesteps
# ======================================================
def time_shift(alpha: float, t: Tensor) -> Tensor:
return alpha * t / (1 + (alpha - 1) * t)
def get_res_lin_function(
x1: float = 256, y1: float = 1, x2: float = 4096, y2: float = 3
) -> callable:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
def get_schedule(
num_steps: int,
image_seq_len: int,
num_frames: int,
shift_alpha: float | None = None,
base_shift: float = 1,
max_shift: float = 3,
shift: bool = True,
) -> list[float]:
# extra step for zero
timesteps = torch.linspace(1, 0, num_steps + 1)
# shifting the schedule to favor high timesteps for higher signal images
if shift:
if shift_alpha is None:
# estimate mu based on linear estimation between two points
# spatial scale
shift_alpha = get_res_lin_function(y1=base_shift, y2=max_shift)(
image_seq_len
)
# temporal scale
shift_alpha *= math.sqrt(num_frames)
# calculate shifted timesteps
timesteps = time_shift(shift_alpha, timesteps)
return timesteps.tolist()
def get_noise(
num_samples: int,
height: int,
width: int,
num_frames: int,
device: torch.device,
dtype: torch.dtype,
seed: int,
patch_size: int = 2,
channel: int = 16,
) -> Tensor:
"""
Generate a noise tensor.
Args:
num_samples (int): Number of samples.
height (int): Height of the noise tensor.
width (int): Width of the noise tensor.
num_frames (int): Number of frames.
device (torch.device): Device to put the noise tensor on.
dtype (torch.dtype): Data type of the noise tensor.
seed (int): Seed for the random number generator.
Returns:
Tensor: The noise tensor.
"""
D = int(os.environ.get("AE_SPATIAL_COMPRESSION", 16))
return torch.randn(
num_samples,
channel,
num_frames,
# allow for packing
patch_size * math.ceil(height / D),
patch_size * math.ceil(width / D),
device=device,
dtype=dtype,
generator=torch.Generator(device=device).manual_seed(seed),
)
def pack(x: Tensor, patch_size: int = 2) -> Tensor:
return rearrange(
x, "b c t (h ph) (w pw) -> b (t h w) (c ph pw)", ph=patch_size, pw=patch_size
)
def unpack(
x: Tensor, height: int, width: int, num_frames: int, patch_size: int = 2
) -> Tensor:
D = int(os.environ.get("AE_SPATIAL_COMPRESSION", 16))
return rearrange(
x,
"b (t h w) (c ph pw) -> b c t (h ph) (w pw)",
h=math.ceil(height / D),
w=math.ceil(width / D),
t=num_frames,
ph=patch_size,
pw=patch_size,
)
# ======================================================
# Prepare
# ======================================================
def prepare(
t5,
clip: HFEmbedder,
img: Tensor,
prompt: str | list[str],
seq_align: int = 1,
patch_size: int = 2,
) -> dict[str, Tensor]:
"""
Prepare the input for the model.
Args:
t5 (HFEmbedder): The T5 model.
clip (HFEmbedder): The CLIP model.
img (Tensor): The image tensor.
prompt (str | list[str]): The prompt(s).
Returns:
dict[str, Tensor]: The input dictionary.
img_ids: used for positional embedding in T,H,W dimensions later
text_ids: for positional embedding, but set to 0 for now since our text encoder already encodes positional information
"""
bs, c, t, h, w = img.shape
device, dtype = img.device, img.dtype
if isinstance(prompt, str):
prompt = [prompt]
if bs != len(prompt):
bs = len(prompt)
img = rearrange(
img, "b c t (h ph) (w pw) -> b (t h w) (c ph pw)", ph=patch_size, pw=patch_size
)
if img.shape[0] != bs:
img = repeat(img, "b ... -> (repeat b) ...", repeat=bs // img.shape[0])
img_ids = torch.zeros(t, h // patch_size, w // patch_size, 3)
img_ids[..., 0] = img_ids[..., 0] + torch.arange(t)[:, None, None]
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // patch_size)[None, :, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // patch_size)[None, None, :]
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
# Encode the tokenized prompts
txt = t5(prompt, added_tokens=img_ids.shape[1], seq_align=seq_align)
if txt.shape[0] == 1 and bs > 1:
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
txt_ids = torch.zeros(bs, txt.shape[1], 3)
vec = clip(prompt)
if vec.shape[0] == 1 and bs > 1:
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
return {
"img": img,
"img_ids": img_ids.to(device, dtype),
"txt": txt.to(device, dtype),
"txt_ids": txt_ids.to(device, dtype),
"y_vec": vec.to(device, dtype),
}
def prepare_ids(
img: Tensor,
t5_embedding: Tensor,
clip_embedding: Tensor,
) -> dict[str, Tensor]:
"""
Prepare the input for the model.
Args:
img (Tensor): The image tensor.
t5_embedding (Tensor): The T5 embedding.
clip_embedding (Tensor): The CLIP embedding.
Returns:
dict[str, Tensor]: The input dictionary.
img_ids: used for positional embedding in T,H,W dimensions later
text_ids: for positional embedding, but set to 0 for now since our text encoder already encodes positional information
"""
bs, c, t, h, w = img.shape
device, dtype = img.device, img.dtype
img = rearrange(img, "b c t (h ph) (w pw) -> b (t h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] != bs:
img = repeat(img, "b ... -> (repeat b) ...", repeat=bs // img.shape[0])
img_ids = torch.zeros(t, h // 2, w // 2, 3)
img_ids[..., 0] = img_ids[..., 0] + torch.arange(t)[:, None, None]
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[None, :, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, None, :]
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
# Encode the tokenized prompts
if t5_embedding.shape[0] == 1 and bs > 1:
t5_embedding = repeat(t5_embedding, "1 ... -> bs ...", bs=bs)
txt_ids = torch.zeros(bs, t5_embedding.shape[1], 3)
if clip_embedding.shape[0] == 1 and bs > 1:
clip_embedding = repeat(clip_embedding, "1 ... -> bs ...", bs=bs)
return {
"img": img,
"img_ids": img_ids.to(device, dtype),
"txt": t5_embedding.to(device, dtype),
"txt_ids": txt_ids.to(device, dtype),
"y_vec": clip_embedding.to(device, dtype),
}
def prepare_models(
cfg: Config,
device: torch.device,
dtype: torch.dtype,
offload_model: bool = False,
) -> tuple[nn.Module, nn.Module, nn.Module, nn.Module, dict[str, nn.Module]]:
"""
Prepare models for inference.
Args:
cfg (Config): The configuration object.
device (torch.device): The device to use.
dtype (torch.dtype): The data type to use.
Returns:
tuple[nn.Module, nn.Module, nn.Module, nn.Module, dict[str, nn.Module]]: The models. They are the diffusion model, the autoencoder model, the T5 model, the CLIP model, and the optional models.
"""
model_device = (
"cpu" if offload_model and cfg.get("img_flux", None) is not None else device
)
model = build_module(
cfg.model, MODELS, device_map=model_device, torch_dtype=dtype
).eval()
model_ae = build_module(
cfg.ae, MODELS, device_map=model_device, torch_dtype=dtype
).eval()
model_t5 = build_module(cfg.t5, MODELS, device_map=device, torch_dtype=dtype).eval()
model_clip = build_module(
cfg.clip, MODELS, device_map=device, torch_dtype=dtype
).eval()
if cfg.get("pretrained_lora_path", None) is not None:
model = PeftModel.from_pretrained(
model, cfg.pretrained_lora_path, is_trainable=False
)
# optional models
optional_models = {}
if cfg.get("img_flux", None) is not None:
model_img_flux = build_module(
cfg.img_flux, MODELS, device_map=device, torch_dtype=dtype
).eval()
model_ae_img_flux = build_module(
cfg.img_flux_ae, MODELS, device_map=device, torch_dtype=dtype
).eval()
optional_models["img_flux"] = model_img_flux
optional_models["img_flux_ae"] = model_ae_img_flux
return model, model_ae, model_t5, model_clip, optional_models
def prepare_api(
model: nn.Module,
model_ae: nn.Module,
model_t5: nn.Module,
model_clip: nn.Module,
optional_models: dict[str, nn.Module],
) -> callable:
"""
Prepare the API function for inference.
Args:
model (nn.Module): The diffusion model.
model_ae (nn.Module): The autoencoder model.
model_t5 (nn.Module): The T5 model.
model_clip (nn.Module): The CLIP model.
Returns:
callable: The API function for inference.
"""
@torch.inference_mode()
def api_fn(
opt: SamplingOption,
cond_type: str = "t2v",
seed: int = None,
sigma_min: float = 1e-5,
text: list[str] = None,
neg: list[str] = None,
patch_size: int = 2,
channel: int = 16,
**kwargs,
):
"""
The API function for inference.
Args:
opt (SamplingOption): The sampling options.
text (list[str], optional): The text prompts. Defaults to None.
neg (list[str], optional): The negative text prompts. Defaults to None.
Returns:
torch.Tensor: The generated images.
"""
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
# passing seed will overwrite opt seed
if seed is None:
# random seed if not provided
seed = opt.seed if opt.seed is not None else random.randint(0, 2**32 - 1)
if opt.is_causal_vae:
num_frames = (
1
if opt.num_frames == 1
else (opt.num_frames - 1) // opt.temporal_reduction + 1
)
else:
num_frames = (
1 if opt.num_frames == 1 else opt.num_frames // opt.temporal_reduction
)
z = get_noise(
len(text),
opt.height,
opt.width,
num_frames,
device,
dtype,
seed,
patch_size=patch_size,
channel=channel // (patch_size**2),
)
denoiser = SamplingMethodDict[opt.method]
# i2v reference conditions
references = [None] * len(text)
if cond_type != "t2v" and "ref" in kwargs:
reference_path_list = kwargs.pop("ref")
references = collect_references_batch(
reference_path_list,
cond_type,
model_ae,
(opt.height, opt.width),
is_causal=opt.is_causal_vae,
)
elif cond_type != "t2v":
print(
"your csv file doesn't have a ref column or is not processed properly. will default to cond_type t2v!"
)
cond_type = "t2v"
# timestep editing
timesteps = get_schedule(
opt.num_steps,
(z.shape[-1] * z.shape[-2]) // patch_size**2,
num_frames,
shift=opt.shift,
shift_alpha=opt.flow_shift,
)
# prepare classifier-free guidance data (method specific)
text, additional_inp = denoiser.prepare_guidance(
text=text,
optional_models=optional_models,
device=device,
dtype=dtype,
neg=neg,
guidance_img=opt.guidance_img,
)
inp = prepare(model_t5, model_clip, z, prompt=text, patch_size=patch_size)
inp.update(additional_inp)
2025-03-17 08:37:03 +01:00
if opt.method in [SamplingMethod.I2V, SamplingMethod.I2VINFERENCESCALING]:
# prepare references
masks, masked_ref = prepare_inference_condition(
z, cond_type, ref_list=references, causal=opt.is_causal_vae
)
inp["masks"] = masks
inp["masked_ref"] = masked_ref
inp["sigma_min"] = sigma_min
x = denoiser.denoise(
model,
**inp,
timesteps=timesteps,
guidance=opt.guidance,
text_osci=opt.text_osci,
image_osci=opt.image_osci,
scale_temporal_osci=(
opt.scale_temporal_osci and "i2v" in cond_type
), # don't use temporal osci for v2v or t2v
flow_shift=opt.flow_shift,
patch_size=patch_size,
2025-03-17 08:37:03 +01:00
# inference scaling parameters
do_inference_scaling=opt.do_inference_scaling,
num_subtree=opt.num_subtree,
model_ae=model_ae,
height=opt.height,
width=opt.width,
backward_scale = opt.backward_scale,
forward_scale = opt.forward_scale,
scaling_steps = opt.scaling_steps,
num_frames=num_frames,
prompt=text,
vbench_dimension_list=opt.vbench_dimension_list,
vbench_gpus=opt.vbench_gpus
)
x = unpack(x, opt.height, opt.width, num_frames, patch_size=patch_size)
# replace for image condition
if cond_type == "i2v_head":
x[0, :, :1] = references[0][0]
elif cond_type == "i2v_tail":
x[0, :, -1:] = references[0][0]
elif cond_type == "i2v_loop":
x[0, :, :1] = references[0][0]
x[0, :, -1:] = references[0][1]
x = model_ae.decode(x)
x = x[:, :, : opt.num_frames] # image
# remove the duplicate frames
if not opt.is_causal_vae:
if cond_type == "i2v_head":
pad_len = model_ae.compression[0] - 1
x = x[:, :, pad_len:]
elif cond_type == "i2v_tail":
pad_len = model_ae.compression[0] - 1
x = x[:, :, :-pad_len]
elif cond_type == "i2v_loop":
pad_len = model_ae.compression[0] - 1
x = x[:, :, pad_len:-pad_len]
return x
return api_fn