mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 12:49:38 +02:00
added disc
This commit is contained in:
parent
bcb997502e
commit
d8fab103f4
|
|
@ -42,13 +42,13 @@ discriminator = dict(
|
|||
)
|
||||
|
||||
# discriminator hyper-parames TODO
|
||||
discriminator_factor=1,
|
||||
discriminator_factor=1
|
||||
discriminator_start=-1
|
||||
generator_factor=0.5,
|
||||
generator_loss_type="hinge",
|
||||
discriminator_loss_type="hinge",
|
||||
lecam_loss_weight=None,
|
||||
gradient_penalty_loss_weight=None,
|
||||
generator_factor=0.5
|
||||
generator_loss_type="hinge"
|
||||
discriminator_loss_type="hinge"
|
||||
lecam_loss_weight=None
|
||||
gradient_penalty_loss_weight=None
|
||||
|
||||
# loss weights
|
||||
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
|
||||
|
|
|
|||
|
|
@ -153,7 +153,7 @@ class VideoAutoencoderPipeline(nn.Module):
|
|||
return self.temporal_vae.get_latent_size(self.spatial_vae.get_latent_size(input_size))
|
||||
|
||||
def get_temporal_last_layer(self):
|
||||
return self.temporal_vae.last_layer[0]
|
||||
return self.temporal_vae.decoder.conv_out.conv.weight
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
|
|
|
|||
|
|
@ -277,9 +277,9 @@ def main():
|
|||
vae_loss += adversarial_loss
|
||||
|
||||
# Backward & update
|
||||
optimizer.zero_grad()
|
||||
booster.backward(loss=vae_loss, optimizer=optimizer)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
||||
# Adversarial Discriminator loss
|
||||
|
|
@ -295,9 +295,9 @@ def main():
|
|||
)
|
||||
disc_loss = weighted_d_adversarial_loss
|
||||
# Backward & update
|
||||
disc_optimizer.zero_grad()
|
||||
booster.backward(loss=disc_loss, optimizer=disc_optimizer)
|
||||
disc_optimizer.step()
|
||||
disc_optimizer.zero_grad()
|
||||
all_reduce_mean(disc_loss)
|
||||
running_disc_loss += disc_loss.item()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue