mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
debug
This commit is contained in:
parent
e004d43f69
commit
a97c70fab2
|
|
@ -34,12 +34,12 @@ model = dict(
|
|||
separate_first_frame_encoding = False,
|
||||
kl_loss_weight = 0.000001,
|
||||
perceptual_loss_weight = 0.1, # use vgg is not None and more than 0
|
||||
adversarial_loss_weight = 1.0,
|
||||
discriminator_factor = 1.0,
|
||||
discriminator_in_channels = 3,
|
||||
discriminator_filters = 128,
|
||||
discriminator_channel_multipliers = (2,4,4,4,4),
|
||||
discriminator_loss="hinge",
|
||||
discriminator_start = 50001,
|
||||
discriminator_weight = 0.5,
|
||||
)
|
||||
|
||||
# Others
|
||||
|
|
|
|||
|
|
@ -65,6 +65,7 @@ def hinge_d_loss(logits_real, logits_fake):
|
|||
loss_real = torch.mean(F.relu(1. - logits_real))
|
||||
loss_fake = torch.mean(F.relu(1. + logits_fake))
|
||||
d_loss = 0.5 * (loss_real + loss_fake)
|
||||
breakpoint() # TODO: CHECK mean rather than sum
|
||||
return d_loss
|
||||
|
||||
def vanilla_d_loss(logits_real, logits_fake):
|
||||
|
|
@ -94,20 +95,6 @@ def SameConv2d(dim_in, dim_out, kernel_size):
|
|||
return nn.Conv2d(dim_in, dim_out, kernel_size = kernel_size, padding = padding)
|
||||
|
||||
|
||||
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
|
||||
breakpoint() # TODO: scrutinize
|
||||
if last_layer is not None:
|
||||
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
||||
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
||||
else:
|
||||
nll_grads = torch.autograd.grad(nll_loss, self.get_last_layer(), retain_graph=True)[0]
|
||||
g_grads = torch.autograd.grad(g_loss, self.get_last_layer(), retain_graph=True)[0]
|
||||
|
||||
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
||||
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
||||
d_weight = d_weight * self.discriminator_weight
|
||||
return d_weight
|
||||
|
||||
def adopt_weight(weight, global_step, threshold=0, value=0.):
|
||||
if global_step < threshold:
|
||||
weight = value
|
||||
|
|
@ -619,6 +606,7 @@ class VAE_3D_V2(nn.Module):
|
|||
discriminator_filters = 128,
|
||||
discriminator_channel_multipliers = (2,4,4,4,4),
|
||||
discriminator_loss="hinge",
|
||||
discriminator_start=50001,
|
||||
num_groups = 32, # for nn.GroupNorm
|
||||
# conv_downsample = False,
|
||||
# upsample = "nearest+conv", # options: "deconv", "nearest+conv"
|
||||
|
|
@ -714,6 +702,7 @@ class VAE_3D_V2(nn.Module):
|
|||
|
||||
# Adversarial Loss
|
||||
self.discriminator_factor = discriminator_factor
|
||||
self.discriminator_start = discriminator_start
|
||||
self.discriminator = None
|
||||
if discriminator_factor is not None and discriminator_factor > 0:
|
||||
self.discriminator = StyleGANDiscriminator(
|
||||
|
|
@ -728,9 +717,9 @@ class VAE_3D_V2(nn.Module):
|
|||
)
|
||||
|
||||
if discriminator_loss == "hinge":
|
||||
self.calc_disc_loss = hinge_d_loss
|
||||
self.disc_loss_fn = hinge_d_loss
|
||||
elif discriminator_loss == "vanilla":
|
||||
self.calc_disc_loss = vanilla_d_loss
|
||||
self.disc_loss_fn = vanilla_d_loss
|
||||
else:
|
||||
raise ValueError(f"Unknown GAN loss '{discriminator_loss}'.")
|
||||
|
||||
|
|
@ -797,6 +786,21 @@ class VAE_3D_V2(nn.Module):
|
|||
def get_last_layer(self):
|
||||
return self.conv_out.weight
|
||||
|
||||
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
|
||||
breakpoint() # TODO: scrutinize
|
||||
if last_layer is not None:
|
||||
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
||||
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
||||
else:
|
||||
nll_grads = torch.autograd.grad(nll_loss, self.get_last_layer(), retain_graph=True)[0]
|
||||
g_grads = torch.autograd.grad(g_loss, self.get_last_layer(), retain_graph=True)[0]
|
||||
|
||||
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
||||
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
||||
d_weight = d_weight * self.discriminator_weight
|
||||
|
||||
return d_weight
|
||||
|
||||
def parameters(self):
|
||||
return [
|
||||
*self.conv_in.parameters(),
|
||||
|
|
@ -896,7 +900,7 @@ class VAE_3D_V2(nn.Module):
|
|||
d_weight = torch.tensor(0.0)
|
||||
gan_loss = torch.tensor(0.0)
|
||||
|
||||
disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_iter_start)
|
||||
disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_start)
|
||||
weighted_gan_loss = d_weight * disc_factor * gan_loss
|
||||
breakpoint()
|
||||
total_loss += weighted_gan_loss
|
||||
|
|
@ -922,7 +926,7 @@ class VAE_3D_V2(nn.Module):
|
|||
# real_logits = self.discriminator(real_video.contiguous.detach())
|
||||
# fake_logits = self.discriminator(fake_video.contiguous.detach())
|
||||
# disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_iter_start)
|
||||
# weight_discriminator_loss = disc_factor * self.calc_disc_loss(real_logits, fake_logits)
|
||||
# weight_discriminator_loss = disc_factor * self.disc_loss_fn(real_logits, fake_logits)
|
||||
# else:
|
||||
# weight_discriminator_loss = 0
|
||||
|
||||
|
|
@ -949,8 +953,8 @@ class VAE_3D_V2(nn.Module):
|
|||
fake_video = pad_at_dim(recon_video, (self.discr_time_padding, 0), value = 0., dim = 2)
|
||||
real_logits = self.discriminator(real_video.contiguous.detach())
|
||||
fake_logits = self.discriminator(fake_video.contiguous.detach())
|
||||
disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_iter_start)
|
||||
weight_discriminator_loss = disc_factor * self.calc_disc_loss(real_logits, fake_logits)
|
||||
disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_start)
|
||||
weight_discriminator_loss = disc_factor * self.disc_loss_fn(real_logits, fake_logits)
|
||||
else:
|
||||
weight_discriminator_loss = 0
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue