This commit is contained in:
zhengzangw 2024-04-30 08:13:20 +00:00
parent 41e276f5ef
commit ace37cf7c7
14 changed files with 328 additions and 714 deletions

View file

@ -1,79 +0,0 @@
num_frames = 16
image_size = (256, 256)
fps = 24 // 3
max_test_samples = None
# Define dataset
dataset = dict(
type="VideoTextDataset",
data_path=None,
num_frames=num_frames,
frame_interval=1,
image_size=image_size,
)
# Define acceleration
num_workers = 4
dtype = "bf16"
grad_checkpoint = True
plugin = "zero2"
sp_size = 1
# Define model
vae_2d = dict(
type="VideoAutoencoderKL",
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
subfolder="vae",
micro_batch_size=4,
local_files_only=True,
)
model = dict(
type="VAE_Temporal_SD",
)
# discriminator = dict(
# type="DISCRIMINATOR_3D",
# image_size=image_size,
# num_frames=num_frames,
# in_channels=3,
# filters=128,
# channel_multipliers=(2, 4, 4, 4, 4),
# # channel_multipliers = (2,4,4), #(2,4,4,4,4) # (2,4,4,4) for 64x64 resolution
# )
# loss weights
logvar_init = 0.0
kl_loss_weight = 0.000001
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
discriminator_factor = 1.0 # for discriminator adversarial loss
# discriminator_loss_weight = 0.5 # for generator adversarial loss
generator_factor = 0.1 # for generator adversarial loss
lecam_loss_weight = None # NOTE: not clear in MAGVIT what is the weight
discriminator_loss_type = "non-saturating"
generator_loss_type = "non-saturating"
discriminator_start = 2500 # 50000 NOTE: change to correct val, debug use -1 for now
gradient_penalty_loss_weight = None # 10 # SCH: MAGVIT uses 10, opensora plan doesn't use
ema_decay = 0.999 # ema decay factor for generator
# Others
seed = 42
save_dir = "samples/samples_vae"
wandb = False
# Training
""" NOTE:
magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128
==> ours num_frams = 16, reso = 256, so samples (K) * epochs ~ [500 - 1200],
3-6 epochs for pexel, from pexel observation its correct
"""
batch_size = 1
lr = 1e-4
grad_clip = 1.0
calc_loss = True

View file

@ -1,80 +0,0 @@
num_frames = 1
# image_size = (256, 256)
image_size = (1024, 1024)
fps = 24 // 3
max_test_samples = None
# Define dataset
dataset = dict(
type="VideoTextDataset",
data_path=None,
num_frames=num_frames,
frame_interval=1,
image_size=image_size,
)
# Define acceleration
num_workers = 4
dtype = "bf16"
grad_checkpoint = True
plugin = "zero2"
sp_size = 1
# Define model
vae_2d = dict(
type="VideoAutoencoderKL",
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
subfolder="vae",
micro_batch_size=4,
local_files_only=True,
)
model = dict(
type="VAE_Temporal_SD",
)
# discriminator = dict(
# type="DISCRIMINATOR_3D",
# image_size=image_size,
# num_frames=num_frames,
# in_channels=3,
# filters=128,
# channel_multipliers=(2, 4, 4, 4, 4),
# # channel_multipliers = (2,4,4), #(2,4,4,4,4) # (2,4,4,4) for 64x64 resolution
# )
# loss weights
logvar_init = 0.0
kl_loss_weight = 0.000001
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
discriminator_factor = 1.0 # for discriminator adversarial loss
# discriminator_loss_weight = 0.5 # for generator adversarial loss
generator_factor = 0.1 # for generator adversarial loss
lecam_loss_weight = None # NOTE: not clear in MAGVIT what is the weight
discriminator_loss_type = "non-saturating"
generator_loss_type = "non-saturating"
discriminator_start = 2500 # 50000 NOTE: change to correct val, debug use -1 for now
gradient_penalty_loss_weight = None # 10 # SCH: MAGVIT uses 10, opensora plan doesn't use
ema_decay = 0.999 # ema decay factor for generator
# Others
seed = 42
save_dir = "samples/samples_vae"
wandb = False
# Training
""" NOTE:
magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128
==> ours num_frams = 16, reso = 256, so samples (K) * epochs ~ [500 - 1200],
3-6 epochs for pexel, from pexel observation its correct
"""
batch_size = 1
lr = 1e-4
grad_clip = 1.0
calc_loss = True

View file

@ -0,0 +1,42 @@
num_frames = 1
frame_interval = 1
fps = 24
image_size = (256, 256)
# Define dataset
dataset = dict(
type="VideoTextDataset",
data_path=None,
num_frames=num_frames,
frame_interval=1,
image_size=image_size,
)
num_workers = 4
max_test_samples = None
# Define model
model = dict(
type="VideoAutoencoderPipeline",
freeze_vae_2d=True,
vae_2d=dict(
type="VideoAutoencoderKL",
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
subfolder="vae",
micro_batch_size=4,
local_files_only=True,
),
vae_temporal=dict(
type="VAE_Temporal_SD",
from_pretrained=None,
),
)
dtype = "bf16"
# loss weights
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
kl_loss_weight = 1e-6
# Others
batch_size = 1
seed = 42
save_dir = "samples/vae_image"

View file

@ -0,0 +1,42 @@
num_frames = 17
frame_interval = 1
fps = 24
image_size = (256, 256)
# Define dataset
dataset = dict(
type="VideoTextDataset",
data_path=None,
num_frames=num_frames,
frame_interval=1,
image_size=image_size,
)
num_workers = 4
max_test_samples = None
# Define model
model = dict(
type="VideoAutoencoderPipeline",
from_pretrained=None,
vae_2d=dict(
type="VideoAutoencoderKL",
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
subfolder="vae",
micro_batch_size=4,
local_files_only=True,
),
vae_temporal=dict(
type="VAE_Temporal_SD",
from_pretrained=None,
),
)
dtype = "bf16"
# loss weights
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
kl_loss_weight = 1e-6
# Others
batch_size = 1
seed = 42
save_dir = "samples/vae_video"

View file

@ -1,74 +0,0 @@
num_frames = 17
image_size = (256, 256)
# Define dataset
dataset = dict(
type="VideoTextDataset",
data_path=None,
num_frames=num_frames,
frame_interval=1,
image_size=image_size,
)
# Define acceleration
num_workers = 16
dtype = "bf16"
grad_checkpoint = True
plugin = "zero2"
sp_size = 1
# latest
vae_2d = dict(
type="VideoAutoencoderKL",
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
subfolder="vae",
micro_batch_size=4,
local_files_only=True,
)
model = dict(
type="VAE_Temporal_SD",
)
# discriminator = dict(
# type="DISCRIMINATOR_3D",
# image_size=image_size, # NOTE: here image size is different
# num_frames=num_frames,
# in_channels=3,
# filters=128,
# use_pretrained=True, # NOTE: set to False only if we want to disable load
# channel_multipliers=(2, 4, 4, 4, 4), # (2,4,4,4) for 64x64 resolution
# # channel_multipliers=(2, 4, 4), # since on intermediate layer dimension ofs z
# )
# loss weights
logvar_init = 0.0
kl_loss_weight = 0.000001
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
discriminator_factor = 1.0 # for discriminator adversarial loss
generator_factor = 0.1 # SCH: generator adversarial loss, MAGVIT v2 uses 0.1
lecam_loss_weight = None # NOTE: MAVGIT v2 use 0.001
discriminator_loss_type = "non-saturating"
generator_loss_type = "non-saturating"
# discriminator_loss_type="hinge"
# generator_loss_type="hinge"
discriminator_start = 2000 # 5000 # 8k data / (8*1) = 1000 steps per epoch
gradient_penalty_loss_weight = None # 10 # SCH: MAGVIT uses 10, opensora plan doesn't use
ema_decay = 0.999 # ema decay factor for generator
# Others
seed = 42
outputs = "outputs"
wandb = False
epochs = 100
log_every = 1
ckpt_every = 1000
load = None
batch_size = 1
lr = 1e-5
grad_clip = 1.0

View file

@ -1,71 +0,0 @@
num_frames = 1
image_size = (256, 256)
# Define dataset
dataset = dict(
type="VideoTextDataset",
data_path=None,
num_frames=num_frames,
frame_interval=1,
image_size=image_size,
)
# Define acceleration
num_workers = 16
dtype = "bf16"
grad_checkpoint = False
plugin = "zero2"
sp_size = 1
# latest
vae_2d = dict(
type="VideoAutoencoderKL",
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
subfolder="vae",
micro_batch_size=4,
local_files_only=True,
)
model = dict(
type="VAE_Temporal_SD",
)
# discriminator = dict(
# type="DISCRIMINATOR_3D",
# image_size=image_size, # NOTE: here image size is different
# num_frames=num_frames,
# in_channels=3,
# filters=128,
# use_pretrained=True, # NOTE: set to False only if we want to disable load
# channel_multipliers=(2, 4, 4, 4, 4), # (2,4,4,4) for 64x64 resolution
# # channel_multipliers=(2, 4, 4), # since on intermediate layer dimension ofs z
# )
# loss weights
logvar_init = 0.0
kl_loss_weight = 0.000001
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
discriminator_factor = 1.0 # for discriminator adversarial loss
generator_factor = 0.1 # generator adversarial loss, MAGVIT v2 uses 0.1
lecam_loss_weight = None # NOTE: MAVGIT v2 use 0.001
discriminator_loss_type = "non-saturating"
generator_loss_type = "non-saturating"
discriminator_start = 2000
gradient_penalty_loss_weight = None # MAGVIT uses 10, opensora plan doesn't use
ema_decay = 0.999 # ema decay factor for generator
# Others
seed = 42
outputs = "outputs"
wandb = False
epochs = 100
log_every = 1
ckpt_every = 1000
load = None
batch_size = 4
lr = 1e-5
grad_clip = 1.0

View file

@ -0,0 +1,58 @@
num_frames = 1
image_size = (256, 256)
# Define dataset
dataset = dict(
type="VideoTextDataset",
data_path=None,
num_frames=num_frames,
frame_interval=1,
image_size=image_size,
)
# Define acceleration
num_workers = 16
dtype = "bf16"
grad_checkpoint = True
plugin = "zero2"
# Define model
model = dict(
type="VideoAutoencoderPipeline",
freeze_vae_2d=True,
from_pretrained=None,
vae_2d=dict(
type="VideoAutoencoderKL",
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
subfolder="vae",
micro_batch_size=4,
local_files_only=True,
),
vae_temporal=dict(
type="VAE_Temporal_SD",
from_pretrained=None,
),
)
# loss weights
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
kl_loss_weight = 1e-6
mixed_image_ratio = 0.1
use_real_rec_loss = False
use_z_rec_loss = True
use_image_identity_loss = True
# Others
seed = 42
outputs = "outputs"
wandb = False
epochs = 100
log_every = 1
ckpt_every = 1000
load = None
batch_size = 1
lr = 1e-5
grad_clip = 1.0

View file

@ -0,0 +1,58 @@
num_frames = 17
image_size = (256, 256)
# Define dataset
dataset = dict(
type="VideoTextDataset",
data_path=None,
num_frames=num_frames,
frame_interval=1,
image_size=image_size,
)
# Define acceleration
num_workers = 16
dtype = "bf16"
grad_checkpoint = True
plugin = "zero2"
# Define model
model = dict(
type="VideoAutoencoderPipeline",
freeze_vae_2d=True,
from_pretrained=None,
vae_2d=dict(
type="VideoAutoencoderKL",
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
subfolder="vae",
micro_batch_size=4,
local_files_only=True,
),
vae_temporal=dict(
type="VAE_Temporal_SD",
from_pretrained=None,
),
)
# loss weights
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
kl_loss_weight = 1e-6
mixed_image_ratio = 0.1
use_real_rec_loss = True
use_z_rec_loss = False
use_image_identity_loss = False
# Others
seed = 42
outputs = "outputs"
wandb = False
epochs = 100
log_every = 1
ckpt_every = 1000
load = None
batch_size = 1
lr = 1e-5
grad_clip = 1.0

View file

@ -94,13 +94,10 @@ class VAELoss(nn.Module):
if self.perceptual_loss_weight is not None and self.perceptual_loss_weight > 0.0:
# handle channels
channels = video.shape[1]
assert channels in {1, 3, 4}
assert channels in {1, 3}
if channels == 1:
input_vgg_input = repeat(video, "b 1 h w -> b c h w", c=3)
recon_vgg_input = repeat(recon_video, "b 1 h w -> b c h w", c=3)
elif channels == 4: # SCH: take the first 3 for perceptual loss calc
input_vgg_input = video[:, :3]
recon_vgg_input = recon_video[:, :3]
else:
input_vgg_input = video
recon_vgg_input = recon_video
@ -109,6 +106,7 @@ class VAELoss(nn.Module):
recon_loss = recon_loss + self.perceptual_loss_weight * perceptual_loss
nll_loss = recon_loss / torch.exp(self.logvar) + self.logvar
weighted_nll_loss = nll_loss
if nll_weights is not None:
weighted_nll_loss = nll_weights * nll_loss

View file

@ -3,15 +3,20 @@ import torch.nn as nn
from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder
from einops import rearrange
from opensora.registry import MODELS
from opensora.registry import MODELS, build_module
from opensora.utils.ckpt_utils import load_checkpoint
@MODELS.register_module()
class VideoAutoencoderKL(nn.Module):
def __init__(self, from_pretrained=None, micro_batch_size=None, cache_dir=None, local_files_only=False, subfolder=None):
def __init__(
self, from_pretrained=None, micro_batch_size=None, cache_dir=None, local_files_only=False, subfolder=None
):
super().__init__()
self.module = AutoencoderKL.from_pretrained(
from_pretrained, cache_dir=cache_dir, local_files_only=local_files_only,
from_pretrained,
cache_dir=cache_dir,
local_files_only=local_files_only,
subfolder=subfolder,
)
self.out_channels = self.module.config.latent_channels
@ -107,3 +112,47 @@ class VideoAutoencoderKLTemporalDecoder(nn.Module):
@property
def dtype(self):
return next(self.parameters()).dtype
@MODELS.register_module()
class VideoAutoencoderPipeline(nn.Module):
def __init__(self, vae_2d=None, vae_temporal=None, freeze_vae_2d=False, from_pretrained=None):
super().__init__()
self.spatial_vae = build_module(vae_2d, MODELS)
self.temporal_vae = build_module(vae_temporal, MODELS)
if from_pretrained is not None:
load_checkpoint(self, from_pretrained)
if freeze_vae_2d:
for param in self.spatial_vae.parameters():
param.requires_grad = False
def encode(self, x, training=True):
x_z = self.spatial_vae.encode(x)
posterior = self.temporal_vae.encode(x_z)
z = posterior.sample()
if training:
return z, posterior, x_z
return z
def decode(self, z, num_frames=None, training=True):
x_z = self.temporal_vae.decode(z, num_frames=num_frames)
x = self.spatial_vae.decode(x_z)
if training:
return x, x_z
return x
def forward(self, x):
z, posterior, x_z = self.encode(x, training=True)
x_rec, x_z_rec = self.decode(z, num_frames=x_z.shape[2], training=True)
return x_rec, x_z_rec, z, posterior, x_z
def get_latent_size(self, input_size):
return self.temporal_vae.get_latent_size(self.spatial_vae.get_latent_size(input_size))
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype

View file

@ -1,7 +1,5 @@
import functools
from typing import Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
@ -82,8 +80,6 @@ class ResBlock(nn.Module):
activation_fn=nn.SiLU,
use_conv_shortcut=False,
num_groups=32,
device="cpu",
dtype=torch.bfloat16,
):
super().__init__()
self.in_channels = in_channels
@ -92,9 +88,9 @@ class ResBlock(nn.Module):
self.use_conv_shortcut = use_conv_shortcut
# SCH: MAGVIT uses GroupNorm by default
self.norm1 = nn.GroupNorm(num_groups, in_channels, device=device, dtype=dtype)
self.norm1 = nn.GroupNorm(num_groups, in_channels)
self.conv1 = conv_fn(in_channels, self.filters, kernel_size=(3, 3, 3), bias=False)
self.norm2 = nn.GroupNorm(num_groups, self.filters, device=device, dtype=dtype)
self.norm2 = nn.GroupNorm(num_groups, self.filters)
self.conv2 = conv_fn(self.filters, self.filters, kernel_size=(3, 3, 3), bias=False)
if in_channels != filters:
if self.use_conv_shortcut:
@ -103,8 +99,6 @@ class ResBlock(nn.Module):
self.conv3 = conv_fn(in_channels, self.filters, kernel_size=(1, 1, 1), bias=False)
def forward(self, x):
# device, dtype = x.device, x.dtype
# input_dim = x.shape[1]
residual = x
x = self.norm1(x)
x = self.activate(x)
@ -140,8 +134,6 @@ class Encoder(nn.Module):
temporal_downsample=(False, True, True),
num_groups=32, # for nn.GroupNorm
activation_fn="swish",
device="cpu",
dtype=torch.bfloat16,
):
super().__init__()
self.filters = filters
@ -154,18 +146,12 @@ class Encoder(nn.Module):
self.activation_fn = get_activation_fn(activation_fn)
self.activate = self.activation_fn()
self.conv_fn = functools.partial(
CausalConv3d,
dtype=dtype,
device=device,
)
self.conv_fn = CausalConv3d
self.block_args = dict(
conv_fn=self.conv_fn,
dtype=dtype,
activation_fn=self.activation_fn,
use_conv_shortcut=False,
num_groups=self.num_groups,
device=device,
)
# first layer conv
@ -174,8 +160,6 @@ class Encoder(nn.Module):
filters,
kernel_size=(3, 3, 3),
bias=False,
dtype=dtype,
device=device,
)
# ResBlocks and conv downsample
@ -214,13 +198,9 @@ class Encoder(nn.Module):
prev_filters = filters # update in_channels
# MAGVIT uses Group Normalization
self.norm1 = nn.GroupNorm(
self.num_groups, prev_filters, dtype=dtype, device=device
) # separate <prev_filters> channels into 32 groups
self.norm1 = nn.GroupNorm(self.num_groups, prev_filters)
self.conv2 = self.conv_fn(
prev_filters, self.embedding_dim, kernel_size=(1, 1, 1), dtype=dtype, device=device, padding="same"
)
self.conv2 = self.conv_fn(prev_filters, self.embedding_dim, kernel_size=(1, 1, 1), padding="same")
def forward(self, x):
x = self.conv_in(x)
@ -252,8 +232,6 @@ class Decoder(nn.Module):
temporal_downsample=(False, True, True),
num_groups=32, # for nn.GroupNorm
activation_fn="swish",
device="cpu",
dtype=torch.bfloat16,
):
super().__init__()
self.filters = filters
@ -267,18 +245,12 @@ class Decoder(nn.Module):
self.activation_fn = get_activation_fn(activation_fn)
self.activate = self.activation_fn()
self.conv_fn = functools.partial(
CausalConv3d,
dtype=dtype,
device=device,
)
self.conv_fn = CausalConv3d
self.block_args = dict(
conv_fn=self.conv_fn,
activation_fn=self.activation_fn,
use_conv_shortcut=False,
num_groups=self.num_groups,
device=device,
dtype=dtype,
)
filters = self.filters * self.channel_multipliers[-1]
@ -323,9 +295,9 @@ class Decoder(nn.Module):
nn.Identity(prev_filters),
)
self.norm1 = nn.GroupNorm(self.num_groups, prev_filters, device=device, dtype=dtype)
self.norm1 = nn.GroupNorm(self.num_groups, prev_filters)
self.conv_out = self.conv_fn(filters, in_out_channels, 3, dtype=dtype, device=device)
self.conv_out = self.conv_fn(filters, in_out_channels, 3)
def forward(self, x):
x = self.conv1(x)
@ -364,8 +336,6 @@ class VAE_Temporal(nn.Module):
temporal_downsample=(True, True, False),
num_groups=32, # for nn.GroupNorm
activation_fn="swish",
device="cpu",
dtype=torch.bfloat16,
):
super().__init__()
@ -383,12 +353,10 @@ class VAE_Temporal(nn.Module):
temporal_downsample=temporal_downsample,
num_groups=num_groups, # for nn.GroupNorm
activation_fn=activation_fn,
device=device,
dtype=dtype,
)
self.quant_conv = CausalConv3d(2 * latent_embed_dim, 2 * embed_dim, 1, device=device, dtype=dtype)
self.quant_conv = CausalConv3d(2 * latent_embed_dim, 2 * embed_dim, 1)
self.post_quant_conv = CausalConv3d(embed_dim, latent_embed_dim, 1, device=device, dtype=dtype)
self.post_quant_conv = CausalConv3d(embed_dim, latent_embed_dim, 1)
self.decoder = Decoder(
in_out_channels=in_out_channels,
latent_embed_dim=latent_embed_dim,
@ -398,8 +366,6 @@ class VAE_Temporal(nn.Module):
temporal_downsample=temporal_downsample,
num_groups=num_groups, # for nn.GroupNorm
activation_fn=activation_fn,
device=device,
dtype=dtype,
)
def get_latent_size(self, input_size):

View file

@ -79,6 +79,10 @@ def merge_args(cfg, args, training=False):
if args.data_path is not None:
cfg.dataset["data_path"] = args.data_path
args.data_path = None
if not training and args.image_size is not None and "dataset" in cfg:
cfg.dataset["image_size"] = args.image_size
if not training and args.num_frames is not None and "dataset" in cfg:
cfg.dataset["num_frames"] = args.num_frames
if not training and args.cfg_scale is not None:
cfg.scheduler["cfg_scale"] = args.cfg_scale
args.cfg_scale = None

View file

@ -7,7 +7,7 @@ from colossalai.cluster import DistCoordinator
from mmengine.runner import set_random_seed
from tqdm import tqdm
from opensora.acceleration.parallel_states import get_data_parallel_group, set_sequence_parallel_group
from opensora.acceleration.parallel_states import get_data_parallel_group
from opensora.datasets import prepare_dataloader, save_sample
from opensora.models.vae.losses import VAELoss
from opensora.registry import DATASETS, MODELS, build_module
@ -27,11 +27,6 @@ def main():
use_dist = True
colossalai.launch_from_torch({})
coordinator = DistCoordinator()
if coordinator.world_size > 1:
set_sequence_parallel_group(dist.group.WORLD)
else:
pass
else:
use_dist = False
@ -59,88 +54,40 @@ def main():
process_group=get_data_parallel_group(),
)
print(f"Dataset contains {len(dataset):,} videos ({cfg.dataset.data_path})")
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size
total_batch_size = cfg.batch_size * dist.get_world_size()
print(f"Total batch size: {total_batch_size}")
# ======================================================
# 4. build model & load weights
# ======================================================
# 3.1. build model
if cfg.get("vae_2d", None) is not None:
vae_2d = build_module(cfg.vae_2d, MODELS)
vae_2d.to(device, dtype).eval()
model = build_module(
cfg.model,
MODELS,
device=device,
dtype=dtype,
)
# discriminator = build_module(cfg.discriminator, MODELS, device=device)
# 3.2. move to device & eval
# discriminator = discriminator.to(device, dtype).eval()
# 3.4. support for multi-resolution
# model_args = dict()
# if cfg.multi_resolution:
# image_size = cfg.dataset.image_size
# hw = torch.tensor([image_size], device=device, dtype=dtype).repeat(cfg.batch_size, 1)
# ar = torch.tensor([[image_size[0] / image_size[1]]], device=device, dtype=dtype).repeat(cfg.batch_size, 1)
# model_args["data_info"] = dict(ar=ar, hw=hw)
# 4.1. build model
model = build_module(cfg.model, MODELS)
model.to(device, dtype).eval()
# ======================================================
# 4. inference
# 5. inference
# ======================================================
save_dir = cfg.save_dir
os.makedirs(save_dir, exist_ok=True)
# 4.1. batch generation
# define loss function
if cfg.calc_loss:
vae_loss_fn = VAELoss(
logvar_init=cfg.logvar_init,
perceptual_loss_weight=cfg.perceptual_loss_weight,
kl_loss_weight=cfg.kl_loss_weight,
device=device,
dtype=dtype,
)
# adversarial_loss_fn = AdversarialLoss(
# discriminator_factor=cfg.discriminator_factor,
# discriminator_start=cfg.discriminator_start,
# generator_factor=cfg.generator_factor,
# generator_loss_type=cfg.generator_loss_type,
# )
# disc_loss_fn = DiscriminatorLoss(
# discriminator_factor=cfg.discriminator_factor,
# discriminator_start=cfg.discriminator_start,
# discriminator_loss_type=cfg.discriminator_loss_type,
# lecam_loss_weight=cfg.lecam_loss_weight,
# gradient_penalty_loss_weight=cfg.gradient_penalty_loss_weight,
# )
# # LeCam EMA for discriminator
# lecam_ema = LeCamEMA(decay=cfg.ema_decay, dtype=dtype, device=device)
running_loss = 0.0
running_nll = 0.0
loss_steps = 0
# disc_time_downsample_factor = 2 ** len(cfg.discriminator.channel_multipliers)
# if cfg.dataset.num_frames % disc_time_downsample_factor != 0:
# disc_time_padding = disc_time_downsample_factor - cfg.dataset.num_frames % disc_time_downsample_factor
# else:
# disc_time_padding = 0
vae_loss_fn = VAELoss(
logvar_init=cfg.get("logvar_init", 0.0),
perceptual_loss_weight=cfg.perceptual_loss_weight,
kl_loss_weight=cfg.kl_loss_weight,
device=device,
dtype=dtype,
)
# get total number of steps
total_steps = len(dataloader)
if cfg.max_test_samples is not None:
total_steps = min(int(cfg.max_test_samples // cfg.batch_size), total_steps)
print(f"limiting test dataset to {int(cfg.max_test_samples//cfg.batch_size) * cfg.batch_size}")
dataloader_iter = iter(dataloader)
running_loss = running_nll = 0.0
loss_steps = 0
with tqdm(
range(total_steps),
disable=not coordinator.is_master(),
@ -151,95 +98,28 @@ def main():
batch = next(dataloader_iter)
x = batch["video"].to(device, dtype) # [B, C, T, H, W]
# ===== Spatial VAE =====
if cfg.get("vae_2d", None) is not None:
x_z = vae_2d.encode(x)
x_z_debug = vae_2d.decode(x_z)
# ===== VAE =====
z, posterior, x_z = model.encode(x, training=True)
x_rec, _ = model.decode(z, num_frames=x.size(2))
x_ref = model.spatial_vae.decode(x_z)
# ====== VAE ======
x_z_rec, posterior, z = model(x_z)
x_rec = vae_2d.decode(x_z_rec)
if cfg.calc_loss:
# simple nll loss
nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(x, x_rec, posterior)
# fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value=0.0, dim=2)
# fake_logits = discriminator(fake_video.contiguous())
# adversarial_loss = adversarial_loss_fn(
# fake_logits,
# nll_loss,
# vae.get_last_layer(),
# cfg.discriminator_start + 1, # Hack to use discriminator
# is_training=vae.training,
# )
# vae_loss = weighted_nll_loss + weighted_kl_loss + adversarial_loss
vae_loss = weighted_nll_loss + weighted_kl_loss
# # ====== Discriminator Loss ======
# real_video = pad_at_dim(video, (disc_time_padding, 0), value=0.0, dim=2)
# fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value=0.0, dim=2)
# if cfg.gradient_penalty_loss_weight is not None and cfg.gradient_penalty_loss_weight > 0.0:
# real_video = real_video.requires_grad_()
# real_logits = discriminator(
# real_video.contiguous()
# ) # SCH: not detached for now for gradient_penalty calculation
# else:
# real_logits = discriminator(real_video.contiguous().detach())
# fake_logits = discriminator(fake_video.contiguous().detach())
# lecam_ema_real, lecam_ema_fake = lecam_ema.get()
# weighted_d_adversarial_loss, lecam_loss, gradient_penalty_loss = disc_loss_fn(
# real_logits,
# fake_logits,
# cfg.discriminator_start + 1, # Hack to use discriminator
# lecam_ema_real=lecam_ema_real,
# lecam_ema_fake=lecam_ema_fake,
# real_video=real_video if cfg.gradient_penalty_loss_weight is not None else None,
# )
# disc_loss = weighted_d_adversarial_loss + lecam_loss + gradient_penalty_loss
loss_steps += 1
# running_disc_loss = disc_loss.item() / loss_steps + running_disc_loss * ((loss_steps - 1) / loss_steps)
running_loss = vae_loss.item() / loss_steps + running_loss * ((loss_steps - 1) / loss_steps)
running_nll = nll_loss.item() / loss_steps + running_nll * ((loss_steps - 1) / loss_steps)
# ===== Spatial VAE =====
# loss calculation
nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(x, x_rec, posterior)
vae_loss = weighted_nll_loss + weighted_kl_loss
loss_steps += 1
running_loss = vae_loss.item() / loss_steps + running_loss * ((loss_steps - 1) / loss_steps)
running_nll = nll_loss.item() / loss_steps + running_nll * ((loss_steps - 1) / loss_steps)
if not use_dist or coordinator.is_master():
for idx in range(len(x)):
for idx, vid in enumerate(x):
pos = step * cfg.batch_size + idx
save_path = os.path.join(save_dir, f"sample_{pos}")
save_sample(x[idx], fps=cfg.fps, save_path=save_path + "_original")
save_sample(x_rec[idx], fps=cfg.fps, save_path=save_path + "_pipeline")
if cfg.get("vae_2d", None) is not None:
save_sample(x_z_debug[idx], fps=cfg.fps, save_path=save_path + "_2d")
save_path = os.path.join(save_dir, f"sample_{pos:03d}")
save_sample(vid, fps=cfg.fps, save_path=save_path + "_ori")
save_sample(x_rec[idx], fps=cfg.fps, save_path=save_path + "_rec")
save_sample(x_ref[idx], fps=cfg.fps, save_path=save_path + "_ref")
# if cfg.get("use_pipeline") == True:
# for idx, (sample_original, sample_pipeline, sample_2d) in enumerate(
# zip(video, recon_video, recon_2d)
# ):
# pos = step * cfg.batch_size + idx
# save_path = os.path.join(save_dir, f"sample_{pos}")
# save_sample(sample_original, fps=cfg.fps, save_path=save_path + "_original")
# save_sample(sample_2d, fps=cfg.fps, save_path=save_path + "_2d")
# save_sample(sample_pipeline, fps=cfg.fps, save_path=save_path + "_pipeline")
# else:
# for idx, (original, recon) in enumerate(zip(video, recon_video)):
# pos = step * cfg.batch_size + idx
# save_path = os.path.join(save_dir, f"sample_{pos}")
# save_sample(original, fps=cfg.fps, save_path=save_path + "_original")
# save_sample(recon, fps=cfg.fps, save_path=save_path + "_recon")
if cfg.calc_loss:
print("test vae loss:", running_loss)
print("test nll loss:", running_nll)
# print("test disc loss:", running_disc_loss)
print("test vae loss:", running_loss)
print("test nll loss:", running_nll)
if __name__ == "__main__":

View file

@ -14,14 +14,9 @@ from tqdm import tqdm
import wandb
from opensora.acceleration.checkpoint import set_grad_checkpoint
from opensora.acceleration.parallel_states import (
get_data_parallel_group,
set_data_parallel_group,
set_sequence_parallel_group,
)
from opensora.acceleration.plugin import ZeroSeqParallelPlugin
from opensora.acceleration.parallel_states import get_data_parallel_group, set_data_parallel_group
from opensora.datasets import prepare_dataloader
from opensora.models.vae.losses import AdversarialLoss, DiscriminatorLoss, VAELoss
from opensora.models.vae.losses import VAELoss
from opensora.registry import DATASETS, MODELS, build_module
from opensora.utils.ckpt_utils import create_logger, load_json, save_json
from opensora.utils.config_utils import (
@ -78,16 +73,6 @@ def main():
max_norm=cfg.grad_clip,
)
set_data_parallel_group(dist.group.WORLD)
elif cfg.plugin == "zero2-seq":
plugin = ZeroSeqParallelPlugin(
sp_size=cfg.sp_size,
stage=2,
precision=cfg.dtype,
initial_scale=2**16,
max_norm=cfg.grad_clip,
)
set_sequence_parallel_group(plugin.sp_group)
set_data_parallel_group(plugin.dp_group)
else:
raise ValueError(f"Unknown plugin {cfg.plugin}")
booster = Booster(plugin=plugin)
@ -110,83 +95,40 @@ def main():
)
# TODO: use plugin's prepare dataloader
dataloader = prepare_dataloader(**dataloader_args)
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size
total_batch_size = cfg.batch_size * dist.get_world_size()
logger.info(f"Total batch size: {total_batch_size}")
# ======================================================
# 4. build model
# ======================================================
# 4.1. build model
if cfg.get("vae_2d", None) is not None:
vae_2d = build_module(cfg.vae_2d, MODELS)
vae_2d.to(device, dtype).eval()
model = build_module(
cfg.model,
MODELS,
device=device,
dtype=dtype,
)
model = build_module(cfg.model, MODELS)
model.to(device, dtype)
model_numel, model_numel_trainable = get_model_numel(model)
logger.info(
f"Trainable model params: {format_numel_str(model_numel_trainable)}, Total model params: {format_numel_str(model_numel)}"
)
# discriminator = build_module(cfg.discriminator, MODELS, device=device)
# discriminator_numel, discriminator_numel_trainable = get_model_numel(discriminator)
# logger.info(
# f"Trainable discriminator params: {format_numel_str(discriminator_numel_trainable)}, Total model params: {format_numel_str(discriminator_numel)}"
# )
# # LeCam Initialization
# lecam_ema = LeCamEMA(decay=cfg.ema_decay, dtype=dtype, device=device)
# 4.3. move to device
model = model.to(device, dtype)
# discriminator = discriminator.to(device, dtype)
# 4.4 loss functions
vae_loss_fn = VAELoss(
logvar_init=cfg.logvar_init,
logvar_init=cfg.get("logvar_init", 0.0),
perceptual_loss_weight=cfg.perceptual_loss_weight,
kl_loss_weight=cfg.kl_loss_weight,
device=device,
dtype=dtype,
)
adversarial_loss_fn = AdversarialLoss(
discriminator_factor=cfg.discriminator_factor,
discriminator_start=cfg.discriminator_start,
generator_factor=cfg.generator_factor,
generator_loss_type=cfg.generator_loss_type,
)
disc_loss_fn = DiscriminatorLoss(
discriminator_factor=cfg.discriminator_factor,
discriminator_start=cfg.discriminator_start,
discriminator_loss_type=cfg.discriminator_loss_type,
lecam_loss_weight=cfg.lecam_loss_weight,
gradient_penalty_loss_weight=cfg.gradient_penalty_loss_weight,
)
# 4.5. setup optimizer
# vae optimizer
optimizer = HybridAdam(
filter(lambda p: p.requires_grad, model.parameters()), lr=cfg.lr, weight_decay=0, adamw_mode=True
)
lr_scheduler = None
# disc optimizer
# disc_optimizer = HybridAdam(
# filter(lambda p: p.requires_grad, discriminator.parameters()), lr=cfg.lr, weight_decay=0, adamw_mode=True
# )
# disc_lr_scheduler = None
# 4.6. prepare for training
if cfg.grad_checkpoint:
set_grad_checkpoint(model)
# set_grad_checkpoint(discriminator)
model.train()
# discriminator.train()
# =======================================================
# 5. boost model for distributed training with colossalai
@ -203,11 +145,6 @@ def main():
logger.info("Boost model for distributed training")
num_steps_per_epoch = len(dataloader)
# discriminator, disc_optimizer, _, _, disc_lr_scheduler = booster.boost(
# model=discriminator, optimizer=disc_optimizer, lr_scheduler=disc_lr_scheduler
# )
# logger.info("Boost discriminator for distributed training")
# =======================================================
# 6. training loop
# =======================================================
@ -221,18 +158,6 @@ def main():
booster.load_model(model, os.path.join(cfg.load, "model"))
booster.load_optimizer(optimizer, os.path.join(cfg.load, "optimizer"))
# booster.load_model(discriminator, os.path.join(cfg.load, "discriminator"))
# booster.load_optimizer(disc_optimizer, os.path.join(cfg.load, "disc_optimizer"))
# LeCam EMA for discriminator
# lecam_path = os.path.join(cfg.load, "lecam_states.json")
# if cfg.lecam_loss_weight is not None and os.path.exists(lecam_path):
# lecam_state = load_json(lecam_path)
# lecam_ema_real, lecam_ema_fake = lecam_state["lecam_ema_real"], lecam_state["lecam_ema_fake"]
# lecam_ema = LeCamEMA(
# decay=cfg.ema_decay, ema_real=lecam_ema_real, ema_fake=lecam_ema_fake, dtype=dtype, device=device
# )
running_states = load_json(os.path.join(cfg.load, "running_states.json"))
dist.barrier()
start_epoch, start_step, sampler_start_idx = (
@ -246,15 +171,6 @@ def main():
dataloader.sampler.set_start_index(sampler_start_idx)
# 6.3. training loop
# calculate discriminator_time_padding
# disc_time_downsample_factor = 2 ** len(cfg.discriminator.channel_multipliers)
# if cfg.dataset.num_frames % disc_time_downsample_factor != 0:
# disc_time_padding = disc_time_downsample_factor - cfg.dataset.num_frames % disc_time_downsample_factor
# else:
# disc_time_padding = 0
# video_contains_first_frame = cfg.video_contains_first_frame
for epoch in range(start_epoch, cfg.epochs):
dataloader.sampler.set_epoch(epoch)
dataloader_iter = iter(dataloader)
@ -269,112 +185,41 @@ def main():
) as pbar:
for step, batch in pbar:
x = batch["video"].to(device, dtype) # [B, C, T, H, W]
if random.random() < 0.5:
if random.random() < cfg.get("mixed_image_ratio", 0.0):
x = x[:, :, :1, :, :]
# ===== Spatial VAE =====
if cfg.get("vae_2d", None) is not None:
with torch.no_grad():
x_z = vae_2d.encode(x)
vae_2d.decode(x_z)
# ====== VAE ======
x_z_rec, posterior, z = model(x_z)
x_rec = vae_2d.decode(x_z_rec)
# ===== VAE =====
x_rec, x_z_rec, z, posterior, x_z = model(x)
# ====== Generator Loss ======
# simple nll loss
_, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(x, x_rec, posterior)
# _, weighted_z_nll_loss, _ = vae_loss_fn(x_z, x_z_rec, posterior)
# _, debug_loss, _ = vae_loss_fn(x, x_z_debug, posterior)
# _, image_identity_loss, _ = vae_loss_fn(x_z, z, posterior)
vae_loss = torch.tensor(0.0, device=device, dtype=dtype)
log_dict = {}
if cfg.get("use_real_rec_loss", False):
_, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(x, x_rec, posterior)
vae_loss += weighted_nll_loss + weighted_kl_loss
log_dict["kl_loss"] = weighted_kl_loss.item()
log_dict["nll_loss"] = weighted_nll_loss.item()
if cfg.get("use_z_rec_loss", False):
_, weighted_z_nll_loss, _ = vae_loss_fn(x_z, x_z_rec, posterior)
vae_loss += weighted_z_nll_loss
log_dict["z_nll_loss"] = weighted_z_nll_loss.item()
if cfg.get("use_image_identity_loss", False):
_, image_identity_loss, _ = vae_loss_fn(x_z, z, posterior)
vae_loss += image_identity_loss
log_dict["image_identity_loss"] = image_identity_loss.item()
# adversarial_loss = torch.tensor(0.0)
# adversarial loss
# if global_step > cfg.discriminator_start:
# # padded videos for GAN
# fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value=0.0, dim=2)
# fake_logits = discriminator(fake_video.contiguous())
# adversarial_loss = adversarial_loss_fn(
# fake_logits,
# nll_loss,
# vae.module.get_last_layer(),
# global_step,
# is_training=vae.training,
# )
# vae_loss = weighted_nll_loss + weighted_kl_loss + adversarial_loss + weighted_z_nll_loss
# vae_loss = weighted_nll_loss + weighted_kl_loss + weighted_z_nll_loss + image_identity_loss
# vae_loss = weighted_z_nll_loss + image_identity_loss
# vae_loss = weighted_nll_loss + weighted_kl_loss + weighted_z_nll_loss
# vae_loss = weighted_z_nll_loss
vae_loss = weighted_nll_loss + weighted_kl_loss
optimizer.zero_grad()
# Backward & update
booster.backward(loss=vae_loss, optimizer=optimizer)
# # NOTE: clip gradients? this is done in Open-Sora-Plan
# torch.nn.utils.clip_grad_norm_(vae.parameters(), 1) # NOTE: done by grad_clip
optimizer.step()
optimizer.zero_grad()
# Log loss values:
all_reduce_mean(vae_loss) # NOTE: this is to get average loss for logging
all_reduce_mean(vae_loss)
running_loss += vae_loss.item()
global_step = epoch * num_steps_per_epoch + step
log_step += 1
acc_step += 1
# ====== Discriminator Loss ======
# if global_step > cfg.discriminator_start:
# # if video_contains_first_frame:
# # Since we don't have enough T frames, pad anyways
# real_video = pad_at_dim(video, (disc_time_padding, 0), value=0.0, dim=2)
# fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value=0.0, dim=2)
# if cfg.gradient_penalty_loss_weight is not None and cfg.gradient_penalty_loss_weight > 0.0:
# real_video = real_video.requires_grad_()
# real_logits = discriminator(
# real_video.contiguous()
# ) # SCH: not detached for now for gradient_penalty calculation
# else:
# real_logits = discriminator(real_video.contiguous().detach())
# fake_logits = discriminator(fake_video.contiguous().detach())
# lecam_ema_real, lecam_ema_fake = lecam_ema.get()
# weighted_d_adversarial_loss, lecam_loss, gradient_penalty_loss = disc_loss_fn(
# real_logits,
# fake_logits,
# global_step,
# lecam_ema_real=lecam_ema_real,
# lecam_ema_fake=lecam_ema_fake,
# real_video=real_video if cfg.gradient_penalty_loss_weight is not None else None,
# )
# disc_loss = weighted_d_adversarial_loss + lecam_loss + gradient_penalty_loss
# if cfg.lecam_loss_weight is not None:
# ema_real = torch.mean(real_logits.clone().detach()).to(device, dtype)
# ema_fake = torch.mean(fake_logits.clone().detach()).to(device, dtype)
# all_reduce_mean(ema_real)
# all_reduce_mean(ema_fake)
# lecam_ema.update(ema_real, ema_fake)
# disc_optimizer.zero_grad()
# # Backward & update
# booster.backward(loss=disc_loss, optimizer=disc_optimizer)
# # # NOTE: TODO: clip gradients? this is done in Open-Sora-Plan
# # torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1) # NOTE: done by grad_clip
# disc_optimizer.step()
# # Log loss values:
# all_reduce_mean(disc_loss)
# running_disc_loss += disc_loss.item()
# else:
# disc_loss = torch.tensor(0.0)
# weighted_d_adversarial_loss = torch.tensor(0.0)
# lecam_loss = torch.tensor(0.0)
# gradient_penalty_loss = torch.tensor(0.0)
# Log to tensorboard
if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0:
avg_loss = running_loss / log_step
@ -393,16 +238,8 @@ def main():
"num_samples": global_step * total_batch_size,
"epoch": epoch,
"loss": vae_loss.item(),
"kl_loss": weighted_kl_loss.item(),
# "gen_adv_loss": adversarial_loss.item(),
# "disc_loss": disc_loss.item(),
# "lecam_loss": lecam_loss.item(),
# "r1_grad_penalty": gradient_penalty_loss.item(),
"nll_loss": weighted_nll_loss.item(),
# "z_nll_loss": weighted_z_nll_loss.item(),
# "image_identity_loss": image_identity_loss.item(),
# "debug_loss": debug_loss.item(),
"avg_loss": avg_loss,
**log_dict,
},
step=global_step,
)
@ -412,38 +249,22 @@ def main():
save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}")
os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) # already handled in booster save_model
booster.save_model(model, os.path.join(save_dir, "model"), shard=True)
# booster.save_model(discriminator, os.path.join(save_dir, "discriminator"), shard=True)
booster.save_optimizer(
optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096
)
# booster.save_optimizer(
# disc_optimizer, os.path.join(save_dir, "disc_optimizer"), shard=True, size_per_shard=4096
# )
running_states = {
"epoch": epoch,
"step": step + 1,
"global_step": global_step + 1,
"sample_start_index": (step + 1) * cfg.batch_size,
}
# lecam_ema_real, lecam_ema_fake = lecam_ema.get()
# lecam_state = {
# "lecam_ema_real": lecam_ema_real.item(),
# "lecam_ema_fake": lecam_ema_fake.item(),
# }
if coordinator.is_master():
save_json(running_states, os.path.join(save_dir, "running_states.json"))
# if cfg.lecam_loss_weight is not None:
# save_json(lecam_state, os.path.join(save_dir, "lecam_states.json"))
dist.barrier()
logger.info(
f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}"
)
# print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
dataloader.sampler.set_start_index(0)
start_step = 0