Open-Sora/scripts/inference-vae.py

137 lines
5.2 KiB
Python
Raw Normal View History

2024-04-16 12:15:12 +02:00
import os
import colossalai
import torch
import torch.distributed as dist
from colossalai.cluster import DistCoordinator
2024-04-30 08:19:40 +02:00
from mmengine.runner import set_random_seed
2024-04-29 09:27:15 +02:00
from tqdm import tqdm
2024-04-16 12:15:12 +02:00
2024-04-30 10:13:20 +02:00
from opensora.acceleration.parallel_states import get_data_parallel_group
2024-04-29 09:27:15 +02:00
from opensora.datasets import prepare_dataloader, save_sample
2024-04-30 08:19:40 +02:00
from opensora.models.vae.losses import VAELoss
2024-04-29 09:27:15 +02:00
from opensora.registry import DATASETS, MODELS, build_module
2024-04-27 14:59:07 +02:00
from opensora.utils.config_utils import parse_configs
2024-04-16 12:15:12 +02:00
from opensora.utils.misc import to_torch_dtype
def main():
# ======================================================
# 1. cfg and init distributed env
# ======================================================
cfg = parse_configs(training=False)
print(cfg)
# init distributed
2024-04-30 08:19:40 +02:00
if os.environ.get("WORLD_SIZE", None):
use_dist = True
colossalai.launch_from_torch({})
coordinator = DistCoordinator()
else:
use_dist = False
2024-04-16 12:15:12 +02:00
# ======================================================
# 2. runtime variables
# ======================================================
torch.set_grad_enabled(False)
2024-04-30 08:19:40 +02:00
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
device = "cuda" if torch.cuda.is_available() else "cpu"
2024-04-16 12:15:12 +02:00
dtype = to_torch_dtype(cfg.dtype)
2024-04-30 08:19:40 +02:00
set_random_seed(seed=cfg.seed)
2024-04-16 12:15:12 +02:00
# ======================================================
# 3. build dataset and dataloader
# ======================================================
2024-04-27 11:02:24 +02:00
dataset = build_module(cfg.dataset, DATASETS)
2024-04-16 12:15:12 +02:00
dataloader = prepare_dataloader(
dataset,
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
2024-04-16 12:30:31 +02:00
shuffle=False,
2024-04-16 12:15:12 +02:00
drop_last=True,
pin_memory=True,
process_group=get_data_parallel_group(),
)
2024-04-27 11:21:45 +02:00
print(f"Dataset contains {len(dataset):,} videos ({cfg.dataset.data_path})")
2024-04-30 10:13:20 +02:00
total_batch_size = cfg.batch_size * dist.get_world_size()
2024-04-16 12:15:12 +02:00
print(f"Total batch size: {total_batch_size}")
# ======================================================
# 4. build model & load weights
# ======================================================
2024-04-30 10:13:20 +02:00
# 4.1. build model
model = build_module(cfg.model, MODELS)
model.to(device, dtype).eval()
2024-04-16 12:15:12 +02:00
# ======================================================
2024-04-30 10:13:20 +02:00
# 5. inference
2024-04-16 12:15:12 +02:00
# ======================================================
save_dir = cfg.save_dir
# define loss function
2024-04-30 10:13:20 +02:00
vae_loss_fn = VAELoss(
logvar_init=cfg.get("logvar_init", 0.0),
perceptual_loss_weight=cfg.perceptual_loss_weight,
kl_loss_weight=cfg.kl_loss_weight,
device=device,
dtype=dtype,
)
2024-04-16 12:15:12 +02:00
2024-04-30 10:13:20 +02:00
# get total number of steps
2024-04-16 12:15:12 +02:00
total_steps = len(dataloader)
2024-04-30 08:19:40 +02:00
if cfg.max_test_samples is not None:
2024-04-29 09:27:15 +02:00
total_steps = min(int(cfg.max_test_samples // cfg.batch_size), total_steps)
2024-04-19 08:47:34 +02:00
print(f"limiting test dataset to {int(cfg.max_test_samples//cfg.batch_size) * cfg.batch_size}")
2024-04-16 12:15:12 +02:00
dataloader_iter = iter(dataloader)
2024-05-02 05:46:01 +02:00
running_loss = running_nll = running_nll_z = 0.0
2024-04-30 10:13:20 +02:00
loss_steps = 0
2024-04-16 12:15:12 +02:00
with tqdm(
range(total_steps),
disable=not coordinator.is_master(),
total=total_steps,
initial=0,
) as pbar:
for step in pbar:
batch = next(dataloader_iter)
x = batch["video"].to(device, dtype) # [B, C, T, H, W]
2024-05-05 17:25:17 +02:00
input_size = x.shape[2:]
latent_size = model.get_latent_size(input_size)
2024-04-16 12:15:12 +02:00
2024-04-30 10:13:20 +02:00
# ===== VAE =====
2024-05-05 17:25:17 +02:00
z, posterior, x_z = model.encode(x)
assert list(z.shape[2:]) == latent_size, f"z shape: {z.shape}, latent_size: {latent_size}"
2024-05-02 05:46:01 +02:00
x_rec, x_z_rec = model.decode(z, num_frames=x.size(2))
2024-04-30 10:13:20 +02:00
x_ref = model.spatial_vae.decode(x_z)
# loss calculation
nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(x, x_rec, posterior)
2024-05-02 05:46:01 +02:00
nll_loss_z, _, _ = vae_loss_fn(x_z, x_z_rec, posterior, no_perceptual=True)
2024-04-30 10:13:20 +02:00
vae_loss = weighted_nll_loss + weighted_kl_loss
loss_steps += 1
running_loss = vae_loss.item() / loss_steps + running_loss * ((loss_steps - 1) / loss_steps)
running_nll = nll_loss.item() / loss_steps + running_nll * ((loss_steps - 1) / loss_steps)
2024-05-02 05:46:01 +02:00
running_nll_z = nll_loss_z.item() / loss_steps + running_nll_z * ((loss_steps - 1) / loss_steps)
2024-04-19 05:06:13 +02:00
2024-04-30 08:19:40 +02:00
if not use_dist or coordinator.is_master():
2024-04-30 10:37:07 +02:00
ori_dir = f"{save_dir}_ori"
rec_dir = f"{save_dir}_rec"
ref_dir = f"{save_dir}_ref"
os.makedirs(ori_dir, exist_ok=True)
os.makedirs(rec_dir, exist_ok=True)
os.makedirs(ref_dir, exist_ok=True)
2024-04-30 10:13:20 +02:00
for idx, vid in enumerate(x):
2024-04-30 08:19:40 +02:00
pos = step * cfg.batch_size + idx
2024-04-30 10:37:07 +02:00
save_sample(vid, fps=cfg.fps, save_path=f"{ori_dir}/{pos:03d}")
save_sample(x_rec[idx], fps=cfg.fps, save_path=f"{rec_dir}/{pos:03d}")
save_sample(x_ref[idx], fps=cfg.fps, save_path=f"{ref_dir}/{pos:03d}")
2024-04-30 10:13:20 +02:00
print("test vae loss:", running_loss)
print("test nll loss:", running_nll)
2024-05-02 05:46:01 +02:00
print("test nll_z loss:", running_nll_z)
2024-04-16 12:15:12 +02:00
if __name__ == "__main__":
main()