mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-05-07 04:41:22 +02:00
* upload v2.0 * update docs * [hotfix] fit latest fa3 (#802) * update readme * update readme * update readme * update train readme * update readme * update readme: motion score * cleaning video dc ae WIP * update config * add dependency functions * undo cleaning * use latest dcae * complete high compression training * update hcae config * cleaned up vae * update ae.md * further cleanup * update vae & ae paths * align naming of ae * [hotfix] fix ring attn bwd for fa3 (#803) * train ae default without wandb * update config * update evaluation results * added hcae report * update readme * update readme demo * update readme demo * update readme gif * display demo directly in readme * update paper * delete files --------- Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: Shen-Chenhui <shen_chenhui@u.nus.edu> Co-authored-by: wuxiwen <wuxiwen.simon@gmail.com>
75 lines
2.9 KiB
Python
75 lines
2.9 KiB
Python
from colossalai.shardformer import ShardConfig, ShardFormer
|
|
from torch import Tensor, nn
|
|
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
|
|
|
from opensora.acceleration.shardformer.policy.t5_encoder import T5EncoderPolicy
|
|
from opensora.registry import MODELS
|
|
|
|
|
|
@MODELS.register_module("text_embedder")
|
|
class HFEmbedder(nn.Module):
|
|
def __init__(self, from_pretrained: str, max_length: int, shardformer: bool = False, **hf_kwargs):
|
|
super().__init__()
|
|
self.is_clip = "openai" in from_pretrained
|
|
self.max_length = max_length
|
|
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
|
|
|
|
if self.is_clip:
|
|
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(from_pretrained, max_length=max_length)
|
|
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(from_pretrained, **hf_kwargs)
|
|
assert not shardformer, "Shardformer is not supported for CLIP"
|
|
else:
|
|
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(
|
|
from_pretrained, max_length=max_length, legacy=True
|
|
)
|
|
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(from_pretrained, **hf_kwargs)
|
|
if shardformer:
|
|
self.hf_module = shardformer_t5(self.hf_module)
|
|
|
|
self.hf_module = self.hf_module.eval().requires_grad_(False)
|
|
|
|
def forward(self, text: list[str], added_tokens: int = 0, seq_align: int = 1) -> Tensor:
|
|
batch_encoding = self.tokenizer(
|
|
text,
|
|
truncation=True,
|
|
max_length=self.max_length,
|
|
return_length=False,
|
|
return_overflowing_tokens=False,
|
|
padding="max_length",
|
|
return_tensors="pt",
|
|
)
|
|
seq_len = batch_encoding["input_ids"].shape[1]
|
|
if (added_tokens + seq_len) % seq_align != 0:
|
|
num_pad_tokens = seq_align - (added_tokens + seq_len) % seq_align
|
|
batch_encoding["input_ids"] = nn.functional.pad(
|
|
batch_encoding["input_ids"], (0, num_pad_tokens), value=self.tokenizer.pad_token_id
|
|
)
|
|
|
|
outputs = self.hf_module(
|
|
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
|
|
attention_mask=None,
|
|
output_hidden_states=False,
|
|
)
|
|
return outputs[self.output_key]
|
|
|
|
|
|
def shardformer_t5(t5: T5EncoderModel) -> T5EncoderModel:
|
|
"""
|
|
Shardformer for T5 model
|
|
|
|
Args:
|
|
t5: T5 model to be optimized
|
|
|
|
Returns:
|
|
optimized T5 model
|
|
"""
|
|
dtype = t5.shared.weight.dtype
|
|
shard_config = ShardConfig(
|
|
enable_tensor_parallelism=False,
|
|
enable_jit_fused=True,
|
|
)
|
|
shard_former = ShardFormer(shard_config=shard_config)
|
|
optim_model, _ = shard_former.optimize(t5, policy=T5EncoderPolicy())
|
|
optim_model = optim_model.to(dtype).eval().requires_grad_(False)
|
|
return optim_model
|