This commit is contained in:
zhengzangw 2024-04-30 06:19:40 +00:00
parent e827b78000
commit 41e276f5ef
11 changed files with 243 additions and 186 deletions

View file

@ -0,0 +1,79 @@
num_frames = 16
image_size = (256, 256)
fps = 24 // 3
max_test_samples = None
# Define dataset
dataset = dict(
type="VideoTextDataset",
data_path=None,
num_frames=num_frames,
frame_interval=1,
image_size=image_size,
)
# Define acceleration
num_workers = 4
dtype = "bf16"
grad_checkpoint = True
plugin = "zero2"
sp_size = 1
# Define model
vae_2d = dict(
type="VideoAutoencoderKL",
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
subfolder="vae",
micro_batch_size=4,
local_files_only=True,
)
model = dict(
type="VAE_Temporal_SD",
)
# discriminator = dict(
# type="DISCRIMINATOR_3D",
# image_size=image_size,
# num_frames=num_frames,
# in_channels=3,
# filters=128,
# channel_multipliers=(2, 4, 4, 4, 4),
# # channel_multipliers = (2,4,4), #(2,4,4,4,4) # (2,4,4,4) for 64x64 resolution
# )
# loss weights
logvar_init = 0.0
kl_loss_weight = 0.000001
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
discriminator_factor = 1.0 # for discriminator adversarial loss
# discriminator_loss_weight = 0.5 # for generator adversarial loss
generator_factor = 0.1 # for generator adversarial loss
lecam_loss_weight = None # NOTE: not clear in MAGVIT what is the weight
discriminator_loss_type = "non-saturating"
generator_loss_type = "non-saturating"
discriminator_start = 2500 # 50000 NOTE: change to correct val, debug use -1 for now
gradient_penalty_loss_weight = None # 10 # SCH: MAGVIT uses 10, opensora plan doesn't use
ema_decay = 0.999 # ema decay factor for generator
# Others
seed = 42
save_dir = "samples/samples_vae"
wandb = False
# Training
""" NOTE:
magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128
==> ours num_frams = 16, reso = 256, so samples (K) * epochs ~ [500 - 1200],
3-6 epochs for pexel, from pexel observation its correct
"""
batch_size = 1
lr = 1e-4
grad_clip = 1.0
calc_loss = True

View file

@ -1,18 +1,18 @@
num_frames = 1
image_size = (256, 256)
# image_size = (256, 256)
image_size = (1024, 1024)
fps = 24 // 3
max_test_samples = None
# Define dataset
dataset = dict(
type="VideoTextDataset",
data_path=None,
num_frames=num_frames,
frame_interval=1,
image_size=image_size,
get_text=False,
)
fps = 24 // 3
is_vae = True
max_test_samples = -1
# Define acceleration
num_workers = 4
dtype = "bf16"
@ -20,12 +20,8 @@ grad_checkpoint = True
plugin = "zero2"
sp_size = 1
use_pipeline = True
video_contains_first_frame = True
# Define model
vae_2d = dict(
type="VideoAutoencoderKL",
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
@ -35,31 +31,18 @@ vae_2d = dict(
)
model = dict(
type="VAE_MAGVIT_V2",
in_out_channels=4,
latent_embed_dim=64,
filters=128,
num_res_blocks=4,
channel_multipliers=(1, 2, 2, 4),
temporal_downsample=(False, True, True),
num_groups=32, # for nn.GroupNorm
kl_embed_dim=4,
activation_fn="swish",
separate_first_frame_encoding=False,
disable_space=True,
custom_conv_padding=None,
encoder_double_z=True,
type="VAE_Temporal_SD",
)
discriminator = dict(
type="DISCRIMINATOR_3D",
image_size=image_size,
num_frames=num_frames,
in_channels=3,
filters=128,
channel_multipliers=(2, 4, 4, 4, 4),
# channel_multipliers = (2,4,4), #(2,4,4,4,4) # (2,4,4,4) for 64x64 resolution
)
# discriminator = dict(
# type="DISCRIMINATOR_3D",
# image_size=image_size,
# num_frames=num_frames,
# in_channels=3,
# filters=128,
# channel_multipliers=(2, 4, 4, 4, 4),
# # channel_multipliers = (2,4,4), #(2,4,4,4,4) # (2,4,4,4) for 64x64 resolution
# )
# loss weights
@ -79,7 +62,7 @@ ema_decay = 0.999 # ema decay factor for generator
# Others
seed = 42
save_dir = "samples/samples_pixabay_17"
save_dir = "samples/samples_vae"
wandb = False
# Training

View file

@ -64,18 +64,11 @@ seed = 42
outputs = "outputs"
wandb = False
# Training
""" NOTE:
magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128
==> ours num_frams = 16, reso = 256, so samples (K) * epochs ~ [500 - 1200],
3-6 epochs for pexel, from pexel observation its correct
"""
epochs = 100
log_every = 1
ckpt_every = 1000
load = None
batch_size = 4
batch_size = 1
lr = 1e-5
grad_clip = 1.0

View file

@ -61,13 +61,6 @@ seed = 42
outputs = "outputs"
wandb = False
# Training
""" NOTE:
magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128
==> ours num_frams = 16, reso = 256, so samples (K) * epochs ~ [500 - 1200],
3-6 epochs for pexel, from pexel observation its correct
"""
epochs = 100
log_every = 1
ckpt_every = 1000

View file

@ -1,3 +1,3 @@
from .discriminator import DISCRIMINATOR_3D
from .vae import VideoAutoencoderKL, VideoAutoencoderKLTemporalDecoder
from .vae_3d import VAE_Temporal
from .vae_temporal import VAE_Temporal

View file

@ -61,7 +61,7 @@ class LPIPS(nn.Module):
param.requires_grad = False
def load_from_pretrained(self, name="vgg_lpips"):
ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips", root="pretrained_models")
ckpt = get_ckpt_path(name, "pretrained_models/taming/modules/autoencoder/lpips")
self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
# print("loaded pretrained LPIPS loss from {}".format(ckpt))

View file

@ -24,23 +24,12 @@ def is_odd(n):
return not divisible_by(n, 2)
def pad_at_dim(t, pad, dim=-1, value=0.0):
def pad_at_dim(t, pad, dim=-1):
dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
zeros = (0, 0) * dims_from_right
return F.pad(t, (*zeros, *pad), mode="replicate")
def pick_video_frame(video, frame_indices):
"""get frame_indices from the video of [B, C, T, H, W] and return images of [B, C, H, W]"""
batch, device = video.shape[0], video.device
video = rearrange(video, "b c f ... -> b f c ...")
batch_indices = torch.arange(batch, device=device)
batch_indices = rearrange(batch_indices, "b -> b 1")
images = video[batch_indices, frame_indices]
images = rearrange(images, "b 1 c ... -> b c ...")
return images
def exists(v):
return v is not None
@ -381,7 +370,7 @@ class VAE_Temporal(nn.Module):
super().__init__()
self.time_downsample_factor = 2 ** sum(temporal_downsample)
self.time_padding = self.time_downsample_factor - 1
# self.time_padding = self.time_downsample_factor - 1
self.patch_size = (self.time_downsample_factor, 1, 1)
# NOTE: following MAGVIT, conv in bias=False in encoder first conv
@ -420,16 +409,18 @@ class VAE_Temporal(nn.Module):
return input_size
def encode(self, x):
x = pad_at_dim(x, (self.time_padding, 0), dim=2)
time_padding = self.time_downsample_factor - x.shape[2] % self.time_downsample_factor
x = pad_at_dim(x, (time_padding, 0), dim=2)
encoded_feature = self.encoder(x)
moments = self.quant_conv(encoded_feature).to(x.dtype)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z):
def decode(self, z, num_frames=None):
time_padding = self.time_downsample_factor - num_frames % self.time_downsample_factor
z = self.post_quant_conv(z)
x = self.decoder(z)
x = x[:, :, self.time_padding :]
x = x[:, :, time_padding:]
return x
def forward(self, x, sample_posterior=True):
@ -438,7 +429,7 @@ class VAE_Temporal(nn.Module):
z = posterior.sample()
else:
z = posterior.mode()
recon_video = self.decode(z)
recon_video = self.decode(z, num_frames=x.shape[2])
return recon_video, posterior, z

View file

@ -264,6 +264,7 @@ def load_checkpoint(model, ckpt_path, save_as_pt=False, model_name="model"):
print(f"Unexpected keys: {unexpected_keys}")
elif os.path.isdir(ckpt_path):
load_from_sharded_state_dict(model, ckpt_path, model_name)
print(f"Model checkpoint loaded from {ckpt_path}")
if save_as_pt:
save_path = os.path.join(ckpt_path, model_name + "_ckpt.pt")
torch.save(model.state_dict(), save_path)

View file

@ -76,7 +76,7 @@ def merge_args(cfg, args, training=False):
if cfg.get("discriminator") is not None:
cfg.discriminator["from_pretrained"] = args.ckpt_path
args.ckpt_path = None
if (training or cfg.get("is_vae", False)) and args.data_path is not None:
if args.data_path is not None:
cfg.dataset["data_path"] = args.data_path
args.data_path = None
if not training and args.cfg_scale is not None:
@ -106,9 +106,8 @@ def merge_args(cfg, args, training=False):
if "prompt_as_path" not in cfg:
cfg["prompt_as_path"] = False
# - Prompt handling
if not "is_vae" in cfg and ("prompt" not in cfg or cfg["prompt"] is None):
if "prompt" not in cfg or cfg["prompt"] is None:
assert cfg["prompt_path"] is not None, "prompt or prompt_path must be provided"
if "prompt" not in cfg or cfg["prompt"] is None:
if ("prompt" not in cfg or cfg["prompt"] is None) and cfg.get("prompt_path", None) is not None:
cfg["prompt"] = load_prompts(cfg["prompt_path"])
if args.start_index is not None and args.end_index is not None:
cfg["prompt"] = cfg["prompt"][args.start_index : args.end_index]

View file

@ -4,13 +4,12 @@ import colossalai
import torch
import torch.distributed as dist
from colossalai.cluster import DistCoordinator
from colossalai.utils import get_current_device
from mmengine.runner import set_random_seed
from tqdm import tqdm
from opensora.acceleration.parallel_states import get_data_parallel_group
from opensora.acceleration.parallel_states import get_data_parallel_group, set_sequence_parallel_group
from opensora.datasets import prepare_dataloader, save_sample
from opensora.models.vae.losses import AdversarialLoss, DiscriminatorLoss, LeCamEMA, VAELoss
from opensora.models.vae.vae_3d import pad_at_dim
from opensora.models.vae.losses import VAELoss
from opensora.registry import DATASETS, MODELS, build_module
from opensora.utils.config_utils import parse_configs
from opensora.utils.misc import to_torch_dtype
@ -24,21 +23,32 @@ def main():
print(cfg)
# init distributed
colossalai.launch_from_torch({})
coordinator = DistCoordinator()
if os.environ.get("WORLD_SIZE", None):
use_dist = True
colossalai.launch_from_torch({})
coordinator = DistCoordinator()
if coordinator.world_size > 1:
set_sequence_parallel_group(dist.group.WORLD)
else:
pass
else:
use_dist = False
# ======================================================
# 2. runtime variables
# ======================================================
torch.set_grad_enabled(False)
device = get_current_device()
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = to_torch_dtype(cfg.dtype)
set_random_seed(seed=cfg.seed)
# ======================================================
# 3. build dataset and dataloader
# ======================================================
dataset = build_module(cfg.dataset, DATASETS)
dataloader = prepare_dataloader(
dataset,
batch_size=cfg.batch_size,
@ -49,7 +59,6 @@ def main():
process_group=get_data_parallel_group(),
)
print(f"Dataset contains {len(dataset):,} videos ({cfg.dataset.data_path})")
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size
print(f"Total batch size: {total_batch_size}")
@ -57,25 +66,27 @@ def main():
# 4. build model & load weights
# ======================================================
# 3.1. build model
if cfg.get("use_pipeline") == True:
# use 2D VAE, then temporal VAE
if cfg.get("vae_2d", None) is not None:
vae_2d = build_module(cfg.vae_2d, MODELS)
vae = build_module(cfg.model, MODELS, device=device)
discriminator = build_module(cfg.discriminator, MODELS, device=device)
vae_2d.to(device, dtype).eval()
model = build_module(
cfg.model,
MODELS,
device=device,
dtype=dtype,
)
# discriminator = build_module(cfg.discriminator, MODELS, device=device)
# 3.2. move to device & eval
if cfg.get("use_pipeline") == True:
vae_2d.to(device, dtype).eval()
vae = vae.to(device, dtype).eval()
discriminator = discriminator.to(device, dtype).eval()
# discriminator = discriminator.to(device, dtype).eval()
# 3.4. support for multi-resolution
model_args = dict()
if cfg.multi_resolution:
image_size = cfg.dataset.image_size
hw = torch.tensor([image_size], device=device, dtype=dtype).repeat(cfg.batch_size, 1)
ar = torch.tensor([[image_size[0] / image_size[1]]], device=device, dtype=dtype).repeat(cfg.batch_size, 1)
model_args["data_info"] = dict(ar=ar, hw=hw)
# model_args = dict()
# if cfg.multi_resolution:
# image_size = cfg.dataset.image_size
# hw = torch.tensor([image_size], device=device, dtype=dtype).repeat(cfg.batch_size, 1)
# ar = torch.tensor([[image_size[0] / image_size[1]]], device=device, dtype=dtype).repeat(cfg.batch_size, 1)
# model_args["data_info"] = dict(ar=ar, hw=hw)
# ======================================================
# 4. inference
@ -95,46 +106,43 @@ def main():
dtype=dtype,
)
adversarial_loss_fn = AdversarialLoss(
discriminator_factor=cfg.discriminator_factor,
discriminator_start=cfg.discriminator_start,
generator_factor=cfg.generator_factor,
generator_loss_type=cfg.generator_loss_type,
)
# adversarial_loss_fn = AdversarialLoss(
# discriminator_factor=cfg.discriminator_factor,
# discriminator_start=cfg.discriminator_start,
# generator_factor=cfg.generator_factor,
# generator_loss_type=cfg.generator_loss_type,
# )
disc_loss_fn = DiscriminatorLoss(
discriminator_factor=cfg.discriminator_factor,
discriminator_start=cfg.discriminator_start,
discriminator_loss_type=cfg.discriminator_loss_type,
lecam_loss_weight=cfg.lecam_loss_weight,
gradient_penalty_loss_weight=cfg.gradient_penalty_loss_weight,
)
# disc_loss_fn = DiscriminatorLoss(
# discriminator_factor=cfg.discriminator_factor,
# discriminator_start=cfg.discriminator_start,
# discriminator_loss_type=cfg.discriminator_loss_type,
# lecam_loss_weight=cfg.lecam_loss_weight,
# gradient_penalty_loss_weight=cfg.gradient_penalty_loss_weight,
# )
# LeCam EMA for discriminator
# # LeCam EMA for discriminator
lecam_ema = LeCamEMA(decay=cfg.ema_decay, dtype=dtype, device=device)
# lecam_ema = LeCamEMA(decay=cfg.ema_decay, dtype=dtype, device=device)
running_loss = 0.0
running_nll = 0.0
running_disc_loss = 0.0
loss_steps = 0
disc_time_downsample_factor = 2 ** len(cfg.discriminator.channel_multipliers)
if cfg.dataset.num_frames % disc_time_downsample_factor != 0:
disc_time_padding = disc_time_downsample_factor - cfg.dataset.num_frames % disc_time_downsample_factor
else:
disc_time_padding = 0
video_contains_first_frame = cfg.video_contains_first_frame
# disc_time_downsample_factor = 2 ** len(cfg.discriminator.channel_multipliers)
# if cfg.dataset.num_frames % disc_time_downsample_factor != 0:
# disc_time_padding = disc_time_downsample_factor - cfg.dataset.num_frames % disc_time_downsample_factor
# else:
# disc_time_padding = 0
total_steps = len(dataloader)
if cfg.max_test_samples > 0:
if cfg.max_test_samples is not None:
total_steps = min(int(cfg.max_test_samples // cfg.batch_size), total_steps)
print(f"limiting test dataset to {int(cfg.max_test_samples//cfg.batch_size) * cfg.batch_size}")
dataloader_iter = iter(dataloader)
with tqdm(
range(total_steps),
# desc=f"Avg Loss: {running_loss}",
disable=not coordinator.is_master(),
total=total_steps,
initial=0,
@ -142,95 +150,96 @@ def main():
for step in pbar:
batch = next(dataloader_iter)
x = batch["video"].to(device, dtype) # [B, C, T, H, W]
video = x
# ===== Spatial VAE =====
if cfg.get("use_pipeline") == True:
with torch.no_grad():
video_enc_spatial = vae_2d.encode(video)
if cfg.get("vae_2d", None) is not None:
x_z = vae_2d.encode(x)
x_z_debug = vae_2d.decode(x_z)
recon_dec_spatial, posterior = vae(
video_enc_spatial, video_contains_first_frame=video_contains_first_frame
)
recon_video = vae_2d.decode(recon_dec_spatial)
recon_2d = vae_2d.decode(video_enc_spatial)
else:
recon_video, posterior = vae(video, video_contains_first_frame=video_contains_first_frame)
# ====== VAE ======
x_z_rec, posterior, z = model(x_z)
x_rec = vae_2d.decode(x_z_rec)
if cfg.calc_loss:
# ====== Calc Loss ======
# simple nll loss
nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(video, recon_video, posterior)
nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(x, x_rec, posterior)
fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value=0.0, dim=2)
fake_logits = discriminator(fake_video.contiguous())
adversarial_loss = adversarial_loss_fn(
fake_logits,
nll_loss,
vae.get_last_layer(),
cfg.discriminator_start + 1, # Hack to use discriminator
is_training=vae.training,
)
# fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value=0.0, dim=2)
# fake_logits = discriminator(fake_video.contiguous())
# adversarial_loss = adversarial_loss_fn(
# fake_logits,
# nll_loss,
# vae.get_last_layer(),
# cfg.discriminator_start + 1, # Hack to use discriminator
# is_training=vae.training,
# )
vae_loss = weighted_nll_loss + weighted_kl_loss + adversarial_loss
# vae_loss = weighted_nll_loss + weighted_kl_loss + adversarial_loss
vae_loss = weighted_nll_loss + weighted_kl_loss
# ====== Discriminator Loss ======
real_video = pad_at_dim(video, (disc_time_padding, 0), value=0.0, dim=2)
fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value=0.0, dim=2)
# # ====== Discriminator Loss ======
# real_video = pad_at_dim(video, (disc_time_padding, 0), value=0.0, dim=2)
# fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value=0.0, dim=2)
if cfg.gradient_penalty_loss_weight is not None and cfg.gradient_penalty_loss_weight > 0.0:
real_video = real_video.requires_grad_()
real_logits = discriminator(
real_video.contiguous()
) # SCH: not detached for now for gradient_penalty calculation
else:
real_logits = discriminator(real_video.contiguous().detach())
# if cfg.gradient_penalty_loss_weight is not None and cfg.gradient_penalty_loss_weight > 0.0:
# real_video = real_video.requires_grad_()
# real_logits = discriminator(
# real_video.contiguous()
# ) # SCH: not detached for now for gradient_penalty calculation
# else:
# real_logits = discriminator(real_video.contiguous().detach())
fake_logits = discriminator(fake_video.contiguous().detach())
# fake_logits = discriminator(fake_video.contiguous().detach())
lecam_ema_real, lecam_ema_fake = lecam_ema.get()
weighted_d_adversarial_loss, lecam_loss, gradient_penalty_loss = disc_loss_fn(
real_logits,
fake_logits,
cfg.discriminator_start + 1, # Hack to use discriminator
lecam_ema_real=lecam_ema_real,
lecam_ema_fake=lecam_ema_fake,
real_video=real_video if cfg.gradient_penalty_loss_weight is not None else None,
)
# lecam_ema_real, lecam_ema_fake = lecam_ema.get()
# weighted_d_adversarial_loss, lecam_loss, gradient_penalty_loss = disc_loss_fn(
# real_logits,
# fake_logits,
# cfg.discriminator_start + 1, # Hack to use discriminator
# lecam_ema_real=lecam_ema_real,
# lecam_ema_fake=lecam_ema_fake,
# real_video=real_video if cfg.gradient_penalty_loss_weight is not None else None,
# )
disc_loss = weighted_d_adversarial_loss + lecam_loss + gradient_penalty_loss
# disc_loss = weighted_d_adversarial_loss + lecam_loss + gradient_penalty_loss
loss_steps += 1
running_disc_loss = disc_loss.item() / loss_steps + running_disc_loss * ((loss_steps - 1) / loss_steps)
# running_disc_loss = disc_loss.item() / loss_steps + running_disc_loss * ((loss_steps - 1) / loss_steps)
running_loss = vae_loss.item() / loss_steps + running_loss * ((loss_steps - 1) / loss_steps)
running_nll = nll_loss.item() / loss_steps + running_nll * ((loss_steps - 1) / loss_steps)
# ===== Spatial VAE =====
if coordinator.is_master():
if cfg.get("use_pipeline") == True:
for idx, (sample_original, sample_pipeline, sample_2d) in enumerate(
zip(video, recon_video, recon_2d)
):
pos = step * cfg.batch_size + idx
save_path = os.path.join(save_dir, f"sample_{pos}")
save_sample(sample_original, fps=cfg.fps, save_path=save_path + "_original")
save_sample(sample_2d, fps=cfg.fps, save_path=save_path + "_2d")
save_sample(sample_pipeline, fps=cfg.fps, save_path=save_path + "_pipeline")
if not use_dist or coordinator.is_master():
for idx in range(len(x)):
pos = step * cfg.batch_size + idx
save_path = os.path.join(save_dir, f"sample_{pos}")
save_sample(x[idx], fps=cfg.fps, save_path=save_path + "_original")
save_sample(x_rec[idx], fps=cfg.fps, save_path=save_path + "_pipeline")
if cfg.get("vae_2d", None) is not None:
save_sample(x_z_debug[idx], fps=cfg.fps, save_path=save_path + "_2d")
else:
for idx, (original, recon) in enumerate(zip(video, recon_video)):
pos = step * cfg.batch_size + idx
save_path = os.path.join(save_dir, f"sample_{pos}")
save_sample(original, fps=cfg.fps, save_path=save_path + "_original")
save_sample(recon, fps=cfg.fps, save_path=save_path + "_recon")
# if cfg.get("use_pipeline") == True:
# for idx, (sample_original, sample_pipeline, sample_2d) in enumerate(
# zip(video, recon_video, recon_2d)
# ):
# pos = step * cfg.batch_size + idx
# save_path = os.path.join(save_dir, f"sample_{pos}")
# save_sample(sample_original, fps=cfg.fps, save_path=save_path + "_original")
# save_sample(sample_2d, fps=cfg.fps, save_path=save_path + "_2d")
# save_sample(sample_pipeline, fps=cfg.fps, save_path=save_path + "_pipeline")
# else:
# for idx, (original, recon) in enumerate(zip(video, recon_video)):
# pos = step * cfg.batch_size + idx
# save_path = os.path.join(save_dir, f"sample_{pos}")
# save_sample(original, fps=cfg.fps, save_path=save_path + "_original")
# save_sample(recon, fps=cfg.fps, save_path=save_path + "_recon")
if cfg.calc_loss:
print("test vae loss:", running_loss)
print("test nll loss:", running_nll)
print("test disc loss:", running_disc_loss)
# print("test disc loss:", running_disc_loss)
if __name__ == "__main__":

View file

@ -1,4 +1,5 @@
import os
import random
from datetime import timedelta
from pprint import pprint
@ -268,11 +269,14 @@ def main():
) as pbar:
for step, batch in pbar:
x = batch["video"].to(device, dtype) # [B, C, T, H, W]
if random.random() < 0.5:
x = x[:, :, :1, :, :]
# ===== Spatial VAE =====
if cfg.get("vae_2d", None) is not None:
with torch.no_grad():
x_z = vae_2d.encode(x)
vae_2d.decode(x_z)
# ====== VAE ======
x_z_rec, posterior, z = model(x_z)
@ -281,7 +285,8 @@ def main():
# ====== Generator Loss ======
# simple nll loss
_, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(x, x_rec, posterior)
_, weighted_z_nll_loss, _ = vae_loss_fn(x_z, x_z_rec, posterior)
# _, weighted_z_nll_loss, _ = vae_loss_fn(x_z, x_z_rec, posterior)
# _, debug_loss, _ = vae_loss_fn(x, x_z_debug, posterior)
# _, image_identity_loss, _ = vae_loss_fn(x_z, z, posterior)
# adversarial_loss = torch.tensor(0.0)
@ -300,7 +305,10 @@ def main():
# vae_loss = weighted_nll_loss + weighted_kl_loss + adversarial_loss + weighted_z_nll_loss
# vae_loss = weighted_nll_loss + weighted_kl_loss + weighted_z_nll_loss + image_identity_loss
vae_loss = weighted_nll_loss + weighted_kl_loss + weighted_z_nll_loss
# vae_loss = weighted_z_nll_loss + image_identity_loss
# vae_loss = weighted_nll_loss + weighted_kl_loss + weighted_z_nll_loss
# vae_loss = weighted_z_nll_loss
vae_loss = weighted_nll_loss + weighted_kl_loss
optimizer.zero_grad()
# Backward & update
@ -391,8 +399,9 @@ def main():
# "lecam_loss": lecam_loss.item(),
# "r1_grad_penalty": gradient_penalty_loss.item(),
"nll_loss": weighted_nll_loss.item(),
"z_nll_loss": weighted_z_nll_loss.item(),
# "z_nll_loss": weighted_z_nll_loss.item(),
# "image_identity_loss": image_identity_loss.item(),
# "debug_loss": debug_loss.item(),
"avg_loss": avg_loss,
},
step=global_step,