mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-12 05:46:22 +02:00
config
This commit is contained in:
parent
4c46972b64
commit
96a42e08db
|
|
@ -21,6 +21,7 @@ sp_size = 1
|
|||
vae_2d = dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="stabilityai/sd-vae-ft-ema",
|
||||
# SDXL
|
||||
)
|
||||
|
||||
model = dict(
|
||||
|
|
@ -82,7 +83,7 @@ magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128
|
|||
3-6 epochs for pexel, from pexel observation its correct
|
||||
'''
|
||||
|
||||
epochs = 200
|
||||
epochs = 1000
|
||||
log_every = 1
|
||||
ckpt_every = 200
|
||||
load = None
|
||||
|
|
|
|||
|
|
@ -1094,6 +1094,7 @@ class AdversarialLoss(nn.Module):
|
|||
):
|
||||
# NOTE: following MAGVIT to allow non_saturating
|
||||
assert self.generator_loss_type in ["hinge", "vanilla", "non-saturating"]
|
||||
|
||||
if self.generator_loss_type == "hinge":
|
||||
gen_loss = -torch.mean(fake_logits)
|
||||
elif self.generator_loss_type == "non-saturating":
|
||||
|
|
|
|||
|
|
@ -425,6 +425,8 @@ def main():
|
|||
"num_samples": global_step * total_batch_size,
|
||||
"epoch": epoch,
|
||||
"loss": vae_loss.item(),
|
||||
"kl_loss": weighted_kl_loss.item(),
|
||||
"gen_adv_loss": adversarial_loss.item(),
|
||||
"disc_loss": disc_loss.item(),
|
||||
"avg_loss": avg_loss,
|
||||
},
|
||||
|
|
|
|||
Loading…
Reference in a new issue