mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-05-09 09:04:09 +02:00
debug
This commit is contained in:
parent
4741d952d8
commit
dbe1c4cf2f
|
|
@ -114,12 +114,11 @@ class CausalConv3d(nn.Module):
|
|||
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
|
||||
|
||||
stride = strides if strides is not None else (stride, 1, 1)
|
||||
padding = kwargs.pop('padding', 0)
|
||||
|
||||
if padding == "same" and not all([pad == 1 for pad in padding]):
|
||||
padding = "valid"
|
||||
# padding = kwargs.pop('padding', 0)
|
||||
# if padding == "same" and not all([pad == 1 for pad in padding]):
|
||||
# padding = "valid"
|
||||
dilation = (dilation, 1, 1)
|
||||
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride = stride, dilation = dilation, padding=padding, **kwargs)
|
||||
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride = stride, dilation = dilation, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
pad_mode = self.pad_mode if self.time_pad < x.shape[2] else 'constant'
|
||||
|
|
@ -372,7 +371,7 @@ class Encoder(nn.Module):
|
|||
|
||||
self.conv_fn = functools.partial(
|
||||
CausalConv3d,
|
||||
padding='valid' if self.custom_conv_padding is not None else 'same', # SCH: lower letter for pytorch
|
||||
# padding='valid' if self.custom_conv_padding is not None else 'same', # SCH: lower letter for pytorch
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
|
@ -485,7 +484,7 @@ class Decoder(nn.Module):
|
|||
self.conv_fn = functools.partial(
|
||||
CausalConv3d,
|
||||
dtype=dtype,
|
||||
padding='valid' if self.custom_conv_padding is not None else 'same', # SCH: lower letter for pytorch
|
||||
# padding='valid' if self.custom_conv_padding is not None else 'same', # SCH: lower letter for pytorch
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue