fixed inference

This commit is contained in:
Shen-Chenhui 2024-04-27 17:21:45 +08:00
parent 6fb4e3cd22
commit 0198ea8a52

View file

@ -66,7 +66,7 @@ def main():
pin_memory=True,
process_group=get_data_parallel_group(),
)
print(f"Dataset contains {len(dataset):,} videos ({cfg.data_path})")
print(f"Dataset contains {len(dataset):,} videos ({cfg.dataset.data_path})")
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size
print(f"Total batch size: {total_batch_size}")
@ -136,8 +136,8 @@ def main():
loss_steps = 0
disc_time_downsample_factor = 2 ** len(cfg.discriminator.channel_multipliers)
if cfg.datasets.num_frames % disc_time_downsample_factor != 0:
disc_time_padding = disc_time_downsample_factor - cfg.datasets.num_frames % disc_time_downsample_factor
if cfg.dataset.num_frames % disc_time_downsample_factor != 0:
disc_time_padding = disc_time_downsample_factor - cfg.dataset.num_frames % disc_time_downsample_factor
else:
disc_time_padding = 0
video_contains_first_frame = cfg.video_contains_first_frame