[fix] ema ckpt save

This commit is contained in:
zhengzangw 2024-05-12 05:15:06 +00:00
parent 9be806803b
commit 74dfaa879c
4 changed files with 45 additions and 32 deletions

View file

@ -24,6 +24,7 @@ dataset = dict(
# # ---
# "2048": {1: (0.1, 5)},
# }
# webvid
bucket_config = { # 20s/it
"144p": {1: (1.0, 100), 51: (1.0, 30), 102: ((1.0, 0.33), 20), 204: ((1.0, 0.1), 8), 408: ((1.0, 0.1), 4)},
@ -32,16 +33,16 @@ bucket_config = { # 20s/it
"240p": {1: (0.3, 100), 51: (0.4, 24), 102: ((0.4, 0.33), 12), 204: ((0.4, 0.1), 4), 408: ((0.4, 0.1), 2)},
# ---
"360p": {1: (0.2, 60), 51: (0.15, 12), 102: ((0.15, 0.33), 6), 204: ((0.15, 0.1), 2), 408: ((0.15, 0.1), 1)},
# "512": {1: (0.1, 60), 51: (0.3, 12), 102: (0.3, 6), 204: (0.3, 2), 408: (0.3, 1)},
# # ---
# "480p": {1: (0.1, 40), 51: (0.3, 6), 102: (0.3, 3), 204: (0.3, 1), 408: (0.0, None)},
# # ---
# "720p": {1: (0.05, 20), 51: (0.3, 2), 102: (0.3, 1), 204: (0.0, None)},
# "1024": {1: (0.05, 20), 51: (0.3, 2), 102: (0.3, 1), 204: (0.0, None)},
# # ---
# "1080p": {1: (0.1, 10)},
# # ---
# "2048": {1: (0.1, 5)},
"512": {1: (0.1, 60), 51: (0.3, 12), 102: (0.3, 6), 204: (0.3, 2), 408: (0.3, 1)},
# ---
"480p": {1: (0.1, 40), 51: (0.3, 6), 102: (0.3, 3), 204: (0.3, 1), 408: (0.0, None)},
# ---
"720p": {1: (0.05, 20), 51: (0.3, 2), 102: (0.3, 1), 204: (0.0, None)},
"1024": {1: (0.05, 20), 51: (0.3, 2), 102: (0.3, 1), 204: (0.0, None)},
# ---
"1080p": {1: (0.1, 10)},
# ---
"2048": {1: (0.1, 5)},
}
grad_checkpoint = True

View file

@ -1,6 +1,7 @@
import collections
import importlib
import logging
import os
import time
from collections import OrderedDict
from collections.abc import Sequence
@ -16,19 +17,30 @@ import torch.distributed as dist
# ======================================================
def create_logger(logging_dir):
def is_distributed():
return os.environ.get("WORLD_SIZE", None) is not None
def is_main_process():
return not is_distributed() or dist.get_rank() == 0
def create_logger(logging_dir=None):
"""
Create a logger that writes to a log file and stdout.
"""
if dist.get_rank() == 0: # real logger
if is_main_process(): # real logger
additional_args = dict()
if logging_dir is not None:
additional_args["handlers"] = [
logging.StreamHandler(),
logging.FileHandler(f"{logging_dir}/log.txt"),
]
logging.basicConfig(
level=logging.INFO,
format="[\033[34m%(asctime)s\033[0m] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[
logging.StreamHandler(),
logging.FileHandler(f"{logging_dir}/log.txt"),
],
**additional_args,
)
logger = logging.getLogger(__name__)
else: # dummy logger (does nothing)

View file

@ -12,7 +12,7 @@ from opensora.datasets import IMG_FPS, save_sample
from opensora.models.text_encoder.t5 import text_preprocessing
from opensora.registry import MODELS, SCHEDULERS, build_module
from opensora.utils.config_utils import parse_configs
from opensora.utils.misc import to_torch_dtype
from opensora.utils.misc import create_logger, is_distributed, is_main_process, to_torch_dtype
def main():
@ -31,29 +31,28 @@ def main():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
verbose = cfg.get("verbose", 2)
print(cfg)
# init distributed
if os.environ.get("WORLD_SIZE", None):
use_dist = True
# == init distributed env ==
if is_distributed():
colossalai.launch_from_torch({})
coordinator = DistCoordinator()
if coordinator.world_size > 1:
enable_sequence_parallelism = coordinator.world_size > 1
if enable_sequence_parallelism:
set_sequence_parallel_group(dist.group.WORLD)
enable_sequence_parallelism = True
else:
enable_sequence_parallelism = False
else:
use_dist = False
enable_sequence_parallelism = False
set_random_seed(seed=cfg.seed)
# == init logger ==
create_logger()
verbose = cfg.get("verbose", 1)
breakpoint()
print(cfg)
# ======================================================
# 2. runtime variables
# ======================================================
set_random_seed(seed=cfg.seed)
prompts = cfg.prompt
# ======================================================
@ -167,7 +166,7 @@ def main():
samples = vae.decode(samples.to(dtype), num_frames=cfg.num_frames)
# 4.4. save samples
if not use_dist or coordinator.is_master():
if is_main_process():
for idx, sample in enumerate(samples):
if verbose >= 2:
print(f"Prompt: {batch_prompts_raw[idx]}")

View file

@ -313,7 +313,8 @@ def main():
global_step=global_step + 1,
batch_size=cfg.get("batch_size", None),
)
model_sharding(ema)
if dist.get_rank() == 0:
model_sharding(ema)
logger.info(
"Saved checkpoint at epoch %s step %s global_step %s to %s",
epoch,