mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-05-20 09:22:22 +02:00
finish discriminator arch
use einops.rearrange in vae model_utils add get_latent_size in vae_3d inference vae_3d WIP
This commit is contained in:
parent
ebb3dc4d59
commit
b6c39873ee
|
|
@ -18,10 +18,6 @@ sp_size = 1
|
|||
kl_weight = 0.000001
|
||||
perceptual_weight = 1.0
|
||||
|
||||
|
||||
# training
|
||||
wandb = True
|
||||
|
||||
# Define model
|
||||
|
||||
model = dict(
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ from typing import Any
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import ml_collections
|
||||
|
||||
# TODO: torch.nn.init.xavier_uniform_
|
||||
# default_kernel_init = nn.initializers.xavier_uniform()
|
||||
|
|
@ -19,7 +18,6 @@ class ResBlock(nn.Module):
|
|||
in_channels,
|
||||
filters,
|
||||
activation_fn,
|
||||
input_dim, # x.shape[-1], TODO
|
||||
num_groups=32,
|
||||
device="cpu",
|
||||
dtype=torch.bfloat16,
|
||||
|
|
@ -29,13 +27,12 @@ class ResBlock(nn.Module):
|
|||
self.filters = filters
|
||||
self.activation_fn = activation_fn
|
||||
|
||||
# TODO: figure out the input_dim
|
||||
|
||||
self.conv1 = nn.Conv3d(in_channels, self.filters, (3,3,3)) # need to init to xavier_uniform
|
||||
self.norm1 = nn.GroupNorm(num_groups, self.filters, device=device, dtype=dtype)
|
||||
self.avg_pool_with_t = nn.AvgPool3d((2,2,2))
|
||||
# SCH: NOTE: although paper says conv (X->Y, Y->Y), original code implementation is (X->X, X->Y), we follow code
|
||||
self.conv1 = nn.Conv3d(in_channels, in_channels, (3,3,3)) # TODO: need to init to xavier_uniform
|
||||
self.norm1 = nn.GroupNorm(num_groups, in_channels, device=device, dtype=dtype)
|
||||
self.avg_pool_with_t = nn.AvgPool3d((2,2,2),count_include_pad=False)
|
||||
self.conv2 = nn.Conv3d(in_channels, self.filters,(1,1,1), use_bias=False) # need to init to xavier_uniform
|
||||
self.conv3 = nn.Conv3d(input_dim, self.filters, (3,3,3)) # need to init to xavier_uniform
|
||||
self.conv3 = nn.Conv3d(in_channels, self.filters, (3,3,3)) # need to init to xavier_uniform
|
||||
self.norm2 = nn.GroupNorm(num_groups, self.filters, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
|
|
@ -59,7 +56,7 @@ class StyleGANDiscriminator(nn.Module):
|
|||
self,
|
||||
config,
|
||||
image_size,
|
||||
input_dim, # x.shape[-1]
|
||||
num_frames,
|
||||
discriminator_in_channels = 3,
|
||||
discriminator_filters = 64,
|
||||
discriminator_channel_multipliers = (2,4,4,4,4),
|
||||
|
|
@ -69,16 +66,11 @@ class StyleGANDiscriminator(nn.Module):
|
|||
):
|
||||
self.config = config
|
||||
self.dtype = dtype
|
||||
|
||||
self.input_size = image_size
|
||||
self.filters = discriminator_filters
|
||||
|
||||
self.activation_fn = nn.LeakyReLu(negative_slope=0.2)
|
||||
|
||||
self.channel_multipliers = discriminator_channel_multipliers
|
||||
|
||||
|
||||
|
||||
self.conv1 = nn.Conv3d(discriminator_in_channels, self.filters, (3, 3, 3)) # need to init to xavier_uniform
|
||||
|
||||
prev_filters = self.filters # record in_channels
|
||||
|
|
@ -86,7 +78,7 @@ class StyleGANDiscriminator(nn.Module):
|
|||
self.res_block_list = []
|
||||
for i in range(self.num_blocks):
|
||||
filters = self.filters * self.channel_multipliers[i]
|
||||
self.res_block_list.append(ResBlock(prev_filters, filters, self.activation_fn)) # TODO
|
||||
self.res_block_list.append(ResBlock(prev_filters, filters, self.activation_fn))
|
||||
prev_filters = filters # update in_channels
|
||||
|
||||
self.conv2 = nn.Conv3d(prev_filters, prev_filters, (3,3,3)) # need to init to xavier_uniform
|
||||
|
|
@ -94,6 +86,10 @@ class StyleGANDiscriminator(nn.Module):
|
|||
self.norm1 = nn.GroupNorm(num_groups, prev_filters, dtype=dtype, device=device)
|
||||
|
||||
# TODO: what is the in_features
|
||||
scale_factor = 2 ** len(self.num_blocks)
|
||||
time_scaled = num_frames / scale_factor
|
||||
image_scaled = image_size / scale_factor
|
||||
in_features = prev_filters * time_scaled * image_scaled * image_scaled # (C*T*W*H)
|
||||
self.linear1 = nn.Linear(in_features, prev_filters, device=device, dtype=dtype) # need to init to xavier_uniform
|
||||
self.linear2 = nn.Linear(prev_filters, 1, device=device, dtype=dtype) # need to init to xavier_uniform
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import numpy as np
|
|||
import torch
|
||||
from taming.modules.losses.lpips import LPIPS # need to pip install https://github.com/CompVis/taming-transformers
|
||||
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
|
||||
|
||||
|
|
@ -105,16 +105,22 @@ class VEA3DLoss(nn.Module):
|
|||
if self.perceptual_weight > 0: # NOTE: need in_channels == 3 in order to use!
|
||||
assert inputs.size(1) == 3, f"using vgg16 that requires 3 input channels but got {inputs.size(1)}"
|
||||
# SCH: transform to [(B,T), C, H, W] shape for percetual loss over each frame
|
||||
permutated_input = torch.permute(inputs, (0, 2, 1, 3, 4)) # [B, C, T, H, W] --> [B, T, C, H, W]
|
||||
permutated_rec = torch.permute(reconstructions, (0, 2, 1, 3, 4))
|
||||
data_shape = permutated_input.size()
|
||||
p_loss = self.perceptual_loss(
|
||||
permutated_input.reshape(-1, data_shape[-3], data_shape[-2],data_shape[-1]).contiguous(),
|
||||
permutated_rec.reshape(-1, data_shape[-3], data_shape[-2],data_shape[-1]).contiguous()
|
||||
)
|
||||
B = inputs.shape[0]
|
||||
inputs = rearrange(inputs,"B C T H W -> (B T) C H W")
|
||||
reconstructions = rearrange(reconstructions, "B C T H W -> (B T) C H W")
|
||||
# permutated_input = torch.permute(inputs, (0, 2, 1, 3, 4)) # [B, C, T, H, W] --> [B, T, C, H, W]
|
||||
# permutated_rec = torch.permute(reconstructions, (0, 2, 1, 3, 4))
|
||||
# data_shape = permutated_input.size()
|
||||
# p_loss = self.perceptual_loss(
|
||||
# permutated_input.reshape(-1, data_shape[-3], data_shape[-2],data_shape[-1]).contiguous(),
|
||||
# permutated_rec.reshape(-1, data_shape[-3], data_shape[-2],data_shape[-1]).contiguous()
|
||||
# )
|
||||
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
|
||||
# SCH: shape back p_loss
|
||||
permuted_p_loss = torch.permute(p_loss.reshape(data_shape[0], data_shape[1], 1, 1, 1), (0,2,1,3,4))
|
||||
rec_loss = rec_loss + self.perceptual_weight * permuted_p_loss
|
||||
# permuted_p_loss = torch.permute(p_loss.reshape(data_shape[0], data_shape[1], 1, 1, 1), (0,2,1,3,4))
|
||||
# rec_loss = rec_loss + self.perceptual_weight * permuted_p_loss
|
||||
p_loss = rearrange(p_loss, "(B T) C H W -> B C T H W", B=B)
|
||||
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
||||
|
||||
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
|
||||
weighted_nll_loss = nll_loss
|
||||
|
|
|
|||
|
|
@ -415,6 +415,16 @@ class VAE_3D(nn.Module):
|
|||
self.quant_conv = nn.Conv3d(latent_embed_dim, 2*kl_embed_dim, 1)
|
||||
self.post_quant_conv = nn.Conv3d(kl_embed_dim, latent_embed_dim, 1)
|
||||
|
||||
image_down = 2 ** len(temporal_downsample)
|
||||
t_down = 2 ** len([x for x in temporal_downsample if x == True])
|
||||
self.patch_size = (t_down, image_down, image_down)
|
||||
|
||||
def get_latent_size(self, input_size):
|
||||
for i in range(len(input_size)):
|
||||
assert input_size[i] % self.patch_size[i] == 0, "Input size must be divisible by patch size"
|
||||
input_size = [input_size[i] // self.patch_size[i] for i in range(3)]
|
||||
return input_size
|
||||
|
||||
def encode(
|
||||
self,
|
||||
x,
|
||||
|
|
|
|||
105
scripts/inference-vae.py
Normal file
105
scripts/inference-vae.py
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
import os
|
||||
|
||||
import colossalai
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from mmengine.runner import set_random_seed
|
||||
|
||||
from opensora.acceleration.parallel_states import set_sequence_parallel_group
|
||||
from opensora.datasets import save_sample
|
||||
from opensora.registry import MODELS, SCHEDULERS, build_module
|
||||
from opensora.utils.config_utils import parse_configs
|
||||
from opensora.utils.misc import to_torch_dtype
|
||||
|
||||
|
||||
def main():
|
||||
# ======================================================
|
||||
# 1. cfg and init distributed env
|
||||
# ======================================================
|
||||
cfg = parse_configs(training=False)
|
||||
print(cfg)
|
||||
|
||||
# init distributed
|
||||
colossalai.launch_from_torch({})
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
if coordinator.world_size > 1:
|
||||
set_sequence_parallel_group(dist.group.WORLD)
|
||||
enable_sequence_parallelism = True
|
||||
else:
|
||||
enable_sequence_parallelism = False
|
||||
|
||||
# ======================================================
|
||||
# 2. runtime variables
|
||||
# ======================================================
|
||||
torch.set_grad_enabled(False)
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
dtype = to_torch_dtype(cfg.dtype)
|
||||
set_random_seed(seed=cfg.seed)
|
||||
prompts = cfg.prompt
|
||||
|
||||
# ======================================================
|
||||
# 3. build model & load weights
|
||||
# ======================================================
|
||||
# 3.1. build model
|
||||
input_size = (cfg.num_frames, *cfg.image_size)
|
||||
vae = build_module(cfg.vae, MODELS)
|
||||
latent_size = vae.get_latent_size(input_size)
|
||||
|
||||
# 3.2. move to device & eval
|
||||
vae = vae.to(device, dtype).eval()
|
||||
|
||||
# # 3.3. build scheduler
|
||||
# scheduler = build_module(cfg.scheduler, SCHEDULERS)
|
||||
|
||||
# 3.4. support for multi-resolution
|
||||
model_args = dict()
|
||||
if cfg.multi_resolution:
|
||||
image_size = cfg.image_size
|
||||
hw = torch.tensor([image_size], device=device, dtype=dtype).repeat(cfg.batch_size, 1)
|
||||
ar = torch.tensor([[image_size[0] / image_size[1]]], device=device, dtype=dtype).repeat(cfg.batch_size, 1)
|
||||
model_args["data_info"] = dict(ar=ar, hw=hw)
|
||||
|
||||
# ======================================================
|
||||
# 4. inference
|
||||
# ======================================================
|
||||
sample_idx = 0
|
||||
save_dir = cfg.save_dir
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
|
||||
# 4.1. batch generation
|
||||
|
||||
# TODO: read input, sample, then decode ??? =========
|
||||
for i in range(0, len(prompts), cfg.batch_size):
|
||||
# 4.2 sample in hidden space
|
||||
|
||||
z = torch.randn(len(batch_prompts), vae.out_channels, *latent_size, device=device, dtype=dtype)
|
||||
|
||||
# 4.3. diffusion sampling
|
||||
samples = scheduler.sample(
|
||||
model,
|
||||
text_encoder,
|
||||
z=z,
|
||||
prompts=batch_prompts,
|
||||
device=device,
|
||||
additional_args=model_args,
|
||||
)
|
||||
|
||||
|
||||
samples = vae.decode(samples.to(dtype))
|
||||
|
||||
if coordinator.is_master():
|
||||
for idx, sample in enumerate(samples):
|
||||
print(f"Prompt: {batch_prompts[idx]}")
|
||||
save_path = os.path.join(save_dir, f"sample_{sample_idx}")
|
||||
save_sample(sample, fps=cfg.fps, save_path=save_path)
|
||||
sample_idx += 1
|
||||
|
||||
# TODO: read input, sample, then decode ??? =========
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in a new issue