finish discriminator arch

use einops.rearrange in vae model_utils
add get_latent_size in vae_3d
inference vae_3d WIP
This commit is contained in:
Shen-Chenhui 2024-04-03 10:29:01 +08:00
parent ebb3dc4d59
commit b6c39873ee
5 changed files with 142 additions and 29 deletions

View file

@ -18,10 +18,6 @@ sp_size = 1
kl_weight = 0.000001
perceptual_weight = 1.0
# training
wandb = True
# Define model
model = dict(

View file

@ -6,7 +6,6 @@ from typing import Any
import torch
import torch.nn as nn
import ml_collections
# TODO: torch.nn.init.xavier_uniform_
# default_kernel_init = nn.initializers.xavier_uniform()
@ -19,7 +18,6 @@ class ResBlock(nn.Module):
in_channels,
filters,
activation_fn,
input_dim, # x.shape[-1], TODO
num_groups=32,
device="cpu",
dtype=torch.bfloat16,
@ -29,13 +27,12 @@ class ResBlock(nn.Module):
self.filters = filters
self.activation_fn = activation_fn
# TODO: figure out the input_dim
self.conv1 = nn.Conv3d(in_channels, self.filters, (3,3,3)) # need to init to xavier_uniform
self.norm1 = nn.GroupNorm(num_groups, self.filters, device=device, dtype=dtype)
self.avg_pool_with_t = nn.AvgPool3d((2,2,2))
# SCH: NOTE: although paper says conv (X->Y, Y->Y), original code implementation is (X->X, X->Y), we follow code
self.conv1 = nn.Conv3d(in_channels, in_channels, (3,3,3)) # TODO: need to init to xavier_uniform
self.norm1 = nn.GroupNorm(num_groups, in_channels, device=device, dtype=dtype)
self.avg_pool_with_t = nn.AvgPool3d((2,2,2),count_include_pad=False)
self.conv2 = nn.Conv3d(in_channels, self.filters,(1,1,1), use_bias=False) # need to init to xavier_uniform
self.conv3 = nn.Conv3d(input_dim, self.filters, (3,3,3)) # need to init to xavier_uniform
self.conv3 = nn.Conv3d(in_channels, self.filters, (3,3,3)) # need to init to xavier_uniform
self.norm2 = nn.GroupNorm(num_groups, self.filters, device=device, dtype=dtype)
def forward(self, x):
@ -59,7 +56,7 @@ class StyleGANDiscriminator(nn.Module):
self,
config,
image_size,
input_dim, # x.shape[-1]
num_frames,
discriminator_in_channels = 3,
discriminator_filters = 64,
discriminator_channel_multipliers = (2,4,4,4,4),
@ -69,16 +66,11 @@ class StyleGANDiscriminator(nn.Module):
):
self.config = config
self.dtype = dtype
self.input_size = image_size
self.filters = discriminator_filters
self.activation_fn = nn.LeakyReLu(negative_slope=0.2)
self.channel_multipliers = discriminator_channel_multipliers
self.conv1 = nn.Conv3d(discriminator_in_channels, self.filters, (3, 3, 3)) # need to init to xavier_uniform
prev_filters = self.filters # record in_channels
@ -86,7 +78,7 @@ class StyleGANDiscriminator(nn.Module):
self.res_block_list = []
for i in range(self.num_blocks):
filters = self.filters * self.channel_multipliers[i]
self.res_block_list.append(ResBlock(prev_filters, filters, self.activation_fn)) # TODO
self.res_block_list.append(ResBlock(prev_filters, filters, self.activation_fn))
prev_filters = filters # update in_channels
self.conv2 = nn.Conv3d(prev_filters, prev_filters, (3,3,3)) # need to init to xavier_uniform
@ -94,6 +86,10 @@ class StyleGANDiscriminator(nn.Module):
self.norm1 = nn.GroupNorm(num_groups, prev_filters, dtype=dtype, device=device)
# TODO: what is the in_features
scale_factor = 2 ** len(self.num_blocks)
time_scaled = num_frames / scale_factor
image_scaled = image_size / scale_factor
in_features = prev_filters * time_scaled * image_scaled * image_scaled # (C*T*W*H)
self.linear1 = nn.Linear(in_features, prev_filters, device=device, dtype=dtype) # need to init to xavier_uniform
self.linear2 = nn.Linear(prev_filters, 1, device=device, dtype=dtype) # need to init to xavier_uniform

View file

@ -7,7 +7,7 @@ import numpy as np
import torch
from taming.modules.losses.lpips import LPIPS # need to pip install https://github.com/CompVis/taming-transformers
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
from einops import rearrange
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
@ -105,16 +105,22 @@ class VEA3DLoss(nn.Module):
if self.perceptual_weight > 0: # NOTE: need in_channels == 3 in order to use!
assert inputs.size(1) == 3, f"using vgg16 that requires 3 input channels but got {inputs.size(1)}"
# SCH: transform to [(B,T), C, H, W] shape for percetual loss over each frame
permutated_input = torch.permute(inputs, (0, 2, 1, 3, 4)) # [B, C, T, H, W] --> [B, T, C, H, W]
permutated_rec = torch.permute(reconstructions, (0, 2, 1, 3, 4))
data_shape = permutated_input.size()
p_loss = self.perceptual_loss(
permutated_input.reshape(-1, data_shape[-3], data_shape[-2],data_shape[-1]).contiguous(),
permutated_rec.reshape(-1, data_shape[-3], data_shape[-2],data_shape[-1]).contiguous()
)
B = inputs.shape[0]
inputs = rearrange(inputs,"B C T H W -> (B T) C H W")
reconstructions = rearrange(reconstructions, "B C T H W -> (B T) C H W")
# permutated_input = torch.permute(inputs, (0, 2, 1, 3, 4)) # [B, C, T, H, W] --> [B, T, C, H, W]
# permutated_rec = torch.permute(reconstructions, (0, 2, 1, 3, 4))
# data_shape = permutated_input.size()
# p_loss = self.perceptual_loss(
# permutated_input.reshape(-1, data_shape[-3], data_shape[-2],data_shape[-1]).contiguous(),
# permutated_rec.reshape(-1, data_shape[-3], data_shape[-2],data_shape[-1]).contiguous()
# )
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
# SCH: shape back p_loss
permuted_p_loss = torch.permute(p_loss.reshape(data_shape[0], data_shape[1], 1, 1, 1), (0,2,1,3,4))
rec_loss = rec_loss + self.perceptual_weight * permuted_p_loss
# permuted_p_loss = torch.permute(p_loss.reshape(data_shape[0], data_shape[1], 1, 1, 1), (0,2,1,3,4))
# rec_loss = rec_loss + self.perceptual_weight * permuted_p_loss
p_loss = rearrange(p_loss, "(B T) C H W -> B C T H W", B=B)
rec_loss = rec_loss + self.perceptual_weight * p_loss
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
weighted_nll_loss = nll_loss

View file

@ -415,6 +415,16 @@ class VAE_3D(nn.Module):
self.quant_conv = nn.Conv3d(latent_embed_dim, 2*kl_embed_dim, 1)
self.post_quant_conv = nn.Conv3d(kl_embed_dim, latent_embed_dim, 1)
image_down = 2 ** len(temporal_downsample)
t_down = 2 ** len([x for x in temporal_downsample if x == True])
self.patch_size = (t_down, image_down, image_down)
def get_latent_size(self, input_size):
for i in range(len(input_size)):
assert input_size[i] % self.patch_size[i] == 0, "Input size must be divisible by patch size"
input_size = [input_size[i] // self.patch_size[i] for i in range(3)]
return input_size
def encode(
self,
x,

105
scripts/inference-vae.py Normal file
View file

@ -0,0 +1,105 @@
import os
import colossalai
import torch
import torch.distributed as dist
from colossalai.cluster import DistCoordinator
from mmengine.runner import set_random_seed
from opensora.acceleration.parallel_states import set_sequence_parallel_group
from opensora.datasets import save_sample
from opensora.registry import MODELS, SCHEDULERS, build_module
from opensora.utils.config_utils import parse_configs
from opensora.utils.misc import to_torch_dtype
def main():
# ======================================================
# 1. cfg and init distributed env
# ======================================================
cfg = parse_configs(training=False)
print(cfg)
# init distributed
colossalai.launch_from_torch({})
coordinator = DistCoordinator()
if coordinator.world_size > 1:
set_sequence_parallel_group(dist.group.WORLD)
enable_sequence_parallelism = True
else:
enable_sequence_parallelism = False
# ======================================================
# 2. runtime variables
# ======================================================
torch.set_grad_enabled(False)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = to_torch_dtype(cfg.dtype)
set_random_seed(seed=cfg.seed)
prompts = cfg.prompt
# ======================================================
# 3. build model & load weights
# ======================================================
# 3.1. build model
input_size = (cfg.num_frames, *cfg.image_size)
vae = build_module(cfg.vae, MODELS)
latent_size = vae.get_latent_size(input_size)
# 3.2. move to device & eval
vae = vae.to(device, dtype).eval()
# # 3.3. build scheduler
# scheduler = build_module(cfg.scheduler, SCHEDULERS)
# 3.4. support for multi-resolution
model_args = dict()
if cfg.multi_resolution:
image_size = cfg.image_size
hw = torch.tensor([image_size], device=device, dtype=dtype).repeat(cfg.batch_size, 1)
ar = torch.tensor([[image_size[0] / image_size[1]]], device=device, dtype=dtype).repeat(cfg.batch_size, 1)
model_args["data_info"] = dict(ar=ar, hw=hw)
# ======================================================
# 4. inference
# ======================================================
sample_idx = 0
save_dir = cfg.save_dir
os.makedirs(save_dir, exist_ok=True)
# 4.1. batch generation
# TODO: read input, sample, then decode ??? =========
for i in range(0, len(prompts), cfg.batch_size):
# 4.2 sample in hidden space
z = torch.randn(len(batch_prompts), vae.out_channels, *latent_size, device=device, dtype=dtype)
# 4.3. diffusion sampling
samples = scheduler.sample(
model,
text_encoder,
z=z,
prompts=batch_prompts,
device=device,
additional_args=model_args,
)
samples = vae.decode(samples.to(dtype))
if coordinator.is_master():
for idx, sample in enumerate(samples):
print(f"Prompt: {batch_prompts[idx]}")
save_path = os.path.join(save_dir, f"sample_{sample_idx}")
save_sample(sample, fps=cfg.fps, save_path=save_path)
sample_idx += 1
# TODO: read input, sample, then decode ??? =========
if __name__ == "__main__":
main()