mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
format and some fix (#8)
This commit is contained in:
parent
a0bdaced4e
commit
f9f539f07e
|
|
@ -4,7 +4,7 @@ The Open-Sora project welcomes any constructive contribution from the community
|
|||
|
||||
## Development Environment Setup
|
||||
|
||||
To contribute to Open-Sora, we would like to first guide you to set up a proper development environment so that you can better implement your code. You can install this library from source with the `editable` flag (`-e`, for development mode) so that your change to the source code will be reflected in runtime without re-installation.
|
||||
To contribute to Open-Sora, we would like to first guide you to set up a proper development environment so that you can better implement your code. You can install this library from source with the `editable` flag (`-e`, for development mode) so that your change to the source code will be reflected in runtime without re-installation.
|
||||
|
||||
You can refer to the [Installation Section](./README.md#installation) and replace `pip install -v .` with `pip install -v -e .`.
|
||||
|
||||
|
|
|
|||
4
LICENSE
4
LICENSE
|
|
@ -313,7 +313,7 @@
|
|||
such as asking that all changes be marked or described.
|
||||
Although not required by our licenses, you are encouraged to
|
||||
respect those requests where reasonable. More_considerations
|
||||
for the public:
|
||||
for the public:
|
||||
wiki.creativecommons.org/Considerations_for_licensees
|
||||
|
||||
=======================================================================
|
||||
|
|
@ -677,5 +677,3 @@
|
|||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -118,7 +118,7 @@ conda create -n opensora python=3.10
|
|||
conda activate opensora
|
||||
|
||||
# install torch
|
||||
# the command below is for CUDA 12.1, choose install commands from
|
||||
# the command below is for CUDA 12.1, choose install commands from
|
||||
# https://pytorch.org/get-started/locally/ based on your own CUDA version
|
||||
pip install torch torchvision
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ bucket_config = {
|
|||
}
|
||||
|
||||
# Define acceleration
|
||||
num_workers = 0
|
||||
num_workers = 4
|
||||
dtype = "bf16"
|
||||
grad_checkpoint = True
|
||||
plugin = "zero2"
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ dtype = "fp16"
|
|||
prompt_path = None
|
||||
prompt = [
|
||||
"Drone view of waves crashing against the rugged cliffs along Big Sur’s garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliff’s edge. The steep drop from the road down to the beach is a dramatic feat, with the cliff's edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway.",
|
||||
"In an ornate, historical hall, a massive tidal wave peaks and begins to crash. Two surfers, seizing the moment, skillfully navigate the face of the wave."
|
||||
"In an ornate, historical hall, a massive tidal wave peaks and begins to crash. Two surfers, seizing the moment, skillfully navigate the face of the wave.",
|
||||
]
|
||||
|
||||
loop = 10
|
||||
|
|
|
|||
|
|
@ -25,13 +25,13 @@ scheduler = dict(
|
|||
type="iddpm",
|
||||
num_sampling_steps=100,
|
||||
cfg_scale=7.0,
|
||||
cfg_channel=3, # or None
|
||||
cfg_channel=3, # or None
|
||||
)
|
||||
dtype = "fp16"
|
||||
|
||||
# Condition
|
||||
prompt_path = "./assets/texts/t2v_samples.txt"
|
||||
prompt = None # prompt has higher priority than prompt_path
|
||||
prompt = None # prompt has higher priority than prompt_path
|
||||
|
||||
# Others
|
||||
batch_size = 1
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ model = dict(
|
|||
time_scale=1.0,
|
||||
enable_flashattn=True,
|
||||
enable_layernorm_kernel=True,
|
||||
from_pretrained="PRETRAINED_MODEL"
|
||||
from_pretrained="PRETRAINED_MODEL",
|
||||
)
|
||||
vae = dict(
|
||||
type="VideoAutoencoderKL",
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@
|
|||
conda create -n opensora python=3.10
|
||||
|
||||
# install torch
|
||||
# the command below is for CUDA 12.1, choose install commands from
|
||||
# the command below is for CUDA 12.1, choose install commands from
|
||||
# https://pytorch.org/get-started/locally/ based on your own CUDA version
|
||||
pip3 install torch torchvision
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import math
|
||||
|
||||
|
||||
# Ours
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -33,19 +33,11 @@ class Bucket:
|
|||
# wrap config with OrderedDict
|
||||
bucket_probs = OrderedDict()
|
||||
bucket_bs = OrderedDict()
|
||||
bucket_names = sorted(
|
||||
bucket_config.keys(), key=lambda x: ASPECT_RATIOS[x][0], reverse=True
|
||||
)
|
||||
bucket_names = sorted(bucket_config.keys(), key=lambda x: ASPECT_RATIOS[x][0], reverse=True)
|
||||
for key in bucket_names:
|
||||
bucket_time_names = sorted(
|
||||
bucket_config[key].keys(), key=lambda x: x, reverse=True
|
||||
)
|
||||
bucket_probs[key] = OrderedDict(
|
||||
{k: bucket_config[key][k][0] for k in bucket_time_names}
|
||||
)
|
||||
bucket_bs[key] = OrderedDict(
|
||||
{k: bucket_config[key][k][1] for k in bucket_time_names}
|
||||
)
|
||||
bucket_time_names = sorted(bucket_config[key].keys(), key=lambda x: x, reverse=True)
|
||||
bucket_probs[key] = OrderedDict({k: bucket_config[key][k][0] for k in bucket_time_names})
|
||||
bucket_bs[key] = OrderedDict({k: bucket_config[key][k][1] for k in bucket_time_names})
|
||||
|
||||
# first level: HW
|
||||
num_bucket = 0
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ from torch.distributed.distributed_c10d import _get_default_group
|
|||
from torch.utils.data import DataLoader, Dataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from .bucket import Bucket
|
||||
from .sampler import DistributedVariableVideoSampler, VariableVideoBatchSampler
|
||||
|
||||
|
||||
|
|
@ -98,38 +97,6 @@ def prepare_dataloader(
|
|||
)
|
||||
|
||||
|
||||
class _VariableVideoBatchSampler(torch.utils.data.BatchSampler):
|
||||
def __init__(self, sampler, batch_size, drop_last, dataset, buckect_config):
|
||||
self.sampler = sampler
|
||||
self.dataset = dataset
|
||||
self.batch_size = batch_size
|
||||
self.drop_last = drop_last
|
||||
self.bucket = Bucket(buckect_config)
|
||||
self.frame_interval = self.dataset.frame_interval
|
||||
self.bucket.info_bucket(self.dataset, self.frame_interval)
|
||||
|
||||
def __iter__(self):
|
||||
for idx in self.sampler:
|
||||
T, H, W = self.dataset.get_data_info(idx)
|
||||
bucket_id = self.bucket.get_bucket_id(T, H, W, self.frame_interval)
|
||||
if bucket_id is None:
|
||||
continue
|
||||
rT, rH, rW = self.bucket.get_thw(bucket_id)
|
||||
self.dataset.set_data_info(idx, rT, rH, rW)
|
||||
buffer = self.bucket[bucket_id]
|
||||
buffer.append(idx)
|
||||
if len(buffer) >= self.bucket.get_batch_size(bucket_id):
|
||||
yield buffer
|
||||
self.bucket.set_empty(bucket_id)
|
||||
|
||||
for k1, v1 in self.bucket.bucket.items():
|
||||
for k2, v2 in v1.items():
|
||||
for k3, buffer in v2.items():
|
||||
if len(buffer) > 0 and not self.drop_last:
|
||||
yield buffer
|
||||
self.bucket.set_empty((k1, k2, k3))
|
||||
|
||||
|
||||
def prepare_variable_dataloader(
|
||||
dataset,
|
||||
batch_size,
|
||||
|
|
|
|||
|
|
@ -39,6 +39,16 @@ class VideoTextDataset(torch.utils.data.Dataset):
|
|||
"video": get_transforms_video(transform_name, image_size),
|
||||
}
|
||||
|
||||
def _print_data_number(self):
|
||||
num_videos = 0
|
||||
num_images = 0
|
||||
for path in self.data["path"]:
|
||||
if self.get_type(path) == "video":
|
||||
num_videos += 1
|
||||
else:
|
||||
num_images += 1
|
||||
print(f"Dataset contains {num_videos} videos and {num_images} images.")
|
||||
|
||||
def get_type(self, path):
|
||||
ext = os.path.splitext(path)[-1].lower()
|
||||
if ext.lower() in VID_EXTENSIONS:
|
||||
|
|
@ -148,7 +158,6 @@ class VariableVideoTextDataset(VideoTextDataset):
|
|||
return {"video": video, "text": text, "num_frames": num_frames, "height": height, "width": width, "ar": ar}
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.getitem(index)
|
||||
for _ in range(10):
|
||||
try:
|
||||
return self.getitem(index)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import math
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from collections import OrderedDict, defaultdict
|
||||
from pprint import pprint
|
||||
from typing import Iterator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
|
@ -43,9 +44,7 @@ class DistributedVariableVideoSampler(DistributedSampler):
|
|||
# group by bucket
|
||||
for i in range(len(self.dataset)):
|
||||
t, h, w = self.dataset.get_data_info(i)
|
||||
bucket_id = self.bucket.get_bucket_id(
|
||||
t, h, w, self.dataset.frame_interval, g
|
||||
)
|
||||
bucket_id = self.bucket.get_bucket_id(t, h, w, self.dataset.frame_interval, g)
|
||||
if bucket_id is None:
|
||||
continue
|
||||
real_t, real_h, real_w = self.bucket.get_thw(bucket_id)
|
||||
|
|
@ -56,12 +55,8 @@ class DistributedVariableVideoSampler(DistributedSampler):
|
|||
# shuffle
|
||||
if self.shuffle:
|
||||
# sort buckets
|
||||
bucket_indices = torch.randperm(
|
||||
len(bucket_sample_dict), generator=g
|
||||
).tolist()
|
||||
bucket_order = {
|
||||
k: bucket_indices[i] for i, k in enumerate(bucket_sample_dict)
|
||||
}
|
||||
bucket_indices = torch.randperm(len(bucket_sample_dict), generator=g).tolist()
|
||||
bucket_order = {k: bucket_indices[i] for i, k in enumerate(bucket_sample_dict)}
|
||||
# sort samples in each bucket
|
||||
for k, v in bucket_sample_dict.items():
|
||||
sample_indices = torch.randperm(len(v), generator=g).tolist()
|
||||
|
|
@ -90,11 +85,7 @@ class DistributedVariableVideoSampler(DistributedSampler):
|
|||
if self.verbose:
|
||||
self._print_bucket_info(bucket_sample_dict)
|
||||
if self.shuffle:
|
||||
bucket_sample_dict = OrderedDict(
|
||||
sorted(
|
||||
bucket_sample_dict.items(), key=lambda item: bucket_order[item[0]]
|
||||
)
|
||||
)
|
||||
bucket_sample_dict = OrderedDict(sorted(bucket_sample_dict.items(), key=lambda item: bucket_order[item[0]]))
|
||||
# iterate
|
||||
found_last_bucket = self.last_bucket_id is None
|
||||
for k, v in bucket_sample_dict.items():
|
||||
|
|
@ -126,13 +117,21 @@ class DistributedVariableVideoSampler(DistributedSampler):
|
|||
def _print_bucket_info(self, bucket_sample_dict: dict) -> None:
|
||||
total_samples = 0
|
||||
num_dict = {}
|
||||
num_aspect_dict = defaultdict(int)
|
||||
num_hwt_dict = defaultdict(int)
|
||||
for k, v in bucket_sample_dict.items():
|
||||
size = len(v) * self.num_replicas
|
||||
total_samples += size
|
||||
num_dict[k] = size
|
||||
print(
|
||||
f"Total training samples: {total_samples}, num buckets: {len(num_dict)}, bucket samples: {num_dict}"
|
||||
)
|
||||
num_aspect_dict[k[-1]] += size
|
||||
num_hwt_dict[k[:-1]] += size
|
||||
print(f"Total training samples: {total_samples}, num buckets: {len(num_dict)}")
|
||||
print("Bucket samples:")
|
||||
pprint(num_dict)
|
||||
print("Bucket samples by HxWxT:")
|
||||
pprint(num_hwt_dict)
|
||||
print("Bucket samples by aspect ratio:")
|
||||
pprint(num_aspect_dict)
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
# users must ensure bucket config is the same
|
||||
|
|
@ -175,9 +174,7 @@ class VariableVideoBatchSampler(Sampler[List[int]]):
|
|||
cur_sample_indices = [sample_idx]
|
||||
else:
|
||||
cur_sample_indices.append(sample_idx)
|
||||
if len(cur_sample_indices) > 0 and (
|
||||
not self.drop_last or len(cur_sample_indices) == cur_batch_size
|
||||
):
|
||||
if len(cur_sample_indices) > 0 and (not self.drop_last or len(cur_sample_indices) == cur_batch_size):
|
||||
yield cur_sample_indices
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
import numbers
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
|
|
@ -146,8 +144,8 @@ def center_crop_arr(pil_image, image_size):
|
|||
|
||||
|
||||
def resize_crop_to_fill(pil_image, image_size):
|
||||
w, h = pil_image.size # PIL is (W, H)
|
||||
th, tw = image_size
|
||||
w, h = pil_image.size # PIL is (W, H)
|
||||
th, tw = image_size
|
||||
rh, rw = th / h, tw / w
|
||||
if rh > rw:
|
||||
sh, sw = th, int(w * rh)
|
||||
|
|
|
|||
|
|
@ -23,10 +23,7 @@ import xformers.ops
|
|||
from einops import rearrange
|
||||
from timm.models.vision_transformer import Mlp
|
||||
|
||||
from opensora.acceleration.communications import (
|
||||
all_to_all,
|
||||
split_forward_gather_backward,
|
||||
)
|
||||
from opensora.acceleration.communications import all_to_all, split_forward_gather_backward
|
||||
from opensora.acceleration.parallel_states import get_sequence_parallel_group
|
||||
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
|
|
@ -568,7 +565,6 @@ class CaptionEmbedder(nn.Module):
|
|||
self.register_buffer(
|
||||
"y_embedding",
|
||||
torch.randn(token_num, in_channels) / in_channels**0.5,
|
||||
persistent=False,
|
||||
)
|
||||
self.uncond_prob = uncond_prob
|
||||
|
||||
|
|
|
|||
|
|
@ -23,14 +23,12 @@
|
|||
|
||||
|
||||
import html
|
||||
import os
|
||||
import re
|
||||
import urllib.parse as ul
|
||||
|
||||
import ftfy
|
||||
import torch
|
||||
from bs4 import BeautifulSoup
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from opensora.registry import MODELS
|
||||
|
|
|
|||
|
|
@ -53,7 +53,9 @@ class VideoAutoencoderKL(nn.Module):
|
|||
def get_latent_size(self, input_size):
|
||||
latent_size = []
|
||||
for i in range(3):
|
||||
assert input_size[i] is None or input_size[i] % self.patch_size[i] == 0, "Input size must be divisible by patch size"
|
||||
assert (
|
||||
input_size[i] is None or input_size[i] % self.patch_size[i] == 0
|
||||
), "Input size must be divisible by patch size"
|
||||
latent_size.append(input_size[i] // self.patch_size[i] if input_size[i] is not None else None)
|
||||
return latent_size
|
||||
|
||||
|
|
@ -87,7 +89,9 @@ class VideoAutoencoderKLTemporalDecoder(nn.Module):
|
|||
def get_latent_size(self, input_size):
|
||||
latent_size = []
|
||||
for i in range(3):
|
||||
assert input_size[i] is None or input_size[i] % self.patch_size[i] == 0, "Input size must be divisible by patch size"
|
||||
assert (
|
||||
input_size[i] is None or input_size[i] % self.patch_size[i] == 0
|
||||
), "Input size must be divisible by patch size"
|
||||
latent_size.append(input_size[i] // self.patch_size[i] if input_size[i] is not None else None)
|
||||
return latent_size
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import operator
|
|||
import os
|
||||
from typing import Tuple
|
||||
|
||||
import colossalai
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
|
@ -55,9 +54,7 @@ def find_model(model_name):
|
|||
model = reparameter(model, model_name)
|
||||
return model
|
||||
else: # Load a custom DiT checkpoint:
|
||||
assert os.path.isfile(
|
||||
model_name
|
||||
), f"Could not find DiT checkpoint at {model_name}"
|
||||
assert os.path.isfile(model_name), f"Could not find DiT checkpoint at {model_name}"
|
||||
checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage)
|
||||
if "pos_embed_temporal" in checkpoint:
|
||||
del checkpoint["pos_embed_temporal"]
|
||||
|
|
@ -93,9 +90,7 @@ def model_sharding(model: torch.nn.Module):
|
|||
for _, param in model.named_parameters():
|
||||
padding_size = (world_size - param.numel() % world_size) % world_size
|
||||
if padding_size > 0:
|
||||
padding_param = torch.nn.functional.pad(
|
||||
param.data.view(-1), [0, padding_size]
|
||||
)
|
||||
padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
|
||||
else:
|
||||
padding_param = param.data.view(-1)
|
||||
splited_params = padding_param.split(padding_param.numel() // world_size)
|
||||
|
|
@ -125,9 +120,7 @@ def model_gathering(model: torch.nn.Module, model_shape_dict: dict):
|
|||
dist.all_gather(all_params, param.data, group=dist.group.WORLD)
|
||||
if int(global_rank) == 0:
|
||||
all_params = torch.cat(all_params)
|
||||
param.data = remove_padding(all_params, model_shape_dict[name]).view(
|
||||
model_shape_dict[name]
|
||||
)
|
||||
param.data = remove_padding(all_params, model_shape_dict[name]).view(model_shape_dict[name])
|
||||
dist.barrier()
|
||||
|
||||
|
||||
|
|
@ -164,9 +157,7 @@ def save(
|
|||
torch.save(ema.state_dict(), os.path.join(save_dir, "ema.pt"))
|
||||
model_sharding(ema)
|
||||
|
||||
booster.save_optimizer(
|
||||
optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096
|
||||
)
|
||||
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096)
|
||||
if lr_scheduler is not None:
|
||||
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
|
||||
running_states = {
|
||||
|
|
@ -194,9 +185,7 @@ def load(
|
|||
booster.load_model(model, os.path.join(load_dir, "model"))
|
||||
# ema is not boosted, so we don't use booster.load_model
|
||||
# ema.load_state_dict(torch.load(os.path.join(load_dir, "ema.pt")))
|
||||
ema.load_state_dict(
|
||||
torch.load(os.path.join(load_dir, "ema.pt"), map_location=torch.device("cpu"))
|
||||
)
|
||||
ema.load_state_dict(torch.load(os.path.join(load_dir, "ema.pt"), map_location=torch.device("cpu")))
|
||||
booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer"))
|
||||
if lr_scheduler is not None:
|
||||
booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler"))
|
||||
|
|
|
|||
|
|
@ -47,14 +47,13 @@ def merge_args(cfg, args, training=False):
|
|||
if args.ckpt_path is not None:
|
||||
cfg.model["from_pretrained"] = args.ckpt_path
|
||||
args.ckpt_path = None
|
||||
|
||||
|
||||
for k, v in vars(args).items():
|
||||
if k in cfg and v is not None:
|
||||
cfg[k] = v
|
||||
|
||||
if not training:
|
||||
# Inference only
|
||||
# Inference only
|
||||
if "reference_path" not in cfg:
|
||||
cfg["reference_path"] = None
|
||||
if "loop" not in cfg:
|
||||
|
|
@ -63,7 +62,7 @@ def merge_args(cfg, args, training=False):
|
|||
assert cfg["prompt_path"] is not None, "prompt or prompt_path must be provided"
|
||||
cfg["prompt"] = load_prompts(cfg["prompt_path"])
|
||||
else:
|
||||
# Training only
|
||||
# Training only
|
||||
if args.data_path is not None:
|
||||
cfg.dataset["data_path"] = args.data_path
|
||||
args.data_path = None
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from copy import deepcopy
|
||||
from pprint import pprint
|
||||
|
||||
import colossalai
|
||||
import torch
|
||||
|
|
@ -20,26 +21,14 @@ from opensora.acceleration.parallel_states import (
|
|||
from opensora.acceleration.plugin import ZeroSeqParallelPlugin
|
||||
from opensora.datasets import prepare_dataloader, prepare_variable_dataloader
|
||||
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
|
||||
from opensora.utils.ckpt_utils import (
|
||||
create_logger,
|
||||
load,
|
||||
model_sharding,
|
||||
record_model_param_shape,
|
||||
save,
|
||||
)
|
||||
from opensora.utils.ckpt_utils import create_logger, load, model_sharding, record_model_param_shape, save
|
||||
from opensora.utils.config_utils import (
|
||||
create_experiment_workspace,
|
||||
create_tensorboard_writer,
|
||||
parse_configs,
|
||||
save_training_config,
|
||||
)
|
||||
from opensora.utils.misc import (
|
||||
all_reduce_mean,
|
||||
format_numel_str,
|
||||
get_model_numel,
|
||||
requires_grad,
|
||||
to_torch_dtype,
|
||||
)
|
||||
from opensora.utils.misc import all_reduce_mean, format_numel_str, get_model_numel, requires_grad, to_torch_dtype
|
||||
from opensora.utils.train_utils import MaskGenerator, update_ema
|
||||
|
||||
|
||||
|
|
@ -48,7 +37,8 @@ def main():
|
|||
# 1. args & cfg
|
||||
# ======================================================
|
||||
cfg = parse_configs(training=True)
|
||||
print(cfg)
|
||||
print("Training configuration:")
|
||||
pprint(cfg._cfg_dict)
|
||||
exp_name, exp_dir = create_experiment_workspace(cfg)
|
||||
save_training_config(cfg._cfg_dict, exp_dir)
|
||||
|
||||
|
|
@ -115,12 +105,10 @@ def main():
|
|||
if cfg.bucket_config is None:
|
||||
dataloader = prepare_dataloader(**dataloader_args)
|
||||
else:
|
||||
dataloader = prepare_variable_dataloader(
|
||||
bucket_config=cfg.bucket_config, **dataloader_args
|
||||
)
|
||||
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size
|
||||
logger.info(f"Dataset contains {len(dataset):,} videos ({dataset.data_path})")
|
||||
logger.info(f"Total batch size: {total_batch_size}")
|
||||
dataloader = prepare_variable_dataloader(bucket_config=cfg.bucket_config, **dataloader_args)
|
||||
if cfg.dataset.type == "VideoTextDataset":
|
||||
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size
|
||||
logger.info(f"Total batch size: {total_batch_size}")
|
||||
|
||||
# ======================================================
|
||||
# 4. build model
|
||||
|
|
@ -193,11 +181,7 @@ def main():
|
|||
# =======================================================
|
||||
start_epoch = start_step = log_step = sampler_start_idx = 0
|
||||
running_loss = 0.0
|
||||
sampler_to_io = (
|
||||
dataloader.batch_sampler
|
||||
if cfg.dataset.type == "VariableVideoTextDataset"
|
||||
else None
|
||||
)
|
||||
sampler_to_io = dataloader.batch_sampler if cfg.dataset.type == "VariableVideoTextDataset" else None
|
||||
# 6.1. resume training
|
||||
if cfg.load is not None:
|
||||
logger.info("Loading checkpoint")
|
||||
|
|
@ -210,12 +194,8 @@ def main():
|
|||
cfg.load,
|
||||
sampler=sampler_to_io,
|
||||
)
|
||||
logger.info(
|
||||
f"Loaded checkpoint {cfg.load} at epoch {start_epoch} step {start_step}"
|
||||
)
|
||||
logger.info(
|
||||
f"Training for {cfg.epochs} epochs with {num_steps_per_epoch} steps per epoch"
|
||||
)
|
||||
logger.info(f"Loaded checkpoint {cfg.load} at epoch {start_epoch} step {start_step}")
|
||||
logger.info(f"Training for {cfg.epochs} epochs with {num_steps_per_epoch} steps per epoch")
|
||||
|
||||
if cfg.dataset.type == "VideoTextDataset":
|
||||
dataloader.sampler.set_start_index(sampler_start_idx)
|
||||
|
|
@ -257,12 +237,8 @@ def main():
|
|||
model_args[k] = v.to(device, dtype)
|
||||
|
||||
# Diffusion
|
||||
t = torch.randint(
|
||||
0, scheduler.num_timesteps, (x.shape[0],), device=device
|
||||
)
|
||||
loss_dict = scheduler.training_losses(
|
||||
model, x, t, model_args, mask=mask
|
||||
)
|
||||
t = torch.randint(0, scheduler.num_timesteps, (x.shape[0],), device=device)
|
||||
loss_dict = scheduler.training_losses(model, x, t, model_args, mask=mask)
|
||||
|
||||
# Backward & update
|
||||
loss = loss_dict["loss"].mean()
|
||||
|
|
@ -282,9 +258,7 @@ def main():
|
|||
# Log to tensorboard
|
||||
if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0:
|
||||
avg_loss = running_loss / log_step
|
||||
pbar.set_postfix(
|
||||
{"loss": avg_loss, "step": step, "global_step": global_step}
|
||||
)
|
||||
pbar.set_postfix({"loss": avg_loss, "step": step, "global_step": global_step})
|
||||
running_loss = 0
|
||||
log_step = 0
|
||||
writer.add_scalar("loss", loss.item(), global_step)
|
||||
|
|
@ -292,7 +266,6 @@ def main():
|
|||
wandb.log(
|
||||
{
|
||||
"iter": global_step,
|
||||
"num_samples": global_step * total_batch_size,
|
||||
"epoch": epoch,
|
||||
"loss": loss.item(),
|
||||
"avg_loss": avg_loss,
|
||||
|
|
|
|||
|
|
@ -69,16 +69,24 @@ def run_cross_attention(rank, world_size):
|
|||
# create model
|
||||
torch.manual_seed(1024)
|
||||
set_sequence_parallel_group(dist.group.WORLD)
|
||||
seq_parallel_attention = SeqParallelMultiHeadCrossAttention(
|
||||
d_model=256,
|
||||
num_heads=4,
|
||||
).cuda().to(torch.bfloat16)
|
||||
seq_parallel_attention = (
|
||||
SeqParallelMultiHeadCrossAttention(
|
||||
d_model=256,
|
||||
num_heads=4,
|
||||
)
|
||||
.cuda()
|
||||
.to(torch.bfloat16)
|
||||
)
|
||||
|
||||
torch.manual_seed(1024)
|
||||
attention = MultiHeadCrossAttention(
|
||||
d_model=256,
|
||||
num_heads=4,
|
||||
).cuda().to(torch.bfloat16)
|
||||
attention = (
|
||||
MultiHeadCrossAttention(
|
||||
d_model=256,
|
||||
num_heads=4,
|
||||
)
|
||||
.cuda()
|
||||
.to(torch.bfloat16)
|
||||
)
|
||||
|
||||
# make sure the weights are the same
|
||||
for p1, p2 in zip(seq_parallel_attention.parameters(), attention.parameters()):
|
||||
|
|
@ -128,7 +136,9 @@ def run_cross_attention(rank, world_size):
|
|||
|
||||
# # check grad
|
||||
for p1, p2 in zip(seq_parallel_attention.named_parameters(), attention.named_parameters()):
|
||||
assert torch.allclose(p1[1].grad, p2[1].grad, rtol=1e-3, atol=1e-4), f"\n{p1[0]}\nvs\n{p2[0]}:\n{p1[1].grad}\nvs\n{p2[1].grad}"
|
||||
assert torch.allclose(
|
||||
p1[1].grad, p2[1].grad, rtol=1e-3, atol=1e-4
|
||||
), f"\n{p1[0]}\nvs\n{p2[0]}:\n{p1[1].grad}\nvs\n{p2[1].grad}"
|
||||
|
||||
# # check input grad
|
||||
assert torch.allclose(x.grad, seq_x.grad, atol=1e-7), f"{x.grad}\nvs\n{seq_x.grad}"
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvi
|
|||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
First, install LLaVA according to their [official instructions](https://github.com/haotian-liu/LLaVA?tab=readme-ov-file#install). We use the `liuhaotian/llava-v1.6-34b` model for captioning, which can be download [here](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b).
|
||||
First, install LLaVA according to their [official instructions](https://github.com/haotian-liu/LLaVA?tab=readme-ov-file#install). We use the `liuhaotian/llava-v1.6-34b` model for captioning, which can be download [here](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b).
|
||||
|
||||
### Usage
|
||||
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ head -n 10 DATA1.csv
|
|||
wc -l DATA1.csv
|
||||
```
|
||||
|
||||
Additionally, Ww provide `csvutils.py` to manage the CSV files.
|
||||
Additionally, Ww provide `csvutils.py` to manage the CSV files.
|
||||
|
||||
### Requirement
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import os
|
|||
import pandas as pd
|
||||
from torchvision.datasets import ImageNet
|
||||
|
||||
|
||||
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
|
||||
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# Scene Detection and Video Split
|
||||
|
||||
Raw videos from the Internet may be too long for training.
|
||||
Raw videos from the Internet may be too long for training.
|
||||
Thus, we detect scenes in raw videos and split them into short clips based on the scenes.
|
||||
First prepare the video processing packages.
|
||||
```bash
|
||||
|
|
|
|||
Loading…
Reference in a new issue