From f77370fff6979bc97e22f95cc19ff18bb608eadc Mon Sep 17 00:00:00 2001 From: zhengzangw Date: Sun, 5 May 2024 15:25:17 +0000 Subject: [PATCH] [feat] vae micro frame --- configs/opensora-v1-2/train/adapt-vae.py | 1 - configs/vae/inference/image.py | 1 + configs/vae/inference/video.py | 2 + configs/vae/train/image.py | 1 + configs/vae/train/video.py | 1 + configs/vae/train/video_disc.py | 17 +++--- opensora/models/vae/vae.py | 69 +++++++++++++++++++----- scripts/inference-vae.py | 5 +- scripts/misc/generate.sh | 2 +- 9 files changed, 76 insertions(+), 23 deletions(-) diff --git a/configs/opensora-v1-2/train/adapt-vae.py b/configs/opensora-v1-2/train/adapt-vae.py index eb99249..dbf470d 100644 --- a/configs/opensora-v1-2/train/adapt-vae.py +++ b/configs/opensora-v1-2/train/adapt-vae.py @@ -48,7 +48,6 @@ model = dict( vae = dict( type="VideoAutoencoderPipeline", from_pretrained="pretrained_models/vae-v1", - training=False, vae_2d=dict( type="VideoAutoencoderKL", from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", diff --git a/configs/vae/inference/image.py b/configs/vae/inference/image.py index d83348a..476e565 100644 --- a/configs/vae/inference/image.py +++ b/configs/vae/inference/image.py @@ -18,6 +18,7 @@ max_test_samples = None model = dict( type="VideoAutoencoderPipeline", freeze_vae_2d=True, + cal_loss=True, vae_2d=dict( type="VideoAutoencoderKL", from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", diff --git a/configs/vae/inference/video.py b/configs/vae/inference/video.py index 70fc49d..37e3ffe 100644 --- a/configs/vae/inference/video.py +++ b/configs/vae/inference/video.py @@ -18,6 +18,8 @@ max_test_samples = None model = dict( type="VideoAutoencoderPipeline", from_pretrained=None, + cal_loss=True, + micro_frame_size=16, vae_2d=dict( type="VideoAutoencoderKL", from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", diff --git a/configs/vae/train/image.py b/configs/vae/train/image.py index d20602c..46621dc 100644 --- a/configs/vae/train/image.py +++ b/configs/vae/train/image.py @@ -21,6 +21,7 @@ model = dict( type="VideoAutoencoderPipeline", freeze_vae_2d=True, from_pretrained=None, + cal_loss=True, vae_2d=dict( type="VideoAutoencoderKL", from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", diff --git a/configs/vae/train/video.py b/configs/vae/train/video.py index 02bbb24..dea41e5 100644 --- a/configs/vae/train/video.py +++ b/configs/vae/train/video.py @@ -21,6 +21,7 @@ model = dict( type="VideoAutoencoderPipeline", freeze_vae_2d=False, from_pretrained=None, + cal_loss=True, vae_2d=dict( type="VideoAutoencoderKL", from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", diff --git a/configs/vae/train/video_disc.py b/configs/vae/train/video_disc.py index 9967c5f..7af989d 100644 --- a/configs/vae/train/video_disc.py +++ b/configs/vae/train/video_disc.py @@ -21,6 +21,7 @@ model = dict( type="VideoAutoencoderPipeline", freeze_vae_2d=False, from_pretrained=None, + cal_loss=True, vae_2d=dict( type="VideoAutoencoderKL", from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", @@ -37,18 +38,18 @@ discriminator = dict( type="NLayerDiscriminator", from_pretrained="/home/shenchenhui/opensoraplan-v1.0.0-discriminator.pt", input_nc=3, - n_layers=3, + n_layers=3, use_actnorm=False, ) # discriminator hyper-parames TODO -discriminator_factor=1 -discriminator_start=-1 -generator_factor=0.5 -generator_loss_type="hinge" -discriminator_loss_type="hinge" -lecam_loss_weight=None -gradient_penalty_loss_weight=None +discriminator_factor = 1 +discriminator_start = -1 +generator_factor = 0.5 +generator_loss_type = "hinge" +discriminator_loss_type = "hinge" +lecam_loss_weight = None +gradient_penalty_loss_weight = None # loss weights perceptual_loss_weight = 0.1 # use vgg is not None and more than 0 diff --git a/opensora/models/vae/vae.py b/opensora/models/vae/vae.py index 2dcc25f..2fd4d57 100644 --- a/opensora/models/vae/vae.py +++ b/opensora/models/vae/vae.py @@ -118,11 +118,21 @@ class VideoAutoencoderKLTemporalDecoder(nn.Module): @MODELS.register_module() class VideoAutoencoderPipeline(nn.Module): - def __init__(self, vae_2d=None, vae_temporal=None, freeze_vae_2d=False, from_pretrained=None, training=True): + def __init__( + self, + vae_2d=None, + vae_temporal=None, + from_pretrained=None, + freeze_vae_2d=False, + cal_loss=False, + micro_frame_size=None, + ): super().__init__() self.spatial_vae = build_module(vae_2d, MODELS) self.temporal_vae = build_module(vae_temporal, MODELS) - self.training = training + self.cal_loss = cal_loss + self.micro_frame_size = micro_frame_size + self.micro_z_frame_size = self.temporal_vae.get_latent_size([micro_frame_size, None, None])[0] if from_pretrained is not None: load_checkpoint(self, from_pretrained) @@ -135,28 +145,63 @@ class VideoAutoencoderPipeline(nn.Module): def encode(self, x): x_z = self.spatial_vae.encode(x) - posterior = self.temporal_vae.encode(x_z) - z = posterior.sample() - if self.training: + + if self.micro_frame_size is None: + posterior = self.temporal_vae.encode(x_z) + z = posterior.sample() + else: + z_list = [] + for i in range(0, x_z.shape[2], self.micro_frame_size): + x_z_bs = x_z[:, :, i : i + self.micro_frame_size] + posterior = self.temporal_vae.encode(x_z_bs) + z_list.append(posterior.sample()) + z = torch.cat(z_list, dim=2) + + if self.cal_loss: return z, posterior, x_z - return z / self.scale + else: + return z / self.scale def decode(self, z, num_frames=None): - if not self.training: + if not self.cal_loss: z = z * self.scale - x_z = self.temporal_vae.decode(z, num_frames=num_frames) - x = self.spatial_vae.decode(x_z) - if self.training: + + if self.micro_frame_size is None: + x_z = self.temporal_vae.decode(z, num_frames=num_frames) + x = self.spatial_vae.decode(x_z) + else: + x_z_list = [] + for i in range(0, z.size(2), self.micro_z_frame_size): + z_bs = z[:, :, i : i + self.micro_z_frame_size] + x_z_bs = self.temporal_vae.decode(z_bs, num_frames=min(self.micro_frame_size, num_frames)) + x_z_list.append(x_z_bs) + num_frames -= self.micro_frame_size + x_z = torch.cat(x_z_list, dim=2) + x = self.spatial_vae.decode(x_z) + + if self.cal_loss: return x, x_z - return x + else: + return x def forward(self, x): + assert self.cal_loss, "This method is only available when cal_loss is True" z, posterior, x_z = self.encode(x) x_rec, x_z_rec = self.decode(z, num_frames=x_z.shape[2]) return x_rec, x_z_rec, z, posterior, x_z def get_latent_size(self, input_size): - return self.temporal_vae.get_latent_size(self.spatial_vae.get_latent_size(input_size)) + if self.micro_frame_size is None: + return self.temporal_vae.get_latent_size(self.spatial_vae.get_latent_size(input_size)) + else: + sub_input_size = [self.micro_frame_size, input_size[1], input_size[2]] + sub_latent_size = self.temporal_vae.get_latent_size(self.spatial_vae.get_latent_size(sub_input_size)) + sub_latent_size[0] = sub_latent_size[0] * (input_size[0] // self.micro_frame_size) + remain_temporal_size = [input_size[0] % self.micro_frame_size, None, None] + if remain_temporal_size[0] > 0: + remain_size = self.temporal_vae.get_latent_size(remain_temporal_size) + sub_latent_size[0] += remain_size[0] + return sub_latent_size def get_temporal_last_layer(self): return self.temporal_vae.decoder.conv_out.conv.weight diff --git a/scripts/inference-vae.py b/scripts/inference-vae.py index 06c1867..78ddc47 100644 --- a/scripts/inference-vae.py +++ b/scripts/inference-vae.py @@ -96,9 +96,12 @@ def main(): for step in pbar: batch = next(dataloader_iter) x = batch["video"].to(device, dtype) # [B, C, T, H, W] + input_size = x.shape[2:] + latent_size = model.get_latent_size(input_size) # ===== VAE ===== - z, posterior, x_z = model.encode(x, training=True) + z, posterior, x_z = model.encode(x) + assert list(z.shape[2:]) == latent_size, f"z shape: {z.shape}, latent_size: {latent_size}" x_rec, x_z_rec = model.decode(z, num_frames=x.size(2)) x_ref = model.spatial_vae.decode(x_z) diff --git a/scripts/misc/generate.sh b/scripts/misc/generate.sh index 6056d20..33888e5 100644 --- a/scripts/misc/generate.sh +++ b/scripts/misc/generate.sh @@ -8,7 +8,7 @@ OUTPUT_PATH=/home/data/sora_data/pixart-sigma-generated/raw CMD="python scripts/inference.py configs/pixart/inference/1x2048MS.py" LOG_BASE=logs/sample/generate NUM_PER_GPU=10000 -N_LAUNCH=6 +N_LAUNCH=2 NUM_START=$(($N_LAUNCH * $NUM_PER_GPU * 8)) CUDA_VISIBLE_DEVICES=0 $CMD --prompt-path $TEXT_PATH --save-dir $OUTPUT_PATH --start-index $(($NUM_START + $NUM_PER_GPU * 0)) --end-index $(($NUM_START + $NUM_PER_GPU * 1)) --image-size 2048 2048 --verbose 1 --batch-size 2 >${LOG_BASE}_${N_LAUNCH}_1.log 2>&1 &