Open-Sora/scripts/sample/latte/sample_t2v.py
Frank Lee 9648d53d4d
added latte sampling (#22)
* added latte sampling

* polish
2024-03-04 10:43:22 +08:00

243 lines
9.4 KiB
Python

# All rights reserved.
# Copyright 2024 Vchitect/Latte
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# modified from https://github.com/Vchitect/Latte/blob/main/sample/sample_t2v.py
import argparse
import os
import sys
import torch
from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder
from diffusers.schedulers import (
DDIMScheduler,
DDPMScheduler,
DEISMultistepScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
PNDMScheduler,
)
from diffusers.schedulers.scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler
from transformers import T5EncoderModel, T5Tokenizer
sys.path.append(os.path.split(sys.path[0])[0])
import imageio
from pipeline_videogen import VideoGenPipeline
from utils import save_video_grid
from download import find_model
from open_sora.modeling import LatteT2V
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, required=True, help="The path to the pretrained model files")
parser.add_argument("--checkpoint", type=str, required=True, help="The path to the t2v.pt file.")
parser.add_argument("--output_path", type=str, required=True, help="The path to save the output")
# generation configs
parser.add_argument(
"--text_prompt", type=str, nargs="+", required=True, help="The text prompt to generate the video."
)
parser.add_argument("--video_length", type=int, default=16, help="The number of frames in the video.")
parser.add_argument("--image_height", type=int, default=256, help="The size of the generated images.")
parser.add_argument("--image_width", type=int, default=256, help="The size of the generated images.")
parser.add_argument("--guidance_scale", type=float, default=7.5, help="The scale of the guidance loss.")
parser.add_argument("--sample_method", type=str, default="PNDM", help="The sampling method to use.")
parser.add_argument("--num_sampling_steps", type=int, default=50, help="The number of sampling steps.")
parser.add_argument(
"--enable_temporal_attentions", action="store_true", default=True, help="Whether to enable temporal attentions."
)
parser.add_argument(
"--enable_vae_temporal_decoder",
action="store_true",
default=True,
help="Whether to enable the VAE temporal decoder.",
)
# Scheduler configs
parser.add_argument("--beta_start", type=float, default=0.0001)
parser.add_argument("--beta_end", type=float, default=0.02)
parser.add_argument("--beta_schedule", type=str, default="linear")
parser.add_argument("--variance_type", type=str, default="learned_range")
args = parser.parse_args()
return args
def main(args):
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
transformer_model = LatteT2V.from_pretrained_2d(
args.model_path, subfolder="transformer", video_length=args.video_length
).to(device, dtype=torch.float16)
state_dict = find_model(args.checkpoint)
transformer_model.load_state_dict(state_dict["model"])
if args.enable_vae_temporal_decoder:
vae = AutoencoderKLTemporalDecoder.from_pretrained(
args.model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float16
).to(device)
else:
vae = AutoencoderKL.from_pretrained(args.model_path, subfolder="vae", torch_dtype=torch.float16).to(device)
tokenizer = T5Tokenizer.from_pretrained(args.model_path, subfolder="tokenizer")
text_encoder = T5EncoderModel.from_pretrained(
args.model_path, subfolder="text_encoder", torch_dtype=torch.float16
).to(device)
# set eval mode
transformer_model.eval()
vae.eval()
text_encoder.eval()
if args.sample_method == "DDIM":
scheduler = DDIMScheduler.from_pretrained(
args.model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type,
)
elif args.sample_method == "EulerDiscrete":
scheduler = EulerDiscreteScheduler.from_pretrained(
args.model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type,
)
elif args.sample_method == "DDPM":
scheduler = DDPMScheduler.from_pretrained(
args.model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type,
)
elif args.sample_method == "DPMSolverMultistep":
scheduler = DPMSolverMultistepScheduler.from_pretrained(
args.model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type,
)
elif args.sample_method == "DPMSolverSinglestep":
scheduler = DPMSolverSinglestepScheduler.from_pretrained(
args.model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type,
)
elif args.sample_method == "PNDM":
scheduler = PNDMScheduler.from_pretrained(
args.model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type,
)
elif args.sample_method == "HeunDiscrete":
scheduler = HeunDiscreteScheduler.from_pretrained(
args.model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type,
)
elif args.sample_method == "EulerAncestralDiscrete":
scheduler = EulerAncestralDiscreteScheduler.from_pretrained(
args.model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type,
)
elif args.sample_method == "DEISMultistep":
scheduler = DEISMultistepScheduler.from_pretrained(
args.model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type,
)
elif args.sample_method == "KDPM2AncestralDiscrete":
scheduler = KDPM2AncestralDiscreteScheduler.from_pretrained(
args.model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule,
variance_type=args.variance_type,
)
videogen_pipeline = VideoGenPipeline(
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, transformer=transformer_model
).to(device)
# videogen_pipeline.enable_xformers_memory_efficient_attention()
if not os.path.exists(args.output_path):
os.makedirs(args.output_path, exist_ok=True)
video_grids = []
for prompt in args.text_prompt:
print("Processing the ({}) prompt".format(prompt))
videos = videogen_pipeline(
prompt,
video_length=args.video_length,
height=args.image_height,
width=args.image_width,
num_inference_steps=args.num_sampling_steps,
guidance_scale=args.guidance_scale,
enable_temporal_attentions=args.enable_temporal_attentions,
num_images_per_prompt=1,
mask_feature=True,
enable_vae_temporal_decoder=args.enable_vae_temporal_decoder,
).video
try:
save_path = os.path.join(args.output_path, prompt.replace(" ", "_") + "_webv-imageio.mp4")
imageio.mimwrite(save_path, videos[0], fps=8, quality=9) # highest quality is 10, lowest is 0
except:
print("Error when saving {}".format(prompt))
video_grids.append(videos)
video_grids = torch.cat(video_grids, dim=0)
video_grids = save_video_grid(video_grids)
# torchvision.io.write_video(args.output_path + '_%04d' % args.run_time + '-.mp4', video_grids, fps=6)
save_path = os.path.join(args.output_path, "grid.mp4")
imageio.mimwrite(save_path, video_grids, fps=8, quality=5)
print("save path {}".format(abspath(args.output_path)))
# save_videos_grid(video, f"./{prompt}.gif")
if __name__ == "__main__":
args = parse_args()
main(args)