mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-02-22 21:43:19 +01:00
read video
This commit is contained in:
parent
9370fbe02d
commit
8a686660bc
|
|
@ -8,7 +8,7 @@ dataset = dict(
|
|||
)
|
||||
|
||||
# Define acceleration
|
||||
num_workers = 4
|
||||
num_workers = 0
|
||||
dtype = "bf16"
|
||||
grad_checkpoint = True
|
||||
plugin = "zero2"
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
|
|||
from opensora.registry import DATASETS
|
||||
|
||||
from .utils import VID_EXTENSIONS, get_transforms_image, get_transforms_video, read_file, temporal_random_crop
|
||||
from .read_video import read_video
|
||||
|
||||
IMG_FPS = 120
|
||||
|
||||
|
|
@ -67,7 +68,7 @@ class VideoTextDataset(torch.utils.data.Dataset):
|
|||
|
||||
if file_type == "video":
|
||||
# loading
|
||||
vframes, _, infos = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
|
||||
vframes, _, infos = read_video(filename=path, pts_unit="sec", output_format="TCHW")
|
||||
|
||||
if "video_fps" in infos:
|
||||
video_fps = infos["video_fps"]
|
||||
|
|
@ -145,7 +146,7 @@ class VariableVideoTextDataset(VideoTextDataset):
|
|||
video_fps = 24 # default fps
|
||||
if file_type == "video":
|
||||
# loading
|
||||
vframes, _, infos = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
|
||||
vframes, _, infos = read_video(filename=path, pts_unit="sec", output_format="TCHW")
|
||||
if "video_fps" in infos:
|
||||
video_fps = infos["video_fps"]
|
||||
|
||||
|
|
|
|||
129
opensora/datasets/read_video.py
Normal file
129
opensora/datasets/read_video.py
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
import math
|
||||
import os
|
||||
from fractions import Fraction
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import torch
|
||||
from torchvision.io.video import (
|
||||
_align_audio_frames,
|
||||
_check_av_available,
|
||||
_log_api_usage_once,
|
||||
_read_from_stream,
|
||||
_video_opt,
|
||||
)
|
||||
|
||||
|
||||
def read_video(
|
||||
filename: str,
|
||||
start_pts: Union[float, Fraction] = 0,
|
||||
end_pts: Optional[Union[float, Fraction]] = None,
|
||||
pts_unit: str = "pts",
|
||||
output_format: str = "THWC",
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
|
||||
"""
|
||||
Reads a video from a file, returning both the video frames and the audio frames
|
||||
|
||||
Args:
|
||||
filename (str): path to the video file
|
||||
start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
|
||||
The start presentation time of the video
|
||||
end_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
|
||||
The end presentation time
|
||||
pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted,
|
||||
either 'pts' or 'sec'. Defaults to 'pts'.
|
||||
output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW".
|
||||
|
||||
Returns:
|
||||
vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames
|
||||
aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points
|
||||
info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int)
|
||||
"""
|
||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||
_log_api_usage_once(read_video)
|
||||
|
||||
output_format = output_format.upper()
|
||||
if output_format not in ("THWC", "TCHW"):
|
||||
raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
|
||||
|
||||
from torchvision import get_video_backend
|
||||
|
||||
if not os.path.exists(filename):
|
||||
raise RuntimeError(f"File not found: {filename}")
|
||||
|
||||
if get_video_backend() != "pyav":
|
||||
vframes, aframes, info = _video_opt._read_video(filename, start_pts, end_pts, pts_unit)
|
||||
else:
|
||||
_check_av_available()
|
||||
|
||||
if end_pts is None:
|
||||
end_pts = float("inf")
|
||||
|
||||
if end_pts < start_pts:
|
||||
raise ValueError(
|
||||
f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}"
|
||||
)
|
||||
|
||||
info = {}
|
||||
video_frames = []
|
||||
audio_frames = []
|
||||
audio_timebase = _video_opt.default_timebase
|
||||
|
||||
try:
|
||||
with av.open(filename, metadata_errors="ignore") as container:
|
||||
if container.streams.audio:
|
||||
audio_timebase = container.streams.audio[0].time_base
|
||||
if container.streams.video:
|
||||
video_frames = _read_from_stream(
|
||||
container,
|
||||
start_pts,
|
||||
end_pts,
|
||||
pts_unit,
|
||||
container.streams.video[0],
|
||||
{"video": 0},
|
||||
)
|
||||
video_fps = container.streams.video[0].average_rate
|
||||
# guard against potentially corrupted files
|
||||
if video_fps is not None:
|
||||
info["video_fps"] = float(video_fps)
|
||||
|
||||
if container.streams.audio:
|
||||
audio_frames = _read_from_stream(
|
||||
container,
|
||||
start_pts,
|
||||
end_pts,
|
||||
pts_unit,
|
||||
container.streams.audio[0],
|
||||
{"audio": 0},
|
||||
)
|
||||
info["audio_fps"] = container.streams.audio[0].rate
|
||||
|
||||
except av.AVError:
|
||||
# TODO raise a warning?
|
||||
pass
|
||||
|
||||
vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames]
|
||||
aframes_list = [frame.to_ndarray() for frame in audio_frames]
|
||||
|
||||
if vframes_list:
|
||||
vframes = torch.as_tensor(np.stack(vframes_list))
|
||||
else:
|
||||
vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8)
|
||||
|
||||
if aframes_list:
|
||||
aframes = np.concatenate(aframes_list, 1)
|
||||
aframes = torch.as_tensor(aframes)
|
||||
if pts_unit == "sec":
|
||||
start_pts = int(math.floor(start_pts * (1 / audio_timebase)))
|
||||
if end_pts != float("inf"):
|
||||
end_pts = int(math.ceil(end_pts * (1 / audio_timebase)))
|
||||
aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
|
||||
else:
|
||||
aframes = torch.empty((1, 0), dtype=torch.float32)
|
||||
|
||||
if output_format == "TCHW":
|
||||
# [T,H,W,C] --> [T,C,H,W]
|
||||
vframes = vframes.permute(0, 3, 1, 2)
|
||||
|
||||
return vframes, aframes, info
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
colossalai>=0.3.7
|
||||
mmengine>=0.10.3
|
||||
pandas>=2.2.2
|
||||
pandas>=2.0.3
|
||||
timm==0.9.16
|
||||
rotary_embedding_torch==0.5.3
|
||||
ftfy>=6.2.0 # for t5
|
||||
|
|
@ -18,9 +18,9 @@ ipywidgets>=8.1.2
|
|||
|
||||
# [training]
|
||||
wandb>=0.17.0
|
||||
tensorboard>=2.16.2
|
||||
tensorboard>=2.14.0
|
||||
pandarallel>=1.6.5
|
||||
pyarrow>=16.1.0 # for parquet
|
||||
|
||||
# [dev]
|
||||
pre-commit>=3.7.1
|
||||
pre-commit>=3.5.0
|
||||
|
|
|
|||
|
|
@ -77,6 +77,7 @@ def main():
|
|||
sp_size=cfg.get("sp_size", 1),
|
||||
)
|
||||
booster = Booster(plugin=plugin)
|
||||
torch.set_num_threads(1)
|
||||
|
||||
# ======================================================
|
||||
# 2. build dataset and dataloader
|
||||
|
|
@ -103,6 +104,8 @@ def main():
|
|||
**dataloader_args,
|
||||
)
|
||||
num_steps_per_epoch = len(dataloader)
|
||||
dataiter = iter(dataloader)
|
||||
next(dataiter)
|
||||
|
||||
# ======================================================
|
||||
# 3. build model
|
||||
|
|
|
|||
Loading…
Reference in a new issue