* support fps

* update fps
This commit is contained in:
Zheng Zangwei (Alex Zheng) 2024-04-19 13:18:52 +08:00 committed by GitHub
parent 0e6c15d50b
commit d2a782efac
7 changed files with 34 additions and 7 deletions

View file

@ -1,5 +1,6 @@
num_frames = 16
fps = 24 // 3
frame_interval = 3
fps = 24
image_size = (240, 426)
multi_resolution = "STDiT2"

View file

@ -1,5 +1,6 @@
num_frames = 16
fps = 24 // 3
frame_interval = 3
fps = 24
image_size = (240, 426)
multi_resolution = "STDiT2"

View file

@ -9,6 +9,8 @@ from opensora.registry import DATASETS
from .utils import VID_EXTENSIONS, get_transforms_image, get_transforms_video, read_file, temporal_random_crop
IMG_FPS = 120
@DATASETS.register_module()
class VideoTextDataset(torch.utils.data.Dataset):
@ -131,9 +133,12 @@ class VariableVideoTextDataset(VideoTextDataset):
file_type = self.get_type(path)
ar = width / height
video_fps = 24 # default fps
if file_type == "video":
# loading
vframes, _, _ = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
vframes, _, infos = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
if "video_fps" in infos:
video_fps = infos["video_fps"]
# Sampling video frames
video = temporal_random_crop(vframes, num_frames, self.frame_interval)
@ -144,6 +149,7 @@ class VariableVideoTextDataset(VideoTextDataset):
else:
# loading
image = pil_loader(path)
video_fps = IMG_FPS
# transform
transform = get_transforms_image(self.transform_name, (height, width))
@ -154,4 +160,12 @@ class VariableVideoTextDataset(VideoTextDataset):
# TCHW -> CTHW
video = video.permute(1, 0, 2, 3)
return {"video": video, "text": text, "num_frames": num_frames, "height": height, "width": width, "ar": ar}
return {
"video": video,
"text": text,
"num_frames": num_frames,
"height": height,
"width": width,
"ar": ar,
"fps": video_fps,
}

View file

@ -250,6 +250,7 @@ class STDiT2(nn.Module):
self.csize_embedder = SizeEmbedder(self.hidden_size // 3)
self.ar_embedder = SizeEmbedder(self.hidden_size // 3)
self.fl_embedder = SizeEmbedder(self.hidden_size) # new
self.fps_embedder = SizeEmbedder(self.hidden_size) # new
# init model
self.initialize_weights()
@ -281,7 +282,9 @@ class STDiT2(nn.Module):
W = W // self.patch_size[2]
return (T, H, W)
def forward(self, x, timestep, y, mask=None, x_mask=None, num_frames=None, height=None, width=None, ar=None):
def forward(
self, x, timestep, y, mask=None, x_mask=None, num_frames=None, height=None, width=None, ar=None, fps=None
):
"""
Forward pass of STDiT.
Args:
@ -311,7 +314,9 @@ class STDiT2(nn.Module):
# 3. get number of frames
fl = num_frames.unsqueeze(1)
fps = fps.unsqueeze(1)
fl = self.fl_embedder(fl, B)
fl = fl + self.fps_embedder(fps, B)
# === get dynamic shape size ===
_, _, Tx, Hx, Wx = x.size()

View file

@ -84,6 +84,8 @@ def merge_args(cfg, args, training=False):
cfg["reference_path"] = None
if "loop" not in cfg:
cfg["loop"] = 1
if "frame_interval" not in cfg:
cfg["frame_interval"] = 3
# - Prompt handling
if "prompt" not in cfg or cfg["prompt"] is None:
assert cfg["prompt_path"] is not None, "prompt or prompt_path must be provided"

View file

@ -150,10 +150,12 @@ def main():
width = torch.tensor([image_size[1]], device=device, dtype=dtype).repeat(cfg.batch_size)
num_frames = torch.tensor([cfg.num_frames], device=device, dtype=dtype).repeat(cfg.batch_size)
ar = torch.tensor([image_size[0] / image_size[1]], device=device, dtype=dtype).repeat(cfg.batch_size)
fps = torch.tensor([cfg.fps], device=device, dtype=dtype).repeat(cfg.batch_size)
model_args["height"] = height
model_args["width"] = width
model_args["num_frames"] = num_frames
model_args["ar"] = ar
model_args["fps"] = fps
# 3.5 reference
if cfg.reference_path is not None:
@ -220,7 +222,7 @@ def main():
video = torch.cat(video_clips_i, dim=1)
print(f"Prompt: {prompts[i + idx]}")
save_path = os.path.join(save_dir, f"{sample_name}_{sample_idx}")
save_sample(video, fps=cfg.fps, save_path=save_path)
save_sample(video, fps=cfg.fps // cfg.frame_interval, save_path=save_path)
sample_idx += 1

View file

@ -87,10 +87,12 @@ def main():
width = torch.tensor([image_size[1]], device=device, dtype=dtype).repeat(cfg.batch_size)
num_frames = torch.tensor([cfg.num_frames], device=device, dtype=dtype).repeat(cfg.batch_size)
ar = torch.tensor([image_size[0] / image_size[1]], device=device, dtype=dtype).repeat(cfg.batch_size)
fps = torch.tensor([cfg.fps], device=device, dtype=dtype).repeat(cfg.batch_size)
model_args["height"] = height
model_args["width"] = width
model_args["num_frames"] = num_frames
model_args["ar"] = ar
model_args["fps"] = fps
# ======================================================
# 4. inference
@ -122,7 +124,7 @@ def main():
for idx, sample in enumerate(samples):
print(f"Prompt: {batch_prompts[idx]}")
save_path = os.path.join(save_dir, f"{sample_name}_{sample_idx}")
save_sample(sample, fps=cfg.fps, save_path=save_path)
save_sample(sample, fps=cfg.fps // cfg.frame_interval, save_path=save_path)
sample_idx += 1