diff --git a/README.md b/README.md index 3ca35ef..2cdb604 100644 --- a/README.md +++ b/README.md @@ -156,11 +156,11 @@ conda activate opensora # install torch, torchvision and xformers pip install -r requirements/requirements_cu121.txt -# install this project +# download the repo git clone https://github.com/hpcaitech/Open-Sora cd Open-Sora -# the default installation is for inference only +# install Open-Sora pip install -v . ``` diff --git a/configs/opensora-v1-2/misc/extract.py b/configs/opensora-v1-2/misc/extract.py new file mode 100644 index 0000000..fa1c436 --- /dev/null +++ b/configs/opensora-v1-2/misc/extract.py @@ -0,0 +1,60 @@ +# Dataset settings +dataset = dict( + type="VariableVideoTextDataset", + transform_name="resize_crop", +) + +# webvid +bucket_config = { # 12s/it + "144p": {1: (1.0, 475), 51: (1.0, 51), 102: ((1.0, 0.33), 27), 204: ((1.0, 0.1), 13), 408: ((1.0, 0.1), 6)}, + # --- + "256": {1: (0.4, 297), 51: (0.5, 20), 102: ((0.5, 0.33), 10), 204: ((0.5, 0.1), 5), 408: ((0.5, 0.1), 2)}, + "240p": {1: (0.3, 297), 51: (0.4, 20), 102: ((0.4, 0.33), 10), 204: ((0.4, 0.1), 5), 408: ((0.4, 0.1), 2)}, + # --- + "360p": {1: (0.2, 141), 51: (0.15, 8), 102: ((0.15, 0.33), 4), 204: ((0.15, 0.1), 2), 408: ((0.15, 0.1), 1)}, + "512": {1: (0.1, 141)}, + # --- + "480p": {1: (0.1, 89)}, + # --- + "720p": {1: (0.05, 36)}, + "1024": {1: (0.05, 36)}, + # --- + "1080p": {1: (0.1, 5)}, + # --- + "2048": {1: (0.1, 5)}, +} + +# Acceleration settings +num_workers = 8 +num_bucket_build_workers = 16 +dtype = "bf16" +seed = 42 +outputs = "outputs" +wandb = False + + +# Model settings +model = dict( + type="STDiT3-XL/2", + from_pretrained=None, + qk_norm=True, + enable_flash_attn=True, + enable_layernorm_kernel=True, +) +vae = dict( + type="OpenSoraVAE_V1_2", + from_pretrained="pretrained_models/vae-pipeline", + micro_frame_size=17, + micro_batch_size=4, +) +text_encoder = dict( + type="t5", + from_pretrained="DeepFloyd/t5-v1_1-xxl", + model_max_length=300, + shardformer=True, + local_files_only=True, +) + +# feature extraction settings +save_text_features = True +save_compressed_text_features = True diff --git a/configs/opensora-v1-2/train/stage1.py b/configs/opensora-v1-2/train/stage1.py index 242784e..21c404b 100644 --- a/configs/opensora-v1-2/train/stage1.py +++ b/configs/opensora-v1-2/train/stage1.py @@ -59,6 +59,7 @@ model = dict( qk_norm=True, enable_flash_attn=True, enable_layernorm_kernel=True, + freeze_y_embedder=True, ) vae = dict( type="OpenSoraVAE_V1_2", diff --git a/opensora/models/stdit/stdit3.py b/opensora/models/stdit/stdit3.py index 14d5ad7..fd48ad1 100644 --- a/opensora/models/stdit/stdit3.py +++ b/opensora/models/stdit/stdit3.py @@ -169,6 +169,8 @@ class STDiT3Config(PretrainedConfig): enable_layernorm_kernel=False, enable_sequence_parallelism=False, only_train_temporal=False, + freeze_y_embedder=False, + skip_y_embedder=False, **kwargs, ): self.input_size = input_size @@ -189,6 +191,8 @@ class STDiT3Config(PretrainedConfig): self.enable_layernorm_kernel = enable_layernorm_kernel self.enable_sequence_parallelism = enable_sequence_parallelism self.only_train_temporal = only_train_temporal + self.freeze_y_embedder = freeze_y_embedder + self.skip_y_embedder = skip_y_embedder super().__init__(**kwargs) @@ -284,6 +288,10 @@ class STDiT3(PreTrainedModel): for param in block.parameters(): param.requires_grad = True + if config.freeze_y_embedder: + for param in self.y_embedder.parameters(): + param.requires_grad = False + def initialize_weights(self): # Initialize transformer layers: def _basic_init(module): @@ -319,6 +327,19 @@ class STDiT3(PreTrainedModel): W = W // self.patch_size[2] return (T, H, W) + def encode_text(self, y, mask=None): + y = self.y_embedder(y, self.training) # [B, 1, N_token, C] + if mask is not None: + if mask.shape[0] != y.shape[0]: + mask = mask.repeat(y.shape[0] // mask.shape[0], 1) + mask = mask.squeeze(1).squeeze(1) + y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, self.hidden_size) + y_lens = mask.sum(dim=1).tolist() + else: + y_lens = [y.shape[2]] * y.shape[0] + y = y.squeeze(1).view(1, -1, self.hidden_size) + return y, y_lens + def forward(self, x, timestep, y, mask=None, x_mask=None, fps=None, height=None, width=None, **kwargs): dtype = self.x_embedder.proj.weight.dtype B = x.size(0) @@ -348,16 +369,10 @@ class STDiT3(PreTrainedModel): t0_mlp = self.t_block(t0) # === get y embed === - y = self.y_embedder(y, self.training) # [B, 1, N_token, C] - if mask is not None: - if mask.shape[0] != y.shape[0]: - mask = mask.repeat(y.shape[0] // mask.shape[0], 1) - mask = mask.squeeze(1).squeeze(1) - y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, self.hidden_size) - y_lens = mask.sum(dim=1).tolist() + if self.config.skip_y_embedder: + y_lens = mask.tolist() else: - y_lens = [y.shape[2]] * y.shape[0] - y = y.squeeze(1).view(1, -1, self.hidden_size) + y, y_lens = self.encode_text(y, mask) # === get x embed === x = self.x_embedder(x) # [B, N, C] diff --git a/opensora/utils/misc.py b/opensora/utils/misc.py index ae4e3fd..4be047c 100644 --- a/opensora/utils/misc.py +++ b/opensora/utils/misc.py @@ -378,3 +378,23 @@ class Timer: self.end_time = time.time() if self.log: print(f"Elapsed time for {self.name}: {self.elapsed_time:.2f} s") + + +def get_tensor_memory(tensor, human_readable=True): + size = tensor.element_size() * tensor.nelement() + if human_readable: + size = format_numel_str(size) + return size + + +class FeatureSaver: + def __init__(self, save_dir, bin_size=10): + self.save_dir = save_dir + self.bin_size = bin_size + + self.data_list = [] + self.cnt = 0 + + def update(self, data): + self.data_list.append(data) + self.cnt += 1 diff --git a/scripts/misc/extract_feat.py b/scripts/misc/extract_feat.py new file mode 100644 index 0000000..476e203 --- /dev/null +++ b/scripts/misc/extract_feat.py @@ -0,0 +1,167 @@ +import os +from pprint import pformat + +import torch +import torch.distributed as dist +from tqdm import tqdm + +from opensora.acceleration.parallel_states import get_data_parallel_group +from opensora.datasets import prepare_dataloader, prepare_variable_dataloader +from opensora.datasets.utils import collate_fn_ignore_none +from opensora.registry import DATASETS, MODELS, build_module +from opensora.utils.config_utils import parse_configs +from opensora.utils.misc import FeatureSaver, create_logger, format_numel_str, get_model_numel, to_torch_dtype + +DEFAULT_DATASET_NAME = "VideoTextDataset" + + +def main(): + torch.set_grad_enabled(False) + # ====================================================== + # 1. configs & runtime variables + # ====================================================== + # == parse configs == + cfg = parse_configs(training=False) + + # == device and dtype == + assert torch.cuda.is_available(), "Training currently requires at least one GPU." + cfg_dtype = cfg.get("dtype", "bf16") + assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}" + dtype = to_torch_dtype(cfg.get("dtype", "bf16")) + + # == colossalai init distributed training == + device = "cuda" if torch.cuda.is_available() else "cpu" + cfg_dtype = cfg.get("dtype", "fp32") + assert cfg_dtype in ["fp16", "bf16", "fp32"], f"Unknown mixed precision {cfg_dtype}" + dtype = to_torch_dtype(cfg.get("dtype", "bf16")) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + # == init logger, tensorboard & wandb == + logger = create_logger() + logger.info("Configuration:\n %s", pformat(cfg.to_dict())) + + # ====================================================== + # 2. build dataset and dataloader + # ====================================================== + logger.info("Building dataset...") + # == build dataset == + dataset = build_module(cfg.dataset, DATASETS) + logger.info("Dataset contains %s samples.", len(dataset)) + + # == build dataloader == + dataloader_args = dict( + dataset=dataset, + batch_size=cfg.get("batch_size", None), + num_workers=cfg.get("num_workers", 4), + seed=cfg.get("seed", 1024), + shuffle=True, + drop_last=True, + pin_memory=True, + process_group=get_data_parallel_group(), + collate_fn=collate_fn_ignore_none, + ) + if cfg.dataset.type == DEFAULT_DATASET_NAME: + dataloader = prepare_dataloader(**dataloader_args) + total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.get("sp_size", 1) + logger.info("Total batch size: %s", total_batch_size) + num_steps_per_epoch = len(dataloader) + else: + dataloader = prepare_variable_dataloader( + bucket_config=cfg.get("bucket_config", None), + num_bucket_build_workers=cfg.get("num_bucket_build_workers", 1), + **dataloader_args, + ) + num_steps_per_epoch = dataloader.batch_sampler.get_num_batch() + + # ====================================================== + # 3. build model + # ====================================================== + logger.info("Building models...") + # == build text-encoder and vae == + text_encoder = build_module(cfg.text_encoder, MODELS, device=device, dtype=dtype) + vae = build_module(cfg.vae, MODELS).to(device, dtype).eval() + + # == build diffusion model == + input_size = (dataset.num_frames, *dataset.image_size) + latent_size = vae.get_latent_size(input_size) + model = ( + build_module( + cfg.model, + MODELS, + input_size=latent_size, + in_channels=vae.out_channels, + caption_channels=text_encoder.output_dim, + model_max_length=text_encoder.model_max_length, + ) + .to(device, dtype) + .train() + ) + model_numel, model_numel_trainable = get_model_numel(model) + logger.info( + "[Diffusion] Trainable model params: %s, Total model params: %s", + format_numel_str(model_numel_trainable), + format_numel_str(model_numel), + ) + + # ======================================================= + # 4. distributed training preparation with colossalai + # ======================================================= + # == global variables == + start_step = sampler_start_idx = 0 + logger.info("Training for with %s steps per epoch", num_steps_per_epoch) + + if cfg.dataset.type == DEFAULT_DATASET_NAME: + dataloader.sampler.set_start_index(sampler_start_idx) + + # ======================================================= + # 5. training loop + # ======================================================= + dist.barrier() + for epoch in range(1): + # == set dataloader to new epoch == + if cfg.dataset.type == DEFAULT_DATASET_NAME: + dataloader.sampler.set_epoch(epoch) + dataloader_iter = iter(dataloader) + logger.info("Beginning epoch %s...", epoch) + + # == training loop in an epoch == + assert cfg.get("save_dir", None) is not None, "Please specify the save_dir in the config file." + os.makedirs(cfg.save_dir, exist_ok=True) + saver = FeatureSaver(cfg.save_dir) + save_text_features = cfg.get("save_text_features", False) + save_compressed_text_features = cfg.get("save_compressed_text_features", False) + + with tqdm( + enumerate(dataloader_iter, start=start_step), + desc=f"Epoch {epoch}", + initial=start_step, + total=num_steps_per_epoch, + ) as pbar: + for step, batch in pbar: + x = batch.pop("video").to(device, dtype) # [B, C, T, H, W] + y = batch.pop("text") + + x = vae.encode(x).cpu() # [B, C, T, H/P, W/P] + fps = batch["fps"].to(dtype) + batch_dict = {"x": x, "fps": fps} + + if save_text_features: + text_infos = text_encoder.encode(y) + y_feat = text_infos["y"] + y_mask = text_infos["mask"] + if not save_compressed_text_features: + y_feat = y_feat.cpu() + y_mask = y_mask.cpu() + else: + y_feat, y_mask = model.encode_text(y_feat, y_mask) + y_feat = y_feat.cpu() + y_mask = torch.tensor(y_mask) + breakpoint() + batch_dict.update({"y": y_feat, "mask": y_mask}) + + saver.update(batch_dict) + + +if __name__ == "__main__": + main() diff --git a/scripts/misc/extract_vae_feat.py b/scripts/misc/extract_vae_feat.py deleted file mode 100644 index f99d233..0000000 --- a/scripts/misc/extract_vae_feat.py +++ /dev/null @@ -1,284 +0,0 @@ -import os -from copy import deepcopy -from datetime import timedelta -from pprint import pformat - -import torch -import torch.distributed as dist -from colossalai.booster import Booster -from colossalai.cluster import DistCoordinator -from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device, set_seed -from tqdm import tqdm - -import wandb -from opensora.acceleration.checkpoint import set_grad_checkpoint -from opensora.acceleration.parallel_states import get_data_parallel_group -from opensora.datasets import prepare_dataloader, prepare_variable_dataloader -from opensora.datasets.utils import collate_fn_ignore_none -from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module -from opensora.utils.ckpt_utils import load, model_gathering, model_sharding, record_model_param_shape, save -from opensora.utils.config_utils import define_experiment_workspace, parse_configs, save_training_config -from opensora.utils.misc import ( - all_reduce_mean, - create_logger, - create_tensorboard_writer, - format_numel_str, - get_model_numel, - requires_grad, - to_torch_dtype, -) -from opensora.utils.train_utils import MaskGenerator, create_colossalai_plugin, update_ema - -DEFAULT_DATASET_NAME = "VideoTextDataset" - - -def main(): - # ====================================================== - # 1. configs & runtime variables - # ====================================================== - # == parse configs == - cfg = parse_configs(training=False) - - # == device and dtype == - assert torch.cuda.is_available(), "Training currently requires at least one GPU." - cfg_dtype = cfg.get("dtype", "bf16") - assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}" - dtype = to_torch_dtype(cfg.get("dtype", "bf16")) - - # == colossalai init distributed training == - # NOTE: A very large timeout is set to avoid some processes exit early - dist.init_process_group(backend="nccl", timeout=timedelta(hours=24)) - torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) - set_seed(cfg.get("seed", 1024)) - coordinator = DistCoordinator() - device = get_current_device() - - # == init exp_dir == - exp_name, exp_dir = define_experiment_workspace(cfg) - coordinator.block_all() - if coordinator.is_master(): - os.makedirs(exp_dir, exist_ok=True) - save_training_config(cfg.to_dict(), exp_dir) - coordinator.block_all() - - # == init logger, tensorboard & wandb == - logger = create_logger(exp_dir) - logger.info("Experiment directory created at %s", exp_dir) - logger.info("Training configuration:\n %s", pformat(cfg.to_dict())) - if coordinator.is_master(): - tb_writer = create_tensorboard_writer(exp_dir) - if cfg.get("wandb", False): - wandb.init(project="minisora", name=exp_name, config=cfg.to_dict(), dir="./outputs/wandb") - - # == init ColossalAI booster == - plugin = create_colossalai_plugin( - plugin=cfg.get("plugin", "zero2"), - dtype=cfg_dtype, - grad_clip=cfg.get("grad_clip", 0), - sp_size=cfg.get("sp_size", 1), - ) - booster = Booster(plugin=plugin) - - # ====================================================== - # 2. build dataset and dataloader - # ====================================================== - logger.info("Building dataset...") - # == build dataset == - dataset = build_module(cfg.dataset, DATASETS) - logger.info("Dataset contains %s samples.", len(dataset)) - - # == build dataloader == - dataloader_args = dict( - dataset=dataset, - batch_size=cfg.get("batch_size", None), - num_workers=cfg.get("num_workers", 4), - seed=cfg.get("seed", 1024), - shuffle=True, - drop_last=True, - pin_memory=True, - process_group=get_data_parallel_group(), - collate_fn=collate_fn_ignore_none, - ) - if cfg.dataset.type == DEFAULT_DATASET_NAME: - dataloader = prepare_dataloader(**dataloader_args) - total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.get("sp_size", 1) - logger.info("Total batch size: %s", total_batch_size) - num_steps_per_epoch = len(dataloader) - sampler_to_io = None - else: - dataloader = prepare_variable_dataloader( - bucket_config=cfg.get("bucket_config", None), - num_bucket_build_workers=cfg.get("num_bucket_build_workers", 1), - **dataloader_args, - ) - num_steps_per_epoch = dataloader.batch_sampler.get_num_batch() // dist.get_world_size() - sampler_to_io = None if cfg.get("start_from_scratch ", False) else dataloader.batch_sampler - - # ====================================================== - # 3. build model - # ====================================================== - logger.info("Building models...") - # == build text-encoder and vae == - text_encoder = build_module(cfg.text_encoder, MODELS, device=device, dtype=dtype) - vae = build_module(cfg.vae, MODELS).to(device, dtype).eval() - - # == build diffusion model == - input_size = (dataset.num_frames, *dataset.image_size) - latent_size = vae.get_latent_size(input_size) - model = ( - build_module( - cfg.model, - MODELS, - input_size=latent_size, - in_channels=vae.out_channels, - caption_channels=text_encoder.output_dim, - model_max_length=text_encoder.model_max_length, - ) - .to(device, dtype) - .train() - ) - model_numel, model_numel_trainable = get_model_numel(model) - logger.info( - "[Diffusion] Trainable model params: %s, Total model params: %s", - format_numel_str(model_numel_trainable), - format_numel_str(model_numel), - ) - - # == build ema for diffusion model == - ema = deepcopy(model).to(torch.float32).to(device) - requires_grad(ema, False) - ema_shape_dict = record_model_param_shape(ema) - ema.eval() - update_ema(ema, model, decay=0, sharded=False) - - # == setup loss function, build scheduler == - scheduler = build_module(cfg.scheduler, SCHEDULERS) - - # == setup optimizer == - optimizer = HybridAdam( - filter(lambda p: p.requires_grad, model.parameters()), - adamw_mode=True, - lr=cfg.get("lr", 1e-4), - weight_decay=cfg.get("weight_decay", 0), - eps=cfg.get("adam_eps", 1e-8), - ) - lr_scheduler = None - - # == additional preparation == - if cfg.get("grad_checkpoint", False): - set_grad_checkpoint(model) - if cfg.get("mask_ratios", None) is not None: - mask_generator = MaskGenerator(cfg.mask_ratios) - - # ======================================================= - # 4. distributed training preparation with colossalai - # ======================================================= - logger.info("Preparing for distributed training...") - # == boosting == - # NOTE: we set dtype first to make initialization of model consistent with the dtype; then reset it to the fp32 as we make diffusion scheduler in fp32 - torch.set_default_dtype(dtype) - model, optimizer, _, dataloader, lr_scheduler = booster.boost( - model=model, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - dataloader=dataloader, - ) - torch.set_default_dtype(torch.float) - logger.info("Boosting model for distributed training") - - # == global variables == - cfg_epochs = cfg.get("epochs", 1000) - start_epoch = start_step = log_step = sampler_start_idx = acc_step = 0 - running_loss = 0.0 - logger.info("Training for %s epochs with %s steps per epoch", cfg_epochs, num_steps_per_epoch) - - # == resume == - if cfg.get("load", None) is not None: - logger.info("Loading checkpoint") - ret = load( - booster, - cfg.load, - model=model, - ema=ema, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - sampler=sampler_to_io, - ) - if not cfg.get("start_from_scratch ", False): - start_epoch, start_step, sampler_start_idx = ret - logger.info("Loaded checkpoint %s at epoch %s step %s", cfg.load, start_epoch, start_step) - if cfg.dataset.type == DEFAULT_DATASET_NAME: - dataloader.sampler.set_start_index(sampler_start_idx) - - model_sharding(ema) - - # ======================================================= - # 5. training loop - # ======================================================= - dist.barrier() - for epoch in range(1): - # == set dataloader to new epoch == - if cfg.dataset.type == DEFAULT_DATASET_NAME: - dataloader.sampler.set_epoch(epoch) - dataloader_iter = iter(dataloader) - logger.info("Beginning epoch %s...", epoch) - - # == training loop in an epoch == - save_root = f'/mnt/nfs-207/sora_data/webvid-10M/feat' - data_list = [] - cnt = 0 - max_len = 10 - - with tqdm( - enumerate(dataloader_iter, start=start_step), - desc=f"Epoch {epoch}", - disable=not coordinator.is_master(), - initial=start_step, - total=num_steps_per_epoch, - ) as pbar: - for step, batch in pbar: - x = batch.pop("video").to(device, dtype) # [B, C, T, H, W] - y = batch.pop("text") - - # == visual and text encoding == - with torch.no_grad(): - # Prepare visual inputs - x = vae.encode(x) # [B, C, T, H/P, W/P] - # Prepare text inputs - - model_args = {} - # == video meta info == - for k, v in batch.items(): - model_args[k] = v.to(device, dtype) - - # keys to save: ['x', 'y', 'mask', 'fps'] - data_i = { - 'x': x.detach().cpu(), - 'fps': model_args['fps'].detach().cpu(), - } - - save_text_feat = [True, False] - if save_text_feat: - with torch.no_grad(): - model_args_text = text_encoder.encode(y) - data_i.update({ - 'y': model_args_text['y'].detach().cpu(), - 'mask': model_args_text['mask'].detach().cpu(), - }) - breakpoint() - else: - data_i['text'] = y - - data_list.append(data_i) - - if len(data_list) == max_len: - save_path = os.path.join(save_root, f'data/{cnt}.bin') - torch.save(data_list, save_path) - - data_list = [] - cnt += 1 - - -if __name__ == "__main__": - main()