From e02f6286deae6cf63e7ce27c0341c83b07b19c05 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Wed, 28 Feb 2024 15:01:53 +0800 Subject: [PATCH] [hotfix] fix sample and update training script (#15) * [hotfix] fix sample * [hotfix] fix sample --- sample.py | 26 +++++++++++++++------- train.py | 64 +++++++++++++++++++++++++++++++++++++++---------------- 2 files changed, 64 insertions(+), 26 deletions(-) diff --git a/sample.py b/sample.py index 0bb60b5..ffb4f2f 100644 --- a/sample.py +++ b/sample.py @@ -19,7 +19,7 @@ from transformers import AutoModel, AutoTokenizer, CLIPTextModel from open_sora.diffusion import create_diffusion from open_sora.modeling import DiT_models -from open_sora.utils.data import col2video +from open_sora.utils.data import col2video, unnormalize_video def main(args): @@ -34,16 +34,20 @@ def main(args): .eval() ) in_channels = vqvae.embedding_dim + w_h_factor = 4 + t_factor = 2 else: # disable VQ-VAE if not provided, just use raw video frames vqvae = None in_channels = 3 + w_h_factor = 1 + t_factor = 1 text_model = CLIPTextModel.from_pretrained(args.text_model).to(device).eval() tokenizer = AutoTokenizer.from_pretrained(args.text_model) model = DiT_models[args.model](in_channels=in_channels).to(device).eval() patch_size = model.patch_size - # model.load_state_dict(torch.load(args.ckpt)) + model.load_state_dict(torch.load(args.ckpt)) diffusion = create_diffusion(str(args.num_sampling_steps)) # Create sampling noise: @@ -54,9 +58,9 @@ def main(args): num_frames = args.fps * args.sec z = torch.randn( 1, - (args.height // patch_size // 4) - * (args.width // patch_size // 4) - * (num_frames // 2), + (args.height // patch_size // w_h_factor) + * (args.width // patch_size // w_h_factor) + * (num_frames // t_factor), in_channels, patch_size, patch_size, @@ -87,7 +91,12 @@ def main(args): samples, _ = samples.chunk(2, dim=0) # Remove null class samples samples = col2video( samples.squeeze(), - (num_frames // 2, in_channels, args.height // 4, args.width // 4), + ( + num_frames // t_factor, + in_channels, + args.height // w_h_factor, + args.width // w_h_factor, + ), ) if vqvae is not None: # [T, C, H, W] -> [B, C, T, H, W] @@ -98,6 +107,7 @@ def main(args): else: # [T, C, H, W] -> [T, H, W, C] samples = samples.permute(0, 2, 3, 1) + samples = unnormalize_video(samples).to(torch.uint8) write_video("sample.mp4", samples.cpu(), args.fps) @@ -105,12 +115,12 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( - "--model", type=str, choices=list(DiT_models.keys()), default="DiT-S/8" + "-m", "--model", type=str, choices=list(DiT_models.keys()), default="DiT-S/8" ) parser.add_argument( "--text", type=str, - default="two ladies laughing by seeing some thing another lady throw dresses and keep it back by reverse motion", + default="a cartoon animals runs through an ice cave in a video game", ) parser.add_argument("--cfg-scale", type=float, default=4.0) parser.add_argument("--num-sampling-steps", type=int, default=250) diff --git a/train.py b/train.py index 3a353c2..ad2ff74 100644 --- a/train.py +++ b/train.py @@ -64,6 +64,18 @@ def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: return tensor +def save_checkpoints(booster, model, optimizer, ema, save_path, coordinator): + os.makedirs(save_path, exist_ok=True) + booster.save_model(model, os.path.join(save_path, "model"), shard=False) + booster.save_optimizer(optimizer, os.path.join(save_path, "optimizer"), shard=True) + if coordinator.is_master(): + ema_state_dict = ema.state_dict() + for k, v in ema_state_dict.items(): + ema_state_dict[k] = v.cpu() + torch.save(ema_state_dict, os.path.join(save_path, "ema.pt")) + dist.barrier() + + ################################################################################# # Training Loop # ################################################################################# @@ -89,7 +101,11 @@ def main(args): # Step 3: Create VQ-VAE if len(args.vqvae) > 0: - vqvae = AutoModel.from_pretrained(args.vqvae, trust_remote_code=True).to(get_current_device()).eval() + vqvae = ( + AutoModel.from_pretrained(args.vqvae, trust_remote_code=True) + .to(get_current_device()) + .eval() + ) model_kwargs = {"in_channels": vqvae.embedding_dim} else: # disable VQ-VAE if not provided, just use raw video frames @@ -110,7 +126,9 @@ def main(args): model.enable_gradient_checkpointing() # Step 5: create diffusion pipeline - diffusion = create_diffusion(timestep_respacing="") # default: 1000 steps, linear noise schedule + diffusion = create_diffusion( + timestep_respacing="" + ) # default: 1000 steps, linear noise schedule # Step 6: setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper): opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0) @@ -128,6 +146,10 @@ def main(args): # Step 8: setup booster model, opt, _, dataloader, _ = booster.boost(model, opt, dataloader=dataloader) + if args.load_model is not None: + booster.load_model(model, args.load_model) + if args.load_optimizer is not None: + booster.load_optimizer(opt, args.load_optimizer) logger.info( f"Booster init max device memory: {get_accelerator().max_memory_allocated() / 1024 ** 2:.2f} MB", ranks=[0], @@ -154,7 +176,9 @@ def main(args): (video_inputs.shape[0],), device=video_inputs.device, ) - loss_dict = diffusion.training_losses(model, video_inputs, t, batch, mask=mask) + loss_dict = diffusion.training_losses( + model, video_inputs, t, batch, mask=mask + ) loss = loss_dict["loss"].mean() / args.accumulation_steps total_loss.add_(loss.data) booster.backward(loss, opt) @@ -167,7 +191,9 @@ def main(args): all_reduce_mean(total_loss) pbar.set_postfix({"Loss": f"{total_loss.item():.4f}"}) if coordinator.is_master(): - global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps + global_step = (epoch * num_steps_per_epoch) + ( + step + 1 + ) // args.accumulation_steps writer.add_scalar( tag="Loss", scalar_value=total_loss.item(), @@ -177,22 +203,20 @@ def main(args): total_loss.zero_() # Save DiT checkpoint: - if (args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0) or ( - step + 1 - ) == len(dataloader): - save_path = os.path.join(args.checkpoint_dir, f"epoch-{epoch}-step-{step}") - os.makedirs(save_path, exist_ok=True) - booster.save_model(model, os.path.join(save_path, "model"), shard=True) - booster.save_optimizer(opt, os.path.join(save_path, "optimizer"), shard=True) - if coordinator.is_master(): - ema_state_dict = ema.state_dict() - for k, v in ema_state_dict.items(): - ema_state_dict[k] = v.cpu() - torch.save(ema_state_dict, os.path.join(save_path, "ema.pt")) - dist.barrier() + if args.save_interval > 0 and ( + (step + 1) % (args.save_interval * args.accumulation_steps) == 0 + or (step + 1) == len(dataloader) + ): + save_path = os.path.join( + args.checkpoint_dir, f"epoch-{epoch}-step-{step}" + ) + save_checkpoints(booster, model, opt, ema, save_path, coordinator) logger.info(f"Saved checkpoint to {save_path}", ranks=[0]) get_accelerator().empty_cache() + final_save_path = os.path.join(args.checkpoint_dir, "final") + save_checkpoints(booster, model, opt, ema, final_save_path, coordinator) + logger.info(f"Saved checkpoint to {final_save_path}", ranks=[0]) logger.info( f"Training complete, max device memory: {get_accelerator().max_memory_allocated() / 1024 ** 2:.2f} MB", ranks=[0], @@ -202,7 +226,9 @@ def main(args): if __name__ == "__main__": # Default args here will train DiT-XL/2 with the hyperparameters we used in our paper (except training iters). parser = argparse.ArgumentParser() - parser.add_argument("-m", "--model", type=str, choices=list(DiT_models.keys()), default="DiT-S/8") + parser.add_argument( + "-m", "--model", type=str, choices=list(DiT_models.keys()), default="DiT-S/8" + ) parser.add_argument("-d", "--dataset", nargs="+", default=[]) parser.add_argument("-v", "--video_dir", type=str, required=True) parser.add_argument("-e", "--epochs", type=int, default=10) @@ -214,5 +240,7 @@ if __name__ == "__main__": parser.add_argument("--checkpoint_dir", type=str, default="checkpoints") parser.add_argument("--tensorboard_dir", type=str, default="runs") parser.add_argument("--vqvae", default="hpcai-tech/vqvae") + parser.add_argument("--load_model", default=None) + parser.add_argument("--load_optimizer", default=None) args = parser.parse_args() main(args)