debug inference code

This commit is contained in:
Shen-Chenhui 2024-04-08 15:40:14 +08:00
parent 4b27448a49
commit fa0ca3983e
2 changed files with 109 additions and 73 deletions

View file

@ -11,11 +11,11 @@ WANDB_API_KEY=<wandb_api_key> CUDA_VISIBLE_DEVICES=<n> torchrun --master_port=<p
### 2. Inference
```yaml
CUDA_VISIBLE_DEVICES=6 torchrun --standalone --nnodes=1 --nproc_per_node=1 scripts/inference-vae.py configs/vae_3d/inference/16x256x256.py --ckpt-path /home/shenchenhui/Open-Sora-dev/outputs/train_pexel_028/epoch3-global_step20000/ --data-path /home/shenchenhui/data/pexels/debug.csv
CUDA_VISIBLE_DEVICES=6 torchrun --standalone --nnodes=1 --nproc_per_node=1 scripts/inference-vae.py configs/vae_3d/inference/16x256x256.py --ckpt-path /home/shenchenhui/Open-Sora-dev/outputs/train_pexel_028/epoch3-global_step20000/ --data-path /home/shenchenhui/data/pexels/debug.csv --save-dir outputs/pexel
CUDA_VISIBLE_DEVICES=6 torchrun --standalone --nnodes=1 --nproc_per_node=1 scripts/inference-vae.py configs/vae_3d/inference/16x256x256.py --ckpt-path /home/shenchenhui/Open-Sora-dev/outputs/004-F16S3-VAE_3D_B/epoch0-global_step1000 --data-path /home/shenchenhui/data/pexels/debug.csv
CUDA_VISIBLE_DEVICES=6 torchrun --standalone --nnodes=1 --nproc_per_node=1 scripts/inference-vae.py configs/vae_3d/inference/16x256x256.py --ckpt-path /home/shenchenhui/Open-Sora-dev/outputs/004-F16S3-VAE_3D_B/epoch0-global_step1000 --data-path /home/shenchenhui/data/pexels/debug.csv --save-dir outputs/pexel
CUDA_VISIBLE_DEVICES=6 torchrun --standalone --nnodes=1 --nproc_per_node=1 scripts/inference-vae.py configs/vae_3d/inference/16x256x256.py --ckpt-path /home/shenchenhui/Open-Sora-dev/outputs/004-F16S3-VAE_3D_B/epoch0-global_step2000 --data-path /home/shenchenhui/data/pexels/debug.csv
CUDA_VISIBLE_DEVICES=6 torchrun --standalone --nnodes=1 --nproc_per_node=1 scripts/inference-vae.py configs/vae_3d/inference/16x256x256.py --ckpt-path /home/shenchenhui/Open-Sora-dev/outputs/004-F16S3-VAE_3D_B/epoch0-global_step2000 --data-path /home/shenchenhui/data/pexels/debug.csv --save-dir outputs/pexel
# debug on the same 8 samples
@ -23,7 +23,7 @@ WANDB_API_KEY=7bc1ce71b2dc0b8cd40c500eb256747583f6c07e CUDA_VISIBLE_DEVICES=5 to
CUDA_VISIBLE_DEVICES=6 torchrun --standalone --nnodes=1 --nproc_per_node=1 scripts/inference-vae.py configs/vae_3d/inference/16x256x256.py --ckpt-path /home/shenchenhui/Open-Sora-dev/outputs/006-F16S3-VAE_3D_B/epoch49-global_step50 --data-path /home/shenchenhui/data/pexels/debug.csv
# resume training debug
WANDB_API_KEY=7bc1ce71b2dc0b8cd40c500eb256747583f6c07e CUDA_VISIBLE_DEVICES=5 torchrun --master_port=29530 --nnodes=1 --nproc_per_node=1 scripts/train-vae.py configs/vae_3d/train/16x256x256.py --data-path /home/shenchenhui/data/pexels/debug.csv --wandb True --load /home/shenchenhui/Open-Sora-dev/outputs/006-F16S3-VAE_3D_B/epoch49-global_step50
WANDB_API_KEY=7bc1ce71b2dc0b8cd40c500eb256747583f6c07e CUDA_VISIBLE_DEVICES=5 torchrun --master_port=29530 --nnodes=1 --nproc_per_node=1 scripts/train-vae.py configs/vae_3d/train/16x256x256.py --data-path /home/shenchenhui/data/pexels/debug.csv --load /home/shenchenhui/Open-Sora-dev/outputs/006-F16S3-VAE_3D_B/epoch49-global_step50 --wandb True
```
```yaml

View file

@ -190,83 +190,119 @@ def main():
loss_function = VEA3DLoss(kl_weight=cfg.kl_weight, perceptual_weight=cfg.perceptual_weight).to(device, dtype)
# 6.2. training loop
for epoch in range(start_epoch, cfg.epochs):
dataloader.sampler.set_epoch(epoch)
dataloader_iter = iter(dataloader)
logger.info(f"Beginning epoch {epoch}...")
# # 6.2. training loop
# for epoch in range(start_epoch, cfg.epochs):
# dataloader.sampler.set_epoch(epoch)
# dataloader_iter = iter(dataloader)
# logger.info(f"Beginning epoch {epoch}...")
with tqdm(
range(start_step, num_steps_per_epoch),
desc=f"Epoch {epoch}",
disable=not coordinator.is_master(),
total=num_steps_per_epoch,
initial=start_step,
) as pbar:
for step in pbar:
batch = next(dataloader_iter)
x = batch["video"].to(device, dtype) # [B, C, T, H, W]
# with tqdm(
# range(start_step, num_steps_per_epoch),
# desc=f"Epoch {epoch}",
# disable=not coordinator.is_master(),
# total=num_steps_per_epoch,
# initial=start_step,
# ) as pbar:
# for step in pbar:
# batch = next(dataloader_iter)
# x = batch["video"].to(device, dtype) # [B, C, T, H, W]
# loss = vae.get_loss(x)
reconstructions, posterior = vae(x)
loss = loss_function(x, reconstructions, posterior)
# # loss = vae.get_loss(x)
# reconstructions, posterior = vae(x)
# loss = loss_function(x, reconstructions, posterior)
# Backward & update
booster.backward(loss=loss, optimizer=optimizer)
optimizer.step()
optimizer.zero_grad()
# # Backward & update
# booster.backward(loss=loss, optimizer=optimizer)
# optimizer.step()
# optimizer.zero_grad()
# Log loss values:
all_reduce_mean(loss)
running_loss += loss.item()
global_step = epoch * num_steps_per_epoch + step
log_step += 1
# # Log loss values:
# all_reduce_mean(loss)
# running_loss += loss.item()
# global_step = epoch * num_steps_per_epoch + step
# log_step += 1
# Log to tensorboard
if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0:
avg_loss = running_loss / log_step
pbar.set_postfix({"loss": avg_loss, "step": step, "global_step": global_step})
running_loss = 0
log_step = 0
writer.add_scalar("loss", loss.item(), global_step)
if cfg.wandb:
wandb.log(
{
"iter": global_step,
"num_samples": global_step * total_batch_size,
"epoch": epoch,
"loss": loss.item(),
"avg_loss": avg_loss,
},
step=global_step,
)
# # Log to tensorboard
# if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0:
# avg_loss = running_loss / log_step
# pbar.set_postfix({"loss": avg_loss, "step": step, "global_step": global_step})
# running_loss = 0
# log_step = 0
# writer.add_scalar("loss", loss.item(), global_step)
# if cfg.wandb:
# wandb.log(
# {
# "iter": global_step,
# "num_samples": global_step * total_batch_size,
# "epoch": epoch,
# "loss": loss.item(),
# "avg_loss": avg_loss,
# },
# step=global_step,
# )
# Save checkpoint
if cfg.ckpt_every > 0 and (global_step + 1) % cfg.ckpt_every == 0:
save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}")
os.makedirs(os.path.join(save_dir, "model"), exist_ok=True)
# TODO: save in model?
booster.save_model(vae, os.path.join(save_dir, "model"), shard=True)
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096)
if lr_scheduler is not None:
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
running_states = {
"epoch": epoch,
"step": step+1,
"global_step": global_step+1,
"sample_start_index": (step+1) * cfg.batch_size,
}
if coordinator.is_master():
save_json(running_states, os.path.join(save_dir, "running_states.json"))
dist.barrier()
logger.info(
f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}"
)
# # Save checkpoint
# if cfg.ckpt_every > 0 and (global_step + 1) % cfg.ckpt_every == 0:
# save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}")
# os.makedirs(os.path.join(save_dir, "model"), exist_ok=True)
# # TODO: save in model?
# booster.save_model(vae, os.path.join(save_dir, "model"), shard=True)
# booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096)
# if lr_scheduler is not None:
# booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
# running_states = {
# "epoch": epoch,
# "step": step+1,
# "global_step": global_step+1,
# "sample_start_index": (step+1) * cfg.batch_size,
# }
# if coordinator.is_master():
# save_json(running_states, os.path.join(save_dir, "running_states.json"))
# dist.barrier()
# logger.info(
# f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}"
# )
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
dataloader.sampler.set_start_index(0)
start_step = 0
# # the continue epochs are not resumed, so we need to reset the sampler start index and start step
# dataloader.sampler.set_start_index(0)
# start_step = 0
# DEBUG inference
# 4.1. batch generation
# define loss function
loss_function = VEA3DLoss(kl_weight=cfg.kl_weight, perceptual_weight=cfg.perceptual_weight).to(device, dtype)
running_loss = 0.0
loss_steps = 0
from opensora.datasets import save_sample
total_steps = len(dataloader)
dataloader_iter = iter(dataloader)
with tqdm(
range(total_steps),
# desc=f"Avg Loss: {running_loss}",
disable=not coordinator.is_master(),
total=total_steps,
initial=0,
) as pbar:
for step in pbar:
batch = next(dataloader_iter)
x = batch["video"].to(device, dtype) # [B, C, T, H, W]
reconstructions, posterior = vae(x)
loss = loss_function(x, reconstructions, posterior)
loss_steps += 1
running_loss = loss.item()/ loss_steps + running_loss * ((loss_steps - 1) / loss_steps)
if coordinator.is_master():
for idx, sample in enumerate(reconstructions):
pos = step * cfg.batch_size + idx
save_path = os.path.join("outputs/debug", f"sample_{pos}")
save_sample(sample, fps=cfg.fps, save_path=save_path)
print("test loss:", running_loss)
if __name__ == "__main__":
main()