This commit is contained in:
Shen-Chenhui 2024-04-16 15:00:31 +08:00
parent af1b1e484d
commit afd3f823d4
3 changed files with 275 additions and 183 deletions

View file

@ -59,12 +59,10 @@ kl_loss_weight = 0.000001
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
discriminator_factor = 1.0
discriminator_loss_weight = 0.5
lecam_loss_weight = 0 # TODO: not clear in MAGVIT what is the weight
discriminator_loss="hinge"
discriminator_start = -1 # 50000 TODO: change to correct val, debug use -1 for now
gradient_penalty_loss_weight = 10 # SCH: following MAGVIT config.vqgan.grad_penalty_cost
lecam_loss_weight = None # TODO: not clear in MAGVIT what is the weight
discriminator_loss="non-saturating"
discriminator_start = 50000 # 50000 TODO: change to correct val, debug use -1 for now
gradient_penalty_loss_weight = None # 10 # SCH: following MAGVIT config.vqgan.grad_penalty_cost, 10
ema_decay = 0.999 # ema decay factor for generator

View file

@ -58,36 +58,9 @@ def pick_video_frame(video, frame_indices):
def exists(v):
return v is not None
def hinge_discr_loss(fake, real):
return (F.relu(1 + fake) + F.relu(1 - real)).mean()
def hinge_d_loss(logits_real, logits_fake):
loss_real = torch.mean(F.relu(1. - logits_real))
loss_fake = torch.mean(F.relu(1. + logits_fake))
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def vanilla_d_loss(logits_real, logits_fake):
d_loss = 0.5 * (
torch.mean(torch.nn.functional.softplus(-logits_real)) +
torch.mean(torch.nn.functional.softplus(logits_fake)))
return d_loss
# ============== Generator Adversarial Loss Functions ==============
# TODO: verify if this is correct implementation
def lecam_reg(real_pred, fake_pred, ema_real_pred, ema_fake_pred):
"""Lecam loss for data-efficient and stable GAN training.
Described in https://arxiv.org/abs/2104.03310
Args:
real_pred: Prediction (scalar) for the real samples.
fake_pred: Prediction for the fake samples.
ema_real_pred: EMA prediction (scalar) for the real samples.
ema_fake_pred: EMA prediction for the fake samples.
Returns:
Lecam regularization loss (scalar).
"""
assert real_pred.ndim == 0 and ema_fake_pred.ndim == 0
lecam_loss = np.mean(np.power(nn.ReLU(real_pred - ema_fake_pred), 2))
lecam_loss += np.mean(np.power(nn.ReLU(ema_real_pred - fake_pred), 2))
@ -107,6 +80,32 @@ def gradient_penalty_fn(images, output):
gradients = rearrange(gradients, 'b ... -> b (...)')
return ((gradients.norm(2, dim = 1) - 1) ** 2).mean()
# ============== Discriminator Adversarial Loss Functions ==============
def hinge_d_loss(logits_real, logits_fake):
loss_real = torch.mean(F.relu(1. - logits_real))
loss_fake = torch.mean(F.relu(1. + logits_fake))
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def vanilla_d_loss(logits_real, logits_fake):
d_loss = 0.5 * (
torch.mean(torch.nn.functional.softplus(-logits_real)) +
torch.mean(torch.nn.functional.softplus(logits_fake)))
return d_loss
# from MAGVIT, used in place hof hinge_d_loss
def sigmoid_cross_entropy_with_logits(labels, logits):
# The final formulation is: max(x, 0) - x * z + log(1 + exp(-abs(x)))
zeros = torch.zeros_like(logits, dtype=logits.dtype)
condition = (logits >= zeros)
relu_logits = torch.where(condition, logits, zeros)
neg_abs_logits = torch.where(condition, -logits, logits)
return relu_logits - logits * labels + torch.log1p(torch.exp(neg_abs_logits))
def xavier_uniform_weight_init(m):
if isinstance(m, nn.Conv3d) or isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
@ -389,21 +388,28 @@ class StyleGANDiscriminatorBlur(nn.Module):
self.apply(xavier_uniform_weight_init)
def forward(self, x):
x = self.conv1(x)
# print("discriminator aft conv:", x.size())
x = self.activation_fn(x)
for i in range(self.num_blocks):
x = self.res_block_list[i](x)
# print("discriminator resblock down:", x.size())
x = self.conv2(x)
# print("discriminator aft conv2:", x.size())
x = self.norm1(x)
x = self.activation_fn(x)
x = x.reshape((x.shape[0], -1)) # SCH: [B, (C * T * W * H)] ?
# print("discriminator reshape:", x.size())
x = self.linear1(x)
# print("discriminator aft linear1:", x.size())
x = self.activation_fn(x)
x = self.linear2(x)
# print("discriminator aft linear2:", x.size())
return x
class Encoder(nn.Module):
@ -499,20 +505,26 @@ class Encoder(nn.Module):
# NOTE: moved to VAE for separate first frame processing
# x = self.conv1(x)
# print("encoder:", x.size())
for i in range(self.num_blocks):
for j in range(self.num_res_blocks):
x = self.block_res_blocks[i][j](x)
# print("encoder:", x.size())
if i < self.num_blocks - 1:
x = self.conv_blocks[i](x)
# print("encoder:", x.size())
for i in range(self.num_res_blocks):
x = self.res_blocks[i](x)
# print("encoder:", x.size())
x = self.norm1(x)
x = self.activate(x)
x = self.conv2(x)
# print("encoder:", x.size())
return x
class Decoder(nn.Module):
@ -620,19 +632,22 @@ class Decoder(nn.Module):
**kwargs,
):
# dtype, device = x.dtype, x.device
x = self.conv1(x)
# print("decoder:", x.size())
for i in range(self.num_res_blocks):
x = self.res_blocks[i](x)
# print("decoder:", x.size())
for i in reversed(range(self.num_blocks)): # reverse here to make decoder symmetric with encoder
for j in range(self.num_res_blocks):
x = self.block_res_blocks[i][j](x)
# print("decoder:", x.size())
if i > 0:
t_stride = 2 if self.temporal_downsample[i - 1] else 1
# SCH: T-Causal Conv 3x3x3, f -> (t_stride * 2 * 2) * f, depth to space t_stride x 2 x 2
x = self.conv_blocks[i-1](x)
x = rearrange(x, "B (C ts hs ws) T H W -> B C (T ts) (H hs) (W ws)", ts=t_stride, hs=2, ws=2)
# print("decoder:", x.size())
x = self.norm1(x)
x = self.activate(x)
@ -744,19 +759,30 @@ class VAE_3D_V2(nn.Module):
video = pad_at_dim(video, (self.time_padding, 0), value = 0., dim = 2)
video_packed_shape = [torch.Size([self.time_padding]), torch.Size([]), torch.Size([video_len - 1])]
# print("pre-encoder:", video.size())
# NOTE: moved encoder conv1 here for separate first frame encoding
if encode_first_frame_separately:
pad, first_frame, video = unpack(video, video_packed_shape, 'b c * h w')
first_frame = self.conv_in_first_frame(first_frame)
video = self.conv_in(video)
# print("pre-encoder:", video.size())
if encode_first_frame_separately:
video, _ = pack([first_frame, video], 'b c * h w')
video = pad_at_dim(video, (self.time_padding, 0), dim = 2)
encoded_feature = self.encoder(video)
# print("after encoder:", encoded_feature.size())
# NOTE: TODO: do we include this before gaussian distri? or go directly to Gaussian distribution
moments = self.quant_conv(encoded_feature).to(video.dtype)
posterior = model_utils.DiagonalGaussianDistribution(moments)
# print("after encoder moments:", moments.size())
return posterior
def decode(
@ -767,8 +793,12 @@ class VAE_3D_V2(nn.Module):
# dtype = z.dtype
decode_first_frame_separately = self.separate_first_frame_encoding and video_contains_first_frame
z = self.post_quant_conv(z)
# print("pre decoder, post quant conv:", z.size())
dec = self.decoder(z)
# print("post decoder:", dec.size())
# SCH: moved decoder last conv layer here for separate first frame decoding
if decode_first_frame_separately:
@ -782,6 +812,8 @@ class VAE_3D_V2(nn.Module):
if video_contains_first_frame:
video = video[:, :, self.time_padding:]
# print("conv out:", video.size())
return video
def get_last_layer(self):
@ -971,24 +1003,19 @@ class DiscriminatorLoss(nn.Module):
self,
discriminator_factor = 1.0,
discriminator_start = 50001,
discriminator_loss="hinge",
discriminator_loss="non-saturating",
lecam_loss_weight=0,
gradient_penalty_loss_weight=10, # SCH: following MAGVIT config.vqgan.grad_penalty_cost
):
super().__init__()
assert discriminator_loss in ["hinge", "vanilla"]
assert discriminator_loss in ["hinge", "vanilla", "non-saturating"]
self.discriminator_factor = discriminator_factor
self.discriminator_start = discriminator_start
self.lecam_loss_weight = lecam_loss_weight
self.gradient_penalty_loss_weight = gradient_penalty_loss_weight
if discriminator_loss == "hinge":
self.disc_loss_fn = hinge_d_loss
elif discriminator_loss == "vanilla":
self.disc_loss_fn = vanilla_d_loss
else:
raise ValueError(f"Unknown GAN loss '{discriminator_loss}'.")
self.discriminator_loss_type = discriminator_loss
def forward(
self,
@ -1002,11 +1029,33 @@ class DiscriminatorLoss(nn.Module):
):
if self.discriminator_factor is not None and self.discriminator_factor > 0.0:
disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_start)
weighted_d_adversarial_loss = disc_factor * self.disc_loss_fn(real_logits, fake_logits)
if self.discriminator_loss_type == "hinge":
disc_loss = hinge_d_loss(real_logits, fake_logits)
elif self.discriminator_loss_type == "non-saturating":
if real_logits is not None:
real_loss = sigmoid_cross_entropy_with_logits(
labels=torch.ones_like(real_logits), logits=real_logits
)
else:
real_loss = 0.0
if fake_logits is not None:
fake_loss = sigmoid_cross_entropy_with_logits(
labels=torch.zeros_like(fake_logits), logits=fake_logits)
else:
fake_loss = 0.0
disc_loss = 0.5 * (torch.mean(real_loss) + torch.mean(fake_loss))
elif self.discriminator_loss_type == "vanilla":
disc_loss = vanilla_d_loss(real_logits, fake_logits)
else:
raise ValueError(f"Unknown GAN loss '{self.discriminator_loss_type}'.")
weighted_d_adversarial_loss = disc_factor * disc_loss
else:
weighted_d_adversarial_loss = 0
lecam_loss = 0.0
if self.lecam_loss_weight is not None and self.lecam_loss_weight > 0.0:
real_pred = np.mean(real_logits.clone().detach())
fake_pred = np.mean(fake_logits.clone().detach())

View file

@ -37,6 +37,14 @@ from opensora.utils.misc import all_reduce_mean, format_numel_str, get_model_num
from opensora.utils.train_utils import update_ema, MaskGenerator
from opensora.models.vae.vae_3d_v2 import VEALoss, DiscriminatorLoss, AdversarialLoss, pad_at_dim
# efficiency
# from torch.profiler import profile, record_function, ProfilerActivity
def trace_handler(p):
output = p.key_averages().table(sort_by="self_cuda_time_total", row_limit=5)
print(output)
# p.export_chrome_trace("/home/shenchenhui/Open-Sora-dev/outputs/traces/trace_" + str(p.step_num) + ".json")
def main():
# ======================================================
@ -226,7 +234,9 @@ def main():
discriminator_factor = cfg.discriminator_factor,
discriminator_start = cfg.discriminator_start,
discriminator_loss = cfg.discriminator_loss,
)
lecam_loss_weight = cfg.lecam_loss_weight,
gradient_penalty_loss_weight = cfg.gradient_penalty_loss_weight,
)
# 6.3. training loop
@ -238,11 +248,15 @@ def main():
lecam_ema_real = torch.tensor(0.0)
lecam_ema_fake = torch.tensor(0.0)
for epoch in range(start_epoch, cfg.epochs):
dataloader.sampler.set_epoch(epoch)
dataloader_iter = iter(dataloader)
logger.info(f"Beginning epoch {epoch}...")
with tqdm(
range(start_step, num_steps_per_epoch),
desc=f"Epoch {epoch}",
@ -250,151 +264,182 @@ def main():
total=num_steps_per_epoch,
initial=start_step,
) as pbar:
for step in pbar:
# SCH: calc global step at the start
global_step = epoch * num_steps_per_epoch + step
batch = next(dataloader_iter)
x = batch["video"].to(device, dtype) # [B, C, T, H, W]
# supprt for image or video inputs
assert x.ndim in {4, 5}, f"received input of {x.ndim} dimensions" # either image or video
assert x.shape[-2:] == cfg.image_size, f"received input size {x.shape[-2:]}, but config image size is {cfg.image_size}"
is_image = x.ndim == 4
if is_image:
video = rearrange(x, 'b c ... -> b c 1 ...')
video_contains_first_frame = True
else:
video = x
# ====== VAE ======
optimizer.zero_grad()
recon_video, posterior = vae(
video,
video_contains_first_frame = video_contains_first_frame,
)
# ====== Generator Loss ======
# simple nll loss
nll_loss = nll_loss_fn(
video,
recon_video,
posterior,
split = "train"
)
vae_loss = nll_loss
# adversarial loss
if global_step > cfg.discriminator_start:
# padded videos for GAN
fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value = 0., dim = 2)
fake_logits = discriminator(fake_video.contiguous())
adversarial_loss = adversarial_loss_fn(
fake_logits,
nll_loss,
vae.module.get_last_layer(),
global_step,
is_training = vae.training,
)
vae_loss += adversarial_loss
# Backward & update
booster.backward(loss=vae_loss, optimizer=optimizer)
# NOTE: clip gradients? this is done in Open-Sora-Plan
torch.nn.utils.clip_grad_norm_(vae.parameters(), 1)
optimizer.step()
# Log loss values:
all_reduce_mean(vae_loss)
running_loss += vae_loss.item()
# with profile(
# activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
# schedule=torch.profiler.schedule(
# wait=1,
# warmup=1,
# active=2,
# repeat=2,
# ),
# on_trace_ready=torch.profiler.tensorboard_trace_handler('/home/shenchenhui/log'),
# with_stack=True,
# record_shapes=True,
# profile_memory=True,
# ) as p: # trace efficiency
for step in pbar:
# ====== Discriminator Loss ======
if global_step > cfg.discriminator_start:
disc_optimizer.zero_grad()
# if video_contains_first_frame:
# Since we don't have enough T frames, pad anyways
real_video = pad_at_dim(video, (disc_time_padding, 0), value = 0., dim = 2)
fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value = 0., dim = 2)
if cfg.gradient_penalty_loss_weight is not None and cfg.gradient_penalty_loss_weight > 0.0:
real_video = real_video.requires_grad_()
# with profile(
# activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
# with_stack=True,
# ) as p: # trace efficiency
real_logits = discriminator(real_video.contiguous()) # SCH: not detached for now for gradient_penalty calculation
fake_logits = discriminator(fake_video.contiguous().detach())
disc_loss = disc_loss_fn(
real_logits,
fake_logits,
global_step,
lecam_ema_real = lecam_ema_real,
lecam_ema_fake = lecam_ema_fake,
real_video = real_video
)
# SCH: calc global step at the start
global_step = epoch * num_steps_per_epoch + step
batch = next(dataloader_iter)
x = batch["video"].to(device, dtype) # [B, C, T, H, W]
if cfg.ema_decay is not None:
# SCH: TODO: is this written properly like this for moving average? e.g. distributed training etc.
lecam_ema_real = lecam_ema_real * cfg.ema_decay + (1 - cfg.ema_decay) * torch.mean(real_logits.clone().detach())
lecam_ema_fake = lecam_ema_fake * cfg.ema_decay + (1 - cfg.ema_decay) * torch.mean(fake_logits.clone().detach())
# supprt for image or video inputs
assert x.ndim in {4, 5}, f"received input of {x.ndim} dimensions" # either image or video
assert x.shape[-2:] == cfg.image_size, f"received input size {x.shape[-2:]}, but config image size is {cfg.image_size}"
is_image = x.ndim == 4
if is_image:
video = rearrange(x, 'b c ... -> b c 1 ...')
video_contains_first_frame = True
else:
video = x
# Backward & update
booster.backward(loss=disc_loss, optimizer=disc_optimizer)
# NOTE: TODO: clip gradients? this is done in Open-Sora-Plan
torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1)
disc_optimizer.step()
# Log loss values:
all_reduce_mean(disc_loss)
running_disc_loss += disc_loss.item()
log_step += 1
# Log to tensorboard
if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0:
avg_loss = running_loss / log_step
avg_disc_loss = running_disc_loss / log_step
pbar.set_postfix({"loss": avg_loss, "disc_loss": avg_disc_loss, "step": step, "global_step": global_step})
running_loss = 0
log_step = 0
writer.add_scalar("loss", vae_loss.item(), global_step)
if cfg.wandb:
wandb.log(
{
"iter": global_step,
"num_samples": global_step * total_batch_size,
"epoch": epoch,
"loss": vae_loss.item(),
"disc_loss": disc_loss.item(),
"avg_loss": avg_loss,
},
step=global_step,
# ====== VAE ======
optimizer.zero_grad()
recon_video, posterior = vae(
video,
video_contains_first_frame = video_contains_first_frame,
)
# Save checkpoint
if cfg.ckpt_every > 0 and (global_step + 1) % cfg.ckpt_every == 0:
save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}")
os.makedirs(os.path.join(save_dir, "model"), exist_ok=True)
booster.save_model(vae, os.path.join(save_dir, "model"), shard=True)
booster.save_model(discriminator, os.path.join(save_dir, "discriminator"), shard=True)
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096)
booster.save_optimizer(disc_optimizer, os.path.join(save_dir, "disc_optimizer"), shard=True, size_per_shard=4096)
# ====== Generator Loss ======
# simple nll loss
nll_loss = nll_loss_fn(
video,
recon_video,
posterior,
split = "train"
)
vae_loss = nll_loss
if lr_scheduler is not None:
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
if disc_lr_scheduler is not None:
booster.save_lr_scheduler(disc_lr_scheduler, os.path.join(save_dir, "disc_lr_scheduler"))
# adversarial loss
if global_step > cfg.discriminator_start:
# padded videos for GAN
fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value = 0., dim = 2)
fake_logits = discriminator(fake_video) # TODO: take out contiguous?
adversarial_loss = adversarial_loss_fn(
fake_logits,
nll_loss,
vae.module.get_last_layer(),
global_step,
is_training = vae.training,
)
vae_loss += adversarial_loss
# Backward & update
booster.backward(loss=vae_loss, optimizer=optimizer)
# NOTE: clip gradients? this is done in Open-Sora-Plan
torch.nn.utils.clip_grad_norm_(vae.parameters(), 1)
optimizer.step()
# Log loss values:
all_reduce_mean(vae_loss)
running_loss += vae_loss.item()
running_states = {
"epoch": epoch,
"step": step+1,
"global_step": global_step+1,
"sample_start_index": (step+1) * cfg.batch_size,
}
if coordinator.is_master():
save_json(running_states, os.path.join(save_dir, "running_states.json"))
dist.barrier()
logger.info(
f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}"
)
# ====== Discriminator Loss ======
if global_step > cfg.discriminator_start:
disc_optimizer.zero_grad()
# if video_contains_first_frame:
# Since we don't have enough T frames, pad anyways
real_video = pad_at_dim(video, (disc_time_padding, 0), value = 0., dim = 2)
fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value = 0., dim = 2)
if cfg.gradient_penalty_loss_weight is not None and cfg.gradient_penalty_loss_weight > 0.0:
real_video = real_video.requires_grad_()
real_logits = discriminator(real_video.contiguous()) # SCH: not detached for now for gradient_penalty calculation
if cfg.gradient_penalty_loss_weight is None:
real_logits = real_logits.detach()
fake_logits = discriminator(fake_video.contiguous().detach())
disc_loss = disc_loss_fn(
real_logits,
fake_logits,
global_step,
lecam_ema_real = lecam_ema_real,
lecam_ema_fake = lecam_ema_fake,
real_video = real_video if cfg.gradient_penalty_loss_weight is not None else None,
)
if cfg.ema_decay is not None:
# SCH: TODO: is this written properly like this for moving average? e.g. distributed training etc.
lecam_ema_real = lecam_ema_real * cfg.ema_decay + (1 - cfg.ema_decay) * torch.mean(real_logits.clone().detach())
lecam_ema_fake = lecam_ema_fake * cfg.ema_decay + (1 - cfg.ema_decay) * torch.mean(fake_logits.clone().detach())
# Backward & update
booster.backward(loss=disc_loss, optimizer=disc_optimizer)
# NOTE: TODO: clip gradients? this is done in Open-Sora-Plan
torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1)
disc_optimizer.step()
# Log loss values:
all_reduce_mean(disc_loss)
running_disc_loss += disc_loss.item()
log_step += 1
# Log to tensorboard
if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0:
avg_loss = running_loss / log_step
avg_disc_loss = running_disc_loss / log_step
pbar.set_postfix({"loss": avg_loss, "disc_loss": avg_disc_loss, "step": step, "global_step": global_step})
running_loss = 0
log_step = 0
writer.add_scalar("loss", vae_loss.item(), global_step)
if cfg.wandb:
wandb.log(
{
"iter": global_step,
"num_samples": global_step * total_batch_size,
"epoch": epoch,
"loss": vae_loss.item(),
"disc_loss": disc_loss.item(),
"avg_loss": avg_loss,
},
step=global_step,
)
# Save checkpoint
if cfg.ckpt_every > 0 and (global_step + 1) % cfg.ckpt_every == 0:
save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}")
os.makedirs(os.path.join(save_dir, "model"), exist_ok=True)
booster.save_model(vae, os.path.join(save_dir, "model"), shard=True)
booster.save_model(discriminator, os.path.join(save_dir, "discriminator"), shard=True)
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096)
booster.save_optimizer(disc_optimizer, os.path.join(save_dir, "disc_optimizer"), shard=True, size_per_shard=4096)
if lr_scheduler is not None:
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
if disc_lr_scheduler is not None:
booster.save_lr_scheduler(disc_lr_scheduler, os.path.join(save_dir, "disc_lr_scheduler"))
running_states = {
"epoch": epoch,
"step": step+1,
"global_step": global_step+1,
"sample_start_index": (step+1) * cfg.batch_size,
}
if coordinator.is_master():
save_json(running_states, os.path.join(save_dir, "running_states.json"))
dist.barrier()
logger.info(
f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}"
)
# p.step()
# print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
dataloader.sampler.set_start_index(0)