[feat] enable freezing y_embedder

This commit is contained in:
zhengzangw 2024-05-17 06:40:44 +00:00
parent a24149e265
commit a016325a8e
7 changed files with 274 additions and 295 deletions

View file

@ -156,11 +156,11 @@ conda activate opensora
# install torch, torchvision and xformers
pip install -r requirements/requirements_cu121.txt
# install this project
# download the repo
git clone https://github.com/hpcaitech/Open-Sora
cd Open-Sora
# the default installation is for inference only
# install Open-Sora
pip install -v .
```

View file

@ -0,0 +1,60 @@
# Dataset settings
dataset = dict(
type="VariableVideoTextDataset",
transform_name="resize_crop",
)
# webvid
bucket_config = { # 12s/it
"144p": {1: (1.0, 475), 51: (1.0, 51), 102: ((1.0, 0.33), 27), 204: ((1.0, 0.1), 13), 408: ((1.0, 0.1), 6)},
# ---
"256": {1: (0.4, 297), 51: (0.5, 20), 102: ((0.5, 0.33), 10), 204: ((0.5, 0.1), 5), 408: ((0.5, 0.1), 2)},
"240p": {1: (0.3, 297), 51: (0.4, 20), 102: ((0.4, 0.33), 10), 204: ((0.4, 0.1), 5), 408: ((0.4, 0.1), 2)},
# ---
"360p": {1: (0.2, 141), 51: (0.15, 8), 102: ((0.15, 0.33), 4), 204: ((0.15, 0.1), 2), 408: ((0.15, 0.1), 1)},
"512": {1: (0.1, 141)},
# ---
"480p": {1: (0.1, 89)},
# ---
"720p": {1: (0.05, 36)},
"1024": {1: (0.05, 36)},
# ---
"1080p": {1: (0.1, 5)},
# ---
"2048": {1: (0.1, 5)},
}
# Acceleration settings
num_workers = 8
num_bucket_build_workers = 16
dtype = "bf16"
seed = 42
outputs = "outputs"
wandb = False
# Model settings
model = dict(
type="STDiT3-XL/2",
from_pretrained=None,
qk_norm=True,
enable_flash_attn=True,
enable_layernorm_kernel=True,
)
vae = dict(
type="OpenSoraVAE_V1_2",
from_pretrained="pretrained_models/vae-pipeline",
micro_frame_size=17,
micro_batch_size=4,
)
text_encoder = dict(
type="t5",
from_pretrained="DeepFloyd/t5-v1_1-xxl",
model_max_length=300,
shardformer=True,
local_files_only=True,
)
# feature extraction settings
save_text_features = True
save_compressed_text_features = True

View file

@ -59,6 +59,7 @@ model = dict(
qk_norm=True,
enable_flash_attn=True,
enable_layernorm_kernel=True,
freeze_y_embedder=True,
)
vae = dict(
type="OpenSoraVAE_V1_2",

View file

@ -169,6 +169,8 @@ class STDiT3Config(PretrainedConfig):
enable_layernorm_kernel=False,
enable_sequence_parallelism=False,
only_train_temporal=False,
freeze_y_embedder=False,
skip_y_embedder=False,
**kwargs,
):
self.input_size = input_size
@ -189,6 +191,8 @@ class STDiT3Config(PretrainedConfig):
self.enable_layernorm_kernel = enable_layernorm_kernel
self.enable_sequence_parallelism = enable_sequence_parallelism
self.only_train_temporal = only_train_temporal
self.freeze_y_embedder = freeze_y_embedder
self.skip_y_embedder = skip_y_embedder
super().__init__(**kwargs)
@ -284,6 +288,10 @@ class STDiT3(PreTrainedModel):
for param in block.parameters():
param.requires_grad = True
if config.freeze_y_embedder:
for param in self.y_embedder.parameters():
param.requires_grad = False
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
@ -319,6 +327,19 @@ class STDiT3(PreTrainedModel):
W = W // self.patch_size[2]
return (T, H, W)
def encode_text(self, y, mask=None):
y = self.y_embedder(y, self.training) # [B, 1, N_token, C]
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, self.hidden_size)
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = [y.shape[2]] * y.shape[0]
y = y.squeeze(1).view(1, -1, self.hidden_size)
return y, y_lens
def forward(self, x, timestep, y, mask=None, x_mask=None, fps=None, height=None, width=None, **kwargs):
dtype = self.x_embedder.proj.weight.dtype
B = x.size(0)
@ -348,16 +369,10 @@ class STDiT3(PreTrainedModel):
t0_mlp = self.t_block(t0)
# === get y embed ===
y = self.y_embedder(y, self.training) # [B, 1, N_token, C]
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, self.hidden_size)
y_lens = mask.sum(dim=1).tolist()
if self.config.skip_y_embedder:
y_lens = mask.tolist()
else:
y_lens = [y.shape[2]] * y.shape[0]
y = y.squeeze(1).view(1, -1, self.hidden_size)
y, y_lens = self.encode_text(y, mask)
# === get x embed ===
x = self.x_embedder(x) # [B, N, C]

View file

@ -378,3 +378,23 @@ class Timer:
self.end_time = time.time()
if self.log:
print(f"Elapsed time for {self.name}: {self.elapsed_time:.2f} s")
def get_tensor_memory(tensor, human_readable=True):
size = tensor.element_size() * tensor.nelement()
if human_readable:
size = format_numel_str(size)
return size
class FeatureSaver:
def __init__(self, save_dir, bin_size=10):
self.save_dir = save_dir
self.bin_size = bin_size
self.data_list = []
self.cnt = 0
def update(self, data):
self.data_list.append(data)
self.cnt += 1

View file

@ -0,0 +1,167 @@
import os
from pprint import pformat
import torch
import torch.distributed as dist
from tqdm import tqdm
from opensora.acceleration.parallel_states import get_data_parallel_group
from opensora.datasets import prepare_dataloader, prepare_variable_dataloader
from opensora.datasets.utils import collate_fn_ignore_none
from opensora.registry import DATASETS, MODELS, build_module
from opensora.utils.config_utils import parse_configs
from opensora.utils.misc import FeatureSaver, create_logger, format_numel_str, get_model_numel, to_torch_dtype
DEFAULT_DATASET_NAME = "VideoTextDataset"
def main():
torch.set_grad_enabled(False)
# ======================================================
# 1. configs & runtime variables
# ======================================================
# == parse configs ==
cfg = parse_configs(training=False)
# == device and dtype ==
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
cfg_dtype = cfg.get("dtype", "bf16")
assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}"
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
# == colossalai init distributed training ==
device = "cuda" if torch.cuda.is_available() else "cpu"
cfg_dtype = cfg.get("dtype", "fp32")
assert cfg_dtype in ["fp16", "bf16", "fp32"], f"Unknown mixed precision {cfg_dtype}"
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# == init logger, tensorboard & wandb ==
logger = create_logger()
logger.info("Configuration:\n %s", pformat(cfg.to_dict()))
# ======================================================
# 2. build dataset and dataloader
# ======================================================
logger.info("Building dataset...")
# == build dataset ==
dataset = build_module(cfg.dataset, DATASETS)
logger.info("Dataset contains %s samples.", len(dataset))
# == build dataloader ==
dataloader_args = dict(
dataset=dataset,
batch_size=cfg.get("batch_size", None),
num_workers=cfg.get("num_workers", 4),
seed=cfg.get("seed", 1024),
shuffle=True,
drop_last=True,
pin_memory=True,
process_group=get_data_parallel_group(),
collate_fn=collate_fn_ignore_none,
)
if cfg.dataset.type == DEFAULT_DATASET_NAME:
dataloader = prepare_dataloader(**dataloader_args)
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.get("sp_size", 1)
logger.info("Total batch size: %s", total_batch_size)
num_steps_per_epoch = len(dataloader)
else:
dataloader = prepare_variable_dataloader(
bucket_config=cfg.get("bucket_config", None),
num_bucket_build_workers=cfg.get("num_bucket_build_workers", 1),
**dataloader_args,
)
num_steps_per_epoch = dataloader.batch_sampler.get_num_batch()
# ======================================================
# 3. build model
# ======================================================
logger.info("Building models...")
# == build text-encoder and vae ==
text_encoder = build_module(cfg.text_encoder, MODELS, device=device, dtype=dtype)
vae = build_module(cfg.vae, MODELS).to(device, dtype).eval()
# == build diffusion model ==
input_size = (dataset.num_frames, *dataset.image_size)
latent_size = vae.get_latent_size(input_size)
model = (
build_module(
cfg.model,
MODELS,
input_size=latent_size,
in_channels=vae.out_channels,
caption_channels=text_encoder.output_dim,
model_max_length=text_encoder.model_max_length,
)
.to(device, dtype)
.train()
)
model_numel, model_numel_trainable = get_model_numel(model)
logger.info(
"[Diffusion] Trainable model params: %s, Total model params: %s",
format_numel_str(model_numel_trainable),
format_numel_str(model_numel),
)
# =======================================================
# 4. distributed training preparation with colossalai
# =======================================================
# == global variables ==
start_step = sampler_start_idx = 0
logger.info("Training for with %s steps per epoch", num_steps_per_epoch)
if cfg.dataset.type == DEFAULT_DATASET_NAME:
dataloader.sampler.set_start_index(sampler_start_idx)
# =======================================================
# 5. training loop
# =======================================================
dist.barrier()
for epoch in range(1):
# == set dataloader to new epoch ==
if cfg.dataset.type == DEFAULT_DATASET_NAME:
dataloader.sampler.set_epoch(epoch)
dataloader_iter = iter(dataloader)
logger.info("Beginning epoch %s...", epoch)
# == training loop in an epoch ==
assert cfg.get("save_dir", None) is not None, "Please specify the save_dir in the config file."
os.makedirs(cfg.save_dir, exist_ok=True)
saver = FeatureSaver(cfg.save_dir)
save_text_features = cfg.get("save_text_features", False)
save_compressed_text_features = cfg.get("save_compressed_text_features", False)
with tqdm(
enumerate(dataloader_iter, start=start_step),
desc=f"Epoch {epoch}",
initial=start_step,
total=num_steps_per_epoch,
) as pbar:
for step, batch in pbar:
x = batch.pop("video").to(device, dtype) # [B, C, T, H, W]
y = batch.pop("text")
x = vae.encode(x).cpu() # [B, C, T, H/P, W/P]
fps = batch["fps"].to(dtype)
batch_dict = {"x": x, "fps": fps}
if save_text_features:
text_infos = text_encoder.encode(y)
y_feat = text_infos["y"]
y_mask = text_infos["mask"]
if not save_compressed_text_features:
y_feat = y_feat.cpu()
y_mask = y_mask.cpu()
else:
y_feat, y_mask = model.encode_text(y_feat, y_mask)
y_feat = y_feat.cpu()
y_mask = torch.tensor(y_mask)
breakpoint()
batch_dict.update({"y": y_feat, "mask": y_mask})
saver.update(batch_dict)
if __name__ == "__main__":
main()

View file

@ -1,284 +0,0 @@
import os
from copy import deepcopy
from datetime import timedelta
from pprint import pformat
import torch
import torch.distributed as dist
from colossalai.booster import Booster
from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device, set_seed
from tqdm import tqdm
import wandb
from opensora.acceleration.checkpoint import set_grad_checkpoint
from opensora.acceleration.parallel_states import get_data_parallel_group
from opensora.datasets import prepare_dataloader, prepare_variable_dataloader
from opensora.datasets.utils import collate_fn_ignore_none
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
from opensora.utils.ckpt_utils import load, model_gathering, model_sharding, record_model_param_shape, save
from opensora.utils.config_utils import define_experiment_workspace, parse_configs, save_training_config
from opensora.utils.misc import (
all_reduce_mean,
create_logger,
create_tensorboard_writer,
format_numel_str,
get_model_numel,
requires_grad,
to_torch_dtype,
)
from opensora.utils.train_utils import MaskGenerator, create_colossalai_plugin, update_ema
DEFAULT_DATASET_NAME = "VideoTextDataset"
def main():
# ======================================================
# 1. configs & runtime variables
# ======================================================
# == parse configs ==
cfg = parse_configs(training=False)
# == device and dtype ==
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
cfg_dtype = cfg.get("dtype", "bf16")
assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}"
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
# == colossalai init distributed training ==
# NOTE: A very large timeout is set to avoid some processes exit early
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
set_seed(cfg.get("seed", 1024))
coordinator = DistCoordinator()
device = get_current_device()
# == init exp_dir ==
exp_name, exp_dir = define_experiment_workspace(cfg)
coordinator.block_all()
if coordinator.is_master():
os.makedirs(exp_dir, exist_ok=True)
save_training_config(cfg.to_dict(), exp_dir)
coordinator.block_all()
# == init logger, tensorboard & wandb ==
logger = create_logger(exp_dir)
logger.info("Experiment directory created at %s", exp_dir)
logger.info("Training configuration:\n %s", pformat(cfg.to_dict()))
if coordinator.is_master():
tb_writer = create_tensorboard_writer(exp_dir)
if cfg.get("wandb", False):
wandb.init(project="minisora", name=exp_name, config=cfg.to_dict(), dir="./outputs/wandb")
# == init ColossalAI booster ==
plugin = create_colossalai_plugin(
plugin=cfg.get("plugin", "zero2"),
dtype=cfg_dtype,
grad_clip=cfg.get("grad_clip", 0),
sp_size=cfg.get("sp_size", 1),
)
booster = Booster(plugin=plugin)
# ======================================================
# 2. build dataset and dataloader
# ======================================================
logger.info("Building dataset...")
# == build dataset ==
dataset = build_module(cfg.dataset, DATASETS)
logger.info("Dataset contains %s samples.", len(dataset))
# == build dataloader ==
dataloader_args = dict(
dataset=dataset,
batch_size=cfg.get("batch_size", None),
num_workers=cfg.get("num_workers", 4),
seed=cfg.get("seed", 1024),
shuffle=True,
drop_last=True,
pin_memory=True,
process_group=get_data_parallel_group(),
collate_fn=collate_fn_ignore_none,
)
if cfg.dataset.type == DEFAULT_DATASET_NAME:
dataloader = prepare_dataloader(**dataloader_args)
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.get("sp_size", 1)
logger.info("Total batch size: %s", total_batch_size)
num_steps_per_epoch = len(dataloader)
sampler_to_io = None
else:
dataloader = prepare_variable_dataloader(
bucket_config=cfg.get("bucket_config", None),
num_bucket_build_workers=cfg.get("num_bucket_build_workers", 1),
**dataloader_args,
)
num_steps_per_epoch = dataloader.batch_sampler.get_num_batch() // dist.get_world_size()
sampler_to_io = None if cfg.get("start_from_scratch ", False) else dataloader.batch_sampler
# ======================================================
# 3. build model
# ======================================================
logger.info("Building models...")
# == build text-encoder and vae ==
text_encoder = build_module(cfg.text_encoder, MODELS, device=device, dtype=dtype)
vae = build_module(cfg.vae, MODELS).to(device, dtype).eval()
# == build diffusion model ==
input_size = (dataset.num_frames, *dataset.image_size)
latent_size = vae.get_latent_size(input_size)
model = (
build_module(
cfg.model,
MODELS,
input_size=latent_size,
in_channels=vae.out_channels,
caption_channels=text_encoder.output_dim,
model_max_length=text_encoder.model_max_length,
)
.to(device, dtype)
.train()
)
model_numel, model_numel_trainable = get_model_numel(model)
logger.info(
"[Diffusion] Trainable model params: %s, Total model params: %s",
format_numel_str(model_numel_trainable),
format_numel_str(model_numel),
)
# == build ema for diffusion model ==
ema = deepcopy(model).to(torch.float32).to(device)
requires_grad(ema, False)
ema_shape_dict = record_model_param_shape(ema)
ema.eval()
update_ema(ema, model, decay=0, sharded=False)
# == setup loss function, build scheduler ==
scheduler = build_module(cfg.scheduler, SCHEDULERS)
# == setup optimizer ==
optimizer = HybridAdam(
filter(lambda p: p.requires_grad, model.parameters()),
adamw_mode=True,
lr=cfg.get("lr", 1e-4),
weight_decay=cfg.get("weight_decay", 0),
eps=cfg.get("adam_eps", 1e-8),
)
lr_scheduler = None
# == additional preparation ==
if cfg.get("grad_checkpoint", False):
set_grad_checkpoint(model)
if cfg.get("mask_ratios", None) is not None:
mask_generator = MaskGenerator(cfg.mask_ratios)
# =======================================================
# 4. distributed training preparation with colossalai
# =======================================================
logger.info("Preparing for distributed training...")
# == boosting ==
# NOTE: we set dtype first to make initialization of model consistent with the dtype; then reset it to the fp32 as we make diffusion scheduler in fp32
torch.set_default_dtype(dtype)
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
dataloader=dataloader,
)
torch.set_default_dtype(torch.float)
logger.info("Boosting model for distributed training")
# == global variables ==
cfg_epochs = cfg.get("epochs", 1000)
start_epoch = start_step = log_step = sampler_start_idx = acc_step = 0
running_loss = 0.0
logger.info("Training for %s epochs with %s steps per epoch", cfg_epochs, num_steps_per_epoch)
# == resume ==
if cfg.get("load", None) is not None:
logger.info("Loading checkpoint")
ret = load(
booster,
cfg.load,
model=model,
ema=ema,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
sampler=sampler_to_io,
)
if not cfg.get("start_from_scratch ", False):
start_epoch, start_step, sampler_start_idx = ret
logger.info("Loaded checkpoint %s at epoch %s step %s", cfg.load, start_epoch, start_step)
if cfg.dataset.type == DEFAULT_DATASET_NAME:
dataloader.sampler.set_start_index(sampler_start_idx)
model_sharding(ema)
# =======================================================
# 5. training loop
# =======================================================
dist.barrier()
for epoch in range(1):
# == set dataloader to new epoch ==
if cfg.dataset.type == DEFAULT_DATASET_NAME:
dataloader.sampler.set_epoch(epoch)
dataloader_iter = iter(dataloader)
logger.info("Beginning epoch %s...", epoch)
# == training loop in an epoch ==
save_root = f'/mnt/nfs-207/sora_data/webvid-10M/feat'
data_list = []
cnt = 0
max_len = 10
with tqdm(
enumerate(dataloader_iter, start=start_step),
desc=f"Epoch {epoch}",
disable=not coordinator.is_master(),
initial=start_step,
total=num_steps_per_epoch,
) as pbar:
for step, batch in pbar:
x = batch.pop("video").to(device, dtype) # [B, C, T, H, W]
y = batch.pop("text")
# == visual and text encoding ==
with torch.no_grad():
# Prepare visual inputs
x = vae.encode(x) # [B, C, T, H/P, W/P]
# Prepare text inputs
model_args = {}
# == video meta info ==
for k, v in batch.items():
model_args[k] = v.to(device, dtype)
# keys to save: ['x', 'y', 'mask', 'fps']
data_i = {
'x': x.detach().cpu(),
'fps': model_args['fps'].detach().cpu(),
}
save_text_feat = [True, False]
if save_text_feat:
with torch.no_grad():
model_args_text = text_encoder.encode(y)
data_i.update({
'y': model_args_text['y'].detach().cpu(),
'mask': model_args_text['mask'].detach().cpu(),
})
breakpoint()
else:
data_i['text'] = y
data_list.append(data_i)
if len(data_list) == max_len:
save_path = os.path.join(save_root, f'data/{cnt}.bin')
torch.save(data_list, save_path)
data_list = []
cnt += 1
if __name__ == "__main__":
main()