mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 21:01:26 +02:00
debug
This commit is contained in:
parent
0da15473d0
commit
0ad337178a
|
|
@ -438,6 +438,7 @@ class StyleGANDiscriminatorBlur(nn.Module):
|
|||
time_scaled = num_frames // scale_factor + 1
|
||||
else:
|
||||
time_scaled = num_frames / scale_factor
|
||||
|
||||
assert self.input_size[0] % scale_factor == 0, f"image width {self.input_size[0]} is not divisible by scale factor {scale_factor}"
|
||||
assert self.input_size[1] % scale_factor == 0, f"image height {self.input_size[1]} is not divisible by scale factor {scale_factor}"
|
||||
w_scaled, h_scaled = self.input_size[0] / scale_factor, self.input_size[1] / scale_factor
|
||||
|
|
|
|||
|
|
@ -255,7 +255,10 @@ def main():
|
|||
|
||||
# calculate discriminator_time_padding
|
||||
disc_time_downsample_factor = 2 ** len(cfg.discriminator.channel_multipliers)
|
||||
disc_time_padding = disc_time_downsample_factor - cfg.num_frames % disc_time_downsample_factor
|
||||
if cfg.num_frames % disc_time_downsample_factor != 0:
|
||||
disc_time_padding = disc_time_downsample_factor - cfg.num_frames % disc_time_downsample_factor
|
||||
else:
|
||||
disc_time_padding = 0
|
||||
video_contains_first_frame = cfg.video_contains_first_frame
|
||||
|
||||
lecam_ema_real = torch.tensor(0.0)
|
||||
|
|
@ -340,7 +343,7 @@ def main():
|
|||
if global_step > cfg.discriminator_start:
|
||||
# padded videos for GAN
|
||||
fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value = 0., dim = 2)
|
||||
fake_logits = discriminator(fake_video.contiguous()) # TODO: take out contiguous?
|
||||
fake_logits = discriminator(fake_video.contiguous())
|
||||
adversarial_loss = adversarial_loss_fn(
|
||||
fake_logits,
|
||||
nll_loss,
|
||||
|
|
|
|||
Loading…
Reference in a new issue