mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-26 10:59:41 +02:00
243 lines
9.4 KiB
Python
243 lines
9.4 KiB
Python
# All rights reserved.
|
|
# Copyright 2024 Vchitect/Latte
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
# modified from https://github.com/Vchitect/Latte/blob/main/sample/sample_t2v.py
|
|
|
|
import argparse
|
|
import os
|
|
import sys
|
|
|
|
import torch
|
|
from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder
|
|
from diffusers.schedulers import (
|
|
DDIMScheduler,
|
|
DDPMScheduler,
|
|
DEISMultistepScheduler,
|
|
DPMSolverMultistepScheduler,
|
|
EulerAncestralDiscreteScheduler,
|
|
EulerDiscreteScheduler,
|
|
HeunDiscreteScheduler,
|
|
KDPM2AncestralDiscreteScheduler,
|
|
PNDMScheduler,
|
|
)
|
|
from diffusers.schedulers.scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler
|
|
from transformers import T5EncoderModel, T5Tokenizer
|
|
|
|
sys.path.append(os.path.split(sys.path[0])[0])
|
|
import imageio
|
|
from pipeline_videogen import VideoGenPipeline
|
|
from utils import save_video_grid
|
|
|
|
from download import find_model
|
|
from open_sora.modeling import LatteT2V
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--model_path", type=str, required=True, help="The path to the pretrained model files")
|
|
parser.add_argument("--checkpoint", type=str, required=True, help="The path to the t2v.pt file.")
|
|
parser.add_argument("--output_path", type=str, required=True, help="The path to save the output")
|
|
|
|
# generation configs
|
|
parser.add_argument(
|
|
"--text_prompt", type=str, nargs="+", required=True, help="The text prompt to generate the video."
|
|
)
|
|
parser.add_argument("--video_length", type=int, default=16, help="The number of frames in the video.")
|
|
parser.add_argument("--image_height", type=int, default=256, help="The size of the generated images.")
|
|
parser.add_argument("--image_width", type=int, default=256, help="The size of the generated images.")
|
|
parser.add_argument("--guidance_scale", type=float, default=7.5, help="The scale of the guidance loss.")
|
|
parser.add_argument("--sample_method", type=str, default="PNDM", help="The sampling method to use.")
|
|
parser.add_argument("--num_sampling_steps", type=int, default=50, help="The number of sampling steps.")
|
|
parser.add_argument(
|
|
"--enable_temporal_attentions", action="store_true", default=True, help="Whether to enable temporal attentions."
|
|
)
|
|
parser.add_argument(
|
|
"--enable_vae_temporal_decoder",
|
|
action="store_true",
|
|
default=True,
|
|
help="Whether to enable the VAE temporal decoder.",
|
|
)
|
|
|
|
# Scheduler configs
|
|
parser.add_argument("--beta_start", type=float, default=0.0001)
|
|
parser.add_argument("--beta_end", type=float, default=0.02)
|
|
parser.add_argument("--beta_schedule", type=str, default="linear")
|
|
parser.add_argument("--variance_type", type=str, default="learned_range")
|
|
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def main(args):
|
|
torch.set_grad_enabled(False)
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
transformer_model = LatteT2V.from_pretrained_2d(
|
|
args.model_path, subfolder="transformer", video_length=args.video_length
|
|
).to(device, dtype=torch.float16)
|
|
state_dict = find_model(args.checkpoint)
|
|
transformer_model.load_state_dict(state_dict["model"])
|
|
|
|
if args.enable_vae_temporal_decoder:
|
|
vae = AutoencoderKLTemporalDecoder.from_pretrained(
|
|
args.model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float16
|
|
).to(device)
|
|
else:
|
|
vae = AutoencoderKL.from_pretrained(args.model_path, subfolder="vae", torch_dtype=torch.float16).to(device)
|
|
tokenizer = T5Tokenizer.from_pretrained(args.model_path, subfolder="tokenizer")
|
|
text_encoder = T5EncoderModel.from_pretrained(
|
|
args.model_path, subfolder="text_encoder", torch_dtype=torch.float16
|
|
).to(device)
|
|
|
|
# set eval mode
|
|
transformer_model.eval()
|
|
vae.eval()
|
|
text_encoder.eval()
|
|
|
|
if args.sample_method == "DDIM":
|
|
scheduler = DDIMScheduler.from_pretrained(
|
|
args.model_path,
|
|
subfolder="scheduler",
|
|
beta_start=args.beta_start,
|
|
beta_end=args.beta_end,
|
|
beta_schedule=args.beta_schedule,
|
|
variance_type=args.variance_type,
|
|
)
|
|
elif args.sample_method == "EulerDiscrete":
|
|
scheduler = EulerDiscreteScheduler.from_pretrained(
|
|
args.model_path,
|
|
subfolder="scheduler",
|
|
beta_start=args.beta_start,
|
|
beta_end=args.beta_end,
|
|
beta_schedule=args.beta_schedule,
|
|
variance_type=args.variance_type,
|
|
)
|
|
elif args.sample_method == "DDPM":
|
|
scheduler = DDPMScheduler.from_pretrained(
|
|
args.model_path,
|
|
subfolder="scheduler",
|
|
beta_start=args.beta_start,
|
|
beta_end=args.beta_end,
|
|
beta_schedule=args.beta_schedule,
|
|
variance_type=args.variance_type,
|
|
)
|
|
elif args.sample_method == "DPMSolverMultistep":
|
|
scheduler = DPMSolverMultistepScheduler.from_pretrained(
|
|
args.model_path,
|
|
subfolder="scheduler",
|
|
beta_start=args.beta_start,
|
|
beta_end=args.beta_end,
|
|
beta_schedule=args.beta_schedule,
|
|
variance_type=args.variance_type,
|
|
)
|
|
elif args.sample_method == "DPMSolverSinglestep":
|
|
scheduler = DPMSolverSinglestepScheduler.from_pretrained(
|
|
args.model_path,
|
|
subfolder="scheduler",
|
|
beta_start=args.beta_start,
|
|
beta_end=args.beta_end,
|
|
beta_schedule=args.beta_schedule,
|
|
variance_type=args.variance_type,
|
|
)
|
|
elif args.sample_method == "PNDM":
|
|
scheduler = PNDMScheduler.from_pretrained(
|
|
args.model_path,
|
|
subfolder="scheduler",
|
|
beta_start=args.beta_start,
|
|
beta_end=args.beta_end,
|
|
beta_schedule=args.beta_schedule,
|
|
variance_type=args.variance_type,
|
|
)
|
|
elif args.sample_method == "HeunDiscrete":
|
|
scheduler = HeunDiscreteScheduler.from_pretrained(
|
|
args.model_path,
|
|
subfolder="scheduler",
|
|
beta_start=args.beta_start,
|
|
beta_end=args.beta_end,
|
|
beta_schedule=args.beta_schedule,
|
|
variance_type=args.variance_type,
|
|
)
|
|
elif args.sample_method == "EulerAncestralDiscrete":
|
|
scheduler = EulerAncestralDiscreteScheduler.from_pretrained(
|
|
args.model_path,
|
|
subfolder="scheduler",
|
|
beta_start=args.beta_start,
|
|
beta_end=args.beta_end,
|
|
beta_schedule=args.beta_schedule,
|
|
variance_type=args.variance_type,
|
|
)
|
|
elif args.sample_method == "DEISMultistep":
|
|
scheduler = DEISMultistepScheduler.from_pretrained(
|
|
args.model_path,
|
|
subfolder="scheduler",
|
|
beta_start=args.beta_start,
|
|
beta_end=args.beta_end,
|
|
beta_schedule=args.beta_schedule,
|
|
variance_type=args.variance_type,
|
|
)
|
|
elif args.sample_method == "KDPM2AncestralDiscrete":
|
|
scheduler = KDPM2AncestralDiscreteScheduler.from_pretrained(
|
|
args.model_path,
|
|
subfolder="scheduler",
|
|
beta_start=args.beta_start,
|
|
beta_end=args.beta_end,
|
|
beta_schedule=args.beta_schedule,
|
|
variance_type=args.variance_type,
|
|
)
|
|
|
|
videogen_pipeline = VideoGenPipeline(
|
|
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, transformer=transformer_model
|
|
).to(device)
|
|
# videogen_pipeline.enable_xformers_memory_efficient_attention()
|
|
|
|
if not os.path.exists(args.output_path):
|
|
os.makedirs(args.output_path, exist_ok=True)
|
|
|
|
video_grids = []
|
|
for prompt in args.text_prompt:
|
|
print("Processing the ({}) prompt".format(prompt))
|
|
videos = videogen_pipeline(
|
|
prompt,
|
|
video_length=args.video_length,
|
|
height=args.image_height,
|
|
width=args.image_width,
|
|
num_inference_steps=args.num_sampling_steps,
|
|
guidance_scale=args.guidance_scale,
|
|
enable_temporal_attentions=args.enable_temporal_attentions,
|
|
num_images_per_prompt=1,
|
|
mask_feature=True,
|
|
enable_vae_temporal_decoder=args.enable_vae_temporal_decoder,
|
|
).video
|
|
try:
|
|
save_path = os.path.join(args.output_path, prompt.replace(" ", "_") + "_webv-imageio.mp4")
|
|
imageio.mimwrite(save_path, videos[0], fps=8, quality=9) # highest quality is 10, lowest is 0
|
|
except:
|
|
print("Error when saving {}".format(prompt))
|
|
video_grids.append(videos)
|
|
video_grids = torch.cat(video_grids, dim=0)
|
|
|
|
video_grids = save_video_grid(video_grids)
|
|
|
|
# torchvision.io.write_video(args.output_path + '_%04d' % args.run_time + '-.mp4', video_grids, fps=6)
|
|
save_path = os.path.join(args.output_path, "grid.mp4")
|
|
imageio.mimwrite(save_path, video_grids, fps=8, quality=5)
|
|
print("save path {}".format(abspath(args.output_path)))
|
|
|
|
# save_videos_grid(video, f"./{prompt}.gif")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parse_args()
|
|
main(args)
|