This commit is contained in:
Shen-Chenhui 2024-04-12 15:20:50 +08:00
parent e004d43f69
commit a97c70fab2
2 changed files with 26 additions and 22 deletions

View file

@ -34,12 +34,12 @@ model = dict(
separate_first_frame_encoding = False,
kl_loss_weight = 0.000001,
perceptual_loss_weight = 0.1, # use vgg is not None and more than 0
adversarial_loss_weight = 1.0,
discriminator_factor = 1.0,
discriminator_in_channels = 3,
discriminator_filters = 128,
discriminator_channel_multipliers = (2,4,4,4,4),
discriminator_loss="hinge",
discriminator_start = 50001,
discriminator_weight = 0.5,
)
# Others

View file

@ -65,6 +65,7 @@ def hinge_d_loss(logits_real, logits_fake):
loss_real = torch.mean(F.relu(1. - logits_real))
loss_fake = torch.mean(F.relu(1. + logits_fake))
d_loss = 0.5 * (loss_real + loss_fake)
breakpoint() # TODO: CHECK mean rather than sum
return d_loss
def vanilla_d_loss(logits_real, logits_fake):
@ -94,20 +95,6 @@ def SameConv2d(dim_in, dim_out, kernel_size):
return nn.Conv2d(dim_in, dim_out, kernel_size = kernel_size, padding = padding)
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
breakpoint() # TODO: scrutinize
if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
else:
nll_grads = torch.autograd.grad(nll_loss, self.get_last_layer(), retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, self.get_last_layer(), retain_graph=True)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight
return d_weight
def adopt_weight(weight, global_step, threshold=0, value=0.):
if global_step < threshold:
weight = value
@ -619,6 +606,7 @@ class VAE_3D_V2(nn.Module):
discriminator_filters = 128,
discriminator_channel_multipliers = (2,4,4,4,4),
discriminator_loss="hinge",
discriminator_start=50001,
num_groups = 32, # for nn.GroupNorm
# conv_downsample = False,
# upsample = "nearest+conv", # options: "deconv", "nearest+conv"
@ -714,6 +702,7 @@ class VAE_3D_V2(nn.Module):
# Adversarial Loss
self.discriminator_factor = discriminator_factor
self.discriminator_start = discriminator_start
self.discriminator = None
if discriminator_factor is not None and discriminator_factor > 0:
self.discriminator = StyleGANDiscriminator(
@ -728,9 +717,9 @@ class VAE_3D_V2(nn.Module):
)
if discriminator_loss == "hinge":
self.calc_disc_loss = hinge_d_loss
self.disc_loss_fn = hinge_d_loss
elif discriminator_loss == "vanilla":
self.calc_disc_loss = vanilla_d_loss
self.disc_loss_fn = vanilla_d_loss
else:
raise ValueError(f"Unknown GAN loss '{discriminator_loss}'.")
@ -797,6 +786,21 @@ class VAE_3D_V2(nn.Module):
def get_last_layer(self):
return self.conv_out.weight
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
breakpoint() # TODO: scrutinize
if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
else:
nll_grads = torch.autograd.grad(nll_loss, self.get_last_layer(), retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, self.get_last_layer(), retain_graph=True)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight
return d_weight
def parameters(self):
return [
*self.conv_in.parameters(),
@ -896,7 +900,7 @@ class VAE_3D_V2(nn.Module):
d_weight = torch.tensor(0.0)
gan_loss = torch.tensor(0.0)
disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_iter_start)
disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_start)
weighted_gan_loss = d_weight * disc_factor * gan_loss
breakpoint()
total_loss += weighted_gan_loss
@ -922,7 +926,7 @@ class VAE_3D_V2(nn.Module):
# real_logits = self.discriminator(real_video.contiguous.detach())
# fake_logits = self.discriminator(fake_video.contiguous.detach())
# disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_iter_start)
# weight_discriminator_loss = disc_factor * self.calc_disc_loss(real_logits, fake_logits)
# weight_discriminator_loss = disc_factor * self.disc_loss_fn(real_logits, fake_logits)
# else:
# weight_discriminator_loss = 0
@ -949,8 +953,8 @@ class VAE_3D_V2(nn.Module):
fake_video = pad_at_dim(recon_video, (self.discr_time_padding, 0), value = 0., dim = 2)
real_logits = self.discriminator(real_video.contiguous.detach())
fake_logits = self.discriminator(fake_video.contiguous.detach())
disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_iter_start)
weight_discriminator_loss = disc_factor * self.calc_disc_loss(real_logits, fake_logits)
disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_start)
weight_discriminator_loss = disc_factor * self.disc_loss_fn(real_logits, fake_logits)
else:
weight_discriminator_loss = 0