diff --git a/configs/vae/inference/17x256x256.py b/configs/vae/inference/17x256x256.py new file mode 100644 index 0000000..25c9ff0 --- /dev/null +++ b/configs/vae/inference/17x256x256.py @@ -0,0 +1,79 @@ +num_frames = 16 +image_size = (256, 256) +fps = 24 // 3 +max_test_samples = None + +# Define dataset +dataset = dict( + type="VideoTextDataset", + data_path=None, + num_frames=num_frames, + frame_interval=1, + image_size=image_size, +) + +# Define acceleration +num_workers = 4 +dtype = "bf16" +grad_checkpoint = True +plugin = "zero2" +sp_size = 1 + + +# Define model +vae_2d = dict( + type="VideoAutoencoderKL", + from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", + subfolder="vae", + micro_batch_size=4, + local_files_only=True, +) + +model = dict( + type="VAE_Temporal_SD", +) + +# discriminator = dict( +# type="DISCRIMINATOR_3D", +# image_size=image_size, +# num_frames=num_frames, +# in_channels=3, +# filters=128, +# channel_multipliers=(2, 4, 4, 4, 4), +# # channel_multipliers = (2,4,4), #(2,4,4,4,4) # (2,4,4,4) for 64x64 resolution +# ) + + +# loss weights +logvar_init = 0.0 +kl_loss_weight = 0.000001 +perceptual_loss_weight = 0.1 # use vgg is not None and more than 0 +discriminator_factor = 1.0 # for discriminator adversarial loss +# discriminator_loss_weight = 0.5 # for generator adversarial loss +generator_factor = 0.1 # for generator adversarial loss +lecam_loss_weight = None # NOTE: not clear in MAGVIT what is the weight +discriminator_loss_type = "non-saturating" +generator_loss_type = "non-saturating" +discriminator_start = 2500 # 50000 NOTE: change to correct val, debug use -1 for now +gradient_penalty_loss_weight = None # 10 # SCH: MAGVIT uses 10, opensora plan doesn't use +ema_decay = 0.999 # ema decay factor for generator + + +# Others +seed = 42 +save_dir = "samples/samples_vae" +wandb = False + +# Training +""" NOTE: +magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128 +==> ours num_frams = 16, reso = 256, so samples (K) * epochs ~ [500 - 1200], +3-6 epochs for pexel, from pexel observation its correct +""" + + +batch_size = 1 +lr = 1e-4 +grad_clip = 1.0 + +calc_loss = True diff --git a/configs/vae/inference/1x256x256.py b/configs/vae/inference/1x256x256.py index 98c162e..e9f8a5e 100644 --- a/configs/vae/inference/1x256x256.py +++ b/configs/vae/inference/1x256x256.py @@ -1,18 +1,18 @@ num_frames = 1 -image_size = (256, 256) +# image_size = (256, 256) +image_size = (1024, 1024) +fps = 24 // 3 +max_test_samples = None + +# Define dataset dataset = dict( type="VideoTextDataset", data_path=None, num_frames=num_frames, frame_interval=1, image_size=image_size, - get_text=False, ) -fps = 24 // 3 -is_vae = True -max_test_samples = -1 - # Define acceleration num_workers = 4 dtype = "bf16" @@ -20,12 +20,8 @@ grad_checkpoint = True plugin = "zero2" sp_size = 1 -use_pipeline = True -video_contains_first_frame = True - # Define model - vae_2d = dict( type="VideoAutoencoderKL", from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", @@ -35,31 +31,18 @@ vae_2d = dict( ) model = dict( - type="VAE_MAGVIT_V2", - in_out_channels=4, - latent_embed_dim=64, - filters=128, - num_res_blocks=4, - channel_multipliers=(1, 2, 2, 4), - temporal_downsample=(False, True, True), - num_groups=32, # for nn.GroupNorm - kl_embed_dim=4, - activation_fn="swish", - separate_first_frame_encoding=False, - disable_space=True, - custom_conv_padding=None, - encoder_double_z=True, + type="VAE_Temporal_SD", ) -discriminator = dict( - type="DISCRIMINATOR_3D", - image_size=image_size, - num_frames=num_frames, - in_channels=3, - filters=128, - channel_multipliers=(2, 4, 4, 4, 4), - # channel_multipliers = (2,4,4), #(2,4,4,4,4) # (2,4,4,4) for 64x64 resolution -) +# discriminator = dict( +# type="DISCRIMINATOR_3D", +# image_size=image_size, +# num_frames=num_frames, +# in_channels=3, +# filters=128, +# channel_multipliers=(2, 4, 4, 4, 4), +# # channel_multipliers = (2,4,4), #(2,4,4,4,4) # (2,4,4,4) for 64x64 resolution +# ) # loss weights @@ -79,7 +62,7 @@ ema_decay = 0.999 # ema decay factor for generator # Others seed = 42 -save_dir = "samples/samples_pixabay_17" +save_dir = "samples/samples_vae" wandb = False # Training diff --git a/configs/vae/train/17x256x256.py b/configs/vae/train/17x256x256.py index 2b24de2..46e3157 100644 --- a/configs/vae/train/17x256x256.py +++ b/configs/vae/train/17x256x256.py @@ -64,18 +64,11 @@ seed = 42 outputs = "outputs" wandb = False -# Training -""" NOTE: -magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128 -==> ours num_frams = 16, reso = 256, so samples (K) * epochs ~ [500 - 1200], -3-6 epochs for pexel, from pexel observation its correct -""" - epochs = 100 log_every = 1 ckpt_every = 1000 load = None -batch_size = 4 +batch_size = 1 lr = 1e-5 grad_clip = 1.0 diff --git a/configs/vae/train/1x256x256.py b/configs/vae/train/1x256x256.py index 582bbcb..d48df09 100644 --- a/configs/vae/train/1x256x256.py +++ b/configs/vae/train/1x256x256.py @@ -61,13 +61,6 @@ seed = 42 outputs = "outputs" wandb = False -# Training -""" NOTE: -magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128 -==> ours num_frams = 16, reso = 256, so samples (K) * epochs ~ [500 - 1200], -3-6 epochs for pexel, from pexel observation its correct -""" - epochs = 100 log_every = 1 ckpt_every = 1000 diff --git a/opensora/models/vae/__init__.py b/opensora/models/vae/__init__.py index 78a24fd..f27be47 100644 --- a/opensora/models/vae/__init__.py +++ b/opensora/models/vae/__init__.py @@ -1,3 +1,3 @@ from .discriminator import DISCRIMINATOR_3D from .vae import VideoAutoencoderKL, VideoAutoencoderKLTemporalDecoder -from .vae_3d import VAE_Temporal +from .vae_temporal import VAE_Temporal diff --git a/opensora/models/vae/lpips.py b/opensora/models/vae/lpips.py index d1f00c7..358cfe5 100644 --- a/opensora/models/vae/lpips.py +++ b/opensora/models/vae/lpips.py @@ -61,7 +61,7 @@ class LPIPS(nn.Module): param.requires_grad = False def load_from_pretrained(self, name="vgg_lpips"): - ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips", root="pretrained_models") + ckpt = get_ckpt_path(name, "pretrained_models/taming/modules/autoencoder/lpips") self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) # print("loaded pretrained LPIPS loss from {}".format(ckpt)) diff --git a/opensora/models/vae/vae_3d.py b/opensora/models/vae/vae_temporal.py similarity index 95% rename from opensora/models/vae/vae_3d.py rename to opensora/models/vae/vae_temporal.py index db0db61..1c34924 100644 --- a/opensora/models/vae/vae_3d.py +++ b/opensora/models/vae/vae_temporal.py @@ -24,23 +24,12 @@ def is_odd(n): return not divisible_by(n, 2) -def pad_at_dim(t, pad, dim=-1, value=0.0): +def pad_at_dim(t, pad, dim=-1): dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1) zeros = (0, 0) * dims_from_right return F.pad(t, (*zeros, *pad), mode="replicate") -def pick_video_frame(video, frame_indices): - """get frame_indices from the video of [B, C, T, H, W] and return images of [B, C, H, W]""" - batch, device = video.shape[0], video.device - video = rearrange(video, "b c f ... -> b f c ...") - batch_indices = torch.arange(batch, device=device) - batch_indices = rearrange(batch_indices, "b -> b 1") - images = video[batch_indices, frame_indices] - images = rearrange(images, "b 1 c ... -> b c ...") - return images - - def exists(v): return v is not None @@ -381,7 +370,7 @@ class VAE_Temporal(nn.Module): super().__init__() self.time_downsample_factor = 2 ** sum(temporal_downsample) - self.time_padding = self.time_downsample_factor - 1 + # self.time_padding = self.time_downsample_factor - 1 self.patch_size = (self.time_downsample_factor, 1, 1) # NOTE: following MAGVIT, conv in bias=False in encoder first conv @@ -420,16 +409,18 @@ class VAE_Temporal(nn.Module): return input_size def encode(self, x): - x = pad_at_dim(x, (self.time_padding, 0), dim=2) + time_padding = self.time_downsample_factor - x.shape[2] % self.time_downsample_factor + x = pad_at_dim(x, (time_padding, 0), dim=2) encoded_feature = self.encoder(x) moments = self.quant_conv(encoded_feature).to(x.dtype) posterior = DiagonalGaussianDistribution(moments) return posterior - def decode(self, z): + def decode(self, z, num_frames=None): + time_padding = self.time_downsample_factor - num_frames % self.time_downsample_factor z = self.post_quant_conv(z) x = self.decoder(z) - x = x[:, :, self.time_padding :] + x = x[:, :, time_padding:] return x def forward(self, x, sample_posterior=True): @@ -438,7 +429,7 @@ class VAE_Temporal(nn.Module): z = posterior.sample() else: z = posterior.mode() - recon_video = self.decode(z) + recon_video = self.decode(z, num_frames=x.shape[2]) return recon_video, posterior, z diff --git a/opensora/utils/ckpt_utils.py b/opensora/utils/ckpt_utils.py index 804de27..5d66449 100644 --- a/opensora/utils/ckpt_utils.py +++ b/opensora/utils/ckpt_utils.py @@ -264,6 +264,7 @@ def load_checkpoint(model, ckpt_path, save_as_pt=False, model_name="model"): print(f"Unexpected keys: {unexpected_keys}") elif os.path.isdir(ckpt_path): load_from_sharded_state_dict(model, ckpt_path, model_name) + print(f"Model checkpoint loaded from {ckpt_path}") if save_as_pt: save_path = os.path.join(ckpt_path, model_name + "_ckpt.pt") torch.save(model.state_dict(), save_path) diff --git a/opensora/utils/config_utils.py b/opensora/utils/config_utils.py index e59d4ab..96638ef 100644 --- a/opensora/utils/config_utils.py +++ b/opensora/utils/config_utils.py @@ -76,7 +76,7 @@ def merge_args(cfg, args, training=False): if cfg.get("discriminator") is not None: cfg.discriminator["from_pretrained"] = args.ckpt_path args.ckpt_path = None - if (training or cfg.get("is_vae", False)) and args.data_path is not None: + if args.data_path is not None: cfg.dataset["data_path"] = args.data_path args.data_path = None if not training and args.cfg_scale is not None: @@ -106,9 +106,8 @@ def merge_args(cfg, args, training=False): if "prompt_as_path" not in cfg: cfg["prompt_as_path"] = False # - Prompt handling - if not "is_vae" in cfg and ("prompt" not in cfg or cfg["prompt"] is None): - if "prompt" not in cfg or cfg["prompt"] is None: - assert cfg["prompt_path"] is not None, "prompt or prompt_path must be provided" + if "prompt" not in cfg or cfg["prompt"] is None: + if ("prompt" not in cfg or cfg["prompt"] is None) and cfg.get("prompt_path", None) is not None: cfg["prompt"] = load_prompts(cfg["prompt_path"]) if args.start_index is not None and args.end_index is not None: cfg["prompt"] = cfg["prompt"][args.start_index : args.end_index] diff --git a/scripts/inference-vae.py b/scripts/inference-vae.py index 7011807..4d42533 100644 --- a/scripts/inference-vae.py +++ b/scripts/inference-vae.py @@ -4,13 +4,12 @@ import colossalai import torch import torch.distributed as dist from colossalai.cluster import DistCoordinator -from colossalai.utils import get_current_device +from mmengine.runner import set_random_seed from tqdm import tqdm -from opensora.acceleration.parallel_states import get_data_parallel_group +from opensora.acceleration.parallel_states import get_data_parallel_group, set_sequence_parallel_group from opensora.datasets import prepare_dataloader, save_sample -from opensora.models.vae.losses import AdversarialLoss, DiscriminatorLoss, LeCamEMA, VAELoss -from opensora.models.vae.vae_3d import pad_at_dim +from opensora.models.vae.losses import VAELoss from opensora.registry import DATASETS, MODELS, build_module from opensora.utils.config_utils import parse_configs from opensora.utils.misc import to_torch_dtype @@ -24,21 +23,32 @@ def main(): print(cfg) # init distributed - colossalai.launch_from_torch({}) - coordinator = DistCoordinator() + if os.environ.get("WORLD_SIZE", None): + use_dist = True + colossalai.launch_from_torch({}) + coordinator = DistCoordinator() + + if coordinator.world_size > 1: + set_sequence_parallel_group(dist.group.WORLD) + else: + pass + else: + use_dist = False # ====================================================== # 2. runtime variables # ====================================================== torch.set_grad_enabled(False) - device = get_current_device() + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + device = "cuda" if torch.cuda.is_available() else "cpu" dtype = to_torch_dtype(cfg.dtype) + set_random_seed(seed=cfg.seed) # ====================================================== # 3. build dataset and dataloader # ====================================================== dataset = build_module(cfg.dataset, DATASETS) - dataloader = prepare_dataloader( dataset, batch_size=cfg.batch_size, @@ -49,7 +59,6 @@ def main(): process_group=get_data_parallel_group(), ) print(f"Dataset contains {len(dataset):,} videos ({cfg.dataset.data_path})") - total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size print(f"Total batch size: {total_batch_size}") @@ -57,25 +66,27 @@ def main(): # 4. build model & load weights # ====================================================== # 3.1. build model - if cfg.get("use_pipeline") == True: - # use 2D VAE, then temporal VAE + if cfg.get("vae_2d", None) is not None: vae_2d = build_module(cfg.vae_2d, MODELS) - vae = build_module(cfg.model, MODELS, device=device) - discriminator = build_module(cfg.discriminator, MODELS, device=device) + vae_2d.to(device, dtype).eval() + model = build_module( + cfg.model, + MODELS, + device=device, + dtype=dtype, + ) + # discriminator = build_module(cfg.discriminator, MODELS, device=device) # 3.2. move to device & eval - if cfg.get("use_pipeline") == True: - vae_2d.to(device, dtype).eval() - vae = vae.to(device, dtype).eval() - discriminator = discriminator.to(device, dtype).eval() + # discriminator = discriminator.to(device, dtype).eval() # 3.4. support for multi-resolution - model_args = dict() - if cfg.multi_resolution: - image_size = cfg.dataset.image_size - hw = torch.tensor([image_size], device=device, dtype=dtype).repeat(cfg.batch_size, 1) - ar = torch.tensor([[image_size[0] / image_size[1]]], device=device, dtype=dtype).repeat(cfg.batch_size, 1) - model_args["data_info"] = dict(ar=ar, hw=hw) + # model_args = dict() + # if cfg.multi_resolution: + # image_size = cfg.dataset.image_size + # hw = torch.tensor([image_size], device=device, dtype=dtype).repeat(cfg.batch_size, 1) + # ar = torch.tensor([[image_size[0] / image_size[1]]], device=device, dtype=dtype).repeat(cfg.batch_size, 1) + # model_args["data_info"] = dict(ar=ar, hw=hw) # ====================================================== # 4. inference @@ -95,46 +106,43 @@ def main(): dtype=dtype, ) - adversarial_loss_fn = AdversarialLoss( - discriminator_factor=cfg.discriminator_factor, - discriminator_start=cfg.discriminator_start, - generator_factor=cfg.generator_factor, - generator_loss_type=cfg.generator_loss_type, - ) + # adversarial_loss_fn = AdversarialLoss( + # discriminator_factor=cfg.discriminator_factor, + # discriminator_start=cfg.discriminator_start, + # generator_factor=cfg.generator_factor, + # generator_loss_type=cfg.generator_loss_type, + # ) - disc_loss_fn = DiscriminatorLoss( - discriminator_factor=cfg.discriminator_factor, - discriminator_start=cfg.discriminator_start, - discriminator_loss_type=cfg.discriminator_loss_type, - lecam_loss_weight=cfg.lecam_loss_weight, - gradient_penalty_loss_weight=cfg.gradient_penalty_loss_weight, - ) + # disc_loss_fn = DiscriminatorLoss( + # discriminator_factor=cfg.discriminator_factor, + # discriminator_start=cfg.discriminator_start, + # discriminator_loss_type=cfg.discriminator_loss_type, + # lecam_loss_weight=cfg.lecam_loss_weight, + # gradient_penalty_loss_weight=cfg.gradient_penalty_loss_weight, + # ) - # LeCam EMA for discriminator + # # LeCam EMA for discriminator - lecam_ema = LeCamEMA(decay=cfg.ema_decay, dtype=dtype, device=device) + # lecam_ema = LeCamEMA(decay=cfg.ema_decay, dtype=dtype, device=device) running_loss = 0.0 running_nll = 0.0 - running_disc_loss = 0.0 loss_steps = 0 - disc_time_downsample_factor = 2 ** len(cfg.discriminator.channel_multipliers) - if cfg.dataset.num_frames % disc_time_downsample_factor != 0: - disc_time_padding = disc_time_downsample_factor - cfg.dataset.num_frames % disc_time_downsample_factor - else: - disc_time_padding = 0 - video_contains_first_frame = cfg.video_contains_first_frame + # disc_time_downsample_factor = 2 ** len(cfg.discriminator.channel_multipliers) + # if cfg.dataset.num_frames % disc_time_downsample_factor != 0: + # disc_time_padding = disc_time_downsample_factor - cfg.dataset.num_frames % disc_time_downsample_factor + # else: + # disc_time_padding = 0 total_steps = len(dataloader) - if cfg.max_test_samples > 0: + if cfg.max_test_samples is not None: total_steps = min(int(cfg.max_test_samples // cfg.batch_size), total_steps) print(f"limiting test dataset to {int(cfg.max_test_samples//cfg.batch_size) * cfg.batch_size}") dataloader_iter = iter(dataloader) with tqdm( range(total_steps), - # desc=f"Avg Loss: {running_loss}", disable=not coordinator.is_master(), total=total_steps, initial=0, @@ -142,95 +150,96 @@ def main(): for step in pbar: batch = next(dataloader_iter) x = batch["video"].to(device, dtype) # [B, C, T, H, W] - video = x # ===== Spatial VAE ===== - if cfg.get("use_pipeline") == True: - with torch.no_grad(): - video_enc_spatial = vae_2d.encode(video) + if cfg.get("vae_2d", None) is not None: + x_z = vae_2d.encode(x) + x_z_debug = vae_2d.decode(x_z) - recon_dec_spatial, posterior = vae( - video_enc_spatial, video_contains_first_frame=video_contains_first_frame - ) - - recon_video = vae_2d.decode(recon_dec_spatial) - recon_2d = vae_2d.decode(video_enc_spatial) - - else: - recon_video, posterior = vae(video, video_contains_first_frame=video_contains_first_frame) + # ====== VAE ====== + x_z_rec, posterior, z = model(x_z) + x_rec = vae_2d.decode(x_z_rec) if cfg.calc_loss: - # ====== Calc Loss ====== # simple nll loss - nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(video, recon_video, posterior) + nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(x, x_rec, posterior) - fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value=0.0, dim=2) - fake_logits = discriminator(fake_video.contiguous()) - adversarial_loss = adversarial_loss_fn( - fake_logits, - nll_loss, - vae.get_last_layer(), - cfg.discriminator_start + 1, # Hack to use discriminator - is_training=vae.training, - ) + # fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value=0.0, dim=2) + # fake_logits = discriminator(fake_video.contiguous()) + # adversarial_loss = adversarial_loss_fn( + # fake_logits, + # nll_loss, + # vae.get_last_layer(), + # cfg.discriminator_start + 1, # Hack to use discriminator + # is_training=vae.training, + # ) - vae_loss = weighted_nll_loss + weighted_kl_loss + adversarial_loss + # vae_loss = weighted_nll_loss + weighted_kl_loss + adversarial_loss + vae_loss = weighted_nll_loss + weighted_kl_loss - # ====== Discriminator Loss ====== - real_video = pad_at_dim(video, (disc_time_padding, 0), value=0.0, dim=2) - fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value=0.0, dim=2) + # # ====== Discriminator Loss ====== + # real_video = pad_at_dim(video, (disc_time_padding, 0), value=0.0, dim=2) + # fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value=0.0, dim=2) - if cfg.gradient_penalty_loss_weight is not None and cfg.gradient_penalty_loss_weight > 0.0: - real_video = real_video.requires_grad_() - real_logits = discriminator( - real_video.contiguous() - ) # SCH: not detached for now for gradient_penalty calculation - else: - real_logits = discriminator(real_video.contiguous().detach()) + # if cfg.gradient_penalty_loss_weight is not None and cfg.gradient_penalty_loss_weight > 0.0: + # real_video = real_video.requires_grad_() + # real_logits = discriminator( + # real_video.contiguous() + # ) # SCH: not detached for now for gradient_penalty calculation + # else: + # real_logits = discriminator(real_video.contiguous().detach()) - fake_logits = discriminator(fake_video.contiguous().detach()) + # fake_logits = discriminator(fake_video.contiguous().detach()) - lecam_ema_real, lecam_ema_fake = lecam_ema.get() - weighted_d_adversarial_loss, lecam_loss, gradient_penalty_loss = disc_loss_fn( - real_logits, - fake_logits, - cfg.discriminator_start + 1, # Hack to use discriminator - lecam_ema_real=lecam_ema_real, - lecam_ema_fake=lecam_ema_fake, - real_video=real_video if cfg.gradient_penalty_loss_weight is not None else None, - ) + # lecam_ema_real, lecam_ema_fake = lecam_ema.get() + # weighted_d_adversarial_loss, lecam_loss, gradient_penalty_loss = disc_loss_fn( + # real_logits, + # fake_logits, + # cfg.discriminator_start + 1, # Hack to use discriminator + # lecam_ema_real=lecam_ema_real, + # lecam_ema_fake=lecam_ema_fake, + # real_video=real_video if cfg.gradient_penalty_loss_weight is not None else None, + # ) - disc_loss = weighted_d_adversarial_loss + lecam_loss + gradient_penalty_loss + # disc_loss = weighted_d_adversarial_loss + lecam_loss + gradient_penalty_loss loss_steps += 1 - running_disc_loss = disc_loss.item() / loss_steps + running_disc_loss * ((loss_steps - 1) / loss_steps) + # running_disc_loss = disc_loss.item() / loss_steps + running_disc_loss * ((loss_steps - 1) / loss_steps) running_loss = vae_loss.item() / loss_steps + running_loss * ((loss_steps - 1) / loss_steps) running_nll = nll_loss.item() / loss_steps + running_nll * ((loss_steps - 1) / loss_steps) # ===== Spatial VAE ===== - if coordinator.is_master(): - if cfg.get("use_pipeline") == True: - for idx, (sample_original, sample_pipeline, sample_2d) in enumerate( - zip(video, recon_video, recon_2d) - ): - pos = step * cfg.batch_size + idx - save_path = os.path.join(save_dir, f"sample_{pos}") - save_sample(sample_original, fps=cfg.fps, save_path=save_path + "_original") - save_sample(sample_2d, fps=cfg.fps, save_path=save_path + "_2d") - save_sample(sample_pipeline, fps=cfg.fps, save_path=save_path + "_pipeline") + if not use_dist or coordinator.is_master(): + for idx in range(len(x)): + pos = step * cfg.batch_size + idx + save_path = os.path.join(save_dir, f"sample_{pos}") + save_sample(x[idx], fps=cfg.fps, save_path=save_path + "_original") + save_sample(x_rec[idx], fps=cfg.fps, save_path=save_path + "_pipeline") + if cfg.get("vae_2d", None) is not None: + save_sample(x_z_debug[idx], fps=cfg.fps, save_path=save_path + "_2d") - else: - for idx, (original, recon) in enumerate(zip(video, recon_video)): - pos = step * cfg.batch_size + idx - save_path = os.path.join(save_dir, f"sample_{pos}") - save_sample(original, fps=cfg.fps, save_path=save_path + "_original") - save_sample(recon, fps=cfg.fps, save_path=save_path + "_recon") + # if cfg.get("use_pipeline") == True: + # for idx, (sample_original, sample_pipeline, sample_2d) in enumerate( + # zip(video, recon_video, recon_2d) + # ): + # pos = step * cfg.batch_size + idx + # save_path = os.path.join(save_dir, f"sample_{pos}") + # save_sample(sample_original, fps=cfg.fps, save_path=save_path + "_original") + # save_sample(sample_2d, fps=cfg.fps, save_path=save_path + "_2d") + # save_sample(sample_pipeline, fps=cfg.fps, save_path=save_path + "_pipeline") + + # else: + # for idx, (original, recon) in enumerate(zip(video, recon_video)): + # pos = step * cfg.batch_size + idx + # save_path = os.path.join(save_dir, f"sample_{pos}") + # save_sample(original, fps=cfg.fps, save_path=save_path + "_original") + # save_sample(recon, fps=cfg.fps, save_path=save_path + "_recon") if cfg.calc_loss: print("test vae loss:", running_loss) print("test nll loss:", running_nll) - print("test disc loss:", running_disc_loss) + # print("test disc loss:", running_disc_loss) if __name__ == "__main__": diff --git a/scripts/train-vae.py b/scripts/train-vae.py index 5200cba..32f870b 100644 --- a/scripts/train-vae.py +++ b/scripts/train-vae.py @@ -1,4 +1,5 @@ import os +import random from datetime import timedelta from pprint import pprint @@ -268,11 +269,14 @@ def main(): ) as pbar: for step, batch in pbar: x = batch["video"].to(device, dtype) # [B, C, T, H, W] + if random.random() < 0.5: + x = x[:, :, :1, :, :] # ===== Spatial VAE ===== if cfg.get("vae_2d", None) is not None: with torch.no_grad(): x_z = vae_2d.encode(x) + vae_2d.decode(x_z) # ====== VAE ====== x_z_rec, posterior, z = model(x_z) @@ -281,7 +285,8 @@ def main(): # ====== Generator Loss ====== # simple nll loss _, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(x, x_rec, posterior) - _, weighted_z_nll_loss, _ = vae_loss_fn(x_z, x_z_rec, posterior) + # _, weighted_z_nll_loss, _ = vae_loss_fn(x_z, x_z_rec, posterior) + # _, debug_loss, _ = vae_loss_fn(x, x_z_debug, posterior) # _, image_identity_loss, _ = vae_loss_fn(x_z, z, posterior) # adversarial_loss = torch.tensor(0.0) @@ -300,7 +305,10 @@ def main(): # vae_loss = weighted_nll_loss + weighted_kl_loss + adversarial_loss + weighted_z_nll_loss # vae_loss = weighted_nll_loss + weighted_kl_loss + weighted_z_nll_loss + image_identity_loss - vae_loss = weighted_nll_loss + weighted_kl_loss + weighted_z_nll_loss + # vae_loss = weighted_z_nll_loss + image_identity_loss + # vae_loss = weighted_nll_loss + weighted_kl_loss + weighted_z_nll_loss + # vae_loss = weighted_z_nll_loss + vae_loss = weighted_nll_loss + weighted_kl_loss optimizer.zero_grad() # Backward & update @@ -391,8 +399,9 @@ def main(): # "lecam_loss": lecam_loss.item(), # "r1_grad_penalty": gradient_penalty_loss.item(), "nll_loss": weighted_nll_loss.item(), - "z_nll_loss": weighted_z_nll_loss.item(), + # "z_nll_loss": weighted_z_nll_loss.item(), # "image_identity_loss": image_identity_loss.item(), + # "debug_loss": debug_loss.item(), "avg_loss": avg_loss, }, step=global_step,