mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
[fix] vae dtype
This commit is contained in:
parent
137d0ac223
commit
9fd4554f55
29
README.md
29
README.md
|
|
@ -146,28 +146,15 @@ Other useful documents and links are listed below.
|
|||
|
||||
### Install from Source
|
||||
|
||||
For CUDA 12.1, you can install the dependencies with the following commands. Otherwise, please refer to [Installation](docs/installation.md) for more instructions.
|
||||
|
||||
```bash
|
||||
# create a virtual env
|
||||
# create a virtual env and activate (conda as an example)
|
||||
conda create -n opensora python=3.10
|
||||
# activate virtual environment
|
||||
conda activate opensora
|
||||
|
||||
# install torch
|
||||
# the command below is for CUDA 12.1, choose install commands from
|
||||
# https://pytorch.org/get-started/locally/ based on your own CUDA version
|
||||
pip install torch torchvision
|
||||
|
||||
# install flash attention (optional)
|
||||
# set enable_flash_attn=False in config to avoid using flash attention
|
||||
pip install packaging ninja
|
||||
pip install flash-attn --no-build-isolation
|
||||
|
||||
# install apex (optional)
|
||||
# set enable_layernorm_kernel=False in config to avoid using apex
|
||||
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" git+https://github.com/NVIDIA/apex.git
|
||||
|
||||
# install xformers
|
||||
pip install -U xformers --index-url https://download.pytorch.org/whl/cu121
|
||||
# install torch, torchvision and xformers
|
||||
pip install -r requirements/requirements_cu121.txt
|
||||
|
||||
# install this project
|
||||
git clone https://github.com/hpcaitech/Open-Sora
|
||||
|
|
@ -249,6 +236,12 @@ Since Open-Sora 1.1 supports inference with dynamic input size, you can pass the
|
|||
python scripts/inference.py configs/opensora-v1-1/inference/sample.py --prompt "A beautiful sunset over the city" --num-frames 32 --image-size 480 854
|
||||
```
|
||||
|
||||
If your installation do not contain `apex` and `flash-attn`, you need to disable them in the config file, or via the folowing command.
|
||||
|
||||
```bash
|
||||
python scripts/inference.py configs/opensora-v1-1/inference/sample.py --prompt "A beautiful sunset over the city" --num-frames 32 --image-size 480 854 --layernorm-kernel False --flash-attn False
|
||||
```
|
||||
|
||||
See [here](docs/commands.md#inference-with-open-sora-11) for more instructions including text-to-image, image-to-video, video-to-video, and infinite time generation.
|
||||
|
||||
### Open-Sora 1.0 Command Line Inference
|
||||
|
|
|
|||
40
docs/installation.md
Normal file
40
docs/installation.md
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
# Installation
|
||||
|
||||
Requirements are listed in `requirements` folder.
|
||||
|
||||
## Different CUDA versions
|
||||
|
||||
You need to mannually install `torch`, `torchvision` and `xformers` for different CUDA versions.
|
||||
|
||||
```bash
|
||||
# install torch (>=2.1 is recommended)
|
||||
# the command below is for CUDA 12.1, choose install commands from
|
||||
# https://pytorch.org/get-started/locally/ based on your own CUDA version
|
||||
pip install torch==2.2.2 torchvision==0.17.2 --index-url https://download.pytorch.org/whl/cu121
|
||||
|
||||
# install xformers
|
||||
# the command below is for CUDA 12.1, choose install commands from
|
||||
# https://github.com/facebookresearch/xformers?tab=readme-ov-file#installing-xformers based on your own CUDA version
|
||||
pip install xformers --index-url https://download.pytorch.org/whl/cu121
|
||||
```
|
||||
|
||||
```bash
|
||||
# install flash attention (optional)
|
||||
# set enable_flash_attn=False in config to avoid using flash attention
|
||||
pip install packaging ninja
|
||||
pip install flash-attn --no-build-isolation
|
||||
|
||||
# install apex (optional)
|
||||
# set enable_layernorm_kernel=False in config to avoid using apex
|
||||
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" git+https://github.com/NVIDIA/apex.git
|
||||
```
|
||||
|
||||
gdown
|
||||
pre-commit
|
||||
pyarrow
|
||||
tensorboard
|
||||
transformers
|
||||
wandb
|
||||
pandarallel
|
||||
gradio
|
||||
spaces
|
||||
|
|
@ -1,4 +0,0 @@
|
|||
from .acceleration import *
|
||||
from .datasets import *
|
||||
from .models import *
|
||||
from .registry import *
|
||||
|
|
@ -5,7 +5,6 @@ from typing import Iterator, List, Optional
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from pandarallel import pandarallel
|
||||
from torch.utils.data import Dataset, DistributedSampler
|
||||
|
||||
from opensora.utils.misc import format_numel_str, get_logger
|
||||
|
|
@ -81,6 +80,8 @@ class VariableVideoBatchSampler(DistributedSampler):
|
|||
def group_by_bucket(self) -> dict:
|
||||
bucket_sample_dict = OrderedDict()
|
||||
|
||||
from pandarallel import pandarallel
|
||||
|
||||
pandarallel.initialize(nb_workers=self.num_bucket_build_workers, progress_bar=False)
|
||||
get_logger().info("Building buckets...")
|
||||
bucket_ids = self.dataset.data.parallel_apply(
|
||||
|
|
|
|||
|
|
@ -143,12 +143,16 @@ class VideoAutoencoderPipeline(nn.Module):
|
|||
param.requires_grad = False
|
||||
|
||||
self.out_channels = self.temporal_vae.out_channels
|
||||
self.scale = torch.tensor(scale).cuda()
|
||||
self.shift = torch.tensor(shift).cuda()
|
||||
if len(self.scale.shape) > 0:
|
||||
self.scale = self.scale[None, :, None, None, None]
|
||||
if len(self.shift.shape) > 0:
|
||||
self.shift = self.shift[None, :, None, None, None]
|
||||
|
||||
# normalization parameters
|
||||
scale = torch.tensor(scale)
|
||||
shift = torch.tensor(shift)
|
||||
if len(scale.shape) > 0:
|
||||
scale = scale[None, :, None, None, None]
|
||||
if len(shift.shape) > 0:
|
||||
shift = shift[None, :, None, None, None]
|
||||
self.register_buffer("scale", scale)
|
||||
self.register_buffer("shift", shift)
|
||||
|
||||
def encode(self, x):
|
||||
x_z = self.spatial_vae.encode(x)
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ def parse_args(training=False):
|
|||
parser.add_argument("--flash-attn", default=None, type=str2bool, help="enable flash attention")
|
||||
parser.add_argument("--layernorm-kernel", default=None, type=str2bool, help="enable layernorm kernel")
|
||||
parser.add_argument("--resolution", default=None, type=str, help="multi resolution")
|
||||
parser.add_argument("--data-path", default=None, type=str, help="path to data csv")
|
||||
|
||||
# ======================================================
|
||||
# Inference
|
||||
|
|
@ -41,7 +42,6 @@ def parse_args(training=False):
|
|||
parser.add_argument("--prompt-as-path", action="store_true", help="use prompt as path to save samples")
|
||||
parser.add_argument("--verbose", default=None, type=int, help="verbose level")
|
||||
|
||||
parser.add_argument("--data-path", default=None, type=str, help="path to data csv")
|
||||
# prompt
|
||||
parser.add_argument("--prompt-path", default=None, type=str, help="path to prompt txt file")
|
||||
parser.add_argument("--prompt", default=None, type=str, nargs="+", help="prompt list")
|
||||
|
|
@ -69,7 +69,6 @@ def parse_args(training=False):
|
|||
parser.add_argument("--lr", default=None, type=float, help="learning rate")
|
||||
parser.add_argument("--wandb", default=None, type=bool, help="enable wandb")
|
||||
parser.add_argument("--load", default=None, type=str, help="path to continue training")
|
||||
parser.add_argument("--data-path", default=None, type=str, help="path to data csv")
|
||||
parser.add_argument("--start-from-scratch", action="store_true", help="start training from scratch")
|
||||
|
||||
return parser.parse_args()
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ from typing import Tuple
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
# ======================================================
|
||||
# Logging
|
||||
|
|
@ -72,6 +71,8 @@ def print_0(*args, **kwargs):
|
|||
|
||||
|
||||
def create_tensorboard_writer(exp_dir):
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
tensorboard_dir = f"{exp_dir}/tensorboard"
|
||||
os.makedirs(tensorboard_dir, exist_ok=True)
|
||||
writer = SummaryWriter(tensorboard_dir)
|
||||
|
|
@ -349,3 +350,31 @@ def transpose(x):
|
|||
|
||||
def all_exists(paths):
|
||||
return all(os.path.exists(path) for path in paths)
|
||||
|
||||
|
||||
# ======================================================
|
||||
# Profile
|
||||
# ======================================================
|
||||
|
||||
|
||||
class Timer:
|
||||
def __init__(self, name, log=True):
|
||||
self.name = name
|
||||
self.start_time = None
|
||||
self.end_time = None
|
||||
self.log = log
|
||||
|
||||
@property
|
||||
def elapsed_time(self):
|
||||
return self.end_time - self.start_time
|
||||
|
||||
def __enter__(self):
|
||||
torch.cuda.synchronize()
|
||||
self.start_time = time.time()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
torch.cuda.synchronize()
|
||||
self.end_time = time.time()
|
||||
if self.log:
|
||||
print(f"Elapsed time for {self.name}: {self.elapsed_time:.2f} s")
|
||||
|
|
|
|||
|
|
@ -1,19 +0,0 @@
|
|||
colossalai
|
||||
accelerate
|
||||
diffusers
|
||||
ftfy
|
||||
gdown
|
||||
mmengine
|
||||
pandas
|
||||
pre-commit
|
||||
pyarrow
|
||||
av
|
||||
tensorboard
|
||||
timm
|
||||
tqdm
|
||||
transformers
|
||||
wandb
|
||||
rotary_embedding_torch
|
||||
pandarallel
|
||||
gradio
|
||||
spaces
|
||||
3
requirements/requirements-cu121.txt
Normal file
3
requirements/requirements-cu121.txt
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
torch==2.2.2 --index-url https://download.pytorch.org/whl/cu121
|
||||
torchvision==0.17.2 --index-url https://download.pytorch.org/whl/cu121
|
||||
xformers==0.0.25.post1 --index-url https://download.pytorch.org/whl/cu121
|
||||
3
requirements/requirements-fast.txt
Normal file
3
requirements/requirements-fast.txt
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
packaging
|
||||
ninja
|
||||
flash-attn --no-build-isolation
|
||||
9
requirements/requirements.txt
Normal file
9
requirements/requirements.txt
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
colossalai>=0.3.7
|
||||
mmengine>=0.10.3
|
||||
pandas>=2.2.2
|
||||
timm==0.9.16
|
||||
rotary_embedding_torch==0.5.3
|
||||
ftfy>=6.2.0 # for t5
|
||||
diffusers==0.27.2 # for vae
|
||||
accelerate==0.29.2 # for t5
|
||||
av>=12.0.0
|
||||
340
scripts/misc/profile_train.py
Normal file
340
scripts/misc/profile_train.py
Normal file
|
|
@ -0,0 +1,340 @@
|
|||
import os
|
||||
from copy import deepcopy
|
||||
from datetime import timedelta
|
||||
from pprint import pformat
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import wandb
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device, set_seed
|
||||
from tqdm import tqdm
|
||||
|
||||
from opensora.acceleration.checkpoint import set_grad_checkpoint
|
||||
from opensora.acceleration.parallel_states import get_data_parallel_group
|
||||
from opensora.datasets import prepare_dataloader, prepare_variable_dataloader
|
||||
from opensora.datasets.utils import collate_fn_ignore_none
|
||||
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
|
||||
from opensora.utils.ckpt_utils import load, model_gathering, model_sharding, record_model_param_shape, save
|
||||
from opensora.utils.config_utils import define_experiment_workspace, parse_configs, save_training_config
|
||||
from opensora.utils.misc import (
|
||||
Timer,
|
||||
all_reduce_mean,
|
||||
create_logger,
|
||||
create_tensorboard_writer,
|
||||
format_numel_str,
|
||||
get_model_numel,
|
||||
requires_grad,
|
||||
to_torch_dtype,
|
||||
)
|
||||
from opensora.utils.train_utils import MaskGenerator, create_colossalai_plugin, update_ema
|
||||
|
||||
DEFAULT_DATASET_NAME = "VideoTextDataset"
|
||||
|
||||
|
||||
def main():
|
||||
# ======================================================
|
||||
# 1. configs & runtime variables
|
||||
# ======================================================
|
||||
# == parse configs ==
|
||||
cfg = parse_configs(training=True)
|
||||
|
||||
# == device and dtype ==
|
||||
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
|
||||
cfg_dtype = cfg.get("dtype", "bf16")
|
||||
assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}"
|
||||
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
|
||||
|
||||
# == colossalai init distributed training ==
|
||||
# NOTE: A very large timeout is set to avoid some processes exit early
|
||||
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
|
||||
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
|
||||
set_seed(cfg.get("seed", 1024))
|
||||
coordinator = DistCoordinator()
|
||||
device = get_current_device()
|
||||
|
||||
# == init exp_dir ==
|
||||
exp_name, exp_dir = define_experiment_workspace(cfg)
|
||||
coordinator.block_all()
|
||||
if coordinator.is_master():
|
||||
os.makedirs(exp_dir, exist_ok=True)
|
||||
save_training_config(cfg.to_dict(), exp_dir)
|
||||
coordinator.block_all()
|
||||
|
||||
# == init logger, tensorboard & wandb ==
|
||||
logger = create_logger(exp_dir)
|
||||
logger.info("Experiment directory created at %s", exp_dir)
|
||||
logger.info("Training configuration:\n %s", pformat(cfg.to_dict()))
|
||||
if coordinator.is_master():
|
||||
tb_writer = create_tensorboard_writer(exp_dir)
|
||||
if cfg.get("wandb", False):
|
||||
wandb.init(project="minisora", name=exp_name, config=cfg.to_dict(), dir="./outputs/wandb")
|
||||
|
||||
# == init ColossalAI booster ==
|
||||
plugin = create_colossalai_plugin(
|
||||
plugin=cfg.get("plugin", "zero2"),
|
||||
dtype=cfg_dtype,
|
||||
grad_clip=cfg.get("grad_clip", 0),
|
||||
sp_size=cfg.get("sp_size", 1),
|
||||
)
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
# ======================================================
|
||||
# 2. build dataset and dataloader
|
||||
# ======================================================
|
||||
logger.info("Building dataset...")
|
||||
# == build dataset ==
|
||||
dataset = build_module(cfg.dataset, DATASETS)
|
||||
logger.info("Dataset contains %s samples.", len(dataset))
|
||||
|
||||
# == build dataloader ==
|
||||
dataloader_args = dict(
|
||||
dataset=dataset,
|
||||
batch_size=cfg.get("batch_size", None),
|
||||
num_workers=cfg.get("num_workers", 4),
|
||||
seed=cfg.get("seed", 1024),
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
pin_memory=True,
|
||||
process_group=get_data_parallel_group(),
|
||||
collate_fn=collate_fn_ignore_none,
|
||||
)
|
||||
if cfg.dataset.type == DEFAULT_DATASET_NAME:
|
||||
dataloader = prepare_dataloader(**dataloader_args)
|
||||
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.get("sp_size", 1)
|
||||
logger.info("Total batch size: %s", total_batch_size)
|
||||
num_steps_per_epoch = len(dataloader)
|
||||
sampler_to_io = None
|
||||
else:
|
||||
dataloader = prepare_variable_dataloader(
|
||||
bucket_config=cfg.get("bucket_config", None),
|
||||
num_bucket_build_workers=cfg.get("num_bucket_build_workers", 1),
|
||||
**dataloader_args,
|
||||
)
|
||||
num_steps_per_epoch = dataloader.batch_sampler.get_num_batch() // dist.get_world_size()
|
||||
sampler_to_io = None if cfg.get("start_from_scratch ", False) else dataloader.batch_sampler
|
||||
|
||||
# ======================================================
|
||||
# 3. build model
|
||||
# ======================================================
|
||||
logger.info("Building models...")
|
||||
# == build text-encoder and vae ==
|
||||
text_encoder = build_module(cfg.text_encoder, MODELS, device=device)
|
||||
vae = build_module(cfg.vae, MODELS).to(device, dtype).eval()
|
||||
|
||||
# == build diffusion model ==
|
||||
input_size = (dataset.num_frames, *dataset.image_size)
|
||||
latent_size = vae.get_latent_size(input_size)
|
||||
model = (
|
||||
build_module(
|
||||
cfg.model,
|
||||
MODELS,
|
||||
input_size=latent_size,
|
||||
in_channels=vae.out_channels,
|
||||
caption_channels=text_encoder.output_dim,
|
||||
model_max_length=text_encoder.model_max_length,
|
||||
)
|
||||
.to(device, dtype)
|
||||
.train()
|
||||
)
|
||||
model_numel, model_numel_trainable = get_model_numel(model)
|
||||
logger.info(
|
||||
"[Diffusion] Trainable model params: %s, Total model params: %s",
|
||||
format_numel_str(model_numel_trainable),
|
||||
format_numel_str(model_numel),
|
||||
)
|
||||
|
||||
# == build ema for diffusion model ==
|
||||
ema = deepcopy(model).to(torch.float32).to(device)
|
||||
requires_grad(ema, False)
|
||||
ema_shape_dict = record_model_param_shape(ema)
|
||||
ema.eval()
|
||||
update_ema(ema, model, decay=0, sharded=False)
|
||||
|
||||
# == setup loss function, build scheduler ==
|
||||
scheduler = build_module(cfg.scheduler, SCHEDULERS)
|
||||
|
||||
# == setup optimizer ==
|
||||
optimizer = HybridAdam(
|
||||
filter(lambda p: p.requires_grad, model.parameters()),
|
||||
adamw_mode=True,
|
||||
lr=cfg.get("lr", 1e-4),
|
||||
weight_decay=cfg.get("weight_decay", 0),
|
||||
eps=cfg.get("adam_eps", 1e-8),
|
||||
)
|
||||
lr_scheduler = None
|
||||
|
||||
# == additional preparation ==
|
||||
if cfg.get("grad_checkpoint", False):
|
||||
set_grad_checkpoint(model)
|
||||
if cfg.get("mask_ratios", None) is not None:
|
||||
mask_generator = MaskGenerator(cfg.mask_ratios)
|
||||
|
||||
# =======================================================
|
||||
# 4. distributed training preparation with colossalai
|
||||
# =======================================================
|
||||
logger.info("Preparing for distributed training...")
|
||||
# == boosting ==
|
||||
# NOTE: we set dtype first to make initialization of model consistent with the dtype; then reset it to the fp32 as we make diffusion scheduler in fp32
|
||||
torch.set_default_dtype(dtype)
|
||||
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
dataloader=dataloader,
|
||||
)
|
||||
torch.set_default_dtype(torch.float)
|
||||
logger.info("Boosting model for distributed training")
|
||||
|
||||
# == global variables ==
|
||||
cfg_epochs = cfg.get("epochs", 1000)
|
||||
start_epoch = start_step = log_step = sampler_start_idx = acc_step = 0
|
||||
running_loss = 0.0
|
||||
logger.info("Training for %s epochs with %s steps per epoch", cfg_epochs, num_steps_per_epoch)
|
||||
|
||||
# == resume ==
|
||||
if cfg.get("load", None) is not None:
|
||||
logger.info("Loading checkpoint")
|
||||
ret = load(
|
||||
booster,
|
||||
cfg.load,
|
||||
model=model,
|
||||
ema=ema,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
sampler=sampler_to_io,
|
||||
)
|
||||
if not cfg.get("start_from_scratch ", False):
|
||||
start_epoch, start_step, sampler_start_idx = ret
|
||||
logger.info("Loaded checkpoint %s at epoch %s step %s", cfg.load, start_epoch, start_step)
|
||||
if cfg.dataset.type == DEFAULT_DATASET_NAME:
|
||||
dataloader.sampler.set_start_index(sampler_start_idx)
|
||||
|
||||
model_sharding(ema)
|
||||
|
||||
# =======================================================
|
||||
# 5. training loop
|
||||
# =======================================================
|
||||
dist.barrier()
|
||||
for epoch in range(start_epoch, cfg_epochs):
|
||||
# == set dataloader to new epoch ==
|
||||
if cfg.dataset.type == DEFAULT_DATASET_NAME:
|
||||
dataloader.sampler.set_epoch(epoch)
|
||||
dataloader_iter = iter(dataloader)
|
||||
logger.info("Beginning epoch %s...", epoch)
|
||||
|
||||
# == training loop in an epoch ==
|
||||
with tqdm(
|
||||
enumerate(dataloader_iter, start=start_step),
|
||||
desc=f"Epoch {epoch}",
|
||||
disable=not coordinator.is_master(),
|
||||
initial=start_step,
|
||||
total=num_steps_per_epoch,
|
||||
) as pbar:
|
||||
for step, batch in pbar:
|
||||
x = batch.pop("video").to(device, dtype) # [B, C, T, H, W]
|
||||
y = batch.pop("text")
|
||||
|
||||
# == visual and text encoding ==
|
||||
with torch.no_grad():
|
||||
# Prepare visual inputs
|
||||
with Timer("VAE"):
|
||||
x = vae.encode(x) # [B, C, T, H/P, W/P]
|
||||
# Prepare text inputs
|
||||
with Timer("Text Encoder"):
|
||||
model_args = text_encoder.encode(y)
|
||||
|
||||
# == mask ==
|
||||
mask = None
|
||||
if cfg.get("mask_ratios", None) is not None:
|
||||
mask = mask_generator.get_masks(x)
|
||||
model_args["x_mask"] = mask
|
||||
|
||||
# == video meta info ==
|
||||
for k, v in batch.items():
|
||||
model_args[k] = v.to(device, dtype)
|
||||
|
||||
# == diffusion loss computation ==
|
||||
with Timer("Forward"):
|
||||
loss_dict = scheduler.training_losses(model, x, model_args, mask=mask)
|
||||
|
||||
# == backward & update ==
|
||||
with Timer("Backward"):
|
||||
loss = loss_dict["loss"].mean()
|
||||
booster.backward(loss=loss, optimizer=optimizer)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# == update EMA ==
|
||||
update_ema(ema, model.module, optimizer=optimizer, decay=cfg.get("ema_decay", 0.9999))
|
||||
|
||||
# == update log info ==
|
||||
all_reduce_mean(loss)
|
||||
running_loss += loss.item()
|
||||
global_step = epoch * num_steps_per_epoch + step
|
||||
log_step += 1
|
||||
acc_step += 1
|
||||
|
||||
# == logging ==
|
||||
if coordinator.is_master() and (global_step + 1) % cfg.get("log_every", 1) == 0:
|
||||
avg_loss = running_loss / log_step
|
||||
# progress bar
|
||||
pbar.set_postfix({"loss": avg_loss, "step": step, "global_step": global_step})
|
||||
# tensorboard
|
||||
tb_writer.add_scalar("loss", loss.item(), global_step)
|
||||
# wandb
|
||||
if cfg.get("wandb", False):
|
||||
wandb.log(
|
||||
{
|
||||
"iter": global_step,
|
||||
"epoch": epoch,
|
||||
"loss": loss.item(),
|
||||
"avg_loss": avg_loss,
|
||||
"acc_step": acc_step,
|
||||
},
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
running_loss = 0.0
|
||||
log_step = 0
|
||||
|
||||
# == checkpoint saving ==
|
||||
ckpt_every = cfg.get("ckpt_every", 0)
|
||||
if ckpt_every > 0 and (global_step + 1) % ckpt_every == 0:
|
||||
model_gathering(ema, ema_shape_dict)
|
||||
save(
|
||||
booster,
|
||||
exp_dir,
|
||||
model=model,
|
||||
ema=ema,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
sampler=sampler_to_io,
|
||||
epoch=epoch,
|
||||
step=step + 1,
|
||||
global_step=global_step + 1,
|
||||
batch_size=cfg.get("batch_size", None),
|
||||
)
|
||||
if dist.get_rank() == 0:
|
||||
model_sharding(ema)
|
||||
logger.info(
|
||||
"Saved checkpoint at epoch %s step %s global_step %s to %s",
|
||||
epoch,
|
||||
step + 1,
|
||||
global_step + 1,
|
||||
exp_dir,
|
||||
)
|
||||
|
||||
# NOTE: the continue epochs are not resumed, so we need to reset the sampler start index and start step
|
||||
if cfg.dataset.type == DEFAULT_DATASET_NAME:
|
||||
dataloader.sampler.set_start_index(0)
|
||||
else:
|
||||
dataloader.batch_sampler.set_epoch(epoch + 1)
|
||||
logger.info("Epoch done, recomputing batch sampler")
|
||||
start_step = 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -5,13 +5,13 @@ from pprint import pformat
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import wandb
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device, set_seed
|
||||
from tqdm import tqdm
|
||||
|
||||
import wandb
|
||||
from opensora.acceleration.checkpoint import set_grad_checkpoint
|
||||
from opensora.acceleration.parallel_states import get_data_parallel_group
|
||||
from opensora.datasets import prepare_dataloader, prepare_variable_dataloader
|
||||
|
|
@ -120,7 +120,7 @@ def main():
|
|||
# ======================================================
|
||||
logger.info("Building models...")
|
||||
# == build text-encoder and vae ==
|
||||
text_encoder = build_module(cfg.text_encoder, MODELS, device=device)
|
||||
text_encoder = build_module(cfg.text_encoder, MODELS, device=device, dtype=dtype)
|
||||
vae = build_module(cfg.vae, MODELS).to(device, dtype).eval()
|
||||
|
||||
# == build diffusion model ==
|
||||
|
|
|
|||
30
setup.py
30
setup.py
|
|
@ -3,7 +3,7 @@ from typing import List
|
|||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
def fetch_requirements(path) -> List[str]:
|
||||
def fetch_requirements(paths) -> List[str]:
|
||||
"""
|
||||
This function reads the requirements file.
|
||||
|
||||
|
|
@ -13,8 +13,13 @@ def fetch_requirements(path) -> List[str]:
|
|||
Returns:
|
||||
The lines in the requirements file.
|
||||
"""
|
||||
with open(path, "r") as fd:
|
||||
return [r.strip() for r in fd.readlines()]
|
||||
if not isinstance(paths, list):
|
||||
paths = [paths]
|
||||
requirements = []
|
||||
for path in paths:
|
||||
with open(path, "r") as fd:
|
||||
requirements += [r.strip() for r in fd.readlines()]
|
||||
return requirements
|
||||
|
||||
|
||||
def fetch_readme() -> str:
|
||||
|
|
@ -34,10 +39,17 @@ setup(
|
|||
packages=find_packages(
|
||||
exclude=(
|
||||
"assets",
|
||||
"cache",
|
||||
"configs",
|
||||
"docs",
|
||||
"eval",
|
||||
"evaluation_results",
|
||||
"gradio",
|
||||
"logs",
|
||||
"notebooks",
|
||||
"outputs",
|
||||
"pretrained_models",
|
||||
"samples",
|
||||
"scripts",
|
||||
"tests",
|
||||
"tools",
|
||||
|
|
@ -48,7 +60,14 @@ setup(
|
|||
long_description=fetch_readme(),
|
||||
long_description_content_type="text/markdown",
|
||||
license="Apache Software License 2.0",
|
||||
install_requires=fetch_requirements("requirements.txt"),
|
||||
url="https://github.com/hpcaitech/Open-Sora",
|
||||
project_urls={
|
||||
"Bug Tracker": "https://github.com/hpcaitech/Open-Sora/issues",
|
||||
"Examples": "https://hpcaitech.github.io/Open-Sora/",
|
||||
"Documentation": "https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file",
|
||||
"Github": "https://github.com/hpcaitech/Open-Sora",
|
||||
},
|
||||
install_requires=fetch_requirements("requirements/requirements.txt"),
|
||||
python_requires=">=3.6",
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
|
|
@ -57,4 +76,7 @@ setup(
|
|||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"Topic :: System :: Distributed Computing",
|
||||
],
|
||||
extras_require={
|
||||
"fast": ["flash-attn --no-build-isolation"],
|
||||
},
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue