mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-17 22:56:10 +02:00
add disable space in vae v2
This commit is contained in:
parent
e241ecbbe1
commit
8bf4a0fa77
|
|
@ -30,6 +30,7 @@ model = dict(
|
|||
kl_embed_dim = 64,
|
||||
activation_fn = 'swish',
|
||||
separate_first_frame_encoding = False,
|
||||
disable_space = True,
|
||||
custom_conv_padding = None
|
||||
)
|
||||
|
||||
|
|
@ -40,7 +41,7 @@ discriminator = dict(
|
|||
num_frames = num_frames,
|
||||
in_channels = 3,
|
||||
filters = 128,
|
||||
channel_multipliers = (2,4,4,4,4) # (2,4,4,4) for 64x64 resolution
|
||||
channel_multipliers = (2,4,4,4,4), # (2,4,4,4) for 64x64 resolution
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
83
configs/vae_magvit_v2/train/pipeline_16x128x128.py
Normal file
83
configs/vae_magvit_v2/train/pipeline_16x128x128.py
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
num_frames = 16
|
||||
frame_interval = 3
|
||||
image_size = (128, 128)
|
||||
|
||||
# Define dataset
|
||||
root = None
|
||||
data_path = "CSV_PATH"
|
||||
use_image_transform = False
|
||||
num_workers = 4
|
||||
video_contains_first_frame = False
|
||||
|
||||
# Define acceleration
|
||||
dtype = "bf16"
|
||||
grad_checkpoint = True
|
||||
plugin = "zero2"
|
||||
sp_size = 1
|
||||
|
||||
|
||||
# Define model
|
||||
|
||||
model = dict(
|
||||
type="VAE_MAGVIT_V2",
|
||||
in_out_channels = 3,
|
||||
latent_embed_dim = 256,
|
||||
filters = 128,
|
||||
num_res_blocks = 4,
|
||||
channel_multipliers = (1, 2, 2, 4),
|
||||
temporal_downsample = (False, True, True),
|
||||
num_groups = 32, # for nn.GroupNorm
|
||||
kl_embed_dim = 64,
|
||||
activation_fn = 'swish',
|
||||
separate_first_frame_encoding = False,
|
||||
custom_conv_padding = None
|
||||
)
|
||||
|
||||
|
||||
discriminator = dict(
|
||||
type="DISCRIMINATOR_3D",
|
||||
image_size = image_size,
|
||||
num_frames = num_frames,
|
||||
in_channels = 3,
|
||||
filters = 128,
|
||||
channel_multipliers = (2,4,4,4,4), # (2,4,4,4) for 64x64 resolution
|
||||
)
|
||||
|
||||
|
||||
# loss weights
|
||||
logvar_init=0.0
|
||||
kl_loss_weight = 0.000001
|
||||
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
|
||||
discriminator_factor = 1.0 # for discriminator adversarial loss
|
||||
# discriminator_loss_weight = 0.5 # for generator adversarial loss
|
||||
generator_factor = 0.1 # SCH: generator adversarial loss, MAGVIT v2 uses 0.1
|
||||
lecam_loss_weight = None # NOTE: MAVGIT v2 use 0.001
|
||||
# discriminator_loss_type="non-saturating"
|
||||
# generator_loss_type="non-saturating"
|
||||
discriminator_loss_type="hinge"
|
||||
generator_loss_type="hinge"
|
||||
discriminator_start = 30000 # 50000 NOTE: change to correct val, debug use -1 for now
|
||||
gradient_penalty_loss_weight = None # 10 # SCH: MAGVIT uses 10, opensora plan doesn't use
|
||||
ema_decay = 0.999 # ema decay factor for generator
|
||||
|
||||
|
||||
# Others
|
||||
seed = 42
|
||||
outputs = "outputs"
|
||||
wandb = False
|
||||
|
||||
# Training
|
||||
''' NOTE:
|
||||
magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128
|
||||
==> ours num_frams = 16, reso = 256, so samples (K) * epochs ~ [500 - 1200],
|
||||
3-6 epochs for pexel, from pexel observation its correct
|
||||
'''
|
||||
|
||||
epochs = 10
|
||||
log_every = 1
|
||||
ckpt_every = 1000
|
||||
load = None
|
||||
|
||||
batch_size = 4
|
||||
lr = 1e-4
|
||||
grad_clip = 1.0
|
||||
|
|
@ -7,7 +7,7 @@ import numpy as np
|
|||
from numpy import typing as nptyping
|
||||
from opensora.models.vae import model_utils
|
||||
from opensora.registry import MODELS
|
||||
from opensora.utils.ckpt_utils import load_checkpoint
|
||||
from opensora.utils.ckpt_utils import load_checkpoint, find_model
|
||||
from einops import rearrange, repeat, pack, unpack
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
|
|
@ -483,6 +483,7 @@ class Encoder(nn.Module):
|
|||
# in_out_channels = 3, # SCH: added, in_channels at the start
|
||||
latent_embed_dim = 512, # num channels for latent vector
|
||||
# conv_downsample = False,
|
||||
disable_spatial_downsample = False, # for vae pipeline
|
||||
custom_conv_padding = None,
|
||||
activation_fn = 'swish',
|
||||
device="cpu",
|
||||
|
|
@ -496,6 +497,7 @@ class Encoder(nn.Module):
|
|||
self.num_groups = num_groups
|
||||
|
||||
self.embedding_dim = latent_embed_dim
|
||||
self.disable_spatial_downsample = disable_spatial_downsample
|
||||
# self.conv_downsample = conv_downsample
|
||||
self.custom_conv_padding = custom_conv_padding
|
||||
|
||||
|
|
@ -542,9 +544,10 @@ class Encoder(nn.Module):
|
|||
prev_filters = filters # update in_channels
|
||||
self.block_res_blocks.append(block_items)
|
||||
|
||||
if i < self.num_blocks - 1: # SCH: T-Causal Conv 3x3x3, 128->128, stride t x 2 x 2
|
||||
if i < self.num_blocks - 1: # SCH: T-Causal Conv 3x3x3, 128->128, stride t x stride s x stride s
|
||||
t_stride = 2 if self.temporal_downsample[i] else 1
|
||||
self.conv_blocks.append(self.conv_fn(prev_filters, filters, kernel_size=(3, 3, 3), strides=(t_stride, 2, 2))) # SCH: should be same in_channel and out_channel
|
||||
s_stride = 2 if not self.disable_spatial_downsample else 1
|
||||
self.conv_blocks.append(self.conv_fn(prev_filters, filters, kernel_size=(3, 3, 3), strides=(t_stride, s_stride, s_stride))) # SCH: should be same in_channel and out_channel
|
||||
prev_filters = filters # update in_channels
|
||||
|
||||
|
||||
|
|
@ -598,6 +601,7 @@ class Decoder(nn.Module):
|
|||
temporal_downsample = (False, True, True),
|
||||
num_groups = 32, # for nn.GroupNorm
|
||||
# upsample = "nearest+conv", # options: "deconv", "nearest+conv"
|
||||
disable_spatial_upsample = False, # for vae pipeline
|
||||
custom_conv_padding = None,
|
||||
activation_fn = 'swish',
|
||||
device="cpu",
|
||||
|
|
@ -613,6 +617,7 @@ class Decoder(nn.Module):
|
|||
self.num_groups = num_groups
|
||||
|
||||
# self.upsample = upsample
|
||||
self.s_stride = 1 if self.disable_spatial_upsample else 2 # spatial stride
|
||||
self.custom_conv_padding = custom_conv_padding
|
||||
# self.norm_type = self.config.vqvae.norm_type
|
||||
# self.num_remat_block = self.config.vqvae.get('num_dec_remat_blocks', 0)
|
||||
|
|
@ -677,7 +682,7 @@ class Decoder(nn.Module):
|
|||
t_stride = 2 if self.temporal_downsample[i - 1] else 1
|
||||
# SCH: T-Causal Conv 3x3x3, f -> (t_stride * 2 * 2) * f, depth to space t_stride x 2 x 2
|
||||
self.conv_blocks.insert(0,
|
||||
self.conv_fn(prev_filters, prev_filters * t_stride * 4, kernel_size=(3,3,3))
|
||||
self.conv_fn(prev_filters, prev_filters * t_stride * self.s_stride * self.s_stride, kernel_size=(3,3,3))
|
||||
)
|
||||
|
||||
self.norm1 = nn.GroupNorm(self.num_groups, prev_filters, device=device, dtype=dtype)
|
||||
|
|
@ -706,7 +711,7 @@ class Decoder(nn.Module):
|
|||
t_stride = 2 if self.temporal_downsample[i - 1] else 1
|
||||
# SCH: T-Causal Conv 3x3x3, f -> (t_stride * 2 * 2) * f, depth to space t_stride x 2 x 2
|
||||
x = self.conv_blocks[i-1](x)
|
||||
x = rearrange(x, "B (C ts hs ws) T H W -> B C (T ts) (H hs) (W ws)", ts=t_stride, hs=2, ws=2)
|
||||
x = rearrange(x, "B (C ts hs ws) T H W -> B C (T ts) (H hs) (W ws)", ts=t_stride, hs=self.s_stride, ws=self.s_stride)
|
||||
# print("decoder:", x.size())
|
||||
|
||||
x = self.norm1(x)
|
||||
|
|
@ -728,6 +733,7 @@ class VAE_3D_V2(nn.Module): # , ModelMixin
|
|||
channel_multipliers = (1, 2, 2, 4),
|
||||
temporal_downsample = (True, True, False),
|
||||
num_groups = 32, # for nn.GroupNorm
|
||||
disable_space = False,
|
||||
custom_conv_padding = None,
|
||||
activation_fn = 'swish',
|
||||
in_out_channels = 4,
|
||||
|
|
@ -777,6 +783,7 @@ class VAE_3D_V2(nn.Module): # , ModelMixin
|
|||
# in_out_channels = in_out_channels,
|
||||
latent_embed_dim = latent_embed_dim,
|
||||
# conv_downsample = conv_downsample,
|
||||
disable_spatial_downsample=disable_space,
|
||||
custom_conv_padding = custom_conv_padding,
|
||||
activation_fn = activation_fn,
|
||||
device = device,
|
||||
|
|
@ -791,6 +798,7 @@ class VAE_3D_V2(nn.Module): # , ModelMixin
|
|||
temporal_downsample = temporal_downsample,
|
||||
num_groups = num_groups, # for nn.GroupNorm
|
||||
# upsample = upsample, # options: "deconv", "nearest+conv"
|
||||
disable_spatial_upsample=disable_space,
|
||||
custom_conv_padding = custom_conv_padding,
|
||||
activation_fn = activation_fn,
|
||||
device = device,
|
||||
|
|
@ -1196,10 +1204,38 @@ def VAE_MAGVIT_V2(from_pretrained=None, **kwargs):
|
|||
return model
|
||||
|
||||
@MODELS.register_module("DISCRIMINATOR_3D")
|
||||
def DISCRIMINATOR_3D(from_pretrained=None, **kwargs):
|
||||
def DISCRIMINATOR_3D(from_pretrained=None, inflate_from_2d=False, **kwargs):
|
||||
model = StyleGANDiscriminatorBlur(**kwargs).apply(xavier_uniform_weight_init)
|
||||
# model = StyleGANDiscriminator(**kwargs).apply(xavier_uniform_weight_init) # SCH: DEBUG: to change back
|
||||
# model = NLayerDiscriminator3D(input_nc=3, n_layers=3,).apply(n_layer_disc_weights_init)
|
||||
if from_pretrained is not None:
|
||||
load_checkpoint(model, from_pretrained)
|
||||
return model
|
||||
if inflate_from_2d:
|
||||
load_checkpoint_with_inflation(model, from_pretrained)
|
||||
else:
|
||||
load_checkpoint(model, from_pretrained)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
def load_checkpoint_with_inflation(model, ckpt_path):
|
||||
"""
|
||||
pre-train using image, then inflate to 3D videos
|
||||
"""
|
||||
if ckpt_path.endswith(".pt") or ckpt_path.endswith(".pth"):
|
||||
state_dict = find_model(ckpt_path)
|
||||
breakpoint() # NOTE: need to manually check before first use
|
||||
with torch.no_grad():
|
||||
for key in state_dict:
|
||||
if key in model:
|
||||
# central inflation
|
||||
if state_dict[key].size() == model[key][:, :, 0, :, :].size():
|
||||
# temporal dimension
|
||||
val = torch.zeros_like(model[key])
|
||||
centre = int(model[key].size(2) // 2)
|
||||
val[:, :, centre, :, :] = state_dict[key]
|
||||
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
||||
print(f"Missing keys: {missing_keys}")
|
||||
print(f"Unexpected keys: {unexpected_keys}")
|
||||
else:
|
||||
load_checkpoint(model, ckpt_path) # use the default function
|
||||
|
||||
Loading…
Reference in a new issue