Open-Sora/scripts/misc/extract_feat.py
2024-05-17 10:05:49 +00:00

168 lines
6.4 KiB
Python

import os
from pprint import pformat
import colossalai
import torch
from tqdm import tqdm
from opensora.acceleration.parallel_states import get_data_parallel_group
from opensora.datasets import prepare_variable_dataloader
from opensora.datasets.utils import collate_fn_ignore_none
from opensora.registry import DATASETS, MODELS, build_module
from opensora.utils.config_utils import parse_configs, save_training_config
from opensora.utils.misc import FeatureSaver, Timer, create_logger, format_numel_str, get_model_numel, to_torch_dtype
def main():
torch.set_grad_enabled(False)
# ======================================================
# 1. configs & runtime variables
# ======================================================
# == parse configs ==
cfg = parse_configs(training=False)
# == device and dtype ==
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
cfg_dtype = cfg.get("dtype", "bf16")
assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}"
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
# == colossalai init distributed training ==
device = "cuda" if torch.cuda.is_available() else "cpu"
cfg_dtype = cfg.get("dtype", "fp32")
assert cfg_dtype in ["fp16", "bf16", "fp32"], f"Unknown mixed precision {cfg_dtype}"
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
colossalai.launch_from_torch({})
# == init logger, tensorboard & wandb ==
logger = create_logger()
logger.info("Configuration:\n %s", pformat(cfg.to_dict()))
# ======================================================
# 2. build dataset and dataloader
# ======================================================
logger.info("Building dataset...")
# == build dataset ==
dataset = build_module(cfg.dataset, DATASETS)
logger.info("Dataset contains %s samples.", len(dataset))
# == build dataloader ==
dataloader_args = dict(
dataset=dataset,
batch_size=cfg.get("batch_size", None),
num_workers=cfg.get("num_workers", 4),
seed=cfg.get("seed", 1024),
shuffle=True,
drop_last=True,
pin_memory=True,
process_group=get_data_parallel_group(),
collate_fn=collate_fn_ignore_none,
)
dataloader = prepare_variable_dataloader(
bucket_config=cfg.get("bucket_config", None),
num_bucket_build_workers=cfg.get("num_bucket_build_workers", 1),
**dataloader_args,
)
num_batch = dataloader.batch_sampler.get_num_batch()
# ======================================================
# 3. build model
# ======================================================
logger.info("Building models...")
# == build text-encoder and vae ==
text_encoder = build_module(cfg.text_encoder, MODELS, device=device, dtype=dtype)
vae = build_module(cfg.vae, MODELS).to(device, dtype).eval()
# == build diffusion model ==
input_size = (dataset.num_frames, *dataset.image_size)
latent_size = vae.get_latent_size(input_size)
model = (
build_module(
cfg.model,
MODELS,
input_size=latent_size,
in_channels=vae.out_channels,
caption_channels=text_encoder.output_dim,
model_max_length=text_encoder.model_max_length,
)
.to(device, dtype)
.train()
)
model_numel, model_numel_trainable = get_model_numel(model)
logger.info(
"[Diffusion] Trainable model params: %s, Total model params: %s",
format_numel_str(model_numel_trainable),
format_numel_str(model_numel),
)
# =======================================================
# 5. training loop
# =======================================================
# == global variables ==
bin_size = cfg.bin_size
save_text_features = cfg.get("save_text_features", False)
save_compressed_text_features = cfg.get("save_compressed_text_features", False)
# == number of bins ==
num_bin = num_batch // bin_size
logger.info("Number of batches: %s", num_batch)
logger.info("Bin size: %s", bin_size)
logger.info("Number of bins: %s", num_bin)
# resume from a specific batch index
start_index = cfg.get("start_index", 0)
end_index = cfg.get("end_index", num_bin)
dataloader.batch_sampler.load_state_dict({"last_micro_batch_access_index": start_index})
num_bin_to_process = min(num_bin, end_index) - start_index
logger.info("Start index: %s", start_index)
logger.info("End index: %s", end_index)
logger.info("Number of batches to process: %s", num_bin_to_process)
# create save directory
assert cfg.get("save_dir", None) is not None, "Please specify the save_dir in the config file."
save_dir = os.path.join(cfg.save_dir, f"s{start_index}_e{end_index}")
os.makedirs(save_dir, exist_ok=True)
save_training_config(cfg.to_dict(), save_dir)
logger.info("Saving features to %s", save_dir)
saver = FeatureSaver(save_dir, bin_size, start_bin=start_index)
# == training loop in an epoch ==
dataloader_iter = iter(dataloader)
log_time = cfg.get("log_time", False)
for i in tqdm(range(0, num_bin_to_process * bin_size)):
with Timer("step", log=log_time):
with Timer("data loading", log=log_time):
batch = next(dataloader_iter)
x = batch.pop("video").to(device, dtype) # [B, C, T, H, W]
y = batch.pop("text")
with Timer("vae", log=log_time):
x = vae.encode(x)
with Timer("feature to cpu", log=log_time):
x = x.cpu()
fps = batch["fps"].to(dtype)
batch_dict = {"x": x, "fps": fps}
if save_text_features:
with Timer("text", log=log_time):
text_infos = text_encoder.encode(y)
y_feat = text_infos["y"]
y_mask = text_infos["mask"]
if save_compressed_text_features:
y_feat, y_mask = model.encode_text(y_feat, y_mask)
y_mask = torch.tensor(y_mask)
with Timer("feature to cpu", log=log_time):
y_feat = y_feat.cpu()
y_mask = y_mask.cpu()
batch_dict.update({"y": y_feat, "mask": y_mask})
saver.update(batch_dict)
if __name__ == "__main__":
main()