diff --git a/opensora/models/text_encoder/t5.py b/opensora/models/text_encoder/t5.py index aaf2ecf..3a021d7 100644 --- a/opensora/models/text_encoder/t5.py +++ b/opensora/models/text_encoder/t5.py @@ -161,6 +161,7 @@ class T5Encoder: self.model_max_length = model_max_length self.output_dim = self.t5.model.config.d_model + self.dtype = dtype if shardformer: self.shardformer_t5() @@ -183,7 +184,7 @@ class T5Encoder: ) shard_former = ShardFormer(shard_config=shard_config) optim_model, _ = shard_former.optimize(self.t5.model, policy=T5EncoderPolicy()) - self.t5.model = optim_model.half() + self.t5.model = optim_model.to(self.dtype) # ensure the weights are frozen requires_grad(self.t5.model, False)