diff --git a/configs/vae_3d/train/16x256x256.py b/configs/vae_3d/train/16x256x256.py index 54120c3..41f6af5 100644 --- a/configs/vae_3d/train/16x256x256.py +++ b/configs/vae_3d/train/16x256x256.py @@ -18,10 +18,6 @@ sp_size = 1 kl_weight = 0.000001 perceptual_weight = 1.0 - -# training -wandb = True - # Define model model = dict( diff --git a/opensora/models/vae/discriminator_3d.py b/opensora/models/vae/discriminator_3d.py index 53fc388..4c78a81 100644 --- a/opensora/models/vae/discriminator_3d.py +++ b/opensora/models/vae/discriminator_3d.py @@ -6,7 +6,6 @@ from typing import Any import torch import torch.nn as nn -import ml_collections # TODO: torch.nn.init.xavier_uniform_ # default_kernel_init = nn.initializers.xavier_uniform() @@ -19,7 +18,6 @@ class ResBlock(nn.Module): in_channels, filters, activation_fn, - input_dim, # x.shape[-1], TODO num_groups=32, device="cpu", dtype=torch.bfloat16, @@ -29,13 +27,12 @@ class ResBlock(nn.Module): self.filters = filters self.activation_fn = activation_fn - # TODO: figure out the input_dim - - self.conv1 = nn.Conv3d(in_channels, self.filters, (3,3,3)) # need to init to xavier_uniform - self.norm1 = nn.GroupNorm(num_groups, self.filters, device=device, dtype=dtype) - self.avg_pool_with_t = nn.AvgPool3d((2,2,2)) + # SCH: NOTE: although paper says conv (X->Y, Y->Y), original code implementation is (X->X, X->Y), we follow code + self.conv1 = nn.Conv3d(in_channels, in_channels, (3,3,3)) # TODO: need to init to xavier_uniform + self.norm1 = nn.GroupNorm(num_groups, in_channels, device=device, dtype=dtype) + self.avg_pool_with_t = nn.AvgPool3d((2,2,2),count_include_pad=False) self.conv2 = nn.Conv3d(in_channels, self.filters,(1,1,1), use_bias=False) # need to init to xavier_uniform - self.conv3 = nn.Conv3d(input_dim, self.filters, (3,3,3)) # need to init to xavier_uniform + self.conv3 = nn.Conv3d(in_channels, self.filters, (3,3,3)) # need to init to xavier_uniform self.norm2 = nn.GroupNorm(num_groups, self.filters, device=device, dtype=dtype) def forward(self, x): @@ -59,7 +56,7 @@ class StyleGANDiscriminator(nn.Module): self, config, image_size, - input_dim, # x.shape[-1] + num_frames, discriminator_in_channels = 3, discriminator_filters = 64, discriminator_channel_multipliers = (2,4,4,4,4), @@ -69,16 +66,11 @@ class StyleGANDiscriminator(nn.Module): ): self.config = config self.dtype = dtype - self.input_size = image_size self.filters = discriminator_filters - self.activation_fn = nn.LeakyReLu(negative_slope=0.2) - self.channel_multipliers = discriminator_channel_multipliers - - self.conv1 = nn.Conv3d(discriminator_in_channels, self.filters, (3, 3, 3)) # need to init to xavier_uniform prev_filters = self.filters # record in_channels @@ -86,7 +78,7 @@ class StyleGANDiscriminator(nn.Module): self.res_block_list = [] for i in range(self.num_blocks): filters = self.filters * self.channel_multipliers[i] - self.res_block_list.append(ResBlock(prev_filters, filters, self.activation_fn)) # TODO + self.res_block_list.append(ResBlock(prev_filters, filters, self.activation_fn)) prev_filters = filters # update in_channels self.conv2 = nn.Conv3d(prev_filters, prev_filters, (3,3,3)) # need to init to xavier_uniform @@ -94,6 +86,10 @@ class StyleGANDiscriminator(nn.Module): self.norm1 = nn.GroupNorm(num_groups, prev_filters, dtype=dtype, device=device) # TODO: what is the in_features + scale_factor = 2 ** len(self.num_blocks) + time_scaled = num_frames / scale_factor + image_scaled = image_size / scale_factor + in_features = prev_filters * time_scaled * image_scaled * image_scaled # (C*T*W*H) self.linear1 = nn.Linear(in_features, prev_filters, device=device, dtype=dtype) # need to init to xavier_uniform self.linear2 = nn.Linear(prev_filters, 1, device=device, dtype=dtype) # need to init to xavier_uniform diff --git a/opensora/models/vae/model_utils.py b/opensora/models/vae/model_utils.py index 8ad0bd9..979caa1 100644 --- a/opensora/models/vae/model_utils.py +++ b/opensora/models/vae/model_utils.py @@ -7,7 +7,7 @@ import numpy as np import torch from taming.modules.losses.lpips import LPIPS # need to pip install https://github.com/CompVis/taming-transformers from taming.modules.discriminator.model import NLayerDiscriminator, weights_init - +from einops import rearrange """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" @@ -105,16 +105,22 @@ class VEA3DLoss(nn.Module): if self.perceptual_weight > 0: # NOTE: need in_channels == 3 in order to use! assert inputs.size(1) == 3, f"using vgg16 that requires 3 input channels but got {inputs.size(1)}" # SCH: transform to [(B,T), C, H, W] shape for percetual loss over each frame - permutated_input = torch.permute(inputs, (0, 2, 1, 3, 4)) # [B, C, T, H, W] --> [B, T, C, H, W] - permutated_rec = torch.permute(reconstructions, (0, 2, 1, 3, 4)) - data_shape = permutated_input.size() - p_loss = self.perceptual_loss( - permutated_input.reshape(-1, data_shape[-3], data_shape[-2],data_shape[-1]).contiguous(), - permutated_rec.reshape(-1, data_shape[-3], data_shape[-2],data_shape[-1]).contiguous() - ) + B = inputs.shape[0] + inputs = rearrange(inputs,"B C T H W -> (B T) C H W") + reconstructions = rearrange(reconstructions, "B C T H W -> (B T) C H W") + # permutated_input = torch.permute(inputs, (0, 2, 1, 3, 4)) # [B, C, T, H, W] --> [B, T, C, H, W] + # permutated_rec = torch.permute(reconstructions, (0, 2, 1, 3, 4)) + # data_shape = permutated_input.size() + # p_loss = self.perceptual_loss( + # permutated_input.reshape(-1, data_shape[-3], data_shape[-2],data_shape[-1]).contiguous(), + # permutated_rec.reshape(-1, data_shape[-3], data_shape[-2],data_shape[-1]).contiguous() + # ) + p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) # SCH: shape back p_loss - permuted_p_loss = torch.permute(p_loss.reshape(data_shape[0], data_shape[1], 1, 1, 1), (0,2,1,3,4)) - rec_loss = rec_loss + self.perceptual_weight * permuted_p_loss + # permuted_p_loss = torch.permute(p_loss.reshape(data_shape[0], data_shape[1], 1, 1, 1), (0,2,1,3,4)) + # rec_loss = rec_loss + self.perceptual_weight * permuted_p_loss + p_loss = rearrange(p_loss, "(B T) C H W -> B C T H W", B=B) + rec_loss = rec_loss + self.perceptual_weight * p_loss nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar weighted_nll_loss = nll_loss diff --git a/opensora/models/vae/vae_3d.py b/opensora/models/vae/vae_3d.py index b9891a3..2ad556a 100644 --- a/opensora/models/vae/vae_3d.py +++ b/opensora/models/vae/vae_3d.py @@ -415,6 +415,16 @@ class VAE_3D(nn.Module): self.quant_conv = nn.Conv3d(latent_embed_dim, 2*kl_embed_dim, 1) self.post_quant_conv = nn.Conv3d(kl_embed_dim, latent_embed_dim, 1) + image_down = 2 ** len(temporal_downsample) + t_down = 2 ** len([x for x in temporal_downsample if x == True]) + self.patch_size = (t_down, image_down, image_down) + + def get_latent_size(self, input_size): + for i in range(len(input_size)): + assert input_size[i] % self.patch_size[i] == 0, "Input size must be divisible by patch size" + input_size = [input_size[i] // self.patch_size[i] for i in range(3)] + return input_size + def encode( self, x, diff --git a/scripts/inference-vae.py b/scripts/inference-vae.py new file mode 100644 index 0000000..fe111b8 --- /dev/null +++ b/scripts/inference-vae.py @@ -0,0 +1,105 @@ +import os + +import colossalai +import torch +import torch.distributed as dist +from colossalai.cluster import DistCoordinator +from mmengine.runner import set_random_seed + +from opensora.acceleration.parallel_states import set_sequence_parallel_group +from opensora.datasets import save_sample +from opensora.registry import MODELS, SCHEDULERS, build_module +from opensora.utils.config_utils import parse_configs +from opensora.utils.misc import to_torch_dtype + + +def main(): + # ====================================================== + # 1. cfg and init distributed env + # ====================================================== + cfg = parse_configs(training=False) + print(cfg) + + # init distributed + colossalai.launch_from_torch({}) + coordinator = DistCoordinator() + + if coordinator.world_size > 1: + set_sequence_parallel_group(dist.group.WORLD) + enable_sequence_parallelism = True + else: + enable_sequence_parallelism = False + + # ====================================================== + # 2. runtime variables + # ====================================================== + torch.set_grad_enabled(False) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + device = "cuda" if torch.cuda.is_available() else "cpu" + dtype = to_torch_dtype(cfg.dtype) + set_random_seed(seed=cfg.seed) + prompts = cfg.prompt + + # ====================================================== + # 3. build model & load weights + # ====================================================== + # 3.1. build model + input_size = (cfg.num_frames, *cfg.image_size) + vae = build_module(cfg.vae, MODELS) + latent_size = vae.get_latent_size(input_size) + + # 3.2. move to device & eval + vae = vae.to(device, dtype).eval() + + # # 3.3. build scheduler + # scheduler = build_module(cfg.scheduler, SCHEDULERS) + + # 3.4. support for multi-resolution + model_args = dict() + if cfg.multi_resolution: + image_size = cfg.image_size + hw = torch.tensor([image_size], device=device, dtype=dtype).repeat(cfg.batch_size, 1) + ar = torch.tensor([[image_size[0] / image_size[1]]], device=device, dtype=dtype).repeat(cfg.batch_size, 1) + model_args["data_info"] = dict(ar=ar, hw=hw) + + # ====================================================== + # 4. inference + # ====================================================== + sample_idx = 0 + save_dir = cfg.save_dir + os.makedirs(save_dir, exist_ok=True) + + + # 4.1. batch generation + + # TODO: read input, sample, then decode ??? ========= + for i in range(0, len(prompts), cfg.batch_size): + # 4.2 sample in hidden space + + z = torch.randn(len(batch_prompts), vae.out_channels, *latent_size, device=device, dtype=dtype) + + # 4.3. diffusion sampling + samples = scheduler.sample( + model, + text_encoder, + z=z, + prompts=batch_prompts, + device=device, + additional_args=model_args, + ) + + + samples = vae.decode(samples.to(dtype)) + + if coordinator.is_master(): + for idx, sample in enumerate(samples): + print(f"Prompt: {batch_prompts[idx]}") + save_path = os.path.join(save_dir, f"sample_{sample_idx}") + save_sample(sample, fps=cfg.fps, save_path=save_path) + sample_idx += 1 + + # TODO: read input, sample, then decode ??? ========= + +if __name__ == "__main__": + main()