diff --git a/opensora/models/vae/vae_3d_v2.py b/opensora/models/vae/vae_3d_v2.py index d81af23..44983e0 100644 --- a/opensora/models/vae/vae_3d_v2.py +++ b/opensora/models/vae/vae_3d_v2.py @@ -438,6 +438,7 @@ class StyleGANDiscriminatorBlur(nn.Module): time_scaled = num_frames // scale_factor + 1 else: time_scaled = num_frames / scale_factor + assert self.input_size[0] % scale_factor == 0, f"image width {self.input_size[0]} is not divisible by scale factor {scale_factor}" assert self.input_size[1] % scale_factor == 0, f"image height {self.input_size[1]} is not divisible by scale factor {scale_factor}" w_scaled, h_scaled = self.input_size[0] / scale_factor, self.input_size[1] / scale_factor diff --git a/scripts/train-vae-v2.py b/scripts/train-vae-v2.py index 3d10db0..8ff7045 100644 --- a/scripts/train-vae-v2.py +++ b/scripts/train-vae-v2.py @@ -255,7 +255,10 @@ def main(): # calculate discriminator_time_padding disc_time_downsample_factor = 2 ** len(cfg.discriminator.channel_multipliers) - disc_time_padding = disc_time_downsample_factor - cfg.num_frames % disc_time_downsample_factor + if cfg.num_frames % disc_time_downsample_factor != 0: + disc_time_padding = disc_time_downsample_factor - cfg.num_frames % disc_time_downsample_factor + else: + disc_time_padding = 0 video_contains_first_frame = cfg.video_contains_first_frame lecam_ema_real = torch.tensor(0.0) @@ -340,7 +343,7 @@ def main(): if global_step > cfg.discriminator_start: # padded videos for GAN fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value = 0., dim = 2) - fake_logits = discriminator(fake_video.contiguous()) # TODO: take out contiguous? + fake_logits = discriminator(fake_video.contiguous()) adversarial_loss = adversarial_loss_fn( fake_logits, nll_loss,