mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 21:01:26 +02:00
75 lines
1.4 KiB
Python
75 lines
1.4 KiB
Python
num_frames = 17
|
|
image_size = (256, 256)
|
|
|
|
# Define dataset
|
|
dataset = dict(
|
|
type="VideoTextDataset",
|
|
data_path=None,
|
|
num_frames=num_frames,
|
|
frame_interval=1,
|
|
image_size=image_size,
|
|
)
|
|
|
|
# Define acceleration
|
|
num_workers = 16
|
|
dtype = "bf16"
|
|
grad_checkpoint = True
|
|
plugin = "zero2"
|
|
|
|
# Define model
|
|
model = dict(
|
|
type="VideoAutoencoderPipeline",
|
|
freeze_vae_2d=False,
|
|
from_pretrained=None,
|
|
vae_2d=dict(
|
|
type="VideoAutoencoderKL",
|
|
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
|
subfolder="vae",
|
|
local_files_only=True,
|
|
),
|
|
vae_temporal=dict(
|
|
type="VAE_Temporal_SD",
|
|
from_pretrained=None,
|
|
),
|
|
)
|
|
|
|
discriminator = dict(
|
|
type="NLayerDiscriminator",
|
|
from_pretrained="/home/shenchenhui/opensoraplan-v1.0.0-discriminator.pt",
|
|
input_nc=3,
|
|
n_layers=3,
|
|
use_actnorm=False,
|
|
)
|
|
|
|
# discriminator hyper-parames TODO
|
|
discriminator_factor=1
|
|
discriminator_start=-1
|
|
generator_factor=0.5
|
|
generator_loss_type="hinge"
|
|
discriminator_loss_type="hinge"
|
|
lecam_loss_weight=None
|
|
gradient_penalty_loss_weight=None
|
|
|
|
# loss weights
|
|
perceptual_loss_weight = 0.1 # use vgg is not None and more than 0
|
|
kl_loss_weight = 1e-6
|
|
|
|
mixed_image_ratio = 0.2
|
|
use_real_rec_loss = True
|
|
use_z_rec_loss = False
|
|
use_image_identity_loss = False
|
|
|
|
# Others
|
|
seed = 42
|
|
outputs = "outputs"
|
|
wandb = False
|
|
|
|
epochs = 100
|
|
log_every = 1
|
|
ckpt_every = 1000
|
|
load = None
|
|
|
|
batch_size = 1
|
|
lr = 1e-5
|
|
grad_clip = 1.0
|