import os from pprint import pformat import colossalai import torch import torch.distributed as dist from colossalai.cluster import DistCoordinator from mmengine.runner import set_random_seed from tqdm import tqdm from opensora.acceleration.parallel_states import set_sequence_parallel_group from opensora.datasets import save_sample from opensora.datasets.aspect import get_image_size, get_num_frames from opensora.models.text_encoder.t5 import text_preprocessing from opensora.registry import MODELS, SCHEDULERS, build_module from opensora.utils.config_utils import parse_configs from opensora.utils.inference_utils import ( append_generated, apply_mask_strategy, collect_references_batch, extract_json_from_prompts, extract_prompts_loop, get_save_path_name, load_prompts, prepare_multi_resolution_info, ) from opensora.utils.misc import all_exists, create_logger, is_distributed, is_main_process, to_torch_dtype def main(): torch.set_grad_enabled(False) # ====================================================== # configs & runtime variables # ====================================================== # == parse configs == cfg = parse_configs(training=False) # == device and dtype == device = "cuda" if torch.cuda.is_available() else "cpu" cfg_dtype = cfg.get("dtype", "fp32") assert cfg_dtype in ["fp16", "bf16", "fp32"], f"Unknown mixed precision {cfg_dtype}" dtype = to_torch_dtype(cfg.get("dtype", "bf16")) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # == init distributed env == if is_distributed(): colossalai.launch_from_torch({}) coordinator = DistCoordinator() enable_sequence_parallelism = coordinator.world_size > 1 if enable_sequence_parallelism: set_sequence_parallel_group(dist.group.WORLD) else: coordinator = None enable_sequence_parallelism = False set_random_seed(seed=cfg.get("seed", 1024)) # == init logger == logger = create_logger() logger.info("Training configuration:\n %s", pformat(cfg.to_dict())) verbose = cfg.get("verbose", 1) progress_wrap = tqdm if verbose == 1 else (lambda x: x) # ====================================================== # build model & load weights # ====================================================== logger.info("Building models...") # == build text-encoder and vae == text_encoder = build_module(cfg.text_encoder, MODELS, device=device) vae = build_module(cfg.vae, MODELS).to(device, dtype).eval() # == prepare video size == image_size = cfg.get("image_size", None) if image_size is None: resolution = cfg.get("resolution", None) aspect_ratio = cfg.get("aspect_ratio", None) assert ( resolution is not None and aspect_ratio is not None ), "resolution and aspect_ratio must be provided if image_size is not provided" image_size = get_image_size(resolution, aspect_ratio) num_frames = get_num_frames(cfg.num_frames) # == build diffusion model == input_size = (num_frames, *image_size) latent_size = vae.get_latent_size(input_size) model = ( build_module( cfg.model, MODELS, input_size=latent_size, in_channels=vae.out_channels, caption_channels=text_encoder.output_dim, model_max_length=text_encoder.model_max_length, enable_sequence_parallelism=enable_sequence_parallelism, ) .to(device, dtype) .eval() ) text_encoder.y_embedder = model.y_embedder # HACK: for classifier-free guidance # == build scheduler == scheduler = build_module(cfg.scheduler, SCHEDULERS) # ====================================================== # inference # ====================================================== # == load prompts == prompts = cfg.get("prompt", None) start_idx = cfg.get("start_index", 0) if prompts is None: assert cfg.get("prompt_path", None) is not None, "Prompt or prompt_path must be provided" prompts = load_prompts(cfg.prompt_path, start_idx, cfg.get("end_index", None)) # == prepare reference == reference_path = cfg.get("reference_path", [""] * len(prompts)) mask_strategy = cfg.get("mask_strategy", [""] * len(prompts)) assert len(reference_path) == len(prompts), "Length of reference must be the same as prompts" assert len(mask_strategy) == len(prompts), "Length of mask_strategy must be the same as prompts" # == prepare arguments == fps = cfg.fps save_fps = fps // cfg.get("frame_interval", 1) multi_resolution = cfg.get("multi_resolution", None) batch_size = cfg.get("batch_size", 1) num_sample = cfg.get("num_sample", 1) loop = cfg.get("loop", 1) condition_frame_length = cfg.get("condition_frame_length", 5) align = cfg.get("align", None) save_dir = cfg.save_dir os.makedirs(save_dir, exist_ok=True) sample_name = cfg.get("sample_name", None) prompt_as_path = cfg.get("prompt_as_path", False) # == Iter over all samples == for i in progress_wrap(range(0, len(prompts), batch_size)): # == prepare batch prompts == batch_prompts = prompts[i : i + batch_size] ms = mask_strategy[i : i + batch_size] refs = reference_path[i : i + batch_size] batch_prompts, refs, ms = extract_json_from_prompts(batch_prompts, refs, ms) refs = collect_references_batch(refs, vae, image_size) # == multi-resolution info == model_args = prepare_multi_resolution_info( multi_resolution, len(batch_prompts), image_size, num_frames, fps, device, dtype ) # == Iter over number of sampling for one prompt == for k in range(num_sample): # == prepare save paths == save_paths = [ get_save_path_name( save_dir, sample_name=sample_name, sample_idx=start_idx + idx, prompt=batch_prompts[idx], prompt_as_path=prompt_as_path, num_sample=num_sample, k=k, ) for idx in range(len(batch_prompts)) ] # NOTE: Skip if the sample already exists # This is useful for resuming sampling VBench if prompt_as_path and all_exists(save_paths): continue # == Iter over loop generation == video_clips = [] for loop_i in range(loop): batch_prompts_loop = extract_prompts_loop(batch_prompts, loop_i) batch_prompts_cleaned = [text_preprocessing(prompt) for prompt in batch_prompts_loop] # == loop == if loop_i > 0: refs, ms = append_generated(vae, video_clips[-1], refs, ms, loop_i, condition_frame_length) # == sampling == z = torch.randn(len(batch_prompts), vae.out_channels, *latent_size, device=device, dtype=dtype) masks = apply_mask_strategy(z, refs, ms, loop_i, align=align) samples = scheduler.sample( model, text_encoder, z=z, prompts=batch_prompts_cleaned, device=device, additional_args=model_args, progress=verbose >= 2, mask=masks, ) samples = vae.decode(samples.to(dtype), num_frames=num_frames) video_clips.append(samples) # == save samples == if is_main_process(): for idx, batch_prompt in enumerate(batch_prompts): if verbose >= 2: logger.info("Prompt: %s", batch_prompt) save_path = save_paths[idx] video = [video_clips[i][idx] for i in range(loop)] for i in range(1, loop): video[i] = video[i][:, condition_frame_length:] video = torch.cat(video, dim=1) save_sample( video, fps=save_fps, save_path=save_path, verbose=verbose >= 2, ) start_idx += len(batch_prompts) logger.info("Inference finished.") logger.info("Saved %s samples to %s", start_idx, save_dir) if __name__ == "__main__": main()