From f81d8648cf96cafbdf5087398a0716efd004b2ff Mon Sep 17 00:00:00 2001 From: Shen-Chenhui Date: Mon, 29 Apr 2024 17:53:03 +0800 Subject: [PATCH] move out losses to different file --- opensora/models/vae/losses.py | 273 ++++++++++++++++++++++++++ opensora/models/vae/vae_3d.py | 348 +--------------------------------- scripts/inference-vae.py | 9 +- scripts/train-vae.py | 3 +- 4 files changed, 279 insertions(+), 354 deletions(-) create mode 100644 opensora/models/vae/losses.py diff --git a/opensora/models/vae/losses.py b/opensora/models/vae/losses.py new file mode 100644 index 0000000..0561e70 --- /dev/null +++ b/opensora/models/vae/losses.py @@ -0,0 +1,273 @@ +import torch.nn as nn +import torch +import torch.nn.functional as F +from .lpips import LPIPS +from einops import rearrange, repeat + +def hinge_d_loss(logits_real, logits_fake): + loss_real = torch.mean(F.relu(1.0 - logits_real)) + loss_fake = torch.mean(F.relu(1.0 + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def vanilla_d_loss(logits_real, logits_fake): + d_loss = 0.5 * ( + torch.mean(torch.nn.functional.softplus(-logits_real)) + torch.mean(torch.nn.functional.softplus(logits_fake)) + ) + return d_loss + + +# from MAGVIT, used in place hof hinge_d_loss +def sigmoid_cross_entropy_with_logits(labels, logits): + # The final formulation is: max(x, 0) - x * z + log(1 + exp(-abs(x))) + zeros = torch.zeros_like(logits, dtype=logits.dtype) + condition = logits >= zeros + relu_logits = torch.where(condition, logits, zeros) + neg_abs_logits = torch.where(condition, -logits, logits) + return relu_logits - logits * labels + torch.log1p(torch.exp(neg_abs_logits)) + +def lecam_reg(real_pred, fake_pred, ema_real_pred, ema_fake_pred): + assert real_pred.ndim == 0 and ema_fake_pred.ndim == 0 + lecam_loss = torch.mean(torch.pow(nn.ReLU()(real_pred - ema_fake_pred), 2)) + lecam_loss += torch.mean(torch.pow(nn.ReLU()(ema_real_pred - fake_pred), 2)) + return lecam_loss + +def gradient_penalty_fn(images, output): + gradients = torch.autograd.grad( + outputs=output, + inputs=images, + grad_outputs=torch.ones(output.size(), device=images.device), + create_graph=True, + retain_graph=True, + only_inputs=True, + )[0] + + gradients = rearrange(gradients, "b ... -> b (...)") + return ((gradients.norm(2, dim=1) - 1) ** 2).mean() + +class VEALoss(nn.Module): + def __init__( + self, + logvar_init=0.0, + perceptual_loss_weight=0.1, + kl_loss_weight=0.000001, + device="cpu", + dtype="bf16", + ): + super().__init__() + + if type(dtype) == str: + if dtype == "bf16": + dtype = torch.bfloat16 + elif dtype == "fp16": + dtype = torch.float16 + else: + raise NotImplementedError(f"dtype: {dtype}") + + # KL Loss + self.kl_loss_weight = kl_loss_weight + # Perceptual Loss + self.perceptual_loss_fn = LPIPS().eval().to(device, dtype) + self.perceptual_loss_weight = perceptual_loss_weight + self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) + + + def forward( + self, + video, + recon_video, + posterior, + nll_weights=None, + split="train", + ): + video = rearrange(video, "b c t h w -> (b t) c h w").contiguous() + recon_video = rearrange(recon_video, "b c t h w -> (b t) c h w").contiguous() + + # reconstruction loss + recon_loss = torch.abs(video - recon_video) + + # perceptual loss + if self.perceptual_loss_weight is not None and self.perceptual_loss_weight > 0.0: + # handle channels + channels = video.shape[1] + assert channels in {1, 3, 4} + if channels == 1: + input_vgg_input = repeat(video, "b 1 h w -> b c h w", c=3) + recon_vgg_input = repeat(recon_video, "b 1 h w -> b c h w", c=3) + elif channels == 4: # SCH: take the first 3 for perceptual loss calc + input_vgg_input = video[:, :3] + recon_vgg_input = recon_video[:, :3] + else: + input_vgg_input = video + recon_vgg_input = recon_video + + perceptual_loss = self.perceptual_loss_fn(input_vgg_input, recon_vgg_input) + recon_loss = recon_loss + self.perceptual_loss_weight * perceptual_loss + + nll_loss = recon_loss / torch.exp(self.logvar) + self.logvar + weighted_nll_loss = nll_loss + if nll_weights is not None: + weighted_nll_loss = nll_weights * nll_loss + weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] + nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + + + # KL Loss + weighted_kl_loss = 0 + if self.kl_loss_weight is not None and self.kl_loss_weight > 0.0: + kl_loss = posterior.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + weighted_kl_loss = kl_loss * self.kl_loss_weight + + return nll_loss, weighted_nll_loss, weighted_kl_loss + +def adopt_weight(weight, global_step, threshold=0, value=0.0): + if global_step < threshold: + weight = value + return weight + +class AdversarialLoss(nn.Module): + def __init__( + self, + discriminator_factor=1.0, + discriminator_start=50001, + generator_factor=0.5, + generator_loss_type="non-saturating", + ): + super().__init__() + self.discriminator_factor = discriminator_factor + self.discriminator_start = discriminator_start + self.generator_factor = generator_factor + self.generator_loss_type = generator_loss_type + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer): + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[ + 0 + ] + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.generator_factor + return d_weight + + def forward( + self, + fake_logits, + nll_loss, + last_layer, + global_step, + is_training=True, + ): + # NOTE: following MAGVIT to allow non_saturating + assert self.generator_loss_type in ["hinge", "vanilla", "non-saturating"] + + if self.generator_loss_type == "hinge": + gen_loss = -torch.mean(fake_logits) + elif self.generator_loss_type == "non-saturating": + gen_loss = torch.mean( + sigmoid_cross_entropy_with_logits(labels=torch.ones_like(fake_logits), logits=fake_logits) + ) + else: + raise ValueError("Generator loss {} not supported".format(self.generator_loss_type)) + + if self.discriminator_factor is not None and self.discriminator_factor > 0.0: + try: + d_weight = self.calculate_adaptive_weight(nll_loss, gen_loss, last_layer) + except RuntimeError: + assert not is_training + d_weight = torch.tensor(0.0) + else: + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_start) + weighted_gen_loss = d_weight * disc_factor * gen_loss + + return weighted_gen_loss + + +class LeCamEMA: + def __init__(self, ema_real=0.0, ema_fake=0.0, decay=0.999, dtype=torch.bfloat16, device="cpu"): + self.decay = decay + self.ema_real = torch.tensor(ema_real).to(device, dtype) + self.ema_fake = torch.tensor(ema_fake).to(device, dtype) + + def update(self, ema_real, ema_fake): + self.ema_real = self.ema_real * self.decay + ema_real * (1 - self.decay) + self.ema_fake = self.ema_fake * self.decay + ema_fake * (1 - self.decay) + + def get(self): + return self.ema_real, self.ema_fake + + +class DiscriminatorLoss(nn.Module): + def __init__( + self, + discriminator_factor=1.0, + discriminator_start=50001, + discriminator_loss_type="non-saturating", + lecam_loss_weight=None, + gradient_penalty_loss_weight=None, # SCH: following MAGVIT config.vqgan.grad_penalty_cost + ): + super().__init__() + + assert discriminator_loss_type in ["hinge", "vanilla", "non-saturating"] + self.discriminator_factor = discriminator_factor + self.discriminator_start = discriminator_start + self.lecam_loss_weight = lecam_loss_weight + self.gradient_penalty_loss_weight = gradient_penalty_loss_weight + self.discriminator_loss_type = discriminator_loss_type + + def forward( + self, + real_logits, + fake_logits, + global_step, + lecam_ema_real=None, + lecam_ema_fake=None, + real_video=None, + split="train", + ): + if self.discriminator_factor is not None and self.discriminator_factor > 0.0: + disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_start) + + if self.discriminator_loss_type == "hinge": + disc_loss = hinge_d_loss(real_logits, fake_logits) + elif self.discriminator_loss_type == "non-saturating": + if real_logits is not None: + real_loss = sigmoid_cross_entropy_with_logits( + labels=torch.ones_like(real_logits), logits=real_logits + ) + else: + real_loss = 0.0 + if fake_logits is not None: + fake_loss = sigmoid_cross_entropy_with_logits( + labels=torch.zeros_like(fake_logits), logits=fake_logits + ) + else: + fake_loss = 0.0 + disc_loss = 0.5 * (torch.mean(real_loss) + torch.mean(fake_loss)) + elif self.discriminator_loss_type == "vanilla": + disc_loss = vanilla_d_loss(real_logits, fake_logits) + else: + raise ValueError(f"Unknown GAN loss '{self.discriminator_loss_type}'.") + + weighted_d_adversarial_loss = disc_factor * disc_loss + + else: + weighted_d_adversarial_loss = 0 + + lecam_loss = torch.tensor(0.0) + if self.lecam_loss_weight is not None and self.lecam_loss_weight > 0.0: + real_pred = torch.mean(real_logits) + fake_pred = torch.mean(fake_logits) + lecam_loss = lecam_reg(real_pred, fake_pred, lecam_ema_real, lecam_ema_fake) + lecam_loss = lecam_loss * self.lecam_loss_weight + + gradient_penalty = torch.tensor(0.0) + if self.gradient_penalty_loss_weight is not None and self.gradient_penalty_loss_weight > 0.0: + assert real_video is not None + gradient_penalty = gradient_penalty_fn(real_video, real_logits) + gradient_penalty *= self.gradient_penalty_loss_weight + + return (weighted_d_adversarial_loss, lecam_loss, gradient_penalty) \ No newline at end of file diff --git a/opensora/models/vae/vae_3d.py b/opensora/models/vae/vae_3d.py index 5b74ed0..c05cfde 100644 --- a/opensora/models/vae/vae_3d.py +++ b/opensora/models/vae/vae_3d.py @@ -6,10 +6,9 @@ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from einops import pack, rearrange, repeat, unpack +from einops import pack, rearrange, unpack from .utils import DiagonalGaussianDistribution -from .lpips import LPIPS # need to pip install https://github.com/CompVis/taming-transformers from opensora.registry import MODELS from opensora.utils.ckpt_utils import find_model, load_checkpoint @@ -62,115 +61,6 @@ def exists(v): return v is not None -# ============== Generator Adversarial Loss Functions ============== -def lecam_reg(real_pred, fake_pred, ema_real_pred, ema_fake_pred): - assert real_pred.ndim == 0 and ema_fake_pred.ndim == 0 - lecam_loss = torch.mean(torch.pow(nn.ReLU()(real_pred - ema_fake_pred), 2)) - lecam_loss += torch.mean(torch.pow(nn.ReLU()(ema_real_pred - fake_pred), 2)) - return lecam_loss - - -# Open-Sora-Plan -# Very bad, do not use -def r1_penalty(real_img, real_pred): - """R1 regularization for discriminator. The core idea is to - penalize the gradient on real data alone: when the - generator distribution produces the true data distribution - and the discriminator is equal to 0 on the data manifold, the - gradient penalty ensures that the discriminator cannot create - a non-zero gradient orthogonal to the data manifold without - suffering a loss in the GAN game. - - Ref: - Eq. 9 in Which training methods for GANs do actually converge. - """ - grad_real = torch.autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0] - grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean() - return grad_penalty - - -# Open-Sora-Plan -# Implementation as described by https://arxiv.org/abs/1704.00028 # TODO: checkout the codes -def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None): - """Calculate gradient penalty for wgan-gp. - - Args: - discriminator (nn.Module): Network for the discriminator. - real_data (Tensor): Real input data. - fake_data (Tensor): Fake input data. - weight (Tensor): Weight tensor. Default: None. - - Returns: - Tensor: A tensor for gradient penalty. - """ - - batch_size = real_data.size(0) - alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1)) - - # interpolate between real_data and fake_data - interpolates = alpha * real_data + (1.0 - alpha) * fake_data - interpolates = torch.autograd.Variable(interpolates, requires_grad=True) - - disc_interpolates = discriminator(interpolates) - gradients = torch.autograd.grad( - outputs=disc_interpolates, - inputs=interpolates, - grad_outputs=torch.ones_like(disc_interpolates), - create_graph=True, - retain_graph=True, - only_inputs=True, - )[0] - - if weight is not None: - gradients = gradients * weight - - gradients_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() - if weight is not None: - gradients_penalty /= torch.mean(weight) - - return gradients_penalty - - -def gradient_penalty_fn(images, output): - # batch_size = images.shape[0] - gradients = torch.autograd.grad( - outputs=output, - inputs=images, - grad_outputs=torch.ones(output.size(), device=images.device), - create_graph=True, - retain_graph=True, - only_inputs=True, - )[0] - - gradients = rearrange(gradients, "b ... -> b (...)") - return ((gradients.norm(2, dim=1) - 1) ** 2).mean() - - -# ============== Discriminator Adversarial Loss Functions ============== -def hinge_d_loss(logits_real, logits_fake): - loss_real = torch.mean(F.relu(1.0 - logits_real)) - loss_fake = torch.mean(F.relu(1.0 + logits_fake)) - d_loss = 0.5 * (loss_real + loss_fake) - return d_loss - - -def vanilla_d_loss(logits_real, logits_fake): - d_loss = 0.5 * ( - torch.mean(torch.nn.functional.softplus(-logits_real)) + torch.mean(torch.nn.functional.softplus(logits_fake)) - ) - return d_loss - - -# from MAGVIT, used in place hof hinge_d_loss -def sigmoid_cross_entropy_with_logits(labels, logits): - # The final formulation is: max(x, 0) - x * z + log(1 + exp(-abs(x))) - zeros = torch.zeros_like(logits, dtype=logits.dtype) - condition = logits >= zeros - relu_logits = torch.where(condition, logits, zeros) - neg_abs_logits = torch.where(condition, -logits, logits) - return relu_logits - logits * labels + torch.log1p(torch.exp(neg_abs_logits)) - - def xavier_uniform_weight_init(m): if isinstance(m, nn.Conv3d) or isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain("relu")) @@ -194,10 +84,7 @@ def SameConv2d(dim_in, dim_out, kernel_size): return nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, padding=padding) -def adopt_weight(weight, global_step, threshold=0, value=0.0): - if global_step < threshold: - weight = value - return weight + class CausalConv3d(nn.Module): @@ -1053,238 +940,7 @@ class VAE_3D_V2(nn.Module): # , ModelMixin return recon_video, posterior -class VEALoss(nn.Module): - def __init__( - self, - logvar_init=0.0, - perceptual_loss_weight=0.1, - kl_loss_weight=0.000001, - # vgg=None, - # vgg_weights: VGG16_Weights = VGG16_Weights.DEFAULT, - device="cpu", - dtype="bf16", - ): - super().__init__() - if type(dtype) == str: - if dtype == "bf16": - dtype = torch.bfloat16 - elif dtype == "fp16": - dtype = torch.float16 - else: - raise NotImplementedError(f"dtype: {dtype}") - - # KL Loss - self.kl_loss_weight = kl_loss_weight - # Perceptual Loss - self.perceptual_loss_fn = LPIPS().eval().to(device, dtype) - self.perceptual_loss_weight = perceptual_loss_weight - self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) - - - def forward( - self, - video, - recon_video, - posterior, - nll_weights=None, - split="train", - ): - video = rearrange(video, "b c t h w -> (b t) c h w").contiguous() - recon_video = rearrange(recon_video, "b c t h w -> (b t) c h w").contiguous() - - # reconstruction loss - recon_loss = torch.abs(video - recon_video) - - # perceptual loss - if self.perceptual_loss_weight is not None and self.perceptual_loss_weight > 0.0: - # handle channels - channels = video.shape[1] - assert channels in {1, 3, 4} - if channels == 1: - input_vgg_input = repeat(video, "b 1 h w -> b c h w", c=3) - recon_vgg_input = repeat(recon_video, "b 1 h w -> b c h w", c=3) - elif channels == 4: # SCH: take the first 3 for perceptual loss calc - input_vgg_input = video[:, :3] - recon_vgg_input = recon_video[:, :3] - else: - input_vgg_input = video - recon_vgg_input = recon_video - - perceptual_loss = self.perceptual_loss_fn(input_vgg_input, recon_vgg_input) - recon_loss = recon_loss + self.perceptual_loss_weight * perceptual_loss - - nll_loss = recon_loss / torch.exp(self.logvar) + self.logvar - weighted_nll_loss = nll_loss - if nll_weights is not None: - weighted_nll_loss = nll_weights * nll_loss - weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] - nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] - - - # KL Loss - weighted_kl_loss = 0 - if self.kl_loss_weight is not None and self.kl_loss_weight > 0.0: - kl_loss = posterior.kl() - kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] - weighted_kl_loss = kl_loss * self.kl_loss_weight - - return nll_loss, weighted_nll_loss, weighted_kl_loss - - -class AdversarialLoss(nn.Module): - def __init__( - self, - discriminator_factor=1.0, - discriminator_start=50001, - generator_factor=0.5, - generator_loss_type="non-saturating", - ): - super().__init__() - self.discriminator_factor = discriminator_factor - self.discriminator_start = discriminator_start - self.generator_factor = generator_factor - self.generator_loss_type = generator_loss_type - - def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer): - nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] # SCH: TODO: debug added creat - g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[ - 0 - ] # SCH: TODO: debug added create_graph=True - d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) - d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() - d_weight = d_weight * self.generator_factor - return d_weight - - def forward( - self, - fake_logits, - nll_loss, - last_layer, - global_step, - is_training=True, - ): - # NOTE: following MAGVIT to allow non_saturating - assert self.generator_loss_type in ["hinge", "vanilla", "non-saturating"] - - if self.generator_loss_type == "hinge": - gen_loss = -torch.mean(fake_logits) - elif self.generator_loss_type == "non-saturating": - gen_loss = torch.mean( - sigmoid_cross_entropy_with_logits(labels=torch.ones_like(fake_logits), logits=fake_logits) - ) - else: - raise ValueError("Generator loss {} not supported".format(self.generator_loss_type)) - - if self.discriminator_factor is not None and self.discriminator_factor > 0.0: - try: - d_weight = self.calculate_adaptive_weight(nll_loss, gen_loss, last_layer) - except RuntimeError: - assert not is_training - d_weight = torch.tensor(0.0) - else: - d_weight = torch.tensor(0.0) - - disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_start) - weighted_gen_loss = d_weight * disc_factor * gen_loss - - return weighted_gen_loss - - -class LeCamEMA: - def __init__(self, ema_real=0.0, ema_fake=0.0, decay=0.999, dtype=torch.bfloat16, device="cpu"): - self.decay = decay - self.ema_real = torch.tensor(ema_real).to(device, dtype) - self.ema_fake = torch.tensor(ema_fake).to(device, dtype) - - def update(self, ema_real, ema_fake): - self.ema_real = self.ema_real * self.decay + ema_real * (1 - self.decay) - self.ema_fake = self.ema_fake * self.decay + ema_fake * (1 - self.decay) - - def get(self): - return self.ema_real, self.ema_fake - - -class DiscriminatorLoss(nn.Module): - def __init__( - self, - discriminator_factor=1.0, - discriminator_start=50001, - discriminator_loss_type="non-saturating", - lecam_loss_weight=None, - gradient_penalty_loss_weight=None, # SCH: following MAGVIT config.vqgan.grad_penalty_cost - ): - super().__init__() - - assert discriminator_loss_type in ["hinge", "vanilla", "non-saturating"] - self.discriminator_factor = discriminator_factor - self.discriminator_start = discriminator_start - self.lecam_loss_weight = lecam_loss_weight - self.gradient_penalty_loss_weight = gradient_penalty_loss_weight - self.discriminator_loss_type = discriminator_loss_type - - def forward( - self, - real_logits, - fake_logits, - global_step, - lecam_ema_real=None, - lecam_ema_fake=None, - real_video=None, - split="train", - ): - if self.discriminator_factor is not None and self.discriminator_factor > 0.0: - disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_start) - - if self.discriminator_loss_type == "hinge": - disc_loss = hinge_d_loss(real_logits, fake_logits) - elif self.discriminator_loss_type == "non-saturating": - if real_logits is not None: - real_loss = sigmoid_cross_entropy_with_logits( - labels=torch.ones_like(real_logits), logits=real_logits - ) - else: - real_loss = 0.0 - if fake_logits is not None: - fake_loss = sigmoid_cross_entropy_with_logits( - labels=torch.zeros_like(fake_logits), logits=fake_logits - ) - else: - fake_loss = 0.0 - disc_loss = 0.5 * (torch.mean(real_loss) + torch.mean(fake_loss)) - elif self.discriminator_loss_type == "vanilla": - disc_loss = vanilla_d_loss(real_logits, fake_logits) - else: - raise ValueError(f"Unknown GAN loss '{self.discriminator_loss_type}'.") - - weighted_d_adversarial_loss = disc_factor * disc_loss - - else: - weighted_d_adversarial_loss = 0 - - lecam_loss = torch.tensor(0.0) - if self.lecam_loss_weight is not None and self.lecam_loss_weight > 0.0: - real_pred = torch.mean(real_logits) - fake_pred = torch.mean(fake_logits) - lecam_loss = lecam_reg(real_pred, fake_pred, lecam_ema_real, lecam_ema_fake) - lecam_loss = lecam_loss * self.lecam_loss_weight - - gradient_penalty = torch.tensor(0.0) - if self.gradient_penalty_loss_weight is not None and self.gradient_penalty_loss_weight > 0.0: - assert real_video is not None - gradient_penalty = gradient_penalty_fn(real_video, real_logits) - # gradient_penalty = r1_penalty(real_video, real_logits) # MAGVIT uses r1 penalty - gradient_penalty *= self.gradient_penalty_loss_weight - - # discriminator_loss = weighted_d_adversarial_loss + lecam_loss + gradient_penalty - - # log = { - # "{}/d_adversarial_loss".format(split): weighted_d_adversarial_loss.detach().mean(), - # "{}/lecam_loss".format(split): lecam_loss.detach().mean(), - # "{}/gradient_penalty".format(split): gradient_penalty.detach().mean(), - # } - - return (weighted_d_adversarial_loss, lecam_loss, gradient_penalty) @MODELS.register_module("VAE_MAGVIT_V2") diff --git a/scripts/inference-vae.py b/scripts/inference-vae.py index 2fd64fe..f40bbe7 100644 --- a/scripts/inference-vae.py +++ b/scripts/inference-vae.py @@ -10,7 +10,8 @@ from tqdm import tqdm from opensora.acceleration.parallel_states import get_data_parallel_group from opensora.datasets import prepare_dataloader, save_sample -from opensora.models.vae.vae_3d_v2 import AdversarialLoss, DiscriminatorLoss, LeCamEMA, VEALoss, pad_at_dim +from opensora.models.vae.losses import AdversarialLoss, DiscriminatorLoss, VEALoss +from opensora.models.vae.vae_3d import LeCamEMA, pad_at_dim from opensora.registry import DATASETS, MODELS, build_module from opensora.utils.config_utils import parse_configs from opensora.utils.misc import to_torch_dtype @@ -27,12 +28,6 @@ def main(): colossalai.launch_from_torch({}) coordinator = DistCoordinator() - # if coordinator.world_size > 1: - # set_sequence_parallel_group(dist.group.WORLD) - # enable_sequence_parallelism = True - # else: - # enable_sequence_parallelism = False - # ====================================================== # 2. runtime variables # ====================================================== diff --git a/scripts/train-vae.py b/scripts/train-vae.py index e8a06f6..387e28a 100644 --- a/scripts/train-vae.py +++ b/scripts/train-vae.py @@ -21,7 +21,8 @@ from opensora.acceleration.parallel_states import ( ) from opensora.acceleration.plugin import ZeroSeqParallelPlugin from opensora.datasets import prepare_dataloader, prepare_variable_dataloader -from opensora.models.vae.vae_3d_v2 import AdversarialLoss, DiscriminatorLoss, LeCamEMA, VEALoss, pad_at_dim +from opensora.models.vae.vae_3d import LeCamEMA, pad_at_dim +from opensora.models.vae.losses import AdversarialLoss, DiscriminatorLoss, VEALoss from opensora.registry import DATASETS, MODELS, build_module from opensora.utils.ckpt_utils import create_logger, load_json, save_json from opensora.utils.config_utils import (