diff --git a/configs/vae_magvit_v2/train/17x128x128.py b/configs/vae_magvit_v2/train/17x128x128.py index 56ee25c..90e5940 100644 --- a/configs/vae_magvit_v2/train/17x128x128.py +++ b/configs/vae_magvit_v2/train/17x128x128.py @@ -50,7 +50,7 @@ discriminator = dict( num_frames = num_frames, in_channels = 3, filters = 128, - channel_multipliers = (2,4,4,4,4) + channel_multipliers = (2,4,4,4,4) # (2,4,4,4) for 64x64 resolution ) @@ -58,10 +58,16 @@ discriminator = dict( kl_loss_weight = 0.000001 perceptual_loss_weight = 0.1 # use vgg is not None and more than 0 discriminator_factor = 1.0 -discriminator_loss_weight = 0.5 # TODO: adjust value +discriminator_loss_weight = 0.5 +lecam_loss_weight = 0 # TODO: not clear in MAGVIT what is the weight discriminator_loss="hinge" discriminator_start = -1 # 50001 TODO: change to correct val, debug use -1 for now + +gradient_penalty_loss_weight = 10 # SCH: following MAGVIT config.vqgan.grad_penalty_cost +ema_decay = 0.999 # ema decay factor for generator + + # Others seed = 42 outputs = "outputs" diff --git a/opensora/models/vae/vae_3d_v2.py b/opensora/models/vae/vae_3d_v2.py index 0712c99..e85dcda 100644 --- a/opensora/models/vae/vae_3d_v2.py +++ b/opensora/models/vae/vae_3d_v2.py @@ -65,7 +65,6 @@ def hinge_d_loss(logits_real, logits_fake): loss_real = torch.mean(F.relu(1. - logits_real)) loss_fake = torch.mean(F.relu(1. + logits_fake)) d_loss = 0.5 * (loss_real + loss_fake) - breakpoint() # TODO: CHECK mean rather than sum return d_loss def vanilla_d_loss(logits_real, logits_fake): @@ -74,6 +73,40 @@ def vanilla_d_loss(logits_real, logits_fake): torch.mean(torch.nn.functional.softplus(logits_fake))) return d_loss +# TODO: verify if this is correct implementation +def lecam_reg(real_pred, fake_pred, ema_real_pred, ema_fake_pred): + """Lecam loss for data-efficient and stable GAN training. + + Described in https://arxiv.org/abs/2104.03310 + + Args: + real_pred: Prediction (scalar) for the real samples. + fake_pred: Prediction for the fake samples. + ema_real_pred: EMA prediction (scalar) for the real samples. + ema_fake_pred: EMA prediction for the fake samples. + + Returns: + Lecam regularization loss (scalar). + """ + assert real_pred.ndim == 0 and ema_fake_pred.ndim == 0 + lecam_loss = np.mean(np.power(nn.ReLU(real_pred - ema_fake_pred), 2)) + lecam_loss += np.mean(np.power(nn.ReLU(ema_real_pred - fake_pred), 2)) + return lecam_loss + +def gradient_penalty_fn(images, output): + # batch_size = images.shape[0] + gradients = torch.autograd.grad( + outputs = output, + inputs = images, + grad_outputs = torch.ones(output.size(), device = images.device), + create_graph = True, + retain_graph = True, + only_inputs = True + )[0] + + gradients = rearrange(gradients, 'b ... -> b (...)') + return ((gradients.norm(2, dim = 1) - 1) ** 2).mean() + def xavier_uniform_weight_init(m): if isinstance(m, nn.Conv3d) or isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu')) @@ -806,9 +839,6 @@ class VAE_3D_V2(nn.Module): def forward( self, video, - # optimizer_idx, - # global_step, - # discriminator, # TODO sample_posterior=True, video_contains_first_frame = True, # split = "train", @@ -1003,13 +1033,15 @@ class VEALoss(nn.Module): # breakpoint() # total_loss = nll_loss + weighted_gan_loss - log = { - "{}/total_loss".format(split): nll_loss.clone().detach().mean(), - "{}/recon_loss".format(split): recon_loss.detach().mean(), - "{}/weighted_perceptual_loss".format(split): weighted_perceptual_loss.detach().mean(), - "{}/weighted_kl_loss".format(split): weighted_kl_loss.detach().mean(), - } - return nll_loss, log + + # log = { + # "{}/total_loss".format(split): nll_loss.clone().detach().mean(), + # "{}/recon_loss".format(split): recon_loss.detach().mean(), + # "{}/weighted_perceptual_loss".format(split): weighted_perceptual_loss.detach().mean(), + # "{}/weighted_kl_loss".format(split): weighted_kl_loss.detach().mean(), + # } + + return nll_loss class AdversarialLoss(nn.Module): @@ -1054,6 +1086,8 @@ class AdversarialLoss(nn.Module): weighted_gan_loss = d_weight * disc_factor * gan_loss return weighted_gan_loss + + class DiscriminatorLoss(nn.Module): def __init__( @@ -1061,12 +1095,16 @@ class DiscriminatorLoss(nn.Module): discriminator_factor = 1.0, discriminator_start = 50001, discriminator_loss="hinge", + lecam_loss_weight=0, + gradient_penalty_loss_weight=10, # SCH: following MAGVIT config.vqgan.grad_penalty_cost ): super().__init__() assert discriminator_loss in ["hinge", "vanilla"] self.discriminator_factor = discriminator_factor self.discriminator_start = discriminator_start + self.lecam_loss_weight = lecam_loss_weight + self.gradient_penalty_loss_weight = gradient_penalty_loss_weight if discriminator_loss == "hinge": self.disc_loss_fn = hinge_d_loss @@ -1080,16 +1118,43 @@ class DiscriminatorLoss(nn.Module): real_logits, fake_logits, global_step, + lecam_ema_real = None, + lecam_ema_fake = None, + real_video = None, + split = "train", ): if self.discriminator_factor is not None and self.discriminator_factor > 0.0: disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_start) - weight_discriminator_loss = disc_factor * self.disc_loss_fn(real_logits, fake_logits) + weighted_d_adversarial_loss = disc_factor * self.disc_loss_fn(real_logits, fake_logits) else: - weight_discriminator_loss = 0 + weighted_d_adversarial_loss = 0 - breakpoint() + lecam_loss = 0.0 + if self.lecam_loss_weight is not None and self.lecam_loss_weight > 0.0: + real_pred = np.mean(real_logits.clone().detach()) + fake_pred = np.mean(fake_logits.clone().detach()) + lecam_loss = lecam_reg(real_pred, fake_pred, + lecam_ema_real, + lecam_ema_fake) + lecam_loss = lecam_loss * self.lecam_loss_weight + + gradient_penalty = 0.0 + if self.gradient_penalty_loss_weight is not None and self.gradient_penalty_loss_weight > 0.0: + assert real_video is not None + gradient_penalty = gradient_penalty_fn(real_video, real_logits) + gradient_penalty *= self.gradient_penalty_loss_weight + + discriminator_loss = weighted_d_adversarial_loss + lecam_loss + gradient_penalty - return weight_discriminator_loss + + # log = { + # "{}/discriminator_loss".format(split): discriminator_loss.clone().detach().mean(), + # "{}/d_adversarial_loss".format(split): weighted_d_adversarial_loss.detach().mean(), + # "{}/lecam_loss".format(split): lecam_loss.detach().mean(), + # "{}/gradient_penalty".format(split): gradient_penalty.detach().mean(), + # } + + return discriminator_loss @MODELS.register_module("VAE_MAGVIT_V2") def VAE_MAGVIT_V2(from_pretrained=None, **kwargs): diff --git a/scripts/train-vae-v2.py b/scripts/train-vae-v2.py index 11276e3..adeb112 100644 --- a/scripts/train-vae-v2.py +++ b/scripts/train-vae-v2.py @@ -15,6 +15,7 @@ from colossalai.utils import get_current_device from tqdm import tqdm import os from einops import rearrange +import numpy as np from opensora.acceleration.checkpoint import set_grad_checkpoint from opensora.acceleration.parallel_states import ( @@ -97,7 +98,6 @@ def main(): # ====================================================== dataset = DatasetFromCSV( cfg.data_path, - # TODO: change transforms transform=( get_transforms_video(cfg.image_size[0]) if not cfg.use_image_transform @@ -108,12 +108,6 @@ def main(): root=cfg.root, ) - # TODO: use plugin's prepare dataloader - # a batch contains: - # { - # "video": torch.Tensor, # [B, C, T, H, W], - # "text": List[str], - # } dataloader = prepare_dataloader( dataset, batch_size=cfg.batch_size, @@ -241,6 +235,9 @@ def main(): disc_time_padding = disc_time_downsample_factor - cfg.num_frames % disc_time_downsample_factor video_contains_first_frame = cfg.video_contains_first_frame + lecam_ema_real = np.asarray(0) + lecam_ema_fake = np.asarray(0) + for epoch in range(start_epoch, cfg.epochs): dataloader.sampler.set_epoch(epoch) dataloader_iter = iter(dataloader) @@ -273,13 +270,24 @@ def main(): # ====== VAE ====== + # this is essential for the first iteration after OOM + # optimizer._grad_store.reset_all_gradients() + # optimizer._bucket_store.reset_num_elements_in_bucket() + # optimizer._bucket_store.grad_to_param_mapping = dict() + # optimizer._bucket_store._grad_in_bucket = dict() + # optimizer._bucket_store._param_list = [] + # optimizer._bucket_store._padding_size = [] + # for rank in range(optimizer._bucket_store._world_size): + # optimizer._bucket_store._grad_in_bucket[rank] = [] + # optimizer._bucket_store.offset_list = [0] + # optimizer.zero_grad() optimizer.zero_grad() recon_video, posterior = vae( video, video_contains_first_frame = video_contains_first_frame, ) # simple nll loss - nll_loss, nll_loss_log = nll_loss_fn( + nll_loss = nll_loss_fn( video, recon_video, posterior, @@ -313,12 +321,30 @@ def main(): # if video_contains_first_frame: # Since we don't have enough T frames, pad anyways real_video = pad_at_dim(video, (disc_time_padding, 0), value = 0., dim = 2) + if cfg.gradient_penalty_loss_weight is not None and cfg.gradient_penalty_loss_weight > 0.0: + real_video = real_video.requires_grad_() + real_logits = discriminator(real_video.contiguous().detach()) fake_logits = discriminator(fake_video.contiguous().detach()) - disc_loss = disc_loss_fn(real_logits, fake_logits, global_step) + disc_loss = disc_loss_fn( + real_logits, + fake_logits, + global_step, + lecam_ema_real = lecam_ema_real, + lecam_ema_fake = lecam_ema_fake, + real_video = real_video + ) + + if cfg.ema_decay is not None: + # SCH: TODO: is this written properly like this for moving average? e.g. distributed training etc. + lecam_ema_real = lecam_ema_real * cfg.ema_decay + (1 - cfg.ema_decay) * np.mean(real_logits.clone().detach()) + lecam_ema_fake = lecam_ema_fake * cfg.ema_decay + (1 - cfg.ema_decay) * np.mean(fake_logits.clone().detach()) + + # Backward & update booster.backward(loss=disc_loss, optimizer=disc_optimizer) disc_optimizer.step() + # Log loss values: all_reduce_mean(disc_loss) running_disc_loss += disc_loss.item()