Merge branch 'dev/v1.2' of github.com:hpcaitech/Open-Sora-dev into dev/v1.2

This commit is contained in:
zhengzangw 2024-06-17 07:01:35 +00:00
commit dc577e98b9
3 changed files with 72 additions and 46 deletions

View file

@ -22,7 +22,7 @@ model = dict(
)
vae = dict(
type="OpenSoraVAE_V1_2",
from_pretrained="/mnt/jfs/sora_checkpoints/vae-pipeline",
from_pretrained="hpcai-tech/OpenSora-VAE-v1.2",
micro_frame_size=17,
micro_batch_size=4,
)

View file

@ -1,3 +1,4 @@
import os
import torch
import torch.nn as nn
from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder
@ -5,6 +6,7 @@ from einops import rearrange
from opensora.registry import MODELS, build_module
from opensora.utils.ckpt_utils import load_checkpoint
from transformers import PretrainedConfig, PreTrainedModel
@MODELS.register_module()
@ -115,9 +117,9 @@ class VideoAutoencoderKLTemporalDecoder(nn.Module):
def dtype(self):
return next(self.parameters()).dtype
@MODELS.register_module()
class VideoAutoencoderPipeline(nn.Module):
class VideoAutoencoderPipelineConfig(PretrainedConfig):
model_type = "VideoAutoencoderPipeline"
def __init__(
self,
vae_2d=None,
@ -128,25 +130,43 @@ class VideoAutoencoderPipeline(nn.Module):
micro_frame_size=None,
shift=0.0,
scale=1.0,
**kwargs
):
super().__init__()
self.spatial_vae = build_module(vae_2d, MODELS)
self.temporal_vae = build_module(vae_temporal, MODELS)
self.vae_2d = vae_2d
self.vae_temporal = vae_temporal
self.from_pretrained = from_pretrained
self.freeze_vae_2d = freeze_vae_2d
self.cal_loss = cal_loss
self.micro_frame_size = micro_frame_size
self.micro_z_frame_size = self.temporal_vae.get_latent_size([micro_frame_size, None, None])[0]
self.shift = shift
self.scale = scale
super().__init__(**kwargs)
if from_pretrained is not None:
load_checkpoint(self, from_pretrained)
if freeze_vae_2d:
@MODELS.register_module()
class VideoAutoencoderPipeline(PreTrainedModel):
config_class = VideoAutoencoderPipelineConfig
def __init__(
self,
config: VideoAutoencoderPipelineConfig
):
super().__init__(config=config)
self.spatial_vae = build_module(config.vae_2d, MODELS)
self.temporal_vae = build_module(config.vae_temporal, MODELS)
self.cal_loss = config.cal_loss
self.micro_frame_size = config.micro_frame_size
self.micro_z_frame_size = self.temporal_vae.get_latent_size([config.micro_frame_size, None, None])[0]
if config.freeze_vae_2d:
for param in self.spatial_vae.parameters():
param.requires_grad = False
self.out_channels = self.temporal_vae.out_channels
# normalization parameters
scale = torch.tensor(scale)
shift = torch.tensor(shift)
scale = torch.tensor(config.scale)
shift = torch.tensor(config.shift)
if len(scale.shape) > 0:
scale = scale[None, :, None, None, None]
if len(shift.shape) > 0:
@ -225,38 +245,44 @@ class VideoAutoencoderPipeline(nn.Module):
def dtype(self):
return next(self.parameters()).dtype
@MODELS.register_module()
class OpenSoraVAE_V1_2(VideoAutoencoderPipeline):
def __init__(
self,
micro_batch_size=4,
micro_frame_size=17,
from_pretrained=None,
local_files_only=False,
freeze_vae_2d=False,
cal_loss=False,
):
vae_2d = dict(
def OpenSoraVAE_V1_2(
micro_batch_size=4,
micro_frame_size=17,
from_pretrained=None,
local_files_only=False,
freeze_vae_2d=False,
cal_loss=False,
):
vae_2d = dict(
type="VideoAutoencoderKL",
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
subfolder="vae",
micro_batch_size=micro_batch_size,
local_files_only=local_files_only,
)
vae_temporal = dict(
type="VAE_Temporal_SD",
from_pretrained=None,
)
shift = (-0.10, 0.34, 0.27, 0.98)
scale = (3.85, 2.32, 2.33, 3.06)
super().__init__(
vae_2d,
vae_temporal,
from_pretrained,
freeze_vae_2d=freeze_vae_2d,
cal_loss=cal_loss,
micro_frame_size=micro_frame_size,
shift=shift,
scale=scale,
)
vae_temporal = dict(
type="VAE_Temporal_SD",
from_pretrained=None,
)
shift = (-0.10, 0.34, 0.27, 0.98)
scale = (3.85, 2.32, 2.33, 3.06)
kwargs = dict(
vae_2d=vae_2d,
vae_temporal=vae_temporal,
freeze_vae_2d=freeze_vae_2d,
cal_loss=cal_loss,
micro_frame_size=micro_frame_size,
shift=shift,
scale=scale
)
if from_pretrained is not None and not os.path.isdir(from_pretrained):
model = VideoAutoencoderPipeline.from_pretrained(from_pretrained, **kwargs)
else:
config = VideoAutoencoderPipelineConfig(**kwargs)
model = VideoAutoencoderPipeline(config)
if from_pretrained:
load_checkpoint(model, from_pretrained)
return model

View file

@ -145,9 +145,9 @@ def download_model(model_name=None, local_path=None, url=None):
return model
def load_from_sharded_state_dict(model, ckpt_path, model_name="model"):
def load_from_sharded_state_dict(model, ckpt_path, model_name="model", strict=False):
ckpt_io = GeneralCheckpointIO()
ckpt_io.load_model(model, os.path.join(ckpt_path, model_name))
ckpt_io.load_model(model, os.path.join(ckpt_path, model_name), strict=strict)
def model_sharding(model: torch.nn.Module):
@ -187,14 +187,14 @@ def record_model_param_shape(model: torch.nn.Module) -> dict:
return param_shape
def load_checkpoint(model, ckpt_path, save_as_pt=False, model_name="model"):
def load_checkpoint(model, ckpt_path, save_as_pt=False, model_name="model", strict=False):
if ckpt_path.endswith(".pt") or ckpt_path.endswith(".pth"):
state_dict = find_model(ckpt_path, model=model)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=strict)
get_logger().info("Missing keys: %s", missing_keys)
get_logger().info("Unexpected keys: %s", unexpected_keys)
elif os.path.isdir(ckpt_path):
load_from_sharded_state_dict(model, ckpt_path, model_name)
load_from_sharded_state_dict(model, ckpt_path, model_name, strict=strict)
get_logger().info("Model checkpoint loaded from %s", ckpt_path)
if save_as_pt:
save_path = os.path.join(ckpt_path, model_name + "_ckpt.pt")