Open-Sora/opensora/models/vae/vae_3d.py
2024-04-09 18:14:25 +08:00

468 lines
17 KiB
Python

import functools
from typing import Any, Dict, Tuple, Type, Union, Sequence, Optional
from absl import logging
import torch
import torch.nn as nn
import numpy as np
from numpy import typing as nptyping
from opensora.models.vae import model_utils
from opensora.registry import MODELS
from opensora.utils.ckpt_utils import load_checkpoint
"""Encoder and Decoder stuctures with 3D CNNs."""
"""
NOTE:
removed LayerNorm since not used in this arch
GroupNorm: flax uses default `epsilon=1e-06`, whereas torch uses `eps=1e-05`
for average pool and upsample, input shape needs to be [N,C,T,H,W] --> if not, adjust the scale factors accordingly
!!! opensora read video into [B,C,T,H,W] format output
TODO:
check data dimensions format
"""
class ResBlock(nn.Module):
def __init__(
self,
in_out_channels, # SCH: added
filters,
# norm_fn, # SCH: removed, use GN
conv_fn,
activation_fn=nn.ReLU,
use_conv_shortcut=False,
num_groups=32,
device="cpu",
dtype=torch.bfloat16,
):
super().__init__()
self.filters = filters
self.activate = activation_fn()
self.use_conv_shortcut = use_conv_shortcut
# SCH: MAGVIT uses GroupNorm by default
self.norm1 = nn.GroupNorm(num_groups, in_out_channels, device=device, dtype=dtype)
self.conv1 = conv_fn(in_out_channels, self.filters, kernel_size=(3, 3, 3), bias=False)
self.norm2 = nn.GroupNorm(num_groups, self.filters, device=device, dtype=dtype)
self.conv2 = conv_fn(self.filters, self.filters, kernel_size=(3, 3, 3), bias=False)
if self.use_conv_shortcut:
self.conv3 = conv_fn(in_out_channels, self.filters, kernel_size=(3, 3, 3), bias=False)
else:
self.conv3 = conv_fn(in_out_channels, self.filters, kernel_size=(1, 1, 1), bias=False)
def forward(self, x):
device, dtype = x.device, x.dtype
input_dim = x.shape[1]
residual = x
x = self.norm1.to(device,dtype)(x)
x = self.activate(x)
x = self.conv1.to(device,dtype)(x)
x = self.norm2.to(device, dtype)(x)
x = self.activate(x)
x = self.conv2.to(device, dtype)(x)
if input_dim != self.filters: # TODO: what does it do here
residual = self.conv3.to(device, dtype)(residual)
return x + residual
def _get_selected_flags(total_len: int, select_len: int, suffix: bool):
assert select_len <= total_len
selected = np.zeros(total_len, dtype=bool)
if not suffix:
selected[:select_len] = True
else:
selected[-select_len:] = True
return selected
class Encoder(nn.Module):
"""Encoder Blocks."""
def __init__(self,
filters = 64,
num_res_blocks = 2,
channel_multipliers = (1, 2, 2, 4),
temporal_downsample = (True, True, False),
num_groups = 32, # for nn.GroupNorm
in_out_channels = 3, # SCH: added, in_channels at the start
latent_embed_dim = 256, # num channels for latent vector
conv_downsample = False,
custom_conv_padding = None,
activation_fn = 'swish',
device="cpu",
dtype=torch.bfloat16,
):
super().__init__()
self.filters = filters
self.num_res_blocks = num_res_blocks
self.channel_multipliers = channel_multipliers
self.temporal_downsample = temporal_downsample
self.num_groups = num_groups
if isinstance(self.temporal_downsample, int):
self.temporal_downsample = _get_selected_flags(
len(self.channel_multipliers) - 1, self.temporal_downsample, False)
self.embedding_dim = latent_embed_dim
self.conv_downsample = conv_downsample
self.custom_conv_padding = custom_conv_padding
# self.norm_type = self.config.vqvae.norm_type
# self.num_remat_block = self.config.vqvae.get('num_enc_remat_blocks', 0)
if activation_fn == 'relu':
self.activation_fn = nn.ReLU
elif activation_fn == 'swish':
self.activation_fn = nn.SiLU
else:
raise NotImplementedError
self.activate = self.activation_fn()
self.conv_fn = functools.partial(
nn.Conv3d,
dtype=dtype,
padding='valid' if self.custom_conv_padding is not None else 'same', # SCH: lower letter for pytorch
# custom_padding=self.custom_conv_padding
device=device,
)
# self.norm_fn = model_utils.get_norm_layer(
# norm_type=self.norm_type, dtype=self.dtype)
self.block_args = dict(
# norm_fn=self.norm_fn,
conv_fn=self.conv_fn,
dtype=dtype,
activation_fn=self.activation_fn,
use_conv_shortcut=False,
num_groups=self.num_groups,
device=device,
)
self.conv1 = self.conv_fn(in_out_channels, self.filters, kernel_size=(3, 3, 3), bias=False)
# ResBlocks and conv downsample
self.block_res_blocks = []
self.num_blocks = len(self.channel_multipliers)
self.conv_blocks = []
filters = self.filters
prev_filters = filters # record for in_channels
for i in range(self.num_blocks):
# resblock handling
filters = self.filters * self.channel_multipliers[i] # SCH: determine the number out_channels
block_items = []
for _ in range(self.num_res_blocks):
block_items.append(ResBlock(prev_filters, filters, **self.block_args))
prev_filters = filters # update in_channels
self.block_res_blocks.append(block_items)
if i < self.num_blocks - 1:
# conv blocks handling
if self.conv_downsample:
t_stride = 2 if self.temporal_downsample[i] else 1
t_pad = 1 if self.temporal_downsample[i] else 0
self.conv_blocks.append(self.conv_fn(prev_filters, filters, kernel_size=(4, 4, 4), stride=(t_stride, 2, 2)), padding=(t_pad,1,1)) # SCH: should be same in_channel and out_channel
prev_filters = filters # update in_channels
# NOTE: downsample, dimensions T, H, W
self.avg_pool_with_t = nn.AvgPool3d((2,2,2), count_include_pad=False)
self.avg_pool = nn.AvgPool3d((1,2,2), count_include_pad=False)
# last layer res block
self.res_blocks = []
for _ in range(self.num_res_blocks):
self.res_blocks.append(ResBlock(prev_filters, filters, **self.block_args))
prev_filters = filters # update in_channels
# MAGVIT uses Group Normalization
self.norm1 = nn.GroupNorm(self.num_groups, prev_filters, dtype=dtype, device=device) # SCH: separate <prev_filters> channels into 32 groups
self.conv2 = self.conv_fn(prev_filters, self.embedding_dim, kernel_size=(1, 1, 1))
def forward(self, x):
dtype, device = x.dtype, x.device
x = self.conv1.to(device, dtype)(x)
for i in range(self.num_blocks):
for j in range(self.num_res_blocks):
x = self.block_res_blocks[i][j](x)
if i < self.num_blocks - 1:
if self.conv_downsample:
x = self.conv_blocks[i].to(device, dtype)(x)
else:
if self.temporal_downsample[i]:
x = self.avg_pool_with_t(x)
else:
x = self.avg_pool(x)
for i in range(self.num_res_blocks):
x = self.res_blocks[i](x)
x = self.norm1.to(device, dtype)(x)
x = self.activate(x)
x = self.conv2.to(device, dtype)(x)
return x
class Decoder(nn.Module):
"""Decoder Blocks."""
def __init__(self,
latent_embed_dim = 256,
filters = 64,
in_out_channels = 4,
num_res_blocks = 2,
channel_multipliers = (1, 2, 2, 4),
temporal_downsample = (True, True, False),
num_groups = 32, # for nn.GroupNorm
upsample = "nearest+conv", # options: "deconv", "nearest+conv"
custom_conv_padding = None,
activation_fn = 'swish',
device="cpu",
dtype=torch.bfloat16,
):
super().__init__()
self.output_dim = in_out_channels
self.embedding_dim = latent_embed_dim
self.filters = filters
self.num_res_blocks = num_res_blocks
self.channel_multipliers = channel_multipliers
self.temporal_downsample = temporal_downsample
self.num_groups = num_groups
if isinstance(self.temporal_downsample, int):
self.temporal_downsample = _get_selected_flags(
len(self.channel_multipliers) - 1, self.temporal_downsample, False)
self.upsample = upsample
self.custom_conv_padding = custom_conv_padding
# self.norm_type = self.config.vqvae.norm_type
# self.num_remat_block = self.config.vqvae.get('num_dec_remat_blocks', 0)
if activation_fn == 'relu':
self.activation_fn = nn.ReLU
elif activation_fn == 'swish':
self.activation_fn = nn.SiLU
else:
raise NotImplementedError
self.activate = self.activation_fn()
self.conv_fn = functools.partial(
nn.Conv3d,
dtype=dtype,
padding='valid' if self.custom_conv_padding is not None else 'same', # SCH: lower letter for pytorch
# custom_padding=self.custom_conv_padding
device=device,
)
self.conv_t_fn = functools.partial(nn.ConvTranspose3d, dtype=dtype)
# self.norm_fn = model_utils.get_norm_layer(
# norm_type=self.norm_type, dtype=dtype)
self.block_args = dict(
# norm_fn=self.norm_fn,
conv_fn=self.conv_fn,
activation_fn=self.activation_fn,
use_conv_shortcut=False,
num_groups=self.num_groups,
device=device,
dtype=dtype,
)
self.num_blocks = len(self.channel_multipliers)
filters = self.filters * self.channel_multipliers[-1]
self.conv1 = self.conv_fn(self.embedding_dim, filters, kernel_size=(3, 3, 3), bias=True)
# last layer res block
self.res_blocks = []
for _ in range(self.num_res_blocks):
self.res_blocks.append(ResBlock(filters, filters, **self.block_args))
# NOTE: upsample, dimensions T, H, W
self.upsampler_with_t = nn.Upsample(scale_factor=(2,2,2))
self.upsampler = nn.Upsample(scale_factor=(1,2,2))
# ResBlocks and conv upsample
prev_filters = filters # SCH: in_channels
self.block_res_blocks = []
self.num_blocks = len(self.channel_multipliers)
self.conv_blocks = []
# SCH: reverse to keep track of the in_channels, but append also in a reverse direction
for i in reversed(range(self.num_blocks)):
filters = self.filters * self.channel_multipliers[i]
# resblock handling
block_items = []
for _ in range(self.num_res_blocks):
block_items.append(ResBlock(prev_filters, filters, **self.block_args))
prev_filters = filters # SCH: update in_channels
self.block_res_blocks.insert(0, block_items) # SCH: append in front
# conv blocks handling
if i > 0:
t_stride = 2 if self.temporal_downsample[i - 1] else 1
t_kernel = 4 if self.temporal_downsample[i - 1] else 3 # SCH: hack to keep dimension same
if self.upsample == "deconv":
assert self.custom_conv_padding is None, ('Custom padding not implemented for ConvTranspose')
# SCH: append in front
self.conv_blocks.insert(0,
self.conv_t_fn(prev_filters, filters, kernel_size=(t_kernel, 4, 4), stride=(t_stride, 2, 2), padding=1))
prev_filters = filters # SCH: update in_channels
elif self.upsample == 'nearest+conv':
# SCH: append in front
self.conv_blocks.insert(0, self.conv_fn(prev_filters, filters, kernel_size=(3, 3, 3)))
prev_filters = filters # SCH: update in_channels
else:
raise NotImplementedError(f'Unknown upsampler: {self.upsample}')
self.norm1 = nn.GroupNorm(self.num_groups, prev_filters, device=device, dtype=dtype)
self.conv2 = self.conv_fn(prev_filters, self.output_dim, kernel_size=(3, 3, 3))
def forward(
self,
x,
**kwargs,
):
dtype, device = x.dtype, x.device
x = self.conv1.to(device, dtype)(x)
for i in range(self.num_res_blocks):
x = self.res_blocks[i](x)
for i in reversed(range(self.num_blocks)): # reverse here to make decoder symmetric with encoder
for j in range(self.num_res_blocks):
x = self.block_res_blocks[i][j](x)
if i > 0:
if self.upsample == 'deconv':
assert self.custom_conv_padding is None, ('Custom padding not implemented for ConvTranspose')
x = self.conv_blocks[i-1].to(device, dtype)(x)
elif self.upsample == 'nearest+conv':
if self.temporal_downsample[i - 1]:
x = self.upsampler_with_t(x)
else:
x = self.upsampler(x)
x = self.conv_blocks[i-1].to(device, dtype)(x)
else:
raise NotImplementedError(f'Unknown upsampler: {self.upsample}')
x = self.norm1.to(device, dtype)(x)
x = self.activate(x)
x = self.conv2.to(device, dtype)(x)
return x
@MODELS.register_module()
class VAE_3D(nn.Module):
"""The 3D VAE """
def __init__(
self,
latent_embed_dim = 256,
filters = 64,
num_res_blocks = 2,
channel_multipliers = (1, 2, 2, 4),
temporal_downsample = (True, True, False),
num_groups = 32, # for nn.GroupNorm
conv_downsample = False,
upsample = "nearest+conv", # options: "deconv", "nearest+conv"
custom_conv_padding = None,
activation_fn = 'swish',
in_out_channels = 4,
kl_embed_dim = 64,
device="cpu",
dtype="bf16",
# precision: Any = jax.lax.Precision.DEFAULT
):
super().__init__()
if type(dtype) == str:
if dtype == "bf16":
dtype = torch.bfloat16
elif dtype == "fp16":
dtype = torch.float16
else:
raise NotImplementedError(f'dtype: {dtype}')
# Model Initialization
self.encoder = Encoder(
filters=filters,
num_res_blocks=num_res_blocks,
channel_multipliers=channel_multipliers,
temporal_downsample=temporal_downsample,
num_groups = num_groups, # for nn.GroupNorm
in_out_channels = in_out_channels,
latent_embed_dim = latent_embed_dim,
conv_downsample = conv_downsample,
custom_conv_padding = custom_conv_padding,
activation_fn = activation_fn,
device=device,
dtype=dtype,
)
self.decoder = Decoder(
latent_embed_dim = latent_embed_dim,
filters = filters,
in_out_channels = in_out_channels,
num_res_blocks = num_res_blocks,
channel_multipliers = channel_multipliers,
temporal_downsample = temporal_downsample,
num_groups = num_groups, # for nn.GroupNorm
upsample = upsample, # options: "deconv", "nearest+conv"
custom_conv_padding = custom_conv_padding,
activation_fn = activation_fn,
device=device,
dtype=dtype,
)
self.quant_conv = nn.Conv3d(latent_embed_dim, 2*kl_embed_dim, 1)
self.post_quant_conv = nn.Conv3d(kl_embed_dim, latent_embed_dim, 1)
image_down = 2 ** len(temporal_downsample)
t_down = 2 ** len([x for x in temporal_downsample if x == True])
self.patch_size = (t_down, image_down, image_down)
def get_latent_size(self, input_size):
for i in range(len(input_size)):
assert input_size[i] % self.patch_size[i] == 0, "Input size must be divisible by patch size"
input_size = [input_size[i] // self.patch_size[i] for i in range(3)]
return input_size
def encode(
self,
x,
):
encoded_feature = self.encoder(x)
moments = self.quant_conv(encoded_feature).to(x.dtype)
posterior = model_utils.DiagonalGaussianDistribution(moments)
return posterior
def decode(
self,
z,
):
dtype = z.dtype
z = self.post_quant_conv(z).to(dtype)
dec = self.decoder(z)
return dec
def forward(
self,
input,
sample_posterior=True,
):
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
@MODELS.register_module("VAE_3D_B")
def VAE_3D_B(from_pretrained=None, **kwargs):
model = VAE_3D(**kwargs)
if from_pretrained is not None:
load_checkpoint(model, from_pretrained)
return model