mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
add local_files_only
This commit is contained in:
parent
dce2ef4a1c
commit
8bf34f4418
|
|
@ -47,12 +47,14 @@ vae = dict(
|
|||
type="VideoAutoencoderKL",
|
||||
from_pretrained="stabilityai/sd-vae-ft-ema",
|
||||
micro_batch_size=4,
|
||||
local_files_only=True,
|
||||
)
|
||||
text_encoder = dict(
|
||||
type="t5",
|
||||
from_pretrained="DeepFloyd/t5-v1_1-xxl",
|
||||
model_max_length=200,
|
||||
shardformer=True,
|
||||
local_files_only=True,
|
||||
)
|
||||
scheduler = dict(
|
||||
type="iddpm-speed",
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ class T5Embedder:
|
|||
torch_dtype=None,
|
||||
use_offload_folder=None,
|
||||
model_max_length=120,
|
||||
local_files_only=False,
|
||||
):
|
||||
self.device = torch.device(device)
|
||||
self.torch_dtype = torch_dtype or torch.bfloat16
|
||||
|
|
@ -99,8 +100,17 @@ class T5Embedder:
|
|||
self.hf_token = hf_token
|
||||
|
||||
assert from_pretrained in self.available_models
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(from_pretrained, cache_dir=cache_dir)
|
||||
self.model = T5EncoderModel.from_pretrained(from_pretrained, cache_dir=cache_dir, **t5_model_kwargs).eval()
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
from_pretrained,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
self.model = T5EncoderModel.from_pretrained(
|
||||
from_pretrained,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
**t5_model_kwargs,
|
||||
).eval()
|
||||
self.model_max_length = model_max_length
|
||||
|
||||
def get_text_embeddings(self, texts):
|
||||
|
|
@ -134,6 +144,7 @@ class T5Encoder:
|
|||
dtype=torch.float,
|
||||
cache_dir=None,
|
||||
shardformer=False,
|
||||
local_files_only=False,
|
||||
):
|
||||
assert from_pretrained is not None, "Please specify the path to the T5 model"
|
||||
|
||||
|
|
@ -143,6 +154,7 @@ class T5Encoder:
|
|||
from_pretrained=from_pretrained,
|
||||
cache_dir=cache_dir,
|
||||
model_max_length=model_max_length,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
self.t5.model.to(dtype=dtype)
|
||||
self.y_embedder = None
|
||||
|
|
|
|||
|
|
@ -8,9 +8,11 @@ from opensora.registry import MODELS
|
|||
|
||||
@MODELS.register_module()
|
||||
class VideoAutoencoderKL(nn.Module):
|
||||
def __init__(self, from_pretrained=None, micro_batch_size=None, cache_dir=None):
|
||||
def __init__(self, from_pretrained=None, micro_batch_size=None, cache_dir=None, local_files_only=False):
|
||||
super().__init__()
|
||||
self.module = AutoencoderKL.from_pretrained(from_pretrained, cache_dir=cache_dir)
|
||||
self.module = AutoencoderKL.from_pretrained(
|
||||
from_pretrained, cache_dir=cache_dir, local_files_only=local_files_only
|
||||
)
|
||||
self.out_channels = self.module.config.latent_channels
|
||||
self.patch_size = (1, 8, 8)
|
||||
self.micro_batch_size = micro_batch_size
|
||||
|
|
@ -70,9 +72,11 @@ class VideoAutoencoderKL(nn.Module):
|
|||
|
||||
@MODELS.register_module()
|
||||
class VideoAutoencoderKLTemporalDecoder(nn.Module):
|
||||
def __init__(self, from_pretrained=None, cache_dir=None):
|
||||
def __init__(self, from_pretrained=None, cache_dir=None, local_files_only=False):
|
||||
super().__init__()
|
||||
self.module = AutoencoderKLTemporalDecoder.from_pretrained(from_pretrained, cache_dir=cache_dir)
|
||||
self.module = AutoencoderKLTemporalDecoder.from_pretrained(
|
||||
from_pretrained, cache_dir=cache_dir, local_files_only=local_files_only
|
||||
)
|
||||
self.out_channels = self.module.config.latent_channels
|
||||
self.patch_size = (1, 8, 8)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue