From 1c3e910fc60d022ca880fa8bcfad46bfbdd66d97 Mon Sep 17 00:00:00 2001 From: Shen-Chenhui Date: Fri, 19 Apr 2024 14:19:29 +0800 Subject: [PATCH] allow loading of discriminator in inference --- opensora/models/vae/vae_3d_v2.py | 10 +++++----- opensora/utils/ckpt_utils.py | 8 ++++---- opensora/utils/config_utils.py | 2 ++ 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/opensora/models/vae/vae_3d_v2.py b/opensora/models/vae/vae_3d_v2.py index efc1e51..7d7a26f 100644 --- a/opensora/models/vae/vae_3d_v2.py +++ b/opensora/models/vae/vae_3d_v2.py @@ -15,7 +15,7 @@ from torchvision.models import VGG16_Weights from taming.modules.losses.lpips import LPIPS # need to pip install https://github.com/CompVis/taming-transformers from torch import nn import math - +import os # from diffusers.models.modeling_utils import ModelMixin @@ -1206,27 +1206,27 @@ class DiscriminatorLoss(nn.Module): def VAE_MAGVIT_V2(from_pretrained=None, **kwargs): model = VAE_3D_V2(**kwargs) if from_pretrained is not None: - load_checkpoint(model, from_pretrained) + load_checkpoint(model, from_pretrained, model_name="model") return model @MODELS.register_module("DISCRIMINATOR_3D") def DISCRIMINATOR_3D(from_pretrained=None, inflate_from_2d=False, use_pretrained=True, **kwargs): model = StyleGANDiscriminatorBlur(**kwargs).apply(xavier_uniform_weight_init) # model = StyleGANDiscriminator(**kwargs).apply(xavier_uniform_weight_init) # SCH: DEBUG: to change back - # model = NLayerDiscriminator3D(input_nc=3, n_layers=3,).apply(n_layer_disc_weights_init) + # model = NLayerDiscriminator3D(input_nc=3, n_layers=3,).apply(n_layer_disc_weights_init) if from_pretrained is not None: if use_pretrained: if inflate_from_2d: load_checkpoint_with_inflation(model, from_pretrained) else: - load_checkpoint(model, from_pretrained) + load_checkpoint(model, from_pretrained, model_name="discriminator") + print(f"loading from:{use_pretrained}") else: print(f"discriminator use_pretrained={use_pretrained}, initializing new discriminator") return model - def load_checkpoint_with_inflation(model, ckpt_path): """ pre-train using image, then inflate to 3D videos diff --git a/opensora/utils/ckpt_utils.py b/opensora/utils/ckpt_utils.py index 3fdf8b7..f910009 100644 --- a/opensora/utils/ckpt_utils.py +++ b/opensora/utils/ckpt_utils.py @@ -80,9 +80,9 @@ def download_model(model_name): return model -def load_from_sharded_state_dict(model, ckpt_path): +def load_from_sharded_state_dict(model, ckpt_path, model_name="model"): ckpt_io = GeneralCheckpointIO() - ckpt_io.load_model(model, os.path.join(ckpt_path, "model")) + ckpt_io.load_model(model, os.path.join(ckpt_path, model_name)) def model_sharding(model: torch.nn.Module): global_rank = dist.get_rank() @@ -203,14 +203,14 @@ def create_logger(logging_dir): return logger -def load_checkpoint(model, ckpt_path, save_as_pt=True): +def load_checkpoint(model, ckpt_path, save_as_pt=True, model_name="model"): if ckpt_path.endswith(".pt") or ckpt_path.endswith(".pth"): state_dict = find_model(ckpt_path) missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) print(f"Missing keys: {missing_keys}") print(f"Unexpected keys: {unexpected_keys}") elif os.path.isdir(ckpt_path): - load_from_sharded_state_dict(model, ckpt_path) + load_from_sharded_state_dict(model, ckpt_path, model_name) if save_as_pt: save_path = os.path.join(ckpt_path, "model_ckpt.pt") torch.save(model.state_dict(), save_path) diff --git a/opensora/utils/config_utils.py b/opensora/utils/config_utils.py index ae5713e..66b8c5a 100644 --- a/opensora/utils/config_utils.py +++ b/opensora/utils/config_utils.py @@ -47,6 +47,8 @@ def parse_args(training=False): def merge_args(cfg, args, training=False): if args.ckpt_path is not None: cfg.model["from_pretrained"] = args.ckpt_path + if cfg.get("discriminator") is not None: + cfg.discriminator["from_pretrained"] = args.ckpt_path args.ckpt_path = None for k, v in vars(args).items():