import torch import torch.nn as nn from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder from einops import rearrange from opensora.registry import MODELS @MODELS.register_module() class VideoAutoencoderKL(nn.Module): def __init__(self, from_pretrained=None, split=None): super().__init__() self.module = AutoencoderKL.from_pretrained(from_pretrained) self.out_channels = self.module.config.latent_channels self.patch_size = (1, 8, 8) self.split = split def encode(self, x): # x: (B, C, T, H, W) B = x.shape[0] x = rearrange(x, "B C T H W -> (B T) C H W") if self.split is None: x = self.module.encode(x).latent_dist.sample().mul_(0.18215) else: bs = x.shape[0] // self.split x_out = [] for i in range(self.split): x_out.append(self.module.encode(x[i * bs : (i + 1) * bs]).latent_dist.sample().mul_(0.18215)) x = torch.cat(x_out, dim=0) x = rearrange(x, "(B T) C H W -> B C T H W", B=B) return x def decode(self, x): # x: (B, C, T, H, W) B = x.shape[0] x = rearrange(x, "B C T H W -> (B T) C H W") if self.split is None: x = self.module.decode(x / 0.18215).sample else: bs = x.shape[0] // self.split x_out = [] for i in range(self.split): x_out.append(self.module.decode(x[i * bs : (i + 1) * bs] / 0.18215).sample) x = torch.cat(x_out, dim=0) x = rearrange(x, "(B T) C H W -> B C T H W", B=B) return x def get_latent_size(self, input_size): for i in range(3): assert input_size[i] % self.patch_size[i] == 0, "Input size must be divisible by patch size" input_size = [input_size[i] // self.patch_size[i] for i in range(3)] return input_size @MODELS.register_module() class VideoAutoencoderKLTemporalDecoder(nn.Module): def __init__(self, from_pretrained=None): super().__init__() self.module = AutoencoderKLTemporalDecoder.from_pretrained(from_pretrained) self.out_channels = self.module.config.latent_channels self.patch_size = (1, 8, 8) def encode(self, x): raise NotImplementedError def decode(self, x): B, _, T = x.shape[:3] x = rearrange(x, "B C T H W -> (B T) C H W") x = self.module.decode(x / 0.18215, num_frames=T).sample x = rearrange(x, "(B T) C H W -> B C T H W", B=B) return x def get_latent_size(self, input_size): for i in range(3): assert input_size[i] % self.patch_size[i] == 0, "Input size must be divisible by patch size" input_size = [input_size[i] // self.patch_size[i] for i in range(3)] return input_size