From 9fd4554f55d776310d2950322ccf4082c50ca0e2 Mon Sep 17 00:00:00 2001 From: zhengzangw Date: Thu, 16 May 2024 08:50:24 +0000 Subject: [PATCH] [fix] vae dtype --- README.md | 29 +-- docs/installation.md | 40 ++++ opensora/__init__.py | 4 - opensora/datasets/sampler.py | 3 +- opensora/models/vae/vae.py | 16 +- opensora/utils/config_utils.py | 3 +- opensora/utils/misc.py | 31 ++- requirements.txt | 19 -- requirements/requirements-cu121.txt | 3 + requirements/requirements-fast.txt | 3 + requirements/requirements.txt | 9 + scripts/misc/profile_train.py | 340 ++++++++++++++++++++++++++++ scripts/train.py | 4 +- setup.py | 30 ++- 14 files changed, 477 insertions(+), 57 deletions(-) create mode 100644 docs/installation.md delete mode 100644 requirements.txt create mode 100644 requirements/requirements-cu121.txt create mode 100644 requirements/requirements-fast.txt create mode 100644 requirements/requirements.txt create mode 100644 scripts/misc/profile_train.py diff --git a/README.md b/README.md index 6a03d64..60a900a 100644 --- a/README.md +++ b/README.md @@ -146,28 +146,15 @@ Other useful documents and links are listed below. ### Install from Source +For CUDA 12.1, you can install the dependencies with the following commands. Otherwise, please refer to [Installation](docs/installation.md) for more instructions. + ```bash -# create a virtual env +# create a virtual env and activate (conda as an example) conda create -n opensora python=3.10 -# activate virtual environment conda activate opensora -# install torch -# the command below is for CUDA 12.1, choose install commands from -# https://pytorch.org/get-started/locally/ based on your own CUDA version -pip install torch torchvision - -# install flash attention (optional) -# set enable_flash_attn=False in config to avoid using flash attention -pip install packaging ninja -pip install flash-attn --no-build-isolation - -# install apex (optional) -# set enable_layernorm_kernel=False in config to avoid using apex -pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" git+https://github.com/NVIDIA/apex.git - -# install xformers -pip install -U xformers --index-url https://download.pytorch.org/whl/cu121 +# install torch, torchvision and xformers +pip install -r requirements/requirements_cu121.txt # install this project git clone https://github.com/hpcaitech/Open-Sora @@ -249,6 +236,12 @@ Since Open-Sora 1.1 supports inference with dynamic input size, you can pass the python scripts/inference.py configs/opensora-v1-1/inference/sample.py --prompt "A beautiful sunset over the city" --num-frames 32 --image-size 480 854 ``` +If your installation do not contain `apex` and `flash-attn`, you need to disable them in the config file, or via the folowing command. + +```bash +python scripts/inference.py configs/opensora-v1-1/inference/sample.py --prompt "A beautiful sunset over the city" --num-frames 32 --image-size 480 854 --layernorm-kernel False --flash-attn False +``` + See [here](docs/commands.md#inference-with-open-sora-11) for more instructions including text-to-image, image-to-video, video-to-video, and infinite time generation. ### Open-Sora 1.0 Command Line Inference diff --git a/docs/installation.md b/docs/installation.md new file mode 100644 index 0000000..d74b923 --- /dev/null +++ b/docs/installation.md @@ -0,0 +1,40 @@ +# Installation + +Requirements are listed in `requirements` folder. + +## Different CUDA versions + +You need to mannually install `torch`, `torchvision` and `xformers` for different CUDA versions. + +```bash +# install torch (>=2.1 is recommended) +# the command below is for CUDA 12.1, choose install commands from +# https://pytorch.org/get-started/locally/ based on your own CUDA version +pip install torch==2.2.2 torchvision==0.17.2 --index-url https://download.pytorch.org/whl/cu121 + +# install xformers +# the command below is for CUDA 12.1, choose install commands from +# https://github.com/facebookresearch/xformers?tab=readme-ov-file#installing-xformers based on your own CUDA version +pip install xformers --index-url https://download.pytorch.org/whl/cu121 +``` + +```bash +# install flash attention (optional) +# set enable_flash_attn=False in config to avoid using flash attention +pip install packaging ninja +pip install flash-attn --no-build-isolation + +# install apex (optional) +# set enable_layernorm_kernel=False in config to avoid using apex +pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" git+https://github.com/NVIDIA/apex.git +``` + +gdown +pre-commit +pyarrow +tensorboard +transformers +wandb +pandarallel +gradio +spaces diff --git a/opensora/__init__.py b/opensora/__init__.py index a3175b2..e69de29 100644 --- a/opensora/__init__.py +++ b/opensora/__init__.py @@ -1,4 +0,0 @@ -from .acceleration import * -from .datasets import * -from .models import * -from .registry import * diff --git a/opensora/datasets/sampler.py b/opensora/datasets/sampler.py index 9b974c4..7fd4d77 100644 --- a/opensora/datasets/sampler.py +++ b/opensora/datasets/sampler.py @@ -5,7 +5,6 @@ from typing import Iterator, List, Optional import torch import torch.distributed as dist -from pandarallel import pandarallel from torch.utils.data import Dataset, DistributedSampler from opensora.utils.misc import format_numel_str, get_logger @@ -81,6 +80,8 @@ class VariableVideoBatchSampler(DistributedSampler): def group_by_bucket(self) -> dict: bucket_sample_dict = OrderedDict() + from pandarallel import pandarallel + pandarallel.initialize(nb_workers=self.num_bucket_build_workers, progress_bar=False) get_logger().info("Building buckets...") bucket_ids = self.dataset.data.parallel_apply( diff --git a/opensora/models/vae/vae.py b/opensora/models/vae/vae.py index da43c18..f5769a2 100644 --- a/opensora/models/vae/vae.py +++ b/opensora/models/vae/vae.py @@ -143,12 +143,16 @@ class VideoAutoencoderPipeline(nn.Module): param.requires_grad = False self.out_channels = self.temporal_vae.out_channels - self.scale = torch.tensor(scale).cuda() - self.shift = torch.tensor(shift).cuda() - if len(self.scale.shape) > 0: - self.scale = self.scale[None, :, None, None, None] - if len(self.shift.shape) > 0: - self.shift = self.shift[None, :, None, None, None] + + # normalization parameters + scale = torch.tensor(scale) + shift = torch.tensor(shift) + if len(scale.shape) > 0: + scale = scale[None, :, None, None, None] + if len(shift.shape) > 0: + shift = shift[None, :, None, None, None] + self.register_buffer("scale", scale) + self.register_buffer("shift", shift) def encode(self, x): x_z = self.spatial_vae.encode(x) diff --git a/opensora/utils/config_utils.py b/opensora/utils/config_utils.py index 0fd558f..755e8ac 100644 --- a/opensora/utils/config_utils.py +++ b/opensora/utils/config_utils.py @@ -27,6 +27,7 @@ def parse_args(training=False): parser.add_argument("--flash-attn", default=None, type=str2bool, help="enable flash attention") parser.add_argument("--layernorm-kernel", default=None, type=str2bool, help="enable layernorm kernel") parser.add_argument("--resolution", default=None, type=str, help="multi resolution") + parser.add_argument("--data-path", default=None, type=str, help="path to data csv") # ====================================================== # Inference @@ -41,7 +42,6 @@ def parse_args(training=False): parser.add_argument("--prompt-as-path", action="store_true", help="use prompt as path to save samples") parser.add_argument("--verbose", default=None, type=int, help="verbose level") - parser.add_argument("--data-path", default=None, type=str, help="path to data csv") # prompt parser.add_argument("--prompt-path", default=None, type=str, help="path to prompt txt file") parser.add_argument("--prompt", default=None, type=str, nargs="+", help="prompt list") @@ -69,7 +69,6 @@ def parse_args(training=False): parser.add_argument("--lr", default=None, type=float, help="learning rate") parser.add_argument("--wandb", default=None, type=bool, help="enable wandb") parser.add_argument("--load", default=None, type=str, help="path to continue training") - parser.add_argument("--data-path", default=None, type=str, help="path to data csv") parser.add_argument("--start-from-scratch", action="store_true", help="start training from scratch") return parser.parse_args() diff --git a/opensora/utils/misc.py b/opensora/utils/misc.py index 08b63de..ae4e3fd 100644 --- a/opensora/utils/misc.py +++ b/opensora/utils/misc.py @@ -11,7 +11,6 @@ from typing import Tuple import numpy as np import torch import torch.distributed as dist -from torch.utils.tensorboard import SummaryWriter # ====================================================== # Logging @@ -72,6 +71,8 @@ def print_0(*args, **kwargs): def create_tensorboard_writer(exp_dir): + from torch.utils.tensorboard import SummaryWriter + tensorboard_dir = f"{exp_dir}/tensorboard" os.makedirs(tensorboard_dir, exist_ok=True) writer = SummaryWriter(tensorboard_dir) @@ -349,3 +350,31 @@ def transpose(x): def all_exists(paths): return all(os.path.exists(path) for path in paths) + + +# ====================================================== +# Profile +# ====================================================== + + +class Timer: + def __init__(self, name, log=True): + self.name = name + self.start_time = None + self.end_time = None + self.log = log + + @property + def elapsed_time(self): + return self.end_time - self.start_time + + def __enter__(self): + torch.cuda.synchronize() + self.start_time = time.time() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + torch.cuda.synchronize() + self.end_time = time.time() + if self.log: + print(f"Elapsed time for {self.name}: {self.elapsed_time:.2f} s") diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index b3bbba6..0000000 --- a/requirements.txt +++ /dev/null @@ -1,19 +0,0 @@ -colossalai -accelerate -diffusers -ftfy -gdown -mmengine -pandas -pre-commit -pyarrow -av -tensorboard -timm -tqdm -transformers -wandb -rotary_embedding_torch -pandarallel -gradio -spaces diff --git a/requirements/requirements-cu121.txt b/requirements/requirements-cu121.txt new file mode 100644 index 0000000..362381d --- /dev/null +++ b/requirements/requirements-cu121.txt @@ -0,0 +1,3 @@ +torch==2.2.2 --index-url https://download.pytorch.org/whl/cu121 +torchvision==0.17.2 --index-url https://download.pytorch.org/whl/cu121 +xformers==0.0.25.post1 --index-url https://download.pytorch.org/whl/cu121 diff --git a/requirements/requirements-fast.txt b/requirements/requirements-fast.txt new file mode 100644 index 0000000..0c02389 --- /dev/null +++ b/requirements/requirements-fast.txt @@ -0,0 +1,3 @@ +packaging +ninja +flash-attn --no-build-isolation diff --git a/requirements/requirements.txt b/requirements/requirements.txt new file mode 100644 index 0000000..e124aca --- /dev/null +++ b/requirements/requirements.txt @@ -0,0 +1,9 @@ +colossalai>=0.3.7 +mmengine>=0.10.3 +pandas>=2.2.2 +timm==0.9.16 +rotary_embedding_torch==0.5.3 +ftfy>=6.2.0 # for t5 +diffusers==0.27.2 # for vae +accelerate==0.29.2 # for t5 +av>=12.0.0 diff --git a/scripts/misc/profile_train.py b/scripts/misc/profile_train.py new file mode 100644 index 0000000..4ec7139 --- /dev/null +++ b/scripts/misc/profile_train.py @@ -0,0 +1,340 @@ +import os +from copy import deepcopy +from datetime import timedelta +from pprint import pformat + +import torch +import torch.distributed as dist +import wandb +from colossalai.booster import Booster +from colossalai.cluster import DistCoordinator +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device, set_seed +from tqdm import tqdm + +from opensora.acceleration.checkpoint import set_grad_checkpoint +from opensora.acceleration.parallel_states import get_data_parallel_group +from opensora.datasets import prepare_dataloader, prepare_variable_dataloader +from opensora.datasets.utils import collate_fn_ignore_none +from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module +from opensora.utils.ckpt_utils import load, model_gathering, model_sharding, record_model_param_shape, save +from opensora.utils.config_utils import define_experiment_workspace, parse_configs, save_training_config +from opensora.utils.misc import ( + Timer, + all_reduce_mean, + create_logger, + create_tensorboard_writer, + format_numel_str, + get_model_numel, + requires_grad, + to_torch_dtype, +) +from opensora.utils.train_utils import MaskGenerator, create_colossalai_plugin, update_ema + +DEFAULT_DATASET_NAME = "VideoTextDataset" + + +def main(): + # ====================================================== + # 1. configs & runtime variables + # ====================================================== + # == parse configs == + cfg = parse_configs(training=True) + + # == device and dtype == + assert torch.cuda.is_available(), "Training currently requires at least one GPU." + cfg_dtype = cfg.get("dtype", "bf16") + assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}" + dtype = to_torch_dtype(cfg.get("dtype", "bf16")) + + # == colossalai init distributed training == + # NOTE: A very large timeout is set to avoid some processes exit early + dist.init_process_group(backend="nccl", timeout=timedelta(hours=24)) + torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) + set_seed(cfg.get("seed", 1024)) + coordinator = DistCoordinator() + device = get_current_device() + + # == init exp_dir == + exp_name, exp_dir = define_experiment_workspace(cfg) + coordinator.block_all() + if coordinator.is_master(): + os.makedirs(exp_dir, exist_ok=True) + save_training_config(cfg.to_dict(), exp_dir) + coordinator.block_all() + + # == init logger, tensorboard & wandb == + logger = create_logger(exp_dir) + logger.info("Experiment directory created at %s", exp_dir) + logger.info("Training configuration:\n %s", pformat(cfg.to_dict())) + if coordinator.is_master(): + tb_writer = create_tensorboard_writer(exp_dir) + if cfg.get("wandb", False): + wandb.init(project="minisora", name=exp_name, config=cfg.to_dict(), dir="./outputs/wandb") + + # == init ColossalAI booster == + plugin = create_colossalai_plugin( + plugin=cfg.get("plugin", "zero2"), + dtype=cfg_dtype, + grad_clip=cfg.get("grad_clip", 0), + sp_size=cfg.get("sp_size", 1), + ) + booster = Booster(plugin=plugin) + + # ====================================================== + # 2. build dataset and dataloader + # ====================================================== + logger.info("Building dataset...") + # == build dataset == + dataset = build_module(cfg.dataset, DATASETS) + logger.info("Dataset contains %s samples.", len(dataset)) + + # == build dataloader == + dataloader_args = dict( + dataset=dataset, + batch_size=cfg.get("batch_size", None), + num_workers=cfg.get("num_workers", 4), + seed=cfg.get("seed", 1024), + shuffle=True, + drop_last=True, + pin_memory=True, + process_group=get_data_parallel_group(), + collate_fn=collate_fn_ignore_none, + ) + if cfg.dataset.type == DEFAULT_DATASET_NAME: + dataloader = prepare_dataloader(**dataloader_args) + total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.get("sp_size", 1) + logger.info("Total batch size: %s", total_batch_size) + num_steps_per_epoch = len(dataloader) + sampler_to_io = None + else: + dataloader = prepare_variable_dataloader( + bucket_config=cfg.get("bucket_config", None), + num_bucket_build_workers=cfg.get("num_bucket_build_workers", 1), + **dataloader_args, + ) + num_steps_per_epoch = dataloader.batch_sampler.get_num_batch() // dist.get_world_size() + sampler_to_io = None if cfg.get("start_from_scratch ", False) else dataloader.batch_sampler + + # ====================================================== + # 3. build model + # ====================================================== + logger.info("Building models...") + # == build text-encoder and vae == + text_encoder = build_module(cfg.text_encoder, MODELS, device=device) + vae = build_module(cfg.vae, MODELS).to(device, dtype).eval() + + # == build diffusion model == + input_size = (dataset.num_frames, *dataset.image_size) + latent_size = vae.get_latent_size(input_size) + model = ( + build_module( + cfg.model, + MODELS, + input_size=latent_size, + in_channels=vae.out_channels, + caption_channels=text_encoder.output_dim, + model_max_length=text_encoder.model_max_length, + ) + .to(device, dtype) + .train() + ) + model_numel, model_numel_trainable = get_model_numel(model) + logger.info( + "[Diffusion] Trainable model params: %s, Total model params: %s", + format_numel_str(model_numel_trainable), + format_numel_str(model_numel), + ) + + # == build ema for diffusion model == + ema = deepcopy(model).to(torch.float32).to(device) + requires_grad(ema, False) + ema_shape_dict = record_model_param_shape(ema) + ema.eval() + update_ema(ema, model, decay=0, sharded=False) + + # == setup loss function, build scheduler == + scheduler = build_module(cfg.scheduler, SCHEDULERS) + + # == setup optimizer == + optimizer = HybridAdam( + filter(lambda p: p.requires_grad, model.parameters()), + adamw_mode=True, + lr=cfg.get("lr", 1e-4), + weight_decay=cfg.get("weight_decay", 0), + eps=cfg.get("adam_eps", 1e-8), + ) + lr_scheduler = None + + # == additional preparation == + if cfg.get("grad_checkpoint", False): + set_grad_checkpoint(model) + if cfg.get("mask_ratios", None) is not None: + mask_generator = MaskGenerator(cfg.mask_ratios) + + # ======================================================= + # 4. distributed training preparation with colossalai + # ======================================================= + logger.info("Preparing for distributed training...") + # == boosting == + # NOTE: we set dtype first to make initialization of model consistent with the dtype; then reset it to the fp32 as we make diffusion scheduler in fp32 + torch.set_default_dtype(dtype) + model, optimizer, _, dataloader, lr_scheduler = booster.boost( + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + dataloader=dataloader, + ) + torch.set_default_dtype(torch.float) + logger.info("Boosting model for distributed training") + + # == global variables == + cfg_epochs = cfg.get("epochs", 1000) + start_epoch = start_step = log_step = sampler_start_idx = acc_step = 0 + running_loss = 0.0 + logger.info("Training for %s epochs with %s steps per epoch", cfg_epochs, num_steps_per_epoch) + + # == resume == + if cfg.get("load", None) is not None: + logger.info("Loading checkpoint") + ret = load( + booster, + cfg.load, + model=model, + ema=ema, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + sampler=sampler_to_io, + ) + if not cfg.get("start_from_scratch ", False): + start_epoch, start_step, sampler_start_idx = ret + logger.info("Loaded checkpoint %s at epoch %s step %s", cfg.load, start_epoch, start_step) + if cfg.dataset.type == DEFAULT_DATASET_NAME: + dataloader.sampler.set_start_index(sampler_start_idx) + + model_sharding(ema) + + # ======================================================= + # 5. training loop + # ======================================================= + dist.barrier() + for epoch in range(start_epoch, cfg_epochs): + # == set dataloader to new epoch == + if cfg.dataset.type == DEFAULT_DATASET_NAME: + dataloader.sampler.set_epoch(epoch) + dataloader_iter = iter(dataloader) + logger.info("Beginning epoch %s...", epoch) + + # == training loop in an epoch == + with tqdm( + enumerate(dataloader_iter, start=start_step), + desc=f"Epoch {epoch}", + disable=not coordinator.is_master(), + initial=start_step, + total=num_steps_per_epoch, + ) as pbar: + for step, batch in pbar: + x = batch.pop("video").to(device, dtype) # [B, C, T, H, W] + y = batch.pop("text") + + # == visual and text encoding == + with torch.no_grad(): + # Prepare visual inputs + with Timer("VAE"): + x = vae.encode(x) # [B, C, T, H/P, W/P] + # Prepare text inputs + with Timer("Text Encoder"): + model_args = text_encoder.encode(y) + + # == mask == + mask = None + if cfg.get("mask_ratios", None) is not None: + mask = mask_generator.get_masks(x) + model_args["x_mask"] = mask + + # == video meta info == + for k, v in batch.items(): + model_args[k] = v.to(device, dtype) + + # == diffusion loss computation == + with Timer("Forward"): + loss_dict = scheduler.training_losses(model, x, model_args, mask=mask) + + # == backward & update == + with Timer("Backward"): + loss = loss_dict["loss"].mean() + booster.backward(loss=loss, optimizer=optimizer) + optimizer.step() + optimizer.zero_grad() + + # == update EMA == + update_ema(ema, model.module, optimizer=optimizer, decay=cfg.get("ema_decay", 0.9999)) + + # == update log info == + all_reduce_mean(loss) + running_loss += loss.item() + global_step = epoch * num_steps_per_epoch + step + log_step += 1 + acc_step += 1 + + # == logging == + if coordinator.is_master() and (global_step + 1) % cfg.get("log_every", 1) == 0: + avg_loss = running_loss / log_step + # progress bar + pbar.set_postfix({"loss": avg_loss, "step": step, "global_step": global_step}) + # tensorboard + tb_writer.add_scalar("loss", loss.item(), global_step) + # wandb + if cfg.get("wandb", False): + wandb.log( + { + "iter": global_step, + "epoch": epoch, + "loss": loss.item(), + "avg_loss": avg_loss, + "acc_step": acc_step, + }, + step=global_step, + ) + + running_loss = 0.0 + log_step = 0 + + # == checkpoint saving == + ckpt_every = cfg.get("ckpt_every", 0) + if ckpt_every > 0 and (global_step + 1) % ckpt_every == 0: + model_gathering(ema, ema_shape_dict) + save( + booster, + exp_dir, + model=model, + ema=ema, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + sampler=sampler_to_io, + epoch=epoch, + step=step + 1, + global_step=global_step + 1, + batch_size=cfg.get("batch_size", None), + ) + if dist.get_rank() == 0: + model_sharding(ema) + logger.info( + "Saved checkpoint at epoch %s step %s global_step %s to %s", + epoch, + step + 1, + global_step + 1, + exp_dir, + ) + + # NOTE: the continue epochs are not resumed, so we need to reset the sampler start index and start step + if cfg.dataset.type == DEFAULT_DATASET_NAME: + dataloader.sampler.set_start_index(0) + else: + dataloader.batch_sampler.set_epoch(epoch + 1) + logger.info("Epoch done, recomputing batch sampler") + start_step = 0 + + +if __name__ == "__main__": + main() diff --git a/scripts/train.py b/scripts/train.py index e5c668a..f5a526c 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -5,13 +5,13 @@ from pprint import pformat import torch import torch.distributed as dist +import wandb from colossalai.booster import Booster from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device, set_seed from tqdm import tqdm -import wandb from opensora.acceleration.checkpoint import set_grad_checkpoint from opensora.acceleration.parallel_states import get_data_parallel_group from opensora.datasets import prepare_dataloader, prepare_variable_dataloader @@ -120,7 +120,7 @@ def main(): # ====================================================== logger.info("Building models...") # == build text-encoder and vae == - text_encoder = build_module(cfg.text_encoder, MODELS, device=device) + text_encoder = build_module(cfg.text_encoder, MODELS, device=device, dtype=dtype) vae = build_module(cfg.vae, MODELS).to(device, dtype).eval() # == build diffusion model == diff --git a/setup.py b/setup.py index 5750345..fea582c 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from typing import List from setuptools import find_packages, setup -def fetch_requirements(path) -> List[str]: +def fetch_requirements(paths) -> List[str]: """ This function reads the requirements file. @@ -13,8 +13,13 @@ def fetch_requirements(path) -> List[str]: Returns: The lines in the requirements file. """ - with open(path, "r") as fd: - return [r.strip() for r in fd.readlines()] + if not isinstance(paths, list): + paths = [paths] + requirements = [] + for path in paths: + with open(path, "r") as fd: + requirements += [r.strip() for r in fd.readlines()] + return requirements def fetch_readme() -> str: @@ -34,10 +39,17 @@ setup( packages=find_packages( exclude=( "assets", + "cache", "configs", "docs", + "eval", + "evaluation_results", + "gradio", + "logs", + "notebooks", "outputs", "pretrained_models", + "samples", "scripts", "tests", "tools", @@ -48,7 +60,14 @@ setup( long_description=fetch_readme(), long_description_content_type="text/markdown", license="Apache Software License 2.0", - install_requires=fetch_requirements("requirements.txt"), + url="https://github.com/hpcaitech/Open-Sora", + project_urls={ + "Bug Tracker": "https://github.com/hpcaitech/Open-Sora/issues", + "Examples": "https://hpcaitech.github.io/Open-Sora/", + "Documentation": "https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file", + "Github": "https://github.com/hpcaitech/Open-Sora", + }, + install_requires=fetch_requirements("requirements/requirements.txt"), python_requires=">=3.6", classifiers=[ "Programming Language :: Python :: 3", @@ -57,4 +76,7 @@ setup( "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: System :: Distributed Computing", ], + extras_require={ + "fast": ["flash-attn --no-build-isolation"], + }, )