mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-05-20 17:35:58 +02:00
fixed inference
This commit is contained in:
parent
6fb4e3cd22
commit
0198ea8a52
|
|
@ -66,7 +66,7 @@ def main():
|
|||
pin_memory=True,
|
||||
process_group=get_data_parallel_group(),
|
||||
)
|
||||
print(f"Dataset contains {len(dataset):,} videos ({cfg.data_path})")
|
||||
print(f"Dataset contains {len(dataset):,} videos ({cfg.dataset.data_path})")
|
||||
|
||||
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size
|
||||
print(f"Total batch size: {total_batch_size}")
|
||||
|
|
@ -136,8 +136,8 @@ def main():
|
|||
loss_steps = 0
|
||||
|
||||
disc_time_downsample_factor = 2 ** len(cfg.discriminator.channel_multipliers)
|
||||
if cfg.datasets.num_frames % disc_time_downsample_factor != 0:
|
||||
disc_time_padding = disc_time_downsample_factor - cfg.datasets.num_frames % disc_time_downsample_factor
|
||||
if cfg.dataset.num_frames % disc_time_downsample_factor != 0:
|
||||
disc_time_padding = disc_time_downsample_factor - cfg.dataset.num_frames % disc_time_downsample_factor
|
||||
else:
|
||||
disc_time_padding = 0
|
||||
video_contains_first_frame = cfg.video_contains_first_frame
|
||||
|
|
|
|||
Loading…
Reference in a new issue