added disc

This commit is contained in:
Shen-Chenhui 2024-05-02 09:48:30 +00:00
parent bcb997502e
commit d8fab103f4
3 changed files with 9 additions and 9 deletions

View file

@ -42,13 +42,13 @@ discriminator = dict(
)
# discriminator hyper-parames TODO
discriminator_factor=1,
discriminator_factor=1
discriminator_start=-1
generator_factor=0.5,
generator_loss_type="hinge",
discriminator_loss_type="hinge",
lecam_loss_weight=None,
gradient_penalty_loss_weight=None,
generator_factor=0.5
generator_loss_type="hinge"
discriminator_loss_type="hinge"
lecam_loss_weight=None
gradient_penalty_loss_weight=None
# loss weights
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0

View file

@ -153,7 +153,7 @@ class VideoAutoencoderPipeline(nn.Module):
return self.temporal_vae.get_latent_size(self.spatial_vae.get_latent_size(input_size))
def get_temporal_last_layer(self):
return self.temporal_vae.last_layer[0]
return self.temporal_vae.decoder.conv_out.conv.weight
@property
def device(self):

View file

@ -277,9 +277,9 @@ def main():
vae_loss += adversarial_loss
# Backward & update
optimizer.zero_grad()
booster.backward(loss=vae_loss, optimizer=optimizer)
optimizer.step()
optimizer.zero_grad()
# Adversarial Discriminator loss
@ -295,9 +295,9 @@ def main():
)
disc_loss = weighted_d_adversarial_loss
# Backward & update
disc_optimizer.zero_grad()
booster.backward(loss=disc_loss, optimizer=disc_optimizer)
disc_optimizer.step()
disc_optimizer.zero_grad()
all_reduce_mean(disc_loss)
running_disc_loss += disc_loss.item()