disable non-master print

This commit is contained in:
Shen-Chenhui 2024-04-23 14:43:02 +08:00
parent 8548b88ca8
commit d565477680
4 changed files with 46 additions and 5 deletions

View file

@ -65,7 +65,7 @@ discriminator_loss_type="non-saturating"
generator_loss_type="non-saturating"
# discriminator_loss_type="hinge"
# generator_loss_type="hinge"
discriminator_start = 50000 # 50000 NOTE: change to correct val, debug use -1 for now
discriminator_start = 100 # 8k data / (8*32) = 31 steps per epoch, use around 3 epochs
gradient_penalty_loss_weight = None # 10 # SCH: MAGVIT uses 10, opensora plan doesn't use
ema_decay = 0.999 # ema decay factor for generator
@ -84,7 +84,7 @@ magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128
epochs = 200
log_every = 1
ckpt_every = 1000
ckpt_every = 50
load = None
batch_size = 32

View file

@ -2,7 +2,48 @@ import torch
import torch.nn as nn
from torchvision import models
from collections import namedtuple
from taming.util import get_ckpt_path
import os, hashlib
import requests
from tqdm import tqdm
URL_MAP = {
"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
}
CKPT_MAP = {
"vgg_lpips": "vgg.pth"
}
MD5_MAP = {
"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
}
def md5_hash(path):
with open(path, "rb") as f:
content = f.read()
return hashlib.md5(content).hexdigest()
def download(url, local_path, chunk_size=1024):
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
with requests.get(url, stream=True) as r:
total_size = int(r.headers.get("content-length", 0))
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
with open(local_path, "wb") as f:
for data in r.iter_content(chunk_size=chunk_size):
if data:
f.write(data)
pbar.update(chunk_size)
def get_ckpt_path(name, root, check=False):
assert name in URL_MAP
path = os.path.join(root, CKPT_MAP[name])
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
download(URL_MAP[name], path)
md5 = md5_hash(path)
assert md5 == MD5_MAP[name], md5
return path
class LPIPS(nn.Module):

View file

@ -12,7 +12,7 @@ from einops import rearrange, repeat, pack, unpack
import torch.nn.functional as F
import torchvision
from torchvision.models import VGG16_Weights
from taming.modules.losses.lpips import LPIPS # need to pip install https://github.com/CompVis/taming-transformers
from opensora.models.vae.lpips import LPIPS # need to pip install https://github.com/CompVis/taming-transformers
from torch import nn
import math
import os

View file

@ -53,7 +53,6 @@ def main():
# 1. args & cfg
# ======================================================
cfg = parse_configs(training=True)
print(cfg)
exp_name, exp_dir = create_experiment_workspace(cfg)
save_training_config(cfg._cfg_dict, exp_dir)
@ -73,6 +72,7 @@ def main():
if not coordinator.is_master():
logger = create_logger(None)
else:
print(cfg)
logger = create_logger(exp_dir)
logger.info(f"Experiment directory created at {exp_dir}")