mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 13:14:44 +02:00
add disc
This commit is contained in:
parent
4f915a51dd
commit
b35b933841
65
configs/vae/train/video_disc.py
Normal file
65
configs/vae/train/video_disc.py
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
num_frames = 17
|
||||
image_size = (256, 256)
|
||||
|
||||
# Define dataset
|
||||
dataset = dict(
|
||||
type="VideoTextDataset",
|
||||
data_path=None,
|
||||
num_frames=num_frames,
|
||||
frame_interval=1,
|
||||
image_size=image_size,
|
||||
)
|
||||
|
||||
# Define acceleration
|
||||
num_workers = 16
|
||||
dtype = "bf16"
|
||||
grad_checkpoint = True
|
||||
plugin = "zero2"
|
||||
|
||||
# Define model
|
||||
model = dict(
|
||||
type="VideoAutoencoderPipeline",
|
||||
freeze_vae_2d=False,
|
||||
from_pretrained=None,
|
||||
vae_2d=dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
||||
subfolder="vae",
|
||||
local_files_only=True,
|
||||
),
|
||||
vae_temporal=dict(
|
||||
type="VAE_Temporal_SD",
|
||||
from_pretrained=None,
|
||||
),
|
||||
)
|
||||
|
||||
discriminator = dict(
|
||||
type="NLayerDiscriminator",
|
||||
from_pretrained="/home/shenchenhui/opensoraplan-v1.0.0-discriminator.pt",
|
||||
input_nc=3,
|
||||
n_layers=3,
|
||||
use_actnorm=False,
|
||||
)
|
||||
|
||||
# loss weights
|
||||
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
|
||||
kl_loss_weight = 1e-6
|
||||
|
||||
mixed_image_ratio = 0.2
|
||||
use_real_rec_loss = True
|
||||
use_z_rec_loss = False
|
||||
use_image_identity_loss = False
|
||||
|
||||
# Others
|
||||
seed = 42
|
||||
outputs = "outputs"
|
||||
wandb = False
|
||||
|
||||
epochs = 100
|
||||
log_every = 1
|
||||
ckpt_every = 1000
|
||||
load = None
|
||||
|
||||
batch_size = 1
|
||||
lr = 1e-5
|
||||
grad_clip = 1.0
|
||||
|
|
@ -148,6 +148,62 @@ class ResBlockDown(nn.Module):
|
|||
out = (residual + x) / math.sqrt(2)
|
||||
return out
|
||||
|
||||
@MODELS.register_module()
|
||||
class NLayerDiscriminator(nn.Module):
|
||||
"""Defines a PatchGAN discriminator as in Pix2Pix
|
||||
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
|
||||
"""
|
||||
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False, from_pretrained=None):
|
||||
"""Construct a PatchGAN discriminator
|
||||
Parameters:
|
||||
input_nc (int) -- the number of channels in input images
|
||||
ndf (int) -- the number of filters in the last conv layer
|
||||
n_layers (int) -- the number of conv layers in the discriminator
|
||||
norm_layer -- normalization layer
|
||||
"""
|
||||
super(NLayerDiscriminator, self).__init__()
|
||||
|
||||
norm_layer = nn.BatchNorm2d
|
||||
|
||||
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
||||
use_bias = norm_layer.func != nn.BatchNorm2d
|
||||
else:
|
||||
use_bias = norm_layer != nn.BatchNorm2d
|
||||
|
||||
kw = 4
|
||||
padw = 1
|
||||
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
||||
nf_mult = 1
|
||||
nf_mult_prev = 1
|
||||
for n in range(1, n_layers): # gradually increase the number of filters
|
||||
nf_mult_prev = nf_mult
|
||||
nf_mult = min(2 ** n, 8)
|
||||
sequence += [
|
||||
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
||||
norm_layer(ndf * nf_mult),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
|
||||
nf_mult_prev = nf_mult
|
||||
nf_mult = min(2 ** n_layers, 8)
|
||||
sequence += [
|
||||
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
||||
norm_layer(ndf * nf_mult),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
|
||||
sequence += [
|
||||
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
|
||||
self.main = nn.Sequential(*sequence)
|
||||
|
||||
if from_pretrained is not None:
|
||||
load_checkpoint(self, from_pretrained)
|
||||
|
||||
def forward(self, input):
|
||||
"""Standard forward."""
|
||||
return self.main(input)
|
||||
|
||||
|
||||
|
||||
class NLayerDiscriminator3D(nn.Module):
|
||||
"""Defines a 3D PatchGAN discriminator as in Pix2Pix but for 3D inputs."""
|
||||
|
|
|
|||
|
|
@ -109,6 +109,15 @@ def main():
|
|||
f"Trainable model params: {format_numel_str(model_numel_trainable)}, Total model params: {format_numel_str(model_numel)}"
|
||||
)
|
||||
|
||||
if cfg.get("discriminator", False) != False:
|
||||
discriminator = build_module(cfg.discriminator, MODELS)
|
||||
discriminator.to(device, dtype)
|
||||
discriminator_numel, discriminator_numel_trainable = get_model_numel(discriminator)
|
||||
logger.info(
|
||||
f"Trainable model params: {format_numel_str(discriminator_numel_trainable)}, Total model params: {format_numel_str(discriminator_numel)}"
|
||||
)
|
||||
breakpoint()
|
||||
|
||||
# 4.4 loss functions
|
||||
vae_loss_fn = VAELoss(
|
||||
logvar_init=cfg.get("logvar_init", 0.0),
|
||||
|
|
|
|||
Loading…
Reference in a new issue