Merge branch 'dev/v1.2' of github.com:hpcaitech/Open-Sora-dev into dev/v1.2

This commit is contained in:
zhengzangw 2024-05-21 02:38:27 +00:00
commit c2d499129f
6 changed files with 100 additions and 24 deletions

3
.gitignore vendored
View file

@ -191,3 +191,6 @@ eval/vae/flolpips/weights/
node_modules/ node_modules/
package-lock.json package-lock.json
package.json package.json
tools/caption/pllava_dir/PLLaVA/

View file

@ -7,8 +7,8 @@ from colossalai.cluster import DistCoordinator
from mmengine.runner import set_random_seed from mmengine.runner import set_random_seed
from tqdm import tqdm from tqdm import tqdm
from opensora.acceleration.parallel_states import get_data_parallel_group from opensora.acceleration.parallel_states import get_data_parallel_group, set_data_parallel_group
from opensora.datasets import prepare_variable_dataloader from opensora.datasets.dataloader import prepare_dataloader
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
from opensora.utils.config_utils import parse_configs from opensora.utils.config_utils import parse_configs
from opensora.utils.misc import create_logger, to_torch_dtype from opensora.utils.misc import create_logger, to_torch_dtype
@ -43,6 +43,7 @@ def main():
colossalai.launch_from_torch({}) colossalai.launch_from_torch({})
DistCoordinator() DistCoordinator()
set_random_seed(seed=cfg.get("seed", 1024)) set_random_seed(seed=cfg.get("seed", 1024))
set_data_parallel_group(dist.group.WORLD)
# == init logger == # == init logger ==
logger = create_logger() logger = create_logger()
@ -101,11 +102,8 @@ def main():
pin_memory=True, pin_memory=True,
process_group=get_data_parallel_group(), process_group=get_data_parallel_group(),
) )
dataloader = prepare_variable_dataloader( dataloader, sampler = prepare_dataloader(bucket_config=bucket_config, **dataloader_args)
bucket_config=bucket_config, num_batch = sampler.get_num_batch()
**dataloader_args,
)
num_batch = dataloader.batch_sampler.get_num_batch()
num_steps_per_epoch = num_batch // dist.get_world_size() num_steps_per_epoch = num_batch // dist.get_world_size()
return dataloader, num_steps_per_epoch, num_batch return dataloader, num_steps_per_epoch, num_batch

View file

@ -54,8 +54,8 @@ class StatefulDistributedSampler(DistributedSampler):
def reset(self) -> None: def reset(self) -> None:
self.start_index = 0 self.start_index = 0
def state_dict(self) -> dict: def state_dict(self, step) -> dict:
return {"start_index": self.start_index} return {"start_index": step}
def load_state_dict(self, state_dict: dict) -> None: def load_state_dict(self, state_dict: dict) -> None:
self.__dict__.update(state_dict) self.__dict__.update(state_dict)

View file

@ -7,7 +7,8 @@ from mmengine.runner import set_random_seed
from tqdm import tqdm from tqdm import tqdm
from opensora.acceleration.parallel_states import get_data_parallel_group from opensora.acceleration.parallel_states import get_data_parallel_group
from opensora.datasets import prepare_dataloader, save_sample from opensora.datasets import save_sample
from opensora.datasets.dataloader import prepare_dataloader
from opensora.models.vae.losses import VAELoss from opensora.models.vae.losses import VAELoss
from opensora.registry import DATASETS, MODELS, build_module from opensora.registry import DATASETS, MODELS, build_module
from opensora.utils.config_utils import parse_configs from opensora.utils.config_utils import parse_configs
@ -54,7 +55,6 @@ def main():
drop_last=False, drop_last=False,
pin_memory=True, pin_memory=True,
process_group=get_data_parallel_group(), process_group=get_data_parallel_group(),
distributed=is_distributed(),
) )
logger.info("Dataset %s contains %s videos.", cfg.dataset.data_path, len(dataset)) logger.info("Dataset %s contains %s videos.", cfg.dataset.data_path, len(dataset))
total_batch_size = batch_size * get_world_size() total_batch_size = batch_size * get_world_size()

View file

@ -5,7 +5,6 @@ from pprint import pformat
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import wandb
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
@ -13,9 +12,10 @@ from colossalai.utils import get_current_device, set_seed
from einops import rearrange from einops import rearrange
from tqdm import tqdm from tqdm import tqdm
import wandb
from opensora.acceleration.checkpoint import set_grad_checkpoint from opensora.acceleration.checkpoint import set_grad_checkpoint
from opensora.acceleration.parallel_states import get_data_parallel_group from opensora.acceleration.parallel_states import get_data_parallel_group
from opensora.datasets import prepare_dataloader from opensora.datasets.dataloader import prepare_dataloader
from opensora.models.vae.losses import AdversarialLoss, DiscriminatorLoss, VAELoss from opensora.models.vae.losses import AdversarialLoss, DiscriminatorLoss, VAELoss
from opensora.registry import DATASETS, MODELS, build_module from opensora.registry import DATASETS, MODELS, build_module
from opensora.utils.ckpt_utils import load, save from opensora.utils.ckpt_utils import load, save
@ -364,6 +364,7 @@ def main():
step=step + 1, step=step + 1,
global_step=global_step + 1, global_step=global_step + 1,
batch_size=cfg.get("batch_size", None), batch_size=cfg.get("batch_size", None),
sampler=sampler,
) )
save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}") save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}")

View file

@ -8,6 +8,7 @@ from functools import partial
from glob import glob from glob import glob
import cv2 import cv2
from PIL import Image
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import torchvision import torchvision
@ -48,7 +49,7 @@ def get_video_length(cap, method="header"):
return length return length
def get_info(path): def get_info_old(path):
try: try:
ext = os.path.splitext(path)[1].lower() ext = os.path.splitext(path)[1].lower()
if ext in IMG_EXTENSIONS: if ext in IMG_EXTENSIONS:
@ -72,21 +73,78 @@ def get_info(path):
return 0, 0, 0, np.nan, np.nan, np.nan return 0, 0, 0, np.nan, np.nan, np.nan
def get_video_info(path): def get_info(path):
try: try:
vframes, _, infos = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW") ext = os.path.splitext(path)[1].lower()
num_frames, height, width = vframes.shape[0], vframes.shape[2], vframes.shape[3] if ext in IMG_EXTENSIONS:
aspect_ratio = height / width return get_image_info(path)
if "video_fps" in infos:
fps = infos["video_fps"]
else: else:
fps = np.nan return get_video_info(path)
resolution = height * width
return num_frames, height, width, aspect_ratio, fps, resolution
except: except:
return 0, 0, 0, np.nan, np.nan, np.nan return 0, 0, 0, np.nan, np.nan, np.nan
def get_image_info(path, backend='pillow'):
if backend == 'pillow':
try:
with open(path, "rb") as f:
img = Image.open(f)
img = img.convert("RGB")
width, height = img.size
num_frames, fps = 1, np.nan
hw = height * width
aspect_ratio = height / width if width > 0 else np.nan
return num_frames, height, width, aspect_ratio, fps, hw
except:
return 0, 0, 0, np.nan, np.nan, np.nan
elif backend == 'cv2':
try:
im = cv2.imread(path)
if im is None:
return 0, 0, 0, np.nan, np.nan, np.nan
height, width = im.shape[:2]
num_frames, fps = 1, np.nan
hw = height * width
aspect_ratio = height / width if width > 0 else np.nan
return num_frames, height, width, aspect_ratio, fps, hw
except:
return 0, 0, 0, np.nan, np.nan, np.nan
else:
raise ValueError
def get_video_info(path, backend='torchvision'):
if backend == 'torchvision':
try:
vframes, _, infos = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
num_frames, height, width = vframes.shape[0], vframes.shape[2], vframes.shape[3]
if "video_fps" in infos:
fps = infos["video_fps"]
else:
fps = np.nan
hw = height * width
aspect_ratio = height / width if width > 0 else np.nan
return num_frames, height, width, aspect_ratio, fps, hw
except:
return 0, 0, 0, np.nan, np.nan, np.nan
elif backend == 'cv2':
try:
cap = cv2.VideoCapture(path)
num_frames, height, width, fps = (
get_video_length(cap, method="header"),
int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
float(cap.get(cv2.CAP_PROP_FPS)),
)
hw = height * width
aspect_ratio = height / width if width > 0 else np.nan
return num_frames, height, width, aspect_ratio, fps, hw
except:
return 0, 0, 0, np.nan, np.nan, np.nan
else:
raise ValueError
# ====================================================== # ======================================================
# --refine-llm-caption # --refine-llm-caption
# ====================================================== # ======================================================
@ -547,6 +605,12 @@ def main(args):
data = data.drop_duplicates(subset=["text"], keep="first") data = data.drop_duplicates(subset=["text"], keep="first")
print(f"Filtered number of samples: {len(data)}.") print(f"Filtered number of samples: {len(data)}.")
# process data
if args.shuffle:
data = data.sample(frac=1).reset_index(drop=True) # shuffle
if args.get_first_n_data is not None:
data = data.head(args.get_first_n_data)
# shard data # shard data
if args.shard is not None: if args.shard is not None:
sharded_data = np.array_split(data, args.shard) sharded_data = np.array_split(data, args.shard)
@ -567,7 +631,7 @@ def parse_args():
parser.add_argument("--format", type=str, default="csv", help="output format", choices=["csv", "parquet"]) parser.add_argument("--format", type=str, default="csv", help="output format", choices=["csv", "parquet"])
parser.add_argument("--disable-parallel", action="store_true", help="disable parallel processing") parser.add_argument("--disable-parallel", action="store_true", help="disable parallel processing")
parser.add_argument("--num-workers", type=int, default=None, help="number of workers") parser.add_argument("--num-workers", type=int, default=None, help="number of workers")
parser.add_argument("--seed", type=int, default=None, help="random seed") parser.add_argument("--seed", type=int, default=42, help="random seed")
# special case # special case
parser.add_argument("--shard", type=int, default=None, help="shard the dataset") parser.add_argument("--shard", type=int, default=None, help="shard the dataset")
@ -623,6 +687,10 @@ def parse_args():
parser.add_argument("--flowmin", type=float, default=None, help="filter the dataset by minimum flow score") parser.add_argument("--flowmin", type=float, default=None, help="filter the dataset by minimum flow score")
parser.add_argument("--fpsmax", type=float, default=None, help="filter the dataset by maximum fps") parser.add_argument("--fpsmax", type=float, default=None, help="filter the dataset by maximum fps")
# data processing
parser.add_argument("--shuffle", default=False, action="store_true", help="shuffle the dataset")
parser.add_argument("--get_first_n_data", type=int, default=None, help="return the first n rows of data")
return parser.parse_args() return parser.parse_args()
@ -697,6 +765,12 @@ def get_output_path(args, input_name):
if args.flowmin is not None: if args.flowmin is not None:
name += f"_flowmin{args.flowmin}" name += f"_flowmin{args.flowmin}"
# processing
if args.shuffle:
name += f"_shuffled_seed{args.seed}"
if args.get_first_n_data is not None:
name += f"_first_{args.get_first_n_data}_data"
output_path = os.path.join(dir_path, f"{name}.{args.format}") output_path = os.path.join(dir_path, f"{name}.{args.format}")
return output_path return output_path