mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-05-21 11:59:01 +02:00
debug
This commit is contained in:
parent
0198ea8a52
commit
52e869079c
|
|
@ -88,12 +88,8 @@ magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128
|
|||
3-6 epochs for pexel, from pexel observation its correct
|
||||
'''
|
||||
|
||||
epochs = 10
|
||||
log_every = 1
|
||||
ckpt_every = 500
|
||||
load = None
|
||||
|
||||
batch_size = 4
|
||||
batch_size = 1
|
||||
lr = 1e-4
|
||||
grad_clip = 1.0
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from mmengine.runner import set_random_seed
|
|||
from opensora.acceleration.parallel_states import set_sequence_parallel_group
|
||||
from opensora.datasets import save_sample
|
||||
from opensora.registry import MODELS, SCHEDULERS, build_module
|
||||
from opensora.utils.config_utils import parse_configs
|
||||
from opensora.utils.config_utils import parse_configs, load_json
|
||||
from opensora.utils.misc import to_torch_dtype
|
||||
from opensora.datasets import prepare_dataloader, prepare_variable_dataloader
|
||||
from opensora.registry import DATASETS, MODELS, build_module
|
||||
|
|
@ -19,7 +19,7 @@ from opensora.acceleration.parallel_states import (
|
|||
set_sequence_parallel_group,
|
||||
)
|
||||
from tqdm import tqdm
|
||||
from opensora.models.vae.vae_3d_v2 import VEALoss, DiscriminatorLoss, AdversarialLoss, pad_at_dim
|
||||
from opensora.models.vae.vae_3d_v2 import VEALoss, DiscriminatorLoss, AdversarialLoss, LeCamEMA, pad_at_dim
|
||||
|
||||
from einops import rearrange
|
||||
from colossalai.utils import get_current_device
|
||||
|
|
@ -129,9 +129,24 @@ def main():
|
|||
lecam_loss_weight = cfg.lecam_loss_weight,
|
||||
gradient_penalty_loss_weight = cfg.gradient_penalty_loss_weight,
|
||||
)
|
||||
|
||||
# LeCam EMA for discriminator
|
||||
lecam_path = os.path.join(cfg.load, "lecam_states.json")
|
||||
if cfg.lecam_loss_weight is not None and os.path.exists(lecam_path):
|
||||
lecam_state = load_json(lecam_path)
|
||||
lecam_ema_real, lecam_ema_fake = lecam_state["lecam_ema_real"], lecam_state["lecam_ema_fake"]
|
||||
lecam_ema = LeCamEMA(
|
||||
decay=cfg.ema_decay, ema_real=lecam_ema_real, ema_fake=lecam_ema_fake, dtype=dtype, device=device
|
||||
)
|
||||
else:
|
||||
lecam_ema = LeCamEMA(
|
||||
decay=cfg.ema_decay, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
|
||||
|
||||
running_loss = 0.0
|
||||
running_nll = 0.0
|
||||
running_disc_loss = 0.0
|
||||
loss_steps = 0
|
||||
|
||||
|
|
@ -142,9 +157,6 @@ def main():
|
|||
disc_time_padding = 0
|
||||
video_contains_first_frame = cfg.video_contains_first_frame
|
||||
|
||||
lecam_ema_real = torch.tensor(0.0)
|
||||
lecam_ema_fake = torch.tensor(0.0)
|
||||
|
||||
total_steps = len(dataloader)
|
||||
if cfg.max_test_samples > 0:
|
||||
total_steps = min(int(cfg.max_test_samples//cfg.batch_size), total_steps)
|
||||
|
|
@ -195,7 +207,7 @@ def main():
|
|||
)
|
||||
|
||||
fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value = 0., dim = 2)
|
||||
fake_logits = discriminator(fake_video.contiguous()) # TODO: take out contiguous?
|
||||
fake_logits = discriminator(fake_video.contiguous())
|
||||
adversarial_loss = adversarial_loss_fn(
|
||||
fake_logits,
|
||||
nll_loss,
|
||||
|
|
@ -217,6 +229,8 @@ def main():
|
|||
real_logits = discriminator(real_video.contiguous().detach())
|
||||
|
||||
fake_logits = discriminator(fake_video.contiguous().detach())
|
||||
|
||||
lecam_ema_real, lecam_ema_fake = lecam_ema.get()
|
||||
weighted_d_adversarial_loss, lecam_loss, gradient_penalty_loss = disc_loss_fn(
|
||||
real_logits,
|
||||
fake_logits,
|
||||
|
|
@ -228,15 +242,10 @@ def main():
|
|||
|
||||
disc_loss = weighted_d_adversarial_loss + lecam_loss + gradient_penalty_loss
|
||||
|
||||
if cfg.ema_decay is not None:
|
||||
# SCH: TODO: is this written properly like this for moving average? e.g. distributed training etc.
|
||||
lecam_ema_real = lecam_ema_real * cfg.ema_decay + (1 - cfg.ema_decay) * torch.mean(real_logits.clone().detach())
|
||||
lecam_ema_fake = lecam_ema_fake * cfg.ema_decay + (1 - cfg.ema_decay) * torch.mean(fake_logits.clone().detach())
|
||||
|
||||
|
||||
loss_steps += 1
|
||||
running_disc_loss = disc_loss.item()/loss_steps + disc_loss * ((loss_steps - 1) / loss_steps)
|
||||
running_disc_loss = disc_loss.item()/loss_steps + running_disc_loss * ((loss_steps - 1) / loss_steps)
|
||||
running_loss = vae_loss.item()/ loss_steps + running_loss * ((loss_steps - 1) / loss_steps)
|
||||
running_nll = nll_loss.item() / loss_steps + running_nll * ((loss_steps - 1) / loss_steps)
|
||||
|
||||
|
||||
# ===== Spatial VAE =====
|
||||
|
|
@ -266,6 +275,7 @@ def main():
|
|||
|
||||
if cfg.calc_loss:
|
||||
print("test vae loss:", running_loss)
|
||||
print("test nll loss:", running_nll)
|
||||
print("test disc loss:", running_disc_loss)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue