read video

This commit is contained in:
zhengzangw 2024-05-30 08:09:03 +00:00
parent 9370fbe02d
commit 8a686660bc
6 changed files with 139 additions and 6 deletions

BIN
.swo Normal file

Binary file not shown.

View file

@ -8,7 +8,7 @@ dataset = dict(
)
# Define acceleration
num_workers = 4
num_workers = 0
dtype = "bf16"
grad_checkpoint = True
plugin = "zero2"

View file

@ -9,6 +9,7 @@ from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
from opensora.registry import DATASETS
from .utils import VID_EXTENSIONS, get_transforms_image, get_transforms_video, read_file, temporal_random_crop
from .read_video import read_video
IMG_FPS = 120
@ -67,7 +68,7 @@ class VideoTextDataset(torch.utils.data.Dataset):
if file_type == "video":
# loading
vframes, _, infos = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
vframes, _, infos = read_video(filename=path, pts_unit="sec", output_format="TCHW")
if "video_fps" in infos:
video_fps = infos["video_fps"]
@ -145,7 +146,7 @@ class VariableVideoTextDataset(VideoTextDataset):
video_fps = 24 # default fps
if file_type == "video":
# loading
vframes, _, infos = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
vframes, _, infos = read_video(filename=path, pts_unit="sec", output_format="TCHW")
if "video_fps" in infos:
video_fps = infos["video_fps"]

View file

@ -0,0 +1,129 @@
import math
import os
from fractions import Fraction
from typing import Any, Dict, List, Optional, Tuple, Union
import av
import numpy as np
import torch
from torchvision.io.video import (
_align_audio_frames,
_check_av_available,
_log_api_usage_once,
_read_from_stream,
_video_opt,
)
def read_video(
filename: str,
start_pts: Union[float, Fraction] = 0,
end_pts: Optional[Union[float, Fraction]] = None,
pts_unit: str = "pts",
output_format: str = "THWC",
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
"""
Reads a video from a file, returning both the video frames and the audio frames
Args:
filename (str): path to the video file
start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
The start presentation time of the video
end_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
The end presentation time
pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted,
either 'pts' or 'sec'. Defaults to 'pts'.
output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW".
Returns:
vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames
aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points
info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int)
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(read_video)
output_format = output_format.upper()
if output_format not in ("THWC", "TCHW"):
raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
from torchvision import get_video_backend
if not os.path.exists(filename):
raise RuntimeError(f"File not found: {filename}")
if get_video_backend() != "pyav":
vframes, aframes, info = _video_opt._read_video(filename, start_pts, end_pts, pts_unit)
else:
_check_av_available()
if end_pts is None:
end_pts = float("inf")
if end_pts < start_pts:
raise ValueError(
f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}"
)
info = {}
video_frames = []
audio_frames = []
audio_timebase = _video_opt.default_timebase
try:
with av.open(filename, metadata_errors="ignore") as container:
if container.streams.audio:
audio_timebase = container.streams.audio[0].time_base
if container.streams.video:
video_frames = _read_from_stream(
container,
start_pts,
end_pts,
pts_unit,
container.streams.video[0],
{"video": 0},
)
video_fps = container.streams.video[0].average_rate
# guard against potentially corrupted files
if video_fps is not None:
info["video_fps"] = float(video_fps)
if container.streams.audio:
audio_frames = _read_from_stream(
container,
start_pts,
end_pts,
pts_unit,
container.streams.audio[0],
{"audio": 0},
)
info["audio_fps"] = container.streams.audio[0].rate
except av.AVError:
# TODO raise a warning?
pass
vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames]
aframes_list = [frame.to_ndarray() for frame in audio_frames]
if vframes_list:
vframes = torch.as_tensor(np.stack(vframes_list))
else:
vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8)
if aframes_list:
aframes = np.concatenate(aframes_list, 1)
aframes = torch.as_tensor(aframes)
if pts_unit == "sec":
start_pts = int(math.floor(start_pts * (1 / audio_timebase)))
if end_pts != float("inf"):
end_pts = int(math.ceil(end_pts * (1 / audio_timebase)))
aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
else:
aframes = torch.empty((1, 0), dtype=torch.float32)
if output_format == "TCHW":
# [T,H,W,C] --> [T,C,H,W]
vframes = vframes.permute(0, 3, 1, 2)
return vframes, aframes, info

View file

@ -1,6 +1,6 @@
colossalai>=0.3.7
mmengine>=0.10.3
pandas>=2.2.2
pandas>=2.0.3
timm==0.9.16
rotary_embedding_torch==0.5.3
ftfy>=6.2.0 # for t5
@ -18,9 +18,9 @@ ipywidgets>=8.1.2
# [training]
wandb>=0.17.0
tensorboard>=2.16.2
tensorboard>=2.14.0
pandarallel>=1.6.5
pyarrow>=16.1.0 # for parquet
# [dev]
pre-commit>=3.7.1
pre-commit>=3.5.0

View file

@ -77,6 +77,7 @@ def main():
sp_size=cfg.get("sp_size", 1),
)
booster = Booster(plugin=plugin)
torch.set_num_threads(1)
# ======================================================
# 2. build dataset and dataloader
@ -103,6 +104,8 @@ def main():
**dataloader_args,
)
num_steps_per_epoch = len(dataloader)
dataiter = iter(dataloader)
next(dataiter)
# ======================================================
# 3. build model