diff --git a/.gitignore b/.gitignore index cd88597..f0a41f1 100644 --- a/.gitignore +++ b/.gitignore @@ -191,3 +191,6 @@ eval/vae/flolpips/weights/ node_modules/ package-lock.json package.json + + +tools/caption/pllava_dir/PLLaVA/ diff --git a/eval/loss/eval_loss.py b/eval/loss/eval_loss.py index 4df070c..a90d400 100644 --- a/eval/loss/eval_loss.py +++ b/eval/loss/eval_loss.py @@ -7,8 +7,8 @@ from colossalai.cluster import DistCoordinator from mmengine.runner import set_random_seed from tqdm import tqdm -from opensora.acceleration.parallel_states import get_data_parallel_group -from opensora.datasets import prepare_variable_dataloader +from opensora.acceleration.parallel_states import get_data_parallel_group, set_data_parallel_group +from opensora.datasets.dataloader import prepare_dataloader from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module from opensora.utils.config_utils import parse_configs from opensora.utils.misc import create_logger, to_torch_dtype @@ -43,6 +43,7 @@ def main(): colossalai.launch_from_torch({}) DistCoordinator() set_random_seed(seed=cfg.get("seed", 1024)) + set_data_parallel_group(dist.group.WORLD) # == init logger == logger = create_logger() @@ -101,11 +102,8 @@ def main(): pin_memory=True, process_group=get_data_parallel_group(), ) - dataloader = prepare_variable_dataloader( - bucket_config=bucket_config, - **dataloader_args, - ) - num_batch = dataloader.batch_sampler.get_num_batch() + dataloader, sampler = prepare_dataloader(bucket_config=bucket_config, **dataloader_args) + num_batch = sampler.get_num_batch() num_steps_per_epoch = num_batch // dist.get_world_size() return dataloader, num_steps_per_epoch, num_batch diff --git a/opensora/datasets/sampler.py b/opensora/datasets/sampler.py index bf2c642..00c6ba2 100644 --- a/opensora/datasets/sampler.py +++ b/opensora/datasets/sampler.py @@ -54,8 +54,8 @@ class StatefulDistributedSampler(DistributedSampler): def reset(self) -> None: self.start_index = 0 - def state_dict(self) -> dict: - return {"start_index": self.start_index} + def state_dict(self, step) -> dict: + return {"start_index": step} def load_state_dict(self, state_dict: dict) -> None: self.__dict__.update(state_dict) diff --git a/scripts/inference_vae.py b/scripts/inference_vae.py index b6c8ea3..f4542e3 100644 --- a/scripts/inference_vae.py +++ b/scripts/inference_vae.py @@ -7,7 +7,8 @@ from mmengine.runner import set_random_seed from tqdm import tqdm from opensora.acceleration.parallel_states import get_data_parallel_group -from opensora.datasets import prepare_dataloader, save_sample +from opensora.datasets import save_sample +from opensora.datasets.dataloader import prepare_dataloader from opensora.models.vae.losses import VAELoss from opensora.registry import DATASETS, MODELS, build_module from opensora.utils.config_utils import parse_configs @@ -54,7 +55,6 @@ def main(): drop_last=False, pin_memory=True, process_group=get_data_parallel_group(), - distributed=is_distributed(), ) logger.info("Dataset %s contains %s videos.", cfg.dataset.data_path, len(dataset)) total_batch_size = batch_size * get_world_size() diff --git a/scripts/train_vae.py b/scripts/train_vae.py index e26b675..8913d07 100644 --- a/scripts/train_vae.py +++ b/scripts/train_vae.py @@ -5,7 +5,6 @@ from pprint import pformat import torch import torch.distributed as dist -import wandb from colossalai.booster import Booster from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam @@ -13,9 +12,10 @@ from colossalai.utils import get_current_device, set_seed from einops import rearrange from tqdm import tqdm +import wandb from opensora.acceleration.checkpoint import set_grad_checkpoint from opensora.acceleration.parallel_states import get_data_parallel_group -from opensora.datasets import prepare_dataloader +from opensora.datasets.dataloader import prepare_dataloader from opensora.models.vae.losses import AdversarialLoss, DiscriminatorLoss, VAELoss from opensora.registry import DATASETS, MODELS, build_module from opensora.utils.ckpt_utils import load, save @@ -364,6 +364,7 @@ def main(): step=step + 1, global_step=global_step + 1, batch_size=cfg.get("batch_size", None), + sampler=sampler, ) save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}") diff --git a/tools/datasets/datautil.py b/tools/datasets/datautil.py index 68bf5be..2a2921c 100644 --- a/tools/datasets/datautil.py +++ b/tools/datasets/datautil.py @@ -605,6 +605,12 @@ def main(args): data = data.drop_duplicates(subset=["text"], keep="first") print(f"Filtered number of samples: {len(data)}.") + # process data + if args.shuffle: + data = data.sample(frac=1).reset_index(drop=True) # shuffle + if args.get_first_n_data is not None: + data = data.head(args.get_first_n_data) + # shard data if args.shard is not None: sharded_data = np.array_split(data, args.shard) @@ -625,7 +631,7 @@ def parse_args(): parser.add_argument("--format", type=str, default="csv", help="output format", choices=["csv", "parquet"]) parser.add_argument("--disable-parallel", action="store_true", help="disable parallel processing") parser.add_argument("--num-workers", type=int, default=None, help="number of workers") - parser.add_argument("--seed", type=int, default=None, help="random seed") + parser.add_argument("--seed", type=int, default=42, help="random seed") # special case parser.add_argument("--shard", type=int, default=None, help="shard the dataset") @@ -681,6 +687,10 @@ def parse_args(): parser.add_argument("--flowmin", type=float, default=None, help="filter the dataset by minimum flow score") parser.add_argument("--fpsmax", type=float, default=None, help="filter the dataset by maximum fps") + # data processing + parser.add_argument("--shuffle", default=False, action="store_true", help="shuffle the dataset") + parser.add_argument("--get_first_n_data", type=int, default=None, help="return the first n rows of data") + return parser.parse_args() @@ -755,6 +765,12 @@ def get_output_path(args, input_name): if args.flowmin is not None: name += f"_flowmin{args.flowmin}" + # processing + if args.shuffle: + name += f"_shuffled_seed{args.seed}" + if args.get_first_n_data is not None: + name += f"_first_{args.get_first_n_data}_data" + output_path = os.path.join(dir_path, f"{name}.{args.format}") return output_path