diff --git a/configs/vae_magvit_v2/train/16x128x128_pixabay_8_GPU_pipeline.py b/configs/vae_magvit_v2/train/16x128x128_pixabay_8_GPU_pipeline.py index 86f8711..31af832 100644 --- a/configs/vae_magvit_v2/train/16x128x128_pixabay_8_GPU_pipeline.py +++ b/configs/vae_magvit_v2/train/16x128x128_pixabay_8_GPU_pipeline.py @@ -1,21 +1,20 @@ -num_frames = 16 -frame_interval = 3 -image_size = (128, 128) -use_pipeline = True - -# Define dataset -root = None -data_path = "CSV_PATH" -use_image_transform = False -num_workers = 4 -video_contains_first_frame = False +dataset = dict( + type="VideoTextDataset", + data_path=None, + num_frames=16, + frame_interval=3, + image_size=(128, 128), +) # Define acceleration +num_workers = 4 dtype = "bf16" grad_checkpoint = True plugin = "zero2" sp_size = 1 +use_pipeline = True +video_contains_first_frame = False # Define model vae_2d = dict( @@ -23,50 +22,49 @@ vae_2d = dict( from_pretrained="stabilityai/sd-vae-ft-ema", # SDXL ) - model = dict( type="VAE_MAGVIT_V2", - in_out_channels = 4, - latent_embed_dim = 4, - filters = 128, - num_res_blocks = 4, - channel_multipliers = (1, 2, 2, 4), - temporal_downsample = (False, True, True), - num_groups = 32, # for nn.GroupNorm - kl_embed_dim = 4, - activation_fn = 'swish', - separate_first_frame_encoding = False, - disable_space = True, - encoder_double_z = True, - custom_conv_padding = None + in_out_channels=4, + latent_embed_dim=4, + filters=128, + num_res_blocks=4, + channel_multipliers=(1, 2, 2, 4), + temporal_downsample=(False, True, True), + num_groups=32, # for nn.GroupNorm + kl_embed_dim=4, + activation_fn="swish", + separate_first_frame_encoding=False, + disable_space=True, + encoder_double_z=True, + custom_conv_padding=None, ) discriminator = dict( type="DISCRIMINATOR_3D", - image_size = (16, 16), # NOTE: here image size is different - num_frames = num_frames, - in_channels = 4, - filters = 128, - use_pretrained=True, # NOTE: set to False only if we want to disable load + image_size=(16, 16), # NOTE: here image size is different + num_frames=16, + in_channels=4, + filters=128, + use_pretrained=True, # NOTE: set to False only if we want to disable load # channel_multipliers = (2,4,4,4,4), # (2,4,4,4) for 64x64 resolution - channel_multipliers= (2,4,4) # since on intermediate layer dimension ofs z + channel_multipliers=(2, 4, 4), # since on intermediate layer dimension ofs z ) -# loss weights -logvar_init=0.0 +# loss weights +logvar_init = 0.0 kl_loss_weight = 0.000001 -perceptual_loss_weight = 0.1 # use vgg is not None and more than 0 -discriminator_factor = 1.0 # for discriminator adversarial loss -generator_factor = 0.1 # SCH: generator adversarial loss, MAGVIT v2 uses 0.1 -lecam_loss_weight = None # NOTE: MAVGIT v2 use 0.001 -discriminator_loss_type="non-saturating" -generator_loss_type="non-saturating" +perceptual_loss_weight = 0.1 # use vgg is not None and more than 0 +discriminator_factor = 1.0 # for discriminator adversarial loss +generator_factor = 0.1 # SCH: generator adversarial loss, MAGVIT v2 uses 0.1 +lecam_loss_weight = None # NOTE: MAVGIT v2 use 0.001 +discriminator_loss_type = "non-saturating" +generator_loss_type = "non-saturating" # discriminator_loss_type="hinge" # generator_loss_type="hinge" -discriminator_start = 100 # 8k data / (8*32) = 31 steps per epoch, use around 3 epochs -gradient_penalty_loss_weight = None # 10 # SCH: MAGVIT uses 10, opensora plan doesn't use +discriminator_start = 100 # 8k data / (8*32) = 31 steps per epoch, use around 3 epochs +gradient_penalty_loss_weight = None # 10 # SCH: MAGVIT uses 10, opensora plan doesn't use ema_decay = 0.999 # ema decay factor for generator @@ -76,11 +74,11 @@ outputs = "outputs" wandb = False # Training -''' NOTE: +""" NOTE: magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128 -==> ours num_frams = 16, reso = 256, so samples (K) * epochs ~ [500 - 1200], +==> ours num_frams = 16, reso = 256, so samples (K) * epochs ~ [500 - 1200], 3-6 epochs for pexel, from pexel observation its correct -''' +""" epochs = 200 log_every = 1 diff --git a/opensora/models/vae/model_utils.py b/opensora/models/vae/model_utils.py index 421dfea..c00f886 100644 --- a/opensora/models/vae/model_utils.py +++ b/opensora/models/vae/model_utils.py @@ -1,12 +1,9 @@ -import functools -import math -from typing import Any, Optional, Sequence, Type - -import torch.nn as nn import numpy as np import torch -from taming.modules.losses.lpips import LPIPS # need to pip install https://github.com/CompVis/taming-transformers -from taming.modules.discriminator.model import NLayerDiscriminator, weights_init +import torch.nn as nn + +# from taming.modules.losses.lpips import LPIPS # need to pip install https://github.com/CompVis/taming-transformers +# from taming.modules.discriminator.model import NLayerDiscriminator, weights_init from einops import rearrange """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" @@ -17,23 +14,23 @@ from einops import rearrange # if norm_type == 'LN': # # supply a few args with partial function and pass the rest of the args when this norm_fn is called # norm_fn = functools.partial(nn.LayerNorm, dtype=dtype) -# elif norm_type == 'GN': # +# elif norm_type == 'GN': # # norm_fn = functools.partial(nn.GroupNorm, dtype=dtype) # elif norm_type is None: # norm_fn = lambda: (lambda x: x) # else: # raise NotImplementedError(f'norm_type: {norm_type}') # return norm_fn - + class DiagonalGaussianDistribution(object): def __init__( - self, - parameters, + self, + parameters, deterministic=False, ): self.parameters = parameters - self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) # SCH: channels dim + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) # SCH: channels dim self.logvar = torch.clamp(self.logvar, -30.0, 20.0) self.deterministic = deterministic self.std = torch.exp(0.5 * self.logvar) @@ -48,40 +45,39 @@ class DiagonalGaussianDistribution(object): def kl(self, other=None): if self.deterministic: - return torch.Tensor([0.]) + return torch.Tensor([0.0]) else: - if other is None: # SCH: assumes other is a standard normal distribution - return 0.5 * torch.sum(torch.pow(self.mean, 2) - + self.var - 1.0 - self.logvar, - dim=[1, 2, 3, 4]) + if other is None: # SCH: assumes other is a standard normal distribution + return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3, 4]) else: return 0.5 * torch.sum( torch.pow(self.mean - other.mean, 2) / other.var - + self.var / other.var - 1.0 - self.logvar + other.logvar, - dim=[1, 2, 3, 4]) + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3, 4], + ) - def nll(self, sample, dims=[1,2,3,4]): # TODO: what does this do? + def nll(self, sample, dims=[1, 2, 3, 4]): # TODO: what does this do? if self.deterministic: - return torch.Tensor([0.]) + return torch.Tensor([0.0]) logtwopi = np.log(2.0 * np.pi) - return 0.5 * torch.sum( - logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, - dim=dims) + return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) - def mode(self): # SCH: used for vae inference? + def mode(self): # SCH: used for vae inference? return self.mean - + class VEA3DLoss(nn.Module): def __init__( self, - # disc_start, - logvar_init=0.0, - kl_weight=1.0, + # disc_start, + logvar_init=0.0, + kl_weight=1.0, pixelloss_weight=1.0, - perceptual_weight=0.1, + perceptual_weight=0.1, disc_loss="hinge", - ): super().__init__() assert disc_loss in ["hinge", "vanilla"] @@ -92,28 +88,27 @@ class VEA3DLoss(nn.Module): # output log variance self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) - def forward( self, inputs, reconstructions, posteriors, # optimizer_idx, - # global_step, + # global_step, weights=None, ): rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) - if self.perceptual_weight > 0: # NOTE: need in_channels == 3 in order to use! + if self.perceptual_weight > 0: # NOTE: need in_channels == 3 in order to use! assert inputs.size(1) == 3, f"using vgg16 that requires 3 input channels but got {inputs.size(1)}" # SCH: transform to [(B,T), C, H, W] shape for percetual loss over each frame B = inputs.shape[0] - inputs = rearrange(inputs,"B C T H W -> (B T) C H W") + inputs = rearrange(inputs, "B C T H W -> (B T) C H W") reconstructions = rearrange(reconstructions, "B C T H W -> (B T) C H W") # permutated_input = torch.permute(inputs, (0, 2, 1, 3, 4)) # [B, C, T, H, W] --> [B, T, C, H, W] # permutated_rec = torch.permute(reconstructions, (0, 2, 1, 3, 4)) # data_shape = permutated_input.size() # p_loss = self.perceptual_loss( - # permutated_input.reshape(-1, data_shape[-3], data_shape[-2],data_shape[-1]).contiguous(), + # permutated_input.reshape(-1, data_shape[-3], data_shape[-2],data_shape[-1]).contiguous(), # permutated_rec.reshape(-1, data_shape[-3], data_shape[-2],data_shape[-1]).contiguous() # ) p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) @@ -126,32 +121,32 @@ class VEA3DLoss(nn.Module): nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar weighted_nll_loss = nll_loss if weights is not None: - weighted_nll_loss = weights*nll_loss + weighted_nll_loss = weights * nll_loss weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] kl_loss = posteriors.kl() kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] - loss = weighted_nll_loss + self.kl_weight * kl_loss # TODO: add discriminator loss later + loss = weighted_nll_loss + self.kl_weight * kl_loss # TODO: add discriminator loss later return loss + class VEA3DLossWithDiscriminator(nn.Module): def __init__( self, - # disc_start, - logvar_init=0.0, - kl_weight=1.0, + # disc_start, + logvar_init=0.0, + kl_weight=1.0, pixelloss_weight=1.0, - disc_num_layers=3, - disc_in_channels=3, - disc_factor=1.0, + disc_num_layers=3, + disc_in_channels=3, + disc_factor=1.0, disc_weight=1.0, - perceptual_weight=1.0, - use_actnorm=False, + perceptual_weight=1.0, + use_actnorm=False, disc_conditional=False, disc_loss="hinge", - ): super().__init__() assert disc_loss in ["hinge", "vanilla"] @@ -185,53 +180,53 @@ class VEA3DLossWithDiscriminator(nn.Module): # d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() # d_weight = d_weight * self.discriminator_weight # return d_weight - + def forward( self, inputs, reconstructions, posteriors, # optimizer_idx, - # global_step, - last_layer=None, - cond=None, + # global_step, + last_layer=None, + cond=None, split="train", weights=None, ): rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) - if self.perceptual_weight > 0: # NOTE: need in_channels == 3 in order to use! - assert inputs.size(1) == 3, f"using vgg16 that requires 3 input channels but got {inputs.size(1)} " + if self.perceptual_weight > 0: # NOTE: need in_channels == 3 in order to use! + assert inputs.size(1) == 3, f"using vgg16 that requires 3 input channels but got {inputs.size(1)} " # SCH: transform to [(B,T), C, H, W] shape for percetual loss over each frame - permutated_input = torch.permute(inputs, (0, 2, 1, 3, 4)) # [B, C, T, H, W] --> [B, T, C, H, W] + permutated_input = torch.permute(inputs, (0, 2, 1, 3, 4)) # [B, C, T, H, W] --> [B, T, C, H, W] permutated_rec = torch.permute(reconstructions, (0, 2, 1, 3, 4)) data_shape = permutated_input.size() p_loss = self.perceptual_loss( - permutated_input.reshape(-1, data_shape[-3], data_shape[-2],data_shape[-1]).contiguous(), - permutated_rec.reshape(-1, data_shape[-3], data_shape[-2],data_shape[-1]).contiguous() + permutated_input.reshape(-1, data_shape[-3], data_shape[-2], data_shape[-1]).contiguous(), + permutated_rec.reshape(-1, data_shape[-3], data_shape[-2], data_shape[-1]).contiguous(), ) # SCH: shape back p_loss - permuted_p_loss = torch.permute(p_loss.reshape(data_shape[0], data_shape[1], 1, 1, 1), (0,2,1,3,4)) + permuted_p_loss = torch.permute(p_loss.reshape(data_shape[0], data_shape[1], 1, 1, 1), (0, 2, 1, 3, 4)) rec_loss = rec_loss + self.perceptual_weight * permuted_p_loss nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar weighted_nll_loss = nll_loss if weights is not None: - weighted_nll_loss = weights*nll_loss + weighted_nll_loss = weights * nll_loss weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] kl_loss = posteriors.kl() kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] - loss = weighted_nll_loss + self.kl_weight * kl_loss # TODO: add discriminator loss later + loss = weighted_nll_loss + self.kl_weight * kl_loss # TODO: add discriminator loss later - # log = {"{}/total_loss".format(split): loss.clone().detach().mean(), + # log = {"{}/total_loss".format(split): loss.clone().detach().mean(), # "{}/logvar".format(split): self.logvar.detach(), - # "{}/kl_loss".format(split): kl_loss.detach().mean(), + # "{}/kl_loss".format(split): kl_loss.detach().mean(), # "{}/nll_loss".format(split): nll_loss.detach().mean(), # "{}/rec_loss".format(split): rec_loss.detach().mean(), # # "{}/d_weight".format(split): d_weight.detach(), # # "{}/disc_factor".format(split): torch.tensor(disc_factor), # # "{}/g_loss".format(split): g_loss.detach().mean(), # } - - return loss \ No newline at end of file + + return loss diff --git a/opensora/models/vae/vae_3d_v2.py b/opensora/models/vae/vae_3d_v2.py index 85b5e4a..c220b7b 100644 --- a/opensora/models/vae/vae_3d_v2.py +++ b/opensora/models/vae/vae_3d_v2.py @@ -1,21 +1,19 @@ import functools -from typing import Any, Dict, Tuple, Type, Union, Sequence, Optional -from absl import logging -import torch -import torch.nn as nn -import numpy as np -from numpy import typing as nptyping -from opensora.models.vae import model_utils -from opensora.registry import MODELS -from opensora.utils.ckpt_utils import load_checkpoint, find_model -from einops import rearrange, repeat, pack, unpack -import torch.nn.functional as F -import torchvision -from torchvision.models import VGG16_Weights -from opensora.models.vae.lpips import LPIPS # need to pip install https://github.com/CompVis/taming-transformers -from torch import nn import math -import os +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import pack, rearrange, repeat, unpack +from torch import nn + +from opensora.models.vae import model_utils +from opensora.models.vae.lpips import LPIPS # need to pip install https://github.com/CompVis/taming-transformers +from opensora.registry import MODELS +from opensora.utils.ckpt_utils import find_model, load_checkpoint + # from diffusers.models.modeling_utils import ModelMixin @@ -30,58 +28,68 @@ NOTE: !!! opensora read video into [B,C,T,H,W] format output """ -def cast_tuple(t, length = 1): + + +def cast_tuple(t, length=1): return t if isinstance(t, tuple) else ((t,) * length) + def divisible_by(num, den): return (num % den) == 0 + def is_odd(n): return not divisible_by(n, 2) -def pad_at_dim(t, pad, dim = -1, value = 0.): - dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1) - zeros = ((0, 0) * dims_from_right) - return F.pad(t, (*zeros, *pad), value = value) + +def pad_at_dim(t, pad, dim=-1, value=0.0): + dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1) + zeros = (0, 0) * dims_from_right + return F.pad(t, (*zeros, *pad), value=value) + def pick_video_frame(video, frame_indices): """get frame_indices from the video of [B, C, T, H, W] and return images of [B, C, H, W]""" batch, device = video.shape[0], video.device - video = rearrange(video, 'b c f ... -> b f c ...') - batch_indices = torch.arange(batch, device = device) - batch_indices = rearrange(batch_indices, 'b -> b 1') + video = rearrange(video, "b c f ... -> b f c ...") + batch_indices = torch.arange(batch, device=device) + batch_indices = rearrange(batch_indices, "b -> b 1") images = video[batch_indices, frame_indices] - images = rearrange(images, 'b 1 c ... -> b c ...') + images = rearrange(images, "b 1 c ... -> b c ...") return images + def exists(v): return v is not None + # ============== Generator Adversarial Loss Functions ============== def lecam_reg(real_pred, fake_pred, ema_real_pred, ema_fake_pred): - assert real_pred.ndim == 0 and ema_fake_pred.ndim == 0 - lecam_loss = torch.mean(torch.pow(nn.ReLU()(real_pred - ema_fake_pred), 2)) - lecam_loss += torch.mean(torch.pow(nn.ReLU()(ema_real_pred - fake_pred), 2)) - return lecam_loss + assert real_pred.ndim == 0 and ema_fake_pred.ndim == 0 + lecam_loss = torch.mean(torch.pow(nn.ReLU()(real_pred - ema_fake_pred), 2)) + lecam_loss += torch.mean(torch.pow(nn.ReLU()(ema_real_pred - fake_pred), 2)) + return lecam_loss + # Open-Sora-Plan # Very bad, do not use def r1_penalty(real_img, real_pred): """R1 regularization for discriminator. The core idea is to - penalize the gradient on real data alone: when the - generator distribution produces the true data distribution - and the discriminator is equal to 0 on the data manifold, the - gradient penalty ensures that the discriminator cannot create - a non-zero gradient orthogonal to the data manifold without - suffering a loss in the GAN game. + penalize the gradient on real data alone: when the + generator distribution produces the true data distribution + and the discriminator is equal to 0 on the data manifold, the + gradient penalty ensures that the discriminator cannot create + a non-zero gradient orthogonal to the data manifold without + suffering a loss in the GAN game. - Ref: - Eq. 9 in Which training methods for GANs do actually converge. - """ + Ref: + Eq. 9 in Which training methods for GANs do actually converge. + """ grad_real = torch.autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0] grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean() return grad_penalty + # Open-Sora-Plan # Implementation as described by https://arxiv.org/abs/1704.00028 # TODO: checkout the codes def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None): @@ -101,7 +109,7 @@ def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None): alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1)) # interpolate between real_data and fake_data - interpolates = alpha * real_data + (1. - alpha) * fake_data + interpolates = alpha * real_data + (1.0 - alpha) * fake_data interpolates = torch.autograd.Variable(interpolates, requires_grad=True) disc_interpolates = discriminator(interpolates) @@ -111,12 +119,13 @@ def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None): grad_outputs=torch.ones_like(disc_interpolates), create_graph=True, retain_graph=True, - only_inputs=True)[0] + only_inputs=True, + )[0] if weight is not None: gradients = gradients * weight - gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean() + gradients_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() if weight is not None: gradients_penalty /= torch.mean(weight) @@ -126,16 +135,16 @@ def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None): def gradient_penalty_fn(images, output): # batch_size = images.shape[0] gradients = torch.autograd.grad( - outputs = output, - inputs = images, - grad_outputs = torch.ones(output.size(), device = images.device), - create_graph = True, - retain_graph = True, - only_inputs = True + outputs=output, + inputs=images, + grad_outputs=torch.ones(output.size(), device=images.device), + create_graph=True, + retain_graph=True, + only_inputs=True, )[0] - gradients = rearrange(gradients, 'b ... -> b (...)') - return ((gradients.norm(2, dim = 1) - 1) ** 2).mean() + gradients = rearrange(gradients, "b ... -> b (...)") + return ((gradients.norm(2, dim=1) - 1) ** 2).mean() # ============== Discriminator Adversarial Loss Functions ============== @@ -145,31 +154,32 @@ def hinge_d_loss(logits_real, logits_fake): d_loss = 0.5 * (loss_real + loss_fake) return d_loss + def vanilla_d_loss(logits_real, logits_fake): d_loss = 0.5 * ( - torch.mean(torch.nn.functional.softplus(-logits_real)) + - torch.mean(torch.nn.functional.softplus(logits_fake))) + torch.mean(torch.nn.functional.softplus(-logits_real)) + torch.mean(torch.nn.functional.softplus(logits_fake)) + ) return d_loss + # from MAGVIT, used in place hof hinge_d_loss def sigmoid_cross_entropy_with_logits(labels, logits): # The final formulation is: max(x, 0) - x * z + log(1 + exp(-abs(x))) zeros = torch.zeros_like(logits, dtype=logits.dtype) - condition = (logits >= zeros) + condition = logits >= zeros relu_logits = torch.where(condition, logits, zeros) neg_abs_logits = torch.where(condition, -logits, logits) return relu_logits - logits * labels + torch.log1p(torch.exp(neg_abs_logits)) - - def xavier_uniform_weight_init(m): if isinstance(m, nn.Conv3d) or isinstance(m, nn.Linear): - nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu')) + nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain("relu")) if m.bias is not None: nn.init.zeros_(m.bias) # print("initialized module to xavier_uniform:", m) + def Sequential(*modules): modules = [*filter(exists, modules)] @@ -178,26 +188,28 @@ def Sequential(*modules): return nn.Sequential(*modules) + def SameConv2d(dim_in, dim_out, kernel_size): kernel_size = cast_tuple(kernel_size, 2) padding = [k // 2 for k in kernel_size] - return nn.Conv2d(dim_in, dim_out, kernel_size = kernel_size, padding = padding) + return nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, padding=padding) -def adopt_weight(weight, global_step, threshold=0, value=0.): +def adopt_weight(weight, global_step, threshold=0, value=0.0): if global_step < threshold: weight = value return weight + class CausalConv3d(nn.Module): def __init__( self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], - pad_mode = 'constant', - strides = None, # allow custom stride - **kwargs + pad_mode="constant", + strides=None, # allow custom stride + **kwargs, ): super().__init__() kernel_size = cast_tuple(kernel_size, 3) @@ -206,8 +218,8 @@ class CausalConv3d(nn.Module): assert is_odd(height_kernel_size) and is_odd(width_kernel_size) - dilation = kwargs.pop('dilation', 1) - stride = strides[0] if strides is not None else kwargs.pop('stride', 1) + dilation = kwargs.pop("dilation", 1) + stride = strides[0] if strides is not None else kwargs.pop("stride", 1) self.pad_mode = pad_mode time_pad = dilation * (time_kernel_size - 1) + (1 - stride) @@ -222,33 +234,33 @@ class CausalConv3d(nn.Module): # if padding == "same" and not all([pad == 1 for pad in padding]): # padding = "valid" dilation = (dilation, 1, 1) - self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride = stride, dilation = dilation, **kwargs) + self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) def forward(self, x): - pad_mode = self.pad_mode if self.time_pad < x.shape[2] else 'constant' + pad_mode = self.pad_mode if self.time_pad < x.shape[2] else "constant" - x = F.pad(x, self.time_causal_padding, mode = pad_mode) + x = F.pad(x, self.time_causal_padding, mode=pad_mode) return self.conv(x) class ResBlock(nn.Module): def __init__( - self, - in_channels, # SCH: added - filters, - conv_fn, - activation_fn=nn.SiLU, - use_conv_shortcut=False, - num_groups=32, - device="cpu", - dtype=torch.bfloat16, + self, + in_channels, # SCH: added + filters, + conv_fn, + activation_fn=nn.SiLU, + use_conv_shortcut=False, + num_groups=32, + device="cpu", + dtype=torch.bfloat16, ): super().__init__() self.in_channels = in_channels self.filters = filters self.activate = activation_fn() self.use_conv_shortcut = use_conv_shortcut - + # SCH: MAGVIT uses GroupNorm by default self.norm1 = nn.GroupNorm(num_groups, in_channels, device=device, dtype=dtype) self.conv1 = conv_fn(in_channels, self.filters, kernel_size=(3, 3, 3), bias=False) @@ -256,11 +268,10 @@ class ResBlock(nn.Module): self.conv2 = conv_fn(self.filters, self.filters, kernel_size=(3, 3, 3), bias=False) if in_channels != filters: if self.use_conv_shortcut: - self.conv3 = conv_fn(in_channels, self.filters, kernel_size=(3, 3, 3), bias=False) + self.conv3 = conv_fn(in_channels, self.filters, kernel_size=(3, 3, 3), bias=False) else: self.conv3 = conv_fn(in_channels, self.filters, kernel_size=(1, 1, 1), bias=False) - def forward(self, x): # device, dtype = x.device, x.dtype # input_dim = x.shape[1] @@ -271,79 +282,88 @@ class ResBlock(nn.Module): x = self.norm2(x) x = self.activate(x) x = self.conv2(x) - if self.in_channels != self.filters: # SCH: ResBlock X->Y + if self.in_channels != self.filters: # SCH: ResBlock X->Y residual = self.conv3(residual) - return x + residual - + return x + residual + + # SCH: own implementation modified on top of: discriminator with anti-aliased downsampling (blurpool Zhang et al.) class BlurPool3D(nn.Module): def __init__( - self, - channels, - pad_type='reflect', - filt_size=3, - stride=2, - pad_off=0, - device="cpu", - dtype=torch.bfloat16, + self, + channels, + pad_type="reflect", + filt_size=3, + stride=2, + pad_off=0, + device="cpu", + dtype=torch.bfloat16, ): super(BlurPool3D, self).__init__() self.filt_size = filt_size self.pad_off = pad_off self.pad_sizes = [ - int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)), - int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)), - int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)) + int(1.0 * (filt_size - 1) / 2), + int(np.ceil(1.0 * (filt_size - 1) / 2)), + int(1.0 * (filt_size - 1) / 2), + int(np.ceil(1.0 * (filt_size - 1) / 2)), + int(1.0 * (filt_size - 1) / 2), + int(np.ceil(1.0 * (filt_size - 1) / 2)), ] - self.pad_sizes = [pad_size+pad_off for pad_size in self.pad_sizes] + self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes] self.stride = stride - self.off = int((self.stride-1)/2.) + self.off = int((self.stride - 1) / 2.0) self.channels = channels - if(self.filt_size==1): - a = np.array([1.,]) - elif(self.filt_size==2): - a = np.array([1., 1.]) - elif(self.filt_size==3): - a = np.array([1., 2., 1.]) - elif(self.filt_size==4): - a = np.array([1., 3., 3., 1.]) - elif(self.filt_size==5): - a = np.array([1., 4., 6., 4., 1.]) - elif(self.filt_size==6): - a = np.array([1., 5., 10., 10., 5., 1.]) - elif(self.filt_size==7): - a = np.array([1., 6., 15., 20., 15., 6., 1.]) + if self.filt_size == 1: + a = np.array( + [ + 1.0, + ] + ) + elif self.filt_size == 2: + a = np.array([1.0, 1.0]) + elif self.filt_size == 3: + a = np.array([1.0, 2.0, 1.0]) + elif self.filt_size == 4: + a = np.array([1.0, 3.0, 3.0, 1.0]) + elif self.filt_size == 5: + a = np.array([1.0, 4.0, 6.0, 4.0, 1.0]) + elif self.filt_size == 6: + a = np.array([1.0, 5.0, 10.0, 10.0, 5.0, 1.0]) + elif self.filt_size == 7: + a = np.array([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0]) - filt_2d = a[:,None]*a[None,:] + filt_2d = a[:, None] * a[None, :] filt_3d = torch.Tensor(a[:, None, None] * filt_2d[None, :, :]).to(device, dtype) - - filt = filt_3d/torch.sum(filt_3d) # SCH: modified to it 3D - self.register_buffer('filt', filt[None,None,:,:,:].repeat((self.channels,1,1,1,1))) + + filt = filt_3d / torch.sum(filt_3d) # SCH: modified to it 3D + self.register_buffer("filt", filt[None, None, :, :, :].repeat((self.channels, 1, 1, 1, 1))) self.pad = get_pad_layer(pad_type)(self.pad_sizes) def forward(self, inp): - if(self.filt_size==1): - if(self.pad_off==0): - return inp[:,:,::self.stride,::self.stride] + if self.filt_size == 1: + if self.pad_off == 0: + return inp[:, :, :: self.stride, :: self.stride] else: - return self.pad(inp)[:,:,::self.stride,::self.stride] + return self.pad(inp)[:, :, :: self.stride, :: self.stride] else: return F.conv3d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]) - + + def get_pad_layer(pad_type): - if(pad_type in ['refl','reflect']): + if pad_type in ["refl", "reflect"]: PadLayer = nn.ReflectionPad3d - elif(pad_type in ['repl','replicate']): + elif pad_type in ["repl", "replicate"]: PadLayer = nn.ReplicationPad3d - elif(pad_type=='zero'): + elif pad_type == "zero": PadLayer = nn.ZeroPad3d else: - print('Pad type [%s] not recognized'%pad_type) + print("Pad type [%s] not recognized" % pad_type) return PadLayer - + class ResBlockDown(nn.Module): """3D StyleGAN ResBlock for D.""" @@ -362,13 +382,19 @@ class ResBlockDown(nn.Module): self.activation_fn = activation_fn # SCH: NOTE: although paper says conv (X->Y, Y->Y), original code implementation is (X->X, X->Y), we follow code - self.conv1 = nn.Conv3d(in_channels, in_channels, (3,3,3), padding=1, device=device, dtype=dtype) # NOTE: init to xavier_uniform + self.conv1 = nn.Conv3d( + in_channels, in_channels, (3, 3, 3), padding=1, device=device, dtype=dtype + ) # NOTE: init to xavier_uniform self.norm1 = nn.GroupNorm(num_groups, in_channels, device=device, dtype=dtype) self.blur = BlurPool3D(in_channels, device=device, dtype=dtype) - self.conv2 = nn.Conv3d(in_channels, self.filters,(1,1,1), bias=False, device=device, dtype=dtype) # NOTE: init to xavier_uniform - self.conv3 = nn.Conv3d(in_channels, self.filters, (3,3,3), padding=1, device=device, dtype=dtype) # NOTE: init to xavier_uniform + self.conv2 = nn.Conv3d( + in_channels, self.filters, (1, 1, 1), bias=False, device=device, dtype=dtype + ) # NOTE: init to xavier_uniform + self.conv3 = nn.Conv3d( + in_channels, self.filters, (3, 3, 3), padding=1, device=device, dtype=dtype + ) # NOTE: init to xavier_uniform self.norm2 = nn.GroupNorm(num_groups, self.filters, device=device, dtype=dtype) # self.apply(xavier_uniform_weight_init) @@ -393,14 +419,16 @@ class ResBlockDown(nn.Module): # SCH: taken from Open Sora Plan def n_layer_disc_weights_init(m): classname = m.__class__.__name__ - if classname.find('Conv') != -1: + if classname.find("Conv") != -1: nn.init.normal_(m.weight.data, 0.0, 0.02) - elif classname.find('BatchNorm') != -1: + elif classname.find("BatchNorm") != -1: nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0) + class NLayerDiscriminator3D(nn.Module): """Defines a 3D PatchGAN discriminator as in Pix2Pix but for 3D inputs.""" + def __init__(self, input_nc=1, ndf=64, n_layers=3, use_actnorm=False): """ Construct a 3D PatchGAN discriminator @@ -428,46 +456,60 @@ class NLayerDiscriminator3D(nn.Module): nf_mult_prev = 1 for n in range(1, n_layers): # gradually increase the number of filters nf_mult_prev = nf_mult - nf_mult = min(2 ** n, 8) + nf_mult = min(2**n, 8) sequence += [ - nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=(1,2,2), padding=padw, bias=use_bias), + nn.Conv3d( + ndf * nf_mult_prev, + ndf * nf_mult, + kernel_size=(kw, kw, kw), + stride=(1, 2, 2), + padding=padw, + bias=use_bias, + ), norm_layer(ndf * nf_mult), - nn.LeakyReLU(0.2, True) + nn.LeakyReLU(0.2, True), ] nf_mult_prev = nf_mult - nf_mult = min(2 ** n_layers, 8) + nf_mult = min(2**n_layers, 8) sequence += [ - nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=1, padding=padw, bias=use_bias), + nn.Conv3d( + ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=1, padding=padw, bias=use_bias + ), norm_layer(ndf * nf_mult), - nn.LeakyReLU(0.2, True) + nn.LeakyReLU(0.2, True), ] - sequence += [nn.Conv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map + sequence += [ + nn.Conv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) + ] # output 1 channel prediction map self.main = nn.Sequential(*sequence) def forward(self, input): """Standard forward.""" return self.main(input) - + + class StyleGANDiscriminatorBlur(nn.Module): """StyleGAN Discriminator.""" + """ - SCH: NOTE: + SCH: NOTE: this discriminator requries the num_frames to be fixed during training; - in case we pre-train with image then train on video, this disciminator's Linear layer would have to be re-trained! + in case we pre-train with image then train on video, this disciminator's Linear layer would have to be re-trained! """ + def __init__( self, - image_size = (128, 128), - num_frames = 17, - in_channels = 3, - filters = 128, - channel_multipliers = (2,4,4,4,4), + image_size=(128, 128), + num_frames=17, + in_channels=3, + filters=128, + channel_multipliers=(2, 4, 4, 4, 4), num_groups=32, - dtype = torch.bfloat16, + dtype=torch.bfloat16, device="cpu", - ): + ): super().__init__() self.dtype = dtype @@ -476,42 +518,53 @@ class StyleGANDiscriminatorBlur(nn.Module): self.activation_fn = nn.LeakyReLU(negative_slope=0.2) self.channel_multipliers = channel_multipliers - self.conv1 = nn.Conv3d(in_channels, self.filters, (3, 3, 3), padding=1, device=device, dtype=dtype) # NOTE: init to xavier_uniform + self.conv1 = nn.Conv3d( + in_channels, self.filters, (3, 3, 3), padding=1, device=device, dtype=dtype + ) # NOTE: init to xavier_uniform - prev_filters = self.filters # record in_channels + prev_filters = self.filters # record in_channels self.num_blocks = len(self.channel_multipliers) self.res_block_list = [] for i in range(self.num_blocks): filters = self.filters * self.channel_multipliers[i] - self.res_block_list.append(ResBlockDown(prev_filters, filters, self.activation_fn, device=device, dtype=dtype).apply(xavier_uniform_weight_init)) - prev_filters = filters # update in_channels + self.res_block_list.append( + ResBlockDown(prev_filters, filters, self.activation_fn, device=device, dtype=dtype).apply( + xavier_uniform_weight_init + ) + ) + prev_filters = filters # update in_channels - self.conv2 = nn.Conv3d(prev_filters, prev_filters, (3,3,3), padding=1, device=device, dtype=dtype) # NOTE: init to xavier_uniform + self.conv2 = nn.Conv3d( + prev_filters, prev_filters, (3, 3, 3), padding=1, device=device, dtype=dtype + ) # NOTE: init to xavier_uniform # torch.nn.init.xavier_uniform_(self.conv2.weight) self.norm1 = nn.GroupNorm(num_groups, prev_filters, dtype=dtype, device=device) - scale_factor = 2 ** self.num_blocks - if num_frames % scale_factor != 0: # SCH: NOTE: has first frame which would be padded before usage + scale_factor = 2**self.num_blocks + if num_frames % scale_factor != 0: # SCH: NOTE: has first frame which would be padded before usage time_scaled = num_frames // scale_factor + 1 else: time_scaled = num_frames / scale_factor - - assert self.input_size[0] % scale_factor == 0, f"image width {self.input_size[0]} is not divisible by scale factor {scale_factor}" - assert self.input_size[1] % scale_factor == 0, f"image height {self.input_size[1]} is not divisible by scale factor {scale_factor}" + + assert ( + self.input_size[0] % scale_factor == 0 + ), f"image width {self.input_size[0]} is not divisible by scale factor {scale_factor}" + assert ( + self.input_size[1] % scale_factor == 0 + ), f"image height {self.input_size[1]} is not divisible by scale factor {scale_factor}" w_scaled, h_scaled = self.input_size[0] / scale_factor, self.input_size[1] / scale_factor in_features = int(prev_filters * time_scaled * w_scaled * h_scaled) # (C*T*W*H) - self.linear1 = nn.Linear(in_features, prev_filters, device=device, dtype=dtype) # NOTE: init to xavier_uniform - self.linear2 = nn.Linear(prev_filters, 1, device=device, dtype=dtype) # NOTE: init to xavier_uniform + self.linear1 = nn.Linear(in_features, prev_filters, device=device, dtype=dtype) # NOTE: init to xavier_uniform + self.linear2 = nn.Linear(prev_filters, 1, device=device, dtype=dtype) # NOTE: init to xavier_uniform # self.apply(xavier_uniform_weight_init) def forward(self, x): - x = self.conv1(x) # print("discriminator aft conv:", x.size()) x = self.activation_fn(x) - + for i in range(self.num_blocks): x = self.res_block_list[i](x) # print("discriminator resblock down:", x.size()) @@ -520,7 +573,7 @@ class StyleGANDiscriminatorBlur(nn.Module): # print("discriminator aft conv2:", x.size()) x = self.norm1(x) x = self.activation_fn(x) - x = x.reshape((x.shape[0], -1)) # SCH: [B, (C * T * W * H)] ? + x = x.reshape((x.shape[0], -1)) # SCH: [B, (C * T * W * H)] ? # print("discriminator reshape:", x.size()) x = self.linear1(x) @@ -530,21 +583,24 @@ class StyleGANDiscriminatorBlur(nn.Module): x = self.linear2(x) # print("discriminator aft linear2:", x.size()) return x - + + class Encoder(nn.Module): """Encoder Blocks.""" - def __init__(self, - filters = 128, - num_res_blocks = 4, - channel_multipliers = (1, 2, 2, 4), - temporal_downsample = (False, True, True), - num_groups = 32, # for nn.GroupNorm + + def __init__( + self, + filters=128, + num_res_blocks=4, + channel_multipliers=(1, 2, 2, 4), + temporal_downsample=(False, True, True), + num_groups=32, # for nn.GroupNorm # in_out_channels = 3, # SCH: added, in_channels at the start - latent_embed_dim = 512, # num channels for latent vector - # conv_downsample = False, - disable_spatial_downsample = False, # for vae pipeline - custom_conv_padding = None, - activation_fn = 'swish', + latent_embed_dim=512, # num channels for latent vector + # conv_downsample = False, + disable_spatial_downsample=False, # for vae pipeline + custom_conv_padding=None, + activation_fn="swish", device="cpu", dtype=torch.bfloat16, ): @@ -554,15 +610,15 @@ class Encoder(nn.Module): self.channel_multipliers = channel_multipliers self.temporal_downsample = temporal_downsample self.num_groups = num_groups - + self.embedding_dim = latent_embed_dim self.disable_spatial_downsample = disable_spatial_downsample # self.conv_downsample = conv_downsample self.custom_conv_padding = custom_conv_padding - if activation_fn == 'relu': + if activation_fn == "relu": self.activation_fn = nn.ReLU - elif activation_fn == 'swish': + elif activation_fn == "swish": self.activation_fn = nn.SiLU else: raise NotImplementedError @@ -574,7 +630,7 @@ class Encoder(nn.Module): dtype=dtype, device=device, ) - + self.block_args = dict( conv_fn=self.conv_fn, dtype=dtype, @@ -583,7 +639,7 @@ class Encoder(nn.Module): num_groups=self.num_groups, device=device, ) - + # NOTE: moved to VAE for separate first frame processing # self.conv1 = self.conv_fn(in_out_channels, self.filters, kernel_size=(3, 3, 3), bias=False) @@ -593,33 +649,38 @@ class Encoder(nn.Module): self.conv_blocks = [] filters = self.filters - prev_filters = filters # record for in_channels + prev_filters = filters # record for in_channels for i in range(self.num_blocks): # resblock handling - filters = self.filters * self.channel_multipliers[i] # SCH: determine the number out_channels + filters = self.filters * self.channel_multipliers[i] # SCH: determine the number out_channels block_items = [] for _ in range(self.num_res_blocks): block_items.append(ResBlock(prev_filters, filters, **self.block_args)) - prev_filters = filters # update in_channels + prev_filters = filters # update in_channels self.block_res_blocks.append(block_items) - - if i < self.num_blocks - 1: # SCH: T-Causal Conv 3x3x3, 128->128, stride t x stride s x stride s + + if i < self.num_blocks - 1: # SCH: T-Causal Conv 3x3x3, 128->128, stride t x stride s x stride s t_stride = 2 if self.temporal_downsample[i] else 1 s_stride = 2 if not self.disable_spatial_downsample else 1 - self.conv_blocks.append(self.conv_fn(prev_filters, filters, kernel_size=(3, 3, 3), strides=(t_stride, s_stride, s_stride))) # SCH: should be same in_channel and out_channel - prev_filters = filters # update in_channels - + self.conv_blocks.append( + self.conv_fn(prev_filters, filters, kernel_size=(3, 3, 3), strides=(t_stride, s_stride, s_stride)) + ) # SCH: should be same in_channel and out_channel + prev_filters = filters # update in_channels # last layer res block self.res_blocks = [] for _ in range(self.num_res_blocks): self.res_blocks.append(ResBlock(prev_filters, filters, **self.block_args)) - prev_filters = filters # update in_channels + prev_filters = filters # update in_channels # MAGVIT uses Group Normalization - self.norm1 = nn.GroupNorm(self.num_groups, prev_filters, dtype=dtype, device=device) # SCH: separate channels into 32 groups + self.norm1 = nn.GroupNorm( + self.num_groups, prev_filters, dtype=dtype, device=device + ) # SCH: separate channels into 32 groups - self.conv2 = nn.Conv3d(prev_filters, self.embedding_dim, kernel_size=(1, 1, 1), dtype=dtype, device=device, padding="same") + self.conv2 = nn.Conv3d( + prev_filters, self.embedding_dim, kernel_size=(1, 1, 1), dtype=dtype, device=device, padding="same" + ) def forward(self, x): # dtype, device = x.dtype, x.device @@ -637,7 +698,6 @@ class Encoder(nn.Module): if i < self.num_blocks - 1: x = self.conv_blocks[i](x) # print("encoder:", x.size()) - for i in range(self.num_res_blocks): x = self.res_blocks[i](x) @@ -648,21 +708,24 @@ class Encoder(nn.Module): x = self.conv2(x) # print("encoder:", x.size()) return x - + + class Decoder(nn.Module): """Decoder Blocks.""" - def __init__(self, - latent_embed_dim = 512, - filters = 128, + + def __init__( + self, + latent_embed_dim=512, + filters=128, # in_out_channels = 4, - num_res_blocks = 4, - channel_multipliers = (1, 2, 2, 4), - temporal_downsample = (False, True, True), - num_groups = 32, # for nn.GroupNorm + num_res_blocks=4, + channel_multipliers=(1, 2, 2, 4), + temporal_downsample=(False, True, True), + num_groups=32, # for nn.GroupNorm # upsample = "nearest+conv", # options: "deconv", "nearest+conv" - disable_spatial_upsample = False, # for vae pipeline - custom_conv_padding = None, - activation_fn = 'swish', + disable_spatial_upsample=False, # for vae pipeline + custom_conv_padding=None, + activation_fn="swish", device="cpu", dtype=torch.bfloat16, ): @@ -674,19 +737,19 @@ class Decoder(nn.Module): self.channel_multipliers = channel_multipliers self.temporal_downsample = temporal_downsample self.num_groups = num_groups - + # self.upsample = upsample - self.s_stride = 1 if disable_spatial_upsample else 2 # spatial stride + self.s_stride = 1 if disable_spatial_upsample else 2 # spatial stride self.custom_conv_padding = custom_conv_padding # self.norm_type = self.config.vqvae.norm_type # self.num_remat_block = self.config.vqvae.get('num_dec_remat_blocks', 0) - if activation_fn == 'relu': + if activation_fn == "relu": self.activation_fn = nn.ReLU - elif activation_fn == 'swish': - self.activation_fn = nn.SiLU + elif activation_fn == "swish": + self.activation_fn = nn.SiLU else: - raise NotImplementedError + raise NotImplementedError self.activate = self.activation_fn() self.conv_fn = functools.partial( @@ -722,34 +785,36 @@ class Decoder(nn.Module): # self.upsampler = nn.Upsample(scale_factor=(1,2,2)) # ResBlocks and conv upsample - prev_filters = filters # SCH: in_channels + prev_filters = filters # SCH: in_channels self.block_res_blocks = [] self.num_blocks = len(self.channel_multipliers) self.conv_blocks = [] # SCH: reverse to keep track of the in_channels, but append also in a reverse direction - for i in reversed(range(self.num_blocks)): + for i in reversed(range(self.num_blocks)): filters = self.filters * self.channel_multipliers[i] # resblock handling block_items = [] for _ in range(self.num_res_blocks): block_items.append(ResBlock(prev_filters, filters, **self.block_args)) - prev_filters = filters # SCH: update in_channels - self.block_res_blocks.insert(0, block_items) # SCH: append in front - + prev_filters = filters # SCH: update in_channels + self.block_res_blocks.insert(0, block_items) # SCH: append in front + # conv blocks with upsampling if i > 0: t_stride = 2 if self.temporal_downsample[i - 1] else 1 # SCH: T-Causal Conv 3x3x3, f -> (t_stride * 2 * 2) * f, depth to space t_stride x 2 x 2 - self.conv_blocks.insert(0, - self.conv_fn(prev_filters, prev_filters * t_stride * self.s_stride * self.s_stride, kernel_size=(3,3,3)) + self.conv_blocks.insert( + 0, + self.conv_fn( + prev_filters, prev_filters * t_stride * self.s_stride * self.s_stride, kernel_size=(3, 3, 3) + ), ) - + self.norm1 = nn.GroupNorm(self.num_groups, prev_filters, device=device, dtype=dtype) # NOTE: moved to VAE for separate first frame processing # self.conv2 = self.conv_fn(prev_filters, self.output_dim, kernel_size=(3, 3, 3)) - def forward( self, x, @@ -762,42 +827,49 @@ class Decoder(nn.Module): for i in range(self.num_res_blocks): x = self.res_blocks[i](x) # print("decoder:", x.size()) - for i in reversed(range(self.num_blocks)): # reverse here to make decoder symmetric with encoder + for i in reversed(range(self.num_blocks)): # reverse here to make decoder symmetric with encoder for j in range(self.num_res_blocks): x = self.block_res_blocks[i][j](x) # print("decoder:", x.size()) if i > 0: t_stride = 2 if self.temporal_downsample[i - 1] else 1 # SCH: T-Causal Conv 3x3x3, f -> (t_stride * 2 * 2) * f, depth to space t_stride x 2 x 2 - x = self.conv_blocks[i-1](x) - x = rearrange(x, "B (C ts hs ws) T H W -> B C (T ts) (H hs) (W ws)", ts=t_stride, hs=self.s_stride, ws=self.s_stride) + x = self.conv_blocks[i - 1](x) + x = rearrange( + x, + "B (C ts hs ws) T H W -> B C (T ts) (H hs) (W ws)", + ts=t_stride, + hs=self.s_stride, + ws=self.s_stride, + ) # print("decoder:", x.size()) x = self.norm1(x) x = self.activate(x) # NOTE: moved to VAE for separate first frame processing - # x = self.conv2(x) + # x = self.conv2(x) return x @MODELS.register_module() -class VAE_3D_V2(nn.Module): # , ModelMixin - """The 3D VAE """ +class VAE_3D_V2(nn.Module): # , ModelMixin + """The 3D VAE""" + def __init__( - self, - latent_embed_dim = 256, - filters = 128, - num_res_blocks = 2, - separate_first_frame_encoding = False, - channel_multipliers = (1, 2, 2, 4), - temporal_downsample = (True, True, False), - num_groups = 32, # for nn.GroupNorm - disable_space = False, - custom_conv_padding = None, - activation_fn = 'swish', - in_out_channels = 4, - kl_embed_dim = 64, - encoder_double_z = True, + self, + latent_embed_dim=256, + filters=128, + num_res_blocks=2, + separate_first_frame_encoding=False, + channel_multipliers=(1, 2, 2, 4), + temporal_downsample=(True, True, False), + num_groups=32, # for nn.GroupNorm + disable_space=False, + custom_conv_padding=None, + activation_fn="swish", + in_out_channels=4, + kl_embed_dim=64, + encoder_double_z=True, device="cpu", dtype="bf16", ): @@ -809,12 +881,11 @@ class VAE_3D_V2(nn.Module): # , ModelMixin elif dtype == "fp16": dtype = torch.float16 else: - raise NotImplementedError(f'dtype: {dtype}') - + raise NotImplementedError(f"dtype: {dtype}") # ==== Model Params ==== # self.image_size = cast_tuple(image_size, 2) - self.time_downsample_factor = 2**sum(temporal_downsample) + self.time_downsample_factor = 2 ** sum(temporal_downsample) self.time_padding = self.time_downsample_factor - 1 self.separate_first_frame_encoding = separate_first_frame_encoding @@ -825,50 +896,52 @@ class VAE_3D_V2(nn.Module): # , ModelMixin # ==== Model Initialization ==== # encoder & decoder first and last conv layer - # SCH: NOTE: following MAGVIT, conv in bias=False in encoder first conv - self.conv_in = CausalConv3d(in_out_channels, filters, kernel_size=(3, 3, 3), bias=False, dtype=dtype, device=device) + # SCH: NOTE: following MAGVIT, conv in bias=False in encoder first conv + self.conv_in = CausalConv3d( + in_out_channels, filters, kernel_size=(3, 3, 3), bias=False, dtype=dtype, device=device + ) self.conv_in_first_frame = nn.Identity() self.conv_out_first_frame = nn.Identity() if separate_first_frame_encoding: - self.conv_in_first_frame = SameConv2d(in_out_channels, filters, (3,3)) - self.conv_out_first_frame = SameConv2d(filters, in_out_channels, (3,3)) + self.conv_in_first_frame = SameConv2d(in_out_channels, filters, (3, 3)) + self.conv_out_first_frame = SameConv2d(filters, in_out_channels, (3, 3)) self.conv_out = CausalConv3d(filters, in_out_channels, 3, dtype=dtype, device=device) self.encoder = Encoder( - filters = filters, - num_res_blocks=num_res_blocks, - channel_multipliers=channel_multipliers, + filters=filters, + num_res_blocks=num_res_blocks, + channel_multipliers=channel_multipliers, temporal_downsample=temporal_downsample, - num_groups = num_groups, # for nn.GroupNorm + num_groups=num_groups, # for nn.GroupNorm # in_out_channels = in_out_channels, - latent_embed_dim = latent_embed_dim * 2 if encoder_double_z else latent_embed_dim, - # conv_downsample = conv_downsample, + latent_embed_dim=latent_embed_dim * 2 if encoder_double_z else latent_embed_dim, + # conv_downsample = conv_downsample, disable_spatial_downsample=disable_space, - custom_conv_padding = custom_conv_padding, - activation_fn = activation_fn, - device = device, - dtype = dtype, + custom_conv_padding=custom_conv_padding, + activation_fn=activation_fn, + device=device, + dtype=dtype, ) self.decoder = Decoder( - latent_embed_dim = latent_embed_dim, - filters = filters, - # in_out_channels = in_out_channels, - num_res_blocks = num_res_blocks, - channel_multipliers = channel_multipliers, - temporal_downsample = temporal_downsample, - num_groups = num_groups, # for nn.GroupNorm + latent_embed_dim=latent_embed_dim, + filters=filters, + # in_out_channels = in_out_channels, + num_res_blocks=num_res_blocks, + channel_multipliers=channel_multipliers, + temporal_downsample=temporal_downsample, + num_groups=num_groups, # for nn.GroupNorm # upsample = upsample, # options: "deconv", "nearest+conv" disable_spatial_upsample=disable_space, - custom_conv_padding = custom_conv_padding, - activation_fn = activation_fn, - device = device, - dtype = dtype, + custom_conv_padding=custom_conv_padding, + activation_fn=activation_fn, + device=device, + dtype=dtype, ) - + if encoder_double_z: - self.quant_conv = nn.Conv3d(2*latent_embed_dim, 2*kl_embed_dim, 1, device=device, dtype=dtype) + self.quant_conv = nn.Conv3d(2 * latent_embed_dim, 2 * kl_embed_dim, 1, device=device, dtype=dtype) else: - self.quant_conv = nn.Conv3d(latent_embed_dim, 2*kl_embed_dim, 1, device=device, dtype=dtype) + self.quant_conv = nn.Conv3d(latent_embed_dim, 2 * kl_embed_dim, 1, device=device, dtype=dtype) self.post_quant_conv = nn.Conv3d(kl_embed_dim, latent_embed_dim, 1, device=device, dtype=dtype) def get_latent_size(self, input_size): @@ -876,33 +949,33 @@ class VAE_3D_V2(nn.Module): # , ModelMixin assert input_size[i] % self.patch_size[i] == 0, "Input size must be divisible by patch size" input_size = [input_size[i] // self.patch_size[i] for i in range(3)] return input_size - + def encode( self, video, - video_contains_first_frame = True, + video_contains_first_frame=True, ): encode_first_frame_separately = self.separate_first_frame_encoding and video_contains_first_frame # whether to pad video or not if video_contains_first_frame: video_len = video.shape[2] - video = pad_at_dim(video, (self.time_padding, 0), value = 0., dim = 2) + video = pad_at_dim(video, (self.time_padding, 0), value=0.0, dim=2) video_packed_shape = [torch.Size([self.time_padding]), torch.Size([]), torch.Size([video_len - 1])] # print("pre-encoder:", video.size()) # NOTE: moved encoder conv1 here for separate first frame encoding if encode_first_frame_separately: - pad, first_frame, video = unpack(video, video_packed_shape, 'b c * h w') + pad, first_frame, video = unpack(video, video_packed_shape, "b c * h w") first_frame = self.conv_in_first_frame(first_frame) video = self.conv_in(video) # print("pre-encoder:", video.size()) if encode_first_frame_separately: - video, _ = pack([first_frame, video], 'b c * h w') - video = pad_at_dim(video, (self.time_padding, 0), dim = 2) + video, _ = pack([first_frame, video], "b c * h w") + video = pad_at_dim(video, (self.time_padding, 0), dim=2) encoded_feature = self.encoder(video) @@ -915,16 +988,15 @@ class VAE_3D_V2(nn.Module): # , ModelMixin # print("after encoder moments:", moments.size()) return posterior - + def decode( self, z, - video_contains_first_frame = True, - ): + video_contains_first_frame=True, + ): # dtype = z.dtype decode_first_frame_separately = self.separate_first_frame_encoding and video_contains_first_frame - z = self.post_quant_conv(z) # print("pre decoder, post quant conv:", z.size()) @@ -933,25 +1005,28 @@ class VAE_3D_V2(nn.Module): # , ModelMixin # SCH: moved decoder last conv layer here for separate first frame decoding if decode_first_frame_separately: - left_pad, dec_ff, dec = dec[:, :, :self.time_padding], dec[:, :, self.time_padding], dec[:, :, (self.time_padding + 1):] + left_pad, dec_ff, dec = ( + dec[:, :, : self.time_padding], + dec[:, :, self.time_padding], + dec[:, :, (self.time_padding + 1) :], + ) out = self.conv_out(dec) outff = self.conv_out_first_frame(dec_ff) - video, _ = pack([outff, out], 'b c * h w') + video, _ = pack([outff, out], "b c * h w") else: video = self.conv_out(dec) # if video were padded, remove padding if video_contains_first_frame: - video = video[:, :, self.time_padding:] + video = video[:, :, self.time_padding :] # print("conv out:", video.size()) return video - + def get_last_layer(self): # CausalConv3d wraps the conv return self.conv_out.conv.weight - - + # def parameters(self): # return [ # *self.conv_in.parameters(), @@ -966,22 +1041,22 @@ class VAE_3D_V2(nn.Module): # , ModelMixin # def disc_parameters(self): # return self.discriminator.parameters() - + def forward( self, video, sample_posterior=True, - video_contains_first_frame = True, + video_contains_first_frame=True, # split = "train", - - ): - + ): batch, channels, frames = video.shape[:3] - assert divisible_by(frames - int(video_contains_first_frame), self.time_downsample_factor), f'number of frames {frames} minus the first frame ({frames - int(video_contains_first_frame)}) must be divisible by the total downsample factor across time {self.time_downsample_factor}' + assert divisible_by( + frames - int(video_contains_first_frame), self.time_downsample_factor + ), f"number of frames {frames} minus the first frame ({frames - int(video_contains_first_frame)}) must be divisible by the total downsample factor across time {self.time_downsample_factor}" posterior = self.encode( video, - video_contains_first_frame = video_contains_first_frame, + video_contains_first_frame=video_contains_first_frame, ) if sample_posterior: @@ -989,12 +1064,8 @@ class VAE_3D_V2(nn.Module): # , ModelMixin else: z = posterior.mode() + recon_video = self.decode(z, video_contains_first_frame=video_contains_first_frame) - recon_video = self.decode( - z, - video_contains_first_frame = video_contains_first_frame - ) - return recon_video, posterior @@ -1002,12 +1073,12 @@ class VEALoss(nn.Module): def __init__( self, logvar_init=0.0, - perceptual_loss_weight = 0.1, - kl_loss_weight = 0.000001, + perceptual_loss_weight=0.1, + kl_loss_weight=0.000001, # vgg=None, # vgg_weights: VGG16_Weights = VGG16_Weights.DEFAULT, - device = "cpu", - dtype = "bf16" + device="cpu", + dtype="bf16", ): super().__init__() @@ -1017,8 +1088,8 @@ class VEALoss(nn.Module): elif dtype == "fp16": dtype = torch.float16 else: - raise NotImplementedError(f'dtype: {dtype}') - + raise NotImplementedError(f"dtype: {dtype}") + # KL Loss self.kl_loss_weight = kl_loss_weight # Perceptual Loss @@ -1034,7 +1105,6 @@ class VEALoss(nn.Module): # ) # vgg.classifier = Sequential(*vgg.classifier[:-2]) # self.vgg = vgg.to(device, dtype).eval() # SCH: added eval - def forward( self, @@ -1042,8 +1112,8 @@ class VEALoss(nn.Module): recon_video, posterior, nll_weights=None, - split = "train", - ): + split="train", + ): video = rearrange(video, "b c t h w -> (b t) c h w").contiguous() recon_video = rearrange(recon_video, "b c t h w -> (b t) c h w").contiguous() @@ -1052,13 +1122,13 @@ class VEALoss(nn.Module): # perceptual loss if self.perceptual_loss_weight is not None and self.perceptual_loss_weight > 0.0: - # handle channels + # handle channels channels = video.shape[1] - assert channels in {1,3,4} + assert channels in {1, 3, 4} if channels == 1: - input_vgg_input = repeat(video, 'b 1 h w -> b c h w', c = 3) - recon_vgg_input = repeat(recon_video, 'b 1 h w -> b c h w', c = 3) - elif channels == 4: # SCH: take the first 3 for perceptual loss calc + input_vgg_input = repeat(video, "b 1 h w -> b c h w", c=3) + recon_vgg_input = repeat(recon_video, "b 1 h w -> b c h w", c=3) + elif channels == 4: # SCH: take the first 3 for perceptual loss calc input_vgg_input = video[:, :3] recon_vgg_input = recon_video[:, :3] else: @@ -1067,7 +1137,7 @@ class VEALoss(nn.Module): perceptual_loss = self.perceptual_loss_fn(input_vgg_input, recon_vgg_input) recon_loss = recon_loss + self.perceptual_loss_weight * perceptual_loss - + nll_loss = recon_loss / torch.exp(self.logvar) + self.logvar weighted_nll_loss = nll_loss if nll_weights is not None: @@ -1118,14 +1188,14 @@ class VEALoss(nn.Module): # } # return nll_loss, weighted_kl_loss - + class AdversarialLoss(nn.Module): def __init__( self, - discriminator_factor = 1.0, - discriminator_start = 50001, - generator_factor = 0.5, + discriminator_factor=1.0, + discriminator_start=50001, + generator_factor=0.5, generator_loss_type="non-saturating", ): super().__init__() @@ -1133,10 +1203,12 @@ class AdversarialLoss(nn.Module): self.discriminator_start = discriminator_start self.generator_factor = generator_factor self.generator_loss_type = generator_loss_type - + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer): - nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] # SCH: TODO: debug added creat - g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] # SCH: TODO: debug added create_graph=True + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] # SCH: TODO: debug added creat + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[ + 0 + ] # SCH: TODO: debug added create_graph=True d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() d_weight = d_weight * self.generator_factor @@ -1148,8 +1220,8 @@ class AdversarialLoss(nn.Module): nll_loss, last_layer, global_step, - is_training = True, - ): + is_training=True, + ): # NOTE: following MAGVIT to allow non_saturating assert self.generator_loss_type in ["hinge", "vanilla", "non-saturating"] @@ -1157,15 +1229,13 @@ class AdversarialLoss(nn.Module): gen_loss = -torch.mean(fake_logits) elif self.generator_loss_type == "non-saturating": gen_loss = torch.mean( - sigmoid_cross_entropy_with_logits( - labels=torch.ones_like(fake_logits), logits=fake_logits) - ) + sigmoid_cross_entropy_with_logits(labels=torch.ones_like(fake_logits), logits=fake_logits) + ) else: raise ValueError("Generator loss {} not supported".format(self.generator_loss_type)) - if self.discriminator_factor is not None and self.discriminator_factor > 0.0: - try: + try: d_weight = self.calculate_adaptive_weight(nll_loss, gen_loss, last_layer) except RuntimeError: assert not is_training @@ -1177,55 +1247,49 @@ class AdversarialLoss(nn.Module): weighted_gen_loss = d_weight * disc_factor * gen_loss return weighted_gen_loss - + + class LeCamEMA: - def __init__( - self, - ema_real = 0.0, - ema_fake = 0.0, - decay=0.999, - dtype=torch.bfloat16, - device="cpu" - ): + def __init__(self, ema_real=0.0, ema_fake=0.0, decay=0.999, dtype=torch.bfloat16, device="cpu"): self.decay = decay self.ema_real = torch.tensor(ema_real).to(device, dtype) self.ema_fake = torch.tensor(ema_fake).to(device, dtype) - + def update(self, ema_real, ema_fake): - self.ema_real = self.ema_real * self.decay + ema_real * (1-self.decay) - self.ema_fake = self.ema_fake * self.decay + ema_fake * (1-self.decay) + self.ema_real = self.ema_real * self.decay + ema_real * (1 - self.decay) + self.ema_fake = self.ema_fake * self.decay + ema_fake * (1 - self.decay) def get(self): return self.ema_real, self.ema_fake - + + class DiscriminatorLoss(nn.Module): def __init__( self, - discriminator_factor = 1.0, - discriminator_start = 50001, + discriminator_factor=1.0, + discriminator_start=50001, discriminator_loss_type="non-saturating", lecam_loss_weight=None, - gradient_penalty_loss_weight=None, # SCH: following MAGVIT config.vqgan.grad_penalty_cost + gradient_penalty_loss_weight=None, # SCH: following MAGVIT config.vqgan.grad_penalty_cost ): super().__init__() - + assert discriminator_loss_type in ["hinge", "vanilla", "non-saturating"] self.discriminator_factor = discriminator_factor self.discriminator_start = discriminator_start self.lecam_loss_weight = lecam_loss_weight self.gradient_penalty_loss_weight = gradient_penalty_loss_weight self.discriminator_loss_type = discriminator_loss_type - def forward( self, real_logits, fake_logits, global_step, - lecam_ema_real = None, - lecam_ema_fake = None, - real_video = None, - split = "train", + lecam_ema_real=None, + lecam_ema_fake=None, + real_video=None, + split="train", ): if self.discriminator_factor is not None and self.discriminator_factor > 0.0: disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_start) @@ -1241,7 +1305,8 @@ class DiscriminatorLoss(nn.Module): real_loss = 0.0 if fake_logits is not None: fake_loss = sigmoid_cross_entropy_with_logits( - labels=torch.zeros_like(fake_logits), logits=fake_logits) + labels=torch.zeros_like(fake_logits), logits=fake_logits + ) else: fake_loss = 0.0 disc_loss = 0.5 * (torch.mean(real_loss) + torch.mean(fake_loss)) @@ -1259,18 +1324,16 @@ class DiscriminatorLoss(nn.Module): if self.lecam_loss_weight is not None and self.lecam_loss_weight > 0.0: real_pred = torch.mean(real_logits) fake_pred = torch.mean(fake_logits) - lecam_loss = lecam_reg(real_pred, fake_pred, - lecam_ema_real, - lecam_ema_fake) + lecam_loss = lecam_reg(real_pred, fake_pred, lecam_ema_real, lecam_ema_fake) lecam_loss = lecam_loss * self.lecam_loss_weight - + gradient_penalty = torch.tensor(0.0) if self.gradient_penalty_loss_weight is not None and self.gradient_penalty_loss_weight > 0.0: assert real_video is not None gradient_penalty = gradient_penalty_fn(real_video, real_logits) # gradient_penalty = r1_penalty(real_video, real_logits) # MAGVIT uses r1 penalty gradient_penalty *= self.gradient_penalty_loss_weight - + # discriminator_loss = weighted_d_adversarial_loss + lecam_loss + gradient_penalty # log = { @@ -1281,6 +1344,7 @@ class DiscriminatorLoss(nn.Module): return (weighted_d_adversarial_loss, lecam_loss, gradient_penalty) + @MODELS.register_module("VAE_MAGVIT_V2") def VAE_MAGVIT_V2(from_pretrained=None, **kwargs): model = VAE_3D_V2(**kwargs) @@ -1288,11 +1352,12 @@ def VAE_MAGVIT_V2(from_pretrained=None, **kwargs): load_checkpoint(model, from_pretrained, model_name="model") return model + @MODELS.register_module("DISCRIMINATOR_3D") def DISCRIMINATOR_3D(from_pretrained=None, inflate_from_2d=False, use_pretrained=True, **kwargs): model = StyleGANDiscriminatorBlur(**kwargs).apply(xavier_uniform_weight_init) - # model = StyleGANDiscriminator(**kwargs).apply(xavier_uniform_weight_init) # SCH: DEBUG: to change back - # model = NLayerDiscriminator3D(input_nc=3, n_layers=3,).apply(n_layer_disc_weights_init) + # model = StyleGANDiscriminator(**kwargs).apply(xavier_uniform_weight_init) # SCH: DEBUG: to change back + # model = NLayerDiscriminator3D(input_nc=3, n_layers=3,).apply(n_layer_disc_weights_init) if from_pretrained is not None: if use_pretrained: if inflate_from_2d: @@ -1302,28 +1367,27 @@ def DISCRIMINATOR_3D(from_pretrained=None, inflate_from_2d=False, use_pretrained print(f"loaded discriminator") else: print(f"discriminator use_pretrained={use_pretrained}, initializing new discriminator") - + return model def load_checkpoint_with_inflation(model, ckpt_path): """ - pre-train using image, then inflate to 3D videos + pre-train using image, then inflate to 3D videos """ if ckpt_path.endswith(".pt") or ckpt_path.endswith(".pth"): state_dict = find_model(ckpt_path) with torch.no_grad(): for key in state_dict: if key in model: - # central inflation - if state_dict[key].size() == model[key][:, :, 0, :, :].size(): - # temporal dimension - val = torch.zeros_like(model[key]) - centre = int(model[key].size(2) // 2) - val[:, :, centre, :, :] = state_dict[key] + # central inflation + if state_dict[key].size() == model[key][:, :, 0, :, :].size(): + # temporal dimension + val = torch.zeros_like(model[key]) + centre = int(model[key].size(2) // 2) + val[:, :, centre, :, :] = state_dict[key] missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) print(f"Missing keys: {missing_keys}") print(f"Unexpected keys: {unexpected_keys}") else: - load_checkpoint(model, ckpt_path) # use the default function - \ No newline at end of file + load_checkpoint(model, ckpt_path) # use the default function diff --git a/scripts/train-vae-v2.py b/scripts/train-vae-v2.py index 4967935..c65c3b3 100644 --- a/scripts/train-vae-v2.py +++ b/scripts/train-vae-v2.py @@ -1,10 +1,8 @@ -from copy import deepcopy +import os +from glob import glob import colossalai import torch -import torch.nn as nn -from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler import torch.distributed as dist import wandb from colossalai.booster import Booster @@ -12,11 +10,8 @@ from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device -from tqdm import tqdm -import os from einops import rearrange -import numpy as np -from glob import glob +from tqdm import tqdm from opensora.acceleration.checkpoint import set_grad_checkpoint from opensora.acceleration.parallel_states import ( @@ -25,28 +20,17 @@ from opensora.acceleration.parallel_states import ( set_sequence_parallel_group, ) from opensora.acceleration.plugin import ZeroSeqParallelPlugin -from opensora.datasets import DatasetFromCSV, get_transforms_image, get_transforms_video, prepare_dataloader -from opensora.registry import MODELS, SCHEDULERS, build_module -from opensora.utils.ckpt_utils import create_logger, load_json, save_json, load, model_sharding, record_model_param_shape, save +from opensora.datasets import prepare_dataloader, prepare_variable_dataloader +from opensora.models.vae.vae_3d_v2 import AdversarialLoss, DiscriminatorLoss, LeCamEMA, VEALoss, pad_at_dim +from opensora.registry import DATASETS, MODELS, build_module +from opensora.utils.ckpt_utils import create_logger, load_json, save_json from opensora.utils.config_utils import ( create_experiment_workspace, create_tensorboard_writer, parse_configs, save_training_config, ) -from opensora.utils.misc import all_reduce_mean, format_numel_str, get_model_numel, requires_grad, to_torch_dtype -from opensora.utils.train_utils import update_ema, MaskGenerator -from opensora.models.vae.vae_3d_v2 import VEALoss, DiscriminatorLoss, AdversarialLoss, LeCamEMA, pad_at_dim - - - -# efficiency -# from torch.profiler import profile, record_function, ProfilerActivity - -def trace_handler(p): - output = p.key_averages().table(sort_by="self_cuda_time_total", row_limit=5) - print(output) - # p.export_chrome_trace("/home/shenchenhui/Open-Sora-dev/outputs/traces/trace_" + str(p.step_num) + ".json") +from opensora.utils.misc import all_reduce_mean, format_numel_str, get_model_numel, to_torch_dtype def main(): @@ -59,20 +43,20 @@ def main(): # 2. runtime variables & colossalai launch # ====================================================== assert torch.cuda.is_available(), "Training currently requires at least one GPU." - + # 2.1. colossalai init distributed training colossalai.launch_from_torch({}) coordinator = DistCoordinator() exp_dir = None - if coordinator.is_master(): # only create directory for master + if coordinator.is_master(): # only create directory for master exp_name, exp_dir = create_experiment_workspace(cfg) save_training_config(cfg._cfg_dict, exp_dir) dist.barrier() # get exp dir for non-master process if exp_dir is None: - experiment_index = len(glob(f"{cfg.outputs}/*"))-1 + experiment_index = len(glob(f"{cfg.outputs}/*")) - 1 model_name = cfg.model["type"].replace("/", "-") exp_name = f"{experiment_index:03d}-F{cfg.num_frames}S{cfg.frame_interval}-{model_name}" exp_dir = f"{cfg.outputs}/{exp_name}" @@ -123,31 +107,30 @@ def main(): # ====================================================== # 3. build dataset and dataloader # ====================================================== - dataset = DatasetFromCSV( - cfg.data_path, - transform=( - get_transforms_video(cfg.image_size[0]) - if not cfg.use_image_transform - else get_transforms_image(cfg.image_size[0]) - ), - num_frames=cfg.num_frames, - frame_interval=cfg.frame_interval, - root=cfg.root, - ) - - dataloader = prepare_dataloader( - dataset, + dataset = build_module(cfg.dataset, DATASETS) + logger.info(f"Dataset contains {len(dataset)} samples.") + dataloader_args = dict( + dataset=dataset, batch_size=cfg.batch_size, num_workers=cfg.num_workers, + seed=cfg.seed, shuffle=True, drop_last=True, pin_memory=True, process_group=get_data_parallel_group(), ) - logger.info(f"Dataset contains {len(dataset):,} videos ({cfg.data_path})") - - total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size - logger.info(f"Total batch size: {total_batch_size}") + # TODO: use plugin's prepare dataloader + if cfg.bucket_config is None: + dataloader = prepare_dataloader(**dataloader_args) + else: + dataloader = prepare_variable_dataloader( + bucket_config=cfg.bucket_config, + num_bucket_build_workers=cfg.num_bucket_build_workers, + **dataloader_args, + ) + if cfg.dataset.type == "VideoTextDataset": + total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size + logger.info(f"Total batch size: {total_batch_size}") # ====================================================== # 4. build model @@ -163,7 +146,8 @@ def main(): logger.info( f"Trainable vae params: {format_numel_str(vae_numel_trainable)}, Total model params: {format_numel_str(vae_numel)}" ) - + breakpoint() + discriminator = build_module(cfg.discriminator, MODELS, device=device) discriminator_numel, discriminator_numel_trainable = get_model_numel(discriminator) logger.info( @@ -175,12 +159,11 @@ def main(): # 4.3. move to device if cfg.get("use_pipeline") == True: - vae_2d.to(device, dtype).eval() # eval mode, not training! + vae_2d.to(device, dtype).eval() # eval mode, not training! vae = vae.to(device, dtype) discriminator = discriminator.to(device, dtype) - # 4.5. setup optimizer # vae optimizer optimizer = HybridAdam( @@ -200,7 +183,6 @@ def main(): vae.train() discriminator.train() - # ======================================================= # 5. boost model for distributed training with colossalai # ======================================================= @@ -212,13 +194,11 @@ def main(): num_steps_per_epoch = len(dataloader) logger.info("Boost vae for distributed training") - discriminator, disc_optimizer, _, _, disc_lr_scheduler = booster.boost( model=discriminator, optimizer=disc_optimizer, lr_scheduler=disc_lr_scheduler ) logger.info("Boost discriminator for distributed training") - # ======================================================= # 6. training loop # ======================================================= @@ -226,7 +206,6 @@ def main(): running_loss = 0.0 running_disc_loss = 0.0 - # 6.1. resume training if cfg.load is not None: logger.info("Loading checkpoint") @@ -244,11 +223,17 @@ def main(): if cfg.lecam_loss_weight is not None and os.path.exists(lecam_path): lecam_state = load_json(lecam_path) lecam_ema_real, lecam_ema_fake = lecam_state["lecam_ema_real"], lecam_state["lecam_ema_fake"] - lecam_ema = LeCamEMA(decay=cfg.ema_decay, ema_real=lecam_ema_real, ema_fake=lecam_ema_fake, dtype=dtype, device=device) + lecam_ema = LeCamEMA( + decay=cfg.ema_decay, ema_real=lecam_ema_real, ema_fake=lecam_ema_fake, dtype=dtype, device=device + ) running_states = load_json(os.path.join(cfg.load, "running_states.json")) dist.barrier() - start_epoch, start_step, sampler_start_idx = running_states["epoch"], running_states["step"], running_states["sample_start_index"] + start_epoch, start_step, sampler_start_idx = ( + running_states["epoch"], + running_states["step"], + running_states["sample_start_index"], + ) logger.info(f"Loaded checkpoint {cfg.load} at epoch {start_epoch} step {start_step}") logger.info(f"Training for {cfg.epochs} epochs with {num_steps_per_epoch} steps per epoch") @@ -257,25 +242,25 @@ def main(): # 6.2 Define loss functions vae_loss_fn = VEALoss( logvar_init=cfg.logvar_init, - perceptual_loss_weight = cfg.perceptual_loss_weight, - kl_loss_weight = cfg.kl_loss_weight, + perceptual_loss_weight=cfg.perceptual_loss_weight, + kl_loss_weight=cfg.kl_loss_weight, device=device, dtype=dtype, ) adversarial_loss_fn = AdversarialLoss( - discriminator_factor = cfg.discriminator_factor, - discriminator_start = cfg.discriminator_start, - generator_factor = cfg.generator_factor, - generator_loss_type = cfg.generator_loss_type, + discriminator_factor=cfg.discriminator_factor, + discriminator_start=cfg.discriminator_start, + generator_factor=cfg.generator_factor, + generator_loss_type=cfg.generator_loss_type, ) disc_loss_fn = DiscriminatorLoss( - discriminator_factor = cfg.discriminator_factor, - discriminator_start = cfg.discriminator_start, - discriminator_loss_type = cfg.discriminator_loss_type, - lecam_loss_weight = cfg.lecam_loss_weight, - gradient_penalty_loss_weight = cfg.gradient_penalty_loss_weight, + discriminator_factor=cfg.discriminator_factor, + discriminator_start=cfg.discriminator_start, + discriminator_loss_type=cfg.discriminator_loss_type, + lecam_loss_weight=cfg.lecam_loss_weight, + gradient_penalty_loss_weight=cfg.gradient_penalty_loss_weight, ) # 6.3. training loop @@ -288,14 +273,11 @@ def main(): disc_time_padding = 0 video_contains_first_frame = cfg.video_contains_first_frame - for epoch in range(start_epoch, cfg.epochs): dataloader.sampler.set_epoch(epoch) dataloader_iter = iter(dataloader) logger.info(f"Beginning epoch {epoch}...") - - with tqdm( range(start_step, num_steps_per_epoch), desc=f"Epoch {epoch}", @@ -303,7 +285,6 @@ def main(): total=num_steps_per_epoch, initial=start_step, ) as pbar: - # with profile( # activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], # schedule=torch.profiler.schedule( @@ -317,199 +298,203 @@ def main(): # record_shapes=True, # profile_memory=True, # ) as p: # trace efficiency - - for step in pbar: - # with profile( - # activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - # with_stack=True, - # ) as p: # trace efficiency + for step in pbar: + # with profile( + # activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + # with_stack=True, + # ) as p: # trace efficiency - # SCH: calc global step at the start - global_step = epoch * num_steps_per_epoch + step - - batch = next(dataloader_iter) - x = batch["video"].to(device, dtype) # [B, C, T, H, W] + # SCH: calc global step at the start + global_step = epoch * num_steps_per_epoch + step - # supprt for image or video inputs - assert x.ndim in {4, 5}, f"received input of {x.ndim} dimensions" # either image or video - assert x.shape[-2:] == cfg.image_size, f"received input size {x.shape[-2:]}, but config image size is {cfg.image_size}" - is_image = x.ndim == 4 - if is_image: - video = rearrange(x, 'b c ... -> b c 1 ...') - video_contains_first_frame = True - else: - video = x - - # ===== Spatial VAE ===== - if cfg.get("use_pipeline") == True: - with torch.no_grad(): - video = vae_2d.encode(video) + batch = next(dataloader_iter) + x = batch["video"].to(device, dtype) # [B, C, T, H, W] - # ====== VAE ====== - recon_video, posterior = vae( - video, - video_contains_first_frame = video_contains_first_frame, - ) + # supprt for image or video inputs + assert x.ndim in {4, 5}, f"received input of {x.ndim} dimensions" # either image or video + assert ( + x.shape[-2:] == cfg.image_size + ), f"received input size {x.shape[-2:]}, but config image size is {cfg.image_size}" + is_image = x.ndim == 4 + if is_image: + video = rearrange(x, "b c ... -> b c 1 ...") + video_contains_first_frame = True + else: + video = x - # ====== Generator Loss ====== - # simple nll loss - nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn( - video, - recon_video, - posterior, - split = "train" - ) + # ===== Spatial VAE ===== + if cfg.get("use_pipeline") == True: + with torch.no_grad(): + video = vae_2d.encode(video) - adversarial_loss = torch.tensor(0.0) - # adversarial loss - if global_step > cfg.discriminator_start: - # padded videos for GAN - fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value = 0., dim = 2) - fake_logits = discriminator(fake_video.contiguous()) - adversarial_loss = adversarial_loss_fn( - fake_logits, - nll_loss, - vae.module.get_last_layer(), - global_step, - is_training = vae.training, - ) - - vae_loss = weighted_nll_loss + weighted_kl_loss + adversarial_loss + # ====== VAE ====== + recon_video, posterior = vae( + video, + video_contains_first_frame=video_contains_first_frame, + ) - optimizer.zero_grad() - # Backward & update - booster.backward(loss=vae_loss, optimizer=optimizer) - # # NOTE: clip gradients? this is done in Open-Sora-Plan - # torch.nn.utils.clip_grad_norm_(vae.parameters(), 1) # NOTE: done by grad_clip - optimizer.step() + # ====== Generator Loss ====== + # simple nll loss + nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn( + video, recon_video, posterior, split="train" + ) - # Log loss values: - all_reduce_mean(vae_loss) # NOTE: this is to get average loss for logging - running_loss += vae_loss.item() - + adversarial_loss = torch.tensor(0.0) + # adversarial loss + if global_step > cfg.discriminator_start: + # padded videos for GAN + fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value=0.0, dim=2) + fake_logits = discriminator(fake_video.contiguous()) + adversarial_loss = adversarial_loss_fn( + fake_logits, + nll_loss, + vae.module.get_last_layer(), + global_step, + is_training=vae.training, + ) - # ====== Discriminator Loss ====== - if global_step > cfg.discriminator_start: - # if video_contains_first_frame: - # Since we don't have enough T frames, pad anyways - real_video = pad_at_dim(video, (disc_time_padding, 0), value = 0., dim = 2) - fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value = 0., dim = 2) + vae_loss = weighted_nll_loss + weighted_kl_loss + adversarial_loss - if cfg.gradient_penalty_loss_weight is not None and cfg.gradient_penalty_loss_weight > 0.0: - real_video = real_video.requires_grad_() - real_logits = discriminator(real_video.contiguous()) # SCH: not detached for now for gradient_penalty calculation - else: - real_logits = discriminator(real_video.contiguous().detach()) + optimizer.zero_grad() + # Backward & update + booster.backward(loss=vae_loss, optimizer=optimizer) + # # NOTE: clip gradients? this is done in Open-Sora-Plan + # torch.nn.utils.clip_grad_norm_(vae.parameters(), 1) # NOTE: done by grad_clip + optimizer.step() - fake_logits = discriminator(fake_video.contiguous().detach()) + # Log loss values: + all_reduce_mean(vae_loss) # NOTE: this is to get average loss for logging + running_loss += vae_loss.item() + # ====== Discriminator Loss ====== + if global_step > cfg.discriminator_start: + # if video_contains_first_frame: + # Since we don't have enough T frames, pad anyways + real_video = pad_at_dim(video, (disc_time_padding, 0), value=0.0, dim=2) + fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value=0.0, dim=2) - lecam_ema_real, lecam_ema_fake = lecam_ema.get() + if cfg.gradient_penalty_loss_weight is not None and cfg.gradient_penalty_loss_weight > 0.0: + real_video = real_video.requires_grad_() + real_logits = discriminator( + real_video.contiguous() + ) # SCH: not detached for now for gradient_penalty calculation + else: + real_logits = discriminator(real_video.contiguous().detach()) - weighted_d_adversarial_loss, lecam_loss, gradient_penalty_loss = disc_loss_fn( - real_logits, - fake_logits, - global_step, - lecam_ema_real = lecam_ema_real, - lecam_ema_fake = lecam_ema_fake, - real_video = real_video if cfg.gradient_penalty_loss_weight is not None else None, - ) - disc_loss = weighted_d_adversarial_loss + lecam_loss + gradient_penalty_loss - if cfg.lecam_loss_weight is not None: - ema_real = torch.mean(real_logits.clone().detach()).to(device, dtype) - ema_fake = torch.mean(fake_logits.clone().detach()).to(device, dtype) - all_reduce_mean(ema_real) - all_reduce_mean(ema_fake) - lecam_ema.update(ema_real, ema_fake) + fake_logits = discriminator(fake_video.contiguous().detach()) - disc_optimizer.zero_grad() - # Backward & update - booster.backward(loss=disc_loss, optimizer=disc_optimizer) - # # NOTE: TODO: clip gradients? this is done in Open-Sora-Plan - # torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1) # NOTE: done by grad_clip - disc_optimizer.step() + lecam_ema_real, lecam_ema_fake = lecam_ema.get() - # Log loss values: - all_reduce_mean(disc_loss) - running_disc_loss += disc_loss.item() - else: - disc_loss = torch.tensor(0.0) - weighted_d_adversarial_loss = torch.tensor(0.0) - lecam_loss = torch.tensor(0.0) - gradient_penalty_loss = torch.tensor(0.0) + weighted_d_adversarial_loss, lecam_loss, gradient_penalty_loss = disc_loss_fn( + real_logits, + fake_logits, + global_step, + lecam_ema_real=lecam_ema_real, + lecam_ema_fake=lecam_ema_fake, + real_video=real_video if cfg.gradient_penalty_loss_weight is not None else None, + ) + disc_loss = weighted_d_adversarial_loss + lecam_loss + gradient_penalty_loss + if cfg.lecam_loss_weight is not None: + ema_real = torch.mean(real_logits.clone().detach()).to(device, dtype) + ema_fake = torch.mean(fake_logits.clone().detach()).to(device, dtype) + all_reduce_mean(ema_real) + all_reduce_mean(ema_fake) + lecam_ema.update(ema_real, ema_fake) - log_step += 1 + disc_optimizer.zero_grad() + # Backward & update + booster.backward(loss=disc_loss, optimizer=disc_optimizer) + # # NOTE: TODO: clip gradients? this is done in Open-Sora-Plan + # torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1) # NOTE: done by grad_clip + disc_optimizer.step() - # Log to tensorboard - if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0: - avg_loss = running_loss / log_step - avg_disc_loss = running_disc_loss / log_step - pbar.set_postfix({"loss": avg_loss, "disc_loss": avg_disc_loss, "step": step, "global_step": global_step}) - running_loss = 0 - log_step = 0 - running_disc_loss = 0 - writer.add_scalar("loss", vae_loss.item(), global_step) - if cfg.wandb: - wandb.log( - { - "iter": global_step, - "num_samples": global_step * total_batch_size, - "epoch": epoch, - "loss": vae_loss.item(), - "kl_loss": weighted_kl_loss.item(), - "gen_adv_loss": adversarial_loss.item(), - "disc_loss": disc_loss.item(), - "lecam_loss": lecam_loss.item(), - "r1_grad_penalty": gradient_penalty_loss.item(), - "nll_loss": weighted_nll_loss.item(), - "avg_loss": avg_loss, - }, - step=global_step, - ) + # Log loss values: + all_reduce_mean(disc_loss) + running_disc_loss += disc_loss.item() + else: + disc_loss = torch.tensor(0.0) + weighted_d_adversarial_loss = torch.tensor(0.0) + lecam_loss = torch.tensor(0.0) + gradient_penalty_loss = torch.tensor(0.0) - # Save checkpoint - if cfg.ckpt_every > 0 and (global_step + 1) % cfg.ckpt_every == 0: - save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}") - os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) # already handled in booster save_model - booster.save_model(vae, os.path.join(save_dir, "model"), shard=True) - booster.save_model(discriminator, os.path.join(save_dir, "discriminator"), shard=True) - booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096) - booster.save_optimizer(disc_optimizer, os.path.join(save_dir, "disc_optimizer"), shard=True, size_per_shard=4096) + log_step += 1 - if lr_scheduler is not None: - booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) - if disc_lr_scheduler is not None: - booster.save_lr_scheduler(disc_lr_scheduler, os.path.join(save_dir, "disc_lr_scheduler")) - - running_states = { + # Log to tensorboard + if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0: + avg_loss = running_loss / log_step + avg_disc_loss = running_disc_loss / log_step + pbar.set_postfix( + {"loss": avg_loss, "disc_loss": avg_disc_loss, "step": step, "global_step": global_step} + ) + running_loss = 0 + log_step = 0 + running_disc_loss = 0 + writer.add_scalar("loss", vae_loss.item(), global_step) + if cfg.wandb: + wandb.log( + { + "iter": global_step, + "num_samples": global_step * total_batch_size, "epoch": epoch, - "step": step+1, - "global_step": global_step+1, - "sample_start_index": (step+1) * cfg.batch_size, - } - - lecam_ema_real, lecam_ema_fake = lecam_ema.get() - lecam_state = { - "lecam_ema_real": lecam_ema_real.item(), - "lecam_ema_fake": lecam_ema_fake.item(), - } - if coordinator.is_master(): - save_json(running_states, os.path.join(save_dir, "running_states.json")) - if cfg.lecam_loss_weight is not None: - save_json(lecam_state, os.path.join(save_dir, "lecam_states.json")) - dist.barrier() + "loss": vae_loss.item(), + "kl_loss": weighted_kl_loss.item(), + "gen_adv_loss": adversarial_loss.item(), + "disc_loss": disc_loss.item(), + "lecam_loss": lecam_loss.item(), + "r1_grad_penalty": gradient_penalty_loss.item(), + "nll_loss": weighted_nll_loss.item(), + "avg_loss": avg_loss, + }, + step=global_step, + ) - logger.info( - f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}" - ) + # Save checkpoint + if cfg.ckpt_every > 0 and (global_step + 1) % cfg.ckpt_every == 0: + save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}") + os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) # already handled in booster save_model + booster.save_model(vae, os.path.join(save_dir, "model"), shard=True) + booster.save_model(discriminator, os.path.join(save_dir, "discriminator"), shard=True) + booster.save_optimizer( + optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096 + ) + booster.save_optimizer( + disc_optimizer, os.path.join(save_dir, "disc_optimizer"), shard=True, size_per_shard=4096 + ) - # print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) - + if lr_scheduler is not None: + booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) + if disc_lr_scheduler is not None: + booster.save_lr_scheduler(disc_lr_scheduler, os.path.join(save_dir, "disc_lr_scheduler")) + + running_states = { + "epoch": epoch, + "step": step + 1, + "global_step": global_step + 1, + "sample_start_index": (step + 1) * cfg.batch_size, + } + + lecam_ema_real, lecam_ema_fake = lecam_ema.get() + lecam_state = { + "lecam_ema_real": lecam_ema_real.item(), + "lecam_ema_fake": lecam_ema_fake.item(), + } + if coordinator.is_master(): + save_json(running_states, os.path.join(save_dir, "running_states.json")) + if cfg.lecam_loss_weight is not None: + save_json(lecam_state, os.path.join(save_dir, "lecam_states.json")) + dist.barrier() + + logger.info( + f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}" + ) + + # print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) # the continue epochs are not resumed, so we need to reset the sampler start index and start step dataloader.sampler.set_start_index(0) start_step = 0 + if __name__ == "__main__": main()