Open-Sora/opensora/models/text/conditioner.py
Zheng Zangwei (Alex Zheng) febf3ad4b2
Update Open-Sora 2.0 (#807)
* upload v2.0

* update docs

* [hotfix] fit latest fa3 (#802)

* update readme

* update readme

* update readme

* update train readme

* update readme

* update readme: motion score

* cleaning video dc ae WIP

* update config

* add dependency functions

* undo cleaning

* use latest dcae

* complete high compression training

* update hcae config

* cleaned up vae

* update ae.md

* further cleanup

* update vae & ae paths

* align naming of ae

* [hotfix] fix ring attn bwd for fa3 (#803)

* train ae default without wandb

* update config

* update evaluation results

* added hcae report

* update readme

* update readme demo

* update readme demo

* update readme gif

* display demo directly in readme

* update paper

* delete files

---------

Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Shen-Chenhui <shen_chenhui@u.nus.edu>
Co-authored-by: wuxiwen <wuxiwen.simon@gmail.com>
2025-03-12 13:14:22 +08:00

75 lines
2.9 KiB
Python

from colossalai.shardformer import ShardConfig, ShardFormer
from torch import Tensor, nn
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from opensora.acceleration.shardformer.policy.t5_encoder import T5EncoderPolicy
from opensora.registry import MODELS
@MODELS.register_module("text_embedder")
class HFEmbedder(nn.Module):
def __init__(self, from_pretrained: str, max_length: int, shardformer: bool = False, **hf_kwargs):
super().__init__()
self.is_clip = "openai" in from_pretrained
self.max_length = max_length
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
if self.is_clip:
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(from_pretrained, max_length=max_length)
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(from_pretrained, **hf_kwargs)
assert not shardformer, "Shardformer is not supported for CLIP"
else:
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(
from_pretrained, max_length=max_length, legacy=True
)
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(from_pretrained, **hf_kwargs)
if shardformer:
self.hf_module = shardformer_t5(self.hf_module)
self.hf_module = self.hf_module.eval().requires_grad_(False)
def forward(self, text: list[str], added_tokens: int = 0, seq_align: int = 1) -> Tensor:
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=False,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
seq_len = batch_encoding["input_ids"].shape[1]
if (added_tokens + seq_len) % seq_align != 0:
num_pad_tokens = seq_align - (added_tokens + seq_len) % seq_align
batch_encoding["input_ids"] = nn.functional.pad(
batch_encoding["input_ids"], (0, num_pad_tokens), value=self.tokenizer.pad_token_id
)
outputs = self.hf_module(
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
attention_mask=None,
output_hidden_states=False,
)
return outputs[self.output_key]
def shardformer_t5(t5: T5EncoderModel) -> T5EncoderModel:
"""
Shardformer for T5 model
Args:
t5: T5 model to be optimized
Returns:
optimized T5 model
"""
dtype = t5.shared.weight.dtype
shard_config = ShardConfig(
enable_tensor_parallelism=False,
enable_jit_fused=True,
)
shard_former = ShardFormer(shard_config=shard_config)
optim_model, _ = shard_former.optimize(t5, policy=T5EncoderPolicy())
optim_model = optim_model.to(dtype).eval().requires_grad_(False)
return optim_model