mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
disable non-master print
This commit is contained in:
parent
8548b88ca8
commit
d565477680
|
|
@ -65,7 +65,7 @@ discriminator_loss_type="non-saturating"
|
|||
generator_loss_type="non-saturating"
|
||||
# discriminator_loss_type="hinge"
|
||||
# generator_loss_type="hinge"
|
||||
discriminator_start = 50000 # 50000 NOTE: change to correct val, debug use -1 for now
|
||||
discriminator_start = 100 # 8k data / (8*32) = 31 steps per epoch, use around 3 epochs
|
||||
gradient_penalty_loss_weight = None # 10 # SCH: MAGVIT uses 10, opensora plan doesn't use
|
||||
ema_decay = 0.999 # ema decay factor for generator
|
||||
|
||||
|
|
@ -84,7 +84,7 @@ magvit uses about # samples (K) * epochs ~ 2-5 K, num_frames = 4, reso = 128
|
|||
|
||||
epochs = 200
|
||||
log_every = 1
|
||||
ckpt_every = 1000
|
||||
ckpt_every = 50
|
||||
load = None
|
||||
|
||||
batch_size = 32
|
||||
|
|
|
|||
|
|
@ -2,7 +2,48 @@ import torch
|
|||
import torch.nn as nn
|
||||
from torchvision import models
|
||||
from collections import namedtuple
|
||||
from taming.util import get_ckpt_path
|
||||
import os, hashlib
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
URL_MAP = {
|
||||
"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
|
||||
}
|
||||
|
||||
CKPT_MAP = {
|
||||
"vgg_lpips": "vgg.pth"
|
||||
}
|
||||
|
||||
MD5_MAP = {
|
||||
"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
|
||||
}
|
||||
|
||||
def md5_hash(path):
|
||||
with open(path, "rb") as f:
|
||||
content = f.read()
|
||||
return hashlib.md5(content).hexdigest()
|
||||
|
||||
def download(url, local_path, chunk_size=1024):
|
||||
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
|
||||
with requests.get(url, stream=True) as r:
|
||||
total_size = int(r.headers.get("content-length", 0))
|
||||
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
|
||||
with open(local_path, "wb") as f:
|
||||
for data in r.iter_content(chunk_size=chunk_size):
|
||||
if data:
|
||||
f.write(data)
|
||||
pbar.update(chunk_size)
|
||||
|
||||
def get_ckpt_path(name, root, check=False):
|
||||
assert name in URL_MAP
|
||||
path = os.path.join(root, CKPT_MAP[name])
|
||||
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
|
||||
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
|
||||
download(URL_MAP[name], path)
|
||||
md5 = md5_hash(path)
|
||||
assert md5 == MD5_MAP[name], md5
|
||||
return path
|
||||
|
||||
|
||||
|
||||
class LPIPS(nn.Module):
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from einops import rearrange, repeat, pack, unpack
|
|||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
from torchvision.models import VGG16_Weights
|
||||
from taming.modules.losses.lpips import LPIPS # need to pip install https://github.com/CompVis/taming-transformers
|
||||
from opensora.models.vae.lpips import LPIPS # need to pip install https://github.com/CompVis/taming-transformers
|
||||
from torch import nn
|
||||
import math
|
||||
import os
|
||||
|
|
|
|||
|
|
@ -53,7 +53,6 @@ def main():
|
|||
# 1. args & cfg
|
||||
# ======================================================
|
||||
cfg = parse_configs(training=True)
|
||||
print(cfg)
|
||||
exp_name, exp_dir = create_experiment_workspace(cfg)
|
||||
save_training_config(cfg._cfg_dict, exp_dir)
|
||||
|
||||
|
|
@ -73,6 +72,7 @@ def main():
|
|||
if not coordinator.is_master():
|
||||
logger = create_logger(None)
|
||||
else:
|
||||
print(cfg)
|
||||
logger = create_logger(exp_dir)
|
||||
logger.info(f"Experiment directory created at {exp_dir}")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue