from typing import Tuple, Union import torch.nn as nn import torch.nn.functional as F from einops import rearrange from opensora.registry import MODELS from opensora.utils.ckpt_utils import load_checkpoint from .utils import DiagonalGaussianDistribution def cast_tuple(t, length=1): return t if isinstance(t, tuple) else ((t,) * length) def divisible_by(num, den): return (num % den) == 0 def is_odd(n): return not divisible_by(n, 2) def pad_at_dim(t, pad, dim=-1): dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1) zeros = (0, 0) * dims_from_right return F.pad(t, (*zeros, *pad), mode="constant") def exists(v): return v is not None class CausalConv3d(nn.Module): def __init__( self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], pad_mode="constant", strides=None, # allow custom stride **kwargs, ): super().__init__() kernel_size = cast_tuple(kernel_size, 3) time_kernel_size, height_kernel_size, width_kernel_size = kernel_size assert is_odd(height_kernel_size) and is_odd(width_kernel_size) dilation = kwargs.pop("dilation", 1) stride = strides[0] if strides is not None else kwargs.pop("stride", 1) self.pad_mode = pad_mode time_pad = dilation * (time_kernel_size - 1) + (1 - stride) height_pad = height_kernel_size // 2 width_pad = width_kernel_size // 2 self.time_pad = time_pad self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0) stride = strides if strides is not None else (stride, 1, 1) dilation = (dilation, 1, 1) self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) def forward(self, x): x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) x = self.conv(x) return x class ResBlock(nn.Module): def __init__( self, in_channels, # SCH: added filters, conv_fn, activation_fn=nn.SiLU, use_conv_shortcut=False, num_groups=32, ): super().__init__() self.in_channels = in_channels self.filters = filters self.activate = activation_fn() self.use_conv_shortcut = use_conv_shortcut # SCH: MAGVIT uses GroupNorm by default self.norm1 = nn.GroupNorm(num_groups, in_channels) self.conv1 = conv_fn(in_channels, self.filters, kernel_size=(3, 3, 3), bias=False) self.norm2 = nn.GroupNorm(num_groups, self.filters) self.conv2 = conv_fn(self.filters, self.filters, kernel_size=(3, 3, 3), bias=False) if in_channels != filters: if self.use_conv_shortcut: self.conv3 = conv_fn(in_channels, self.filters, kernel_size=(3, 3, 3), bias=False) else: self.conv3 = conv_fn(in_channels, self.filters, kernel_size=(1, 1, 1), bias=False) def forward(self, x): residual = x x = self.norm1(x) x = self.activate(x) x = self.conv1(x) x = self.norm2(x) x = self.activate(x) x = self.conv2(x) if self.in_channels != self.filters: # SCH: ResBlock X->Y residual = self.conv3(residual) return x + residual def get_activation_fn(activation): if activation == "relu": activation_fn = nn.ReLU elif activation == "swish": activation_fn = nn.SiLU else: raise NotImplementedError return activation_fn class Encoder(nn.Module): """Encoder Blocks.""" def __init__( self, in_out_channels=4, latent_embed_dim=512, # num channels for latent vector filters=128, num_res_blocks=4, channel_multipliers=(1, 2, 2, 4), temporal_downsample=(False, True, True), num_groups=32, # for nn.GroupNorm activation_fn="swish", ): super().__init__() self.filters = filters self.num_res_blocks = num_res_blocks self.num_blocks = len(channel_multipliers) self.channel_multipliers = channel_multipliers self.temporal_downsample = temporal_downsample self.num_groups = num_groups self.embedding_dim = latent_embed_dim self.activation_fn = get_activation_fn(activation_fn) self.activate = self.activation_fn() self.conv_fn = CausalConv3d self.block_args = dict( conv_fn=self.conv_fn, activation_fn=self.activation_fn, use_conv_shortcut=False, num_groups=self.num_groups, ) # first layer conv self.conv_in = self.conv_fn( in_out_channels, filters, kernel_size=(3, 3, 3), bias=False, ) # ResBlocks and conv downsample self.block_res_blocks = nn.ModuleList([]) self.conv_blocks = nn.ModuleList([]) filters = self.filters prev_filters = filters # record for in_channels for i in range(self.num_blocks): filters = self.filters * self.channel_multipliers[i] block_items = nn.ModuleList([]) for _ in range(self.num_res_blocks): block_items.append(ResBlock(prev_filters, filters, **self.block_args)) prev_filters = filters # update in_channels self.block_res_blocks.append(block_items) if i < self.num_blocks - 1: if self.temporal_downsample[i]: t_stride = 2 if self.temporal_downsample[i] else 1 s_stride = 1 self.conv_blocks.append( self.conv_fn( prev_filters, filters, kernel_size=(3, 3, 3), strides=(t_stride, s_stride, s_stride) ) ) prev_filters = filters # update in_channels else: # if no t downsample, don't add since this does nothing for pipeline models self.conv_blocks.append(nn.Identity(prev_filters)) # Identity prev_filters = filters # update in_channels # last layer res block self.res_blocks = nn.ModuleList([]) for _ in range(self.num_res_blocks): self.res_blocks.append(ResBlock(prev_filters, filters, **self.block_args)) prev_filters = filters # update in_channels # MAGVIT uses Group Normalization self.norm1 = nn.GroupNorm(self.num_groups, prev_filters) self.conv2 = self.conv_fn(prev_filters, self.embedding_dim, kernel_size=(1, 1, 1), padding="same") def forward(self, x): x = self.conv_in(x) for i in range(self.num_blocks): for j in range(self.num_res_blocks): x = self.block_res_blocks[i][j](x) if i < self.num_blocks - 1: x = self.conv_blocks[i](x) for i in range(self.num_res_blocks): x = self.res_blocks[i](x) x = self.norm1(x) x = self.activate(x) x = self.conv2(x) return x class Decoder(nn.Module): """Decoder Blocks.""" def __init__( self, in_out_channels=4, latent_embed_dim=512, filters=128, num_res_blocks=4, channel_multipliers=(1, 2, 2, 4), temporal_downsample=(False, True, True), num_groups=32, # for nn.GroupNorm activation_fn="swish", ): super().__init__() self.filters = filters self.num_res_blocks = num_res_blocks self.num_blocks = len(channel_multipliers) self.channel_multipliers = channel_multipliers self.temporal_downsample = temporal_downsample self.num_groups = num_groups self.embedding_dim = latent_embed_dim self.s_stride = 1 self.activation_fn = get_activation_fn(activation_fn) self.activate = self.activation_fn() self.conv_fn = CausalConv3d self.block_args = dict( conv_fn=self.conv_fn, activation_fn=self.activation_fn, use_conv_shortcut=False, num_groups=self.num_groups, ) filters = self.filters * self.channel_multipliers[-1] prev_filters = filters # last conv self.conv1 = self.conv_fn(self.embedding_dim, filters, kernel_size=(3, 3, 3), bias=True) # last layer res block self.res_blocks = nn.ModuleList([]) for _ in range(self.num_res_blocks): self.res_blocks.append(ResBlock(filters, filters, **self.block_args)) # ResBlocks and conv upsample self.block_res_blocks = nn.ModuleList([]) self.num_blocks = len(self.channel_multipliers) self.conv_blocks = nn.ModuleList([]) # reverse to keep track of the in_channels, but append also in a reverse direction for i in reversed(range(self.num_blocks)): filters = self.filters * self.channel_multipliers[i] # resblock handling block_items = nn.ModuleList([]) for _ in range(self.num_res_blocks): block_items.append(ResBlock(prev_filters, filters, **self.block_args)) prev_filters = filters # SCH: update in_channels self.block_res_blocks.insert(0, block_items) # SCH: append in front # conv blocks with upsampling if i > 0: if self.temporal_downsample[i - 1]: t_stride = 2 if self.temporal_downsample[i - 1] else 1 # SCH: T-Causal Conv 3x3x3, f -> (t_stride * 2 * 2) * f, depth to space t_stride x 2 x 2 self.conv_blocks.insert( 0, self.conv_fn( prev_filters, prev_filters * t_stride * self.s_stride * self.s_stride, kernel_size=(3, 3, 3) ), ) else: self.conv_blocks.insert( 0, nn.Identity(prev_filters), ) self.norm1 = nn.GroupNorm(self.num_groups, prev_filters) self.conv_out = self.conv_fn(filters, in_out_channels, 3) def forward(self, x): x = self.conv1(x) for i in range(self.num_res_blocks): x = self.res_blocks[i](x) for i in reversed(range(self.num_blocks)): for j in range(self.num_res_blocks): x = self.block_res_blocks[i][j](x) if i > 0: t_stride = 2 if self.temporal_downsample[i - 1] else 1 x = self.conv_blocks[i - 1](x) x = rearrange( x, "B (C ts hs ws) T H W -> B C (T ts) (H hs) (W ws)", ts=t_stride, hs=self.s_stride, ws=self.s_stride, ) x = self.norm1(x) x = self.activate(x) x = self.conv_out(x) return x @MODELS.register_module() class VAE_Temporal(nn.Module): def __init__( self, in_out_channels=4, latent_embed_dim=4, embed_dim=4, filters=128, num_res_blocks=4, channel_multipliers=(1, 2, 2, 4), temporal_downsample=(True, True, False), num_groups=32, # for nn.GroupNorm activation_fn="swish", ): super().__init__() self.time_downsample_factor = 2 ** sum(temporal_downsample) # self.time_padding = self.time_downsample_factor - 1 self.patch_size = (self.time_downsample_factor, 1, 1) self.out_channels = in_out_channels # NOTE: following MAGVIT, conv in bias=False in encoder first conv self.encoder = Encoder( in_out_channels=in_out_channels, latent_embed_dim=latent_embed_dim * 2, filters=filters, num_res_blocks=num_res_blocks, channel_multipliers=channel_multipliers, temporal_downsample=temporal_downsample, num_groups=num_groups, # for nn.GroupNorm activation_fn=activation_fn, ) self.quant_conv = CausalConv3d(2 * latent_embed_dim, 2 * embed_dim, 1) self.post_quant_conv = CausalConv3d(embed_dim, latent_embed_dim, 1) self.decoder = Decoder( in_out_channels=in_out_channels, latent_embed_dim=latent_embed_dim, filters=filters, num_res_blocks=num_res_blocks, channel_multipliers=channel_multipliers, temporal_downsample=temporal_downsample, num_groups=num_groups, # for nn.GroupNorm activation_fn=activation_fn, ) def get_latent_size(self, input_size): latent_size = [] for i in range(3): if input_size[i] is None: lsize = None elif i == 0: time_padding = ( 0 if (input_size[i] % self.time_downsample_factor == 0) else self.time_downsample_factor - input_size[i] % self.time_downsample_factor ) lsize = (input_size[i] + time_padding) // self.patch_size[i] else: lsize = input_size[i] // self.patch_size[i] latent_size.append(lsize) return latent_size def encode(self, x): time_padding = ( 0 if (x.shape[2] % self.time_downsample_factor == 0) else self.time_downsample_factor - x.shape[2] % self.time_downsample_factor ) x = pad_at_dim(x, (time_padding, 0), dim=2) encoded_feature = self.encoder(x) moments = self.quant_conv(encoded_feature).to(x.dtype) posterior = DiagonalGaussianDistribution(moments) return posterior def decode(self, z, num_frames=None): time_padding = ( 0 if (num_frames % self.time_downsample_factor == 0) else self.time_downsample_factor - num_frames % self.time_downsample_factor ) z = self.post_quant_conv(z) x = self.decoder(z) x = x[:, :, time_padding:] return x def forward(self, x, sample_posterior=True): posterior = self.encode(x) if sample_posterior: z = posterior.sample() else: z = posterior.mode() recon_video = self.decode(z, num_frames=x.shape[2]) return recon_video, posterior, z @MODELS.register_module("VAE_Temporal_SD") def VAE_Temporal_SD(from_pretrained=None, **kwargs): model = VAE_Temporal( in_out_channels=4, latent_embed_dim=4, embed_dim=4, filters=128, num_res_blocks=4, channel_multipliers=(1, 2, 2, 4), temporal_downsample=(False, True, True), **kwargs, ) if from_pretrained is not None: load_checkpoint(model, from_pretrained) return model