[feat] reduce memory leakage in dataloader and pyav

This commit is contained in:
zhengzangw 2024-06-21 18:23:30 +00:00
parent f3c5f5f533
commit dec17bd990
8 changed files with 222 additions and 182 deletions

View file

@ -0,0 +1,58 @@
# Dataset settings
dataset = dict(
type="VariableVideoTextDataset",
transform_name="resize_crop",
)
# webvid
bucket_config = {"360p": {102: (1.0, 5)}}
grad_checkpoint = True
# Acceleration settings
num_workers = 8
num_bucket_build_workers = 16
dtype = "bf16"
plugin = "zero2"
# Model settings
model = dict(
type="STDiT3-XL/2",
from_pretrained=None,
qk_norm=True,
enable_flash_attn=True,
enable_layernorm_kernel=True,
freeze_y_embedder=True,
)
vae = dict(
type="OpenSoraVAE_V1_2",
from_pretrained="hpcai-tech/OpenSora-VAE-v1.2",
micro_frame_size=17,
micro_batch_size=4,
)
text_encoder = dict(
type="t5",
from_pretrained="DeepFloyd/t5-v1_1-xxl",
model_max_length=300,
shardformer=True,
)
scheduler = dict(
type="rflow",
use_timestep_transform=True,
sample_method="logit-normal",
)
# Log settings
seed = 42
outputs = "outputs"
wandb = False
epochs = 1000
log_every = 10
ckpt_every = 200
# optimization settings
load = None
grad_clip = 1.0
lr = 1e-4
ema_decay = 0.99
adam_eps = 1e-15
warmup_steps = 1000

View file

@ -9,7 +9,7 @@ bucket_config = {"480p": {51: (0.5, 5)}}
grad_checkpoint = True
# Acceleration settings
num_workers = 0
num_workers = 8
num_bucket_build_workers = 16
dtype = "bf16"
plugin = "zero2"
@ -41,21 +41,6 @@ scheduler = dict(
sample_method="logit-normal",
)
# Mask settings
# 25%
mask_ratios = {
"random": 0.01,
"intepolate": 0.002,
"quarter_random": 0.002,
"quarter_head": 0.002,
"quarter_tail": 0.002,
"quarter_head_tail": 0.002,
"image_random": 0.0,
"image_head": 0.22,
"image_tail": 0.005,
"image_head_tail": 0.005,
}
# Log settings
seed = 42
outputs = "outputs"

View file

@ -34,6 +34,7 @@ def prepare_dataloader(
process_group: Optional[ProcessGroup] = None,
bucket_config=None,
num_bucket_build_workers=1,
prefetch_factor=None,
**kwargs,
):
_kwargs = kwargs.copy()
@ -57,6 +58,7 @@ def prepare_dataloader(
pin_memory=pin_memory,
num_workers=num_workers,
collate_fn=collate_fn_default,
prefetch_factor=prefetch_factor,
**_kwargs,
),
batch_sampler,
@ -79,6 +81,7 @@ def prepare_dataloader(
pin_memory=pin_memory,
num_workers=num_workers,
collate_fn=collate_fn_default,
prefetch_factor=prefetch_factor,
**_kwargs,
),
sampler,
@ -98,6 +101,7 @@ def prepare_dataloader(
pin_memory=pin_memory,
num_workers=num_workers,
collate_fn=collate_fn_batch,
prefetch_factor=prefetch_factor,
**_kwargs,
),
sampler,

View file

@ -151,9 +151,11 @@ class VariableVideoTextDataset(VideoTextDataset):
# Sampling video frames
video = temporal_random_crop(vframes, num_frames, self.frame_interval)
video = video.clone()
del vframes
video_fps = video_fps // self.frame_interval
# transform
transform = get_transforms_video(self.transform_name, (height, width))
video = transform(video) # T C H W
@ -169,7 +171,7 @@ class VariableVideoTextDataset(VideoTextDataset):
# repeat
video = image.unsqueeze(0)
# TCHW -> CTHW
# # TCHW -> CTHW
video = video.permute(1, 0, 2, 3)
ret = {
"video": video,

View file

@ -1,20 +1,19 @@
import gc
import math
import os
import re
import warnings
from fractions import Fraction
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import av
import cv2
import numpy as np
import torch
from torchvision.io.video import (
_align_audio_frames,
_check_av_available,
_log_api_usage_once,
_read_from_stream,
_video_opt,
)
from torchvision import get_video_backend
from torchvision.io.video import _check_av_available
MAX_NUM_FRAMES = 2500
def read_video_av(
@ -27,6 +26,13 @@ def read_video_av(
"""
Reads a video from a file, returning both the video frames and the audio frames
This method is modified from torchvision.io.video.read_video, with the following changes:
1. will not extract audio frames and return empty for aframes
2. remove checks and only support pyav
3. add container.close() and gc.collect() to avoid thread leakage
4. try our best to avoid memory leak
Args:
filename (str): path to the video file
start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
@ -42,99 +48,162 @@ def read_video_av(
aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points
info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int)
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(read_video)
# format
output_format = output_format.upper()
if output_format not in ("THWC", "TCHW"):
raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
from torchvision import get_video_backend
# file existence
if not os.path.exists(filename):
raise RuntimeError(f"File not found: {filename}")
# backend check
assert get_video_backend() == "pyav", "pyav backend is required for read_video_av"
_check_av_available()
# end_pts check
if end_pts is None:
end_pts = float("inf")
if end_pts < start_pts:
raise ValueError(f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}")
if get_video_backend() != "pyav":
vframes, aframes, info = _video_opt._read_video(filename, start_pts, end_pts, pts_unit)
else:
_check_av_available()
# == get video info ==
info = {}
# TODO: creating an container leads to memory leak (1G for 8 workers 1 GPU)
container = av.open(filename, metadata_errors="ignore")
# fps
video_fps = container.streams.video[0].average_rate
# guard against potentially corrupted files
if video_fps is not None:
info["video_fps"] = float(video_fps)
iter_video = container.decode(**{"video": 0})
frame = next(iter_video).to_rgb().to_ndarray()
height, width = frame.shape[:2]
total_frames = container.streams.video[0].frames
if total_frames == 0:
total_frames = MAX_NUM_FRAMES
warnings.warn(f"total_frames is 0, using {MAX_NUM_FRAMES} as a fallback")
container.close()
del container
if end_pts is None:
end_pts = float("inf")
# HACK: must create before iterating stream
# use np.zeros will not actually allocate memory
# use np.ones will lead to a little memory leak
video_frames = np.zeros((total_frames, height, width, 3), dtype=np.uint8)
if end_pts < start_pts:
raise ValueError(
f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}"
)
info = {}
video_frames = []
audio_frames = []
audio_timebase = _video_opt.default_timebase
container = av.open(filename, metadata_errors="ignore")
try:
if container.streams.audio:
audio_timebase = container.streams.audio[0].time_base
if container.streams.video:
video_frames = _read_from_stream(
container,
start_pts,
end_pts,
pts_unit,
container.streams.video[0],
{"video": 0},
)
video_fps = container.streams.video[0].average_rate
# guard against potentially corrupted files
if video_fps is not None:
info["video_fps"] = float(video_fps)
if container.streams.audio:
audio_frames = _read_from_stream(
container,
start_pts,
end_pts,
pts_unit,
container.streams.audio[0],
{"audio": 0},
)
info["audio_fps"] = container.streams.audio[0].rate
except av.AVError:
# TODO raise a warning?
pass
finally:
container.close()
del container
# NOTE: manually garbage collect to close pyav threads
gc.collect()
vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames]
aframes_list = [frame.to_ndarray() for frame in audio_frames]
if vframes_list:
vframes = torch.as_tensor(np.stack(vframes_list))
else:
vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8)
if aframes_list:
aframes = np.concatenate(aframes_list, 1)
aframes = torch.as_tensor(aframes)
if pts_unit == "sec":
start_pts = int(math.floor(start_pts * (1 / audio_timebase)))
if end_pts != float("inf"):
end_pts = int(math.ceil(end_pts * (1 / audio_timebase)))
aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
else:
aframes = torch.empty((1, 0), dtype=torch.float32)
# == read ==
# TODO: The reading has memory leak (4G for 8 workers 1 GPU)
container = av.open(filename, metadata_errors="ignore")
assert container.streams.video is not None
video_frames = _read_from_stream(
video_frames,
container,
start_pts,
end_pts,
pts_unit,
container.streams.video[0],
{"video": 0},
)
vframes = torch.from_numpy(video_frames).clone()
del video_frames
if output_format == "TCHW":
# [T,H,W,C] --> [T,C,H,W]
vframes = vframes.permute(0, 3, 1, 2)
aframes = torch.empty((1, 0), dtype=torch.float32)
return vframes, aframes, info
def _read_from_stream(
video_frames,
container: "av.container.Container",
start_offset: float,
end_offset: float,
pts_unit: str,
stream: "av.stream.Stream",
stream_name: Dict[str, Optional[Union[int, Tuple[int, ...], List[int]]]],
) -> List["av.frame.Frame"]:
if pts_unit == "sec":
# TODO: we should change all of this from ground up to simply take
# sec and convert to MS in C++
start_offset = int(math.floor(start_offset * (1 / stream.time_base)))
if end_offset != float("inf"):
end_offset = int(math.ceil(end_offset * (1 / stream.time_base)))
else:
warnings.warn("The pts_unit 'pts' gives wrong results. Please use pts_unit 'sec'.")
should_buffer = True
max_buffer_size = 5
if stream.type == "video":
# DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt)
# so need to buffer some extra frames to sort everything
# properly
extradata = stream.codec_context.extradata
# overly complicated way of finding if `divx_packed` is set, following
# https://github.com/FFmpeg/FFmpeg/commit/d5a21172283572af587b3d939eba0091484d3263
if extradata and b"DivX" in extradata:
# can't use regex directly because of some weird characters sometimes...
pos = extradata.find(b"DivX")
d = extradata[pos:]
o = re.search(rb"DivX(\d+)Build(\d+)(\w)", d)
if o is None:
o = re.search(rb"DivX(\d+)b(\d+)(\w)", d)
if o is not None:
should_buffer = o.group(3) == b"p"
seek_offset = start_offset
# some files don't seek to the right location, so better be safe here
seek_offset = max(seek_offset - 1, 0)
if should_buffer:
# FIXME this is kind of a hack, but we will jump to the previous keyframe
# so this will be safe
seek_offset = max(seek_offset - max_buffer_size, 0)
try:
# TODO check if stream needs to always be the video stream here or not
container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
except av.AVError:
# TODO add some warnings in this case
# print("Corrupted file?", container.name)
return []
# == main ==
buffer_count = 0
frames_pts = []
cnt = 0
for _idx, frame in enumerate(container.decode(**stream_name)):
frames_pts.append(frame.pts)
video_frames[cnt] = frame.to_rgb().to_ndarray()
cnt += 1
if cnt >= len(video_frames):
break
if frame.pts >= end_offset:
if should_buffer and buffer_count < max_buffer_size:
buffer_count += 1
continue
break
# garbage collection for thread leakage
container.close()
del container
# NOTE: manually garbage collect to close pyav threads
gc.collect()
# ensure that the results are sorted wrt the pts
# NOTE: here we assert frames_pts is sorted
start_ptr = 0
end_ptr = cnt
while start_ptr < end_ptr and frames_pts[start_ptr] < start_offset:
start_ptr += 1
while start_ptr < end_ptr and frames_pts[end_ptr - 1] > end_offset:
end_ptr -= 1
if start_offset > 0 and start_offset not in frames_pts[start_ptr:end_ptr]:
# if there is no frame that exactly matches the pts of start_offset
# add the last frame smaller than start_offset, to guarantee that
# we will have all the necessary data. This is most useful for audio
if start_ptr > 0:
start_ptr -= 1
result = video_frames[start_ptr:end_ptr].copy()
return result
def read_video_cv2(video_path):
cap = cv2.VideoCapture(video_path)
@ -181,8 +250,3 @@ def read_video(video_path, backend="av"):
raise ValueError
return vframes, vinfo
if __name__ == "__main__":
vframes, vinfo = read_video("./data/colors/9.mp4", backend="cv2")
x = 0

View file

@ -98,6 +98,7 @@ def main():
drop_last=True,
pin_memory=True,
process_group=get_data_parallel_group(),
prefetch_factor=cfg.get("prefetch_factor", None),
)
dataloader, sampler = prepare_dataloader(
bucket_config=cfg.get("bucket_config", None),
@ -247,7 +248,6 @@ def main():
with Timer("move data") as move_data_t:
x = batch.pop("video").to(device, dtype) # [B, C, T, H, W]
y = batch.pop("text")
timer_list.append(move_data_t)
# == visual and text encoding ==
with Timer("encode") as encode_t:

View file

@ -1,73 +0,0 @@
"""
Implementation of Net2Net (http://arxiv.org/abs/1511.05641)
Numpy modules for Net2Net
- Net2Wider
- Net2Deeper
Written by Kyunghyun Paeng
"""
def net2net(teach_param, stu_param):
# teach param with shape (a, b)
# stu param with shape (c, d)
# net to net (a, b) -> (c, d) where c >= a and d >= b
teach_param_shape = teach_param.shape
stu_param_shape = stu_param.shape
if len(stu_param_shape) > 2:
teach_param = teach_param.reshape(teach_param_shape[0], -1)
stu_param = stu_param.reshape(stu_param_shape[0], -1)
assert len(stu_param.shape) == 1 or len(stu_param.shape) == 2, "teach_param and stu_param must be 2-dim array"
assert len(teach_param_shape) == len(stu_param_shape), "teach_param and stu_param must have same dimension"
if len(teach_param_shape) == 1:
stu_param[: teach_param_shape[0]] = teach_param
elif len(teach_param_shape) == 2:
stu_param[: teach_param_shape[0], : teach_param_shape[1]] = teach_param
else:
breakpoint()
if stu_param.shape != stu_param_shape:
stu_param = stu_param.reshape(stu_param_shape)
return stu_param
if __name__ == "__main__":
"""Net2Net Class Test"""
import torch
from opensora.models.pixart import PixArt_1B_2
model = PixArt_1B_2(no_temporal_pos_emb=True, space_scale=4, enable_flash_attn=True, enable_layernorm_kernel=True)
print("load model done")
ckpt = torch.load("/home/zhouyukun/projs/opensora/pretrained_models/PixArt-Sigma-XL-2-2K-MS.pth")
print("load ckpt done")
ckpt = ckpt["state_dict"]
ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2)
missing_keys = []
for name, module in model.named_parameters():
if name in ckpt:
teach_param = ckpt[name].data
stu_param = module.data
stu_param = net2net(teach_param, stu_param)
module.data = stu_param
print("processing layer: ", name, "shape: ", module.size())
else:
# print("Missing key: ", name)
missing_keys.append(name)
print(missing_keys)
breakpoint()
torch.save({"state_dict": model.state_dict()}, "PixArt-1B-2.pth")