add local_files_only

This commit is contained in:
Zangwei Zheng 2024-04-19 16:47:26 +08:00
parent dce2ef4a1c
commit 8bf34f4418
3 changed files with 24 additions and 6 deletions

View file

@ -47,12 +47,14 @@ vae = dict(
type="VideoAutoencoderKL",
from_pretrained="stabilityai/sd-vae-ft-ema",
micro_batch_size=4,
local_files_only=True,
)
text_encoder = dict(
type="t5",
from_pretrained="DeepFloyd/t5-v1_1-xxl",
model_max_length=200,
shardformer=True,
local_files_only=True,
)
scheduler = dict(
type="iddpm-speed",

View file

@ -46,6 +46,7 @@ class T5Embedder:
torch_dtype=None,
use_offload_folder=None,
model_max_length=120,
local_files_only=False,
):
self.device = torch.device(device)
self.torch_dtype = torch_dtype or torch.bfloat16
@ -99,8 +100,17 @@ class T5Embedder:
self.hf_token = hf_token
assert from_pretrained in self.available_models
self.tokenizer = AutoTokenizer.from_pretrained(from_pretrained, cache_dir=cache_dir)
self.model = T5EncoderModel.from_pretrained(from_pretrained, cache_dir=cache_dir, **t5_model_kwargs).eval()
self.tokenizer = AutoTokenizer.from_pretrained(
from_pretrained,
cache_dir=cache_dir,
local_files_only=local_files_only,
)
self.model = T5EncoderModel.from_pretrained(
from_pretrained,
cache_dir=cache_dir,
local_files_only=local_files_only,
**t5_model_kwargs,
).eval()
self.model_max_length = model_max_length
def get_text_embeddings(self, texts):
@ -134,6 +144,7 @@ class T5Encoder:
dtype=torch.float,
cache_dir=None,
shardformer=False,
local_files_only=False,
):
assert from_pretrained is not None, "Please specify the path to the T5 model"
@ -143,6 +154,7 @@ class T5Encoder:
from_pretrained=from_pretrained,
cache_dir=cache_dir,
model_max_length=model_max_length,
local_files_only=local_files_only,
)
self.t5.model.to(dtype=dtype)
self.y_embedder = None

View file

@ -8,9 +8,11 @@ from opensora.registry import MODELS
@MODELS.register_module()
class VideoAutoencoderKL(nn.Module):
def __init__(self, from_pretrained=None, micro_batch_size=None, cache_dir=None):
def __init__(self, from_pretrained=None, micro_batch_size=None, cache_dir=None, local_files_only=False):
super().__init__()
self.module = AutoencoderKL.from_pretrained(from_pretrained, cache_dir=cache_dir)
self.module = AutoencoderKL.from_pretrained(
from_pretrained, cache_dir=cache_dir, local_files_only=local_files_only
)
self.out_channels = self.module.config.latent_channels
self.patch_size = (1, 8, 8)
self.micro_batch_size = micro_batch_size
@ -70,9 +72,11 @@ class VideoAutoencoderKL(nn.Module):
@MODELS.register_module()
class VideoAutoencoderKLTemporalDecoder(nn.Module):
def __init__(self, from_pretrained=None, cache_dir=None):
def __init__(self, from_pretrained=None, cache_dir=None, local_files_only=False):
super().__init__()
self.module = AutoencoderKLTemporalDecoder.from_pretrained(from_pretrained, cache_dir=cache_dir)
self.module = AutoencoderKLTemporalDecoder.from_pretrained(
from_pretrained, cache_dir=cache_dir, local_files_only=local_files_only
)
self.out_channels = self.module.config.latent_channels
self.patch_size = (1, 8, 8)