diff --git a/opensora/models/vae/vae_3d.py b/opensora/models/vae/vae_3d.py index d4a47ad..4bed5ad 100644 --- a/opensora/models/vae/vae_3d.py +++ b/opensora/models/vae/vae_3d.py @@ -681,7 +681,7 @@ class Decoder(nn.Module): # conv blocks with upsampling if i > 0: - if self.temporal_downsample[i]: + if self.temporal_downsample[i-1]: t_stride = 2 if self.temporal_downsample[i - 1] else 1 # SCH: T-Causal Conv 3x3x3, f -> (t_stride * 2 * 2) * f, depth to space t_stride x 2 x 2 self.conv_blocks.insert( @@ -697,7 +697,7 @@ class Decoder(nn.Module): prev_filters ), ) - + self.norm1 = nn.GroupNorm(self.num_groups, prev_filters, device=device, dtype=dtype) diff --git a/scripts/inference-vae.py b/scripts/inference-vae.py index 63030db..b6a4e52 100644 --- a/scripts/inference-vae.py +++ b/scripts/inference-vae.py @@ -9,8 +9,8 @@ from tqdm import tqdm from opensora.acceleration.parallel_states import get_data_parallel_group from opensora.datasets import prepare_dataloader, save_sample -from opensora.models.vae.losses import AdversarialLoss, DiscriminatorLoss, VEALoss -from opensora.models.vae.vae_3d import LeCamEMA, pad_at_dim +from opensora.models.vae.losses import AdversarialLoss, DiscriminatorLoss, VEALoss, LeCamEMA +from opensora.models.vae.vae_3d import pad_at_dim from opensora.registry import DATASETS, MODELS, build_module from opensora.utils.config_utils import parse_configs from opensora.utils.misc import to_torch_dtype