mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-05-15 09:52:10 +02:00
[feat] vae micro frame
This commit is contained in:
parent
c4f17b41fc
commit
f77370fff6
|
|
@ -48,7 +48,6 @@ model = dict(
|
|||
vae = dict(
|
||||
type="VideoAutoencoderPipeline",
|
||||
from_pretrained="pretrained_models/vae-v1",
|
||||
training=False,
|
||||
vae_2d=dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ max_test_samples = None
|
|||
model = dict(
|
||||
type="VideoAutoencoderPipeline",
|
||||
freeze_vae_2d=True,
|
||||
cal_loss=True,
|
||||
vae_2d=dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ max_test_samples = None
|
|||
model = dict(
|
||||
type="VideoAutoencoderPipeline",
|
||||
from_pretrained=None,
|
||||
cal_loss=True,
|
||||
micro_frame_size=16,
|
||||
vae_2d=dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ model = dict(
|
|||
type="VideoAutoencoderPipeline",
|
||||
freeze_vae_2d=True,
|
||||
from_pretrained=None,
|
||||
cal_loss=True,
|
||||
vae_2d=dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ model = dict(
|
|||
type="VideoAutoencoderPipeline",
|
||||
freeze_vae_2d=False,
|
||||
from_pretrained=None,
|
||||
cal_loss=True,
|
||||
vae_2d=dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ model = dict(
|
|||
type="VideoAutoencoderPipeline",
|
||||
freeze_vae_2d=False,
|
||||
from_pretrained=None,
|
||||
cal_loss=True,
|
||||
vae_2d=dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
||||
|
|
@ -37,18 +38,18 @@ discriminator = dict(
|
|||
type="NLayerDiscriminator",
|
||||
from_pretrained="/home/shenchenhui/opensoraplan-v1.0.0-discriminator.pt",
|
||||
input_nc=3,
|
||||
n_layers=3,
|
||||
n_layers=3,
|
||||
use_actnorm=False,
|
||||
)
|
||||
|
||||
# discriminator hyper-parames TODO
|
||||
discriminator_factor=1
|
||||
discriminator_start=-1
|
||||
generator_factor=0.5
|
||||
generator_loss_type="hinge"
|
||||
discriminator_loss_type="hinge"
|
||||
lecam_loss_weight=None
|
||||
gradient_penalty_loss_weight=None
|
||||
discriminator_factor = 1
|
||||
discriminator_start = -1
|
||||
generator_factor = 0.5
|
||||
generator_loss_type = "hinge"
|
||||
discriminator_loss_type = "hinge"
|
||||
lecam_loss_weight = None
|
||||
gradient_penalty_loss_weight = None
|
||||
|
||||
# loss weights
|
||||
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
|
||||
|
|
|
|||
|
|
@ -118,11 +118,21 @@ class VideoAutoencoderKLTemporalDecoder(nn.Module):
|
|||
|
||||
@MODELS.register_module()
|
||||
class VideoAutoencoderPipeline(nn.Module):
|
||||
def __init__(self, vae_2d=None, vae_temporal=None, freeze_vae_2d=False, from_pretrained=None, training=True):
|
||||
def __init__(
|
||||
self,
|
||||
vae_2d=None,
|
||||
vae_temporal=None,
|
||||
from_pretrained=None,
|
||||
freeze_vae_2d=False,
|
||||
cal_loss=False,
|
||||
micro_frame_size=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.spatial_vae = build_module(vae_2d, MODELS)
|
||||
self.temporal_vae = build_module(vae_temporal, MODELS)
|
||||
self.training = training
|
||||
self.cal_loss = cal_loss
|
||||
self.micro_frame_size = micro_frame_size
|
||||
self.micro_z_frame_size = self.temporal_vae.get_latent_size([micro_frame_size, None, None])[0]
|
||||
|
||||
if from_pretrained is not None:
|
||||
load_checkpoint(self, from_pretrained)
|
||||
|
|
@ -135,28 +145,63 @@ class VideoAutoencoderPipeline(nn.Module):
|
|||
|
||||
def encode(self, x):
|
||||
x_z = self.spatial_vae.encode(x)
|
||||
posterior = self.temporal_vae.encode(x_z)
|
||||
z = posterior.sample()
|
||||
if self.training:
|
||||
|
||||
if self.micro_frame_size is None:
|
||||
posterior = self.temporal_vae.encode(x_z)
|
||||
z = posterior.sample()
|
||||
else:
|
||||
z_list = []
|
||||
for i in range(0, x_z.shape[2], self.micro_frame_size):
|
||||
x_z_bs = x_z[:, :, i : i + self.micro_frame_size]
|
||||
posterior = self.temporal_vae.encode(x_z_bs)
|
||||
z_list.append(posterior.sample())
|
||||
z = torch.cat(z_list, dim=2)
|
||||
|
||||
if self.cal_loss:
|
||||
return z, posterior, x_z
|
||||
return z / self.scale
|
||||
else:
|
||||
return z / self.scale
|
||||
|
||||
def decode(self, z, num_frames=None):
|
||||
if not self.training:
|
||||
if not self.cal_loss:
|
||||
z = z * self.scale
|
||||
x_z = self.temporal_vae.decode(z, num_frames=num_frames)
|
||||
x = self.spatial_vae.decode(x_z)
|
||||
if self.training:
|
||||
|
||||
if self.micro_frame_size is None:
|
||||
x_z = self.temporal_vae.decode(z, num_frames=num_frames)
|
||||
x = self.spatial_vae.decode(x_z)
|
||||
else:
|
||||
x_z_list = []
|
||||
for i in range(0, z.size(2), self.micro_z_frame_size):
|
||||
z_bs = z[:, :, i : i + self.micro_z_frame_size]
|
||||
x_z_bs = self.temporal_vae.decode(z_bs, num_frames=min(self.micro_frame_size, num_frames))
|
||||
x_z_list.append(x_z_bs)
|
||||
num_frames -= self.micro_frame_size
|
||||
x_z = torch.cat(x_z_list, dim=2)
|
||||
x = self.spatial_vae.decode(x_z)
|
||||
|
||||
if self.cal_loss:
|
||||
return x, x_z
|
||||
return x
|
||||
else:
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
assert self.cal_loss, "This method is only available when cal_loss is True"
|
||||
z, posterior, x_z = self.encode(x)
|
||||
x_rec, x_z_rec = self.decode(z, num_frames=x_z.shape[2])
|
||||
return x_rec, x_z_rec, z, posterior, x_z
|
||||
|
||||
def get_latent_size(self, input_size):
|
||||
return self.temporal_vae.get_latent_size(self.spatial_vae.get_latent_size(input_size))
|
||||
if self.micro_frame_size is None:
|
||||
return self.temporal_vae.get_latent_size(self.spatial_vae.get_latent_size(input_size))
|
||||
else:
|
||||
sub_input_size = [self.micro_frame_size, input_size[1], input_size[2]]
|
||||
sub_latent_size = self.temporal_vae.get_latent_size(self.spatial_vae.get_latent_size(sub_input_size))
|
||||
sub_latent_size[0] = sub_latent_size[0] * (input_size[0] // self.micro_frame_size)
|
||||
remain_temporal_size = [input_size[0] % self.micro_frame_size, None, None]
|
||||
if remain_temporal_size[0] > 0:
|
||||
remain_size = self.temporal_vae.get_latent_size(remain_temporal_size)
|
||||
sub_latent_size[0] += remain_size[0]
|
||||
return sub_latent_size
|
||||
|
||||
def get_temporal_last_layer(self):
|
||||
return self.temporal_vae.decoder.conv_out.conv.weight
|
||||
|
|
|
|||
|
|
@ -96,9 +96,12 @@ def main():
|
|||
for step in pbar:
|
||||
batch = next(dataloader_iter)
|
||||
x = batch["video"].to(device, dtype) # [B, C, T, H, W]
|
||||
input_size = x.shape[2:]
|
||||
latent_size = model.get_latent_size(input_size)
|
||||
|
||||
# ===== VAE =====
|
||||
z, posterior, x_z = model.encode(x, training=True)
|
||||
z, posterior, x_z = model.encode(x)
|
||||
assert list(z.shape[2:]) == latent_size, f"z shape: {z.shape}, latent_size: {latent_size}"
|
||||
x_rec, x_z_rec = model.decode(z, num_frames=x.size(2))
|
||||
x_ref = model.spatial_vae.decode(x_z)
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ OUTPUT_PATH=/home/data/sora_data/pixart-sigma-generated/raw
|
|||
CMD="python scripts/inference.py configs/pixart/inference/1x2048MS.py"
|
||||
LOG_BASE=logs/sample/generate
|
||||
NUM_PER_GPU=10000
|
||||
N_LAUNCH=6
|
||||
N_LAUNCH=2
|
||||
NUM_START=$(($N_LAUNCH * $NUM_PER_GPU * 8))
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 $CMD --prompt-path $TEXT_PATH --save-dir $OUTPUT_PATH --start-index $(($NUM_START + $NUM_PER_GPU * 0)) --end-index $(($NUM_START + $NUM_PER_GPU * 1)) --image-size 2048 2048 --verbose 1 --batch-size 2 >${LOG_BASE}_${N_LAUNCH}_1.log 2>&1 &
|
||||
|
|
|
|||
Loading…
Reference in a new issue