This commit is contained in:
Shen-Chenhui 2024-04-19 17:49:50 +08:00
parent 4c46972b64
commit 96a42e08db
3 changed files with 5 additions and 1 deletions

View file

@ -21,6 +21,7 @@ sp_size = 1
vae_2d = dict(
type="VideoAutoencoderKL",
from_pretrained="stabilityai/sd-vae-ft-ema",
# SDXL
)
model = dict(
@ -82,7 +83,7 @@ magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128
3-6 epochs for pexel, from pexel observation its correct
'''
epochs = 200
epochs = 1000
log_every = 1
ckpt_every = 200
load = None

View file

@ -1094,6 +1094,7 @@ class AdversarialLoss(nn.Module):
):
# NOTE: following MAGVIT to allow non_saturating
assert self.generator_loss_type in ["hinge", "vanilla", "non-saturating"]
if self.generator_loss_type == "hinge":
gen_loss = -torch.mean(fake_logits)
elif self.generator_loss_type == "non-saturating":

View file

@ -425,6 +425,8 @@ def main():
"num_samples": global_step * total_batch_size,
"epoch": epoch,
"loss": vae_loss.item(),
"kl_loss": weighted_kl_loss.item(),
"gen_adv_loss": adversarial_loss.item(),
"disc_loss": disc_loss.item(),
"avg_loss": avg_loss,
},