From 9ce3099f4716546155face9d776f38d6664b6aff Mon Sep 17 00:00:00 2001 From: Shen-Chenhui Date: Fri, 19 Apr 2024 11:27:33 +0800 Subject: [PATCH] config trial data --- .../vae_magvit_v2/train/16x128x128_trial.py | 82 +++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 configs/vae_magvit_v2/train/16x128x128_trial.py diff --git a/configs/vae_magvit_v2/train/16x128x128_trial.py b/configs/vae_magvit_v2/train/16x128x128_trial.py new file mode 100644 index 0000000..9de6e7f --- /dev/null +++ b/configs/vae_magvit_v2/train/16x128x128_trial.py @@ -0,0 +1,82 @@ +num_frames = 16 +frame_interval = 3 +image_size = (128, 128) + +# Define dataset +root = None +data_path = "CSV_PATH" +use_image_transform = False +num_workers = 4 +video_contains_first_frame = False + +# Define acceleration +dtype = "bf16" +grad_checkpoint = True +plugin = "zero2" +sp_size = 1 + + +# Define model + +model = dict( + type="VAE_MAGVIT_V2", + in_out_channels = 3, + latent_embed_dim = 256, + filters = 128, + num_res_blocks = 4, + channel_multipliers = (1, 2, 2, 4), + temporal_downsample = (False, True, True), + num_groups = 32, # for nn.GroupNorm + kl_embed_dim = 64, + activation_fn = 'swish', + separate_first_frame_encoding = False, + disable_space = False, + custom_conv_padding = None +) + + +discriminator = dict( + type="DISCRIMINATOR_3D", + image_size = image_size, + num_frames = num_frames, + in_channels = 3, + filters = 128, + channel_multipliers = (2,4,4,4,4), # (2,4,4,4) for 64x64 resolution +) + + +# loss weights +logvar_init=0.0 +kl_loss_weight = 0.000001 +perceptual_loss_weight = 0.1 # use vgg is not None and more than 0 +discriminator_factor = 1.0 # for discriminator adversarial loss +# discriminator_loss_weight = 0.5 # for generator adversarial loss +generator_factor = 0.1 # SCH: generator adversarial loss, MAGVIT v2 uses 0.1 +lecam_loss_weight = None # NOTE: MAVGIT v2 use 0.001 +discriminator_loss_type="non-saturating" +generator_loss_type="non-saturating" +discriminator_start = 1000 # 50000 NOTE: change to correct val, debug use -1 for now +gradient_penalty_loss_weight = None # 10 # SCH: MAGVIT uses 10, opensora plan doesn't use +ema_decay = 0.999 # ema decay factor for generator + + +# Others +seed = 42 +outputs = "outputs" +wandb = False + +# Training +''' NOTE: +magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128 +==> ours num_frams = 16, reso = 256, so samples (K) * epochs ~ [500 - 1200], +3-6 epochs for pexel, from pexel observation its correct +''' + +epochs = 100 +log_every = 1 +ckpt_every = 200 +load = None + +batch_size = 4 +lr = 1e-4 +grad_clip = 1.0