From 1171e5b6f9ed67d88d7433e5d72ffbd8377943f4 Mon Sep 17 00:00:00 2001 From: zhengzangw Date: Mon, 29 Apr 2024 07:27:15 +0000 Subject: [PATCH] update config --- .gitignore | 1 + .../inference/17x128x128_pixabay.py | 64 ++++---- .../inference/1x256x256_pixabay.py | 78 +++++----- .../vae_magvit_v2/train/1x256x256_pixabay.py | 8 +- opensora/utils/ckpt_utils.py | 2 +- scripts/inference-vae-v2.py | 140 +++++++----------- 6 files changed, 132 insertions(+), 161 deletions(-) diff --git a/.gitignore b/.gitignore index b9f8121..4d5dc97 100644 --- a/.gitignore +++ b/.gitignore @@ -181,3 +181,4 @@ cache/ hostfile gradio_cached_examples/ wandb/ +taming/ diff --git a/configs/vae_magvit_v2/inference/17x128x128_pixabay.py b/configs/vae_magvit_v2/inference/17x128x128_pixabay.py index 9f6cd50..b264c2e 100644 --- a/configs/vae_magvit_v2/inference/17x128x128_pixabay.py +++ b/configs/vae_magvit_v2/inference/17x128x128_pixabay.py @@ -36,44 +36,44 @@ vae_2d = dict( model = dict( type="VAE_MAGVIT_V2", - in_out_channels = 4, - latent_embed_dim = 64, - filters = 128, - num_res_blocks = 4, - channel_multipliers = (1, 2, 2, 4), - temporal_downsample = (False, True, True), - num_groups = 32, # for nn.GroupNorm - kl_embed_dim = 4, - activation_fn = 'swish', - separate_first_frame_encoding = False, - disable_space = True, - custom_conv_padding = None, - encoder_double_z = True, + in_out_channels=4, + latent_embed_dim=64, + filters=128, + num_res_blocks=4, + channel_multipliers=(1, 2, 2, 4), + temporal_downsample=(False, True, True), + num_groups=32, # for nn.GroupNorm + kl_embed_dim=4, + activation_fn="swish", + separate_first_frame_encoding=False, + disable_space=True, + custom_conv_padding=None, + encoder_double_z=True, ) discriminator = dict( type="DISCRIMINATOR_3D", - image_size = (128, 128), - num_frames = num_frames, - in_channels = 3, - filters = 128, - channel_multipliers = (2,4,4,4,4), + image_size=(128, 128), + num_frames=num_frames, + in_channels=3, + filters=128, + channel_multipliers=(2, 4, 4, 4, 4), # channel_multipliers = (2,4,4), #(2,4,4,4,4) # (2,4,4,4) for 64x64 resolution ) -# loss weights -logvar_init=0.0 +# loss weights +logvar_init = 0.0 kl_loss_weight = 0.000001 -perceptual_loss_weight = 0.1 # use vgg is not None and more than 0 -discriminator_factor = 1.0 # for discriminator adversarial loss +perceptual_loss_weight = 0.1 # use vgg is not None and more than 0 +discriminator_factor = 1.0 # for discriminator adversarial loss # discriminator_loss_weight = 0.5 # for generator adversarial loss -generator_factor = 0.1 # for generator adversarial loss -lecam_loss_weight = None # NOTE: not clear in MAGVIT what is the weight -discriminator_loss_type="non-saturating" -generator_loss_type="non-saturating" -discriminator_start = 2500 # 50000 NOTE: change to correct val, debug use -1 for now -gradient_penalty_loss_weight = None # 10 # SCH: MAGVIT uses 10, opensora plan doesn't use +generator_factor = 0.1 # for generator adversarial loss +lecam_loss_weight = None # NOTE: not clear in MAGVIT what is the weight +discriminator_loss_type = "non-saturating" +generator_loss_type = "non-saturating" +discriminator_start = 2500 # 50000 NOTE: change to correct val, debug use -1 for now +gradient_penalty_loss_weight = None # 10 # SCH: MAGVIT uses 10, opensora plan doesn't use ema_decay = 0.999 # ema decay factor for generator @@ -83,15 +83,15 @@ save_dir = "outputs/samples_pixabay_17" wandb = False # Training -''' NOTE: +""" NOTE: magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128 -==> ours num_frams = 16, reso = 256, so samples (K) * epochs ~ [500 - 1200], +==> ours num_frams = 16, reso = 256, so samples (K) * epochs ~ [500 - 1200], 3-6 epochs for pexel, from pexel observation its correct -''' +""" batch_size = 1 lr = 1e-4 grad_clip = 1.0 -calc_loss = True \ No newline at end of file +calc_loss = True diff --git a/configs/vae_magvit_v2/inference/1x256x256_pixabay.py b/configs/vae_magvit_v2/inference/1x256x256_pixabay.py index 31a1327..98c162e 100644 --- a/configs/vae_magvit_v2/inference/1x256x256_pixabay.py +++ b/configs/vae_magvit_v2/inference/1x256x256_pixabay.py @@ -1,23 +1,18 @@ num_frames = 1 - image_size = (256, 256) - dataset = dict( type="VideoTextDataset", data_path=None, num_frames=num_frames, - frame_interval=3, + frame_interval=1, image_size=image_size, get_text=False, ) fps = 24 // 3 is_vae = True - -# Define dataset max_test_samples = -1 - # Define acceleration num_workers = 4 dtype = "bf16" @@ -33,67 +28,70 @@ video_contains_first_frame = True vae_2d = dict( type="VideoAutoencoderKL", - from_pretrained="stabilityai/sd-vae-ft-ema", + from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", + subfolder="vae", + micro_batch_size=4, + local_files_only=True, ) model = dict( type="VAE_MAGVIT_V2", - in_out_channels = 4, - latent_embed_dim = 64, - filters = 128, - num_res_blocks = 4, - channel_multipliers = (1, 2, 2, 4), - temporal_downsample = (False, True, True), - num_groups = 32, # for nn.GroupNorm - kl_embed_dim = 4, - activation_fn = 'swish', - separate_first_frame_encoding = False, - disable_space = True, - custom_conv_padding = None, - encoder_double_z = True, + in_out_channels=4, + latent_embed_dim=64, + filters=128, + num_res_blocks=4, + channel_multipliers=(1, 2, 2, 4), + temporal_downsample=(False, True, True), + num_groups=32, # for nn.GroupNorm + kl_embed_dim=4, + activation_fn="swish", + separate_first_frame_encoding=False, + disable_space=True, + custom_conv_padding=None, + encoder_double_z=True, ) discriminator = dict( type="DISCRIMINATOR_3D", - image_size = image_size, - num_frames = num_frames, - in_channels = 3, - filters = 128, - channel_multipliers = (2,4,4,4,4), + image_size=image_size, + num_frames=num_frames, + in_channels=3, + filters=128, + channel_multipliers=(2, 4, 4, 4, 4), # channel_multipliers = (2,4,4), #(2,4,4,4,4) # (2,4,4,4) for 64x64 resolution ) -# loss weights -logvar_init=0.0 +# loss weights +logvar_init = 0.0 kl_loss_weight = 0.000001 -perceptual_loss_weight = 0.1 # use vgg is not None and more than 0 -discriminator_factor = 1.0 # for discriminator adversarial loss +perceptual_loss_weight = 0.1 # use vgg is not None and more than 0 +discriminator_factor = 1.0 # for discriminator adversarial loss # discriminator_loss_weight = 0.5 # for generator adversarial loss -generator_factor = 0.1 # for generator adversarial loss -lecam_loss_weight = None # NOTE: not clear in MAGVIT what is the weight -discriminator_loss_type="non-saturating" -generator_loss_type="non-saturating" -discriminator_start = 2500 # 50000 NOTE: change to correct val, debug use -1 for now -gradient_penalty_loss_weight = None # 10 # SCH: MAGVIT uses 10, opensora plan doesn't use +generator_factor = 0.1 # for generator adversarial loss +lecam_loss_weight = None # NOTE: not clear in MAGVIT what is the weight +discriminator_loss_type = "non-saturating" +generator_loss_type = "non-saturating" +discriminator_start = 2500 # 50000 NOTE: change to correct val, debug use -1 for now +gradient_penalty_loss_weight = None # 10 # SCH: MAGVIT uses 10, opensora plan doesn't use ema_decay = 0.999 # ema decay factor for generator # Others seed = 42 -save_dir = "outputs/samples_pixabay_17" +save_dir = "samples/samples_pixabay_17" wandb = False # Training -''' NOTE: +""" NOTE: magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128 -==> ours num_frams = 16, reso = 256, so samples (K) * epochs ~ [500 - 1200], +==> ours num_frams = 16, reso = 256, so samples (K) * epochs ~ [500 - 1200], 3-6 epochs for pexel, from pexel observation its correct -''' +""" batch_size = 1 lr = 1e-4 grad_clip = 1.0 -calc_loss = True \ No newline at end of file +calc_loss = True diff --git a/configs/vae_magvit_v2/train/1x256x256_pixabay.py b/configs/vae_magvit_v2/train/1x256x256_pixabay.py index 2c34afa..91d99a9 100644 --- a/configs/vae_magvit_v2/train/1x256x256_pixabay.py +++ b/configs/vae_magvit_v2/train/1x256x256_pixabay.py @@ -48,8 +48,8 @@ model = dict( activation_fn="swish", separate_first_frame_encoding=False, disable_space=True, - encoder_double_z=True, custom_conv_padding=None, + encoder_double_z=True, ) @@ -60,8 +60,8 @@ discriminator = dict( in_channels=3, filters=128, use_pretrained=True, # NOTE: set to False only if we want to disable load - channel_multipliers = (2,4,4,4,4), # (2,4,4,4) for 64x64 resolution - # channel_multipliers=(2, 4, 4), # since on intermediate layer dimension ofs z + channel_multipliers=(2, 4, 4, 4, 4), # (2,4,4,4) for 64x64 resolution + # channel_multipliers=(2, 4, 4), # since on intermediate layer dimension ofs z ) @@ -76,7 +76,7 @@ discriminator_loss_type = "non-saturating" generator_loss_type = "non-saturating" # discriminator_loss_type="hinge" # generator_loss_type="hinge" -discriminator_start = 2000 # 5000 # 8k data / (8*1) = 1000 steps per epoch +discriminator_start = 2000 # 5000 # 8k data / (8*1) = 1000 steps per epoch gradient_penalty_loss_weight = None # 10 # SCH: MAGVIT uses 10, opensora plan doesn't use ema_decay = 0.999 # ema decay factor for generator diff --git a/opensora/utils/ckpt_utils.py b/opensora/utils/ckpt_utils.py index 2b0655b..804de27 100644 --- a/opensora/utils/ckpt_utils.py +++ b/opensora/utils/ckpt_utils.py @@ -256,7 +256,7 @@ def create_logger(logging_dir): return logger -def load_checkpoint(model, ckpt_path, save_as_pt=True, model_name="model"): +def load_checkpoint(model, ckpt_path, save_as_pt=False, model_name="model"): if ckpt_path.endswith(".pt") or ckpt_path.endswith(".pth"): state_dict = find_model(ckpt_path, model=model) missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) diff --git a/scripts/inference-vae-v2.py b/scripts/inference-vae-v2.py index ab43079..2fd64fe 100644 --- a/scripts/inference-vae-v2.py +++ b/scripts/inference-vae-v2.py @@ -4,25 +4,16 @@ import colossalai import torch import torch.distributed as dist from colossalai.cluster import DistCoordinator -from mmengine.runner import set_random_seed +from colossalai.utils import get_current_device +from einops import rearrange +from tqdm import tqdm -from opensora.acceleration.parallel_states import set_sequence_parallel_group -from opensora.datasets import save_sample -from opensora.registry import MODELS, SCHEDULERS, build_module +from opensora.acceleration.parallel_states import get_data_parallel_group +from opensora.datasets import prepare_dataloader, save_sample +from opensora.models.vae.vae_3d_v2 import AdversarialLoss, DiscriminatorLoss, LeCamEMA, VEALoss, pad_at_dim +from opensora.registry import DATASETS, MODELS, build_module from opensora.utils.config_utils import parse_configs from opensora.utils.misc import to_torch_dtype -from opensora.datasets import prepare_dataloader, prepare_variable_dataloader -from opensora.registry import DATASETS, MODELS, build_module -from opensora.acceleration.parallel_states import ( - get_data_parallel_group, - set_data_parallel_group, - set_sequence_parallel_group, -) -from tqdm import tqdm -from opensora.models.vae.vae_3d_v2 import VEALoss, DiscriminatorLoss, AdversarialLoss, LeCamEMA, pad_at_dim - -from einops import rearrange -from colossalai.utils import get_current_device def main(): @@ -49,9 +40,6 @@ def main(): device = get_current_device() dtype = to_torch_dtype(cfg.dtype) - - - # ====================================================== # 3. build dataset and dataloader # ====================================================== @@ -101,43 +89,37 @@ def main(): save_dir = cfg.save_dir os.makedirs(save_dir, exist_ok=True) - - # 4.1. batch generation - + # define loss function if cfg.calc_loss: vae_loss_fn = VEALoss( logvar_init=cfg.logvar_init, - perceptual_loss_weight = cfg.perceptual_loss_weight, - kl_loss_weight = cfg.kl_loss_weight, + perceptual_loss_weight=cfg.perceptual_loss_weight, + kl_loss_weight=cfg.kl_loss_weight, device=device, dtype=dtype, ) adversarial_loss_fn = AdversarialLoss( - discriminator_factor = cfg.discriminator_factor, - discriminator_start = cfg.discriminator_start, - generator_factor = cfg.generator_factor, - generator_loss_type = cfg.generator_loss_type, + discriminator_factor=cfg.discriminator_factor, + discriminator_start=cfg.discriminator_start, + generator_factor=cfg.generator_factor, + generator_loss_type=cfg.generator_loss_type, ) disc_loss_fn = DiscriminatorLoss( - discriminator_factor = cfg.discriminator_factor, - discriminator_start = cfg.discriminator_start, - discriminator_loss_type = cfg.discriminator_loss_type, - lecam_loss_weight = cfg.lecam_loss_weight, - gradient_penalty_loss_weight = cfg.gradient_penalty_loss_weight, - ) - - # LeCam EMA for discriminator - - lecam_ema = LeCamEMA( - decay=cfg.ema_decay, dtype=dtype, device=device + discriminator_factor=cfg.discriminator_factor, + discriminator_start=cfg.discriminator_start, + discriminator_loss_type=cfg.discriminator_loss_type, + lecam_loss_weight=cfg.lecam_loss_weight, + gradient_penalty_loss_weight=cfg.gradient_penalty_loss_weight, ) - - + # LeCam EMA for discriminator + + lecam_ema = LeCamEMA(decay=cfg.ema_decay, dtype=dtype, device=device) + running_loss = 0.0 running_nll = 0.0 running_disc_loss = 0.0 @@ -152,7 +134,7 @@ def main(): total_steps = len(dataloader) if cfg.max_test_samples > 0: - total_steps = min(int(cfg.max_test_samples//cfg.batch_size), total_steps) + total_steps = min(int(cfg.max_test_samples // cfg.batch_size), total_steps) print(f"limiting test dataset to {int(cfg.max_test_samples//cfg.batch_size) * cfg.batch_size}") dataloader_iter = iter(dataloader) @@ -169,7 +151,7 @@ def main(): is_image = x.ndim == 4 if is_image: - video = rearrange(x, 'b c ... -> b c 1 ...') + video = rearrange(x, "b c ... -> b c 1 ...") video_contains_first_frame = True else: video = x @@ -180,98 +162,88 @@ def main(): video_enc_spatial = vae_2d.encode(video) recon_dec_spatial, posterior = vae( - video_enc_spatial, - video_contains_first_frame = video_contains_first_frame + video_enc_spatial, video_contains_first_frame=video_contains_first_frame ) recon_video = vae_2d.decode(recon_dec_spatial) recon_2d = vae_2d.decode(video_enc_spatial) else: - recon_video, posterior = vae( - video, - video_contains_first_frame = video_contains_first_frame - ) + recon_video, posterior = vae(video, video_contains_first_frame=video_contains_first_frame) if cfg.calc_loss: # ====== Calc Loss ====== # simple nll loss - nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn( - video, - recon_video, - posterior, - split = "eval" - ) + nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(video, recon_video, posterior, split="eval") - fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value = 0., dim = 2) + fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value=0.0, dim=2) fake_logits = discriminator(fake_video.contiguous()) adversarial_loss = adversarial_loss_fn( fake_logits, - nll_loss, + nll_loss, vae.get_last_layer(), - cfg.discriminator_start+1, # Hack to use discriminator - is_training = vae.training, + cfg.discriminator_start + 1, # Hack to use discriminator + is_training=vae.training, ) - + vae_loss = weighted_nll_loss + weighted_kl_loss + adversarial_loss - + # ====== Discriminator Loss ====== - real_video = pad_at_dim(video, (disc_time_padding, 0), value = 0., dim = 2) - fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value = 0., dim = 2) + real_video = pad_at_dim(video, (disc_time_padding, 0), value=0.0, dim=2) + fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value=0.0, dim=2) if cfg.gradient_penalty_loss_weight is not None and cfg.gradient_penalty_loss_weight > 0.0: real_video = real_video.requires_grad_() - real_logits = discriminator(real_video.contiguous()) # SCH: not detached for now for gradient_penalty calculation + real_logits = discriminator( + real_video.contiguous() + ) # SCH: not detached for now for gradient_penalty calculation else: - real_logits = discriminator(real_video.contiguous().detach()) + real_logits = discriminator(real_video.contiguous().detach()) fake_logits = discriminator(fake_video.contiguous().detach()) lecam_ema_real, lecam_ema_fake = lecam_ema.get() weighted_d_adversarial_loss, lecam_loss, gradient_penalty_loss = disc_loss_fn( - real_logits, - fake_logits, - cfg.discriminator_start+1, # Hack to use discriminator - lecam_ema_real = lecam_ema_real, - lecam_ema_fake = lecam_ema_fake, - real_video = real_video if cfg.gradient_penalty_loss_weight is not None else None, + real_logits, + fake_logits, + cfg.discriminator_start + 1, # Hack to use discriminator + lecam_ema_real=lecam_ema_real, + lecam_ema_fake=lecam_ema_fake, + real_video=real_video if cfg.gradient_penalty_loss_weight is not None else None, ) disc_loss = weighted_d_adversarial_loss + lecam_loss + gradient_penalty_loss loss_steps += 1 - running_disc_loss = disc_loss.item()/loss_steps + running_disc_loss * ((loss_steps - 1) / loss_steps) - running_loss = vae_loss.item()/ loss_steps + running_loss * ((loss_steps - 1) / loss_steps) + running_disc_loss = disc_loss.item() / loss_steps + running_disc_loss * ((loss_steps - 1) / loss_steps) + running_loss = vae_loss.item() / loss_steps + running_loss * ((loss_steps - 1) / loss_steps) running_nll = nll_loss.item() / loss_steps + running_nll * ((loss_steps - 1) / loss_steps) - # ===== Spatial VAE ===== - if coordinator.is_master(): if cfg.get("use_pipeline") == True: - for idx, (sample_original, sample_pipeline, sample_2d) in enumerate(zip(video, recon_video, recon_2d)): + for idx, (sample_original, sample_pipeline, sample_2d) in enumerate( + zip(video, recon_video, recon_2d) + ): pos = step * cfg.batch_size + idx save_path = os.path.join(save_dir, f"sample_{pos}") - save_sample(sample_original, fps=cfg.fps, save_path=save_path+"_original") - save_sample(sample_2d, fps=cfg.fps, save_path=save_path+"_2d") - save_sample(sample_pipeline, fps=cfg.fps, save_path=save_path+"_pipeline") + save_sample(sample_original, fps=cfg.fps, save_path=save_path + "_original") + save_sample(sample_2d, fps=cfg.fps, save_path=save_path + "_2d") + save_sample(sample_pipeline, fps=cfg.fps, save_path=save_path + "_pipeline") else: for idx, (original, recon) in enumerate(zip(video, recon_video)): pos = step * cfg.batch_size + idx save_path = os.path.join(save_dir, f"sample_{pos}") - save_sample(original, fps=cfg.fps, save_path=save_path+"_original") - save_sample(recon, fps=cfg.fps, save_path=save_path+"_recon") - - + save_sample(original, fps=cfg.fps, save_path=save_path + "_original") + save_sample(recon, fps=cfg.fps, save_path=save_path + "_recon") if cfg.calc_loss: print("test vae loss:", running_loss) print("test nll loss:", running_nll) print("test disc loss:", running_disc_loss) - if __name__ == "__main__": main()