mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
parent
11faae8db6
commit
b8d99d7a8b
|
|
@ -20,7 +20,9 @@ model = dict(
|
|||
vae = dict(
|
||||
type="VideoAutoencoderPipeline",
|
||||
from_pretrained="pretrained_models/vae-v3",
|
||||
scale=2.5,
|
||||
# scale=2.5,
|
||||
shift=(-0.10, 0.34, 0.27, 0.98),
|
||||
scale=(3.85, 2.32, 2.33, 3.06),
|
||||
micro_frame_size=17,
|
||||
vae_2d=dict(
|
||||
type="VideoAutoencoderKL",
|
||||
|
|
|
|||
|
|
@ -171,7 +171,7 @@ class VideoAutoencoderPipeline(nn.Module):
|
|||
|
||||
def decode(self, z, num_frames=None):
|
||||
if not self.cal_loss:
|
||||
z = z * self.scale + self.shift
|
||||
z = z * self.scale.to(z.dtype) + self.shift.to(z.dtype)
|
||||
|
||||
if self.micro_frame_size is None:
|
||||
x_z = self.temporal_vae.decode(z, num_frames=num_frames)
|
||||
|
|
|
|||
Loading…
Reference in a new issue