mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
parent
0e6c15d50b
commit
d2a782efac
|
|
@ -1,5 +1,6 @@
|
|||
num_frames = 16
|
||||
fps = 24 // 3
|
||||
frame_interval = 3
|
||||
fps = 24
|
||||
image_size = (240, 426)
|
||||
multi_resolution = "STDiT2"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
num_frames = 16
|
||||
fps = 24 // 3
|
||||
frame_interval = 3
|
||||
fps = 24
|
||||
image_size = (240, 426)
|
||||
multi_resolution = "STDiT2"
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ from opensora.registry import DATASETS
|
|||
|
||||
from .utils import VID_EXTENSIONS, get_transforms_image, get_transforms_video, read_file, temporal_random_crop
|
||||
|
||||
IMG_FPS = 120
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class VideoTextDataset(torch.utils.data.Dataset):
|
||||
|
|
@ -131,9 +133,12 @@ class VariableVideoTextDataset(VideoTextDataset):
|
|||
file_type = self.get_type(path)
|
||||
ar = width / height
|
||||
|
||||
video_fps = 24 # default fps
|
||||
if file_type == "video":
|
||||
# loading
|
||||
vframes, _, _ = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
|
||||
vframes, _, infos = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
|
||||
if "video_fps" in infos:
|
||||
video_fps = infos["video_fps"]
|
||||
|
||||
# Sampling video frames
|
||||
video = temporal_random_crop(vframes, num_frames, self.frame_interval)
|
||||
|
|
@ -144,6 +149,7 @@ class VariableVideoTextDataset(VideoTextDataset):
|
|||
else:
|
||||
# loading
|
||||
image = pil_loader(path)
|
||||
video_fps = IMG_FPS
|
||||
|
||||
# transform
|
||||
transform = get_transforms_image(self.transform_name, (height, width))
|
||||
|
|
@ -154,4 +160,12 @@ class VariableVideoTextDataset(VideoTextDataset):
|
|||
|
||||
# TCHW -> CTHW
|
||||
video = video.permute(1, 0, 2, 3)
|
||||
return {"video": video, "text": text, "num_frames": num_frames, "height": height, "width": width, "ar": ar}
|
||||
return {
|
||||
"video": video,
|
||||
"text": text,
|
||||
"num_frames": num_frames,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"ar": ar,
|
||||
"fps": video_fps,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -250,6 +250,7 @@ class STDiT2(nn.Module):
|
|||
self.csize_embedder = SizeEmbedder(self.hidden_size // 3)
|
||||
self.ar_embedder = SizeEmbedder(self.hidden_size // 3)
|
||||
self.fl_embedder = SizeEmbedder(self.hidden_size) # new
|
||||
self.fps_embedder = SizeEmbedder(self.hidden_size) # new
|
||||
|
||||
# init model
|
||||
self.initialize_weights()
|
||||
|
|
@ -281,7 +282,9 @@ class STDiT2(nn.Module):
|
|||
W = W // self.patch_size[2]
|
||||
return (T, H, W)
|
||||
|
||||
def forward(self, x, timestep, y, mask=None, x_mask=None, num_frames=None, height=None, width=None, ar=None):
|
||||
def forward(
|
||||
self, x, timestep, y, mask=None, x_mask=None, num_frames=None, height=None, width=None, ar=None, fps=None
|
||||
):
|
||||
"""
|
||||
Forward pass of STDiT.
|
||||
Args:
|
||||
|
|
@ -311,7 +314,9 @@ class STDiT2(nn.Module):
|
|||
|
||||
# 3. get number of frames
|
||||
fl = num_frames.unsqueeze(1)
|
||||
fps = fps.unsqueeze(1)
|
||||
fl = self.fl_embedder(fl, B)
|
||||
fl = fl + self.fps_embedder(fps, B)
|
||||
|
||||
# === get dynamic shape size ===
|
||||
_, _, Tx, Hx, Wx = x.size()
|
||||
|
|
|
|||
|
|
@ -84,6 +84,8 @@ def merge_args(cfg, args, training=False):
|
|||
cfg["reference_path"] = None
|
||||
if "loop" not in cfg:
|
||||
cfg["loop"] = 1
|
||||
if "frame_interval" not in cfg:
|
||||
cfg["frame_interval"] = 3
|
||||
# - Prompt handling
|
||||
if "prompt" not in cfg or cfg["prompt"] is None:
|
||||
assert cfg["prompt_path"] is not None, "prompt or prompt_path must be provided"
|
||||
|
|
|
|||
|
|
@ -150,10 +150,12 @@ def main():
|
|||
width = torch.tensor([image_size[1]], device=device, dtype=dtype).repeat(cfg.batch_size)
|
||||
num_frames = torch.tensor([cfg.num_frames], device=device, dtype=dtype).repeat(cfg.batch_size)
|
||||
ar = torch.tensor([image_size[0] / image_size[1]], device=device, dtype=dtype).repeat(cfg.batch_size)
|
||||
fps = torch.tensor([cfg.fps], device=device, dtype=dtype).repeat(cfg.batch_size)
|
||||
model_args["height"] = height
|
||||
model_args["width"] = width
|
||||
model_args["num_frames"] = num_frames
|
||||
model_args["ar"] = ar
|
||||
model_args["fps"] = fps
|
||||
|
||||
# 3.5 reference
|
||||
if cfg.reference_path is not None:
|
||||
|
|
@ -220,7 +222,7 @@ def main():
|
|||
video = torch.cat(video_clips_i, dim=1)
|
||||
print(f"Prompt: {prompts[i + idx]}")
|
||||
save_path = os.path.join(save_dir, f"{sample_name}_{sample_idx}")
|
||||
save_sample(video, fps=cfg.fps, save_path=save_path)
|
||||
save_sample(video, fps=cfg.fps // cfg.frame_interval, save_path=save_path)
|
||||
sample_idx += 1
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -87,10 +87,12 @@ def main():
|
|||
width = torch.tensor([image_size[1]], device=device, dtype=dtype).repeat(cfg.batch_size)
|
||||
num_frames = torch.tensor([cfg.num_frames], device=device, dtype=dtype).repeat(cfg.batch_size)
|
||||
ar = torch.tensor([image_size[0] / image_size[1]], device=device, dtype=dtype).repeat(cfg.batch_size)
|
||||
fps = torch.tensor([cfg.fps], device=device, dtype=dtype).repeat(cfg.batch_size)
|
||||
model_args["height"] = height
|
||||
model_args["width"] = width
|
||||
model_args["num_frames"] = num_frames
|
||||
model_args["ar"] = ar
|
||||
model_args["fps"] = fps
|
||||
|
||||
# ======================================================
|
||||
# 4. inference
|
||||
|
|
@ -122,7 +124,7 @@ def main():
|
|||
for idx, sample in enumerate(samples):
|
||||
print(f"Prompt: {batch_prompts[idx]}")
|
||||
save_path = os.path.join(save_dir, f"{sample_name}_{sample_idx}")
|
||||
save_sample(sample, fps=cfg.fps, save_path=save_path)
|
||||
save_sample(sample, fps=cfg.fps // cfg.frame_interval, save_path=save_path)
|
||||
sample_idx += 1
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue