diff --git a/configs/opensora-v1-2/inference/sample.py b/configs/opensora-v1-2/inference/sample.py index f07587c..5a93076 100644 --- a/configs/opensora-v1-2/inference/sample.py +++ b/configs/opensora-v1-2/inference/sample.py @@ -22,7 +22,7 @@ model = dict( ) vae = dict( type="OpenSoraVAE_V1_2", - from_pretrained="/mnt/jfs/sora_checkpoints/vae-pipeline", + from_pretrained="hpcai-tech/OpenSora-VAE-v1.2", micro_frame_size=17, micro_batch_size=4, ) diff --git a/opensora/models/vae/vae.py b/opensora/models/vae/vae.py index f5769a2..f9823d9 100644 --- a/opensora/models/vae/vae.py +++ b/opensora/models/vae/vae.py @@ -1,3 +1,4 @@ +import os import torch import torch.nn as nn from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder @@ -5,6 +6,7 @@ from einops import rearrange from opensora.registry import MODELS, build_module from opensora.utils.ckpt_utils import load_checkpoint +from transformers import PretrainedConfig, PreTrainedModel @MODELS.register_module() @@ -115,9 +117,9 @@ class VideoAutoencoderKLTemporalDecoder(nn.Module): def dtype(self): return next(self.parameters()).dtype - -@MODELS.register_module() -class VideoAutoencoderPipeline(nn.Module): +class VideoAutoencoderPipelineConfig(PretrainedConfig): + model_type = "VideoAutoencoderPipeline" + def __init__( self, vae_2d=None, @@ -128,25 +130,43 @@ class VideoAutoencoderPipeline(nn.Module): micro_frame_size=None, shift=0.0, scale=1.0, + **kwargs ): - super().__init__() - self.spatial_vae = build_module(vae_2d, MODELS) - self.temporal_vae = build_module(vae_temporal, MODELS) + self.vae_2d = vae_2d + self.vae_temporal = vae_temporal + self.from_pretrained = from_pretrained + self.freeze_vae_2d = freeze_vae_2d self.cal_loss = cal_loss self.micro_frame_size = micro_frame_size - self.micro_z_frame_size = self.temporal_vae.get_latent_size([micro_frame_size, None, None])[0] + self.shift = shift + self.scale = scale + super().__init__(**kwargs) - if from_pretrained is not None: - load_checkpoint(self, from_pretrained) - if freeze_vae_2d: + +@MODELS.register_module() +class VideoAutoencoderPipeline(PreTrainedModel): + config_class = VideoAutoencoderPipelineConfig + + def __init__( + self, + config: VideoAutoencoderPipelineConfig + ): + super().__init__(config=config) + self.spatial_vae = build_module(config.vae_2d, MODELS) + self.temporal_vae = build_module(config.vae_temporal, MODELS) + self.cal_loss = config.cal_loss + self.micro_frame_size = config.micro_frame_size + self.micro_z_frame_size = self.temporal_vae.get_latent_size([config.micro_frame_size, None, None])[0] + + if config.freeze_vae_2d: for param in self.spatial_vae.parameters(): param.requires_grad = False self.out_channels = self.temporal_vae.out_channels # normalization parameters - scale = torch.tensor(scale) - shift = torch.tensor(shift) + scale = torch.tensor(config.scale) + shift = torch.tensor(config.shift) if len(scale.shape) > 0: scale = scale[None, :, None, None, None] if len(shift.shape) > 0: @@ -225,38 +245,44 @@ class VideoAutoencoderPipeline(nn.Module): def dtype(self): return next(self.parameters()).dtype - @MODELS.register_module() -class OpenSoraVAE_V1_2(VideoAutoencoderPipeline): - def __init__( - self, - micro_batch_size=4, - micro_frame_size=17, - from_pretrained=None, - local_files_only=False, - freeze_vae_2d=False, - cal_loss=False, - ): - vae_2d = dict( +def OpenSoraVAE_V1_2( + micro_batch_size=4, + micro_frame_size=17, + from_pretrained=None, + local_files_only=False, + freeze_vae_2d=False, + cal_loss=False, +): + vae_2d = dict( type="VideoAutoencoderKL", from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", subfolder="vae", micro_batch_size=micro_batch_size, local_files_only=local_files_only, ) - vae_temporal = dict( - type="VAE_Temporal_SD", - from_pretrained=None, - ) - shift = (-0.10, 0.34, 0.27, 0.98) - scale = (3.85, 2.32, 2.33, 3.06) - super().__init__( - vae_2d, - vae_temporal, - from_pretrained, - freeze_vae_2d=freeze_vae_2d, - cal_loss=cal_loss, - micro_frame_size=micro_frame_size, - shift=shift, - scale=scale, - ) + vae_temporal = dict( + type="VAE_Temporal_SD", + from_pretrained=None, + ) + shift = (-0.10, 0.34, 0.27, 0.98) + scale = (3.85, 2.32, 2.33, 3.06) + kwargs = dict( + vae_2d=vae_2d, + vae_temporal=vae_temporal, + freeze_vae_2d=freeze_vae_2d, + cal_loss=cal_loss, + micro_frame_size=micro_frame_size, + shift=shift, + scale=scale + ) + + if from_pretrained is not None and not os.path.isdir(from_pretrained): + model = VideoAutoencoderPipeline.from_pretrained(from_pretrained, **kwargs) + else: + config = VideoAutoencoderPipelineConfig(**kwargs) + model = VideoAutoencoderPipeline(config) + + if from_pretrained: + load_checkpoint(model, from_pretrained) + return model diff --git a/opensora/utils/ckpt_utils.py b/opensora/utils/ckpt_utils.py index 8ecd350..b2ac5e2 100644 --- a/opensora/utils/ckpt_utils.py +++ b/opensora/utils/ckpt_utils.py @@ -145,9 +145,9 @@ def download_model(model_name=None, local_path=None, url=None): return model -def load_from_sharded_state_dict(model, ckpt_path, model_name="model"): +def load_from_sharded_state_dict(model, ckpt_path, model_name="model", strict=False): ckpt_io = GeneralCheckpointIO() - ckpt_io.load_model(model, os.path.join(ckpt_path, model_name)) + ckpt_io.load_model(model, os.path.join(ckpt_path, model_name), strict=strict) def model_sharding(model: torch.nn.Module): @@ -187,14 +187,14 @@ def record_model_param_shape(model: torch.nn.Module) -> dict: return param_shape -def load_checkpoint(model, ckpt_path, save_as_pt=False, model_name="model"): +def load_checkpoint(model, ckpt_path, save_as_pt=False, model_name="model", strict=False): if ckpt_path.endswith(".pt") or ckpt_path.endswith(".pth"): state_dict = find_model(ckpt_path, model=model) - missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=strict) get_logger().info("Missing keys: %s", missing_keys) get_logger().info("Unexpected keys: %s", unexpected_keys) elif os.path.isdir(ckpt_path): - load_from_sharded_state_dict(model, ckpt_path, model_name) + load_from_sharded_state_dict(model, ckpt_path, model_name, strict=strict) get_logger().info("Model checkpoint loaded from %s", ckpt_path) if save_as_pt: save_path = os.path.join(ckpt_path, model_name + "_ckpt.pt")