mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
debug
This commit is contained in:
parent
fb0f59171c
commit
79bff13099
|
|
@ -58,8 +58,9 @@ discriminator = dict(
|
|||
kl_loss_weight = 0.000001
|
||||
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
|
||||
discriminator_factor = 1.0
|
||||
discriminator_loss_weight = 0.5 # TODO: adjust value
|
||||
discriminator_loss="hinge"
|
||||
discriminator_start = 1 # 50001 TODO: change to correct val, debug use 1 for now
|
||||
discriminator_start = -1 # 50001 TODO: change to correct val, debug use -1 for now
|
||||
|
||||
# Others
|
||||
seed = 42
|
||||
|
|
|
|||
|
|
@ -941,15 +941,6 @@ class VEALoss(nn.Module):
|
|||
)
|
||||
vgg.classifier = Sequential(*vgg.classifier[:-2])
|
||||
self.vgg = vgg.to(device, dtype)
|
||||
|
||||
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer):
|
||||
breakpoint() # TODO: scrutinize
|
||||
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
||||
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
||||
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
||||
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
||||
d_weight = d_weight * self.discriminator_weight
|
||||
return d_weight
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
@ -1026,10 +1017,20 @@ class AdversarialLoss(nn.Module):
|
|||
self,
|
||||
discriminator_factor = 1.0,
|
||||
discriminator_start = 50001,
|
||||
discriminator_loss_weight = 0.5,
|
||||
):
|
||||
super().__init__()
|
||||
self.discriminator_factor = discriminator_factor
|
||||
self.discriminator_start = discriminator_start
|
||||
self.discriminator_loss_weight = discriminator_loss_weight
|
||||
|
||||
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer):
|
||||
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
||||
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
||||
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
||||
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
||||
d_weight = d_weight * self.discriminator_loss_weight
|
||||
return d_weight
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
@ -1038,17 +1039,16 @@ class AdversarialLoss(nn.Module):
|
|||
last_layer,
|
||||
global_step,
|
||||
is_training = True,
|
||||
):
|
||||
):
|
||||
gan_loss = -torch.mean(fake_logits)
|
||||
if self.discriminator_factor is not None and self.discriminator_factor > 0.0:
|
||||
try:
|
||||
d_weight = self.calculate_adaptive_weight(nll_loss, gan_loss, last_layer)
|
||||
except RuntimeError:
|
||||
assert not is_training
|
||||
d_weight = torch.tensor(0.0)
|
||||
gan_loss = -torch.mean(fake_logits)
|
||||
else:
|
||||
d_weight = torch.tensor(0.0)
|
||||
gan_loss = torch.tensor(0.0)
|
||||
|
||||
disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_start)
|
||||
weighted_gan_loss = d_weight * disc_factor * gan_loss
|
||||
|
|
|
|||
|
|
@ -271,10 +271,6 @@ def main():
|
|||
else:
|
||||
video = x
|
||||
|
||||
# padded videos for GAN
|
||||
if global_step > cfg.discriminator_start:
|
||||
real_video = pad_at_dim(video, (disc_time_padding, 0), value = 0., dim = 2)
|
||||
fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value = 0., dim = 2)
|
||||
|
||||
# ====== VAE ======
|
||||
optimizer.zero_grad()
|
||||
|
|
@ -290,8 +286,11 @@ def main():
|
|||
split = "train"
|
||||
)
|
||||
vae_loss = nll_loss
|
||||
|
||||
# adversarial loss
|
||||
if global_step > cfg.discriminator_start:
|
||||
# padded videos for GAN
|
||||
fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value = 0., dim = 2)
|
||||
fake_logits = discriminator(fake_video.contiguous())
|
||||
adversarial_loss = adversarial_loss_fn(
|
||||
fake_logits,
|
||||
|
|
@ -313,6 +312,7 @@ def main():
|
|||
disc_optimizer.zero_grad()
|
||||
# if video_contains_first_frame:
|
||||
# Since we don't have enough T frames, pad anyways
|
||||
real_video = pad_at_dim(video, (disc_time_padding, 0), value = 0., dim = 2)
|
||||
real_logits = discriminator(real_video.contiguous().detach())
|
||||
fake_logits = discriminator(fake_video.contiguous().detach())
|
||||
disc_loss = disc_loss_fn(real_logits, fake_logits, global_step)
|
||||
|
|
|
|||
Loading…
Reference in a new issue