From 137d0ac22348b9d0bee0bd7b66ce2a1ee7a12153 Mon Sep 17 00:00:00 2001 From: zhengzangw Date: Thu, 16 May 2024 08:33:01 +0000 Subject: [PATCH] [fix] t5 dtype --- opensora/models/text_encoder/t5.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/opensora/models/text_encoder/t5.py b/opensora/models/text_encoder/t5.py index aaf2ecf..3a021d7 100644 --- a/opensora/models/text_encoder/t5.py +++ b/opensora/models/text_encoder/t5.py @@ -161,6 +161,7 @@ class T5Encoder: self.model_max_length = model_max_length self.output_dim = self.t5.model.config.d_model + self.dtype = dtype if shardformer: self.shardformer_t5() @@ -183,7 +184,7 @@ class T5Encoder: ) shard_former = ShardFormer(shard_config=shard_config) optim_model, _ = shard_former.optimize(self.t5.model, policy=T5EncoderPolicy()) - self.t5.model = optim_model.half() + self.t5.model = optim_model.to(self.dtype) # ensure the weights are frozen requires_grad(self.t5.model, False)