format and some fix (#8)

This commit is contained in:
Zheng Zangwei (Alex Zheng) 2024-03-30 13:34:19 +08:00 committed by GitHub
parent a0bdaced4e
commit f9f539f07e
25 changed files with 94 additions and 166 deletions

View file

@ -4,7 +4,7 @@ The Open-Sora project welcomes any constructive contribution from the community
## Development Environment Setup
To contribute to Open-Sora, we would like to first guide you to set up a proper development environment so that you can better implement your code. You can install this library from source with the `editable` flag (`-e`, for development mode) so that your change to the source code will be reflected in runtime without re-installation.
To contribute to Open-Sora, we would like to first guide you to set up a proper development environment so that you can better implement your code. You can install this library from source with the `editable` flag (`-e`, for development mode) so that your change to the source code will be reflected in runtime without re-installation.
You can refer to the [Installation Section](./README.md#installation) and replace `pip install -v .` with `pip install -v -e .`.

View file

@ -313,7 +313,7 @@
such as asking that all changes be marked or described.
Although not required by our licenses, you are encouraged to
respect those requests where reasonable. More_considerations
for the public:
for the public:
wiki.creativecommons.org/Considerations_for_licensees
=======================================================================
@ -677,5 +677,3 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View file

@ -118,7 +118,7 @@ conda create -n opensora python=3.10
conda activate opensora
# install torch
# the command below is for CUDA 12.1, choose install commands from
# the command below is for CUDA 12.1, choose install commands from
# https://pytorch.org/get-started/locally/ based on your own CUDA version
pip install torch torchvision

View file

@ -15,7 +15,7 @@ bucket_config = {
}
# Define acceleration
num_workers = 0
num_workers = 4
dtype = "bf16"
grad_checkpoint = True
plugin = "zero2"

View file

@ -35,7 +35,7 @@ dtype = "fp16"
prompt_path = None
prompt = [
"Drone view of waves crashing against the rugged cliffs along Big Surs garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliffs edge. The steep drop from the road down to the beach is a dramatic feat, with the cliff's edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway.",
"In an ornate, historical hall, a massive tidal wave peaks and begins to crash. Two surfers, seizing the moment, skillfully navigate the face of the wave."
"In an ornate, historical hall, a massive tidal wave peaks and begins to crash. Two surfers, seizing the moment, skillfully navigate the face of the wave.",
]
loop = 10

View file

@ -25,13 +25,13 @@ scheduler = dict(
type="iddpm",
num_sampling_steps=100,
cfg_scale=7.0,
cfg_channel=3, # or None
cfg_channel=3, # or None
)
dtype = "fp16"
# Condition
prompt_path = "./assets/texts/t2v_samples.txt"
prompt = None # prompt has higher priority than prompt_path
prompt = None # prompt has higher priority than prompt_path
# Others
batch_size = 1

View file

@ -9,7 +9,7 @@ model = dict(
time_scale=1.0,
enable_flashattn=True,
enable_layernorm_kernel=True,
from_pretrained="PRETRAINED_MODEL"
from_pretrained="PRETRAINED_MODEL",
)
vae = dict(
type="VideoAutoencoderKL",

View file

@ -91,7 +91,7 @@
conda create -n opensora python=3.10
# install torch
# the command below is for CUDA 12.1, choose install commands from
# the command below is for CUDA 12.1, choose install commands from
# https://pytorch.org/get-started/locally/ based on your own CUDA version
pip3 install torch torchvision

View file

@ -1,6 +1,5 @@
import math
# Ours

View file

@ -33,19 +33,11 @@ class Bucket:
# wrap config with OrderedDict
bucket_probs = OrderedDict()
bucket_bs = OrderedDict()
bucket_names = sorted(
bucket_config.keys(), key=lambda x: ASPECT_RATIOS[x][0], reverse=True
)
bucket_names = sorted(bucket_config.keys(), key=lambda x: ASPECT_RATIOS[x][0], reverse=True)
for key in bucket_names:
bucket_time_names = sorted(
bucket_config[key].keys(), key=lambda x: x, reverse=True
)
bucket_probs[key] = OrderedDict(
{k: bucket_config[key][k][0] for k in bucket_time_names}
)
bucket_bs[key] = OrderedDict(
{k: bucket_config[key][k][1] for k in bucket_time_names}
)
bucket_time_names = sorted(bucket_config[key].keys(), key=lambda x: x, reverse=True)
bucket_probs[key] = OrderedDict({k: bucket_config[key][k][0] for k in bucket_time_names})
bucket_bs[key] = OrderedDict({k: bucket_config[key][k][1] for k in bucket_time_names})
# first level: HW
num_bucket = 0

View file

@ -8,7 +8,6 @@ from torch.distributed.distributed_c10d import _get_default_group
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from .bucket import Bucket
from .sampler import DistributedVariableVideoSampler, VariableVideoBatchSampler
@ -98,38 +97,6 @@ def prepare_dataloader(
)
class _VariableVideoBatchSampler(torch.utils.data.BatchSampler):
def __init__(self, sampler, batch_size, drop_last, dataset, buckect_config):
self.sampler = sampler
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
self.bucket = Bucket(buckect_config)
self.frame_interval = self.dataset.frame_interval
self.bucket.info_bucket(self.dataset, self.frame_interval)
def __iter__(self):
for idx in self.sampler:
T, H, W = self.dataset.get_data_info(idx)
bucket_id = self.bucket.get_bucket_id(T, H, W, self.frame_interval)
if bucket_id is None:
continue
rT, rH, rW = self.bucket.get_thw(bucket_id)
self.dataset.set_data_info(idx, rT, rH, rW)
buffer = self.bucket[bucket_id]
buffer.append(idx)
if len(buffer) >= self.bucket.get_batch_size(bucket_id):
yield buffer
self.bucket.set_empty(bucket_id)
for k1, v1 in self.bucket.bucket.items():
for k2, v2 in v1.items():
for k3, buffer in v2.items():
if len(buffer) > 0 and not self.drop_last:
yield buffer
self.bucket.set_empty((k1, k2, k3))
def prepare_variable_dataloader(
dataset,
batch_size,

View file

@ -39,6 +39,16 @@ class VideoTextDataset(torch.utils.data.Dataset):
"video": get_transforms_video(transform_name, image_size),
}
def _print_data_number(self):
num_videos = 0
num_images = 0
for path in self.data["path"]:
if self.get_type(path) == "video":
num_videos += 1
else:
num_images += 1
print(f"Dataset contains {num_videos} videos and {num_images} images.")
def get_type(self, path):
ext = os.path.splitext(path)[-1].lower()
if ext.lower() in VID_EXTENSIONS:
@ -148,7 +158,6 @@ class VariableVideoTextDataset(VideoTextDataset):
return {"video": video, "text": text, "num_frames": num_frames, "height": height, "width": width, "ar": ar}
def __getitem__(self, index):
return self.getitem(index)
for _ in range(10):
try:
return self.getitem(index)

View file

@ -1,6 +1,7 @@
import math
import warnings
from collections import OrderedDict
from collections import OrderedDict, defaultdict
from pprint import pprint
from typing import Iterator, List, Optional, Tuple
import torch
@ -43,9 +44,7 @@ class DistributedVariableVideoSampler(DistributedSampler):
# group by bucket
for i in range(len(self.dataset)):
t, h, w = self.dataset.get_data_info(i)
bucket_id = self.bucket.get_bucket_id(
t, h, w, self.dataset.frame_interval, g
)
bucket_id = self.bucket.get_bucket_id(t, h, w, self.dataset.frame_interval, g)
if bucket_id is None:
continue
real_t, real_h, real_w = self.bucket.get_thw(bucket_id)
@ -56,12 +55,8 @@ class DistributedVariableVideoSampler(DistributedSampler):
# shuffle
if self.shuffle:
# sort buckets
bucket_indices = torch.randperm(
len(bucket_sample_dict), generator=g
).tolist()
bucket_order = {
k: bucket_indices[i] for i, k in enumerate(bucket_sample_dict)
}
bucket_indices = torch.randperm(len(bucket_sample_dict), generator=g).tolist()
bucket_order = {k: bucket_indices[i] for i, k in enumerate(bucket_sample_dict)}
# sort samples in each bucket
for k, v in bucket_sample_dict.items():
sample_indices = torch.randperm(len(v), generator=g).tolist()
@ -90,11 +85,7 @@ class DistributedVariableVideoSampler(DistributedSampler):
if self.verbose:
self._print_bucket_info(bucket_sample_dict)
if self.shuffle:
bucket_sample_dict = OrderedDict(
sorted(
bucket_sample_dict.items(), key=lambda item: bucket_order[item[0]]
)
)
bucket_sample_dict = OrderedDict(sorted(bucket_sample_dict.items(), key=lambda item: bucket_order[item[0]]))
# iterate
found_last_bucket = self.last_bucket_id is None
for k, v in bucket_sample_dict.items():
@ -126,13 +117,21 @@ class DistributedVariableVideoSampler(DistributedSampler):
def _print_bucket_info(self, bucket_sample_dict: dict) -> None:
total_samples = 0
num_dict = {}
num_aspect_dict = defaultdict(int)
num_hwt_dict = defaultdict(int)
for k, v in bucket_sample_dict.items():
size = len(v) * self.num_replicas
total_samples += size
num_dict[k] = size
print(
f"Total training samples: {total_samples}, num buckets: {len(num_dict)}, bucket samples: {num_dict}"
)
num_aspect_dict[k[-1]] += size
num_hwt_dict[k[:-1]] += size
print(f"Total training samples: {total_samples}, num buckets: {len(num_dict)}")
print("Bucket samples:")
pprint(num_dict)
print("Bucket samples by HxWxT:")
pprint(num_hwt_dict)
print("Bucket samples by aspect ratio:")
pprint(num_aspect_dict)
def state_dict(self) -> dict:
# users must ensure bucket config is the same
@ -175,9 +174,7 @@ class VariableVideoBatchSampler(Sampler[List[int]]):
cur_sample_indices = [sample_idx]
else:
cur_sample_indices.append(sample_idx)
if len(cur_sample_indices) > 0 and (
not self.drop_last or len(cur_sample_indices) == cur_batch_size
):
if len(cur_sample_indices) > 0 and (not self.drop_last or len(cur_sample_indices) == cur_batch_size):
yield cur_sample_indices
def state_dict(self) -> dict:

View file

@ -1,5 +1,3 @@
import numbers
import numpy as np
import torch
import torchvision
@ -146,8 +144,8 @@ def center_crop_arr(pil_image, image_size):
def resize_crop_to_fill(pil_image, image_size):
w, h = pil_image.size # PIL is (W, H)
th, tw = image_size
w, h = pil_image.size # PIL is (W, H)
th, tw = image_size
rh, rw = th / h, tw / w
if rh > rw:
sh, sw = th, int(w * rh)

View file

@ -23,10 +23,7 @@ import xformers.ops
from einops import rearrange
from timm.models.vision_transformer import Mlp
from opensora.acceleration.communications import (
all_to_all,
split_forward_gather_backward,
)
from opensora.acceleration.communications import all_to_all, split_forward_gather_backward
from opensora.acceleration.parallel_states import get_sequence_parallel_group
approx_gelu = lambda: nn.GELU(approximate="tanh")
@ -568,7 +565,6 @@ class CaptionEmbedder(nn.Module):
self.register_buffer(
"y_embedding",
torch.randn(token_num, in_channels) / in_channels**0.5,
persistent=False,
)
self.uncond_prob = uncond_prob

View file

@ -23,14 +23,12 @@
import html
import os
import re
import urllib.parse as ul
import ftfy
import torch
from bs4 import BeautifulSoup
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, T5EncoderModel
from opensora.registry import MODELS

View file

@ -53,7 +53,9 @@ class VideoAutoencoderKL(nn.Module):
def get_latent_size(self, input_size):
latent_size = []
for i in range(3):
assert input_size[i] is None or input_size[i] % self.patch_size[i] == 0, "Input size must be divisible by patch size"
assert (
input_size[i] is None or input_size[i] % self.patch_size[i] == 0
), "Input size must be divisible by patch size"
latent_size.append(input_size[i] // self.patch_size[i] if input_size[i] is not None else None)
return latent_size
@ -87,7 +89,9 @@ class VideoAutoencoderKLTemporalDecoder(nn.Module):
def get_latent_size(self, input_size):
latent_size = []
for i in range(3):
assert input_size[i] is None or input_size[i] % self.patch_size[i] == 0, "Input size must be divisible by patch size"
assert (
input_size[i] is None or input_size[i] % self.patch_size[i] == 0
), "Input size must be divisible by patch size"
latent_size.append(input_size[i] // self.patch_size[i] if input_size[i] is not None else None)
return latent_size

View file

@ -5,7 +5,6 @@ import operator
import os
from typing import Tuple
import colossalai
import torch
import torch.distributed as dist
import torch.nn as nn
@ -55,9 +54,7 @@ def find_model(model_name):
model = reparameter(model, model_name)
return model
else: # Load a custom DiT checkpoint:
assert os.path.isfile(
model_name
), f"Could not find DiT checkpoint at {model_name}"
assert os.path.isfile(model_name), f"Could not find DiT checkpoint at {model_name}"
checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage)
if "pos_embed_temporal" in checkpoint:
del checkpoint["pos_embed_temporal"]
@ -93,9 +90,7 @@ def model_sharding(model: torch.nn.Module):
for _, param in model.named_parameters():
padding_size = (world_size - param.numel() % world_size) % world_size
if padding_size > 0:
padding_param = torch.nn.functional.pad(
param.data.view(-1), [0, padding_size]
)
padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
else:
padding_param = param.data.view(-1)
splited_params = padding_param.split(padding_param.numel() // world_size)
@ -125,9 +120,7 @@ def model_gathering(model: torch.nn.Module, model_shape_dict: dict):
dist.all_gather(all_params, param.data, group=dist.group.WORLD)
if int(global_rank) == 0:
all_params = torch.cat(all_params)
param.data = remove_padding(all_params, model_shape_dict[name]).view(
model_shape_dict[name]
)
param.data = remove_padding(all_params, model_shape_dict[name]).view(model_shape_dict[name])
dist.barrier()
@ -164,9 +157,7 @@ def save(
torch.save(ema.state_dict(), os.path.join(save_dir, "ema.pt"))
model_sharding(ema)
booster.save_optimizer(
optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096
)
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096)
if lr_scheduler is not None:
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
running_states = {
@ -194,9 +185,7 @@ def load(
booster.load_model(model, os.path.join(load_dir, "model"))
# ema is not boosted, so we don't use booster.load_model
# ema.load_state_dict(torch.load(os.path.join(load_dir, "ema.pt")))
ema.load_state_dict(
torch.load(os.path.join(load_dir, "ema.pt"), map_location=torch.device("cpu"))
)
ema.load_state_dict(torch.load(os.path.join(load_dir, "ema.pt"), map_location=torch.device("cpu")))
booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer"))
if lr_scheduler is not None:
booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler"))

View file

@ -47,14 +47,13 @@ def merge_args(cfg, args, training=False):
if args.ckpt_path is not None:
cfg.model["from_pretrained"] = args.ckpt_path
args.ckpt_path = None
for k, v in vars(args).items():
if k in cfg and v is not None:
cfg[k] = v
if not training:
# Inference only
# Inference only
if "reference_path" not in cfg:
cfg["reference_path"] = None
if "loop" not in cfg:
@ -63,7 +62,7 @@ def merge_args(cfg, args, training=False):
assert cfg["prompt_path"] is not None, "prompt or prompt_path must be provided"
cfg["prompt"] = load_prompts(cfg["prompt_path"])
else:
# Training only
# Training only
if args.data_path is not None:
cfg.dataset["data_path"] = args.data_path
args.data_path = None

View file

@ -1,4 +1,5 @@
from copy import deepcopy
from pprint import pprint
import colossalai
import torch
@ -20,26 +21,14 @@ from opensora.acceleration.parallel_states import (
from opensora.acceleration.plugin import ZeroSeqParallelPlugin
from opensora.datasets import prepare_dataloader, prepare_variable_dataloader
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
from opensora.utils.ckpt_utils import (
create_logger,
load,
model_sharding,
record_model_param_shape,
save,
)
from opensora.utils.ckpt_utils import create_logger, load, model_sharding, record_model_param_shape, save
from opensora.utils.config_utils import (
create_experiment_workspace,
create_tensorboard_writer,
parse_configs,
save_training_config,
)
from opensora.utils.misc import (
all_reduce_mean,
format_numel_str,
get_model_numel,
requires_grad,
to_torch_dtype,
)
from opensora.utils.misc import all_reduce_mean, format_numel_str, get_model_numel, requires_grad, to_torch_dtype
from opensora.utils.train_utils import MaskGenerator, update_ema
@ -48,7 +37,8 @@ def main():
# 1. args & cfg
# ======================================================
cfg = parse_configs(training=True)
print(cfg)
print("Training configuration:")
pprint(cfg._cfg_dict)
exp_name, exp_dir = create_experiment_workspace(cfg)
save_training_config(cfg._cfg_dict, exp_dir)
@ -115,12 +105,10 @@ def main():
if cfg.bucket_config is None:
dataloader = prepare_dataloader(**dataloader_args)
else:
dataloader = prepare_variable_dataloader(
bucket_config=cfg.bucket_config, **dataloader_args
)
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size
logger.info(f"Dataset contains {len(dataset):,} videos ({dataset.data_path})")
logger.info(f"Total batch size: {total_batch_size}")
dataloader = prepare_variable_dataloader(bucket_config=cfg.bucket_config, **dataloader_args)
if cfg.dataset.type == "VideoTextDataset":
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size
logger.info(f"Total batch size: {total_batch_size}")
# ======================================================
# 4. build model
@ -193,11 +181,7 @@ def main():
# =======================================================
start_epoch = start_step = log_step = sampler_start_idx = 0
running_loss = 0.0
sampler_to_io = (
dataloader.batch_sampler
if cfg.dataset.type == "VariableVideoTextDataset"
else None
)
sampler_to_io = dataloader.batch_sampler if cfg.dataset.type == "VariableVideoTextDataset" else None
# 6.1. resume training
if cfg.load is not None:
logger.info("Loading checkpoint")
@ -210,12 +194,8 @@ def main():
cfg.load,
sampler=sampler_to_io,
)
logger.info(
f"Loaded checkpoint {cfg.load} at epoch {start_epoch} step {start_step}"
)
logger.info(
f"Training for {cfg.epochs} epochs with {num_steps_per_epoch} steps per epoch"
)
logger.info(f"Loaded checkpoint {cfg.load} at epoch {start_epoch} step {start_step}")
logger.info(f"Training for {cfg.epochs} epochs with {num_steps_per_epoch} steps per epoch")
if cfg.dataset.type == "VideoTextDataset":
dataloader.sampler.set_start_index(sampler_start_idx)
@ -257,12 +237,8 @@ def main():
model_args[k] = v.to(device, dtype)
# Diffusion
t = torch.randint(
0, scheduler.num_timesteps, (x.shape[0],), device=device
)
loss_dict = scheduler.training_losses(
model, x, t, model_args, mask=mask
)
t = torch.randint(0, scheduler.num_timesteps, (x.shape[0],), device=device)
loss_dict = scheduler.training_losses(model, x, t, model_args, mask=mask)
# Backward & update
loss = loss_dict["loss"].mean()
@ -282,9 +258,7 @@ def main():
# Log to tensorboard
if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0:
avg_loss = running_loss / log_step
pbar.set_postfix(
{"loss": avg_loss, "step": step, "global_step": global_step}
)
pbar.set_postfix({"loss": avg_loss, "step": step, "global_step": global_step})
running_loss = 0
log_step = 0
writer.add_scalar("loss", loss.item(), global_step)
@ -292,7 +266,6 @@ def main():
wandb.log(
{
"iter": global_step,
"num_samples": global_step * total_batch_size,
"epoch": epoch,
"loss": loss.item(),
"avg_loss": avg_loss,

View file

@ -69,16 +69,24 @@ def run_cross_attention(rank, world_size):
# create model
torch.manual_seed(1024)
set_sequence_parallel_group(dist.group.WORLD)
seq_parallel_attention = SeqParallelMultiHeadCrossAttention(
d_model=256,
num_heads=4,
).cuda().to(torch.bfloat16)
seq_parallel_attention = (
SeqParallelMultiHeadCrossAttention(
d_model=256,
num_heads=4,
)
.cuda()
.to(torch.bfloat16)
)
torch.manual_seed(1024)
attention = MultiHeadCrossAttention(
d_model=256,
num_heads=4,
).cuda().to(torch.bfloat16)
attention = (
MultiHeadCrossAttention(
d_model=256,
num_heads=4,
)
.cuda()
.to(torch.bfloat16)
)
# make sure the weights are the same
for p1, p2 in zip(seq_parallel_attention.parameters(), attention.parameters()):
@ -128,7 +136,9 @@ def run_cross_attention(rank, world_size):
# # check grad
for p1, p2 in zip(seq_parallel_attention.named_parameters(), attention.named_parameters()):
assert torch.allclose(p1[1].grad, p2[1].grad, rtol=1e-3, atol=1e-4), f"\n{p1[0]}\nvs\n{p2[0]}:\n{p1[1].grad}\nvs\n{p2[1].grad}"
assert torch.allclose(
p1[1].grad, p2[1].grad, rtol=1e-3, atol=1e-4
), f"\n{p1[0]}\nvs\n{p2[0]}:\n{p1[1].grad}\nvs\n{p2[1].grad}"
# # check input grad
assert torch.allclose(x.grad, seq_x.grad, atol=1e-7), f"{x.grad}\nvs\n{seq_x.grad}"

View file

@ -39,7 +39,7 @@ conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvi
pip install flash-attn --no-build-isolation
```
First, install LLaVA according to their [official instructions](https://github.com/haotian-liu/LLaVA?tab=readme-ov-file#install). We use the `liuhaotian/llava-v1.6-34b` model for captioning, which can be download [here](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b).
First, install LLaVA according to their [official instructions](https://github.com/haotian-liu/LLaVA?tab=readme-ov-file#install). We use the `liuhaotian/llava-v1.6-34b` model for captioning, which can be download [here](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b).
### Usage

View file

@ -60,7 +60,7 @@ head -n 10 DATA1.csv
wc -l DATA1.csv
```
Additionally, Ww provide `csvutils.py` to manage the CSV files.
Additionally, Ww provide `csvutils.py` to manage the CSV files.
### Requirement

View file

@ -4,7 +4,6 @@ import os
import pandas as pd
from torchvision.datasets import ImageNet
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")

View file

@ -1,6 +1,6 @@
# Scene Detection and Video Split
Raw videos from the Internet may be too long for training.
Raw videos from the Internet may be too long for training.
Thus, we detect scenes in raw videos and split them into short clips based on the scenes.
First prepare the video processing packages.
```bash