mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-17 14:25:07 +02:00
move out losses to different file
This commit is contained in:
parent
4d69671663
commit
f81d8648cf
273
opensora/models/vae/losses.py
Normal file
273
opensora/models/vae/losses.py
Normal file
|
|
@ -0,0 +1,273 @@
|
|||
import torch.nn as nn
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from .lpips import LPIPS
|
||||
from einops import rearrange, repeat
|
||||
|
||||
def hinge_d_loss(logits_real, logits_fake):
|
||||
loss_real = torch.mean(F.relu(1.0 - logits_real))
|
||||
loss_fake = torch.mean(F.relu(1.0 + logits_fake))
|
||||
d_loss = 0.5 * (loss_real + loss_fake)
|
||||
return d_loss
|
||||
|
||||
|
||||
def vanilla_d_loss(logits_real, logits_fake):
|
||||
d_loss = 0.5 * (
|
||||
torch.mean(torch.nn.functional.softplus(-logits_real)) + torch.mean(torch.nn.functional.softplus(logits_fake))
|
||||
)
|
||||
return d_loss
|
||||
|
||||
|
||||
# from MAGVIT, used in place hof hinge_d_loss
|
||||
def sigmoid_cross_entropy_with_logits(labels, logits):
|
||||
# The final formulation is: max(x, 0) - x * z + log(1 + exp(-abs(x)))
|
||||
zeros = torch.zeros_like(logits, dtype=logits.dtype)
|
||||
condition = logits >= zeros
|
||||
relu_logits = torch.where(condition, logits, zeros)
|
||||
neg_abs_logits = torch.where(condition, -logits, logits)
|
||||
return relu_logits - logits * labels + torch.log1p(torch.exp(neg_abs_logits))
|
||||
|
||||
def lecam_reg(real_pred, fake_pred, ema_real_pred, ema_fake_pred):
|
||||
assert real_pred.ndim == 0 and ema_fake_pred.ndim == 0
|
||||
lecam_loss = torch.mean(torch.pow(nn.ReLU()(real_pred - ema_fake_pred), 2))
|
||||
lecam_loss += torch.mean(torch.pow(nn.ReLU()(ema_real_pred - fake_pred), 2))
|
||||
return lecam_loss
|
||||
|
||||
def gradient_penalty_fn(images, output):
|
||||
gradients = torch.autograd.grad(
|
||||
outputs=output,
|
||||
inputs=images,
|
||||
grad_outputs=torch.ones(output.size(), device=images.device),
|
||||
create_graph=True,
|
||||
retain_graph=True,
|
||||
only_inputs=True,
|
||||
)[0]
|
||||
|
||||
gradients = rearrange(gradients, "b ... -> b (...)")
|
||||
return ((gradients.norm(2, dim=1) - 1) ** 2).mean()
|
||||
|
||||
class VEALoss(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
logvar_init=0.0,
|
||||
perceptual_loss_weight=0.1,
|
||||
kl_loss_weight=0.000001,
|
||||
device="cpu",
|
||||
dtype="bf16",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if type(dtype) == str:
|
||||
if dtype == "bf16":
|
||||
dtype = torch.bfloat16
|
||||
elif dtype == "fp16":
|
||||
dtype = torch.float16
|
||||
else:
|
||||
raise NotImplementedError(f"dtype: {dtype}")
|
||||
|
||||
# KL Loss
|
||||
self.kl_loss_weight = kl_loss_weight
|
||||
# Perceptual Loss
|
||||
self.perceptual_loss_fn = LPIPS().eval().to(device, dtype)
|
||||
self.perceptual_loss_weight = perceptual_loss_weight
|
||||
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
video,
|
||||
recon_video,
|
||||
posterior,
|
||||
nll_weights=None,
|
||||
split="train",
|
||||
):
|
||||
video = rearrange(video, "b c t h w -> (b t) c h w").contiguous()
|
||||
recon_video = rearrange(recon_video, "b c t h w -> (b t) c h w").contiguous()
|
||||
|
||||
# reconstruction loss
|
||||
recon_loss = torch.abs(video - recon_video)
|
||||
|
||||
# perceptual loss
|
||||
if self.perceptual_loss_weight is not None and self.perceptual_loss_weight > 0.0:
|
||||
# handle channels
|
||||
channels = video.shape[1]
|
||||
assert channels in {1, 3, 4}
|
||||
if channels == 1:
|
||||
input_vgg_input = repeat(video, "b 1 h w -> b c h w", c=3)
|
||||
recon_vgg_input = repeat(recon_video, "b 1 h w -> b c h w", c=3)
|
||||
elif channels == 4: # SCH: take the first 3 for perceptual loss calc
|
||||
input_vgg_input = video[:, :3]
|
||||
recon_vgg_input = recon_video[:, :3]
|
||||
else:
|
||||
input_vgg_input = video
|
||||
recon_vgg_input = recon_video
|
||||
|
||||
perceptual_loss = self.perceptual_loss_fn(input_vgg_input, recon_vgg_input)
|
||||
recon_loss = recon_loss + self.perceptual_loss_weight * perceptual_loss
|
||||
|
||||
nll_loss = recon_loss / torch.exp(self.logvar) + self.logvar
|
||||
weighted_nll_loss = nll_loss
|
||||
if nll_weights is not None:
|
||||
weighted_nll_loss = nll_weights * nll_loss
|
||||
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
|
||||
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
||||
|
||||
|
||||
# KL Loss
|
||||
weighted_kl_loss = 0
|
||||
if self.kl_loss_weight is not None and self.kl_loss_weight > 0.0:
|
||||
kl_loss = posterior.kl()
|
||||
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
|
||||
weighted_kl_loss = kl_loss * self.kl_loss_weight
|
||||
|
||||
return nll_loss, weighted_nll_loss, weighted_kl_loss
|
||||
|
||||
def adopt_weight(weight, global_step, threshold=0, value=0.0):
|
||||
if global_step < threshold:
|
||||
weight = value
|
||||
return weight
|
||||
|
||||
class AdversarialLoss(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
discriminator_factor=1.0,
|
||||
discriminator_start=50001,
|
||||
generator_factor=0.5,
|
||||
generator_loss_type="non-saturating",
|
||||
):
|
||||
super().__init__()
|
||||
self.discriminator_factor = discriminator_factor
|
||||
self.discriminator_start = discriminator_start
|
||||
self.generator_factor = generator_factor
|
||||
self.generator_loss_type = generator_loss_type
|
||||
|
||||
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer):
|
||||
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
||||
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[
|
||||
0
|
||||
]
|
||||
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
||||
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
||||
d_weight = d_weight * self.generator_factor
|
||||
return d_weight
|
||||
|
||||
def forward(
|
||||
self,
|
||||
fake_logits,
|
||||
nll_loss,
|
||||
last_layer,
|
||||
global_step,
|
||||
is_training=True,
|
||||
):
|
||||
# NOTE: following MAGVIT to allow non_saturating
|
||||
assert self.generator_loss_type in ["hinge", "vanilla", "non-saturating"]
|
||||
|
||||
if self.generator_loss_type == "hinge":
|
||||
gen_loss = -torch.mean(fake_logits)
|
||||
elif self.generator_loss_type == "non-saturating":
|
||||
gen_loss = torch.mean(
|
||||
sigmoid_cross_entropy_with_logits(labels=torch.ones_like(fake_logits), logits=fake_logits)
|
||||
)
|
||||
else:
|
||||
raise ValueError("Generator loss {} not supported".format(self.generator_loss_type))
|
||||
|
||||
if self.discriminator_factor is not None and self.discriminator_factor > 0.0:
|
||||
try:
|
||||
d_weight = self.calculate_adaptive_weight(nll_loss, gen_loss, last_layer)
|
||||
except RuntimeError:
|
||||
assert not is_training
|
||||
d_weight = torch.tensor(0.0)
|
||||
else:
|
||||
d_weight = torch.tensor(0.0)
|
||||
|
||||
disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_start)
|
||||
weighted_gen_loss = d_weight * disc_factor * gen_loss
|
||||
|
||||
return weighted_gen_loss
|
||||
|
||||
|
||||
class LeCamEMA:
|
||||
def __init__(self, ema_real=0.0, ema_fake=0.0, decay=0.999, dtype=torch.bfloat16, device="cpu"):
|
||||
self.decay = decay
|
||||
self.ema_real = torch.tensor(ema_real).to(device, dtype)
|
||||
self.ema_fake = torch.tensor(ema_fake).to(device, dtype)
|
||||
|
||||
def update(self, ema_real, ema_fake):
|
||||
self.ema_real = self.ema_real * self.decay + ema_real * (1 - self.decay)
|
||||
self.ema_fake = self.ema_fake * self.decay + ema_fake * (1 - self.decay)
|
||||
|
||||
def get(self):
|
||||
return self.ema_real, self.ema_fake
|
||||
|
||||
|
||||
class DiscriminatorLoss(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
discriminator_factor=1.0,
|
||||
discriminator_start=50001,
|
||||
discriminator_loss_type="non-saturating",
|
||||
lecam_loss_weight=None,
|
||||
gradient_penalty_loss_weight=None, # SCH: following MAGVIT config.vqgan.grad_penalty_cost
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert discriminator_loss_type in ["hinge", "vanilla", "non-saturating"]
|
||||
self.discriminator_factor = discriminator_factor
|
||||
self.discriminator_start = discriminator_start
|
||||
self.lecam_loss_weight = lecam_loss_weight
|
||||
self.gradient_penalty_loss_weight = gradient_penalty_loss_weight
|
||||
self.discriminator_loss_type = discriminator_loss_type
|
||||
|
||||
def forward(
|
||||
self,
|
||||
real_logits,
|
||||
fake_logits,
|
||||
global_step,
|
||||
lecam_ema_real=None,
|
||||
lecam_ema_fake=None,
|
||||
real_video=None,
|
||||
split="train",
|
||||
):
|
||||
if self.discriminator_factor is not None and self.discriminator_factor > 0.0:
|
||||
disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_start)
|
||||
|
||||
if self.discriminator_loss_type == "hinge":
|
||||
disc_loss = hinge_d_loss(real_logits, fake_logits)
|
||||
elif self.discriminator_loss_type == "non-saturating":
|
||||
if real_logits is not None:
|
||||
real_loss = sigmoid_cross_entropy_with_logits(
|
||||
labels=torch.ones_like(real_logits), logits=real_logits
|
||||
)
|
||||
else:
|
||||
real_loss = 0.0
|
||||
if fake_logits is not None:
|
||||
fake_loss = sigmoid_cross_entropy_with_logits(
|
||||
labels=torch.zeros_like(fake_logits), logits=fake_logits
|
||||
)
|
||||
else:
|
||||
fake_loss = 0.0
|
||||
disc_loss = 0.5 * (torch.mean(real_loss) + torch.mean(fake_loss))
|
||||
elif self.discriminator_loss_type == "vanilla":
|
||||
disc_loss = vanilla_d_loss(real_logits, fake_logits)
|
||||
else:
|
||||
raise ValueError(f"Unknown GAN loss '{self.discriminator_loss_type}'.")
|
||||
|
||||
weighted_d_adversarial_loss = disc_factor * disc_loss
|
||||
|
||||
else:
|
||||
weighted_d_adversarial_loss = 0
|
||||
|
||||
lecam_loss = torch.tensor(0.0)
|
||||
if self.lecam_loss_weight is not None and self.lecam_loss_weight > 0.0:
|
||||
real_pred = torch.mean(real_logits)
|
||||
fake_pred = torch.mean(fake_logits)
|
||||
lecam_loss = lecam_reg(real_pred, fake_pred, lecam_ema_real, lecam_ema_fake)
|
||||
lecam_loss = lecam_loss * self.lecam_loss_weight
|
||||
|
||||
gradient_penalty = torch.tensor(0.0)
|
||||
if self.gradient_penalty_loss_weight is not None and self.gradient_penalty_loss_weight > 0.0:
|
||||
assert real_video is not None
|
||||
gradient_penalty = gradient_penalty_fn(real_video, real_logits)
|
||||
gradient_penalty *= self.gradient_penalty_loss_weight
|
||||
|
||||
return (weighted_d_adversarial_loss, lecam_loss, gradient_penalty)
|
||||
|
|
@ -6,10 +6,9 @@ import numpy as np
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import pack, rearrange, repeat, unpack
|
||||
from einops import pack, rearrange, unpack
|
||||
|
||||
from .utils import DiagonalGaussianDistribution
|
||||
from .lpips import LPIPS # need to pip install https://github.com/CompVis/taming-transformers
|
||||
from opensora.registry import MODELS
|
||||
from opensora.utils.ckpt_utils import find_model, load_checkpoint
|
||||
|
||||
|
|
@ -62,115 +61,6 @@ def exists(v):
|
|||
return v is not None
|
||||
|
||||
|
||||
# ============== Generator Adversarial Loss Functions ==============
|
||||
def lecam_reg(real_pred, fake_pred, ema_real_pred, ema_fake_pred):
|
||||
assert real_pred.ndim == 0 and ema_fake_pred.ndim == 0
|
||||
lecam_loss = torch.mean(torch.pow(nn.ReLU()(real_pred - ema_fake_pred), 2))
|
||||
lecam_loss += torch.mean(torch.pow(nn.ReLU()(ema_real_pred - fake_pred), 2))
|
||||
return lecam_loss
|
||||
|
||||
|
||||
# Open-Sora-Plan
|
||||
# Very bad, do not use
|
||||
def r1_penalty(real_img, real_pred):
|
||||
"""R1 regularization for discriminator. The core idea is to
|
||||
penalize the gradient on real data alone: when the
|
||||
generator distribution produces the true data distribution
|
||||
and the discriminator is equal to 0 on the data manifold, the
|
||||
gradient penalty ensures that the discriminator cannot create
|
||||
a non-zero gradient orthogonal to the data manifold without
|
||||
suffering a loss in the GAN game.
|
||||
|
||||
Ref:
|
||||
Eq. 9 in Which training methods for GANs do actually converge.
|
||||
"""
|
||||
grad_real = torch.autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0]
|
||||
grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
|
||||
return grad_penalty
|
||||
|
||||
|
||||
# Open-Sora-Plan
|
||||
# Implementation as described by https://arxiv.org/abs/1704.00028 # TODO: checkout the codes
|
||||
def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
|
||||
"""Calculate gradient penalty for wgan-gp.
|
||||
|
||||
Args:
|
||||
discriminator (nn.Module): Network for the discriminator.
|
||||
real_data (Tensor): Real input data.
|
||||
fake_data (Tensor): Fake input data.
|
||||
weight (Tensor): Weight tensor. Default: None.
|
||||
|
||||
Returns:
|
||||
Tensor: A tensor for gradient penalty.
|
||||
"""
|
||||
|
||||
batch_size = real_data.size(0)
|
||||
alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))
|
||||
|
||||
# interpolate between real_data and fake_data
|
||||
interpolates = alpha * real_data + (1.0 - alpha) * fake_data
|
||||
interpolates = torch.autograd.Variable(interpolates, requires_grad=True)
|
||||
|
||||
disc_interpolates = discriminator(interpolates)
|
||||
gradients = torch.autograd.grad(
|
||||
outputs=disc_interpolates,
|
||||
inputs=interpolates,
|
||||
grad_outputs=torch.ones_like(disc_interpolates),
|
||||
create_graph=True,
|
||||
retain_graph=True,
|
||||
only_inputs=True,
|
||||
)[0]
|
||||
|
||||
if weight is not None:
|
||||
gradients = gradients * weight
|
||||
|
||||
gradients_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
|
||||
if weight is not None:
|
||||
gradients_penalty /= torch.mean(weight)
|
||||
|
||||
return gradients_penalty
|
||||
|
||||
|
||||
def gradient_penalty_fn(images, output):
|
||||
# batch_size = images.shape[0]
|
||||
gradients = torch.autograd.grad(
|
||||
outputs=output,
|
||||
inputs=images,
|
||||
grad_outputs=torch.ones(output.size(), device=images.device),
|
||||
create_graph=True,
|
||||
retain_graph=True,
|
||||
only_inputs=True,
|
||||
)[0]
|
||||
|
||||
gradients = rearrange(gradients, "b ... -> b (...)")
|
||||
return ((gradients.norm(2, dim=1) - 1) ** 2).mean()
|
||||
|
||||
|
||||
# ============== Discriminator Adversarial Loss Functions ==============
|
||||
def hinge_d_loss(logits_real, logits_fake):
|
||||
loss_real = torch.mean(F.relu(1.0 - logits_real))
|
||||
loss_fake = torch.mean(F.relu(1.0 + logits_fake))
|
||||
d_loss = 0.5 * (loss_real + loss_fake)
|
||||
return d_loss
|
||||
|
||||
|
||||
def vanilla_d_loss(logits_real, logits_fake):
|
||||
d_loss = 0.5 * (
|
||||
torch.mean(torch.nn.functional.softplus(-logits_real)) + torch.mean(torch.nn.functional.softplus(logits_fake))
|
||||
)
|
||||
return d_loss
|
||||
|
||||
|
||||
# from MAGVIT, used in place hof hinge_d_loss
|
||||
def sigmoid_cross_entropy_with_logits(labels, logits):
|
||||
# The final formulation is: max(x, 0) - x * z + log(1 + exp(-abs(x)))
|
||||
zeros = torch.zeros_like(logits, dtype=logits.dtype)
|
||||
condition = logits >= zeros
|
||||
relu_logits = torch.where(condition, logits, zeros)
|
||||
neg_abs_logits = torch.where(condition, -logits, logits)
|
||||
return relu_logits - logits * labels + torch.log1p(torch.exp(neg_abs_logits))
|
||||
|
||||
|
||||
def xavier_uniform_weight_init(m):
|
||||
if isinstance(m, nn.Conv3d) or isinstance(m, nn.Linear):
|
||||
nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain("relu"))
|
||||
|
|
@ -194,10 +84,7 @@ def SameConv2d(dim_in, dim_out, kernel_size):
|
|||
return nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, padding=padding)
|
||||
|
||||
|
||||
def adopt_weight(weight, global_step, threshold=0, value=0.0):
|
||||
if global_step < threshold:
|
||||
weight = value
|
||||
return weight
|
||||
|
||||
|
||||
|
||||
class CausalConv3d(nn.Module):
|
||||
|
|
@ -1053,238 +940,7 @@ class VAE_3D_V2(nn.Module): # , ModelMixin
|
|||
return recon_video, posterior
|
||||
|
||||
|
||||
class VEALoss(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
logvar_init=0.0,
|
||||
perceptual_loss_weight=0.1,
|
||||
kl_loss_weight=0.000001,
|
||||
# vgg=None,
|
||||
# vgg_weights: VGG16_Weights = VGG16_Weights.DEFAULT,
|
||||
device="cpu",
|
||||
dtype="bf16",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if type(dtype) == str:
|
||||
if dtype == "bf16":
|
||||
dtype = torch.bfloat16
|
||||
elif dtype == "fp16":
|
||||
dtype = torch.float16
|
||||
else:
|
||||
raise NotImplementedError(f"dtype: {dtype}")
|
||||
|
||||
# KL Loss
|
||||
self.kl_loss_weight = kl_loss_weight
|
||||
# Perceptual Loss
|
||||
self.perceptual_loss_fn = LPIPS().eval().to(device, dtype)
|
||||
self.perceptual_loss_weight = perceptual_loss_weight
|
||||
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
video,
|
||||
recon_video,
|
||||
posterior,
|
||||
nll_weights=None,
|
||||
split="train",
|
||||
):
|
||||
video = rearrange(video, "b c t h w -> (b t) c h w").contiguous()
|
||||
recon_video = rearrange(recon_video, "b c t h w -> (b t) c h w").contiguous()
|
||||
|
||||
# reconstruction loss
|
||||
recon_loss = torch.abs(video - recon_video)
|
||||
|
||||
# perceptual loss
|
||||
if self.perceptual_loss_weight is not None and self.perceptual_loss_weight > 0.0:
|
||||
# handle channels
|
||||
channels = video.shape[1]
|
||||
assert channels in {1, 3, 4}
|
||||
if channels == 1:
|
||||
input_vgg_input = repeat(video, "b 1 h w -> b c h w", c=3)
|
||||
recon_vgg_input = repeat(recon_video, "b 1 h w -> b c h w", c=3)
|
||||
elif channels == 4: # SCH: take the first 3 for perceptual loss calc
|
||||
input_vgg_input = video[:, :3]
|
||||
recon_vgg_input = recon_video[:, :3]
|
||||
else:
|
||||
input_vgg_input = video
|
||||
recon_vgg_input = recon_video
|
||||
|
||||
perceptual_loss = self.perceptual_loss_fn(input_vgg_input, recon_vgg_input)
|
||||
recon_loss = recon_loss + self.perceptual_loss_weight * perceptual_loss
|
||||
|
||||
nll_loss = recon_loss / torch.exp(self.logvar) + self.logvar
|
||||
weighted_nll_loss = nll_loss
|
||||
if nll_weights is not None:
|
||||
weighted_nll_loss = nll_weights * nll_loss
|
||||
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
|
||||
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
||||
|
||||
|
||||
# KL Loss
|
||||
weighted_kl_loss = 0
|
||||
if self.kl_loss_weight is not None and self.kl_loss_weight > 0.0:
|
||||
kl_loss = posterior.kl()
|
||||
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
|
||||
weighted_kl_loss = kl_loss * self.kl_loss_weight
|
||||
|
||||
return nll_loss, weighted_nll_loss, weighted_kl_loss
|
||||
|
||||
|
||||
class AdversarialLoss(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
discriminator_factor=1.0,
|
||||
discriminator_start=50001,
|
||||
generator_factor=0.5,
|
||||
generator_loss_type="non-saturating",
|
||||
):
|
||||
super().__init__()
|
||||
self.discriminator_factor = discriminator_factor
|
||||
self.discriminator_start = discriminator_start
|
||||
self.generator_factor = generator_factor
|
||||
self.generator_loss_type = generator_loss_type
|
||||
|
||||
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer):
|
||||
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] # SCH: TODO: debug added creat
|
||||
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[
|
||||
0
|
||||
] # SCH: TODO: debug added create_graph=True
|
||||
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
||||
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
||||
d_weight = d_weight * self.generator_factor
|
||||
return d_weight
|
||||
|
||||
def forward(
|
||||
self,
|
||||
fake_logits,
|
||||
nll_loss,
|
||||
last_layer,
|
||||
global_step,
|
||||
is_training=True,
|
||||
):
|
||||
# NOTE: following MAGVIT to allow non_saturating
|
||||
assert self.generator_loss_type in ["hinge", "vanilla", "non-saturating"]
|
||||
|
||||
if self.generator_loss_type == "hinge":
|
||||
gen_loss = -torch.mean(fake_logits)
|
||||
elif self.generator_loss_type == "non-saturating":
|
||||
gen_loss = torch.mean(
|
||||
sigmoid_cross_entropy_with_logits(labels=torch.ones_like(fake_logits), logits=fake_logits)
|
||||
)
|
||||
else:
|
||||
raise ValueError("Generator loss {} not supported".format(self.generator_loss_type))
|
||||
|
||||
if self.discriminator_factor is not None and self.discriminator_factor > 0.0:
|
||||
try:
|
||||
d_weight = self.calculate_adaptive_weight(nll_loss, gen_loss, last_layer)
|
||||
except RuntimeError:
|
||||
assert not is_training
|
||||
d_weight = torch.tensor(0.0)
|
||||
else:
|
||||
d_weight = torch.tensor(0.0)
|
||||
|
||||
disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_start)
|
||||
weighted_gen_loss = d_weight * disc_factor * gen_loss
|
||||
|
||||
return weighted_gen_loss
|
||||
|
||||
|
||||
class LeCamEMA:
|
||||
def __init__(self, ema_real=0.0, ema_fake=0.0, decay=0.999, dtype=torch.bfloat16, device="cpu"):
|
||||
self.decay = decay
|
||||
self.ema_real = torch.tensor(ema_real).to(device, dtype)
|
||||
self.ema_fake = torch.tensor(ema_fake).to(device, dtype)
|
||||
|
||||
def update(self, ema_real, ema_fake):
|
||||
self.ema_real = self.ema_real * self.decay + ema_real * (1 - self.decay)
|
||||
self.ema_fake = self.ema_fake * self.decay + ema_fake * (1 - self.decay)
|
||||
|
||||
def get(self):
|
||||
return self.ema_real, self.ema_fake
|
||||
|
||||
|
||||
class DiscriminatorLoss(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
discriminator_factor=1.0,
|
||||
discriminator_start=50001,
|
||||
discriminator_loss_type="non-saturating",
|
||||
lecam_loss_weight=None,
|
||||
gradient_penalty_loss_weight=None, # SCH: following MAGVIT config.vqgan.grad_penalty_cost
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert discriminator_loss_type in ["hinge", "vanilla", "non-saturating"]
|
||||
self.discriminator_factor = discriminator_factor
|
||||
self.discriminator_start = discriminator_start
|
||||
self.lecam_loss_weight = lecam_loss_weight
|
||||
self.gradient_penalty_loss_weight = gradient_penalty_loss_weight
|
||||
self.discriminator_loss_type = discriminator_loss_type
|
||||
|
||||
def forward(
|
||||
self,
|
||||
real_logits,
|
||||
fake_logits,
|
||||
global_step,
|
||||
lecam_ema_real=None,
|
||||
lecam_ema_fake=None,
|
||||
real_video=None,
|
||||
split="train",
|
||||
):
|
||||
if self.discriminator_factor is not None and self.discriminator_factor > 0.0:
|
||||
disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_start)
|
||||
|
||||
if self.discriminator_loss_type == "hinge":
|
||||
disc_loss = hinge_d_loss(real_logits, fake_logits)
|
||||
elif self.discriminator_loss_type == "non-saturating":
|
||||
if real_logits is not None:
|
||||
real_loss = sigmoid_cross_entropy_with_logits(
|
||||
labels=torch.ones_like(real_logits), logits=real_logits
|
||||
)
|
||||
else:
|
||||
real_loss = 0.0
|
||||
if fake_logits is not None:
|
||||
fake_loss = sigmoid_cross_entropy_with_logits(
|
||||
labels=torch.zeros_like(fake_logits), logits=fake_logits
|
||||
)
|
||||
else:
|
||||
fake_loss = 0.0
|
||||
disc_loss = 0.5 * (torch.mean(real_loss) + torch.mean(fake_loss))
|
||||
elif self.discriminator_loss_type == "vanilla":
|
||||
disc_loss = vanilla_d_loss(real_logits, fake_logits)
|
||||
else:
|
||||
raise ValueError(f"Unknown GAN loss '{self.discriminator_loss_type}'.")
|
||||
|
||||
weighted_d_adversarial_loss = disc_factor * disc_loss
|
||||
|
||||
else:
|
||||
weighted_d_adversarial_loss = 0
|
||||
|
||||
lecam_loss = torch.tensor(0.0)
|
||||
if self.lecam_loss_weight is not None and self.lecam_loss_weight > 0.0:
|
||||
real_pred = torch.mean(real_logits)
|
||||
fake_pred = torch.mean(fake_logits)
|
||||
lecam_loss = lecam_reg(real_pred, fake_pred, lecam_ema_real, lecam_ema_fake)
|
||||
lecam_loss = lecam_loss * self.lecam_loss_weight
|
||||
|
||||
gradient_penalty = torch.tensor(0.0)
|
||||
if self.gradient_penalty_loss_weight is not None and self.gradient_penalty_loss_weight > 0.0:
|
||||
assert real_video is not None
|
||||
gradient_penalty = gradient_penalty_fn(real_video, real_logits)
|
||||
# gradient_penalty = r1_penalty(real_video, real_logits) # MAGVIT uses r1 penalty
|
||||
gradient_penalty *= self.gradient_penalty_loss_weight
|
||||
|
||||
# discriminator_loss = weighted_d_adversarial_loss + lecam_loss + gradient_penalty
|
||||
|
||||
# log = {
|
||||
# "{}/d_adversarial_loss".format(split): weighted_d_adversarial_loss.detach().mean(),
|
||||
# "{}/lecam_loss".format(split): lecam_loss.detach().mean(),
|
||||
# "{}/gradient_penalty".format(split): gradient_penalty.detach().mean(),
|
||||
# }
|
||||
|
||||
return (weighted_d_adversarial_loss, lecam_loss, gradient_penalty)
|
||||
|
||||
|
||||
@MODELS.register_module("VAE_MAGVIT_V2")
|
||||
|
|
|
|||
|
|
@ -10,7 +10,8 @@ from tqdm import tqdm
|
|||
|
||||
from opensora.acceleration.parallel_states import get_data_parallel_group
|
||||
from opensora.datasets import prepare_dataloader, save_sample
|
||||
from opensora.models.vae.vae_3d_v2 import AdversarialLoss, DiscriminatorLoss, LeCamEMA, VEALoss, pad_at_dim
|
||||
from opensora.models.vae.losses import AdversarialLoss, DiscriminatorLoss, VEALoss
|
||||
from opensora.models.vae.vae_3d import LeCamEMA, pad_at_dim
|
||||
from opensora.registry import DATASETS, MODELS, build_module
|
||||
from opensora.utils.config_utils import parse_configs
|
||||
from opensora.utils.misc import to_torch_dtype
|
||||
|
|
@ -27,12 +28,6 @@ def main():
|
|||
colossalai.launch_from_torch({})
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# if coordinator.world_size > 1:
|
||||
# set_sequence_parallel_group(dist.group.WORLD)
|
||||
# enable_sequence_parallelism = True
|
||||
# else:
|
||||
# enable_sequence_parallelism = False
|
||||
|
||||
# ======================================================
|
||||
# 2. runtime variables
|
||||
# ======================================================
|
||||
|
|
|
|||
|
|
@ -21,7 +21,8 @@ from opensora.acceleration.parallel_states import (
|
|||
)
|
||||
from opensora.acceleration.plugin import ZeroSeqParallelPlugin
|
||||
from opensora.datasets import prepare_dataloader, prepare_variable_dataloader
|
||||
from opensora.models.vae.vae_3d_v2 import AdversarialLoss, DiscriminatorLoss, LeCamEMA, VEALoss, pad_at_dim
|
||||
from opensora.models.vae.vae_3d import LeCamEMA, pad_at_dim
|
||||
from opensora.models.vae.losses import AdversarialLoss, DiscriminatorLoss, VEALoss
|
||||
from opensora.registry import DATASETS, MODELS, build_module
|
||||
from opensora.utils.ckpt_utils import create_logger, load_json, save_json
|
||||
from opensora.utils.config_utils import (
|
||||
|
|
|
|||
Loading…
Reference in a new issue