[wip] debug vae

This commit is contained in:
zhengzangw 2024-04-26 07:27:26 +00:00
parent a4652a8aef
commit 478b585024
4 changed files with 745 additions and 703 deletions

View file

@ -1,21 +1,20 @@
num_frames = 16
frame_interval = 3
image_size = (128, 128)
use_pipeline = True
# Define dataset
root = None
data_path = "CSV_PATH"
use_image_transform = False
num_workers = 4
video_contains_first_frame = False
dataset = dict(
type="VideoTextDataset",
data_path=None,
num_frames=16,
frame_interval=3,
image_size=(128, 128),
)
# Define acceleration
num_workers = 4
dtype = "bf16"
grad_checkpoint = True
plugin = "zero2"
sp_size = 1
use_pipeline = True
video_contains_first_frame = False
# Define model
vae_2d = dict(
@ -23,50 +22,49 @@ vae_2d = dict(
from_pretrained="stabilityai/sd-vae-ft-ema",
# SDXL
)
model = dict(
type="VAE_MAGVIT_V2",
in_out_channels = 4,
latent_embed_dim = 4,
filters = 128,
num_res_blocks = 4,
channel_multipliers = (1, 2, 2, 4),
temporal_downsample = (False, True, True),
num_groups = 32, # for nn.GroupNorm
kl_embed_dim = 4,
activation_fn = 'swish',
separate_first_frame_encoding = False,
disable_space = True,
encoder_double_z = True,
custom_conv_padding = None
in_out_channels=4,
latent_embed_dim=4,
filters=128,
num_res_blocks=4,
channel_multipliers=(1, 2, 2, 4),
temporal_downsample=(False, True, True),
num_groups=32, # for nn.GroupNorm
kl_embed_dim=4,
activation_fn="swish",
separate_first_frame_encoding=False,
disable_space=True,
encoder_double_z=True,
custom_conv_padding=None,
)
discriminator = dict(
type="DISCRIMINATOR_3D",
image_size = (16, 16), # NOTE: here image size is different
num_frames = num_frames,
in_channels = 4,
filters = 128,
use_pretrained=True, # NOTE: set to False only if we want to disable load
image_size=(16, 16), # NOTE: here image size is different
num_frames=16,
in_channels=4,
filters=128,
use_pretrained=True, # NOTE: set to False only if we want to disable load
# channel_multipliers = (2,4,4,4,4), # (2,4,4,4) for 64x64 resolution
channel_multipliers= (2,4,4) # since on intermediate layer dimension ofs z
channel_multipliers=(2, 4, 4), # since on intermediate layer dimension ofs z
)
# loss weights
logvar_init=0.0
# loss weights
logvar_init = 0.0
kl_loss_weight = 0.000001
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
discriminator_factor = 1.0 # for discriminator adversarial loss
generator_factor = 0.1 # SCH: generator adversarial loss, MAGVIT v2 uses 0.1
lecam_loss_weight = None # NOTE: MAVGIT v2 use 0.001
discriminator_loss_type="non-saturating"
generator_loss_type="non-saturating"
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
discriminator_factor = 1.0 # for discriminator adversarial loss
generator_factor = 0.1 # SCH: generator adversarial loss, MAGVIT v2 uses 0.1
lecam_loss_weight = None # NOTE: MAVGIT v2 use 0.001
discriminator_loss_type = "non-saturating"
generator_loss_type = "non-saturating"
# discriminator_loss_type="hinge"
# generator_loss_type="hinge"
discriminator_start = 100 # 8k data / (8*32) = 31 steps per epoch, use around 3 epochs
gradient_penalty_loss_weight = None # 10 # SCH: MAGVIT uses 10, opensora plan doesn't use
discriminator_start = 100 # 8k data / (8*32) = 31 steps per epoch, use around 3 epochs
gradient_penalty_loss_weight = None # 10 # SCH: MAGVIT uses 10, opensora plan doesn't use
ema_decay = 0.999 # ema decay factor for generator
@ -76,11 +74,11 @@ outputs = "outputs"
wandb = False
# Training
''' NOTE:
""" NOTE:
magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128
==> ours num_frams = 16, reso = 256, so samples (K) * epochs ~ [500 - 1200],
==> ours num_frams = 16, reso = 256, so samples (K) * epochs ~ [500 - 1200],
3-6 epochs for pexel, from pexel observation its correct
'''
"""
epochs = 200
log_every = 1

View file

@ -1,12 +1,9 @@
import functools
import math
from typing import Any, Optional, Sequence, Type
import torch.nn as nn
import numpy as np
import torch
from taming.modules.losses.lpips import LPIPS # need to pip install https://github.com/CompVis/taming-transformers
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
import torch.nn as nn
# from taming.modules.losses.lpips import LPIPS # need to pip install https://github.com/CompVis/taming-transformers
# from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
from einops import rearrange
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
@ -17,23 +14,23 @@ from einops import rearrange
# if norm_type == 'LN':
# # supply a few args with partial function and pass the rest of the args when this norm_fn is called
# norm_fn = functools.partial(nn.LayerNorm, dtype=dtype)
# elif norm_type == 'GN': #
# elif norm_type == 'GN': #
# norm_fn = functools.partial(nn.GroupNorm, dtype=dtype)
# elif norm_type is None:
# norm_fn = lambda: (lambda x: x)
# else:
# raise NotImplementedError(f'norm_type: {norm_type}')
# return norm_fn
class DiagonalGaussianDistribution(object):
def __init__(
self,
parameters,
self,
parameters,
deterministic=False,
):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) # SCH: channels dim
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) # SCH: channels dim
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
@ -48,40 +45,39 @@ class DiagonalGaussianDistribution(object):
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
return torch.Tensor([0.0])
else:
if other is None: # SCH: assumes other is a standard normal distribution
return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ self.var - 1.0 - self.logvar,
dim=[1, 2, 3, 4])
if other is None: # SCH: assumes other is a standard normal distribution
return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3, 4])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3, 4])
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar,
dim=[1, 2, 3, 4],
)
def nll(self, sample, dims=[1,2,3,4]): # TODO: what does this do?
def nll(self, sample, dims=[1, 2, 3, 4]): # TODO: what does this do?
if self.deterministic:
return torch.Tensor([0.])
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
def mode(self): # SCH: used for vae inference?
def mode(self): # SCH: used for vae inference?
return self.mean
class VEA3DLoss(nn.Module):
def __init__(
self,
# disc_start,
logvar_init=0.0,
kl_weight=1.0,
# disc_start,
logvar_init=0.0,
kl_weight=1.0,
pixelloss_weight=1.0,
perceptual_weight=0.1,
perceptual_weight=0.1,
disc_loss="hinge",
):
super().__init__()
assert disc_loss in ["hinge", "vanilla"]
@ -92,28 +88,27 @@ class VEA3DLoss(nn.Module):
# output log variance
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
def forward(
self,
inputs,
reconstructions,
posteriors,
# optimizer_idx,
# global_step,
# global_step,
weights=None,
):
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
if self.perceptual_weight > 0: # NOTE: need in_channels == 3 in order to use!
if self.perceptual_weight > 0: # NOTE: need in_channels == 3 in order to use!
assert inputs.size(1) == 3, f"using vgg16 that requires 3 input channels but got {inputs.size(1)}"
# SCH: transform to [(B,T), C, H, W] shape for percetual loss over each frame
B = inputs.shape[0]
inputs = rearrange(inputs,"B C T H W -> (B T) C H W")
inputs = rearrange(inputs, "B C T H W -> (B T) C H W")
reconstructions = rearrange(reconstructions, "B C T H W -> (B T) C H W")
# permutated_input = torch.permute(inputs, (0, 2, 1, 3, 4)) # [B, C, T, H, W] --> [B, T, C, H, W]
# permutated_rec = torch.permute(reconstructions, (0, 2, 1, 3, 4))
# data_shape = permutated_input.size()
# p_loss = self.perceptual_loss(
# permutated_input.reshape(-1, data_shape[-3], data_shape[-2],data_shape[-1]).contiguous(),
# permutated_input.reshape(-1, data_shape[-3], data_shape[-2],data_shape[-1]).contiguous(),
# permutated_rec.reshape(-1, data_shape[-3], data_shape[-2],data_shape[-1]).contiguous()
# )
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
@ -126,32 +121,32 @@ class VEA3DLoss(nn.Module):
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
weighted_nll_loss = nll_loss
if weights is not None:
weighted_nll_loss = weights*nll_loss
weighted_nll_loss = weights * nll_loss
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
kl_loss = posteriors.kl()
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
loss = weighted_nll_loss + self.kl_weight * kl_loss # TODO: add discriminator loss later
loss = weighted_nll_loss + self.kl_weight * kl_loss # TODO: add discriminator loss later
return loss
class VEA3DLossWithDiscriminator(nn.Module):
def __init__(
self,
# disc_start,
logvar_init=0.0,
kl_weight=1.0,
# disc_start,
logvar_init=0.0,
kl_weight=1.0,
pixelloss_weight=1.0,
disc_num_layers=3,
disc_in_channels=3,
disc_factor=1.0,
disc_num_layers=3,
disc_in_channels=3,
disc_factor=1.0,
disc_weight=1.0,
perceptual_weight=1.0,
use_actnorm=False,
perceptual_weight=1.0,
use_actnorm=False,
disc_conditional=False,
disc_loss="hinge",
):
super().__init__()
assert disc_loss in ["hinge", "vanilla"]
@ -185,53 +180,53 @@ class VEA3DLossWithDiscriminator(nn.Module):
# d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
# d_weight = d_weight * self.discriminator_weight
# return d_weight
def forward(
self,
inputs,
reconstructions,
posteriors,
# optimizer_idx,
# global_step,
last_layer=None,
cond=None,
# global_step,
last_layer=None,
cond=None,
split="train",
weights=None,
):
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
if self.perceptual_weight > 0: # NOTE: need in_channels == 3 in order to use!
assert inputs.size(1) == 3, f"using vgg16 that requires 3 input channels but got {inputs.size(1)} "
if self.perceptual_weight > 0: # NOTE: need in_channels == 3 in order to use!
assert inputs.size(1) == 3, f"using vgg16 that requires 3 input channels but got {inputs.size(1)} "
# SCH: transform to [(B,T), C, H, W] shape for percetual loss over each frame
permutated_input = torch.permute(inputs, (0, 2, 1, 3, 4)) # [B, C, T, H, W] --> [B, T, C, H, W]
permutated_input = torch.permute(inputs, (0, 2, 1, 3, 4)) # [B, C, T, H, W] --> [B, T, C, H, W]
permutated_rec = torch.permute(reconstructions, (0, 2, 1, 3, 4))
data_shape = permutated_input.size()
p_loss = self.perceptual_loss(
permutated_input.reshape(-1, data_shape[-3], data_shape[-2],data_shape[-1]).contiguous(),
permutated_rec.reshape(-1, data_shape[-3], data_shape[-2],data_shape[-1]).contiguous()
permutated_input.reshape(-1, data_shape[-3], data_shape[-2], data_shape[-1]).contiguous(),
permutated_rec.reshape(-1, data_shape[-3], data_shape[-2], data_shape[-1]).contiguous(),
)
# SCH: shape back p_loss
permuted_p_loss = torch.permute(p_loss.reshape(data_shape[0], data_shape[1], 1, 1, 1), (0,2,1,3,4))
permuted_p_loss = torch.permute(p_loss.reshape(data_shape[0], data_shape[1], 1, 1, 1), (0, 2, 1, 3, 4))
rec_loss = rec_loss + self.perceptual_weight * permuted_p_loss
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
weighted_nll_loss = nll_loss
if weights is not None:
weighted_nll_loss = weights*nll_loss
weighted_nll_loss = weights * nll_loss
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
kl_loss = posteriors.kl()
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
loss = weighted_nll_loss + self.kl_weight * kl_loss # TODO: add discriminator loss later
loss = weighted_nll_loss + self.kl_weight * kl_loss # TODO: add discriminator loss later
# log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
# log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
# "{}/logvar".format(split): self.logvar.detach(),
# "{}/kl_loss".format(split): kl_loss.detach().mean(),
# "{}/kl_loss".format(split): kl_loss.detach().mean(),
# "{}/nll_loss".format(split): nll_loss.detach().mean(),
# "{}/rec_loss".format(split): rec_loss.detach().mean(),
# # "{}/d_weight".format(split): d_weight.detach(),
# # "{}/disc_factor".format(split): torch.tensor(disc_factor),
# # "{}/g_loss".format(split): g_loss.detach().mean(),
# }
return loss
return loss

File diff suppressed because it is too large Load diff

View file

@ -1,10 +1,8 @@
from copy import deepcopy
import os
from glob import glob
import colossalai
import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
import torch.distributed as dist
import wandb
from colossalai.booster import Booster
@ -12,11 +10,8 @@ from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
from tqdm import tqdm
import os
from einops import rearrange
import numpy as np
from glob import glob
from tqdm import tqdm
from opensora.acceleration.checkpoint import set_grad_checkpoint
from opensora.acceleration.parallel_states import (
@ -25,28 +20,17 @@ from opensora.acceleration.parallel_states import (
set_sequence_parallel_group,
)
from opensora.acceleration.plugin import ZeroSeqParallelPlugin
from opensora.datasets import DatasetFromCSV, get_transforms_image, get_transforms_video, prepare_dataloader
from opensora.registry import MODELS, SCHEDULERS, build_module
from opensora.utils.ckpt_utils import create_logger, load_json, save_json, load, model_sharding, record_model_param_shape, save
from opensora.datasets import prepare_dataloader, prepare_variable_dataloader
from opensora.models.vae.vae_3d_v2 import AdversarialLoss, DiscriminatorLoss, LeCamEMA, VEALoss, pad_at_dim
from opensora.registry import DATASETS, MODELS, build_module
from opensora.utils.ckpt_utils import create_logger, load_json, save_json
from opensora.utils.config_utils import (
create_experiment_workspace,
create_tensorboard_writer,
parse_configs,
save_training_config,
)
from opensora.utils.misc import all_reduce_mean, format_numel_str, get_model_numel, requires_grad, to_torch_dtype
from opensora.utils.train_utils import update_ema, MaskGenerator
from opensora.models.vae.vae_3d_v2 import VEALoss, DiscriminatorLoss, AdversarialLoss, LeCamEMA, pad_at_dim
# efficiency
# from torch.profiler import profile, record_function, ProfilerActivity
def trace_handler(p):
output = p.key_averages().table(sort_by="self_cuda_time_total", row_limit=5)
print(output)
# p.export_chrome_trace("/home/shenchenhui/Open-Sora-dev/outputs/traces/trace_" + str(p.step_num) + ".json")
from opensora.utils.misc import all_reduce_mean, format_numel_str, get_model_numel, to_torch_dtype
def main():
@ -59,20 +43,20 @@ def main():
# 2. runtime variables & colossalai launch
# ======================================================
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
# 2.1. colossalai init distributed training
colossalai.launch_from_torch({})
coordinator = DistCoordinator()
exp_dir = None
if coordinator.is_master(): # only create directory for master
if coordinator.is_master(): # only create directory for master
exp_name, exp_dir = create_experiment_workspace(cfg)
save_training_config(cfg._cfg_dict, exp_dir)
dist.barrier()
# get exp dir for non-master process
if exp_dir is None:
experiment_index = len(glob(f"{cfg.outputs}/*"))-1
experiment_index = len(glob(f"{cfg.outputs}/*")) - 1
model_name = cfg.model["type"].replace("/", "-")
exp_name = f"{experiment_index:03d}-F{cfg.num_frames}S{cfg.frame_interval}-{model_name}"
exp_dir = f"{cfg.outputs}/{exp_name}"
@ -123,31 +107,30 @@ def main():
# ======================================================
# 3. build dataset and dataloader
# ======================================================
dataset = DatasetFromCSV(
cfg.data_path,
transform=(
get_transforms_video(cfg.image_size[0])
if not cfg.use_image_transform
else get_transforms_image(cfg.image_size[0])
),
num_frames=cfg.num_frames,
frame_interval=cfg.frame_interval,
root=cfg.root,
)
dataloader = prepare_dataloader(
dataset,
dataset = build_module(cfg.dataset, DATASETS)
logger.info(f"Dataset contains {len(dataset)} samples.")
dataloader_args = dict(
dataset=dataset,
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
seed=cfg.seed,
shuffle=True,
drop_last=True,
pin_memory=True,
process_group=get_data_parallel_group(),
)
logger.info(f"Dataset contains {len(dataset):,} videos ({cfg.data_path})")
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size
logger.info(f"Total batch size: {total_batch_size}")
# TODO: use plugin's prepare dataloader
if cfg.bucket_config is None:
dataloader = prepare_dataloader(**dataloader_args)
else:
dataloader = prepare_variable_dataloader(
bucket_config=cfg.bucket_config,
num_bucket_build_workers=cfg.num_bucket_build_workers,
**dataloader_args,
)
if cfg.dataset.type == "VideoTextDataset":
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size
logger.info(f"Total batch size: {total_batch_size}")
# ======================================================
# 4. build model
@ -163,7 +146,8 @@ def main():
logger.info(
f"Trainable vae params: {format_numel_str(vae_numel_trainable)}, Total model params: {format_numel_str(vae_numel)}"
)
breakpoint()
discriminator = build_module(cfg.discriminator, MODELS, device=device)
discriminator_numel, discriminator_numel_trainable = get_model_numel(discriminator)
logger.info(
@ -175,12 +159,11 @@ def main():
# 4.3. move to device
if cfg.get("use_pipeline") == True:
vae_2d.to(device, dtype).eval() # eval mode, not training!
vae_2d.to(device, dtype).eval() # eval mode, not training!
vae = vae.to(device, dtype)
discriminator = discriminator.to(device, dtype)
# 4.5. setup optimizer
# vae optimizer
optimizer = HybridAdam(
@ -200,7 +183,6 @@ def main():
vae.train()
discriminator.train()
# =======================================================
# 5. boost model for distributed training with colossalai
# =======================================================
@ -212,13 +194,11 @@ def main():
num_steps_per_epoch = len(dataloader)
logger.info("Boost vae for distributed training")
discriminator, disc_optimizer, _, _, disc_lr_scheduler = booster.boost(
model=discriminator, optimizer=disc_optimizer, lr_scheduler=disc_lr_scheduler
)
logger.info("Boost discriminator for distributed training")
# =======================================================
# 6. training loop
# =======================================================
@ -226,7 +206,6 @@ def main():
running_loss = 0.0
running_disc_loss = 0.0
# 6.1. resume training
if cfg.load is not None:
logger.info("Loading checkpoint")
@ -244,11 +223,17 @@ def main():
if cfg.lecam_loss_weight is not None and os.path.exists(lecam_path):
lecam_state = load_json(lecam_path)
lecam_ema_real, lecam_ema_fake = lecam_state["lecam_ema_real"], lecam_state["lecam_ema_fake"]
lecam_ema = LeCamEMA(decay=cfg.ema_decay, ema_real=lecam_ema_real, ema_fake=lecam_ema_fake, dtype=dtype, device=device)
lecam_ema = LeCamEMA(
decay=cfg.ema_decay, ema_real=lecam_ema_real, ema_fake=lecam_ema_fake, dtype=dtype, device=device
)
running_states = load_json(os.path.join(cfg.load, "running_states.json"))
dist.barrier()
start_epoch, start_step, sampler_start_idx = running_states["epoch"], running_states["step"], running_states["sample_start_index"]
start_epoch, start_step, sampler_start_idx = (
running_states["epoch"],
running_states["step"],
running_states["sample_start_index"],
)
logger.info(f"Loaded checkpoint {cfg.load} at epoch {start_epoch} step {start_step}")
logger.info(f"Training for {cfg.epochs} epochs with {num_steps_per_epoch} steps per epoch")
@ -257,25 +242,25 @@ def main():
# 6.2 Define loss functions
vae_loss_fn = VEALoss(
logvar_init=cfg.logvar_init,
perceptual_loss_weight = cfg.perceptual_loss_weight,
kl_loss_weight = cfg.kl_loss_weight,
perceptual_loss_weight=cfg.perceptual_loss_weight,
kl_loss_weight=cfg.kl_loss_weight,
device=device,
dtype=dtype,
)
adversarial_loss_fn = AdversarialLoss(
discriminator_factor = cfg.discriminator_factor,
discriminator_start = cfg.discriminator_start,
generator_factor = cfg.generator_factor,
generator_loss_type = cfg.generator_loss_type,
discriminator_factor=cfg.discriminator_factor,
discriminator_start=cfg.discriminator_start,
generator_factor=cfg.generator_factor,
generator_loss_type=cfg.generator_loss_type,
)
disc_loss_fn = DiscriminatorLoss(
discriminator_factor = cfg.discriminator_factor,
discriminator_start = cfg.discriminator_start,
discriminator_loss_type = cfg.discriminator_loss_type,
lecam_loss_weight = cfg.lecam_loss_weight,
gradient_penalty_loss_weight = cfg.gradient_penalty_loss_weight,
discriminator_factor=cfg.discriminator_factor,
discriminator_start=cfg.discriminator_start,
discriminator_loss_type=cfg.discriminator_loss_type,
lecam_loss_weight=cfg.lecam_loss_weight,
gradient_penalty_loss_weight=cfg.gradient_penalty_loss_weight,
)
# 6.3. training loop
@ -288,14 +273,11 @@ def main():
disc_time_padding = 0
video_contains_first_frame = cfg.video_contains_first_frame
for epoch in range(start_epoch, cfg.epochs):
dataloader.sampler.set_epoch(epoch)
dataloader_iter = iter(dataloader)
logger.info(f"Beginning epoch {epoch}...")
with tqdm(
range(start_step, num_steps_per_epoch),
desc=f"Epoch {epoch}",
@ -303,7 +285,6 @@ def main():
total=num_steps_per_epoch,
initial=start_step,
) as pbar:
# with profile(
# activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
# schedule=torch.profiler.schedule(
@ -317,199 +298,203 @@ def main():
# record_shapes=True,
# profile_memory=True,
# ) as p: # trace efficiency
for step in pbar:
# with profile(
# activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
# with_stack=True,
# ) as p: # trace efficiency
for step in pbar:
# with profile(
# activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
# with_stack=True,
# ) as p: # trace efficiency
# SCH: calc global step at the start
global_step = epoch * num_steps_per_epoch + step
batch = next(dataloader_iter)
x = batch["video"].to(device, dtype) # [B, C, T, H, W]
# SCH: calc global step at the start
global_step = epoch * num_steps_per_epoch + step
# supprt for image or video inputs
assert x.ndim in {4, 5}, f"received input of {x.ndim} dimensions" # either image or video
assert x.shape[-2:] == cfg.image_size, f"received input size {x.shape[-2:]}, but config image size is {cfg.image_size}"
is_image = x.ndim == 4
if is_image:
video = rearrange(x, 'b c ... -> b c 1 ...')
video_contains_first_frame = True
else:
video = x
# ===== Spatial VAE =====
if cfg.get("use_pipeline") == True:
with torch.no_grad():
video = vae_2d.encode(video)
batch = next(dataloader_iter)
x = batch["video"].to(device, dtype) # [B, C, T, H, W]
# ====== VAE ======
recon_video, posterior = vae(
video,
video_contains_first_frame = video_contains_first_frame,
)
# supprt for image or video inputs
assert x.ndim in {4, 5}, f"received input of {x.ndim} dimensions" # either image or video
assert (
x.shape[-2:] == cfg.image_size
), f"received input size {x.shape[-2:]}, but config image size is {cfg.image_size}"
is_image = x.ndim == 4
if is_image:
video = rearrange(x, "b c ... -> b c 1 ...")
video_contains_first_frame = True
else:
video = x
# ====== Generator Loss ======
# simple nll loss
nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(
video,
recon_video,
posterior,
split = "train"
)
# ===== Spatial VAE =====
if cfg.get("use_pipeline") == True:
with torch.no_grad():
video = vae_2d.encode(video)
adversarial_loss = torch.tensor(0.0)
# adversarial loss
if global_step > cfg.discriminator_start:
# padded videos for GAN
fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value = 0., dim = 2)
fake_logits = discriminator(fake_video.contiguous())
adversarial_loss = adversarial_loss_fn(
fake_logits,
nll_loss,
vae.module.get_last_layer(),
global_step,
is_training = vae.training,
)
vae_loss = weighted_nll_loss + weighted_kl_loss + adversarial_loss
# ====== VAE ======
recon_video, posterior = vae(
video,
video_contains_first_frame=video_contains_first_frame,
)
optimizer.zero_grad()
# Backward & update
booster.backward(loss=vae_loss, optimizer=optimizer)
# # NOTE: clip gradients? this is done in Open-Sora-Plan
# torch.nn.utils.clip_grad_norm_(vae.parameters(), 1) # NOTE: done by grad_clip
optimizer.step()
# ====== Generator Loss ======
# simple nll loss
nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(
video, recon_video, posterior, split="train"
)
# Log loss values:
all_reduce_mean(vae_loss) # NOTE: this is to get average loss for logging
running_loss += vae_loss.item()
adversarial_loss = torch.tensor(0.0)
# adversarial loss
if global_step > cfg.discriminator_start:
# padded videos for GAN
fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value=0.0, dim=2)
fake_logits = discriminator(fake_video.contiguous())
adversarial_loss = adversarial_loss_fn(
fake_logits,
nll_loss,
vae.module.get_last_layer(),
global_step,
is_training=vae.training,
)
# ====== Discriminator Loss ======
if global_step > cfg.discriminator_start:
# if video_contains_first_frame:
# Since we don't have enough T frames, pad anyways
real_video = pad_at_dim(video, (disc_time_padding, 0), value = 0., dim = 2)
fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value = 0., dim = 2)
vae_loss = weighted_nll_loss + weighted_kl_loss + adversarial_loss
if cfg.gradient_penalty_loss_weight is not None and cfg.gradient_penalty_loss_weight > 0.0:
real_video = real_video.requires_grad_()
real_logits = discriminator(real_video.contiguous()) # SCH: not detached for now for gradient_penalty calculation
else:
real_logits = discriminator(real_video.contiguous().detach())
optimizer.zero_grad()
# Backward & update
booster.backward(loss=vae_loss, optimizer=optimizer)
# # NOTE: clip gradients? this is done in Open-Sora-Plan
# torch.nn.utils.clip_grad_norm_(vae.parameters(), 1) # NOTE: done by grad_clip
optimizer.step()
fake_logits = discriminator(fake_video.contiguous().detach())
# Log loss values:
all_reduce_mean(vae_loss) # NOTE: this is to get average loss for logging
running_loss += vae_loss.item()
# ====== Discriminator Loss ======
if global_step > cfg.discriminator_start:
# if video_contains_first_frame:
# Since we don't have enough T frames, pad anyways
real_video = pad_at_dim(video, (disc_time_padding, 0), value=0.0, dim=2)
fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value=0.0, dim=2)
lecam_ema_real, lecam_ema_fake = lecam_ema.get()
if cfg.gradient_penalty_loss_weight is not None and cfg.gradient_penalty_loss_weight > 0.0:
real_video = real_video.requires_grad_()
real_logits = discriminator(
real_video.contiguous()
) # SCH: not detached for now for gradient_penalty calculation
else:
real_logits = discriminator(real_video.contiguous().detach())
weighted_d_adversarial_loss, lecam_loss, gradient_penalty_loss = disc_loss_fn(
real_logits,
fake_logits,
global_step,
lecam_ema_real = lecam_ema_real,
lecam_ema_fake = lecam_ema_fake,
real_video = real_video if cfg.gradient_penalty_loss_weight is not None else None,
)
disc_loss = weighted_d_adversarial_loss + lecam_loss + gradient_penalty_loss
if cfg.lecam_loss_weight is not None:
ema_real = torch.mean(real_logits.clone().detach()).to(device, dtype)
ema_fake = torch.mean(fake_logits.clone().detach()).to(device, dtype)
all_reduce_mean(ema_real)
all_reduce_mean(ema_fake)
lecam_ema.update(ema_real, ema_fake)
fake_logits = discriminator(fake_video.contiguous().detach())
disc_optimizer.zero_grad()
# Backward & update
booster.backward(loss=disc_loss, optimizer=disc_optimizer)
# # NOTE: TODO: clip gradients? this is done in Open-Sora-Plan
# torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1) # NOTE: done by grad_clip
disc_optimizer.step()
lecam_ema_real, lecam_ema_fake = lecam_ema.get()
# Log loss values:
all_reduce_mean(disc_loss)
running_disc_loss += disc_loss.item()
else:
disc_loss = torch.tensor(0.0)
weighted_d_adversarial_loss = torch.tensor(0.0)
lecam_loss = torch.tensor(0.0)
gradient_penalty_loss = torch.tensor(0.0)
weighted_d_adversarial_loss, lecam_loss, gradient_penalty_loss = disc_loss_fn(
real_logits,
fake_logits,
global_step,
lecam_ema_real=lecam_ema_real,
lecam_ema_fake=lecam_ema_fake,
real_video=real_video if cfg.gradient_penalty_loss_weight is not None else None,
)
disc_loss = weighted_d_adversarial_loss + lecam_loss + gradient_penalty_loss
if cfg.lecam_loss_weight is not None:
ema_real = torch.mean(real_logits.clone().detach()).to(device, dtype)
ema_fake = torch.mean(fake_logits.clone().detach()).to(device, dtype)
all_reduce_mean(ema_real)
all_reduce_mean(ema_fake)
lecam_ema.update(ema_real, ema_fake)
log_step += 1
disc_optimizer.zero_grad()
# Backward & update
booster.backward(loss=disc_loss, optimizer=disc_optimizer)
# # NOTE: TODO: clip gradients? this is done in Open-Sora-Plan
# torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1) # NOTE: done by grad_clip
disc_optimizer.step()
# Log to tensorboard
if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0:
avg_loss = running_loss / log_step
avg_disc_loss = running_disc_loss / log_step
pbar.set_postfix({"loss": avg_loss, "disc_loss": avg_disc_loss, "step": step, "global_step": global_step})
running_loss = 0
log_step = 0
running_disc_loss = 0
writer.add_scalar("loss", vae_loss.item(), global_step)
if cfg.wandb:
wandb.log(
{
"iter": global_step,
"num_samples": global_step * total_batch_size,
"epoch": epoch,
"loss": vae_loss.item(),
"kl_loss": weighted_kl_loss.item(),
"gen_adv_loss": adversarial_loss.item(),
"disc_loss": disc_loss.item(),
"lecam_loss": lecam_loss.item(),
"r1_grad_penalty": gradient_penalty_loss.item(),
"nll_loss": weighted_nll_loss.item(),
"avg_loss": avg_loss,
},
step=global_step,
)
# Log loss values:
all_reduce_mean(disc_loss)
running_disc_loss += disc_loss.item()
else:
disc_loss = torch.tensor(0.0)
weighted_d_adversarial_loss = torch.tensor(0.0)
lecam_loss = torch.tensor(0.0)
gradient_penalty_loss = torch.tensor(0.0)
# Save checkpoint
if cfg.ckpt_every > 0 and (global_step + 1) % cfg.ckpt_every == 0:
save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}")
os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) # already handled in booster save_model
booster.save_model(vae, os.path.join(save_dir, "model"), shard=True)
booster.save_model(discriminator, os.path.join(save_dir, "discriminator"), shard=True)
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096)
booster.save_optimizer(disc_optimizer, os.path.join(save_dir, "disc_optimizer"), shard=True, size_per_shard=4096)
log_step += 1
if lr_scheduler is not None:
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
if disc_lr_scheduler is not None:
booster.save_lr_scheduler(disc_lr_scheduler, os.path.join(save_dir, "disc_lr_scheduler"))
running_states = {
# Log to tensorboard
if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0:
avg_loss = running_loss / log_step
avg_disc_loss = running_disc_loss / log_step
pbar.set_postfix(
{"loss": avg_loss, "disc_loss": avg_disc_loss, "step": step, "global_step": global_step}
)
running_loss = 0
log_step = 0
running_disc_loss = 0
writer.add_scalar("loss", vae_loss.item(), global_step)
if cfg.wandb:
wandb.log(
{
"iter": global_step,
"num_samples": global_step * total_batch_size,
"epoch": epoch,
"step": step+1,
"global_step": global_step+1,
"sample_start_index": (step+1) * cfg.batch_size,
}
lecam_ema_real, lecam_ema_fake = lecam_ema.get()
lecam_state = {
"lecam_ema_real": lecam_ema_real.item(),
"lecam_ema_fake": lecam_ema_fake.item(),
}
if coordinator.is_master():
save_json(running_states, os.path.join(save_dir, "running_states.json"))
if cfg.lecam_loss_weight is not None:
save_json(lecam_state, os.path.join(save_dir, "lecam_states.json"))
dist.barrier()
"loss": vae_loss.item(),
"kl_loss": weighted_kl_loss.item(),
"gen_adv_loss": adversarial_loss.item(),
"disc_loss": disc_loss.item(),
"lecam_loss": lecam_loss.item(),
"r1_grad_penalty": gradient_penalty_loss.item(),
"nll_loss": weighted_nll_loss.item(),
"avg_loss": avg_loss,
},
step=global_step,
)
logger.info(
f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}"
)
# Save checkpoint
if cfg.ckpt_every > 0 and (global_step + 1) % cfg.ckpt_every == 0:
save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}")
os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) # already handled in booster save_model
booster.save_model(vae, os.path.join(save_dir, "model"), shard=True)
booster.save_model(discriminator, os.path.join(save_dir, "discriminator"), shard=True)
booster.save_optimizer(
optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096
)
booster.save_optimizer(
disc_optimizer, os.path.join(save_dir, "disc_optimizer"), shard=True, size_per_shard=4096
)
# print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
if lr_scheduler is not None:
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
if disc_lr_scheduler is not None:
booster.save_lr_scheduler(disc_lr_scheduler, os.path.join(save_dir, "disc_lr_scheduler"))
running_states = {
"epoch": epoch,
"step": step + 1,
"global_step": global_step + 1,
"sample_start_index": (step + 1) * cfg.batch_size,
}
lecam_ema_real, lecam_ema_fake = lecam_ema.get()
lecam_state = {
"lecam_ema_real": lecam_ema_real.item(),
"lecam_ema_fake": lecam_ema_fake.item(),
}
if coordinator.is_master():
save_json(running_states, os.path.join(save_dir, "running_states.json"))
if cfg.lecam_loss_weight is not None:
save_json(lecam_state, os.path.join(save_dir, "lecam_states.json"))
dist.barrier()
logger.info(
f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}"
)
# print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
dataloader.sampler.set_start_index(0)
start_step = 0
if __name__ == "__main__":
main()