diff --git a/.gitignore b/.gitignore index cd88597..f0a41f1 100644 --- a/.gitignore +++ b/.gitignore @@ -191,3 +191,6 @@ eval/vae/flolpips/weights/ node_modules/ package-lock.json package.json + + +tools/caption/pllava_dir/PLLaVA/ diff --git a/eval/loss/eval_loss.py b/eval/loss/eval_loss.py index 4df070c..a90d400 100644 --- a/eval/loss/eval_loss.py +++ b/eval/loss/eval_loss.py @@ -7,8 +7,8 @@ from colossalai.cluster import DistCoordinator from mmengine.runner import set_random_seed from tqdm import tqdm -from opensora.acceleration.parallel_states import get_data_parallel_group -from opensora.datasets import prepare_variable_dataloader +from opensora.acceleration.parallel_states import get_data_parallel_group, set_data_parallel_group +from opensora.datasets.dataloader import prepare_dataloader from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module from opensora.utils.config_utils import parse_configs from opensora.utils.misc import create_logger, to_torch_dtype @@ -43,6 +43,7 @@ def main(): colossalai.launch_from_torch({}) DistCoordinator() set_random_seed(seed=cfg.get("seed", 1024)) + set_data_parallel_group(dist.group.WORLD) # == init logger == logger = create_logger() @@ -101,11 +102,8 @@ def main(): pin_memory=True, process_group=get_data_parallel_group(), ) - dataloader = prepare_variable_dataloader( - bucket_config=bucket_config, - **dataloader_args, - ) - num_batch = dataloader.batch_sampler.get_num_batch() + dataloader, sampler = prepare_dataloader(bucket_config=bucket_config, **dataloader_args) + num_batch = sampler.get_num_batch() num_steps_per_epoch = num_batch // dist.get_world_size() return dataloader, num_steps_per_epoch, num_batch diff --git a/opensora/datasets/sampler.py b/opensora/datasets/sampler.py index bf2c642..00c6ba2 100644 --- a/opensora/datasets/sampler.py +++ b/opensora/datasets/sampler.py @@ -54,8 +54,8 @@ class StatefulDistributedSampler(DistributedSampler): def reset(self) -> None: self.start_index = 0 - def state_dict(self) -> dict: - return {"start_index": self.start_index} + def state_dict(self, step) -> dict: + return {"start_index": step} def load_state_dict(self, state_dict: dict) -> None: self.__dict__.update(state_dict) diff --git a/scripts/inference_vae.py b/scripts/inference_vae.py index b6c8ea3..f4542e3 100644 --- a/scripts/inference_vae.py +++ b/scripts/inference_vae.py @@ -7,7 +7,8 @@ from mmengine.runner import set_random_seed from tqdm import tqdm from opensora.acceleration.parallel_states import get_data_parallel_group -from opensora.datasets import prepare_dataloader, save_sample +from opensora.datasets import save_sample +from opensora.datasets.dataloader import prepare_dataloader from opensora.models.vae.losses import VAELoss from opensora.registry import DATASETS, MODELS, build_module from opensora.utils.config_utils import parse_configs @@ -54,7 +55,6 @@ def main(): drop_last=False, pin_memory=True, process_group=get_data_parallel_group(), - distributed=is_distributed(), ) logger.info("Dataset %s contains %s videos.", cfg.dataset.data_path, len(dataset)) total_batch_size = batch_size * get_world_size() diff --git a/scripts/train_vae.py b/scripts/train_vae.py index e26b675..8913d07 100644 --- a/scripts/train_vae.py +++ b/scripts/train_vae.py @@ -5,7 +5,6 @@ from pprint import pformat import torch import torch.distributed as dist -import wandb from colossalai.booster import Booster from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam @@ -13,9 +12,10 @@ from colossalai.utils import get_current_device, set_seed from einops import rearrange from tqdm import tqdm +import wandb from opensora.acceleration.checkpoint import set_grad_checkpoint from opensora.acceleration.parallel_states import get_data_parallel_group -from opensora.datasets import prepare_dataloader +from opensora.datasets.dataloader import prepare_dataloader from opensora.models.vae.losses import AdversarialLoss, DiscriminatorLoss, VAELoss from opensora.registry import DATASETS, MODELS, build_module from opensora.utils.ckpt_utils import load, save @@ -364,6 +364,7 @@ def main(): step=step + 1, global_step=global_step + 1, batch_size=cfg.get("batch_size", None), + sampler=sampler, ) save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}") diff --git a/tools/datasets/datautil.py b/tools/datasets/datautil.py index d4d52a9..2a2921c 100644 --- a/tools/datasets/datautil.py +++ b/tools/datasets/datautil.py @@ -8,6 +8,7 @@ from functools import partial from glob import glob import cv2 +from PIL import Image import numpy as np import pandas as pd import torchvision @@ -48,7 +49,7 @@ def get_video_length(cap, method="header"): return length -def get_info(path): +def get_info_old(path): try: ext = os.path.splitext(path)[1].lower() if ext in IMG_EXTENSIONS: @@ -72,21 +73,78 @@ def get_info(path): return 0, 0, 0, np.nan, np.nan, np.nan -def get_video_info(path): +def get_info(path): try: - vframes, _, infos = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW") - num_frames, height, width = vframes.shape[0], vframes.shape[2], vframes.shape[3] - aspect_ratio = height / width - if "video_fps" in infos: - fps = infos["video_fps"] + ext = os.path.splitext(path)[1].lower() + if ext in IMG_EXTENSIONS: + return get_image_info(path) else: - fps = np.nan - resolution = height * width - return num_frames, height, width, aspect_ratio, fps, resolution + return get_video_info(path) except: return 0, 0, 0, np.nan, np.nan, np.nan +def get_image_info(path, backend='pillow'): + if backend == 'pillow': + try: + with open(path, "rb") as f: + img = Image.open(f) + img = img.convert("RGB") + width, height = img.size + num_frames, fps = 1, np.nan + hw = height * width + aspect_ratio = height / width if width > 0 else np.nan + return num_frames, height, width, aspect_ratio, fps, hw + except: + return 0, 0, 0, np.nan, np.nan, np.nan + elif backend == 'cv2': + try: + im = cv2.imread(path) + if im is None: + return 0, 0, 0, np.nan, np.nan, np.nan + height, width = im.shape[:2] + num_frames, fps = 1, np.nan + hw = height * width + aspect_ratio = height / width if width > 0 else np.nan + return num_frames, height, width, aspect_ratio, fps, hw + except: + return 0, 0, 0, np.nan, np.nan, np.nan + else: + raise ValueError + + +def get_video_info(path, backend='torchvision'): + if backend == 'torchvision': + try: + vframes, _, infos = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW") + num_frames, height, width = vframes.shape[0], vframes.shape[2], vframes.shape[3] + if "video_fps" in infos: + fps = infos["video_fps"] + else: + fps = np.nan + hw = height * width + aspect_ratio = height / width if width > 0 else np.nan + return num_frames, height, width, aspect_ratio, fps, hw + except: + return 0, 0, 0, np.nan, np.nan, np.nan + elif backend == 'cv2': + try: + cap = cv2.VideoCapture(path) + num_frames, height, width, fps = ( + get_video_length(cap, method="header"), + int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), + int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), + float(cap.get(cv2.CAP_PROP_FPS)), + ) + hw = height * width + aspect_ratio = height / width if width > 0 else np.nan + return num_frames, height, width, aspect_ratio, fps, hw + except: + return 0, 0, 0, np.nan, np.nan, np.nan + else: + raise ValueError + + # ====================================================== # --refine-llm-caption # ====================================================== @@ -547,6 +605,12 @@ def main(args): data = data.drop_duplicates(subset=["text"], keep="first") print(f"Filtered number of samples: {len(data)}.") + # process data + if args.shuffle: + data = data.sample(frac=1).reset_index(drop=True) # shuffle + if args.get_first_n_data is not None: + data = data.head(args.get_first_n_data) + # shard data if args.shard is not None: sharded_data = np.array_split(data, args.shard) @@ -567,7 +631,7 @@ def parse_args(): parser.add_argument("--format", type=str, default="csv", help="output format", choices=["csv", "parquet"]) parser.add_argument("--disable-parallel", action="store_true", help="disable parallel processing") parser.add_argument("--num-workers", type=int, default=None, help="number of workers") - parser.add_argument("--seed", type=int, default=None, help="random seed") + parser.add_argument("--seed", type=int, default=42, help="random seed") # special case parser.add_argument("--shard", type=int, default=None, help="shard the dataset") @@ -623,6 +687,10 @@ def parse_args(): parser.add_argument("--flowmin", type=float, default=None, help="filter the dataset by minimum flow score") parser.add_argument("--fpsmax", type=float, default=None, help="filter the dataset by maximum fps") + # data processing + parser.add_argument("--shuffle", default=False, action="store_true", help="shuffle the dataset") + parser.add_argument("--get_first_n_data", type=int, default=None, help="return the first n rows of data") + return parser.parse_args() @@ -697,6 +765,12 @@ def get_output_path(args, input_name): if args.flowmin is not None: name += f"_flowmin{args.flowmin}" + # processing + if args.shuffle: + name += f"_shuffled_seed{args.seed}" + if args.get_first_n_data is not None: + name += f"_first_{args.get_first_n_data}_data" + output_path = os.path.join(dir_path, f"{name}.{args.format}") return output_path