This commit is contained in:
Shen-Chenhui 2024-04-13 10:51:56 +08:00
parent fb0f59171c
commit 79bff13099
3 changed files with 18 additions and 17 deletions

View file

@ -58,8 +58,9 @@ discriminator = dict(
kl_loss_weight = 0.000001
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
discriminator_factor = 1.0
discriminator_loss_weight = 0.5 # TODO: adjust value
discriminator_loss="hinge"
discriminator_start = 1 # 50001 TODO: change to correct val, debug use 1 for now
discriminator_start = -1 # 50001 TODO: change to correct val, debug use -1 for now
# Others
seed = 42

View file

@ -941,15 +941,6 @@ class VEALoss(nn.Module):
)
vgg.classifier = Sequential(*vgg.classifier[:-2])
self.vgg = vgg.to(device, dtype)
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer):
breakpoint() # TODO: scrutinize
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight
return d_weight
def forward(
self,
@ -1026,10 +1017,20 @@ class AdversarialLoss(nn.Module):
self,
discriminator_factor = 1.0,
discriminator_start = 50001,
discriminator_loss_weight = 0.5,
):
super().__init__()
self.discriminator_factor = discriminator_factor
self.discriminator_start = discriminator_start
self.discriminator_loss_weight = discriminator_loss_weight
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer):
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_loss_weight
return d_weight
def forward(
self,
@ -1038,17 +1039,16 @@ class AdversarialLoss(nn.Module):
last_layer,
global_step,
is_training = True,
):
):
gan_loss = -torch.mean(fake_logits)
if self.discriminator_factor is not None and self.discriminator_factor > 0.0:
try:
d_weight = self.calculate_adaptive_weight(nll_loss, gan_loss, last_layer)
except RuntimeError:
assert not is_training
d_weight = torch.tensor(0.0)
gan_loss = -torch.mean(fake_logits)
else:
d_weight = torch.tensor(0.0)
gan_loss = torch.tensor(0.0)
disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_start)
weighted_gan_loss = d_weight * disc_factor * gan_loss

View file

@ -271,10 +271,6 @@ def main():
else:
video = x
# padded videos for GAN
if global_step > cfg.discriminator_start:
real_video = pad_at_dim(video, (disc_time_padding, 0), value = 0., dim = 2)
fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value = 0., dim = 2)
# ====== VAE ======
optimizer.zero_grad()
@ -290,8 +286,11 @@ def main():
split = "train"
)
vae_loss = nll_loss
# adversarial loss
if global_step > cfg.discriminator_start:
# padded videos for GAN
fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value = 0., dim = 2)
fake_logits = discriminator(fake_video.contiguous())
adversarial_loss = adversarial_loss_fn(
fake_logits,
@ -313,6 +312,7 @@ def main():
disc_optimizer.zero_grad()
# if video_contains_first_frame:
# Since we don't have enough T frames, pad anyways
real_video = pad_at_dim(video, (disc_time_padding, 0), value = 0., dim = 2)
real_logits = discriminator(real_video.contiguous().detach())
fake_logits = discriminator(fake_video.contiguous().detach())
disc_loss = disc_loss_fn(real_logits, fake_logits, global_step)