[wip] feat training

This commit is contained in:
zhengzangw 2024-05-21 05:45:06 +00:00
parent 80f9ecbde3
commit 066b0c9bb3
11 changed files with 94 additions and 123 deletions

View file

@ -0,0 +1,52 @@
# Dataset settings
dataset = dict(type="BatchFeatureDataset")
grad_checkpoint = True
num_workers = 4
# Acceleration settings
dtype = "bf16"
plugin = "zero2"
# Model settings
model = dict(
type="STDiT3-XL/2",
from_pretrained=None,
qk_norm=True,
enable_flash_attn=True,
enable_layernorm_kernel=True,
freeze_y_embedder=True,
)
scheduler = dict(
type="rflow",
use_timestep_transform=True,
sample_method="logit-normal",
)
# Mask settings
mask_ratios = {
"random": 0.2,
"intepolate": 0.01,
"quarter_random": 0.01,
"quarter_head": 0.01,
"quarter_tail": 0.01,
"quarter_head_tail": 0.01,
"image_random": 0.05,
"image_head": 0.1,
"image_tail": 0.05,
"image_head_tail": 0.05,
}
# Log settings
seed = 42
outputs = "outputs"
wandb = False
epochs = 1000
log_every = 10
ckpt_every = 1
# optimization settings
load = None
grad_clip = 1.0
lr = 2e-4
ema_decay = 0.99
adam_eps = 1e-15

View file

@ -1,6 +1,6 @@
# Dataset settings
dataset = dict(
type="BatchDataset",
type="BatchFeatureDataset",
)
grad_checkpoint = True

View file

@ -1,2 +1,2 @@
from .datasets import IMG_FPS, VariableVideoTextDataset, VideoTextDataset, BatchDataset
from .datasets import IMG_FPS, BatchFeatureDataset, VariableVideoTextDataset, VideoTextDataset
from .utils import get_transforms_image, get_transforms_video, is_img, is_vid, save_sample

View file

@ -8,7 +8,7 @@ from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import _get_default_group
from torch.utils.data import DataLoader
from .datasets import VariableVideoTextDataset, VideoTextDataset
from .datasets import BatchFeatureDataset, VariableVideoTextDataset, VideoTextDataset
from .sampler import BatchDistributedSampler, StatefulDistributedSampler, VariableVideoBatchSampler
@ -56,6 +56,7 @@ def prepare_dataloader(
worker_init_fn=get_seed_worker(seed),
pin_memory=pin_memory,
num_workers=num_workers,
collate_fn=collate_fn_default,
**_kwargs,
),
batch_sampler,
@ -77,105 +78,29 @@ def prepare_dataloader(
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
collate_fn=collate_fn_default,
**_kwargs,
),
sampler,
)
else:
raise ValueError(f"Unsupported dataset type: {type(dataset)}")
def prepare_variable_dataloader(
dataset,
batch_size,
bucket_config,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
process_group=None,
num_bucket_build_workers=1,
**kwargs,
):
_kwargs = kwargs.copy()
process_group = process_group or _get_default_group()
batch_sampler = VariableVideoBatchSampler(
dataset,
bucket_config,
num_replicas=process_group.size(),
rank=process_group.rank(),
shuffle=shuffle,
seed=seed,
drop_last=drop_last,
verbose=True,
num_bucket_build_workers=num_bucket_build_workers,
)
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return torch.utils.data.DataLoader(
dataset,
batch_sampler=batch_sampler,
worker_init_fn=seed_worker,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs,
)
def build_batch_dataloader(
dataset,
# batch_size=1,
# shuffle=False,
seed=1024,
# drop_last=False,
pin_memory=False,
num_workers=0,
process_group: Optional[ProcessGroup] = None,
distributed=True,
**kwargs,
):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` and `BatchDistributedSampler`.
batch_size must be 1; shuffle is not supported so far
"""
_kwargs = kwargs.copy()
if distributed:
process_group = process_group or _get_default_group()
elif isinstance(dataset, BatchFeatureDataset):
sampler = BatchDistributedSampler(
dataset,
num_replicas=process_group.size(),
rank=process_group.rank(),
)
return DataLoader(
dataset,
batch_size=1,
sampler=sampler,
worker_init_fn=get_seed_worker(seed),
pin_memory=pin_memory,
num_workers=num_workers,
collate_fn=collate_fn_batch,
**_kwargs,
)
else:
raise NotImplementedError
sampler = None
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return DataLoader(
dataset,
batch_size=1,
sampler=sampler,
worker_init_fn=seed_worker,
# drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs,
)
raise ValueError(f"Unsupported dataset type: {type(dataset)}")
def collate_fn_default(batch):

View file

@ -1,4 +1,5 @@
import os
from glob import glob
import numpy as np
import torch
@ -185,16 +186,11 @@ class VariableVideoTextDataset(VideoTextDataset):
return ret
def __getitem__(self, index):
try:
return self.getitem(index)
except Exception:
# we return None here in case of errorneous data
# the collate function will handle it
return None
return self.getitem(index)
@DATASETS.register_module()
class BatchDataset(torch.utils.data.Dataset):
class BatchFeatureDataset(torch.utils.data.Dataset):
"""
The dataset is composed of multiple .bin files.
Each .bin file is a list of batch data (like a buffer). All .bin files have the same length.
@ -203,16 +199,15 @@ class BatchDataset(torch.utils.data.Dataset):
Avoid loading the same .bin on two difference GPUs, i.e., one .bin is assigned to one GPU only.
"""
def __init__(self):
# self.meta = read_file(data_path)
# self.path_list = self.meta['path'].tolist()
self.path_list = [f"/mnt/nfs-207/sora_data/webvid-10M/feat_text/data/{idx}.bin" for idx in range(5)]
def __init__(self, data_path=None):
self.path_list = sorted(glob(data_path + "/**/*.bin"))
self._len_buffer = len(torch.load(self.path_list[0]))
self._num_buffers = len(self.path_list)
self.num_samples = self.len_buffer * len(self.path_list)
self.cur_file_idx = -1
self.cur_buffer = None
@property
def num_buffers(self):
@ -224,10 +219,9 @@ class BatchDataset(torch.utils.data.Dataset):
def _load_buffer(self, idx):
file_idx = idx // self.len_buffer
if file_idx == self.cur_file_idx:
return
self.cur_file_idx = file_idx
self.cur_buffer = torch.load(self.path_list[file_idx])
if file_idx != self.cur_file_idx:
self.cur_file_idx = file_idx
self.cur_buffer = torch.load(self.path_list[file_idx])
def __len__(self):
return self.num_samples

View file

@ -1,6 +1,3 @@
import math
import warnings
from collections import OrderedDict, defaultdict
from pprint import pformat
from typing import Iterator, List, Optional
@ -291,13 +288,15 @@ class VariableVideoBatchSampler(DistributedSampler):
class BatchDistributedSampler(DistributedSampler):
"""
Used with BatchDataset;
Used with BatchFeatureDataset;
Suppose len_buffer == 5, num_buffers == 6, #GPUs == 3, then
| buffer {i} | buffer {i+1}
------ | ------------------- | -------------------
rank 0 | 0, 1, 2, 3, 4, | 5, 6, 7, 8, 9
rank 1 | 10, 11, 12, 13, 14, | 15, 16, 17, 18, 19
rank 2 | 20, 21, 22, 23, 24, | 25, 26, 27, 28, 29
"""
def __iter__(self):
num_buffers = self.dataset.num_buffers
len_buffer = self.dataset.len_buffer

View file

@ -15,6 +15,8 @@ def build_module(module, builder, **kwargs):
Returns:
Any: The built module.
"""
if module is None:
return None
if isinstance(module, dict):
cfg = deepcopy(module)
for k, v in kwargs.items():

View file

@ -7,7 +7,7 @@ import torch.distributed as dist
from tqdm import tqdm
from opensora.acceleration.parallel_states import get_data_parallel_group, set_data_parallel_group
from opensora.datasets.dataloader import collate_fn_default, prepare_dataloader
from opensora.datasets.dataloader import prepare_dataloader
from opensora.registry import DATASETS, MODELS, build_module
from opensora.utils.config_utils import parse_configs, save_training_config
from opensora.utils.misc import FeatureSaver, Timer, create_logger, format_numel_str, get_model_numel, to_torch_dtype
@ -60,7 +60,6 @@ def main():
drop_last=True,
pin_memory=True,
process_group=get_data_parallel_group(),
collate_fn=collate_fn_default,
)
dataloader, sampler = prepare_dataloader(
bucket_config=cfg.get("bucket_config", None),

View file

@ -14,8 +14,7 @@ from tqdm import tqdm
from opensora.acceleration.checkpoint import set_grad_checkpoint
from opensora.acceleration.parallel_states import get_data_parallel_group
from opensora.datasets import prepare_dataloader
from opensora.datasets.dataloader import collate_fn_default
from opensora.datasets.dataloader import prepare_dataloader
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
from opensora.utils.ckpt_utils import load, model_gathering, model_sharding, record_model_param_shape, save
from opensora.utils.config_utils import define_experiment_workspace, parse_configs, save_training_config
@ -99,7 +98,6 @@ def main():
drop_last=True,
pin_memory=True,
process_group=get_data_parallel_group(),
collate_fn=collate_fn_default,
)
dataloader, sampler = prepare_dataloader(
bucket_config=cfg.get("bucket_config", None),

View file

@ -14,7 +14,7 @@ from tqdm import tqdm
from opensora.acceleration.checkpoint import set_grad_checkpoint
from opensora.acceleration.parallel_states import get_data_parallel_group
from opensora.datasets.dataloader import collate_fn_default, prepare_dataloader
from opensora.datasets.dataloader import prepare_dataloader
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
from opensora.utils.ckpt_utils import load, model_gathering, model_sharding, record_model_param_shape, save
from opensora.utils.config_utils import define_experiment_workspace, parse_configs, save_training_config
@ -95,7 +95,6 @@ def main():
drop_last=True,
pin_memory=True,
process_group=get_data_parallel_group(),
collate_fn=collate_fn_default,
)
dataloader, sampler = prepare_dataloader(
bucket_config=cfg.get("bucket_config", None),
@ -109,12 +108,15 @@ def main():
# ======================================================
logger.info("Building models...")
# == build text-encoder and vae ==
text_encoder = build_module(cfg.text_encoder, MODELS, device=device, dtype=dtype)
vae = build_module(cfg.vae, MODELS).to(device, dtype).eval()
text_encoder = build_module(cfg.get("text_encoder", None), MODELS, device=device, dtype=dtype)
vae = build_module(cfg.get("vae", None), MODELS).to(device, dtype).eval()
# == build diffusion model ==
input_size = (dataset.num_frames, *dataset.image_size)
latent_size = vae.get_latent_size(input_size)
if vae is not None:
latent_size = vae.get_latent_size(input_size)
else:
latent_size = (None, None, None)
model = (
build_module(
cfg.model,
@ -221,6 +223,7 @@ def main():
for step, batch in pbar:
x = batch.pop("video").to(device, dtype) # [B, C, T, H, W]
y = batch.pop("text")
breakpoint()
# == visual and text encoding ==
with torch.no_grad():

View file

@ -14,7 +14,7 @@ from tqdm import tqdm
from opensora.acceleration.checkpoint import set_grad_checkpoint
from opensora.acceleration.parallel_states import get_data_parallel_group
from opensora.datasets.dataloader import build_batch_dataloader, collate_fn_batch
from opensora.datasets.dataloader import build_batch_dataloader
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
from opensora.utils.ckpt_utils import load, model_gathering, model_sharding, record_model_param_shape, save
from opensora.utils.config_utils import define_experiment_workspace, parse_configs, save_training_config
@ -98,7 +98,6 @@ def main():
# drop_last=True,
pin_memory=True,
process_group=get_data_parallel_group(),
collate_fn=collate_fn_batch,
)
dataloader = build_batch_dataloader(**dataloader_args)
num_steps_per_epoch = len(dataset) // dist.get_world_size()