mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-05-20 09:22:22 +02:00
debug
This commit is contained in:
parent
036b427b00
commit
6f460c1d05
|
|
@ -302,7 +302,9 @@ class VEA3DLossWithPerceptualLoss(nn.Module):
|
|||
permutated_input.reshape(-1, data_shape[-3], data_shape[-2],data_shape[-1]).contiguous(),
|
||||
permutated_rec.reshape(-1, data_shape[-3], data_shape[-2],data_shape[-1]).contiguous()
|
||||
)
|
||||
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
||||
# SCH: shape back p_loss
|
||||
permuted_p_loss = torch.permute(p_loss.reshape(data_shape[0], data_shape[1], 1, 1, 1), (0,2,1,3,4))
|
||||
rec_loss = rec_loss + self.perceptual_weight * permuted_p_loss
|
||||
|
||||
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
|
||||
weighted_nll_loss = nll_loss
|
||||
|
|
|
|||
Loading…
Reference in a new issue