inference scaling init

This commit is contained in:
nicolaus 2025-03-17 15:37:03 +08:00
parent 4c3c1740a3
commit bb64366a85
6 changed files with 9518 additions and 1 deletions

3
.gitignore vendored
View file

@ -196,3 +196,6 @@ exps
ckpts
flash-attention
datasets
# inference scaling
temp/

9132
assets/VBench_full_info.json Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,28 @@
_base_ = [ # inherit grammer from mmengine
"256px.py",
"plugins/t2i2v.py",
"plugins/tp.py", # use tensor parallel
]
sampling_option = dict(
resolution="256px", # 256px or 768px
aspect_ratio="16:9", # 9:16 or 16:9 or 1:1
num_frames=129, # number of frames
num_steps=50, # number of steps
shift=True,
temporal_reduction=4,
is_causal_vae=True,
guidance=7.5, # guidance for text-to-video
guidance_img=3.0, # guidance for image-to-video
text_osci=True, # enable text guidance oscillation
image_osci=True, # enable image guidance oscillation
scale_temporal_osci=True,
method="i2v_inference_scaling", # hard-coded for now
vbench_dimension_list=['subject_consistency'],
do_inference_scaling=True,
num_subtree=3,
backward_scale=0.78,
forward_scale=0.83,
scaling_steps=[1,2,4,7,9,15,20],
seed=None, # random seed for z
vbench_gpus=[4,5,6,7]
)

View file

@ -0,0 +1,17 @@
_base_ = [ # inherit grammer from mmengine
"256px.py",
"plugins/sp.py", # use sequence parallel
"plugins/t2i2v.py",
]
sampling_option = dict(
resolution="768px", # 256px or 768px
method="i2v_inference_scaling", # hard-coded for now
vbench_dimension_list=['subject_consistency'],
do_inference_scaling=True,
num_subtree=3,
backward_scale=0.78,
forward_scale=0.83,
scaling_steps=[1,2,4,7,9,15,20],
seed=None, # random seed for z
vbench_gpus=[4,5,6,7]
)

View file

@ -16,6 +16,7 @@ from opensora.utils.prompt_refine import refine_prompts
class SamplingMethod(Enum):
I2V = "i2v" # for open sora video generation
DISTILLED = "distill" # for flux image generation
I2VINFERENCESCALING = "i2v_inference_scaling" # for open sora video generation with inference scaling
def create_tmp_csv(save_dir: str, prompt: str, ref: str = None, create=True) -> str:

View file

@ -3,12 +3,17 @@ import os
import random
from abc import ABC, abstractmethod
from dataclasses import dataclass, replace
import json
import subprocess
from collections import defaultdict
import torch
import torchvision
from einops import rearrange, repeat
from mmengine.config import Config
from peft import PeftModel
from torch import Tensor, nn
import torch.distributed as dist
from opensora.datasets.aspect import get_image_size
from opensora.models.mmdit.model import MMDiTModel
@ -78,6 +83,85 @@ class SamplingOption:
# flow shift
flow_shift: float | None = None
# do_inference_scaling
do_inference_scaling: bool | None = False
# num_subtree
num_subtree: int = 3
# backward_scale
backward_scale: float = 1.0
# forward_scale
forward_scale: float = 1.0
# scaling_steps
scaling_steps: list = None
# vbench_gpus
vbench_gpus: list = None
# vbench_dimension_list
vbench_dimension_list: list = None
def find_highest_score_video(data):
video_scores = defaultdict(list)
normalization_rules = {
"subject_consistency": lambda e: e["video_results"],
"background_consistency": lambda e: e["video_results"],
"temporal_flickering": lambda e: e["video_results"],
"motion_smoothness": lambda e: e["video_results"],
"dynamic_degree": lambda e: 1.0 if e["video_results"] else 0.0,
"aesthetic_quality": lambda e: e["video_results"],
"imaging_quality": lambda e: e["video_results"] / 100,
"human_action": lambda e: e["cor_num_per_video"],
"temporal_style": lambda e: e["video_results"],
"overall_consistency": lambda e: e["video_results"]
}
for metric_name, metric_data in data.items():
if not isinstance(metric_data, list) or len(metric_data) < 2:
continue
process_rule = normalization_rules.get(metric_name)
if not process_rule:
continue
for entry in metric_data[1]:
try:
path_parts = entry["video_path"].split("/")
filename = path_parts[-1]
video_index = int(filename.split(".")[0])
score = process_rule(entry)
video_scores[video_index].append(score)
except (KeyError, ValueError, IndexError):
continue
avg_scores = {}
for vid, scores in video_scores.items():
if len(scores) == 0:
avg_scores[vid] = 0.0
continue
avg_scores[vid] = sum(scores) / len(scores)
if not avg_scores:
return -1
max_score = max(avg_scores.values())
candidates = sorted(
[vid for vid, score in avg_scores.items() if score == max_score]
)
return candidates[0] if candidates else -1
def sanitize_sampling_option(sampling_option: SamplingOption) -> SamplingOption:
"""
@ -245,6 +329,244 @@ class I2VDenoiser(Denoiser):
return text, ret
class I2VScalingDenoiser(Denoiser):
def denoise(self, model: MMDiTModel, **kwargs) -> Tensor:
img = kwargs.pop("img")
timesteps = kwargs.pop("timesteps")
guidance = kwargs.pop("guidance")
guidance_img = kwargs.pop("guidance_img")
vbench = kwargs.pop("vbench", None)
do_inference_scaling = kwargs.pop("do_inference_scaling", True)
num_subtree = kwargs.pop("num_subtree", 3)
model_ae = kwargs.pop("model_ae", None)
height = kwargs.pop("height", None)
width = kwargs.pop("width", None)
num_frames = kwargs.pop("num_frames", None)
patch_size = kwargs.pop("patch_size", None)
prompt = kwargs.pop("prompt", None)
backward_scale = kwargs.pop("backward_scale", 1)
forward_scale = kwargs.pop("forward_scale", 1)
scaling_steps = kwargs.pop("scaling_steps", range(len(timesteps))[1:])
vbench_dimension_list = kwargs.pop("vbench_dimension_list", ["overall_consistency"])
vbench_gpus = kwargs.pop("vbench_gpus", [4,5,6,7])
vbench_gpus = vbench_gpus[:num_subtree]
# cond ref arguments
masks = kwargs.pop("masks")
masked_ref = kwargs.pop("masked_ref")
kwargs.pop("sigma_min")
# oscillation guidance
text_osci = kwargs.pop("text_osci", False)
image_osci = kwargs.pop("image_osci", False)
scale_temporal_osci = kwargs.pop("scale_temporal_osci", False)
prompt = [prompt[0]]
prompt_short = prompt[0][:30].replace(" ", "_")
save_dir = f'temp/{prompt_short}'
os.makedirs(save_dir, exist_ok=True)
def save_video(tensor, filename):
# tensor shape: [B, C, T, H, W]
# Normalize to [0, 255]
video = ((tensor + 1) * 127.5).clamp(0, 255).to(torch.uint8)
# Ensure output directory exists
os.makedirs(os.path.dirname(filename), exist_ok=True)
if video.shape[2] == 1: # Single frame - save as image
# Save middle frame for evaluation
frame = video[0, :, 0] # [H, W, C]
image_path = filename.replace('.mp4', '.png')
torchvision.io.write_png(frame.cpu(), image_path)
return image_path
else:
# Also save full video
video = video[0].permute(1, 2, 3, 0)[20:].contiguous() # [B, T, H, W, C]
torchvision.io.write_video(filename, video.cpu(), fps=8)
return filename
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
# timesteps
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
b, c, t, w, h = masked_ref.size()
cond = torch.cat((masks, masked_ref), dim=1)
cond = pack(cond)
kwargs["cond"] = torch.cat([cond, cond, torch.zeros_like(cond)], dim=0)
# forward preparation
cond_x = img[: len(img) // 3]
noise_shape = cond_x.shape
img = torch.cat([cond_x, cond_x, cond_x], dim=0)
if i not in scaling_steps or not do_inference_scaling:
# forward
pred = model(
img=img,
**kwargs,
timesteps=t_vec,
guidance=guidance_vec,
)
# prepare guidance
text_gs = get_oscillation_gs(guidance, i) if text_osci else guidance
image_gs = get_oscillation_gs(guidance_img, i) if image_osci else guidance_img
cond, uncond, uncond_2 = pred.chunk(3, dim=0)
if image_gs > 1.0 and scale_temporal_osci:
# image_gs decrease with each denoising step
step_upper_image_gs = torch.linspace(image_gs, 1.0, len(timesteps))[i]
# image_gs increase along the temporal axis of the latent video
image_gs = torch.linspace(1.0, step_upper_image_gs, t)[None, None, :, None, None].repeat(b, c, 1, h, w)
image_gs = pack(image_gs).to(cond.device, cond.dtype)
# update
pred = uncond_2 + image_gs * (uncond - uncond_2) + text_gs * (cond - uncond)
pred = torch.cat([pred, pred, pred], dim=0)
img = img + (t_prev - t_curr) * pred
else:
subtrees = []
subtree_onesteps = []
# Generate multiple subtrees
for subtree_idx in range(num_subtree):
noise = torch.randn(noise_shape, device=cond_x.device, dtype=cond_x.dtype)
zeros = torch.zeros(noise_shape, device=cond_x.device, dtype=cond_x.dtype)
noise = torch.cat([noise, zeros, zeros], dim=0)
subtree = img - (t_curr - timesteps[i-1]) * forward_scale * noise
t_subtree = t_curr - (t_curr - timesteps[i-1]) * forward_scale
t_subtree_prev = t_subtree + (t_curr - timesteps[i-1]) * backward_scale
t_subtree_vec = torch.full((img.shape[0],), t_subtree, dtype=cond.dtype, device=cond.device)
subtree_noise_pred = model(
img=subtree,
**kwargs,
timesteps=t_subtree_vec,
guidance=guidance_vec,
)
# prepare guidance
text_gs = get_oscillation_gs(guidance, i) if text_osci else guidance
image_gs = get_oscillation_gs(guidance_img, i) if image_osci else guidance_img
cond, uncond, uncond_2 = subtree_noise_pred.chunk(3, dim=0)
if image_gs > 1.0 and scale_temporal_osci:
# image_gs decrease with each denoising step
step_upper_image_gs = torch.linspace(image_gs, 1.0, len(timesteps))[i]
# image_gs increase along the temporal axis of the latent video
image_gs = torch.linspace(1.0, step_upper_image_gs, t)[None, None, :, None, None].repeat(b, c, 1, h, w)
image_gs = pack(image_gs).to(cond.device, cond.dtype)
# update
subtree_noise_pred = uncond_2 + image_gs * (uncond - uncond_2) + text_gs * (cond - uncond)
subtree_noise_pred = torch.cat([subtree_noise_pred, subtree_noise_pred, subtree_noise_pred], dim=0)
subtree = subtree + (t_subtree_prev - t_subtree) * subtree_noise_pred
# Next timestep prediction
t_subtree_prev_vec = torch.full((img.shape[0],), t_subtree_prev, dtype=cond.dtype, device=cond.device)
subtree_onestep_noise_pred = model(
img=subtree,
**kwargs,
timesteps=t_subtree_prev_vec,
guidance=guidance_vec,
)
# prepare guidance
text_gs = get_oscillation_gs(guidance, i) if text_osci else guidance
image_gs = get_oscillation_gs(guidance_img, i) if image_osci else guidance_img
cond, uncond, uncond_2 = subtree_noise_pred.chunk(3, dim=0)
if image_gs > 1.0 and scale_temporal_osci:
# image_gs decrease with each denoising step
step_upper_image_gs = torch.linspace(image_gs, 1.0, len(timesteps))[i]
# image_gs increase along the temporal axis of the latent video
image_gs = torch.linspace(1.0, step_upper_image_gs, t)[None, None, :, None, None].repeat(b, c, 1, h, w)
image_gs = pack(image_gs).to(cond.device, cond.dtype)
# update
subtree_noise_pred = uncond_2 + image_gs * (uncond - uncond_2) + text_gs * (cond - uncond)
subtree_noise_pred = torch.cat([subtree_noise_pred, subtree_noise_pred, subtree_noise_pred], dim=0)
# img = img + (t_prev - t_curr) * pred
subtree_onestep = subtree + (timesteps[-1] - t_subtree_prev) * subtree_onestep_noise_pred
# Save subtree results
if model_ae is not None:
unpacked_subtree_onestep = unpack(subtree_onestep, height, width, num_frames, patch_size)[: len(img) // 3]
decoded_onestep = model_ae.decode(unpacked_subtree_onestep)
# Now returns path to middle frame image for evaluation
os.makedirs(f"{save_dir}/{i}_subtree", exist_ok=True)
decoded_onestep_path = save_video(decoded_onestep, f"{save_dir}/{i}_subtree/{subtree_idx}.mp4")
# Store results
subtrees.append(subtree)
subtree_onesteps.append(decoded_onestep_path)
dist.barrier()
if dist.get_rank() == 0:
videos_path = f"{save_dir}/{i}_subtree"
output_path = f"{save_dir}/{i}_subtree"
prompt_file = "temp/prompt.json"
with open(prompt_file, "w") as fp:
prompt_json = {f"{index}.mp4": prompt[0] for index in range(num_subtree)}
json.dump(prompt_json, fp)
minimal_env = {
"PATH": "/usr/local/bin:/usr/bin:/bin:/mnt/jfs-hdd/home/huangshijie/opensora_vbench/bin",
"CUDA_VISIBLE_DEVICES": ",".join([str(item) for item in vbench_gpus])
}
cmd_args = [
'vbench',
'evaluate',
'--dimension',
','.join(vbench_dimension_list),
'--videos_path',
videos_path,
'--mode',
'custom_input',
'--prompt_file',
prompt_file,
'--output_path',
output_path,
'--ngpus',
str(len(vbench_gpus))
]
subprocess.run(cmd_args, check=True, env=minimal_env)
dist.barrier()
# Find the latest evaluation results file
subtree_dir = f"{save_dir}/{i}_subtree"
eval_files = [f for f in os.listdir(subtree_dir) if f.startswith('results_') and f.endswith('_eval_results.json')]
if not eval_files:
raise FileNotFoundError(f"No evaluation results found in {subtree_dir}")
# Sort by timestamp in filename to get the latest one
latest_eval_file = sorted(eval_files)[-1]
eval_results_path = os.path.join(subtree_dir, latest_eval_file)
with open(eval_results_path, 'r') as f:
eval_results = json.load(f)
best_idx = find_highest_score_video(eval_results)
img = subtrees[best_idx]
img = img[: len(img) // 3]
return img
def prepare_guidance(
self, text: list[str], optional_models: dict[str, nn.Module], device: torch.device, dtype: torch.dtype, **kwargs
) -> tuple[list[str], dict[str, Tensor]]:
ret = {}
neg = kwargs.get("neg", None)
ret["guidance_img"] = kwargs.pop("guidance_img")
# text
if neg is None:
neg = [""] * len(text)
text = text + neg + neg
return text, ret
class DistilledDenoiser(Denoiser):
def denoise(self, model: MMDiTModel, **kwargs) -> Tensor:
img = kwargs.pop("img")
@ -283,6 +605,7 @@ class DistilledDenoiser(Denoiser):
SamplingMethodDict = {
SamplingMethod.I2V: I2VDenoiser(),
SamplingMethod.I2VINFERENCESCALING: I2VScalingDenoiser(),
SamplingMethod.DISTILLED: DistilledDenoiser(),
}
@ -672,7 +995,7 @@ def prepare_api(
inp = prepare(model_t5, model_clip, z, prompt=text, patch_size=patch_size)
inp.update(additional_inp)
if opt.method in [SamplingMethod.I2V]:
if opt.method in [SamplingMethod.I2V, SamplingMethod.I2VINFERENCESCALING]:
# prepare references
masks, masked_ref = prepare_inference_condition(
z, cond_type, ref_list=references, causal=opt.is_causal_vae
@ -693,6 +1016,19 @@ def prepare_api(
), # don't use temporal osci for v2v or t2v
flow_shift=opt.flow_shift,
patch_size=patch_size,
# inference scaling parameters
do_inference_scaling=opt.do_inference_scaling,
num_subtree=opt.num_subtree,
model_ae=model_ae,
height=opt.height,
width=opt.width,
backward_scale = opt.backward_scale,
forward_scale = opt.forward_scale,
scaling_steps = opt.scaling_steps,
num_frames=num_frames,
prompt=text,
vbench_dimension_list=opt.vbench_dimension_list,
vbench_gpus=opt.vbench_gpus
)
x = unpack(x, opt.height, opt.width, num_frames, patch_size=patch_size)