add lecam

This commit is contained in:
Shen-Chenhui 2024-04-23 09:22:05 +08:00
parent 36d63f4100
commit f4fe9b5eca
3 changed files with 118 additions and 15 deletions

View file

@ -0,0 +1,82 @@
num_frames = 16
frame_interval = 3
image_size = (128, 128)
# Define dataset
root = None
data_path = "CSV_PATH"
use_image_transform = False
num_workers = 4
video_contains_first_frame = False
# Define acceleration
dtype = "bf16"
grad_checkpoint = True
plugin = "zero2"
sp_size = 1
# Define model
model = dict(
type="VAE_MAGVIT_V2",
in_out_channels = 3,
latent_embed_dim = 4,
filters = 128,
num_res_blocks = 4,
channel_multipliers = (1, 2, 2, 4),
temporal_downsample = (False, True, True),
num_groups = 32, # for nn.GroupNorm
kl_embed_dim = 4,
activation_fn = 'swish',
separate_first_frame_encoding = False,
disable_space = False,
custom_conv_padding = None
)
discriminator = dict(
type="DISCRIMINATOR_3D",
image_size = image_size,
num_frames = num_frames,
in_channels = 3,
filters = 128,
channel_multipliers = (2,4,4,4,4), # (2,4,4,4) for 64x64 resolution
)
# loss weights
logvar_init=0.0
kl_loss_weight = 0.000001
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
discriminator_factor = 1.0 # for discriminator adversarial loss
# discriminator_loss_weight = 0.5 # for generator adversarial loss
generator_factor = 0.1 # SCH: generator adversarial loss, MAGVIT v2 uses 0.1
lecam_loss_weight = 0.001 # NOTE: MAVGIT v2 use 0.001
discriminator_loss_type="non-saturating"
generator_loss_type="non-saturating"
discriminator_start = 1000 # 50000 NOTE: change to correct val, debug use -1 for now
gradient_penalty_loss_weight = None # 10 # SCH: MAGVIT uses 10, opensora plan doesn't use
ema_decay = 0.999 # ema decay for LeCam
# Others
seed = 42
outputs = "outputs"
wandb = False
# Training
''' NOTE:
magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128
==> ours num_frams = 16, reso = 256, so samples (K) * epochs ~ [500 - 1200],
3-6 epochs for pexel, from pexel observation its correct
'''
epochs = 2000
log_every = 1
ckpt_every = 1000
load = None
batch_size = 4
lr = 1e-4
grad_clip = 1.0

View file

@ -58,7 +58,6 @@ def exists(v):
return v is not None
# ============== Generator Adversarial Loss Functions ==============
# TODO: verify if this is correct implementation
def lecam_reg(real_pred, fake_pred, ema_real_pred, ema_fake_pred):
assert real_pred.ndim == 0 and ema_fake_pred.ndim == 0
lecam_loss = np.mean(np.power(nn.ReLU(real_pred - ema_fake_pred), 2))
@ -1120,8 +1119,19 @@ class AdversarialLoss(nn.Module):
return weighted_gen_loss
class LeCamEMA:
def __init__(self, decay=0.999):
self.decay = decay
self.ema_real = torch.tensor(0.0)
self.ema_fake = torch.tensor(0.0)
def update(self, ema_real, ema_fake):
self.ema_real = self.ema_real * self.decay + ema_real * (1-self.decay)
self.ema_fake = self.ema_fake * self.decay + ema_fake * (1-self.decay)
def get(self):
return self.ema_real, self.ema_fake
class DiscriminatorLoss(nn.Module):
def __init__(
self,

View file

@ -35,7 +35,7 @@ from opensora.utils.config_utils import (
)
from opensora.utils.misc import all_reduce_mean, format_numel_str, get_model_numel, requires_grad, to_torch_dtype
from opensora.utils.train_utils import update_ema, MaskGenerator
from opensora.models.vae.vae_3d_v2 import VEALoss, DiscriminatorLoss, AdversarialLoss, pad_at_dim
from opensora.models.vae.vae_3d_v2 import VEALoss, DiscriminatorLoss, AdversarialLoss, LeCamEMA, pad_at_dim
@ -78,7 +78,9 @@ def main():
writer = create_tensorboard_writer(exp_dir)
if cfg.wandb:
wandb.init(project="opensora-vae", name=exp_name, config=cfg._cfg_dict)
# wandb.init(project="opensora-vae", name=exp_name, config=cfg._cfg_dict)
# NOTE: here we use the outputs folder name to store running records of different experiments (since frequent interruption)
wandb.init(project="opensora-vae", name=cfg.outputs, config=cfg._cfg_dict)
# 2.3. initialize ColossalAI booster
if cfg.plugin == "zero2":
@ -261,10 +263,10 @@ def main():
disc_time_padding = 0
video_contains_first_frame = cfg.video_contains_first_frame
lecam_ema_real = torch.tensor(0.0)
lecam_ema_fake = torch.tensor(0.0)
# lecam_ema_real = torch.tensor(0.0)
# lecam_ema_fake = torch.tensor(0.0)
lecam_ema = LeCamEMA(decay=cfg.ema_decay)
for epoch in range(start_epoch, cfg.epochs):
dataloader.sampler.set_epoch(epoch)
@ -357,12 +359,12 @@ def main():
optimizer.zero_grad()
# Backward & update
booster.backward(loss=vae_loss, optimizer=optimizer)
# NOTE: clip gradients? this is done in Open-Sora-Plan
torch.nn.utils.clip_grad_norm_(vae.parameters(), 1)
# # NOTE: clip gradients? this is done in Open-Sora-Plan
# torch.nn.utils.clip_grad_norm_(vae.parameters(), 1) # NOTE: done by grad_clip
optimizer.step()
# Log loss values:
all_reduce_mean(vae_loss)
all_reduce_mean(vae_loss) # NOTE: this is to get average loss for logging
running_loss += vae_loss.item()
@ -381,6 +383,10 @@ def main():
real_logits = discriminator(real_video.contiguous().detach())
fake_logits = discriminator(fake_video.contiguous().detach())
lecam_ema_real, lecam_ema_fake = lecam_ema.get()
disc_loss = disc_loss_fn(
real_logits,
fake_logits,
@ -392,14 +398,19 @@ def main():
if cfg.ema_decay is not None:
# SCH: TODO: is this written properly like this for moving average? e.g. distributed training etc.
lecam_ema_real = lecam_ema_real * cfg.ema_decay + (1 - cfg.ema_decay) * torch.mean(real_logits.clone().detach())
lecam_ema_fake = lecam_ema_fake * cfg.ema_decay + (1 - cfg.ema_decay) * torch.mean(fake_logits.clone().detach())
# lecam_ema_real = lecam_ema_real * cfg.ema_decay + (1 - cfg.ema_decay) * torch.mean(real_logits.clone().detach())
# lecam_ema_fake = lecam_ema_fake * cfg.ema_decay + (1 - cfg.ema_decay) * torch.mean(fake_logits.clone().detach())
ema_real = torch.mean(real_logits.clone().detach())
ema_fake = torch.mean(fake_logits.clone().detach())
all_reduce_mean(ema_real)
all_reduce_mean(ema_fake)
lecam_ema.update(ema_real, ema_fake)
disc_optimizer.zero_grad()
# Backward & update
booster.backward(loss=disc_loss, optimizer=disc_optimizer)
# NOTE: TODO: clip gradients? this is done in Open-Sora-Plan
torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1)
# # NOTE: TODO: clip gradients? this is done in Open-Sora-Plan
# torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1) # NOTE: done by grad_clip
disc_optimizer.step()
# Log loss values: