This commit is contained in:
Shen-Chenhui 2024-04-29 20:12:38 +08:00
parent c584743d50
commit da5f42d488
2 changed files with 4 additions and 4 deletions

View file

@ -681,7 +681,7 @@ class Decoder(nn.Module):
# conv blocks with upsampling
if i > 0:
if self.temporal_downsample[i]:
if self.temporal_downsample[i-1]:
t_stride = 2 if self.temporal_downsample[i - 1] else 1
# SCH: T-Causal Conv 3x3x3, f -> (t_stride * 2 * 2) * f, depth to space t_stride x 2 x 2
self.conv_blocks.insert(
@ -697,7 +697,7 @@ class Decoder(nn.Module):
prev_filters
),
)
self.norm1 = nn.GroupNorm(self.num_groups, prev_filters, device=device, dtype=dtype)

View file

@ -9,8 +9,8 @@ from tqdm import tqdm
from opensora.acceleration.parallel_states import get_data_parallel_group
from opensora.datasets import prepare_dataloader, save_sample
from opensora.models.vae.losses import AdversarialLoss, DiscriminatorLoss, VEALoss
from opensora.models.vae.vae_3d import LeCamEMA, pad_at_dim
from opensora.models.vae.losses import AdversarialLoss, DiscriminatorLoss, VEALoss, LeCamEMA
from opensora.models.vae.vae_3d import pad_at_dim
from opensora.registry import DATASETS, MODELS, build_module
from opensora.utils.config_utils import parse_configs
from opensora.utils.misc import to_torch_dtype