diff --git a/configs/opensora-v1-2/train/demo_360p.py b/configs/opensora-v1-2/train/demo_360p.py new file mode 100644 index 0000000..e27bd3c --- /dev/null +++ b/configs/opensora-v1-2/train/demo_360p.py @@ -0,0 +1,58 @@ +# Dataset settings +dataset = dict( + type="VariableVideoTextDataset", + transform_name="resize_crop", +) + +# webvid +bucket_config = {"360p": {102: (1.0, 5)}} +grad_checkpoint = True + +# Acceleration settings +num_workers = 8 +num_bucket_build_workers = 16 +dtype = "bf16" +plugin = "zero2" + +# Model settings +model = dict( + type="STDiT3-XL/2", + from_pretrained=None, + qk_norm=True, + enable_flash_attn=True, + enable_layernorm_kernel=True, + freeze_y_embedder=True, +) +vae = dict( + type="OpenSoraVAE_V1_2", + from_pretrained="hpcai-tech/OpenSora-VAE-v1.2", + micro_frame_size=17, + micro_batch_size=4, +) +text_encoder = dict( + type="t5", + from_pretrained="DeepFloyd/t5-v1_1-xxl", + model_max_length=300, + shardformer=True, +) +scheduler = dict( + type="rflow", + use_timestep_transform=True, + sample_method="logit-normal", +) + +# Log settings +seed = 42 +outputs = "outputs" +wandb = False +epochs = 1000 +log_every = 10 +ckpt_every = 200 + +# optimization settings +load = None +grad_clip = 1.0 +lr = 1e-4 +ema_decay = 0.99 +adam_eps = 1e-15 +warmup_steps = 1000 diff --git a/configs/opensora-v1-2/train/stage3_480p.py b/configs/opensora-v1-2/train/demo_480p.py similarity index 76% rename from configs/opensora-v1-2/train/stage3_480p.py rename to configs/opensora-v1-2/train/demo_480p.py index b4b9ffd..08121c7 100644 --- a/configs/opensora-v1-2/train/stage3_480p.py +++ b/configs/opensora-v1-2/train/demo_480p.py @@ -9,7 +9,7 @@ bucket_config = {"480p": {51: (0.5, 5)}} grad_checkpoint = True # Acceleration settings -num_workers = 0 +num_workers = 8 num_bucket_build_workers = 16 dtype = "bf16" plugin = "zero2" @@ -41,21 +41,6 @@ scheduler = dict( sample_method="logit-normal", ) -# Mask settings -# 25% -mask_ratios = { - "random": 0.01, - "intepolate": 0.002, - "quarter_random": 0.002, - "quarter_head": 0.002, - "quarter_tail": 0.002, - "quarter_head_tail": 0.002, - "image_random": 0.0, - "image_head": 0.22, - "image_tail": 0.005, - "image_head_tail": 0.005, -} - # Log settings seed = 42 outputs = "outputs" diff --git a/opensora/datasets/dataloader.py b/opensora/datasets/dataloader.py index 8bcaed9..15058ac 100644 --- a/opensora/datasets/dataloader.py +++ b/opensora/datasets/dataloader.py @@ -34,6 +34,7 @@ def prepare_dataloader( process_group: Optional[ProcessGroup] = None, bucket_config=None, num_bucket_build_workers=1, + prefetch_factor=None, **kwargs, ): _kwargs = kwargs.copy() @@ -57,6 +58,7 @@ def prepare_dataloader( pin_memory=pin_memory, num_workers=num_workers, collate_fn=collate_fn_default, + prefetch_factor=prefetch_factor, **_kwargs, ), batch_sampler, @@ -79,6 +81,7 @@ def prepare_dataloader( pin_memory=pin_memory, num_workers=num_workers, collate_fn=collate_fn_default, + prefetch_factor=prefetch_factor, **_kwargs, ), sampler, @@ -98,6 +101,7 @@ def prepare_dataloader( pin_memory=pin_memory, num_workers=num_workers, collate_fn=collate_fn_batch, + prefetch_factor=prefetch_factor, **_kwargs, ), sampler, diff --git a/opensora/datasets/datasets.py b/opensora/datasets/datasets.py index 34a5dcf..fcf070a 100644 --- a/opensora/datasets/datasets.py +++ b/opensora/datasets/datasets.py @@ -151,9 +151,11 @@ class VariableVideoTextDataset(VideoTextDataset): # Sampling video frames video = temporal_random_crop(vframes, num_frames, self.frame_interval) + video = video.clone() + del vframes video_fps = video_fps // self.frame_interval - + # transform transform = get_transforms_video(self.transform_name, (height, width)) video = transform(video) # T C H W @@ -169,7 +171,7 @@ class VariableVideoTextDataset(VideoTextDataset): # repeat video = image.unsqueeze(0) - # TCHW -> CTHW + # # TCHW -> CTHW video = video.permute(1, 0, 2, 3) ret = { "video": video, diff --git a/opensora/datasets/read_video.py b/opensora/datasets/read_video.py index f988c30..ce88f59 100644 --- a/opensora/datasets/read_video.py +++ b/opensora/datasets/read_video.py @@ -1,20 +1,19 @@ import gc import math import os +import re +import warnings from fractions import Fraction -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import av import cv2 import numpy as np import torch -from torchvision.io.video import ( - _align_audio_frames, - _check_av_available, - _log_api_usage_once, - _read_from_stream, - _video_opt, -) +from torchvision import get_video_backend +from torchvision.io.video import _check_av_available + +MAX_NUM_FRAMES = 2500 def read_video_av( @@ -27,6 +26,13 @@ def read_video_av( """ Reads a video from a file, returning both the video frames and the audio frames + This method is modified from torchvision.io.video.read_video, with the following changes: + + 1. will not extract audio frames and return empty for aframes + 2. remove checks and only support pyav + 3. add container.close() and gc.collect() to avoid thread leakage + 4. try our best to avoid memory leak + Args: filename (str): path to the video file start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional): @@ -42,99 +48,162 @@ def read_video_av( aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int) """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(read_video) - + # format output_format = output_format.upper() if output_format not in ("THWC", "TCHW"): raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.") - - from torchvision import get_video_backend - + # file existence if not os.path.exists(filename): raise RuntimeError(f"File not found: {filename}") + # backend check + assert get_video_backend() == "pyav", "pyav backend is required for read_video_av" + _check_av_available() + # end_pts check + if end_pts is None: + end_pts = float("inf") + if end_pts < start_pts: + raise ValueError(f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}") - if get_video_backend() != "pyav": - vframes, aframes, info = _video_opt._read_video(filename, start_pts, end_pts, pts_unit) - else: - _check_av_available() + # == get video info == + info = {} + # TODO: creating an container leads to memory leak (1G for 8 workers 1 GPU) + container = av.open(filename, metadata_errors="ignore") + # fps + video_fps = container.streams.video[0].average_rate + # guard against potentially corrupted files + if video_fps is not None: + info["video_fps"] = float(video_fps) + iter_video = container.decode(**{"video": 0}) + frame = next(iter_video).to_rgb().to_ndarray() + height, width = frame.shape[:2] + total_frames = container.streams.video[0].frames + if total_frames == 0: + total_frames = MAX_NUM_FRAMES + warnings.warn(f"total_frames is 0, using {MAX_NUM_FRAMES} as a fallback") + container.close() + del container - if end_pts is None: - end_pts = float("inf") + # HACK: must create before iterating stream + # use np.zeros will not actually allocate memory + # use np.ones will lead to a little memory leak + video_frames = np.zeros((total_frames, height, width, 3), dtype=np.uint8) - if end_pts < start_pts: - raise ValueError( - f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}" - ) - - info = {} - video_frames = [] - audio_frames = [] - audio_timebase = _video_opt.default_timebase - - container = av.open(filename, metadata_errors="ignore") - try: - if container.streams.audio: - audio_timebase = container.streams.audio[0].time_base - if container.streams.video: - video_frames = _read_from_stream( - container, - start_pts, - end_pts, - pts_unit, - container.streams.video[0], - {"video": 0}, - ) - video_fps = container.streams.video[0].average_rate - # guard against potentially corrupted files - if video_fps is not None: - info["video_fps"] = float(video_fps) - - if container.streams.audio: - audio_frames = _read_from_stream( - container, - start_pts, - end_pts, - pts_unit, - container.streams.audio[0], - {"audio": 0}, - ) - info["audio_fps"] = container.streams.audio[0].rate - except av.AVError: - # TODO raise a warning? - pass - finally: - container.close() - del container - # NOTE: manually garbage collect to close pyav threads - gc.collect() - - vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames] - aframes_list = [frame.to_ndarray() for frame in audio_frames] - - if vframes_list: - vframes = torch.as_tensor(np.stack(vframes_list)) - else: - vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8) - - if aframes_list: - aframes = np.concatenate(aframes_list, 1) - aframes = torch.as_tensor(aframes) - if pts_unit == "sec": - start_pts = int(math.floor(start_pts * (1 / audio_timebase))) - if end_pts != float("inf"): - end_pts = int(math.ceil(end_pts * (1 / audio_timebase))) - aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts) - else: - aframes = torch.empty((1, 0), dtype=torch.float32) + # == read == + # TODO: The reading has memory leak (4G for 8 workers 1 GPU) + container = av.open(filename, metadata_errors="ignore") + assert container.streams.video is not None + video_frames = _read_from_stream( + video_frames, + container, + start_pts, + end_pts, + pts_unit, + container.streams.video[0], + {"video": 0}, + ) + vframes = torch.from_numpy(video_frames).clone() + del video_frames if output_format == "TCHW": # [T,H,W,C] --> [T,C,H,W] vframes = vframes.permute(0, 3, 1, 2) + aframes = torch.empty((1, 0), dtype=torch.float32) return vframes, aframes, info +def _read_from_stream( + video_frames, + container: "av.container.Container", + start_offset: float, + end_offset: float, + pts_unit: str, + stream: "av.stream.Stream", + stream_name: Dict[str, Optional[Union[int, Tuple[int, ...], List[int]]]], +) -> List["av.frame.Frame"]: + + if pts_unit == "sec": + # TODO: we should change all of this from ground up to simply take + # sec and convert to MS in C++ + start_offset = int(math.floor(start_offset * (1 / stream.time_base))) + if end_offset != float("inf"): + end_offset = int(math.ceil(end_offset * (1 / stream.time_base))) + else: + warnings.warn("The pts_unit 'pts' gives wrong results. Please use pts_unit 'sec'.") + + should_buffer = True + max_buffer_size = 5 + if stream.type == "video": + # DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt) + # so need to buffer some extra frames to sort everything + # properly + extradata = stream.codec_context.extradata + # overly complicated way of finding if `divx_packed` is set, following + # https://github.com/FFmpeg/FFmpeg/commit/d5a21172283572af587b3d939eba0091484d3263 + if extradata and b"DivX" in extradata: + # can't use regex directly because of some weird characters sometimes... + pos = extradata.find(b"DivX") + d = extradata[pos:] + o = re.search(rb"DivX(\d+)Build(\d+)(\w)", d) + if o is None: + o = re.search(rb"DivX(\d+)b(\d+)(\w)", d) + if o is not None: + should_buffer = o.group(3) == b"p" + seek_offset = start_offset + # some files don't seek to the right location, so better be safe here + seek_offset = max(seek_offset - 1, 0) + if should_buffer: + # FIXME this is kind of a hack, but we will jump to the previous keyframe + # so this will be safe + seek_offset = max(seek_offset - max_buffer_size, 0) + try: + # TODO check if stream needs to always be the video stream here or not + container.seek(seek_offset, any_frame=False, backward=True, stream=stream) + except av.AVError: + # TODO add some warnings in this case + # print("Corrupted file?", container.name) + return [] + + # == main == + buffer_count = 0 + frames_pts = [] + cnt = 0 + for _idx, frame in enumerate(container.decode(**stream_name)): + frames_pts.append(frame.pts) + video_frames[cnt] = frame.to_rgb().to_ndarray() + cnt += 1 + if cnt >= len(video_frames): + break + if frame.pts >= end_offset: + if should_buffer and buffer_count < max_buffer_size: + buffer_count += 1 + continue + break + + # garbage collection for thread leakage + container.close() + del container + # NOTE: manually garbage collect to close pyav threads + gc.collect() + + # ensure that the results are sorted wrt the pts + # NOTE: here we assert frames_pts is sorted + start_ptr = 0 + end_ptr = cnt + while start_ptr < end_ptr and frames_pts[start_ptr] < start_offset: + start_ptr += 1 + while start_ptr < end_ptr and frames_pts[end_ptr - 1] > end_offset: + end_ptr -= 1 + if start_offset > 0 and start_offset not in frames_pts[start_ptr:end_ptr]: + # if there is no frame that exactly matches the pts of start_offset + # add the last frame smaller than start_offset, to guarantee that + # we will have all the necessary data. This is most useful for audio + if start_ptr > 0: + start_ptr -= 1 + result = video_frames[start_ptr:end_ptr].copy() + return result + + def read_video_cv2(video_path): cap = cv2.VideoCapture(video_path) @@ -181,8 +250,3 @@ def read_video(video_path, backend="av"): raise ValueError return vframes, vinfo - - -if __name__ == "__main__": - vframes, vinfo = read_video("./data/colors/9.mp4", backend="cv2") - x = 0 diff --git a/scripts/train.py b/scripts/train.py index 2ebdd41..574fc0b 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -98,6 +98,7 @@ def main(): drop_last=True, pin_memory=True, process_group=get_data_parallel_group(), + prefetch_factor=cfg.get("prefetch_factor", None), ) dataloader, sampler = prepare_dataloader( bucket_config=cfg.get("bucket_config", None), @@ -247,7 +248,6 @@ def main(): with Timer("move data") as move_data_t: x = batch.pop("video").to(device, dtype) # [B, C, T, H, W] y = batch.pop("text") - timer_list.append(move_data_t) # == visual and text encoding == with Timer("encode") as encode_t: diff --git a/tools/architecture/__init__.py b/tools/architecture/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tools/architecture/net2net.py b/tools/architecture/net2net.py deleted file mode 100644 index d5d7eb3..0000000 --- a/tools/architecture/net2net.py +++ /dev/null @@ -1,73 +0,0 @@ -""" -Implementation of Net2Net (http://arxiv.org/abs/1511.05641) -Numpy modules for Net2Net -- Net2Wider -- Net2Deeper - -Written by Kyunghyun Paeng - -""" - - -def net2net(teach_param, stu_param): - # teach param with shape (a, b) - # stu param with shape (c, d) - # net to net (a, b) -> (c, d) where c >= a and d >= b - teach_param_shape = teach_param.shape - stu_param_shape = stu_param.shape - - if len(stu_param_shape) > 2: - teach_param = teach_param.reshape(teach_param_shape[0], -1) - stu_param = stu_param.reshape(stu_param_shape[0], -1) - - assert len(stu_param.shape) == 1 or len(stu_param.shape) == 2, "teach_param and stu_param must be 2-dim array" - assert len(teach_param_shape) == len(stu_param_shape), "teach_param and stu_param must have same dimension" - - if len(teach_param_shape) == 1: - stu_param[: teach_param_shape[0]] = teach_param - elif len(teach_param_shape) == 2: - stu_param[: teach_param_shape[0], : teach_param_shape[1]] = teach_param - else: - breakpoint() - - if stu_param.shape != stu_param_shape: - stu_param = stu_param.reshape(stu_param_shape) - - return stu_param - - -if __name__ == "__main__": - """Net2Net Class Test""" - - import torch - - from opensora.models.pixart import PixArt_1B_2 - - model = PixArt_1B_2(no_temporal_pos_emb=True, space_scale=4, enable_flash_attn=True, enable_layernorm_kernel=True) - print("load model done") - - ckpt = torch.load("/home/zhouyukun/projs/opensora/pretrained_models/PixArt-Sigma-XL-2-2K-MS.pth") - print("load ckpt done") - - ckpt = ckpt["state_dict"] - ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2) - - missing_keys = [] - for name, module in model.named_parameters(): - if name in ckpt: - teach_param = ckpt[name].data - stu_param = module.data - stu_param = net2net(teach_param, stu_param) - - module.data = stu_param - - print("processing layer: ", name, "shape: ", module.size()) - - else: - # print("Missing key: ", name) - missing_keys.append(name) - - print(missing_keys) - - breakpoint() - torch.save({"state_dict": model.state_dict()}, "PixArt-1B-2.pth")