mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-05-12 13:19:09 +02:00
209 lines
7 KiB
Python
209 lines
7 KiB
Python
# modified from
|
|
# https://github.com/bornfly-detachment/asymmetric_magvitv2/blob/main/models/modules/updownsample.py
|
|
|
|
import logging
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from opensora.models.vae_v1_3.utils import video_to_image
|
|
|
|
logpy = logging.getLogger(__name__)
|
|
from .conv import CausalConv3dPlainAR
|
|
|
|
|
|
class Downsample2D(nn.Module):
|
|
def __init__(self, in_channels, with_conv, micro_batch_size_2d=None):
|
|
super().__init__()
|
|
self.with_conv = with_conv
|
|
if self.with_conv:
|
|
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
|
self.micro_batch_size_2d = micro_batch_size_2d
|
|
|
|
@video_to_image
|
|
def forward(self, x, is_training=False):
|
|
if self.with_conv:
|
|
pad = (0, 1, 0, 1)
|
|
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
|
x = self.conv(x)
|
|
else:
|
|
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
|
return x
|
|
|
|
|
|
class Downsample3D(nn.Module):
|
|
def __init__(self, in_channels, with_conv, stride):
|
|
super().__init__()
|
|
self.with_conv = with_conv
|
|
if self.with_conv:
|
|
self.conv = CausalConv3dPlainAR(in_channels, in_channels, kernel_size=3, stride=stride)
|
|
|
|
def forward(self, x, is_training=False):
|
|
if self.with_conv:
|
|
x = self.conv(x)
|
|
else:
|
|
x = torch.nn.functional.avg_pool3d(x, kernel_size=2, stride=2)
|
|
return x
|
|
|
|
|
|
class Res3DBlockUpsample(nn.Module):
|
|
def __init__(
|
|
self,
|
|
input_filters,
|
|
num_filters,
|
|
base_filters,
|
|
down_sampling_stride,
|
|
down_sampling=False,
|
|
down_sampling_temporal=None,
|
|
is_real_3d=True,
|
|
with_norm=True,
|
|
):
|
|
super(Res3DBlockUpsample, self).__init__()
|
|
self.num_filters = num_filters
|
|
self.base_filters = base_filters
|
|
self.input_filters = input_filters
|
|
self.with_norm = with_norm
|
|
if down_sampling:
|
|
if is_real_3d and down_sampling_temporal:
|
|
self.down_sampling_stride = down_sampling_stride
|
|
else:
|
|
self.down_sampling_stride = down_sampling_stride
|
|
else:
|
|
self.down_sampling_stride = [1, 1, 1]
|
|
|
|
self.down_sampling = down_sampling
|
|
|
|
self.act = nn.SiLU()
|
|
self.conv1 = CausalConv3dPlainAR(num_filters, num_filters, kernel_size=[3, 3, 3], stride=[1, 1, 1])
|
|
if self.with_norm:
|
|
self.norm1 = nn.GroupNorm(32, num_filters)
|
|
self.conv2 = CausalConv3dPlainAR(num_filters, num_filters, kernel_size=[3, 3, 3], stride=[1, 1, 1])
|
|
if self.with_norm:
|
|
self.norm2 = nn.GroupNorm(32, num_filters)
|
|
if num_filters != input_filters or down_sampling:
|
|
self.conv3 = CausalConv3dPlainAR(
|
|
input_filters, num_filters, kernel_size=[1, 1, 1], stride=self.down_sampling_stride
|
|
)
|
|
if self.with_norm:
|
|
self.norm3 = nn.GroupNorm(32, num_filters)
|
|
|
|
def _enable_tiled_conv3d(self, tile_size=16, tiled_dim=None, num_tiles=None):
|
|
self.conv1._enable_tiled_conv3d(tile_size=tile_size, tiled_dim=tiled_dim, num_tiles=num_tiles)
|
|
self.conv2._enable_tiled_conv3d(tile_size=tile_size, tiled_dim=tiled_dim, num_tiles=num_tiles)
|
|
if hasattr(self, "conv3"):
|
|
self.conv3._enable_tiled_conv3d(tile_size=tile_size, tiled_dim=tiled_dim, num_tiles=num_tiles)
|
|
|
|
def forward(self, x, is_training=False):
|
|
identity = x
|
|
out = self.conv1(x)
|
|
if self.with_norm:
|
|
out = self.norm1(out)
|
|
out = self.act(out)
|
|
|
|
out = self.conv2(out)
|
|
if self.with_norm:
|
|
out = self.norm2(out)
|
|
if self.down_sampling or self.num_filters != self.input_filters:
|
|
identity = self.conv3(identity)
|
|
if self.with_norm:
|
|
identity = self.norm3(identity)
|
|
if is_training:
|
|
out += identity
|
|
else:
|
|
out.add_(identity)
|
|
out = self.act(out)
|
|
return out
|
|
|
|
|
|
class Upsample3D(nn.Module):
|
|
def __init__(self, in_channels, with_conv, scale_factor=2):
|
|
super().__init__()
|
|
self.with_conv = with_conv
|
|
self.scale_factor = scale_factor
|
|
|
|
self.conv3d = Res3DBlockUpsample(
|
|
input_filters=in_channels,
|
|
num_filters=in_channels,
|
|
base_filters=in_channels,
|
|
down_sampling_stride=(1, 1, 1),
|
|
down_sampling=False,
|
|
)
|
|
|
|
def _split_by_channel(self, x, split_size):
|
|
slices = torch.split(x, split_size, dim=1)
|
|
return slices
|
|
|
|
def _split_by_batch(self, x, split_size):
|
|
slices = torch.split(x, split_size, dim=0)
|
|
return slices
|
|
|
|
def _enable_tiled_conv3d(self, tile_size=16, tiled_dim=None, num_tiles=None):
|
|
self.conv3d._enable_tiled_conv3d(tile_size=tile_size, tiled_dim=tiled_dim, num_tiles=num_tiles)
|
|
|
|
def forward(self, x, is_split=False, is_training=False):
|
|
b, c, t, h, w = x.shape
|
|
if is_split and not is_training:
|
|
split_size = c // 8
|
|
x_slices = self._split_by_channel(x, split_size)
|
|
|
|
x = [
|
|
torch.nn.functional.interpolate(
|
|
x,
|
|
(
|
|
x.shape[2] * self.scale_factor,
|
|
x.shape[3] * self.scale_factor,
|
|
x.shape[4] * self.scale_factor,
|
|
),
|
|
mode="nearest",
|
|
)
|
|
for x in x_slices
|
|
]
|
|
x = torch.cat(x, dim=1)
|
|
identity = x
|
|
if b > 2 and b % 2 == 0:
|
|
split_size = b // 2
|
|
x_slices = self._split_by_batch(x, split_size)
|
|
x = [self.conv3d(b_x) for b_x in x_slices]
|
|
x = torch.cat(x, dim=0)
|
|
else:
|
|
x = self.conv3d(x)
|
|
|
|
if is_training:
|
|
x += identity
|
|
else:
|
|
x.add_(identity)
|
|
return x
|
|
else:
|
|
x = torch.nn.functional.interpolate(
|
|
x,
|
|
(
|
|
x.shape[2] * self.scale_factor,
|
|
x.shape[3] * self.scale_factor,
|
|
x.shape[4] * self.scale_factor,
|
|
),
|
|
mode="nearest",
|
|
)
|
|
identity = x
|
|
x = self.conv3d(x)
|
|
if is_training:
|
|
x += identity
|
|
else:
|
|
x.add_(identity)
|
|
return x
|
|
|
|
|
|
class Upsample2D(nn.Module):
|
|
def __init__(self, in_channels, with_conv, micro_batch_size_2d=None):
|
|
super().__init__()
|
|
self.with_conv = with_conv
|
|
if self.with_conv:
|
|
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
|
self.micro_batch_size_2d = micro_batch_size_2d
|
|
|
|
@video_to_image
|
|
def forward(self, x, is_split=False, is_training=False):
|
|
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
|
if self.with_conv:
|
|
x = self.conv(x)
|
|
return x
|