merge mask-related utils

This commit is contained in:
Zangwei Zheng 2024-03-23 16:32:51 +08:00
parent 24c707bfb9
commit 7d27f5553e
8 changed files with 120 additions and 15 deletions

View file

@ -28,8 +28,11 @@ scheduler = dict(
)
dtype = "fp16"
# Condition
prompt_path = "./assets/texts/t2v_samples.txt"
prompt = None # prompt has higher priority than prompt_path
# Others
batch_size = 1
seed = 42
prompt_path = "./assets/texts/t2v_samples.txt"
save_dir = "./outputs/samples/"

View file

@ -0,0 +1,46 @@
num_frames = 16
fps = 24 // 3
image_size = (256, 256)
# Define model
model = dict(
type="STDiT-XL/2",
space_scale=0.5,
time_scale=1.0,
enable_flashattn=True,
enable_layernorm_kernel=True,
from_pretrained="PRETRAINED_MODEL",
)
vae = dict(
type="VideoAutoencoderKL",
from_pretrained="stabilityai/sd-vae-ft-ema",
micro_batch_size=4,
)
text_encoder = dict(
type="t5",
from_pretrained="./pretrained_models/t5_ckpts",
model_max_length=120,
)
scheduler = dict(
type="iddpm",
num_sampling_steps=100,
cfg_scale=7.0,
)
dtype = "fp16"
# Condition
prompt_path = None
prompt = [
"In an ornate, historical hall, a massive tidal wave peaks and begins to crash. Two surfers, seizing the moment, skillfully navigate the face of the wave."
]
loop = 10
condition_frame_length = 4
reference_path = ["assets/images/condition/wave.png"]
mask_strategy = ["0,0,0,1,0"] # valid when reference_path is not None
# loop id, ref id, ref start, length, target start
# Others
batch_size = 2
seed = 42
save_dir = "./samples/"

View file

@ -8,14 +8,14 @@ import torchvision.transforms as transforms
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
from . import video_transforms
from .utils import center_crop_arr
from .utils import center_crop_arr, VID_EXTENSIONS
def get_transforms_video(resolution=256):
transform_video = transforms.Compose(
[
video_transforms.ToTensorVideo(), # TCHW
video_transforms.RandomHorizontalFlipVideo(),
# video_transforms.RandomHorizontalFlipVideo(),
video_transforms.UCFCenterCropVideo(resolution),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
@ -27,7 +27,7 @@ def get_transforms_image(image_size=256):
transform = transforms.Compose(
[
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size)),
transforms.RandomHorizontalFlip(),
# transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
@ -58,7 +58,7 @@ class DatasetFromCSV(torch.utils.data.Dataset):
self.samples = list(reader)
ext = self.samples[0][0].split(".")[-1]
if ext.lower() in ("mp4", "avi", "mov", "mkv"):
if ext.lower() in VID_EXTENSIONS:
self.is_video = True
else:
assert f".{ext.lower()}" in IMG_EXTENSIONS, f"Unsupported file format: {ext}"

View file

@ -11,6 +11,36 @@ from torch.utils.data.distributed import DistributedSampler
from torchvision.io import write_video
from torchvision.utils import save_image
VID_EXTENSIONS = ("mp4", "avi", "mov", "mkv")
def read_image_from_path(path, transform=None, num_frames=1, image_size=256):
image = pil_loader(path)
if transform is None:
transform = get_transforms_image(image_size)
image = transform(image)
video = image.unsqueeze(0).repeat(num_frames, 1, 1, 1)
video = video.permute(1, 0, 2, 3)
return video
def read_video_from_path(path, transform=None, image_size=256):
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
if transform is None:
transform = get_transforms_video(image_size)
video = transform(vframes) # T C H W
video = video.permute(1, 0, 2, 3)
return video
def read_from_path(path, image_size):
ext = path.split(".")[-1]
if ext.lower() in VID_EXTENSIONS:
return read_video_from_path(path, image_size=image_size)
else:
assert f".{ext.lower()}" in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
return read_image_from_path(path, image_size=image_size)
def save_sample(x, fps=8, save_path=None, normalize=True, value_range=(-1, 1)):
"""

View file

@ -56,6 +56,14 @@ class VideoAutoencoderKL(nn.Module):
input_size = [input_size[i] // self.patch_size[i] for i in range(3)]
return input_size
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
@MODELS.register_module()
class VideoAutoencoderKLTemporalDecoder(nn.Module):
@ -80,3 +88,11 @@ class VideoAutoencoderKLTemporalDecoder(nn.Module):
assert input_size[i] % self.patch_size[i] == 0, "Input size must be divisible by patch size"
input_size = [input_size[i] // self.patch_size[i] for i in range(3)]
return input_size
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype

View file

@ -7,6 +7,12 @@ from mmengine.config import Config
from torch.utils.tensorboard import SummaryWriter
def load_prompts(prompt_path):
with open(prompt_path, "r") as f:
prompts = [line.strip() for line in f.readlines()]
return prompts
def parse_args(training=False):
parser = argparse.ArgumentParser()
@ -47,12 +53,22 @@ def merge_args(cfg, args, training=False):
cfg.scheduler["cfg_scale"] = args.cfg_scale
args.cfg_scale = None
if "multi_resolution" not in cfg:
cfg["multi_resolution"] = False
for k, v in vars(args).items():
if k in cfg and v is not None:
cfg[k] = v
if "reference_path" not in cfg:
cfg["reference_path"] = None
if "loop" not in cfg:
cfg["loop"] = 1
if "multi_resolution" not in cfg:
cfg["multi_resolution"] = False
if "mask_ratios" not in cfg:
cfg["mask_ratios"] = None
if "prompt" not in cfg or cfg["prompt"] is None:
assert cfg["prompt_path"] is not None, "prompt or prompt_path must be provided"
cfg["prompt"] = load_prompts(cfg["prompt_path"])
return cfg

View file

@ -13,12 +13,6 @@ from opensora.acceleration.parallel_states import set_sequence_parallel_group
from colossalai.cluster import DistCoordinator
def load_prompts(prompt_path):
with open(prompt_path, "r") as f:
prompts = [line.strip() for line in f.readlines()]
return prompts
def main():
# ======================================================
# 1. cfg and init distributed env
@ -31,7 +25,7 @@ def main():
coordinator = DistCoordinator()
if coordinator.world_size > 1:
set_sequence_parallel_group(dist.group.WORLD)
set_sequence_parallel_group(dist.group.WORLD)
enable_sequence_parallelism = True
else:
enable_sequence_parallelism = False
@ -45,7 +39,7 @@ def main():
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = to_torch_dtype(cfg.dtype)
set_random_seed(seed=cfg.seed)
prompts = load_prompts(cfg.prompt_path)
prompts = cfg.prompt
# ======================================================
# 3. build model & load weights

View file