From afd3f823d43c628910ade62b7a2fc5d2061669a2 Mon Sep 17 00:00:00 2001 From: Shen-Chenhui Date: Tue, 16 Apr 2024 15:00:31 +0800 Subject: [PATCH] debug --- configs/vae_magvit_v2/train/17x128x128.py | 10 +- opensora/models/vae/vae_3d_v2.py | 131 ++++++--- scripts/train-vae-v2.py | 317 ++++++++++++---------- 3 files changed, 275 insertions(+), 183 deletions(-) diff --git a/configs/vae_magvit_v2/train/17x128x128.py b/configs/vae_magvit_v2/train/17x128x128.py index de54065..8684ba2 100644 --- a/configs/vae_magvit_v2/train/17x128x128.py +++ b/configs/vae_magvit_v2/train/17x128x128.py @@ -59,12 +59,10 @@ kl_loss_weight = 0.000001 perceptual_loss_weight = 0.1 # use vgg is not None and more than 0 discriminator_factor = 1.0 discriminator_loss_weight = 0.5 -lecam_loss_weight = 0 # TODO: not clear in MAGVIT what is the weight -discriminator_loss="hinge" -discriminator_start = -1 # 50000 TODO: change to correct val, debug use -1 for now - - -gradient_penalty_loss_weight = 10 # SCH: following MAGVIT config.vqgan.grad_penalty_cost +lecam_loss_weight = None # TODO: not clear in MAGVIT what is the weight +discriminator_loss="non-saturating" +discriminator_start = 50000 # 50000 TODO: change to correct val, debug use -1 for now +gradient_penalty_loss_weight = None # 10 # SCH: following MAGVIT config.vqgan.grad_penalty_cost, 10 ema_decay = 0.999 # ema decay factor for generator diff --git a/opensora/models/vae/vae_3d_v2.py b/opensora/models/vae/vae_3d_v2.py index a1b261b..6a8ec42 100644 --- a/opensora/models/vae/vae_3d_v2.py +++ b/opensora/models/vae/vae_3d_v2.py @@ -58,36 +58,9 @@ def pick_video_frame(video, frame_indices): def exists(v): return v is not None -def hinge_discr_loss(fake, real): - return (F.relu(1 + fake) + F.relu(1 - real)).mean() - -def hinge_d_loss(logits_real, logits_fake): - loss_real = torch.mean(F.relu(1. - logits_real)) - loss_fake = torch.mean(F.relu(1. + logits_fake)) - d_loss = 0.5 * (loss_real + loss_fake) - return d_loss - -def vanilla_d_loss(logits_real, logits_fake): - d_loss = 0.5 * ( - torch.mean(torch.nn.functional.softplus(-logits_real)) + - torch.mean(torch.nn.functional.softplus(logits_fake))) - return d_loss - +# ============== Generator Adversarial Loss Functions ============== # TODO: verify if this is correct implementation def lecam_reg(real_pred, fake_pred, ema_real_pred, ema_fake_pred): - """Lecam loss for data-efficient and stable GAN training. - - Described in https://arxiv.org/abs/2104.03310 - - Args: - real_pred: Prediction (scalar) for the real samples. - fake_pred: Prediction for the fake samples. - ema_real_pred: EMA prediction (scalar) for the real samples. - ema_fake_pred: EMA prediction for the fake samples. - - Returns: - Lecam regularization loss (scalar). - """ assert real_pred.ndim == 0 and ema_fake_pred.ndim == 0 lecam_loss = np.mean(np.power(nn.ReLU(real_pred - ema_fake_pred), 2)) lecam_loss += np.mean(np.power(nn.ReLU(ema_real_pred - fake_pred), 2)) @@ -107,6 +80,32 @@ def gradient_penalty_fn(images, output): gradients = rearrange(gradients, 'b ... -> b (...)') return ((gradients.norm(2, dim = 1) - 1) ** 2).mean() + +# ============== Discriminator Adversarial Loss Functions ============== +def hinge_d_loss(logits_real, logits_fake): + loss_real = torch.mean(F.relu(1. - logits_real)) + loss_fake = torch.mean(F.relu(1. + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + +def vanilla_d_loss(logits_real, logits_fake): + d_loss = 0.5 * ( + torch.mean(torch.nn.functional.softplus(-logits_real)) + + torch.mean(torch.nn.functional.softplus(logits_fake))) + return d_loss + +# from MAGVIT, used in place hof hinge_d_loss +def sigmoid_cross_entropy_with_logits(labels, logits): + # The final formulation is: max(x, 0) - x * z + log(1 + exp(-abs(x))) + zeros = torch.zeros_like(logits, dtype=logits.dtype) + condition = (logits >= zeros) + relu_logits = torch.where(condition, logits, zeros) + neg_abs_logits = torch.where(condition, -logits, logits) + return relu_logits - logits * labels + torch.log1p(torch.exp(neg_abs_logits)) + + + + def xavier_uniform_weight_init(m): if isinstance(m, nn.Conv3d) or isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu')) @@ -389,21 +388,28 @@ class StyleGANDiscriminatorBlur(nn.Module): self.apply(xavier_uniform_weight_init) def forward(self, x): - + x = self.conv1(x) + # print("discriminator aft conv:", x.size()) x = self.activation_fn(x) for i in range(self.num_blocks): x = self.res_block_list[i](x) + # print("discriminator resblock down:", x.size()) x = self.conv2(x) + # print("discriminator aft conv2:", x.size()) x = self.norm1(x) x = self.activation_fn(x) x = x.reshape((x.shape[0], -1)) # SCH: [B, (C * T * W * H)] ? + # print("discriminator reshape:", x.size()) x = self.linear1(x) + # print("discriminator aft linear1:", x.size()) + x = self.activation_fn(x) x = self.linear2(x) + # print("discriminator aft linear2:", x.size()) return x class Encoder(nn.Module): @@ -499,20 +505,26 @@ class Encoder(nn.Module): # NOTE: moved to VAE for separate first frame processing # x = self.conv1(x) + # print("encoder:", x.size()) + for i in range(self.num_blocks): for j in range(self.num_res_blocks): x = self.block_res_blocks[i][j](x) + # print("encoder:", x.size()) if i < self.num_blocks - 1: x = self.conv_blocks[i](x) + # print("encoder:", x.size()) for i in range(self.num_res_blocks): x = self.res_blocks[i](x) + # print("encoder:", x.size()) x = self.norm1(x) x = self.activate(x) x = self.conv2(x) + # print("encoder:", x.size()) return x class Decoder(nn.Module): @@ -620,19 +632,22 @@ class Decoder(nn.Module): **kwargs, ): # dtype, device = x.dtype, x.device + x = self.conv1(x) + # print("decoder:", x.size()) for i in range(self.num_res_blocks): x = self.res_blocks[i](x) + # print("decoder:", x.size()) for i in reversed(range(self.num_blocks)): # reverse here to make decoder symmetric with encoder for j in range(self.num_res_blocks): x = self.block_res_blocks[i][j](x) - + # print("decoder:", x.size()) if i > 0: t_stride = 2 if self.temporal_downsample[i - 1] else 1 # SCH: T-Causal Conv 3x3x3, f -> (t_stride * 2 * 2) * f, depth to space t_stride x 2 x 2 x = self.conv_blocks[i-1](x) x = rearrange(x, "B (C ts hs ws) T H W -> B C (T ts) (H hs) (W ws)", ts=t_stride, hs=2, ws=2) - + # print("decoder:", x.size()) x = self.norm1(x) x = self.activate(x) @@ -744,19 +759,30 @@ class VAE_3D_V2(nn.Module): video = pad_at_dim(video, (self.time_padding, 0), value = 0., dim = 2) video_packed_shape = [torch.Size([self.time_padding]), torch.Size([]), torch.Size([video_len - 1])] + # print("pre-encoder:", video.size()) + # NOTE: moved encoder conv1 here for separate first frame encoding if encode_first_frame_separately: pad, first_frame, video = unpack(video, video_packed_shape, 'b c * h w') first_frame = self.conv_in_first_frame(first_frame) video = self.conv_in(video) + + # print("pre-encoder:", video.size()) + if encode_first_frame_separately: video, _ = pack([first_frame, video], 'b c * h w') video = pad_at_dim(video, (self.time_padding, 0), dim = 2) encoded_feature = self.encoder(video) + # print("after encoder:", encoded_feature.size()) + + # NOTE: TODO: do we include this before gaussian distri? or go directly to Gaussian distribution moments = self.quant_conv(encoded_feature).to(video.dtype) posterior = model_utils.DiagonalGaussianDistribution(moments) + + # print("after encoder moments:", moments.size()) + return posterior def decode( @@ -767,8 +793,12 @@ class VAE_3D_V2(nn.Module): # dtype = z.dtype decode_first_frame_separately = self.separate_first_frame_encoding and video_contains_first_frame + z = self.post_quant_conv(z) + # print("pre decoder, post quant conv:", z.size()) + dec = self.decoder(z) + # print("post decoder:", dec.size()) # SCH: moved decoder last conv layer here for separate first frame decoding if decode_first_frame_separately: @@ -782,6 +812,8 @@ class VAE_3D_V2(nn.Module): if video_contains_first_frame: video = video[:, :, self.time_padding:] + # print("conv out:", video.size()) + return video def get_last_layer(self): @@ -971,24 +1003,19 @@ class DiscriminatorLoss(nn.Module): self, discriminator_factor = 1.0, discriminator_start = 50001, - discriminator_loss="hinge", + discriminator_loss="non-saturating", lecam_loss_weight=0, gradient_penalty_loss_weight=10, # SCH: following MAGVIT config.vqgan.grad_penalty_cost ): super().__init__() - assert discriminator_loss in ["hinge", "vanilla"] + assert discriminator_loss in ["hinge", "vanilla", "non-saturating"] self.discriminator_factor = discriminator_factor self.discriminator_start = discriminator_start self.lecam_loss_weight = lecam_loss_weight self.gradient_penalty_loss_weight = gradient_penalty_loss_weight - - if discriminator_loss == "hinge": - self.disc_loss_fn = hinge_d_loss - elif discriminator_loss == "vanilla": - self.disc_loss_fn = vanilla_d_loss - else: - raise ValueError(f"Unknown GAN loss '{discriminator_loss}'.") + self.discriminator_loss_type = discriminator_loss + def forward( self, @@ -1002,11 +1029,33 @@ class DiscriminatorLoss(nn.Module): ): if self.discriminator_factor is not None and self.discriminator_factor > 0.0: disc_factor = adopt_weight(self.discriminator_factor, global_step, threshold=self.discriminator_start) - weighted_d_adversarial_loss = disc_factor * self.disc_loss_fn(real_logits, fake_logits) + + if self.discriminator_loss_type == "hinge": + disc_loss = hinge_d_loss(real_logits, fake_logits) + elif self.discriminator_loss_type == "non-saturating": + if real_logits is not None: + real_loss = sigmoid_cross_entropy_with_logits( + labels=torch.ones_like(real_logits), logits=real_logits + ) + else: + real_loss = 0.0 + if fake_logits is not None: + fake_loss = sigmoid_cross_entropy_with_logits( + labels=torch.zeros_like(fake_logits), logits=fake_logits) + else: + fake_loss = 0.0 + disc_loss = 0.5 * (torch.mean(real_loss) + torch.mean(fake_loss)) + elif self.discriminator_loss_type == "vanilla": + disc_loss = vanilla_d_loss(real_logits, fake_logits) + else: + raise ValueError(f"Unknown GAN loss '{self.discriminator_loss_type}'.") + + weighted_d_adversarial_loss = disc_factor * disc_loss else: weighted_d_adversarial_loss = 0 lecam_loss = 0.0 + if self.lecam_loss_weight is not None and self.lecam_loss_weight > 0.0: real_pred = np.mean(real_logits.clone().detach()) fake_pred = np.mean(fake_logits.clone().detach()) diff --git a/scripts/train-vae-v2.py b/scripts/train-vae-v2.py index 1aba922..ec7ebbf 100644 --- a/scripts/train-vae-v2.py +++ b/scripts/train-vae-v2.py @@ -37,6 +37,14 @@ from opensora.utils.misc import all_reduce_mean, format_numel_str, get_model_num from opensora.utils.train_utils import update_ema, MaskGenerator from opensora.models.vae.vae_3d_v2 import VEALoss, DiscriminatorLoss, AdversarialLoss, pad_at_dim +# efficiency +# from torch.profiler import profile, record_function, ProfilerActivity + +def trace_handler(p): + output = p.key_averages().table(sort_by="self_cuda_time_total", row_limit=5) + print(output) + # p.export_chrome_trace("/home/shenchenhui/Open-Sora-dev/outputs/traces/trace_" + str(p.step_num) + ".json") + def main(): # ====================================================== @@ -226,7 +234,9 @@ def main(): discriminator_factor = cfg.discriminator_factor, discriminator_start = cfg.discriminator_start, discriminator_loss = cfg.discriminator_loss, - ) + lecam_loss_weight = cfg.lecam_loss_weight, + gradient_penalty_loss_weight = cfg.gradient_penalty_loss_weight, + ) # 6.3. training loop @@ -238,11 +248,15 @@ def main(): lecam_ema_real = torch.tensor(0.0) lecam_ema_fake = torch.tensor(0.0) + + for epoch in range(start_epoch, cfg.epochs): dataloader.sampler.set_epoch(epoch) dataloader_iter = iter(dataloader) logger.info(f"Beginning epoch {epoch}...") + + with tqdm( range(start_step, num_steps_per_epoch), desc=f"Epoch {epoch}", @@ -250,151 +264,182 @@ def main(): total=num_steps_per_epoch, initial=start_step, ) as pbar: - for step in pbar: - - # SCH: calc global step at the start - global_step = epoch * num_steps_per_epoch + step - batch = next(dataloader_iter) - x = batch["video"].to(device, dtype) # [B, C, T, H, W] - - # supprt for image or video inputs - assert x.ndim in {4, 5}, f"received input of {x.ndim} dimensions" # either image or video - assert x.shape[-2:] == cfg.image_size, f"received input size {x.shape[-2:]}, but config image size is {cfg.image_size}" - is_image = x.ndim == 4 - if is_image: - video = rearrange(x, 'b c ... -> b c 1 ...') - video_contains_first_frame = True - else: - video = x - - # ====== VAE ====== - optimizer.zero_grad() - recon_video, posterior = vae( - video, - video_contains_first_frame = video_contains_first_frame, - ) - - # ====== Generator Loss ====== - # simple nll loss - nll_loss = nll_loss_fn( - video, - recon_video, - posterior, - split = "train" - ) - vae_loss = nll_loss - - # adversarial loss - if global_step > cfg.discriminator_start: - # padded videos for GAN - fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value = 0., dim = 2) - fake_logits = discriminator(fake_video.contiguous()) - adversarial_loss = adversarial_loss_fn( - fake_logits, - nll_loss, - vae.module.get_last_layer(), - global_step, - is_training = vae.training, - ) - vae_loss += adversarial_loss - # Backward & update - booster.backward(loss=vae_loss, optimizer=optimizer) - # NOTE: clip gradients? this is done in Open-Sora-Plan - torch.nn.utils.clip_grad_norm_(vae.parameters(), 1) - optimizer.step() - # Log loss values: - all_reduce_mean(vae_loss) - running_loss += vae_loss.item() + # with profile( + # activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + # schedule=torch.profiler.schedule( + # wait=1, + # warmup=1, + # active=2, + # repeat=2, + # ), + # on_trace_ready=torch.profiler.tensorboard_trace_handler('/home/shenchenhui/log'), + # with_stack=True, + # record_shapes=True, + # profile_memory=True, + # ) as p: # trace efficiency + for step in pbar: - - # ====== Discriminator Loss ====== - if global_step > cfg.discriminator_start: - disc_optimizer.zero_grad() - # if video_contains_first_frame: - # Since we don't have enough T frames, pad anyways - real_video = pad_at_dim(video, (disc_time_padding, 0), value = 0., dim = 2) - fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value = 0., dim = 2) - if cfg.gradient_penalty_loss_weight is not None and cfg.gradient_penalty_loss_weight > 0.0: - real_video = real_video.requires_grad_() + # with profile( + # activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + # with_stack=True, + # ) as p: # trace efficiency - real_logits = discriminator(real_video.contiguous()) # SCH: not detached for now for gradient_penalty calculation - fake_logits = discriminator(fake_video.contiguous().detach()) - disc_loss = disc_loss_fn( - real_logits, - fake_logits, - global_step, - lecam_ema_real = lecam_ema_real, - lecam_ema_fake = lecam_ema_fake, - real_video = real_video - ) + # SCH: calc global step at the start + global_step = epoch * num_steps_per_epoch + step + + batch = next(dataloader_iter) + x = batch["video"].to(device, dtype) # [B, C, T, H, W] - if cfg.ema_decay is not None: - # SCH: TODO: is this written properly like this for moving average? e.g. distributed training etc. - lecam_ema_real = lecam_ema_real * cfg.ema_decay + (1 - cfg.ema_decay) * torch.mean(real_logits.clone().detach()) - lecam_ema_fake = lecam_ema_fake * cfg.ema_decay + (1 - cfg.ema_decay) * torch.mean(fake_logits.clone().detach()) + # supprt for image or video inputs + assert x.ndim in {4, 5}, f"received input of {x.ndim} dimensions" # either image or video + assert x.shape[-2:] == cfg.image_size, f"received input size {x.shape[-2:]}, but config image size is {cfg.image_size}" + is_image = x.ndim == 4 + if is_image: + video = rearrange(x, 'b c ... -> b c 1 ...') + video_contains_first_frame = True + else: + video = x - # Backward & update - booster.backward(loss=disc_loss, optimizer=disc_optimizer) - # NOTE: TODO: clip gradients? this is done in Open-Sora-Plan - torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1) - - disc_optimizer.step() - - # Log loss values: - all_reduce_mean(disc_loss) - running_disc_loss += disc_loss.item() - - log_step += 1 - - # Log to tensorboard - if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0: - avg_loss = running_loss / log_step - avg_disc_loss = running_disc_loss / log_step - pbar.set_postfix({"loss": avg_loss, "disc_loss": avg_disc_loss, "step": step, "global_step": global_step}) - running_loss = 0 - log_step = 0 - writer.add_scalar("loss", vae_loss.item(), global_step) - if cfg.wandb: - wandb.log( - { - "iter": global_step, - "num_samples": global_step * total_batch_size, - "epoch": epoch, - "loss": vae_loss.item(), - "disc_loss": disc_loss.item(), - "avg_loss": avg_loss, - }, - step=global_step, + # ====== VAE ====== + optimizer.zero_grad() + recon_video, posterior = vae( + video, + video_contains_first_frame = video_contains_first_frame, ) - # Save checkpoint - if cfg.ckpt_every > 0 and (global_step + 1) % cfg.ckpt_every == 0: - save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}") - os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) - booster.save_model(vae, os.path.join(save_dir, "model"), shard=True) - booster.save_model(discriminator, os.path.join(save_dir, "discriminator"), shard=True) - booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096) - booster.save_optimizer(disc_optimizer, os.path.join(save_dir, "disc_optimizer"), shard=True, size_per_shard=4096) + # ====== Generator Loss ====== + # simple nll loss + nll_loss = nll_loss_fn( + video, + recon_video, + posterior, + split = "train" + ) + vae_loss = nll_loss - if lr_scheduler is not None: - booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) - if disc_lr_scheduler is not None: - booster.save_lr_scheduler(disc_lr_scheduler, os.path.join(save_dir, "disc_lr_scheduler")) + # adversarial loss + if global_step > cfg.discriminator_start: + # padded videos for GAN + fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value = 0., dim = 2) + fake_logits = discriminator(fake_video) # TODO: take out contiguous? + adversarial_loss = adversarial_loss_fn( + fake_logits, + nll_loss, + vae.module.get_last_layer(), + global_step, + is_training = vae.training, + ) + vae_loss += adversarial_loss + # Backward & update + booster.backward(loss=vae_loss, optimizer=optimizer) + # NOTE: clip gradients? this is done in Open-Sora-Plan + torch.nn.utils.clip_grad_norm_(vae.parameters(), 1) + optimizer.step() + # Log loss values: + all_reduce_mean(vae_loss) + running_loss += vae_loss.item() + - running_states = { - "epoch": epoch, - "step": step+1, - "global_step": global_step+1, - "sample_start_index": (step+1) * cfg.batch_size, - } - if coordinator.is_master(): - save_json(running_states, os.path.join(save_dir, "running_states.json")) - dist.barrier() - logger.info( - f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}" - ) + + # ====== Discriminator Loss ====== + if global_step > cfg.discriminator_start: + disc_optimizer.zero_grad() + # if video_contains_first_frame: + # Since we don't have enough T frames, pad anyways + + real_video = pad_at_dim(video, (disc_time_padding, 0), value = 0., dim = 2) + fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value = 0., dim = 2) + + if cfg.gradient_penalty_loss_weight is not None and cfg.gradient_penalty_loss_weight > 0.0: + real_video = real_video.requires_grad_() + + real_logits = discriminator(real_video.contiguous()) # SCH: not detached for now for gradient_penalty calculation + + if cfg.gradient_penalty_loss_weight is None: + real_logits = real_logits.detach() + + fake_logits = discriminator(fake_video.contiguous().detach()) + disc_loss = disc_loss_fn( + real_logits, + fake_logits, + global_step, + lecam_ema_real = lecam_ema_real, + lecam_ema_fake = lecam_ema_fake, + real_video = real_video if cfg.gradient_penalty_loss_weight is not None else None, + ) + + if cfg.ema_decay is not None: + # SCH: TODO: is this written properly like this for moving average? e.g. distributed training etc. + lecam_ema_real = lecam_ema_real * cfg.ema_decay + (1 - cfg.ema_decay) * torch.mean(real_logits.clone().detach()) + lecam_ema_fake = lecam_ema_fake * cfg.ema_decay + (1 - cfg.ema_decay) * torch.mean(fake_logits.clone().detach()) + + # Backward & update + booster.backward(loss=disc_loss, optimizer=disc_optimizer) + # NOTE: TODO: clip gradients? this is done in Open-Sora-Plan + torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1) + + disc_optimizer.step() + + # Log loss values: + all_reduce_mean(disc_loss) + running_disc_loss += disc_loss.item() + + log_step += 1 + + # Log to tensorboard + if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0: + avg_loss = running_loss / log_step + avg_disc_loss = running_disc_loss / log_step + pbar.set_postfix({"loss": avg_loss, "disc_loss": avg_disc_loss, "step": step, "global_step": global_step}) + running_loss = 0 + log_step = 0 + writer.add_scalar("loss", vae_loss.item(), global_step) + if cfg.wandb: + wandb.log( + { + "iter": global_step, + "num_samples": global_step * total_batch_size, + "epoch": epoch, + "loss": vae_loss.item(), + "disc_loss": disc_loss.item(), + "avg_loss": avg_loss, + }, + step=global_step, + ) + + # Save checkpoint + if cfg.ckpt_every > 0 and (global_step + 1) % cfg.ckpt_every == 0: + save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}") + os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) + booster.save_model(vae, os.path.join(save_dir, "model"), shard=True) + booster.save_model(discriminator, os.path.join(save_dir, "discriminator"), shard=True) + booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096) + booster.save_optimizer(disc_optimizer, os.path.join(save_dir, "disc_optimizer"), shard=True, size_per_shard=4096) + + if lr_scheduler is not None: + booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) + if disc_lr_scheduler is not None: + booster.save_lr_scheduler(disc_lr_scheduler, os.path.join(save_dir, "disc_lr_scheduler")) + + running_states = { + "epoch": epoch, + "step": step+1, + "global_step": global_step+1, + "sample_start_index": (step+1) * cfg.batch_size, + } + if coordinator.is_master(): + save_json(running_states, os.path.join(save_dir, "running_states.json")) + dist.barrier() + logger.info( + f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}" + ) + + # p.step() + + # print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) + # the continue epochs are not resumed, so we need to reset the sampler start index and start step dataloader.sampler.set_start_index(0)