From 6762476261d0dcf030b7639aba17948a277d8886 Mon Sep 17 00:00:00 2001 From: Zangwei Zheng Date: Wed, 3 Apr 2024 17:25:41 +0800 Subject: [PATCH] update mask & config --- configs/opensora-v1-1/train/benchmark.py | 17 +++++++++++------ configs/opensora-v1-1/train/video.py | 15 +++++++++++---- opensora/models/stdit/stdit2.py | 2 +- 3 files changed, 23 insertions(+), 11 deletions(-) diff --git a/configs/opensora-v1-1/train/benchmark.py b/configs/opensora-v1-1/train/benchmark.py index d45c2a5..669ce14 100644 --- a/configs/opensora-v1-1/train/benchmark.py +++ b/configs/opensora-v1-1/train/benchmark.py @@ -3,12 +3,17 @@ dataset = dict( type="VariableVideoTextDataset", data_path=None, num_frames=None, - frame_interval=3, + frame_interval=1, image_size=(None, None), transform_name="resize_crop", ) bucket_config = { - "240p": {16: (1.0, 16)}, # 4.6s/it + # "240p": {128: (1.0, 2)}, # 4.28s/it + # "240p": {64: (1.0, 4)}, + # "240p": {32: (1.0, 8)}, # 4.6s/it + # "240p": {16: (1.0, 16)}, # 4.6s/it + # "480p": {16: (1.0, 4)}, # 4.6s/it + "720p": {16: (1.0, 2)}, # "256": {1: (1.0, 256)}, # 4.5s/it # "512": {1: (1.0, 96)}, # 4.7s/it # "512": {1: (1.0, 128)}, # 6.3s/it @@ -18,10 +23,10 @@ bucket_config = { # "1080p": {1: (1.0, 16)}, # 8.6s/it # "1080p": {1: (1.0, 8)}, # 4.4s/it } -mask_ratios = { - "mask_no": 0.0, - "mask_random": 1.0, -} +# mask_ratios = { +# "mask_no": 0.0, +# "mask_random": 1.0, +# } # Define acceleration num_workers = 4 diff --git a/configs/opensora-v1-1/train/video.py b/configs/opensora-v1-1/train/video.py index c79cf35..285ae84 100644 --- a/configs/opensora-v1-1/train/video.py +++ b/configs/opensora-v1-1/train/video.py @@ -3,19 +3,26 @@ dataset = dict( type="VariableVideoTextDataset", data_path=None, num_frames=None, - frame_interval=3, + frame_interval=1, image_size=(None, None), transform_name="resize_crop", ) bucket_config = { # 6s/it - "240p": {16: (1.0, 16)}, + "240p": {16: (1.0, 16), 32: (1.0, 8), 64: (1.0, 4), 128: (1.0, 2)}, "256": {1: (1.0, 256)}, "512": {1: (1.0, 80)}, - "480p": {1: (1.0, 52), 16: (1.0, 4)}, - "720p": {16: (1.0, 2)}, + "480p": {1: (1.0, 52), 16: (0.5, 4), 32: (0.0, None)}, + "720p": {16: (1.0, 2), 32: (0.0, None)}, # No examples now "1024": {1: (1.0, 20)}, "1080p": {1: (1.0, 8)}, } +mask_ratios = { + "mask_no": 0.9, + "mask_random": 0.06, + "mask_head": 0.01, + "mask_tail": 0.01, + "mask_head_tail": 0.02, +} # Define acceleration num_workers = 4 diff --git a/opensora/models/stdit/stdit2.py b/opensora/models/stdit/stdit2.py index e494209..506ef1e 100644 --- a/opensora/models/stdit/stdit2.py +++ b/opensora/models/stdit/stdit2.py @@ -90,7 +90,7 @@ class STDiT2Block(nn.Module): # x: [B, (T, S), C] # mased_x: [B, (T, S), C] # x_mask: [B, T] - x = rearrange(x, "B (T S) C -> B T S C", T=self.T, S=self.S) + x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S) masked_x = rearrange(masked_x, "B (T S) C -> B T S C", T=T, S=S) x = torch.where(x_mask[:, :, None, None], x, masked_x) x = rearrange(x, "B T S C -> B (T S) C")