mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 12:49:38 +02:00
Merge branch 'dev/v1.2' of github.com:hpcaitech/Open-Sora-dev into dev/v1.2
This commit is contained in:
commit
c2d499129f
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -191,3 +191,6 @@ eval/vae/flolpips/weights/
|
||||||
node_modules/
|
node_modules/
|
||||||
package-lock.json
|
package-lock.json
|
||||||
package.json
|
package.json
|
||||||
|
|
||||||
|
|
||||||
|
tools/caption/pllava_dir/PLLaVA/
|
||||||
|
|
|
||||||
|
|
@ -7,8 +7,8 @@ from colossalai.cluster import DistCoordinator
|
||||||
from mmengine.runner import set_random_seed
|
from mmengine.runner import set_random_seed
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from opensora.acceleration.parallel_states import get_data_parallel_group
|
from opensora.acceleration.parallel_states import get_data_parallel_group, set_data_parallel_group
|
||||||
from opensora.datasets import prepare_variable_dataloader
|
from opensora.datasets.dataloader import prepare_dataloader
|
||||||
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
|
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
|
||||||
from opensora.utils.config_utils import parse_configs
|
from opensora.utils.config_utils import parse_configs
|
||||||
from opensora.utils.misc import create_logger, to_torch_dtype
|
from opensora.utils.misc import create_logger, to_torch_dtype
|
||||||
|
|
@ -43,6 +43,7 @@ def main():
|
||||||
colossalai.launch_from_torch({})
|
colossalai.launch_from_torch({})
|
||||||
DistCoordinator()
|
DistCoordinator()
|
||||||
set_random_seed(seed=cfg.get("seed", 1024))
|
set_random_seed(seed=cfg.get("seed", 1024))
|
||||||
|
set_data_parallel_group(dist.group.WORLD)
|
||||||
|
|
||||||
# == init logger ==
|
# == init logger ==
|
||||||
logger = create_logger()
|
logger = create_logger()
|
||||||
|
|
@ -101,11 +102,8 @@ def main():
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
process_group=get_data_parallel_group(),
|
process_group=get_data_parallel_group(),
|
||||||
)
|
)
|
||||||
dataloader = prepare_variable_dataloader(
|
dataloader, sampler = prepare_dataloader(bucket_config=bucket_config, **dataloader_args)
|
||||||
bucket_config=bucket_config,
|
num_batch = sampler.get_num_batch()
|
||||||
**dataloader_args,
|
|
||||||
)
|
|
||||||
num_batch = dataloader.batch_sampler.get_num_batch()
|
|
||||||
num_steps_per_epoch = num_batch // dist.get_world_size()
|
num_steps_per_epoch = num_batch // dist.get_world_size()
|
||||||
return dataloader, num_steps_per_epoch, num_batch
|
return dataloader, num_steps_per_epoch, num_batch
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -54,8 +54,8 @@ class StatefulDistributedSampler(DistributedSampler):
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
self.start_index = 0
|
self.start_index = 0
|
||||||
|
|
||||||
def state_dict(self) -> dict:
|
def state_dict(self, step) -> dict:
|
||||||
return {"start_index": self.start_index}
|
return {"start_index": step}
|
||||||
|
|
||||||
def load_state_dict(self, state_dict: dict) -> None:
|
def load_state_dict(self, state_dict: dict) -> None:
|
||||||
self.__dict__.update(state_dict)
|
self.__dict__.update(state_dict)
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,8 @@ from mmengine.runner import set_random_seed
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from opensora.acceleration.parallel_states import get_data_parallel_group
|
from opensora.acceleration.parallel_states import get_data_parallel_group
|
||||||
from opensora.datasets import prepare_dataloader, save_sample
|
from opensora.datasets import save_sample
|
||||||
|
from opensora.datasets.dataloader import prepare_dataloader
|
||||||
from opensora.models.vae.losses import VAELoss
|
from opensora.models.vae.losses import VAELoss
|
||||||
from opensora.registry import DATASETS, MODELS, build_module
|
from opensora.registry import DATASETS, MODELS, build_module
|
||||||
from opensora.utils.config_utils import parse_configs
|
from opensora.utils.config_utils import parse_configs
|
||||||
|
|
@ -54,7 +55,6 @@ def main():
|
||||||
drop_last=False,
|
drop_last=False,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
process_group=get_data_parallel_group(),
|
process_group=get_data_parallel_group(),
|
||||||
distributed=is_distributed(),
|
|
||||||
)
|
)
|
||||||
logger.info("Dataset %s contains %s videos.", cfg.dataset.data_path, len(dataset))
|
logger.info("Dataset %s contains %s videos.", cfg.dataset.data_path, len(dataset))
|
||||||
total_batch_size = batch_size * get_world_size()
|
total_batch_size = batch_size * get_world_size()
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ from pprint import pformat
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import wandb
|
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
|
@ -13,9 +12,10 @@ from colossalai.utils import get_current_device, set_seed
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import wandb
|
||||||
from opensora.acceleration.checkpoint import set_grad_checkpoint
|
from opensora.acceleration.checkpoint import set_grad_checkpoint
|
||||||
from opensora.acceleration.parallel_states import get_data_parallel_group
|
from opensora.acceleration.parallel_states import get_data_parallel_group
|
||||||
from opensora.datasets import prepare_dataloader
|
from opensora.datasets.dataloader import prepare_dataloader
|
||||||
from opensora.models.vae.losses import AdversarialLoss, DiscriminatorLoss, VAELoss
|
from opensora.models.vae.losses import AdversarialLoss, DiscriminatorLoss, VAELoss
|
||||||
from opensora.registry import DATASETS, MODELS, build_module
|
from opensora.registry import DATASETS, MODELS, build_module
|
||||||
from opensora.utils.ckpt_utils import load, save
|
from opensora.utils.ckpt_utils import load, save
|
||||||
|
|
@ -364,6 +364,7 @@ def main():
|
||||||
step=step + 1,
|
step=step + 1,
|
||||||
global_step=global_step + 1,
|
global_step=global_step + 1,
|
||||||
batch_size=cfg.get("batch_size", None),
|
batch_size=cfg.get("batch_size", None),
|
||||||
|
sampler=sampler,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}")
|
save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}")
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ from functools import partial
|
||||||
from glob import glob
|
from glob import glob
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torchvision
|
import torchvision
|
||||||
|
|
@ -48,7 +49,7 @@ def get_video_length(cap, method="header"):
|
||||||
return length
|
return length
|
||||||
|
|
||||||
|
|
||||||
def get_info(path):
|
def get_info_old(path):
|
||||||
try:
|
try:
|
||||||
ext = os.path.splitext(path)[1].lower()
|
ext = os.path.splitext(path)[1].lower()
|
||||||
if ext in IMG_EXTENSIONS:
|
if ext in IMG_EXTENSIONS:
|
||||||
|
|
@ -72,21 +73,78 @@ def get_info(path):
|
||||||
return 0, 0, 0, np.nan, np.nan, np.nan
|
return 0, 0, 0, np.nan, np.nan, np.nan
|
||||||
|
|
||||||
|
|
||||||
def get_video_info(path):
|
def get_info(path):
|
||||||
try:
|
try:
|
||||||
vframes, _, infos = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
|
ext = os.path.splitext(path)[1].lower()
|
||||||
num_frames, height, width = vframes.shape[0], vframes.shape[2], vframes.shape[3]
|
if ext in IMG_EXTENSIONS:
|
||||||
aspect_ratio = height / width
|
return get_image_info(path)
|
||||||
if "video_fps" in infos:
|
|
||||||
fps = infos["video_fps"]
|
|
||||||
else:
|
else:
|
||||||
fps = np.nan
|
return get_video_info(path)
|
||||||
resolution = height * width
|
|
||||||
return num_frames, height, width, aspect_ratio, fps, resolution
|
|
||||||
except:
|
except:
|
||||||
return 0, 0, 0, np.nan, np.nan, np.nan
|
return 0, 0, 0, np.nan, np.nan, np.nan
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_info(path, backend='pillow'):
|
||||||
|
if backend == 'pillow':
|
||||||
|
try:
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
img = Image.open(f)
|
||||||
|
img = img.convert("RGB")
|
||||||
|
width, height = img.size
|
||||||
|
num_frames, fps = 1, np.nan
|
||||||
|
hw = height * width
|
||||||
|
aspect_ratio = height / width if width > 0 else np.nan
|
||||||
|
return num_frames, height, width, aspect_ratio, fps, hw
|
||||||
|
except:
|
||||||
|
return 0, 0, 0, np.nan, np.nan, np.nan
|
||||||
|
elif backend == 'cv2':
|
||||||
|
try:
|
||||||
|
im = cv2.imread(path)
|
||||||
|
if im is None:
|
||||||
|
return 0, 0, 0, np.nan, np.nan, np.nan
|
||||||
|
height, width = im.shape[:2]
|
||||||
|
num_frames, fps = 1, np.nan
|
||||||
|
hw = height * width
|
||||||
|
aspect_ratio = height / width if width > 0 else np.nan
|
||||||
|
return num_frames, height, width, aspect_ratio, fps, hw
|
||||||
|
except:
|
||||||
|
return 0, 0, 0, np.nan, np.nan, np.nan
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
|
||||||
|
def get_video_info(path, backend='torchvision'):
|
||||||
|
if backend == 'torchvision':
|
||||||
|
try:
|
||||||
|
vframes, _, infos = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
|
||||||
|
num_frames, height, width = vframes.shape[0], vframes.shape[2], vframes.shape[3]
|
||||||
|
if "video_fps" in infos:
|
||||||
|
fps = infos["video_fps"]
|
||||||
|
else:
|
||||||
|
fps = np.nan
|
||||||
|
hw = height * width
|
||||||
|
aspect_ratio = height / width if width > 0 else np.nan
|
||||||
|
return num_frames, height, width, aspect_ratio, fps, hw
|
||||||
|
except:
|
||||||
|
return 0, 0, 0, np.nan, np.nan, np.nan
|
||||||
|
elif backend == 'cv2':
|
||||||
|
try:
|
||||||
|
cap = cv2.VideoCapture(path)
|
||||||
|
num_frames, height, width, fps = (
|
||||||
|
get_video_length(cap, method="header"),
|
||||||
|
int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
|
||||||
|
int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
|
||||||
|
float(cap.get(cv2.CAP_PROP_FPS)),
|
||||||
|
)
|
||||||
|
hw = height * width
|
||||||
|
aspect_ratio = height / width if width > 0 else np.nan
|
||||||
|
return num_frames, height, width, aspect_ratio, fps, hw
|
||||||
|
except:
|
||||||
|
return 0, 0, 0, np.nan, np.nan, np.nan
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
|
||||||
# ======================================================
|
# ======================================================
|
||||||
# --refine-llm-caption
|
# --refine-llm-caption
|
||||||
# ======================================================
|
# ======================================================
|
||||||
|
|
@ -547,6 +605,12 @@ def main(args):
|
||||||
data = data.drop_duplicates(subset=["text"], keep="first")
|
data = data.drop_duplicates(subset=["text"], keep="first")
|
||||||
print(f"Filtered number of samples: {len(data)}.")
|
print(f"Filtered number of samples: {len(data)}.")
|
||||||
|
|
||||||
|
# process data
|
||||||
|
if args.shuffle:
|
||||||
|
data = data.sample(frac=1).reset_index(drop=True) # shuffle
|
||||||
|
if args.get_first_n_data is not None:
|
||||||
|
data = data.head(args.get_first_n_data)
|
||||||
|
|
||||||
# shard data
|
# shard data
|
||||||
if args.shard is not None:
|
if args.shard is not None:
|
||||||
sharded_data = np.array_split(data, args.shard)
|
sharded_data = np.array_split(data, args.shard)
|
||||||
|
|
@ -567,7 +631,7 @@ def parse_args():
|
||||||
parser.add_argument("--format", type=str, default="csv", help="output format", choices=["csv", "parquet"])
|
parser.add_argument("--format", type=str, default="csv", help="output format", choices=["csv", "parquet"])
|
||||||
parser.add_argument("--disable-parallel", action="store_true", help="disable parallel processing")
|
parser.add_argument("--disable-parallel", action="store_true", help="disable parallel processing")
|
||||||
parser.add_argument("--num-workers", type=int, default=None, help="number of workers")
|
parser.add_argument("--num-workers", type=int, default=None, help="number of workers")
|
||||||
parser.add_argument("--seed", type=int, default=None, help="random seed")
|
parser.add_argument("--seed", type=int, default=42, help="random seed")
|
||||||
|
|
||||||
# special case
|
# special case
|
||||||
parser.add_argument("--shard", type=int, default=None, help="shard the dataset")
|
parser.add_argument("--shard", type=int, default=None, help="shard the dataset")
|
||||||
|
|
@ -623,6 +687,10 @@ def parse_args():
|
||||||
parser.add_argument("--flowmin", type=float, default=None, help="filter the dataset by minimum flow score")
|
parser.add_argument("--flowmin", type=float, default=None, help="filter the dataset by minimum flow score")
|
||||||
parser.add_argument("--fpsmax", type=float, default=None, help="filter the dataset by maximum fps")
|
parser.add_argument("--fpsmax", type=float, default=None, help="filter the dataset by maximum fps")
|
||||||
|
|
||||||
|
# data processing
|
||||||
|
parser.add_argument("--shuffle", default=False, action="store_true", help="shuffle the dataset")
|
||||||
|
parser.add_argument("--get_first_n_data", type=int, default=None, help="return the first n rows of data")
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -697,6 +765,12 @@ def get_output_path(args, input_name):
|
||||||
if args.flowmin is not None:
|
if args.flowmin is not None:
|
||||||
name += f"_flowmin{args.flowmin}"
|
name += f"_flowmin{args.flowmin}"
|
||||||
|
|
||||||
|
# processing
|
||||||
|
if args.shuffle:
|
||||||
|
name += f"_shuffled_seed{args.seed}"
|
||||||
|
if args.get_first_n_data is not None:
|
||||||
|
name += f"_first_{args.get_first_n_data}_data"
|
||||||
|
|
||||||
output_path = os.path.join(dir_path, f"{name}.{args.format}")
|
output_path = os.path.join(dir_path, f"{name}.{args.format}")
|
||||||
return output_path
|
return output_path
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue