This commit is contained in:
Shen-Chenhui 2024-05-02 16:32:19 +08:00
parent 4f915a51dd
commit b35b933841
3 changed files with 130 additions and 0 deletions

View file

@ -0,0 +1,65 @@
num_frames = 17
image_size = (256, 256)
# Define dataset
dataset = dict(
type="VideoTextDataset",
data_path=None,
num_frames=num_frames,
frame_interval=1,
image_size=image_size,
)
# Define acceleration
num_workers = 16
dtype = "bf16"
grad_checkpoint = True
plugin = "zero2"
# Define model
model = dict(
type="VideoAutoencoderPipeline",
freeze_vae_2d=False,
from_pretrained=None,
vae_2d=dict(
type="VideoAutoencoderKL",
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
subfolder="vae",
local_files_only=True,
),
vae_temporal=dict(
type="VAE_Temporal_SD",
from_pretrained=None,
),
)
discriminator = dict(
type="NLayerDiscriminator",
from_pretrained="/home/shenchenhui/opensoraplan-v1.0.0-discriminator.pt",
input_nc=3,
n_layers=3,
use_actnorm=False,
)
# loss weights
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
kl_loss_weight = 1e-6
mixed_image_ratio = 0.2
use_real_rec_loss = True
use_z_rec_loss = False
use_image_identity_loss = False
# Others
seed = 42
outputs = "outputs"
wandb = False
epochs = 100
log_every = 1
ckpt_every = 1000
load = None
batch_size = 1
lr = 1e-5
grad_clip = 1.0

View file

@ -148,6 +148,62 @@ class ResBlockDown(nn.Module):
out = (residual + x) / math.sqrt(2)
return out
@MODELS.register_module()
class NLayerDiscriminator(nn.Module):
"""Defines a PatchGAN discriminator as in Pix2Pix
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
"""
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False, from_pretrained=None):
"""Construct a PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
norm_layer -- normalization layer
"""
super(NLayerDiscriminator, self).__init__()
norm_layer = nn.BatchNorm2d
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func != nn.BatchNorm2d
else:
use_bias = norm_layer != nn.BatchNorm2d
kw = 4
padw = 1
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
self.main = nn.Sequential(*sequence)
if from_pretrained is not None:
load_checkpoint(self, from_pretrained)
def forward(self, input):
"""Standard forward."""
return self.main(input)
class NLayerDiscriminator3D(nn.Module):
"""Defines a 3D PatchGAN discriminator as in Pix2Pix but for 3D inputs."""

View file

@ -109,6 +109,15 @@ def main():
f"Trainable model params: {format_numel_str(model_numel_trainable)}, Total model params: {format_numel_str(model_numel)}"
)
if cfg.get("discriminator", False) != False:
discriminator = build_module(cfg.discriminator, MODELS)
discriminator.to(device, dtype)
discriminator_numel, discriminator_numel_trainable = get_model_numel(discriminator)
logger.info(
f"Trainable model params: {format_numel_str(discriminator_numel_trainable)}, Total model params: {format_numel_str(discriminator_numel)}"
)
breakpoint()
# 4.4 loss functions
vae_loss_fn = VAELoss(
logvar_init=cfg.get("logvar_init", 0.0),