mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 13:14:44 +02:00
update mask & config
This commit is contained in:
parent
3fcda91877
commit
6762476261
|
|
@ -3,12 +3,17 @@ dataset = dict(
|
|||
type="VariableVideoTextDataset",
|
||||
data_path=None,
|
||||
num_frames=None,
|
||||
frame_interval=3,
|
||||
frame_interval=1,
|
||||
image_size=(None, None),
|
||||
transform_name="resize_crop",
|
||||
)
|
||||
bucket_config = {
|
||||
"240p": {16: (1.0, 16)}, # 4.6s/it
|
||||
# "240p": {128: (1.0, 2)}, # 4.28s/it
|
||||
# "240p": {64: (1.0, 4)},
|
||||
# "240p": {32: (1.0, 8)}, # 4.6s/it
|
||||
# "240p": {16: (1.0, 16)}, # 4.6s/it
|
||||
# "480p": {16: (1.0, 4)}, # 4.6s/it
|
||||
"720p": {16: (1.0, 2)},
|
||||
# "256": {1: (1.0, 256)}, # 4.5s/it
|
||||
# "512": {1: (1.0, 96)}, # 4.7s/it
|
||||
# "512": {1: (1.0, 128)}, # 6.3s/it
|
||||
|
|
@ -18,10 +23,10 @@ bucket_config = {
|
|||
# "1080p": {1: (1.0, 16)}, # 8.6s/it
|
||||
# "1080p": {1: (1.0, 8)}, # 4.4s/it
|
||||
}
|
||||
mask_ratios = {
|
||||
"mask_no": 0.0,
|
||||
"mask_random": 1.0,
|
||||
}
|
||||
# mask_ratios = {
|
||||
# "mask_no": 0.0,
|
||||
# "mask_random": 1.0,
|
||||
# }
|
||||
|
||||
# Define acceleration
|
||||
num_workers = 4
|
||||
|
|
|
|||
|
|
@ -3,19 +3,26 @@ dataset = dict(
|
|||
type="VariableVideoTextDataset",
|
||||
data_path=None,
|
||||
num_frames=None,
|
||||
frame_interval=3,
|
||||
frame_interval=1,
|
||||
image_size=(None, None),
|
||||
transform_name="resize_crop",
|
||||
)
|
||||
bucket_config = { # 6s/it
|
||||
"240p": {16: (1.0, 16)},
|
||||
"240p": {16: (1.0, 16), 32: (1.0, 8), 64: (1.0, 4), 128: (1.0, 2)},
|
||||
"256": {1: (1.0, 256)},
|
||||
"512": {1: (1.0, 80)},
|
||||
"480p": {1: (1.0, 52), 16: (1.0, 4)},
|
||||
"720p": {16: (1.0, 2)},
|
||||
"480p": {1: (1.0, 52), 16: (0.5, 4), 32: (0.0, None)},
|
||||
"720p": {16: (1.0, 2), 32: (0.0, None)}, # No examples now
|
||||
"1024": {1: (1.0, 20)},
|
||||
"1080p": {1: (1.0, 8)},
|
||||
}
|
||||
mask_ratios = {
|
||||
"mask_no": 0.9,
|
||||
"mask_random": 0.06,
|
||||
"mask_head": 0.01,
|
||||
"mask_tail": 0.01,
|
||||
"mask_head_tail": 0.02,
|
||||
}
|
||||
|
||||
# Define acceleration
|
||||
num_workers = 4
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ class STDiT2Block(nn.Module):
|
|||
# x: [B, (T, S), C]
|
||||
# mased_x: [B, (T, S), C]
|
||||
# x_mask: [B, T]
|
||||
x = rearrange(x, "B (T S) C -> B T S C", T=self.T, S=self.S)
|
||||
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
|
||||
masked_x = rearrange(masked_x, "B (T S) C -> B T S C", T=T, S=S)
|
||||
x = torch.where(x_mask[:, :, None, None], x, masked_x)
|
||||
x = rearrange(x, "B T S C -> B (T S) C")
|
||||
|
|
|
|||
Loading…
Reference in a new issue