mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
168 lines
6.4 KiB
Python
168 lines
6.4 KiB
Python
import os
|
|
from pprint import pformat
|
|
|
|
import colossalai
|
|
import torch
|
|
from tqdm import tqdm
|
|
|
|
from opensora.acceleration.parallel_states import get_data_parallel_group
|
|
from opensora.datasets import prepare_variable_dataloader
|
|
from opensora.datasets.utils import collate_fn_ignore_none
|
|
from opensora.registry import DATASETS, MODELS, build_module
|
|
from opensora.utils.config_utils import parse_configs, save_training_config
|
|
from opensora.utils.misc import FeatureSaver, Timer, create_logger, format_numel_str, get_model_numel, to_torch_dtype
|
|
|
|
|
|
def main():
|
|
torch.set_grad_enabled(False)
|
|
# ======================================================
|
|
# 1. configs & runtime variables
|
|
# ======================================================
|
|
# == parse configs ==
|
|
cfg = parse_configs(training=False)
|
|
|
|
# == device and dtype ==
|
|
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
|
|
cfg_dtype = cfg.get("dtype", "bf16")
|
|
assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}"
|
|
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
|
|
|
|
# == colossalai init distributed training ==
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
cfg_dtype = cfg.get("dtype", "fp32")
|
|
assert cfg_dtype in ["fp16", "bf16", "fp32"], f"Unknown mixed precision {cfg_dtype}"
|
|
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
|
|
colossalai.launch_from_torch({})
|
|
|
|
# == init logger, tensorboard & wandb ==
|
|
logger = create_logger()
|
|
logger.info("Configuration:\n %s", pformat(cfg.to_dict()))
|
|
|
|
# ======================================================
|
|
# 2. build dataset and dataloader
|
|
# ======================================================
|
|
logger.info("Building dataset...")
|
|
# == build dataset ==
|
|
dataset = build_module(cfg.dataset, DATASETS)
|
|
logger.info("Dataset contains %s samples.", len(dataset))
|
|
|
|
# == build dataloader ==
|
|
dataloader_args = dict(
|
|
dataset=dataset,
|
|
batch_size=cfg.get("batch_size", None),
|
|
num_workers=cfg.get("num_workers", 4),
|
|
seed=cfg.get("seed", 1024),
|
|
shuffle=True,
|
|
drop_last=True,
|
|
pin_memory=True,
|
|
process_group=get_data_parallel_group(),
|
|
collate_fn=collate_fn_ignore_none,
|
|
)
|
|
dataloader = prepare_variable_dataloader(
|
|
bucket_config=cfg.get("bucket_config", None),
|
|
num_bucket_build_workers=cfg.get("num_bucket_build_workers", 1),
|
|
**dataloader_args,
|
|
)
|
|
num_batch = dataloader.batch_sampler.get_num_batch()
|
|
|
|
# ======================================================
|
|
# 3. build model
|
|
# ======================================================
|
|
logger.info("Building models...")
|
|
# == build text-encoder and vae ==
|
|
text_encoder = build_module(cfg.text_encoder, MODELS, device=device, dtype=dtype)
|
|
vae = build_module(cfg.vae, MODELS).to(device, dtype).eval()
|
|
|
|
# == build diffusion model ==
|
|
input_size = (dataset.num_frames, *dataset.image_size)
|
|
latent_size = vae.get_latent_size(input_size)
|
|
model = (
|
|
build_module(
|
|
cfg.model,
|
|
MODELS,
|
|
input_size=latent_size,
|
|
in_channels=vae.out_channels,
|
|
caption_channels=text_encoder.output_dim,
|
|
model_max_length=text_encoder.model_max_length,
|
|
)
|
|
.to(device, dtype)
|
|
.train()
|
|
)
|
|
model_numel, model_numel_trainable = get_model_numel(model)
|
|
logger.info(
|
|
"[Diffusion] Trainable model params: %s, Total model params: %s",
|
|
format_numel_str(model_numel_trainable),
|
|
format_numel_str(model_numel),
|
|
)
|
|
|
|
# =======================================================
|
|
# 5. training loop
|
|
# =======================================================
|
|
# == global variables ==
|
|
bin_size = cfg.bin_size
|
|
save_text_features = cfg.get("save_text_features", False)
|
|
save_compressed_text_features = cfg.get("save_compressed_text_features", False)
|
|
|
|
# == number of bins ==
|
|
num_bin = num_batch // bin_size
|
|
logger.info("Number of batches: %s", num_batch)
|
|
logger.info("Bin size: %s", bin_size)
|
|
logger.info("Number of bins: %s", num_bin)
|
|
|
|
# resume from a specific batch index
|
|
start_index = cfg.get("start_index", 0)
|
|
end_index = cfg.get("end_index", num_bin)
|
|
dataloader.batch_sampler.load_state_dict({"last_micro_batch_access_index": start_index})
|
|
num_bin_to_process = min(num_bin, end_index) - start_index
|
|
logger.info("Start index: %s", start_index)
|
|
logger.info("End index: %s", end_index)
|
|
logger.info("Number of batches to process: %s", num_bin_to_process)
|
|
|
|
# create save directory
|
|
assert cfg.get("save_dir", None) is not None, "Please specify the save_dir in the config file."
|
|
save_dir = os.path.join(cfg.save_dir, f"s{start_index}_e{end_index}")
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
save_training_config(cfg.to_dict(), save_dir)
|
|
logger.info("Saving features to %s", save_dir)
|
|
|
|
saver = FeatureSaver(save_dir, bin_size, start_bin=start_index)
|
|
|
|
# == training loop in an epoch ==
|
|
dataloader_iter = iter(dataloader)
|
|
log_time = cfg.get("log_time", False)
|
|
for i in tqdm(range(0, num_bin_to_process * bin_size)):
|
|
with Timer("step", log=log_time):
|
|
with Timer("data loading", log=log_time):
|
|
batch = next(dataloader_iter)
|
|
x = batch.pop("video").to(device, dtype) # [B, C, T, H, W]
|
|
y = batch.pop("text")
|
|
|
|
with Timer("vae", log=log_time):
|
|
x = vae.encode(x)
|
|
with Timer("feature to cpu", log=log_time):
|
|
x = x.cpu()
|
|
fps = batch["fps"].to(dtype)
|
|
batch_dict = {"x": x, "fps": fps}
|
|
|
|
if save_text_features:
|
|
with Timer("text", log=log_time):
|
|
text_infos = text_encoder.encode(y)
|
|
y_feat = text_infos["y"]
|
|
y_mask = text_infos["mask"]
|
|
if save_compressed_text_features:
|
|
y_feat, y_mask = model.encode_text(y_feat, y_mask)
|
|
y_mask = torch.tensor(y_mask)
|
|
with Timer("feature to cpu", log=log_time):
|
|
y_feat = y_feat.cpu()
|
|
y_mask = y_mask.cpu()
|
|
batch_dict.update({"y": y_feat, "mask": y_mask})
|
|
|
|
saver.update(batch_dict)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|