This commit is contained in:
Shen-Chenhui 2024-04-01 16:06:47 +08:00
parent 036b427b00
commit 6f460c1d05

View file

@ -302,7 +302,9 @@ class VEA3DLossWithPerceptualLoss(nn.Module):
permutated_input.reshape(-1, data_shape[-3], data_shape[-2],data_shape[-1]).contiguous(),
permutated_rec.reshape(-1, data_shape[-3], data_shape[-2],data_shape[-1]).contiguous()
)
rec_loss = rec_loss + self.perceptual_weight * p_loss
# SCH: shape back p_loss
permuted_p_loss = torch.permute(p_loss.reshape(data_shape[0], data_shape[1], 1, 1, 1), (0,2,1,3,4))
rec_loss = rec_loss + self.perceptual_weight * permuted_p_loss
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
weighted_nll_loss = nll_loss