allow loading of discriminator in inference

This commit is contained in:
Shen-Chenhui 2024-04-19 14:19:29 +08:00
parent 6581556997
commit 1c3e910fc6
3 changed files with 11 additions and 9 deletions

View file

@ -15,7 +15,7 @@ from torchvision.models import VGG16_Weights
from taming.modules.losses.lpips import LPIPS # need to pip install https://github.com/CompVis/taming-transformers
from torch import nn
import math
import os
# from diffusers.models.modeling_utils import ModelMixin
@ -1206,27 +1206,27 @@ class DiscriminatorLoss(nn.Module):
def VAE_MAGVIT_V2(from_pretrained=None, **kwargs):
model = VAE_3D_V2(**kwargs)
if from_pretrained is not None:
load_checkpoint(model, from_pretrained)
load_checkpoint(model, from_pretrained, model_name="model")
return model
@MODELS.register_module("DISCRIMINATOR_3D")
def DISCRIMINATOR_3D(from_pretrained=None, inflate_from_2d=False, use_pretrained=True, **kwargs):
model = StyleGANDiscriminatorBlur(**kwargs).apply(xavier_uniform_weight_init)
# model = StyleGANDiscriminator(**kwargs).apply(xavier_uniform_weight_init) # SCH: DEBUG: to change back
# model = NLayerDiscriminator3D(input_nc=3, n_layers=3,).apply(n_layer_disc_weights_init)
# model = NLayerDiscriminator3D(input_nc=3, n_layers=3,).apply(n_layer_disc_weights_init)
if from_pretrained is not None:
if use_pretrained:
if inflate_from_2d:
load_checkpoint_with_inflation(model, from_pretrained)
else:
load_checkpoint(model, from_pretrained)
load_checkpoint(model, from_pretrained, model_name="discriminator")
print(f"loading from:{use_pretrained}")
else:
print(f"discriminator use_pretrained={use_pretrained}, initializing new discriminator")
return model
def load_checkpoint_with_inflation(model, ckpt_path):
"""
pre-train using image, then inflate to 3D videos

View file

@ -80,9 +80,9 @@ def download_model(model_name):
return model
def load_from_sharded_state_dict(model, ckpt_path):
def load_from_sharded_state_dict(model, ckpt_path, model_name="model"):
ckpt_io = GeneralCheckpointIO()
ckpt_io.load_model(model, os.path.join(ckpt_path, "model"))
ckpt_io.load_model(model, os.path.join(ckpt_path, model_name))
def model_sharding(model: torch.nn.Module):
global_rank = dist.get_rank()
@ -203,14 +203,14 @@ def create_logger(logging_dir):
return logger
def load_checkpoint(model, ckpt_path, save_as_pt=True):
def load_checkpoint(model, ckpt_path, save_as_pt=True, model_name="model"):
if ckpt_path.endswith(".pt") or ckpt_path.endswith(".pth"):
state_dict = find_model(ckpt_path)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
print(f"Missing keys: {missing_keys}")
print(f"Unexpected keys: {unexpected_keys}")
elif os.path.isdir(ckpt_path):
load_from_sharded_state_dict(model, ckpt_path)
load_from_sharded_state_dict(model, ckpt_path, model_name)
if save_as_pt:
save_path = os.path.join(ckpt_path, "model_ckpt.pt")
torch.save(model.state_dict(), save_path)

View file

@ -47,6 +47,8 @@ def parse_args(training=False):
def merge_args(cfg, args, training=False):
if args.ckpt_path is not None:
cfg.model["from_pretrained"] = args.ckpt_path
if cfg.get("discriminator") is not None:
cfg.discriminator["from_pretrained"] = args.ckpt_path
args.ckpt_path = None
for k, v in vars(args).items():