mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 21:01:26 +02:00
[wip] debug vae
This commit is contained in:
parent
a4652a8aef
commit
478b585024
|
|
@ -1,21 +1,20 @@
|
|||
num_frames = 16
|
||||
frame_interval = 3
|
||||
image_size = (128, 128)
|
||||
use_pipeline = True
|
||||
|
||||
# Define dataset
|
||||
root = None
|
||||
data_path = "CSV_PATH"
|
||||
use_image_transform = False
|
||||
num_workers = 4
|
||||
video_contains_first_frame = False
|
||||
dataset = dict(
|
||||
type="VideoTextDataset",
|
||||
data_path=None,
|
||||
num_frames=16,
|
||||
frame_interval=3,
|
||||
image_size=(128, 128),
|
||||
)
|
||||
|
||||
# Define acceleration
|
||||
num_workers = 4
|
||||
dtype = "bf16"
|
||||
grad_checkpoint = True
|
||||
plugin = "zero2"
|
||||
sp_size = 1
|
||||
|
||||
use_pipeline = True
|
||||
video_contains_first_frame = False
|
||||
|
||||
# Define model
|
||||
vae_2d = dict(
|
||||
|
|
@ -23,50 +22,49 @@ vae_2d = dict(
|
|||
from_pretrained="stabilityai/sd-vae-ft-ema",
|
||||
# SDXL
|
||||
)
|
||||
|
||||
model = dict(
|
||||
type="VAE_MAGVIT_V2",
|
||||
in_out_channels = 4,
|
||||
latent_embed_dim = 4,
|
||||
filters = 128,
|
||||
num_res_blocks = 4,
|
||||
channel_multipliers = (1, 2, 2, 4),
|
||||
temporal_downsample = (False, True, True),
|
||||
num_groups = 32, # for nn.GroupNorm
|
||||
kl_embed_dim = 4,
|
||||
activation_fn = 'swish',
|
||||
separate_first_frame_encoding = False,
|
||||
disable_space = True,
|
||||
encoder_double_z = True,
|
||||
custom_conv_padding = None
|
||||
in_out_channels=4,
|
||||
latent_embed_dim=4,
|
||||
filters=128,
|
||||
num_res_blocks=4,
|
||||
channel_multipliers=(1, 2, 2, 4),
|
||||
temporal_downsample=(False, True, True),
|
||||
num_groups=32, # for nn.GroupNorm
|
||||
kl_embed_dim=4,
|
||||
activation_fn="swish",
|
||||
separate_first_frame_encoding=False,
|
||||
disable_space=True,
|
||||
encoder_double_z=True,
|
||||
custom_conv_padding=None,
|
||||
)
|
||||
|
||||
|
||||
discriminator = dict(
|
||||
type="DISCRIMINATOR_3D",
|
||||
image_size = (16, 16), # NOTE: here image size is different
|
||||
num_frames = num_frames,
|
||||
in_channels = 4,
|
||||
filters = 128,
|
||||
use_pretrained=True, # NOTE: set to False only if we want to disable load
|
||||
image_size=(16, 16), # NOTE: here image size is different
|
||||
num_frames=16,
|
||||
in_channels=4,
|
||||
filters=128,
|
||||
use_pretrained=True, # NOTE: set to False only if we want to disable load
|
||||
# channel_multipliers = (2,4,4,4,4), # (2,4,4,4) for 64x64 resolution
|
||||
channel_multipliers= (2,4,4) # since on intermediate layer dimension ofs z
|
||||
channel_multipliers=(2, 4, 4), # since on intermediate layer dimension ofs z
|
||||
)
|
||||
|
||||
|
||||
# loss weights
|
||||
logvar_init=0.0
|
||||
# loss weights
|
||||
logvar_init = 0.0
|
||||
kl_loss_weight = 0.000001
|
||||
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
|
||||
discriminator_factor = 1.0 # for discriminator adversarial loss
|
||||
generator_factor = 0.1 # SCH: generator adversarial loss, MAGVIT v2 uses 0.1
|
||||
lecam_loss_weight = None # NOTE: MAVGIT v2 use 0.001
|
||||
discriminator_loss_type="non-saturating"
|
||||
generator_loss_type="non-saturating"
|
||||
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
|
||||
discriminator_factor = 1.0 # for discriminator adversarial loss
|
||||
generator_factor = 0.1 # SCH: generator adversarial loss, MAGVIT v2 uses 0.1
|
||||
lecam_loss_weight = None # NOTE: MAVGIT v2 use 0.001
|
||||
discriminator_loss_type = "non-saturating"
|
||||
generator_loss_type = "non-saturating"
|
||||
# discriminator_loss_type="hinge"
|
||||
# generator_loss_type="hinge"
|
||||
discriminator_start = 100 # 8k data / (8*32) = 31 steps per epoch, use around 3 epochs
|
||||
gradient_penalty_loss_weight = None # 10 # SCH: MAGVIT uses 10, opensora plan doesn't use
|
||||
discriminator_start = 100 # 8k data / (8*32) = 31 steps per epoch, use around 3 epochs
|
||||
gradient_penalty_loss_weight = None # 10 # SCH: MAGVIT uses 10, opensora plan doesn't use
|
||||
ema_decay = 0.999 # ema decay factor for generator
|
||||
|
||||
|
||||
|
|
@ -76,11 +74,11 @@ outputs = "outputs"
|
|||
wandb = False
|
||||
|
||||
# Training
|
||||
''' NOTE:
|
||||
""" NOTE:
|
||||
magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128
|
||||
==> ours num_frams = 16, reso = 256, so samples (K) * epochs ~ [500 - 1200],
|
||||
==> ours num_frams = 16, reso = 256, so samples (K) * epochs ~ [500 - 1200],
|
||||
3-6 epochs for pexel, from pexel observation its correct
|
||||
'''
|
||||
"""
|
||||
|
||||
epochs = 200
|
||||
log_every = 1
|
||||
|
|
|
|||
|
|
@ -1,12 +1,9 @@
|
|||
import functools
|
||||
import math
|
||||
from typing import Any, Optional, Sequence, Type
|
||||
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import torch
|
||||
from taming.modules.losses.lpips import LPIPS # need to pip install https://github.com/CompVis/taming-transformers
|
||||
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
|
||||
import torch.nn as nn
|
||||
|
||||
# from taming.modules.losses.lpips import LPIPS # need to pip install https://github.com/CompVis/taming-transformers
|
||||
# from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
|
||||
from einops import rearrange
|
||||
|
||||
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
|
||||
|
|
@ -17,23 +14,23 @@ from einops import rearrange
|
|||
# if norm_type == 'LN':
|
||||
# # supply a few args with partial function and pass the rest of the args when this norm_fn is called
|
||||
# norm_fn = functools.partial(nn.LayerNorm, dtype=dtype)
|
||||
# elif norm_type == 'GN': #
|
||||
# elif norm_type == 'GN': #
|
||||
# norm_fn = functools.partial(nn.GroupNorm, dtype=dtype)
|
||||
# elif norm_type is None:
|
||||
# norm_fn = lambda: (lambda x: x)
|
||||
# else:
|
||||
# raise NotImplementedError(f'norm_type: {norm_type}')
|
||||
# return norm_fn
|
||||
|
||||
|
||||
|
||||
class DiagonalGaussianDistribution(object):
|
||||
def __init__(
|
||||
self,
|
||||
parameters,
|
||||
self,
|
||||
parameters,
|
||||
deterministic=False,
|
||||
):
|
||||
self.parameters = parameters
|
||||
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) # SCH: channels dim
|
||||
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) # SCH: channels dim
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.deterministic = deterministic
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
|
|
@ -48,40 +45,39 @@ class DiagonalGaussianDistribution(object):
|
|||
|
||||
def kl(self, other=None):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
return torch.Tensor([0.0])
|
||||
else:
|
||||
if other is None: # SCH: assumes other is a standard normal distribution
|
||||
return 0.5 * torch.sum(torch.pow(self.mean, 2)
|
||||
+ self.var - 1.0 - self.logvar,
|
||||
dim=[1, 2, 3, 4])
|
||||
if other is None: # SCH: assumes other is a standard normal distribution
|
||||
return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3, 4])
|
||||
else:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean - other.mean, 2) / other.var
|
||||
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
||||
dim=[1, 2, 3, 4])
|
||||
+ self.var / other.var
|
||||
- 1.0
|
||||
- self.logvar
|
||||
+ other.logvar,
|
||||
dim=[1, 2, 3, 4],
|
||||
)
|
||||
|
||||
def nll(self, sample, dims=[1,2,3,4]): # TODO: what does this do?
|
||||
def nll(self, sample, dims=[1, 2, 3, 4]): # TODO: what does this do?
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
return torch.Tensor([0.0])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
return 0.5 * torch.sum(
|
||||
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
||||
dim=dims)
|
||||
return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
|
||||
|
||||
def mode(self): # SCH: used for vae inference?
|
||||
def mode(self): # SCH: used for vae inference?
|
||||
return self.mean
|
||||
|
||||
|
||||
|
||||
class VEA3DLoss(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
# disc_start,
|
||||
logvar_init=0.0,
|
||||
kl_weight=1.0,
|
||||
# disc_start,
|
||||
logvar_init=0.0,
|
||||
kl_weight=1.0,
|
||||
pixelloss_weight=1.0,
|
||||
perceptual_weight=0.1,
|
||||
perceptual_weight=0.1,
|
||||
disc_loss="hinge",
|
||||
|
||||
):
|
||||
super().__init__()
|
||||
assert disc_loss in ["hinge", "vanilla"]
|
||||
|
|
@ -92,28 +88,27 @@ class VEA3DLoss(nn.Module):
|
|||
# output log variance
|
||||
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs,
|
||||
reconstructions,
|
||||
posteriors,
|
||||
# optimizer_idx,
|
||||
# global_step,
|
||||
# global_step,
|
||||
weights=None,
|
||||
):
|
||||
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
||||
if self.perceptual_weight > 0: # NOTE: need in_channels == 3 in order to use!
|
||||
if self.perceptual_weight > 0: # NOTE: need in_channels == 3 in order to use!
|
||||
assert inputs.size(1) == 3, f"using vgg16 that requires 3 input channels but got {inputs.size(1)}"
|
||||
# SCH: transform to [(B,T), C, H, W] shape for percetual loss over each frame
|
||||
B = inputs.shape[0]
|
||||
inputs = rearrange(inputs,"B C T H W -> (B T) C H W")
|
||||
inputs = rearrange(inputs, "B C T H W -> (B T) C H W")
|
||||
reconstructions = rearrange(reconstructions, "B C T H W -> (B T) C H W")
|
||||
# permutated_input = torch.permute(inputs, (0, 2, 1, 3, 4)) # [B, C, T, H, W] --> [B, T, C, H, W]
|
||||
# permutated_rec = torch.permute(reconstructions, (0, 2, 1, 3, 4))
|
||||
# data_shape = permutated_input.size()
|
||||
# p_loss = self.perceptual_loss(
|
||||
# permutated_input.reshape(-1, data_shape[-3], data_shape[-2],data_shape[-1]).contiguous(),
|
||||
# permutated_input.reshape(-1, data_shape[-3], data_shape[-2],data_shape[-1]).contiguous(),
|
||||
# permutated_rec.reshape(-1, data_shape[-3], data_shape[-2],data_shape[-1]).contiguous()
|
||||
# )
|
||||
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
|
||||
|
|
@ -126,32 +121,32 @@ class VEA3DLoss(nn.Module):
|
|||
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
|
||||
weighted_nll_loss = nll_loss
|
||||
if weights is not None:
|
||||
weighted_nll_loss = weights*nll_loss
|
||||
weighted_nll_loss = weights * nll_loss
|
||||
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
|
||||
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
||||
kl_loss = posteriors.kl()
|
||||
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
|
||||
|
||||
loss = weighted_nll_loss + self.kl_weight * kl_loss # TODO: add discriminator loss later
|
||||
loss = weighted_nll_loss + self.kl_weight * kl_loss # TODO: add discriminator loss later
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class VEA3DLossWithDiscriminator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
# disc_start,
|
||||
logvar_init=0.0,
|
||||
kl_weight=1.0,
|
||||
# disc_start,
|
||||
logvar_init=0.0,
|
||||
kl_weight=1.0,
|
||||
pixelloss_weight=1.0,
|
||||
disc_num_layers=3,
|
||||
disc_in_channels=3,
|
||||
disc_factor=1.0,
|
||||
disc_num_layers=3,
|
||||
disc_in_channels=3,
|
||||
disc_factor=1.0,
|
||||
disc_weight=1.0,
|
||||
perceptual_weight=1.0,
|
||||
use_actnorm=False,
|
||||
perceptual_weight=1.0,
|
||||
use_actnorm=False,
|
||||
disc_conditional=False,
|
||||
disc_loss="hinge",
|
||||
|
||||
):
|
||||
super().__init__()
|
||||
assert disc_loss in ["hinge", "vanilla"]
|
||||
|
|
@ -185,53 +180,53 @@ class VEA3DLossWithDiscriminator(nn.Module):
|
|||
# d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
||||
# d_weight = d_weight * self.discriminator_weight
|
||||
# return d_weight
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs,
|
||||
reconstructions,
|
||||
posteriors,
|
||||
# optimizer_idx,
|
||||
# global_step,
|
||||
last_layer=None,
|
||||
cond=None,
|
||||
# global_step,
|
||||
last_layer=None,
|
||||
cond=None,
|
||||
split="train",
|
||||
weights=None,
|
||||
):
|
||||
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
||||
if self.perceptual_weight > 0: # NOTE: need in_channels == 3 in order to use!
|
||||
assert inputs.size(1) == 3, f"using vgg16 that requires 3 input channels but got {inputs.size(1)} "
|
||||
if self.perceptual_weight > 0: # NOTE: need in_channels == 3 in order to use!
|
||||
assert inputs.size(1) == 3, f"using vgg16 that requires 3 input channels but got {inputs.size(1)} "
|
||||
# SCH: transform to [(B,T), C, H, W] shape for percetual loss over each frame
|
||||
permutated_input = torch.permute(inputs, (0, 2, 1, 3, 4)) # [B, C, T, H, W] --> [B, T, C, H, W]
|
||||
permutated_input = torch.permute(inputs, (0, 2, 1, 3, 4)) # [B, C, T, H, W] --> [B, T, C, H, W]
|
||||
permutated_rec = torch.permute(reconstructions, (0, 2, 1, 3, 4))
|
||||
data_shape = permutated_input.size()
|
||||
p_loss = self.perceptual_loss(
|
||||
permutated_input.reshape(-1, data_shape[-3], data_shape[-2],data_shape[-1]).contiguous(),
|
||||
permutated_rec.reshape(-1, data_shape[-3], data_shape[-2],data_shape[-1]).contiguous()
|
||||
permutated_input.reshape(-1, data_shape[-3], data_shape[-2], data_shape[-1]).contiguous(),
|
||||
permutated_rec.reshape(-1, data_shape[-3], data_shape[-2], data_shape[-1]).contiguous(),
|
||||
)
|
||||
# SCH: shape back p_loss
|
||||
permuted_p_loss = torch.permute(p_loss.reshape(data_shape[0], data_shape[1], 1, 1, 1), (0,2,1,3,4))
|
||||
permuted_p_loss = torch.permute(p_loss.reshape(data_shape[0], data_shape[1], 1, 1, 1), (0, 2, 1, 3, 4))
|
||||
rec_loss = rec_loss + self.perceptual_weight * permuted_p_loss
|
||||
|
||||
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
|
||||
weighted_nll_loss = nll_loss
|
||||
if weights is not None:
|
||||
weighted_nll_loss = weights*nll_loss
|
||||
weighted_nll_loss = weights * nll_loss
|
||||
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
|
||||
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
||||
kl_loss = posteriors.kl()
|
||||
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
|
||||
|
||||
loss = weighted_nll_loss + self.kl_weight * kl_loss # TODO: add discriminator loss later
|
||||
loss = weighted_nll_loss + self.kl_weight * kl_loss # TODO: add discriminator loss later
|
||||
|
||||
# log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
|
||||
# log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
|
||||
# "{}/logvar".format(split): self.logvar.detach(),
|
||||
# "{}/kl_loss".format(split): kl_loss.detach().mean(),
|
||||
# "{}/kl_loss".format(split): kl_loss.detach().mean(),
|
||||
# "{}/nll_loss".format(split): nll_loss.detach().mean(),
|
||||
# "{}/rec_loss".format(split): rec_loss.detach().mean(),
|
||||
# # "{}/d_weight".format(split): d_weight.detach(),
|
||||
# # "{}/disc_factor".format(split): torch.tensor(disc_factor),
|
||||
# # "{}/g_loss".format(split): g_loss.detach().mean(),
|
||||
# }
|
||||
|
||||
return loss
|
||||
|
||||
return loss
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -1,10 +1,8 @@
|
|||
from copy import deepcopy
|
||||
import os
|
||||
from glob import glob
|
||||
|
||||
import colossalai
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
import torch.distributed as dist
|
||||
import wandb
|
||||
from colossalai.booster import Booster
|
||||
|
|
@ -12,11 +10,8 @@ from colossalai.booster.plugin import LowLevelZeroPlugin
|
|||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
from glob import glob
|
||||
from tqdm import tqdm
|
||||
|
||||
from opensora.acceleration.checkpoint import set_grad_checkpoint
|
||||
from opensora.acceleration.parallel_states import (
|
||||
|
|
@ -25,28 +20,17 @@ from opensora.acceleration.parallel_states import (
|
|||
set_sequence_parallel_group,
|
||||
)
|
||||
from opensora.acceleration.plugin import ZeroSeqParallelPlugin
|
||||
from opensora.datasets import DatasetFromCSV, get_transforms_image, get_transforms_video, prepare_dataloader
|
||||
from opensora.registry import MODELS, SCHEDULERS, build_module
|
||||
from opensora.utils.ckpt_utils import create_logger, load_json, save_json, load, model_sharding, record_model_param_shape, save
|
||||
from opensora.datasets import prepare_dataloader, prepare_variable_dataloader
|
||||
from opensora.models.vae.vae_3d_v2 import AdversarialLoss, DiscriminatorLoss, LeCamEMA, VEALoss, pad_at_dim
|
||||
from opensora.registry import DATASETS, MODELS, build_module
|
||||
from opensora.utils.ckpt_utils import create_logger, load_json, save_json
|
||||
from opensora.utils.config_utils import (
|
||||
create_experiment_workspace,
|
||||
create_tensorboard_writer,
|
||||
parse_configs,
|
||||
save_training_config,
|
||||
)
|
||||
from opensora.utils.misc import all_reduce_mean, format_numel_str, get_model_numel, requires_grad, to_torch_dtype
|
||||
from opensora.utils.train_utils import update_ema, MaskGenerator
|
||||
from opensora.models.vae.vae_3d_v2 import VEALoss, DiscriminatorLoss, AdversarialLoss, LeCamEMA, pad_at_dim
|
||||
|
||||
|
||||
|
||||
# efficiency
|
||||
# from torch.profiler import profile, record_function, ProfilerActivity
|
||||
|
||||
def trace_handler(p):
|
||||
output = p.key_averages().table(sort_by="self_cuda_time_total", row_limit=5)
|
||||
print(output)
|
||||
# p.export_chrome_trace("/home/shenchenhui/Open-Sora-dev/outputs/traces/trace_" + str(p.step_num) + ".json")
|
||||
from opensora.utils.misc import all_reduce_mean, format_numel_str, get_model_numel, to_torch_dtype
|
||||
|
||||
|
||||
def main():
|
||||
|
|
@ -59,20 +43,20 @@ def main():
|
|||
# 2. runtime variables & colossalai launch
|
||||
# ======================================================
|
||||
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
|
||||
|
||||
|
||||
# 2.1. colossalai init distributed training
|
||||
colossalai.launch_from_torch({})
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
exp_dir = None
|
||||
if coordinator.is_master(): # only create directory for master
|
||||
if coordinator.is_master(): # only create directory for master
|
||||
exp_name, exp_dir = create_experiment_workspace(cfg)
|
||||
save_training_config(cfg._cfg_dict, exp_dir)
|
||||
dist.barrier()
|
||||
|
||||
# get exp dir for non-master process
|
||||
if exp_dir is None:
|
||||
experiment_index = len(glob(f"{cfg.outputs}/*"))-1
|
||||
experiment_index = len(glob(f"{cfg.outputs}/*")) - 1
|
||||
model_name = cfg.model["type"].replace("/", "-")
|
||||
exp_name = f"{experiment_index:03d}-F{cfg.num_frames}S{cfg.frame_interval}-{model_name}"
|
||||
exp_dir = f"{cfg.outputs}/{exp_name}"
|
||||
|
|
@ -123,31 +107,30 @@ def main():
|
|||
# ======================================================
|
||||
# 3. build dataset and dataloader
|
||||
# ======================================================
|
||||
dataset = DatasetFromCSV(
|
||||
cfg.data_path,
|
||||
transform=(
|
||||
get_transforms_video(cfg.image_size[0])
|
||||
if not cfg.use_image_transform
|
||||
else get_transforms_image(cfg.image_size[0])
|
||||
),
|
||||
num_frames=cfg.num_frames,
|
||||
frame_interval=cfg.frame_interval,
|
||||
root=cfg.root,
|
||||
)
|
||||
|
||||
dataloader = prepare_dataloader(
|
||||
dataset,
|
||||
dataset = build_module(cfg.dataset, DATASETS)
|
||||
logger.info(f"Dataset contains {len(dataset)} samples.")
|
||||
dataloader_args = dict(
|
||||
dataset=dataset,
|
||||
batch_size=cfg.batch_size,
|
||||
num_workers=cfg.num_workers,
|
||||
seed=cfg.seed,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
pin_memory=True,
|
||||
process_group=get_data_parallel_group(),
|
||||
)
|
||||
logger.info(f"Dataset contains {len(dataset):,} videos ({cfg.data_path})")
|
||||
|
||||
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size
|
||||
logger.info(f"Total batch size: {total_batch_size}")
|
||||
# TODO: use plugin's prepare dataloader
|
||||
if cfg.bucket_config is None:
|
||||
dataloader = prepare_dataloader(**dataloader_args)
|
||||
else:
|
||||
dataloader = prepare_variable_dataloader(
|
||||
bucket_config=cfg.bucket_config,
|
||||
num_bucket_build_workers=cfg.num_bucket_build_workers,
|
||||
**dataloader_args,
|
||||
)
|
||||
if cfg.dataset.type == "VideoTextDataset":
|
||||
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size
|
||||
logger.info(f"Total batch size: {total_batch_size}")
|
||||
|
||||
# ======================================================
|
||||
# 4. build model
|
||||
|
|
@ -163,7 +146,8 @@ def main():
|
|||
logger.info(
|
||||
f"Trainable vae params: {format_numel_str(vae_numel_trainable)}, Total model params: {format_numel_str(vae_numel)}"
|
||||
)
|
||||
|
||||
breakpoint()
|
||||
|
||||
discriminator = build_module(cfg.discriminator, MODELS, device=device)
|
||||
discriminator_numel, discriminator_numel_trainable = get_model_numel(discriminator)
|
||||
logger.info(
|
||||
|
|
@ -175,12 +159,11 @@ def main():
|
|||
|
||||
# 4.3. move to device
|
||||
if cfg.get("use_pipeline") == True:
|
||||
vae_2d.to(device, dtype).eval() # eval mode, not training!
|
||||
vae_2d.to(device, dtype).eval() # eval mode, not training!
|
||||
|
||||
vae = vae.to(device, dtype)
|
||||
discriminator = discriminator.to(device, dtype)
|
||||
|
||||
|
||||
# 4.5. setup optimizer
|
||||
# vae optimizer
|
||||
optimizer = HybridAdam(
|
||||
|
|
@ -200,7 +183,6 @@ def main():
|
|||
vae.train()
|
||||
discriminator.train()
|
||||
|
||||
|
||||
# =======================================================
|
||||
# 5. boost model for distributed training with colossalai
|
||||
# =======================================================
|
||||
|
|
@ -212,13 +194,11 @@ def main():
|
|||
num_steps_per_epoch = len(dataloader)
|
||||
logger.info("Boost vae for distributed training")
|
||||
|
||||
|
||||
discriminator, disc_optimizer, _, _, disc_lr_scheduler = booster.boost(
|
||||
model=discriminator, optimizer=disc_optimizer, lr_scheduler=disc_lr_scheduler
|
||||
)
|
||||
logger.info("Boost discriminator for distributed training")
|
||||
|
||||
|
||||
# =======================================================
|
||||
# 6. training loop
|
||||
# =======================================================
|
||||
|
|
@ -226,7 +206,6 @@ def main():
|
|||
running_loss = 0.0
|
||||
running_disc_loss = 0.0
|
||||
|
||||
|
||||
# 6.1. resume training
|
||||
if cfg.load is not None:
|
||||
logger.info("Loading checkpoint")
|
||||
|
|
@ -244,11 +223,17 @@ def main():
|
|||
if cfg.lecam_loss_weight is not None and os.path.exists(lecam_path):
|
||||
lecam_state = load_json(lecam_path)
|
||||
lecam_ema_real, lecam_ema_fake = lecam_state["lecam_ema_real"], lecam_state["lecam_ema_fake"]
|
||||
lecam_ema = LeCamEMA(decay=cfg.ema_decay, ema_real=lecam_ema_real, ema_fake=lecam_ema_fake, dtype=dtype, device=device)
|
||||
lecam_ema = LeCamEMA(
|
||||
decay=cfg.ema_decay, ema_real=lecam_ema_real, ema_fake=lecam_ema_fake, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
running_states = load_json(os.path.join(cfg.load, "running_states.json"))
|
||||
dist.barrier()
|
||||
start_epoch, start_step, sampler_start_idx = running_states["epoch"], running_states["step"], running_states["sample_start_index"]
|
||||
start_epoch, start_step, sampler_start_idx = (
|
||||
running_states["epoch"],
|
||||
running_states["step"],
|
||||
running_states["sample_start_index"],
|
||||
)
|
||||
logger.info(f"Loaded checkpoint {cfg.load} at epoch {start_epoch} step {start_step}")
|
||||
logger.info(f"Training for {cfg.epochs} epochs with {num_steps_per_epoch} steps per epoch")
|
||||
|
||||
|
|
@ -257,25 +242,25 @@ def main():
|
|||
# 6.2 Define loss functions
|
||||
vae_loss_fn = VEALoss(
|
||||
logvar_init=cfg.logvar_init,
|
||||
perceptual_loss_weight = cfg.perceptual_loss_weight,
|
||||
kl_loss_weight = cfg.kl_loss_weight,
|
||||
perceptual_loss_weight=cfg.perceptual_loss_weight,
|
||||
kl_loss_weight=cfg.kl_loss_weight,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
adversarial_loss_fn = AdversarialLoss(
|
||||
discriminator_factor = cfg.discriminator_factor,
|
||||
discriminator_start = cfg.discriminator_start,
|
||||
generator_factor = cfg.generator_factor,
|
||||
generator_loss_type = cfg.generator_loss_type,
|
||||
discriminator_factor=cfg.discriminator_factor,
|
||||
discriminator_start=cfg.discriminator_start,
|
||||
generator_factor=cfg.generator_factor,
|
||||
generator_loss_type=cfg.generator_loss_type,
|
||||
)
|
||||
|
||||
disc_loss_fn = DiscriminatorLoss(
|
||||
discriminator_factor = cfg.discriminator_factor,
|
||||
discriminator_start = cfg.discriminator_start,
|
||||
discriminator_loss_type = cfg.discriminator_loss_type,
|
||||
lecam_loss_weight = cfg.lecam_loss_weight,
|
||||
gradient_penalty_loss_weight = cfg.gradient_penalty_loss_weight,
|
||||
discriminator_factor=cfg.discriminator_factor,
|
||||
discriminator_start=cfg.discriminator_start,
|
||||
discriminator_loss_type=cfg.discriminator_loss_type,
|
||||
lecam_loss_weight=cfg.lecam_loss_weight,
|
||||
gradient_penalty_loss_weight=cfg.gradient_penalty_loss_weight,
|
||||
)
|
||||
|
||||
# 6.3. training loop
|
||||
|
|
@ -288,14 +273,11 @@ def main():
|
|||
disc_time_padding = 0
|
||||
video_contains_first_frame = cfg.video_contains_first_frame
|
||||
|
||||
|
||||
for epoch in range(start_epoch, cfg.epochs):
|
||||
dataloader.sampler.set_epoch(epoch)
|
||||
dataloader_iter = iter(dataloader)
|
||||
logger.info(f"Beginning epoch {epoch}...")
|
||||
|
||||
|
||||
|
||||
with tqdm(
|
||||
range(start_step, num_steps_per_epoch),
|
||||
desc=f"Epoch {epoch}",
|
||||
|
|
@ -303,7 +285,6 @@ def main():
|
|||
total=num_steps_per_epoch,
|
||||
initial=start_step,
|
||||
) as pbar:
|
||||
|
||||
# with profile(
|
||||
# activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
# schedule=torch.profiler.schedule(
|
||||
|
|
@ -317,199 +298,203 @@ def main():
|
|||
# record_shapes=True,
|
||||
# profile_memory=True,
|
||||
# ) as p: # trace efficiency
|
||||
|
||||
for step in pbar:
|
||||
|
||||
# with profile(
|
||||
# activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
# with_stack=True,
|
||||
# ) as p: # trace efficiency
|
||||
for step in pbar:
|
||||
# with profile(
|
||||
# activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
# with_stack=True,
|
||||
# ) as p: # trace efficiency
|
||||
|
||||
# SCH: calc global step at the start
|
||||
global_step = epoch * num_steps_per_epoch + step
|
||||
|
||||
batch = next(dataloader_iter)
|
||||
x = batch["video"].to(device, dtype) # [B, C, T, H, W]
|
||||
# SCH: calc global step at the start
|
||||
global_step = epoch * num_steps_per_epoch + step
|
||||
|
||||
# supprt for image or video inputs
|
||||
assert x.ndim in {4, 5}, f"received input of {x.ndim} dimensions" # either image or video
|
||||
assert x.shape[-2:] == cfg.image_size, f"received input size {x.shape[-2:]}, but config image size is {cfg.image_size}"
|
||||
is_image = x.ndim == 4
|
||||
if is_image:
|
||||
video = rearrange(x, 'b c ... -> b c 1 ...')
|
||||
video_contains_first_frame = True
|
||||
else:
|
||||
video = x
|
||||
|
||||
# ===== Spatial VAE =====
|
||||
if cfg.get("use_pipeline") == True:
|
||||
with torch.no_grad():
|
||||
video = vae_2d.encode(video)
|
||||
batch = next(dataloader_iter)
|
||||
x = batch["video"].to(device, dtype) # [B, C, T, H, W]
|
||||
|
||||
# ====== VAE ======
|
||||
recon_video, posterior = vae(
|
||||
video,
|
||||
video_contains_first_frame = video_contains_first_frame,
|
||||
)
|
||||
# supprt for image or video inputs
|
||||
assert x.ndim in {4, 5}, f"received input of {x.ndim} dimensions" # either image or video
|
||||
assert (
|
||||
x.shape[-2:] == cfg.image_size
|
||||
), f"received input size {x.shape[-2:]}, but config image size is {cfg.image_size}"
|
||||
is_image = x.ndim == 4
|
||||
if is_image:
|
||||
video = rearrange(x, "b c ... -> b c 1 ...")
|
||||
video_contains_first_frame = True
|
||||
else:
|
||||
video = x
|
||||
|
||||
# ====== Generator Loss ======
|
||||
# simple nll loss
|
||||
nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(
|
||||
video,
|
||||
recon_video,
|
||||
posterior,
|
||||
split = "train"
|
||||
)
|
||||
# ===== Spatial VAE =====
|
||||
if cfg.get("use_pipeline") == True:
|
||||
with torch.no_grad():
|
||||
video = vae_2d.encode(video)
|
||||
|
||||
adversarial_loss = torch.tensor(0.0)
|
||||
# adversarial loss
|
||||
if global_step > cfg.discriminator_start:
|
||||
# padded videos for GAN
|
||||
fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value = 0., dim = 2)
|
||||
fake_logits = discriminator(fake_video.contiguous())
|
||||
adversarial_loss = adversarial_loss_fn(
|
||||
fake_logits,
|
||||
nll_loss,
|
||||
vae.module.get_last_layer(),
|
||||
global_step,
|
||||
is_training = vae.training,
|
||||
)
|
||||
|
||||
vae_loss = weighted_nll_loss + weighted_kl_loss + adversarial_loss
|
||||
# ====== VAE ======
|
||||
recon_video, posterior = vae(
|
||||
video,
|
||||
video_contains_first_frame=video_contains_first_frame,
|
||||
)
|
||||
|
||||
optimizer.zero_grad()
|
||||
# Backward & update
|
||||
booster.backward(loss=vae_loss, optimizer=optimizer)
|
||||
# # NOTE: clip gradients? this is done in Open-Sora-Plan
|
||||
# torch.nn.utils.clip_grad_norm_(vae.parameters(), 1) # NOTE: done by grad_clip
|
||||
optimizer.step()
|
||||
# ====== Generator Loss ======
|
||||
# simple nll loss
|
||||
nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(
|
||||
video, recon_video, posterior, split="train"
|
||||
)
|
||||
|
||||
# Log loss values:
|
||||
all_reduce_mean(vae_loss) # NOTE: this is to get average loss for logging
|
||||
running_loss += vae_loss.item()
|
||||
|
||||
adversarial_loss = torch.tensor(0.0)
|
||||
# adversarial loss
|
||||
if global_step > cfg.discriminator_start:
|
||||
# padded videos for GAN
|
||||
fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value=0.0, dim=2)
|
||||
fake_logits = discriminator(fake_video.contiguous())
|
||||
adversarial_loss = adversarial_loss_fn(
|
||||
fake_logits,
|
||||
nll_loss,
|
||||
vae.module.get_last_layer(),
|
||||
global_step,
|
||||
is_training=vae.training,
|
||||
)
|
||||
|
||||
# ====== Discriminator Loss ======
|
||||
if global_step > cfg.discriminator_start:
|
||||
# if video_contains_first_frame:
|
||||
# Since we don't have enough T frames, pad anyways
|
||||
real_video = pad_at_dim(video, (disc_time_padding, 0), value = 0., dim = 2)
|
||||
fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value = 0., dim = 2)
|
||||
vae_loss = weighted_nll_loss + weighted_kl_loss + adversarial_loss
|
||||
|
||||
if cfg.gradient_penalty_loss_weight is not None and cfg.gradient_penalty_loss_weight > 0.0:
|
||||
real_video = real_video.requires_grad_()
|
||||
real_logits = discriminator(real_video.contiguous()) # SCH: not detached for now for gradient_penalty calculation
|
||||
else:
|
||||
real_logits = discriminator(real_video.contiguous().detach())
|
||||
optimizer.zero_grad()
|
||||
# Backward & update
|
||||
booster.backward(loss=vae_loss, optimizer=optimizer)
|
||||
# # NOTE: clip gradients? this is done in Open-Sora-Plan
|
||||
# torch.nn.utils.clip_grad_norm_(vae.parameters(), 1) # NOTE: done by grad_clip
|
||||
optimizer.step()
|
||||
|
||||
fake_logits = discriminator(fake_video.contiguous().detach())
|
||||
# Log loss values:
|
||||
all_reduce_mean(vae_loss) # NOTE: this is to get average loss for logging
|
||||
running_loss += vae_loss.item()
|
||||
|
||||
# ====== Discriminator Loss ======
|
||||
if global_step > cfg.discriminator_start:
|
||||
# if video_contains_first_frame:
|
||||
# Since we don't have enough T frames, pad anyways
|
||||
real_video = pad_at_dim(video, (disc_time_padding, 0), value=0.0, dim=2)
|
||||
fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value=0.0, dim=2)
|
||||
|
||||
lecam_ema_real, lecam_ema_fake = lecam_ema.get()
|
||||
if cfg.gradient_penalty_loss_weight is not None and cfg.gradient_penalty_loss_weight > 0.0:
|
||||
real_video = real_video.requires_grad_()
|
||||
real_logits = discriminator(
|
||||
real_video.contiguous()
|
||||
) # SCH: not detached for now for gradient_penalty calculation
|
||||
else:
|
||||
real_logits = discriminator(real_video.contiguous().detach())
|
||||
|
||||
weighted_d_adversarial_loss, lecam_loss, gradient_penalty_loss = disc_loss_fn(
|
||||
real_logits,
|
||||
fake_logits,
|
||||
global_step,
|
||||
lecam_ema_real = lecam_ema_real,
|
||||
lecam_ema_fake = lecam_ema_fake,
|
||||
real_video = real_video if cfg.gradient_penalty_loss_weight is not None else None,
|
||||
)
|
||||
disc_loss = weighted_d_adversarial_loss + lecam_loss + gradient_penalty_loss
|
||||
if cfg.lecam_loss_weight is not None:
|
||||
ema_real = torch.mean(real_logits.clone().detach()).to(device, dtype)
|
||||
ema_fake = torch.mean(fake_logits.clone().detach()).to(device, dtype)
|
||||
all_reduce_mean(ema_real)
|
||||
all_reduce_mean(ema_fake)
|
||||
lecam_ema.update(ema_real, ema_fake)
|
||||
fake_logits = discriminator(fake_video.contiguous().detach())
|
||||
|
||||
disc_optimizer.zero_grad()
|
||||
# Backward & update
|
||||
booster.backward(loss=disc_loss, optimizer=disc_optimizer)
|
||||
# # NOTE: TODO: clip gradients? this is done in Open-Sora-Plan
|
||||
# torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1) # NOTE: done by grad_clip
|
||||
disc_optimizer.step()
|
||||
lecam_ema_real, lecam_ema_fake = lecam_ema.get()
|
||||
|
||||
# Log loss values:
|
||||
all_reduce_mean(disc_loss)
|
||||
running_disc_loss += disc_loss.item()
|
||||
else:
|
||||
disc_loss = torch.tensor(0.0)
|
||||
weighted_d_adversarial_loss = torch.tensor(0.0)
|
||||
lecam_loss = torch.tensor(0.0)
|
||||
gradient_penalty_loss = torch.tensor(0.0)
|
||||
weighted_d_adversarial_loss, lecam_loss, gradient_penalty_loss = disc_loss_fn(
|
||||
real_logits,
|
||||
fake_logits,
|
||||
global_step,
|
||||
lecam_ema_real=lecam_ema_real,
|
||||
lecam_ema_fake=lecam_ema_fake,
|
||||
real_video=real_video if cfg.gradient_penalty_loss_weight is not None else None,
|
||||
)
|
||||
disc_loss = weighted_d_adversarial_loss + lecam_loss + gradient_penalty_loss
|
||||
if cfg.lecam_loss_weight is not None:
|
||||
ema_real = torch.mean(real_logits.clone().detach()).to(device, dtype)
|
||||
ema_fake = torch.mean(fake_logits.clone().detach()).to(device, dtype)
|
||||
all_reduce_mean(ema_real)
|
||||
all_reduce_mean(ema_fake)
|
||||
lecam_ema.update(ema_real, ema_fake)
|
||||
|
||||
log_step += 1
|
||||
disc_optimizer.zero_grad()
|
||||
# Backward & update
|
||||
booster.backward(loss=disc_loss, optimizer=disc_optimizer)
|
||||
# # NOTE: TODO: clip gradients? this is done in Open-Sora-Plan
|
||||
# torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1) # NOTE: done by grad_clip
|
||||
disc_optimizer.step()
|
||||
|
||||
# Log to tensorboard
|
||||
if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0:
|
||||
avg_loss = running_loss / log_step
|
||||
avg_disc_loss = running_disc_loss / log_step
|
||||
pbar.set_postfix({"loss": avg_loss, "disc_loss": avg_disc_loss, "step": step, "global_step": global_step})
|
||||
running_loss = 0
|
||||
log_step = 0
|
||||
running_disc_loss = 0
|
||||
writer.add_scalar("loss", vae_loss.item(), global_step)
|
||||
if cfg.wandb:
|
||||
wandb.log(
|
||||
{
|
||||
"iter": global_step,
|
||||
"num_samples": global_step * total_batch_size,
|
||||
"epoch": epoch,
|
||||
"loss": vae_loss.item(),
|
||||
"kl_loss": weighted_kl_loss.item(),
|
||||
"gen_adv_loss": adversarial_loss.item(),
|
||||
"disc_loss": disc_loss.item(),
|
||||
"lecam_loss": lecam_loss.item(),
|
||||
"r1_grad_penalty": gradient_penalty_loss.item(),
|
||||
"nll_loss": weighted_nll_loss.item(),
|
||||
"avg_loss": avg_loss,
|
||||
},
|
||||
step=global_step,
|
||||
)
|
||||
# Log loss values:
|
||||
all_reduce_mean(disc_loss)
|
||||
running_disc_loss += disc_loss.item()
|
||||
else:
|
||||
disc_loss = torch.tensor(0.0)
|
||||
weighted_d_adversarial_loss = torch.tensor(0.0)
|
||||
lecam_loss = torch.tensor(0.0)
|
||||
gradient_penalty_loss = torch.tensor(0.0)
|
||||
|
||||
# Save checkpoint
|
||||
if cfg.ckpt_every > 0 and (global_step + 1) % cfg.ckpt_every == 0:
|
||||
save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}")
|
||||
os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) # already handled in booster save_model
|
||||
booster.save_model(vae, os.path.join(save_dir, "model"), shard=True)
|
||||
booster.save_model(discriminator, os.path.join(save_dir, "discriminator"), shard=True)
|
||||
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096)
|
||||
booster.save_optimizer(disc_optimizer, os.path.join(save_dir, "disc_optimizer"), shard=True, size_per_shard=4096)
|
||||
log_step += 1
|
||||
|
||||
if lr_scheduler is not None:
|
||||
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
|
||||
if disc_lr_scheduler is not None:
|
||||
booster.save_lr_scheduler(disc_lr_scheduler, os.path.join(save_dir, "disc_lr_scheduler"))
|
||||
|
||||
running_states = {
|
||||
# Log to tensorboard
|
||||
if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0:
|
||||
avg_loss = running_loss / log_step
|
||||
avg_disc_loss = running_disc_loss / log_step
|
||||
pbar.set_postfix(
|
||||
{"loss": avg_loss, "disc_loss": avg_disc_loss, "step": step, "global_step": global_step}
|
||||
)
|
||||
running_loss = 0
|
||||
log_step = 0
|
||||
running_disc_loss = 0
|
||||
writer.add_scalar("loss", vae_loss.item(), global_step)
|
||||
if cfg.wandb:
|
||||
wandb.log(
|
||||
{
|
||||
"iter": global_step,
|
||||
"num_samples": global_step * total_batch_size,
|
||||
"epoch": epoch,
|
||||
"step": step+1,
|
||||
"global_step": global_step+1,
|
||||
"sample_start_index": (step+1) * cfg.batch_size,
|
||||
}
|
||||
|
||||
lecam_ema_real, lecam_ema_fake = lecam_ema.get()
|
||||
lecam_state = {
|
||||
"lecam_ema_real": lecam_ema_real.item(),
|
||||
"lecam_ema_fake": lecam_ema_fake.item(),
|
||||
}
|
||||
if coordinator.is_master():
|
||||
save_json(running_states, os.path.join(save_dir, "running_states.json"))
|
||||
if cfg.lecam_loss_weight is not None:
|
||||
save_json(lecam_state, os.path.join(save_dir, "lecam_states.json"))
|
||||
dist.barrier()
|
||||
"loss": vae_loss.item(),
|
||||
"kl_loss": weighted_kl_loss.item(),
|
||||
"gen_adv_loss": adversarial_loss.item(),
|
||||
"disc_loss": disc_loss.item(),
|
||||
"lecam_loss": lecam_loss.item(),
|
||||
"r1_grad_penalty": gradient_penalty_loss.item(),
|
||||
"nll_loss": weighted_nll_loss.item(),
|
||||
"avg_loss": avg_loss,
|
||||
},
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}"
|
||||
)
|
||||
# Save checkpoint
|
||||
if cfg.ckpt_every > 0 and (global_step + 1) % cfg.ckpt_every == 0:
|
||||
save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}")
|
||||
os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) # already handled in booster save_model
|
||||
booster.save_model(vae, os.path.join(save_dir, "model"), shard=True)
|
||||
booster.save_model(discriminator, os.path.join(save_dir, "discriminator"), shard=True)
|
||||
booster.save_optimizer(
|
||||
optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096
|
||||
)
|
||||
booster.save_optimizer(
|
||||
disc_optimizer, os.path.join(save_dir, "disc_optimizer"), shard=True, size_per_shard=4096
|
||||
)
|
||||
|
||||
# print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
|
||||
|
||||
if lr_scheduler is not None:
|
||||
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
|
||||
if disc_lr_scheduler is not None:
|
||||
booster.save_lr_scheduler(disc_lr_scheduler, os.path.join(save_dir, "disc_lr_scheduler"))
|
||||
|
||||
running_states = {
|
||||
"epoch": epoch,
|
||||
"step": step + 1,
|
||||
"global_step": global_step + 1,
|
||||
"sample_start_index": (step + 1) * cfg.batch_size,
|
||||
}
|
||||
|
||||
lecam_ema_real, lecam_ema_fake = lecam_ema.get()
|
||||
lecam_state = {
|
||||
"lecam_ema_real": lecam_ema_real.item(),
|
||||
"lecam_ema_fake": lecam_ema_fake.item(),
|
||||
}
|
||||
if coordinator.is_master():
|
||||
save_json(running_states, os.path.join(save_dir, "running_states.json"))
|
||||
if cfg.lecam_loss_weight is not None:
|
||||
save_json(lecam_state, os.path.join(save_dir, "lecam_states.json"))
|
||||
dist.barrier()
|
||||
|
||||
logger.info(
|
||||
f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}"
|
||||
)
|
||||
|
||||
# print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
|
||||
|
||||
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
|
||||
dataloader.sampler.set_start_index(0)
|
||||
start_step = 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
Loading…
Reference in a new issue