mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 21:01:26 +02:00
[fix] ema ckpt save
This commit is contained in:
parent
9be806803b
commit
74dfaa879c
|
|
@ -24,6 +24,7 @@ dataset = dict(
|
|||
# # ---
|
||||
# "2048": {1: (0.1, 5)},
|
||||
# }
|
||||
|
||||
# webvid
|
||||
bucket_config = { # 20s/it
|
||||
"144p": {1: (1.0, 100), 51: (1.0, 30), 102: ((1.0, 0.33), 20), 204: ((1.0, 0.1), 8), 408: ((1.0, 0.1), 4)},
|
||||
|
|
@ -32,16 +33,16 @@ bucket_config = { # 20s/it
|
|||
"240p": {1: (0.3, 100), 51: (0.4, 24), 102: ((0.4, 0.33), 12), 204: ((0.4, 0.1), 4), 408: ((0.4, 0.1), 2)},
|
||||
# ---
|
||||
"360p": {1: (0.2, 60), 51: (0.15, 12), 102: ((0.15, 0.33), 6), 204: ((0.15, 0.1), 2), 408: ((0.15, 0.1), 1)},
|
||||
# "512": {1: (0.1, 60), 51: (0.3, 12), 102: (0.3, 6), 204: (0.3, 2), 408: (0.3, 1)},
|
||||
# # ---
|
||||
# "480p": {1: (0.1, 40), 51: (0.3, 6), 102: (0.3, 3), 204: (0.3, 1), 408: (0.0, None)},
|
||||
# # ---
|
||||
# "720p": {1: (0.05, 20), 51: (0.3, 2), 102: (0.3, 1), 204: (0.0, None)},
|
||||
# "1024": {1: (0.05, 20), 51: (0.3, 2), 102: (0.3, 1), 204: (0.0, None)},
|
||||
# # ---
|
||||
# "1080p": {1: (0.1, 10)},
|
||||
# # ---
|
||||
# "2048": {1: (0.1, 5)},
|
||||
"512": {1: (0.1, 60), 51: (0.3, 12), 102: (0.3, 6), 204: (0.3, 2), 408: (0.3, 1)},
|
||||
# ---
|
||||
"480p": {1: (0.1, 40), 51: (0.3, 6), 102: (0.3, 3), 204: (0.3, 1), 408: (0.0, None)},
|
||||
# ---
|
||||
"720p": {1: (0.05, 20), 51: (0.3, 2), 102: (0.3, 1), 204: (0.0, None)},
|
||||
"1024": {1: (0.05, 20), 51: (0.3, 2), 102: (0.3, 1), 204: (0.0, None)},
|
||||
# ---
|
||||
"1080p": {1: (0.1, 10)},
|
||||
# ---
|
||||
"2048": {1: (0.1, 5)},
|
||||
}
|
||||
|
||||
grad_checkpoint = True
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import collections
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Sequence
|
||||
|
|
@ -16,19 +17,30 @@ import torch.distributed as dist
|
|||
# ======================================================
|
||||
|
||||
|
||||
def create_logger(logging_dir):
|
||||
def is_distributed():
|
||||
return os.environ.get("WORLD_SIZE", None) is not None
|
||||
|
||||
|
||||
def is_main_process():
|
||||
return not is_distributed() or dist.get_rank() == 0
|
||||
|
||||
|
||||
def create_logger(logging_dir=None):
|
||||
"""
|
||||
Create a logger that writes to a log file and stdout.
|
||||
"""
|
||||
if dist.get_rank() == 0: # real logger
|
||||
if is_main_process(): # real logger
|
||||
additional_args = dict()
|
||||
if logging_dir is not None:
|
||||
additional_args["handlers"] = [
|
||||
logging.StreamHandler(),
|
||||
logging.FileHandler(f"{logging_dir}/log.txt"),
|
||||
]
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="[\033[34m%(asctime)s\033[0m] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
handlers=[
|
||||
logging.StreamHandler(),
|
||||
logging.FileHandler(f"{logging_dir}/log.txt"),
|
||||
],
|
||||
**additional_args,
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
else: # dummy logger (does nothing)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from opensora.datasets import IMG_FPS, save_sample
|
|||
from opensora.models.text_encoder.t5 import text_preprocessing
|
||||
from opensora.registry import MODELS, SCHEDULERS, build_module
|
||||
from opensora.utils.config_utils import parse_configs
|
||||
from opensora.utils.misc import to_torch_dtype
|
||||
from opensora.utils.misc import create_logger, is_distributed, is_main_process, to_torch_dtype
|
||||
|
||||
|
||||
def main():
|
||||
|
|
@ -31,29 +31,28 @@ def main():
|
|||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
verbose = cfg.get("verbose", 2)
|
||||
print(cfg)
|
||||
|
||||
# init distributed
|
||||
if os.environ.get("WORLD_SIZE", None):
|
||||
use_dist = True
|
||||
# == init distributed env ==
|
||||
if is_distributed():
|
||||
colossalai.launch_from_torch({})
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
if coordinator.world_size > 1:
|
||||
enable_sequence_parallelism = coordinator.world_size > 1
|
||||
if enable_sequence_parallelism:
|
||||
set_sequence_parallel_group(dist.group.WORLD)
|
||||
enable_sequence_parallelism = True
|
||||
else:
|
||||
enable_sequence_parallelism = False
|
||||
else:
|
||||
use_dist = False
|
||||
enable_sequence_parallelism = False
|
||||
set_random_seed(seed=cfg.seed)
|
||||
|
||||
# == init logger ==
|
||||
create_logger()
|
||||
verbose = cfg.get("verbose", 1)
|
||||
breakpoint()
|
||||
|
||||
print(cfg)
|
||||
|
||||
# ======================================================
|
||||
# 2. runtime variables
|
||||
# ======================================================
|
||||
|
||||
set_random_seed(seed=cfg.seed)
|
||||
prompts = cfg.prompt
|
||||
|
||||
# ======================================================
|
||||
|
|
@ -167,7 +166,7 @@ def main():
|
|||
samples = vae.decode(samples.to(dtype), num_frames=cfg.num_frames)
|
||||
|
||||
# 4.4. save samples
|
||||
if not use_dist or coordinator.is_master():
|
||||
if is_main_process():
|
||||
for idx, sample in enumerate(samples):
|
||||
if verbose >= 2:
|
||||
print(f"Prompt: {batch_prompts_raw[idx]}")
|
||||
|
|
|
|||
|
|
@ -313,7 +313,8 @@ def main():
|
|||
global_step=global_step + 1,
|
||||
batch_size=cfg.get("batch_size", None),
|
||||
)
|
||||
model_sharding(ema)
|
||||
if dist.get_rank() == 0:
|
||||
model_sharding(ema)
|
||||
logger.info(
|
||||
"Saved checkpoint at epoch %s step %s global_step %s to %s",
|
||||
epoch,
|
||||
|
|
|
|||
Loading…
Reference in a new issue