mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 12:49:38 +02:00
[fix] t5 dtype
This commit is contained in:
parent
b8b0398eaa
commit
137d0ac223
|
|
@ -161,6 +161,7 @@ class T5Encoder:
|
|||
|
||||
self.model_max_length = model_max_length
|
||||
self.output_dim = self.t5.model.config.d_model
|
||||
self.dtype = dtype
|
||||
|
||||
if shardformer:
|
||||
self.shardformer_t5()
|
||||
|
|
@ -183,7 +184,7 @@ class T5Encoder:
|
|||
)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
optim_model, _ = shard_former.optimize(self.t5.model, policy=T5EncoderPolicy())
|
||||
self.t5.model = optim_model.half()
|
||||
self.t5.model = optim_model.to(self.dtype)
|
||||
|
||||
# ensure the weights are frozen
|
||||
requires_grad(self.t5.model, False)
|
||||
|
|
|
|||
Loading…
Reference in a new issue