config lc gp

This commit is contained in:
Shen-Chenhui 2024-04-24 10:03:58 +08:00
parent 48e3a84ba6
commit 988bc3bb65
4 changed files with 194 additions and 9 deletions

View file

@ -0,0 +1,93 @@
num_frames = 16
frame_interval = 3
image_size = (128, 128)
use_pipeline = True
# Define dataset
root = None
data_path = "CSV_PATH"
use_image_transform = False
num_workers = 4
video_contains_first_frame = False
# Define acceleration
dtype = "bf16"
grad_checkpoint = True
plugin = "zero2"
sp_size = 1
# Define model
vae_2d = dict(
type="VideoAutoencoderKL",
from_pretrained="stabilityai/sd-vae-ft-ema",
# SDXL
)
model = dict(
type="VAE_MAGVIT_V2",
in_out_channels = 4,
latent_embed_dim = 256,
filters = 128,
num_res_blocks = 4,
channel_multipliers = (1, 2, 2, 4),
temporal_downsample = (False, True, True),
num_groups = 32, # for nn.GroupNorm
kl_embed_dim = 64,
activation_fn = 'swish',
separate_first_frame_encoding = False,
disable_space = True,
encoder_double_z = False,
custom_conv_padding = None
)
discriminator = dict(
type="DISCRIMINATOR_3D",
image_size = (16, 16), # NOTE: here image size is different
num_frames = num_frames,
in_channels = 4,
filters = 128,
use_pretrained=True, # NOTE: set to False only if we want to fresh train using a different discriminator!
# channel_multipliers = (2,4,4,4,4), # (2,4,4,4) for 64x64 resolution
channel_multipliers= (2,4,4) # since on intermediate layer, 16 x 16 x 16 dimension z
)
# loss weights
logvar_init=0.0
kl_loss_weight = 0.000001
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
discriminator_factor = 1.0 # for discriminator adversarial loss
# discriminator_loss_weight = 0.5 # for generator adversarial loss
generator_factor = 0.1 # SCH: generator adversarial loss, MAGVIT v2 uses 0.1
lecam_loss_weight = 0.001 # NOTE: MAVGIT v2 use 0.001
# discriminator_loss_type="non-saturating"
# generator_loss_type="non-saturating"
discriminator_loss_type="hinge"
generator_loss_type="hinge"
discriminator_start = 1000 # 50000 NOTE: change to correct val, debug use -1 for now
gradient_penalty_loss_weight = 10 # 10 # SCH: MAGVIT uses 10, opensora plan doesn't use
ema_decay = 0.999 # ema decay factor for generator
# Others
seed = 42
outputs = "outputs"
wandb = False
# Training
''' NOTE:
magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128
==> ours num_frams = 16, reso = 256, so samples (K) * epochs ~ [500 - 1200],
3-6 epochs for pexel, from pexel observation its correct
'''
epochs = 2000
log_every = 1
ckpt_every = 1000
load = None
batch_size = 4
lr = 1e-4
grad_clip = 1.0

View file

@ -0,0 +1,93 @@
num_frames = 16
frame_interval = 3
image_size = (128, 128)
use_pipeline = True
# Define dataset
root = None
data_path = "CSV_PATH"
use_image_transform = False
num_workers = 4
video_contains_first_frame = False
# Define acceleration
dtype = "bf16"
grad_checkpoint = True
plugin = "zero2"
sp_size = 1
# Define model
vae_2d = dict(
type="VideoAutoencoderKL",
from_pretrained="stabilityai/sd-vae-ft-ema",
# SDXL
)
model = dict(
type="VAE_MAGVIT_V2",
in_out_channels = 4,
latent_embed_dim = 256,
filters = 128,
num_res_blocks = 4,
channel_multipliers = (1, 2, 2, 4),
temporal_downsample = (False, True, True),
num_groups = 32, # for nn.GroupNorm
kl_embed_dim = 64,
activation_fn = 'swish',
separate_first_frame_encoding = False,
disable_space = True,
encoder_double_z = False,
custom_conv_padding = None
)
discriminator = dict(
type="DISCRIMINATOR_3D",
image_size = (16, 16), # NOTE: here image size is different
num_frames = num_frames,
in_channels = 4,
filters = 128,
use_pretrained=True, # NOTE: set to False only if we want to fresh train using a different discriminator!
# channel_multipliers = (2,4,4,4,4), # (2,4,4,4) for 64x64 resolution
channel_multipliers= (2,4,4) # since on intermediate layer, 16 x 16 x 16 dimension z
)
# loss weights
logvar_init=0.0
kl_loss_weight = 0.000001
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
discriminator_factor = 1.0 # for discriminator adversarial loss
# discriminator_loss_weight = 0.5 # for generator adversarial loss
generator_factor = 0.1 # SCH: generator adversarial loss, MAGVIT v2 uses 0.1
lecam_loss_weight = 0.001 # NOTE: MAVGIT v2 use 0.001
# discriminator_loss_type="non-saturating"
# generator_loss_type="non-saturating"
discriminator_loss_type="hinge"
generator_loss_type="hinge"
discriminator_start = 1000 # 50000 NOTE: change to correct val, debug use -1 for now
gradient_penalty_loss_weight = None # 10 # SCH: MAGVIT uses 10, opensora plan doesn't use
ema_decay = 0.999 # ema decay factor for generator
# Others
seed = 42
outputs = "outputs"
wandb = False
# Training
''' NOTE:
magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128
==> ours num_frams = 16, reso = 256, so samples (K) * epochs ~ [500 - 1200],
3-6 epochs for pexel, from pexel observation its correct
'''
epochs = 2000
log_every = 1
ckpt_every = 1000
load = None
batch_size = 4
lr = 1e-4
grad_clip = 1.0

View file

@ -1251,8 +1251,7 @@ class DiscriminatorLoss(nn.Module):
else:
weighted_d_adversarial_loss = 0
lecam_loss = 0.0
lecam_loss = torch.tensor(0.0)
if self.lecam_loss_weight is not None and self.lecam_loss_weight > 0.0:
real_pred = torch.mean(real_logits)
fake_pred = torch.mean(fake_logits)
@ -1261,24 +1260,22 @@ class DiscriminatorLoss(nn.Module):
lecam_ema_fake)
lecam_loss = lecam_loss * self.lecam_loss_weight
gradient_penalty = 0.0
gradient_penalty = torch.tensor(0.0)
if self.gradient_penalty_loss_weight is not None and self.gradient_penalty_loss_weight > 0.0:
assert real_video is not None
# gradient_penalty = gradient_penalty_fn(real_video, real_logits)
gradient_penalty = r1_penalty(real_video, real_logits) # MAGVIT uses r1 penalty
gradient_penalty *= self.gradient_penalty_loss_weight
discriminator_loss = weighted_d_adversarial_loss + lecam_loss + gradient_penalty
# discriminator_loss = weighted_d_adversarial_loss + lecam_loss + gradient_penalty
# log = {
# "{}/discriminator_loss".format(split): discriminator_loss.clone().detach().mean(),
# "{}/d_adversarial_loss".format(split): weighted_d_adversarial_loss.detach().mean(),
# "{}/lecam_loss".format(split): lecam_loss.detach().mean(),
# "{}/gradient_penalty".format(split): gradient_penalty.detach().mean(),
# }
return discriminator_loss
return (weighted_d_adversarial_loss, lecam_loss, gradient_penalty)
@MODELS.register_module("VAE_MAGVIT_V2")
def VAE_MAGVIT_V2(from_pretrained=None, **kwargs):

View file

@ -388,7 +388,7 @@ def main():
lecam_ema_real, lecam_ema_fake = lecam_ema.get()
disc_loss = disc_loss_fn(
weighted_d_adversarial_loss, lecam_loss, gradient_penalty_loss = disc_loss_fn(
real_logits,
fake_logits,
global_step,
@ -396,7 +396,7 @@ def main():
lecam_ema_fake = lecam_ema_fake,
real_video = real_video if cfg.gradient_penalty_loss_weight is not None else None,
)
disc_loss = weighted_d_adversarial_loss + lecam_loss + gradient_penalty_loss
if cfg.ema_decay is not None:
# SCH: TODO: is this written properly like this for moving average? e.g. distributed training etc.
# lecam_ema_real = lecam_ema_real * cfg.ema_decay + (1 - cfg.ema_decay) * torch.mean(real_logits.clone().detach())
@ -440,6 +440,8 @@ def main():
"kl_loss": weighted_kl_loss.item(),
"gen_adv_loss": adversarial_loss.item(),
"disc_loss": disc_loss.item(),
"lecam_loss": lecam_loss.item(),
"r1_grad_penalty": gradient_penalty_loss.item(),
"avg_loss": avg_loss,
},
step=global_step,