mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 12:49:38 +02:00
[feat] enable freezing y_embedder
This commit is contained in:
parent
a24149e265
commit
a016325a8e
|
|
@ -156,11 +156,11 @@ conda activate opensora
|
|||
# install torch, torchvision and xformers
|
||||
pip install -r requirements/requirements_cu121.txt
|
||||
|
||||
# install this project
|
||||
# download the repo
|
||||
git clone https://github.com/hpcaitech/Open-Sora
|
||||
cd Open-Sora
|
||||
|
||||
# the default installation is for inference only
|
||||
# install Open-Sora
|
||||
pip install -v .
|
||||
```
|
||||
|
||||
|
|
|
|||
60
configs/opensora-v1-2/misc/extract.py
Normal file
60
configs/opensora-v1-2/misc/extract.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
# Dataset settings
|
||||
dataset = dict(
|
||||
type="VariableVideoTextDataset",
|
||||
transform_name="resize_crop",
|
||||
)
|
||||
|
||||
# webvid
|
||||
bucket_config = { # 12s/it
|
||||
"144p": {1: (1.0, 475), 51: (1.0, 51), 102: ((1.0, 0.33), 27), 204: ((1.0, 0.1), 13), 408: ((1.0, 0.1), 6)},
|
||||
# ---
|
||||
"256": {1: (0.4, 297), 51: (0.5, 20), 102: ((0.5, 0.33), 10), 204: ((0.5, 0.1), 5), 408: ((0.5, 0.1), 2)},
|
||||
"240p": {1: (0.3, 297), 51: (0.4, 20), 102: ((0.4, 0.33), 10), 204: ((0.4, 0.1), 5), 408: ((0.4, 0.1), 2)},
|
||||
# ---
|
||||
"360p": {1: (0.2, 141), 51: (0.15, 8), 102: ((0.15, 0.33), 4), 204: ((0.15, 0.1), 2), 408: ((0.15, 0.1), 1)},
|
||||
"512": {1: (0.1, 141)},
|
||||
# ---
|
||||
"480p": {1: (0.1, 89)},
|
||||
# ---
|
||||
"720p": {1: (0.05, 36)},
|
||||
"1024": {1: (0.05, 36)},
|
||||
# ---
|
||||
"1080p": {1: (0.1, 5)},
|
||||
# ---
|
||||
"2048": {1: (0.1, 5)},
|
||||
}
|
||||
|
||||
# Acceleration settings
|
||||
num_workers = 8
|
||||
num_bucket_build_workers = 16
|
||||
dtype = "bf16"
|
||||
seed = 42
|
||||
outputs = "outputs"
|
||||
wandb = False
|
||||
|
||||
|
||||
# Model settings
|
||||
model = dict(
|
||||
type="STDiT3-XL/2",
|
||||
from_pretrained=None,
|
||||
qk_norm=True,
|
||||
enable_flash_attn=True,
|
||||
enable_layernorm_kernel=True,
|
||||
)
|
||||
vae = dict(
|
||||
type="OpenSoraVAE_V1_2",
|
||||
from_pretrained="pretrained_models/vae-pipeline",
|
||||
micro_frame_size=17,
|
||||
micro_batch_size=4,
|
||||
)
|
||||
text_encoder = dict(
|
||||
type="t5",
|
||||
from_pretrained="DeepFloyd/t5-v1_1-xxl",
|
||||
model_max_length=300,
|
||||
shardformer=True,
|
||||
local_files_only=True,
|
||||
)
|
||||
|
||||
# feature extraction settings
|
||||
save_text_features = True
|
||||
save_compressed_text_features = True
|
||||
|
|
@ -59,6 +59,7 @@ model = dict(
|
|||
qk_norm=True,
|
||||
enable_flash_attn=True,
|
||||
enable_layernorm_kernel=True,
|
||||
freeze_y_embedder=True,
|
||||
)
|
||||
vae = dict(
|
||||
type="OpenSoraVAE_V1_2",
|
||||
|
|
|
|||
|
|
@ -169,6 +169,8 @@ class STDiT3Config(PretrainedConfig):
|
|||
enable_layernorm_kernel=False,
|
||||
enable_sequence_parallelism=False,
|
||||
only_train_temporal=False,
|
||||
freeze_y_embedder=False,
|
||||
skip_y_embedder=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.input_size = input_size
|
||||
|
|
@ -189,6 +191,8 @@ class STDiT3Config(PretrainedConfig):
|
|||
self.enable_layernorm_kernel = enable_layernorm_kernel
|
||||
self.enable_sequence_parallelism = enable_sequence_parallelism
|
||||
self.only_train_temporal = only_train_temporal
|
||||
self.freeze_y_embedder = freeze_y_embedder
|
||||
self.skip_y_embedder = skip_y_embedder
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
|
|
@ -284,6 +288,10 @@ class STDiT3(PreTrainedModel):
|
|||
for param in block.parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
if config.freeze_y_embedder:
|
||||
for param in self.y_embedder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def initialize_weights(self):
|
||||
# Initialize transformer layers:
|
||||
def _basic_init(module):
|
||||
|
|
@ -319,6 +327,19 @@ class STDiT3(PreTrainedModel):
|
|||
W = W // self.patch_size[2]
|
||||
return (T, H, W)
|
||||
|
||||
def encode_text(self, y, mask=None):
|
||||
y = self.y_embedder(y, self.training) # [B, 1, N_token, C]
|
||||
if mask is not None:
|
||||
if mask.shape[0] != y.shape[0]:
|
||||
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
|
||||
mask = mask.squeeze(1).squeeze(1)
|
||||
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, self.hidden_size)
|
||||
y_lens = mask.sum(dim=1).tolist()
|
||||
else:
|
||||
y_lens = [y.shape[2]] * y.shape[0]
|
||||
y = y.squeeze(1).view(1, -1, self.hidden_size)
|
||||
return y, y_lens
|
||||
|
||||
def forward(self, x, timestep, y, mask=None, x_mask=None, fps=None, height=None, width=None, **kwargs):
|
||||
dtype = self.x_embedder.proj.weight.dtype
|
||||
B = x.size(0)
|
||||
|
|
@ -348,16 +369,10 @@ class STDiT3(PreTrainedModel):
|
|||
t0_mlp = self.t_block(t0)
|
||||
|
||||
# === get y embed ===
|
||||
y = self.y_embedder(y, self.training) # [B, 1, N_token, C]
|
||||
if mask is not None:
|
||||
if mask.shape[0] != y.shape[0]:
|
||||
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
|
||||
mask = mask.squeeze(1).squeeze(1)
|
||||
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, self.hidden_size)
|
||||
y_lens = mask.sum(dim=1).tolist()
|
||||
if self.config.skip_y_embedder:
|
||||
y_lens = mask.tolist()
|
||||
else:
|
||||
y_lens = [y.shape[2]] * y.shape[0]
|
||||
y = y.squeeze(1).view(1, -1, self.hidden_size)
|
||||
y, y_lens = self.encode_text(y, mask)
|
||||
|
||||
# === get x embed ===
|
||||
x = self.x_embedder(x) # [B, N, C]
|
||||
|
|
|
|||
|
|
@ -378,3 +378,23 @@ class Timer:
|
|||
self.end_time = time.time()
|
||||
if self.log:
|
||||
print(f"Elapsed time for {self.name}: {self.elapsed_time:.2f} s")
|
||||
|
||||
|
||||
def get_tensor_memory(tensor, human_readable=True):
|
||||
size = tensor.element_size() * tensor.nelement()
|
||||
if human_readable:
|
||||
size = format_numel_str(size)
|
||||
return size
|
||||
|
||||
|
||||
class FeatureSaver:
|
||||
def __init__(self, save_dir, bin_size=10):
|
||||
self.save_dir = save_dir
|
||||
self.bin_size = bin_size
|
||||
|
||||
self.data_list = []
|
||||
self.cnt = 0
|
||||
|
||||
def update(self, data):
|
||||
self.data_list.append(data)
|
||||
self.cnt += 1
|
||||
|
|
|
|||
167
scripts/misc/extract_feat.py
Normal file
167
scripts/misc/extract_feat.py
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
import os
|
||||
from pprint import pformat
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from tqdm import tqdm
|
||||
|
||||
from opensora.acceleration.parallel_states import get_data_parallel_group
|
||||
from opensora.datasets import prepare_dataloader, prepare_variable_dataloader
|
||||
from opensora.datasets.utils import collate_fn_ignore_none
|
||||
from opensora.registry import DATASETS, MODELS, build_module
|
||||
from opensora.utils.config_utils import parse_configs
|
||||
from opensora.utils.misc import FeatureSaver, create_logger, format_numel_str, get_model_numel, to_torch_dtype
|
||||
|
||||
DEFAULT_DATASET_NAME = "VideoTextDataset"
|
||||
|
||||
|
||||
def main():
|
||||
torch.set_grad_enabled(False)
|
||||
# ======================================================
|
||||
# 1. configs & runtime variables
|
||||
# ======================================================
|
||||
# == parse configs ==
|
||||
cfg = parse_configs(training=False)
|
||||
|
||||
# == device and dtype ==
|
||||
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
|
||||
cfg_dtype = cfg.get("dtype", "bf16")
|
||||
assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}"
|
||||
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
|
||||
|
||||
# == colossalai init distributed training ==
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
cfg_dtype = cfg.get("dtype", "fp32")
|
||||
assert cfg_dtype in ["fp16", "bf16", "fp32"], f"Unknown mixed precision {cfg_dtype}"
|
||||
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
# == init logger, tensorboard & wandb ==
|
||||
logger = create_logger()
|
||||
logger.info("Configuration:\n %s", pformat(cfg.to_dict()))
|
||||
|
||||
# ======================================================
|
||||
# 2. build dataset and dataloader
|
||||
# ======================================================
|
||||
logger.info("Building dataset...")
|
||||
# == build dataset ==
|
||||
dataset = build_module(cfg.dataset, DATASETS)
|
||||
logger.info("Dataset contains %s samples.", len(dataset))
|
||||
|
||||
# == build dataloader ==
|
||||
dataloader_args = dict(
|
||||
dataset=dataset,
|
||||
batch_size=cfg.get("batch_size", None),
|
||||
num_workers=cfg.get("num_workers", 4),
|
||||
seed=cfg.get("seed", 1024),
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
pin_memory=True,
|
||||
process_group=get_data_parallel_group(),
|
||||
collate_fn=collate_fn_ignore_none,
|
||||
)
|
||||
if cfg.dataset.type == DEFAULT_DATASET_NAME:
|
||||
dataloader = prepare_dataloader(**dataloader_args)
|
||||
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.get("sp_size", 1)
|
||||
logger.info("Total batch size: %s", total_batch_size)
|
||||
num_steps_per_epoch = len(dataloader)
|
||||
else:
|
||||
dataloader = prepare_variable_dataloader(
|
||||
bucket_config=cfg.get("bucket_config", None),
|
||||
num_bucket_build_workers=cfg.get("num_bucket_build_workers", 1),
|
||||
**dataloader_args,
|
||||
)
|
||||
num_steps_per_epoch = dataloader.batch_sampler.get_num_batch()
|
||||
|
||||
# ======================================================
|
||||
# 3. build model
|
||||
# ======================================================
|
||||
logger.info("Building models...")
|
||||
# == build text-encoder and vae ==
|
||||
text_encoder = build_module(cfg.text_encoder, MODELS, device=device, dtype=dtype)
|
||||
vae = build_module(cfg.vae, MODELS).to(device, dtype).eval()
|
||||
|
||||
# == build diffusion model ==
|
||||
input_size = (dataset.num_frames, *dataset.image_size)
|
||||
latent_size = vae.get_latent_size(input_size)
|
||||
model = (
|
||||
build_module(
|
||||
cfg.model,
|
||||
MODELS,
|
||||
input_size=latent_size,
|
||||
in_channels=vae.out_channels,
|
||||
caption_channels=text_encoder.output_dim,
|
||||
model_max_length=text_encoder.model_max_length,
|
||||
)
|
||||
.to(device, dtype)
|
||||
.train()
|
||||
)
|
||||
model_numel, model_numel_trainable = get_model_numel(model)
|
||||
logger.info(
|
||||
"[Diffusion] Trainable model params: %s, Total model params: %s",
|
||||
format_numel_str(model_numel_trainable),
|
||||
format_numel_str(model_numel),
|
||||
)
|
||||
|
||||
# =======================================================
|
||||
# 4. distributed training preparation with colossalai
|
||||
# =======================================================
|
||||
# == global variables ==
|
||||
start_step = sampler_start_idx = 0
|
||||
logger.info("Training for with %s steps per epoch", num_steps_per_epoch)
|
||||
|
||||
if cfg.dataset.type == DEFAULT_DATASET_NAME:
|
||||
dataloader.sampler.set_start_index(sampler_start_idx)
|
||||
|
||||
# =======================================================
|
||||
# 5. training loop
|
||||
# =======================================================
|
||||
dist.barrier()
|
||||
for epoch in range(1):
|
||||
# == set dataloader to new epoch ==
|
||||
if cfg.dataset.type == DEFAULT_DATASET_NAME:
|
||||
dataloader.sampler.set_epoch(epoch)
|
||||
dataloader_iter = iter(dataloader)
|
||||
logger.info("Beginning epoch %s...", epoch)
|
||||
|
||||
# == training loop in an epoch ==
|
||||
assert cfg.get("save_dir", None) is not None, "Please specify the save_dir in the config file."
|
||||
os.makedirs(cfg.save_dir, exist_ok=True)
|
||||
saver = FeatureSaver(cfg.save_dir)
|
||||
save_text_features = cfg.get("save_text_features", False)
|
||||
save_compressed_text_features = cfg.get("save_compressed_text_features", False)
|
||||
|
||||
with tqdm(
|
||||
enumerate(dataloader_iter, start=start_step),
|
||||
desc=f"Epoch {epoch}",
|
||||
initial=start_step,
|
||||
total=num_steps_per_epoch,
|
||||
) as pbar:
|
||||
for step, batch in pbar:
|
||||
x = batch.pop("video").to(device, dtype) # [B, C, T, H, W]
|
||||
y = batch.pop("text")
|
||||
|
||||
x = vae.encode(x).cpu() # [B, C, T, H/P, W/P]
|
||||
fps = batch["fps"].to(dtype)
|
||||
batch_dict = {"x": x, "fps": fps}
|
||||
|
||||
if save_text_features:
|
||||
text_infos = text_encoder.encode(y)
|
||||
y_feat = text_infos["y"]
|
||||
y_mask = text_infos["mask"]
|
||||
if not save_compressed_text_features:
|
||||
y_feat = y_feat.cpu()
|
||||
y_mask = y_mask.cpu()
|
||||
else:
|
||||
y_feat, y_mask = model.encode_text(y_feat, y_mask)
|
||||
y_feat = y_feat.cpu()
|
||||
y_mask = torch.tensor(y_mask)
|
||||
breakpoint()
|
||||
batch_dict.update({"y": y_feat, "mask": y_mask})
|
||||
|
||||
saver.update(batch_dict)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,284 +0,0 @@
|
|||
import os
|
||||
from copy import deepcopy
|
||||
from datetime import timedelta
|
||||
from pprint import pformat
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device, set_seed
|
||||
from tqdm import tqdm
|
||||
|
||||
import wandb
|
||||
from opensora.acceleration.checkpoint import set_grad_checkpoint
|
||||
from opensora.acceleration.parallel_states import get_data_parallel_group
|
||||
from opensora.datasets import prepare_dataloader, prepare_variable_dataloader
|
||||
from opensora.datasets.utils import collate_fn_ignore_none
|
||||
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
|
||||
from opensora.utils.ckpt_utils import load, model_gathering, model_sharding, record_model_param_shape, save
|
||||
from opensora.utils.config_utils import define_experiment_workspace, parse_configs, save_training_config
|
||||
from opensora.utils.misc import (
|
||||
all_reduce_mean,
|
||||
create_logger,
|
||||
create_tensorboard_writer,
|
||||
format_numel_str,
|
||||
get_model_numel,
|
||||
requires_grad,
|
||||
to_torch_dtype,
|
||||
)
|
||||
from opensora.utils.train_utils import MaskGenerator, create_colossalai_plugin, update_ema
|
||||
|
||||
DEFAULT_DATASET_NAME = "VideoTextDataset"
|
||||
|
||||
|
||||
def main():
|
||||
# ======================================================
|
||||
# 1. configs & runtime variables
|
||||
# ======================================================
|
||||
# == parse configs ==
|
||||
cfg = parse_configs(training=False)
|
||||
|
||||
# == device and dtype ==
|
||||
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
|
||||
cfg_dtype = cfg.get("dtype", "bf16")
|
||||
assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}"
|
||||
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
|
||||
|
||||
# == colossalai init distributed training ==
|
||||
# NOTE: A very large timeout is set to avoid some processes exit early
|
||||
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
|
||||
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
|
||||
set_seed(cfg.get("seed", 1024))
|
||||
coordinator = DistCoordinator()
|
||||
device = get_current_device()
|
||||
|
||||
# == init exp_dir ==
|
||||
exp_name, exp_dir = define_experiment_workspace(cfg)
|
||||
coordinator.block_all()
|
||||
if coordinator.is_master():
|
||||
os.makedirs(exp_dir, exist_ok=True)
|
||||
save_training_config(cfg.to_dict(), exp_dir)
|
||||
coordinator.block_all()
|
||||
|
||||
# == init logger, tensorboard & wandb ==
|
||||
logger = create_logger(exp_dir)
|
||||
logger.info("Experiment directory created at %s", exp_dir)
|
||||
logger.info("Training configuration:\n %s", pformat(cfg.to_dict()))
|
||||
if coordinator.is_master():
|
||||
tb_writer = create_tensorboard_writer(exp_dir)
|
||||
if cfg.get("wandb", False):
|
||||
wandb.init(project="minisora", name=exp_name, config=cfg.to_dict(), dir="./outputs/wandb")
|
||||
|
||||
# == init ColossalAI booster ==
|
||||
plugin = create_colossalai_plugin(
|
||||
plugin=cfg.get("plugin", "zero2"),
|
||||
dtype=cfg_dtype,
|
||||
grad_clip=cfg.get("grad_clip", 0),
|
||||
sp_size=cfg.get("sp_size", 1),
|
||||
)
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
# ======================================================
|
||||
# 2. build dataset and dataloader
|
||||
# ======================================================
|
||||
logger.info("Building dataset...")
|
||||
# == build dataset ==
|
||||
dataset = build_module(cfg.dataset, DATASETS)
|
||||
logger.info("Dataset contains %s samples.", len(dataset))
|
||||
|
||||
# == build dataloader ==
|
||||
dataloader_args = dict(
|
||||
dataset=dataset,
|
||||
batch_size=cfg.get("batch_size", None),
|
||||
num_workers=cfg.get("num_workers", 4),
|
||||
seed=cfg.get("seed", 1024),
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
pin_memory=True,
|
||||
process_group=get_data_parallel_group(),
|
||||
collate_fn=collate_fn_ignore_none,
|
||||
)
|
||||
if cfg.dataset.type == DEFAULT_DATASET_NAME:
|
||||
dataloader = prepare_dataloader(**dataloader_args)
|
||||
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.get("sp_size", 1)
|
||||
logger.info("Total batch size: %s", total_batch_size)
|
||||
num_steps_per_epoch = len(dataloader)
|
||||
sampler_to_io = None
|
||||
else:
|
||||
dataloader = prepare_variable_dataloader(
|
||||
bucket_config=cfg.get("bucket_config", None),
|
||||
num_bucket_build_workers=cfg.get("num_bucket_build_workers", 1),
|
||||
**dataloader_args,
|
||||
)
|
||||
num_steps_per_epoch = dataloader.batch_sampler.get_num_batch() // dist.get_world_size()
|
||||
sampler_to_io = None if cfg.get("start_from_scratch ", False) else dataloader.batch_sampler
|
||||
|
||||
# ======================================================
|
||||
# 3. build model
|
||||
# ======================================================
|
||||
logger.info("Building models...")
|
||||
# == build text-encoder and vae ==
|
||||
text_encoder = build_module(cfg.text_encoder, MODELS, device=device, dtype=dtype)
|
||||
vae = build_module(cfg.vae, MODELS).to(device, dtype).eval()
|
||||
|
||||
# == build diffusion model ==
|
||||
input_size = (dataset.num_frames, *dataset.image_size)
|
||||
latent_size = vae.get_latent_size(input_size)
|
||||
model = (
|
||||
build_module(
|
||||
cfg.model,
|
||||
MODELS,
|
||||
input_size=latent_size,
|
||||
in_channels=vae.out_channels,
|
||||
caption_channels=text_encoder.output_dim,
|
||||
model_max_length=text_encoder.model_max_length,
|
||||
)
|
||||
.to(device, dtype)
|
||||
.train()
|
||||
)
|
||||
model_numel, model_numel_trainable = get_model_numel(model)
|
||||
logger.info(
|
||||
"[Diffusion] Trainable model params: %s, Total model params: %s",
|
||||
format_numel_str(model_numel_trainable),
|
||||
format_numel_str(model_numel),
|
||||
)
|
||||
|
||||
# == build ema for diffusion model ==
|
||||
ema = deepcopy(model).to(torch.float32).to(device)
|
||||
requires_grad(ema, False)
|
||||
ema_shape_dict = record_model_param_shape(ema)
|
||||
ema.eval()
|
||||
update_ema(ema, model, decay=0, sharded=False)
|
||||
|
||||
# == setup loss function, build scheduler ==
|
||||
scheduler = build_module(cfg.scheduler, SCHEDULERS)
|
||||
|
||||
# == setup optimizer ==
|
||||
optimizer = HybridAdam(
|
||||
filter(lambda p: p.requires_grad, model.parameters()),
|
||||
adamw_mode=True,
|
||||
lr=cfg.get("lr", 1e-4),
|
||||
weight_decay=cfg.get("weight_decay", 0),
|
||||
eps=cfg.get("adam_eps", 1e-8),
|
||||
)
|
||||
lr_scheduler = None
|
||||
|
||||
# == additional preparation ==
|
||||
if cfg.get("grad_checkpoint", False):
|
||||
set_grad_checkpoint(model)
|
||||
if cfg.get("mask_ratios", None) is not None:
|
||||
mask_generator = MaskGenerator(cfg.mask_ratios)
|
||||
|
||||
# =======================================================
|
||||
# 4. distributed training preparation with colossalai
|
||||
# =======================================================
|
||||
logger.info("Preparing for distributed training...")
|
||||
# == boosting ==
|
||||
# NOTE: we set dtype first to make initialization of model consistent with the dtype; then reset it to the fp32 as we make diffusion scheduler in fp32
|
||||
torch.set_default_dtype(dtype)
|
||||
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
dataloader=dataloader,
|
||||
)
|
||||
torch.set_default_dtype(torch.float)
|
||||
logger.info("Boosting model for distributed training")
|
||||
|
||||
# == global variables ==
|
||||
cfg_epochs = cfg.get("epochs", 1000)
|
||||
start_epoch = start_step = log_step = sampler_start_idx = acc_step = 0
|
||||
running_loss = 0.0
|
||||
logger.info("Training for %s epochs with %s steps per epoch", cfg_epochs, num_steps_per_epoch)
|
||||
|
||||
# == resume ==
|
||||
if cfg.get("load", None) is not None:
|
||||
logger.info("Loading checkpoint")
|
||||
ret = load(
|
||||
booster,
|
||||
cfg.load,
|
||||
model=model,
|
||||
ema=ema,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
sampler=sampler_to_io,
|
||||
)
|
||||
if not cfg.get("start_from_scratch ", False):
|
||||
start_epoch, start_step, sampler_start_idx = ret
|
||||
logger.info("Loaded checkpoint %s at epoch %s step %s", cfg.load, start_epoch, start_step)
|
||||
if cfg.dataset.type == DEFAULT_DATASET_NAME:
|
||||
dataloader.sampler.set_start_index(sampler_start_idx)
|
||||
|
||||
model_sharding(ema)
|
||||
|
||||
# =======================================================
|
||||
# 5. training loop
|
||||
# =======================================================
|
||||
dist.barrier()
|
||||
for epoch in range(1):
|
||||
# == set dataloader to new epoch ==
|
||||
if cfg.dataset.type == DEFAULT_DATASET_NAME:
|
||||
dataloader.sampler.set_epoch(epoch)
|
||||
dataloader_iter = iter(dataloader)
|
||||
logger.info("Beginning epoch %s...", epoch)
|
||||
|
||||
# == training loop in an epoch ==
|
||||
save_root = f'/mnt/nfs-207/sora_data/webvid-10M/feat'
|
||||
data_list = []
|
||||
cnt = 0
|
||||
max_len = 10
|
||||
|
||||
with tqdm(
|
||||
enumerate(dataloader_iter, start=start_step),
|
||||
desc=f"Epoch {epoch}",
|
||||
disable=not coordinator.is_master(),
|
||||
initial=start_step,
|
||||
total=num_steps_per_epoch,
|
||||
) as pbar:
|
||||
for step, batch in pbar:
|
||||
x = batch.pop("video").to(device, dtype) # [B, C, T, H, W]
|
||||
y = batch.pop("text")
|
||||
|
||||
# == visual and text encoding ==
|
||||
with torch.no_grad():
|
||||
# Prepare visual inputs
|
||||
x = vae.encode(x) # [B, C, T, H/P, W/P]
|
||||
# Prepare text inputs
|
||||
|
||||
model_args = {}
|
||||
# == video meta info ==
|
||||
for k, v in batch.items():
|
||||
model_args[k] = v.to(device, dtype)
|
||||
|
||||
# keys to save: ['x', 'y', 'mask', 'fps']
|
||||
data_i = {
|
||||
'x': x.detach().cpu(),
|
||||
'fps': model_args['fps'].detach().cpu(),
|
||||
}
|
||||
|
||||
save_text_feat = [True, False]
|
||||
if save_text_feat:
|
||||
with torch.no_grad():
|
||||
model_args_text = text_encoder.encode(y)
|
||||
data_i.update({
|
||||
'y': model_args_text['y'].detach().cpu(),
|
||||
'mask': model_args_text['mask'].detach().cpu(),
|
||||
})
|
||||
breakpoint()
|
||||
else:
|
||||
data_i['text'] = y
|
||||
|
||||
data_list.append(data_i)
|
||||
|
||||
if len(data_list) == max_len:
|
||||
save_path = os.path.join(save_root, f'data/{cnt}.bin')
|
||||
torch.save(data_list, save_path)
|
||||
|
||||
data_list = []
|
||||
cnt += 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in a new issue