mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 12:49:38 +02:00
Merge branch 'dev/v1.2' of github.com:hpcaitech/Open-Sora-dev into dev/v1.2
This commit is contained in:
commit
c2d499129f
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -191,3 +191,6 @@ eval/vae/flolpips/weights/
|
|||
node_modules/
|
||||
package-lock.json
|
||||
package.json
|
||||
|
||||
|
||||
tools/caption/pllava_dir/PLLaVA/
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@ from colossalai.cluster import DistCoordinator
|
|||
from mmengine.runner import set_random_seed
|
||||
from tqdm import tqdm
|
||||
|
||||
from opensora.acceleration.parallel_states import get_data_parallel_group
|
||||
from opensora.datasets import prepare_variable_dataloader
|
||||
from opensora.acceleration.parallel_states import get_data_parallel_group, set_data_parallel_group
|
||||
from opensora.datasets.dataloader import prepare_dataloader
|
||||
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
|
||||
from opensora.utils.config_utils import parse_configs
|
||||
from opensora.utils.misc import create_logger, to_torch_dtype
|
||||
|
|
@ -43,6 +43,7 @@ def main():
|
|||
colossalai.launch_from_torch({})
|
||||
DistCoordinator()
|
||||
set_random_seed(seed=cfg.get("seed", 1024))
|
||||
set_data_parallel_group(dist.group.WORLD)
|
||||
|
||||
# == init logger ==
|
||||
logger = create_logger()
|
||||
|
|
@ -101,11 +102,8 @@ def main():
|
|||
pin_memory=True,
|
||||
process_group=get_data_parallel_group(),
|
||||
)
|
||||
dataloader = prepare_variable_dataloader(
|
||||
bucket_config=bucket_config,
|
||||
**dataloader_args,
|
||||
)
|
||||
num_batch = dataloader.batch_sampler.get_num_batch()
|
||||
dataloader, sampler = prepare_dataloader(bucket_config=bucket_config, **dataloader_args)
|
||||
num_batch = sampler.get_num_batch()
|
||||
num_steps_per_epoch = num_batch // dist.get_world_size()
|
||||
return dataloader, num_steps_per_epoch, num_batch
|
||||
|
||||
|
|
|
|||
|
|
@ -54,8 +54,8 @@ class StatefulDistributedSampler(DistributedSampler):
|
|||
def reset(self) -> None:
|
||||
self.start_index = 0
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
return {"start_index": self.start_index}
|
||||
def state_dict(self, step) -> dict:
|
||||
return {"start_index": step}
|
||||
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
self.__dict__.update(state_dict)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,8 @@ from mmengine.runner import set_random_seed
|
|||
from tqdm import tqdm
|
||||
|
||||
from opensora.acceleration.parallel_states import get_data_parallel_group
|
||||
from opensora.datasets import prepare_dataloader, save_sample
|
||||
from opensora.datasets import save_sample
|
||||
from opensora.datasets.dataloader import prepare_dataloader
|
||||
from opensora.models.vae.losses import VAELoss
|
||||
from opensora.registry import DATASETS, MODELS, build_module
|
||||
from opensora.utils.config_utils import parse_configs
|
||||
|
|
@ -54,7 +55,6 @@ def main():
|
|||
drop_last=False,
|
||||
pin_memory=True,
|
||||
process_group=get_data_parallel_group(),
|
||||
distributed=is_distributed(),
|
||||
)
|
||||
logger.info("Dataset %s contains %s videos.", cfg.dataset.data_path, len(dataset))
|
||||
total_batch_size = batch_size * get_world_size()
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ from pprint import pformat
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import wandb
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
|
@ -13,9 +12,10 @@ from colossalai.utils import get_current_device, set_seed
|
|||
from einops import rearrange
|
||||
from tqdm import tqdm
|
||||
|
||||
import wandb
|
||||
from opensora.acceleration.checkpoint import set_grad_checkpoint
|
||||
from opensora.acceleration.parallel_states import get_data_parallel_group
|
||||
from opensora.datasets import prepare_dataloader
|
||||
from opensora.datasets.dataloader import prepare_dataloader
|
||||
from opensora.models.vae.losses import AdversarialLoss, DiscriminatorLoss, VAELoss
|
||||
from opensora.registry import DATASETS, MODELS, build_module
|
||||
from opensora.utils.ckpt_utils import load, save
|
||||
|
|
@ -364,6 +364,7 @@ def main():
|
|||
step=step + 1,
|
||||
global_step=global_step + 1,
|
||||
batch_size=cfg.get("batch_size", None),
|
||||
sampler=sampler,
|
||||
)
|
||||
|
||||
save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}")
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from functools import partial
|
|||
from glob import glob
|
||||
|
||||
import cv2
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torchvision
|
||||
|
|
@ -48,7 +49,7 @@ def get_video_length(cap, method="header"):
|
|||
return length
|
||||
|
||||
|
||||
def get_info(path):
|
||||
def get_info_old(path):
|
||||
try:
|
||||
ext = os.path.splitext(path)[1].lower()
|
||||
if ext in IMG_EXTENSIONS:
|
||||
|
|
@ -72,21 +73,78 @@ def get_info(path):
|
|||
return 0, 0, 0, np.nan, np.nan, np.nan
|
||||
|
||||
|
||||
def get_video_info(path):
|
||||
def get_info(path):
|
||||
try:
|
||||
vframes, _, infos = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
|
||||
num_frames, height, width = vframes.shape[0], vframes.shape[2], vframes.shape[3]
|
||||
aspect_ratio = height / width
|
||||
if "video_fps" in infos:
|
||||
fps = infos["video_fps"]
|
||||
ext = os.path.splitext(path)[1].lower()
|
||||
if ext in IMG_EXTENSIONS:
|
||||
return get_image_info(path)
|
||||
else:
|
||||
fps = np.nan
|
||||
resolution = height * width
|
||||
return num_frames, height, width, aspect_ratio, fps, resolution
|
||||
return get_video_info(path)
|
||||
except:
|
||||
return 0, 0, 0, np.nan, np.nan, np.nan
|
||||
|
||||
|
||||
def get_image_info(path, backend='pillow'):
|
||||
if backend == 'pillow':
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
img = Image.open(f)
|
||||
img = img.convert("RGB")
|
||||
width, height = img.size
|
||||
num_frames, fps = 1, np.nan
|
||||
hw = height * width
|
||||
aspect_ratio = height / width if width > 0 else np.nan
|
||||
return num_frames, height, width, aspect_ratio, fps, hw
|
||||
except:
|
||||
return 0, 0, 0, np.nan, np.nan, np.nan
|
||||
elif backend == 'cv2':
|
||||
try:
|
||||
im = cv2.imread(path)
|
||||
if im is None:
|
||||
return 0, 0, 0, np.nan, np.nan, np.nan
|
||||
height, width = im.shape[:2]
|
||||
num_frames, fps = 1, np.nan
|
||||
hw = height * width
|
||||
aspect_ratio = height / width if width > 0 else np.nan
|
||||
return num_frames, height, width, aspect_ratio, fps, hw
|
||||
except:
|
||||
return 0, 0, 0, np.nan, np.nan, np.nan
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
|
||||
def get_video_info(path, backend='torchvision'):
|
||||
if backend == 'torchvision':
|
||||
try:
|
||||
vframes, _, infos = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
|
||||
num_frames, height, width = vframes.shape[0], vframes.shape[2], vframes.shape[3]
|
||||
if "video_fps" in infos:
|
||||
fps = infos["video_fps"]
|
||||
else:
|
||||
fps = np.nan
|
||||
hw = height * width
|
||||
aspect_ratio = height / width if width > 0 else np.nan
|
||||
return num_frames, height, width, aspect_ratio, fps, hw
|
||||
except:
|
||||
return 0, 0, 0, np.nan, np.nan, np.nan
|
||||
elif backend == 'cv2':
|
||||
try:
|
||||
cap = cv2.VideoCapture(path)
|
||||
num_frames, height, width, fps = (
|
||||
get_video_length(cap, method="header"),
|
||||
int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
|
||||
int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
|
||||
float(cap.get(cv2.CAP_PROP_FPS)),
|
||||
)
|
||||
hw = height * width
|
||||
aspect_ratio = height / width if width > 0 else np.nan
|
||||
return num_frames, height, width, aspect_ratio, fps, hw
|
||||
except:
|
||||
return 0, 0, 0, np.nan, np.nan, np.nan
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
|
||||
# ======================================================
|
||||
# --refine-llm-caption
|
||||
# ======================================================
|
||||
|
|
@ -547,6 +605,12 @@ def main(args):
|
|||
data = data.drop_duplicates(subset=["text"], keep="first")
|
||||
print(f"Filtered number of samples: {len(data)}.")
|
||||
|
||||
# process data
|
||||
if args.shuffle:
|
||||
data = data.sample(frac=1).reset_index(drop=True) # shuffle
|
||||
if args.get_first_n_data is not None:
|
||||
data = data.head(args.get_first_n_data)
|
||||
|
||||
# shard data
|
||||
if args.shard is not None:
|
||||
sharded_data = np.array_split(data, args.shard)
|
||||
|
|
@ -567,7 +631,7 @@ def parse_args():
|
|||
parser.add_argument("--format", type=str, default="csv", help="output format", choices=["csv", "parquet"])
|
||||
parser.add_argument("--disable-parallel", action="store_true", help="disable parallel processing")
|
||||
parser.add_argument("--num-workers", type=int, default=None, help="number of workers")
|
||||
parser.add_argument("--seed", type=int, default=None, help="random seed")
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed")
|
||||
|
||||
# special case
|
||||
parser.add_argument("--shard", type=int, default=None, help="shard the dataset")
|
||||
|
|
@ -623,6 +687,10 @@ def parse_args():
|
|||
parser.add_argument("--flowmin", type=float, default=None, help="filter the dataset by minimum flow score")
|
||||
parser.add_argument("--fpsmax", type=float, default=None, help="filter the dataset by maximum fps")
|
||||
|
||||
# data processing
|
||||
parser.add_argument("--shuffle", default=False, action="store_true", help="shuffle the dataset")
|
||||
parser.add_argument("--get_first_n_data", type=int, default=None, help="return the first n rows of data")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
|
@ -697,6 +765,12 @@ def get_output_path(args, input_name):
|
|||
if args.flowmin is not None:
|
||||
name += f"_flowmin{args.flowmin}"
|
||||
|
||||
# processing
|
||||
if args.shuffle:
|
||||
name += f"_shuffled_seed{args.seed}"
|
||||
if args.get_first_n_data is not None:
|
||||
name += f"_first_{args.get_first_n_data}_data"
|
||||
|
||||
output_path = os.path.join(dir_path, f"{name}.{args.format}")
|
||||
return output_path
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue