This commit is contained in:
Shen-Chenhui 2024-04-27 20:57:48 +08:00
parent 0198ea8a52
commit 52e869079c
2 changed files with 24 additions and 18 deletions

View file

@ -88,12 +88,8 @@ magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128
3-6 epochs for pexel, from pexel observation its correct
'''
epochs = 10
log_every = 1
ckpt_every = 500
load = None
batch_size = 4
batch_size = 1
lr = 1e-4
grad_clip = 1.0

View file

@ -9,7 +9,7 @@ from mmengine.runner import set_random_seed
from opensora.acceleration.parallel_states import set_sequence_parallel_group
from opensora.datasets import save_sample
from opensora.registry import MODELS, SCHEDULERS, build_module
from opensora.utils.config_utils import parse_configs
from opensora.utils.config_utils import parse_configs, load_json
from opensora.utils.misc import to_torch_dtype
from opensora.datasets import prepare_dataloader, prepare_variable_dataloader
from opensora.registry import DATASETS, MODELS, build_module
@ -19,7 +19,7 @@ from opensora.acceleration.parallel_states import (
set_sequence_parallel_group,
)
from tqdm import tqdm
from opensora.models.vae.vae_3d_v2 import VEALoss, DiscriminatorLoss, AdversarialLoss, pad_at_dim
from opensora.models.vae.vae_3d_v2 import VEALoss, DiscriminatorLoss, AdversarialLoss, LeCamEMA, pad_at_dim
from einops import rearrange
from colossalai.utils import get_current_device
@ -129,9 +129,24 @@ def main():
lecam_loss_weight = cfg.lecam_loss_weight,
gradient_penalty_loss_weight = cfg.gradient_penalty_loss_weight,
)
# LeCam EMA for discriminator
lecam_path = os.path.join(cfg.load, "lecam_states.json")
if cfg.lecam_loss_weight is not None and os.path.exists(lecam_path):
lecam_state = load_json(lecam_path)
lecam_ema_real, lecam_ema_fake = lecam_state["lecam_ema_real"], lecam_state["lecam_ema_fake"]
lecam_ema = LeCamEMA(
decay=cfg.ema_decay, ema_real=lecam_ema_real, ema_fake=lecam_ema_fake, dtype=dtype, device=device
)
else:
lecam_ema = LeCamEMA(
decay=cfg.ema_decay, dtype=dtype, device=device
)
running_loss = 0.0
running_nll = 0.0
running_disc_loss = 0.0
loss_steps = 0
@ -142,9 +157,6 @@ def main():
disc_time_padding = 0
video_contains_first_frame = cfg.video_contains_first_frame
lecam_ema_real = torch.tensor(0.0)
lecam_ema_fake = torch.tensor(0.0)
total_steps = len(dataloader)
if cfg.max_test_samples > 0:
total_steps = min(int(cfg.max_test_samples//cfg.batch_size), total_steps)
@ -195,7 +207,7 @@ def main():
)
fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value = 0., dim = 2)
fake_logits = discriminator(fake_video.contiguous()) # TODO: take out contiguous?
fake_logits = discriminator(fake_video.contiguous())
adversarial_loss = adversarial_loss_fn(
fake_logits,
nll_loss,
@ -217,6 +229,8 @@ def main():
real_logits = discriminator(real_video.contiguous().detach())
fake_logits = discriminator(fake_video.contiguous().detach())
lecam_ema_real, lecam_ema_fake = lecam_ema.get()
weighted_d_adversarial_loss, lecam_loss, gradient_penalty_loss = disc_loss_fn(
real_logits,
fake_logits,
@ -228,15 +242,10 @@ def main():
disc_loss = weighted_d_adversarial_loss + lecam_loss + gradient_penalty_loss
if cfg.ema_decay is not None:
# SCH: TODO: is this written properly like this for moving average? e.g. distributed training etc.
lecam_ema_real = lecam_ema_real * cfg.ema_decay + (1 - cfg.ema_decay) * torch.mean(real_logits.clone().detach())
lecam_ema_fake = lecam_ema_fake * cfg.ema_decay + (1 - cfg.ema_decay) * torch.mean(fake_logits.clone().detach())
loss_steps += 1
running_disc_loss = disc_loss.item()/loss_steps + disc_loss * ((loss_steps - 1) / loss_steps)
running_disc_loss = disc_loss.item()/loss_steps + running_disc_loss * ((loss_steps - 1) / loss_steps)
running_loss = vae_loss.item()/ loss_steps + running_loss * ((loss_steps - 1) / loss_steps)
running_nll = nll_loss.item() / loss_steps + running_nll * ((loss_steps - 1) / loss_steps)
# ===== Spatial VAE =====
@ -266,6 +275,7 @@ def main():
if cfg.calc_loss:
print("test vae loss:", running_loss)
print("test nll loss:", running_nll)
print("test disc loss:", running_disc_loss)