diff --git a/configs/opensora/inference_long/16x512x512-long.py b/configs/opensora/inference_long/16x512x512-long.py index 2bbb766..58c9e7a 100644 --- a/configs/opensora/inference_long/16x512x512-long.py +++ b/configs/opensora/inference_long/16x512x512-long.py @@ -37,7 +37,7 @@ prompt = [ "In an ornate, historical hall, a massive tidal wave peaks and begins to crash. Two surfers, seizing the moment, skillfully navigate the face of the wave." ] -loop = 10 +loop = 5 condition_frame_length = 4 reference_path = ["assets/images/condition/wave.png"] mask_strategy = ["0,0,0,1,0"] # valid when reference_path is not None diff --git a/configs/opensora/train/16x512x512-mask.py b/configs/opensora/train/16x512x512-mask.py new file mode 100644 index 0000000..7c8ffc1 --- /dev/null +++ b/configs/opensora/train/16x512x512-mask.py @@ -0,0 +1,56 @@ +num_frames = 16 +frame_interval = 3 +image_size = (512, 512) + +# Define dataset +root = None +data_path = "CSV_PATH" +use_image_transform = False +num_workers = 4 + +# Define acceleration +dtype = "bf16" +grad_checkpoint = True +plugin = "zero2" +sp_size = 1 + +# Define model +model = dict( + type="STDiT-XL/2", + space_scale=1.0, + time_scale=1.0, + use_x_mask=True, + from_pretrained=None, + enable_flashattn=True, + enable_layernorm_kernel=True, +) +mask_ratios = [0.5, 0.29, 0.07, 0.07, 0.07] +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", + micro_batch_size=128, +) +text_encoder = dict( + type="t5", + from_pretrained="DeepFloyd/t5-v1_1-xxl", + model_max_length=120, + shardformer=True, +) +scheduler = dict( + type="iddpm", + timestep_respacing="", +) + +# Others +seed = 42 +outputs = "outputs" +wandb = False + +epochs = 1000 +log_every = 10 +ckpt_every = 500 +load = None + +batch_size = 8 +lr = 2e-5 +grad_clip = 1.0 diff --git a/opensora/schedulers/iddpm/gaussian_diffusion.py b/opensora/schedulers/iddpm/gaussian_diffusion.py index 914d4d3..da233c3 100644 --- a/opensora/schedulers/iddpm/gaussian_diffusion.py +++ b/opensora/schedulers/iddpm/gaussian_diffusion.py @@ -395,7 +395,6 @@ class GaussianDiffusion: - 'sample': a random sample from the model. - 'pred_xstart': a prediction of x_0. """ - model_kwargs['x_mask'] = mask out = self.p_mean_variance( model, x, @@ -673,7 +672,7 @@ class GaussianDiffusion: yield out img = out["sample"] - def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None): + def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None, mask=None): """ Get a term for the variational lower-bound. The resulting units are bits (rather than nats, as one might expect). @@ -685,20 +684,20 @@ class GaussianDiffusion: true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t) out = self.p_mean_variance(model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs) kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]) - kl = mean_flat(kl) / np.log(2.0) + kl = mean_flat(kl, mask=mask) / np.log(2.0) decoder_nll = -discretized_gaussian_log_likelihood( x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] ) assert decoder_nll.shape == x_start.shape - decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + decoder_nll = mean_flat(decoder_nll, mask=mask) / np.log(2.0) # At the first timestep return the decoder NLL, # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) output = th.where((t == 0), decoder_nll, kl) return {"output": output, "pred_xstart": out["pred_xstart"]} - def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): + def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, mask=None): """ Compute training losses for a single timestep. :param model: the model to evaluate loss on. @@ -715,10 +714,13 @@ class GaussianDiffusion: if noise is None: noise = th.randn_like(x_start) x_t = self.q_sample(x_start, t, noise=noise) + if mask is not None: + x_t = th.where(mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1), x_t, x_start) terms = {} if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + assert mask is None, "mask not supported for KL loss" terms["loss"] = self._vb_terms_bpd( model=model, x_start=x_start, @@ -748,6 +750,7 @@ class GaussianDiffusion: x_t=x_t, t=t, clip_denoised=False, + mask=mask, )["output"] if self.loss_type == LossType.RESCALED_MSE: # Divide by 1000 for equivalence with initial implementation. @@ -760,7 +763,7 @@ class GaussianDiffusion: ModelMeanType.EPSILON: noise, }[self.model_mean_type] assert model_output.shape == target.shape == x_start.shape - terms["mse"] = mean_flat((target - model_output) ** 2) + terms["mse"] = mean_flat((target - model_output) ** 2, mask=mask) if "vb" in terms: terms["loss"] = terms["mse"] + terms["vb"] else: diff --git a/opensora/utils/config_utils.py b/opensora/utils/config_utils.py index 89c7fe8..350cc8e 100644 --- a/opensora/utils/config_utils.py +++ b/opensora/utils/config_utils.py @@ -48,26 +48,27 @@ def merge_args(cfg, args, training=False): cfg.model["from_pretrained"] = args.ckpt_path args.ckpt_path = None - if not training: - if args.cfg_scale is not None: - cfg.scheduler["cfg_scale"] = args.cfg_scale - args.cfg_scale = None - for k, v in vars(args).items(): if k in cfg and v is not None: cfg[k] = v - if "reference_path" not in cfg: - cfg["reference_path"] = None - if "loop" not in cfg: - cfg["loop"] = 1 + if not training: + # Inference only + if "reference_path" not in cfg: + cfg["reference_path"] = None + if "loop" not in cfg: + cfg["loop"] = 1 + if "prompt" not in cfg or cfg["prompt"] is None: + assert cfg["prompt_path"] is not None, "prompt or prompt_path must be provided" + cfg["prompt"] = load_prompts(cfg["prompt_path"]) + else: + # Training only + if "mask_ratios" not in cfg: + cfg["mask_ratios"] = None + + # Both training and inference if "multi_resolution" not in cfg: cfg["multi_resolution"] = False - if "mask_ratios" not in cfg: - cfg["mask_ratios"] = None - if "prompt" not in cfg or cfg["prompt"] is None: - assert cfg["prompt_path"] is not None, "prompt or prompt_path must be provided" - cfg["prompt"] = load_prompts(cfg["prompt_path"]) return cfg diff --git a/opensora/utils/train_utils.py b/opensora/utils/train_utils.py index f846043..4a64aa9 100644 --- a/opensora/utils/train_utils.py +++ b/opensora/utils/train_utils.py @@ -1,3 +1,4 @@ +import random from collections import OrderedDict import torch @@ -29,3 +30,45 @@ def update_ema( else: param_data = param.data ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay) + + +class MaskGenerator: + def __init__(self, mask_ratios): + self.mask_name = ["mask_no", "mask_random", "mask_head", "mask_tail", "mask_head_tail"] + self.mask_prob = mask_ratios + print(self.mask_prob) + self.mask_acc_prob = [sum(self.mask_prob[: i + 1]) for i in range(len(self.mask_prob))] + + def get_mask(self, x): + mask_type = random.random() + for i, acc_prob in enumerate(self.mask_acc_prob): + if mask_type <= acc_prob: + mask_name = self.mask_name[i] + break + + mask = torch.ones(x.shape[2], dtype=torch.bool, device=x.device) + if mask_name == "mask_random": + random_size = random.randint(1, 4) + random_pos = random.randint(0, x.shape[2] - random_size) + mask[random_pos : random_pos + random_size] = 0 + return mask + elif mask_name == "mask_head": + random_size = random.randint(1, 4) + mask[:random_size] = 0 + elif mask_name == "mask_tail": + random_size = random.randint(1, 4) + mask[-random_size:] = 0 + elif mask_name == "mask_head_tail": + random_size = random.randint(1, 4) + mask[:random_size] = 0 + mask[-random_size:] = 0 + + return mask + + def get_masks(self, x): + masks = [] + for _ in range(len(x)): + mask = self.get_mask(x) + masks.append(mask) + masks = torch.stack(masks, dim=0) + return masks diff --git a/scripts/inference_long.py b/scripts/inference_long.py index 7f59f6c..4fddaf0 100644 --- a/scripts/inference_long.py +++ b/scripts/inference_long.py @@ -176,6 +176,7 @@ def main(): j ] += f";{loop_i},{len(refs)-1},-{cfg.condition_frame_length},{cfg.condition_frame_length},0" masks = apply_mask_strategy(z, refs_x, mask_strategy, loop_i) + model_args["x_mask"] = masks # 4.6. diffusion sampling samples = scheduler.sample( diff --git a/scripts/train.py b/scripts/train.py index 9f611b7..1c61d6e 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -28,7 +28,7 @@ from opensora.utils.config_utils import ( save_training_config, ) from opensora.utils.misc import all_reduce_mean, format_numel_str, get_model_numel, requires_grad, to_torch_dtype -from opensora.utils.train_utils import update_ema +from opensora.utils.train_utils import update_ema, MaskGenerator def main(): @@ -168,6 +168,8 @@ def main(): model.train() update_ema(ema, model, decay=0, sharded=False) ema.eval() + if cfg.mask_ratios is not None: + mask_generator = MaskGenerator(cfg.mask_ratios) # ======================================================= # 5. boost model for distributed training with colossalai @@ -214,15 +216,23 @@ def main(): x = batch["video"].to(device, dtype) # [B, C, T, H, W] y = batch["text"] + # Visual and text encoding with torch.no_grad(): # Prepare visual inputs x = vae.encode(x) # [B, C, T, H/P, W/P] # Prepare text inputs model_args = text_encoder.encode(y) + + # Mask + if cfg.mask_ratios is not None: + mask = mask_generator.get_masks(x) + model_args["x_mask"] = mask + else: + mask = None # Diffusion t = torch.randint(0, scheduler.num_timesteps, (x.shape[0],), device=device) - loss_dict = scheduler.training_losses(model, x, t, model_args) + loss_dict = scheduler.training_losses(model, x, t, model_args, mask=mask) # Backward & update loss = loss_dict["loss"].mean()