mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 13:14:44 +02:00
Exp/image mix (#20)
* [exp] image mixed training * [exp] add batch info * [exp] launch * update num_step_per_epoch * [feat] verify image mix training
This commit is contained in:
parent
f15ef17937
commit
7612d22fc6
|
|
@ -1,14 +1,15 @@
|
|||
num_frames = 1
|
||||
frame_interval = 1
|
||||
image_size = (256, 256)
|
||||
|
||||
# Define dataset
|
||||
root = None
|
||||
data_path = "CSV_PATH"
|
||||
use_image_transform = True
|
||||
num_workers = 4
|
||||
dataset = dict(
|
||||
type="VideoTextDataset",
|
||||
data_path=None,
|
||||
num_frames=1,
|
||||
frame_interval=1,
|
||||
image_size=(256, 256),
|
||||
transform_name="center",
|
||||
)
|
||||
|
||||
# Define acceleration
|
||||
num_workers = 4
|
||||
dtype = "bf16"
|
||||
grad_checkpoint = False
|
||||
plugin = "zero2"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
num_frames = 4
|
||||
num_frames = 1
|
||||
fps = 24 // 3
|
||||
image_size = (704, 1472)
|
||||
# image_size = (1358, 680)
|
||||
image_size = (2160, 3840)
|
||||
multi_resolution = "STDiT2"
|
||||
|
||||
# Define model
|
||||
|
|
@ -4,15 +4,22 @@ dataset = dict(
|
|||
data_path=None,
|
||||
num_frames=None,
|
||||
frame_interval=3,
|
||||
image_size=(None, None, None),
|
||||
image_size=(None, None),
|
||||
transform_name="resize_crop",
|
||||
)
|
||||
bucket_config = {
|
||||
"256": {1: (1.0, 64)}, # 32 ok, 64 broken
|
||||
# "256": {1: (1.0, 256)}, # 4.5s/it
|
||||
# "512": {1: (1.0, 96)}, # 4.7s/it
|
||||
# "512": {1: (1.0, 128)}, # 6.3s/it
|
||||
# "480p": {1: (1.0, 50)}, # 4.0s/it
|
||||
# "1024": {1: (1.0, 32)}, # 6.8s/it
|
||||
# "1024": {1: (1.0, 20)}, # 4.3s/it
|
||||
# "1080p": {1: (1.0, 16)}, # 8.6s/it
|
||||
"1080p": {1: (1.0, 8)}, # 4.4s/it
|
||||
}
|
||||
|
||||
# Define acceleration
|
||||
num_workers = 0
|
||||
num_workers = 4
|
||||
dtype = "bf16"
|
||||
grad_checkpoint = True
|
||||
plugin = "zero2"
|
||||
|
|
@ -4,15 +4,15 @@ dataset = dict(
|
|||
data_path=None,
|
||||
num_frames=None,
|
||||
frame_interval=3,
|
||||
image_size=(None, None, None),
|
||||
image_size=(None, None),
|
||||
transform_name="resize_crop",
|
||||
)
|
||||
bucket_config = {
|
||||
"256": {1: (1.0, 128)},
|
||||
"512": {1: (1.0, 2)},
|
||||
"480p": {1: (1.0, 2)},
|
||||
"1024": {1: (1.0, 2)},
|
||||
"1080p": {1: (1.0, 2)},
|
||||
bucket_config = { # 6s/it
|
||||
"256": {1: (1.0, 256)},
|
||||
"512": {1: (1.0, 80)},
|
||||
"480p": {1: (1.0, 52)},
|
||||
"1024": {1: (1.0, 20)},
|
||||
"1080p": {1: (1.0, 8)},
|
||||
}
|
||||
|
||||
# Define acceleration
|
||||
|
|
@ -53,7 +53,7 @@ wandb = False
|
|||
|
||||
epochs = 1000
|
||||
log_every = 10
|
||||
ckpt_every = 1000
|
||||
ckpt_every = 500
|
||||
load = None
|
||||
|
||||
batch_size = 10 # only for logging
|
||||
|
|
|
|||
|
|
@ -3,10 +3,12 @@ import math
|
|||
# Ours
|
||||
|
||||
|
||||
def get_h_w(a, ts):
|
||||
def get_h_w(a, ts, eps=1e-4):
|
||||
h = (ts * a) ** 0.5
|
||||
h = h + eps
|
||||
h = math.ceil(h) if math.ceil(h) % 2 == 0 else math.floor(h)
|
||||
w = h / a
|
||||
w = w + eps
|
||||
w = math.ceil(w) if math.ceil(w) % 2 == 0 else math.floor(w)
|
||||
return h, w
|
||||
|
||||
|
|
@ -93,11 +95,11 @@ ASPECT_RATIO_720P = {
|
|||
"0.67": (784, 1176),
|
||||
"0.75": (832, 1110),
|
||||
"1.00": (960, 960),
|
||||
"1.33": (1108, 831),
|
||||
"1.33": (1108, 832),
|
||||
"1.50": (1176, 784),
|
||||
"1.78": (1280, 720),
|
||||
"1.89": (1320, 698),
|
||||
"2.00": (1358, 679),
|
||||
"2.00": (1358, 680),
|
||||
"2.08": (1386, 666),
|
||||
}
|
||||
|
||||
|
|
@ -116,7 +118,7 @@ ASPECT_RATIO_480P = {
|
|||
"1.52": (790, 520),
|
||||
"1.78": (854, 480),
|
||||
"1.92": (888, 462),
|
||||
"2.00": (906, 453),
|
||||
"2.00": (906, 454),
|
||||
"2.10": (928, 442),
|
||||
}
|
||||
|
||||
|
|
@ -130,14 +132,14 @@ ASPECT_RATIO_360P = {
|
|||
"0.54": (470, 870),
|
||||
"0.56": (480, 854), # base
|
||||
"0.62": (506, 810),
|
||||
"0.67": (522, 783),
|
||||
"0.67": (522, 784),
|
||||
"0.75": (554, 738),
|
||||
"1.00": (640, 640),
|
||||
"1.33": (740, 555),
|
||||
"1.50": (784, 522),
|
||||
"1.78": (854, 480),
|
||||
"1.89": (880, 466),
|
||||
"2.00": (906, 453),
|
||||
"2.00": (906, 454),
|
||||
"2.08": (924, 444),
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from pprint import pprint
|
|||
from typing import Iterator, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import DistributedSampler
|
||||
|
||||
from .bucket import Bucket
|
||||
|
|
@ -29,6 +30,27 @@ class VariableVideoBatchSampler(DistributedSampler):
|
|||
self.bucket = Bucket(bucket_config)
|
||||
self.verbose = verbose
|
||||
self.last_micro_batch_access_index = 0
|
||||
self.approximate_num_batch = None
|
||||
|
||||
def get_num_batch(self) -> int:
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.seed + self.epoch)
|
||||
bucket_sample_dict = OrderedDict()
|
||||
|
||||
# group by bucket
|
||||
# each data sample is put into a bucket with a similar image/video size
|
||||
for i in range(len(self.dataset)):
|
||||
t, h, w = self.dataset.get_data_info(i)
|
||||
bucket_id = self.bucket.get_bucket_id(t, h, w, self.dataset.frame_interval, g)
|
||||
if bucket_id is None:
|
||||
continue
|
||||
if bucket_id not in bucket_sample_dict:
|
||||
bucket_sample_dict[bucket_id] = []
|
||||
bucket_sample_dict[bucket_id].append(i)
|
||||
|
||||
# calculate the number of batches
|
||||
self._print_bucket_info(bucket_sample_dict)
|
||||
return self.approximate_num_batch
|
||||
|
||||
def __iter__(self) -> Iterator[List[int]]:
|
||||
g = torch.Generator()
|
||||
|
|
@ -155,8 +177,9 @@ class VariableVideoBatchSampler(DistributedSampler):
|
|||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
self.__dict__.update(state_dict)
|
||||
|
||||
def _print_bucket_info(self, bucket_sample_dict: dict) -> None:
|
||||
def _print_bucket_info(self, bucket_sample_dict: dict, verbose=True) -> None:
|
||||
total_samples = 0
|
||||
num_batch = 0
|
||||
num_dict = {}
|
||||
num_aspect_dict = defaultdict(int)
|
||||
num_hwt_dict = defaultdict(int)
|
||||
|
|
@ -166,13 +189,17 @@ class VariableVideoBatchSampler(DistributedSampler):
|
|||
num_dict[k] = size
|
||||
num_aspect_dict[k[-1]] += size
|
||||
num_hwt_dict[k[:-1]] += size
|
||||
print(f"Total training samples: {total_samples}, num buckets: {len(num_dict)}")
|
||||
print("Bucket samples:")
|
||||
pprint(num_dict)
|
||||
print("Bucket samples by HxWxT:")
|
||||
pprint(num_hwt_dict)
|
||||
print("Bucket samples by aspect ratio:")
|
||||
pprint(num_aspect_dict)
|
||||
num_batch += size // self.bucket.get_batch_size(k[:-1])
|
||||
if dist.get_rank() == 0 and verbose:
|
||||
print(f"Total training samples: {total_samples}, num buckets: {len(num_dict)}")
|
||||
print("Bucket samples:")
|
||||
pprint(num_dict)
|
||||
print("Bucket samples by HxWxT:")
|
||||
pprint(num_hwt_dict)
|
||||
print("Bucket samples by aspect ratio:")
|
||||
pprint(num_aspect_dict)
|
||||
print(f"Number of batches: {num_batch}")
|
||||
self.approximate_num_batch = num_batch
|
||||
|
||||
def set_epoch(self, epoch: int) -> None:
|
||||
super().set_epoch(epoch)
|
||||
|
|
|
|||
|
|
@ -53,9 +53,9 @@ class VideoAutoencoderKL(nn.Module):
|
|||
def get_latent_size(self, input_size):
|
||||
latent_size = []
|
||||
for i in range(3):
|
||||
assert (
|
||||
input_size[i] is None or input_size[i] % self.patch_size[i] == 0
|
||||
), "Input size must be divisible by patch size"
|
||||
# assert (
|
||||
# input_size[i] is None or input_size[i] % self.patch_size[i] == 0
|
||||
# ), "Input size must be divisible by patch size"
|
||||
latent_size.append(input_size[i] // self.patch_size[i] if input_size[i] is not None else None)
|
||||
return latent_size
|
||||
|
||||
|
|
@ -89,9 +89,9 @@ class VideoAutoencoderKLTemporalDecoder(nn.Module):
|
|||
def get_latent_size(self, input_size):
|
||||
latent_size = []
|
||||
for i in range(3):
|
||||
assert (
|
||||
input_size[i] is None or input_size[i] % self.patch_size[i] == 0
|
||||
), "Input size must be divisible by patch size"
|
||||
# assert (
|
||||
# input_size[i] is None or input_size[i] % self.patch_size[i] == 0
|
||||
# ), "Input size must be divisible by patch size"
|
||||
latent_size.append(input_size[i] // self.patch_size[i] if input_size[i] is not None else None)
|
||||
return latent_size
|
||||
|
||||
|
|
|
|||
|
|
@ -37,8 +37,6 @@ def main():
|
|||
# 1. args & cfg
|
||||
# ======================================================
|
||||
cfg = parse_configs(training=True)
|
||||
print("Training configuration:")
|
||||
pprint(cfg._cfg_dict)
|
||||
exp_name, exp_dir = create_experiment_workspace(cfg)
|
||||
save_training_config(cfg._cfg_dict, exp_dir)
|
||||
|
||||
|
|
@ -58,6 +56,8 @@ def main():
|
|||
if not coordinator.is_master():
|
||||
logger = create_logger(None)
|
||||
else:
|
||||
print("Training configuration:")
|
||||
pprint(cfg._cfg_dict)
|
||||
logger = create_logger(exp_dir)
|
||||
logger.info(f"Experiment directory created at {exp_dir}")
|
||||
|
||||
|
|
@ -173,13 +173,16 @@ def main():
|
|||
dataloader=dataloader,
|
||||
)
|
||||
torch.set_default_dtype(torch.float)
|
||||
num_steps_per_epoch = len(dataloader)
|
||||
logger.info("Boost model for distributed training")
|
||||
if cfg.dataset.type == "VariableVideoTextDataset":
|
||||
num_steps_per_epoch = dataloader.batch_sampler.get_num_batch() // dist.get_world_size()
|
||||
else:
|
||||
num_steps_per_epoch = len(dataloader)
|
||||
|
||||
# =======================================================
|
||||
# 6. training loop
|
||||
# =======================================================
|
||||
start_epoch = start_step = log_step = sampler_start_idx = 0
|
||||
start_epoch = start_step = log_step = sampler_start_idx = acc_step = 0
|
||||
running_loss = 0.0
|
||||
sampler_to_io = dataloader.batch_sampler if cfg.dataset.type == "VariableVideoTextDataset" else None
|
||||
# 6.1. resume training
|
||||
|
|
@ -254,6 +257,7 @@ def main():
|
|||
running_loss += loss.item()
|
||||
global_step = epoch * num_steps_per_epoch + step
|
||||
log_step += 1
|
||||
acc_step += 1
|
||||
|
||||
# Log to tensorboard
|
||||
if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0:
|
||||
|
|
@ -269,6 +273,7 @@ def main():
|
|||
"epoch": epoch,
|
||||
"loss": loss.item(),
|
||||
"avg_loss": avg_loss,
|
||||
"acc_step": acc_step,
|
||||
},
|
||||
step=global_step,
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue