[feat] vae micro frame

This commit is contained in:
zhengzangw 2024-05-05 15:25:17 +00:00
parent c4f17b41fc
commit f77370fff6
9 changed files with 76 additions and 23 deletions

View file

@ -48,7 +48,6 @@ model = dict(
vae = dict(
type="VideoAutoencoderPipeline",
from_pretrained="pretrained_models/vae-v1",
training=False,
vae_2d=dict(
type="VideoAutoencoderKL",
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",

View file

@ -18,6 +18,7 @@ max_test_samples = None
model = dict(
type="VideoAutoencoderPipeline",
freeze_vae_2d=True,
cal_loss=True,
vae_2d=dict(
type="VideoAutoencoderKL",
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",

View file

@ -18,6 +18,8 @@ max_test_samples = None
model = dict(
type="VideoAutoencoderPipeline",
from_pretrained=None,
cal_loss=True,
micro_frame_size=16,
vae_2d=dict(
type="VideoAutoencoderKL",
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",

View file

@ -21,6 +21,7 @@ model = dict(
type="VideoAutoencoderPipeline",
freeze_vae_2d=True,
from_pretrained=None,
cal_loss=True,
vae_2d=dict(
type="VideoAutoencoderKL",
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",

View file

@ -21,6 +21,7 @@ model = dict(
type="VideoAutoencoderPipeline",
freeze_vae_2d=False,
from_pretrained=None,
cal_loss=True,
vae_2d=dict(
type="VideoAutoencoderKL",
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",

View file

@ -21,6 +21,7 @@ model = dict(
type="VideoAutoencoderPipeline",
freeze_vae_2d=False,
from_pretrained=None,
cal_loss=True,
vae_2d=dict(
type="VideoAutoencoderKL",
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
@ -37,18 +38,18 @@ discriminator = dict(
type="NLayerDiscriminator",
from_pretrained="/home/shenchenhui/opensoraplan-v1.0.0-discriminator.pt",
input_nc=3,
n_layers=3,
n_layers=3,
use_actnorm=False,
)
# discriminator hyper-parames TODO
discriminator_factor=1
discriminator_start=-1
generator_factor=0.5
generator_loss_type="hinge"
discriminator_loss_type="hinge"
lecam_loss_weight=None
gradient_penalty_loss_weight=None
discriminator_factor = 1
discriminator_start = -1
generator_factor = 0.5
generator_loss_type = "hinge"
discriminator_loss_type = "hinge"
lecam_loss_weight = None
gradient_penalty_loss_weight = None
# loss weights
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0

View file

@ -118,11 +118,21 @@ class VideoAutoencoderKLTemporalDecoder(nn.Module):
@MODELS.register_module()
class VideoAutoencoderPipeline(nn.Module):
def __init__(self, vae_2d=None, vae_temporal=None, freeze_vae_2d=False, from_pretrained=None, training=True):
def __init__(
self,
vae_2d=None,
vae_temporal=None,
from_pretrained=None,
freeze_vae_2d=False,
cal_loss=False,
micro_frame_size=None,
):
super().__init__()
self.spatial_vae = build_module(vae_2d, MODELS)
self.temporal_vae = build_module(vae_temporal, MODELS)
self.training = training
self.cal_loss = cal_loss
self.micro_frame_size = micro_frame_size
self.micro_z_frame_size = self.temporal_vae.get_latent_size([micro_frame_size, None, None])[0]
if from_pretrained is not None:
load_checkpoint(self, from_pretrained)
@ -135,28 +145,63 @@ class VideoAutoencoderPipeline(nn.Module):
def encode(self, x):
x_z = self.spatial_vae.encode(x)
posterior = self.temporal_vae.encode(x_z)
z = posterior.sample()
if self.training:
if self.micro_frame_size is None:
posterior = self.temporal_vae.encode(x_z)
z = posterior.sample()
else:
z_list = []
for i in range(0, x_z.shape[2], self.micro_frame_size):
x_z_bs = x_z[:, :, i : i + self.micro_frame_size]
posterior = self.temporal_vae.encode(x_z_bs)
z_list.append(posterior.sample())
z = torch.cat(z_list, dim=2)
if self.cal_loss:
return z, posterior, x_z
return z / self.scale
else:
return z / self.scale
def decode(self, z, num_frames=None):
if not self.training:
if not self.cal_loss:
z = z * self.scale
x_z = self.temporal_vae.decode(z, num_frames=num_frames)
x = self.spatial_vae.decode(x_z)
if self.training:
if self.micro_frame_size is None:
x_z = self.temporal_vae.decode(z, num_frames=num_frames)
x = self.spatial_vae.decode(x_z)
else:
x_z_list = []
for i in range(0, z.size(2), self.micro_z_frame_size):
z_bs = z[:, :, i : i + self.micro_z_frame_size]
x_z_bs = self.temporal_vae.decode(z_bs, num_frames=min(self.micro_frame_size, num_frames))
x_z_list.append(x_z_bs)
num_frames -= self.micro_frame_size
x_z = torch.cat(x_z_list, dim=2)
x = self.spatial_vae.decode(x_z)
if self.cal_loss:
return x, x_z
return x
else:
return x
def forward(self, x):
assert self.cal_loss, "This method is only available when cal_loss is True"
z, posterior, x_z = self.encode(x)
x_rec, x_z_rec = self.decode(z, num_frames=x_z.shape[2])
return x_rec, x_z_rec, z, posterior, x_z
def get_latent_size(self, input_size):
return self.temporal_vae.get_latent_size(self.spatial_vae.get_latent_size(input_size))
if self.micro_frame_size is None:
return self.temporal_vae.get_latent_size(self.spatial_vae.get_latent_size(input_size))
else:
sub_input_size = [self.micro_frame_size, input_size[1], input_size[2]]
sub_latent_size = self.temporal_vae.get_latent_size(self.spatial_vae.get_latent_size(sub_input_size))
sub_latent_size[0] = sub_latent_size[0] * (input_size[0] // self.micro_frame_size)
remain_temporal_size = [input_size[0] % self.micro_frame_size, None, None]
if remain_temporal_size[0] > 0:
remain_size = self.temporal_vae.get_latent_size(remain_temporal_size)
sub_latent_size[0] += remain_size[0]
return sub_latent_size
def get_temporal_last_layer(self):
return self.temporal_vae.decoder.conv_out.conv.weight

View file

@ -96,9 +96,12 @@ def main():
for step in pbar:
batch = next(dataloader_iter)
x = batch["video"].to(device, dtype) # [B, C, T, H, W]
input_size = x.shape[2:]
latent_size = model.get_latent_size(input_size)
# ===== VAE =====
z, posterior, x_z = model.encode(x, training=True)
z, posterior, x_z = model.encode(x)
assert list(z.shape[2:]) == latent_size, f"z shape: {z.shape}, latent_size: {latent_size}"
x_rec, x_z_rec = model.decode(z, num_frames=x.size(2))
x_ref = model.spatial_vae.decode(x_z)

View file

@ -8,7 +8,7 @@ OUTPUT_PATH=/home/data/sora_data/pixart-sigma-generated/raw
CMD="python scripts/inference.py configs/pixart/inference/1x2048MS.py"
LOG_BASE=logs/sample/generate
NUM_PER_GPU=10000
N_LAUNCH=6
N_LAUNCH=2
NUM_START=$(($N_LAUNCH * $NUM_PER_GPU * 8))
CUDA_VISIBLE_DEVICES=0 $CMD --prompt-path $TEXT_PATH --save-dir $OUTPUT_PATH --start-index $(($NUM_START + $NUM_PER_GPU * 0)) --end-index $(($NUM_START + $NUM_PER_GPU * 1)) --image-size 2048 2048 --verbose 1 --batch-size 2 >${LOG_BASE}_${N_LAUNCH}_1.log 2>&1 &