mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-05-11 20:28:26 +02:00
22 lines
610 B
Python
22 lines
610 B
Python
import torch
|
|
|
|
from opensora.registry import MODELS
|
|
|
|
|
|
@MODELS.register_module("classes")
|
|
class ClassEncoder:
|
|
def __init__(self, num_classes, model_max_length=None, device="cuda", dtype=torch.float):
|
|
self.num_classes = num_classes
|
|
self.y_embedder = None
|
|
|
|
self.model_max_length = model_max_length
|
|
self.output_dim = None
|
|
self.device = device
|
|
self.tokenize_fn = None
|
|
|
|
def encode(self, input_ids, attention_mask=None):
|
|
return dict(y=input_ids.to(self.device))
|
|
|
|
def null(self, n):
|
|
return torch.tensor([self.num_classes] * n).to(self.device)
|