mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-18 16:49:41 +02:00
debug inference code
This commit is contained in:
parent
4b27448a49
commit
fa0ca3983e
|
|
@ -11,11 +11,11 @@ WANDB_API_KEY=<wandb_api_key> CUDA_VISIBLE_DEVICES=<n> torchrun --master_port=<p
|
|||
### 2. Inference
|
||||
|
||||
```yaml
|
||||
CUDA_VISIBLE_DEVICES=6 torchrun --standalone --nnodes=1 --nproc_per_node=1 scripts/inference-vae.py configs/vae_3d/inference/16x256x256.py --ckpt-path /home/shenchenhui/Open-Sora-dev/outputs/train_pexel_028/epoch3-global_step20000/ --data-path /home/shenchenhui/data/pexels/debug.csv
|
||||
CUDA_VISIBLE_DEVICES=6 torchrun --standalone --nnodes=1 --nproc_per_node=1 scripts/inference-vae.py configs/vae_3d/inference/16x256x256.py --ckpt-path /home/shenchenhui/Open-Sora-dev/outputs/train_pexel_028/epoch3-global_step20000/ --data-path /home/shenchenhui/data/pexels/debug.csv --save-dir outputs/pexel
|
||||
|
||||
CUDA_VISIBLE_DEVICES=6 torchrun --standalone --nnodes=1 --nproc_per_node=1 scripts/inference-vae.py configs/vae_3d/inference/16x256x256.py --ckpt-path /home/shenchenhui/Open-Sora-dev/outputs/004-F16S3-VAE_3D_B/epoch0-global_step1000 --data-path /home/shenchenhui/data/pexels/debug.csv
|
||||
CUDA_VISIBLE_DEVICES=6 torchrun --standalone --nnodes=1 --nproc_per_node=1 scripts/inference-vae.py configs/vae_3d/inference/16x256x256.py --ckpt-path /home/shenchenhui/Open-Sora-dev/outputs/004-F16S3-VAE_3D_B/epoch0-global_step1000 --data-path /home/shenchenhui/data/pexels/debug.csv --save-dir outputs/pexel
|
||||
|
||||
CUDA_VISIBLE_DEVICES=6 torchrun --standalone --nnodes=1 --nproc_per_node=1 scripts/inference-vae.py configs/vae_3d/inference/16x256x256.py --ckpt-path /home/shenchenhui/Open-Sora-dev/outputs/004-F16S3-VAE_3D_B/epoch0-global_step2000 --data-path /home/shenchenhui/data/pexels/debug.csv
|
||||
CUDA_VISIBLE_DEVICES=6 torchrun --standalone --nnodes=1 --nproc_per_node=1 scripts/inference-vae.py configs/vae_3d/inference/16x256x256.py --ckpt-path /home/shenchenhui/Open-Sora-dev/outputs/004-F16S3-VAE_3D_B/epoch0-global_step2000 --data-path /home/shenchenhui/data/pexels/debug.csv --save-dir outputs/pexel
|
||||
|
||||
|
||||
# debug on the same 8 samples
|
||||
|
|
@ -23,7 +23,7 @@ WANDB_API_KEY=7bc1ce71b2dc0b8cd40c500eb256747583f6c07e CUDA_VISIBLE_DEVICES=5 to
|
|||
CUDA_VISIBLE_DEVICES=6 torchrun --standalone --nnodes=1 --nproc_per_node=1 scripts/inference-vae.py configs/vae_3d/inference/16x256x256.py --ckpt-path /home/shenchenhui/Open-Sora-dev/outputs/006-F16S3-VAE_3D_B/epoch49-global_step50 --data-path /home/shenchenhui/data/pexels/debug.csv
|
||||
|
||||
# resume training debug
|
||||
WANDB_API_KEY=7bc1ce71b2dc0b8cd40c500eb256747583f6c07e CUDA_VISIBLE_DEVICES=5 torchrun --master_port=29530 --nnodes=1 --nproc_per_node=1 scripts/train-vae.py configs/vae_3d/train/16x256x256.py --data-path /home/shenchenhui/data/pexels/debug.csv --wandb True --load /home/shenchenhui/Open-Sora-dev/outputs/006-F16S3-VAE_3D_B/epoch49-global_step50
|
||||
WANDB_API_KEY=7bc1ce71b2dc0b8cd40c500eb256747583f6c07e CUDA_VISIBLE_DEVICES=5 torchrun --master_port=29530 --nnodes=1 --nproc_per_node=1 scripts/train-vae.py configs/vae_3d/train/16x256x256.py --data-path /home/shenchenhui/data/pexels/debug.csv --load /home/shenchenhui/Open-Sora-dev/outputs/006-F16S3-VAE_3D_B/epoch49-global_step50 --wandb True
|
||||
```
|
||||
|
||||
```yaml
|
||||
|
|
|
|||
|
|
@ -190,83 +190,119 @@ def main():
|
|||
loss_function = VEA3DLoss(kl_weight=cfg.kl_weight, perceptual_weight=cfg.perceptual_weight).to(device, dtype)
|
||||
|
||||
|
||||
# 6.2. training loop
|
||||
for epoch in range(start_epoch, cfg.epochs):
|
||||
dataloader.sampler.set_epoch(epoch)
|
||||
dataloader_iter = iter(dataloader)
|
||||
logger.info(f"Beginning epoch {epoch}...")
|
||||
# # 6.2. training loop
|
||||
# for epoch in range(start_epoch, cfg.epochs):
|
||||
# dataloader.sampler.set_epoch(epoch)
|
||||
# dataloader_iter = iter(dataloader)
|
||||
# logger.info(f"Beginning epoch {epoch}...")
|
||||
|
||||
with tqdm(
|
||||
range(start_step, num_steps_per_epoch),
|
||||
desc=f"Epoch {epoch}",
|
||||
disable=not coordinator.is_master(),
|
||||
total=num_steps_per_epoch,
|
||||
initial=start_step,
|
||||
) as pbar:
|
||||
for step in pbar:
|
||||
batch = next(dataloader_iter)
|
||||
x = batch["video"].to(device, dtype) # [B, C, T, H, W]
|
||||
# with tqdm(
|
||||
# range(start_step, num_steps_per_epoch),
|
||||
# desc=f"Epoch {epoch}",
|
||||
# disable=not coordinator.is_master(),
|
||||
# total=num_steps_per_epoch,
|
||||
# initial=start_step,
|
||||
# ) as pbar:
|
||||
# for step in pbar:
|
||||
# batch = next(dataloader_iter)
|
||||
# x = batch["video"].to(device, dtype) # [B, C, T, H, W]
|
||||
|
||||
# loss = vae.get_loss(x)
|
||||
reconstructions, posterior = vae(x)
|
||||
loss = loss_function(x, reconstructions, posterior)
|
||||
# # loss = vae.get_loss(x)
|
||||
# reconstructions, posterior = vae(x)
|
||||
# loss = loss_function(x, reconstructions, posterior)
|
||||
|
||||
# Backward & update
|
||||
booster.backward(loss=loss, optimizer=optimizer)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
# # Backward & update
|
||||
# booster.backward(loss=loss, optimizer=optimizer)
|
||||
# optimizer.step()
|
||||
# optimizer.zero_grad()
|
||||
|
||||
# Log loss values:
|
||||
all_reduce_mean(loss)
|
||||
running_loss += loss.item()
|
||||
global_step = epoch * num_steps_per_epoch + step
|
||||
log_step += 1
|
||||
# # Log loss values:
|
||||
# all_reduce_mean(loss)
|
||||
# running_loss += loss.item()
|
||||
# global_step = epoch * num_steps_per_epoch + step
|
||||
# log_step += 1
|
||||
|
||||
# Log to tensorboard
|
||||
if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0:
|
||||
avg_loss = running_loss / log_step
|
||||
pbar.set_postfix({"loss": avg_loss, "step": step, "global_step": global_step})
|
||||
running_loss = 0
|
||||
log_step = 0
|
||||
writer.add_scalar("loss", loss.item(), global_step)
|
||||
if cfg.wandb:
|
||||
wandb.log(
|
||||
{
|
||||
"iter": global_step,
|
||||
"num_samples": global_step * total_batch_size,
|
||||
"epoch": epoch,
|
||||
"loss": loss.item(),
|
||||
"avg_loss": avg_loss,
|
||||
},
|
||||
step=global_step,
|
||||
)
|
||||
# # Log to tensorboard
|
||||
# if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0:
|
||||
# avg_loss = running_loss / log_step
|
||||
# pbar.set_postfix({"loss": avg_loss, "step": step, "global_step": global_step})
|
||||
# running_loss = 0
|
||||
# log_step = 0
|
||||
# writer.add_scalar("loss", loss.item(), global_step)
|
||||
# if cfg.wandb:
|
||||
# wandb.log(
|
||||
# {
|
||||
# "iter": global_step,
|
||||
# "num_samples": global_step * total_batch_size,
|
||||
# "epoch": epoch,
|
||||
# "loss": loss.item(),
|
||||
# "avg_loss": avg_loss,
|
||||
# },
|
||||
# step=global_step,
|
||||
# )
|
||||
|
||||
# Save checkpoint
|
||||
if cfg.ckpt_every > 0 and (global_step + 1) % cfg.ckpt_every == 0:
|
||||
save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}")
|
||||
os.makedirs(os.path.join(save_dir, "model"), exist_ok=True)
|
||||
# TODO: save in model?
|
||||
booster.save_model(vae, os.path.join(save_dir, "model"), shard=True)
|
||||
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096)
|
||||
if lr_scheduler is not None:
|
||||
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
|
||||
running_states = {
|
||||
"epoch": epoch,
|
||||
"step": step+1,
|
||||
"global_step": global_step+1,
|
||||
"sample_start_index": (step+1) * cfg.batch_size,
|
||||
}
|
||||
if coordinator.is_master():
|
||||
save_json(running_states, os.path.join(save_dir, "running_states.json"))
|
||||
dist.barrier()
|
||||
logger.info(
|
||||
f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}"
|
||||
)
|
||||
# # Save checkpoint
|
||||
# if cfg.ckpt_every > 0 and (global_step + 1) % cfg.ckpt_every == 0:
|
||||
# save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}")
|
||||
# os.makedirs(os.path.join(save_dir, "model"), exist_ok=True)
|
||||
# # TODO: save in model?
|
||||
# booster.save_model(vae, os.path.join(save_dir, "model"), shard=True)
|
||||
# booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096)
|
||||
# if lr_scheduler is not None:
|
||||
# booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
|
||||
# running_states = {
|
||||
# "epoch": epoch,
|
||||
# "step": step+1,
|
||||
# "global_step": global_step+1,
|
||||
# "sample_start_index": (step+1) * cfg.batch_size,
|
||||
# }
|
||||
# if coordinator.is_master():
|
||||
# save_json(running_states, os.path.join(save_dir, "running_states.json"))
|
||||
# dist.barrier()
|
||||
# logger.info(
|
||||
# f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}"
|
||||
# )
|
||||
|
||||
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
|
||||
dataloader.sampler.set_start_index(0)
|
||||
start_step = 0
|
||||
# # the continue epochs are not resumed, so we need to reset the sampler start index and start step
|
||||
# dataloader.sampler.set_start_index(0)
|
||||
# start_step = 0
|
||||
|
||||
# DEBUG inference
|
||||
|
||||
# 4.1. batch generation
|
||||
|
||||
# define loss function
|
||||
loss_function = VEA3DLoss(kl_weight=cfg.kl_weight, perceptual_weight=cfg.perceptual_weight).to(device, dtype)
|
||||
running_loss = 0.0
|
||||
loss_steps = 0
|
||||
|
||||
from opensora.datasets import save_sample
|
||||
|
||||
total_steps = len(dataloader)
|
||||
dataloader_iter = iter(dataloader)
|
||||
|
||||
with tqdm(
|
||||
range(total_steps),
|
||||
# desc=f"Avg Loss: {running_loss}",
|
||||
disable=not coordinator.is_master(),
|
||||
total=total_steps,
|
||||
initial=0,
|
||||
) as pbar:
|
||||
for step in pbar:
|
||||
batch = next(dataloader_iter)
|
||||
x = batch["video"].to(device, dtype) # [B, C, T, H, W]
|
||||
reconstructions, posterior = vae(x)
|
||||
loss = loss_function(x, reconstructions, posterior)
|
||||
loss_steps += 1
|
||||
running_loss = loss.item()/ loss_steps + running_loss * ((loss_steps - 1) / loss_steps)
|
||||
|
||||
if coordinator.is_master():
|
||||
for idx, sample in enumerate(reconstructions):
|
||||
pos = step * cfg.batch_size + idx
|
||||
save_path = os.path.join("outputs/debug", f"sample_{pos}")
|
||||
save_sample(sample, fps=cfg.fps, save_path=save_path)
|
||||
|
||||
print("test loss:", running_loss)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
Loading…
Reference in a new issue