[fix] vae dtype

This commit is contained in:
zhengzangw 2024-05-16 08:50:24 +00:00
parent 137d0ac223
commit 9fd4554f55
14 changed files with 477 additions and 57 deletions

View file

@ -146,28 +146,15 @@ Other useful documents and links are listed below.
### Install from Source
For CUDA 12.1, you can install the dependencies with the following commands. Otherwise, please refer to [Installation](docs/installation.md) for more instructions.
```bash
# create a virtual env
# create a virtual env and activate (conda as an example)
conda create -n opensora python=3.10
# activate virtual environment
conda activate opensora
# install torch
# the command below is for CUDA 12.1, choose install commands from
# https://pytorch.org/get-started/locally/ based on your own CUDA version
pip install torch torchvision
# install flash attention (optional)
# set enable_flash_attn=False in config to avoid using flash attention
pip install packaging ninja
pip install flash-attn --no-build-isolation
# install apex (optional)
# set enable_layernorm_kernel=False in config to avoid using apex
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" git+https://github.com/NVIDIA/apex.git
# install xformers
pip install -U xformers --index-url https://download.pytorch.org/whl/cu121
# install torch, torchvision and xformers
pip install -r requirements/requirements_cu121.txt
# install this project
git clone https://github.com/hpcaitech/Open-Sora
@ -249,6 +236,12 @@ Since Open-Sora 1.1 supports inference with dynamic input size, you can pass the
python scripts/inference.py configs/opensora-v1-1/inference/sample.py --prompt "A beautiful sunset over the city" --num-frames 32 --image-size 480 854
```
If your installation do not contain `apex` and `flash-attn`, you need to disable them in the config file, or via the folowing command.
```bash
python scripts/inference.py configs/opensora-v1-1/inference/sample.py --prompt "A beautiful sunset over the city" --num-frames 32 --image-size 480 854 --layernorm-kernel False --flash-attn False
```
See [here](docs/commands.md#inference-with-open-sora-11) for more instructions including text-to-image, image-to-video, video-to-video, and infinite time generation.
### Open-Sora 1.0 Command Line Inference

40
docs/installation.md Normal file
View file

@ -0,0 +1,40 @@
# Installation
Requirements are listed in `requirements` folder.
## Different CUDA versions
You need to mannually install `torch`, `torchvision` and `xformers` for different CUDA versions.
```bash
# install torch (>=2.1 is recommended)
# the command below is for CUDA 12.1, choose install commands from
# https://pytorch.org/get-started/locally/ based on your own CUDA version
pip install torch==2.2.2 torchvision==0.17.2 --index-url https://download.pytorch.org/whl/cu121
# install xformers
# the command below is for CUDA 12.1, choose install commands from
# https://github.com/facebookresearch/xformers?tab=readme-ov-file#installing-xformers based on your own CUDA version
pip install xformers --index-url https://download.pytorch.org/whl/cu121
```
```bash
# install flash attention (optional)
# set enable_flash_attn=False in config to avoid using flash attention
pip install packaging ninja
pip install flash-attn --no-build-isolation
# install apex (optional)
# set enable_layernorm_kernel=False in config to avoid using apex
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" git+https://github.com/NVIDIA/apex.git
```
gdown
pre-commit
pyarrow
tensorboard
transformers
wandb
pandarallel
gradio
spaces

View file

@ -1,4 +0,0 @@
from .acceleration import *
from .datasets import *
from .models import *
from .registry import *

View file

@ -5,7 +5,6 @@ from typing import Iterator, List, Optional
import torch
import torch.distributed as dist
from pandarallel import pandarallel
from torch.utils.data import Dataset, DistributedSampler
from opensora.utils.misc import format_numel_str, get_logger
@ -81,6 +80,8 @@ class VariableVideoBatchSampler(DistributedSampler):
def group_by_bucket(self) -> dict:
bucket_sample_dict = OrderedDict()
from pandarallel import pandarallel
pandarallel.initialize(nb_workers=self.num_bucket_build_workers, progress_bar=False)
get_logger().info("Building buckets...")
bucket_ids = self.dataset.data.parallel_apply(

View file

@ -143,12 +143,16 @@ class VideoAutoencoderPipeline(nn.Module):
param.requires_grad = False
self.out_channels = self.temporal_vae.out_channels
self.scale = torch.tensor(scale).cuda()
self.shift = torch.tensor(shift).cuda()
if len(self.scale.shape) > 0:
self.scale = self.scale[None, :, None, None, None]
if len(self.shift.shape) > 0:
self.shift = self.shift[None, :, None, None, None]
# normalization parameters
scale = torch.tensor(scale)
shift = torch.tensor(shift)
if len(scale.shape) > 0:
scale = scale[None, :, None, None, None]
if len(shift.shape) > 0:
shift = shift[None, :, None, None, None]
self.register_buffer("scale", scale)
self.register_buffer("shift", shift)
def encode(self, x):
x_z = self.spatial_vae.encode(x)

View file

@ -27,6 +27,7 @@ def parse_args(training=False):
parser.add_argument("--flash-attn", default=None, type=str2bool, help="enable flash attention")
parser.add_argument("--layernorm-kernel", default=None, type=str2bool, help="enable layernorm kernel")
parser.add_argument("--resolution", default=None, type=str, help="multi resolution")
parser.add_argument("--data-path", default=None, type=str, help="path to data csv")
# ======================================================
# Inference
@ -41,7 +42,6 @@ def parse_args(training=False):
parser.add_argument("--prompt-as-path", action="store_true", help="use prompt as path to save samples")
parser.add_argument("--verbose", default=None, type=int, help="verbose level")
parser.add_argument("--data-path", default=None, type=str, help="path to data csv")
# prompt
parser.add_argument("--prompt-path", default=None, type=str, help="path to prompt txt file")
parser.add_argument("--prompt", default=None, type=str, nargs="+", help="prompt list")
@ -69,7 +69,6 @@ def parse_args(training=False):
parser.add_argument("--lr", default=None, type=float, help="learning rate")
parser.add_argument("--wandb", default=None, type=bool, help="enable wandb")
parser.add_argument("--load", default=None, type=str, help="path to continue training")
parser.add_argument("--data-path", default=None, type=str, help="path to data csv")
parser.add_argument("--start-from-scratch", action="store_true", help="start training from scratch")
return parser.parse_args()

View file

@ -11,7 +11,6 @@ from typing import Tuple
import numpy as np
import torch
import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter
# ======================================================
# Logging
@ -72,6 +71,8 @@ def print_0(*args, **kwargs):
def create_tensorboard_writer(exp_dir):
from torch.utils.tensorboard import SummaryWriter
tensorboard_dir = f"{exp_dir}/tensorboard"
os.makedirs(tensorboard_dir, exist_ok=True)
writer = SummaryWriter(tensorboard_dir)
@ -349,3 +350,31 @@ def transpose(x):
def all_exists(paths):
return all(os.path.exists(path) for path in paths)
# ======================================================
# Profile
# ======================================================
class Timer:
def __init__(self, name, log=True):
self.name = name
self.start_time = None
self.end_time = None
self.log = log
@property
def elapsed_time(self):
return self.end_time - self.start_time
def __enter__(self):
torch.cuda.synchronize()
self.start_time = time.time()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
torch.cuda.synchronize()
self.end_time = time.time()
if self.log:
print(f"Elapsed time for {self.name}: {self.elapsed_time:.2f} s")

View file

@ -1,19 +0,0 @@
colossalai
accelerate
diffusers
ftfy
gdown
mmengine
pandas
pre-commit
pyarrow
av
tensorboard
timm
tqdm
transformers
wandb
rotary_embedding_torch
pandarallel
gradio
spaces

View file

@ -0,0 +1,3 @@
torch==2.2.2 --index-url https://download.pytorch.org/whl/cu121
torchvision==0.17.2 --index-url https://download.pytorch.org/whl/cu121
xformers==0.0.25.post1 --index-url https://download.pytorch.org/whl/cu121

View file

@ -0,0 +1,3 @@
packaging
ninja
flash-attn --no-build-isolation

View file

@ -0,0 +1,9 @@
colossalai>=0.3.7
mmengine>=0.10.3
pandas>=2.2.2
timm==0.9.16
rotary_embedding_torch==0.5.3
ftfy>=6.2.0 # for t5
diffusers==0.27.2 # for vae
accelerate==0.29.2 # for t5
av>=12.0.0

View file

@ -0,0 +1,340 @@
import os
from copy import deepcopy
from datetime import timedelta
from pprint import pformat
import torch
import torch.distributed as dist
import wandb
from colossalai.booster import Booster
from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device, set_seed
from tqdm import tqdm
from opensora.acceleration.checkpoint import set_grad_checkpoint
from opensora.acceleration.parallel_states import get_data_parallel_group
from opensora.datasets import prepare_dataloader, prepare_variable_dataloader
from opensora.datasets.utils import collate_fn_ignore_none
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
from opensora.utils.ckpt_utils import load, model_gathering, model_sharding, record_model_param_shape, save
from opensora.utils.config_utils import define_experiment_workspace, parse_configs, save_training_config
from opensora.utils.misc import (
Timer,
all_reduce_mean,
create_logger,
create_tensorboard_writer,
format_numel_str,
get_model_numel,
requires_grad,
to_torch_dtype,
)
from opensora.utils.train_utils import MaskGenerator, create_colossalai_plugin, update_ema
DEFAULT_DATASET_NAME = "VideoTextDataset"
def main():
# ======================================================
# 1. configs & runtime variables
# ======================================================
# == parse configs ==
cfg = parse_configs(training=True)
# == device and dtype ==
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
cfg_dtype = cfg.get("dtype", "bf16")
assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}"
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
# == colossalai init distributed training ==
# NOTE: A very large timeout is set to avoid some processes exit early
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
set_seed(cfg.get("seed", 1024))
coordinator = DistCoordinator()
device = get_current_device()
# == init exp_dir ==
exp_name, exp_dir = define_experiment_workspace(cfg)
coordinator.block_all()
if coordinator.is_master():
os.makedirs(exp_dir, exist_ok=True)
save_training_config(cfg.to_dict(), exp_dir)
coordinator.block_all()
# == init logger, tensorboard & wandb ==
logger = create_logger(exp_dir)
logger.info("Experiment directory created at %s", exp_dir)
logger.info("Training configuration:\n %s", pformat(cfg.to_dict()))
if coordinator.is_master():
tb_writer = create_tensorboard_writer(exp_dir)
if cfg.get("wandb", False):
wandb.init(project="minisora", name=exp_name, config=cfg.to_dict(), dir="./outputs/wandb")
# == init ColossalAI booster ==
plugin = create_colossalai_plugin(
plugin=cfg.get("plugin", "zero2"),
dtype=cfg_dtype,
grad_clip=cfg.get("grad_clip", 0),
sp_size=cfg.get("sp_size", 1),
)
booster = Booster(plugin=plugin)
# ======================================================
# 2. build dataset and dataloader
# ======================================================
logger.info("Building dataset...")
# == build dataset ==
dataset = build_module(cfg.dataset, DATASETS)
logger.info("Dataset contains %s samples.", len(dataset))
# == build dataloader ==
dataloader_args = dict(
dataset=dataset,
batch_size=cfg.get("batch_size", None),
num_workers=cfg.get("num_workers", 4),
seed=cfg.get("seed", 1024),
shuffle=True,
drop_last=True,
pin_memory=True,
process_group=get_data_parallel_group(),
collate_fn=collate_fn_ignore_none,
)
if cfg.dataset.type == DEFAULT_DATASET_NAME:
dataloader = prepare_dataloader(**dataloader_args)
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.get("sp_size", 1)
logger.info("Total batch size: %s", total_batch_size)
num_steps_per_epoch = len(dataloader)
sampler_to_io = None
else:
dataloader = prepare_variable_dataloader(
bucket_config=cfg.get("bucket_config", None),
num_bucket_build_workers=cfg.get("num_bucket_build_workers", 1),
**dataloader_args,
)
num_steps_per_epoch = dataloader.batch_sampler.get_num_batch() // dist.get_world_size()
sampler_to_io = None if cfg.get("start_from_scratch ", False) else dataloader.batch_sampler
# ======================================================
# 3. build model
# ======================================================
logger.info("Building models...")
# == build text-encoder and vae ==
text_encoder = build_module(cfg.text_encoder, MODELS, device=device)
vae = build_module(cfg.vae, MODELS).to(device, dtype).eval()
# == build diffusion model ==
input_size = (dataset.num_frames, *dataset.image_size)
latent_size = vae.get_latent_size(input_size)
model = (
build_module(
cfg.model,
MODELS,
input_size=latent_size,
in_channels=vae.out_channels,
caption_channels=text_encoder.output_dim,
model_max_length=text_encoder.model_max_length,
)
.to(device, dtype)
.train()
)
model_numel, model_numel_trainable = get_model_numel(model)
logger.info(
"[Diffusion] Trainable model params: %s, Total model params: %s",
format_numel_str(model_numel_trainable),
format_numel_str(model_numel),
)
# == build ema for diffusion model ==
ema = deepcopy(model).to(torch.float32).to(device)
requires_grad(ema, False)
ema_shape_dict = record_model_param_shape(ema)
ema.eval()
update_ema(ema, model, decay=0, sharded=False)
# == setup loss function, build scheduler ==
scheduler = build_module(cfg.scheduler, SCHEDULERS)
# == setup optimizer ==
optimizer = HybridAdam(
filter(lambda p: p.requires_grad, model.parameters()),
adamw_mode=True,
lr=cfg.get("lr", 1e-4),
weight_decay=cfg.get("weight_decay", 0),
eps=cfg.get("adam_eps", 1e-8),
)
lr_scheduler = None
# == additional preparation ==
if cfg.get("grad_checkpoint", False):
set_grad_checkpoint(model)
if cfg.get("mask_ratios", None) is not None:
mask_generator = MaskGenerator(cfg.mask_ratios)
# =======================================================
# 4. distributed training preparation with colossalai
# =======================================================
logger.info("Preparing for distributed training...")
# == boosting ==
# NOTE: we set dtype first to make initialization of model consistent with the dtype; then reset it to the fp32 as we make diffusion scheduler in fp32
torch.set_default_dtype(dtype)
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
dataloader=dataloader,
)
torch.set_default_dtype(torch.float)
logger.info("Boosting model for distributed training")
# == global variables ==
cfg_epochs = cfg.get("epochs", 1000)
start_epoch = start_step = log_step = sampler_start_idx = acc_step = 0
running_loss = 0.0
logger.info("Training for %s epochs with %s steps per epoch", cfg_epochs, num_steps_per_epoch)
# == resume ==
if cfg.get("load", None) is not None:
logger.info("Loading checkpoint")
ret = load(
booster,
cfg.load,
model=model,
ema=ema,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
sampler=sampler_to_io,
)
if not cfg.get("start_from_scratch ", False):
start_epoch, start_step, sampler_start_idx = ret
logger.info("Loaded checkpoint %s at epoch %s step %s", cfg.load, start_epoch, start_step)
if cfg.dataset.type == DEFAULT_DATASET_NAME:
dataloader.sampler.set_start_index(sampler_start_idx)
model_sharding(ema)
# =======================================================
# 5. training loop
# =======================================================
dist.barrier()
for epoch in range(start_epoch, cfg_epochs):
# == set dataloader to new epoch ==
if cfg.dataset.type == DEFAULT_DATASET_NAME:
dataloader.sampler.set_epoch(epoch)
dataloader_iter = iter(dataloader)
logger.info("Beginning epoch %s...", epoch)
# == training loop in an epoch ==
with tqdm(
enumerate(dataloader_iter, start=start_step),
desc=f"Epoch {epoch}",
disable=not coordinator.is_master(),
initial=start_step,
total=num_steps_per_epoch,
) as pbar:
for step, batch in pbar:
x = batch.pop("video").to(device, dtype) # [B, C, T, H, W]
y = batch.pop("text")
# == visual and text encoding ==
with torch.no_grad():
# Prepare visual inputs
with Timer("VAE"):
x = vae.encode(x) # [B, C, T, H/P, W/P]
# Prepare text inputs
with Timer("Text Encoder"):
model_args = text_encoder.encode(y)
# == mask ==
mask = None
if cfg.get("mask_ratios", None) is not None:
mask = mask_generator.get_masks(x)
model_args["x_mask"] = mask
# == video meta info ==
for k, v in batch.items():
model_args[k] = v.to(device, dtype)
# == diffusion loss computation ==
with Timer("Forward"):
loss_dict = scheduler.training_losses(model, x, model_args, mask=mask)
# == backward & update ==
with Timer("Backward"):
loss = loss_dict["loss"].mean()
booster.backward(loss=loss, optimizer=optimizer)
optimizer.step()
optimizer.zero_grad()
# == update EMA ==
update_ema(ema, model.module, optimizer=optimizer, decay=cfg.get("ema_decay", 0.9999))
# == update log info ==
all_reduce_mean(loss)
running_loss += loss.item()
global_step = epoch * num_steps_per_epoch + step
log_step += 1
acc_step += 1
# == logging ==
if coordinator.is_master() and (global_step + 1) % cfg.get("log_every", 1) == 0:
avg_loss = running_loss / log_step
# progress bar
pbar.set_postfix({"loss": avg_loss, "step": step, "global_step": global_step})
# tensorboard
tb_writer.add_scalar("loss", loss.item(), global_step)
# wandb
if cfg.get("wandb", False):
wandb.log(
{
"iter": global_step,
"epoch": epoch,
"loss": loss.item(),
"avg_loss": avg_loss,
"acc_step": acc_step,
},
step=global_step,
)
running_loss = 0.0
log_step = 0
# == checkpoint saving ==
ckpt_every = cfg.get("ckpt_every", 0)
if ckpt_every > 0 and (global_step + 1) % ckpt_every == 0:
model_gathering(ema, ema_shape_dict)
save(
booster,
exp_dir,
model=model,
ema=ema,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
sampler=sampler_to_io,
epoch=epoch,
step=step + 1,
global_step=global_step + 1,
batch_size=cfg.get("batch_size", None),
)
if dist.get_rank() == 0:
model_sharding(ema)
logger.info(
"Saved checkpoint at epoch %s step %s global_step %s to %s",
epoch,
step + 1,
global_step + 1,
exp_dir,
)
# NOTE: the continue epochs are not resumed, so we need to reset the sampler start index and start step
if cfg.dataset.type == DEFAULT_DATASET_NAME:
dataloader.sampler.set_start_index(0)
else:
dataloader.batch_sampler.set_epoch(epoch + 1)
logger.info("Epoch done, recomputing batch sampler")
start_step = 0
if __name__ == "__main__":
main()

View file

@ -5,13 +5,13 @@ from pprint import pformat
import torch
import torch.distributed as dist
import wandb
from colossalai.booster import Booster
from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device, set_seed
from tqdm import tqdm
import wandb
from opensora.acceleration.checkpoint import set_grad_checkpoint
from opensora.acceleration.parallel_states import get_data_parallel_group
from opensora.datasets import prepare_dataloader, prepare_variable_dataloader
@ -120,7 +120,7 @@ def main():
# ======================================================
logger.info("Building models...")
# == build text-encoder and vae ==
text_encoder = build_module(cfg.text_encoder, MODELS, device=device)
text_encoder = build_module(cfg.text_encoder, MODELS, device=device, dtype=dtype)
vae = build_module(cfg.vae, MODELS).to(device, dtype).eval()
# == build diffusion model ==

View file

@ -3,7 +3,7 @@ from typing import List
from setuptools import find_packages, setup
def fetch_requirements(path) -> List[str]:
def fetch_requirements(paths) -> List[str]:
"""
This function reads the requirements file.
@ -13,8 +13,13 @@ def fetch_requirements(path) -> List[str]:
Returns:
The lines in the requirements file.
"""
with open(path, "r") as fd:
return [r.strip() for r in fd.readlines()]
if not isinstance(paths, list):
paths = [paths]
requirements = []
for path in paths:
with open(path, "r") as fd:
requirements += [r.strip() for r in fd.readlines()]
return requirements
def fetch_readme() -> str:
@ -34,10 +39,17 @@ setup(
packages=find_packages(
exclude=(
"assets",
"cache",
"configs",
"docs",
"eval",
"evaluation_results",
"gradio",
"logs",
"notebooks",
"outputs",
"pretrained_models",
"samples",
"scripts",
"tests",
"tools",
@ -48,7 +60,14 @@ setup(
long_description=fetch_readme(),
long_description_content_type="text/markdown",
license="Apache Software License 2.0",
install_requires=fetch_requirements("requirements.txt"),
url="https://github.com/hpcaitech/Open-Sora",
project_urls={
"Bug Tracker": "https://github.com/hpcaitech/Open-Sora/issues",
"Examples": "https://hpcaitech.github.io/Open-Sora/",
"Documentation": "https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file",
"Github": "https://github.com/hpcaitech/Open-Sora",
},
install_requires=fetch_requirements("requirements/requirements.txt"),
python_requires=">=3.6",
classifiers=[
"Programming Language :: Python :: 3",
@ -57,4 +76,7 @@ setup(
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: System :: Distributed Computing",
],
extras_require={
"fast": ["flash-attn --no-build-isolation"],
},
)