move out losses to different file

This commit is contained in:
Shen-Chenhui 2024-04-29 17:53:03 +08:00
parent 4d69671663
commit f81d8648cf
4 changed files with 279 additions and 354 deletions

View file

@ -0,0 +1,273 @@
import torch.nn as nn
import torch
import torch.nn.functional as F
from .lpips import LPIPS
from einops import rearrange, repeat
def hinge_d_loss(logits_real, logits_fake):
loss_real = torch.mean(F.relu(1.0 - logits_real))
loss_fake = torch.mean(F.relu(1.0 + logits_fake))
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def vanilla_d_loss(logits_real, logits_fake):
d_loss = 0.5 * (
torch.mean(torch.nn.functional.softplus(-logits_real)) + torch.mean(torch.nn.functional.softplus(logits_fake))
)
return d_loss
# from MAGVIT, used in place hof hinge_d_loss
def sigmoid_cross_entropy_with_logits(labels, logits):
# The final formulation is: max(x, 0) - x * z + log(1 + exp(-abs(x)))
zeros = torch.zeros_like(logits, dtype=logits.dtype)
condition = logits >= zeros
relu_logits = torch.where(condition, logits, zeros)
neg_abs_logits = torch.where(condition, -logits, logits)
return relu_logits - logits * labels + torch.log1p(torch.exp(neg_abs_logits))
def lecam_reg(real_pred, fake_pred, ema_real_pred, ema_fake_pred):
assert real_pred.ndim == 0 and ema_fake_pred.ndim == 0
lecam_loss = torch.mean(torch.pow(nn.ReLU()(real_pred - ema_fake_pred), 2))
lecam_loss += torch.mean(torch.pow(nn.ReLU()(ema_real_pred - fake_pred), 2))
return lecam_loss
def gradient_penalty_fn(images, output):
gradients = torch.autograd.grad(
outputs=output,
inputs=images,
grad_outputs=torch.ones(output.size(), device=images.device),
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
gradients = rearrange(gradients, "b ... -> b (...)")
return ((gradients.norm(2, dim=1) - 1) ** 2).mean()
class VEALoss(nn.Module):
def __init__(
self,
logvar_init=0.0,
perceptual_loss_weight=0.1,
kl_loss_weight=0.000001,
device="cpu",
dtype="bf16",
):
super().__init__()
if type(dtype) == str:
if dtype == "bf16":
dtype = torch.bfloat16
elif dtype == "fp16":
dtype = torch.float16
else:
raise NotImplementedError(f"dtype: {dtype}")
# KL Loss
self.kl_loss_weight = kl_loss_weight
# Perceptual Loss
self.perceptual_loss_fn = LPIPS().eval().to(device, dtype)
self.perceptual_loss_weight = perceptual_loss_weight
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
def forward(
self,
video,
recon_video,
posterior,
nll_weights=None,
split="train",
):
video = rearrange(video, "b c t h w -> (b t) c h w").contiguous()
recon_video = rearrange(recon_video, "b c t h w -> (b t) c h w").contiguous()
# reconstruction loss
recon_loss = torch.abs(video - recon_video)
# perceptual loss
if self.perceptual_loss_weight is not None and self.perceptual_loss_weight > 0.0:
# handle channels
channels = video.shape[1]
assert channels in {1, 3, 4}
if channels == 1:
input_vgg_input = repeat(video, "b 1 h w -> b c h w", c=3)
recon_vgg_input = repeat(recon_video, "b 1 h w -> b c h w", c=3)
elif channels == 4: # SCH: take the first 3 for perceptual loss calc
input_vgg_input = video[:, :3]
recon_vgg_input = recon_video[:, :3]
else:
input_vgg_input = video
recon_vgg_input = recon_video
perceptual_loss = self.perceptual_loss_fn(input_vgg_input, recon_vgg_input)
recon_loss = recon_loss + self.perceptual_loss_weight * perceptual_loss
nll_loss = recon_loss / torch.exp(self.logvar) + self.logvar
weighted_nll_loss = nll_loss
if nll_weights is not None:
weighted_nll_loss = nll_weights * nll_loss
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
# KL Loss
weighted_kl_loss = 0
if self.kl_loss_weight is not None and self.kl_loss_weight > 0.0:
kl_loss = posterior.kl()
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
weighted_kl_loss = kl_loss * self.kl_loss_weight
return nll_loss, weighted_nll_loss, weighted_kl_loss
def adopt_weight(weight, global_step, threshold=0, value=0.0):
if global_step < threshold:
weight = value
return weight
class AdversarialLoss(nn.Module):
def __init__(
self,
discriminator_factor=1.0,
discriminator_start=50001,
generator_factor=0.5,
generator_loss_type="non-saturating",
):
super().__init__()
self.discriminator_factor = discriminator_factor
self.discriminator_start = discriminator_start
self.generator_factor = generator_factor
self.generator_loss_type = generator_loss_type
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer):
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[
0
]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.generator_factor
return d_weight
def forward(
self,
fake_logits,
nll_loss,
last_layer,
global_step,
is_training=True,
):
# NOTE: following MAGVIT to allow non_saturating
assert self.generator_loss_type in ["hinge", "vanilla", "non-saturating"]
if self.generator_loss_type == "hinge":
gen_loss = -torch.mean(fake_logits)
elif self.generator_loss_type == "non-saturating":
gen_loss = torch.mean(
sigmoid_cross_entropy_with_logits(labels=torch.ones_like(fake_logits), logits=fake_logits)
)
else:
raise ValueError("Generator loss {} not supported".format(self.generator_loss_type))
if self.discriminator_factor is not None and self.discriminator_factor > 0.0:
try:
d_weight = self.calculate_adaptive_weight(nll_loss, gen_loss, last_layer)
except RuntimeError:
assert not is_training
d_weight = torch.tensor(0.0)
else:
d_weight = torch.tensor(0.0)
disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_start)
weighted_gen_loss = d_weight * disc_factor * gen_loss
return weighted_gen_loss
class LeCamEMA:
def __init__(self, ema_real=0.0, ema_fake=0.0, decay=0.999, dtype=torch.bfloat16, device="cpu"):
self.decay = decay
self.ema_real = torch.tensor(ema_real).to(device, dtype)
self.ema_fake = torch.tensor(ema_fake).to(device, dtype)
def update(self, ema_real, ema_fake):
self.ema_real = self.ema_real * self.decay + ema_real * (1 - self.decay)
self.ema_fake = self.ema_fake * self.decay + ema_fake * (1 - self.decay)
def get(self):
return self.ema_real, self.ema_fake
class DiscriminatorLoss(nn.Module):
def __init__(
self,
discriminator_factor=1.0,
discriminator_start=50001,
discriminator_loss_type="non-saturating",
lecam_loss_weight=None,
gradient_penalty_loss_weight=None, # SCH: following MAGVIT config.vqgan.grad_penalty_cost
):
super().__init__()
assert discriminator_loss_type in ["hinge", "vanilla", "non-saturating"]
self.discriminator_factor = discriminator_factor
self.discriminator_start = discriminator_start
self.lecam_loss_weight = lecam_loss_weight
self.gradient_penalty_loss_weight = gradient_penalty_loss_weight
self.discriminator_loss_type = discriminator_loss_type
def forward(
self,
real_logits,
fake_logits,
global_step,
lecam_ema_real=None,
lecam_ema_fake=None,
real_video=None,
split="train",
):
if self.discriminator_factor is not None and self.discriminator_factor > 0.0:
disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_start)
if self.discriminator_loss_type == "hinge":
disc_loss = hinge_d_loss(real_logits, fake_logits)
elif self.discriminator_loss_type == "non-saturating":
if real_logits is not None:
real_loss = sigmoid_cross_entropy_with_logits(
labels=torch.ones_like(real_logits), logits=real_logits
)
else:
real_loss = 0.0
if fake_logits is not None:
fake_loss = sigmoid_cross_entropy_with_logits(
labels=torch.zeros_like(fake_logits), logits=fake_logits
)
else:
fake_loss = 0.0
disc_loss = 0.5 * (torch.mean(real_loss) + torch.mean(fake_loss))
elif self.discriminator_loss_type == "vanilla":
disc_loss = vanilla_d_loss(real_logits, fake_logits)
else:
raise ValueError(f"Unknown GAN loss '{self.discriminator_loss_type}'.")
weighted_d_adversarial_loss = disc_factor * disc_loss
else:
weighted_d_adversarial_loss = 0
lecam_loss = torch.tensor(0.0)
if self.lecam_loss_weight is not None and self.lecam_loss_weight > 0.0:
real_pred = torch.mean(real_logits)
fake_pred = torch.mean(fake_logits)
lecam_loss = lecam_reg(real_pred, fake_pred, lecam_ema_real, lecam_ema_fake)
lecam_loss = lecam_loss * self.lecam_loss_weight
gradient_penalty = torch.tensor(0.0)
if self.gradient_penalty_loss_weight is not None and self.gradient_penalty_loss_weight > 0.0:
assert real_video is not None
gradient_penalty = gradient_penalty_fn(real_video, real_logits)
gradient_penalty *= self.gradient_penalty_loss_weight
return (weighted_d_adversarial_loss, lecam_loss, gradient_penalty)

View file

@ -6,10 +6,9 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import pack, rearrange, repeat, unpack
from einops import pack, rearrange, unpack
from .utils import DiagonalGaussianDistribution
from .lpips import LPIPS # need to pip install https://github.com/CompVis/taming-transformers
from opensora.registry import MODELS
from opensora.utils.ckpt_utils import find_model, load_checkpoint
@ -62,115 +61,6 @@ def exists(v):
return v is not None
# ============== Generator Adversarial Loss Functions ==============
def lecam_reg(real_pred, fake_pred, ema_real_pred, ema_fake_pred):
assert real_pred.ndim == 0 and ema_fake_pred.ndim == 0
lecam_loss = torch.mean(torch.pow(nn.ReLU()(real_pred - ema_fake_pred), 2))
lecam_loss += torch.mean(torch.pow(nn.ReLU()(ema_real_pred - fake_pred), 2))
return lecam_loss
# Open-Sora-Plan
# Very bad, do not use
def r1_penalty(real_img, real_pred):
"""R1 regularization for discriminator. The core idea is to
penalize the gradient on real data alone: when the
generator distribution produces the true data distribution
and the discriminator is equal to 0 on the data manifold, the
gradient penalty ensures that the discriminator cannot create
a non-zero gradient orthogonal to the data manifold without
suffering a loss in the GAN game.
Ref:
Eq. 9 in Which training methods for GANs do actually converge.
"""
grad_real = torch.autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0]
grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
return grad_penalty
# Open-Sora-Plan
# Implementation as described by https://arxiv.org/abs/1704.00028 # TODO: checkout the codes
def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
"""Calculate gradient penalty for wgan-gp.
Args:
discriminator (nn.Module): Network for the discriminator.
real_data (Tensor): Real input data.
fake_data (Tensor): Fake input data.
weight (Tensor): Weight tensor. Default: None.
Returns:
Tensor: A tensor for gradient penalty.
"""
batch_size = real_data.size(0)
alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))
# interpolate between real_data and fake_data
interpolates = alpha * real_data + (1.0 - alpha) * fake_data
interpolates = torch.autograd.Variable(interpolates, requires_grad=True)
disc_interpolates = discriminator(interpolates)
gradients = torch.autograd.grad(
outputs=disc_interpolates,
inputs=interpolates,
grad_outputs=torch.ones_like(disc_interpolates),
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
if weight is not None:
gradients = gradients * weight
gradients_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
if weight is not None:
gradients_penalty /= torch.mean(weight)
return gradients_penalty
def gradient_penalty_fn(images, output):
# batch_size = images.shape[0]
gradients = torch.autograd.grad(
outputs=output,
inputs=images,
grad_outputs=torch.ones(output.size(), device=images.device),
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
gradients = rearrange(gradients, "b ... -> b (...)")
return ((gradients.norm(2, dim=1) - 1) ** 2).mean()
# ============== Discriminator Adversarial Loss Functions ==============
def hinge_d_loss(logits_real, logits_fake):
loss_real = torch.mean(F.relu(1.0 - logits_real))
loss_fake = torch.mean(F.relu(1.0 + logits_fake))
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def vanilla_d_loss(logits_real, logits_fake):
d_loss = 0.5 * (
torch.mean(torch.nn.functional.softplus(-logits_real)) + torch.mean(torch.nn.functional.softplus(logits_fake))
)
return d_loss
# from MAGVIT, used in place hof hinge_d_loss
def sigmoid_cross_entropy_with_logits(labels, logits):
# The final formulation is: max(x, 0) - x * z + log(1 + exp(-abs(x)))
zeros = torch.zeros_like(logits, dtype=logits.dtype)
condition = logits >= zeros
relu_logits = torch.where(condition, logits, zeros)
neg_abs_logits = torch.where(condition, -logits, logits)
return relu_logits - logits * labels + torch.log1p(torch.exp(neg_abs_logits))
def xavier_uniform_weight_init(m):
if isinstance(m, nn.Conv3d) or isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain("relu"))
@ -194,10 +84,7 @@ def SameConv2d(dim_in, dim_out, kernel_size):
return nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, padding=padding)
def adopt_weight(weight, global_step, threshold=0, value=0.0):
if global_step < threshold:
weight = value
return weight
class CausalConv3d(nn.Module):
@ -1053,238 +940,7 @@ class VAE_3D_V2(nn.Module): # , ModelMixin
return recon_video, posterior
class VEALoss(nn.Module):
def __init__(
self,
logvar_init=0.0,
perceptual_loss_weight=0.1,
kl_loss_weight=0.000001,
# vgg=None,
# vgg_weights: VGG16_Weights = VGG16_Weights.DEFAULT,
device="cpu",
dtype="bf16",
):
super().__init__()
if type(dtype) == str:
if dtype == "bf16":
dtype = torch.bfloat16
elif dtype == "fp16":
dtype = torch.float16
else:
raise NotImplementedError(f"dtype: {dtype}")
# KL Loss
self.kl_loss_weight = kl_loss_weight
# Perceptual Loss
self.perceptual_loss_fn = LPIPS().eval().to(device, dtype)
self.perceptual_loss_weight = perceptual_loss_weight
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
def forward(
self,
video,
recon_video,
posterior,
nll_weights=None,
split="train",
):
video = rearrange(video, "b c t h w -> (b t) c h w").contiguous()
recon_video = rearrange(recon_video, "b c t h w -> (b t) c h w").contiguous()
# reconstruction loss
recon_loss = torch.abs(video - recon_video)
# perceptual loss
if self.perceptual_loss_weight is not None and self.perceptual_loss_weight > 0.0:
# handle channels
channels = video.shape[1]
assert channels in {1, 3, 4}
if channels == 1:
input_vgg_input = repeat(video, "b 1 h w -> b c h w", c=3)
recon_vgg_input = repeat(recon_video, "b 1 h w -> b c h w", c=3)
elif channels == 4: # SCH: take the first 3 for perceptual loss calc
input_vgg_input = video[:, :3]
recon_vgg_input = recon_video[:, :3]
else:
input_vgg_input = video
recon_vgg_input = recon_video
perceptual_loss = self.perceptual_loss_fn(input_vgg_input, recon_vgg_input)
recon_loss = recon_loss + self.perceptual_loss_weight * perceptual_loss
nll_loss = recon_loss / torch.exp(self.logvar) + self.logvar
weighted_nll_loss = nll_loss
if nll_weights is not None:
weighted_nll_loss = nll_weights * nll_loss
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
# KL Loss
weighted_kl_loss = 0
if self.kl_loss_weight is not None and self.kl_loss_weight > 0.0:
kl_loss = posterior.kl()
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
weighted_kl_loss = kl_loss * self.kl_loss_weight
return nll_loss, weighted_nll_loss, weighted_kl_loss
class AdversarialLoss(nn.Module):
def __init__(
self,
discriminator_factor=1.0,
discriminator_start=50001,
generator_factor=0.5,
generator_loss_type="non-saturating",
):
super().__init__()
self.discriminator_factor = discriminator_factor
self.discriminator_start = discriminator_start
self.generator_factor = generator_factor
self.generator_loss_type = generator_loss_type
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer):
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] # SCH: TODO: debug added creat
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[
0
] # SCH: TODO: debug added create_graph=True
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.generator_factor
return d_weight
def forward(
self,
fake_logits,
nll_loss,
last_layer,
global_step,
is_training=True,
):
# NOTE: following MAGVIT to allow non_saturating
assert self.generator_loss_type in ["hinge", "vanilla", "non-saturating"]
if self.generator_loss_type == "hinge":
gen_loss = -torch.mean(fake_logits)
elif self.generator_loss_type == "non-saturating":
gen_loss = torch.mean(
sigmoid_cross_entropy_with_logits(labels=torch.ones_like(fake_logits), logits=fake_logits)
)
else:
raise ValueError("Generator loss {} not supported".format(self.generator_loss_type))
if self.discriminator_factor is not None and self.discriminator_factor > 0.0:
try:
d_weight = self.calculate_adaptive_weight(nll_loss, gen_loss, last_layer)
except RuntimeError:
assert not is_training
d_weight = torch.tensor(0.0)
else:
d_weight = torch.tensor(0.0)
disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_start)
weighted_gen_loss = d_weight * disc_factor * gen_loss
return weighted_gen_loss
class LeCamEMA:
def __init__(self, ema_real=0.0, ema_fake=0.0, decay=0.999, dtype=torch.bfloat16, device="cpu"):
self.decay = decay
self.ema_real = torch.tensor(ema_real).to(device, dtype)
self.ema_fake = torch.tensor(ema_fake).to(device, dtype)
def update(self, ema_real, ema_fake):
self.ema_real = self.ema_real * self.decay + ema_real * (1 - self.decay)
self.ema_fake = self.ema_fake * self.decay + ema_fake * (1 - self.decay)
def get(self):
return self.ema_real, self.ema_fake
class DiscriminatorLoss(nn.Module):
def __init__(
self,
discriminator_factor=1.0,
discriminator_start=50001,
discriminator_loss_type="non-saturating",
lecam_loss_weight=None,
gradient_penalty_loss_weight=None, # SCH: following MAGVIT config.vqgan.grad_penalty_cost
):
super().__init__()
assert discriminator_loss_type in ["hinge", "vanilla", "non-saturating"]
self.discriminator_factor = discriminator_factor
self.discriminator_start = discriminator_start
self.lecam_loss_weight = lecam_loss_weight
self.gradient_penalty_loss_weight = gradient_penalty_loss_weight
self.discriminator_loss_type = discriminator_loss_type
def forward(
self,
real_logits,
fake_logits,
global_step,
lecam_ema_real=None,
lecam_ema_fake=None,
real_video=None,
split="train",
):
if self.discriminator_factor is not None and self.discriminator_factor > 0.0:
disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_start)
if self.discriminator_loss_type == "hinge":
disc_loss = hinge_d_loss(real_logits, fake_logits)
elif self.discriminator_loss_type == "non-saturating":
if real_logits is not None:
real_loss = sigmoid_cross_entropy_with_logits(
labels=torch.ones_like(real_logits), logits=real_logits
)
else:
real_loss = 0.0
if fake_logits is not None:
fake_loss = sigmoid_cross_entropy_with_logits(
labels=torch.zeros_like(fake_logits), logits=fake_logits
)
else:
fake_loss = 0.0
disc_loss = 0.5 * (torch.mean(real_loss) + torch.mean(fake_loss))
elif self.discriminator_loss_type == "vanilla":
disc_loss = vanilla_d_loss(real_logits, fake_logits)
else:
raise ValueError(f"Unknown GAN loss '{self.discriminator_loss_type}'.")
weighted_d_adversarial_loss = disc_factor * disc_loss
else:
weighted_d_adversarial_loss = 0
lecam_loss = torch.tensor(0.0)
if self.lecam_loss_weight is not None and self.lecam_loss_weight > 0.0:
real_pred = torch.mean(real_logits)
fake_pred = torch.mean(fake_logits)
lecam_loss = lecam_reg(real_pred, fake_pred, lecam_ema_real, lecam_ema_fake)
lecam_loss = lecam_loss * self.lecam_loss_weight
gradient_penalty = torch.tensor(0.0)
if self.gradient_penalty_loss_weight is not None and self.gradient_penalty_loss_weight > 0.0:
assert real_video is not None
gradient_penalty = gradient_penalty_fn(real_video, real_logits)
# gradient_penalty = r1_penalty(real_video, real_logits) # MAGVIT uses r1 penalty
gradient_penalty *= self.gradient_penalty_loss_weight
# discriminator_loss = weighted_d_adversarial_loss + lecam_loss + gradient_penalty
# log = {
# "{}/d_adversarial_loss".format(split): weighted_d_adversarial_loss.detach().mean(),
# "{}/lecam_loss".format(split): lecam_loss.detach().mean(),
# "{}/gradient_penalty".format(split): gradient_penalty.detach().mean(),
# }
return (weighted_d_adversarial_loss, lecam_loss, gradient_penalty)
@MODELS.register_module("VAE_MAGVIT_V2")

View file

@ -10,7 +10,8 @@ from tqdm import tqdm
from opensora.acceleration.parallel_states import get_data_parallel_group
from opensora.datasets import prepare_dataloader, save_sample
from opensora.models.vae.vae_3d_v2 import AdversarialLoss, DiscriminatorLoss, LeCamEMA, VEALoss, pad_at_dim
from opensora.models.vae.losses import AdversarialLoss, DiscriminatorLoss, VEALoss
from opensora.models.vae.vae_3d import LeCamEMA, pad_at_dim
from opensora.registry import DATASETS, MODELS, build_module
from opensora.utils.config_utils import parse_configs
from opensora.utils.misc import to_torch_dtype
@ -27,12 +28,6 @@ def main():
colossalai.launch_from_torch({})
coordinator = DistCoordinator()
# if coordinator.world_size > 1:
# set_sequence_parallel_group(dist.group.WORLD)
# enable_sequence_parallelism = True
# else:
# enable_sequence_parallelism = False
# ======================================================
# 2. runtime variables
# ======================================================

View file

@ -21,7 +21,8 @@ from opensora.acceleration.parallel_states import (
)
from opensora.acceleration.plugin import ZeroSeqParallelPlugin
from opensora.datasets import prepare_dataloader, prepare_variable_dataloader
from opensora.models.vae.vae_3d_v2 import AdversarialLoss, DiscriminatorLoss, LeCamEMA, VEALoss, pad_at_dim
from opensora.models.vae.vae_3d import LeCamEMA, pad_at_dim
from opensora.models.vae.losses import AdversarialLoss, DiscriminatorLoss, VEALoss
from opensora.registry import DATASETS, MODELS, build_module
from opensora.utils.ckpt_utils import create_logger, load_json, save_json
from opensora.utils.config_utils import (