mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
[feat] reduce memory leakage in dataloader and pyav
This commit is contained in:
parent
f3c5f5f533
commit
dec17bd990
58
configs/opensora-v1-2/train/demo_360p.py
Normal file
58
configs/opensora-v1-2/train/demo_360p.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
# Dataset settings
|
||||
dataset = dict(
|
||||
type="VariableVideoTextDataset",
|
||||
transform_name="resize_crop",
|
||||
)
|
||||
|
||||
# webvid
|
||||
bucket_config = {"360p": {102: (1.0, 5)}}
|
||||
grad_checkpoint = True
|
||||
|
||||
# Acceleration settings
|
||||
num_workers = 8
|
||||
num_bucket_build_workers = 16
|
||||
dtype = "bf16"
|
||||
plugin = "zero2"
|
||||
|
||||
# Model settings
|
||||
model = dict(
|
||||
type="STDiT3-XL/2",
|
||||
from_pretrained=None,
|
||||
qk_norm=True,
|
||||
enable_flash_attn=True,
|
||||
enable_layernorm_kernel=True,
|
||||
freeze_y_embedder=True,
|
||||
)
|
||||
vae = dict(
|
||||
type="OpenSoraVAE_V1_2",
|
||||
from_pretrained="hpcai-tech/OpenSora-VAE-v1.2",
|
||||
micro_frame_size=17,
|
||||
micro_batch_size=4,
|
||||
)
|
||||
text_encoder = dict(
|
||||
type="t5",
|
||||
from_pretrained="DeepFloyd/t5-v1_1-xxl",
|
||||
model_max_length=300,
|
||||
shardformer=True,
|
||||
)
|
||||
scheduler = dict(
|
||||
type="rflow",
|
||||
use_timestep_transform=True,
|
||||
sample_method="logit-normal",
|
||||
)
|
||||
|
||||
# Log settings
|
||||
seed = 42
|
||||
outputs = "outputs"
|
||||
wandb = False
|
||||
epochs = 1000
|
||||
log_every = 10
|
||||
ckpt_every = 200
|
||||
|
||||
# optimization settings
|
||||
load = None
|
||||
grad_clip = 1.0
|
||||
lr = 1e-4
|
||||
ema_decay = 0.99
|
||||
adam_eps = 1e-15
|
||||
warmup_steps = 1000
|
||||
|
|
@ -9,7 +9,7 @@ bucket_config = {"480p": {51: (0.5, 5)}}
|
|||
grad_checkpoint = True
|
||||
|
||||
# Acceleration settings
|
||||
num_workers = 0
|
||||
num_workers = 8
|
||||
num_bucket_build_workers = 16
|
||||
dtype = "bf16"
|
||||
plugin = "zero2"
|
||||
|
|
@ -41,21 +41,6 @@ scheduler = dict(
|
|||
sample_method="logit-normal",
|
||||
)
|
||||
|
||||
# Mask settings
|
||||
# 25%
|
||||
mask_ratios = {
|
||||
"random": 0.01,
|
||||
"intepolate": 0.002,
|
||||
"quarter_random": 0.002,
|
||||
"quarter_head": 0.002,
|
||||
"quarter_tail": 0.002,
|
||||
"quarter_head_tail": 0.002,
|
||||
"image_random": 0.0,
|
||||
"image_head": 0.22,
|
||||
"image_tail": 0.005,
|
||||
"image_head_tail": 0.005,
|
||||
}
|
||||
|
||||
# Log settings
|
||||
seed = 42
|
||||
outputs = "outputs"
|
||||
|
|
@ -34,6 +34,7 @@ def prepare_dataloader(
|
|||
process_group: Optional[ProcessGroup] = None,
|
||||
bucket_config=None,
|
||||
num_bucket_build_workers=1,
|
||||
prefetch_factor=None,
|
||||
**kwargs,
|
||||
):
|
||||
_kwargs = kwargs.copy()
|
||||
|
|
@ -57,6 +58,7 @@ def prepare_dataloader(
|
|||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
collate_fn=collate_fn_default,
|
||||
prefetch_factor=prefetch_factor,
|
||||
**_kwargs,
|
||||
),
|
||||
batch_sampler,
|
||||
|
|
@ -79,6 +81,7 @@ def prepare_dataloader(
|
|||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
collate_fn=collate_fn_default,
|
||||
prefetch_factor=prefetch_factor,
|
||||
**_kwargs,
|
||||
),
|
||||
sampler,
|
||||
|
|
@ -98,6 +101,7 @@ def prepare_dataloader(
|
|||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
collate_fn=collate_fn_batch,
|
||||
prefetch_factor=prefetch_factor,
|
||||
**_kwargs,
|
||||
),
|
||||
sampler,
|
||||
|
|
|
|||
|
|
@ -151,9 +151,11 @@ class VariableVideoTextDataset(VideoTextDataset):
|
|||
|
||||
# Sampling video frames
|
||||
video = temporal_random_crop(vframes, num_frames, self.frame_interval)
|
||||
video = video.clone()
|
||||
del vframes
|
||||
|
||||
video_fps = video_fps // self.frame_interval
|
||||
|
||||
|
||||
# transform
|
||||
transform = get_transforms_video(self.transform_name, (height, width))
|
||||
video = transform(video) # T C H W
|
||||
|
|
@ -169,7 +171,7 @@ class VariableVideoTextDataset(VideoTextDataset):
|
|||
# repeat
|
||||
video = image.unsqueeze(0)
|
||||
|
||||
# TCHW -> CTHW
|
||||
# # TCHW -> CTHW
|
||||
video = video.permute(1, 0, 2, 3)
|
||||
ret = {
|
||||
"video": video,
|
||||
|
|
|
|||
|
|
@ -1,20 +1,19 @@
|
|||
import gc
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
from fractions import Fraction
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import av
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from torchvision.io.video import (
|
||||
_align_audio_frames,
|
||||
_check_av_available,
|
||||
_log_api_usage_once,
|
||||
_read_from_stream,
|
||||
_video_opt,
|
||||
)
|
||||
from torchvision import get_video_backend
|
||||
from torchvision.io.video import _check_av_available
|
||||
|
||||
MAX_NUM_FRAMES = 2500
|
||||
|
||||
|
||||
def read_video_av(
|
||||
|
|
@ -27,6 +26,13 @@ def read_video_av(
|
|||
"""
|
||||
Reads a video from a file, returning both the video frames and the audio frames
|
||||
|
||||
This method is modified from torchvision.io.video.read_video, with the following changes:
|
||||
|
||||
1. will not extract audio frames and return empty for aframes
|
||||
2. remove checks and only support pyav
|
||||
3. add container.close() and gc.collect() to avoid thread leakage
|
||||
4. try our best to avoid memory leak
|
||||
|
||||
Args:
|
||||
filename (str): path to the video file
|
||||
start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
|
||||
|
|
@ -42,99 +48,162 @@ def read_video_av(
|
|||
aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points
|
||||
info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int)
|
||||
"""
|
||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||
_log_api_usage_once(read_video)
|
||||
|
||||
# format
|
||||
output_format = output_format.upper()
|
||||
if output_format not in ("THWC", "TCHW"):
|
||||
raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
|
||||
|
||||
from torchvision import get_video_backend
|
||||
|
||||
# file existence
|
||||
if not os.path.exists(filename):
|
||||
raise RuntimeError(f"File not found: {filename}")
|
||||
# backend check
|
||||
assert get_video_backend() == "pyav", "pyav backend is required for read_video_av"
|
||||
_check_av_available()
|
||||
# end_pts check
|
||||
if end_pts is None:
|
||||
end_pts = float("inf")
|
||||
if end_pts < start_pts:
|
||||
raise ValueError(f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}")
|
||||
|
||||
if get_video_backend() != "pyav":
|
||||
vframes, aframes, info = _video_opt._read_video(filename, start_pts, end_pts, pts_unit)
|
||||
else:
|
||||
_check_av_available()
|
||||
# == get video info ==
|
||||
info = {}
|
||||
# TODO: creating an container leads to memory leak (1G for 8 workers 1 GPU)
|
||||
container = av.open(filename, metadata_errors="ignore")
|
||||
# fps
|
||||
video_fps = container.streams.video[0].average_rate
|
||||
# guard against potentially corrupted files
|
||||
if video_fps is not None:
|
||||
info["video_fps"] = float(video_fps)
|
||||
iter_video = container.decode(**{"video": 0})
|
||||
frame = next(iter_video).to_rgb().to_ndarray()
|
||||
height, width = frame.shape[:2]
|
||||
total_frames = container.streams.video[0].frames
|
||||
if total_frames == 0:
|
||||
total_frames = MAX_NUM_FRAMES
|
||||
warnings.warn(f"total_frames is 0, using {MAX_NUM_FRAMES} as a fallback")
|
||||
container.close()
|
||||
del container
|
||||
|
||||
if end_pts is None:
|
||||
end_pts = float("inf")
|
||||
# HACK: must create before iterating stream
|
||||
# use np.zeros will not actually allocate memory
|
||||
# use np.ones will lead to a little memory leak
|
||||
video_frames = np.zeros((total_frames, height, width, 3), dtype=np.uint8)
|
||||
|
||||
if end_pts < start_pts:
|
||||
raise ValueError(
|
||||
f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}"
|
||||
)
|
||||
|
||||
info = {}
|
||||
video_frames = []
|
||||
audio_frames = []
|
||||
audio_timebase = _video_opt.default_timebase
|
||||
|
||||
container = av.open(filename, metadata_errors="ignore")
|
||||
try:
|
||||
if container.streams.audio:
|
||||
audio_timebase = container.streams.audio[0].time_base
|
||||
if container.streams.video:
|
||||
video_frames = _read_from_stream(
|
||||
container,
|
||||
start_pts,
|
||||
end_pts,
|
||||
pts_unit,
|
||||
container.streams.video[0],
|
||||
{"video": 0},
|
||||
)
|
||||
video_fps = container.streams.video[0].average_rate
|
||||
# guard against potentially corrupted files
|
||||
if video_fps is not None:
|
||||
info["video_fps"] = float(video_fps)
|
||||
|
||||
if container.streams.audio:
|
||||
audio_frames = _read_from_stream(
|
||||
container,
|
||||
start_pts,
|
||||
end_pts,
|
||||
pts_unit,
|
||||
container.streams.audio[0],
|
||||
{"audio": 0},
|
||||
)
|
||||
info["audio_fps"] = container.streams.audio[0].rate
|
||||
except av.AVError:
|
||||
# TODO raise a warning?
|
||||
pass
|
||||
finally:
|
||||
container.close()
|
||||
del container
|
||||
# NOTE: manually garbage collect to close pyav threads
|
||||
gc.collect()
|
||||
|
||||
vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames]
|
||||
aframes_list = [frame.to_ndarray() for frame in audio_frames]
|
||||
|
||||
if vframes_list:
|
||||
vframes = torch.as_tensor(np.stack(vframes_list))
|
||||
else:
|
||||
vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8)
|
||||
|
||||
if aframes_list:
|
||||
aframes = np.concatenate(aframes_list, 1)
|
||||
aframes = torch.as_tensor(aframes)
|
||||
if pts_unit == "sec":
|
||||
start_pts = int(math.floor(start_pts * (1 / audio_timebase)))
|
||||
if end_pts != float("inf"):
|
||||
end_pts = int(math.ceil(end_pts * (1 / audio_timebase)))
|
||||
aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
|
||||
else:
|
||||
aframes = torch.empty((1, 0), dtype=torch.float32)
|
||||
# == read ==
|
||||
# TODO: The reading has memory leak (4G for 8 workers 1 GPU)
|
||||
container = av.open(filename, metadata_errors="ignore")
|
||||
assert container.streams.video is not None
|
||||
video_frames = _read_from_stream(
|
||||
video_frames,
|
||||
container,
|
||||
start_pts,
|
||||
end_pts,
|
||||
pts_unit,
|
||||
container.streams.video[0],
|
||||
{"video": 0},
|
||||
)
|
||||
|
||||
vframes = torch.from_numpy(video_frames).clone()
|
||||
del video_frames
|
||||
if output_format == "TCHW":
|
||||
# [T,H,W,C] --> [T,C,H,W]
|
||||
vframes = vframes.permute(0, 3, 1, 2)
|
||||
|
||||
aframes = torch.empty((1, 0), dtype=torch.float32)
|
||||
return vframes, aframes, info
|
||||
|
||||
|
||||
def _read_from_stream(
|
||||
video_frames,
|
||||
container: "av.container.Container",
|
||||
start_offset: float,
|
||||
end_offset: float,
|
||||
pts_unit: str,
|
||||
stream: "av.stream.Stream",
|
||||
stream_name: Dict[str, Optional[Union[int, Tuple[int, ...], List[int]]]],
|
||||
) -> List["av.frame.Frame"]:
|
||||
|
||||
if pts_unit == "sec":
|
||||
# TODO: we should change all of this from ground up to simply take
|
||||
# sec and convert to MS in C++
|
||||
start_offset = int(math.floor(start_offset * (1 / stream.time_base)))
|
||||
if end_offset != float("inf"):
|
||||
end_offset = int(math.ceil(end_offset * (1 / stream.time_base)))
|
||||
else:
|
||||
warnings.warn("The pts_unit 'pts' gives wrong results. Please use pts_unit 'sec'.")
|
||||
|
||||
should_buffer = True
|
||||
max_buffer_size = 5
|
||||
if stream.type == "video":
|
||||
# DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt)
|
||||
# so need to buffer some extra frames to sort everything
|
||||
# properly
|
||||
extradata = stream.codec_context.extradata
|
||||
# overly complicated way of finding if `divx_packed` is set, following
|
||||
# https://github.com/FFmpeg/FFmpeg/commit/d5a21172283572af587b3d939eba0091484d3263
|
||||
if extradata and b"DivX" in extradata:
|
||||
# can't use regex directly because of some weird characters sometimes...
|
||||
pos = extradata.find(b"DivX")
|
||||
d = extradata[pos:]
|
||||
o = re.search(rb"DivX(\d+)Build(\d+)(\w)", d)
|
||||
if o is None:
|
||||
o = re.search(rb"DivX(\d+)b(\d+)(\w)", d)
|
||||
if o is not None:
|
||||
should_buffer = o.group(3) == b"p"
|
||||
seek_offset = start_offset
|
||||
# some files don't seek to the right location, so better be safe here
|
||||
seek_offset = max(seek_offset - 1, 0)
|
||||
if should_buffer:
|
||||
# FIXME this is kind of a hack, but we will jump to the previous keyframe
|
||||
# so this will be safe
|
||||
seek_offset = max(seek_offset - max_buffer_size, 0)
|
||||
try:
|
||||
# TODO check if stream needs to always be the video stream here or not
|
||||
container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
|
||||
except av.AVError:
|
||||
# TODO add some warnings in this case
|
||||
# print("Corrupted file?", container.name)
|
||||
return []
|
||||
|
||||
# == main ==
|
||||
buffer_count = 0
|
||||
frames_pts = []
|
||||
cnt = 0
|
||||
for _idx, frame in enumerate(container.decode(**stream_name)):
|
||||
frames_pts.append(frame.pts)
|
||||
video_frames[cnt] = frame.to_rgb().to_ndarray()
|
||||
cnt += 1
|
||||
if cnt >= len(video_frames):
|
||||
break
|
||||
if frame.pts >= end_offset:
|
||||
if should_buffer and buffer_count < max_buffer_size:
|
||||
buffer_count += 1
|
||||
continue
|
||||
break
|
||||
|
||||
# garbage collection for thread leakage
|
||||
container.close()
|
||||
del container
|
||||
# NOTE: manually garbage collect to close pyav threads
|
||||
gc.collect()
|
||||
|
||||
# ensure that the results are sorted wrt the pts
|
||||
# NOTE: here we assert frames_pts is sorted
|
||||
start_ptr = 0
|
||||
end_ptr = cnt
|
||||
while start_ptr < end_ptr and frames_pts[start_ptr] < start_offset:
|
||||
start_ptr += 1
|
||||
while start_ptr < end_ptr and frames_pts[end_ptr - 1] > end_offset:
|
||||
end_ptr -= 1
|
||||
if start_offset > 0 and start_offset not in frames_pts[start_ptr:end_ptr]:
|
||||
# if there is no frame that exactly matches the pts of start_offset
|
||||
# add the last frame smaller than start_offset, to guarantee that
|
||||
# we will have all the necessary data. This is most useful for audio
|
||||
if start_ptr > 0:
|
||||
start_ptr -= 1
|
||||
result = video_frames[start_ptr:end_ptr].copy()
|
||||
return result
|
||||
|
||||
|
||||
def read_video_cv2(video_path):
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
|
||||
|
|
@ -181,8 +250,3 @@ def read_video(video_path, backend="av"):
|
|||
raise ValueError
|
||||
|
||||
return vframes, vinfo
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
vframes, vinfo = read_video("./data/colors/9.mp4", backend="cv2")
|
||||
x = 0
|
||||
|
|
|
|||
|
|
@ -98,6 +98,7 @@ def main():
|
|||
drop_last=True,
|
||||
pin_memory=True,
|
||||
process_group=get_data_parallel_group(),
|
||||
prefetch_factor=cfg.get("prefetch_factor", None),
|
||||
)
|
||||
dataloader, sampler = prepare_dataloader(
|
||||
bucket_config=cfg.get("bucket_config", None),
|
||||
|
|
@ -247,7 +248,6 @@ def main():
|
|||
with Timer("move data") as move_data_t:
|
||||
x = batch.pop("video").to(device, dtype) # [B, C, T, H, W]
|
||||
y = batch.pop("text")
|
||||
timer_list.append(move_data_t)
|
||||
|
||||
# == visual and text encoding ==
|
||||
with Timer("encode") as encode_t:
|
||||
|
|
|
|||
|
|
@ -1,73 +0,0 @@
|
|||
"""
|
||||
Implementation of Net2Net (http://arxiv.org/abs/1511.05641)
|
||||
Numpy modules for Net2Net
|
||||
- Net2Wider
|
||||
- Net2Deeper
|
||||
|
||||
Written by Kyunghyun Paeng
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def net2net(teach_param, stu_param):
|
||||
# teach param with shape (a, b)
|
||||
# stu param with shape (c, d)
|
||||
# net to net (a, b) -> (c, d) where c >= a and d >= b
|
||||
teach_param_shape = teach_param.shape
|
||||
stu_param_shape = stu_param.shape
|
||||
|
||||
if len(stu_param_shape) > 2:
|
||||
teach_param = teach_param.reshape(teach_param_shape[0], -1)
|
||||
stu_param = stu_param.reshape(stu_param_shape[0], -1)
|
||||
|
||||
assert len(stu_param.shape) == 1 or len(stu_param.shape) == 2, "teach_param and stu_param must be 2-dim array"
|
||||
assert len(teach_param_shape) == len(stu_param_shape), "teach_param and stu_param must have same dimension"
|
||||
|
||||
if len(teach_param_shape) == 1:
|
||||
stu_param[: teach_param_shape[0]] = teach_param
|
||||
elif len(teach_param_shape) == 2:
|
||||
stu_param[: teach_param_shape[0], : teach_param_shape[1]] = teach_param
|
||||
else:
|
||||
breakpoint()
|
||||
|
||||
if stu_param.shape != stu_param_shape:
|
||||
stu_param = stu_param.reshape(stu_param_shape)
|
||||
|
||||
return stu_param
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""Net2Net Class Test"""
|
||||
|
||||
import torch
|
||||
|
||||
from opensora.models.pixart import PixArt_1B_2
|
||||
|
||||
model = PixArt_1B_2(no_temporal_pos_emb=True, space_scale=4, enable_flash_attn=True, enable_layernorm_kernel=True)
|
||||
print("load model done")
|
||||
|
||||
ckpt = torch.load("/home/zhouyukun/projs/opensora/pretrained_models/PixArt-Sigma-XL-2-2K-MS.pth")
|
||||
print("load ckpt done")
|
||||
|
||||
ckpt = ckpt["state_dict"]
|
||||
ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2)
|
||||
|
||||
missing_keys = []
|
||||
for name, module in model.named_parameters():
|
||||
if name in ckpt:
|
||||
teach_param = ckpt[name].data
|
||||
stu_param = module.data
|
||||
stu_param = net2net(teach_param, stu_param)
|
||||
|
||||
module.data = stu_param
|
||||
|
||||
print("processing layer: ", name, "shape: ", module.size())
|
||||
|
||||
else:
|
||||
# print("Missing key: ", name)
|
||||
missing_keys.append(name)
|
||||
|
||||
print(missing_keys)
|
||||
|
||||
breakpoint()
|
||||
torch.save({"state_dict": model.state_dict()}, "PixArt-1B-2.pth")
|
||||
Loading…
Reference in a new issue