complete masked training

This commit is contained in:
Zangwei Zheng 2024-03-23 22:06:19 +08:00
parent 9c81cd61bc
commit 4e6e17d800
7 changed files with 137 additions and 23 deletions

View file

@ -37,7 +37,7 @@ prompt = [
"In an ornate, historical hall, a massive tidal wave peaks and begins to crash. Two surfers, seizing the moment, skillfully navigate the face of the wave."
]
loop = 10
loop = 5
condition_frame_length = 4
reference_path = ["assets/images/condition/wave.png"]
mask_strategy = ["0,0,0,1,0"] # valid when reference_path is not None

View file

@ -0,0 +1,56 @@
num_frames = 16
frame_interval = 3
image_size = (512, 512)
# Define dataset
root = None
data_path = "CSV_PATH"
use_image_transform = False
num_workers = 4
# Define acceleration
dtype = "bf16"
grad_checkpoint = True
plugin = "zero2"
sp_size = 1
# Define model
model = dict(
type="STDiT-XL/2",
space_scale=1.0,
time_scale=1.0,
use_x_mask=True,
from_pretrained=None,
enable_flashattn=True,
enable_layernorm_kernel=True,
)
mask_ratios = [0.5, 0.29, 0.07, 0.07, 0.07]
vae = dict(
type="VideoAutoencoderKL",
from_pretrained="stabilityai/sd-vae-ft-ema",
micro_batch_size=128,
)
text_encoder = dict(
type="t5",
from_pretrained="DeepFloyd/t5-v1_1-xxl",
model_max_length=120,
shardformer=True,
)
scheduler = dict(
type="iddpm",
timestep_respacing="",
)
# Others
seed = 42
outputs = "outputs"
wandb = False
epochs = 1000
log_every = 10
ckpt_every = 500
load = None
batch_size = 8
lr = 2e-5
grad_clip = 1.0

View file

@ -395,7 +395,6 @@ class GaussianDiffusion:
- 'sample': a random sample from the model.
- 'pred_xstart': a prediction of x_0.
"""
model_kwargs['x_mask'] = mask
out = self.p_mean_variance(
model,
x,
@ -673,7 +672,7 @@ class GaussianDiffusion:
yield out
img = out["sample"]
def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None):
def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None, mask=None):
"""
Get a term for the variational lower-bound.
The resulting units are bits (rather than nats, as one might expect).
@ -685,20 +684,20 @@ class GaussianDiffusion:
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)
out = self.p_mean_variance(model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs)
kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"])
kl = mean_flat(kl) / np.log(2.0)
kl = mean_flat(kl, mask=mask) / np.log(2.0)
decoder_nll = -discretized_gaussian_log_likelihood(
x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
)
assert decoder_nll.shape == x_start.shape
decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
decoder_nll = mean_flat(decoder_nll, mask=mask) / np.log(2.0)
# At the first timestep return the decoder NLL,
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
output = th.where((t == 0), decoder_nll, kl)
return {"output": output, "pred_xstart": out["pred_xstart"]}
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, mask=None):
"""
Compute training losses for a single timestep.
:param model: the model to evaluate loss on.
@ -715,10 +714,13 @@ class GaussianDiffusion:
if noise is None:
noise = th.randn_like(x_start)
x_t = self.q_sample(x_start, t, noise=noise)
if mask is not None:
x_t = th.where(mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1), x_t, x_start)
terms = {}
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
assert mask is None, "mask not supported for KL loss"
terms["loss"] = self._vb_terms_bpd(
model=model,
x_start=x_start,
@ -748,6 +750,7 @@ class GaussianDiffusion:
x_t=x_t,
t=t,
clip_denoised=False,
mask=mask,
)["output"]
if self.loss_type == LossType.RESCALED_MSE:
# Divide by 1000 for equivalence with initial implementation.
@ -760,7 +763,7 @@ class GaussianDiffusion:
ModelMeanType.EPSILON: noise,
}[self.model_mean_type]
assert model_output.shape == target.shape == x_start.shape
terms["mse"] = mean_flat((target - model_output) ** 2)
terms["mse"] = mean_flat((target - model_output) ** 2, mask=mask)
if "vb" in terms:
terms["loss"] = terms["mse"] + terms["vb"]
else:

View file

@ -48,26 +48,27 @@ def merge_args(cfg, args, training=False):
cfg.model["from_pretrained"] = args.ckpt_path
args.ckpt_path = None
if not training:
if args.cfg_scale is not None:
cfg.scheduler["cfg_scale"] = args.cfg_scale
args.cfg_scale = None
for k, v in vars(args).items():
if k in cfg and v is not None:
cfg[k] = v
if "reference_path" not in cfg:
cfg["reference_path"] = None
if "loop" not in cfg:
cfg["loop"] = 1
if not training:
# Inference only
if "reference_path" not in cfg:
cfg["reference_path"] = None
if "loop" not in cfg:
cfg["loop"] = 1
if "prompt" not in cfg or cfg["prompt"] is None:
assert cfg["prompt_path"] is not None, "prompt or prompt_path must be provided"
cfg["prompt"] = load_prompts(cfg["prompt_path"])
else:
# Training only
if "mask_ratios" not in cfg:
cfg["mask_ratios"] = None
# Both training and inference
if "multi_resolution" not in cfg:
cfg["multi_resolution"] = False
if "mask_ratios" not in cfg:
cfg["mask_ratios"] = None
if "prompt" not in cfg or cfg["prompt"] is None:
assert cfg["prompt_path"] is not None, "prompt or prompt_path must be provided"
cfg["prompt"] = load_prompts(cfg["prompt_path"])
return cfg

View file

@ -1,3 +1,4 @@
import random
from collections import OrderedDict
import torch
@ -29,3 +30,45 @@ def update_ema(
else:
param_data = param.data
ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay)
class MaskGenerator:
def __init__(self, mask_ratios):
self.mask_name = ["mask_no", "mask_random", "mask_head", "mask_tail", "mask_head_tail"]
self.mask_prob = mask_ratios
print(self.mask_prob)
self.mask_acc_prob = [sum(self.mask_prob[: i + 1]) for i in range(len(self.mask_prob))]
def get_mask(self, x):
mask_type = random.random()
for i, acc_prob in enumerate(self.mask_acc_prob):
if mask_type <= acc_prob:
mask_name = self.mask_name[i]
break
mask = torch.ones(x.shape[2], dtype=torch.bool, device=x.device)
if mask_name == "mask_random":
random_size = random.randint(1, 4)
random_pos = random.randint(0, x.shape[2] - random_size)
mask[random_pos : random_pos + random_size] = 0
return mask
elif mask_name == "mask_head":
random_size = random.randint(1, 4)
mask[:random_size] = 0
elif mask_name == "mask_tail":
random_size = random.randint(1, 4)
mask[-random_size:] = 0
elif mask_name == "mask_head_tail":
random_size = random.randint(1, 4)
mask[:random_size] = 0
mask[-random_size:] = 0
return mask
def get_masks(self, x):
masks = []
for _ in range(len(x)):
mask = self.get_mask(x)
masks.append(mask)
masks = torch.stack(masks, dim=0)
return masks

View file

@ -176,6 +176,7 @@ def main():
j
] += f";{loop_i},{len(refs)-1},-{cfg.condition_frame_length},{cfg.condition_frame_length},0"
masks = apply_mask_strategy(z, refs_x, mask_strategy, loop_i)
model_args["x_mask"] = masks
# 4.6. diffusion sampling
samples = scheduler.sample(

View file

@ -28,7 +28,7 @@ from opensora.utils.config_utils import (
save_training_config,
)
from opensora.utils.misc import all_reduce_mean, format_numel_str, get_model_numel, requires_grad, to_torch_dtype
from opensora.utils.train_utils import update_ema
from opensora.utils.train_utils import update_ema, MaskGenerator
def main():
@ -168,6 +168,8 @@ def main():
model.train()
update_ema(ema, model, decay=0, sharded=False)
ema.eval()
if cfg.mask_ratios is not None:
mask_generator = MaskGenerator(cfg.mask_ratios)
# =======================================================
# 5. boost model for distributed training with colossalai
@ -214,15 +216,23 @@ def main():
x = batch["video"].to(device, dtype) # [B, C, T, H, W]
y = batch["text"]
# Visual and text encoding
with torch.no_grad():
# Prepare visual inputs
x = vae.encode(x) # [B, C, T, H/P, W/P]
# Prepare text inputs
model_args = text_encoder.encode(y)
# Mask
if cfg.mask_ratios is not None:
mask = mask_generator.get_masks(x)
model_args["x_mask"] = mask
else:
mask = None
# Diffusion
t = torch.randint(0, scheduler.num_timesteps, (x.shape[0],), device=device)
loss_dict = scheduler.training_losses(model, x, t, model_args)
loss_dict = scheduler.training_losses(model, x, t, model_args, mask=mask)
# Backward & update
loss = loss_dict["loss"].mean()