From 01728dc28aa810cf467edeedc4483c4196399ae6 Mon Sep 17 00:00:00 2001 From: Zangwei Zheng Date: Tue, 26 Mar 2024 16:50:36 +0800 Subject: [PATCH] refactor datasets --- configs/opensora-v1-1/inference/Vx256x256.py | 38 +++++++++++++++ configs/opensora-v1-1/train/Vx256x256.py | 50 +++++++++++++++++++ configs/opensora/train/16x256x256-spee.py | 2 +- opensora/datasets/datasets.py | 51 ++++++++++++-------- opensora/models/stdit/stdit2.py | 26 ++++------ scripts/train.py | 12 ++--- 6 files changed, 131 insertions(+), 48 deletions(-) create mode 100644 configs/opensora-v1-1/inference/Vx256x256.py create mode 100644 configs/opensora-v1-1/train/Vx256x256.py diff --git a/configs/opensora-v1-1/inference/Vx256x256.py b/configs/opensora-v1-1/inference/Vx256x256.py new file mode 100644 index 0000000..f609c37 --- /dev/null +++ b/configs/opensora-v1-1/inference/Vx256x256.py @@ -0,0 +1,38 @@ +num_frames = None +fps = 24 // 3 +image_size = (256, 256) + +# Define model +model = dict( + type="STDiT2-XL/2", + space_scale=0.5, + from_pretrained="PixArt-XL-2-1024-MS.pth", + enable_flashattn=True, + enable_layernorm_kernel=True, +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", + micro_batch_size=4, +) +text_encoder = dict( + type="t5", + from_pretrained="DeepFloyd/t5-v1_1-xxl", + model_max_length=120, +) +scheduler = dict( + type="iddpm", + num_sampling_steps=100, + cfg_scale=7.0, + cfg_channel=3, # or None +) +dtype = "fp16" + +# Condition +prompt_path = "./assets/texts/t2v_samples.txt" +prompt = None # prompt has higher priority than prompt_path + +# Others +batch_size = 1 +seed = 42 +save_dir = "./outputs/samples/" diff --git a/configs/opensora-v1-1/train/Vx256x256.py b/configs/opensora-v1-1/train/Vx256x256.py new file mode 100644 index 0000000..c2bd76d --- /dev/null +++ b/configs/opensora-v1-1/train/Vx256x256.py @@ -0,0 +1,50 @@ +# Define dataset +data_path = "CSV_PATH" +num_frames = 16 +frame_interval = 3 +image_size = (256, 256) + +# Define acceleration +num_workers = 4 +dtype = "bf16" +grad_checkpoint = True +plugin = "zero2" +sp_size = 1 + +# Define model +model = dict( + type="STDiT2-XL/2", + space_scale=0.5, + from_pretrained="PixArt-XL-2-1024-MS.pth", + enable_flashattn=True, + enable_layernorm_kernel=True, +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", + micro_batch_size=4, +) +text_encoder = dict( + type="t5", + from_pretrained="DeepFloyd/t5-v1_1-xxl", + model_max_length=120, + shardformer=True, +) +scheduler = dict( + type="iddpm-speed", + timestep_respacing="", +) + +# Others +seed = 42 +outputs = "outputs" +wandb = False + +epochs = 1000 +log_every = 10 +ckpt_every = 1000 +load = None + +batch_size = 8 +lr = 2e-5 +grad_clip = 1.0 diff --git a/configs/opensora/train/16x256x256-spee.py b/configs/opensora/train/16x256x256-spee.py index 12f06c9..729a382 100644 --- a/configs/opensora/train/16x256x256-spee.py +++ b/configs/opensora/train/16x256x256-spee.py @@ -23,7 +23,7 @@ model = dict( enable_flashattn=True, enable_layernorm_kernel=True, ) -mask_ratios = [0.5, 0.29, 0.07, 0.07, 0.07] +# mask_ratios = [0.5, 0.29, 0.07, 0.07, 0.07] vae = dict( type="VideoAutoencoderKL", from_pretrained="stabilityai/sd-vae-ft-ema", diff --git a/opensora/datasets/datasets.py b/opensora/datasets/datasets.py index 3f2b71c..fde74e5 100644 --- a/opensora/datasets/datasets.py +++ b/opensora/datasets/datasets.py @@ -8,6 +8,7 @@ from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader from . import video_transforms from .utils import VID_EXTENSIONS +from .utils import get_transforms_image, get_transforms_video class DatasetFromCSV(torch.utils.data.Dataset): @@ -24,53 +25,61 @@ class DatasetFromCSV(torch.utils.data.Dataset): csv_path, num_frames=16, frame_interval=1, - transform=None, - root=None, + image_size=(256, 256), ): self.csv_path = csv_path self.data = pd.read_csv(csv_path) - - ext = self.data["path"][0].split(".")[-1] - if ext.lower() in VID_EXTENSIONS: - self.is_video = True - else: - assert f".{ext.lower()}" in IMG_EXTENSIONS, f"Unsupported file format: {ext}" - self.is_video = False - - self.transform = transform self.num_frames = num_frames self.frame_interval = frame_interval self.temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval) - self.root = root + self.transforms = { + "image": get_transforms_image(image_size[0]), + "video": get_transforms_video(image_size[0]), + } + + def get_type(self, path): + ext = path.split(".")[-1] + if ext.lower() in VID_EXTENSIONS: + return "video" + else: + assert f".{ext.lower()}" in IMG_EXTENSIONS, f"Unsupported file format: {ext}" + return "image" def getitem(self, index): sample = self.data.iloc[index] path = sample["path"] - if self.root: - path = os.path.join(self.root, path) text = sample["text"] + file_type = self.get_type(path) - if self.is_video: - vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW") - total_frames = len(vframes) + if file_type == "video": + # loading + vframes, _, _ = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW") # Sampling video frames + total_frames = len(vframes) start_frame_ind, end_frame_ind = self.temporal_sample(total_frames) assert ( end_frame_ind - start_frame_ind >= self.num_frames ), f"{path} with index {index} has not enough frames." frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int) - video = vframes[frame_indice] - video = self.transform(video) # T C H W + + # transform + transform = self.transforms["video"] + video = transform(video) # T C H W else: + # loading image = pil_loader(path) - image = self.transform(image) + + # transform + transform = self.transforms["image"] + image = transform(image) + + # repeat video = image.unsqueeze(0).repeat(self.num_frames, 1, 1, 1) # TCHW -> CTHW video = video.permute(1, 0, 2, 3) - return {"video": video, "text": text} def __getitem__(self, index): diff --git a/opensora/models/stdit/stdit2.py b/opensora/models/stdit/stdit2.py index 7c41388..c82962f 100644 --- a/opensora/models/stdit/stdit2.py +++ b/opensora/models/stdit/stdit2.py @@ -195,7 +195,6 @@ class STDiT2(nn.Module): model_max_length=120, dtype=torch.float32, space_scale=1.0, - time_scale=1.0, freeze=None, enable_flashattn=False, enable_layernorm_kernel=False, @@ -206,12 +205,6 @@ class STDiT2(nn.Module): self.in_channels = in_channels self.out_channels = in_channels * 2 if pred_sigma else in_channels self.hidden_size = hidden_size - self.patch_size = patch_size - self.input_size = input_size - num_patches = np.prod([input_size[i] // patch_size[i] for i in range(3)]) - self.num_patches = num_patches - self.num_temporal = input_size[0] // patch_size[0] - self.num_spatial = num_patches // self.num_temporal self.num_heads = num_heads self.dtype = dtype self.no_temporal_pos_emb = no_temporal_pos_emb @@ -220,7 +213,15 @@ class STDiT2(nn.Module): self.enable_flashattn = enable_flashattn self.enable_layernorm_kernel = enable_layernorm_kernel self.space_scale = space_scale - self.time_scale = time_scale + + # support dynamic input + self.patch_size = patch_size + self.input_size = input_size + if input_size[0] == None: + self.num_temporal = None + else: + self.num_temporal = input_size[0] // patch_size[0] + self.num_spatial = np.prod(input_size[1:]) // np.prod(patch_size[1:]) self.register_buffer("pos_embed", self.get_spatial_pos_embed()) # self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed()) @@ -420,15 +421,6 @@ class STDiT2(nn.Module): pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False) return pos_embed - def get_temporal_pos_embed(self): - pos_embed = get_1d_sincos_pos_embed( - self.hidden_size, - self.input_size[0] // self.patch_size[0], - scale=self.time_scale, - ) - pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False) - return pos_embed - def freeze_not_temporal(self): for n, p in self.named_parameters(): if "attn_temp" not in n: diff --git a/scripts/train.py b/scripts/train.py index 1c61d6e..63be9f7 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -90,16 +90,10 @@ def main(): # 3. build dataset and dataloader # ====================================================== dataset = DatasetFromCSV( - cfg.data_path, - # TODO: change transforms - transform=( - get_transforms_video(cfg.image_size[0]) - if not cfg.use_image_transform - else get_transforms_image(cfg.image_size[0]) - ), + csv_path=cfg.data_path, num_frames=cfg.num_frames, frame_interval=cfg.frame_interval, - root=cfg.root, + image_size=cfg.image_size, ) # TODO: use plugin's prepare dataloader @@ -222,7 +216,7 @@ def main(): x = vae.encode(x) # [B, C, T, H/P, W/P] # Prepare text inputs model_args = text_encoder.encode(y) - + # Mask if cfg.mask_ratios is not None: mask = mask_generator.get_masks(x)