Open-Sora/opensora/utils/inference_utils.py
2024-05-13 06:33:12 +00:00

50 lines
2 KiB
Python

import os
import torch
from opensora.datasets import IMG_FPS
def prepare_multi_resolution_info(info_type, batch_size, image_size, num_frames, fps, device, dtype):
if info_type is None:
return dict()
elif info_type == "PixArtMS":
hw = torch.tensor([image_size], device=device, dtype=dtype).repeat(batch_size, 1)
ar = torch.tensor([[image_size[0] / image_size[1]]], device=device, dtype=dtype).repeat(batch_size, 1)
return dict(ar=ar, hw=hw)
elif info_type in ["STDiT2", "OpenSora"]:
height = torch.tensor([image_size[0]], device=device, dtype=dtype).repeat(batch_size)
width = torch.tensor([image_size[1]], device=device, dtype=dtype).repeat(batch_size)
num_frames = torch.tensor([num_frames], device=device, dtype=dtype).repeat(batch_size)
ar = torch.tensor([image_size[0] / image_size[1]], device=device, dtype=dtype).repeat(batch_size)
fps = fps if num_frames > 1 else IMG_FPS
fps = torch.tensor([fps], device=device, dtype=dtype).repeat(batch_size)
return dict(height=height, width=width, num_frames=num_frames, ar=ar, fps=fps)
else:
raise NotImplementedError
def load_prompts(prompt_path, start_idx=None, end_idx=None):
with open(prompt_path, "r") as f:
prompts = [line.strip() for line in f.readlines()]
prompts = prompts[start_idx:end_idx]
return prompts
def get_save_path_name(
save_dir,
sample_name=None, # prefix
sample_idx=None, # sample index
prompt=None, # used prompt
prompt_as_path=False, # use prompt as path
num_sample=1, # number of samples to generate for one prompt
k=None, # kth sample
):
if sample_name is None:
sample_name = "" if prompt_as_path else "sample"
sample_name_suffix = prompt if prompt_as_path else f"_{sample_idx}"
save_path = os.path.join(save_dir, f"{sample_name}{sample_name_suffix}")
if num_sample != 1:
save_path = f"{save_path}-{k}"
return save_path