mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 21:01:26 +02:00
allow loading of discriminator in inference
This commit is contained in:
parent
6581556997
commit
1c3e910fc6
|
|
@ -15,7 +15,7 @@ from torchvision.models import VGG16_Weights
|
|||
from taming.modules.losses.lpips import LPIPS # need to pip install https://github.com/CompVis/taming-transformers
|
||||
from torch import nn
|
||||
import math
|
||||
|
||||
import os
|
||||
# from diffusers.models.modeling_utils import ModelMixin
|
||||
|
||||
|
||||
|
|
@ -1206,27 +1206,27 @@ class DiscriminatorLoss(nn.Module):
|
|||
def VAE_MAGVIT_V2(from_pretrained=None, **kwargs):
|
||||
model = VAE_3D_V2(**kwargs)
|
||||
if from_pretrained is not None:
|
||||
load_checkpoint(model, from_pretrained)
|
||||
load_checkpoint(model, from_pretrained, model_name="model")
|
||||
return model
|
||||
|
||||
@MODELS.register_module("DISCRIMINATOR_3D")
|
||||
def DISCRIMINATOR_3D(from_pretrained=None, inflate_from_2d=False, use_pretrained=True, **kwargs):
|
||||
model = StyleGANDiscriminatorBlur(**kwargs).apply(xavier_uniform_weight_init)
|
||||
# model = StyleGANDiscriminator(**kwargs).apply(xavier_uniform_weight_init) # SCH: DEBUG: to change back
|
||||
# model = NLayerDiscriminator3D(input_nc=3, n_layers=3,).apply(n_layer_disc_weights_init)
|
||||
# model = NLayerDiscriminator3D(input_nc=3, n_layers=3,).apply(n_layer_disc_weights_init)
|
||||
if from_pretrained is not None:
|
||||
if use_pretrained:
|
||||
if inflate_from_2d:
|
||||
load_checkpoint_with_inflation(model, from_pretrained)
|
||||
else:
|
||||
load_checkpoint(model, from_pretrained)
|
||||
load_checkpoint(model, from_pretrained, model_name="discriminator")
|
||||
print(f"loading from:{use_pretrained}")
|
||||
else:
|
||||
print(f"discriminator use_pretrained={use_pretrained}, initializing new discriminator")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
|
||||
def load_checkpoint_with_inflation(model, ckpt_path):
|
||||
"""
|
||||
pre-train using image, then inflate to 3D videos
|
||||
|
|
|
|||
|
|
@ -80,9 +80,9 @@ def download_model(model_name):
|
|||
return model
|
||||
|
||||
|
||||
def load_from_sharded_state_dict(model, ckpt_path):
|
||||
def load_from_sharded_state_dict(model, ckpt_path, model_name="model"):
|
||||
ckpt_io = GeneralCheckpointIO()
|
||||
ckpt_io.load_model(model, os.path.join(ckpt_path, "model"))
|
||||
ckpt_io.load_model(model, os.path.join(ckpt_path, model_name))
|
||||
|
||||
def model_sharding(model: torch.nn.Module):
|
||||
global_rank = dist.get_rank()
|
||||
|
|
@ -203,14 +203,14 @@ def create_logger(logging_dir):
|
|||
return logger
|
||||
|
||||
|
||||
def load_checkpoint(model, ckpt_path, save_as_pt=True):
|
||||
def load_checkpoint(model, ckpt_path, save_as_pt=True, model_name="model"):
|
||||
if ckpt_path.endswith(".pt") or ckpt_path.endswith(".pth"):
|
||||
state_dict = find_model(ckpt_path)
|
||||
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
||||
print(f"Missing keys: {missing_keys}")
|
||||
print(f"Unexpected keys: {unexpected_keys}")
|
||||
elif os.path.isdir(ckpt_path):
|
||||
load_from_sharded_state_dict(model, ckpt_path)
|
||||
load_from_sharded_state_dict(model, ckpt_path, model_name)
|
||||
if save_as_pt:
|
||||
save_path = os.path.join(ckpt_path, "model_ckpt.pt")
|
||||
torch.save(model.state_dict(), save_path)
|
||||
|
|
|
|||
|
|
@ -47,6 +47,8 @@ def parse_args(training=False):
|
|||
def merge_args(cfg, args, training=False):
|
||||
if args.ckpt_path is not None:
|
||||
cfg.model["from_pretrained"] = args.ckpt_path
|
||||
if cfg.get("discriminator") is not None:
|
||||
cfg.discriminator["from_pretrained"] = args.ckpt_path
|
||||
args.ckpt_path = None
|
||||
|
||||
for k, v in vars(args).items():
|
||||
|
|
|
|||
Loading…
Reference in a new issue