mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
add lecam and gradient penalty loss to discriminator
This commit is contained in:
parent
79bff13099
commit
c71e04daaa
|
|
@ -50,7 +50,7 @@ discriminator = dict(
|
|||
num_frames = num_frames,
|
||||
in_channels = 3,
|
||||
filters = 128,
|
||||
channel_multipliers = (2,4,4,4,4)
|
||||
channel_multipliers = (2,4,4,4,4) # (2,4,4,4) for 64x64 resolution
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -58,10 +58,16 @@ discriminator = dict(
|
|||
kl_loss_weight = 0.000001
|
||||
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
|
||||
discriminator_factor = 1.0
|
||||
discriminator_loss_weight = 0.5 # TODO: adjust value
|
||||
discriminator_loss_weight = 0.5
|
||||
lecam_loss_weight = 0 # TODO: not clear in MAGVIT what is the weight
|
||||
discriminator_loss="hinge"
|
||||
discriminator_start = -1 # 50001 TODO: change to correct val, debug use -1 for now
|
||||
|
||||
|
||||
gradient_penalty_loss_weight = 10 # SCH: following MAGVIT config.vqgan.grad_penalty_cost
|
||||
ema_decay = 0.999 # ema decay factor for generator
|
||||
|
||||
|
||||
# Others
|
||||
seed = 42
|
||||
outputs = "outputs"
|
||||
|
|
|
|||
|
|
@ -65,7 +65,6 @@ def hinge_d_loss(logits_real, logits_fake):
|
|||
loss_real = torch.mean(F.relu(1. - logits_real))
|
||||
loss_fake = torch.mean(F.relu(1. + logits_fake))
|
||||
d_loss = 0.5 * (loss_real + loss_fake)
|
||||
breakpoint() # TODO: CHECK mean rather than sum
|
||||
return d_loss
|
||||
|
||||
def vanilla_d_loss(logits_real, logits_fake):
|
||||
|
|
@ -74,6 +73,40 @@ def vanilla_d_loss(logits_real, logits_fake):
|
|||
torch.mean(torch.nn.functional.softplus(logits_fake)))
|
||||
return d_loss
|
||||
|
||||
# TODO: verify if this is correct implementation
|
||||
def lecam_reg(real_pred, fake_pred, ema_real_pred, ema_fake_pred):
|
||||
"""Lecam loss for data-efficient and stable GAN training.
|
||||
|
||||
Described in https://arxiv.org/abs/2104.03310
|
||||
|
||||
Args:
|
||||
real_pred: Prediction (scalar) for the real samples.
|
||||
fake_pred: Prediction for the fake samples.
|
||||
ema_real_pred: EMA prediction (scalar) for the real samples.
|
||||
ema_fake_pred: EMA prediction for the fake samples.
|
||||
|
||||
Returns:
|
||||
Lecam regularization loss (scalar).
|
||||
"""
|
||||
assert real_pred.ndim == 0 and ema_fake_pred.ndim == 0
|
||||
lecam_loss = np.mean(np.power(nn.ReLU(real_pred - ema_fake_pred), 2))
|
||||
lecam_loss += np.mean(np.power(nn.ReLU(ema_real_pred - fake_pred), 2))
|
||||
return lecam_loss
|
||||
|
||||
def gradient_penalty_fn(images, output):
|
||||
# batch_size = images.shape[0]
|
||||
gradients = torch.autograd.grad(
|
||||
outputs = output,
|
||||
inputs = images,
|
||||
grad_outputs = torch.ones(output.size(), device = images.device),
|
||||
create_graph = True,
|
||||
retain_graph = True,
|
||||
only_inputs = True
|
||||
)[0]
|
||||
|
||||
gradients = rearrange(gradients, 'b ... -> b (...)')
|
||||
return ((gradients.norm(2, dim = 1) - 1) ** 2).mean()
|
||||
|
||||
def xavier_uniform_weight_init(m):
|
||||
if isinstance(m, nn.Conv3d) or isinstance(m, nn.Linear):
|
||||
nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
|
||||
|
|
@ -806,9 +839,6 @@ class VAE_3D_V2(nn.Module):
|
|||
def forward(
|
||||
self,
|
||||
video,
|
||||
# optimizer_idx,
|
||||
# global_step,
|
||||
# discriminator, # TODO
|
||||
sample_posterior=True,
|
||||
video_contains_first_frame = True,
|
||||
# split = "train",
|
||||
|
|
@ -1003,13 +1033,15 @@ class VEALoss(nn.Module):
|
|||
# breakpoint()
|
||||
# total_loss = nll_loss + weighted_gan_loss
|
||||
|
||||
log = {
|
||||
"{}/total_loss".format(split): nll_loss.clone().detach().mean(),
|
||||
"{}/recon_loss".format(split): recon_loss.detach().mean(),
|
||||
"{}/weighted_perceptual_loss".format(split): weighted_perceptual_loss.detach().mean(),
|
||||
"{}/weighted_kl_loss".format(split): weighted_kl_loss.detach().mean(),
|
||||
}
|
||||
return nll_loss, log
|
||||
|
||||
# log = {
|
||||
# "{}/total_loss".format(split): nll_loss.clone().detach().mean(),
|
||||
# "{}/recon_loss".format(split): recon_loss.detach().mean(),
|
||||
# "{}/weighted_perceptual_loss".format(split): weighted_perceptual_loss.detach().mean(),
|
||||
# "{}/weighted_kl_loss".format(split): weighted_kl_loss.detach().mean(),
|
||||
# }
|
||||
|
||||
return nll_loss
|
||||
|
||||
|
||||
class AdversarialLoss(nn.Module):
|
||||
|
|
@ -1054,6 +1086,8 @@ class AdversarialLoss(nn.Module):
|
|||
weighted_gan_loss = d_weight * disc_factor * gan_loss
|
||||
|
||||
return weighted_gan_loss
|
||||
|
||||
|
||||
|
||||
class DiscriminatorLoss(nn.Module):
|
||||
def __init__(
|
||||
|
|
@ -1061,12 +1095,16 @@ class DiscriminatorLoss(nn.Module):
|
|||
discriminator_factor = 1.0,
|
||||
discriminator_start = 50001,
|
||||
discriminator_loss="hinge",
|
||||
lecam_loss_weight=0,
|
||||
gradient_penalty_loss_weight=10, # SCH: following MAGVIT config.vqgan.grad_penalty_cost
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert discriminator_loss in ["hinge", "vanilla"]
|
||||
self.discriminator_factor = discriminator_factor
|
||||
self.discriminator_start = discriminator_start
|
||||
self.lecam_loss_weight = lecam_loss_weight
|
||||
self.gradient_penalty_loss_weight = gradient_penalty_loss_weight
|
||||
|
||||
if discriminator_loss == "hinge":
|
||||
self.disc_loss_fn = hinge_d_loss
|
||||
|
|
@ -1080,16 +1118,43 @@ class DiscriminatorLoss(nn.Module):
|
|||
real_logits,
|
||||
fake_logits,
|
||||
global_step,
|
||||
lecam_ema_real = None,
|
||||
lecam_ema_fake = None,
|
||||
real_video = None,
|
||||
split = "train",
|
||||
):
|
||||
if self.discriminator_factor is not None and self.discriminator_factor > 0.0:
|
||||
disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_start)
|
||||
weight_discriminator_loss = disc_factor * self.disc_loss_fn(real_logits, fake_logits)
|
||||
weighted_d_adversarial_loss = disc_factor * self.disc_loss_fn(real_logits, fake_logits)
|
||||
else:
|
||||
weight_discriminator_loss = 0
|
||||
weighted_d_adversarial_loss = 0
|
||||
|
||||
breakpoint()
|
||||
lecam_loss = 0.0
|
||||
if self.lecam_loss_weight is not None and self.lecam_loss_weight > 0.0:
|
||||
real_pred = np.mean(real_logits.clone().detach())
|
||||
fake_pred = np.mean(fake_logits.clone().detach())
|
||||
lecam_loss = lecam_reg(real_pred, fake_pred,
|
||||
lecam_ema_real,
|
||||
lecam_ema_fake)
|
||||
lecam_loss = lecam_loss * self.lecam_loss_weight
|
||||
|
||||
gradient_penalty = 0.0
|
||||
if self.gradient_penalty_loss_weight is not None and self.gradient_penalty_loss_weight > 0.0:
|
||||
assert real_video is not None
|
||||
gradient_penalty = gradient_penalty_fn(real_video, real_logits)
|
||||
gradient_penalty *= self.gradient_penalty_loss_weight
|
||||
|
||||
discriminator_loss = weighted_d_adversarial_loss + lecam_loss + gradient_penalty
|
||||
|
||||
return weight_discriminator_loss
|
||||
|
||||
# log = {
|
||||
# "{}/discriminator_loss".format(split): discriminator_loss.clone().detach().mean(),
|
||||
# "{}/d_adversarial_loss".format(split): weighted_d_adversarial_loss.detach().mean(),
|
||||
# "{}/lecam_loss".format(split): lecam_loss.detach().mean(),
|
||||
# "{}/gradient_penalty".format(split): gradient_penalty.detach().mean(),
|
||||
# }
|
||||
|
||||
return discriminator_loss
|
||||
|
||||
@MODELS.register_module("VAE_MAGVIT_V2")
|
||||
def VAE_MAGVIT_V2(from_pretrained=None, **kwargs):
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from colossalai.utils import get_current_device
|
|||
from tqdm import tqdm
|
||||
import os
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
|
||||
from opensora.acceleration.checkpoint import set_grad_checkpoint
|
||||
from opensora.acceleration.parallel_states import (
|
||||
|
|
@ -97,7 +98,6 @@ def main():
|
|||
# ======================================================
|
||||
dataset = DatasetFromCSV(
|
||||
cfg.data_path,
|
||||
# TODO: change transforms
|
||||
transform=(
|
||||
get_transforms_video(cfg.image_size[0])
|
||||
if not cfg.use_image_transform
|
||||
|
|
@ -108,12 +108,6 @@ def main():
|
|||
root=cfg.root,
|
||||
)
|
||||
|
||||
# TODO: use plugin's prepare dataloader
|
||||
# a batch contains:
|
||||
# {
|
||||
# "video": torch.Tensor, # [B, C, T, H, W],
|
||||
# "text": List[str],
|
||||
# }
|
||||
dataloader = prepare_dataloader(
|
||||
dataset,
|
||||
batch_size=cfg.batch_size,
|
||||
|
|
@ -241,6 +235,9 @@ def main():
|
|||
disc_time_padding = disc_time_downsample_factor - cfg.num_frames % disc_time_downsample_factor
|
||||
video_contains_first_frame = cfg.video_contains_first_frame
|
||||
|
||||
lecam_ema_real = np.asarray(0)
|
||||
lecam_ema_fake = np.asarray(0)
|
||||
|
||||
for epoch in range(start_epoch, cfg.epochs):
|
||||
dataloader.sampler.set_epoch(epoch)
|
||||
dataloader_iter = iter(dataloader)
|
||||
|
|
@ -273,13 +270,24 @@ def main():
|
|||
|
||||
|
||||
# ====== VAE ======
|
||||
# this is essential for the first iteration after OOM
|
||||
# optimizer._grad_store.reset_all_gradients()
|
||||
# optimizer._bucket_store.reset_num_elements_in_bucket()
|
||||
# optimizer._bucket_store.grad_to_param_mapping = dict()
|
||||
# optimizer._bucket_store._grad_in_bucket = dict()
|
||||
# optimizer._bucket_store._param_list = []
|
||||
# optimizer._bucket_store._padding_size = []
|
||||
# for rank in range(optimizer._bucket_store._world_size):
|
||||
# optimizer._bucket_store._grad_in_bucket[rank] = []
|
||||
# optimizer._bucket_store.offset_list = [0]
|
||||
# optimizer.zero_grad()
|
||||
optimizer.zero_grad()
|
||||
recon_video, posterior = vae(
|
||||
video,
|
||||
video_contains_first_frame = video_contains_first_frame,
|
||||
)
|
||||
# simple nll loss
|
||||
nll_loss, nll_loss_log = nll_loss_fn(
|
||||
nll_loss = nll_loss_fn(
|
||||
video,
|
||||
recon_video,
|
||||
posterior,
|
||||
|
|
@ -313,12 +321,30 @@ def main():
|
|||
# if video_contains_first_frame:
|
||||
# Since we don't have enough T frames, pad anyways
|
||||
real_video = pad_at_dim(video, (disc_time_padding, 0), value = 0., dim = 2)
|
||||
if cfg.gradient_penalty_loss_weight is not None and cfg.gradient_penalty_loss_weight > 0.0:
|
||||
real_video = real_video.requires_grad_()
|
||||
|
||||
real_logits = discriminator(real_video.contiguous().detach())
|
||||
fake_logits = discriminator(fake_video.contiguous().detach())
|
||||
disc_loss = disc_loss_fn(real_logits, fake_logits, global_step)
|
||||
disc_loss = disc_loss_fn(
|
||||
real_logits,
|
||||
fake_logits,
|
||||
global_step,
|
||||
lecam_ema_real = lecam_ema_real,
|
||||
lecam_ema_fake = lecam_ema_fake,
|
||||
real_video = real_video
|
||||
)
|
||||
|
||||
if cfg.ema_decay is not None:
|
||||
# SCH: TODO: is this written properly like this for moving average? e.g. distributed training etc.
|
||||
lecam_ema_real = lecam_ema_real * cfg.ema_decay + (1 - cfg.ema_decay) * np.mean(real_logits.clone().detach())
|
||||
lecam_ema_fake = lecam_ema_fake * cfg.ema_decay + (1 - cfg.ema_decay) * np.mean(fake_logits.clone().detach())
|
||||
|
||||
|
||||
# Backward & update
|
||||
booster.backward(loss=disc_loss, optimizer=disc_optimizer)
|
||||
disc_optimizer.step()
|
||||
|
||||
# Log loss values:
|
||||
all_reduce_mean(disc_loss)
|
||||
running_disc_loss += disc_loss.item()
|
||||
|
|
|
|||
Loading…
Reference in a new issue