Open-Sora/scripts/inference.py
2024-05-12 05:15:06 +00:00

191 lines
7.5 KiB
Python

import os
import colossalai
import torch
import torch.distributed as dist
from colossalai.cluster import DistCoordinator
from mmengine.runner import set_random_seed
from tqdm import tqdm
from opensora.acceleration.parallel_states import set_sequence_parallel_group
from opensora.datasets import IMG_FPS, save_sample
from opensora.models.text_encoder.t5 import text_preprocessing
from opensora.registry import MODELS, SCHEDULERS, build_module
from opensora.utils.config_utils import parse_configs
from opensora.utils.misc import create_logger, is_distributed, is_main_process, to_torch_dtype
def main():
torch.set_grad_enabled(False)
# ======================================================
# 1. configs & runtime variables
# ======================================================
# == parse configs ==
cfg = parse_configs(training=False)
# == device and dtype ==
device = "cuda" if torch.cuda.is_available() else "cpu"
cfg_dtype = cfg.get("dtype", "fp32")
assert cfg_dtype in ["fp16", "bf16", "fp32"], f"Unknown mixed precision {cfg_dtype}"
dtype = to_torch_dtype(cfg.dtype)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# == init distributed env ==
if is_distributed():
colossalai.launch_from_torch({})
coordinator = DistCoordinator()
enable_sequence_parallelism = coordinator.world_size > 1
if enable_sequence_parallelism:
set_sequence_parallel_group(dist.group.WORLD)
else:
enable_sequence_parallelism = False
set_random_seed(seed=cfg.seed)
# == init logger ==
create_logger()
verbose = cfg.get("verbose", 1)
breakpoint()
print(cfg)
# ======================================================
# 2. runtime variables
# ======================================================
prompts = cfg.prompt
# ======================================================
# 3. build model & load weights
# ======================================================
# 3.1. build model
input_size = (cfg.num_frames, *cfg.image_size)
vae = build_module(cfg.vae, MODELS)
latent_size = vae.get_latent_size(input_size)
text_encoder = build_module(cfg.text_encoder, MODELS, device=device) # T5 must be fp32
model = build_module(
cfg.model,
MODELS,
input_size=latent_size,
in_channels=vae.out_channels,
caption_channels=text_encoder.output_dim,
model_max_length=text_encoder.model_max_length,
enable_sequence_parallelism=enable_sequence_parallelism,
)
text_encoder.y_embedder = model.y_embedder # hack for classifier-free guidance
# 3.2. move to device & eval
vae = vae.to(device, dtype).eval()
model = model.to(device, dtype).eval()
# 3.3. build scheduler
scheduler = build_module(cfg.scheduler, SCHEDULERS)
# 3.4. support for multi-resolution
model_args = dict()
if cfg.multi_resolution == "PixArtMS":
image_size = cfg.image_size
hw = torch.tensor([image_size], device=device, dtype=dtype).repeat(cfg.batch_size, 1)
ar = torch.tensor([[image_size[0] / image_size[1]]], device=device, dtype=dtype).repeat(cfg.batch_size, 1)
model_args["data_info"] = dict(ar=ar, hw=hw)
elif cfg.multi_resolution == "STDiT2":
image_size = cfg.image_size
height = torch.tensor([image_size[0]], device=device, dtype=dtype).repeat(cfg.batch_size)
width = torch.tensor([image_size[1]], device=device, dtype=dtype).repeat(cfg.batch_size)
num_frames = torch.tensor([cfg.num_frames], device=device, dtype=dtype).repeat(cfg.batch_size)
ar = torch.tensor([image_size[0] / image_size[1]], device=device, dtype=dtype).repeat(cfg.batch_size)
if cfg.num_frames == 1:
cfg.fps = IMG_FPS
fps = torch.tensor([cfg.fps], device=device, dtype=dtype).repeat(cfg.batch_size)
model_args["height"] = height
model_args["width"] = width
model_args["num_frames"] = num_frames
model_args["ar"] = ar
model_args["fps"] = fps
# ======================================================
# 4. inference
# ======================================================
sample_idx = cfg.get("start_index", 0)
if cfg.sample_name is not None:
sample_name = cfg.sample_name
elif cfg.prompt_as_path:
sample_name = ""
else:
sample_name = "sample"
save_dir = cfg.save_dir
os.makedirs(save_dir, exist_ok=True)
# 4.1. batch generation
progress_wrap = tqdm if verbose == 1 else (lambda x: x)
for i in progress_wrap(range(0, len(prompts), cfg.batch_size)):
# 4.2 sample in hidden space
batch_prompts_raw = prompts[i : i + cfg.batch_size]
batch_prompts = [text_preprocessing(prompt) for prompt in batch_prompts_raw]
# handle the last batch
if len(batch_prompts_raw) < cfg.batch_size and cfg.multi_resolution == "STDiT2":
model_args["height"] = model_args["height"][: len(batch_prompts_raw)]
model_args["width"] = model_args["width"][: len(batch_prompts_raw)]
model_args["num_frames"] = model_args["num_frames"][: len(batch_prompts_raw)]
model_args["ar"] = model_args["ar"][: len(batch_prompts_raw)]
model_args["fps"] = model_args["fps"][: len(batch_prompts_raw)]
# 4.3. diffusion sampling
old_sample_idx = sample_idx
# generate multiple samples for each prompt
for k in range(cfg.num_sample):
sample_idx = old_sample_idx
# Skip if the sample already exists
# This is useful for resuming sampling VBench
if cfg.prompt_as_path:
skip = True
for batch_prompt in batch_prompts_raw:
path = os.path.join(save_dir, f"{sample_name}{batch_prompt}")
if cfg.num_sample != 1:
path = f"{path}-{k}"
path = f"{path}.mp4"
if not os.path.exists(path):
skip = False
break
if skip:
continue
# sampling
z = torch.randn(len(batch_prompts), vae.out_channels, *latent_size, device=device, dtype=dtype)
samples = scheduler.sample(
model,
text_encoder,
z=z,
prompts=batch_prompts,
device=device,
additional_args=model_args,
progress=verbose >= 2,
)
samples = vae.decode(samples.to(dtype), num_frames=cfg.num_frames)
# 4.4. save samples
if is_main_process():
for idx, sample in enumerate(samples):
if verbose >= 2:
print(f"Prompt: {batch_prompts_raw[idx]}")
if cfg.prompt_as_path:
sample_name_suffix = batch_prompts_raw[idx]
else:
sample_name_suffix = f"_{sample_idx}"
save_path = os.path.join(save_dir, f"{sample_name}{sample_name_suffix}")
if cfg.num_sample != 1:
save_path = f"{save_path}-{k}"
save_sample(
sample,
fps=cfg.fps // cfg.frame_interval,
save_path=save_path,
verbose=verbose >= 2,
)
sample_idx += 1
if __name__ == "__main__":
main()