mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-05-20 09:22:22 +02:00
Merge branch 'dev/v1.2' of github.com:hpcaitech/Open-Sora-dev into dev/v1.2
This commit is contained in:
commit
dc577e98b9
|
|
@ -22,7 +22,7 @@ model = dict(
|
|||
)
|
||||
vae = dict(
|
||||
type="OpenSoraVAE_V1_2",
|
||||
from_pretrained="/mnt/jfs/sora_checkpoints/vae-pipeline",
|
||||
from_pretrained="hpcai-tech/OpenSora-VAE-v1.2",
|
||||
micro_frame_size=17,
|
||||
micro_batch_size=4,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder
|
||||
|
|
@ -5,6 +6,7 @@ from einops import rearrange
|
|||
|
||||
from opensora.registry import MODELS, build_module
|
||||
from opensora.utils.ckpt_utils import load_checkpoint
|
||||
from transformers import PretrainedConfig, PreTrainedModel
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
|
@ -115,9 +117,9 @@ class VideoAutoencoderKLTemporalDecoder(nn.Module):
|
|||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class VideoAutoencoderPipeline(nn.Module):
|
||||
class VideoAutoencoderPipelineConfig(PretrainedConfig):
|
||||
model_type = "VideoAutoencoderPipeline"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae_2d=None,
|
||||
|
|
@ -128,25 +130,43 @@ class VideoAutoencoderPipeline(nn.Module):
|
|||
micro_frame_size=None,
|
||||
shift=0.0,
|
||||
scale=1.0,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.spatial_vae = build_module(vae_2d, MODELS)
|
||||
self.temporal_vae = build_module(vae_temporal, MODELS)
|
||||
self.vae_2d = vae_2d
|
||||
self.vae_temporal = vae_temporal
|
||||
self.from_pretrained = from_pretrained
|
||||
self.freeze_vae_2d = freeze_vae_2d
|
||||
self.cal_loss = cal_loss
|
||||
self.micro_frame_size = micro_frame_size
|
||||
self.micro_z_frame_size = self.temporal_vae.get_latent_size([micro_frame_size, None, None])[0]
|
||||
self.shift = shift
|
||||
self.scale = scale
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if from_pretrained is not None:
|
||||
load_checkpoint(self, from_pretrained)
|
||||
if freeze_vae_2d:
|
||||
|
||||
@MODELS.register_module()
|
||||
class VideoAutoencoderPipeline(PreTrainedModel):
|
||||
config_class = VideoAutoencoderPipelineConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: VideoAutoencoderPipelineConfig
|
||||
):
|
||||
super().__init__(config=config)
|
||||
self.spatial_vae = build_module(config.vae_2d, MODELS)
|
||||
self.temporal_vae = build_module(config.vae_temporal, MODELS)
|
||||
self.cal_loss = config.cal_loss
|
||||
self.micro_frame_size = config.micro_frame_size
|
||||
self.micro_z_frame_size = self.temporal_vae.get_latent_size([config.micro_frame_size, None, None])[0]
|
||||
|
||||
if config.freeze_vae_2d:
|
||||
for param in self.spatial_vae.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self.out_channels = self.temporal_vae.out_channels
|
||||
|
||||
# normalization parameters
|
||||
scale = torch.tensor(scale)
|
||||
shift = torch.tensor(shift)
|
||||
scale = torch.tensor(config.scale)
|
||||
shift = torch.tensor(config.shift)
|
||||
if len(scale.shape) > 0:
|
||||
scale = scale[None, :, None, None, None]
|
||||
if len(shift.shape) > 0:
|
||||
|
|
@ -225,38 +245,44 @@ class VideoAutoencoderPipeline(nn.Module):
|
|||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class OpenSoraVAE_V1_2(VideoAutoencoderPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
micro_batch_size=4,
|
||||
micro_frame_size=17,
|
||||
from_pretrained=None,
|
||||
local_files_only=False,
|
||||
freeze_vae_2d=False,
|
||||
cal_loss=False,
|
||||
):
|
||||
vae_2d = dict(
|
||||
def OpenSoraVAE_V1_2(
|
||||
micro_batch_size=4,
|
||||
micro_frame_size=17,
|
||||
from_pretrained=None,
|
||||
local_files_only=False,
|
||||
freeze_vae_2d=False,
|
||||
cal_loss=False,
|
||||
):
|
||||
vae_2d = dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
||||
subfolder="vae",
|
||||
micro_batch_size=micro_batch_size,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
vae_temporal = dict(
|
||||
type="VAE_Temporal_SD",
|
||||
from_pretrained=None,
|
||||
)
|
||||
shift = (-0.10, 0.34, 0.27, 0.98)
|
||||
scale = (3.85, 2.32, 2.33, 3.06)
|
||||
super().__init__(
|
||||
vae_2d,
|
||||
vae_temporal,
|
||||
from_pretrained,
|
||||
freeze_vae_2d=freeze_vae_2d,
|
||||
cal_loss=cal_loss,
|
||||
micro_frame_size=micro_frame_size,
|
||||
shift=shift,
|
||||
scale=scale,
|
||||
)
|
||||
vae_temporal = dict(
|
||||
type="VAE_Temporal_SD",
|
||||
from_pretrained=None,
|
||||
)
|
||||
shift = (-0.10, 0.34, 0.27, 0.98)
|
||||
scale = (3.85, 2.32, 2.33, 3.06)
|
||||
kwargs = dict(
|
||||
vae_2d=vae_2d,
|
||||
vae_temporal=vae_temporal,
|
||||
freeze_vae_2d=freeze_vae_2d,
|
||||
cal_loss=cal_loss,
|
||||
micro_frame_size=micro_frame_size,
|
||||
shift=shift,
|
||||
scale=scale
|
||||
)
|
||||
|
||||
if from_pretrained is not None and not os.path.isdir(from_pretrained):
|
||||
model = VideoAutoencoderPipeline.from_pretrained(from_pretrained, **kwargs)
|
||||
else:
|
||||
config = VideoAutoencoderPipelineConfig(**kwargs)
|
||||
model = VideoAutoencoderPipeline(config)
|
||||
|
||||
if from_pretrained:
|
||||
load_checkpoint(model, from_pretrained)
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -145,9 +145,9 @@ def download_model(model_name=None, local_path=None, url=None):
|
|||
return model
|
||||
|
||||
|
||||
def load_from_sharded_state_dict(model, ckpt_path, model_name="model"):
|
||||
def load_from_sharded_state_dict(model, ckpt_path, model_name="model", strict=False):
|
||||
ckpt_io = GeneralCheckpointIO()
|
||||
ckpt_io.load_model(model, os.path.join(ckpt_path, model_name))
|
||||
ckpt_io.load_model(model, os.path.join(ckpt_path, model_name), strict=strict)
|
||||
|
||||
|
||||
def model_sharding(model: torch.nn.Module):
|
||||
|
|
@ -187,14 +187,14 @@ def record_model_param_shape(model: torch.nn.Module) -> dict:
|
|||
return param_shape
|
||||
|
||||
|
||||
def load_checkpoint(model, ckpt_path, save_as_pt=False, model_name="model"):
|
||||
def load_checkpoint(model, ckpt_path, save_as_pt=False, model_name="model", strict=False):
|
||||
if ckpt_path.endswith(".pt") or ckpt_path.endswith(".pth"):
|
||||
state_dict = find_model(ckpt_path, model=model)
|
||||
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
||||
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=strict)
|
||||
get_logger().info("Missing keys: %s", missing_keys)
|
||||
get_logger().info("Unexpected keys: %s", unexpected_keys)
|
||||
elif os.path.isdir(ckpt_path):
|
||||
load_from_sharded_state_dict(model, ckpt_path, model_name)
|
||||
load_from_sharded_state_dict(model, ckpt_path, model_name, strict=strict)
|
||||
get_logger().info("Model checkpoint loaded from %s", ckpt_path)
|
||||
if save_as_pt:
|
||||
save_path = os.path.join(ckpt_path, model_name + "_ckpt.pt")
|
||||
|
|
|
|||
Loading…
Reference in a new issue