diff --git a/configs/vae_magvit_v2/train/16x128x128_pixabay_8_GPU_pipeline.py b/configs/vae_magvit_v2/train/16x128x128_pixabay_8_GPU_pipeline.py index dab0c98..b356c46 100644 --- a/configs/vae_magvit_v2/train/16x128x128_pixabay_8_GPU_pipeline.py +++ b/configs/vae_magvit_v2/train/16x128x128_pixabay_8_GPU_pipeline.py @@ -65,7 +65,7 @@ discriminator_loss_type="non-saturating" generator_loss_type="non-saturating" # discriminator_loss_type="hinge" # generator_loss_type="hinge" -discriminator_start = 50000 # 50000 NOTE: change to correct val, debug use -1 for now +discriminator_start = 100 # 8k data / (8*32) = 31 steps per epoch, use around 3 epochs gradient_penalty_loss_weight = None # 10 # SCH: MAGVIT uses 10, opensora plan doesn't use ema_decay = 0.999 # ema decay factor for generator @@ -84,7 +84,7 @@ magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128 epochs = 200 log_every = 1 -ckpt_every = 1000 +ckpt_every = 50 load = None batch_size = 32 diff --git a/opensora/models/vae/lpips.py b/opensora/models/vae/lpips.py index 3dde500..b22cd0d 100644 --- a/opensora/models/vae/lpips.py +++ b/opensora/models/vae/lpips.py @@ -2,7 +2,48 @@ import torch import torch.nn as nn from torchvision import models from collections import namedtuple -from taming.util import get_ckpt_path +import os, hashlib +import requests +from tqdm import tqdm + +URL_MAP = { + "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" +} + +CKPT_MAP = { + "vgg_lpips": "vgg.pth" +} + +MD5_MAP = { + "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" +} + +def md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + +def download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + +def get_ckpt_path(name, root, check=False): + assert name in URL_MAP + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): + print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) + download(URL_MAP[name], path) + md5 = md5_hash(path) + assert md5 == MD5_MAP[name], md5 + return path + class LPIPS(nn.Module): diff --git a/opensora/models/vae/vae_3d_v2.py b/opensora/models/vae/vae_3d_v2.py index 39e687a..628c013 100644 --- a/opensora/models/vae/vae_3d_v2.py +++ b/opensora/models/vae/vae_3d_v2.py @@ -12,7 +12,7 @@ from einops import rearrange, repeat, pack, unpack import torch.nn.functional as F import torchvision from torchvision.models import VGG16_Weights -from taming.modules.losses.lpips import LPIPS # need to pip install https://github.com/CompVis/taming-transformers +from opensora.models.vae.lpips import LPIPS # need to pip install https://github.com/CompVis/taming-transformers from torch import nn import math import os diff --git a/scripts/train-vae-v2.py b/scripts/train-vae-v2.py index 0b19998..0f5993c 100644 --- a/scripts/train-vae-v2.py +++ b/scripts/train-vae-v2.py @@ -53,7 +53,6 @@ def main(): # 1. args & cfg # ====================================================== cfg = parse_configs(training=True) - print(cfg) exp_name, exp_dir = create_experiment_workspace(cfg) save_training_config(cfg._cfg_dict, exp_dir) @@ -73,6 +72,7 @@ def main(): if not coordinator.is_master(): logger = create_logger(None) else: + print(cfg) logger = create_logger(exp_dir) logger.info(f"Experiment directory created at {exp_dir}")