mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-14 18:25:35 +02:00
dataset trainscript
This commit is contained in:
parent
b5a3334caf
commit
2c391a636b
|
|
@ -58,7 +58,7 @@ def main():
|
|||
if exp_dir is None:
|
||||
experiment_index = len(glob(f"{cfg.outputs}/*")) - 1
|
||||
model_name = cfg.model["type"].replace("/", "-")
|
||||
exp_name = f"{experiment_index:03d}-F{cfg.num_frames}S{cfg.frame_interval}-{model_name}"
|
||||
exp_name = f"{experiment_index:03d}-F{cfg.dataset.num_frames}S{cfg.dataset.frame_interval}-{model_name}"
|
||||
exp_dir = f"{cfg.outputs}/{exp_name}"
|
||||
assert os.path.exists(exp_dir)
|
||||
|
||||
|
|
@ -267,8 +267,8 @@ def main():
|
|||
|
||||
# calculate discriminator_time_padding
|
||||
disc_time_downsample_factor = 2 ** len(cfg.discriminator.channel_multipliers)
|
||||
if cfg.num_frames % disc_time_downsample_factor != 0:
|
||||
disc_time_padding = disc_time_downsample_factor - cfg.num_frames % disc_time_downsample_factor
|
||||
if cfg.dataset.num_frames % disc_time_downsample_factor != 0:
|
||||
disc_time_padding = disc_time_downsample_factor - cfg.dataset.num_frames % disc_time_downsample_factor
|
||||
else:
|
||||
disc_time_padding = 0
|
||||
video_contains_first_frame = cfg.video_contains_first_frame
|
||||
|
|
@ -314,8 +314,8 @@ def main():
|
|||
# supprt for image or video inputs
|
||||
assert x.ndim in {4, 5}, f"received input of {x.ndim} dimensions" # either image or video
|
||||
assert (
|
||||
x.shape[-2:] == cfg.image_size
|
||||
), f"received input size {x.shape[-2:]}, but config image size is {cfg.image_size}"
|
||||
x.shape[-2:] == cfg.dataset.image_size
|
||||
), f"received input size {x.shape[-2:]}, but config image size is {cfg.dataset.image_size}"
|
||||
is_image = x.ndim == 4
|
||||
if is_image:
|
||||
video = rearrange(x, "b c ... -> b c 1 ...")
|
||||
|
|
|
|||
Loading…
Reference in a new issue