mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 13:14:44 +02:00
debug
This commit is contained in:
parent
af1b1e484d
commit
afd3f823d4
|
|
@ -59,12 +59,10 @@ kl_loss_weight = 0.000001
|
|||
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
|
||||
discriminator_factor = 1.0
|
||||
discriminator_loss_weight = 0.5
|
||||
lecam_loss_weight = 0 # TODO: not clear in MAGVIT what is the weight
|
||||
discriminator_loss="hinge"
|
||||
discriminator_start = -1 # 50000 TODO: change to correct val, debug use -1 for now
|
||||
|
||||
|
||||
gradient_penalty_loss_weight = 10 # SCH: following MAGVIT config.vqgan.grad_penalty_cost
|
||||
lecam_loss_weight = None # TODO: not clear in MAGVIT what is the weight
|
||||
discriminator_loss="non-saturating"
|
||||
discriminator_start = 50000 # 50000 TODO: change to correct val, debug use -1 for now
|
||||
gradient_penalty_loss_weight = None # 10 # SCH: following MAGVIT config.vqgan.grad_penalty_cost, 10
|
||||
ema_decay = 0.999 # ema decay factor for generator
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -58,36 +58,9 @@ def pick_video_frame(video, frame_indices):
|
|||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
def hinge_discr_loss(fake, real):
|
||||
return (F.relu(1 + fake) + F.relu(1 - real)).mean()
|
||||
|
||||
def hinge_d_loss(logits_real, logits_fake):
|
||||
loss_real = torch.mean(F.relu(1. - logits_real))
|
||||
loss_fake = torch.mean(F.relu(1. + logits_fake))
|
||||
d_loss = 0.5 * (loss_real + loss_fake)
|
||||
return d_loss
|
||||
|
||||
def vanilla_d_loss(logits_real, logits_fake):
|
||||
d_loss = 0.5 * (
|
||||
torch.mean(torch.nn.functional.softplus(-logits_real)) +
|
||||
torch.mean(torch.nn.functional.softplus(logits_fake)))
|
||||
return d_loss
|
||||
|
||||
# ============== Generator Adversarial Loss Functions ==============
|
||||
# TODO: verify if this is correct implementation
|
||||
def lecam_reg(real_pred, fake_pred, ema_real_pred, ema_fake_pred):
|
||||
"""Lecam loss for data-efficient and stable GAN training.
|
||||
|
||||
Described in https://arxiv.org/abs/2104.03310
|
||||
|
||||
Args:
|
||||
real_pred: Prediction (scalar) for the real samples.
|
||||
fake_pred: Prediction for the fake samples.
|
||||
ema_real_pred: EMA prediction (scalar) for the real samples.
|
||||
ema_fake_pred: EMA prediction for the fake samples.
|
||||
|
||||
Returns:
|
||||
Lecam regularization loss (scalar).
|
||||
"""
|
||||
assert real_pred.ndim == 0 and ema_fake_pred.ndim == 0
|
||||
lecam_loss = np.mean(np.power(nn.ReLU(real_pred - ema_fake_pred), 2))
|
||||
lecam_loss += np.mean(np.power(nn.ReLU(ema_real_pred - fake_pred), 2))
|
||||
|
|
@ -107,6 +80,32 @@ def gradient_penalty_fn(images, output):
|
|||
gradients = rearrange(gradients, 'b ... -> b (...)')
|
||||
return ((gradients.norm(2, dim = 1) - 1) ** 2).mean()
|
||||
|
||||
|
||||
# ============== Discriminator Adversarial Loss Functions ==============
|
||||
def hinge_d_loss(logits_real, logits_fake):
|
||||
loss_real = torch.mean(F.relu(1. - logits_real))
|
||||
loss_fake = torch.mean(F.relu(1. + logits_fake))
|
||||
d_loss = 0.5 * (loss_real + loss_fake)
|
||||
return d_loss
|
||||
|
||||
def vanilla_d_loss(logits_real, logits_fake):
|
||||
d_loss = 0.5 * (
|
||||
torch.mean(torch.nn.functional.softplus(-logits_real)) +
|
||||
torch.mean(torch.nn.functional.softplus(logits_fake)))
|
||||
return d_loss
|
||||
|
||||
# from MAGVIT, used in place hof hinge_d_loss
|
||||
def sigmoid_cross_entropy_with_logits(labels, logits):
|
||||
# The final formulation is: max(x, 0) - x * z + log(1 + exp(-abs(x)))
|
||||
zeros = torch.zeros_like(logits, dtype=logits.dtype)
|
||||
condition = (logits >= zeros)
|
||||
relu_logits = torch.where(condition, logits, zeros)
|
||||
neg_abs_logits = torch.where(condition, -logits, logits)
|
||||
return relu_logits - logits * labels + torch.log1p(torch.exp(neg_abs_logits))
|
||||
|
||||
|
||||
|
||||
|
||||
def xavier_uniform_weight_init(m):
|
||||
if isinstance(m, nn.Conv3d) or isinstance(m, nn.Linear):
|
||||
nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
|
||||
|
|
@ -389,21 +388,28 @@ class StyleGANDiscriminatorBlur(nn.Module):
|
|||
self.apply(xavier_uniform_weight_init)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
|
||||
x = self.conv1(x)
|
||||
# print("discriminator aft conv:", x.size())
|
||||
x = self.activation_fn(x)
|
||||
|
||||
for i in range(self.num_blocks):
|
||||
x = self.res_block_list[i](x)
|
||||
# print("discriminator resblock down:", x.size())
|
||||
|
||||
x = self.conv2(x)
|
||||
# print("discriminator aft conv2:", x.size())
|
||||
x = self.norm1(x)
|
||||
x = self.activation_fn(x)
|
||||
x = x.reshape((x.shape[0], -1)) # SCH: [B, (C * T * W * H)] ?
|
||||
|
||||
# print("discriminator reshape:", x.size())
|
||||
x = self.linear1(x)
|
||||
# print("discriminator aft linear1:", x.size())
|
||||
|
||||
x = self.activation_fn(x)
|
||||
x = self.linear2(x)
|
||||
# print("discriminator aft linear2:", x.size())
|
||||
return x
|
||||
|
||||
class Encoder(nn.Module):
|
||||
|
|
@ -499,20 +505,26 @@ class Encoder(nn.Module):
|
|||
# NOTE: moved to VAE for separate first frame processing
|
||||
# x = self.conv1(x)
|
||||
|
||||
# print("encoder:", x.size())
|
||||
|
||||
for i in range(self.num_blocks):
|
||||
for j in range(self.num_res_blocks):
|
||||
x = self.block_res_blocks[i][j](x)
|
||||
# print("encoder:", x.size())
|
||||
|
||||
if i < self.num_blocks - 1:
|
||||
x = self.conv_blocks[i](x)
|
||||
# print("encoder:", x.size())
|
||||
|
||||
|
||||
for i in range(self.num_res_blocks):
|
||||
x = self.res_blocks[i](x)
|
||||
# print("encoder:", x.size())
|
||||
|
||||
x = self.norm1(x)
|
||||
x = self.activate(x)
|
||||
x = self.conv2(x)
|
||||
# print("encoder:", x.size())
|
||||
return x
|
||||
|
||||
class Decoder(nn.Module):
|
||||
|
|
@ -620,19 +632,22 @@ class Decoder(nn.Module):
|
|||
**kwargs,
|
||||
):
|
||||
# dtype, device = x.dtype, x.device
|
||||
|
||||
x = self.conv1(x)
|
||||
# print("decoder:", x.size())
|
||||
for i in range(self.num_res_blocks):
|
||||
x = self.res_blocks[i](x)
|
||||
# print("decoder:", x.size())
|
||||
for i in reversed(range(self.num_blocks)): # reverse here to make decoder symmetric with encoder
|
||||
for j in range(self.num_res_blocks):
|
||||
x = self.block_res_blocks[i][j](x)
|
||||
|
||||
# print("decoder:", x.size())
|
||||
if i > 0:
|
||||
t_stride = 2 if self.temporal_downsample[i - 1] else 1
|
||||
# SCH: T-Causal Conv 3x3x3, f -> (t_stride * 2 * 2) * f, depth to space t_stride x 2 x 2
|
||||
x = self.conv_blocks[i-1](x)
|
||||
x = rearrange(x, "B (C ts hs ws) T H W -> B C (T ts) (H hs) (W ws)", ts=t_stride, hs=2, ws=2)
|
||||
|
||||
# print("decoder:", x.size())
|
||||
|
||||
x = self.norm1(x)
|
||||
x = self.activate(x)
|
||||
|
|
@ -744,19 +759,30 @@ class VAE_3D_V2(nn.Module):
|
|||
video = pad_at_dim(video, (self.time_padding, 0), value = 0., dim = 2)
|
||||
video_packed_shape = [torch.Size([self.time_padding]), torch.Size([]), torch.Size([video_len - 1])]
|
||||
|
||||
# print("pre-encoder:", video.size())
|
||||
|
||||
# NOTE: moved encoder conv1 here for separate first frame encoding
|
||||
if encode_first_frame_separately:
|
||||
pad, first_frame, video = unpack(video, video_packed_shape, 'b c * h w')
|
||||
first_frame = self.conv_in_first_frame(first_frame)
|
||||
video = self.conv_in(video)
|
||||
|
||||
# print("pre-encoder:", video.size())
|
||||
|
||||
if encode_first_frame_separately:
|
||||
video, _ = pack([first_frame, video], 'b c * h w')
|
||||
video = pad_at_dim(video, (self.time_padding, 0), dim = 2)
|
||||
|
||||
encoded_feature = self.encoder(video)
|
||||
|
||||
# print("after encoder:", encoded_feature.size())
|
||||
|
||||
# NOTE: TODO: do we include this before gaussian distri? or go directly to Gaussian distribution
|
||||
moments = self.quant_conv(encoded_feature).to(video.dtype)
|
||||
posterior = model_utils.DiagonalGaussianDistribution(moments)
|
||||
|
||||
# print("after encoder moments:", moments.size())
|
||||
|
||||
return posterior
|
||||
|
||||
def decode(
|
||||
|
|
@ -767,8 +793,12 @@ class VAE_3D_V2(nn.Module):
|
|||
# dtype = z.dtype
|
||||
decode_first_frame_separately = self.separate_first_frame_encoding and video_contains_first_frame
|
||||
|
||||
|
||||
z = self.post_quant_conv(z)
|
||||
# print("pre decoder, post quant conv:", z.size())
|
||||
|
||||
dec = self.decoder(z)
|
||||
# print("post decoder:", dec.size())
|
||||
|
||||
# SCH: moved decoder last conv layer here for separate first frame decoding
|
||||
if decode_first_frame_separately:
|
||||
|
|
@ -782,6 +812,8 @@ class VAE_3D_V2(nn.Module):
|
|||
if video_contains_first_frame:
|
||||
video = video[:, :, self.time_padding:]
|
||||
|
||||
# print("conv out:", video.size())
|
||||
|
||||
return video
|
||||
|
||||
def get_last_layer(self):
|
||||
|
|
@ -971,24 +1003,19 @@ class DiscriminatorLoss(nn.Module):
|
|||
self,
|
||||
discriminator_factor = 1.0,
|
||||
discriminator_start = 50001,
|
||||
discriminator_loss="hinge",
|
||||
discriminator_loss="non-saturating",
|
||||
lecam_loss_weight=0,
|
||||
gradient_penalty_loss_weight=10, # SCH: following MAGVIT config.vqgan.grad_penalty_cost
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert discriminator_loss in ["hinge", "vanilla"]
|
||||
assert discriminator_loss in ["hinge", "vanilla", "non-saturating"]
|
||||
self.discriminator_factor = discriminator_factor
|
||||
self.discriminator_start = discriminator_start
|
||||
self.lecam_loss_weight = lecam_loss_weight
|
||||
self.gradient_penalty_loss_weight = gradient_penalty_loss_weight
|
||||
|
||||
if discriminator_loss == "hinge":
|
||||
self.disc_loss_fn = hinge_d_loss
|
||||
elif discriminator_loss == "vanilla":
|
||||
self.disc_loss_fn = vanilla_d_loss
|
||||
else:
|
||||
raise ValueError(f"Unknown GAN loss '{discriminator_loss}'.")
|
||||
self.discriminator_loss_type = discriminator_loss
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
@ -1002,11 +1029,33 @@ class DiscriminatorLoss(nn.Module):
|
|||
):
|
||||
if self.discriminator_factor is not None and self.discriminator_factor > 0.0:
|
||||
disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_start)
|
||||
weighted_d_adversarial_loss = disc_factor * self.disc_loss_fn(real_logits, fake_logits)
|
||||
|
||||
if self.discriminator_loss_type == "hinge":
|
||||
disc_loss = hinge_d_loss(real_logits, fake_logits)
|
||||
elif self.discriminator_loss_type == "non-saturating":
|
||||
if real_logits is not None:
|
||||
real_loss = sigmoid_cross_entropy_with_logits(
|
||||
labels=torch.ones_like(real_logits), logits=real_logits
|
||||
)
|
||||
else:
|
||||
real_loss = 0.0
|
||||
if fake_logits is not None:
|
||||
fake_loss = sigmoid_cross_entropy_with_logits(
|
||||
labels=torch.zeros_like(fake_logits), logits=fake_logits)
|
||||
else:
|
||||
fake_loss = 0.0
|
||||
disc_loss = 0.5 * (torch.mean(real_loss) + torch.mean(fake_loss))
|
||||
elif self.discriminator_loss_type == "vanilla":
|
||||
disc_loss = vanilla_d_loss(real_logits, fake_logits)
|
||||
else:
|
||||
raise ValueError(f"Unknown GAN loss '{self.discriminator_loss_type}'.")
|
||||
|
||||
weighted_d_adversarial_loss = disc_factor * disc_loss
|
||||
else:
|
||||
weighted_d_adversarial_loss = 0
|
||||
|
||||
lecam_loss = 0.0
|
||||
|
||||
if self.lecam_loss_weight is not None and self.lecam_loss_weight > 0.0:
|
||||
real_pred = np.mean(real_logits.clone().detach())
|
||||
fake_pred = np.mean(fake_logits.clone().detach())
|
||||
|
|
|
|||
|
|
@ -37,6 +37,14 @@ from opensora.utils.misc import all_reduce_mean, format_numel_str, get_model_num
|
|||
from opensora.utils.train_utils import update_ema, MaskGenerator
|
||||
from opensora.models.vae.vae_3d_v2 import VEALoss, DiscriminatorLoss, AdversarialLoss, pad_at_dim
|
||||
|
||||
# efficiency
|
||||
# from torch.profiler import profile, record_function, ProfilerActivity
|
||||
|
||||
def trace_handler(p):
|
||||
output = p.key_averages().table(sort_by="self_cuda_time_total", row_limit=5)
|
||||
print(output)
|
||||
# p.export_chrome_trace("/home/shenchenhui/Open-Sora-dev/outputs/traces/trace_" + str(p.step_num) + ".json")
|
||||
|
||||
|
||||
def main():
|
||||
# ======================================================
|
||||
|
|
@ -226,7 +234,9 @@ def main():
|
|||
discriminator_factor = cfg.discriminator_factor,
|
||||
discriminator_start = cfg.discriminator_start,
|
||||
discriminator_loss = cfg.discriminator_loss,
|
||||
)
|
||||
lecam_loss_weight = cfg.lecam_loss_weight,
|
||||
gradient_penalty_loss_weight = cfg.gradient_penalty_loss_weight,
|
||||
)
|
||||
|
||||
# 6.3. training loop
|
||||
|
||||
|
|
@ -238,11 +248,15 @@ def main():
|
|||
lecam_ema_real = torch.tensor(0.0)
|
||||
lecam_ema_fake = torch.tensor(0.0)
|
||||
|
||||
|
||||
|
||||
for epoch in range(start_epoch, cfg.epochs):
|
||||
dataloader.sampler.set_epoch(epoch)
|
||||
dataloader_iter = iter(dataloader)
|
||||
logger.info(f"Beginning epoch {epoch}...")
|
||||
|
||||
|
||||
|
||||
with tqdm(
|
||||
range(start_step, num_steps_per_epoch),
|
||||
desc=f"Epoch {epoch}",
|
||||
|
|
@ -250,151 +264,182 @@ def main():
|
|||
total=num_steps_per_epoch,
|
||||
initial=start_step,
|
||||
) as pbar:
|
||||
for step in pbar:
|
||||
|
||||
# SCH: calc global step at the start
|
||||
global_step = epoch * num_steps_per_epoch + step
|
||||
|
||||
batch = next(dataloader_iter)
|
||||
x = batch["video"].to(device, dtype) # [B, C, T, H, W]
|
||||
|
||||
# supprt for image or video inputs
|
||||
assert x.ndim in {4, 5}, f"received input of {x.ndim} dimensions" # either image or video
|
||||
assert x.shape[-2:] == cfg.image_size, f"received input size {x.shape[-2:]}, but config image size is {cfg.image_size}"
|
||||
is_image = x.ndim == 4
|
||||
if is_image:
|
||||
video = rearrange(x, 'b c ... -> b c 1 ...')
|
||||
video_contains_first_frame = True
|
||||
else:
|
||||
video = x
|
||||
|
||||
# ====== VAE ======
|
||||
optimizer.zero_grad()
|
||||
recon_video, posterior = vae(
|
||||
video,
|
||||
video_contains_first_frame = video_contains_first_frame,
|
||||
)
|
||||
|
||||
# ====== Generator Loss ======
|
||||
# simple nll loss
|
||||
nll_loss = nll_loss_fn(
|
||||
video,
|
||||
recon_video,
|
||||
posterior,
|
||||
split = "train"
|
||||
)
|
||||
vae_loss = nll_loss
|
||||
|
||||
# adversarial loss
|
||||
if global_step > cfg.discriminator_start:
|
||||
# padded videos for GAN
|
||||
fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value = 0., dim = 2)
|
||||
fake_logits = discriminator(fake_video.contiguous())
|
||||
adversarial_loss = adversarial_loss_fn(
|
||||
fake_logits,
|
||||
nll_loss,
|
||||
vae.module.get_last_layer(),
|
||||
global_step,
|
||||
is_training = vae.training,
|
||||
)
|
||||
vae_loss += adversarial_loss
|
||||
# Backward & update
|
||||
booster.backward(loss=vae_loss, optimizer=optimizer)
|
||||
# NOTE: clip gradients? this is done in Open-Sora-Plan
|
||||
torch.nn.utils.clip_grad_norm_(vae.parameters(), 1)
|
||||
optimizer.step()
|
||||
# Log loss values:
|
||||
all_reduce_mean(vae_loss)
|
||||
running_loss += vae_loss.item()
|
||||
# with profile(
|
||||
# activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
# schedule=torch.profiler.schedule(
|
||||
# wait=1,
|
||||
# warmup=1,
|
||||
# active=2,
|
||||
# repeat=2,
|
||||
# ),
|
||||
# on_trace_ready=torch.profiler.tensorboard_trace_handler('/home/shenchenhui/log'),
|
||||
# with_stack=True,
|
||||
# record_shapes=True,
|
||||
# profile_memory=True,
|
||||
# ) as p: # trace efficiency
|
||||
|
||||
for step in pbar:
|
||||
|
||||
|
||||
# ====== Discriminator Loss ======
|
||||
if global_step > cfg.discriminator_start:
|
||||
disc_optimizer.zero_grad()
|
||||
# if video_contains_first_frame:
|
||||
# Since we don't have enough T frames, pad anyways
|
||||
real_video = pad_at_dim(video, (disc_time_padding, 0), value = 0., dim = 2)
|
||||
fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value = 0., dim = 2)
|
||||
if cfg.gradient_penalty_loss_weight is not None and cfg.gradient_penalty_loss_weight > 0.0:
|
||||
real_video = real_video.requires_grad_()
|
||||
# with profile(
|
||||
# activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
# with_stack=True,
|
||||
# ) as p: # trace efficiency
|
||||
|
||||
real_logits = discriminator(real_video.contiguous()) # SCH: not detached for now for gradient_penalty calculation
|
||||
fake_logits = discriminator(fake_video.contiguous().detach())
|
||||
disc_loss = disc_loss_fn(
|
||||
real_logits,
|
||||
fake_logits,
|
||||
global_step,
|
||||
lecam_ema_real = lecam_ema_real,
|
||||
lecam_ema_fake = lecam_ema_fake,
|
||||
real_video = real_video
|
||||
)
|
||||
# SCH: calc global step at the start
|
||||
global_step = epoch * num_steps_per_epoch + step
|
||||
|
||||
batch = next(dataloader_iter)
|
||||
x = batch["video"].to(device, dtype) # [B, C, T, H, W]
|
||||
|
||||
if cfg.ema_decay is not None:
|
||||
# SCH: TODO: is this written properly like this for moving average? e.g. distributed training etc.
|
||||
lecam_ema_real = lecam_ema_real * cfg.ema_decay + (1 - cfg.ema_decay) * torch.mean(real_logits.clone().detach())
|
||||
lecam_ema_fake = lecam_ema_fake * cfg.ema_decay + (1 - cfg.ema_decay) * torch.mean(fake_logits.clone().detach())
|
||||
# supprt for image or video inputs
|
||||
assert x.ndim in {4, 5}, f"received input of {x.ndim} dimensions" # either image or video
|
||||
assert x.shape[-2:] == cfg.image_size, f"received input size {x.shape[-2:]}, but config image size is {cfg.image_size}"
|
||||
is_image = x.ndim == 4
|
||||
if is_image:
|
||||
video = rearrange(x, 'b c ... -> b c 1 ...')
|
||||
video_contains_first_frame = True
|
||||
else:
|
||||
video = x
|
||||
|
||||
# Backward & update
|
||||
booster.backward(loss=disc_loss, optimizer=disc_optimizer)
|
||||
# NOTE: TODO: clip gradients? this is done in Open-Sora-Plan
|
||||
torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1)
|
||||
|
||||
disc_optimizer.step()
|
||||
|
||||
# Log loss values:
|
||||
all_reduce_mean(disc_loss)
|
||||
running_disc_loss += disc_loss.item()
|
||||
|
||||
log_step += 1
|
||||
|
||||
# Log to tensorboard
|
||||
if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0:
|
||||
avg_loss = running_loss / log_step
|
||||
avg_disc_loss = running_disc_loss / log_step
|
||||
pbar.set_postfix({"loss": avg_loss, "disc_loss": avg_disc_loss, "step": step, "global_step": global_step})
|
||||
running_loss = 0
|
||||
log_step = 0
|
||||
writer.add_scalar("loss", vae_loss.item(), global_step)
|
||||
if cfg.wandb:
|
||||
wandb.log(
|
||||
{
|
||||
"iter": global_step,
|
||||
"num_samples": global_step * total_batch_size,
|
||||
"epoch": epoch,
|
||||
"loss": vae_loss.item(),
|
||||
"disc_loss": disc_loss.item(),
|
||||
"avg_loss": avg_loss,
|
||||
},
|
||||
step=global_step,
|
||||
# ====== VAE ======
|
||||
optimizer.zero_grad()
|
||||
recon_video, posterior = vae(
|
||||
video,
|
||||
video_contains_first_frame = video_contains_first_frame,
|
||||
)
|
||||
|
||||
# Save checkpoint
|
||||
if cfg.ckpt_every > 0 and (global_step + 1) % cfg.ckpt_every == 0:
|
||||
save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}")
|
||||
os.makedirs(os.path.join(save_dir, "model"), exist_ok=True)
|
||||
booster.save_model(vae, os.path.join(save_dir, "model"), shard=True)
|
||||
booster.save_model(discriminator, os.path.join(save_dir, "discriminator"), shard=True)
|
||||
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096)
|
||||
booster.save_optimizer(disc_optimizer, os.path.join(save_dir, "disc_optimizer"), shard=True, size_per_shard=4096)
|
||||
# ====== Generator Loss ======
|
||||
# simple nll loss
|
||||
nll_loss = nll_loss_fn(
|
||||
video,
|
||||
recon_video,
|
||||
posterior,
|
||||
split = "train"
|
||||
)
|
||||
vae_loss = nll_loss
|
||||
|
||||
if lr_scheduler is not None:
|
||||
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
|
||||
if disc_lr_scheduler is not None:
|
||||
booster.save_lr_scheduler(disc_lr_scheduler, os.path.join(save_dir, "disc_lr_scheduler"))
|
||||
# adversarial loss
|
||||
if global_step > cfg.discriminator_start:
|
||||
# padded videos for GAN
|
||||
fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value = 0., dim = 2)
|
||||
fake_logits = discriminator(fake_video) # TODO: take out contiguous?
|
||||
adversarial_loss = adversarial_loss_fn(
|
||||
fake_logits,
|
||||
nll_loss,
|
||||
vae.module.get_last_layer(),
|
||||
global_step,
|
||||
is_training = vae.training,
|
||||
)
|
||||
vae_loss += adversarial_loss
|
||||
# Backward & update
|
||||
booster.backward(loss=vae_loss, optimizer=optimizer)
|
||||
# NOTE: clip gradients? this is done in Open-Sora-Plan
|
||||
torch.nn.utils.clip_grad_norm_(vae.parameters(), 1)
|
||||
optimizer.step()
|
||||
# Log loss values:
|
||||
all_reduce_mean(vae_loss)
|
||||
running_loss += vae_loss.item()
|
||||
|
||||
|
||||
running_states = {
|
||||
"epoch": epoch,
|
||||
"step": step+1,
|
||||
"global_step": global_step+1,
|
||||
"sample_start_index": (step+1) * cfg.batch_size,
|
||||
}
|
||||
if coordinator.is_master():
|
||||
save_json(running_states, os.path.join(save_dir, "running_states.json"))
|
||||
dist.barrier()
|
||||
logger.info(
|
||||
f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}"
|
||||
)
|
||||
|
||||
# ====== Discriminator Loss ======
|
||||
if global_step > cfg.discriminator_start:
|
||||
disc_optimizer.zero_grad()
|
||||
# if video_contains_first_frame:
|
||||
# Since we don't have enough T frames, pad anyways
|
||||
|
||||
real_video = pad_at_dim(video, (disc_time_padding, 0), value = 0., dim = 2)
|
||||
fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value = 0., dim = 2)
|
||||
|
||||
if cfg.gradient_penalty_loss_weight is not None and cfg.gradient_penalty_loss_weight > 0.0:
|
||||
real_video = real_video.requires_grad_()
|
||||
|
||||
real_logits = discriminator(real_video.contiguous()) # SCH: not detached for now for gradient_penalty calculation
|
||||
|
||||
if cfg.gradient_penalty_loss_weight is None:
|
||||
real_logits = real_logits.detach()
|
||||
|
||||
fake_logits = discriminator(fake_video.contiguous().detach())
|
||||
disc_loss = disc_loss_fn(
|
||||
real_logits,
|
||||
fake_logits,
|
||||
global_step,
|
||||
lecam_ema_real = lecam_ema_real,
|
||||
lecam_ema_fake = lecam_ema_fake,
|
||||
real_video = real_video if cfg.gradient_penalty_loss_weight is not None else None,
|
||||
)
|
||||
|
||||
if cfg.ema_decay is not None:
|
||||
# SCH: TODO: is this written properly like this for moving average? e.g. distributed training etc.
|
||||
lecam_ema_real = lecam_ema_real * cfg.ema_decay + (1 - cfg.ema_decay) * torch.mean(real_logits.clone().detach())
|
||||
lecam_ema_fake = lecam_ema_fake * cfg.ema_decay + (1 - cfg.ema_decay) * torch.mean(fake_logits.clone().detach())
|
||||
|
||||
# Backward & update
|
||||
booster.backward(loss=disc_loss, optimizer=disc_optimizer)
|
||||
# NOTE: TODO: clip gradients? this is done in Open-Sora-Plan
|
||||
torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1)
|
||||
|
||||
disc_optimizer.step()
|
||||
|
||||
# Log loss values:
|
||||
all_reduce_mean(disc_loss)
|
||||
running_disc_loss += disc_loss.item()
|
||||
|
||||
log_step += 1
|
||||
|
||||
# Log to tensorboard
|
||||
if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0:
|
||||
avg_loss = running_loss / log_step
|
||||
avg_disc_loss = running_disc_loss / log_step
|
||||
pbar.set_postfix({"loss": avg_loss, "disc_loss": avg_disc_loss, "step": step, "global_step": global_step})
|
||||
running_loss = 0
|
||||
log_step = 0
|
||||
writer.add_scalar("loss", vae_loss.item(), global_step)
|
||||
if cfg.wandb:
|
||||
wandb.log(
|
||||
{
|
||||
"iter": global_step,
|
||||
"num_samples": global_step * total_batch_size,
|
||||
"epoch": epoch,
|
||||
"loss": vae_loss.item(),
|
||||
"disc_loss": disc_loss.item(),
|
||||
"avg_loss": avg_loss,
|
||||
},
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
# Save checkpoint
|
||||
if cfg.ckpt_every > 0 and (global_step + 1) % cfg.ckpt_every == 0:
|
||||
save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}")
|
||||
os.makedirs(os.path.join(save_dir, "model"), exist_ok=True)
|
||||
booster.save_model(vae, os.path.join(save_dir, "model"), shard=True)
|
||||
booster.save_model(discriminator, os.path.join(save_dir, "discriminator"), shard=True)
|
||||
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096)
|
||||
booster.save_optimizer(disc_optimizer, os.path.join(save_dir, "disc_optimizer"), shard=True, size_per_shard=4096)
|
||||
|
||||
if lr_scheduler is not None:
|
||||
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
|
||||
if disc_lr_scheduler is not None:
|
||||
booster.save_lr_scheduler(disc_lr_scheduler, os.path.join(save_dir, "disc_lr_scheduler"))
|
||||
|
||||
running_states = {
|
||||
"epoch": epoch,
|
||||
"step": step+1,
|
||||
"global_step": global_step+1,
|
||||
"sample_start_index": (step+1) * cfg.batch_size,
|
||||
}
|
||||
if coordinator.is_master():
|
||||
save_json(running_states, os.path.join(save_dir, "running_states.json"))
|
||||
dist.barrier()
|
||||
logger.info(
|
||||
f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}"
|
||||
)
|
||||
|
||||
# p.step()
|
||||
|
||||
# print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
|
||||
|
||||
|
||||
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
|
||||
dataloader.sampler.set_start_index(0)
|
||||
|
|
|
|||
Loading…
Reference in a new issue