mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-17 14:25:07 +02:00
enable pipeline vae training
This commit is contained in:
parent
aba1de6eb3
commit
246eddecdd
|
|
@ -25,7 +25,7 @@ vae_2d = dict(
|
|||
|
||||
model = dict(
|
||||
type="VAE_MAGVIT_V2",
|
||||
in_out_channels = 3,
|
||||
in_out_channels = 4,
|
||||
latent_embed_dim = 256,
|
||||
filters = 128,
|
||||
num_res_blocks = 4,
|
||||
|
|
@ -44,7 +44,7 @@ discriminator = dict(
|
|||
type="DISCRIMINATOR_3D",
|
||||
image_size = image_size,
|
||||
num_frames = num_frames,
|
||||
in_channels = 3,
|
||||
in_channels = 4,
|
||||
filters = 128,
|
||||
channel_multipliers = (2,4,4,4,4), # (2,4,4,4) for 64x64 resolution
|
||||
)
|
||||
|
|
|
|||
|
|
@ -932,6 +932,8 @@ class VAE_3D_V2(nn.Module): # , ModelMixin
|
|||
video_contains_first_frame = video_contains_first_frame
|
||||
)
|
||||
|
||||
breakpoint()
|
||||
|
||||
return recon_video, posterior
|
||||
|
||||
|
||||
|
|
@ -991,13 +993,18 @@ class VEALoss(nn.Module):
|
|||
if self.perceptual_loss_weight is not None and self.perceptual_loss_weight > 0.0:
|
||||
# handle channels
|
||||
channels = video.shape[1]
|
||||
assert channels in {1,3,4}
|
||||
if channels == 1:
|
||||
input_vgg_input = repeat(input_vgg_input, 'b 1 h w -> b c h w', c = 3)
|
||||
recon_vgg_input = repeat(recon_vgg_input, 'b 1 h w -> b c h w', c = 3)
|
||||
input_vgg_input = repeat(video, 'b 1 h w -> b c h w', c = 3)
|
||||
recon_vgg_input = repeat(recon_video, 'b 1 h w -> b c h w', c = 3)
|
||||
elif channels == 4: # SCH: take the first 3 for perceptual loss calc
|
||||
input_vgg_input = input_vgg_input[:, :3]
|
||||
recon_vgg_input = recon_vgg_input[:, :3]
|
||||
perceptual_loss = self.perceptual_loss_fn(video, recon_video)
|
||||
input_vgg_input = video[:, :3]
|
||||
recon_vgg_input = recon_video[:, :3]
|
||||
else:
|
||||
input_vgg_input = video
|
||||
recon_vgg_input = recon_video
|
||||
|
||||
perceptual_loss = self.perceptual_loss_fn(input_vgg_input, recon_vgg_input)
|
||||
recon_loss = recon_loss + self.perceptual_loss_weight * perceptual_loss
|
||||
|
||||
nll_loss = recon_loss / torch.exp(self.logvar) + self.logvar
|
||||
|
|
@ -1223,7 +1230,6 @@ def load_checkpoint_with_inflation(model, ckpt_path):
|
|||
"""
|
||||
if ckpt_path.endswith(".pt") or ckpt_path.endswith(".pth"):
|
||||
state_dict = find_model(ckpt_path)
|
||||
breakpoint() # NOTE: need to manually check before first use
|
||||
with torch.no_grad():
|
||||
for key in state_dict:
|
||||
if key in model:
|
||||
|
|
|
|||
|
|
@ -314,13 +314,11 @@ def main():
|
|||
video_contains_first_frame = True
|
||||
else:
|
||||
video = x
|
||||
|
||||
|
||||
# ===== Spatial VAE =====
|
||||
if cfg.get("use_pipeline") == True:
|
||||
with torch.no_grad():
|
||||
video = vae.encode(video)
|
||||
|
||||
breakpoint()
|
||||
video = vae_2d.encode(video)
|
||||
|
||||
# ====== VAE ======
|
||||
recon_video, posterior = vae(
|
||||
|
|
|
|||
Loading…
Reference in a new issue