add lecam and gradient penalty loss to discriminator

This commit is contained in:
Shen-Chenhui 2024-04-15 11:13:39 +08:00
parent 79bff13099
commit c71e04daaa
3 changed files with 123 additions and 26 deletions

View file

@ -50,7 +50,7 @@ discriminator = dict(
num_frames = num_frames,
in_channels = 3,
filters = 128,
channel_multipliers = (2,4,4,4,4)
channel_multipliers = (2,4,4,4,4) # (2,4,4,4) for 64x64 resolution
)
@ -58,10 +58,16 @@ discriminator = dict(
kl_loss_weight = 0.000001
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
discriminator_factor = 1.0
discriminator_loss_weight = 0.5 # TODO: adjust value
discriminator_loss_weight = 0.5
lecam_loss_weight = 0 # TODO: not clear in MAGVIT what is the weight
discriminator_loss="hinge"
discriminator_start = -1 # 50001 TODO: change to correct val, debug use -1 for now
gradient_penalty_loss_weight = 10 # SCH: following MAGVIT config.vqgan.grad_penalty_cost
ema_decay = 0.999 # ema decay factor for generator
# Others
seed = 42
outputs = "outputs"

View file

@ -65,7 +65,6 @@ def hinge_d_loss(logits_real, logits_fake):
loss_real = torch.mean(F.relu(1. - logits_real))
loss_fake = torch.mean(F.relu(1. + logits_fake))
d_loss = 0.5 * (loss_real + loss_fake)
breakpoint() # TODO: CHECK mean rather than sum
return d_loss
def vanilla_d_loss(logits_real, logits_fake):
@ -74,6 +73,40 @@ def vanilla_d_loss(logits_real, logits_fake):
torch.mean(torch.nn.functional.softplus(logits_fake)))
return d_loss
# TODO: verify if this is correct implementation
def lecam_reg(real_pred, fake_pred, ema_real_pred, ema_fake_pred):
"""Lecam loss for data-efficient and stable GAN training.
Described in https://arxiv.org/abs/2104.03310
Args:
real_pred: Prediction (scalar) for the real samples.
fake_pred: Prediction for the fake samples.
ema_real_pred: EMA prediction (scalar) for the real samples.
ema_fake_pred: EMA prediction for the fake samples.
Returns:
Lecam regularization loss (scalar).
"""
assert real_pred.ndim == 0 and ema_fake_pred.ndim == 0
lecam_loss = np.mean(np.power(nn.ReLU(real_pred - ema_fake_pred), 2))
lecam_loss += np.mean(np.power(nn.ReLU(ema_real_pred - fake_pred), 2))
return lecam_loss
def gradient_penalty_fn(images, output):
# batch_size = images.shape[0]
gradients = torch.autograd.grad(
outputs = output,
inputs = images,
grad_outputs = torch.ones(output.size(), device = images.device),
create_graph = True,
retain_graph = True,
only_inputs = True
)[0]
gradients = rearrange(gradients, 'b ... -> b (...)')
return ((gradients.norm(2, dim = 1) - 1) ** 2).mean()
def xavier_uniform_weight_init(m):
if isinstance(m, nn.Conv3d) or isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
@ -806,9 +839,6 @@ class VAE_3D_V2(nn.Module):
def forward(
self,
video,
# optimizer_idx,
# global_step,
# discriminator, # TODO
sample_posterior=True,
video_contains_first_frame = True,
# split = "train",
@ -1003,13 +1033,15 @@ class VEALoss(nn.Module):
# breakpoint()
# total_loss = nll_loss + weighted_gan_loss
log = {
"{}/total_loss".format(split): nll_loss.clone().detach().mean(),
"{}/recon_loss".format(split): recon_loss.detach().mean(),
"{}/weighted_perceptual_loss".format(split): weighted_perceptual_loss.detach().mean(),
"{}/weighted_kl_loss".format(split): weighted_kl_loss.detach().mean(),
}
return nll_loss, log
# log = {
# "{}/total_loss".format(split): nll_loss.clone().detach().mean(),
# "{}/recon_loss".format(split): recon_loss.detach().mean(),
# "{}/weighted_perceptual_loss".format(split): weighted_perceptual_loss.detach().mean(),
# "{}/weighted_kl_loss".format(split): weighted_kl_loss.detach().mean(),
# }
return nll_loss
class AdversarialLoss(nn.Module):
@ -1054,6 +1086,8 @@ class AdversarialLoss(nn.Module):
weighted_gan_loss = d_weight * disc_factor * gan_loss
return weighted_gan_loss
class DiscriminatorLoss(nn.Module):
def __init__(
@ -1061,12 +1095,16 @@ class DiscriminatorLoss(nn.Module):
discriminator_factor = 1.0,
discriminator_start = 50001,
discriminator_loss="hinge",
lecam_loss_weight=0,
gradient_penalty_loss_weight=10, # SCH: following MAGVIT config.vqgan.grad_penalty_cost
):
super().__init__()
assert discriminator_loss in ["hinge", "vanilla"]
self.discriminator_factor = discriminator_factor
self.discriminator_start = discriminator_start
self.lecam_loss_weight = lecam_loss_weight
self.gradient_penalty_loss_weight = gradient_penalty_loss_weight
if discriminator_loss == "hinge":
self.disc_loss_fn = hinge_d_loss
@ -1080,16 +1118,43 @@ class DiscriminatorLoss(nn.Module):
real_logits,
fake_logits,
global_step,
lecam_ema_real = None,
lecam_ema_fake = None,
real_video = None,
split = "train",
):
if self.discriminator_factor is not None and self.discriminator_factor > 0.0:
disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_start)
weight_discriminator_loss = disc_factor * self.disc_loss_fn(real_logits, fake_logits)
weighted_d_adversarial_loss = disc_factor * self.disc_loss_fn(real_logits, fake_logits)
else:
weight_discriminator_loss = 0
weighted_d_adversarial_loss = 0
breakpoint()
lecam_loss = 0.0
if self.lecam_loss_weight is not None and self.lecam_loss_weight > 0.0:
real_pred = np.mean(real_logits.clone().detach())
fake_pred = np.mean(fake_logits.clone().detach())
lecam_loss = lecam_reg(real_pred, fake_pred,
lecam_ema_real,
lecam_ema_fake)
lecam_loss = lecam_loss * self.lecam_loss_weight
gradient_penalty = 0.0
if self.gradient_penalty_loss_weight is not None and self.gradient_penalty_loss_weight > 0.0:
assert real_video is not None
gradient_penalty = gradient_penalty_fn(real_video, real_logits)
gradient_penalty *= self.gradient_penalty_loss_weight
discriminator_loss = weighted_d_adversarial_loss + lecam_loss + gradient_penalty
return weight_discriminator_loss
# log = {
# "{}/discriminator_loss".format(split): discriminator_loss.clone().detach().mean(),
# "{}/d_adversarial_loss".format(split): weighted_d_adversarial_loss.detach().mean(),
# "{}/lecam_loss".format(split): lecam_loss.detach().mean(),
# "{}/gradient_penalty".format(split): gradient_penalty.detach().mean(),
# }
return discriminator_loss
@MODELS.register_module("VAE_MAGVIT_V2")
def VAE_MAGVIT_V2(from_pretrained=None, **kwargs):

View file

@ -15,6 +15,7 @@ from colossalai.utils import get_current_device
from tqdm import tqdm
import os
from einops import rearrange
import numpy as np
from opensora.acceleration.checkpoint import set_grad_checkpoint
from opensora.acceleration.parallel_states import (
@ -97,7 +98,6 @@ def main():
# ======================================================
dataset = DatasetFromCSV(
cfg.data_path,
# TODO: change transforms
transform=(
get_transforms_video(cfg.image_size[0])
if not cfg.use_image_transform
@ -108,12 +108,6 @@ def main():
root=cfg.root,
)
# TODO: use plugin's prepare dataloader
# a batch contains:
# {
# "video": torch.Tensor, # [B, C, T, H, W],
# "text": List[str],
# }
dataloader = prepare_dataloader(
dataset,
batch_size=cfg.batch_size,
@ -241,6 +235,9 @@ def main():
disc_time_padding = disc_time_downsample_factor - cfg.num_frames % disc_time_downsample_factor
video_contains_first_frame = cfg.video_contains_first_frame
lecam_ema_real = np.asarray(0)
lecam_ema_fake = np.asarray(0)
for epoch in range(start_epoch, cfg.epochs):
dataloader.sampler.set_epoch(epoch)
dataloader_iter = iter(dataloader)
@ -273,13 +270,24 @@ def main():
# ====== VAE ======
# this is essential for the first iteration after OOM
# optimizer._grad_store.reset_all_gradients()
# optimizer._bucket_store.reset_num_elements_in_bucket()
# optimizer._bucket_store.grad_to_param_mapping = dict()
# optimizer._bucket_store._grad_in_bucket = dict()
# optimizer._bucket_store._param_list = []
# optimizer._bucket_store._padding_size = []
# for rank in range(optimizer._bucket_store._world_size):
# optimizer._bucket_store._grad_in_bucket[rank] = []
# optimizer._bucket_store.offset_list = [0]
# optimizer.zero_grad()
optimizer.zero_grad()
recon_video, posterior = vae(
video,
video_contains_first_frame = video_contains_first_frame,
)
# simple nll loss
nll_loss, nll_loss_log = nll_loss_fn(
nll_loss = nll_loss_fn(
video,
recon_video,
posterior,
@ -313,12 +321,30 @@ def main():
# if video_contains_first_frame:
# Since we don't have enough T frames, pad anyways
real_video = pad_at_dim(video, (disc_time_padding, 0), value = 0., dim = 2)
if cfg.gradient_penalty_loss_weight is not None and cfg.gradient_penalty_loss_weight > 0.0:
real_video = real_video.requires_grad_()
real_logits = discriminator(real_video.contiguous().detach())
fake_logits = discriminator(fake_video.contiguous().detach())
disc_loss = disc_loss_fn(real_logits, fake_logits, global_step)
disc_loss = disc_loss_fn(
real_logits,
fake_logits,
global_step,
lecam_ema_real = lecam_ema_real,
lecam_ema_fake = lecam_ema_fake,
real_video = real_video
)
if cfg.ema_decay is not None:
# SCH: TODO: is this written properly like this for moving average? e.g. distributed training etc.
lecam_ema_real = lecam_ema_real * cfg.ema_decay + (1 - cfg.ema_decay) * np.mean(real_logits.clone().detach())
lecam_ema_fake = lecam_ema_fake * cfg.ema_decay + (1 - cfg.ema_decay) * np.mean(fake_logits.clone().detach())
# Backward & update
booster.backward(loss=disc_loss, optimizer=disc_optimizer)
disc_optimizer.step()
# Log loss values:
all_reduce_mean(disc_loss)
running_disc_loss += disc_loss.item()