2024-03-28 08:12:20 +01:00
import functools
from typing import Any , Dict , Tuple , Type , Union , Sequence , Optional
from absl import logging
import torch
import torch . nn as nn
import numpy as np
from numpy import typing as nptyping
from opensora . models . vae import model_utils
from opensora . registry import MODELS
2024-04-05 10:01:24 +02:00
from opensora . utils . ckpt_utils import load_checkpoint
2024-03-28 08:12:20 +01:00
""" Encoder and Decoder stuctures with 3D CNNs. """
"""
NOTE :
removed LayerNorm since not used in this arch
GroupNorm : flax uses default ` epsilon = 1e-06 ` , whereas torch uses ` eps = 1e-05 `
for average pool and upsample , input shape needs to be [ N , C , T , H , W ] - - > if not , adjust the scale factors accordingly
! ! ! opensora read video into [ B , C , T , H , W ] format output
TODO :
check data dimensions format
"""
class ResBlock ( nn . Module ) :
def __init__ (
self ,
in_out_channels , # SCH: added
filters ,
# norm_fn, # SCH: removed, use GN
conv_fn ,
activation_fn = nn . ReLU ,
use_conv_shortcut = False ,
num_groups = 32 ,
2024-03-28 11:17:42 +01:00
device = " cpu " ,
2024-03-29 03:28:57 +01:00
dtype = torch . bfloat16 ,
2024-03-28 08:12:20 +01:00
) :
super ( ) . __init__ ( )
self . filters = filters
2024-03-28 10:09:37 +01:00
self . activate = activation_fn ( )
self . use_conv_shortcut = use_conv_shortcut
2024-03-28 08:12:20 +01:00
# SCH: MAGVIT uses GroupNorm by default
2024-03-29 03:28:57 +01:00
self . norm1 = nn . GroupNorm ( num_groups , in_out_channels , device = device , dtype = dtype )
2024-03-28 09:46:38 +01:00
self . conv1 = conv_fn ( in_out_channels , self . filters , kernel_size = ( 3 , 3 , 3 ) , bias = False )
2024-03-29 03:28:57 +01:00
self . norm2 = nn . GroupNorm ( num_groups , self . filters , device = device , dtype = dtype )
2024-03-28 09:46:38 +01:00
self . conv2 = conv_fn ( self . filters , self . filters , kernel_size = ( 3 , 3 , 3 ) , bias = False )
2024-03-28 08:12:20 +01:00
if self . use_conv_shortcut :
2024-03-28 15:51:22 +01:00
self . conv3 = conv_fn ( in_out_channels , self . filters , kernel_size = ( 3 , 3 , 3 ) , bias = False )
2024-03-28 08:12:20 +01:00
else :
2024-03-28 15:51:22 +01:00
self . conv3 = conv_fn ( in_out_channels , self . filters , kernel_size = ( 1 , 1 , 1 ) , bias = False )
2024-03-28 08:12:20 +01:00
def forward ( self , x ) :
2024-03-29 03:01:39 +01:00
device , dtype = x . device , x . dtype
2024-03-28 14:41:39 +01:00
input_dim = x . shape [ 1 ]
2024-03-28 08:12:20 +01:00
residual = x
2024-04-05 10:20:39 +02:00
x = self . norm1 . to ( device , dtype ) ( x )
2024-03-28 08:12:20 +01:00
x = self . activate ( x )
2024-04-05 10:20:39 +02:00
x = self . conv1 . to ( device , dtype ) ( x )
x = self . norm2 . to ( device , dtype ) ( x )
2024-03-28 08:12:20 +01:00
x = self . activate ( x )
2024-04-05 10:20:39 +02:00
x = self . conv2 . to ( device , dtype ) ( x )
2024-03-28 14:41:39 +01:00
if input_dim != self . filters : # TODO: what does it do here
2024-04-05 10:20:39 +02:00
residual = self . conv3 . to ( device , dtype ) ( residual )
2024-03-28 08:12:20 +01:00
return x + residual
def _get_selected_flags ( total_len : int , select_len : int , suffix : bool ) :
assert select_len < = total_len
selected = np . zeros ( total_len , dtype = bool )
if not suffix :
selected [ : select_len ] = True
else :
selected [ - select_len : ] = True
return selected
class Encoder ( nn . Module ) :
""" Encoder Blocks. """
def __init__ ( self ,
filters = 64 ,
2024-04-09 11:49:01 +02:00
num_res_blocks = 2 ,
2024-03-28 08:12:20 +01:00
channel_multipliers = ( 1 , 2 , 2 , 4 ) ,
temporal_downsample = ( True , True , False ) ,
num_groups = 32 , # for nn.GroupNorm
in_out_channels = 3 , # SCH: added, in_channels at the start
2024-04-09 11:49:01 +02:00
latent_embed_dim = 256 , # num channels for latent vector
2024-03-28 08:12:20 +01:00
conv_downsample = False ,
custom_conv_padding = None ,
activation_fn = ' swish ' ,
2024-03-28 11:17:42 +01:00
device = " cpu " ,
2024-03-29 03:28:57 +01:00
dtype = torch . bfloat16 ,
2024-03-28 08:12:20 +01:00
) :
super ( ) . __init__ ( )
self . filters = filters
self . num_res_blocks = num_res_blocks
self . channel_multipliers = channel_multipliers
self . temporal_downsample = temporal_downsample
self . num_groups = num_groups
if isinstance ( self . temporal_downsample , int ) :
self . temporal_downsample = _get_selected_flags (
len ( self . channel_multipliers ) - 1 , self . temporal_downsample , False )
self . embedding_dim = latent_embed_dim
self . conv_downsample = conv_downsample
self . custom_conv_padding = custom_conv_padding
# self.norm_type = self.config.vqvae.norm_type
# self.num_remat_block = self.config.vqvae.get('num_enc_remat_blocks', 0)
if activation_fn == ' relu ' :
self . activation_fn = nn . ReLU
elif activation_fn == ' swish ' :
self . activation_fn = nn . SiLU
else :
raise NotImplementedError
self . activate = self . activation_fn ( )
self . conv_fn = functools . partial (
2024-03-28 09:44:41 +01:00
nn . Conv3d ,
2024-03-29 03:38:44 +01:00
dtype = dtype ,
2024-03-28 08:12:20 +01:00
padding = ' valid ' if self . custom_conv_padding is not None else ' same ' , # SCH: lower letter for pytorch
2024-03-28 09:44:41 +01:00
# custom_padding=self.custom_conv_padding
2024-03-28 14:17:13 +01:00
device = device ,
2024-03-28 08:12:20 +01:00
)
# self.norm_fn = model_utils.get_norm_layer(
# norm_type=self.norm_type, dtype=self.dtype)
self . block_args = dict (
# norm_fn=self.norm_fn,
conv_fn = self . conv_fn ,
2024-03-29 03:31:58 +01:00
dtype = dtype ,
2024-03-28 08:12:20 +01:00
activation_fn = self . activation_fn ,
use_conv_shortcut = False ,
num_groups = self . num_groups ,
2024-03-28 11:17:42 +01:00
device = device ,
2024-03-28 08:12:20 +01:00
)
2024-03-28 10:07:56 +01:00
self . conv1 = self . conv_fn ( in_out_channels , self . filters , kernel_size = ( 3 , 3 , 3 ) , bias = False )
2024-03-28 08:12:20 +01:00
# ResBlocks and conv downsample
self . block_res_blocks = [ ]
self . num_blocks = len ( self . channel_multipliers )
self . conv_blocks = [ ]
filters = self . filters
prev_filters = filters # record for in_channels
for i in range ( self . num_blocks ) :
# resblock handling
filters = self . filters * self . channel_multipliers [ i ] # SCH: determine the number out_channels
block_items = [ ]
for _ in range ( self . num_res_blocks ) :
block_items . append ( ResBlock ( prev_filters , filters , * * self . block_args ) )
prev_filters = filters # update in_channels
self . block_res_blocks . append ( block_items )
if i < self . num_blocks - 1 :
# conv blocks handling
if self . conv_downsample :
t_stride = 2 if self . temporal_downsample [ i ] else 1
2024-04-02 05:21:01 +02:00
t_pad = 1 if self . temporal_downsample [ i ] else 0
2024-04-09 12:14:25 +02:00
self . conv_blocks . append ( self . conv_fn ( prev_filters , filters , kernel_size = ( 4 , 4 , 4 ) , stride = ( t_stride , 2 , 2 ) ) , padding = ( t_pad , 1 , 1 ) ) # SCH: should be same in_channel and out_channel
2024-03-28 08:12:20 +01:00
prev_filters = filters # update in_channels
# NOTE: downsample, dimensions T, H, W
2024-04-02 05:21:01 +02:00
self . avg_pool_with_t = nn . AvgPool3d ( ( 2 , 2 , 2 ) , count_include_pad = False )
self . avg_pool = nn . AvgPool3d ( ( 1 , 2 , 2 ) , count_include_pad = False )
2024-03-28 08:12:20 +01:00
# last layer res block
self . res_blocks = [ ]
for _ in range ( self . num_res_blocks ) :
self . res_blocks . append ( ResBlock ( prev_filters , filters , * * self . block_args ) )
prev_filters = filters # update in_channels
# MAGVIT uses Group Normalization
2024-03-29 03:28:57 +01:00
self . norm1 = nn . GroupNorm ( self . num_groups , prev_filters , dtype = dtype , device = device ) # SCH: separate <prev_filters> channels into 32 groups
2024-03-28 08:12:20 +01:00
self . conv2 = self . conv_fn ( prev_filters , self . embedding_dim , kernel_size = ( 1 , 1 , 1 ) )
def forward ( self , x ) :
2024-04-05 10:25:32 +02:00
dtype , device = x . dtype , x . device
x = self . conv1 . to ( device , dtype ) ( x )
2024-03-28 08:12:20 +01:00
for i in range ( self . num_blocks ) :
for j in range ( self . num_res_blocks ) :
x = self . block_res_blocks [ i ] [ j ] ( x )
if i < self . num_blocks - 1 :
if self . conv_downsample :
2024-04-05 10:25:32 +02:00
x = self . conv_blocks [ i ] . to ( device , dtype ) ( x )
2024-03-28 08:12:20 +01:00
else :
if self . temporal_downsample [ i ] :
x = self . avg_pool_with_t ( x )
else :
x = self . avg_pool ( x )
for i in range ( self . num_res_blocks ) :
x = self . res_blocks [ i ] ( x )
2024-04-05 10:25:32 +02:00
x = self . norm1 . to ( device , dtype ) ( x )
2024-03-28 08:12:20 +01:00
x = self . activate ( x )
2024-04-05 10:25:32 +02:00
x = self . conv2 . to ( device , dtype ) ( x )
2024-03-28 08:12:20 +01:00
return x
class Decoder ( nn . Module ) :
""" Decoder Blocks. """
def __init__ ( self ,
latent_embed_dim = 256 ,
filters = 64 ,
in_out_channels = 4 ,
num_res_blocks = 2 ,
channel_multipliers = ( 1 , 2 , 2 , 4 ) ,
temporal_downsample = ( True , True , False ) ,
num_groups = 32 , # for nn.GroupNorm
upsample = " nearest+conv " , # options: "deconv", "nearest+conv"
custom_conv_padding = None ,
activation_fn = ' swish ' ,
2024-03-28 11:17:42 +01:00
device = " cpu " ,
2024-03-29 03:28:57 +01:00
dtype = torch . bfloat16 ,
2024-03-28 08:12:20 +01:00
) :
2024-03-28 10:13:16 +01:00
super ( ) . __init__ ( )
2024-03-28 08:12:20 +01:00
self . output_dim = in_out_channels
self . embedding_dim = latent_embed_dim
self . filters = filters
self . num_res_blocks = num_res_blocks
self . channel_multipliers = channel_multipliers
self . temporal_downsample = temporal_downsample
self . num_groups = num_groups
if isinstance ( self . temporal_downsample , int ) :
self . temporal_downsample = _get_selected_flags (
len ( self . channel_multipliers ) - 1 , self . temporal_downsample , False )
self . upsample = upsample
self . custom_conv_padding = custom_conv_padding
# self.norm_type = self.config.vqvae.norm_type
# self.num_remat_block = self.config.vqvae.get('num_dec_remat_blocks', 0)
if activation_fn == ' relu ' :
self . activation_fn = nn . ReLU
elif activation_fn == ' swish ' :
self . activation_fn = nn . SiLU
else :
raise NotImplementedError
self . activate = self . activation_fn ( )
self . conv_fn = functools . partial (
2024-03-28 09:44:41 +01:00
nn . Conv3d ,
2024-03-29 03:38:44 +01:00
dtype = dtype ,
2024-03-28 09:44:41 +01:00
padding = ' valid ' if self . custom_conv_padding is not None else ' same ' , # SCH: lower letter for pytorch
# custom_padding=self.custom_conv_padding
2024-03-28 14:17:13 +01:00
device = device ,
2024-03-28 09:44:41 +01:00
)
2024-03-28 08:12:20 +01:00
2024-03-29 03:38:44 +01:00
self . conv_t_fn = functools . partial ( nn . ConvTranspose3d , dtype = dtype )
2024-03-28 08:12:20 +01:00
# self.norm_fn = model_utils.get_norm_layer(
2024-03-29 02:59:07 +01:00
# norm_type=self.norm_type, dtype=dtype)
2024-03-28 08:12:20 +01:00
self . block_args = dict (
# norm_fn=self.norm_fn,
conv_fn = self . conv_fn ,
activation_fn = self . activation_fn ,
use_conv_shortcut = False ,
num_groups = self . num_groups ,
2024-03-28 11:17:42 +01:00
device = device ,
2024-03-29 03:28:57 +01:00
dtype = dtype ,
2024-03-28 08:12:20 +01:00
)
self . num_blocks = len ( self . channel_multipliers )
filters = self . filters * self . channel_multipliers [ - 1 ]
2024-03-28 09:46:38 +01:00
self . conv1 = self . conv_fn ( self . embedding_dim , filters , kernel_size = ( 3 , 3 , 3 ) , bias = True )
2024-03-28 08:12:20 +01:00
# last layer res block
self . res_blocks = [ ]
for _ in range ( self . num_res_blocks ) :
self . res_blocks . append ( ResBlock ( filters , filters , * * self . block_args ) )
# NOTE: upsample, dimensions T, H, W
2024-03-28 10:25:46 +01:00
self . upsampler_with_t = nn . Upsample ( scale_factor = ( 2 , 2 , 2 ) )
self . upsampler = nn . Upsample ( scale_factor = ( 1 , 2 , 2 ) )
2024-03-28 08:12:20 +01:00
# ResBlocks and conv upsample
prev_filters = filters # SCH: in_channels
self . block_res_blocks = [ ]
self . num_blocks = len ( self . channel_multipliers )
self . conv_blocks = [ ]
# SCH: reverse to keep track of the in_channels, but append also in a reverse direction
for i in reversed ( range ( self . num_blocks ) ) :
filters = self . filters * self . channel_multipliers [ i ]
# resblock handling
block_items = [ ]
for _ in range ( self . num_res_blocks ) :
block_items . append ( ResBlock ( prev_filters , filters , * * self . block_args ) )
prev_filters = filters # SCH: update in_channels
self . block_res_blocks . insert ( 0 , block_items ) # SCH: append in front
# conv blocks handling
if i > 0 :
t_stride = 2 if self . temporal_downsample [ i - 1 ] else 1
2024-04-02 05:21:01 +02:00
t_kernel = 4 if self . temporal_downsample [ i - 1 ] else 3 # SCH: hack to keep dimension same
2024-03-28 08:12:20 +01:00
if self . upsample == " deconv " :
assert self . custom_conv_padding is None , ( ' Custom padding not implemented for ConvTranspose ' )
# SCH: append in front
2024-04-02 05:21:01 +02:00
self . conv_blocks . insert ( 0 ,
self . conv_t_fn ( prev_filters , filters , kernel_size = ( t_kernel , 4 , 4 ) , stride = ( t_stride , 2 , 2 ) , padding = 1 ) )
2024-03-28 08:12:20 +01:00
prev_filters = filters # SCH: update in_channels
elif self . upsample == ' nearest+conv ' :
2024-04-02 05:21:01 +02:00
# SCH: append in front
2024-03-28 08:12:20 +01:00
self . conv_blocks . insert ( 0 , self . conv_fn ( prev_filters , filters , kernel_size = ( 3 , 3 , 3 ) ) )
prev_filters = filters # SCH: update in_channels
else :
raise NotImplementedError ( f ' Unknown upsampler: { self . upsample } ' )
2024-03-29 03:28:57 +01:00
self . norm1 = nn . GroupNorm ( self . num_groups , prev_filters , device = device , dtype = dtype )
2024-03-28 08:12:20 +01:00
self . conv2 = self . conv_fn ( prev_filters , self . output_dim , kernel_size = ( 3 , 3 , 3 ) )
def forward (
self ,
2024-03-28 16:12:36 +01:00
x ,
2024-03-28 08:12:20 +01:00
* * kwargs ,
) :
2024-04-05 10:25:32 +02:00
dtype , device = x . dtype , x . device
x = self . conv1 . to ( device , dtype ) ( x )
2024-03-28 08:12:20 +01:00
for i in range ( self . num_res_blocks ) :
x = self . res_blocks [ i ] ( x )
for i in reversed ( range ( self . num_blocks ) ) : # reverse here to make decoder symmetric with encoder
for j in range ( self . num_res_blocks ) :
x = self . block_res_blocks [ i ] [ j ] ( x )
if i > 0 :
if self . upsample == ' deconv ' :
assert self . custom_conv_padding is None , ( ' Custom padding not implemented for ConvTranspose ' )
2024-04-05 10:25:32 +02:00
x = self . conv_blocks [ i - 1 ] . to ( device , dtype ) ( x )
2024-03-28 16:24:37 +01:00
elif self . upsample == ' nearest+conv ' :
if self . temporal_downsample [ i - 1 ] :
x = self . upsampler_with_t ( x )
else :
x = self . upsampler ( x )
2024-04-05 10:25:32 +02:00
x = self . conv_blocks [ i - 1 ] . to ( device , dtype ) ( x )
2024-03-28 08:12:20 +01:00
else :
2024-03-28 16:24:37 +01:00
raise NotImplementedError ( f ' Unknown upsampler: { self . upsample } ' )
2024-04-05 10:25:32 +02:00
x = self . norm1 . to ( device , dtype ) ( x )
2024-03-28 08:12:20 +01:00
x = self . activate ( x )
2024-04-05 10:25:32 +02:00
x = self . conv2 . to ( device , dtype ) ( x )
2024-03-28 08:12:20 +01:00
return x
@MODELS.register_module ( )
class VAE_3D ( nn . Module ) :
""" The 3D VAE """
def __init__ (
self ,
latent_embed_dim = 256 ,
filters = 64 ,
num_res_blocks = 2 ,
channel_multipliers = ( 1 , 2 , 2 , 4 ) ,
temporal_downsample = ( True , True , False ) ,
num_groups = 32 , # for nn.GroupNorm
conv_downsample = False ,
upsample = " nearest+conv " , # options: "deconv", "nearest+conv"
custom_conv_padding = None ,
activation_fn = ' swish ' ,
in_out_channels = 4 ,
kl_embed_dim = 64 ,
2024-03-28 11:17:42 +01:00
device = " cpu " ,
2024-03-29 03:28:57 +01:00
dtype = " bf16 " ,
2024-03-28 08:12:20 +01:00
# precision: Any = jax.lax.Precision.DEFAULT
) :
2024-03-28 10:13:16 +01:00
super ( ) . __init__ ( )
2024-03-29 03:28:57 +01:00
if type ( dtype ) == str :
if dtype == " bf16 " :
dtype = torch . bfloat16
elif dtype == " fp16 " :
dtype = torch . float16
else :
raise NotImplementedError ( f ' dtype: { dtype } ' )
2024-03-28 08:12:20 +01:00
2024-04-04 09:41:15 +02:00
# Model Initialization
2024-03-28 08:12:20 +01:00
2024-04-05 10:01:24 +02:00
self . encoder = Encoder (
filters = filters ,
num_res_blocks = num_res_blocks ,
channel_multipliers = channel_multipliers ,
temporal_downsample = temporal_downsample ,
num_groups = num_groups , # for nn.GroupNorm
in_out_channels = in_out_channels ,
latent_embed_dim = latent_embed_dim ,
conv_downsample = conv_downsample ,
custom_conv_padding = custom_conv_padding ,
activation_fn = activation_fn ,
device = device ,
dtype = dtype ,
)
self . decoder = Decoder (
latent_embed_dim = latent_embed_dim ,
filters = filters ,
in_out_channels = in_out_channels ,
num_res_blocks = num_res_blocks ,
channel_multipliers = channel_multipliers ,
temporal_downsample = temporal_downsample ,
num_groups = num_groups , # for nn.GroupNorm
upsample = upsample , # options: "deconv", "nearest+conv"
custom_conv_padding = custom_conv_padding ,
activation_fn = activation_fn ,
device = device ,
dtype = dtype ,
)
2024-03-28 08:12:20 +01:00
2024-04-05 10:01:24 +02:00
self . quant_conv = nn . Conv3d ( latent_embed_dim , 2 * kl_embed_dim , 1 )
self . post_quant_conv = nn . Conv3d ( kl_embed_dim , latent_embed_dim , 1 )
2024-04-03 04:29:01 +02:00
image_down = 2 * * len ( temporal_downsample )
t_down = 2 * * len ( [ x for x in temporal_downsample if x == True ] )
self . patch_size = ( t_down , image_down , image_down )
def get_latent_size ( self , input_size ) :
for i in range ( len ( input_size ) ) :
assert input_size [ i ] % self . patch_size [ i ] == 0 , " Input size must be divisible by patch size "
input_size = [ input_size [ i ] / / self . patch_size [ i ] for i in range ( 3 ) ]
return input_size
2024-03-28 08:12:20 +01:00
def encode (
self ,
x ,
) :
2024-04-05 10:01:24 +02:00
encoded_feature = self . encoder ( x )
moments = self . quant_conv ( encoded_feature ) . to ( x . dtype )
2024-03-29 02:59:07 +01:00
posterior = model_utils . DiagonalGaussianDistribution ( moments )
2024-03-28 08:12:20 +01:00
return posterior
def decode (
self ,
z ,
) :
2024-03-29 02:31:43 +01:00
dtype = z . dtype
2024-04-05 10:01:24 +02:00
z = self . post_quant_conv ( z ) . to ( dtype )
dec = self . decoder ( z )
2024-03-28 08:12:20 +01:00
return dec
def forward (
self ,
input ,
sample_posterior = True ,
) :
posterior = self . encode ( input )
if sample_posterior :
z = posterior . sample ( )
else :
z = posterior . mode ( )
dec = self . decode ( z )
2024-04-05 10:01:24 +02:00
return dec , posterior
@MODELS.register_module ( " VAE_3D_B " )
def VAE_3D_B ( from_pretrained = None , * * kwargs ) :
model = VAE_3D ( * * kwargs )
if from_pretrained is not None :
load_checkpoint ( model , from_pretrained )
return model