diff --git a/scripts/inference-vae.py b/scripts/inference-vae.py index 365c4a2..afbc06d 100644 --- a/scripts/inference-vae.py +++ b/scripts/inference-vae.py @@ -20,6 +20,12 @@ from opensora.acceleration.parallel_states import ( from tqdm import tqdm from opensora.models.vae.model_utils import VEA3DLoss +# DEBUG +from colossalai.booster import Booster +from colossalai.booster.plugin import LowLevelZeroPlugin +from opensora.acceleration.plugin import ZeroSeqParallelPlugin + + def main(): # ====================================================== @@ -49,6 +55,31 @@ def main(): set_random_seed(seed=cfg.seed) + # 2.3 DEBUG: USE BOOSTER + # 2.3. initialize ColossalAI booster + if cfg.plugin == "zero2": + plugin = LowLevelZeroPlugin( + stage=2, + precision=cfg.dtype, + initial_scale=2**16, + max_norm=cfg.grad_clip, + ) + set_data_parallel_group(dist.group.WORLD) + elif cfg.plugin == "zero2-seq": + plugin = ZeroSeqParallelPlugin( + sp_size=cfg.sp_size, + stage=2, + precision=cfg.dtype, + initial_scale=2**16, + max_norm=cfg.grad_clip, + ) + set_sequence_parallel_group(plugin.sp_group) + set_data_parallel_group(plugin.dp_group) + else: + raise ValueError(f"Unknown plugin {cfg.plugin}") + booster = Booster(plugin=plugin) + + # ====================================================== # 3. build dataset and dataloader # ====================================================== @@ -108,6 +139,15 @@ def main(): os.makedirs(save_dir, exist_ok=True) + ### TODO: DEBUG, USE booster + torch.set_default_dtype(dtype) + vae, _, _, dataloader, _ = booster.boost( + model=vae, dataloader=dataloader + ) + # load model using booster + booster.load_model(vae, os.path.join(cfg.ckpt_path, "model")) + + # 4.1. batch generation # define loss function diff --git a/scripts/train-vae.py b/scripts/train-vae.py index 6966b71..0d3ac2f 100644 --- a/scripts/train-vae.py +++ b/scripts/train-vae.py @@ -186,151 +186,151 @@ def main(): dataloader.sampler.set_start_index(sampler_start_idx) - # # define loss function - # loss_function = VEA3DLoss(kl_weight=cfg.kl_weight, perceptual_weight=cfg.perceptual_weight).to(device, dtype) - - - # # 6.2. training loop - # for epoch in range(start_epoch, cfg.epochs): - # dataloader.sampler.set_epoch(epoch) - # dataloader_iter = iter(dataloader) - # logger.info(f"Beginning epoch {epoch}...") - - # with tqdm( - # range(start_step, num_steps_per_epoch), - # desc=f"Epoch {epoch}", - # disable=not coordinator.is_master(), - # total=num_steps_per_epoch, - # initial=start_step, - # ) as pbar: - # for step in pbar: - # batch = next(dataloader_iter) - # x = batch["video"].to(device, dtype) # [B, C, T, H, W] - - # # loss = vae.get_loss(x) - # reconstructions, posterior = vae(x) - # loss = loss_function(x, reconstructions, posterior) - - # # Backward & update - # booster.backward(loss=loss, optimizer=optimizer) - # optimizer.step() - # optimizer.zero_grad() - - # # Log loss values: - # all_reduce_mean(loss) - # running_loss += loss.item() - # global_step = epoch * num_steps_per_epoch + step - # log_step += 1 - - # # Log to tensorboard - # if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0: - # avg_loss = running_loss / log_step - # pbar.set_postfix({"loss": avg_loss, "step": step, "global_step": global_step}) - # running_loss = 0 - # log_step = 0 - # writer.add_scalar("loss", loss.item(), global_step) - # if cfg.wandb: - # wandb.log( - # { - # "iter": global_step, - # "num_samples": global_step * total_batch_size, - # "epoch": epoch, - # "loss": loss.item(), - # "avg_loss": avg_loss, - # }, - # step=global_step, - # ) - - # # Save checkpoint - # if cfg.ckpt_every > 0 and (global_step + 1) % cfg.ckpt_every == 0: - # save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}") - # os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) - # # TODO: save in model? - # booster.save_model(vae, os.path.join(save_dir, "model"), shard=True) - # booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096) - # if lr_scheduler is not None: - # booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) - # running_states = { - # "epoch": epoch, - # "step": step+1, - # "global_step": global_step+1, - # "sample_start_index": (step+1) * cfg.batch_size, - # } - # if coordinator.is_master(): - # save_json(running_states, os.path.join(save_dir, "running_states.json")) - # dist.barrier() - # logger.info( - # f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}" - # ) - - # # the continue epochs are not resumed, so we need to reset the sampler start index and start step - # dataloader.sampler.set_start_index(0) - # start_step = 0 - - # DEBUG inference - - # 4.1. batch generation - # define loss function loss_function = VEA3DLoss(kl_weight=cfg.kl_weight, perceptual_weight=cfg.perceptual_weight).to(device, dtype) - running_loss = 0.0 - loss_steps = 0 - from opensora.datasets import save_sample - # get data again - print("loading test data...") - dataset = DatasetFromCSV( - cfg.data_path, - # TODO: change transforms - transform=( - get_transforms_video(cfg.image_size[0]) - if not cfg.use_image_transform - else get_transforms_image(cfg.image_size[0]) - ), - num_frames=cfg.num_frames, - frame_interval=cfg.frame_interval, - root=cfg.root, - ) + # 6.2. training loop + for epoch in range(start_epoch, cfg.epochs): + dataloader.sampler.set_epoch(epoch) + dataloader_iter = iter(dataloader) + logger.info(f"Beginning epoch {epoch}...") - dataloader = prepare_dataloader( - dataset, - batch_size=cfg.batch_size, - num_workers=cfg.num_workers, - shuffle=False, - drop_last=True, - pin_memory=True, - process_group=get_data_parallel_group(), - ) - print(f"Dataset contains {len(dataset):,} videos ({cfg.data_path})") + with tqdm( + range(start_step, num_steps_per_epoch), + desc=f"Epoch {epoch}", + disable=not coordinator.is_master(), + total=num_steps_per_epoch, + initial=start_step, + ) as pbar: + for step in pbar: + batch = next(dataloader_iter) + x = batch["video"].to(device, dtype) # [B, C, T, H, W] - total_steps = len(dataloader) - dataloader_iter = iter(dataloader) + # loss = vae.get_loss(x) + reconstructions, posterior = vae(x) + loss = loss_function(x, reconstructions, posterior) - print("total steps:", total_steps) + # Backward & update + booster.backward(loss=loss, optimizer=optimizer) + optimizer.step() + optimizer.zero_grad() - with tqdm( - range(total_steps), - # desc=f"Avg Loss: {running_loss}", - disable=not coordinator.is_master(), - total=total_steps, - initial=0, - ) as pbar: - for step in pbar: - batch = next(dataloader_iter) - x = batch["video"].to(device, dtype) # [B, C, T, H, W] - reconstructions, posterior = vae(x) - loss = loss_function(x, reconstructions, posterior) - loss_steps += 1 - running_loss = loss.item()/ loss_steps + running_loss * ((loss_steps - 1) / loss_steps) + # Log loss values: + all_reduce_mean(loss) + running_loss += loss.item() + global_step = epoch * num_steps_per_epoch + step + log_step += 1 - if coordinator.is_master(): - for idx, sample in enumerate(reconstructions): - pos = step * cfg.batch_size + idx - save_path = os.path.join("outputs/debug", f"sample_{pos}") - save_sample(sample, fps=8, save_path=save_path) + # Log to tensorboard + if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0: + avg_loss = running_loss / log_step + pbar.set_postfix({"loss": avg_loss, "step": step, "global_step": global_step}) + running_loss = 0 + log_step = 0 + writer.add_scalar("loss", loss.item(), global_step) + if cfg.wandb: + wandb.log( + { + "iter": global_step, + "num_samples": global_step * total_batch_size, + "epoch": epoch, + "loss": loss.item(), + "avg_loss": avg_loss, + }, + step=global_step, + ) - print("test loss:", running_loss) + # Save checkpoint + if cfg.ckpt_every > 0 and (global_step + 1) % cfg.ckpt_every == 0: + save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}") + os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) + # TODO: save in model? + booster.save_model(vae, os.path.join(save_dir, "model"), shard=True) + booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096) + if lr_scheduler is not None: + booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) + running_states = { + "epoch": epoch, + "step": step+1, + "global_step": global_step+1, + "sample_start_index": (step+1) * cfg.batch_size, + } + if coordinator.is_master(): + save_json(running_states, os.path.join(save_dir, "running_states.json")) + dist.barrier() + logger.info( + f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}" + ) + + # the continue epochs are not resumed, so we need to reset the sampler start index and start step + dataloader.sampler.set_start_index(0) + start_step = 0 + + # # DEBUG inference + + # # 4.1. batch generation + + # # define loss function + # loss_function = VEA3DLoss(kl_weight=cfg.kl_weight, perceptual_weight=cfg.perceptual_weight).to(device, dtype) + # running_loss = 0.0 + # loss_steps = 0 + + # from opensora.datasets import save_sample + + # # get data again + # print("loading test data...") + # dataset = DatasetFromCSV( + # cfg.data_path, + # # TODO: change transforms + # transform=( + # get_transforms_video(cfg.image_size[0]) + # if not cfg.use_image_transform + # else get_transforms_image(cfg.image_size[0]) + # ), + # num_frames=cfg.num_frames, + # frame_interval=cfg.frame_interval, + # root=cfg.root, + # ) + + # dataloader = prepare_dataloader( + # dataset, + # batch_size=cfg.batch_size, + # num_workers=cfg.num_workers, + # shuffle=False, + # drop_last=True, + # pin_memory=True, + # process_group=get_data_parallel_group(), + # ) + # print(f"Dataset contains {len(dataset):,} videos ({cfg.data_path})") + + # total_steps = len(dataloader) + # dataloader_iter = iter(dataloader) + + # print("total steps:", total_steps) + + # with tqdm( + # range(total_steps), + # # desc=f"Avg Loss: {running_loss}", + # disable=not coordinator.is_master(), + # total=total_steps, + # initial=0, + # ) as pbar: + # for step in pbar: + # batch = next(dataloader_iter) + # x = batch["video"].to(device, dtype) # [B, C, T, H, W] + # reconstructions, posterior = vae(x) + # loss = loss_function(x, reconstructions, posterior) + # loss_steps += 1 + # running_loss = loss.item()/ loss_steps + running_loss * ((loss_steps - 1) / loss_steps) + + # # if coordinator.is_master(): + # # for idx, sample in enumerate(reconstructions): + # # pos = step * cfg.batch_size + idx + # # save_path = os.path.join("outputs/debug", f"sample_{pos}") + # # save_sample(sample, fps=8, save_path=save_path) + + # print("test loss:", running_loss) if __name__ == "__main__": main()