mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 13:14:44 +02:00
259 lines
9 KiB
Python
259 lines
9 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder
|
|
from einops import rearrange
|
|
|
|
from opensora.registry import MODELS, build_module
|
|
from opensora.utils.ckpt_utils import load_checkpoint
|
|
|
|
|
|
@MODELS.register_module()
|
|
class VideoAutoencoderKL(nn.Module):
|
|
def __init__(
|
|
self, from_pretrained=None, micro_batch_size=None, cache_dir=None, local_files_only=False, subfolder=None
|
|
):
|
|
super().__init__()
|
|
self.module = AutoencoderKL.from_pretrained(
|
|
from_pretrained,
|
|
cache_dir=cache_dir,
|
|
local_files_only=local_files_only,
|
|
subfolder=subfolder,
|
|
)
|
|
self.out_channels = self.module.config.latent_channels
|
|
self.patch_size = (1, 8, 8)
|
|
self.micro_batch_size = micro_batch_size
|
|
|
|
def encode(self, x):
|
|
# x: (B, C, T, H, W)
|
|
B = x.shape[0]
|
|
x = rearrange(x, "B C T H W -> (B T) C H W")
|
|
|
|
if self.micro_batch_size is None:
|
|
x = self.module.encode(x).latent_dist.sample().mul_(0.18215)
|
|
else:
|
|
# NOTE: cannot be used for training
|
|
bs = self.micro_batch_size
|
|
x_out = []
|
|
for i in range(0, x.shape[0], bs):
|
|
x_bs = x[i : i + bs]
|
|
x_bs = self.module.encode(x_bs).latent_dist.sample().mul_(0.18215)
|
|
x_out.append(x_bs)
|
|
x = torch.cat(x_out, dim=0)
|
|
x = rearrange(x, "(B T) C H W -> B C T H W", B=B)
|
|
return x
|
|
|
|
def decode(self, x, **kwargs):
|
|
# x: (B, C, T, H, W)
|
|
B = x.shape[0]
|
|
x = rearrange(x, "B C T H W -> (B T) C H W")
|
|
if self.micro_batch_size is None:
|
|
x = self.module.decode(x / 0.18215).sample
|
|
else:
|
|
# NOTE: cannot be used for training
|
|
bs = self.micro_batch_size
|
|
x_out = []
|
|
for i in range(0, x.shape[0], bs):
|
|
x_bs = x[i : i + bs]
|
|
x_bs = self.module.decode(x_bs / 0.18215).sample
|
|
x_out.append(x_bs)
|
|
x = torch.cat(x_out, dim=0)
|
|
x = rearrange(x, "(B T) C H W -> B C T H W", B=B)
|
|
return x
|
|
|
|
def get_latent_size(self, input_size):
|
|
latent_size = []
|
|
for i in range(3):
|
|
# assert (
|
|
# input_size[i] is None or input_size[i] % self.patch_size[i] == 0
|
|
# ), "Input size must be divisible by patch size"
|
|
latent_size.append(input_size[i] // self.patch_size[i] if input_size[i] is not None else None)
|
|
return latent_size
|
|
|
|
@property
|
|
def device(self):
|
|
return next(self.parameters()).device
|
|
|
|
@property
|
|
def dtype(self):
|
|
return next(self.parameters()).dtype
|
|
|
|
|
|
@MODELS.register_module()
|
|
class VideoAutoencoderKLTemporalDecoder(nn.Module):
|
|
def __init__(self, from_pretrained=None, cache_dir=None, local_files_only=False):
|
|
super().__init__()
|
|
self.module = AutoencoderKLTemporalDecoder.from_pretrained(
|
|
from_pretrained, cache_dir=cache_dir, local_files_only=local_files_only
|
|
)
|
|
self.out_channels = self.module.config.latent_channels
|
|
self.patch_size = (1, 8, 8)
|
|
|
|
def encode(self, x):
|
|
raise NotImplementedError
|
|
|
|
def decode(self, x, **kwargs):
|
|
B, _, T = x.shape[:3]
|
|
x = rearrange(x, "B C T H W -> (B T) C H W")
|
|
x = self.module.decode(x / 0.18215, num_frames=T).sample
|
|
x = rearrange(x, "(B T) C H W -> B C T H W", B=B)
|
|
return x
|
|
|
|
def get_latent_size(self, input_size):
|
|
latent_size = []
|
|
for i in range(3):
|
|
# assert (
|
|
# input_size[i] is None or input_size[i] % self.patch_size[i] == 0
|
|
# ), "Input size must be divisible by patch size"
|
|
latent_size.append(input_size[i] // self.patch_size[i] if input_size[i] is not None else None)
|
|
return latent_size
|
|
|
|
@property
|
|
def device(self):
|
|
return next(self.parameters()).device
|
|
|
|
@property
|
|
def dtype(self):
|
|
return next(self.parameters()).dtype
|
|
|
|
|
|
@MODELS.register_module()
|
|
class VideoAutoencoderPipeline(nn.Module):
|
|
def __init__(
|
|
self,
|
|
vae_2d=None,
|
|
vae_temporal=None,
|
|
from_pretrained=None,
|
|
freeze_vae_2d=False,
|
|
cal_loss=False,
|
|
micro_frame_size=None,
|
|
shift=0.0,
|
|
scale=1.0,
|
|
):
|
|
super().__init__()
|
|
self.spatial_vae = build_module(vae_2d, MODELS)
|
|
self.temporal_vae = build_module(vae_temporal, MODELS)
|
|
self.cal_loss = cal_loss
|
|
self.micro_frame_size = micro_frame_size
|
|
self.micro_z_frame_size = self.temporal_vae.get_latent_size([micro_frame_size, None, None])[0]
|
|
|
|
if from_pretrained is not None:
|
|
load_checkpoint(self, from_pretrained)
|
|
if freeze_vae_2d:
|
|
for param in self.spatial_vae.parameters():
|
|
param.requires_grad = False
|
|
|
|
self.out_channels = self.temporal_vae.out_channels
|
|
self.scale = torch.tensor(scale).cuda()
|
|
self.shift = torch.tensor(shift).cuda()
|
|
if len(self.scale.shape) > 0:
|
|
self.scale = self.scale[None, :, None, None, None]
|
|
if len(self.shift.shape) > 0:
|
|
self.shift = self.shift[None, :, None, None, None]
|
|
|
|
def encode(self, x):
|
|
x_z = self.spatial_vae.encode(x)
|
|
|
|
if self.micro_frame_size is None:
|
|
posterior = self.temporal_vae.encode(x_z)
|
|
z = posterior.sample()
|
|
else:
|
|
z_list = []
|
|
for i in range(0, x_z.shape[2], self.micro_frame_size):
|
|
x_z_bs = x_z[:, :, i : i + self.micro_frame_size]
|
|
posterior = self.temporal_vae.encode(x_z_bs)
|
|
z_list.append(posterior.sample())
|
|
z = torch.cat(z_list, dim=2)
|
|
|
|
if self.cal_loss:
|
|
return z, posterior, x_z
|
|
else:
|
|
return (z - self.shift) / self.scale
|
|
|
|
def decode(self, z, num_frames=None):
|
|
if not self.cal_loss:
|
|
z = z * self.scale.to(z.dtype) + self.shift.to(z.dtype)
|
|
|
|
if self.micro_frame_size is None:
|
|
x_z = self.temporal_vae.decode(z, num_frames=num_frames)
|
|
x = self.spatial_vae.decode(x_z)
|
|
else:
|
|
x_z_list = []
|
|
for i in range(0, z.size(2), self.micro_z_frame_size):
|
|
z_bs = z[:, :, i : i + self.micro_z_frame_size]
|
|
x_z_bs = self.temporal_vae.decode(z_bs, num_frames=min(self.micro_frame_size, num_frames))
|
|
x_z_list.append(x_z_bs)
|
|
num_frames -= self.micro_frame_size
|
|
x_z = torch.cat(x_z_list, dim=2)
|
|
x = self.spatial_vae.decode(x_z)
|
|
|
|
if self.cal_loss:
|
|
return x, x_z
|
|
else:
|
|
return x
|
|
|
|
def forward(self, x):
|
|
assert self.cal_loss, "This method is only available when cal_loss is True"
|
|
z, posterior, x_z = self.encode(x)
|
|
x_rec, x_z_rec = self.decode(z, num_frames=x_z.shape[2])
|
|
return x_rec, x_z_rec, z, posterior, x_z
|
|
|
|
def get_latent_size(self, input_size):
|
|
if self.micro_frame_size is None or input_size[0] is None:
|
|
return self.temporal_vae.get_latent_size(self.spatial_vae.get_latent_size(input_size))
|
|
else:
|
|
sub_input_size = [self.micro_frame_size, input_size[1], input_size[2]]
|
|
sub_latent_size = self.temporal_vae.get_latent_size(self.spatial_vae.get_latent_size(sub_input_size))
|
|
sub_latent_size[0] = sub_latent_size[0] * (input_size[0] // self.micro_frame_size)
|
|
remain_temporal_size = [input_size[0] % self.micro_frame_size, None, None]
|
|
if remain_temporal_size[0] > 0:
|
|
remain_size = self.temporal_vae.get_latent_size(remain_temporal_size)
|
|
sub_latent_size[0] += remain_size[0]
|
|
return sub_latent_size
|
|
|
|
def get_temporal_last_layer(self):
|
|
return self.temporal_vae.decoder.conv_out.conv.weight
|
|
|
|
@property
|
|
def device(self):
|
|
return next(self.parameters()).device
|
|
|
|
@property
|
|
def dtype(self):
|
|
return next(self.parameters()).dtype
|
|
|
|
|
|
@MODELS.register_module()
|
|
class OpenSoraVAE_V1_2(VideoAutoencoderPipeline):
|
|
def __init__(
|
|
self,
|
|
micro_batch_size=4,
|
|
micro_frame_size=17,
|
|
from_pretrained=None,
|
|
local_files_only=True,
|
|
freeze_vae_2d=False,
|
|
cal_loss=False,
|
|
):
|
|
vae_2d = dict(
|
|
type="VideoAutoencoderKL",
|
|
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
|
subfolder="vae",
|
|
micro_batch_size=micro_batch_size,
|
|
local_files_only=local_files_only,
|
|
)
|
|
vae_temporal = dict(
|
|
type="VAE_Temporal_SD",
|
|
from_pretrained=None,
|
|
)
|
|
shift = (-0.10, 0.34, 0.27, 0.98)
|
|
scale = (3.85, 2.32, 2.33, 3.06)
|
|
super().__init__(
|
|
vae_2d,
|
|
vae_temporal,
|
|
from_pretrained,
|
|
freeze_vae_2d=freeze_vae_2d,
|
|
cal_loss=cal_loss,
|
|
micro_frame_size=micro_frame_size,
|
|
shift=shift,
|
|
scale=scale,
|
|
)
|