Merge branch 'dev/v1.2' of github.com:hpcaitech/Open-Sora-dev into dev/v1.2

This commit is contained in:
zhengzangw 2024-05-21 02:38:27 +00:00
commit c2d499129f
6 changed files with 100 additions and 24 deletions

3
.gitignore vendored
View file

@ -191,3 +191,6 @@ eval/vae/flolpips/weights/
node_modules/
package-lock.json
package.json
tools/caption/pllava_dir/PLLaVA/

View file

@ -7,8 +7,8 @@ from colossalai.cluster import DistCoordinator
from mmengine.runner import set_random_seed
from tqdm import tqdm
from opensora.acceleration.parallel_states import get_data_parallel_group
from opensora.datasets import prepare_variable_dataloader
from opensora.acceleration.parallel_states import get_data_parallel_group, set_data_parallel_group
from opensora.datasets.dataloader import prepare_dataloader
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
from opensora.utils.config_utils import parse_configs
from opensora.utils.misc import create_logger, to_torch_dtype
@ -43,6 +43,7 @@ def main():
colossalai.launch_from_torch({})
DistCoordinator()
set_random_seed(seed=cfg.get("seed", 1024))
set_data_parallel_group(dist.group.WORLD)
# == init logger ==
logger = create_logger()
@ -101,11 +102,8 @@ def main():
pin_memory=True,
process_group=get_data_parallel_group(),
)
dataloader = prepare_variable_dataloader(
bucket_config=bucket_config,
**dataloader_args,
)
num_batch = dataloader.batch_sampler.get_num_batch()
dataloader, sampler = prepare_dataloader(bucket_config=bucket_config, **dataloader_args)
num_batch = sampler.get_num_batch()
num_steps_per_epoch = num_batch // dist.get_world_size()
return dataloader, num_steps_per_epoch, num_batch

View file

@ -54,8 +54,8 @@ class StatefulDistributedSampler(DistributedSampler):
def reset(self) -> None:
self.start_index = 0
def state_dict(self) -> dict:
return {"start_index": self.start_index}
def state_dict(self, step) -> dict:
return {"start_index": step}
def load_state_dict(self, state_dict: dict) -> None:
self.__dict__.update(state_dict)

View file

@ -7,7 +7,8 @@ from mmengine.runner import set_random_seed
from tqdm import tqdm
from opensora.acceleration.parallel_states import get_data_parallel_group
from opensora.datasets import prepare_dataloader, save_sample
from opensora.datasets import save_sample
from opensora.datasets.dataloader import prepare_dataloader
from opensora.models.vae.losses import VAELoss
from opensora.registry import DATASETS, MODELS, build_module
from opensora.utils.config_utils import parse_configs
@ -54,7 +55,6 @@ def main():
drop_last=False,
pin_memory=True,
process_group=get_data_parallel_group(),
distributed=is_distributed(),
)
logger.info("Dataset %s contains %s videos.", cfg.dataset.data_path, len(dataset))
total_batch_size = batch_size * get_world_size()

View file

@ -5,7 +5,6 @@ from pprint import pformat
import torch
import torch.distributed as dist
import wandb
from colossalai.booster import Booster
from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam
@ -13,9 +12,10 @@ from colossalai.utils import get_current_device, set_seed
from einops import rearrange
from tqdm import tqdm
import wandb
from opensora.acceleration.checkpoint import set_grad_checkpoint
from opensora.acceleration.parallel_states import get_data_parallel_group
from opensora.datasets import prepare_dataloader
from opensora.datasets.dataloader import prepare_dataloader
from opensora.models.vae.losses import AdversarialLoss, DiscriminatorLoss, VAELoss
from opensora.registry import DATASETS, MODELS, build_module
from opensora.utils.ckpt_utils import load, save
@ -364,6 +364,7 @@ def main():
step=step + 1,
global_step=global_step + 1,
batch_size=cfg.get("batch_size", None),
sampler=sampler,
)
save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}")

View file

@ -8,6 +8,7 @@ from functools import partial
from glob import glob
import cv2
from PIL import Image
import numpy as np
import pandas as pd
import torchvision
@ -48,7 +49,7 @@ def get_video_length(cap, method="header"):
return length
def get_info(path):
def get_info_old(path):
try:
ext = os.path.splitext(path)[1].lower()
if ext in IMG_EXTENSIONS:
@ -72,21 +73,78 @@ def get_info(path):
return 0, 0, 0, np.nan, np.nan, np.nan
def get_video_info(path):
def get_info(path):
try:
vframes, _, infos = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
num_frames, height, width = vframes.shape[0], vframes.shape[2], vframes.shape[3]
aspect_ratio = height / width
if "video_fps" in infos:
fps = infos["video_fps"]
ext = os.path.splitext(path)[1].lower()
if ext in IMG_EXTENSIONS:
return get_image_info(path)
else:
fps = np.nan
resolution = height * width
return num_frames, height, width, aspect_ratio, fps, resolution
return get_video_info(path)
except:
return 0, 0, 0, np.nan, np.nan, np.nan
def get_image_info(path, backend='pillow'):
if backend == 'pillow':
try:
with open(path, "rb") as f:
img = Image.open(f)
img = img.convert("RGB")
width, height = img.size
num_frames, fps = 1, np.nan
hw = height * width
aspect_ratio = height / width if width > 0 else np.nan
return num_frames, height, width, aspect_ratio, fps, hw
except:
return 0, 0, 0, np.nan, np.nan, np.nan
elif backend == 'cv2':
try:
im = cv2.imread(path)
if im is None:
return 0, 0, 0, np.nan, np.nan, np.nan
height, width = im.shape[:2]
num_frames, fps = 1, np.nan
hw = height * width
aspect_ratio = height / width if width > 0 else np.nan
return num_frames, height, width, aspect_ratio, fps, hw
except:
return 0, 0, 0, np.nan, np.nan, np.nan
else:
raise ValueError
def get_video_info(path, backend='torchvision'):
if backend == 'torchvision':
try:
vframes, _, infos = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
num_frames, height, width = vframes.shape[0], vframes.shape[2], vframes.shape[3]
if "video_fps" in infos:
fps = infos["video_fps"]
else:
fps = np.nan
hw = height * width
aspect_ratio = height / width if width > 0 else np.nan
return num_frames, height, width, aspect_ratio, fps, hw
except:
return 0, 0, 0, np.nan, np.nan, np.nan
elif backend == 'cv2':
try:
cap = cv2.VideoCapture(path)
num_frames, height, width, fps = (
get_video_length(cap, method="header"),
int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
float(cap.get(cv2.CAP_PROP_FPS)),
)
hw = height * width
aspect_ratio = height / width if width > 0 else np.nan
return num_frames, height, width, aspect_ratio, fps, hw
except:
return 0, 0, 0, np.nan, np.nan, np.nan
else:
raise ValueError
# ======================================================
# --refine-llm-caption
# ======================================================
@ -547,6 +605,12 @@ def main(args):
data = data.drop_duplicates(subset=["text"], keep="first")
print(f"Filtered number of samples: {len(data)}.")
# process data
if args.shuffle:
data = data.sample(frac=1).reset_index(drop=True) # shuffle
if args.get_first_n_data is not None:
data = data.head(args.get_first_n_data)
# shard data
if args.shard is not None:
sharded_data = np.array_split(data, args.shard)
@ -567,7 +631,7 @@ def parse_args():
parser.add_argument("--format", type=str, default="csv", help="output format", choices=["csv", "parquet"])
parser.add_argument("--disable-parallel", action="store_true", help="disable parallel processing")
parser.add_argument("--num-workers", type=int, default=None, help="number of workers")
parser.add_argument("--seed", type=int, default=None, help="random seed")
parser.add_argument("--seed", type=int, default=42, help="random seed")
# special case
parser.add_argument("--shard", type=int, default=None, help="shard the dataset")
@ -623,6 +687,10 @@ def parse_args():
parser.add_argument("--flowmin", type=float, default=None, help="filter the dataset by minimum flow score")
parser.add_argument("--fpsmax", type=float, default=None, help="filter the dataset by maximum fps")
# data processing
parser.add_argument("--shuffle", default=False, action="store_true", help="shuffle the dataset")
parser.add_argument("--get_first_n_data", type=int, default=None, help="return the first n rows of data")
return parser.parse_args()
@ -697,6 +765,12 @@ def get_output_path(args, input_name):
if args.flowmin is not None:
name += f"_flowmin{args.flowmin}"
# processing
if args.shuffle:
name += f"_shuffled_seed{args.seed}"
if args.get_first_n_data is not None:
name += f"_first_{args.get_first_n_data}_data"
output_path = os.path.join(dir_path, f"{name}.{args.format}")
return output_path