[feat] update inference config

This commit is contained in:
zhengzangw 2024-05-14 05:40:17 +00:00
parent 3a01cb440d
commit e1e379d898
16 changed files with 261 additions and 321 deletions

View file

@ -1,4 +1,5 @@
image_size = (240, 426)
resolution = "240p"
aspect_ratio = "9:16"
num_frames = 51
fps = 24
frame_interval = 1
@ -30,22 +31,10 @@ model = dict(
enable_layernorm_kernel=True,
)
vae = dict(
type="VideoAutoencoderPipeline",
type="OpenSoraVAE_V1_2",
from_pretrained="pretrained_models/vae-pipeline",
shift=(-0.10, 0.34, 0.27, 0.98),
scale=(3.85, 2.32, 2.33, 3.06),
micro_frame_size=17,
vae_2d=dict(
type="VideoAutoencoderKL",
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
subfolder="vae",
micro_batch_size=4,
local_files_only=True,
),
vae_temporal=dict(
type="VAE_Temporal_SD",
from_pretrained=None,
),
micro_batch_size=4,
)
text_encoder = dict(
type="t5",

View file

@ -36,22 +36,10 @@ model = dict(
enable_layernorm_kernel=True,
)
vae = dict(
type="VideoAutoencoderPipeline",
type="OpenSoraVAE_V1_2",
from_pretrained="pretrained_models/vae-pipeline",
shift=(-0.10, 0.34, 0.27, 0.98),
scale=(3.85, 2.32, 2.33, 3.06),
micro_frame_size=17,
vae_2d=dict(
type="VideoAutoencoderKL",
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
subfolder="vae",
micro_batch_size=4,
local_files_only=True,
),
vae_temporal=dict(
type="VAE_Temporal_SD",
from_pretrained=None,
),
micro_batch_size=4,
)
text_encoder = dict(
type="t5",

View file

@ -1,88 +0,0 @@
# Dataset settings
dataset = dict(
type="VariableVideoTextDataset",
transform_name="resize_crop",
frame_interval=1,
)
bucket_config = { # 20s/it
"1024": {1: (1.0, 1)},
}
grad_checkpoint = True
batch_size = None
# Acceleration settings
num_workers = 8
num_bucket_build_workers = 16
dtype = "bf16"
plugin = "zero2"
# Model settings
model = dict(
type="STDiT3-XL/2",
from_pretrained=None,
qk_norm=True,
enable_flash_attn=True,
enable_layernorm_kernel=True,
)
vae = dict(
type="VideoAutoencoderPipeline",
from_pretrained="pretrained_models/vae-pipeline",
micro_frame_size=17,
shift=(-0.10, 0.34, 0.27, 0.98),
scale=(3.85, 2.32, 2.33, 3.06),
vae_2d=dict(
type="VideoAutoencoderKL",
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
subfolder="vae",
micro_batch_size=4,
local_files_only=True,
),
vae_temporal=dict(
type="VAE_Temporal_SD",
from_pretrained=None,
),
)
text_encoder = dict(
type="t5",
from_pretrained="DeepFloyd/t5-v1_1-xxl",
model_max_length=300,
shardformer=True,
local_files_only=True,
)
scheduler = dict(
type="rflow",
use_discrete_timesteps=False,
use_timestep_transform=False,
# sample_method="logit-normal",
)
# Mask settings
# mask_ratios = {
# "random": 0.1,
# "intepolate": 0.01,
# "quarter_random": 0.01,
# "quarter_head": 0.01,
# "quarter_tail": 0.01,
# "quarter_head_tail": 0.01,
# "image_random": 0.05,
# "image_head": 0.1,
# "image_tail": 0.05,
# "image_head_tail": 0.05,
# }
# Log settings
seed = 42
outputs = "outputs"
wandb = False
epochs = 1000
log_every = 10
ckpt_every = 500
# optimization settings
load = None
grad_clip = 1.0
lr = 2e-4
ema_decay = 0.99
adam_eps = 1e-15

View file

@ -61,22 +61,10 @@ model = dict(
enable_layernorm_kernel=True,
)
vae = dict(
type="VideoAutoencoderPipeline",
type="OpenSoraVAE_V1_2",
from_pretrained="pretrained_models/vae-pipeline",
shift=(-0.10, 0.34, 0.27, 0.98),
scale=(3.85, 2.32, 2.33, 3.06),
micro_frame_size=17,
vae_2d=dict(
type="VideoAutoencoderKL",
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
subfolder="vae",
micro_batch_size=4,
local_files_only=True,
),
vae_temporal=dict(
type="VAE_Temporal_SD",
from_pretrained=None,
),
micro_batch_size=4,
)
text_encoder = dict(
type="t5",

View file

@ -1,44 +1,32 @@
num_frames = 1
frame_interval = 1
fps = 24
image_size = (256, 256)
num_frames = 1
dtype = "bf16"
batch_size = 1
seed = 42
save_dir = "samples/vae_video"
cal_stats = True
log_stats_every = 100
# Define dataset
dataset = dict(
type="VideoTextDataset",
data_path=None,
num_frames=num_frames,
frame_interval=1,
image_size=image_size,
)
num_samples = 100
num_workers = 4
max_test_samples = None
# Define model
model = dict(
type="VideoAutoencoderPipeline",
freeze_vae_2d=True,
type="OpenSoraVAE_V1_2",
from_pretrained="pretrained_models/vae-pipeline",
micro_frame_size=None,
micro_batch_size=4,
cal_loss=True,
vae_2d=dict(
type="VideoAutoencoderKL",
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
subfolder="vae",
micro_batch_size=4,
local_files_only=True,
),
vae_temporal=dict(
type="VAE_Temporal_SD",
from_pretrained=None,
),
)
dtype = "bf16"
# loss weights
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
kl_loss_weight = 1e-6
calc_std = True
# Others
batch_size = 1
seed = 42
save_dir = "samples/vae_image"

View file

@ -1,45 +1,32 @@
num_frames = 17
frame_interval = 1
fps = 24
image_size = (256, 256)
num_frames = 17
dtype = "bf16"
batch_size = 1
seed = 42
save_dir = "samples/vae_video"
cal_stats = True
log_stats_every = 100
# Define dataset
dataset = dict(
type="VideoTextDataset",
data_path=None,
num_frames=num_frames,
frame_interval=1,
image_size=image_size,
)
num_samples = 100
num_workers = 4
max_test_samples = None
# Define model
model = dict(
type="VideoAutoencoderPipeline",
from_pretrained=None,
type="OpenSoraVAE_V1_2",
from_pretrained="pretrained_models/vae-pipeline",
micro_frame_size=None,
micro_batch_size=4,
cal_loss=True,
micro_frame_size=16,
vae_2d=dict(
type="VideoAutoencoderKL",
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
subfolder="vae",
micro_batch_size=4,
local_files_only=True,
),
vae_temporal=dict(
type="VAE_Temporal_SD",
from_pretrained=None,
),
)
dtype = "bf16"
# loss weights
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
kl_loss_weight = 1e-6
calc_std = True
# Others
batch_size = 1
seed = 42
save_dir = "samples/vae_video"

View file

View file

@ -1,28 +1,7 @@
import math
# computation
AR = (
3 / 8,
9 / 21,
0.48,
1 / 2,
9 / 17,
1 / 1.85,
9 / 16,
5 / 8,
2 / 3,
3 / 4,
1 / 1,
4 / 3,
3 / 2,
16 / 9,
17 / 9,
2 / 1,
1 / 0.48,
)
AR_fraction = (0.375, 0.43, 0.48, 0.50, 0.53, 0.54, 0.56, 0.62, 0.67, 0.75, 1, 1.33, 1.50, 1.78, 1.89, 2, 2.08)
def get_h_w(a, ts, eps=1e-4):
h = (ts * a) ** 0.5
h = h + eps
@ -33,11 +12,39 @@ def get_h_w(a, ts, eps=1e-4):
return h, w
def get_aspect_ratios_dict(ts=360 * 640, ars=AR):
def get_aspect_ratios_dict(ars, ts=360 * 640):
est = {f"{a:.2f}": get_h_w(a, ts) for a in ars}
return est
def get_ar(ratio):
h, w = ratio.split(":")
return int(h) / int(w)
ASPECT_RATIO_MAP = {
"3:8": "0.38",
"9:21": "0.43",
"12:25": "0.48",
"1:2": "0.50",
"9:17": "0.53",
"27:50": "0.54",
"9:16": "0.56",
"5:8": "0.62",
"2:3": "0.67",
"3:4": "0.75",
"1:1": "1.00",
"4:3": "1.33",
"3:2": "1.50",
"16:9": "1.78",
"17:9": "1.89",
"2:1": "2.00",
"50:27": "2.08",
}
AR = [get_ar(ratio) for ratio in ASPECT_RATIO_MAP.keys()]
# computed from above code
# S = 8294400
ASPECT_RATIO_4K = {
@ -454,3 +461,10 @@ ASPECT_RATIOS = {
def get_num_pixels(name):
return ASPECT_RATIOS[name][0]
def get_image_size(resolution, ar_ratio):
ar_key = ASPECT_RATIO_MAP[ar_ratio]
rs_dict = ASPECT_RATIOS[resolution][1]
assert ar_key in rs_dict, f"Aspect ratio {ar_ratio} not found for resolution {resolution}"
return rs_dict[ar_key]

View file

@ -19,6 +19,7 @@ def prepare_dataloader(
pin_memory=False,
num_workers=0,
process_group: Optional[ProcessGroup] = None,
distributed=True,
**kwargs,
):
r"""
@ -43,13 +44,16 @@ def prepare_dataloader(
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs = kwargs.copy()
process_group = process_group or _get_default_group()
sampler = StatefulDistributedSampler(
dataset,
num_replicas=process_group.size(),
rank=process_group.rank(),
shuffle=shuffle,
)
if distributed:
process_group = process_group or _get_default_group()
sampler = StatefulDistributedSampler(
dataset,
num_replicas=process_group.size(),
rank=process_group.rank(),
shuffle=shuffle,
)
else:
sampler = None
# Deterministic dataloader
def seed_worker(worker_id):

View file

@ -220,3 +220,39 @@ class VideoAutoencoderPipeline(nn.Module):
@property
def dtype(self):
return next(self.parameters()).dtype
@MODELS.register_module()
class OpenSoraVAE_V1_2(VideoAutoencoderPipeline):
def __init__(
self,
micro_batch_size=4,
micro_frame_size=17,
from_pretrained=None,
local_files_only=True,
freeze_vae_2d=False,
cal_loss=False,
):
vae_2d = dict(
type="VideoAutoencoderKL",
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
subfolder="vae",
micro_batch_size=micro_batch_size,
local_files_only=local_files_only,
)
vae_temporal = dict(
type="VAE_Temporal_SD",
from_pretrained=None,
)
shift = (-0.10, 0.34, 0.27, 0.98)
scale = (3.85, 2.32, 2.33, 3.06)
super().__init__(
vae_2d,
vae_temporal,
from_pretrained,
freeze_vae_2d=freeze_vae_2d,
cal_loss=cal_loss,
micro_frame_size=micro_frame_size,
shift=shift,
scale=scale,
)

View file

@ -24,10 +24,8 @@ def parse_args(training=False):
)
parser.add_argument("--batch-size", default=None, type=int, help="batch size")
parser.add_argument("--outputs", default=None, type=str, help="the dir to save model weights")
parser.add_argument("--flash-attn", default=None, action=argparse.BooleanOptionalAction, help="enable flash attn")
parser.add_argument(
"--layernorm-kernel", default=None, action=argparse.BooleanOptionalAction, help="enable layernorm kernel"
)
parser.add_argument("--flash-attn", default=None, type=str2bool, help="enable flash attention")
parser.add_argument("--layernorm-kernel", default=None, type=str2bool, help="enable layernorm kernel")
# ======================================================
# Inference
@ -51,6 +49,9 @@ def parse_args(training=False):
parser.add_argument("--num-frames", default=None, type=int, help="number of frames")
parser.add_argument("--fps", default=None, type=int, help="fps")
parser.add_argument("--image-size", default=None, type=int, nargs=2, help="image size")
parser.add_argument("--frame-interval", default=None, type=int, help="frame interval")
parser.add_argument("--resolution", default=None, type=str, help="multi resolution")
parser.add_argument("--aspect-ratio", default=None, type=float, help="aspect ratio")
# hyperparameters
parser.add_argument("--num-sampling-steps", default=None, type=int, help="sampling steps")
@ -143,3 +144,14 @@ def define_experiment_workspace(cfg, get_last_workspace=False):
def save_training_config(cfg, experiment_dir):
with open(f"{experiment_dir}/config.txt", "w") as f:
json.dump(cfg, f, indent=4)
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")

View file

@ -132,7 +132,9 @@ def find_nearest_point(value, point, max_value):
def apply_mask_strategy(z, refs_x, mask_strategys, loop_i, align=None):
masks = []
no_mask = True
for i, mask_strategy in enumerate(mask_strategys):
no_mask = False
mask = torch.ones(z.shape[2], dtype=torch.float, device=z.device)
mask_strategy = parse_mask_strategy(mask_strategy)
for mst in mask_strategy:
@ -154,6 +156,8 @@ def apply_mask_strategy(z, refs_x, mask_strategys, loop_i, align=None):
z[i, :, m_target_start : m_target_start + m_length] = ref[:, m_ref_start : m_ref_start + m_length]
mask[m_target_start : m_target_start + m_length] = edit_ratio
masks.append(mask)
if no_mask:
return None
masks = torch.stack(masks)
return masks

View file

@ -26,6 +26,13 @@ def is_main_process():
return not is_distributed() or dist.get_rank() == 0
def get_world_size():
if is_distributed():
return dist.get_world_size()
else:
return 1
def create_logger(logging_dir=None):
"""
Create a logger that writes to a log file and stdout.

View file

@ -10,6 +10,7 @@ from tqdm import tqdm
from opensora.acceleration.parallel_states import set_sequence_parallel_group
from opensora.datasets import save_sample
from opensora.datasets.aspect import get_image_size
from opensora.models.text_encoder.t5 import text_preprocessing
from opensora.registry import MODELS, SCHEDULERS, build_module
from opensora.utils.config_utils import parse_configs
@ -52,7 +53,7 @@ def main():
else:
coordinator = None
enable_sequence_parallelism = False
set_random_seed(seed=cfg.seed)
set_random_seed(seed=cfg.get("seed", 1024))
# == init logger ==
logger = create_logger()
@ -68,8 +69,19 @@ def main():
text_encoder = build_module(cfg.text_encoder, MODELS, device=device)
vae = build_module(cfg.vae, MODELS).to(device, dtype).eval()
# == prepare video size ==
image_size = cfg.get("image_size", None)
if image_size is None:
resolution = cfg.get("resolution", None)
aspect_ratio = cfg.get("aspect_ratio", None)
assert (
resolution is not None and aspect_ratio is not None
), "resolution and aspect_ratio must be provided if image_size is not provided"
image_size = get_image_size(resolution, aspect_ratio)
num_frames = cfg.num_frames
# == build diffusion model ==
input_size = (cfg.get("num_frames", None), *cfg.get("image_size", (None, None)))
input_size = (num_frames, *image_size)
latent_size = vae.get_latent_size(input_size)
model = (
build_module(
@ -106,10 +118,8 @@ def main():
assert len(mask_strategy) == len(prompts), "Length of mask_strategy must be the same as prompts"
# == prepare arguments ==
image_size = cfg.image_size
num_frames = cfg.num_frames
fps = cfg.fps
save_fps = cfg.fps // cfg.get("frame_interval", 1)
save_fps = fps // cfg.get("frame_interval", 1)
multi_resolution = cfg.get("multi_resolution", None)
batch_size = cfg.get("batch_size", 1)
num_sample = cfg.get("num_sample", 1)

View file

@ -1,144 +1,145 @@
import os
from pprint import pformat
import colossalai
import torch
import torch.distributed as dist
from colossalai.cluster import DistCoordinator
from mmengine.runner import set_random_seed
from tqdm import tqdm
from opensora.acceleration.parallel_states import get_data_parallel_group
from opensora.datasets import prepare_dataloader
from opensora.datasets import prepare_dataloader, save_sample
from opensora.models.vae.losses import VAELoss
from opensora.registry import DATASETS, MODELS, build_module
from opensora.utils.config_utils import parse_configs
from opensora.utils.misc import to_torch_dtype
from opensora.utils.misc import create_logger, get_world_size, is_distributed, is_main_process, to_torch_dtype
def main():
# ======================================================
# 1. cfg and init distributed env
# ======================================================
cfg = parse_configs(training=False)
print(cfg)
# init distributed
if os.environ.get("WORLD_SIZE", None):
colossalai.launch_from_torch({})
coordinator = DistCoordinator()
else:
pass
# ======================================================
# 2. runtime variables
# ======================================================
torch.set_grad_enabled(False)
# ======================================================
# configs & runtime variables
# ======================================================
# == parse configs ==
cfg = parse_configs(training=False)
# == device and dtype ==
device = "cuda" if torch.cuda.is_available() else "cpu"
cfg_dtype = cfg.get("dtype", "fp32")
assert cfg_dtype in ["fp16", "bf16", "fp32"], f"Unknown mixed precision {cfg_dtype}"
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = to_torch_dtype(cfg.dtype)
set_random_seed(seed=cfg.seed)
# == init distributed env ==
if is_distributed():
colossalai.launch_from_torch({})
set_random_seed(seed=cfg.get("seed", 1024))
# == init logger ==
logger = create_logger()
logger.info("Training configuration:\n %s", pformat(cfg.to_dict()))
verbose = cfg.get("verbose", 1)
# ======================================================
# 3. build dataset and dataloader
# build dataset and dataloader
# ======================================================
logger.info("Building reconstruction dataset...")
dataset = build_module(cfg.dataset, DATASETS)
batch_size = cfg.get("batch_size", 1)
dataloader = prepare_dataloader(
dataset,
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
batch_size=batch_size,
num_workers=cfg.get("num_workers", 4),
shuffle=False,
drop_last=True,
drop_last=False,
pin_memory=True,
process_group=get_data_parallel_group(),
distributed=is_distributed(),
)
print(f"Dataset contains {len(dataset):,} videos ({cfg.dataset.data_path})")
total_batch_size = cfg.batch_size * dist.get_world_size()
print(f"Total batch size: {total_batch_size}")
logger.info("Dataset %s contains %s videos.", cfg.dataset.data_path, len(dataset))
total_batch_size = batch_size * get_world_size()
logger.info("Total batch size: %s", total_batch_size)
total_steps = len(dataloader)
if cfg.get("num_samples", None) is not None:
total_steps = min(int(cfg.num_samples // cfg.batch_size), total_steps)
logger.info("limiting test dataset to %s", int(cfg.num_samples // cfg.batch_size) * cfg.batch_size)
dataiter = iter(dataloader)
# ======================================================
# 4. build model & load weights
# build model & loss
# ======================================================
# 4.1. build model
model = build_module(cfg.model, MODELS)
model.to(device, dtype).eval()
# ======================================================
# 5. inference
# ======================================================
cfg.save_dir
# define loss function
logger.info("Building models...")
model = build_module(cfg.model, MODELS).to(device, dtype).eval()
vae_loss_fn = VAELoss(
logvar_init=cfg.get("logvar_init", 0.0),
perceptual_loss_weight=cfg.perceptual_loss_weight,
kl_loss_weight=cfg.kl_loss_weight,
perceptual_loss_weight=cfg.get("perceptual_loss_weight", 0.1),
kl_loss_weight=cfg.get("kl_loss_weight", 1e-6),
device=device,
dtype=dtype,
)
# get total number of steps
total_steps = len(dataloader)
if cfg.max_test_samples is not None:
total_steps = min(int(cfg.max_test_samples // cfg.batch_size), total_steps)
print(f"limiting test dataset to {int(cfg.max_test_samples//cfg.batch_size) * cfg.batch_size}")
dataloader_iter = iter(dataloader)
# ======================================================
# inference
# ======================================================
# == global variables ==
running_loss = running_nll = running_nll_z = 0.0
loss_steps = 0
calc_std = cfg.get("calc_std", False)
if calc_std:
running_sum = 0.0
running_sum_c = torch.zeros(model.out_channels, dtype=torch.float, device=device)
running_var = 0.0
running_var_c = torch.zeros(model.out_channels, dtype=torch.float, device=device)
cal_stats = cfg.get("cal_stats", False)
if cal_stats:
num_samples = 0
running_sum = running_var = 0.0
running_sum_c = torch.zeros(model.out_channels, dtype=torch.float, device=device)
running_var_c = torch.zeros(model.out_channels, dtype=torch.float, device=device)
# prepare arguments
save_fps = cfg.get("fps", 24) // cfg.get("frame_interval", 1)
# Iter over the dataset
with tqdm(
range(total_steps),
disable=not coordinator.is_master(),
disable=not is_main_process() or verbose < 1,
total=total_steps,
initial=0,
) as pbar:
for step in pbar:
batch = next(dataloader_iter)
batch = next(dataiter)
x = batch["video"].to(device, dtype) # [B, C, T, H, W]
# == vae encoding & decoding ===
z, posterior, x_z = model.encode(x)
x_rec, x_z_rec = model.decode(z, num_frames=x.size(2))
x_ref = model.spatial_vae.decode(x_z)
# == check z shape ==
input_size = x.shape[2:]
latent_size = model.get_latent_size(input_size)
assert list(z.shape[2:]) == latent_size, f"z shape: {z.shape}, latent_size: {latent_size}"
# ===== VAE =====
z, posterior, x_z = model.encode(x)
# calc std
if calc_std:
# == calculate stats ==
if cal_stats:
num_samples += 1
running_sum += z.mean().item()
running_var += (z - running_sum / num_samples).pow(2).mean().item()
running_sum_c += z.mean(dim=(0, 2, 3, 4)).float()
running_var_c += (
(z - running_sum_c[None, :, None, None, None] / num_samples).pow(2).mean(dim=(0, 2, 3, 4)).float()
)
pbar.set_postfix(
{
"mean": running_sum / num_samples,
"std": (running_var / num_samples) ** 0.5,
}
)
if num_samples % 100 == 0:
print(
" mean_c ",
if verbose >= 1:
pbar.set_postfix(
{
"mean": running_sum / num_samples,
"std": (running_var / num_samples) ** 0.5,
}
)
if num_samples % cfg.get("log_stats_every", 100) == 0:
logger.info(
"VAE feature per channel stats: mean %s, var %s",
(running_sum_c / num_samples).cpu().tolist(),
"std_c ",
(running_var_c / num_samples).sqrt().cpu().tolist(),
)
assert list(z.shape[2:]) == latent_size, f"z shape: {z.shape}, latent_size: {latent_size}"
x_rec, x_z_rec = model.decode(z, num_frames=x.size(2))
model.spatial_vae.decode(x_z)
# loss calculation
# == loss calculation ==
nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(x, x_rec, posterior)
nll_loss_z, _, _ = vae_loss_fn(x_z, x_z_rec, posterior, no_perceptual=True)
vae_loss = weighted_nll_loss + weighted_kl_loss
@ -147,24 +148,24 @@ def main():
running_nll = nll_loss.item() / loss_steps + running_nll * ((loss_steps - 1) / loss_steps)
running_nll_z = nll_loss_z.item() / loss_steps + running_nll_z * ((loss_steps - 1) / loss_steps)
# if not use_dist or coordinator.is_master():
# ori_dir = f"{save_dir}_ori"
# rec_dir = f"{save_dir}_rec"
# ref_dir = f"{save_dir}_ref"
# os.makedirs(ori_dir, exist_ok=True)
# os.makedirs(rec_dir, exist_ok=True)
# os.makedirs(ref_dir, exist_ok=True)
# for idx, vid in enumerate(x):
# pos = step * cfg.batch_size + idx
# save_sample(vid, fps=cfg.fps, save_path=f"{ori_dir}/{pos:03d}")
# save_sample(x_rec[idx], fps=cfg.fps, save_path=f"{rec_dir}/{pos:03d}")
# save_sample(x_ref[idx], fps=cfg.fps, save_path=f"{ref_dir}/{pos:03d}")
# == save samples ==
save_dir = cfg.get("save_dir", None)
if is_main_process() and save_dir is not None:
ori_dir = f"{save_dir}_ori"
rec_dir = f"{save_dir}_rec"
ref_dir = f"{save_dir}_spatial"
os.makedirs(ori_dir, exist_ok=True)
os.makedirs(rec_dir, exist_ok=True)
os.makedirs(ref_dir, exist_ok=True)
for idx, vid in enumerate(x):
pos = step * cfg.batch_size + idx
save_sample(vid, fps=save_fps, save_path=f"{ori_dir}/{pos:03d}", verbose=verbose >= 2)
save_sample(x_rec[idx], fps=save_fps, save_path=f"{rec_dir}/{pos:03d}", verbose=verbose >= 2)
save_sample(x_ref[idx], fps=save_fps, save_path=f"{ref_dir}/{pos:03d}", verbose=verbose >= 2)
print("test vae loss:", running_loss)
print("test nll loss:", running_nll)
print("test nll_z loss:", running_nll_z)
if calc_std:
print("z std:", running_std_sum / num_samples)
logger.info("VAE loss: %s", running_loss)
logger.info("VAE nll loss: %s", running_nll)
logger.info("VAE nll_z loss: %s", running_nll_z)
if __name__ == "__main__":

View file

@ -230,12 +230,12 @@ def main():
for epoch in range(start_epoch, cfg_epochs):
# == set dataloader to new epoch ==
dataloader.sampler.set_epoch(epoch)
dataloader_iter = iter(dataloader)
dataiter = iter(dataloader)
logger.info("Beginning epoch %s...", epoch)
# == training loop in an epoch ==
with tqdm(
enumerate(dataloader_iter, start=start_step),
enumerate(dataiter, start=start_step),
desc=f"Epoch {epoch}",
disable=not coordinator.is_master(),
total=num_steps_per_epoch,
@ -253,7 +253,7 @@ def main():
length = random.randint(1, x.size(2))
x = x[:, :, :length, :, :]
# == vae encoding ===
# == vae encoding & decoding ===
x_rec, x_z_rec, z, posterior, x_z = model(x)
# == loss initialization ==