diff --git a/configs/vae_magvit_v2/train/pipeline_16x128x128.py b/configs/vae_magvit_v2/train/pipeline_16x128x128.py index 6da9ee1..8b2dd77 100644 --- a/configs/vae_magvit_v2/train/pipeline_16x128x128.py +++ b/configs/vae_magvit_v2/train/pipeline_16x128x128.py @@ -21,6 +21,7 @@ sp_size = 1 vae_2d = dict( type="VideoAutoencoderKL", from_pretrained="stabilityai/sd-vae-ft-ema", + # SDXL ) model = dict( @@ -82,7 +83,7 @@ magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128 3-6 epochs for pexel, from pexel observation its correct ''' -epochs = 200 +epochs = 1000 log_every = 1 ckpt_every = 200 load = None diff --git a/opensora/models/vae/vae_3d_v2.py b/opensora/models/vae/vae_3d_v2.py index fe97dda..66f8f3e 100644 --- a/opensora/models/vae/vae_3d_v2.py +++ b/opensora/models/vae/vae_3d_v2.py @@ -1094,6 +1094,7 @@ class AdversarialLoss(nn.Module): ): # NOTE: following MAGVIT to allow non_saturating assert self.generator_loss_type in ["hinge", "vanilla", "non-saturating"] + if self.generator_loss_type == "hinge": gen_loss = -torch.mean(fake_logits) elif self.generator_loss_type == "non-saturating": diff --git a/scripts/train-vae-v2.py b/scripts/train-vae-v2.py index d9900ef..e007878 100644 --- a/scripts/train-vae-v2.py +++ b/scripts/train-vae-v2.py @@ -425,6 +425,8 @@ def main(): "num_samples": global_step * total_batch_size, "epoch": epoch, "loss": vae_loss.item(), + "kl_loss": weighted_kl_loss.item(), + "gen_adv_loss": adversarial_loss.item(), "disc_loss": disc_loss.item(), "avg_loss": avg_loss, },