Open-Sora/opensora/models/vae/vae_temporal.py
2024-04-30 08:13:20 +00:00

417 lines
14 KiB
Python

from typing import Tuple, Union
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from opensora.registry import MODELS
from opensora.utils.ckpt_utils import load_checkpoint
from .utils import DiagonalGaussianDistribution
def cast_tuple(t, length=1):
return t if isinstance(t, tuple) else ((t,) * length)
def divisible_by(num, den):
return (num % den) == 0
def is_odd(n):
return not divisible_by(n, 2)
def pad_at_dim(t, pad, dim=-1):
dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
zeros = (0, 0) * dims_from_right
return F.pad(t, (*zeros, *pad), mode="replicate")
def exists(v):
return v is not None
class CausalConv3d(nn.Module):
def __init__(
self,
chan_in,
chan_out,
kernel_size: Union[int, Tuple[int, int, int]],
pad_mode="constant",
strides=None, # allow custom stride
**kwargs,
):
super().__init__()
kernel_size = cast_tuple(kernel_size, 3)
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
dilation = kwargs.pop("dilation", 1)
stride = strides[0] if strides is not None else kwargs.pop("stride", 1)
self.pad_mode = pad_mode
time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
height_pad = height_kernel_size // 2
width_pad = width_kernel_size // 2
self.time_pad = time_pad
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
stride = strides if strides is not None else (stride, 1, 1)
dilation = (dilation, 1, 1)
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
pad_mode = "replicate"
x = F.pad(x, self.time_causal_padding, mode=pad_mode)
x = self.conv(x)
return x
class ResBlock(nn.Module):
def __init__(
self,
in_channels, # SCH: added
filters,
conv_fn,
activation_fn=nn.SiLU,
use_conv_shortcut=False,
num_groups=32,
):
super().__init__()
self.in_channels = in_channels
self.filters = filters
self.activate = activation_fn()
self.use_conv_shortcut = use_conv_shortcut
# SCH: MAGVIT uses GroupNorm by default
self.norm1 = nn.GroupNorm(num_groups, in_channels)
self.conv1 = conv_fn(in_channels, self.filters, kernel_size=(3, 3, 3), bias=False)
self.norm2 = nn.GroupNorm(num_groups, self.filters)
self.conv2 = conv_fn(self.filters, self.filters, kernel_size=(3, 3, 3), bias=False)
if in_channels != filters:
if self.use_conv_shortcut:
self.conv3 = conv_fn(in_channels, self.filters, kernel_size=(3, 3, 3), bias=False)
else:
self.conv3 = conv_fn(in_channels, self.filters, kernel_size=(1, 1, 1), bias=False)
def forward(self, x):
residual = x
x = self.norm1(x)
x = self.activate(x)
x = self.conv1(x)
x = self.norm2(x)
x = self.activate(x)
x = self.conv2(x)
if self.in_channels != self.filters: # SCH: ResBlock X->Y
residual = self.conv3(residual)
return x + residual
def get_activation_fn(activation):
if activation == "relu":
activation_fn = nn.ReLU
elif activation == "swish":
activation_fn = nn.SiLU
else:
raise NotImplementedError
return activation_fn
class Encoder(nn.Module):
"""Encoder Blocks."""
def __init__(
self,
in_out_channels=4,
latent_embed_dim=512, # num channels for latent vector
filters=128,
num_res_blocks=4,
channel_multipliers=(1, 2, 2, 4),
temporal_downsample=(False, True, True),
num_groups=32, # for nn.GroupNorm
activation_fn="swish",
):
super().__init__()
self.filters = filters
self.num_res_blocks = num_res_blocks
self.num_blocks = len(channel_multipliers)
self.channel_multipliers = channel_multipliers
self.temporal_downsample = temporal_downsample
self.num_groups = num_groups
self.embedding_dim = latent_embed_dim
self.activation_fn = get_activation_fn(activation_fn)
self.activate = self.activation_fn()
self.conv_fn = CausalConv3d
self.block_args = dict(
conv_fn=self.conv_fn,
activation_fn=self.activation_fn,
use_conv_shortcut=False,
num_groups=self.num_groups,
)
# first layer conv
self.conv_in = self.conv_fn(
in_out_channels,
filters,
kernel_size=(3, 3, 3),
bias=False,
)
# ResBlocks and conv downsample
self.block_res_blocks = nn.ModuleList([])
self.conv_blocks = nn.ModuleList([])
filters = self.filters
prev_filters = filters # record for in_channels
for i in range(self.num_blocks):
filters = self.filters * self.channel_multipliers[i]
block_items = nn.ModuleList([])
for _ in range(self.num_res_blocks):
block_items.append(ResBlock(prev_filters, filters, **self.block_args))
prev_filters = filters # update in_channels
self.block_res_blocks.append(block_items)
if i < self.num_blocks - 1:
if self.temporal_downsample[i]:
t_stride = 2 if self.temporal_downsample[i] else 1
s_stride = 1
self.conv_blocks.append(
self.conv_fn(
prev_filters, filters, kernel_size=(3, 3, 3), strides=(t_stride, s_stride, s_stride)
)
)
prev_filters = filters # update in_channels
else:
# if no t downsample, don't add since this does nothing for pipeline models
self.conv_blocks.append(nn.Identity(prev_filters)) # Identity
prev_filters = filters # update in_channels
# last layer res block
self.res_blocks = nn.ModuleList([])
for _ in range(self.num_res_blocks):
self.res_blocks.append(ResBlock(prev_filters, filters, **self.block_args))
prev_filters = filters # update in_channels
# MAGVIT uses Group Normalization
self.norm1 = nn.GroupNorm(self.num_groups, prev_filters)
self.conv2 = self.conv_fn(prev_filters, self.embedding_dim, kernel_size=(1, 1, 1), padding="same")
def forward(self, x):
x = self.conv_in(x)
for i in range(self.num_blocks):
for j in range(self.num_res_blocks):
x = self.block_res_blocks[i][j](x)
if i < self.num_blocks - 1:
x = self.conv_blocks[i](x)
for i in range(self.num_res_blocks):
x = self.res_blocks[i](x)
x = self.norm1(x)
x = self.activate(x)
x = self.conv2(x)
return x
class Decoder(nn.Module):
"""Decoder Blocks."""
def __init__(
self,
in_out_channels=4,
latent_embed_dim=512,
filters=128,
num_res_blocks=4,
channel_multipliers=(1, 2, 2, 4),
temporal_downsample=(False, True, True),
num_groups=32, # for nn.GroupNorm
activation_fn="swish",
):
super().__init__()
self.filters = filters
self.num_res_blocks = num_res_blocks
self.num_blocks = len(channel_multipliers)
self.channel_multipliers = channel_multipliers
self.temporal_downsample = temporal_downsample
self.num_groups = num_groups
self.embedding_dim = latent_embed_dim
self.s_stride = 1
self.activation_fn = get_activation_fn(activation_fn)
self.activate = self.activation_fn()
self.conv_fn = CausalConv3d
self.block_args = dict(
conv_fn=self.conv_fn,
activation_fn=self.activation_fn,
use_conv_shortcut=False,
num_groups=self.num_groups,
)
filters = self.filters * self.channel_multipliers[-1]
prev_filters = filters
# last conv
self.conv1 = self.conv_fn(self.embedding_dim, filters, kernel_size=(3, 3, 3), bias=True)
# last layer res block
self.res_blocks = nn.ModuleList([])
for _ in range(self.num_res_blocks):
self.res_blocks.append(ResBlock(filters, filters, **self.block_args))
# ResBlocks and conv upsample
self.block_res_blocks = nn.ModuleList([])
self.num_blocks = len(self.channel_multipliers)
self.conv_blocks = nn.ModuleList([])
# reverse to keep track of the in_channels, but append also in a reverse direction
for i in reversed(range(self.num_blocks)):
filters = self.filters * self.channel_multipliers[i]
# resblock handling
block_items = nn.ModuleList([])
for _ in range(self.num_res_blocks):
block_items.append(ResBlock(prev_filters, filters, **self.block_args))
prev_filters = filters # SCH: update in_channels
self.block_res_blocks.insert(0, block_items) # SCH: append in front
# conv blocks with upsampling
if i > 0:
if self.temporal_downsample[i - 1]:
t_stride = 2 if self.temporal_downsample[i - 1] else 1
# SCH: T-Causal Conv 3x3x3, f -> (t_stride * 2 * 2) * f, depth to space t_stride x 2 x 2
self.conv_blocks.insert(
0,
self.conv_fn(
prev_filters, prev_filters * t_stride * self.s_stride * self.s_stride, kernel_size=(3, 3, 3)
),
)
else:
self.conv_blocks.insert(
0,
nn.Identity(prev_filters),
)
self.norm1 = nn.GroupNorm(self.num_groups, prev_filters)
self.conv_out = self.conv_fn(filters, in_out_channels, 3)
def forward(self, x):
x = self.conv1(x)
for i in range(self.num_res_blocks):
x = self.res_blocks[i](x)
for i in reversed(range(self.num_blocks)):
for j in range(self.num_res_blocks):
x = self.block_res_blocks[i][j](x)
if i > 0:
t_stride = 2 if self.temporal_downsample[i - 1] else 1
x = self.conv_blocks[i - 1](x)
x = rearrange(
x,
"B (C ts hs ws) T H W -> B C (T ts) (H hs) (W ws)",
ts=t_stride,
hs=self.s_stride,
ws=self.s_stride,
)
x = self.norm1(x)
x = self.activate(x)
x = self.conv_out(x)
return x
@MODELS.register_module()
class VAE_Temporal(nn.Module):
def __init__(
self,
in_out_channels=4,
latent_embed_dim=4,
embed_dim=4,
filters=128,
num_res_blocks=4,
channel_multipliers=(1, 2, 2, 4),
temporal_downsample=(True, True, False),
num_groups=32, # for nn.GroupNorm
activation_fn="swish",
):
super().__init__()
self.time_downsample_factor = 2 ** sum(temporal_downsample)
# self.time_padding = self.time_downsample_factor - 1
self.patch_size = (self.time_downsample_factor, 1, 1)
# NOTE: following MAGVIT, conv in bias=False in encoder first conv
self.encoder = Encoder(
in_out_channels=in_out_channels,
latent_embed_dim=latent_embed_dim * 2,
filters=filters,
num_res_blocks=num_res_blocks,
channel_multipliers=channel_multipliers,
temporal_downsample=temporal_downsample,
num_groups=num_groups, # for nn.GroupNorm
activation_fn=activation_fn,
)
self.quant_conv = CausalConv3d(2 * latent_embed_dim, 2 * embed_dim, 1)
self.post_quant_conv = CausalConv3d(embed_dim, latent_embed_dim, 1)
self.decoder = Decoder(
in_out_channels=in_out_channels,
latent_embed_dim=latent_embed_dim,
filters=filters,
num_res_blocks=num_res_blocks,
channel_multipliers=channel_multipliers,
temporal_downsample=temporal_downsample,
num_groups=num_groups, # for nn.GroupNorm
activation_fn=activation_fn,
)
def get_latent_size(self, input_size):
for i in range(len(input_size)):
assert input_size[i] % self.patch_size[i] == 0, "Input size must be divisible by patch size"
input_size = [input_size[i] // self.patch_size[i] for i in range(3)]
return input_size
def encode(self, x):
time_padding = self.time_downsample_factor - x.shape[2] % self.time_downsample_factor
x = pad_at_dim(x, (time_padding, 0), dim=2)
encoded_feature = self.encoder(x)
moments = self.quant_conv(encoded_feature).to(x.dtype)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z, num_frames=None):
time_padding = self.time_downsample_factor - num_frames % self.time_downsample_factor
z = self.post_quant_conv(z)
x = self.decoder(z)
x = x[:, :, time_padding:]
return x
def forward(self, x, sample_posterior=True):
posterior = self.encode(x)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
recon_video = self.decode(z, num_frames=x.shape[2])
return recon_video, posterior, z
@MODELS.register_module("VAE_Temporal_SD")
def VAE_Temporal_SD(from_pretrained=None, **kwargs):
model = VAE_Temporal(
in_out_channels=4,
latent_embed_dim=4,
embed_dim=4,
filters=128,
num_res_blocks=3,
channel_multipliers=(1, 2, 2),
temporal_downsample=(True, True),
**kwargs,
)
if from_pretrained is not None:
load_checkpoint(model, from_pretrained)
return model