This commit is contained in:
Shen-Chenhui 2024-04-01 15:38:29 +08:00
parent 996e6fb180
commit 036b427b00

View file

@ -67,8 +67,8 @@ class LPIPS(nn.Module):
class ScalingLayer(nn.Module):
def __init__(self):
super(ScalingLayer, self).__init__()
self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None, None])
self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None, None])
self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
def forward(self, inp):
return (inp - self.shift) / self.scale