Open-Sora/opensora/models/vae_v1_3/modules/updownsample.py
Zheng Zangwei (Alex Zheng) f1c6b8b88e open-sora v1.3 code upload (#786)
Co-authored-by: gxyes <gxynoz@gmail.com>
2025-02-20 16:50:24 +08:00

209 lines
7 KiB
Python

# modified from
# https://github.com/bornfly-detachment/asymmetric_magvitv2/blob/main/models/modules/updownsample.py
import logging
import torch
import torch.nn as nn
from opensora.models.vae_v1_3.utils import video_to_image
logpy = logging.getLogger(__name__)
from .conv import CausalConv3dPlainAR
class Downsample2D(nn.Module):
def __init__(self, in_channels, with_conv, micro_batch_size_2d=None):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
self.micro_batch_size_2d = micro_batch_size_2d
@video_to_image
def forward(self, x, is_training=False):
if self.with_conv:
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class Downsample3D(nn.Module):
def __init__(self, in_channels, with_conv, stride):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = CausalConv3dPlainAR(in_channels, in_channels, kernel_size=3, stride=stride)
def forward(self, x, is_training=False):
if self.with_conv:
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool3d(x, kernel_size=2, stride=2)
return x
class Res3DBlockUpsample(nn.Module):
def __init__(
self,
input_filters,
num_filters,
base_filters,
down_sampling_stride,
down_sampling=False,
down_sampling_temporal=None,
is_real_3d=True,
with_norm=True,
):
super(Res3DBlockUpsample, self).__init__()
self.num_filters = num_filters
self.base_filters = base_filters
self.input_filters = input_filters
self.with_norm = with_norm
if down_sampling:
if is_real_3d and down_sampling_temporal:
self.down_sampling_stride = down_sampling_stride
else:
self.down_sampling_stride = down_sampling_stride
else:
self.down_sampling_stride = [1, 1, 1]
self.down_sampling = down_sampling
self.act = nn.SiLU()
self.conv1 = CausalConv3dPlainAR(num_filters, num_filters, kernel_size=[3, 3, 3], stride=[1, 1, 1])
if self.with_norm:
self.norm1 = nn.GroupNorm(32, num_filters)
self.conv2 = CausalConv3dPlainAR(num_filters, num_filters, kernel_size=[3, 3, 3], stride=[1, 1, 1])
if self.with_norm:
self.norm2 = nn.GroupNorm(32, num_filters)
if num_filters != input_filters or down_sampling:
self.conv3 = CausalConv3dPlainAR(
input_filters, num_filters, kernel_size=[1, 1, 1], stride=self.down_sampling_stride
)
if self.with_norm:
self.norm3 = nn.GroupNorm(32, num_filters)
def _enable_tiled_conv3d(self, tile_size=16, tiled_dim=None, num_tiles=None):
self.conv1._enable_tiled_conv3d(tile_size=tile_size, tiled_dim=tiled_dim, num_tiles=num_tiles)
self.conv2._enable_tiled_conv3d(tile_size=tile_size, tiled_dim=tiled_dim, num_tiles=num_tiles)
if hasattr(self, "conv3"):
self.conv3._enable_tiled_conv3d(tile_size=tile_size, tiled_dim=tiled_dim, num_tiles=num_tiles)
def forward(self, x, is_training=False):
identity = x
out = self.conv1(x)
if self.with_norm:
out = self.norm1(out)
out = self.act(out)
out = self.conv2(out)
if self.with_norm:
out = self.norm2(out)
if self.down_sampling or self.num_filters != self.input_filters:
identity = self.conv3(identity)
if self.with_norm:
identity = self.norm3(identity)
if is_training:
out += identity
else:
out.add_(identity)
out = self.act(out)
return out
class Upsample3D(nn.Module):
def __init__(self, in_channels, with_conv, scale_factor=2):
super().__init__()
self.with_conv = with_conv
self.scale_factor = scale_factor
self.conv3d = Res3DBlockUpsample(
input_filters=in_channels,
num_filters=in_channels,
base_filters=in_channels,
down_sampling_stride=(1, 1, 1),
down_sampling=False,
)
def _split_by_channel(self, x, split_size):
slices = torch.split(x, split_size, dim=1)
return slices
def _split_by_batch(self, x, split_size):
slices = torch.split(x, split_size, dim=0)
return slices
def _enable_tiled_conv3d(self, tile_size=16, tiled_dim=None, num_tiles=None):
self.conv3d._enable_tiled_conv3d(tile_size=tile_size, tiled_dim=tiled_dim, num_tiles=num_tiles)
def forward(self, x, is_split=False, is_training=False):
b, c, t, h, w = x.shape
if is_split and not is_training:
split_size = c // 8
x_slices = self._split_by_channel(x, split_size)
x = [
torch.nn.functional.interpolate(
x,
(
x.shape[2] * self.scale_factor,
x.shape[3] * self.scale_factor,
x.shape[4] * self.scale_factor,
),
mode="nearest",
)
for x in x_slices
]
x = torch.cat(x, dim=1)
identity = x
if b > 2 and b % 2 == 0:
split_size = b // 2
x_slices = self._split_by_batch(x, split_size)
x = [self.conv3d(b_x) for b_x in x_slices]
x = torch.cat(x, dim=0)
else:
x = self.conv3d(x)
if is_training:
x += identity
else:
x.add_(identity)
return x
else:
x = torch.nn.functional.interpolate(
x,
(
x.shape[2] * self.scale_factor,
x.shape[3] * self.scale_factor,
x.shape[4] * self.scale_factor,
),
mode="nearest",
)
identity = x
x = self.conv3d(x)
if is_training:
x += identity
else:
x.add_(identity)
return x
class Upsample2D(nn.Module):
def __init__(self, in_channels, with_conv, micro_batch_size_2d=None):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
self.micro_batch_size_2d = micro_batch_size_2d
@video_to_image
def forward(self, x, is_split=False, is_training=False):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
return x