mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
refactor
This commit is contained in:
parent
41e276f5ef
commit
ace37cf7c7
|
|
@ -1,79 +0,0 @@
|
|||
num_frames = 16
|
||||
image_size = (256, 256)
|
||||
fps = 24 // 3
|
||||
max_test_samples = None
|
||||
|
||||
# Define dataset
|
||||
dataset = dict(
|
||||
type="VideoTextDataset",
|
||||
data_path=None,
|
||||
num_frames=num_frames,
|
||||
frame_interval=1,
|
||||
image_size=image_size,
|
||||
)
|
||||
|
||||
# Define acceleration
|
||||
num_workers = 4
|
||||
dtype = "bf16"
|
||||
grad_checkpoint = True
|
||||
plugin = "zero2"
|
||||
sp_size = 1
|
||||
|
||||
|
||||
# Define model
|
||||
vae_2d = dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
||||
subfolder="vae",
|
||||
micro_batch_size=4,
|
||||
local_files_only=True,
|
||||
)
|
||||
|
||||
model = dict(
|
||||
type="VAE_Temporal_SD",
|
||||
)
|
||||
|
||||
# discriminator = dict(
|
||||
# type="DISCRIMINATOR_3D",
|
||||
# image_size=image_size,
|
||||
# num_frames=num_frames,
|
||||
# in_channels=3,
|
||||
# filters=128,
|
||||
# channel_multipliers=(2, 4, 4, 4, 4),
|
||||
# # channel_multipliers = (2,4,4), #(2,4,4,4,4) # (2,4,4,4) for 64x64 resolution
|
||||
# )
|
||||
|
||||
|
||||
# loss weights
|
||||
logvar_init = 0.0
|
||||
kl_loss_weight = 0.000001
|
||||
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
|
||||
discriminator_factor = 1.0 # for discriminator adversarial loss
|
||||
# discriminator_loss_weight = 0.5 # for generator adversarial loss
|
||||
generator_factor = 0.1 # for generator adversarial loss
|
||||
lecam_loss_weight = None # NOTE: not clear in MAGVIT what is the weight
|
||||
discriminator_loss_type = "non-saturating"
|
||||
generator_loss_type = "non-saturating"
|
||||
discriminator_start = 2500 # 50000 NOTE: change to correct val, debug use -1 for now
|
||||
gradient_penalty_loss_weight = None # 10 # SCH: MAGVIT uses 10, opensora plan doesn't use
|
||||
ema_decay = 0.999 # ema decay factor for generator
|
||||
|
||||
|
||||
# Others
|
||||
seed = 42
|
||||
save_dir = "samples/samples_vae"
|
||||
wandb = False
|
||||
|
||||
# Training
|
||||
""" NOTE:
|
||||
magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128
|
||||
==> ours num_frams = 16, reso = 256, so samples (K) * epochs ~ [500 - 1200],
|
||||
3-6 epochs for pexel, from pexel observation its correct
|
||||
"""
|
||||
|
||||
|
||||
batch_size = 1
|
||||
lr = 1e-4
|
||||
grad_clip = 1.0
|
||||
|
||||
calc_loss = True
|
||||
|
|
@ -1,80 +0,0 @@
|
|||
num_frames = 1
|
||||
# image_size = (256, 256)
|
||||
image_size = (1024, 1024)
|
||||
fps = 24 // 3
|
||||
max_test_samples = None
|
||||
|
||||
# Define dataset
|
||||
dataset = dict(
|
||||
type="VideoTextDataset",
|
||||
data_path=None,
|
||||
num_frames=num_frames,
|
||||
frame_interval=1,
|
||||
image_size=image_size,
|
||||
)
|
||||
|
||||
# Define acceleration
|
||||
num_workers = 4
|
||||
dtype = "bf16"
|
||||
grad_checkpoint = True
|
||||
plugin = "zero2"
|
||||
sp_size = 1
|
||||
|
||||
|
||||
# Define model
|
||||
vae_2d = dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
||||
subfolder="vae",
|
||||
micro_batch_size=4,
|
||||
local_files_only=True,
|
||||
)
|
||||
|
||||
model = dict(
|
||||
type="VAE_Temporal_SD",
|
||||
)
|
||||
|
||||
# discriminator = dict(
|
||||
# type="DISCRIMINATOR_3D",
|
||||
# image_size=image_size,
|
||||
# num_frames=num_frames,
|
||||
# in_channels=3,
|
||||
# filters=128,
|
||||
# channel_multipliers=(2, 4, 4, 4, 4),
|
||||
# # channel_multipliers = (2,4,4), #(2,4,4,4,4) # (2,4,4,4) for 64x64 resolution
|
||||
# )
|
||||
|
||||
|
||||
# loss weights
|
||||
logvar_init = 0.0
|
||||
kl_loss_weight = 0.000001
|
||||
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
|
||||
discriminator_factor = 1.0 # for discriminator adversarial loss
|
||||
# discriminator_loss_weight = 0.5 # for generator adversarial loss
|
||||
generator_factor = 0.1 # for generator adversarial loss
|
||||
lecam_loss_weight = None # NOTE: not clear in MAGVIT what is the weight
|
||||
discriminator_loss_type = "non-saturating"
|
||||
generator_loss_type = "non-saturating"
|
||||
discriminator_start = 2500 # 50000 NOTE: change to correct val, debug use -1 for now
|
||||
gradient_penalty_loss_weight = None # 10 # SCH: MAGVIT uses 10, opensora plan doesn't use
|
||||
ema_decay = 0.999 # ema decay factor for generator
|
||||
|
||||
|
||||
# Others
|
||||
seed = 42
|
||||
save_dir = "samples/samples_vae"
|
||||
wandb = False
|
||||
|
||||
# Training
|
||||
""" NOTE:
|
||||
magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128
|
||||
==> ours num_frams = 16, reso = 256, so samples (K) * epochs ~ [500 - 1200],
|
||||
3-6 epochs for pexel, from pexel observation its correct
|
||||
"""
|
||||
|
||||
|
||||
batch_size = 1
|
||||
lr = 1e-4
|
||||
grad_clip = 1.0
|
||||
|
||||
calc_loss = True
|
||||
42
configs/vae/inference/image.py
Normal file
42
configs/vae/inference/image.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
num_frames = 1
|
||||
frame_interval = 1
|
||||
fps = 24
|
||||
image_size = (256, 256)
|
||||
|
||||
# Define dataset
|
||||
dataset = dict(
|
||||
type="VideoTextDataset",
|
||||
data_path=None,
|
||||
num_frames=num_frames,
|
||||
frame_interval=1,
|
||||
image_size=image_size,
|
||||
)
|
||||
num_workers = 4
|
||||
max_test_samples = None
|
||||
|
||||
# Define model
|
||||
model = dict(
|
||||
type="VideoAutoencoderPipeline",
|
||||
freeze_vae_2d=True,
|
||||
vae_2d=dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
||||
subfolder="vae",
|
||||
micro_batch_size=4,
|
||||
local_files_only=True,
|
||||
),
|
||||
vae_temporal=dict(
|
||||
type="VAE_Temporal_SD",
|
||||
from_pretrained=None,
|
||||
),
|
||||
)
|
||||
dtype = "bf16"
|
||||
|
||||
# loss weights
|
||||
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
|
||||
kl_loss_weight = 1e-6
|
||||
|
||||
# Others
|
||||
batch_size = 1
|
||||
seed = 42
|
||||
save_dir = "samples/vae_image"
|
||||
42
configs/vae/inference/video.py
Normal file
42
configs/vae/inference/video.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
num_frames = 17
|
||||
frame_interval = 1
|
||||
fps = 24
|
||||
image_size = (256, 256)
|
||||
|
||||
# Define dataset
|
||||
dataset = dict(
|
||||
type="VideoTextDataset",
|
||||
data_path=None,
|
||||
num_frames=num_frames,
|
||||
frame_interval=1,
|
||||
image_size=image_size,
|
||||
)
|
||||
num_workers = 4
|
||||
max_test_samples = None
|
||||
|
||||
# Define model
|
||||
model = dict(
|
||||
type="VideoAutoencoderPipeline",
|
||||
from_pretrained=None,
|
||||
vae_2d=dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
||||
subfolder="vae",
|
||||
micro_batch_size=4,
|
||||
local_files_only=True,
|
||||
),
|
||||
vae_temporal=dict(
|
||||
type="VAE_Temporal_SD",
|
||||
from_pretrained=None,
|
||||
),
|
||||
)
|
||||
dtype = "bf16"
|
||||
|
||||
# loss weights
|
||||
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
|
||||
kl_loss_weight = 1e-6
|
||||
|
||||
# Others
|
||||
batch_size = 1
|
||||
seed = 42
|
||||
save_dir = "samples/vae_video"
|
||||
|
|
@ -1,74 +0,0 @@
|
|||
num_frames = 17
|
||||
image_size = (256, 256)
|
||||
|
||||
# Define dataset
|
||||
dataset = dict(
|
||||
type="VideoTextDataset",
|
||||
data_path=None,
|
||||
num_frames=num_frames,
|
||||
frame_interval=1,
|
||||
image_size=image_size,
|
||||
)
|
||||
|
||||
# Define acceleration
|
||||
num_workers = 16
|
||||
dtype = "bf16"
|
||||
grad_checkpoint = True
|
||||
plugin = "zero2"
|
||||
sp_size = 1
|
||||
|
||||
# latest
|
||||
vae_2d = dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
||||
subfolder="vae",
|
||||
micro_batch_size=4,
|
||||
local_files_only=True,
|
||||
)
|
||||
|
||||
model = dict(
|
||||
type="VAE_Temporal_SD",
|
||||
)
|
||||
|
||||
|
||||
# discriminator = dict(
|
||||
# type="DISCRIMINATOR_3D",
|
||||
# image_size=image_size, # NOTE: here image size is different
|
||||
# num_frames=num_frames,
|
||||
# in_channels=3,
|
||||
# filters=128,
|
||||
# use_pretrained=True, # NOTE: set to False only if we want to disable load
|
||||
# channel_multipliers=(2, 4, 4, 4, 4), # (2,4,4,4) for 64x64 resolution
|
||||
# # channel_multipliers=(2, 4, 4), # since on intermediate layer dimension ofs z
|
||||
# )
|
||||
|
||||
|
||||
# loss weights
|
||||
logvar_init = 0.0
|
||||
kl_loss_weight = 0.000001
|
||||
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
|
||||
discriminator_factor = 1.0 # for discriminator adversarial loss
|
||||
generator_factor = 0.1 # SCH: generator adversarial loss, MAGVIT v2 uses 0.1
|
||||
lecam_loss_weight = None # NOTE: MAVGIT v2 use 0.001
|
||||
discriminator_loss_type = "non-saturating"
|
||||
generator_loss_type = "non-saturating"
|
||||
# discriminator_loss_type="hinge"
|
||||
# generator_loss_type="hinge"
|
||||
discriminator_start = 2000 # 5000 # 8k data / (8*1) = 1000 steps per epoch
|
||||
gradient_penalty_loss_weight = None # 10 # SCH: MAGVIT uses 10, opensora plan doesn't use
|
||||
ema_decay = 0.999 # ema decay factor for generator
|
||||
|
||||
|
||||
# Others
|
||||
seed = 42
|
||||
outputs = "outputs"
|
||||
wandb = False
|
||||
|
||||
epochs = 100
|
||||
log_every = 1
|
||||
ckpt_every = 1000
|
||||
load = None
|
||||
|
||||
batch_size = 1
|
||||
lr = 1e-5
|
||||
grad_clip = 1.0
|
||||
|
|
@ -1,71 +0,0 @@
|
|||
num_frames = 1
|
||||
image_size = (256, 256)
|
||||
|
||||
# Define dataset
|
||||
dataset = dict(
|
||||
type="VideoTextDataset",
|
||||
data_path=None,
|
||||
num_frames=num_frames,
|
||||
frame_interval=1,
|
||||
image_size=image_size,
|
||||
)
|
||||
|
||||
# Define acceleration
|
||||
num_workers = 16
|
||||
dtype = "bf16"
|
||||
grad_checkpoint = False
|
||||
plugin = "zero2"
|
||||
sp_size = 1
|
||||
|
||||
# latest
|
||||
vae_2d = dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
||||
subfolder="vae",
|
||||
micro_batch_size=4,
|
||||
local_files_only=True,
|
||||
)
|
||||
|
||||
model = dict(
|
||||
type="VAE_Temporal_SD",
|
||||
)
|
||||
|
||||
# discriminator = dict(
|
||||
# type="DISCRIMINATOR_3D",
|
||||
# image_size=image_size, # NOTE: here image size is different
|
||||
# num_frames=num_frames,
|
||||
# in_channels=3,
|
||||
# filters=128,
|
||||
# use_pretrained=True, # NOTE: set to False only if we want to disable load
|
||||
# channel_multipliers=(2, 4, 4, 4, 4), # (2,4,4,4) for 64x64 resolution
|
||||
# # channel_multipliers=(2, 4, 4), # since on intermediate layer dimension ofs z
|
||||
# )
|
||||
|
||||
|
||||
# loss weights
|
||||
logvar_init = 0.0
|
||||
kl_loss_weight = 0.000001
|
||||
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
|
||||
discriminator_factor = 1.0 # for discriminator adversarial loss
|
||||
generator_factor = 0.1 # generator adversarial loss, MAGVIT v2 uses 0.1
|
||||
lecam_loss_weight = None # NOTE: MAVGIT v2 use 0.001
|
||||
discriminator_loss_type = "non-saturating"
|
||||
generator_loss_type = "non-saturating"
|
||||
discriminator_start = 2000
|
||||
gradient_penalty_loss_weight = None # MAGVIT uses 10, opensora plan doesn't use
|
||||
ema_decay = 0.999 # ema decay factor for generator
|
||||
|
||||
|
||||
# Others
|
||||
seed = 42
|
||||
outputs = "outputs"
|
||||
wandb = False
|
||||
|
||||
epochs = 100
|
||||
log_every = 1
|
||||
ckpt_every = 1000
|
||||
load = None
|
||||
|
||||
batch_size = 4
|
||||
lr = 1e-5
|
||||
grad_clip = 1.0
|
||||
58
configs/vae/train/image.py
Normal file
58
configs/vae/train/image.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
num_frames = 1
|
||||
image_size = (256, 256)
|
||||
|
||||
# Define dataset
|
||||
dataset = dict(
|
||||
type="VideoTextDataset",
|
||||
data_path=None,
|
||||
num_frames=num_frames,
|
||||
frame_interval=1,
|
||||
image_size=image_size,
|
||||
)
|
||||
|
||||
# Define acceleration
|
||||
num_workers = 16
|
||||
dtype = "bf16"
|
||||
grad_checkpoint = True
|
||||
plugin = "zero2"
|
||||
|
||||
# Define model
|
||||
model = dict(
|
||||
type="VideoAutoencoderPipeline",
|
||||
freeze_vae_2d=True,
|
||||
from_pretrained=None,
|
||||
vae_2d=dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
||||
subfolder="vae",
|
||||
micro_batch_size=4,
|
||||
local_files_only=True,
|
||||
),
|
||||
vae_temporal=dict(
|
||||
type="VAE_Temporal_SD",
|
||||
from_pretrained=None,
|
||||
),
|
||||
)
|
||||
|
||||
# loss weights
|
||||
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
|
||||
kl_loss_weight = 1e-6
|
||||
|
||||
mixed_image_ratio = 0.1
|
||||
use_real_rec_loss = False
|
||||
use_z_rec_loss = True
|
||||
use_image_identity_loss = True
|
||||
|
||||
# Others
|
||||
seed = 42
|
||||
outputs = "outputs"
|
||||
wandb = False
|
||||
|
||||
epochs = 100
|
||||
log_every = 1
|
||||
ckpt_every = 1000
|
||||
load = None
|
||||
|
||||
batch_size = 1
|
||||
lr = 1e-5
|
||||
grad_clip = 1.0
|
||||
58
configs/vae/train/video.py
Normal file
58
configs/vae/train/video.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
num_frames = 17
|
||||
image_size = (256, 256)
|
||||
|
||||
# Define dataset
|
||||
dataset = dict(
|
||||
type="VideoTextDataset",
|
||||
data_path=None,
|
||||
num_frames=num_frames,
|
||||
frame_interval=1,
|
||||
image_size=image_size,
|
||||
)
|
||||
|
||||
# Define acceleration
|
||||
num_workers = 16
|
||||
dtype = "bf16"
|
||||
grad_checkpoint = True
|
||||
plugin = "zero2"
|
||||
|
||||
# Define model
|
||||
model = dict(
|
||||
type="VideoAutoencoderPipeline",
|
||||
freeze_vae_2d=True,
|
||||
from_pretrained=None,
|
||||
vae_2d=dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
||||
subfolder="vae",
|
||||
micro_batch_size=4,
|
||||
local_files_only=True,
|
||||
),
|
||||
vae_temporal=dict(
|
||||
type="VAE_Temporal_SD",
|
||||
from_pretrained=None,
|
||||
),
|
||||
)
|
||||
|
||||
# loss weights
|
||||
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
|
||||
kl_loss_weight = 1e-6
|
||||
|
||||
mixed_image_ratio = 0.1
|
||||
use_real_rec_loss = True
|
||||
use_z_rec_loss = False
|
||||
use_image_identity_loss = False
|
||||
|
||||
# Others
|
||||
seed = 42
|
||||
outputs = "outputs"
|
||||
wandb = False
|
||||
|
||||
epochs = 100
|
||||
log_every = 1
|
||||
ckpt_every = 1000
|
||||
load = None
|
||||
|
||||
batch_size = 1
|
||||
lr = 1e-5
|
||||
grad_clip = 1.0
|
||||
|
|
@ -94,13 +94,10 @@ class VAELoss(nn.Module):
|
|||
if self.perceptual_loss_weight is not None and self.perceptual_loss_weight > 0.0:
|
||||
# handle channels
|
||||
channels = video.shape[1]
|
||||
assert channels in {1, 3, 4}
|
||||
assert channels in {1, 3}
|
||||
if channels == 1:
|
||||
input_vgg_input = repeat(video, "b 1 h w -> b c h w", c=3)
|
||||
recon_vgg_input = repeat(recon_video, "b 1 h w -> b c h w", c=3)
|
||||
elif channels == 4: # SCH: take the first 3 for perceptual loss calc
|
||||
input_vgg_input = video[:, :3]
|
||||
recon_vgg_input = recon_video[:, :3]
|
||||
else:
|
||||
input_vgg_input = video
|
||||
recon_vgg_input = recon_video
|
||||
|
|
@ -109,6 +106,7 @@ class VAELoss(nn.Module):
|
|||
recon_loss = recon_loss + self.perceptual_loss_weight * perceptual_loss
|
||||
|
||||
nll_loss = recon_loss / torch.exp(self.logvar) + self.logvar
|
||||
|
||||
weighted_nll_loss = nll_loss
|
||||
if nll_weights is not None:
|
||||
weighted_nll_loss = nll_weights * nll_loss
|
||||
|
|
|
|||
|
|
@ -3,15 +3,20 @@ import torch.nn as nn
|
|||
from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder
|
||||
from einops import rearrange
|
||||
|
||||
from opensora.registry import MODELS
|
||||
from opensora.registry import MODELS, build_module
|
||||
from opensora.utils.ckpt_utils import load_checkpoint
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class VideoAutoencoderKL(nn.Module):
|
||||
def __init__(self, from_pretrained=None, micro_batch_size=None, cache_dir=None, local_files_only=False, subfolder=None):
|
||||
def __init__(
|
||||
self, from_pretrained=None, micro_batch_size=None, cache_dir=None, local_files_only=False, subfolder=None
|
||||
):
|
||||
super().__init__()
|
||||
self.module = AutoencoderKL.from_pretrained(
|
||||
from_pretrained, cache_dir=cache_dir, local_files_only=local_files_only,
|
||||
from_pretrained,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
subfolder=subfolder,
|
||||
)
|
||||
self.out_channels = self.module.config.latent_channels
|
||||
|
|
@ -107,3 +112,47 @@ class VideoAutoencoderKLTemporalDecoder(nn.Module):
|
|||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class VideoAutoencoderPipeline(nn.Module):
|
||||
def __init__(self, vae_2d=None, vae_temporal=None, freeze_vae_2d=False, from_pretrained=None):
|
||||
super().__init__()
|
||||
self.spatial_vae = build_module(vae_2d, MODELS)
|
||||
self.temporal_vae = build_module(vae_temporal, MODELS)
|
||||
if from_pretrained is not None:
|
||||
load_checkpoint(self, from_pretrained)
|
||||
if freeze_vae_2d:
|
||||
for param in self.spatial_vae.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def encode(self, x, training=True):
|
||||
x_z = self.spatial_vae.encode(x)
|
||||
posterior = self.temporal_vae.encode(x_z)
|
||||
z = posterior.sample()
|
||||
if training:
|
||||
return z, posterior, x_z
|
||||
return z
|
||||
|
||||
def decode(self, z, num_frames=None, training=True):
|
||||
x_z = self.temporal_vae.decode(z, num_frames=num_frames)
|
||||
x = self.spatial_vae.decode(x_z)
|
||||
if training:
|
||||
return x, x_z
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
z, posterior, x_z = self.encode(x, training=True)
|
||||
x_rec, x_z_rec = self.decode(z, num_frames=x_z.shape[2], training=True)
|
||||
return x_rec, x_z_rec, z, posterior, x_z
|
||||
|
||||
def get_latent_size(self, input_size):
|
||||
return self.temporal_vae.get_latent_size(self.spatial_vae.get_latent_size(input_size))
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
import functools
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
|
@ -82,8 +80,6 @@ class ResBlock(nn.Module):
|
|||
activation_fn=nn.SiLU,
|
||||
use_conv_shortcut=False,
|
||||
num_groups=32,
|
||||
device="cpu",
|
||||
dtype=torch.bfloat16,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
|
@ -92,9 +88,9 @@ class ResBlock(nn.Module):
|
|||
self.use_conv_shortcut = use_conv_shortcut
|
||||
|
||||
# SCH: MAGVIT uses GroupNorm by default
|
||||
self.norm1 = nn.GroupNorm(num_groups, in_channels, device=device, dtype=dtype)
|
||||
self.norm1 = nn.GroupNorm(num_groups, in_channels)
|
||||
self.conv1 = conv_fn(in_channels, self.filters, kernel_size=(3, 3, 3), bias=False)
|
||||
self.norm2 = nn.GroupNorm(num_groups, self.filters, device=device, dtype=dtype)
|
||||
self.norm2 = nn.GroupNorm(num_groups, self.filters)
|
||||
self.conv2 = conv_fn(self.filters, self.filters, kernel_size=(3, 3, 3), bias=False)
|
||||
if in_channels != filters:
|
||||
if self.use_conv_shortcut:
|
||||
|
|
@ -103,8 +99,6 @@ class ResBlock(nn.Module):
|
|||
self.conv3 = conv_fn(in_channels, self.filters, kernel_size=(1, 1, 1), bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
# device, dtype = x.device, x.dtype
|
||||
# input_dim = x.shape[1]
|
||||
residual = x
|
||||
x = self.norm1(x)
|
||||
x = self.activate(x)
|
||||
|
|
@ -140,8 +134,6 @@ class Encoder(nn.Module):
|
|||
temporal_downsample=(False, True, True),
|
||||
num_groups=32, # for nn.GroupNorm
|
||||
activation_fn="swish",
|
||||
device="cpu",
|
||||
dtype=torch.bfloat16,
|
||||
):
|
||||
super().__init__()
|
||||
self.filters = filters
|
||||
|
|
@ -154,18 +146,12 @@ class Encoder(nn.Module):
|
|||
|
||||
self.activation_fn = get_activation_fn(activation_fn)
|
||||
self.activate = self.activation_fn()
|
||||
self.conv_fn = functools.partial(
|
||||
CausalConv3d,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.conv_fn = CausalConv3d
|
||||
self.block_args = dict(
|
||||
conv_fn=self.conv_fn,
|
||||
dtype=dtype,
|
||||
activation_fn=self.activation_fn,
|
||||
use_conv_shortcut=False,
|
||||
num_groups=self.num_groups,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# first layer conv
|
||||
|
|
@ -174,8 +160,6 @@ class Encoder(nn.Module):
|
|||
filters,
|
||||
kernel_size=(3, 3, 3),
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# ResBlocks and conv downsample
|
||||
|
|
@ -214,13 +198,9 @@ class Encoder(nn.Module):
|
|||
prev_filters = filters # update in_channels
|
||||
|
||||
# MAGVIT uses Group Normalization
|
||||
self.norm1 = nn.GroupNorm(
|
||||
self.num_groups, prev_filters, dtype=dtype, device=device
|
||||
) # separate <prev_filters> channels into 32 groups
|
||||
self.norm1 = nn.GroupNorm(self.num_groups, prev_filters)
|
||||
|
||||
self.conv2 = self.conv_fn(
|
||||
prev_filters, self.embedding_dim, kernel_size=(1, 1, 1), dtype=dtype, device=device, padding="same"
|
||||
)
|
||||
self.conv2 = self.conv_fn(prev_filters, self.embedding_dim, kernel_size=(1, 1, 1), padding="same")
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_in(x)
|
||||
|
|
@ -252,8 +232,6 @@ class Decoder(nn.Module):
|
|||
temporal_downsample=(False, True, True),
|
||||
num_groups=32, # for nn.GroupNorm
|
||||
activation_fn="swish",
|
||||
device="cpu",
|
||||
dtype=torch.bfloat16,
|
||||
):
|
||||
super().__init__()
|
||||
self.filters = filters
|
||||
|
|
@ -267,18 +245,12 @@ class Decoder(nn.Module):
|
|||
|
||||
self.activation_fn = get_activation_fn(activation_fn)
|
||||
self.activate = self.activation_fn()
|
||||
self.conv_fn = functools.partial(
|
||||
CausalConv3d,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.conv_fn = CausalConv3d
|
||||
self.block_args = dict(
|
||||
conv_fn=self.conv_fn,
|
||||
activation_fn=self.activation_fn,
|
||||
use_conv_shortcut=False,
|
||||
num_groups=self.num_groups,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
filters = self.filters * self.channel_multipliers[-1]
|
||||
|
|
@ -323,9 +295,9 @@ class Decoder(nn.Module):
|
|||
nn.Identity(prev_filters),
|
||||
)
|
||||
|
||||
self.norm1 = nn.GroupNorm(self.num_groups, prev_filters, device=device, dtype=dtype)
|
||||
self.norm1 = nn.GroupNorm(self.num_groups, prev_filters)
|
||||
|
||||
self.conv_out = self.conv_fn(filters, in_out_channels, 3, dtype=dtype, device=device)
|
||||
self.conv_out = self.conv_fn(filters, in_out_channels, 3)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
|
|
@ -364,8 +336,6 @@ class VAE_Temporal(nn.Module):
|
|||
temporal_downsample=(True, True, False),
|
||||
num_groups=32, # for nn.GroupNorm
|
||||
activation_fn="swish",
|
||||
device="cpu",
|
||||
dtype=torch.bfloat16,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
|
@ -383,12 +353,10 @@ class VAE_Temporal(nn.Module):
|
|||
temporal_downsample=temporal_downsample,
|
||||
num_groups=num_groups, # for nn.GroupNorm
|
||||
activation_fn=activation_fn,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.quant_conv = CausalConv3d(2 * latent_embed_dim, 2 * embed_dim, 1, device=device, dtype=dtype)
|
||||
self.quant_conv = CausalConv3d(2 * latent_embed_dim, 2 * embed_dim, 1)
|
||||
|
||||
self.post_quant_conv = CausalConv3d(embed_dim, latent_embed_dim, 1, device=device, dtype=dtype)
|
||||
self.post_quant_conv = CausalConv3d(embed_dim, latent_embed_dim, 1)
|
||||
self.decoder = Decoder(
|
||||
in_out_channels=in_out_channels,
|
||||
latent_embed_dim=latent_embed_dim,
|
||||
|
|
@ -398,8 +366,6 @@ class VAE_Temporal(nn.Module):
|
|||
temporal_downsample=temporal_downsample,
|
||||
num_groups=num_groups, # for nn.GroupNorm
|
||||
activation_fn=activation_fn,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def get_latent_size(self, input_size):
|
||||
|
|
|
|||
|
|
@ -79,6 +79,10 @@ def merge_args(cfg, args, training=False):
|
|||
if args.data_path is not None:
|
||||
cfg.dataset["data_path"] = args.data_path
|
||||
args.data_path = None
|
||||
if not training and args.image_size is not None and "dataset" in cfg:
|
||||
cfg.dataset["image_size"] = args.image_size
|
||||
if not training and args.num_frames is not None and "dataset" in cfg:
|
||||
cfg.dataset["num_frames"] = args.num_frames
|
||||
if not training and args.cfg_scale is not None:
|
||||
cfg.scheduler["cfg_scale"] = args.cfg_scale
|
||||
args.cfg_scale = None
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from colossalai.cluster import DistCoordinator
|
|||
from mmengine.runner import set_random_seed
|
||||
from tqdm import tqdm
|
||||
|
||||
from opensora.acceleration.parallel_states import get_data_parallel_group, set_sequence_parallel_group
|
||||
from opensora.acceleration.parallel_states import get_data_parallel_group
|
||||
from opensora.datasets import prepare_dataloader, save_sample
|
||||
from opensora.models.vae.losses import VAELoss
|
||||
from opensora.registry import DATASETS, MODELS, build_module
|
||||
|
|
@ -27,11 +27,6 @@ def main():
|
|||
use_dist = True
|
||||
colossalai.launch_from_torch({})
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
if coordinator.world_size > 1:
|
||||
set_sequence_parallel_group(dist.group.WORLD)
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
use_dist = False
|
||||
|
||||
|
|
@ -59,88 +54,40 @@ def main():
|
|||
process_group=get_data_parallel_group(),
|
||||
)
|
||||
print(f"Dataset contains {len(dataset):,} videos ({cfg.dataset.data_path})")
|
||||
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size
|
||||
total_batch_size = cfg.batch_size * dist.get_world_size()
|
||||
print(f"Total batch size: {total_batch_size}")
|
||||
|
||||
# ======================================================
|
||||
# 4. build model & load weights
|
||||
# ======================================================
|
||||
# 3.1. build model
|
||||
if cfg.get("vae_2d", None) is not None:
|
||||
vae_2d = build_module(cfg.vae_2d, MODELS)
|
||||
vae_2d.to(device, dtype).eval()
|
||||
model = build_module(
|
||||
cfg.model,
|
||||
MODELS,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
# discriminator = build_module(cfg.discriminator, MODELS, device=device)
|
||||
|
||||
# 3.2. move to device & eval
|
||||
# discriminator = discriminator.to(device, dtype).eval()
|
||||
|
||||
# 3.4. support for multi-resolution
|
||||
# model_args = dict()
|
||||
# if cfg.multi_resolution:
|
||||
# image_size = cfg.dataset.image_size
|
||||
# hw = torch.tensor([image_size], device=device, dtype=dtype).repeat(cfg.batch_size, 1)
|
||||
# ar = torch.tensor([[image_size[0] / image_size[1]]], device=device, dtype=dtype).repeat(cfg.batch_size, 1)
|
||||
# model_args["data_info"] = dict(ar=ar, hw=hw)
|
||||
# 4.1. build model
|
||||
model = build_module(cfg.model, MODELS)
|
||||
model.to(device, dtype).eval()
|
||||
|
||||
# ======================================================
|
||||
# 4. inference
|
||||
# 5. inference
|
||||
# ======================================================
|
||||
save_dir = cfg.save_dir
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# 4.1. batch generation
|
||||
|
||||
# define loss function
|
||||
if cfg.calc_loss:
|
||||
vae_loss_fn = VAELoss(
|
||||
logvar_init=cfg.logvar_init,
|
||||
perceptual_loss_weight=cfg.perceptual_loss_weight,
|
||||
kl_loss_weight=cfg.kl_loss_weight,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# adversarial_loss_fn = AdversarialLoss(
|
||||
# discriminator_factor=cfg.discriminator_factor,
|
||||
# discriminator_start=cfg.discriminator_start,
|
||||
# generator_factor=cfg.generator_factor,
|
||||
# generator_loss_type=cfg.generator_loss_type,
|
||||
# )
|
||||
|
||||
# disc_loss_fn = DiscriminatorLoss(
|
||||
# discriminator_factor=cfg.discriminator_factor,
|
||||
# discriminator_start=cfg.discriminator_start,
|
||||
# discriminator_loss_type=cfg.discriminator_loss_type,
|
||||
# lecam_loss_weight=cfg.lecam_loss_weight,
|
||||
# gradient_penalty_loss_weight=cfg.gradient_penalty_loss_weight,
|
||||
# )
|
||||
|
||||
# # LeCam EMA for discriminator
|
||||
|
||||
# lecam_ema = LeCamEMA(decay=cfg.ema_decay, dtype=dtype, device=device)
|
||||
|
||||
running_loss = 0.0
|
||||
running_nll = 0.0
|
||||
loss_steps = 0
|
||||
|
||||
# disc_time_downsample_factor = 2 ** len(cfg.discriminator.channel_multipliers)
|
||||
# if cfg.dataset.num_frames % disc_time_downsample_factor != 0:
|
||||
# disc_time_padding = disc_time_downsample_factor - cfg.dataset.num_frames % disc_time_downsample_factor
|
||||
# else:
|
||||
# disc_time_padding = 0
|
||||
vae_loss_fn = VAELoss(
|
||||
logvar_init=cfg.get("logvar_init", 0.0),
|
||||
perceptual_loss_weight=cfg.perceptual_loss_weight,
|
||||
kl_loss_weight=cfg.kl_loss_weight,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# get total number of steps
|
||||
total_steps = len(dataloader)
|
||||
if cfg.max_test_samples is not None:
|
||||
total_steps = min(int(cfg.max_test_samples // cfg.batch_size), total_steps)
|
||||
print(f"limiting test dataset to {int(cfg.max_test_samples//cfg.batch_size) * cfg.batch_size}")
|
||||
dataloader_iter = iter(dataloader)
|
||||
|
||||
running_loss = running_nll = 0.0
|
||||
loss_steps = 0
|
||||
with tqdm(
|
||||
range(total_steps),
|
||||
disable=not coordinator.is_master(),
|
||||
|
|
@ -151,95 +98,28 @@ def main():
|
|||
batch = next(dataloader_iter)
|
||||
x = batch["video"].to(device, dtype) # [B, C, T, H, W]
|
||||
|
||||
# ===== Spatial VAE =====
|
||||
if cfg.get("vae_2d", None) is not None:
|
||||
x_z = vae_2d.encode(x)
|
||||
x_z_debug = vae_2d.decode(x_z)
|
||||
# ===== VAE =====
|
||||
z, posterior, x_z = model.encode(x, training=True)
|
||||
x_rec, _ = model.decode(z, num_frames=x.size(2))
|
||||
x_ref = model.spatial_vae.decode(x_z)
|
||||
|
||||
# ====== VAE ======
|
||||
x_z_rec, posterior, z = model(x_z)
|
||||
x_rec = vae_2d.decode(x_z_rec)
|
||||
|
||||
if cfg.calc_loss:
|
||||
# simple nll loss
|
||||
nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(x, x_rec, posterior)
|
||||
|
||||
# fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value=0.0, dim=2)
|
||||
# fake_logits = discriminator(fake_video.contiguous())
|
||||
# adversarial_loss = adversarial_loss_fn(
|
||||
# fake_logits,
|
||||
# nll_loss,
|
||||
# vae.get_last_layer(),
|
||||
# cfg.discriminator_start + 1, # Hack to use discriminator
|
||||
# is_training=vae.training,
|
||||
# )
|
||||
|
||||
# vae_loss = weighted_nll_loss + weighted_kl_loss + adversarial_loss
|
||||
vae_loss = weighted_nll_loss + weighted_kl_loss
|
||||
|
||||
# # ====== Discriminator Loss ======
|
||||
# real_video = pad_at_dim(video, (disc_time_padding, 0), value=0.0, dim=2)
|
||||
# fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value=0.0, dim=2)
|
||||
|
||||
# if cfg.gradient_penalty_loss_weight is not None and cfg.gradient_penalty_loss_weight > 0.0:
|
||||
# real_video = real_video.requires_grad_()
|
||||
# real_logits = discriminator(
|
||||
# real_video.contiguous()
|
||||
# ) # SCH: not detached for now for gradient_penalty calculation
|
||||
# else:
|
||||
# real_logits = discriminator(real_video.contiguous().detach())
|
||||
|
||||
# fake_logits = discriminator(fake_video.contiguous().detach())
|
||||
|
||||
# lecam_ema_real, lecam_ema_fake = lecam_ema.get()
|
||||
# weighted_d_adversarial_loss, lecam_loss, gradient_penalty_loss = disc_loss_fn(
|
||||
# real_logits,
|
||||
# fake_logits,
|
||||
# cfg.discriminator_start + 1, # Hack to use discriminator
|
||||
# lecam_ema_real=lecam_ema_real,
|
||||
# lecam_ema_fake=lecam_ema_fake,
|
||||
# real_video=real_video if cfg.gradient_penalty_loss_weight is not None else None,
|
||||
# )
|
||||
|
||||
# disc_loss = weighted_d_adversarial_loss + lecam_loss + gradient_penalty_loss
|
||||
|
||||
loss_steps += 1
|
||||
# running_disc_loss = disc_loss.item() / loss_steps + running_disc_loss * ((loss_steps - 1) / loss_steps)
|
||||
running_loss = vae_loss.item() / loss_steps + running_loss * ((loss_steps - 1) / loss_steps)
|
||||
running_nll = nll_loss.item() / loss_steps + running_nll * ((loss_steps - 1) / loss_steps)
|
||||
|
||||
# ===== Spatial VAE =====
|
||||
# loss calculation
|
||||
nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(x, x_rec, posterior)
|
||||
vae_loss = weighted_nll_loss + weighted_kl_loss
|
||||
loss_steps += 1
|
||||
running_loss = vae_loss.item() / loss_steps + running_loss * ((loss_steps - 1) / loss_steps)
|
||||
running_nll = nll_loss.item() / loss_steps + running_nll * ((loss_steps - 1) / loss_steps)
|
||||
|
||||
if not use_dist or coordinator.is_master():
|
||||
for idx in range(len(x)):
|
||||
for idx, vid in enumerate(x):
|
||||
pos = step * cfg.batch_size + idx
|
||||
save_path = os.path.join(save_dir, f"sample_{pos}")
|
||||
save_sample(x[idx], fps=cfg.fps, save_path=save_path + "_original")
|
||||
save_sample(x_rec[idx], fps=cfg.fps, save_path=save_path + "_pipeline")
|
||||
if cfg.get("vae_2d", None) is not None:
|
||||
save_sample(x_z_debug[idx], fps=cfg.fps, save_path=save_path + "_2d")
|
||||
save_path = os.path.join(save_dir, f"sample_{pos:03d}")
|
||||
save_sample(vid, fps=cfg.fps, save_path=save_path + "_ori")
|
||||
save_sample(x_rec[idx], fps=cfg.fps, save_path=save_path + "_rec")
|
||||
save_sample(x_ref[idx], fps=cfg.fps, save_path=save_path + "_ref")
|
||||
|
||||
# if cfg.get("use_pipeline") == True:
|
||||
# for idx, (sample_original, sample_pipeline, sample_2d) in enumerate(
|
||||
# zip(video, recon_video, recon_2d)
|
||||
# ):
|
||||
# pos = step * cfg.batch_size + idx
|
||||
# save_path = os.path.join(save_dir, f"sample_{pos}")
|
||||
# save_sample(sample_original, fps=cfg.fps, save_path=save_path + "_original")
|
||||
# save_sample(sample_2d, fps=cfg.fps, save_path=save_path + "_2d")
|
||||
# save_sample(sample_pipeline, fps=cfg.fps, save_path=save_path + "_pipeline")
|
||||
|
||||
# else:
|
||||
# for idx, (original, recon) in enumerate(zip(video, recon_video)):
|
||||
# pos = step * cfg.batch_size + idx
|
||||
# save_path = os.path.join(save_dir, f"sample_{pos}")
|
||||
# save_sample(original, fps=cfg.fps, save_path=save_path + "_original")
|
||||
# save_sample(recon, fps=cfg.fps, save_path=save_path + "_recon")
|
||||
|
||||
if cfg.calc_loss:
|
||||
print("test vae loss:", running_loss)
|
||||
print("test nll loss:", running_nll)
|
||||
# print("test disc loss:", running_disc_loss)
|
||||
print("test vae loss:", running_loss)
|
||||
print("test nll loss:", running_nll)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -14,14 +14,9 @@ from tqdm import tqdm
|
|||
|
||||
import wandb
|
||||
from opensora.acceleration.checkpoint import set_grad_checkpoint
|
||||
from opensora.acceleration.parallel_states import (
|
||||
get_data_parallel_group,
|
||||
set_data_parallel_group,
|
||||
set_sequence_parallel_group,
|
||||
)
|
||||
from opensora.acceleration.plugin import ZeroSeqParallelPlugin
|
||||
from opensora.acceleration.parallel_states import get_data_parallel_group, set_data_parallel_group
|
||||
from opensora.datasets import prepare_dataloader
|
||||
from opensora.models.vae.losses import AdversarialLoss, DiscriminatorLoss, VAELoss
|
||||
from opensora.models.vae.losses import VAELoss
|
||||
from opensora.registry import DATASETS, MODELS, build_module
|
||||
from opensora.utils.ckpt_utils import create_logger, load_json, save_json
|
||||
from opensora.utils.config_utils import (
|
||||
|
|
@ -78,16 +73,6 @@ def main():
|
|||
max_norm=cfg.grad_clip,
|
||||
)
|
||||
set_data_parallel_group(dist.group.WORLD)
|
||||
elif cfg.plugin == "zero2-seq":
|
||||
plugin = ZeroSeqParallelPlugin(
|
||||
sp_size=cfg.sp_size,
|
||||
stage=2,
|
||||
precision=cfg.dtype,
|
||||
initial_scale=2**16,
|
||||
max_norm=cfg.grad_clip,
|
||||
)
|
||||
set_sequence_parallel_group(plugin.sp_group)
|
||||
set_data_parallel_group(plugin.dp_group)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {cfg.plugin}")
|
||||
booster = Booster(plugin=plugin)
|
||||
|
|
@ -110,83 +95,40 @@ def main():
|
|||
)
|
||||
# TODO: use plugin's prepare dataloader
|
||||
dataloader = prepare_dataloader(**dataloader_args)
|
||||
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size
|
||||
total_batch_size = cfg.batch_size * dist.get_world_size()
|
||||
logger.info(f"Total batch size: {total_batch_size}")
|
||||
|
||||
# ======================================================
|
||||
# 4. build model
|
||||
# ======================================================
|
||||
# 4.1. build model
|
||||
if cfg.get("vae_2d", None) is not None:
|
||||
vae_2d = build_module(cfg.vae_2d, MODELS)
|
||||
vae_2d.to(device, dtype).eval()
|
||||
|
||||
model = build_module(
|
||||
cfg.model,
|
||||
MODELS,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
model = build_module(cfg.model, MODELS)
|
||||
model.to(device, dtype)
|
||||
model_numel, model_numel_trainable = get_model_numel(model)
|
||||
logger.info(
|
||||
f"Trainable model params: {format_numel_str(model_numel_trainable)}, Total model params: {format_numel_str(model_numel)}"
|
||||
)
|
||||
|
||||
# discriminator = build_module(cfg.discriminator, MODELS, device=device)
|
||||
# discriminator_numel, discriminator_numel_trainable = get_model_numel(discriminator)
|
||||
# logger.info(
|
||||
# f"Trainable discriminator params: {format_numel_str(discriminator_numel_trainable)}, Total model params: {format_numel_str(discriminator_numel)}"
|
||||
# )
|
||||
|
||||
# # LeCam Initialization
|
||||
# lecam_ema = LeCamEMA(decay=cfg.ema_decay, dtype=dtype, device=device)
|
||||
|
||||
# 4.3. move to device
|
||||
model = model.to(device, dtype)
|
||||
# discriminator = discriminator.to(device, dtype)
|
||||
|
||||
# 4.4 loss functions
|
||||
vae_loss_fn = VAELoss(
|
||||
logvar_init=cfg.logvar_init,
|
||||
logvar_init=cfg.get("logvar_init", 0.0),
|
||||
perceptual_loss_weight=cfg.perceptual_loss_weight,
|
||||
kl_loss_weight=cfg.kl_loss_weight,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
adversarial_loss_fn = AdversarialLoss(
|
||||
discriminator_factor=cfg.discriminator_factor,
|
||||
discriminator_start=cfg.discriminator_start,
|
||||
generator_factor=cfg.generator_factor,
|
||||
generator_loss_type=cfg.generator_loss_type,
|
||||
)
|
||||
|
||||
disc_loss_fn = DiscriminatorLoss(
|
||||
discriminator_factor=cfg.discriminator_factor,
|
||||
discriminator_start=cfg.discriminator_start,
|
||||
discriminator_loss_type=cfg.discriminator_loss_type,
|
||||
lecam_loss_weight=cfg.lecam_loss_weight,
|
||||
gradient_penalty_loss_weight=cfg.gradient_penalty_loss_weight,
|
||||
)
|
||||
|
||||
# 4.5. setup optimizer
|
||||
# vae optimizer
|
||||
optimizer = HybridAdam(
|
||||
filter(lambda p: p.requires_grad, model.parameters()), lr=cfg.lr, weight_decay=0, adamw_mode=True
|
||||
)
|
||||
lr_scheduler = None
|
||||
# disc optimizer
|
||||
# disc_optimizer = HybridAdam(
|
||||
# filter(lambda p: p.requires_grad, discriminator.parameters()), lr=cfg.lr, weight_decay=0, adamw_mode=True
|
||||
# )
|
||||
# disc_lr_scheduler = None
|
||||
|
||||
# 4.6. prepare for training
|
||||
if cfg.grad_checkpoint:
|
||||
set_grad_checkpoint(model)
|
||||
# set_grad_checkpoint(discriminator)
|
||||
model.train()
|
||||
# discriminator.train()
|
||||
|
||||
# =======================================================
|
||||
# 5. boost model for distributed training with colossalai
|
||||
|
|
@ -203,11 +145,6 @@ def main():
|
|||
logger.info("Boost model for distributed training")
|
||||
num_steps_per_epoch = len(dataloader)
|
||||
|
||||
# discriminator, disc_optimizer, _, _, disc_lr_scheduler = booster.boost(
|
||||
# model=discriminator, optimizer=disc_optimizer, lr_scheduler=disc_lr_scheduler
|
||||
# )
|
||||
# logger.info("Boost discriminator for distributed training")
|
||||
|
||||
# =======================================================
|
||||
# 6. training loop
|
||||
# =======================================================
|
||||
|
|
@ -221,18 +158,6 @@ def main():
|
|||
booster.load_model(model, os.path.join(cfg.load, "model"))
|
||||
booster.load_optimizer(optimizer, os.path.join(cfg.load, "optimizer"))
|
||||
|
||||
# booster.load_model(discriminator, os.path.join(cfg.load, "discriminator"))
|
||||
# booster.load_optimizer(disc_optimizer, os.path.join(cfg.load, "disc_optimizer"))
|
||||
|
||||
# LeCam EMA for discriminator
|
||||
# lecam_path = os.path.join(cfg.load, "lecam_states.json")
|
||||
# if cfg.lecam_loss_weight is not None and os.path.exists(lecam_path):
|
||||
# lecam_state = load_json(lecam_path)
|
||||
# lecam_ema_real, lecam_ema_fake = lecam_state["lecam_ema_real"], lecam_state["lecam_ema_fake"]
|
||||
# lecam_ema = LeCamEMA(
|
||||
# decay=cfg.ema_decay, ema_real=lecam_ema_real, ema_fake=lecam_ema_fake, dtype=dtype, device=device
|
||||
# )
|
||||
|
||||
running_states = load_json(os.path.join(cfg.load, "running_states.json"))
|
||||
dist.barrier()
|
||||
start_epoch, start_step, sampler_start_idx = (
|
||||
|
|
@ -246,15 +171,6 @@ def main():
|
|||
dataloader.sampler.set_start_index(sampler_start_idx)
|
||||
|
||||
# 6.3. training loop
|
||||
|
||||
# calculate discriminator_time_padding
|
||||
# disc_time_downsample_factor = 2 ** len(cfg.discriminator.channel_multipliers)
|
||||
# if cfg.dataset.num_frames % disc_time_downsample_factor != 0:
|
||||
# disc_time_padding = disc_time_downsample_factor - cfg.dataset.num_frames % disc_time_downsample_factor
|
||||
# else:
|
||||
# disc_time_padding = 0
|
||||
# video_contains_first_frame = cfg.video_contains_first_frame
|
||||
|
||||
for epoch in range(start_epoch, cfg.epochs):
|
||||
dataloader.sampler.set_epoch(epoch)
|
||||
dataloader_iter = iter(dataloader)
|
||||
|
|
@ -269,112 +185,41 @@ def main():
|
|||
) as pbar:
|
||||
for step, batch in pbar:
|
||||
x = batch["video"].to(device, dtype) # [B, C, T, H, W]
|
||||
if random.random() < 0.5:
|
||||
if random.random() < cfg.get("mixed_image_ratio", 0.0):
|
||||
x = x[:, :, :1, :, :]
|
||||
|
||||
# ===== Spatial VAE =====
|
||||
if cfg.get("vae_2d", None) is not None:
|
||||
with torch.no_grad():
|
||||
x_z = vae_2d.encode(x)
|
||||
vae_2d.decode(x_z)
|
||||
|
||||
# ====== VAE ======
|
||||
x_z_rec, posterior, z = model(x_z)
|
||||
x_rec = vae_2d.decode(x_z_rec)
|
||||
# ===== VAE =====
|
||||
x_rec, x_z_rec, z, posterior, x_z = model(x)
|
||||
|
||||
# ====== Generator Loss ======
|
||||
# simple nll loss
|
||||
_, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(x, x_rec, posterior)
|
||||
# _, weighted_z_nll_loss, _ = vae_loss_fn(x_z, x_z_rec, posterior)
|
||||
# _, debug_loss, _ = vae_loss_fn(x, x_z_debug, posterior)
|
||||
# _, image_identity_loss, _ = vae_loss_fn(x_z, z, posterior)
|
||||
vae_loss = torch.tensor(0.0, device=device, dtype=dtype)
|
||||
log_dict = {}
|
||||
if cfg.get("use_real_rec_loss", False):
|
||||
_, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(x, x_rec, posterior)
|
||||
vae_loss += weighted_nll_loss + weighted_kl_loss
|
||||
log_dict["kl_loss"] = weighted_kl_loss.item()
|
||||
log_dict["nll_loss"] = weighted_nll_loss.item()
|
||||
if cfg.get("use_z_rec_loss", False):
|
||||
_, weighted_z_nll_loss, _ = vae_loss_fn(x_z, x_z_rec, posterior)
|
||||
vae_loss += weighted_z_nll_loss
|
||||
log_dict["z_nll_loss"] = weighted_z_nll_loss.item()
|
||||
if cfg.get("use_image_identity_loss", False):
|
||||
_, image_identity_loss, _ = vae_loss_fn(x_z, z, posterior)
|
||||
vae_loss += image_identity_loss
|
||||
log_dict["image_identity_loss"] = image_identity_loss.item()
|
||||
|
||||
# adversarial_loss = torch.tensor(0.0)
|
||||
# adversarial loss
|
||||
# if global_step > cfg.discriminator_start:
|
||||
# # padded videos for GAN
|
||||
# fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value=0.0, dim=2)
|
||||
# fake_logits = discriminator(fake_video.contiguous())
|
||||
# adversarial_loss = adversarial_loss_fn(
|
||||
# fake_logits,
|
||||
# nll_loss,
|
||||
# vae.module.get_last_layer(),
|
||||
# global_step,
|
||||
# is_training=vae.training,
|
||||
# )
|
||||
|
||||
# vae_loss = weighted_nll_loss + weighted_kl_loss + adversarial_loss + weighted_z_nll_loss
|
||||
# vae_loss = weighted_nll_loss + weighted_kl_loss + weighted_z_nll_loss + image_identity_loss
|
||||
# vae_loss = weighted_z_nll_loss + image_identity_loss
|
||||
# vae_loss = weighted_nll_loss + weighted_kl_loss + weighted_z_nll_loss
|
||||
# vae_loss = weighted_z_nll_loss
|
||||
vae_loss = weighted_nll_loss + weighted_kl_loss
|
||||
|
||||
optimizer.zero_grad()
|
||||
# Backward & update
|
||||
booster.backward(loss=vae_loss, optimizer=optimizer)
|
||||
# # NOTE: clip gradients? this is done in Open-Sora-Plan
|
||||
# torch.nn.utils.clip_grad_norm_(vae.parameters(), 1) # NOTE: done by grad_clip
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Log loss values:
|
||||
all_reduce_mean(vae_loss) # NOTE: this is to get average loss for logging
|
||||
all_reduce_mean(vae_loss)
|
||||
running_loss += vae_loss.item()
|
||||
global_step = epoch * num_steps_per_epoch + step
|
||||
log_step += 1
|
||||
acc_step += 1
|
||||
|
||||
# ====== Discriminator Loss ======
|
||||
# if global_step > cfg.discriminator_start:
|
||||
# # if video_contains_first_frame:
|
||||
# # Since we don't have enough T frames, pad anyways
|
||||
# real_video = pad_at_dim(video, (disc_time_padding, 0), value=0.0, dim=2)
|
||||
# fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value=0.0, dim=2)
|
||||
|
||||
# if cfg.gradient_penalty_loss_weight is not None and cfg.gradient_penalty_loss_weight > 0.0:
|
||||
# real_video = real_video.requires_grad_()
|
||||
# real_logits = discriminator(
|
||||
# real_video.contiguous()
|
||||
# ) # SCH: not detached for now for gradient_penalty calculation
|
||||
# else:
|
||||
# real_logits = discriminator(real_video.contiguous().detach())
|
||||
|
||||
# fake_logits = discriminator(fake_video.contiguous().detach())
|
||||
|
||||
# lecam_ema_real, lecam_ema_fake = lecam_ema.get()
|
||||
|
||||
# weighted_d_adversarial_loss, lecam_loss, gradient_penalty_loss = disc_loss_fn(
|
||||
# real_logits,
|
||||
# fake_logits,
|
||||
# global_step,
|
||||
# lecam_ema_real=lecam_ema_real,
|
||||
# lecam_ema_fake=lecam_ema_fake,
|
||||
# real_video=real_video if cfg.gradient_penalty_loss_weight is not None else None,
|
||||
# )
|
||||
# disc_loss = weighted_d_adversarial_loss + lecam_loss + gradient_penalty_loss
|
||||
# if cfg.lecam_loss_weight is not None:
|
||||
# ema_real = torch.mean(real_logits.clone().detach()).to(device, dtype)
|
||||
# ema_fake = torch.mean(fake_logits.clone().detach()).to(device, dtype)
|
||||
# all_reduce_mean(ema_real)
|
||||
# all_reduce_mean(ema_fake)
|
||||
# lecam_ema.update(ema_real, ema_fake)
|
||||
|
||||
# disc_optimizer.zero_grad()
|
||||
# # Backward & update
|
||||
# booster.backward(loss=disc_loss, optimizer=disc_optimizer)
|
||||
# # # NOTE: TODO: clip gradients? this is done in Open-Sora-Plan
|
||||
# # torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1) # NOTE: done by grad_clip
|
||||
# disc_optimizer.step()
|
||||
|
||||
# # Log loss values:
|
||||
# all_reduce_mean(disc_loss)
|
||||
# running_disc_loss += disc_loss.item()
|
||||
# else:
|
||||
# disc_loss = torch.tensor(0.0)
|
||||
# weighted_d_adversarial_loss = torch.tensor(0.0)
|
||||
# lecam_loss = torch.tensor(0.0)
|
||||
# gradient_penalty_loss = torch.tensor(0.0)
|
||||
|
||||
# Log to tensorboard
|
||||
if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0:
|
||||
avg_loss = running_loss / log_step
|
||||
|
|
@ -393,16 +238,8 @@ def main():
|
|||
"num_samples": global_step * total_batch_size,
|
||||
"epoch": epoch,
|
||||
"loss": vae_loss.item(),
|
||||
"kl_loss": weighted_kl_loss.item(),
|
||||
# "gen_adv_loss": adversarial_loss.item(),
|
||||
# "disc_loss": disc_loss.item(),
|
||||
# "lecam_loss": lecam_loss.item(),
|
||||
# "r1_grad_penalty": gradient_penalty_loss.item(),
|
||||
"nll_loss": weighted_nll_loss.item(),
|
||||
# "z_nll_loss": weighted_z_nll_loss.item(),
|
||||
# "image_identity_loss": image_identity_loss.item(),
|
||||
# "debug_loss": debug_loss.item(),
|
||||
"avg_loss": avg_loss,
|
||||
**log_dict,
|
||||
},
|
||||
step=global_step,
|
||||
)
|
||||
|
|
@ -412,38 +249,22 @@ def main():
|
|||
save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}")
|
||||
os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) # already handled in booster save_model
|
||||
booster.save_model(model, os.path.join(save_dir, "model"), shard=True)
|
||||
# booster.save_model(discriminator, os.path.join(save_dir, "discriminator"), shard=True)
|
||||
booster.save_optimizer(
|
||||
optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096
|
||||
)
|
||||
# booster.save_optimizer(
|
||||
# disc_optimizer, os.path.join(save_dir, "disc_optimizer"), shard=True, size_per_shard=4096
|
||||
# )
|
||||
|
||||
running_states = {
|
||||
"epoch": epoch,
|
||||
"step": step + 1,
|
||||
"global_step": global_step + 1,
|
||||
"sample_start_index": (step + 1) * cfg.batch_size,
|
||||
}
|
||||
|
||||
# lecam_ema_real, lecam_ema_fake = lecam_ema.get()
|
||||
# lecam_state = {
|
||||
# "lecam_ema_real": lecam_ema_real.item(),
|
||||
# "lecam_ema_fake": lecam_ema_fake.item(),
|
||||
# }
|
||||
if coordinator.is_master():
|
||||
save_json(running_states, os.path.join(save_dir, "running_states.json"))
|
||||
# if cfg.lecam_loss_weight is not None:
|
||||
# save_json(lecam_state, os.path.join(save_dir, "lecam_states.json"))
|
||||
dist.barrier()
|
||||
|
||||
logger.info(
|
||||
f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}"
|
||||
)
|
||||
|
||||
# print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
|
||||
|
||||
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
|
||||
dataloader.sampler.set_start_index(0)
|
||||
start_step = 0
|
||||
|
|
|
|||
Loading…
Reference in a new issue