From ace37cf7c7d6971b0d6f0d4fcfc344240d35a570 Mon Sep 17 00:00:00 2001 From: zhengzangw Date: Tue, 30 Apr 2024 08:13:20 +0000 Subject: [PATCH] refactor --- configs/vae/inference/17x256x256.py | 79 ---------- configs/vae/inference/1x256x256.py | 80 ---------- configs/vae/inference/image.py | 42 +++++ configs/vae/inference/video.py | 42 +++++ configs/vae/train/17x256x256.py | 74 --------- configs/vae/train/1x256x256.py | 71 --------- configs/vae/train/image.py | 58 +++++++ configs/vae/train/video.py | 58 +++++++ opensora/models/vae/losses.py | 6 +- opensora/models/vae/vae.py | 55 ++++++- opensora/models/vae/vae_temporal.py | 54 ++----- opensora/utils/config_utils.py | 4 + scripts/inference-vae.py | 186 ++++------------------ scripts/train-vae.py | 233 ++++------------------------ 14 files changed, 328 insertions(+), 714 deletions(-) delete mode 100644 configs/vae/inference/17x256x256.py delete mode 100644 configs/vae/inference/1x256x256.py create mode 100644 configs/vae/inference/image.py create mode 100644 configs/vae/inference/video.py delete mode 100644 configs/vae/train/17x256x256.py delete mode 100644 configs/vae/train/1x256x256.py create mode 100644 configs/vae/train/image.py create mode 100644 configs/vae/train/video.py diff --git a/configs/vae/inference/17x256x256.py b/configs/vae/inference/17x256x256.py deleted file mode 100644 index 25c9ff0..0000000 --- a/configs/vae/inference/17x256x256.py +++ /dev/null @@ -1,79 +0,0 @@ -num_frames = 16 -image_size = (256, 256) -fps = 24 // 3 -max_test_samples = None - -# Define dataset -dataset = dict( - type="VideoTextDataset", - data_path=None, - num_frames=num_frames, - frame_interval=1, - image_size=image_size, -) - -# Define acceleration -num_workers = 4 -dtype = "bf16" -grad_checkpoint = True -plugin = "zero2" -sp_size = 1 - - -# Define model -vae_2d = dict( - type="VideoAutoencoderKL", - from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", - subfolder="vae", - micro_batch_size=4, - local_files_only=True, -) - -model = dict( - type="VAE_Temporal_SD", -) - -# discriminator = dict( -# type="DISCRIMINATOR_3D", -# image_size=image_size, -# num_frames=num_frames, -# in_channels=3, -# filters=128, -# channel_multipliers=(2, 4, 4, 4, 4), -# # channel_multipliers = (2,4,4), #(2,4,4,4,4) # (2,4,4,4) for 64x64 resolution -# ) - - -# loss weights -logvar_init = 0.0 -kl_loss_weight = 0.000001 -perceptual_loss_weight = 0.1 # use vgg is not None and more than 0 -discriminator_factor = 1.0 # for discriminator adversarial loss -# discriminator_loss_weight = 0.5 # for generator adversarial loss -generator_factor = 0.1 # for generator adversarial loss -lecam_loss_weight = None # NOTE: not clear in MAGVIT what is the weight -discriminator_loss_type = "non-saturating" -generator_loss_type = "non-saturating" -discriminator_start = 2500 # 50000 NOTE: change to correct val, debug use -1 for now -gradient_penalty_loss_weight = None # 10 # SCH: MAGVIT uses 10, opensora plan doesn't use -ema_decay = 0.999 # ema decay factor for generator - - -# Others -seed = 42 -save_dir = "samples/samples_vae" -wandb = False - -# Training -""" NOTE: -magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128 -==> ours num_frams = 16, reso = 256, so samples (K) * epochs ~ [500 - 1200], -3-6 epochs for pexel, from pexel observation its correct -""" - - -batch_size = 1 -lr = 1e-4 -grad_clip = 1.0 - -calc_loss = True diff --git a/configs/vae/inference/1x256x256.py b/configs/vae/inference/1x256x256.py deleted file mode 100644 index e9f8a5e..0000000 --- a/configs/vae/inference/1x256x256.py +++ /dev/null @@ -1,80 +0,0 @@ -num_frames = 1 -# image_size = (256, 256) -image_size = (1024, 1024) -fps = 24 // 3 -max_test_samples = None - -# Define dataset -dataset = dict( - type="VideoTextDataset", - data_path=None, - num_frames=num_frames, - frame_interval=1, - image_size=image_size, -) - -# Define acceleration -num_workers = 4 -dtype = "bf16" -grad_checkpoint = True -plugin = "zero2" -sp_size = 1 - - -# Define model -vae_2d = dict( - type="VideoAutoencoderKL", - from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", - subfolder="vae", - micro_batch_size=4, - local_files_only=True, -) - -model = dict( - type="VAE_Temporal_SD", -) - -# discriminator = dict( -# type="DISCRIMINATOR_3D", -# image_size=image_size, -# num_frames=num_frames, -# in_channels=3, -# filters=128, -# channel_multipliers=(2, 4, 4, 4, 4), -# # channel_multipliers = (2,4,4), #(2,4,4,4,4) # (2,4,4,4) for 64x64 resolution -# ) - - -# loss weights -logvar_init = 0.0 -kl_loss_weight = 0.000001 -perceptual_loss_weight = 0.1 # use vgg is not None and more than 0 -discriminator_factor = 1.0 # for discriminator adversarial loss -# discriminator_loss_weight = 0.5 # for generator adversarial loss -generator_factor = 0.1 # for generator adversarial loss -lecam_loss_weight = None # NOTE: not clear in MAGVIT what is the weight -discriminator_loss_type = "non-saturating" -generator_loss_type = "non-saturating" -discriminator_start = 2500 # 50000 NOTE: change to correct val, debug use -1 for now -gradient_penalty_loss_weight = None # 10 # SCH: MAGVIT uses 10, opensora plan doesn't use -ema_decay = 0.999 # ema decay factor for generator - - -# Others -seed = 42 -save_dir = "samples/samples_vae" -wandb = False - -# Training -""" NOTE: -magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128 -==> ours num_frams = 16, reso = 256, so samples (K) * epochs ~ [500 - 1200], -3-6 epochs for pexel, from pexel observation its correct -""" - - -batch_size = 1 -lr = 1e-4 -grad_clip = 1.0 - -calc_loss = True diff --git a/configs/vae/inference/image.py b/configs/vae/inference/image.py new file mode 100644 index 0000000..d83348a --- /dev/null +++ b/configs/vae/inference/image.py @@ -0,0 +1,42 @@ +num_frames = 1 +frame_interval = 1 +fps = 24 +image_size = (256, 256) + +# Define dataset +dataset = dict( + type="VideoTextDataset", + data_path=None, + num_frames=num_frames, + frame_interval=1, + image_size=image_size, +) +num_workers = 4 +max_test_samples = None + +# Define model +model = dict( + type="VideoAutoencoderPipeline", + freeze_vae_2d=True, + vae_2d=dict( + type="VideoAutoencoderKL", + from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", + subfolder="vae", + micro_batch_size=4, + local_files_only=True, + ), + vae_temporal=dict( + type="VAE_Temporal_SD", + from_pretrained=None, + ), +) +dtype = "bf16" + +# loss weights +perceptual_loss_weight = 0.1 # use vgg is not None and more than 0 +kl_loss_weight = 1e-6 + +# Others +batch_size = 1 +seed = 42 +save_dir = "samples/vae_image" diff --git a/configs/vae/inference/video.py b/configs/vae/inference/video.py new file mode 100644 index 0000000..70fc49d --- /dev/null +++ b/configs/vae/inference/video.py @@ -0,0 +1,42 @@ +num_frames = 17 +frame_interval = 1 +fps = 24 +image_size = (256, 256) + +# Define dataset +dataset = dict( + type="VideoTextDataset", + data_path=None, + num_frames=num_frames, + frame_interval=1, + image_size=image_size, +) +num_workers = 4 +max_test_samples = None + +# Define model +model = dict( + type="VideoAutoencoderPipeline", + from_pretrained=None, + vae_2d=dict( + type="VideoAutoencoderKL", + from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", + subfolder="vae", + micro_batch_size=4, + local_files_only=True, + ), + vae_temporal=dict( + type="VAE_Temporal_SD", + from_pretrained=None, + ), +) +dtype = "bf16" + +# loss weights +perceptual_loss_weight = 0.1 # use vgg is not None and more than 0 +kl_loss_weight = 1e-6 + +# Others +batch_size = 1 +seed = 42 +save_dir = "samples/vae_video" diff --git a/configs/vae/train/17x256x256.py b/configs/vae/train/17x256x256.py deleted file mode 100644 index 46e3157..0000000 --- a/configs/vae/train/17x256x256.py +++ /dev/null @@ -1,74 +0,0 @@ -num_frames = 17 -image_size = (256, 256) - -# Define dataset -dataset = dict( - type="VideoTextDataset", - data_path=None, - num_frames=num_frames, - frame_interval=1, - image_size=image_size, -) - -# Define acceleration -num_workers = 16 -dtype = "bf16" -grad_checkpoint = True -plugin = "zero2" -sp_size = 1 - -# latest -vae_2d = dict( - type="VideoAutoencoderKL", - from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", - subfolder="vae", - micro_batch_size=4, - local_files_only=True, -) - -model = dict( - type="VAE_Temporal_SD", -) - - -# discriminator = dict( -# type="DISCRIMINATOR_3D", -# image_size=image_size, # NOTE: here image size is different -# num_frames=num_frames, -# in_channels=3, -# filters=128, -# use_pretrained=True, # NOTE: set to False only if we want to disable load -# channel_multipliers=(2, 4, 4, 4, 4), # (2,4,4,4) for 64x64 resolution -# # channel_multipliers=(2, 4, 4), # since on intermediate layer dimension ofs z -# ) - - -# loss weights -logvar_init = 0.0 -kl_loss_weight = 0.000001 -perceptual_loss_weight = 0.1 # use vgg is not None and more than 0 -discriminator_factor = 1.0 # for discriminator adversarial loss -generator_factor = 0.1 # SCH: generator adversarial loss, MAGVIT v2 uses 0.1 -lecam_loss_weight = None # NOTE: MAVGIT v2 use 0.001 -discriminator_loss_type = "non-saturating" -generator_loss_type = "non-saturating" -# discriminator_loss_type="hinge" -# generator_loss_type="hinge" -discriminator_start = 2000 # 5000 # 8k data / (8*1) = 1000 steps per epoch -gradient_penalty_loss_weight = None # 10 # SCH: MAGVIT uses 10, opensora plan doesn't use -ema_decay = 0.999 # ema decay factor for generator - - -# Others -seed = 42 -outputs = "outputs" -wandb = False - -epochs = 100 -log_every = 1 -ckpt_every = 1000 -load = None - -batch_size = 1 -lr = 1e-5 -grad_clip = 1.0 diff --git a/configs/vae/train/1x256x256.py b/configs/vae/train/1x256x256.py deleted file mode 100644 index d48df09..0000000 --- a/configs/vae/train/1x256x256.py +++ /dev/null @@ -1,71 +0,0 @@ -num_frames = 1 -image_size = (256, 256) - -# Define dataset -dataset = dict( - type="VideoTextDataset", - data_path=None, - num_frames=num_frames, - frame_interval=1, - image_size=image_size, -) - -# Define acceleration -num_workers = 16 -dtype = "bf16" -grad_checkpoint = False -plugin = "zero2" -sp_size = 1 - -# latest -vae_2d = dict( - type="VideoAutoencoderKL", - from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", - subfolder="vae", - micro_batch_size=4, - local_files_only=True, -) - -model = dict( - type="VAE_Temporal_SD", -) - -# discriminator = dict( -# type="DISCRIMINATOR_3D", -# image_size=image_size, # NOTE: here image size is different -# num_frames=num_frames, -# in_channels=3, -# filters=128, -# use_pretrained=True, # NOTE: set to False only if we want to disable load -# channel_multipliers=(2, 4, 4, 4, 4), # (2,4,4,4) for 64x64 resolution -# # channel_multipliers=(2, 4, 4), # since on intermediate layer dimension ofs z -# ) - - -# loss weights -logvar_init = 0.0 -kl_loss_weight = 0.000001 -perceptual_loss_weight = 0.1 # use vgg is not None and more than 0 -discriminator_factor = 1.0 # for discriminator adversarial loss -generator_factor = 0.1 # generator adversarial loss, MAGVIT v2 uses 0.1 -lecam_loss_weight = None # NOTE: MAVGIT v2 use 0.001 -discriminator_loss_type = "non-saturating" -generator_loss_type = "non-saturating" -discriminator_start = 2000 -gradient_penalty_loss_weight = None # MAGVIT uses 10, opensora plan doesn't use -ema_decay = 0.999 # ema decay factor for generator - - -# Others -seed = 42 -outputs = "outputs" -wandb = False - -epochs = 100 -log_every = 1 -ckpt_every = 1000 -load = None - -batch_size = 4 -lr = 1e-5 -grad_clip = 1.0 diff --git a/configs/vae/train/image.py b/configs/vae/train/image.py new file mode 100644 index 0000000..8541c25 --- /dev/null +++ b/configs/vae/train/image.py @@ -0,0 +1,58 @@ +num_frames = 1 +image_size = (256, 256) + +# Define dataset +dataset = dict( + type="VideoTextDataset", + data_path=None, + num_frames=num_frames, + frame_interval=1, + image_size=image_size, +) + +# Define acceleration +num_workers = 16 +dtype = "bf16" +grad_checkpoint = True +plugin = "zero2" + +# Define model +model = dict( + type="VideoAutoencoderPipeline", + freeze_vae_2d=True, + from_pretrained=None, + vae_2d=dict( + type="VideoAutoencoderKL", + from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", + subfolder="vae", + micro_batch_size=4, + local_files_only=True, + ), + vae_temporal=dict( + type="VAE_Temporal_SD", + from_pretrained=None, + ), +) + +# loss weights +perceptual_loss_weight = 0.1 # use vgg is not None and more than 0 +kl_loss_weight = 1e-6 + +mixed_image_ratio = 0.1 +use_real_rec_loss = False +use_z_rec_loss = True +use_image_identity_loss = True + +# Others +seed = 42 +outputs = "outputs" +wandb = False + +epochs = 100 +log_every = 1 +ckpt_every = 1000 +load = None + +batch_size = 1 +lr = 1e-5 +grad_clip = 1.0 diff --git a/configs/vae/train/video.py b/configs/vae/train/video.py new file mode 100644 index 0000000..453a226 --- /dev/null +++ b/configs/vae/train/video.py @@ -0,0 +1,58 @@ +num_frames = 17 +image_size = (256, 256) + +# Define dataset +dataset = dict( + type="VideoTextDataset", + data_path=None, + num_frames=num_frames, + frame_interval=1, + image_size=image_size, +) + +# Define acceleration +num_workers = 16 +dtype = "bf16" +grad_checkpoint = True +plugin = "zero2" + +# Define model +model = dict( + type="VideoAutoencoderPipeline", + freeze_vae_2d=True, + from_pretrained=None, + vae_2d=dict( + type="VideoAutoencoderKL", + from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", + subfolder="vae", + micro_batch_size=4, + local_files_only=True, + ), + vae_temporal=dict( + type="VAE_Temporal_SD", + from_pretrained=None, + ), +) + +# loss weights +perceptual_loss_weight = 0.1 # use vgg is not None and more than 0 +kl_loss_weight = 1e-6 + +mixed_image_ratio = 0.1 +use_real_rec_loss = True +use_z_rec_loss = False +use_image_identity_loss = False + +# Others +seed = 42 +outputs = "outputs" +wandb = False + +epochs = 100 +log_every = 1 +ckpt_every = 1000 +load = None + +batch_size = 1 +lr = 1e-5 +grad_clip = 1.0 diff --git a/opensora/models/vae/losses.py b/opensora/models/vae/losses.py index 68d555f..e87403c 100644 --- a/opensora/models/vae/losses.py +++ b/opensora/models/vae/losses.py @@ -94,13 +94,10 @@ class VAELoss(nn.Module): if self.perceptual_loss_weight is not None and self.perceptual_loss_weight > 0.0: # handle channels channels = video.shape[1] - assert channels in {1, 3, 4} + assert channels in {1, 3} if channels == 1: input_vgg_input = repeat(video, "b 1 h w -> b c h w", c=3) recon_vgg_input = repeat(recon_video, "b 1 h w -> b c h w", c=3) - elif channels == 4: # SCH: take the first 3 for perceptual loss calc - input_vgg_input = video[:, :3] - recon_vgg_input = recon_video[:, :3] else: input_vgg_input = video recon_vgg_input = recon_video @@ -109,6 +106,7 @@ class VAELoss(nn.Module): recon_loss = recon_loss + self.perceptual_loss_weight * perceptual_loss nll_loss = recon_loss / torch.exp(self.logvar) + self.logvar + weighted_nll_loss = nll_loss if nll_weights is not None: weighted_nll_loss = nll_weights * nll_loss diff --git a/opensora/models/vae/vae.py b/opensora/models/vae/vae.py index 277ae51..c54aa60 100644 --- a/opensora/models/vae/vae.py +++ b/opensora/models/vae/vae.py @@ -3,15 +3,20 @@ import torch.nn as nn from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder from einops import rearrange -from opensora.registry import MODELS +from opensora.registry import MODELS, build_module +from opensora.utils.ckpt_utils import load_checkpoint @MODELS.register_module() class VideoAutoencoderKL(nn.Module): - def __init__(self, from_pretrained=None, micro_batch_size=None, cache_dir=None, local_files_only=False, subfolder=None): + def __init__( + self, from_pretrained=None, micro_batch_size=None, cache_dir=None, local_files_only=False, subfolder=None + ): super().__init__() self.module = AutoencoderKL.from_pretrained( - from_pretrained, cache_dir=cache_dir, local_files_only=local_files_only, + from_pretrained, + cache_dir=cache_dir, + local_files_only=local_files_only, subfolder=subfolder, ) self.out_channels = self.module.config.latent_channels @@ -107,3 +112,47 @@ class VideoAutoencoderKLTemporalDecoder(nn.Module): @property def dtype(self): return next(self.parameters()).dtype + + +@MODELS.register_module() +class VideoAutoencoderPipeline(nn.Module): + def __init__(self, vae_2d=None, vae_temporal=None, freeze_vae_2d=False, from_pretrained=None): + super().__init__() + self.spatial_vae = build_module(vae_2d, MODELS) + self.temporal_vae = build_module(vae_temporal, MODELS) + if from_pretrained is not None: + load_checkpoint(self, from_pretrained) + if freeze_vae_2d: + for param in self.spatial_vae.parameters(): + param.requires_grad = False + + def encode(self, x, training=True): + x_z = self.spatial_vae.encode(x) + posterior = self.temporal_vae.encode(x_z) + z = posterior.sample() + if training: + return z, posterior, x_z + return z + + def decode(self, z, num_frames=None, training=True): + x_z = self.temporal_vae.decode(z, num_frames=num_frames) + x = self.spatial_vae.decode(x_z) + if training: + return x, x_z + return x + + def forward(self, x): + z, posterior, x_z = self.encode(x, training=True) + x_rec, x_z_rec = self.decode(z, num_frames=x_z.shape[2], training=True) + return x_rec, x_z_rec, z, posterior, x_z + + def get_latent_size(self, input_size): + return self.temporal_vae.get_latent_size(self.spatial_vae.get_latent_size(input_size)) + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype diff --git a/opensora/models/vae/vae_temporal.py b/opensora/models/vae/vae_temporal.py index 1c34924..3e8b399 100644 --- a/opensora/models/vae/vae_temporal.py +++ b/opensora/models/vae/vae_temporal.py @@ -1,7 +1,5 @@ -import functools from typing import Tuple, Union -import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange @@ -82,8 +80,6 @@ class ResBlock(nn.Module): activation_fn=nn.SiLU, use_conv_shortcut=False, num_groups=32, - device="cpu", - dtype=torch.bfloat16, ): super().__init__() self.in_channels = in_channels @@ -92,9 +88,9 @@ class ResBlock(nn.Module): self.use_conv_shortcut = use_conv_shortcut # SCH: MAGVIT uses GroupNorm by default - self.norm1 = nn.GroupNorm(num_groups, in_channels, device=device, dtype=dtype) + self.norm1 = nn.GroupNorm(num_groups, in_channels) self.conv1 = conv_fn(in_channels, self.filters, kernel_size=(3, 3, 3), bias=False) - self.norm2 = nn.GroupNorm(num_groups, self.filters, device=device, dtype=dtype) + self.norm2 = nn.GroupNorm(num_groups, self.filters) self.conv2 = conv_fn(self.filters, self.filters, kernel_size=(3, 3, 3), bias=False) if in_channels != filters: if self.use_conv_shortcut: @@ -103,8 +99,6 @@ class ResBlock(nn.Module): self.conv3 = conv_fn(in_channels, self.filters, kernel_size=(1, 1, 1), bias=False) def forward(self, x): - # device, dtype = x.device, x.dtype - # input_dim = x.shape[1] residual = x x = self.norm1(x) x = self.activate(x) @@ -140,8 +134,6 @@ class Encoder(nn.Module): temporal_downsample=(False, True, True), num_groups=32, # for nn.GroupNorm activation_fn="swish", - device="cpu", - dtype=torch.bfloat16, ): super().__init__() self.filters = filters @@ -154,18 +146,12 @@ class Encoder(nn.Module): self.activation_fn = get_activation_fn(activation_fn) self.activate = self.activation_fn() - self.conv_fn = functools.partial( - CausalConv3d, - dtype=dtype, - device=device, - ) + self.conv_fn = CausalConv3d self.block_args = dict( conv_fn=self.conv_fn, - dtype=dtype, activation_fn=self.activation_fn, use_conv_shortcut=False, num_groups=self.num_groups, - device=device, ) # first layer conv @@ -174,8 +160,6 @@ class Encoder(nn.Module): filters, kernel_size=(3, 3, 3), bias=False, - dtype=dtype, - device=device, ) # ResBlocks and conv downsample @@ -214,13 +198,9 @@ class Encoder(nn.Module): prev_filters = filters # update in_channels # MAGVIT uses Group Normalization - self.norm1 = nn.GroupNorm( - self.num_groups, prev_filters, dtype=dtype, device=device - ) # separate channels into 32 groups + self.norm1 = nn.GroupNorm(self.num_groups, prev_filters) - self.conv2 = self.conv_fn( - prev_filters, self.embedding_dim, kernel_size=(1, 1, 1), dtype=dtype, device=device, padding="same" - ) + self.conv2 = self.conv_fn(prev_filters, self.embedding_dim, kernel_size=(1, 1, 1), padding="same") def forward(self, x): x = self.conv_in(x) @@ -252,8 +232,6 @@ class Decoder(nn.Module): temporal_downsample=(False, True, True), num_groups=32, # for nn.GroupNorm activation_fn="swish", - device="cpu", - dtype=torch.bfloat16, ): super().__init__() self.filters = filters @@ -267,18 +245,12 @@ class Decoder(nn.Module): self.activation_fn = get_activation_fn(activation_fn) self.activate = self.activation_fn() - self.conv_fn = functools.partial( - CausalConv3d, - dtype=dtype, - device=device, - ) + self.conv_fn = CausalConv3d self.block_args = dict( conv_fn=self.conv_fn, activation_fn=self.activation_fn, use_conv_shortcut=False, num_groups=self.num_groups, - device=device, - dtype=dtype, ) filters = self.filters * self.channel_multipliers[-1] @@ -323,9 +295,9 @@ class Decoder(nn.Module): nn.Identity(prev_filters), ) - self.norm1 = nn.GroupNorm(self.num_groups, prev_filters, device=device, dtype=dtype) + self.norm1 = nn.GroupNorm(self.num_groups, prev_filters) - self.conv_out = self.conv_fn(filters, in_out_channels, 3, dtype=dtype, device=device) + self.conv_out = self.conv_fn(filters, in_out_channels, 3) def forward(self, x): x = self.conv1(x) @@ -364,8 +336,6 @@ class VAE_Temporal(nn.Module): temporal_downsample=(True, True, False), num_groups=32, # for nn.GroupNorm activation_fn="swish", - device="cpu", - dtype=torch.bfloat16, ): super().__init__() @@ -383,12 +353,10 @@ class VAE_Temporal(nn.Module): temporal_downsample=temporal_downsample, num_groups=num_groups, # for nn.GroupNorm activation_fn=activation_fn, - device=device, - dtype=dtype, ) - self.quant_conv = CausalConv3d(2 * latent_embed_dim, 2 * embed_dim, 1, device=device, dtype=dtype) + self.quant_conv = CausalConv3d(2 * latent_embed_dim, 2 * embed_dim, 1) - self.post_quant_conv = CausalConv3d(embed_dim, latent_embed_dim, 1, device=device, dtype=dtype) + self.post_quant_conv = CausalConv3d(embed_dim, latent_embed_dim, 1) self.decoder = Decoder( in_out_channels=in_out_channels, latent_embed_dim=latent_embed_dim, @@ -398,8 +366,6 @@ class VAE_Temporal(nn.Module): temporal_downsample=temporal_downsample, num_groups=num_groups, # for nn.GroupNorm activation_fn=activation_fn, - device=device, - dtype=dtype, ) def get_latent_size(self, input_size): diff --git a/opensora/utils/config_utils.py b/opensora/utils/config_utils.py index 96638ef..8b9bf65 100644 --- a/opensora/utils/config_utils.py +++ b/opensora/utils/config_utils.py @@ -79,6 +79,10 @@ def merge_args(cfg, args, training=False): if args.data_path is not None: cfg.dataset["data_path"] = args.data_path args.data_path = None + if not training and args.image_size is not None and "dataset" in cfg: + cfg.dataset["image_size"] = args.image_size + if not training and args.num_frames is not None and "dataset" in cfg: + cfg.dataset["num_frames"] = args.num_frames if not training and args.cfg_scale is not None: cfg.scheduler["cfg_scale"] = args.cfg_scale args.cfg_scale = None diff --git a/scripts/inference-vae.py b/scripts/inference-vae.py index 4d42533..fc9a952 100644 --- a/scripts/inference-vae.py +++ b/scripts/inference-vae.py @@ -7,7 +7,7 @@ from colossalai.cluster import DistCoordinator from mmengine.runner import set_random_seed from tqdm import tqdm -from opensora.acceleration.parallel_states import get_data_parallel_group, set_sequence_parallel_group +from opensora.acceleration.parallel_states import get_data_parallel_group from opensora.datasets import prepare_dataloader, save_sample from opensora.models.vae.losses import VAELoss from opensora.registry import DATASETS, MODELS, build_module @@ -27,11 +27,6 @@ def main(): use_dist = True colossalai.launch_from_torch({}) coordinator = DistCoordinator() - - if coordinator.world_size > 1: - set_sequence_parallel_group(dist.group.WORLD) - else: - pass else: use_dist = False @@ -59,88 +54,40 @@ def main(): process_group=get_data_parallel_group(), ) print(f"Dataset contains {len(dataset):,} videos ({cfg.dataset.data_path})") - total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size + total_batch_size = cfg.batch_size * dist.get_world_size() print(f"Total batch size: {total_batch_size}") # ====================================================== # 4. build model & load weights # ====================================================== - # 3.1. build model - if cfg.get("vae_2d", None) is not None: - vae_2d = build_module(cfg.vae_2d, MODELS) - vae_2d.to(device, dtype).eval() - model = build_module( - cfg.model, - MODELS, - device=device, - dtype=dtype, - ) - # discriminator = build_module(cfg.discriminator, MODELS, device=device) - - # 3.2. move to device & eval - # discriminator = discriminator.to(device, dtype).eval() - - # 3.4. support for multi-resolution - # model_args = dict() - # if cfg.multi_resolution: - # image_size = cfg.dataset.image_size - # hw = torch.tensor([image_size], device=device, dtype=dtype).repeat(cfg.batch_size, 1) - # ar = torch.tensor([[image_size[0] / image_size[1]]], device=device, dtype=dtype).repeat(cfg.batch_size, 1) - # model_args["data_info"] = dict(ar=ar, hw=hw) + # 4.1. build model + model = build_module(cfg.model, MODELS) + model.to(device, dtype).eval() # ====================================================== - # 4. inference + # 5. inference # ====================================================== save_dir = cfg.save_dir os.makedirs(save_dir, exist_ok=True) - # 4.1. batch generation - # define loss function - if cfg.calc_loss: - vae_loss_fn = VAELoss( - logvar_init=cfg.logvar_init, - perceptual_loss_weight=cfg.perceptual_loss_weight, - kl_loss_weight=cfg.kl_loss_weight, - device=device, - dtype=dtype, - ) - - # adversarial_loss_fn = AdversarialLoss( - # discriminator_factor=cfg.discriminator_factor, - # discriminator_start=cfg.discriminator_start, - # generator_factor=cfg.generator_factor, - # generator_loss_type=cfg.generator_loss_type, - # ) - - # disc_loss_fn = DiscriminatorLoss( - # discriminator_factor=cfg.discriminator_factor, - # discriminator_start=cfg.discriminator_start, - # discriminator_loss_type=cfg.discriminator_loss_type, - # lecam_loss_weight=cfg.lecam_loss_weight, - # gradient_penalty_loss_weight=cfg.gradient_penalty_loss_weight, - # ) - - # # LeCam EMA for discriminator - - # lecam_ema = LeCamEMA(decay=cfg.ema_decay, dtype=dtype, device=device) - - running_loss = 0.0 - running_nll = 0.0 - loss_steps = 0 - - # disc_time_downsample_factor = 2 ** len(cfg.discriminator.channel_multipliers) - # if cfg.dataset.num_frames % disc_time_downsample_factor != 0: - # disc_time_padding = disc_time_downsample_factor - cfg.dataset.num_frames % disc_time_downsample_factor - # else: - # disc_time_padding = 0 + vae_loss_fn = VAELoss( + logvar_init=cfg.get("logvar_init", 0.0), + perceptual_loss_weight=cfg.perceptual_loss_weight, + kl_loss_weight=cfg.kl_loss_weight, + device=device, + dtype=dtype, + ) + # get total number of steps total_steps = len(dataloader) if cfg.max_test_samples is not None: total_steps = min(int(cfg.max_test_samples // cfg.batch_size), total_steps) print(f"limiting test dataset to {int(cfg.max_test_samples//cfg.batch_size) * cfg.batch_size}") dataloader_iter = iter(dataloader) + running_loss = running_nll = 0.0 + loss_steps = 0 with tqdm( range(total_steps), disable=not coordinator.is_master(), @@ -151,95 +98,28 @@ def main(): batch = next(dataloader_iter) x = batch["video"].to(device, dtype) # [B, C, T, H, W] - # ===== Spatial VAE ===== - if cfg.get("vae_2d", None) is not None: - x_z = vae_2d.encode(x) - x_z_debug = vae_2d.decode(x_z) + # ===== VAE ===== + z, posterior, x_z = model.encode(x, training=True) + x_rec, _ = model.decode(z, num_frames=x.size(2)) + x_ref = model.spatial_vae.decode(x_z) - # ====== VAE ====== - x_z_rec, posterior, z = model(x_z) - x_rec = vae_2d.decode(x_z_rec) - - if cfg.calc_loss: - # simple nll loss - nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(x, x_rec, posterior) - - # fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value=0.0, dim=2) - # fake_logits = discriminator(fake_video.contiguous()) - # adversarial_loss = adversarial_loss_fn( - # fake_logits, - # nll_loss, - # vae.get_last_layer(), - # cfg.discriminator_start + 1, # Hack to use discriminator - # is_training=vae.training, - # ) - - # vae_loss = weighted_nll_loss + weighted_kl_loss + adversarial_loss - vae_loss = weighted_nll_loss + weighted_kl_loss - - # # ====== Discriminator Loss ====== - # real_video = pad_at_dim(video, (disc_time_padding, 0), value=0.0, dim=2) - # fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value=0.0, dim=2) - - # if cfg.gradient_penalty_loss_weight is not None and cfg.gradient_penalty_loss_weight > 0.0: - # real_video = real_video.requires_grad_() - # real_logits = discriminator( - # real_video.contiguous() - # ) # SCH: not detached for now for gradient_penalty calculation - # else: - # real_logits = discriminator(real_video.contiguous().detach()) - - # fake_logits = discriminator(fake_video.contiguous().detach()) - - # lecam_ema_real, lecam_ema_fake = lecam_ema.get() - # weighted_d_adversarial_loss, lecam_loss, gradient_penalty_loss = disc_loss_fn( - # real_logits, - # fake_logits, - # cfg.discriminator_start + 1, # Hack to use discriminator - # lecam_ema_real=lecam_ema_real, - # lecam_ema_fake=lecam_ema_fake, - # real_video=real_video if cfg.gradient_penalty_loss_weight is not None else None, - # ) - - # disc_loss = weighted_d_adversarial_loss + lecam_loss + gradient_penalty_loss - - loss_steps += 1 - # running_disc_loss = disc_loss.item() / loss_steps + running_disc_loss * ((loss_steps - 1) / loss_steps) - running_loss = vae_loss.item() / loss_steps + running_loss * ((loss_steps - 1) / loss_steps) - running_nll = nll_loss.item() / loss_steps + running_nll * ((loss_steps - 1) / loss_steps) - - # ===== Spatial VAE ===== + # loss calculation + nll_loss, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(x, x_rec, posterior) + vae_loss = weighted_nll_loss + weighted_kl_loss + loss_steps += 1 + running_loss = vae_loss.item() / loss_steps + running_loss * ((loss_steps - 1) / loss_steps) + running_nll = nll_loss.item() / loss_steps + running_nll * ((loss_steps - 1) / loss_steps) if not use_dist or coordinator.is_master(): - for idx in range(len(x)): + for idx, vid in enumerate(x): pos = step * cfg.batch_size + idx - save_path = os.path.join(save_dir, f"sample_{pos}") - save_sample(x[idx], fps=cfg.fps, save_path=save_path + "_original") - save_sample(x_rec[idx], fps=cfg.fps, save_path=save_path + "_pipeline") - if cfg.get("vae_2d", None) is not None: - save_sample(x_z_debug[idx], fps=cfg.fps, save_path=save_path + "_2d") + save_path = os.path.join(save_dir, f"sample_{pos:03d}") + save_sample(vid, fps=cfg.fps, save_path=save_path + "_ori") + save_sample(x_rec[idx], fps=cfg.fps, save_path=save_path + "_rec") + save_sample(x_ref[idx], fps=cfg.fps, save_path=save_path + "_ref") - # if cfg.get("use_pipeline") == True: - # for idx, (sample_original, sample_pipeline, sample_2d) in enumerate( - # zip(video, recon_video, recon_2d) - # ): - # pos = step * cfg.batch_size + idx - # save_path = os.path.join(save_dir, f"sample_{pos}") - # save_sample(sample_original, fps=cfg.fps, save_path=save_path + "_original") - # save_sample(sample_2d, fps=cfg.fps, save_path=save_path + "_2d") - # save_sample(sample_pipeline, fps=cfg.fps, save_path=save_path + "_pipeline") - - # else: - # for idx, (original, recon) in enumerate(zip(video, recon_video)): - # pos = step * cfg.batch_size + idx - # save_path = os.path.join(save_dir, f"sample_{pos}") - # save_sample(original, fps=cfg.fps, save_path=save_path + "_original") - # save_sample(recon, fps=cfg.fps, save_path=save_path + "_recon") - - if cfg.calc_loss: - print("test vae loss:", running_loss) - print("test nll loss:", running_nll) - # print("test disc loss:", running_disc_loss) + print("test vae loss:", running_loss) + print("test nll loss:", running_nll) if __name__ == "__main__": diff --git a/scripts/train-vae.py b/scripts/train-vae.py index 32f870b..48b1b37 100644 --- a/scripts/train-vae.py +++ b/scripts/train-vae.py @@ -14,14 +14,9 @@ from tqdm import tqdm import wandb from opensora.acceleration.checkpoint import set_grad_checkpoint -from opensora.acceleration.parallel_states import ( - get_data_parallel_group, - set_data_parallel_group, - set_sequence_parallel_group, -) -from opensora.acceleration.plugin import ZeroSeqParallelPlugin +from opensora.acceleration.parallel_states import get_data_parallel_group, set_data_parallel_group from opensora.datasets import prepare_dataloader -from opensora.models.vae.losses import AdversarialLoss, DiscriminatorLoss, VAELoss +from opensora.models.vae.losses import VAELoss from opensora.registry import DATASETS, MODELS, build_module from opensora.utils.ckpt_utils import create_logger, load_json, save_json from opensora.utils.config_utils import ( @@ -78,16 +73,6 @@ def main(): max_norm=cfg.grad_clip, ) set_data_parallel_group(dist.group.WORLD) - elif cfg.plugin == "zero2-seq": - plugin = ZeroSeqParallelPlugin( - sp_size=cfg.sp_size, - stage=2, - precision=cfg.dtype, - initial_scale=2**16, - max_norm=cfg.grad_clip, - ) - set_sequence_parallel_group(plugin.sp_group) - set_data_parallel_group(plugin.dp_group) else: raise ValueError(f"Unknown plugin {cfg.plugin}") booster = Booster(plugin=plugin) @@ -110,83 +95,40 @@ def main(): ) # TODO: use plugin's prepare dataloader dataloader = prepare_dataloader(**dataloader_args) - total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size + total_batch_size = cfg.batch_size * dist.get_world_size() logger.info(f"Total batch size: {total_batch_size}") # ====================================================== # 4. build model # ====================================================== # 4.1. build model - if cfg.get("vae_2d", None) is not None: - vae_2d = build_module(cfg.vae_2d, MODELS) - vae_2d.to(device, dtype).eval() - - model = build_module( - cfg.model, - MODELS, - device=device, - dtype=dtype, - ) + model = build_module(cfg.model, MODELS) + model.to(device, dtype) model_numel, model_numel_trainable = get_model_numel(model) logger.info( f"Trainable model params: {format_numel_str(model_numel_trainable)}, Total model params: {format_numel_str(model_numel)}" ) - # discriminator = build_module(cfg.discriminator, MODELS, device=device) - # discriminator_numel, discriminator_numel_trainable = get_model_numel(discriminator) - # logger.info( - # f"Trainable discriminator params: {format_numel_str(discriminator_numel_trainable)}, Total model params: {format_numel_str(discriminator_numel)}" - # ) - - # # LeCam Initialization - # lecam_ema = LeCamEMA(decay=cfg.ema_decay, dtype=dtype, device=device) - - # 4.3. move to device - model = model.to(device, dtype) - # discriminator = discriminator.to(device, dtype) - # 4.4 loss functions vae_loss_fn = VAELoss( - logvar_init=cfg.logvar_init, + logvar_init=cfg.get("logvar_init", 0.0), perceptual_loss_weight=cfg.perceptual_loss_weight, kl_loss_weight=cfg.kl_loss_weight, device=device, dtype=dtype, ) - adversarial_loss_fn = AdversarialLoss( - discriminator_factor=cfg.discriminator_factor, - discriminator_start=cfg.discriminator_start, - generator_factor=cfg.generator_factor, - generator_loss_type=cfg.generator_loss_type, - ) - - disc_loss_fn = DiscriminatorLoss( - discriminator_factor=cfg.discriminator_factor, - discriminator_start=cfg.discriminator_start, - discriminator_loss_type=cfg.discriminator_loss_type, - lecam_loss_weight=cfg.lecam_loss_weight, - gradient_penalty_loss_weight=cfg.gradient_penalty_loss_weight, - ) - # 4.5. setup optimizer # vae optimizer optimizer = HybridAdam( filter(lambda p: p.requires_grad, model.parameters()), lr=cfg.lr, weight_decay=0, adamw_mode=True ) lr_scheduler = None - # disc optimizer - # disc_optimizer = HybridAdam( - # filter(lambda p: p.requires_grad, discriminator.parameters()), lr=cfg.lr, weight_decay=0, adamw_mode=True - # ) - # disc_lr_scheduler = None # 4.6. prepare for training if cfg.grad_checkpoint: set_grad_checkpoint(model) - # set_grad_checkpoint(discriminator) model.train() - # discriminator.train() # ======================================================= # 5. boost model for distributed training with colossalai @@ -203,11 +145,6 @@ def main(): logger.info("Boost model for distributed training") num_steps_per_epoch = len(dataloader) - # discriminator, disc_optimizer, _, _, disc_lr_scheduler = booster.boost( - # model=discriminator, optimizer=disc_optimizer, lr_scheduler=disc_lr_scheduler - # ) - # logger.info("Boost discriminator for distributed training") - # ======================================================= # 6. training loop # ======================================================= @@ -221,18 +158,6 @@ def main(): booster.load_model(model, os.path.join(cfg.load, "model")) booster.load_optimizer(optimizer, os.path.join(cfg.load, "optimizer")) - # booster.load_model(discriminator, os.path.join(cfg.load, "discriminator")) - # booster.load_optimizer(disc_optimizer, os.path.join(cfg.load, "disc_optimizer")) - - # LeCam EMA for discriminator - # lecam_path = os.path.join(cfg.load, "lecam_states.json") - # if cfg.lecam_loss_weight is not None and os.path.exists(lecam_path): - # lecam_state = load_json(lecam_path) - # lecam_ema_real, lecam_ema_fake = lecam_state["lecam_ema_real"], lecam_state["lecam_ema_fake"] - # lecam_ema = LeCamEMA( - # decay=cfg.ema_decay, ema_real=lecam_ema_real, ema_fake=lecam_ema_fake, dtype=dtype, device=device - # ) - running_states = load_json(os.path.join(cfg.load, "running_states.json")) dist.barrier() start_epoch, start_step, sampler_start_idx = ( @@ -246,15 +171,6 @@ def main(): dataloader.sampler.set_start_index(sampler_start_idx) # 6.3. training loop - - # calculate discriminator_time_padding - # disc_time_downsample_factor = 2 ** len(cfg.discriminator.channel_multipliers) - # if cfg.dataset.num_frames % disc_time_downsample_factor != 0: - # disc_time_padding = disc_time_downsample_factor - cfg.dataset.num_frames % disc_time_downsample_factor - # else: - # disc_time_padding = 0 - # video_contains_first_frame = cfg.video_contains_first_frame - for epoch in range(start_epoch, cfg.epochs): dataloader.sampler.set_epoch(epoch) dataloader_iter = iter(dataloader) @@ -269,112 +185,41 @@ def main(): ) as pbar: for step, batch in pbar: x = batch["video"].to(device, dtype) # [B, C, T, H, W] - if random.random() < 0.5: + if random.random() < cfg.get("mixed_image_ratio", 0.0): x = x[:, :, :1, :, :] - # ===== Spatial VAE ===== - if cfg.get("vae_2d", None) is not None: - with torch.no_grad(): - x_z = vae_2d.encode(x) - vae_2d.decode(x_z) - - # ====== VAE ====== - x_z_rec, posterior, z = model(x_z) - x_rec = vae_2d.decode(x_z_rec) + # ===== VAE ===== + x_rec, x_z_rec, z, posterior, x_z = model(x) # ====== Generator Loss ====== - # simple nll loss - _, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(x, x_rec, posterior) - # _, weighted_z_nll_loss, _ = vae_loss_fn(x_z, x_z_rec, posterior) - # _, debug_loss, _ = vae_loss_fn(x, x_z_debug, posterior) - # _, image_identity_loss, _ = vae_loss_fn(x_z, z, posterior) + vae_loss = torch.tensor(0.0, device=device, dtype=dtype) + log_dict = {} + if cfg.get("use_real_rec_loss", False): + _, weighted_nll_loss, weighted_kl_loss = vae_loss_fn(x, x_rec, posterior) + vae_loss += weighted_nll_loss + weighted_kl_loss + log_dict["kl_loss"] = weighted_kl_loss.item() + log_dict["nll_loss"] = weighted_nll_loss.item() + if cfg.get("use_z_rec_loss", False): + _, weighted_z_nll_loss, _ = vae_loss_fn(x_z, x_z_rec, posterior) + vae_loss += weighted_z_nll_loss + log_dict["z_nll_loss"] = weighted_z_nll_loss.item() + if cfg.get("use_image_identity_loss", False): + _, image_identity_loss, _ = vae_loss_fn(x_z, z, posterior) + vae_loss += image_identity_loss + log_dict["image_identity_loss"] = image_identity_loss.item() - # adversarial_loss = torch.tensor(0.0) - # adversarial loss - # if global_step > cfg.discriminator_start: - # # padded videos for GAN - # fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value=0.0, dim=2) - # fake_logits = discriminator(fake_video.contiguous()) - # adversarial_loss = adversarial_loss_fn( - # fake_logits, - # nll_loss, - # vae.module.get_last_layer(), - # global_step, - # is_training=vae.training, - # ) - - # vae_loss = weighted_nll_loss + weighted_kl_loss + adversarial_loss + weighted_z_nll_loss - # vae_loss = weighted_nll_loss + weighted_kl_loss + weighted_z_nll_loss + image_identity_loss - # vae_loss = weighted_z_nll_loss + image_identity_loss - # vae_loss = weighted_nll_loss + weighted_kl_loss + weighted_z_nll_loss - # vae_loss = weighted_z_nll_loss - vae_loss = weighted_nll_loss + weighted_kl_loss - - optimizer.zero_grad() # Backward & update booster.backward(loss=vae_loss, optimizer=optimizer) - # # NOTE: clip gradients? this is done in Open-Sora-Plan - # torch.nn.utils.clip_grad_norm_(vae.parameters(), 1) # NOTE: done by grad_clip optimizer.step() + optimizer.zero_grad() # Log loss values: - all_reduce_mean(vae_loss) # NOTE: this is to get average loss for logging + all_reduce_mean(vae_loss) running_loss += vae_loss.item() global_step = epoch * num_steps_per_epoch + step log_step += 1 acc_step += 1 - # ====== Discriminator Loss ====== - # if global_step > cfg.discriminator_start: - # # if video_contains_first_frame: - # # Since we don't have enough T frames, pad anyways - # real_video = pad_at_dim(video, (disc_time_padding, 0), value=0.0, dim=2) - # fake_video = pad_at_dim(recon_video, (disc_time_padding, 0), value=0.0, dim=2) - - # if cfg.gradient_penalty_loss_weight is not None and cfg.gradient_penalty_loss_weight > 0.0: - # real_video = real_video.requires_grad_() - # real_logits = discriminator( - # real_video.contiguous() - # ) # SCH: not detached for now for gradient_penalty calculation - # else: - # real_logits = discriminator(real_video.contiguous().detach()) - - # fake_logits = discriminator(fake_video.contiguous().detach()) - - # lecam_ema_real, lecam_ema_fake = lecam_ema.get() - - # weighted_d_adversarial_loss, lecam_loss, gradient_penalty_loss = disc_loss_fn( - # real_logits, - # fake_logits, - # global_step, - # lecam_ema_real=lecam_ema_real, - # lecam_ema_fake=lecam_ema_fake, - # real_video=real_video if cfg.gradient_penalty_loss_weight is not None else None, - # ) - # disc_loss = weighted_d_adversarial_loss + lecam_loss + gradient_penalty_loss - # if cfg.lecam_loss_weight is not None: - # ema_real = torch.mean(real_logits.clone().detach()).to(device, dtype) - # ema_fake = torch.mean(fake_logits.clone().detach()).to(device, dtype) - # all_reduce_mean(ema_real) - # all_reduce_mean(ema_fake) - # lecam_ema.update(ema_real, ema_fake) - - # disc_optimizer.zero_grad() - # # Backward & update - # booster.backward(loss=disc_loss, optimizer=disc_optimizer) - # # # NOTE: TODO: clip gradients? this is done in Open-Sora-Plan - # # torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1) # NOTE: done by grad_clip - # disc_optimizer.step() - - # # Log loss values: - # all_reduce_mean(disc_loss) - # running_disc_loss += disc_loss.item() - # else: - # disc_loss = torch.tensor(0.0) - # weighted_d_adversarial_loss = torch.tensor(0.0) - # lecam_loss = torch.tensor(0.0) - # gradient_penalty_loss = torch.tensor(0.0) - # Log to tensorboard if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0: avg_loss = running_loss / log_step @@ -393,16 +238,8 @@ def main(): "num_samples": global_step * total_batch_size, "epoch": epoch, "loss": vae_loss.item(), - "kl_loss": weighted_kl_loss.item(), - # "gen_adv_loss": adversarial_loss.item(), - # "disc_loss": disc_loss.item(), - # "lecam_loss": lecam_loss.item(), - # "r1_grad_penalty": gradient_penalty_loss.item(), - "nll_loss": weighted_nll_loss.item(), - # "z_nll_loss": weighted_z_nll_loss.item(), - # "image_identity_loss": image_identity_loss.item(), - # "debug_loss": debug_loss.item(), "avg_loss": avg_loss, + **log_dict, }, step=global_step, ) @@ -412,38 +249,22 @@ def main(): save_dir = os.path.join(exp_dir, f"epoch{epoch}-global_step{global_step+1}") os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) # already handled in booster save_model booster.save_model(model, os.path.join(save_dir, "model"), shard=True) - # booster.save_model(discriminator, os.path.join(save_dir, "discriminator"), shard=True) booster.save_optimizer( optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096 ) - # booster.save_optimizer( - # disc_optimizer, os.path.join(save_dir, "disc_optimizer"), shard=True, size_per_shard=4096 - # ) - running_states = { "epoch": epoch, "step": step + 1, "global_step": global_step + 1, "sample_start_index": (step + 1) * cfg.batch_size, } - - # lecam_ema_real, lecam_ema_fake = lecam_ema.get() - # lecam_state = { - # "lecam_ema_real": lecam_ema_real.item(), - # "lecam_ema_fake": lecam_ema_fake.item(), - # } if coordinator.is_master(): save_json(running_states, os.path.join(save_dir, "running_states.json")) - # if cfg.lecam_loss_weight is not None: - # save_json(lecam_state, os.path.join(save_dir, "lecam_states.json")) dist.barrier() - logger.info( f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}" ) - # print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) - # the continue epochs are not resumed, so we need to reset the sampler start index and start step dataloader.sampler.set_start_index(0) start_step = 0