mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 13:14:44 +02:00
merge mask-related utils
This commit is contained in:
parent
24c707bfb9
commit
7d27f5553e
|
|
@ -28,8 +28,11 @@ scheduler = dict(
|
|||
)
|
||||
dtype = "fp16"
|
||||
|
||||
# Condition
|
||||
prompt_path = "./assets/texts/t2v_samples.txt"
|
||||
prompt = None # prompt has higher priority than prompt_path
|
||||
|
||||
# Others
|
||||
batch_size = 1
|
||||
seed = 42
|
||||
prompt_path = "./assets/texts/t2v_samples.txt"
|
||||
save_dir = "./outputs/samples/"
|
||||
|
|
|
|||
46
configs/opensora/inference_long/16x256x256_long.py
Normal file
46
configs/opensora/inference_long/16x256x256_long.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
num_frames = 16
|
||||
fps = 24 // 3
|
||||
image_size = (256, 256)
|
||||
|
||||
# Define model
|
||||
model = dict(
|
||||
type="STDiT-XL/2",
|
||||
space_scale=0.5,
|
||||
time_scale=1.0,
|
||||
enable_flashattn=True,
|
||||
enable_layernorm_kernel=True,
|
||||
from_pretrained="PRETRAINED_MODEL",
|
||||
)
|
||||
vae = dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="stabilityai/sd-vae-ft-ema",
|
||||
micro_batch_size=4,
|
||||
)
|
||||
text_encoder = dict(
|
||||
type="t5",
|
||||
from_pretrained="./pretrained_models/t5_ckpts",
|
||||
model_max_length=120,
|
||||
)
|
||||
scheduler = dict(
|
||||
type="iddpm",
|
||||
num_sampling_steps=100,
|
||||
cfg_scale=7.0,
|
||||
)
|
||||
dtype = "fp16"
|
||||
|
||||
# Condition
|
||||
prompt_path = None
|
||||
prompt = [
|
||||
"In an ornate, historical hall, a massive tidal wave peaks and begins to crash. Two surfers, seizing the moment, skillfully navigate the face of the wave."
|
||||
]
|
||||
|
||||
loop = 10
|
||||
condition_frame_length = 4
|
||||
reference_path = ["assets/images/condition/wave.png"]
|
||||
mask_strategy = ["0,0,0,1,0"] # valid when reference_path is not None
|
||||
# loop id, ref id, ref start, length, target start
|
||||
|
||||
# Others
|
||||
batch_size = 2
|
||||
seed = 42
|
||||
save_dir = "./samples/"
|
||||
|
|
@ -8,14 +8,14 @@ import torchvision.transforms as transforms
|
|||
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
|
||||
|
||||
from . import video_transforms
|
||||
from .utils import center_crop_arr
|
||||
from .utils import center_crop_arr, VID_EXTENSIONS
|
||||
|
||||
|
||||
def get_transforms_video(resolution=256):
|
||||
transform_video = transforms.Compose(
|
||||
[
|
||||
video_transforms.ToTensorVideo(), # TCHW
|
||||
video_transforms.RandomHorizontalFlipVideo(),
|
||||
# video_transforms.RandomHorizontalFlipVideo(),
|
||||
video_transforms.UCFCenterCropVideo(resolution),
|
||||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
||||
]
|
||||
|
|
@ -27,7 +27,7 @@ def get_transforms_image(image_size=256):
|
|||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
# transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
||||
]
|
||||
|
|
@ -58,7 +58,7 @@ class DatasetFromCSV(torch.utils.data.Dataset):
|
|||
self.samples = list(reader)
|
||||
|
||||
ext = self.samples[0][0].split(".")[-1]
|
||||
if ext.lower() in ("mp4", "avi", "mov", "mkv"):
|
||||
if ext.lower() in VID_EXTENSIONS:
|
||||
self.is_video = True
|
||||
else:
|
||||
assert f".{ext.lower()}" in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
|
||||
|
|
|
|||
|
|
@ -11,6 +11,36 @@ from torch.utils.data.distributed import DistributedSampler
|
|||
from torchvision.io import write_video
|
||||
from torchvision.utils import save_image
|
||||
|
||||
VID_EXTENSIONS = ("mp4", "avi", "mov", "mkv")
|
||||
|
||||
|
||||
def read_image_from_path(path, transform=None, num_frames=1, image_size=256):
|
||||
image = pil_loader(path)
|
||||
if transform is None:
|
||||
transform = get_transforms_image(image_size)
|
||||
image = transform(image)
|
||||
video = image.unsqueeze(0).repeat(num_frames, 1, 1, 1)
|
||||
video = video.permute(1, 0, 2, 3)
|
||||
return video
|
||||
|
||||
|
||||
def read_video_from_path(path, transform=None, image_size=256):
|
||||
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
|
||||
if transform is None:
|
||||
transform = get_transforms_video(image_size)
|
||||
video = transform(vframes) # T C H W
|
||||
video = video.permute(1, 0, 2, 3)
|
||||
return video
|
||||
|
||||
|
||||
def read_from_path(path, image_size):
|
||||
ext = path.split(".")[-1]
|
||||
if ext.lower() in VID_EXTENSIONS:
|
||||
return read_video_from_path(path, image_size=image_size)
|
||||
else:
|
||||
assert f".{ext.lower()}" in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
|
||||
return read_image_from_path(path, image_size=image_size)
|
||||
|
||||
|
||||
def save_sample(x, fps=8, save_path=None, normalize=True, value_range=(-1, 1)):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -56,6 +56,14 @@ class VideoAutoencoderKL(nn.Module):
|
|||
input_size = [input_size[i] // self.patch_size[i] for i in range(3)]
|
||||
return input_size
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class VideoAutoencoderKLTemporalDecoder(nn.Module):
|
||||
|
|
@ -80,3 +88,11 @@ class VideoAutoencoderKLTemporalDecoder(nn.Module):
|
|||
assert input_size[i] % self.patch_size[i] == 0, "Input size must be divisible by patch size"
|
||||
input_size = [input_size[i] // self.patch_size[i] for i in range(3)]
|
||||
return input_size
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
|
|
|||
|
|
@ -7,6 +7,12 @@ from mmengine.config import Config
|
|||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
|
||||
def load_prompts(prompt_path):
|
||||
with open(prompt_path, "r") as f:
|
||||
prompts = [line.strip() for line in f.readlines()]
|
||||
return prompts
|
||||
|
||||
|
||||
def parse_args(training=False):
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
|
|
@ -47,12 +53,22 @@ def merge_args(cfg, args, training=False):
|
|||
cfg.scheduler["cfg_scale"] = args.cfg_scale
|
||||
args.cfg_scale = None
|
||||
|
||||
if "multi_resolution" not in cfg:
|
||||
cfg["multi_resolution"] = False
|
||||
for k, v in vars(args).items():
|
||||
if k in cfg and v is not None:
|
||||
cfg[k] = v
|
||||
|
||||
if "reference_path" not in cfg:
|
||||
cfg["reference_path"] = None
|
||||
if "loop" not in cfg:
|
||||
cfg["loop"] = 1
|
||||
if "multi_resolution" not in cfg:
|
||||
cfg["multi_resolution"] = False
|
||||
if "mask_ratios" not in cfg:
|
||||
cfg["mask_ratios"] = None
|
||||
if "prompt" not in cfg or cfg["prompt"] is None:
|
||||
assert cfg["prompt_path"] is not None, "prompt or prompt_path must be provided"
|
||||
cfg["prompt"] = load_prompts(cfg["prompt_path"])
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -13,12 +13,6 @@ from opensora.acceleration.parallel_states import set_sequence_parallel_group
|
|||
from colossalai.cluster import DistCoordinator
|
||||
|
||||
|
||||
def load_prompts(prompt_path):
|
||||
with open(prompt_path, "r") as f:
|
||||
prompts = [line.strip() for line in f.readlines()]
|
||||
return prompts
|
||||
|
||||
|
||||
def main():
|
||||
# ======================================================
|
||||
# 1. cfg and init distributed env
|
||||
|
|
@ -31,7 +25,7 @@ def main():
|
|||
coordinator = DistCoordinator()
|
||||
|
||||
if coordinator.world_size > 1:
|
||||
set_sequence_parallel_group(dist.group.WORLD)
|
||||
set_sequence_parallel_group(dist.group.WORLD)
|
||||
enable_sequence_parallelism = True
|
||||
else:
|
||||
enable_sequence_parallelism = False
|
||||
|
|
@ -45,7 +39,7 @@ def main():
|
|||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
dtype = to_torch_dtype(cfg.dtype)
|
||||
set_random_seed(seed=cfg.seed)
|
||||
prompts = load_prompts(cfg.prompt_path)
|
||||
prompts = cfg.prompt
|
||||
|
||||
# ======================================================
|
||||
# 3. build model & load weights
|
||||
|
|
|
|||
0
scripts/inference_long.py
Normal file
0
scripts/inference_long.py
Normal file
Loading…
Reference in a new issue