Open-Sora/opensora/models/vae/vae_3d.py
2024-04-29 17:18:45 +08:00

1349 lines
49 KiB
Python

import functools
import math
from typing import Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import pack, rearrange, repeat, unpack
from .utils import DiagonalGaussianDistribution
from .lpips import LPIPS # need to pip install https://github.com/CompVis/taming-transformers
from opensora.registry import MODELS
from opensora.utils.ckpt_utils import find_model, load_checkpoint
# from diffusers.models.modeling_utils import ModelMixin
"""Encoder and Decoder stuctures with 3D CNNs."""
"""
NOTE:
removed LayerNorm since not used in this arch
GroupNorm: flax uses default `epsilon=1e-06`, whereas torch uses `eps=1e-05`
for average pool and upsample, input shape needs to be [N,C,T,H,W] --> if not, adjust the scale factors accordingly
!!! opensora read video into [B,C,T,H,W] format output
"""
def cast_tuple(t, length=1):
return t if isinstance(t, tuple) else ((t,) * length)
def divisible_by(num, den):
return (num % den) == 0
def is_odd(n):
return not divisible_by(n, 2)
def pad_at_dim(t, pad, dim=-1, value=0.0):
dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
zeros = (0, 0) * dims_from_right
return F.pad(t, (*zeros, *pad), value=value)
def pick_video_frame(video, frame_indices):
"""get frame_indices from the video of [B, C, T, H, W] and return images of [B, C, H, W]"""
batch, device = video.shape[0], video.device
video = rearrange(video, "b c f ... -> b f c ...")
batch_indices = torch.arange(batch, device=device)
batch_indices = rearrange(batch_indices, "b -> b 1")
images = video[batch_indices, frame_indices]
images = rearrange(images, "b 1 c ... -> b c ...")
return images
def exists(v):
return v is not None
# ============== Generator Adversarial Loss Functions ==============
def lecam_reg(real_pred, fake_pred, ema_real_pred, ema_fake_pred):
assert real_pred.ndim == 0 and ema_fake_pred.ndim == 0
lecam_loss = torch.mean(torch.pow(nn.ReLU()(real_pred - ema_fake_pred), 2))
lecam_loss += torch.mean(torch.pow(nn.ReLU()(ema_real_pred - fake_pred), 2))
return lecam_loss
# Open-Sora-Plan
# Very bad, do not use
def r1_penalty(real_img, real_pred):
"""R1 regularization for discriminator. The core idea is to
penalize the gradient on real data alone: when the
generator distribution produces the true data distribution
and the discriminator is equal to 0 on the data manifold, the
gradient penalty ensures that the discriminator cannot create
a non-zero gradient orthogonal to the data manifold without
suffering a loss in the GAN game.
Ref:
Eq. 9 in Which training methods for GANs do actually converge.
"""
grad_real = torch.autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0]
grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
return grad_penalty
# Open-Sora-Plan
# Implementation as described by https://arxiv.org/abs/1704.00028 # TODO: checkout the codes
def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
"""Calculate gradient penalty for wgan-gp.
Args:
discriminator (nn.Module): Network for the discriminator.
real_data (Tensor): Real input data.
fake_data (Tensor): Fake input data.
weight (Tensor): Weight tensor. Default: None.
Returns:
Tensor: A tensor for gradient penalty.
"""
batch_size = real_data.size(0)
alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))
# interpolate between real_data and fake_data
interpolates = alpha * real_data + (1.0 - alpha) * fake_data
interpolates = torch.autograd.Variable(interpolates, requires_grad=True)
disc_interpolates = discriminator(interpolates)
gradients = torch.autograd.grad(
outputs=disc_interpolates,
inputs=interpolates,
grad_outputs=torch.ones_like(disc_interpolates),
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
if weight is not None:
gradients = gradients * weight
gradients_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
if weight is not None:
gradients_penalty /= torch.mean(weight)
return gradients_penalty
def gradient_penalty_fn(images, output):
# batch_size = images.shape[0]
gradients = torch.autograd.grad(
outputs=output,
inputs=images,
grad_outputs=torch.ones(output.size(), device=images.device),
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
gradients = rearrange(gradients, "b ... -> b (...)")
return ((gradients.norm(2, dim=1) - 1) ** 2).mean()
# ============== Discriminator Adversarial Loss Functions ==============
def hinge_d_loss(logits_real, logits_fake):
loss_real = torch.mean(F.relu(1.0 - logits_real))
loss_fake = torch.mean(F.relu(1.0 + logits_fake))
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def vanilla_d_loss(logits_real, logits_fake):
d_loss = 0.5 * (
torch.mean(torch.nn.functional.softplus(-logits_real)) + torch.mean(torch.nn.functional.softplus(logits_fake))
)
return d_loss
# from MAGVIT, used in place hof hinge_d_loss
def sigmoid_cross_entropy_with_logits(labels, logits):
# The final formulation is: max(x, 0) - x * z + log(1 + exp(-abs(x)))
zeros = torch.zeros_like(logits, dtype=logits.dtype)
condition = logits >= zeros
relu_logits = torch.where(condition, logits, zeros)
neg_abs_logits = torch.where(condition, -logits, logits)
return relu_logits - logits * labels + torch.log1p(torch.exp(neg_abs_logits))
def xavier_uniform_weight_init(m):
if isinstance(m, nn.Conv3d) or isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain("relu"))
if m.bias is not None:
nn.init.zeros_(m.bias)
# print("initialized module to xavier_uniform:", m)
def Sequential(*modules):
modules = [*filter(exists, modules)]
if len(modules) == 0:
return nn.Identity()
return nn.Sequential(*modules)
def SameConv2d(dim_in, dim_out, kernel_size):
kernel_size = cast_tuple(kernel_size, 2)
padding = [k // 2 for k in kernel_size]
return nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, padding=padding)
def adopt_weight(weight, global_step, threshold=0, value=0.0):
if global_step < threshold:
weight = value
return weight
class CausalConv3d(nn.Module):
def __init__(
self,
chan_in,
chan_out,
kernel_size: Union[int, Tuple[int, int, int]],
pad_mode="constant",
strides=None, # allow custom stride
**kwargs,
):
super().__init__()
kernel_size = cast_tuple(kernel_size, 3)
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
dilation = kwargs.pop("dilation", 1)
stride = strides[0] if strides is not None else kwargs.pop("stride", 1)
self.pad_mode = pad_mode
time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
height_pad = height_kernel_size // 2
width_pad = width_kernel_size // 2
self.time_pad = time_pad
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
stride = strides if strides is not None else (stride, 1, 1)
# padding = kwargs.pop('padding', 0)
# if padding == "same" and not all([pad == 1 for pad in padding]):
# padding = "valid"
dilation = (dilation, 1, 1)
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
pad_mode = self.pad_mode if self.time_pad < x.shape[2] else "constant"
x = F.pad(x, self.time_causal_padding, mode=pad_mode)
return self.conv(x)
class ResBlock(nn.Module):
def __init__(
self,
in_channels, # SCH: added
filters,
conv_fn,
activation_fn=nn.SiLU,
use_conv_shortcut=False,
num_groups=32,
device="cpu",
dtype=torch.bfloat16,
):
super().__init__()
self.in_channels = in_channels
self.filters = filters
self.activate = activation_fn()
self.use_conv_shortcut = use_conv_shortcut
# SCH: MAGVIT uses GroupNorm by default
self.norm1 = nn.GroupNorm(num_groups, in_channels, device=device, dtype=dtype)
self.conv1 = conv_fn(in_channels, self.filters, kernel_size=(3, 3, 3), bias=False)
self.norm2 = nn.GroupNorm(num_groups, self.filters, device=device, dtype=dtype)
self.conv2 = conv_fn(self.filters, self.filters, kernel_size=(3, 3, 3), bias=False)
if in_channels != filters:
if self.use_conv_shortcut:
self.conv3 = conv_fn(in_channels, self.filters, kernel_size=(3, 3, 3), bias=False)
else:
self.conv3 = conv_fn(in_channels, self.filters, kernel_size=(1, 1, 1), bias=False)
def forward(self, x):
# device, dtype = x.device, x.dtype
# input_dim = x.shape[1]
residual = x
x = self.norm1(x)
x = self.activate(x)
x = self.conv1(x)
x = self.norm2(x)
x = self.activate(x)
x = self.conv2(x)
if self.in_channels != self.filters: # SCH: ResBlock X->Y
residual = self.conv3(residual)
return x + residual
# SCH: own implementation modified on top of: discriminator with anti-aliased downsampling (blurpool Zhang et al.)
class BlurPool3D(nn.Module):
def __init__(
self,
channels,
pad_type="reflect",
filt_size=3,
stride=2,
pad_off=0,
device="cpu",
dtype=torch.bfloat16,
):
super(BlurPool3D, self).__init__()
self.filt_size = filt_size
self.pad_off = pad_off
self.pad_sizes = [
int(1.0 * (filt_size - 1) / 2),
int(np.ceil(1.0 * (filt_size - 1) / 2)),
int(1.0 * (filt_size - 1) / 2),
int(np.ceil(1.0 * (filt_size - 1) / 2)),
int(1.0 * (filt_size - 1) / 2),
int(np.ceil(1.0 * (filt_size - 1) / 2)),
]
self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
self.stride = stride
self.off = int((self.stride - 1) / 2.0)
self.channels = channels
if self.filt_size == 1:
a = np.array(
[
1.0,
]
)
elif self.filt_size == 2:
a = np.array([1.0, 1.0])
elif self.filt_size == 3:
a = np.array([1.0, 2.0, 1.0])
elif self.filt_size == 4:
a = np.array([1.0, 3.0, 3.0, 1.0])
elif self.filt_size == 5:
a = np.array([1.0, 4.0, 6.0, 4.0, 1.0])
elif self.filt_size == 6:
a = np.array([1.0, 5.0, 10.0, 10.0, 5.0, 1.0])
elif self.filt_size == 7:
a = np.array([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0])
filt_2d = a[:, None] * a[None, :]
filt_3d = torch.Tensor(a[:, None, None] * filt_2d[None, :, :]).to(device, dtype)
filt = filt_3d / torch.sum(filt_3d) # SCH: modified to it 3D
self.register_buffer("filt", filt[None, None, :, :, :].repeat((self.channels, 1, 1, 1, 1)))
self.pad = get_pad_layer(pad_type)(self.pad_sizes)
def forward(self, inp):
if self.filt_size == 1:
if self.pad_off == 0:
return inp[:, :, :: self.stride, :: self.stride]
else:
return self.pad(inp)[:, :, :: self.stride, :: self.stride]
else:
return F.conv3d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])
def get_pad_layer(pad_type):
if pad_type in ["refl", "reflect"]:
PadLayer = nn.ReflectionPad3d
elif pad_type in ["repl", "replicate"]:
PadLayer = nn.ReplicationPad3d
elif pad_type == "zero":
PadLayer = nn.ZeroPad3d
else:
print("Pad type [%s] not recognized" % pad_type)
return PadLayer
class ResBlockDown(nn.Module):
"""3D StyleGAN ResBlock for D."""
def __init__(
self,
in_channels,
filters,
activation_fn,
num_groups=32,
device="cpu",
dtype=torch.bfloat16,
):
super().__init__()
self.filters = filters
self.activation_fn = activation_fn
# SCH: NOTE: although paper says conv (X->Y, Y->Y), original code implementation is (X->X, X->Y), we follow code
self.conv1 = nn.Conv3d(
in_channels, in_channels, (3, 3, 3), padding=1, device=device, dtype=dtype
) # NOTE: init to xavier_uniform
self.norm1 = nn.GroupNorm(num_groups, in_channels, device=device, dtype=dtype)
self.blur = BlurPool3D(in_channels, device=device, dtype=dtype)
self.conv2 = nn.Conv3d(
in_channels, self.filters, (1, 1, 1), bias=False, device=device, dtype=dtype
) # NOTE: init to xavier_uniform
self.conv3 = nn.Conv3d(
in_channels, self.filters, (3, 3, 3), padding=1, device=device, dtype=dtype
) # NOTE: init to xavier_uniform
self.norm2 = nn.GroupNorm(num_groups, self.filters, device=device, dtype=dtype)
# self.apply(xavier_uniform_weight_init)
def forward(self, x):
residual = x
x = self.conv1(x)
x = self.norm1(x)
x = self.activation_fn(x)
residual = self.blur(residual)
residual = self.conv2(residual)
x = self.blur(x)
x = self.conv3(x)
x = self.norm2(x)
x = self.activation_fn(x)
out = (residual + x) / math.sqrt(2)
return out
# SCH: taken from Open Sora Plan
def n_layer_disc_weights_init(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm") != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
class NLayerDiscriminator3D(nn.Module):
"""Defines a 3D PatchGAN discriminator as in Pix2Pix but for 3D inputs."""
def __init__(self, input_nc=1, ndf=64, n_layers=3, use_actnorm=False):
"""
Construct a 3D PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input volumes
ndf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
use_actnorm (bool) -- flag to use actnorm instead of batchnorm
"""
super(NLayerDiscriminator3D, self).__init__()
if not use_actnorm:
norm_layer = nn.BatchNorm3d
else:
raise NotImplementedError("Not implemented.")
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func != nn.BatchNorm3d
else:
use_bias = norm_layer != nn.BatchNorm3d
kw = 4
padw = 1
sequence = [nn.Conv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2**n, 8)
sequence += [
nn.Conv3d(
ndf * nf_mult_prev,
ndf * nf_mult,
kernel_size=(kw, kw, kw),
stride=(1, 2, 2),
padding=padw,
bias=use_bias,
),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True),
]
nf_mult_prev = nf_mult
nf_mult = min(2**n_layers, 8)
sequence += [
nn.Conv3d(
ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=1, padding=padw, bias=use_bias
),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True),
]
sequence += [
nn.Conv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
] # output 1 channel prediction map
self.main = nn.Sequential(*sequence)
def forward(self, input):
"""Standard forward."""
return self.main(input)
class StyleGANDiscriminatorBlur(nn.Module):
"""StyleGAN Discriminator."""
"""
SCH: NOTE:
this discriminator requries the num_frames to be fixed during training;
in case we pre-train with image then train on video, this disciminator's Linear layer would have to be re-trained!
"""
def __init__(
self,
image_size=(128, 128),
num_frames=17,
in_channels=3,
filters=128,
channel_multipliers=(2, 4, 4, 4, 4),
num_groups=32,
dtype=torch.bfloat16,
device="cpu",
):
super().__init__()
self.dtype = dtype
self.input_size = cast_tuple(image_size, 2)
self.filters = filters
self.activation_fn = nn.LeakyReLU(negative_slope=0.2)
self.channel_multipliers = channel_multipliers
self.conv1 = nn.Conv3d(
in_channels, self.filters, (3, 3, 3), padding=1, device=device, dtype=dtype
) # NOTE: init to xavier_uniform
prev_filters = self.filters # record in_channels
self.num_blocks = len(self.channel_multipliers)
self.res_block_list = nn.ModuleList([])
for i in range(self.num_blocks):
filters = self.filters * self.channel_multipliers[i]
self.res_block_list.append(
ResBlockDown(prev_filters, filters, self.activation_fn, device=device, dtype=dtype).apply(
xavier_uniform_weight_init
)
)
prev_filters = filters # update in_channels
self.conv2 = nn.Conv3d(
prev_filters, prev_filters, (3, 3, 3), padding=1, device=device, dtype=dtype
) # NOTE: init to xavier_uniform
# torch.nn.init.xavier_uniform_(self.conv2.weight)
self.norm1 = nn.GroupNorm(num_groups, prev_filters, dtype=dtype, device=device)
scale_factor = 2**self.num_blocks
if num_frames % scale_factor != 0: # SCH: NOTE: has first frame which would be padded before usage
time_scaled = num_frames // scale_factor + 1
else:
time_scaled = num_frames / scale_factor
assert (
self.input_size[0] % scale_factor == 0
), f"image width {self.input_size[0]} is not divisible by scale factor {scale_factor}"
assert (
self.input_size[1] % scale_factor == 0
), f"image height {self.input_size[1]} is not divisible by scale factor {scale_factor}"
w_scaled, h_scaled = self.input_size[0] / scale_factor, self.input_size[1] / scale_factor
in_features = int(prev_filters * time_scaled * w_scaled * h_scaled) # (C*T*W*H)
self.linear1 = nn.Linear(in_features, prev_filters, device=device, dtype=dtype) # NOTE: init to xavier_uniform
self.linear2 = nn.Linear(prev_filters, 1, device=device, dtype=dtype) # NOTE: init to xavier_uniform
# self.apply(xavier_uniform_weight_init)
def forward(self, x):
x = self.conv1(x)
# print("discriminator aft conv:", x.size())
x = self.activation_fn(x)
for i in range(self.num_blocks):
x = self.res_block_list[i](x)
# print("discriminator resblock down:", x.size())
x = self.conv2(x)
# print("discriminator aft conv2:", x.size())
x = self.norm1(x)
x = self.activation_fn(x)
x = x.reshape((x.shape[0], -1)) # SCH: [B, (C * T * W * H)] ?
# print("discriminator reshape:", x.size())
x = self.linear1(x)
# print("discriminator aft linear1:", x.size())
x = self.activation_fn(x)
x = self.linear2(x)
# print("discriminator aft linear2:", x.size())
return x
class Encoder(nn.Module):
"""Encoder Blocks."""
def __init__(
self,
filters=128,
num_res_blocks=4,
channel_multipliers=(1, 2, 2, 4),
temporal_downsample=(False, True, True),
num_groups=32, # for nn.GroupNorm
# in_out_channels = 3, # SCH: added, in_channels at the start
latent_embed_dim=512, # num channels for latent vector
# conv_downsample = False,
disable_spatial_downsample=False, # for vae pipeline
custom_conv_padding=None,
activation_fn="swish",
device="cpu",
dtype=torch.bfloat16,
):
super().__init__()
self.filters = filters
self.num_res_blocks = num_res_blocks
self.channel_multipliers = channel_multipliers
self.temporal_downsample = temporal_downsample
self.num_groups = num_groups
self.embedding_dim = latent_embed_dim
self.disable_spatial_downsample = disable_spatial_downsample
# self.conv_downsample = conv_downsample
self.custom_conv_padding = custom_conv_padding
if activation_fn == "relu":
self.activation_fn = nn.ReLU
elif activation_fn == "swish":
self.activation_fn = nn.SiLU
else:
raise NotImplementedError
self.activate = self.activation_fn()
self.conv_fn = functools.partial(
CausalConv3d,
# padding='valid' if self.custom_conv_padding is not None else 'same', # SCH: lower letter for pytorch
dtype=dtype,
device=device,
)
self.block_args = dict(
conv_fn=self.conv_fn,
dtype=dtype,
activation_fn=self.activation_fn,
use_conv_shortcut=False,
num_groups=self.num_groups,
device=device,
)
# NOTE: moved to VAE for separate first frame processing
# self.conv1 = self.conv_fn(in_out_channels, self.filters, kernel_size=(3, 3, 3), bias=False)
# ResBlocks and conv downsample
self.block_res_blocks = nn.ModuleList([])
self.num_blocks = len(self.channel_multipliers)
self.conv_blocks = nn.ModuleList([])
filters = self.filters
prev_filters = filters # record for in_channels
for i in range(self.num_blocks):
# resblock handling
filters = self.filters * self.channel_multipliers[i] # SCH: determine the number out_channels
block_items = nn.ModuleList([])
for _ in range(self.num_res_blocks):
block_items.append(ResBlock(prev_filters, filters, **self.block_args))
prev_filters = filters # update in_channels
self.block_res_blocks.append(block_items)
if i < self.num_blocks - 1: # SCH: T-Causal Conv 3x3x3, 128->128, stride t x stride s x stride s
t_stride = 2 if self.temporal_downsample[i] else 1
s_stride = 2 if not self.disable_spatial_downsample else 1
self.conv_blocks.append(
self.conv_fn(prev_filters, filters, kernel_size=(3, 3, 3), strides=(t_stride, s_stride, s_stride))
) # SCH: should be same in_channel and out_channel
prev_filters = filters # update in_channels
# last layer res block
self.res_blocks = nn.ModuleList([])
for _ in range(self.num_res_blocks):
self.res_blocks.append(ResBlock(prev_filters, filters, **self.block_args))
prev_filters = filters # update in_channels
# MAGVIT uses Group Normalization
self.norm1 = nn.GroupNorm(
self.num_groups, prev_filters, dtype=dtype, device=device
) # SCH: separate <prev_filters> channels into 32 groups
self.conv2 = nn.Conv3d(
prev_filters, self.embedding_dim, kernel_size=(1, 1, 1), dtype=dtype, device=device, padding="same"
)
def forward(self, x):
# dtype, device = x.dtype, x.device
# NOTE: moved to VAE for separate first frame processing
# x = self.conv1(x)
# print("encoder:", x.size())
for i in range(self.num_blocks):
for j in range(self.num_res_blocks):
x = self.block_res_blocks[i][j](x)
# print("encoder:", x.size())
if i < self.num_blocks - 1:
x = self.conv_blocks[i](x)
# print("encoder:", x.size())
for i in range(self.num_res_blocks):
x = self.res_blocks[i](x)
# print("encoder:", x.size())
x = self.norm1(x)
x = self.activate(x)
x = self.conv2(x)
# print("encoder:", x.size())
return x
class Decoder(nn.Module):
"""Decoder Blocks."""
def __init__(
self,
latent_embed_dim=512,
filters=128,
# in_out_channels = 4,
num_res_blocks=4,
channel_multipliers=(1, 2, 2, 4),
temporal_downsample=(False, True, True),
num_groups=32, # for nn.GroupNorm
# upsample = "nearest+conv", # options: "deconv", "nearest+conv"
disable_spatial_upsample=False, # for vae pipeline
custom_conv_padding=None,
activation_fn="swish",
device="cpu",
dtype=torch.bfloat16,
):
super().__init__()
# self.output_dim = in_out_channels
self.embedding_dim = latent_embed_dim
self.filters = filters
self.num_res_blocks = num_res_blocks
self.channel_multipliers = channel_multipliers
self.temporal_downsample = temporal_downsample
self.num_groups = num_groups
# self.upsample = upsample
self.s_stride = 1 if disable_spatial_upsample else 2 # spatial stride
self.custom_conv_padding = custom_conv_padding
# self.norm_type = self.config.vqvae.norm_type
# self.num_remat_block = self.config.vqvae.get('num_dec_remat_blocks', 0)
if activation_fn == "relu":
self.activation_fn = nn.ReLU
elif activation_fn == "swish":
self.activation_fn = nn.SiLU
else:
raise NotImplementedError
self.activate = self.activation_fn()
self.conv_fn = functools.partial(
CausalConv3d,
dtype=dtype,
# padding='valid' if self.custom_conv_padding is not None else 'same', # SCH: lower letter for pytorch
device=device,
)
self.block_args = dict(
conv_fn=self.conv_fn,
activation_fn=self.activation_fn,
use_conv_shortcut=False,
num_groups=self.num_groups,
device=device,
dtype=dtype,
)
self.num_blocks = len(self.channel_multipliers)
filters = self.filters * self.channel_multipliers[-1]
self.conv1 = self.conv_fn(self.embedding_dim, filters, kernel_size=(3, 3, 3), bias=True)
# last layer res block
self.res_blocks = nn.ModuleList([])
for _ in range(self.num_res_blocks):
self.res_blocks.append(ResBlock(filters, filters, **self.block_args))
# TODO: do I need to add adaptive GroupNorm in between each block?
# # NOTE: upsample, dimensions T, H, W
# self.upsampler_with_t = nn.Upsample(scale_factor=(2,2,2))
# self.upsampler = nn.Upsample(scale_factor=(1,2,2))
# ResBlocks and conv upsample
prev_filters = filters # SCH: in_channels
self.block_res_blocks = nn.ModuleList([])
self.num_blocks = len(self.channel_multipliers)
self.conv_blocks = nn.ModuleList([])
# SCH: reverse to keep track of the in_channels, but append also in a reverse direction
for i in reversed(range(self.num_blocks)):
filters = self.filters * self.channel_multipliers[i]
# resblock handling
block_items = nn.ModuleList([])
for _ in range(self.num_res_blocks):
block_items.append(ResBlock(prev_filters, filters, **self.block_args))
prev_filters = filters # SCH: update in_channels
self.block_res_blocks.insert(0, block_items) # SCH: append in front
# conv blocks with upsampling
if i > 0:
t_stride = 2 if self.temporal_downsample[i - 1] else 1
# SCH: T-Causal Conv 3x3x3, f -> (t_stride * 2 * 2) * f, depth to space t_stride x 2 x 2
self.conv_blocks.insert(
0,
self.conv_fn(
prev_filters, prev_filters * t_stride * self.s_stride * self.s_stride, kernel_size=(3, 3, 3)
),
)
self.norm1 = nn.GroupNorm(self.num_groups, prev_filters, device=device, dtype=dtype)
# NOTE: moved to VAE for separate first frame processing
# self.conv2 = self.conv_fn(prev_filters, self.output_dim, kernel_size=(3, 3, 3))
def forward(
self,
x,
**kwargs,
):
# dtype, device = x.dtype, x.device
x = self.conv1(x)
# print("decoder:", x.size())
for i in range(self.num_res_blocks):
x = self.res_blocks[i](x)
# print("decoder:", x.size())
for i in reversed(range(self.num_blocks)): # reverse here to make decoder symmetric with encoder
for j in range(self.num_res_blocks):
x = self.block_res_blocks[i][j](x)
# print("decoder:", x.size())
if i > 0:
t_stride = 2 if self.temporal_downsample[i - 1] else 1
# SCH: T-Causal Conv 3x3x3, f -> (t_stride * 2 * 2) * f, depth to space t_stride x 2 x 2
x = self.conv_blocks[i - 1](x)
x = rearrange(
x,
"B (C ts hs ws) T H W -> B C (T ts) (H hs) (W ws)",
ts=t_stride,
hs=self.s_stride,
ws=self.s_stride,
)
# print("decoder:", x.size())
x = self.norm1(x)
x = self.activate(x)
# NOTE: moved to VAE for separate first frame processing
# x = self.conv2(x)
return x
@MODELS.register_module()
class VAE_3D_V2(nn.Module): # , ModelMixin
"""The 3D VAE"""
def __init__(
self,
latent_embed_dim=256,
filters=128,
num_res_blocks=2,
separate_first_frame_encoding=False,
channel_multipliers=(1, 2, 2, 4),
temporal_downsample=(True, True, False),
num_groups=32, # for nn.GroupNorm
disable_space=False,
custom_conv_padding=None,
activation_fn="swish",
in_out_channels=4,
kl_embed_dim=64,
encoder_double_z=True,
device="cpu",
dtype="bf16",
):
super().__init__()
if type(dtype) == str:
if dtype == "bf16":
dtype = torch.bfloat16
elif dtype == "fp16":
dtype = torch.float16
else:
raise NotImplementedError(f"dtype: {dtype}")
# ==== Model Params ====
# self.image_size = cast_tuple(image_size, 2)
self.time_downsample_factor = 2 ** sum(temporal_downsample)
self.time_padding = self.time_downsample_factor - 1
self.separate_first_frame_encoding = separate_first_frame_encoding
image_down = 2 ** len(temporal_downsample)
t_down = 2 ** len([x for x in temporal_downsample if x == True])
self.patch_size = (t_down, image_down, image_down)
# ==== Model Initialization ====
# encoder & decoder first and last conv layer
# SCH: NOTE: following MAGVIT, conv in bias=False in encoder first conv
self.conv_in = CausalConv3d(
in_out_channels, filters, kernel_size=(3, 3, 3), bias=False, dtype=dtype, device=device
)
self.conv_in_first_frame = nn.Identity()
self.conv_out_first_frame = nn.Identity()
if separate_first_frame_encoding:
self.conv_in_first_frame = SameConv2d(in_out_channels, filters, (3, 3))
self.conv_out_first_frame = SameConv2d(filters, in_out_channels, (3, 3))
self.conv_out = CausalConv3d(filters, in_out_channels, 3, dtype=dtype, device=device)
self.encoder = Encoder(
filters=filters,
num_res_blocks=num_res_blocks,
channel_multipliers=channel_multipliers,
temporal_downsample=temporal_downsample,
num_groups=num_groups, # for nn.GroupNorm
# in_out_channels = in_out_channels,
latent_embed_dim=latent_embed_dim * 2 if encoder_double_z else latent_embed_dim,
# conv_downsample = conv_downsample,
disable_spatial_downsample=disable_space,
custom_conv_padding=custom_conv_padding,
activation_fn=activation_fn,
device=device,
dtype=dtype,
)
self.decoder = Decoder(
latent_embed_dim=latent_embed_dim,
filters=filters,
# in_out_channels = in_out_channels,
num_res_blocks=num_res_blocks,
channel_multipliers=channel_multipliers,
temporal_downsample=temporal_downsample,
num_groups=num_groups, # for nn.GroupNorm
# upsample = upsample, # options: "deconv", "nearest+conv"
disable_spatial_upsample=disable_space,
custom_conv_padding=custom_conv_padding,
activation_fn=activation_fn,
device=device,
dtype=dtype,
)
if encoder_double_z:
self.quant_conv = nn.Conv3d(2 * latent_embed_dim, 2 * kl_embed_dim, 1, device=device, dtype=dtype)
else:
self.quant_conv = nn.Conv3d(latent_embed_dim, 2 * kl_embed_dim, 1, device=device, dtype=dtype)
self.post_quant_conv = nn.Conv3d(kl_embed_dim, latent_embed_dim, 1, device=device, dtype=dtype)
def get_latent_size(self, input_size):
for i in range(len(input_size)):
assert input_size[i] % self.patch_size[i] == 0, "Input size must be divisible by patch size"
input_size = [input_size[i] // self.patch_size[i] for i in range(3)]
return input_size
def encode(
self,
video,
video_contains_first_frame=True,
):
encode_first_frame_separately = self.separate_first_frame_encoding and video_contains_first_frame
# whether to pad video or not
if video_contains_first_frame:
video_len = video.shape[2]
video = pad_at_dim(video, (self.time_padding, 0), value=0.0, dim=2)
video_packed_shape = [torch.Size([self.time_padding]), torch.Size([]), torch.Size([video_len - 1])]
# print("pre-encoder:", video.size())
# NOTE: moved encoder conv1 here for separate first frame encoding
if encode_first_frame_separately:
pad, first_frame, video = unpack(video, video_packed_shape, "b c * h w")
first_frame = self.conv_in_first_frame(first_frame)
video = self.conv_in(video)
# print("pre-encoder:", video.size())
if encode_first_frame_separately:
video, _ = pack([first_frame, video], "b c * h w")
video = pad_at_dim(video, (self.time_padding, 0), dim=2)
encoded_feature = self.encoder(video)
# print("after encoder:", encoded_feature.size())
# NOTE: TODO: do we include this before gaussian distri? or go directly to Gaussian distribution
moments = self.quant_conv(encoded_feature).to(video.dtype)
posterior = DiagonalGaussianDistribution(moments)
# print("after encoder moments:", moments.size())
return posterior
def decode(
self,
z,
video_contains_first_frame=True,
):
# dtype = z.dtype
decode_first_frame_separately = self.separate_first_frame_encoding and video_contains_first_frame
z = self.post_quant_conv(z)
# print("pre decoder, post quant conv:", z.size())
dec = self.decoder(z)
# print("post decoder:", dec.size())
# SCH: moved decoder last conv layer here for separate first frame decoding
if decode_first_frame_separately:
left_pad, dec_ff, dec = (
dec[:, :, : self.time_padding],
dec[:, :, self.time_padding],
dec[:, :, (self.time_padding + 1) :],
)
out = self.conv_out(dec)
outff = self.conv_out_first_frame(dec_ff)
video, _ = pack([outff, out], "b c * h w")
else:
video = self.conv_out(dec)
# if video were padded, remove padding
if video_contains_first_frame:
video = video[:, :, self.time_padding :]
# print("conv out:", video.size())
return video
def get_last_layer(self):
# CausalConv3d wraps the conv
return self.conv_out.conv.weight
def forward(
self,
video,
sample_posterior=True,
video_contains_first_frame=True,
# split = "train",
):
batch, channels, frames = video.shape[:3]
assert divisible_by(
frames - int(video_contains_first_frame), self.time_downsample_factor
), f"number of frames {frames} minus the first frame ({frames - int(video_contains_first_frame)}) must be divisible by the total downsample factor across time {self.time_downsample_factor}"
posterior = self.encode(
video,
video_contains_first_frame=video_contains_first_frame,
)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
recon_video = self.decode(z, video_contains_first_frame=video_contains_first_frame)
return recon_video, posterior
class VEALoss(nn.Module):
def __init__(
self,
logvar_init=0.0,
perceptual_loss_weight=0.1,
kl_loss_weight=0.000001,
# vgg=None,
# vgg_weights: VGG16_Weights = VGG16_Weights.DEFAULT,
device="cpu",
dtype="bf16",
):
super().__init__()
if type(dtype) == str:
if dtype == "bf16":
dtype = torch.bfloat16
elif dtype == "fp16":
dtype = torch.float16
else:
raise NotImplementedError(f"dtype: {dtype}")
# KL Loss
self.kl_loss_weight = kl_loss_weight
# Perceptual Loss
self.perceptual_loss_fn = LPIPS().eval().to(device, dtype)
self.perceptual_loss_weight = perceptual_loss_weight
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
def forward(
self,
video,
recon_video,
posterior,
nll_weights=None,
split="train",
):
video = rearrange(video, "b c t h w -> (b t) c h w").contiguous()
recon_video = rearrange(recon_video, "b c t h w -> (b t) c h w").contiguous()
# reconstruction loss
recon_loss = torch.abs(video - recon_video)
# perceptual loss
if self.perceptual_loss_weight is not None and self.perceptual_loss_weight > 0.0:
# handle channels
channels = video.shape[1]
assert channels in {1, 3, 4}
if channels == 1:
input_vgg_input = repeat(video, "b 1 h w -> b c h w", c=3)
recon_vgg_input = repeat(recon_video, "b 1 h w -> b c h w", c=3)
elif channels == 4: # SCH: take the first 3 for perceptual loss calc
input_vgg_input = video[:, :3]
recon_vgg_input = recon_video[:, :3]
else:
input_vgg_input = video
recon_vgg_input = recon_video
perceptual_loss = self.perceptual_loss_fn(input_vgg_input, recon_vgg_input)
recon_loss = recon_loss + self.perceptual_loss_weight * perceptual_loss
nll_loss = recon_loss / torch.exp(self.logvar) + self.logvar
weighted_nll_loss = nll_loss
if nll_weights is not None:
weighted_nll_loss = nll_weights * nll_loss
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
# KL Loss
weighted_kl_loss = 0
if self.kl_loss_weight is not None and self.kl_loss_weight > 0.0:
kl_loss = posterior.kl()
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
weighted_kl_loss = kl_loss * self.kl_loss_weight
return nll_loss, weighted_nll_loss, weighted_kl_loss
class AdversarialLoss(nn.Module):
def __init__(
self,
discriminator_factor=1.0,
discriminator_start=50001,
generator_factor=0.5,
generator_loss_type="non-saturating",
):
super().__init__()
self.discriminator_factor = discriminator_factor
self.discriminator_start = discriminator_start
self.generator_factor = generator_factor
self.generator_loss_type = generator_loss_type
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer):
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] # SCH: TODO: debug added creat
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[
0
] # SCH: TODO: debug added create_graph=True
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.generator_factor
return d_weight
def forward(
self,
fake_logits,
nll_loss,
last_layer,
global_step,
is_training=True,
):
# NOTE: following MAGVIT to allow non_saturating
assert self.generator_loss_type in ["hinge", "vanilla", "non-saturating"]
if self.generator_loss_type == "hinge":
gen_loss = -torch.mean(fake_logits)
elif self.generator_loss_type == "non-saturating":
gen_loss = torch.mean(
sigmoid_cross_entropy_with_logits(labels=torch.ones_like(fake_logits), logits=fake_logits)
)
else:
raise ValueError("Generator loss {} not supported".format(self.generator_loss_type))
if self.discriminator_factor is not None and self.discriminator_factor > 0.0:
try:
d_weight = self.calculate_adaptive_weight(nll_loss, gen_loss, last_layer)
except RuntimeError:
assert not is_training
d_weight = torch.tensor(0.0)
else:
d_weight = torch.tensor(0.0)
disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_start)
weighted_gen_loss = d_weight * disc_factor * gen_loss
return weighted_gen_loss
class LeCamEMA:
def __init__(self, ema_real=0.0, ema_fake=0.0, decay=0.999, dtype=torch.bfloat16, device="cpu"):
self.decay = decay
self.ema_real = torch.tensor(ema_real).to(device, dtype)
self.ema_fake = torch.tensor(ema_fake).to(device, dtype)
def update(self, ema_real, ema_fake):
self.ema_real = self.ema_real * self.decay + ema_real * (1 - self.decay)
self.ema_fake = self.ema_fake * self.decay + ema_fake * (1 - self.decay)
def get(self):
return self.ema_real, self.ema_fake
class DiscriminatorLoss(nn.Module):
def __init__(
self,
discriminator_factor=1.0,
discriminator_start=50001,
discriminator_loss_type="non-saturating",
lecam_loss_weight=None,
gradient_penalty_loss_weight=None, # SCH: following MAGVIT config.vqgan.grad_penalty_cost
):
super().__init__()
assert discriminator_loss_type in ["hinge", "vanilla", "non-saturating"]
self.discriminator_factor = discriminator_factor
self.discriminator_start = discriminator_start
self.lecam_loss_weight = lecam_loss_weight
self.gradient_penalty_loss_weight = gradient_penalty_loss_weight
self.discriminator_loss_type = discriminator_loss_type
def forward(
self,
real_logits,
fake_logits,
global_step,
lecam_ema_real=None,
lecam_ema_fake=None,
real_video=None,
split="train",
):
if self.discriminator_factor is not None and self.discriminator_factor > 0.0:
disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_start)
if self.discriminator_loss_type == "hinge":
disc_loss = hinge_d_loss(real_logits, fake_logits)
elif self.discriminator_loss_type == "non-saturating":
if real_logits is not None:
real_loss = sigmoid_cross_entropy_with_logits(
labels=torch.ones_like(real_logits), logits=real_logits
)
else:
real_loss = 0.0
if fake_logits is not None:
fake_loss = sigmoid_cross_entropy_with_logits(
labels=torch.zeros_like(fake_logits), logits=fake_logits
)
else:
fake_loss = 0.0
disc_loss = 0.5 * (torch.mean(real_loss) + torch.mean(fake_loss))
elif self.discriminator_loss_type == "vanilla":
disc_loss = vanilla_d_loss(real_logits, fake_logits)
else:
raise ValueError(f"Unknown GAN loss '{self.discriminator_loss_type}'.")
weighted_d_adversarial_loss = disc_factor * disc_loss
else:
weighted_d_adversarial_loss = 0
lecam_loss = torch.tensor(0.0)
if self.lecam_loss_weight is not None and self.lecam_loss_weight > 0.0:
real_pred = torch.mean(real_logits)
fake_pred = torch.mean(fake_logits)
lecam_loss = lecam_reg(real_pred, fake_pred, lecam_ema_real, lecam_ema_fake)
lecam_loss = lecam_loss * self.lecam_loss_weight
gradient_penalty = torch.tensor(0.0)
if self.gradient_penalty_loss_weight is not None and self.gradient_penalty_loss_weight > 0.0:
assert real_video is not None
gradient_penalty = gradient_penalty_fn(real_video, real_logits)
# gradient_penalty = r1_penalty(real_video, real_logits) # MAGVIT uses r1 penalty
gradient_penalty *= self.gradient_penalty_loss_weight
# discriminator_loss = weighted_d_adversarial_loss + lecam_loss + gradient_penalty
# log = {
# "{}/d_adversarial_loss".format(split): weighted_d_adversarial_loss.detach().mean(),
# "{}/lecam_loss".format(split): lecam_loss.detach().mean(),
# "{}/gradient_penalty".format(split): gradient_penalty.detach().mean(),
# }
return (weighted_d_adversarial_loss, lecam_loss, gradient_penalty)
@MODELS.register_module("VAE_MAGVIT_V2")
def VAE_MAGVIT_V2(from_pretrained=None, **kwargs):
model = VAE_3D_V2(**kwargs)
if from_pretrained is not None:
load_checkpoint(model, from_pretrained, model_name="model")
return model
@MODELS.register_module("DISCRIMINATOR_3D")
def DISCRIMINATOR_3D(from_pretrained=None, inflate_from_2d=False, use_pretrained=True, **kwargs):
model = StyleGANDiscriminatorBlur(**kwargs).apply(xavier_uniform_weight_init)
if from_pretrained is not None:
if use_pretrained:
if inflate_from_2d:
load_checkpoint_with_inflation(model, from_pretrained)
else:
load_checkpoint(model, from_pretrained, model_name="discriminator")
print(f"loaded discriminator")
else:
print(f"discriminator use_pretrained={use_pretrained}, initializing new discriminator")
return model
@MODELS.register_module("N_Layer_DISCRIMINATOR_3D")
def DISCRIMINATOR_3D(from_pretrained=None, inflate_from_2d=False, use_pretrained=True, **kwargs):
model = NLayerDiscriminator3D(input_nc=3, n_layers=3,).apply(n_layer_disc_weights_init)
if from_pretrained is not None:
if use_pretrained:
if inflate_from_2d:
load_checkpoint_with_inflation(model, from_pretrained)
else:
load_checkpoint(model, from_pretrained, model_name="discriminator")
print(f"loaded discriminator")
else:
print(f"discriminator use_pretrained={use_pretrained}, initializing new discriminator")
return model
def load_checkpoint_with_inflation(model, ckpt_path):
"""
pre-train using image, then inflate to 3D videos
"""
if ckpt_path.endswith(".pt") or ckpt_path.endswith(".pth"):
state_dict = find_model(ckpt_path)
with torch.no_grad():
for key in state_dict:
if key in model:
# central inflation
if state_dict[key].size() == model[key][:, :, 0, :, :].size():
# temporal dimension
val = torch.zeros_like(model[key])
centre = int(model[key].size(2) // 2)
val[:, :, centre, :, :] = state_dict[key]
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
print(f"Missing keys: {missing_keys}")
print(f"Unexpected keys: {unexpected_keys}")
else:
load_checkpoint(model, ckpt_path) # use the default function