This commit is contained in:
Shen-Chenhui 2024-04-19 10:29:25 +08:00
parent 0da15473d0
commit 0ad337178a
2 changed files with 6 additions and 2 deletions

View file

@ -438,6 +438,7 @@ class StyleGANDiscriminatorBlur(nn.Module):
time_scaled = num_frames // scale_factor + 1
else:
time_scaled = num_frames / scale_factor
assert self.input_size[0] % scale_factor == 0, f"image width {self.input_size[0]} is not divisible by scale factor {scale_factor}"
assert self.input_size[1] % scale_factor == 0, f"image height {self.input_size[1]} is not divisible by scale factor {scale_factor}"
w_scaled, h_scaled = self.input_size[0] / scale_factor, self.input_size[1] / scale_factor

View file

@ -255,7 +255,10 @@ def main():
# calculate discriminator_time_padding
disc_time_downsample_factor = 2 ** len(cfg.discriminator.channel_multipliers)
disc_time_padding = disc_time_downsample_factor - cfg.num_frames % disc_time_downsample_factor
if cfg.num_frames % disc_time_downsample_factor != 0:
disc_time_padding = disc_time_downsample_factor - cfg.num_frames % disc_time_downsample_factor
else:
disc_time_padding = 0
video_contains_first_frame = cfg.video_contains_first_frame
lecam_ema_real = torch.tensor(0.0)
@ -340,7 +343,7 @@ def main():
if global_step > cfg.discriminator_start:
# padded videos for GAN
fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value = 0., dim = 2)
fake_logits = discriminator(fake_video.contiguous()) # TODO: take out contiguous?
fake_logits = discriminator(fake_video.contiguous())
adversarial_loss = adversarial_loss_fn(
fake_logits,
nll_loss,