mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-14 18:25:35 +02:00
inference scaling init
This commit is contained in:
parent
4c3c1740a3
commit
bb64366a85
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -196,3 +196,6 @@ exps
|
|||
ckpts
|
||||
flash-attention
|
||||
datasets
|
||||
|
||||
# inference scaling
|
||||
temp/
|
||||
9132
assets/VBench_full_info.json
Normal file
9132
assets/VBench_full_info.json
Normal file
File diff suppressed because it is too large
Load diff
28
configs/diffusion/inference/256px_inference_scaling.py
Normal file
28
configs/diffusion/inference/256px_inference_scaling.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
_base_ = [ # inherit grammer from mmengine
|
||||
"256px.py",
|
||||
"plugins/t2i2v.py",
|
||||
"plugins/tp.py", # use tensor parallel
|
||||
]
|
||||
sampling_option = dict(
|
||||
resolution="256px", # 256px or 768px
|
||||
aspect_ratio="16:9", # 9:16 or 16:9 or 1:1
|
||||
num_frames=129, # number of frames
|
||||
num_steps=50, # number of steps
|
||||
shift=True,
|
||||
temporal_reduction=4,
|
||||
is_causal_vae=True,
|
||||
guidance=7.5, # guidance for text-to-video
|
||||
guidance_img=3.0, # guidance for image-to-video
|
||||
text_osci=True, # enable text guidance oscillation
|
||||
image_osci=True, # enable image guidance oscillation
|
||||
scale_temporal_osci=True,
|
||||
method="i2v_inference_scaling", # hard-coded for now
|
||||
vbench_dimension_list=['subject_consistency'],
|
||||
do_inference_scaling=True,
|
||||
num_subtree=3,
|
||||
backward_scale=0.78,
|
||||
forward_scale=0.83,
|
||||
scaling_steps=[1,2,4,7,9,15,20],
|
||||
seed=None, # random seed for z
|
||||
vbench_gpus=[4,5,6,7]
|
||||
)
|
||||
17
configs/diffusion/inference/768px_inference_scaling.py
Normal file
17
configs/diffusion/inference/768px_inference_scaling.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
_base_ = [ # inherit grammer from mmengine
|
||||
"256px.py",
|
||||
"plugins/sp.py", # use sequence parallel
|
||||
"plugins/t2i2v.py",
|
||||
]
|
||||
sampling_option = dict(
|
||||
resolution="768px", # 256px or 768px
|
||||
method="i2v_inference_scaling", # hard-coded for now
|
||||
vbench_dimension_list=['subject_consistency'],
|
||||
do_inference_scaling=True,
|
||||
num_subtree=3,
|
||||
backward_scale=0.78,
|
||||
forward_scale=0.83,
|
||||
scaling_steps=[1,2,4,7,9,15,20],
|
||||
seed=None, # random seed for z
|
||||
vbench_gpus=[4,5,6,7]
|
||||
)
|
||||
|
|
@ -16,6 +16,7 @@ from opensora.utils.prompt_refine import refine_prompts
|
|||
class SamplingMethod(Enum):
|
||||
I2V = "i2v" # for open sora video generation
|
||||
DISTILLED = "distill" # for flux image generation
|
||||
I2VINFERENCESCALING = "i2v_inference_scaling" # for open sora video generation with inference scaling
|
||||
|
||||
|
||||
def create_tmp_csv(save_dir: str, prompt: str, ref: str = None, create=True) -> str:
|
||||
|
|
|
|||
|
|
@ -3,12 +3,17 @@ import os
|
|||
import random
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, replace
|
||||
import json
|
||||
import subprocess
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from einops import rearrange, repeat
|
||||
from mmengine.config import Config
|
||||
from peft import PeftModel
|
||||
from torch import Tensor, nn
|
||||
import torch.distributed as dist
|
||||
|
||||
from opensora.datasets.aspect import get_image_size
|
||||
from opensora.models.mmdit.model import MMDiTModel
|
||||
|
|
@ -78,6 +83,85 @@ class SamplingOption:
|
|||
# flow shift
|
||||
flow_shift: float | None = None
|
||||
|
||||
# do_inference_scaling
|
||||
do_inference_scaling: bool | None = False
|
||||
|
||||
# num_subtree
|
||||
num_subtree: int = 3
|
||||
|
||||
# backward_scale
|
||||
backward_scale: float = 1.0
|
||||
|
||||
# forward_scale
|
||||
forward_scale: float = 1.0
|
||||
|
||||
# scaling_steps
|
||||
scaling_steps: list = None
|
||||
|
||||
# vbench_gpus
|
||||
vbench_gpus: list = None
|
||||
|
||||
# vbench_dimension_list
|
||||
vbench_dimension_list: list = None
|
||||
|
||||
|
||||
def find_highest_score_video(data):
|
||||
video_scores = defaultdict(list)
|
||||
normalization_rules = {
|
||||
"subject_consistency": lambda e: e["video_results"],
|
||||
"background_consistency": lambda e: e["video_results"],
|
||||
"temporal_flickering": lambda e: e["video_results"],
|
||||
"motion_smoothness": lambda e: e["video_results"],
|
||||
|
||||
"dynamic_degree": lambda e: 1.0 if e["video_results"] else 0.0,
|
||||
|
||||
"aesthetic_quality": lambda e: e["video_results"],
|
||||
"imaging_quality": lambda e: e["video_results"] / 100,
|
||||
|
||||
"human_action": lambda e: e["cor_num_per_video"],
|
||||
|
||||
"temporal_style": lambda e: e["video_results"],
|
||||
"overall_consistency": lambda e: e["video_results"]
|
||||
}
|
||||
|
||||
for metric_name, metric_data in data.items():
|
||||
if not isinstance(metric_data, list) or len(metric_data) < 2:
|
||||
continue
|
||||
|
||||
process_rule = normalization_rules.get(metric_name)
|
||||
if not process_rule:
|
||||
continue
|
||||
|
||||
for entry in metric_data[1]:
|
||||
try:
|
||||
path_parts = entry["video_path"].split("/")
|
||||
filename = path_parts[-1]
|
||||
video_index = int(filename.split(".")[0])
|
||||
|
||||
score = process_rule(entry)
|
||||
video_scores[video_index].append(score)
|
||||
|
||||
except (KeyError, ValueError, IndexError):
|
||||
continue
|
||||
|
||||
avg_scores = {}
|
||||
for vid, scores in video_scores.items():
|
||||
if len(scores) == 0:
|
||||
avg_scores[vid] = 0.0
|
||||
continue
|
||||
|
||||
avg_scores[vid] = sum(scores) / len(scores)
|
||||
|
||||
if not avg_scores:
|
||||
return -1
|
||||
|
||||
max_score = max(avg_scores.values())
|
||||
candidates = sorted(
|
||||
[vid for vid, score in avg_scores.items() if score == max_score]
|
||||
)
|
||||
|
||||
return candidates[0] if candidates else -1
|
||||
|
||||
|
||||
def sanitize_sampling_option(sampling_option: SamplingOption) -> SamplingOption:
|
||||
"""
|
||||
|
|
@ -245,6 +329,244 @@ class I2VDenoiser(Denoiser):
|
|||
return text, ret
|
||||
|
||||
|
||||
class I2VScalingDenoiser(Denoiser):
|
||||
def denoise(self, model: MMDiTModel, **kwargs) -> Tensor:
|
||||
img = kwargs.pop("img")
|
||||
timesteps = kwargs.pop("timesteps")
|
||||
guidance = kwargs.pop("guidance")
|
||||
guidance_img = kwargs.pop("guidance_img")
|
||||
vbench = kwargs.pop("vbench", None)
|
||||
do_inference_scaling = kwargs.pop("do_inference_scaling", True)
|
||||
num_subtree = kwargs.pop("num_subtree", 3)
|
||||
model_ae = kwargs.pop("model_ae", None)
|
||||
height = kwargs.pop("height", None)
|
||||
width = kwargs.pop("width", None)
|
||||
num_frames = kwargs.pop("num_frames", None)
|
||||
patch_size = kwargs.pop("patch_size", None)
|
||||
prompt = kwargs.pop("prompt", None)
|
||||
backward_scale = kwargs.pop("backward_scale", 1)
|
||||
forward_scale = kwargs.pop("forward_scale", 1)
|
||||
scaling_steps = kwargs.pop("scaling_steps", range(len(timesteps))[1:])
|
||||
vbench_dimension_list = kwargs.pop("vbench_dimension_list", ["overall_consistency"])
|
||||
vbench_gpus = kwargs.pop("vbench_gpus", [4,5,6,7])
|
||||
|
||||
|
||||
vbench_gpus = vbench_gpus[:num_subtree]
|
||||
|
||||
# cond ref arguments
|
||||
masks = kwargs.pop("masks")
|
||||
masked_ref = kwargs.pop("masked_ref")
|
||||
kwargs.pop("sigma_min")
|
||||
|
||||
# oscillation guidance
|
||||
text_osci = kwargs.pop("text_osci", False)
|
||||
image_osci = kwargs.pop("image_osci", False)
|
||||
scale_temporal_osci = kwargs.pop("scale_temporal_osci", False)
|
||||
|
||||
prompt = [prompt[0]]
|
||||
prompt_short = prompt[0][:30].replace(" ", "_")
|
||||
save_dir = f'temp/{prompt_short}'
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
def save_video(tensor, filename):
|
||||
# tensor shape: [B, C, T, H, W]
|
||||
# Normalize to [0, 255]
|
||||
video = ((tensor + 1) * 127.5).clamp(0, 255).to(torch.uint8)
|
||||
|
||||
# Ensure output directory exists
|
||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||
if video.shape[2] == 1: # Single frame - save as image
|
||||
# Save middle frame for evaluation
|
||||
frame = video[0, :, 0] # [H, W, C]
|
||||
image_path = filename.replace('.mp4', '.png')
|
||||
torchvision.io.write_png(frame.cpu(), image_path)
|
||||
return image_path
|
||||
else:
|
||||
# Also save full video
|
||||
video = video[0].permute(1, 2, 3, 0)[20:].contiguous() # [B, T, H, W, C]
|
||||
torchvision.io.write_video(filename, video.cpu(), fps=8)
|
||||
return filename
|
||||
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
|
||||
# timesteps
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
b, c, t, w, h = masked_ref.size()
|
||||
cond = torch.cat((masks, masked_ref), dim=1)
|
||||
cond = pack(cond)
|
||||
kwargs["cond"] = torch.cat([cond, cond, torch.zeros_like(cond)], dim=0)
|
||||
|
||||
# forward preparation
|
||||
cond_x = img[: len(img) // 3]
|
||||
|
||||
noise_shape = cond_x.shape
|
||||
img = torch.cat([cond_x, cond_x, cond_x], dim=0)
|
||||
|
||||
|
||||
if i not in scaling_steps or not do_inference_scaling:
|
||||
# forward
|
||||
pred = model(
|
||||
img=img,
|
||||
**kwargs,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
|
||||
# prepare guidance
|
||||
text_gs = get_oscillation_gs(guidance, i) if text_osci else guidance
|
||||
image_gs = get_oscillation_gs(guidance_img, i) if image_osci else guidance_img
|
||||
cond, uncond, uncond_2 = pred.chunk(3, dim=0)
|
||||
if image_gs > 1.0 and scale_temporal_osci:
|
||||
# image_gs decrease with each denoising step
|
||||
step_upper_image_gs = torch.linspace(image_gs, 1.0, len(timesteps))[i]
|
||||
# image_gs increase along the temporal axis of the latent video
|
||||
image_gs = torch.linspace(1.0, step_upper_image_gs, t)[None, None, :, None, None].repeat(b, c, 1, h, w)
|
||||
image_gs = pack(image_gs).to(cond.device, cond.dtype)
|
||||
|
||||
# update
|
||||
pred = uncond_2 + image_gs * (uncond - uncond_2) + text_gs * (cond - uncond)
|
||||
pred = torch.cat([pred, pred, pred], dim=0)
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
else:
|
||||
subtrees = []
|
||||
subtree_onesteps = []
|
||||
|
||||
# Generate multiple subtrees
|
||||
for subtree_idx in range(num_subtree):
|
||||
noise = torch.randn(noise_shape, device=cond_x.device, dtype=cond_x.dtype)
|
||||
zeros = torch.zeros(noise_shape, device=cond_x.device, dtype=cond_x.dtype)
|
||||
noise = torch.cat([noise, zeros, zeros], dim=0)
|
||||
subtree = img - (t_curr - timesteps[i-1]) * forward_scale * noise
|
||||
t_subtree = t_curr - (t_curr - timesteps[i-1]) * forward_scale
|
||||
t_subtree_prev = t_subtree + (t_curr - timesteps[i-1]) * backward_scale
|
||||
t_subtree_vec = torch.full((img.shape[0],), t_subtree, dtype=cond.dtype, device=cond.device)
|
||||
subtree_noise_pred = model(
|
||||
img=subtree,
|
||||
**kwargs,
|
||||
timesteps=t_subtree_vec,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
# prepare guidance
|
||||
text_gs = get_oscillation_gs(guidance, i) if text_osci else guidance
|
||||
image_gs = get_oscillation_gs(guidance_img, i) if image_osci else guidance_img
|
||||
cond, uncond, uncond_2 = subtree_noise_pred.chunk(3, dim=0)
|
||||
if image_gs > 1.0 and scale_temporal_osci:
|
||||
# image_gs decrease with each denoising step
|
||||
step_upper_image_gs = torch.linspace(image_gs, 1.0, len(timesteps))[i]
|
||||
# image_gs increase along the temporal axis of the latent video
|
||||
image_gs = torch.linspace(1.0, step_upper_image_gs, t)[None, None, :, None, None].repeat(b, c, 1, h, w)
|
||||
image_gs = pack(image_gs).to(cond.device, cond.dtype)
|
||||
|
||||
# update
|
||||
subtree_noise_pred = uncond_2 + image_gs * (uncond - uncond_2) + text_gs * (cond - uncond)
|
||||
subtree_noise_pred = torch.cat([subtree_noise_pred, subtree_noise_pred, subtree_noise_pred], dim=0)
|
||||
|
||||
subtree = subtree + (t_subtree_prev - t_subtree) * subtree_noise_pred
|
||||
|
||||
# Next timestep prediction
|
||||
t_subtree_prev_vec = torch.full((img.shape[0],), t_subtree_prev, dtype=cond.dtype, device=cond.device)
|
||||
subtree_onestep_noise_pred = model(
|
||||
img=subtree,
|
||||
**kwargs,
|
||||
timesteps=t_subtree_prev_vec,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
# prepare guidance
|
||||
text_gs = get_oscillation_gs(guidance, i) if text_osci else guidance
|
||||
image_gs = get_oscillation_gs(guidance_img, i) if image_osci else guidance_img
|
||||
cond, uncond, uncond_2 = subtree_noise_pred.chunk(3, dim=0)
|
||||
if image_gs > 1.0 and scale_temporal_osci:
|
||||
# image_gs decrease with each denoising step
|
||||
step_upper_image_gs = torch.linspace(image_gs, 1.0, len(timesteps))[i]
|
||||
# image_gs increase along the temporal axis of the latent video
|
||||
image_gs = torch.linspace(1.0, step_upper_image_gs, t)[None, None, :, None, None].repeat(b, c, 1, h, w)
|
||||
image_gs = pack(image_gs).to(cond.device, cond.dtype)
|
||||
|
||||
# update
|
||||
subtree_noise_pred = uncond_2 + image_gs * (uncond - uncond_2) + text_gs * (cond - uncond)
|
||||
subtree_noise_pred = torch.cat([subtree_noise_pred, subtree_noise_pred, subtree_noise_pred], dim=0)
|
||||
|
||||
# img = img + (t_prev - t_curr) * pred
|
||||
subtree_onestep = subtree + (timesteps[-1] - t_subtree_prev) * subtree_onestep_noise_pred
|
||||
# Save subtree results
|
||||
if model_ae is not None:
|
||||
unpacked_subtree_onestep = unpack(subtree_onestep, height, width, num_frames, patch_size)[: len(img) // 3]
|
||||
decoded_onestep = model_ae.decode(unpacked_subtree_onestep)
|
||||
# Now returns path to middle frame image for evaluation
|
||||
os.makedirs(f"{save_dir}/{i}_subtree", exist_ok=True)
|
||||
decoded_onestep_path = save_video(decoded_onestep, f"{save_dir}/{i}_subtree/{subtree_idx}.mp4")
|
||||
|
||||
# Store results
|
||||
subtrees.append(subtree)
|
||||
subtree_onesteps.append(decoded_onestep_path)
|
||||
|
||||
dist.barrier()
|
||||
if dist.get_rank() == 0:
|
||||
videos_path = f"{save_dir}/{i}_subtree"
|
||||
output_path = f"{save_dir}/{i}_subtree"
|
||||
|
||||
prompt_file = "temp/prompt.json"
|
||||
with open(prompt_file, "w") as fp:
|
||||
prompt_json = {f"{index}.mp4": prompt[0] for index in range(num_subtree)}
|
||||
json.dump(prompt_json, fp)
|
||||
|
||||
minimal_env = {
|
||||
"PATH": "/usr/local/bin:/usr/bin:/bin:/mnt/jfs-hdd/home/huangshijie/opensora_vbench/bin",
|
||||
"CUDA_VISIBLE_DEVICES": ",".join([str(item) for item in vbench_gpus])
|
||||
}
|
||||
cmd_args = [
|
||||
'vbench',
|
||||
'evaluate',
|
||||
'--dimension',
|
||||
','.join(vbench_dimension_list),
|
||||
'--videos_path',
|
||||
videos_path,
|
||||
'--mode',
|
||||
'custom_input',
|
||||
'--prompt_file',
|
||||
prompt_file,
|
||||
'--output_path',
|
||||
output_path,
|
||||
'--ngpus',
|
||||
str(len(vbench_gpus))
|
||||
]
|
||||
subprocess.run(cmd_args, check=True, env=minimal_env)
|
||||
|
||||
dist.barrier()
|
||||
# Find the latest evaluation results file
|
||||
subtree_dir = f"{save_dir}/{i}_subtree"
|
||||
eval_files = [f for f in os.listdir(subtree_dir) if f.startswith('results_') and f.endswith('_eval_results.json')]
|
||||
if not eval_files:
|
||||
raise FileNotFoundError(f"No evaluation results found in {subtree_dir}")
|
||||
# Sort by timestamp in filename to get the latest one
|
||||
latest_eval_file = sorted(eval_files)[-1]
|
||||
eval_results_path = os.path.join(subtree_dir, latest_eval_file)
|
||||
|
||||
with open(eval_results_path, 'r') as f:
|
||||
eval_results = json.load(f)
|
||||
best_idx = find_highest_score_video(eval_results)
|
||||
img = subtrees[best_idx]
|
||||
|
||||
img = img[: len(img) // 3]
|
||||
|
||||
return img
|
||||
|
||||
def prepare_guidance(
|
||||
self, text: list[str], optional_models: dict[str, nn.Module], device: torch.device, dtype: torch.dtype, **kwargs
|
||||
) -> tuple[list[str], dict[str, Tensor]]:
|
||||
ret = {}
|
||||
|
||||
neg = kwargs.get("neg", None)
|
||||
ret["guidance_img"] = kwargs.pop("guidance_img")
|
||||
|
||||
# text
|
||||
if neg is None:
|
||||
neg = [""] * len(text)
|
||||
text = text + neg + neg
|
||||
return text, ret
|
||||
|
||||
|
||||
class DistilledDenoiser(Denoiser):
|
||||
def denoise(self, model: MMDiTModel, **kwargs) -> Tensor:
|
||||
img = kwargs.pop("img")
|
||||
|
|
@ -283,6 +605,7 @@ class DistilledDenoiser(Denoiser):
|
|||
|
||||
SamplingMethodDict = {
|
||||
SamplingMethod.I2V: I2VDenoiser(),
|
||||
SamplingMethod.I2VINFERENCESCALING: I2VScalingDenoiser(),
|
||||
SamplingMethod.DISTILLED: DistilledDenoiser(),
|
||||
}
|
||||
|
||||
|
|
@ -672,7 +995,7 @@ def prepare_api(
|
|||
inp = prepare(model_t5, model_clip, z, prompt=text, patch_size=patch_size)
|
||||
inp.update(additional_inp)
|
||||
|
||||
if opt.method in [SamplingMethod.I2V]:
|
||||
if opt.method in [SamplingMethod.I2V, SamplingMethod.I2VINFERENCESCALING]:
|
||||
# prepare references
|
||||
masks, masked_ref = prepare_inference_condition(
|
||||
z, cond_type, ref_list=references, causal=opt.is_causal_vae
|
||||
|
|
@ -693,6 +1016,19 @@ def prepare_api(
|
|||
), # don't use temporal osci for v2v or t2v
|
||||
flow_shift=opt.flow_shift,
|
||||
patch_size=patch_size,
|
||||
# inference scaling parameters
|
||||
do_inference_scaling=opt.do_inference_scaling,
|
||||
num_subtree=opt.num_subtree,
|
||||
model_ae=model_ae,
|
||||
height=opt.height,
|
||||
width=opt.width,
|
||||
backward_scale = opt.backward_scale,
|
||||
forward_scale = opt.forward_scale,
|
||||
scaling_steps = opt.scaling_steps,
|
||||
num_frames=num_frames,
|
||||
prompt=text,
|
||||
vbench_dimension_list=opt.vbench_dimension_list,
|
||||
vbench_gpus=opt.vbench_gpus
|
||||
)
|
||||
|
||||
x = unpack(x, opt.height, opt.width, num_frames, patch_size=patch_size)
|
||||
|
|
|
|||
Loading…
Reference in a new issue