mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-05-05 19:49:35 +02:00
50 lines
2 KiB
Python
50 lines
2 KiB
Python
import os
|
|
|
|
import torch
|
|
|
|
from opensora.datasets import IMG_FPS
|
|
|
|
|
|
def prepare_multi_resolution_info(info_type, batch_size, image_size, num_frames, fps, device, dtype):
|
|
if info_type is None:
|
|
return dict()
|
|
elif info_type == "PixArtMS":
|
|
hw = torch.tensor([image_size], device=device, dtype=dtype).repeat(batch_size, 1)
|
|
ar = torch.tensor([[image_size[0] / image_size[1]]], device=device, dtype=dtype).repeat(batch_size, 1)
|
|
return dict(ar=ar, hw=hw)
|
|
elif info_type in ["STDiT2", "OpenSora"]:
|
|
height = torch.tensor([image_size[0]], device=device, dtype=dtype).repeat(batch_size)
|
|
width = torch.tensor([image_size[1]], device=device, dtype=dtype).repeat(batch_size)
|
|
num_frames = torch.tensor([num_frames], device=device, dtype=dtype).repeat(batch_size)
|
|
ar = torch.tensor([image_size[0] / image_size[1]], device=device, dtype=dtype).repeat(batch_size)
|
|
fps = fps if num_frames > 1 else IMG_FPS
|
|
fps = torch.tensor([fps], device=device, dtype=dtype).repeat(batch_size)
|
|
return dict(height=height, width=width, num_frames=num_frames, ar=ar, fps=fps)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
|
|
def load_prompts(prompt_path, start_idx=None, end_idx=None):
|
|
with open(prompt_path, "r") as f:
|
|
prompts = [line.strip() for line in f.readlines()]
|
|
prompts = prompts[start_idx:end_idx]
|
|
return prompts
|
|
|
|
|
|
def get_save_path_name(
|
|
save_dir,
|
|
sample_name=None, # prefix
|
|
sample_idx=None, # sample index
|
|
prompt=None, # used prompt
|
|
prompt_as_path=False, # use prompt as path
|
|
num_sample=1, # number of samples to generate for one prompt
|
|
k=None, # kth sample
|
|
):
|
|
if sample_name is None:
|
|
sample_name = "" if prompt_as_path else "sample"
|
|
sample_name_suffix = prompt if prompt_as_path else f"_{sample_idx}"
|
|
save_path = os.path.join(save_dir, f"{sample_name}{sample_name_suffix}")
|
|
if num_sample != 1:
|
|
save_path = f"{save_path}-{k}"
|
|
return save_path
|