mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 21:01:26 +02:00
updated
This commit is contained in:
parent
e827b78000
commit
41e276f5ef
79
configs/vae/inference/17x256x256.py
Normal file
79
configs/vae/inference/17x256x256.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
num_frames = 16
|
||||
image_size = (256, 256)
|
||||
fps = 24 // 3
|
||||
max_test_samples = None
|
||||
|
||||
# Define dataset
|
||||
dataset = dict(
|
||||
type="VideoTextDataset",
|
||||
data_path=None,
|
||||
num_frames=num_frames,
|
||||
frame_interval=1,
|
||||
image_size=image_size,
|
||||
)
|
||||
|
||||
# Define acceleration
|
||||
num_workers = 4
|
||||
dtype = "bf16"
|
||||
grad_checkpoint = True
|
||||
plugin = "zero2"
|
||||
sp_size = 1
|
||||
|
||||
|
||||
# Define model
|
||||
vae_2d = dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
||||
subfolder="vae",
|
||||
micro_batch_size=4,
|
||||
local_files_only=True,
|
||||
)
|
||||
|
||||
model = dict(
|
||||
type="VAE_Temporal_SD",
|
||||
)
|
||||
|
||||
# discriminator = dict(
|
||||
# type="DISCRIMINATOR_3D",
|
||||
# image_size=image_size,
|
||||
# num_frames=num_frames,
|
||||
# in_channels=3,
|
||||
# filters=128,
|
||||
# channel_multipliers=(2, 4, 4, 4, 4),
|
||||
# # channel_multipliers = (2,4,4), #(2,4,4,4,4) # (2,4,4,4) for 64x64 resolution
|
||||
# )
|
||||
|
||||
|
||||
# loss weights
|
||||
logvar_init = 0.0
|
||||
kl_loss_weight = 0.000001
|
||||
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
|
||||
discriminator_factor = 1.0 # for discriminator adversarial loss
|
||||
# discriminator_loss_weight = 0.5 # for generator adversarial loss
|
||||
generator_factor = 0.1 # for generator adversarial loss
|
||||
lecam_loss_weight = None # NOTE: not clear in MAGVIT what is the weight
|
||||
discriminator_loss_type = "non-saturating"
|
||||
generator_loss_type = "non-saturating"
|
||||
discriminator_start = 2500 # 50000 NOTE: change to correct val, debug use -1 for now
|
||||
gradient_penalty_loss_weight = None # 10 # SCH: MAGVIT uses 10, opensora plan doesn't use
|
||||
ema_decay = 0.999 # ema decay factor for generator
|
||||
|
||||
|
||||
# Others
|
||||
seed = 42
|
||||
save_dir = "samples/samples_vae"
|
||||
wandb = False
|
||||
|
||||
# Training
|
||||
""" NOTE:
|
||||
magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128
|
||||
==> ours num_frams = 16, reso = 256, so samples (K) * epochs ~ [500 - 1200],
|
||||
3-6 epochs for pexel, from pexel observation its correct
|
||||
"""
|
||||
|
||||
|
||||
batch_size = 1
|
||||
lr = 1e-4
|
||||
grad_clip = 1.0
|
||||
|
||||
calc_loss = True
|
||||
|
|
@ -1,18 +1,18 @@
|
|||
num_frames = 1
|
||||
image_size = (256, 256)
|
||||
# image_size = (256, 256)
|
||||
image_size = (1024, 1024)
|
||||
fps = 24 // 3
|
||||
max_test_samples = None
|
||||
|
||||
# Define dataset
|
||||
dataset = dict(
|
||||
type="VideoTextDataset",
|
||||
data_path=None,
|
||||
num_frames=num_frames,
|
||||
frame_interval=1,
|
||||
image_size=image_size,
|
||||
get_text=False,
|
||||
)
|
||||
|
||||
fps = 24 // 3
|
||||
is_vae = True
|
||||
max_test_samples = -1
|
||||
|
||||
# Define acceleration
|
||||
num_workers = 4
|
||||
dtype = "bf16"
|
||||
|
|
@ -20,12 +20,8 @@ grad_checkpoint = True
|
|||
plugin = "zero2"
|
||||
sp_size = 1
|
||||
|
||||
use_pipeline = True
|
||||
video_contains_first_frame = True
|
||||
|
||||
|
||||
# Define model
|
||||
|
||||
vae_2d = dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
||||
|
|
@ -35,31 +31,18 @@ vae_2d = dict(
|
|||
)
|
||||
|
||||
model = dict(
|
||||
type="VAE_MAGVIT_V2",
|
||||
in_out_channels=4,
|
||||
latent_embed_dim=64,
|
||||
filters=128,
|
||||
num_res_blocks=4,
|
||||
channel_multipliers=(1, 2, 2, 4),
|
||||
temporal_downsample=(False, True, True),
|
||||
num_groups=32, # for nn.GroupNorm
|
||||
kl_embed_dim=4,
|
||||
activation_fn="swish",
|
||||
separate_first_frame_encoding=False,
|
||||
disable_space=True,
|
||||
custom_conv_padding=None,
|
||||
encoder_double_z=True,
|
||||
type="VAE_Temporal_SD",
|
||||
)
|
||||
|
||||
discriminator = dict(
|
||||
type="DISCRIMINATOR_3D",
|
||||
image_size=image_size,
|
||||
num_frames=num_frames,
|
||||
in_channels=3,
|
||||
filters=128,
|
||||
channel_multipliers=(2, 4, 4, 4, 4),
|
||||
# channel_multipliers = (2,4,4), #(2,4,4,4,4) # (2,4,4,4) for 64x64 resolution
|
||||
)
|
||||
# discriminator = dict(
|
||||
# type="DISCRIMINATOR_3D",
|
||||
# image_size=image_size,
|
||||
# num_frames=num_frames,
|
||||
# in_channels=3,
|
||||
# filters=128,
|
||||
# channel_multipliers=(2, 4, 4, 4, 4),
|
||||
# # channel_multipliers = (2,4,4), #(2,4,4,4,4) # (2,4,4,4) for 64x64 resolution
|
||||
# )
|
||||
|
||||
|
||||
# loss weights
|
||||
|
|
@ -79,7 +62,7 @@ ema_decay = 0.999 # ema decay factor for generator
|
|||
|
||||
# Others
|
||||
seed = 42
|
||||
save_dir = "samples/samples_pixabay_17"
|
||||
save_dir = "samples/samples_vae"
|
||||
wandb = False
|
||||
|
||||
# Training
|
||||
|
|
|
|||
|
|
@ -64,18 +64,11 @@ seed = 42
|
|||
outputs = "outputs"
|
||||
wandb = False
|
||||
|
||||
# Training
|
||||
""" NOTE:
|
||||
magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128
|
||||
==> ours num_frams = 16, reso = 256, so samples (K) * epochs ~ [500 - 1200],
|
||||
3-6 epochs for pexel, from pexel observation its correct
|
||||
"""
|
||||
|
||||
epochs = 100
|
||||
log_every = 1
|
||||
ckpt_every = 1000
|
||||
load = None
|
||||
|
||||
batch_size = 4
|
||||
batch_size = 1
|
||||
lr = 1e-5
|
||||
grad_clip = 1.0
|
||||
|
|
|
|||
|
|
@ -61,13 +61,6 @@ seed = 42
|
|||
outputs = "outputs"
|
||||
wandb = False
|
||||
|
||||
# Training
|
||||
""" NOTE:
|
||||
magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128
|
||||
==> ours num_frams = 16, reso = 256, so samples (K) * epochs ~ [500 - 1200],
|
||||
3-6 epochs for pexel, from pexel observation its correct
|
||||
"""
|
||||
|
||||
epochs = 100
|
||||
log_every = 1
|
||||
ckpt_every = 1000
|
||||
|
|
|
|||
|
|
@ -1,3 +1,3 @@
|
|||
from .discriminator import DISCRIMINATOR_3D
|
||||
from .vae import VideoAutoencoderKL, VideoAutoencoderKLTemporalDecoder
|
||||
from .vae_3d import VAE_Temporal
|
||||
from .vae_temporal import VAE_Temporal
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ class LPIPS(nn.Module):
|
|||
param.requires_grad = False
|
||||
|
||||
def load_from_pretrained(self, name="vgg_lpips"):
|
||||
ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips", root="pretrained_models")
|
||||
ckpt = get_ckpt_path(name, "pretrained_models/taming/modules/autoencoder/lpips")
|
||||
self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
|
||||
# print("loaded pretrained LPIPS loss from {}".format(ckpt))
|
||||
|
||||
|
|
|
|||
|
|
@ -24,23 +24,12 @@ def is_odd(n):
|
|||
return not divisible_by(n, 2)
|
||||
|
||||
|
||||
def pad_at_dim(t, pad, dim=-1, value=0.0):
|
||||
def pad_at_dim(t, pad, dim=-1):
|
||||
dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
||||
zeros = (0, 0) * dims_from_right
|
||||
return F.pad(t, (*zeros, *pad), mode="replicate")
|
||||
|
||||
|
||||
def pick_video_frame(video, frame_indices):
|
||||
"""get frame_indices from the video of [B, C, T, H, W] and return images of [B, C, H, W]"""
|
||||
batch, device = video.shape[0], video.device
|
||||
video = rearrange(video, "b c f ... -> b f c ...")
|
||||
batch_indices = torch.arange(batch, device=device)
|
||||
batch_indices = rearrange(batch_indices, "b -> b 1")
|
||||
images = video[batch_indices, frame_indices]
|
||||
images = rearrange(images, "b 1 c ... -> b c ...")
|
||||
return images
|
||||
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
|
|
@ -381,7 +370,7 @@ class VAE_Temporal(nn.Module):
|
|||
super().__init__()
|
||||
|
||||
self.time_downsample_factor = 2 ** sum(temporal_downsample)
|
||||
self.time_padding = self.time_downsample_factor - 1
|
||||
# self.time_padding = self.time_downsample_factor - 1
|
||||
self.patch_size = (self.time_downsample_factor, 1, 1)
|
||||
|
||||
# NOTE: following MAGVIT, conv in bias=False in encoder first conv
|
||||
|
|
@ -420,16 +409,18 @@ class VAE_Temporal(nn.Module):
|
|||
return input_size
|
||||
|
||||
def encode(self, x):
|
||||
x = pad_at_dim(x, (self.time_padding, 0), dim=2)
|
||||
time_padding = self.time_downsample_factor - x.shape[2] % self.time_downsample_factor
|
||||
x = pad_at_dim(x, (time_padding, 0), dim=2)
|
||||
encoded_feature = self.encoder(x)
|
||||
moments = self.quant_conv(encoded_feature).to(x.dtype)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
return posterior
|
||||
|
||||
def decode(self, z):
|
||||
def decode(self, z, num_frames=None):
|
||||
time_padding = self.time_downsample_factor - num_frames % self.time_downsample_factor
|
||||
z = self.post_quant_conv(z)
|
||||
x = self.decoder(z)
|
||||
x = x[:, :, self.time_padding :]
|
||||
x = x[:, :, time_padding:]
|
||||
return x
|
||||
|
||||
def forward(self, x, sample_posterior=True):
|
||||
|
|
@ -438,7 +429,7 @@ class VAE_Temporal(nn.Module):
|
|||
z = posterior.sample()
|
||||
else:
|
||||
z = posterior.mode()
|
||||
recon_video = self.decode(z)
|
||||
recon_video = self.decode(z, num_frames=x.shape[2])
|
||||
return recon_video, posterior, z
|
||||
|
||||
|
||||
|
|
@ -264,6 +264,7 @@ def load_checkpoint(model, ckpt_path, save_as_pt=False, model_name="model"):
|
|||
print(f"Unexpected keys: {unexpected_keys}")
|
||||
elif os.path.isdir(ckpt_path):
|
||||
load_from_sharded_state_dict(model, ckpt_path, model_name)
|
||||
print(f"Model checkpoint loaded from {ckpt_path}")
|
||||
if save_as_pt:
|
||||
save_path = os.path.join(ckpt_path, model_name + "_ckpt.pt")
|
||||
torch.save(model.state_dict(), save_path)
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ def merge_args(cfg, args, training=False):
|
|||
if cfg.get("discriminator") is not None:
|
||||
cfg.discriminator["from_pretrained"] = args.ckpt_path
|
||||
args.ckpt_path = None
|
||||
if (training or cfg.get("is_vae", False)) and args.data_path is not None:
|
||||
if args.data_path is not None:
|
||||
cfg.dataset["data_path"] = args.data_path
|
||||
args.data_path = None
|
||||
if not training and args.cfg_scale is not None:
|
||||
|
|
@ -106,9 +106,8 @@ def merge_args(cfg, args, training=False):
|
|||
if "prompt_as_path" not in cfg:
|
||||
cfg["prompt_as_path"] = False
|
||||
# - Prompt handling
|
||||
if not "is_vae" in cfg and ("prompt" not in cfg or cfg["prompt"] is None):
|
||||
if "prompt" not in cfg or cfg["prompt"] is None:
|
||||
assert cfg["prompt_path"] is not None, "prompt or prompt_path must be provided"
|
||||
if "prompt" not in cfg or cfg["prompt"] is None:
|
||||
if ("prompt" not in cfg or cfg["prompt"] is None) and cfg.get("prompt_path", None) is not None:
|
||||
cfg["prompt"] = load_prompts(cfg["prompt_path"])
|
||||
if args.start_index is not None and args.end_index is not None:
|
||||
cfg["prompt"] = cfg["prompt"][args.start_index : args.end_index]
|
||||
|
|
|
|||
|
|
@ -4,13 +4,12 @@ import colossalai
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.utils import get_current_device
|
||||
from mmengine.runner import set_random_seed
|
||||
from tqdm import tqdm
|
||||
|
||||
from opensora.acceleration.parallel_states import get_data_parallel_group
|
||||
from opensora.acceleration.parallel_states import get_data_parallel_group, set_sequence_parallel_group
|
||||
from opensora.datasets import prepare_dataloader, save_sample
|
||||
from opensora.models.vae.losses import AdversarialLoss, DiscriminatorLoss, LeCamEMA, VAELoss
|
||||
from opensora.models.vae.vae_3d import pad_at_dim
|
||||
from opensora.models.vae.losses import VAELoss
|
||||
from opensora.registry import DATASETS, MODELS, build_module
|
||||
from opensora.utils.config_utils import parse_configs
|
||||
from opensora.utils.misc import to_torch_dtype
|
||||
|
|
@ -24,21 +23,32 @@ def main():
|
|||
print(cfg)
|
||||
|
||||
# init distributed
|
||||
colossalai.launch_from_torch({})
|
||||
coordinator = DistCoordinator()
|
||||
if os.environ.get("WORLD_SIZE", None):
|
||||
use_dist = True
|
||||
colossalai.launch_from_torch({})
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
if coordinator.world_size > 1:
|
||||
set_sequence_parallel_group(dist.group.WORLD)
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
use_dist = False
|
||||
|
||||
# ======================================================
|
||||
# 2. runtime variables
|
||||
# ======================================================
|
||||
torch.set_grad_enabled(False)
|
||||
device = get_current_device()
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
dtype = to_torch_dtype(cfg.dtype)
|
||||
set_random_seed(seed=cfg.seed)
|
||||
|
||||
# ======================================================
|
||||
# 3. build dataset and dataloader
|
||||
# ======================================================
|
||||
dataset = build_module(cfg.dataset, DATASETS)
|
||||
|
||||
dataloader = prepare_dataloader(
|
||||
dataset,
|
||||
batch_size=cfg.batch_size,
|
||||
|
|
@ -49,7 +59,6 @@ def main():
|
|||
process_group=get_data_parallel_group(),
|
||||
)
|
||||
print(f"Dataset contains {len(dataset):,} videos ({cfg.dataset.data_path})")
|
||||
|
||||
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size
|
||||
print(f"Total batch size: {total_batch_size}")
|
||||
|
||||
|
|
@ -57,25 +66,27 @@ def main():
|
|||
# 4. build model & load weights
|
||||
# ======================================================
|
||||
# 3.1. build model
|
||||
if cfg.get("use_pipeline") == True:
|
||||
# use 2D VAE, then temporal VAE
|
||||
if cfg.get("vae_2d", None) is not None:
|
||||
vae_2d = build_module(cfg.vae_2d, MODELS)
|
||||
vae = build_module(cfg.model, MODELS, device=device)
|
||||
discriminator = build_module(cfg.discriminator, MODELS, device=device)
|
||||
vae_2d.to(device, dtype).eval()
|
||||
model = build_module(
|
||||
cfg.model,
|
||||
MODELS,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
# discriminator = build_module(cfg.discriminator, MODELS, device=device)
|
||||
|
||||
# 3.2. move to device & eval
|
||||
if cfg.get("use_pipeline") == True:
|
||||
vae_2d.to(device, dtype).eval()
|
||||
vae = vae.to(device, dtype).eval()
|
||||
discriminator = discriminator.to(device, dtype).eval()
|
||||
# discriminator = discriminator.to(device, dtype).eval()
|
||||
|
||||
# 3.4. support for multi-resolution
|
||||
model_args = dict()
|
||||
if cfg.multi_resolution:
|
||||
image_size = cfg.dataset.image_size
|
||||
hw = torch.tensor([image_size], device=device, dtype=dtype).repeat(cfg.batch_size, 1)
|
||||
ar = torch.tensor([[image_size[0] / image_size[1]]], device=device, dtype=dtype).repeat(cfg.batch_size, 1)
|
||||
model_args["data_info"] = dict(ar=ar, hw=hw)
|
||||
# model_args = dict()
|
||||
# if cfg.multi_resolution:
|
||||
# image_size = cfg.dataset.image_size
|
||||
# hw = torch.tensor([image_size], device=device, dtype=dtype).repeat(cfg.batch_size, 1)
|
||||
# ar = torch.tensor([[image_size[0] / image_size[1]]], device=device, dtype=dtype).repeat(cfg.batch_size, 1)
|
||||
# model_args["data_info"] = dict(ar=ar, hw=hw)
|
||||
|
||||
# ======================================================
|
||||
# 4. inference
|
||||
|
|
@ -95,46 +106,43 @@ def main():
|
|||
dtype=dtype,
|
||||
)
|
||||
|
||||
adversarial_loss_fn = AdversarialLoss(
|
||||
discriminator_factor=cfg.discriminator_factor,
|
||||
discriminator_start=cfg.discriminator_start,
|
||||
generator_factor=cfg.generator_factor,
|
||||
generator_loss_type=cfg.generator_loss_type,
|
||||
)
|
||||
# adversarial_loss_fn = AdversarialLoss(
|
||||
# discriminator_factor=cfg.discriminator_factor,
|
||||
# discriminator_start=cfg.discriminator_start,
|
||||
# generator_factor=cfg.generator_factor,
|
||||
# generator_loss_type=cfg.generator_loss_type,
|
||||
# )
|
||||
|
||||
disc_loss_fn = DiscriminatorLoss(
|
||||
discriminator_factor=cfg.discriminator_factor,
|
||||
discriminator_start=cfg.discriminator_start,
|
||||
discriminator_loss_type=cfg.discriminator_loss_type,
|
||||
lecam_loss_weight=cfg.lecam_loss_weight,
|
||||
gradient_penalty_loss_weight=cfg.gradient_penalty_loss_weight,
|
||||
)
|
||||
# disc_loss_fn = DiscriminatorLoss(
|
||||
# discriminator_factor=cfg.discriminator_factor,
|
||||
# discriminator_start=cfg.discriminator_start,
|
||||
# discriminator_loss_type=cfg.discriminator_loss_type,
|
||||
# lecam_loss_weight=cfg.lecam_loss_weight,
|
||||
# gradient_penalty_loss_weight=cfg.gradient_penalty_loss_weight,
|
||||
# )
|
||||
|
||||
# LeCam EMA for discriminator
|
||||
# # LeCam EMA for discriminator
|
||||
|
||||
lecam_ema = LeCamEMA(decay=cfg.ema_decay, dtype=dtype, device=device)
|
||||
# lecam_ema = LeCamEMA(decay=cfg.ema_decay, dtype=dtype, device=device)
|
||||
|
||||
running_loss = 0.0
|
||||
running_nll = 0.0
|
||||
running_disc_loss = 0.0
|
||||
loss_steps = 0
|
||||
|
||||
disc_time_downsample_factor = 2 ** len(cfg.discriminator.channel_multipliers)
|
||||
if cfg.dataset.num_frames % disc_time_downsample_factor != 0:
|
||||
disc_time_padding = disc_time_downsample_factor - cfg.dataset.num_frames % disc_time_downsample_factor
|
||||
else:
|
||||
disc_time_padding = 0
|
||||
video_contains_first_frame = cfg.video_contains_first_frame
|
||||
# disc_time_downsample_factor = 2 ** len(cfg.discriminator.channel_multipliers)
|
||||
# if cfg.dataset.num_frames % disc_time_downsample_factor != 0:
|
||||
# disc_time_padding = disc_time_downsample_factor - cfg.dataset.num_frames % disc_time_downsample_factor
|
||||
# else:
|
||||
# disc_time_padding = 0
|
||||
|
||||
total_steps = len(dataloader)
|
||||
if cfg.max_test_samples > 0:
|
||||
if cfg.max_test_samples is not None:
|
||||
total_steps = min(int(cfg.max_test_samples // cfg.batch_size), total_steps)
|
||||
print(f"limiting test dataset to {int(cfg.max_test_samples//cfg.batch_size) * cfg.batch_size}")
|
||||
dataloader_iter = iter(dataloader)
|
||||
|
||||
with tqdm(
|
||||
range(total_steps),
|
||||
# desc=f"Avg Loss: {running_loss}",
|
||||
disable=not coordinator.is_master(),
|
||||
total=total_steps,
|
||||
initial=0,
|
||||
|
|
@ -142,95 +150,96 @@ def main():
|
|||
for step in pbar:
|
||||
batch = next(dataloader_iter)
|
||||
x = batch["video"].to(device, dtype) # [B, C, T, H, W]
|
||||
video = x
|
||||
|
||||
# ===== Spatial VAE =====
|
||||
if cfg.get("use_pipeline") == True:
|
||||
with torch.no_grad():
|
||||
video_enc_spatial = vae_2d.encode(video)
|
||||
if cfg.get("vae_2d", None) is not None:
|
||||
x_z = vae_2d.encode(x)
|
||||
x_z_debug = vae_2d.decode(x_z)
|
||||
|
||||
recon_dec_spatial, posterior = vae(
|
||||
video_enc_spatial, video_contains_first_frame=video_contains_first_frame
|
||||
)
|
||||
|
||||
recon_video = vae_2d.decode(recon_dec_spatial)
|
||||
recon_2d = vae_2d.decode(video_enc_spatial)
|
||||
|
||||
else:
|
||||
recon_video, posterior = vae(video, video_contains_first_frame=video_contains_first_frame)
|
||||
# ====== VAE ======
|
||||
x_z_rec, posterior, z = model(x_z)
|
||||
x_rec = vae_2d.decode(x_z_rec)
|
||||
|
||||
if cfg.calc_loss:
|
||||
# ====== Calc Loss ======
|
||||
# simple nll loss
|
||||
nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(video, recon_video, posterior)
|
||||
nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(x, x_rec, posterior)
|
||||
|
||||
fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value=0.0, dim=2)
|
||||
fake_logits = discriminator(fake_video.contiguous())
|
||||
adversarial_loss = adversarial_loss_fn(
|
||||
fake_logits,
|
||||
nll_loss,
|
||||
vae.get_last_layer(),
|
||||
cfg.discriminator_start + 1, # Hack to use discriminator
|
||||
is_training=vae.training,
|
||||
)
|
||||
# fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value=0.0, dim=2)
|
||||
# fake_logits = discriminator(fake_video.contiguous())
|
||||
# adversarial_loss = adversarial_loss_fn(
|
||||
# fake_logits,
|
||||
# nll_loss,
|
||||
# vae.get_last_layer(),
|
||||
# cfg.discriminator_start + 1, # Hack to use discriminator
|
||||
# is_training=vae.training,
|
||||
# )
|
||||
|
||||
vae_loss = weighted_nll_loss + weighted_kl_loss + adversarial_loss
|
||||
# vae_loss = weighted_nll_loss + weighted_kl_loss + adversarial_loss
|
||||
vae_loss = weighted_nll_loss + weighted_kl_loss
|
||||
|
||||
# ====== Discriminator Loss ======
|
||||
real_video = pad_at_dim(video, (disc_time_padding, 0), value=0.0, dim=2)
|
||||
fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value=0.0, dim=2)
|
||||
# # ====== Discriminator Loss ======
|
||||
# real_video = pad_at_dim(video, (disc_time_padding, 0), value=0.0, dim=2)
|
||||
# fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value=0.0, dim=2)
|
||||
|
||||
if cfg.gradient_penalty_loss_weight is not None and cfg.gradient_penalty_loss_weight > 0.0:
|
||||
real_video = real_video.requires_grad_()
|
||||
real_logits = discriminator(
|
||||
real_video.contiguous()
|
||||
) # SCH: not detached for now for gradient_penalty calculation
|
||||
else:
|
||||
real_logits = discriminator(real_video.contiguous().detach())
|
||||
# if cfg.gradient_penalty_loss_weight is not None and cfg.gradient_penalty_loss_weight > 0.0:
|
||||
# real_video = real_video.requires_grad_()
|
||||
# real_logits = discriminator(
|
||||
# real_video.contiguous()
|
||||
# ) # SCH: not detached for now for gradient_penalty calculation
|
||||
# else:
|
||||
# real_logits = discriminator(real_video.contiguous().detach())
|
||||
|
||||
fake_logits = discriminator(fake_video.contiguous().detach())
|
||||
# fake_logits = discriminator(fake_video.contiguous().detach())
|
||||
|
||||
lecam_ema_real, lecam_ema_fake = lecam_ema.get()
|
||||
weighted_d_adversarial_loss, lecam_loss, gradient_penalty_loss = disc_loss_fn(
|
||||
real_logits,
|
||||
fake_logits,
|
||||
cfg.discriminator_start + 1, # Hack to use discriminator
|
||||
lecam_ema_real=lecam_ema_real,
|
||||
lecam_ema_fake=lecam_ema_fake,
|
||||
real_video=real_video if cfg.gradient_penalty_loss_weight is not None else None,
|
||||
)
|
||||
# lecam_ema_real, lecam_ema_fake = lecam_ema.get()
|
||||
# weighted_d_adversarial_loss, lecam_loss, gradient_penalty_loss = disc_loss_fn(
|
||||
# real_logits,
|
||||
# fake_logits,
|
||||
# cfg.discriminator_start + 1, # Hack to use discriminator
|
||||
# lecam_ema_real=lecam_ema_real,
|
||||
# lecam_ema_fake=lecam_ema_fake,
|
||||
# real_video=real_video if cfg.gradient_penalty_loss_weight is not None else None,
|
||||
# )
|
||||
|
||||
disc_loss = weighted_d_adversarial_loss + lecam_loss + gradient_penalty_loss
|
||||
# disc_loss = weighted_d_adversarial_loss + lecam_loss + gradient_penalty_loss
|
||||
|
||||
loss_steps += 1
|
||||
running_disc_loss = disc_loss.item() / loss_steps + running_disc_loss * ((loss_steps - 1) / loss_steps)
|
||||
# running_disc_loss = disc_loss.item() / loss_steps + running_disc_loss * ((loss_steps - 1) / loss_steps)
|
||||
running_loss = vae_loss.item() / loss_steps + running_loss * ((loss_steps - 1) / loss_steps)
|
||||
running_nll = nll_loss.item() / loss_steps + running_nll * ((loss_steps - 1) / loss_steps)
|
||||
|
||||
# ===== Spatial VAE =====
|
||||
|
||||
if coordinator.is_master():
|
||||
if cfg.get("use_pipeline") == True:
|
||||
for idx, (sample_original, sample_pipeline, sample_2d) in enumerate(
|
||||
zip(video, recon_video, recon_2d)
|
||||
):
|
||||
pos = step * cfg.batch_size + idx
|
||||
save_path = os.path.join(save_dir, f"sample_{pos}")
|
||||
save_sample(sample_original, fps=cfg.fps, save_path=save_path + "_original")
|
||||
save_sample(sample_2d, fps=cfg.fps, save_path=save_path + "_2d")
|
||||
save_sample(sample_pipeline, fps=cfg.fps, save_path=save_path + "_pipeline")
|
||||
if not use_dist or coordinator.is_master():
|
||||
for idx in range(len(x)):
|
||||
pos = step * cfg.batch_size + idx
|
||||
save_path = os.path.join(save_dir, f"sample_{pos}")
|
||||
save_sample(x[idx], fps=cfg.fps, save_path=save_path + "_original")
|
||||
save_sample(x_rec[idx], fps=cfg.fps, save_path=save_path + "_pipeline")
|
||||
if cfg.get("vae_2d", None) is not None:
|
||||
save_sample(x_z_debug[idx], fps=cfg.fps, save_path=save_path + "_2d")
|
||||
|
||||
else:
|
||||
for idx, (original, recon) in enumerate(zip(video, recon_video)):
|
||||
pos = step * cfg.batch_size + idx
|
||||
save_path = os.path.join(save_dir, f"sample_{pos}")
|
||||
save_sample(original, fps=cfg.fps, save_path=save_path + "_original")
|
||||
save_sample(recon, fps=cfg.fps, save_path=save_path + "_recon")
|
||||
# if cfg.get("use_pipeline") == True:
|
||||
# for idx, (sample_original, sample_pipeline, sample_2d) in enumerate(
|
||||
# zip(video, recon_video, recon_2d)
|
||||
# ):
|
||||
# pos = step * cfg.batch_size + idx
|
||||
# save_path = os.path.join(save_dir, f"sample_{pos}")
|
||||
# save_sample(sample_original, fps=cfg.fps, save_path=save_path + "_original")
|
||||
# save_sample(sample_2d, fps=cfg.fps, save_path=save_path + "_2d")
|
||||
# save_sample(sample_pipeline, fps=cfg.fps, save_path=save_path + "_pipeline")
|
||||
|
||||
# else:
|
||||
# for idx, (original, recon) in enumerate(zip(video, recon_video)):
|
||||
# pos = step * cfg.batch_size + idx
|
||||
# save_path = os.path.join(save_dir, f"sample_{pos}")
|
||||
# save_sample(original, fps=cfg.fps, save_path=save_path + "_original")
|
||||
# save_sample(recon, fps=cfg.fps, save_path=save_path + "_recon")
|
||||
|
||||
if cfg.calc_loss:
|
||||
print("test vae loss:", running_loss)
|
||||
print("test nll loss:", running_nll)
|
||||
print("test disc loss:", running_disc_loss)
|
||||
# print("test disc loss:", running_disc_loss)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
import random
|
||||
from datetime import timedelta
|
||||
from pprint import pprint
|
||||
|
||||
|
|
@ -268,11 +269,14 @@ def main():
|
|||
) as pbar:
|
||||
for step, batch in pbar:
|
||||
x = batch["video"].to(device, dtype) # [B, C, T, H, W]
|
||||
if random.random() < 0.5:
|
||||
x = x[:, :, :1, :, :]
|
||||
|
||||
# ===== Spatial VAE =====
|
||||
if cfg.get("vae_2d", None) is not None:
|
||||
with torch.no_grad():
|
||||
x_z = vae_2d.encode(x)
|
||||
vae_2d.decode(x_z)
|
||||
|
||||
# ====== VAE ======
|
||||
x_z_rec, posterior, z = model(x_z)
|
||||
|
|
@ -281,7 +285,8 @@ def main():
|
|||
# ====== Generator Loss ======
|
||||
# simple nll loss
|
||||
_, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(x, x_rec, posterior)
|
||||
_, weighted_z_nll_loss, _ = vae_loss_fn(x_z, x_z_rec, posterior)
|
||||
# _, weighted_z_nll_loss, _ = vae_loss_fn(x_z, x_z_rec, posterior)
|
||||
# _, debug_loss, _ = vae_loss_fn(x, x_z_debug, posterior)
|
||||
# _, image_identity_loss, _ = vae_loss_fn(x_z, z, posterior)
|
||||
|
||||
# adversarial_loss = torch.tensor(0.0)
|
||||
|
|
@ -300,7 +305,10 @@ def main():
|
|||
|
||||
# vae_loss = weighted_nll_loss + weighted_kl_loss + adversarial_loss + weighted_z_nll_loss
|
||||
# vae_loss = weighted_nll_loss + weighted_kl_loss + weighted_z_nll_loss + image_identity_loss
|
||||
vae_loss = weighted_nll_loss + weighted_kl_loss + weighted_z_nll_loss
|
||||
# vae_loss = weighted_z_nll_loss + image_identity_loss
|
||||
# vae_loss = weighted_nll_loss + weighted_kl_loss + weighted_z_nll_loss
|
||||
# vae_loss = weighted_z_nll_loss
|
||||
vae_loss = weighted_nll_loss + weighted_kl_loss
|
||||
|
||||
optimizer.zero_grad()
|
||||
# Backward & update
|
||||
|
|
@ -391,8 +399,9 @@ def main():
|
|||
# "lecam_loss": lecam_loss.item(),
|
||||
# "r1_grad_penalty": gradient_penalty_loss.item(),
|
||||
"nll_loss": weighted_nll_loss.item(),
|
||||
"z_nll_loss": weighted_z_nll_loss.item(),
|
||||
# "z_nll_loss": weighted_z_nll_loss.item(),
|
||||
# "image_identity_loss": image_identity_loss.item(),
|
||||
# "debug_loss": debug_loss.item(),
|
||||
"avg_loss": avg_loss,
|
||||
},
|
||||
step=global_step,
|
||||
|
|
|
|||
Loading…
Reference in a new issue