diff --git a/configs/opensora-v1-1/inference/sample-ref.py b/configs/opensora-v1-1/inference/sample-ref.py index 9fce4a0..5099164 100644 --- a/configs/opensora-v1-1/inference/sample-ref.py +++ b/configs/opensora-v1-1/inference/sample-ref.py @@ -1,5 +1,6 @@ num_frames = 16 -fps = 24 // 3 +frame_interval = 3 +fps = 24 image_size = (240, 426) multi_resolution = "STDiT2" diff --git a/configs/opensora-v1-1/inference/sample.py b/configs/opensora-v1-1/inference/sample.py index 0d16f51..7e5185a 100644 --- a/configs/opensora-v1-1/inference/sample.py +++ b/configs/opensora-v1-1/inference/sample.py @@ -1,5 +1,6 @@ num_frames = 16 -fps = 24 // 3 +frame_interval = 3 +fps = 24 image_size = (240, 426) multi_resolution = "STDiT2" diff --git a/opensora/datasets/datasets.py b/opensora/datasets/datasets.py index ea0a61c..4c2b316 100644 --- a/opensora/datasets/datasets.py +++ b/opensora/datasets/datasets.py @@ -9,6 +9,8 @@ from opensora.registry import DATASETS from .utils import VID_EXTENSIONS, get_transforms_image, get_transforms_video, read_file, temporal_random_crop +IMG_FPS = 120 + @DATASETS.register_module() class VideoTextDataset(torch.utils.data.Dataset): @@ -131,9 +133,12 @@ class VariableVideoTextDataset(VideoTextDataset): file_type = self.get_type(path) ar = width / height + video_fps = 24 # default fps if file_type == "video": # loading - vframes, _, _ = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW") + vframes, _, infos = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW") + if "video_fps" in infos: + video_fps = infos["video_fps"] # Sampling video frames video = temporal_random_crop(vframes, num_frames, self.frame_interval) @@ -144,6 +149,7 @@ class VariableVideoTextDataset(VideoTextDataset): else: # loading image = pil_loader(path) + video_fps = IMG_FPS # transform transform = get_transforms_image(self.transform_name, (height, width)) @@ -154,4 +160,12 @@ class VariableVideoTextDataset(VideoTextDataset): # TCHW -> CTHW video = video.permute(1, 0, 2, 3) - return {"video": video, "text": text, "num_frames": num_frames, "height": height, "width": width, "ar": ar} + return { + "video": video, + "text": text, + "num_frames": num_frames, + "height": height, + "width": width, + "ar": ar, + "fps": video_fps, + } diff --git a/opensora/models/stdit/stdit2.py b/opensora/models/stdit/stdit2.py index 506ef1e..5e53db4 100644 --- a/opensora/models/stdit/stdit2.py +++ b/opensora/models/stdit/stdit2.py @@ -250,6 +250,7 @@ class STDiT2(nn.Module): self.csize_embedder = SizeEmbedder(self.hidden_size // 3) self.ar_embedder = SizeEmbedder(self.hidden_size // 3) self.fl_embedder = SizeEmbedder(self.hidden_size) # new + self.fps_embedder = SizeEmbedder(self.hidden_size) # new # init model self.initialize_weights() @@ -281,7 +282,9 @@ class STDiT2(nn.Module): W = W // self.patch_size[2] return (T, H, W) - def forward(self, x, timestep, y, mask=None, x_mask=None, num_frames=None, height=None, width=None, ar=None): + def forward( + self, x, timestep, y, mask=None, x_mask=None, num_frames=None, height=None, width=None, ar=None, fps=None + ): """ Forward pass of STDiT. Args: @@ -311,7 +314,9 @@ class STDiT2(nn.Module): # 3. get number of frames fl = num_frames.unsqueeze(1) + fps = fps.unsqueeze(1) fl = self.fl_embedder(fl, B) + fl = fl + self.fps_embedder(fps, B) # === get dynamic shape size === _, _, Tx, Hx, Wx = x.size() diff --git a/opensora/utils/config_utils.py b/opensora/utils/config_utils.py index 26ca2bc..0839185 100644 --- a/opensora/utils/config_utils.py +++ b/opensora/utils/config_utils.py @@ -84,6 +84,8 @@ def merge_args(cfg, args, training=False): cfg["reference_path"] = None if "loop" not in cfg: cfg["loop"] = 1 + if "frame_interval" not in cfg: + cfg["frame_interval"] = 3 # - Prompt handling if "prompt" not in cfg or cfg["prompt"] is None: assert cfg["prompt_path"] is not None, "prompt or prompt_path must be provided" diff --git a/scripts/inference-long.py b/scripts/inference-long.py index 9c7663d..33e3642 100644 --- a/scripts/inference-long.py +++ b/scripts/inference-long.py @@ -150,10 +150,12 @@ def main(): width = torch.tensor([image_size[1]], device=device, dtype=dtype).repeat(cfg.batch_size) num_frames = torch.tensor([cfg.num_frames], device=device, dtype=dtype).repeat(cfg.batch_size) ar = torch.tensor([image_size[0] / image_size[1]], device=device, dtype=dtype).repeat(cfg.batch_size) + fps = torch.tensor([cfg.fps], device=device, dtype=dtype).repeat(cfg.batch_size) model_args["height"] = height model_args["width"] = width model_args["num_frames"] = num_frames model_args["ar"] = ar + model_args["fps"] = fps # 3.5 reference if cfg.reference_path is not None: @@ -220,7 +222,7 @@ def main(): video = torch.cat(video_clips_i, dim=1) print(f"Prompt: {prompts[i + idx]}") save_path = os.path.join(save_dir, f"{sample_name}_{sample_idx}") - save_sample(video, fps=cfg.fps, save_path=save_path) + save_sample(video, fps=cfg.fps // cfg.frame_interval, save_path=save_path) sample_idx += 1 diff --git a/scripts/inference.py b/scripts/inference.py index bf2f8e1..7a66120 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -87,10 +87,12 @@ def main(): width = torch.tensor([image_size[1]], device=device, dtype=dtype).repeat(cfg.batch_size) num_frames = torch.tensor([cfg.num_frames], device=device, dtype=dtype).repeat(cfg.batch_size) ar = torch.tensor([image_size[0] / image_size[1]], device=device, dtype=dtype).repeat(cfg.batch_size) + fps = torch.tensor([cfg.fps], device=device, dtype=dtype).repeat(cfg.batch_size) model_args["height"] = height model_args["width"] = width model_args["num_frames"] = num_frames model_args["ar"] = ar + model_args["fps"] = fps # ====================================================== # 4. inference @@ -122,7 +124,7 @@ def main(): for idx, sample in enumerate(samples): print(f"Prompt: {batch_prompts[idx]}") save_path = os.path.join(save_dir, f"{sample_name}_{sample_idx}") - save_sample(sample, fps=cfg.fps, save_path=save_path) + save_sample(sample, fps=cfg.fps // cfg.frame_interval, save_path=save_path) sample_idx += 1