update mask & config

This commit is contained in:
Zangwei Zheng 2024-04-03 17:25:41 +08:00
parent 3fcda91877
commit 6762476261
3 changed files with 23 additions and 11 deletions

View file

@ -3,12 +3,17 @@ dataset = dict(
type="VariableVideoTextDataset",
data_path=None,
num_frames=None,
frame_interval=3,
frame_interval=1,
image_size=(None, None),
transform_name="resize_crop",
)
bucket_config = {
"240p": {16: (1.0, 16)}, # 4.6s/it
# "240p": {128: (1.0, 2)}, # 4.28s/it
# "240p": {64: (1.0, 4)},
# "240p": {32: (1.0, 8)}, # 4.6s/it
# "240p": {16: (1.0, 16)}, # 4.6s/it
# "480p": {16: (1.0, 4)}, # 4.6s/it
"720p": {16: (1.0, 2)},
# "256": {1: (1.0, 256)}, # 4.5s/it
# "512": {1: (1.0, 96)}, # 4.7s/it
# "512": {1: (1.0, 128)}, # 6.3s/it
@ -18,10 +23,10 @@ bucket_config = {
# "1080p": {1: (1.0, 16)}, # 8.6s/it
# "1080p": {1: (1.0, 8)}, # 4.4s/it
}
mask_ratios = {
"mask_no": 0.0,
"mask_random": 1.0,
}
# mask_ratios = {
# "mask_no": 0.0,
# "mask_random": 1.0,
# }
# Define acceleration
num_workers = 4

View file

@ -3,19 +3,26 @@ dataset = dict(
type="VariableVideoTextDataset",
data_path=None,
num_frames=None,
frame_interval=3,
frame_interval=1,
image_size=(None, None),
transform_name="resize_crop",
)
bucket_config = { # 6s/it
"240p": {16: (1.0, 16)},
"240p": {16: (1.0, 16), 32: (1.0, 8), 64: (1.0, 4), 128: (1.0, 2)},
"256": {1: (1.0, 256)},
"512": {1: (1.0, 80)},
"480p": {1: (1.0, 52), 16: (1.0, 4)},
"720p": {16: (1.0, 2)},
"480p": {1: (1.0, 52), 16: (0.5, 4), 32: (0.0, None)},
"720p": {16: (1.0, 2), 32: (0.0, None)}, # No examples now
"1024": {1: (1.0, 20)},
"1080p": {1: (1.0, 8)},
}
mask_ratios = {
"mask_no": 0.9,
"mask_random": 0.06,
"mask_head": 0.01,
"mask_tail": 0.01,
"mask_head_tail": 0.02,
}
# Define acceleration
num_workers = 4

View file

@ -90,7 +90,7 @@ class STDiT2Block(nn.Module):
# x: [B, (T, S), C]
# mased_x: [B, (T, S), C]
# x_mask: [B, T]
x = rearrange(x, "B (T S) C -> B T S C", T=self.T, S=self.S)
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
masked_x = rearrange(masked_x, "B (T S) C -> B T S C", T=T, S=S)
x = torch.where(x_mask[:, :, None, None], x, masked_x)
x = rearrange(x, "B T S C -> B (T S) C")