dataset trainscript

This commit is contained in:
Shen-Chenhui 2024-04-27 14:55:47 +08:00
parent b5a3334caf
commit 2c391a636b

View file

@ -58,7 +58,7 @@ def main():
if exp_dir is None:
experiment_index = len(glob(f"{cfg.outputs}/*")) - 1
model_name = cfg.model["type"].replace("/", "-")
exp_name = f"{experiment_index:03d}-F{cfg.num_frames}S{cfg.frame_interval}-{model_name}"
exp_name = f"{experiment_index:03d}-F{cfg.dataset.num_frames}S{cfg.dataset.frame_interval}-{model_name}"
exp_dir = f"{cfg.outputs}/{exp_name}"
assert os.path.exists(exp_dir)
@ -267,8 +267,8 @@ def main():
# calculate discriminator_time_padding
disc_time_downsample_factor = 2 ** len(cfg.discriminator.channel_multipliers)
if cfg.num_frames % disc_time_downsample_factor != 0:
disc_time_padding = disc_time_downsample_factor - cfg.num_frames % disc_time_downsample_factor
if cfg.dataset.num_frames % disc_time_downsample_factor != 0:
disc_time_padding = disc_time_downsample_factor - cfg.dataset.num_frames % disc_time_downsample_factor
else:
disc_time_padding = 0
video_contains_first_frame = cfg.video_contains_first_frame
@ -314,8 +314,8 @@ def main():
# supprt for image or video inputs
assert x.ndim in {4, 5}, f"received input of {x.ndim} dimensions" # either image or video
assert (
x.shape[-2:] == cfg.image_size
), f"received input size {x.shape[-2:]}, but config image size is {cfg.image_size}"
x.shape[-2:] == cfg.dataset.image_size
), f"received input size {x.shape[-2:]}, but config image size is {cfg.dataset.image_size}"
is_image = x.ndim == 4
if is_image:
video = rearrange(x, "b c ... -> b c 1 ...")