mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 21:01:26 +02:00
update config
This commit is contained in:
parent
98de78910d
commit
1171e5b6f9
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -181,3 +181,4 @@ cache/
|
|||
hostfile
|
||||
gradio_cached_examples/
|
||||
wandb/
|
||||
taming/
|
||||
|
|
|
|||
|
|
@ -36,44 +36,44 @@ vae_2d = dict(
|
|||
|
||||
model = dict(
|
||||
type="VAE_MAGVIT_V2",
|
||||
in_out_channels = 4,
|
||||
latent_embed_dim = 64,
|
||||
filters = 128,
|
||||
num_res_blocks = 4,
|
||||
channel_multipliers = (1, 2, 2, 4),
|
||||
temporal_downsample = (False, True, True),
|
||||
num_groups = 32, # for nn.GroupNorm
|
||||
kl_embed_dim = 4,
|
||||
activation_fn = 'swish',
|
||||
separate_first_frame_encoding = False,
|
||||
disable_space = True,
|
||||
custom_conv_padding = None,
|
||||
encoder_double_z = True,
|
||||
in_out_channels=4,
|
||||
latent_embed_dim=64,
|
||||
filters=128,
|
||||
num_res_blocks=4,
|
||||
channel_multipliers=(1, 2, 2, 4),
|
||||
temporal_downsample=(False, True, True),
|
||||
num_groups=32, # for nn.GroupNorm
|
||||
kl_embed_dim=4,
|
||||
activation_fn="swish",
|
||||
separate_first_frame_encoding=False,
|
||||
disable_space=True,
|
||||
custom_conv_padding=None,
|
||||
encoder_double_z=True,
|
||||
)
|
||||
|
||||
discriminator = dict(
|
||||
type="DISCRIMINATOR_3D",
|
||||
image_size = (128, 128),
|
||||
num_frames = num_frames,
|
||||
in_channels = 3,
|
||||
filters = 128,
|
||||
channel_multipliers = (2,4,4,4,4),
|
||||
image_size=(128, 128),
|
||||
num_frames=num_frames,
|
||||
in_channels=3,
|
||||
filters=128,
|
||||
channel_multipliers=(2, 4, 4, 4, 4),
|
||||
# channel_multipliers = (2,4,4), #(2,4,4,4,4) # (2,4,4,4) for 64x64 resolution
|
||||
)
|
||||
|
||||
|
||||
# loss weights
|
||||
logvar_init=0.0
|
||||
# loss weights
|
||||
logvar_init = 0.0
|
||||
kl_loss_weight = 0.000001
|
||||
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
|
||||
discriminator_factor = 1.0 # for discriminator adversarial loss
|
||||
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
|
||||
discriminator_factor = 1.0 # for discriminator adversarial loss
|
||||
# discriminator_loss_weight = 0.5 # for generator adversarial loss
|
||||
generator_factor = 0.1 # for generator adversarial loss
|
||||
lecam_loss_weight = None # NOTE: not clear in MAGVIT what is the weight
|
||||
discriminator_loss_type="non-saturating"
|
||||
generator_loss_type="non-saturating"
|
||||
discriminator_start = 2500 # 50000 NOTE: change to correct val, debug use -1 for now
|
||||
gradient_penalty_loss_weight = None # 10 # SCH: MAGVIT uses 10, opensora plan doesn't use
|
||||
generator_factor = 0.1 # for generator adversarial loss
|
||||
lecam_loss_weight = None # NOTE: not clear in MAGVIT what is the weight
|
||||
discriminator_loss_type = "non-saturating"
|
||||
generator_loss_type = "non-saturating"
|
||||
discriminator_start = 2500 # 50000 NOTE: change to correct val, debug use -1 for now
|
||||
gradient_penalty_loss_weight = None # 10 # SCH: MAGVIT uses 10, opensora plan doesn't use
|
||||
ema_decay = 0.999 # ema decay factor for generator
|
||||
|
||||
|
||||
|
|
@ -83,15 +83,15 @@ save_dir = "outputs/samples_pixabay_17"
|
|||
wandb = False
|
||||
|
||||
# Training
|
||||
''' NOTE:
|
||||
""" NOTE:
|
||||
magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128
|
||||
==> ours num_frams = 16, reso = 256, so samples (K) * epochs ~ [500 - 1200],
|
||||
==> ours num_frams = 16, reso = 256, so samples (K) * epochs ~ [500 - 1200],
|
||||
3-6 epochs for pexel, from pexel observation its correct
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
batch_size = 1
|
||||
lr = 1e-4
|
||||
grad_clip = 1.0
|
||||
|
||||
calc_loss = True
|
||||
calc_loss = True
|
||||
|
|
|
|||
|
|
@ -1,23 +1,18 @@
|
|||
num_frames = 1
|
||||
|
||||
image_size = (256, 256)
|
||||
|
||||
dataset = dict(
|
||||
type="VideoTextDataset",
|
||||
data_path=None,
|
||||
num_frames=num_frames,
|
||||
frame_interval=3,
|
||||
frame_interval=1,
|
||||
image_size=image_size,
|
||||
get_text=False,
|
||||
)
|
||||
|
||||
fps = 24 // 3
|
||||
is_vae = True
|
||||
|
||||
# Define dataset
|
||||
max_test_samples = -1
|
||||
|
||||
|
||||
# Define acceleration
|
||||
num_workers = 4
|
||||
dtype = "bf16"
|
||||
|
|
@ -33,67 +28,70 @@ video_contains_first_frame = True
|
|||
|
||||
vae_2d = dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="stabilityai/sd-vae-ft-ema",
|
||||
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
||||
subfolder="vae",
|
||||
micro_batch_size=4,
|
||||
local_files_only=True,
|
||||
)
|
||||
|
||||
model = dict(
|
||||
type="VAE_MAGVIT_V2",
|
||||
in_out_channels = 4,
|
||||
latent_embed_dim = 64,
|
||||
filters = 128,
|
||||
num_res_blocks = 4,
|
||||
channel_multipliers = (1, 2, 2, 4),
|
||||
temporal_downsample = (False, True, True),
|
||||
num_groups = 32, # for nn.GroupNorm
|
||||
kl_embed_dim = 4,
|
||||
activation_fn = 'swish',
|
||||
separate_first_frame_encoding = False,
|
||||
disable_space = True,
|
||||
custom_conv_padding = None,
|
||||
encoder_double_z = True,
|
||||
in_out_channels=4,
|
||||
latent_embed_dim=64,
|
||||
filters=128,
|
||||
num_res_blocks=4,
|
||||
channel_multipliers=(1, 2, 2, 4),
|
||||
temporal_downsample=(False, True, True),
|
||||
num_groups=32, # for nn.GroupNorm
|
||||
kl_embed_dim=4,
|
||||
activation_fn="swish",
|
||||
separate_first_frame_encoding=False,
|
||||
disable_space=True,
|
||||
custom_conv_padding=None,
|
||||
encoder_double_z=True,
|
||||
)
|
||||
|
||||
discriminator = dict(
|
||||
type="DISCRIMINATOR_3D",
|
||||
image_size = image_size,
|
||||
num_frames = num_frames,
|
||||
in_channels = 3,
|
||||
filters = 128,
|
||||
channel_multipliers = (2,4,4,4,4),
|
||||
image_size=image_size,
|
||||
num_frames=num_frames,
|
||||
in_channels=3,
|
||||
filters=128,
|
||||
channel_multipliers=(2, 4, 4, 4, 4),
|
||||
# channel_multipliers = (2,4,4), #(2,4,4,4,4) # (2,4,4,4) for 64x64 resolution
|
||||
)
|
||||
|
||||
|
||||
# loss weights
|
||||
logvar_init=0.0
|
||||
# loss weights
|
||||
logvar_init = 0.0
|
||||
kl_loss_weight = 0.000001
|
||||
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
|
||||
discriminator_factor = 1.0 # for discriminator adversarial loss
|
||||
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
|
||||
discriminator_factor = 1.0 # for discriminator adversarial loss
|
||||
# discriminator_loss_weight = 0.5 # for generator adversarial loss
|
||||
generator_factor = 0.1 # for generator adversarial loss
|
||||
lecam_loss_weight = None # NOTE: not clear in MAGVIT what is the weight
|
||||
discriminator_loss_type="non-saturating"
|
||||
generator_loss_type="non-saturating"
|
||||
discriminator_start = 2500 # 50000 NOTE: change to correct val, debug use -1 for now
|
||||
gradient_penalty_loss_weight = None # 10 # SCH: MAGVIT uses 10, opensora plan doesn't use
|
||||
generator_factor = 0.1 # for generator adversarial loss
|
||||
lecam_loss_weight = None # NOTE: not clear in MAGVIT what is the weight
|
||||
discriminator_loss_type = "non-saturating"
|
||||
generator_loss_type = "non-saturating"
|
||||
discriminator_start = 2500 # 50000 NOTE: change to correct val, debug use -1 for now
|
||||
gradient_penalty_loss_weight = None # 10 # SCH: MAGVIT uses 10, opensora plan doesn't use
|
||||
ema_decay = 0.999 # ema decay factor for generator
|
||||
|
||||
|
||||
# Others
|
||||
seed = 42
|
||||
save_dir = "outputs/samples_pixabay_17"
|
||||
save_dir = "samples/samples_pixabay_17"
|
||||
wandb = False
|
||||
|
||||
# Training
|
||||
''' NOTE:
|
||||
""" NOTE:
|
||||
magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128
|
||||
==> ours num_frams = 16, reso = 256, so samples (K) * epochs ~ [500 - 1200],
|
||||
==> ours num_frams = 16, reso = 256, so samples (K) * epochs ~ [500 - 1200],
|
||||
3-6 epochs for pexel, from pexel observation its correct
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
batch_size = 1
|
||||
lr = 1e-4
|
||||
grad_clip = 1.0
|
||||
|
||||
calc_loss = True
|
||||
calc_loss = True
|
||||
|
|
|
|||
|
|
@ -48,8 +48,8 @@ model = dict(
|
|||
activation_fn="swish",
|
||||
separate_first_frame_encoding=False,
|
||||
disable_space=True,
|
||||
encoder_double_z=True,
|
||||
custom_conv_padding=None,
|
||||
encoder_double_z=True,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -60,8 +60,8 @@ discriminator = dict(
|
|||
in_channels=3,
|
||||
filters=128,
|
||||
use_pretrained=True, # NOTE: set to False only if we want to disable load
|
||||
channel_multipliers = (2,4,4,4,4), # (2,4,4,4) for 64x64 resolution
|
||||
# channel_multipliers=(2, 4, 4), # since on intermediate layer dimension ofs z
|
||||
channel_multipliers=(2, 4, 4, 4, 4), # (2,4,4,4) for 64x64 resolution
|
||||
# channel_multipliers=(2, 4, 4), # since on intermediate layer dimension ofs z
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -76,7 +76,7 @@ discriminator_loss_type = "non-saturating"
|
|||
generator_loss_type = "non-saturating"
|
||||
# discriminator_loss_type="hinge"
|
||||
# generator_loss_type="hinge"
|
||||
discriminator_start = 2000 # 5000 # 8k data / (8*1) = 1000 steps per epoch
|
||||
discriminator_start = 2000 # 5000 # 8k data / (8*1) = 1000 steps per epoch
|
||||
gradient_penalty_loss_weight = None # 10 # SCH: MAGVIT uses 10, opensora plan doesn't use
|
||||
ema_decay = 0.999 # ema decay factor for generator
|
||||
|
||||
|
|
|
|||
|
|
@ -256,7 +256,7 @@ def create_logger(logging_dir):
|
|||
return logger
|
||||
|
||||
|
||||
def load_checkpoint(model, ckpt_path, save_as_pt=True, model_name="model"):
|
||||
def load_checkpoint(model, ckpt_path, save_as_pt=False, model_name="model"):
|
||||
if ckpt_path.endswith(".pt") or ckpt_path.endswith(".pth"):
|
||||
state_dict = find_model(ckpt_path, model=model)
|
||||
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
||||
|
|
|
|||
|
|
@ -4,25 +4,16 @@ import colossalai
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from mmengine.runner import set_random_seed
|
||||
from colossalai.utils import get_current_device
|
||||
from einops import rearrange
|
||||
from tqdm import tqdm
|
||||
|
||||
from opensora.acceleration.parallel_states import set_sequence_parallel_group
|
||||
from opensora.datasets import save_sample
|
||||
from opensora.registry import MODELS, SCHEDULERS, build_module
|
||||
from opensora.acceleration.parallel_states import get_data_parallel_group
|
||||
from opensora.datasets import prepare_dataloader, save_sample
|
||||
from opensora.models.vae.vae_3d_v2 import AdversarialLoss, DiscriminatorLoss, LeCamEMA, VEALoss, pad_at_dim
|
||||
from opensora.registry import DATASETS, MODELS, build_module
|
||||
from opensora.utils.config_utils import parse_configs
|
||||
from opensora.utils.misc import to_torch_dtype
|
||||
from opensora.datasets import prepare_dataloader, prepare_variable_dataloader
|
||||
from opensora.registry import DATASETS, MODELS, build_module
|
||||
from opensora.acceleration.parallel_states import (
|
||||
get_data_parallel_group,
|
||||
set_data_parallel_group,
|
||||
set_sequence_parallel_group,
|
||||
)
|
||||
from tqdm import tqdm
|
||||
from opensora.models.vae.vae_3d_v2 import VEALoss, DiscriminatorLoss, AdversarialLoss, LeCamEMA, pad_at_dim
|
||||
|
||||
from einops import rearrange
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def main():
|
||||
|
|
@ -49,9 +40,6 @@ def main():
|
|||
device = get_current_device()
|
||||
dtype = to_torch_dtype(cfg.dtype)
|
||||
|
||||
|
||||
|
||||
|
||||
# ======================================================
|
||||
# 3. build dataset and dataloader
|
||||
# ======================================================
|
||||
|
|
@ -101,43 +89,37 @@ def main():
|
|||
save_dir = cfg.save_dir
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
|
||||
|
||||
# 4.1. batch generation
|
||||
|
||||
|
||||
# define loss function
|
||||
if cfg.calc_loss:
|
||||
vae_loss_fn = VEALoss(
|
||||
logvar_init=cfg.logvar_init,
|
||||
perceptual_loss_weight = cfg.perceptual_loss_weight,
|
||||
kl_loss_weight = cfg.kl_loss_weight,
|
||||
perceptual_loss_weight=cfg.perceptual_loss_weight,
|
||||
kl_loss_weight=cfg.kl_loss_weight,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
adversarial_loss_fn = AdversarialLoss(
|
||||
discriminator_factor = cfg.discriminator_factor,
|
||||
discriminator_start = cfg.discriminator_start,
|
||||
generator_factor = cfg.generator_factor,
|
||||
generator_loss_type = cfg.generator_loss_type,
|
||||
discriminator_factor=cfg.discriminator_factor,
|
||||
discriminator_start=cfg.discriminator_start,
|
||||
generator_factor=cfg.generator_factor,
|
||||
generator_loss_type=cfg.generator_loss_type,
|
||||
)
|
||||
|
||||
disc_loss_fn = DiscriminatorLoss(
|
||||
discriminator_factor = cfg.discriminator_factor,
|
||||
discriminator_start = cfg.discriminator_start,
|
||||
discriminator_loss_type = cfg.discriminator_loss_type,
|
||||
lecam_loss_weight = cfg.lecam_loss_weight,
|
||||
gradient_penalty_loss_weight = cfg.gradient_penalty_loss_weight,
|
||||
)
|
||||
|
||||
# LeCam EMA for discriminator
|
||||
|
||||
lecam_ema = LeCamEMA(
|
||||
decay=cfg.ema_decay, dtype=dtype, device=device
|
||||
discriminator_factor=cfg.discriminator_factor,
|
||||
discriminator_start=cfg.discriminator_start,
|
||||
discriminator_loss_type=cfg.discriminator_loss_type,
|
||||
lecam_loss_weight=cfg.lecam_loss_weight,
|
||||
gradient_penalty_loss_weight=cfg.gradient_penalty_loss_weight,
|
||||
)
|
||||
|
||||
|
||||
|
||||
# LeCam EMA for discriminator
|
||||
|
||||
lecam_ema = LeCamEMA(decay=cfg.ema_decay, dtype=dtype, device=device)
|
||||
|
||||
running_loss = 0.0
|
||||
running_nll = 0.0
|
||||
running_disc_loss = 0.0
|
||||
|
|
@ -152,7 +134,7 @@ def main():
|
|||
|
||||
total_steps = len(dataloader)
|
||||
if cfg.max_test_samples > 0:
|
||||
total_steps = min(int(cfg.max_test_samples//cfg.batch_size), total_steps)
|
||||
total_steps = min(int(cfg.max_test_samples // cfg.batch_size), total_steps)
|
||||
print(f"limiting test dataset to {int(cfg.max_test_samples//cfg.batch_size) * cfg.batch_size}")
|
||||
dataloader_iter = iter(dataloader)
|
||||
|
||||
|
|
@ -169,7 +151,7 @@ def main():
|
|||
|
||||
is_image = x.ndim == 4
|
||||
if is_image:
|
||||
video = rearrange(x, 'b c ... -> b c 1 ...')
|
||||
video = rearrange(x, "b c ... -> b c 1 ...")
|
||||
video_contains_first_frame = True
|
||||
else:
|
||||
video = x
|
||||
|
|
@ -180,98 +162,88 @@ def main():
|
|||
video_enc_spatial = vae_2d.encode(video)
|
||||
|
||||
recon_dec_spatial, posterior = vae(
|
||||
video_enc_spatial,
|
||||
video_contains_first_frame = video_contains_first_frame
|
||||
video_enc_spatial, video_contains_first_frame=video_contains_first_frame
|
||||
)
|
||||
|
||||
recon_video = vae_2d.decode(recon_dec_spatial)
|
||||
recon_2d = vae_2d.decode(video_enc_spatial)
|
||||
|
||||
else:
|
||||
recon_video, posterior = vae(
|
||||
video,
|
||||
video_contains_first_frame = video_contains_first_frame
|
||||
)
|
||||
recon_video, posterior = vae(video, video_contains_first_frame=video_contains_first_frame)
|
||||
|
||||
if cfg.calc_loss:
|
||||
# ====== Calc Loss ======
|
||||
# simple nll loss
|
||||
nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(
|
||||
video,
|
||||
recon_video,
|
||||
posterior,
|
||||
split = "eval"
|
||||
)
|
||||
nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(video, recon_video, posterior, split="eval")
|
||||
|
||||
fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value = 0., dim = 2)
|
||||
fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value=0.0, dim=2)
|
||||
fake_logits = discriminator(fake_video.contiguous())
|
||||
adversarial_loss = adversarial_loss_fn(
|
||||
fake_logits,
|
||||
nll_loss,
|
||||
nll_loss,
|
||||
vae.get_last_layer(),
|
||||
cfg.discriminator_start+1, # Hack to use discriminator
|
||||
is_training = vae.training,
|
||||
cfg.discriminator_start + 1, # Hack to use discriminator
|
||||
is_training=vae.training,
|
||||
)
|
||||
|
||||
|
||||
vae_loss = weighted_nll_loss + weighted_kl_loss + adversarial_loss
|
||||
|
||||
|
||||
# ====== Discriminator Loss ======
|
||||
real_video = pad_at_dim(video, (disc_time_padding, 0), value = 0., dim = 2)
|
||||
fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value = 0., dim = 2)
|
||||
real_video = pad_at_dim(video, (disc_time_padding, 0), value=0.0, dim=2)
|
||||
fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value=0.0, dim=2)
|
||||
|
||||
if cfg.gradient_penalty_loss_weight is not None and cfg.gradient_penalty_loss_weight > 0.0:
|
||||
real_video = real_video.requires_grad_()
|
||||
real_logits = discriminator(real_video.contiguous()) # SCH: not detached for now for gradient_penalty calculation
|
||||
real_logits = discriminator(
|
||||
real_video.contiguous()
|
||||
) # SCH: not detached for now for gradient_penalty calculation
|
||||
else:
|
||||
real_logits = discriminator(real_video.contiguous().detach())
|
||||
real_logits = discriminator(real_video.contiguous().detach())
|
||||
|
||||
fake_logits = discriminator(fake_video.contiguous().detach())
|
||||
|
||||
lecam_ema_real, lecam_ema_fake = lecam_ema.get()
|
||||
weighted_d_adversarial_loss, lecam_loss, gradient_penalty_loss = disc_loss_fn(
|
||||
real_logits,
|
||||
fake_logits,
|
||||
cfg.discriminator_start+1, # Hack to use discriminator
|
||||
lecam_ema_real = lecam_ema_real,
|
||||
lecam_ema_fake = lecam_ema_fake,
|
||||
real_video = real_video if cfg.gradient_penalty_loss_weight is not None else None,
|
||||
real_logits,
|
||||
fake_logits,
|
||||
cfg.discriminator_start + 1, # Hack to use discriminator
|
||||
lecam_ema_real=lecam_ema_real,
|
||||
lecam_ema_fake=lecam_ema_fake,
|
||||
real_video=real_video if cfg.gradient_penalty_loss_weight is not None else None,
|
||||
)
|
||||
|
||||
disc_loss = weighted_d_adversarial_loss + lecam_loss + gradient_penalty_loss
|
||||
|
||||
loss_steps += 1
|
||||
running_disc_loss = disc_loss.item()/loss_steps + running_disc_loss * ((loss_steps - 1) / loss_steps)
|
||||
running_loss = vae_loss.item()/ loss_steps + running_loss * ((loss_steps - 1) / loss_steps)
|
||||
running_disc_loss = disc_loss.item() / loss_steps + running_disc_loss * ((loss_steps - 1) / loss_steps)
|
||||
running_loss = vae_loss.item() / loss_steps + running_loss * ((loss_steps - 1) / loss_steps)
|
||||
running_nll = nll_loss.item() / loss_steps + running_nll * ((loss_steps - 1) / loss_steps)
|
||||
|
||||
|
||||
# ===== Spatial VAE =====
|
||||
|
||||
|
||||
if coordinator.is_master():
|
||||
if cfg.get("use_pipeline") == True:
|
||||
for idx, (sample_original, sample_pipeline, sample_2d) in enumerate(zip(video, recon_video, recon_2d)):
|
||||
for idx, (sample_original, sample_pipeline, sample_2d) in enumerate(
|
||||
zip(video, recon_video, recon_2d)
|
||||
):
|
||||
pos = step * cfg.batch_size + idx
|
||||
save_path = os.path.join(save_dir, f"sample_{pos}")
|
||||
save_sample(sample_original, fps=cfg.fps, save_path=save_path+"_original")
|
||||
save_sample(sample_2d, fps=cfg.fps, save_path=save_path+"_2d")
|
||||
save_sample(sample_pipeline, fps=cfg.fps, save_path=save_path+"_pipeline")
|
||||
save_sample(sample_original, fps=cfg.fps, save_path=save_path + "_original")
|
||||
save_sample(sample_2d, fps=cfg.fps, save_path=save_path + "_2d")
|
||||
save_sample(sample_pipeline, fps=cfg.fps, save_path=save_path + "_pipeline")
|
||||
|
||||
else:
|
||||
for idx, (original, recon) in enumerate(zip(video, recon_video)):
|
||||
pos = step * cfg.batch_size + idx
|
||||
save_path = os.path.join(save_dir, f"sample_{pos}")
|
||||
save_sample(original, fps=cfg.fps, save_path=save_path+"_original")
|
||||
save_sample(recon, fps=cfg.fps, save_path=save_path+"_recon")
|
||||
|
||||
|
||||
save_sample(original, fps=cfg.fps, save_path=save_path + "_original")
|
||||
save_sample(recon, fps=cfg.fps, save_path=save_path + "_recon")
|
||||
|
||||
if cfg.calc_loss:
|
||||
print("test vae loss:", running_loss)
|
||||
print("test nll loss:", running_nll)
|
||||
print("test disc loss:", running_disc_loss)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
Loading…
Reference in a new issue