Open-Sora/scripts/inference.py
2024-05-14 07:24:56 +00:00

220 lines
8.4 KiB
Python

import os
from pprint import pformat
import colossalai
import torch
import torch.distributed as dist
from colossalai.cluster import DistCoordinator
from mmengine.runner import set_random_seed
from tqdm import tqdm
from opensora.acceleration.parallel_states import set_sequence_parallel_group
from opensora.datasets import save_sample
from opensora.datasets.aspect import get_image_size, get_num_frames
from opensora.models.text_encoder.t5 import text_preprocessing
from opensora.registry import MODELS, SCHEDULERS, build_module
from opensora.utils.config_utils import parse_configs
from opensora.utils.inference_utils import (
append_generated,
apply_mask_strategy,
collect_references_batch,
extract_json_from_prompts,
extract_prompts_loop,
get_save_path_name,
load_prompts,
prepare_multi_resolution_info,
)
from opensora.utils.misc import all_exists, create_logger, is_distributed, is_main_process, to_torch_dtype
def main():
torch.set_grad_enabled(False)
# ======================================================
# configs & runtime variables
# ======================================================
# == parse configs ==
cfg = parse_configs(training=False)
# == device and dtype ==
device = "cuda" if torch.cuda.is_available() else "cpu"
cfg_dtype = cfg.get("dtype", "fp32")
assert cfg_dtype in ["fp16", "bf16", "fp32"], f"Unknown mixed precision {cfg_dtype}"
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# == init distributed env ==
if is_distributed():
colossalai.launch_from_torch({})
coordinator = DistCoordinator()
enable_sequence_parallelism = coordinator.world_size > 1
if enable_sequence_parallelism:
set_sequence_parallel_group(dist.group.WORLD)
else:
coordinator = None
enable_sequence_parallelism = False
set_random_seed(seed=cfg.get("seed", 1024))
# == init logger ==
logger = create_logger()
logger.info("Training configuration:\n %s", pformat(cfg.to_dict()))
verbose = cfg.get("verbose", 1)
progress_wrap = tqdm if verbose == 1 else (lambda x: x)
# ======================================================
# build model & load weights
# ======================================================
logger.info("Building models...")
# == build text-encoder and vae ==
text_encoder = build_module(cfg.text_encoder, MODELS, device=device)
vae = build_module(cfg.vae, MODELS).to(device, dtype).eval()
# == prepare video size ==
image_size = cfg.get("image_size", None)
if image_size is None:
resolution = cfg.get("resolution", None)
aspect_ratio = cfg.get("aspect_ratio", None)
assert (
resolution is not None and aspect_ratio is not None
), "resolution and aspect_ratio must be provided if image_size is not provided"
image_size = get_image_size(resolution, aspect_ratio)
num_frames = get_num_frames(cfg.num_frames)
# == build diffusion model ==
input_size = (num_frames, *image_size)
latent_size = vae.get_latent_size(input_size)
model = (
build_module(
cfg.model,
MODELS,
input_size=latent_size,
in_channels=vae.out_channels,
caption_channels=text_encoder.output_dim,
model_max_length=text_encoder.model_max_length,
enable_sequence_parallelism=enable_sequence_parallelism,
)
.to(device, dtype)
.eval()
)
text_encoder.y_embedder = model.y_embedder # HACK: for classifier-free guidance
# == build scheduler ==
scheduler = build_module(cfg.scheduler, SCHEDULERS)
# ======================================================
# inference
# ======================================================
# == load prompts ==
prompts = cfg.get("prompt", None)
start_idx = cfg.get("start_index", 0)
if prompts is None:
assert cfg.get("prompt_path", None) is not None, "Prompt or prompt_path must be provided"
prompts = load_prompts(cfg.prompt_path, start_idx, cfg.get("end_index", None))
# == prepare reference ==
reference_path = cfg.get("reference_path", [""] * len(prompts))
mask_strategy = cfg.get("mask_strategy", [""] * len(prompts))
assert len(reference_path) == len(prompts), "Length of reference must be the same as prompts"
assert len(mask_strategy) == len(prompts), "Length of mask_strategy must be the same as prompts"
# == prepare arguments ==
fps = cfg.fps
save_fps = fps // cfg.get("frame_interval", 1)
multi_resolution = cfg.get("multi_resolution", None)
batch_size = cfg.get("batch_size", 1)
num_sample = cfg.get("num_sample", 1)
loop = cfg.get("loop", 1)
condition_frame_length = cfg.get("condition_frame_length", 5)
align = cfg.get("align", None)
save_dir = cfg.save_dir
os.makedirs(save_dir, exist_ok=True)
sample_name = cfg.get("sample_name", None)
prompt_as_path = cfg.get("prompt_as_path", False)
# == Iter over all samples ==
for i in progress_wrap(range(0, len(prompts), batch_size)):
# == prepare batch prompts ==
batch_prompts = prompts[i : i + batch_size]
ms = mask_strategy[i : i + batch_size]
refs = reference_path[i : i + batch_size]
batch_prompts, refs, ms = extract_json_from_prompts(batch_prompts, refs, ms)
refs = collect_references_batch(refs, vae, image_size)
# == multi-resolution info ==
model_args = prepare_multi_resolution_info(
multi_resolution, len(batch_prompts), image_size, num_frames, fps, device, dtype
)
# == Iter over number of sampling for one prompt ==
for k in range(num_sample):
# == prepare save paths ==
save_paths = [
get_save_path_name(
save_dir,
sample_name=sample_name,
sample_idx=start_idx + idx,
prompt=batch_prompts[idx],
prompt_as_path=prompt_as_path,
num_sample=num_sample,
k=k,
)
for idx in range(len(batch_prompts))
]
# NOTE: Skip if the sample already exists
# This is useful for resuming sampling VBench
if prompt_as_path and all_exists(save_paths):
continue
# == Iter over loop generation ==
video_clips = []
for loop_i in range(loop):
batch_prompts_loop = extract_prompts_loop(batch_prompts, loop_i)
batch_prompts_cleaned = [text_preprocessing(prompt) for prompt in batch_prompts_loop]
# == loop ==
if loop_i > 0:
refs, ms = append_generated(vae, video_clips[-1], refs, ms, loop_i, condition_frame_length)
# == sampling ==
z = torch.randn(len(batch_prompts), vae.out_channels, *latent_size, device=device, dtype=dtype)
masks = apply_mask_strategy(z, refs, ms, loop_i, align=align)
samples = scheduler.sample(
model,
text_encoder,
z=z,
prompts=batch_prompts_cleaned,
device=device,
additional_args=model_args,
progress=verbose >= 2,
mask=masks,
)
samples = vae.decode(samples.to(dtype), num_frames=num_frames)
video_clips.append(samples)
# == save samples ==
if is_main_process():
for idx, batch_prompt in enumerate(batch_prompts):
if verbose >= 2:
logger.info("Prompt: %s", batch_prompt)
save_path = save_paths[idx]
video = [video_clips[i][idx] for i in range(loop)]
for i in range(1, loop):
video[i] = video[i][:, condition_frame_length:]
video = torch.cat(video, dim=1)
save_sample(
video,
fps=save_fps,
save_path=save_path,
verbose=verbose >= 2,
)
start_idx += len(batch_prompts)
logger.info("Inference finished.")
logger.info("Saved %s samples to %s", start_idx, save_dir)
if __name__ == "__main__":
main()