mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 13:14:44 +02:00
complete masked training
This commit is contained in:
parent
9c81cd61bc
commit
4e6e17d800
|
|
@ -37,7 +37,7 @@ prompt = [
|
|||
"In an ornate, historical hall, a massive tidal wave peaks and begins to crash. Two surfers, seizing the moment, skillfully navigate the face of the wave."
|
||||
]
|
||||
|
||||
loop = 10
|
||||
loop = 5
|
||||
condition_frame_length = 4
|
||||
reference_path = ["assets/images/condition/wave.png"]
|
||||
mask_strategy = ["0,0,0,1,0"] # valid when reference_path is not None
|
||||
|
|
|
|||
56
configs/opensora/train/16x512x512-mask.py
Normal file
56
configs/opensora/train/16x512x512-mask.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
num_frames = 16
|
||||
frame_interval = 3
|
||||
image_size = (512, 512)
|
||||
|
||||
# Define dataset
|
||||
root = None
|
||||
data_path = "CSV_PATH"
|
||||
use_image_transform = False
|
||||
num_workers = 4
|
||||
|
||||
# Define acceleration
|
||||
dtype = "bf16"
|
||||
grad_checkpoint = True
|
||||
plugin = "zero2"
|
||||
sp_size = 1
|
||||
|
||||
# Define model
|
||||
model = dict(
|
||||
type="STDiT-XL/2",
|
||||
space_scale=1.0,
|
||||
time_scale=1.0,
|
||||
use_x_mask=True,
|
||||
from_pretrained=None,
|
||||
enable_flashattn=True,
|
||||
enable_layernorm_kernel=True,
|
||||
)
|
||||
mask_ratios = [0.5, 0.29, 0.07, 0.07, 0.07]
|
||||
vae = dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="stabilityai/sd-vae-ft-ema",
|
||||
micro_batch_size=128,
|
||||
)
|
||||
text_encoder = dict(
|
||||
type="t5",
|
||||
from_pretrained="DeepFloyd/t5-v1_1-xxl",
|
||||
model_max_length=120,
|
||||
shardformer=True,
|
||||
)
|
||||
scheduler = dict(
|
||||
type="iddpm",
|
||||
timestep_respacing="",
|
||||
)
|
||||
|
||||
# Others
|
||||
seed = 42
|
||||
outputs = "outputs"
|
||||
wandb = False
|
||||
|
||||
epochs = 1000
|
||||
log_every = 10
|
||||
ckpt_every = 500
|
||||
load = None
|
||||
|
||||
batch_size = 8
|
||||
lr = 2e-5
|
||||
grad_clip = 1.0
|
||||
|
|
@ -395,7 +395,6 @@ class GaussianDiffusion:
|
|||
- 'sample': a random sample from the model.
|
||||
- 'pred_xstart': a prediction of x_0.
|
||||
"""
|
||||
model_kwargs['x_mask'] = mask
|
||||
out = self.p_mean_variance(
|
||||
model,
|
||||
x,
|
||||
|
|
@ -673,7 +672,7 @@ class GaussianDiffusion:
|
|||
yield out
|
||||
img = out["sample"]
|
||||
|
||||
def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None):
|
||||
def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None, mask=None):
|
||||
"""
|
||||
Get a term for the variational lower-bound.
|
||||
The resulting units are bits (rather than nats, as one might expect).
|
||||
|
|
@ -685,20 +684,20 @@ class GaussianDiffusion:
|
|||
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)
|
||||
out = self.p_mean_variance(model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs)
|
||||
kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"])
|
||||
kl = mean_flat(kl) / np.log(2.0)
|
||||
kl = mean_flat(kl, mask=mask) / np.log(2.0)
|
||||
|
||||
decoder_nll = -discretized_gaussian_log_likelihood(
|
||||
x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
|
||||
)
|
||||
assert decoder_nll.shape == x_start.shape
|
||||
decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
|
||||
decoder_nll = mean_flat(decoder_nll, mask=mask) / np.log(2.0)
|
||||
|
||||
# At the first timestep return the decoder NLL,
|
||||
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
|
||||
output = th.where((t == 0), decoder_nll, kl)
|
||||
return {"output": output, "pred_xstart": out["pred_xstart"]}
|
||||
|
||||
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
|
||||
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, mask=None):
|
||||
"""
|
||||
Compute training losses for a single timestep.
|
||||
:param model: the model to evaluate loss on.
|
||||
|
|
@ -715,10 +714,13 @@ class GaussianDiffusion:
|
|||
if noise is None:
|
||||
noise = th.randn_like(x_start)
|
||||
x_t = self.q_sample(x_start, t, noise=noise)
|
||||
if mask is not None:
|
||||
x_t = th.where(mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1), x_t, x_start)
|
||||
|
||||
terms = {}
|
||||
|
||||
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
|
||||
assert mask is None, "mask not supported for KL loss"
|
||||
terms["loss"] = self._vb_terms_bpd(
|
||||
model=model,
|
||||
x_start=x_start,
|
||||
|
|
@ -748,6 +750,7 @@ class GaussianDiffusion:
|
|||
x_t=x_t,
|
||||
t=t,
|
||||
clip_denoised=False,
|
||||
mask=mask,
|
||||
)["output"]
|
||||
if self.loss_type == LossType.RESCALED_MSE:
|
||||
# Divide by 1000 for equivalence with initial implementation.
|
||||
|
|
@ -760,7 +763,7 @@ class GaussianDiffusion:
|
|||
ModelMeanType.EPSILON: noise,
|
||||
}[self.model_mean_type]
|
||||
assert model_output.shape == target.shape == x_start.shape
|
||||
terms["mse"] = mean_flat((target - model_output) ** 2)
|
||||
terms["mse"] = mean_flat((target - model_output) ** 2, mask=mask)
|
||||
if "vb" in terms:
|
||||
terms["loss"] = terms["mse"] + terms["vb"]
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -48,26 +48,27 @@ def merge_args(cfg, args, training=False):
|
|||
cfg.model["from_pretrained"] = args.ckpt_path
|
||||
args.ckpt_path = None
|
||||
|
||||
if not training:
|
||||
if args.cfg_scale is not None:
|
||||
cfg.scheduler["cfg_scale"] = args.cfg_scale
|
||||
args.cfg_scale = None
|
||||
|
||||
for k, v in vars(args).items():
|
||||
if k in cfg and v is not None:
|
||||
cfg[k] = v
|
||||
|
||||
if "reference_path" not in cfg:
|
||||
cfg["reference_path"] = None
|
||||
if "loop" not in cfg:
|
||||
cfg["loop"] = 1
|
||||
if not training:
|
||||
# Inference only
|
||||
if "reference_path" not in cfg:
|
||||
cfg["reference_path"] = None
|
||||
if "loop" not in cfg:
|
||||
cfg["loop"] = 1
|
||||
if "prompt" not in cfg or cfg["prompt"] is None:
|
||||
assert cfg["prompt_path"] is not None, "prompt or prompt_path must be provided"
|
||||
cfg["prompt"] = load_prompts(cfg["prompt_path"])
|
||||
else:
|
||||
# Training only
|
||||
if "mask_ratios" not in cfg:
|
||||
cfg["mask_ratios"] = None
|
||||
|
||||
# Both training and inference
|
||||
if "multi_resolution" not in cfg:
|
||||
cfg["multi_resolution"] = False
|
||||
if "mask_ratios" not in cfg:
|
||||
cfg["mask_ratios"] = None
|
||||
if "prompt" not in cfg or cfg["prompt"] is None:
|
||||
assert cfg["prompt_path"] is not None, "prompt or prompt_path must be provided"
|
||||
cfg["prompt"] = load_prompts(cfg["prompt_path"])
|
||||
|
||||
return cfg
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import random
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
|
|
@ -29,3 +30,45 @@ def update_ema(
|
|||
else:
|
||||
param_data = param.data
|
||||
ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay)
|
||||
|
||||
|
||||
class MaskGenerator:
|
||||
def __init__(self, mask_ratios):
|
||||
self.mask_name = ["mask_no", "mask_random", "mask_head", "mask_tail", "mask_head_tail"]
|
||||
self.mask_prob = mask_ratios
|
||||
print(self.mask_prob)
|
||||
self.mask_acc_prob = [sum(self.mask_prob[: i + 1]) for i in range(len(self.mask_prob))]
|
||||
|
||||
def get_mask(self, x):
|
||||
mask_type = random.random()
|
||||
for i, acc_prob in enumerate(self.mask_acc_prob):
|
||||
if mask_type <= acc_prob:
|
||||
mask_name = self.mask_name[i]
|
||||
break
|
||||
|
||||
mask = torch.ones(x.shape[2], dtype=torch.bool, device=x.device)
|
||||
if mask_name == "mask_random":
|
||||
random_size = random.randint(1, 4)
|
||||
random_pos = random.randint(0, x.shape[2] - random_size)
|
||||
mask[random_pos : random_pos + random_size] = 0
|
||||
return mask
|
||||
elif mask_name == "mask_head":
|
||||
random_size = random.randint(1, 4)
|
||||
mask[:random_size] = 0
|
||||
elif mask_name == "mask_tail":
|
||||
random_size = random.randint(1, 4)
|
||||
mask[-random_size:] = 0
|
||||
elif mask_name == "mask_head_tail":
|
||||
random_size = random.randint(1, 4)
|
||||
mask[:random_size] = 0
|
||||
mask[-random_size:] = 0
|
||||
|
||||
return mask
|
||||
|
||||
def get_masks(self, x):
|
||||
masks = []
|
||||
for _ in range(len(x)):
|
||||
mask = self.get_mask(x)
|
||||
masks.append(mask)
|
||||
masks = torch.stack(masks, dim=0)
|
||||
return masks
|
||||
|
|
|
|||
|
|
@ -176,6 +176,7 @@ def main():
|
|||
j
|
||||
] += f";{loop_i},{len(refs)-1},-{cfg.condition_frame_length},{cfg.condition_frame_length},0"
|
||||
masks = apply_mask_strategy(z, refs_x, mask_strategy, loop_i)
|
||||
model_args["x_mask"] = masks
|
||||
|
||||
# 4.6. diffusion sampling
|
||||
samples = scheduler.sample(
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ from opensora.utils.config_utils import (
|
|||
save_training_config,
|
||||
)
|
||||
from opensora.utils.misc import all_reduce_mean, format_numel_str, get_model_numel, requires_grad, to_torch_dtype
|
||||
from opensora.utils.train_utils import update_ema
|
||||
from opensora.utils.train_utils import update_ema, MaskGenerator
|
||||
|
||||
|
||||
def main():
|
||||
|
|
@ -168,6 +168,8 @@ def main():
|
|||
model.train()
|
||||
update_ema(ema, model, decay=0, sharded=False)
|
||||
ema.eval()
|
||||
if cfg.mask_ratios is not None:
|
||||
mask_generator = MaskGenerator(cfg.mask_ratios)
|
||||
|
||||
# =======================================================
|
||||
# 5. boost model for distributed training with colossalai
|
||||
|
|
@ -214,15 +216,23 @@ def main():
|
|||
x = batch["video"].to(device, dtype) # [B, C, T, H, W]
|
||||
y = batch["text"]
|
||||
|
||||
# Visual and text encoding
|
||||
with torch.no_grad():
|
||||
# Prepare visual inputs
|
||||
x = vae.encode(x) # [B, C, T, H/P, W/P]
|
||||
# Prepare text inputs
|
||||
model_args = text_encoder.encode(y)
|
||||
|
||||
# Mask
|
||||
if cfg.mask_ratios is not None:
|
||||
mask = mask_generator.get_masks(x)
|
||||
model_args["x_mask"] = mask
|
||||
else:
|
||||
mask = None
|
||||
|
||||
# Diffusion
|
||||
t = torch.randint(0, scheduler.num_timesteps, (x.shape[0],), device=device)
|
||||
loss_dict = scheduler.training_losses(model, x, t, model_args)
|
||||
loss_dict = scheduler.training_losses(model, x, t, model_args, mask=mask)
|
||||
|
||||
# Backward & update
|
||||
loss = loss_dict["loss"].mean()
|
||||
|
|
|
|||
Loading…
Reference in a new issue