[fix] t5 dtype

This commit is contained in:
zhengzangw 2024-05-16 08:33:01 +00:00
parent b8b0398eaa
commit 137d0ac223

View file

@ -161,6 +161,7 @@ class T5Encoder:
self.model_max_length = model_max_length
self.output_dim = self.t5.model.config.d_model
self.dtype = dtype
if shardformer:
self.shardformer_t5()
@ -183,7 +184,7 @@ class T5Encoder:
)
shard_former = ShardFormer(shard_config=shard_config)
optim_model, _ = shard_former.optimize(self.t5.model, policy=T5EncoderPolicy())
self.t5.model = optim_model.half()
self.t5.model = optim_model.to(self.dtype)
# ensure the weights are frozen
requires_grad(self.t5.model, False)