Open-Sora/sample.py
Hongxin Liu e02f6286de
[hotfix] fix sample and update training script (#15)
* [hotfix] fix sample

* [hotfix] fix sample
2024-02-28 15:01:53 +08:00

144 lines
4.6 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Sample new images from a pre-trained DiT.
"""
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
import argparse
from colossalai.utils import get_current_device
from torchvision.io import write_video
from transformers import AutoModel, AutoTokenizer, CLIPTextModel
from open_sora.diffusion import create_diffusion
from open_sora.modeling import DiT_models
from open_sora.utils.data import col2video, unnormalize_video
def main(args):
# Setup PyTorch:
torch.manual_seed(args.seed)
torch.set_grad_enabled(False)
device = get_current_device()
if len(args.vqvae) > 0:
vqvae = (
AutoModel.from_pretrained(args.vqvae, trust_remote_code=True)
.to(device)
.eval()
)
in_channels = vqvae.embedding_dim
w_h_factor = 4
t_factor = 2
else:
# disable VQ-VAE if not provided, just use raw video frames
vqvae = None
in_channels = 3
w_h_factor = 1
t_factor = 1
text_model = CLIPTextModel.from_pretrained(args.text_model).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(args.text_model)
model = DiT_models[args.model](in_channels=in_channels).to(device).eval()
patch_size = model.patch_size
model.load_state_dict(torch.load(args.ckpt))
diffusion = create_diffusion(str(args.num_sampling_steps))
# Create sampling noise:
text_inputs = tokenizer(args.text, return_tensors="pt")
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
text_latent_states = text_model(**text_inputs).last_hidden_state
num_frames = args.fps * args.sec
z = torch.randn(
1,
(args.height // patch_size // w_h_factor)
* (args.width // patch_size // w_h_factor)
* (num_frames // t_factor),
in_channels,
patch_size,
patch_size,
device=device,
)
# Setup classifier-free guidance:
model_kwargs = {}
z = torch.cat([z, z], 0)
model_kwargs["text_latent_states"] = torch.cat(
[text_latent_states, torch.zeros_like(text_latent_states)], 0
)
model_kwargs["cfg_scale"] = args.cfg_scale
model_kwargs["attention_mask"] = torch.ones(
2, 1, z.shape[1], text_latent_states.shape[1], device=device, dtype=torch.int
)
# Sample images:
samples = diffusion.p_sample_loop(
model.forward_with_cfg,
z.shape,
z,
clip_denoised=False,
model_kwargs=model_kwargs,
progress=True,
device=device,
)
samples, _ = samples.chunk(2, dim=0) # Remove null class samples
samples = col2video(
samples.squeeze(),
(
num_frames // t_factor,
in_channels,
args.height // w_h_factor,
args.width // w_h_factor,
),
)
if vqvae is not None:
# [T, C, H, W] -> [B, C, T, H, W]
samples = samples.permute(1, 0, 2, 3).unsqueeze(0)
samples = vqvae.decode_from_embeddings(samples)
# [B, C, T, H, W] -> [T, H, W, C]
samples = samples.squeeze(0).permute(1, 2, 3, 0)
else:
# [T, C, H, W] -> [T, H, W, C]
samples = samples.permute(0, 2, 3, 1)
samples = unnormalize_video(samples).to(torch.uint8)
write_video("sample.mp4", samples.cpu(), args.fps)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-m", "--model", type=str, choices=list(DiT_models.keys()), default="DiT-S/8"
)
parser.add_argument(
"--text",
type=str,
default="a cartoon animals runs through an ice cave in a video game",
)
parser.add_argument("--cfg-scale", type=float, default=4.0)
parser.add_argument("--num-sampling-steps", type=int, default=250)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument(
"--ckpt",
type=str,
required=True,
help="Optional path to a DiT checkpoint (default: auto-download a pre-trained DiT-XL/2 model).",
)
parser.add_argument("--vqvae", default="hpcai-tech/vqvae")
parser.add_argument(
"--text_model", type=str, default="openai/clip-vit-base-patch32"
)
parser.add_argument("--width", type=int, default=480)
parser.add_argument("--height", type=int, default=320)
parser.add_argument("--fps", type=int, default=15)
parser.add_argument("--sec", type=int, default=8)
args = parser.parse_args()
main(args)