From b8d99d7a8b65d0e07513df412c5c1e098e489235 Mon Sep 17 00:00:00 2001 From: Shen Chenhui Date: Mon, 13 May 2024 11:58:06 +0800 Subject: [PATCH] fixed (#98) Co-authored-by: Shen-Chenhui --- configs/opensora-v1-2/inference/sample.py | 4 +++- opensora/models/vae/vae.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/configs/opensora-v1-2/inference/sample.py b/configs/opensora-v1-2/inference/sample.py index c46e4dd..65ccc1b 100644 --- a/configs/opensora-v1-2/inference/sample.py +++ b/configs/opensora-v1-2/inference/sample.py @@ -20,7 +20,9 @@ model = dict( vae = dict( type="VideoAutoencoderPipeline", from_pretrained="pretrained_models/vae-v3", - scale=2.5, + # scale=2.5, + shift=(-0.10, 0.34, 0.27, 0.98), + scale=(3.85, 2.32, 2.33, 3.06), micro_frame_size=17, vae_2d=dict( type="VideoAutoencoderKL", diff --git a/opensora/models/vae/vae.py b/opensora/models/vae/vae.py index 8ac8729..2115e63 100644 --- a/opensora/models/vae/vae.py +++ b/opensora/models/vae/vae.py @@ -171,7 +171,7 @@ class VideoAutoencoderPipeline(nn.Module): def decode(self, z, num_frames=None): if not self.cal_loss: - z = z * self.scale + self.shift + z = z * self.scale.to(z.dtype) + self.shift.to(z.dtype) if self.micro_frame_size is None: x_z = self.temporal_vae.decode(z, num_frames=num_frames)