This commit is contained in:
Zangwei Zheng 2024-03-26 17:15:59 +08:00
parent 552b7e8f79
commit 3a0b85456c
3 changed files with 17 additions and 19 deletions

View file

@ -1,14 +1,14 @@
num_frames = 16
frame_interval = 3
image_size = (256, 256)
# Define dataset
root = None
data_path = "CSV_PATH"
use_image_transform = False
num_workers = 4
dataset = dict(
type="VideoTextDataset",
data_path=None,
num_frames=16,
frame_interval=3,
image_size=(256, 256),
)
# Define acceleration
num_workers = 4
dtype = "bf16"
grad_checkpoint = True
plugin = "zero2"
@ -23,7 +23,7 @@ model = dict(
enable_flashattn=True,
enable_layernorm_kernel=True,
)
# mask_ratios = [0.5, 0.29, 0.07, 0.07, 0.07]
mask_ratios = [0.5, 0.29, 0.07, 0.07, 0.07]
vae = dict(
type="VideoAutoencoderKL",
from_pretrained="stabilityai/sd-vae-ft-ema",

View file

@ -684,16 +684,13 @@ class GaussianDiffusion:
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)
out = self.p_mean_variance(model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs)
kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"])
kl = mean_flat(kl, mask=mask) / np.log(2.0)
decoder_nll = -discretized_gaussian_log_likelihood(
x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
)
assert decoder_nll.shape == x_start.shape
if mask is not None:
kl = th.where(mask[:, None, :, None, None], kl, decoder_nll)
kl = mean_flat(kl) / np.log(2.0)
decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
decoder_nll = mean_flat(decoder_nll, mask=mask) / np.log(2.0)
# At the first timestep return the decoder NLL,
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
@ -769,10 +766,10 @@ class GaussianDiffusion:
}[self.model_mean_type]
assert model_output.shape == target.shape == x_start.shape
if weights is None:
terms["mse"] = mean_flat((target - model_output) ** 2)
terms["mse"] = mean_flat((target - model_output) ** 2, mask=mask)
else:
weight = _extract_into_tensor(weights, t, target.shape)
terms["mse"] = mean_flat(weight * (target - model_output) ** 2)
terms["mse"] = mean_flat(weight * (target - model_output) ** 2, mask=mask)
if "vb" in terms:
terms["loss"] = terms["mse"] + terms["vb"]
else:

View file

@ -47,9 +47,7 @@ def merge_args(cfg, args, training=False):
if args.ckpt_path is not None:
cfg.model["from_pretrained"] = args.ckpt_path
args.ckpt_path = None
if args.data_path is not None:
cfg.dataset["data_path"] = args.data_path
args.data_path = None
for k, v in vars(args).items():
if k in cfg and v is not None:
@ -66,6 +64,9 @@ def merge_args(cfg, args, training=False):
cfg["prompt"] = load_prompts(cfg["prompt_path"])
else:
# Training only
if args.data_path is not None:
cfg.dataset["data_path"] = args.data_path
args.data_path = None
if "mask_ratios" not in cfg:
cfg["mask_ratios"] = None