mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
[wip] feat training
This commit is contained in:
parent
80f9ecbde3
commit
066b0c9bb3
52
configs/opensora-v1-2/train/stage1_feat.py
Normal file
52
configs/opensora-v1-2/train/stage1_feat.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
# Dataset settings
|
||||
dataset = dict(type="BatchFeatureDataset")
|
||||
grad_checkpoint = True
|
||||
num_workers = 4
|
||||
|
||||
# Acceleration settings
|
||||
dtype = "bf16"
|
||||
plugin = "zero2"
|
||||
|
||||
# Model settings
|
||||
model = dict(
|
||||
type="STDiT3-XL/2",
|
||||
from_pretrained=None,
|
||||
qk_norm=True,
|
||||
enable_flash_attn=True,
|
||||
enable_layernorm_kernel=True,
|
||||
freeze_y_embedder=True,
|
||||
)
|
||||
scheduler = dict(
|
||||
type="rflow",
|
||||
use_timestep_transform=True,
|
||||
sample_method="logit-normal",
|
||||
)
|
||||
|
||||
# Mask settings
|
||||
mask_ratios = {
|
||||
"random": 0.2,
|
||||
"intepolate": 0.01,
|
||||
"quarter_random": 0.01,
|
||||
"quarter_head": 0.01,
|
||||
"quarter_tail": 0.01,
|
||||
"quarter_head_tail": 0.01,
|
||||
"image_random": 0.05,
|
||||
"image_head": 0.1,
|
||||
"image_tail": 0.05,
|
||||
"image_head_tail": 0.05,
|
||||
}
|
||||
|
||||
# Log settings
|
||||
seed = 42
|
||||
outputs = "outputs"
|
||||
wandb = False
|
||||
epochs = 1000
|
||||
log_every = 10
|
||||
ckpt_every = 1
|
||||
|
||||
# optimization settings
|
||||
load = None
|
||||
grad_clip = 1.0
|
||||
lr = 2e-4
|
||||
ema_decay = 0.99
|
||||
adam_eps = 1e-15
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
# Dataset settings
|
||||
dataset = dict(
|
||||
type="BatchDataset",
|
||||
type="BatchFeatureDataset",
|
||||
)
|
||||
|
||||
grad_checkpoint = True
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
from .datasets import IMG_FPS, VariableVideoTextDataset, VideoTextDataset, BatchDataset
|
||||
from .datasets import IMG_FPS, BatchFeatureDataset, VariableVideoTextDataset, VideoTextDataset
|
||||
from .utils import get_transforms_image, get_transforms_video, is_img, is_vid, save_sample
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from torch.distributed import ProcessGroup
|
|||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .datasets import VariableVideoTextDataset, VideoTextDataset
|
||||
from .datasets import BatchFeatureDataset, VariableVideoTextDataset, VideoTextDataset
|
||||
from .sampler import BatchDistributedSampler, StatefulDistributedSampler, VariableVideoBatchSampler
|
||||
|
||||
|
||||
|
|
@ -56,6 +56,7 @@ def prepare_dataloader(
|
|||
worker_init_fn=get_seed_worker(seed),
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
collate_fn=collate_fn_default,
|
||||
**_kwargs,
|
||||
),
|
||||
batch_sampler,
|
||||
|
|
@ -77,105 +78,29 @@ def prepare_dataloader(
|
|||
drop_last=drop_last,
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
collate_fn=collate_fn_default,
|
||||
**_kwargs,
|
||||
),
|
||||
sampler,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported dataset type: {type(dataset)}")
|
||||
|
||||
|
||||
def prepare_variable_dataloader(
|
||||
dataset,
|
||||
batch_size,
|
||||
bucket_config,
|
||||
shuffle=False,
|
||||
seed=1024,
|
||||
drop_last=False,
|
||||
pin_memory=False,
|
||||
num_workers=0,
|
||||
process_group=None,
|
||||
num_bucket_build_workers=1,
|
||||
**kwargs,
|
||||
):
|
||||
_kwargs = kwargs.copy()
|
||||
process_group = process_group or _get_default_group()
|
||||
batch_sampler = VariableVideoBatchSampler(
|
||||
dataset,
|
||||
bucket_config,
|
||||
num_replicas=process_group.size(),
|
||||
rank=process_group.rank(),
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
drop_last=drop_last,
|
||||
verbose=True,
|
||||
num_bucket_build_workers=num_bucket_build_workers,
|
||||
)
|
||||
|
||||
# Deterministic dataloader
|
||||
def seed_worker(worker_id):
|
||||
worker_seed = seed
|
||||
np.random.seed(worker_seed)
|
||||
torch.manual_seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
worker_init_fn=seed_worker,
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
**_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def build_batch_dataloader(
|
||||
dataset,
|
||||
# batch_size=1,
|
||||
# shuffle=False,
|
||||
seed=1024,
|
||||
# drop_last=False,
|
||||
pin_memory=False,
|
||||
num_workers=0,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
distributed=True,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
||||
`torch.utils.data.DataLoader` and `BatchDistributedSampler`.
|
||||
|
||||
batch_size must be 1; shuffle is not supported so far
|
||||
"""
|
||||
_kwargs = kwargs.copy()
|
||||
if distributed:
|
||||
process_group = process_group or _get_default_group()
|
||||
elif isinstance(dataset, BatchFeatureDataset):
|
||||
sampler = BatchDistributedSampler(
|
||||
dataset,
|
||||
num_replicas=process_group.size(),
|
||||
rank=process_group.rank(),
|
||||
)
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=1,
|
||||
sampler=sampler,
|
||||
worker_init_fn=get_seed_worker(seed),
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
collate_fn=collate_fn_batch,
|
||||
**_kwargs,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
sampler = None
|
||||
|
||||
# Deterministic dataloader
|
||||
def seed_worker(worker_id):
|
||||
worker_seed = seed
|
||||
np.random.seed(worker_seed)
|
||||
torch.manual_seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=1,
|
||||
sampler=sampler,
|
||||
worker_init_fn=seed_worker,
|
||||
# drop_last=drop_last,
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
**_kwargs,
|
||||
)
|
||||
raise ValueError(f"Unsupported dataset type: {type(dataset)}")
|
||||
|
||||
|
||||
def collate_fn_default(batch):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
from glob import glob
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
|
@ -185,16 +186,11 @@ class VariableVideoTextDataset(VideoTextDataset):
|
|||
return ret
|
||||
|
||||
def __getitem__(self, index):
|
||||
try:
|
||||
return self.getitem(index)
|
||||
except Exception:
|
||||
# we return None here in case of errorneous data
|
||||
# the collate function will handle it
|
||||
return None
|
||||
return self.getitem(index)
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class BatchDataset(torch.utils.data.Dataset):
|
||||
class BatchFeatureDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
The dataset is composed of multiple .bin files.
|
||||
Each .bin file is a list of batch data (like a buffer). All .bin files have the same length.
|
||||
|
|
@ -203,16 +199,15 @@ class BatchDataset(torch.utils.data.Dataset):
|
|||
Avoid loading the same .bin on two difference GPUs, i.e., one .bin is assigned to one GPU only.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# self.meta = read_file(data_path)
|
||||
# self.path_list = self.meta['path'].tolist()
|
||||
self.path_list = [f"/mnt/nfs-207/sora_data/webvid-10M/feat_text/data/{idx}.bin" for idx in range(5)]
|
||||
def __init__(self, data_path=None):
|
||||
self.path_list = sorted(glob(data_path + "/**/*.bin"))
|
||||
|
||||
self._len_buffer = len(torch.load(self.path_list[0]))
|
||||
self._num_buffers = len(self.path_list)
|
||||
self.num_samples = self.len_buffer * len(self.path_list)
|
||||
|
||||
self.cur_file_idx = -1
|
||||
self.cur_buffer = None
|
||||
|
||||
@property
|
||||
def num_buffers(self):
|
||||
|
|
@ -224,10 +219,9 @@ class BatchDataset(torch.utils.data.Dataset):
|
|||
|
||||
def _load_buffer(self, idx):
|
||||
file_idx = idx // self.len_buffer
|
||||
if file_idx == self.cur_file_idx:
|
||||
return
|
||||
self.cur_file_idx = file_idx
|
||||
self.cur_buffer = torch.load(self.path_list[file_idx])
|
||||
if file_idx != self.cur_file_idx:
|
||||
self.cur_file_idx = file_idx
|
||||
self.cur_buffer = torch.load(self.path_list[file_idx])
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
|
|
|||
|
|
@ -1,6 +1,3 @@
|
|||
import math
|
||||
import warnings
|
||||
|
||||
from collections import OrderedDict, defaultdict
|
||||
from pprint import pformat
|
||||
from typing import Iterator, List, Optional
|
||||
|
|
@ -291,13 +288,15 @@ class VariableVideoBatchSampler(DistributedSampler):
|
|||
|
||||
class BatchDistributedSampler(DistributedSampler):
|
||||
"""
|
||||
Used with BatchDataset;
|
||||
Used with BatchFeatureDataset;
|
||||
Suppose len_buffer == 5, num_buffers == 6, #GPUs == 3, then
|
||||
| buffer {i} | buffer {i+1}
|
||||
------ | ------------------- | -------------------
|
||||
rank 0 | 0, 1, 2, 3, 4, | 5, 6, 7, 8, 9
|
||||
rank 1 | 10, 11, 12, 13, 14, | 15, 16, 17, 18, 19
|
||||
rank 2 | 20, 21, 22, 23, 24, | 25, 26, 27, 28, 29
|
||||
"""
|
||||
|
||||
def __iter__(self):
|
||||
num_buffers = self.dataset.num_buffers
|
||||
len_buffer = self.dataset.len_buffer
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@ def build_module(module, builder, **kwargs):
|
|||
Returns:
|
||||
Any: The built module.
|
||||
"""
|
||||
if module is None:
|
||||
return None
|
||||
if isinstance(module, dict):
|
||||
cfg = deepcopy(module)
|
||||
for k, v in kwargs.items():
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import torch.distributed as dist
|
|||
from tqdm import tqdm
|
||||
|
||||
from opensora.acceleration.parallel_states import get_data_parallel_group, set_data_parallel_group
|
||||
from opensora.datasets.dataloader import collate_fn_default, prepare_dataloader
|
||||
from opensora.datasets.dataloader import prepare_dataloader
|
||||
from opensora.registry import DATASETS, MODELS, build_module
|
||||
from opensora.utils.config_utils import parse_configs, save_training_config
|
||||
from opensora.utils.misc import FeatureSaver, Timer, create_logger, format_numel_str, get_model_numel, to_torch_dtype
|
||||
|
|
@ -60,7 +60,6 @@ def main():
|
|||
drop_last=True,
|
||||
pin_memory=True,
|
||||
process_group=get_data_parallel_group(),
|
||||
collate_fn=collate_fn_default,
|
||||
)
|
||||
dataloader, sampler = prepare_dataloader(
|
||||
bucket_config=cfg.get("bucket_config", None),
|
||||
|
|
|
|||
|
|
@ -14,8 +14,7 @@ from tqdm import tqdm
|
|||
|
||||
from opensora.acceleration.checkpoint import set_grad_checkpoint
|
||||
from opensora.acceleration.parallel_states import get_data_parallel_group
|
||||
from opensora.datasets import prepare_dataloader
|
||||
from opensora.datasets.dataloader import collate_fn_default
|
||||
from opensora.datasets.dataloader import prepare_dataloader
|
||||
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
|
||||
from opensora.utils.ckpt_utils import load, model_gathering, model_sharding, record_model_param_shape, save
|
||||
from opensora.utils.config_utils import define_experiment_workspace, parse_configs, save_training_config
|
||||
|
|
@ -99,7 +98,6 @@ def main():
|
|||
drop_last=True,
|
||||
pin_memory=True,
|
||||
process_group=get_data_parallel_group(),
|
||||
collate_fn=collate_fn_default,
|
||||
)
|
||||
dataloader, sampler = prepare_dataloader(
|
||||
bucket_config=cfg.get("bucket_config", None),
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from tqdm import tqdm
|
|||
|
||||
from opensora.acceleration.checkpoint import set_grad_checkpoint
|
||||
from opensora.acceleration.parallel_states import get_data_parallel_group
|
||||
from opensora.datasets.dataloader import collate_fn_default, prepare_dataloader
|
||||
from opensora.datasets.dataloader import prepare_dataloader
|
||||
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
|
||||
from opensora.utils.ckpt_utils import load, model_gathering, model_sharding, record_model_param_shape, save
|
||||
from opensora.utils.config_utils import define_experiment_workspace, parse_configs, save_training_config
|
||||
|
|
@ -95,7 +95,6 @@ def main():
|
|||
drop_last=True,
|
||||
pin_memory=True,
|
||||
process_group=get_data_parallel_group(),
|
||||
collate_fn=collate_fn_default,
|
||||
)
|
||||
dataloader, sampler = prepare_dataloader(
|
||||
bucket_config=cfg.get("bucket_config", None),
|
||||
|
|
@ -109,12 +108,15 @@ def main():
|
|||
# ======================================================
|
||||
logger.info("Building models...")
|
||||
# == build text-encoder and vae ==
|
||||
text_encoder = build_module(cfg.text_encoder, MODELS, device=device, dtype=dtype)
|
||||
vae = build_module(cfg.vae, MODELS).to(device, dtype).eval()
|
||||
text_encoder = build_module(cfg.get("text_encoder", None), MODELS, device=device, dtype=dtype)
|
||||
vae = build_module(cfg.get("vae", None), MODELS).to(device, dtype).eval()
|
||||
|
||||
# == build diffusion model ==
|
||||
input_size = (dataset.num_frames, *dataset.image_size)
|
||||
latent_size = vae.get_latent_size(input_size)
|
||||
if vae is not None:
|
||||
latent_size = vae.get_latent_size(input_size)
|
||||
else:
|
||||
latent_size = (None, None, None)
|
||||
model = (
|
||||
build_module(
|
||||
cfg.model,
|
||||
|
|
@ -221,6 +223,7 @@ def main():
|
|||
for step, batch in pbar:
|
||||
x = batch.pop("video").to(device, dtype) # [B, C, T, H, W]
|
||||
y = batch.pop("text")
|
||||
breakpoint()
|
||||
|
||||
# == visual and text encoding ==
|
||||
with torch.no_grad():
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from tqdm import tqdm
|
|||
|
||||
from opensora.acceleration.checkpoint import set_grad_checkpoint
|
||||
from opensora.acceleration.parallel_states import get_data_parallel_group
|
||||
from opensora.datasets.dataloader import build_batch_dataloader, collate_fn_batch
|
||||
from opensora.datasets.dataloader import build_batch_dataloader
|
||||
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
|
||||
from opensora.utils.ckpt_utils import load, model_gathering, model_sharding, record_model_param_shape, save
|
||||
from opensora.utils.config_utils import define_experiment_workspace, parse_configs, save_training_config
|
||||
|
|
@ -98,7 +98,6 @@ def main():
|
|||
# drop_last=True,
|
||||
pin_memory=True,
|
||||
process_group=get_data_parallel_group(),
|
||||
collate_fn=collate_fn_batch,
|
||||
)
|
||||
dataloader = build_batch_dataloader(**dataloader_args)
|
||||
num_steps_per_epoch = len(dataset) // dist.get_world_size()
|
||||
|
|
|
|||
Loading…
Reference in a new issue