Open-Sora/opensora/models/text_encoder/classes.py
Zheng Zangwei (Alex Zheng) f1c6b8b88e open-sora v1.3 code upload (#786)
Co-authored-by: gxyes <gxynoz@gmail.com>
2025-02-20 16:50:24 +08:00

22 lines
610 B
Python

import torch
from opensora.registry import MODELS
@MODELS.register_module("classes")
class ClassEncoder:
def __init__(self, num_classes, model_max_length=None, device="cuda", dtype=torch.float):
self.num_classes = num_classes
self.y_embedder = None
self.model_max_length = model_max_length
self.output_dim = None
self.device = device
self.tokenize_fn = None
def encode(self, input_ids, attention_mask=None):
return dict(y=input_ids.to(self.device))
def null(self, n):
return torch.tensor([self.num_classes] * n).to(self.device)