diff --git a/configs/opensora-v1-2/inference/sample.py b/configs/opensora-v1-2/inference/sample.py index a4b592f..abc43b5 100644 --- a/configs/opensora-v1-2/inference/sample.py +++ b/configs/opensora-v1-2/inference/sample.py @@ -1,4 +1,5 @@ -image_size = (240, 426) +resolution = "240p" +aspect_ratio = "9:16" num_frames = 51 fps = 24 frame_interval = 1 @@ -30,22 +31,10 @@ model = dict( enable_layernorm_kernel=True, ) vae = dict( - type="VideoAutoencoderPipeline", + type="OpenSoraVAE_V1_2", from_pretrained="pretrained_models/vae-pipeline", - shift=(-0.10, 0.34, 0.27, 0.98), - scale=(3.85, 2.32, 2.33, 3.06), micro_frame_size=17, - vae_2d=dict( - type="VideoAutoencoderKL", - from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", - subfolder="vae", - micro_batch_size=4, - local_files_only=True, - ), - vae_temporal=dict( - type="VAE_Temporal_SD", - from_pretrained=None, - ), + micro_batch_size=4, ) text_encoder = dict( type="t5", diff --git a/configs/opensora-v1-2/train/adapt.py b/configs/opensora-v1-2/train/adapt.py index 7088dfe..94a6ca4 100644 --- a/configs/opensora-v1-2/train/adapt.py +++ b/configs/opensora-v1-2/train/adapt.py @@ -36,22 +36,10 @@ model = dict( enable_layernorm_kernel=True, ) vae = dict( - type="VideoAutoencoderPipeline", + type="OpenSoraVAE_V1_2", from_pretrained="pretrained_models/vae-pipeline", - shift=(-0.10, 0.34, 0.27, 0.98), - scale=(3.85, 2.32, 2.33, 3.06), micro_frame_size=17, - vae_2d=dict( - type="VideoAutoencoderKL", - from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", - subfolder="vae", - micro_batch_size=4, - local_files_only=True, - ), - vae_temporal=dict( - type="VAE_Temporal_SD", - from_pretrained=None, - ), + micro_batch_size=4, ) text_encoder = dict( type="t5", diff --git a/configs/opensora-v1-2/train/eval.py b/configs/opensora-v1-2/train/eval.py deleted file mode 100644 index 68686ae..0000000 --- a/configs/opensora-v1-2/train/eval.py +++ /dev/null @@ -1,88 +0,0 @@ -# Dataset settings -dataset = dict( - type="VariableVideoTextDataset", - transform_name="resize_crop", - frame_interval=1, -) - -bucket_config = { # 20s/it - "1024": {1: (1.0, 1)}, -} - -grad_checkpoint = True -batch_size = None - -# Acceleration settings -num_workers = 8 -num_bucket_build_workers = 16 -dtype = "bf16" -plugin = "zero2" - -# Model settings -model = dict( - type="STDiT3-XL/2", - from_pretrained=None, - qk_norm=True, - enable_flash_attn=True, - enable_layernorm_kernel=True, -) -vae = dict( - type="VideoAutoencoderPipeline", - from_pretrained="pretrained_models/vae-pipeline", - micro_frame_size=17, - shift=(-0.10, 0.34, 0.27, 0.98), - scale=(3.85, 2.32, 2.33, 3.06), - vae_2d=dict( - type="VideoAutoencoderKL", - from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", - subfolder="vae", - micro_batch_size=4, - local_files_only=True, - ), - vae_temporal=dict( - type="VAE_Temporal_SD", - from_pretrained=None, - ), -) -text_encoder = dict( - type="t5", - from_pretrained="DeepFloyd/t5-v1_1-xxl", - model_max_length=300, - shardformer=True, - local_files_only=True, -) -scheduler = dict( - type="rflow", - use_discrete_timesteps=False, - use_timestep_transform=False, - # sample_method="logit-normal", -) - -# Mask settings -# mask_ratios = { -# "random": 0.1, -# "intepolate": 0.01, -# "quarter_random": 0.01, -# "quarter_head": 0.01, -# "quarter_tail": 0.01, -# "quarter_head_tail": 0.01, -# "image_random": 0.05, -# "image_head": 0.1, -# "image_tail": 0.05, -# "image_head_tail": 0.05, -# } - -# Log settings -seed = 42 -outputs = "outputs" -wandb = False -epochs = 1000 -log_every = 10 -ckpt_every = 500 - -# optimization settings -load = None -grad_clip = 1.0 -lr = 2e-4 -ema_decay = 0.99 -adam_eps = 1e-15 diff --git a/configs/opensora-v1-2/train/stage1.py b/configs/opensora-v1-2/train/stage1.py index 531a485..59e5b90 100644 --- a/configs/opensora-v1-2/train/stage1.py +++ b/configs/opensora-v1-2/train/stage1.py @@ -61,22 +61,10 @@ model = dict( enable_layernorm_kernel=True, ) vae = dict( - type="VideoAutoencoderPipeline", + type="OpenSoraVAE_V1_2", from_pretrained="pretrained_models/vae-pipeline", - shift=(-0.10, 0.34, 0.27, 0.98), - scale=(3.85, 2.32, 2.33, 3.06), micro_frame_size=17, - vae_2d=dict( - type="VideoAutoencoderKL", - from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", - subfolder="vae", - micro_batch_size=4, - local_files_only=True, - ), - vae_temporal=dict( - type="VAE_Temporal_SD", - from_pretrained=None, - ), + micro_batch_size=4, ) text_encoder = dict( type="t5", diff --git a/configs/vae/inference/image.py b/configs/vae/inference/image.py index e01c75c..cb25757 100644 --- a/configs/vae/inference/image.py +++ b/configs/vae/inference/image.py @@ -1,44 +1,32 @@ -num_frames = 1 -frame_interval = 1 -fps = 24 image_size = (256, 256) +num_frames = 1 + +dtype = "bf16" +batch_size = 1 +seed = 42 +save_dir = "samples/vae_video" +cal_stats = True +log_stats_every = 100 # Define dataset dataset = dict( type="VideoTextDataset", data_path=None, num_frames=num_frames, - frame_interval=1, image_size=image_size, ) +num_samples = 100 num_workers = 4 -max_test_samples = None # Define model model = dict( - type="VideoAutoencoderPipeline", - freeze_vae_2d=True, + type="OpenSoraVAE_V1_2", + from_pretrained="pretrained_models/vae-pipeline", + micro_frame_size=None, + micro_batch_size=4, cal_loss=True, - vae_2d=dict( - type="VideoAutoencoderKL", - from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", - subfolder="vae", - micro_batch_size=4, - local_files_only=True, - ), - vae_temporal=dict( - type="VAE_Temporal_SD", - from_pretrained=None, - ), ) -dtype = "bf16" # loss weights perceptual_loss_weight = 0.1 # use vgg is not None and more than 0 kl_loss_weight = 1e-6 -calc_std = True - -# Others -batch_size = 1 -seed = 42 -save_dir = "samples/vae_image" diff --git a/configs/vae/inference/video.py b/configs/vae/inference/video.py index 03e2787..4697a2e 100644 --- a/configs/vae/inference/video.py +++ b/configs/vae/inference/video.py @@ -1,45 +1,32 @@ -num_frames = 17 -frame_interval = 1 -fps = 24 image_size = (256, 256) +num_frames = 17 + +dtype = "bf16" +batch_size = 1 +seed = 42 +save_dir = "samples/vae_video" +cal_stats = True +log_stats_every = 100 # Define dataset dataset = dict( type="VideoTextDataset", data_path=None, num_frames=num_frames, - frame_interval=1, image_size=image_size, ) +num_samples = 100 num_workers = 4 -max_test_samples = None # Define model model = dict( - type="VideoAutoencoderPipeline", - from_pretrained=None, + type="OpenSoraVAE_V1_2", + from_pretrained="pretrained_models/vae-pipeline", + micro_frame_size=None, + micro_batch_size=4, cal_loss=True, - micro_frame_size=16, - vae_2d=dict( - type="VideoAutoencoderKL", - from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", - subfolder="vae", - micro_batch_size=4, - local_files_only=True, - ), - vae_temporal=dict( - type="VAE_Temporal_SD", - from_pretrained=None, - ), ) -dtype = "bf16" # loss weights perceptual_loss_weight = 0.1 # use vgg is not None and more than 0 kl_loss_weight = 1e-6 -calc_std = True - -# Others -batch_size = 1 -seed = 42 -save_dir = "samples/vae_video" diff --git a/notebooks/inference.ipynb b/notebooks/inference.ipynb new file mode 100644 index 0000000..e69de29 diff --git a/opensora/datasets/aspect.py b/opensora/datasets/aspect.py index a4a79ae..952e7a4 100644 --- a/opensora/datasets/aspect.py +++ b/opensora/datasets/aspect.py @@ -1,28 +1,7 @@ import math + # computation -AR = ( - 3 / 8, - 9 / 21, - 0.48, - 1 / 2, - 9 / 17, - 1 / 1.85, - 9 / 16, - 5 / 8, - 2 / 3, - 3 / 4, - 1 / 1, - 4 / 3, - 3 / 2, - 16 / 9, - 17 / 9, - 2 / 1, - 1 / 0.48, -) -AR_fraction = (0.375, 0.43, 0.48, 0.50, 0.53, 0.54, 0.56, 0.62, 0.67, 0.75, 1, 1.33, 1.50, 1.78, 1.89, 2, 2.08) - - def get_h_w(a, ts, eps=1e-4): h = (ts * a) ** 0.5 h = h + eps @@ -33,11 +12,39 @@ def get_h_w(a, ts, eps=1e-4): return h, w -def get_aspect_ratios_dict(ts=360 * 640, ars=AR): +def get_aspect_ratios_dict(ars, ts=360 * 640): est = {f"{a:.2f}": get_h_w(a, ts) for a in ars} return est +def get_ar(ratio): + h, w = ratio.split(":") + return int(h) / int(w) + + +ASPECT_RATIO_MAP = { + "3:8": "0.38", + "9:21": "0.43", + "12:25": "0.48", + "1:2": "0.50", + "9:17": "0.53", + "27:50": "0.54", + "9:16": "0.56", + "5:8": "0.62", + "2:3": "0.67", + "3:4": "0.75", + "1:1": "1.00", + "4:3": "1.33", + "3:2": "1.50", + "16:9": "1.78", + "17:9": "1.89", + "2:1": "2.00", + "50:27": "2.08", +} + + +AR = [get_ar(ratio) for ratio in ASPECT_RATIO_MAP.keys()] + # computed from above code # S = 8294400 ASPECT_RATIO_4K = { @@ -454,3 +461,10 @@ ASPECT_RATIOS = { def get_num_pixels(name): return ASPECT_RATIOS[name][0] + + +def get_image_size(resolution, ar_ratio): + ar_key = ASPECT_RATIO_MAP[ar_ratio] + rs_dict = ASPECT_RATIOS[resolution][1] + assert ar_key in rs_dict, f"Aspect ratio {ar_ratio} not found for resolution {resolution}" + return rs_dict[ar_key] diff --git a/opensora/datasets/dataloader.py b/opensora/datasets/dataloader.py index 96c6590..4d002cd 100644 --- a/opensora/datasets/dataloader.py +++ b/opensora/datasets/dataloader.py @@ -19,6 +19,7 @@ def prepare_dataloader( pin_memory=False, num_workers=0, process_group: Optional[ProcessGroup] = None, + distributed=True, **kwargs, ): r""" @@ -43,13 +44,16 @@ def prepare_dataloader( :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. """ _kwargs = kwargs.copy() - process_group = process_group or _get_default_group() - sampler = StatefulDistributedSampler( - dataset, - num_replicas=process_group.size(), - rank=process_group.rank(), - shuffle=shuffle, - ) + if distributed: + process_group = process_group or _get_default_group() + sampler = StatefulDistributedSampler( + dataset, + num_replicas=process_group.size(), + rank=process_group.rank(), + shuffle=shuffle, + ) + else: + sampler = None # Deterministic dataloader def seed_worker(worker_id): diff --git a/opensora/models/vae/vae.py b/opensora/models/vae/vae.py index 2115e63..2420cf1 100644 --- a/opensora/models/vae/vae.py +++ b/opensora/models/vae/vae.py @@ -220,3 +220,39 @@ class VideoAutoencoderPipeline(nn.Module): @property def dtype(self): return next(self.parameters()).dtype + + +@MODELS.register_module() +class OpenSoraVAE_V1_2(VideoAutoencoderPipeline): + def __init__( + self, + micro_batch_size=4, + micro_frame_size=17, + from_pretrained=None, + local_files_only=True, + freeze_vae_2d=False, + cal_loss=False, + ): + vae_2d = dict( + type="VideoAutoencoderKL", + from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", + subfolder="vae", + micro_batch_size=micro_batch_size, + local_files_only=local_files_only, + ) + vae_temporal = dict( + type="VAE_Temporal_SD", + from_pretrained=None, + ) + shift = (-0.10, 0.34, 0.27, 0.98) + scale = (3.85, 2.32, 2.33, 3.06) + super().__init__( + vae_2d, + vae_temporal, + from_pretrained, + freeze_vae_2d=freeze_vae_2d, + cal_loss=cal_loss, + micro_frame_size=micro_frame_size, + shift=shift, + scale=scale, + ) diff --git a/opensora/utils/config_utils.py b/opensora/utils/config_utils.py index e0177ed..b8f2af6 100644 --- a/opensora/utils/config_utils.py +++ b/opensora/utils/config_utils.py @@ -24,10 +24,8 @@ def parse_args(training=False): ) parser.add_argument("--batch-size", default=None, type=int, help="batch size") parser.add_argument("--outputs", default=None, type=str, help="the dir to save model weights") - parser.add_argument("--flash-attn", default=None, action=argparse.BooleanOptionalAction, help="enable flash attn") - parser.add_argument( - "--layernorm-kernel", default=None, action=argparse.BooleanOptionalAction, help="enable layernorm kernel" - ) + parser.add_argument("--flash-attn", default=None, type=str2bool, help="enable flash attention") + parser.add_argument("--layernorm-kernel", default=None, type=str2bool, help="enable layernorm kernel") # ====================================================== # Inference @@ -51,6 +49,9 @@ def parse_args(training=False): parser.add_argument("--num-frames", default=None, type=int, help="number of frames") parser.add_argument("--fps", default=None, type=int, help="fps") parser.add_argument("--image-size", default=None, type=int, nargs=2, help="image size") + parser.add_argument("--frame-interval", default=None, type=int, help="frame interval") + parser.add_argument("--resolution", default=None, type=str, help="multi resolution") + parser.add_argument("--aspect-ratio", default=None, type=float, help="aspect ratio") # hyperparameters parser.add_argument("--num-sampling-steps", default=None, type=int, help="sampling steps") @@ -143,3 +144,14 @@ def define_experiment_workspace(cfg, get_last_workspace=False): def save_training_config(cfg, experiment_dir): with open(f"{experiment_dir}/config.txt", "w") as f: json.dump(cfg, f, indent=4) + + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") diff --git a/opensora/utils/inference_utils.py b/opensora/utils/inference_utils.py index 48b1312..beb07e4 100644 --- a/opensora/utils/inference_utils.py +++ b/opensora/utils/inference_utils.py @@ -132,7 +132,9 @@ def find_nearest_point(value, point, max_value): def apply_mask_strategy(z, refs_x, mask_strategys, loop_i, align=None): masks = [] + no_mask = True for i, mask_strategy in enumerate(mask_strategys): + no_mask = False mask = torch.ones(z.shape[2], dtype=torch.float, device=z.device) mask_strategy = parse_mask_strategy(mask_strategy) for mst in mask_strategy: @@ -154,6 +156,8 @@ def apply_mask_strategy(z, refs_x, mask_strategys, loop_i, align=None): z[i, :, m_target_start : m_target_start + m_length] = ref[:, m_ref_start : m_ref_start + m_length] mask[m_target_start : m_target_start + m_length] = edit_ratio masks.append(mask) + if no_mask: + return None masks = torch.stack(masks) return masks diff --git a/opensora/utils/misc.py b/opensora/utils/misc.py index ea620a4..621bd73 100644 --- a/opensora/utils/misc.py +++ b/opensora/utils/misc.py @@ -26,6 +26,13 @@ def is_main_process(): return not is_distributed() or dist.get_rank() == 0 +def get_world_size(): + if is_distributed(): + return dist.get_world_size() + else: + return 1 + + def create_logger(logging_dir=None): """ Create a logger that writes to a log file and stdout. diff --git a/scripts/inference.py b/scripts/inference.py index 03bf5e7..af591b0 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -10,6 +10,7 @@ from tqdm import tqdm from opensora.acceleration.parallel_states import set_sequence_parallel_group from opensora.datasets import save_sample +from opensora.datasets.aspect import get_image_size from opensora.models.text_encoder.t5 import text_preprocessing from opensora.registry import MODELS, SCHEDULERS, build_module from opensora.utils.config_utils import parse_configs @@ -52,7 +53,7 @@ def main(): else: coordinator = None enable_sequence_parallelism = False - set_random_seed(seed=cfg.seed) + set_random_seed(seed=cfg.get("seed", 1024)) # == init logger == logger = create_logger() @@ -68,8 +69,19 @@ def main(): text_encoder = build_module(cfg.text_encoder, MODELS, device=device) vae = build_module(cfg.vae, MODELS).to(device, dtype).eval() + # == prepare video size == + image_size = cfg.get("image_size", None) + if image_size is None: + resolution = cfg.get("resolution", None) + aspect_ratio = cfg.get("aspect_ratio", None) + assert ( + resolution is not None and aspect_ratio is not None + ), "resolution and aspect_ratio must be provided if image_size is not provided" + image_size = get_image_size(resolution, aspect_ratio) + num_frames = cfg.num_frames + # == build diffusion model == - input_size = (cfg.get("num_frames", None), *cfg.get("image_size", (None, None))) + input_size = (num_frames, *image_size) latent_size = vae.get_latent_size(input_size) model = ( build_module( @@ -106,10 +118,8 @@ def main(): assert len(mask_strategy) == len(prompts), "Length of mask_strategy must be the same as prompts" # == prepare arguments == - image_size = cfg.image_size - num_frames = cfg.num_frames fps = cfg.fps - save_fps = cfg.fps // cfg.get("frame_interval", 1) + save_fps = fps // cfg.get("frame_interval", 1) multi_resolution = cfg.get("multi_resolution", None) batch_size = cfg.get("batch_size", 1) num_sample = cfg.get("num_sample", 1) diff --git a/scripts/inference_vae.py b/scripts/inference_vae.py index 240e7bb..cf18d10 100644 --- a/scripts/inference_vae.py +++ b/scripts/inference_vae.py @@ -1,144 +1,145 @@ import os +from pprint import pformat import colossalai import torch -import torch.distributed as dist -from colossalai.cluster import DistCoordinator from mmengine.runner import set_random_seed from tqdm import tqdm from opensora.acceleration.parallel_states import get_data_parallel_group -from opensora.datasets import prepare_dataloader +from opensora.datasets import prepare_dataloader, save_sample from opensora.models.vae.losses import VAELoss from opensora.registry import DATASETS, MODELS, build_module from opensora.utils.config_utils import parse_configs -from opensora.utils.misc import to_torch_dtype +from opensora.utils.misc import create_logger, get_world_size, is_distributed, is_main_process, to_torch_dtype def main(): - # ====================================================== - # 1. cfg and init distributed env - # ====================================================== - cfg = parse_configs(training=False) - print(cfg) - - # init distributed - if os.environ.get("WORLD_SIZE", None): - colossalai.launch_from_torch({}) - coordinator = DistCoordinator() - else: - pass - - # ====================================================== - # 2. runtime variables - # ====================================================== torch.set_grad_enabled(False) + # ====================================================== + # configs & runtime variables + # ====================================================== + # == parse configs == + cfg = parse_configs(training=False) + + # == device and dtype == + device = "cuda" if torch.cuda.is_available() else "cpu" + cfg_dtype = cfg.get("dtype", "fp32") + assert cfg_dtype in ["fp16", "bf16", "fp32"], f"Unknown mixed precision {cfg_dtype}" + dtype = to_torch_dtype(cfg.get("dtype", "bf16")) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True - device = "cuda" if torch.cuda.is_available() else "cpu" - dtype = to_torch_dtype(cfg.dtype) - set_random_seed(seed=cfg.seed) + + # == init distributed env == + if is_distributed(): + colossalai.launch_from_torch({}) + set_random_seed(seed=cfg.get("seed", 1024)) + + # == init logger == + logger = create_logger() + logger.info("Training configuration:\n %s", pformat(cfg.to_dict())) + verbose = cfg.get("verbose", 1) # ====================================================== - # 3. build dataset and dataloader + # build dataset and dataloader # ====================================================== + logger.info("Building reconstruction dataset...") dataset = build_module(cfg.dataset, DATASETS) + batch_size = cfg.get("batch_size", 1) dataloader = prepare_dataloader( dataset, - batch_size=cfg.batch_size, - num_workers=cfg.num_workers, + batch_size=batch_size, + num_workers=cfg.get("num_workers", 4), shuffle=False, - drop_last=True, + drop_last=False, pin_memory=True, process_group=get_data_parallel_group(), + distributed=is_distributed(), ) - print(f"Dataset contains {len(dataset):,} videos ({cfg.dataset.data_path})") - total_batch_size = cfg.batch_size * dist.get_world_size() - print(f"Total batch size: {total_batch_size}") + logger.info("Dataset %s contains %s videos.", cfg.dataset.data_path, len(dataset)) + total_batch_size = batch_size * get_world_size() + logger.info("Total batch size: %s", total_batch_size) + + total_steps = len(dataloader) + if cfg.get("num_samples", None) is not None: + total_steps = min(int(cfg.num_samples // cfg.batch_size), total_steps) + logger.info("limiting test dataset to %s", int(cfg.num_samples // cfg.batch_size) * cfg.batch_size) + dataiter = iter(dataloader) # ====================================================== - # 4. build model & load weights + # build model & loss # ====================================================== - # 4.1. build model - model = build_module(cfg.model, MODELS) - model.to(device, dtype).eval() - - # ====================================================== - # 5. inference - # ====================================================== - cfg.save_dir - - # define loss function + logger.info("Building models...") + model = build_module(cfg.model, MODELS).to(device, dtype).eval() vae_loss_fn = VAELoss( logvar_init=cfg.get("logvar_init", 0.0), - perceptual_loss_weight=cfg.perceptual_loss_weight, - kl_loss_weight=cfg.kl_loss_weight, + perceptual_loss_weight=cfg.get("perceptual_loss_weight", 0.1), + kl_loss_weight=cfg.get("kl_loss_weight", 1e-6), device=device, dtype=dtype, ) - # get total number of steps - total_steps = len(dataloader) - if cfg.max_test_samples is not None: - total_steps = min(int(cfg.max_test_samples // cfg.batch_size), total_steps) - print(f"limiting test dataset to {int(cfg.max_test_samples//cfg.batch_size) * cfg.batch_size}") - dataloader_iter = iter(dataloader) - + # ====================================================== + # inference + # ====================================================== + # == global variables == running_loss = running_nll = running_nll_z = 0.0 loss_steps = 0 - - calc_std = cfg.get("calc_std", False) - if calc_std: - running_sum = 0.0 - running_sum_c = torch.zeros(model.out_channels, dtype=torch.float, device=device) - running_var = 0.0 - running_var_c = torch.zeros(model.out_channels, dtype=torch.float, device=device) + cal_stats = cfg.get("cal_stats", False) + if cal_stats: num_samples = 0 + running_sum = running_var = 0.0 + running_sum_c = torch.zeros(model.out_channels, dtype=torch.float, device=device) + running_var_c = torch.zeros(model.out_channels, dtype=torch.float, device=device) + # prepare arguments + save_fps = cfg.get("fps", 24) // cfg.get("frame_interval", 1) + + # Iter over the dataset with tqdm( range(total_steps), - disable=not coordinator.is_master(), + disable=not is_main_process() or verbose < 1, total=total_steps, initial=0, ) as pbar: for step in pbar: - batch = next(dataloader_iter) + batch = next(dataiter) x = batch["video"].to(device, dtype) # [B, C, T, H, W] + + # == vae encoding & decoding === + z, posterior, x_z = model.encode(x) + x_rec, x_z_rec = model.decode(z, num_frames=x.size(2)) + x_ref = model.spatial_vae.decode(x_z) + + # == check z shape == input_size = x.shape[2:] latent_size = model.get_latent_size(input_size) + assert list(z.shape[2:]) == latent_size, f"z shape: {z.shape}, latent_size: {latent_size}" - # ===== VAE ===== - z, posterior, x_z = model.encode(x) - - # calc std - if calc_std: + # == calculate stats == + if cal_stats: num_samples += 1 running_sum += z.mean().item() running_var += (z - running_sum / num_samples).pow(2).mean().item() - running_sum_c += z.mean(dim=(0, 2, 3, 4)).float() running_var_c += ( (z - running_sum_c[None, :, None, None, None] / num_samples).pow(2).mean(dim=(0, 2, 3, 4)).float() ) - pbar.set_postfix( - { - "mean": running_sum / num_samples, - "std": (running_var / num_samples) ** 0.5, - } - ) - if num_samples % 100 == 0: - print( - " mean_c ", + if verbose >= 1: + pbar.set_postfix( + { + "mean": running_sum / num_samples, + "std": (running_var / num_samples) ** 0.5, + } + ) + if num_samples % cfg.get("log_stats_every", 100) == 0: + logger.info( + "VAE feature per channel stats: mean %s, var %s", (running_sum_c / num_samples).cpu().tolist(), - "std_c ", (running_var_c / num_samples).sqrt().cpu().tolist(), ) - assert list(z.shape[2:]) == latent_size, f"z shape: {z.shape}, latent_size: {latent_size}" - x_rec, x_z_rec = model.decode(z, num_frames=x.size(2)) - model.spatial_vae.decode(x_z) - - # loss calculation + # == loss calculation == nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(x, x_rec, posterior) nll_loss_z, _, _ = vae_loss_fn(x_z, x_z_rec, posterior, no_perceptual=True) vae_loss = weighted_nll_loss + weighted_kl_loss @@ -147,24 +148,24 @@ def main(): running_nll = nll_loss.item() / loss_steps + running_nll * ((loss_steps - 1) / loss_steps) running_nll_z = nll_loss_z.item() / loss_steps + running_nll_z * ((loss_steps - 1) / loss_steps) - # if not use_dist or coordinator.is_master(): - # ori_dir = f"{save_dir}_ori" - # rec_dir = f"{save_dir}_rec" - # ref_dir = f"{save_dir}_ref" - # os.makedirs(ori_dir, exist_ok=True) - # os.makedirs(rec_dir, exist_ok=True) - # os.makedirs(ref_dir, exist_ok=True) - # for idx, vid in enumerate(x): - # pos = step * cfg.batch_size + idx - # save_sample(vid, fps=cfg.fps, save_path=f"{ori_dir}/{pos:03d}") - # save_sample(x_rec[idx], fps=cfg.fps, save_path=f"{rec_dir}/{pos:03d}") - # save_sample(x_ref[idx], fps=cfg.fps, save_path=f"{ref_dir}/{pos:03d}") + # == save samples == + save_dir = cfg.get("save_dir", None) + if is_main_process() and save_dir is not None: + ori_dir = f"{save_dir}_ori" + rec_dir = f"{save_dir}_rec" + ref_dir = f"{save_dir}_spatial" + os.makedirs(ori_dir, exist_ok=True) + os.makedirs(rec_dir, exist_ok=True) + os.makedirs(ref_dir, exist_ok=True) + for idx, vid in enumerate(x): + pos = step * cfg.batch_size + idx + save_sample(vid, fps=save_fps, save_path=f"{ori_dir}/{pos:03d}", verbose=verbose >= 2) + save_sample(x_rec[idx], fps=save_fps, save_path=f"{rec_dir}/{pos:03d}", verbose=verbose >= 2) + save_sample(x_ref[idx], fps=save_fps, save_path=f"{ref_dir}/{pos:03d}", verbose=verbose >= 2) - print("test vae loss:", running_loss) - print("test nll loss:", running_nll) - print("test nll_z loss:", running_nll_z) - if calc_std: - print("z std:", running_std_sum / num_samples) + logger.info("VAE loss: %s", running_loss) + logger.info("VAE nll loss: %s", running_nll) + logger.info("VAE nll_z loss: %s", running_nll_z) if __name__ == "__main__": diff --git a/scripts/train_vae.py b/scripts/train_vae.py index 9c5cbec..d12ddc9 100644 --- a/scripts/train_vae.py +++ b/scripts/train_vae.py @@ -230,12 +230,12 @@ def main(): for epoch in range(start_epoch, cfg_epochs): # == set dataloader to new epoch == dataloader.sampler.set_epoch(epoch) - dataloader_iter = iter(dataloader) + dataiter = iter(dataloader) logger.info("Beginning epoch %s...", epoch) # == training loop in an epoch == with tqdm( - enumerate(dataloader_iter, start=start_step), + enumerate(dataiter, start=start_step), desc=f"Epoch {epoch}", disable=not coordinator.is_master(), total=num_steps_per_epoch, @@ -253,7 +253,7 @@ def main(): length = random.randint(1, x.size(2)) x = x[:, :, :length, :, :] - # == vae encoding === + # == vae encoding & decoding === x_rec, x_z_rec, z, posterior, x_z = model(x) # == loss initialization ==