Exp/image mix (#20)

* [exp] image mixed training

* [exp] add batch info

* [exp] launch

* update num_step_per_epoch

* [feat] verify image mix training
This commit is contained in:
Zheng Zangwei (Alex Zheng) 2024-04-02 13:35:09 +08:00 committed by GitHub
parent f15ef17937
commit 7612d22fc6
8 changed files with 88 additions and 45 deletions

View file

@ -1,14 +1,15 @@
num_frames = 1
frame_interval = 1
image_size = (256, 256)
# Define dataset
root = None
data_path = "CSV_PATH"
use_image_transform = True
num_workers = 4
dataset = dict(
type="VideoTextDataset",
data_path=None,
num_frames=1,
frame_interval=1,
image_size=(256, 256),
transform_name="center",
)
# Define acceleration
num_workers = 4
dtype = "bf16"
grad_checkpoint = False
plugin = "zero2"

View file

@ -1,6 +1,7 @@
num_frames = 4
num_frames = 1
fps = 24 // 3
image_size = (704, 1472)
# image_size = (1358, 680)
image_size = (2160, 3840)
multi_resolution = "STDiT2"
# Define model

View file

@ -4,15 +4,22 @@ dataset = dict(
data_path=None,
num_frames=None,
frame_interval=3,
image_size=(None, None, None),
image_size=(None, None),
transform_name="resize_crop",
)
bucket_config = {
"256": {1: (1.0, 64)}, # 32 ok, 64 broken
# "256": {1: (1.0, 256)}, # 4.5s/it
# "512": {1: (1.0, 96)}, # 4.7s/it
# "512": {1: (1.0, 128)}, # 6.3s/it
# "480p": {1: (1.0, 50)}, # 4.0s/it
# "1024": {1: (1.0, 32)}, # 6.8s/it
# "1024": {1: (1.0, 20)}, # 4.3s/it
# "1080p": {1: (1.0, 16)}, # 8.6s/it
"1080p": {1: (1.0, 8)}, # 4.4s/it
}
# Define acceleration
num_workers = 0
num_workers = 4
dtype = "bf16"
grad_checkpoint = True
plugin = "zero2"

View file

@ -4,15 +4,15 @@ dataset = dict(
data_path=None,
num_frames=None,
frame_interval=3,
image_size=(None, None, None),
image_size=(None, None),
transform_name="resize_crop",
)
bucket_config = {
"256": {1: (1.0, 128)},
"512": {1: (1.0, 2)},
"480p": {1: (1.0, 2)},
"1024": {1: (1.0, 2)},
"1080p": {1: (1.0, 2)},
bucket_config = { # 6s/it
"256": {1: (1.0, 256)},
"512": {1: (1.0, 80)},
"480p": {1: (1.0, 52)},
"1024": {1: (1.0, 20)},
"1080p": {1: (1.0, 8)},
}
# Define acceleration
@ -53,7 +53,7 @@ wandb = False
epochs = 1000
log_every = 10
ckpt_every = 1000
ckpt_every = 500
load = None
batch_size = 10 # only for logging

View file

@ -3,10 +3,12 @@ import math
# Ours
def get_h_w(a, ts):
def get_h_w(a, ts, eps=1e-4):
h = (ts * a) ** 0.5
h = h + eps
h = math.ceil(h) if math.ceil(h) % 2 == 0 else math.floor(h)
w = h / a
w = w + eps
w = math.ceil(w) if math.ceil(w) % 2 == 0 else math.floor(w)
return h, w
@ -93,11 +95,11 @@ ASPECT_RATIO_720P = {
"0.67": (784, 1176),
"0.75": (832, 1110),
"1.00": (960, 960),
"1.33": (1108, 831),
"1.33": (1108, 832),
"1.50": (1176, 784),
"1.78": (1280, 720),
"1.89": (1320, 698),
"2.00": (1358, 679),
"2.00": (1358, 680),
"2.08": (1386, 666),
}
@ -116,7 +118,7 @@ ASPECT_RATIO_480P = {
"1.52": (790, 520),
"1.78": (854, 480),
"1.92": (888, 462),
"2.00": (906, 453),
"2.00": (906, 454),
"2.10": (928, 442),
}
@ -130,14 +132,14 @@ ASPECT_RATIO_360P = {
"0.54": (470, 870),
"0.56": (480, 854), # base
"0.62": (506, 810),
"0.67": (522, 783),
"0.67": (522, 784),
"0.75": (554, 738),
"1.00": (640, 640),
"1.33": (740, 555),
"1.50": (784, 522),
"1.78": (854, 480),
"1.89": (880, 466),
"2.00": (906, 453),
"2.00": (906, 454),
"2.08": (924, 444),
}

View file

@ -4,6 +4,7 @@ from pprint import pprint
from typing import Iterator, List, Optional
import torch
import torch.distributed as dist
from torch.utils.data import DistributedSampler
from .bucket import Bucket
@ -29,6 +30,27 @@ class VariableVideoBatchSampler(DistributedSampler):
self.bucket = Bucket(bucket_config)
self.verbose = verbose
self.last_micro_batch_access_index = 0
self.approximate_num_batch = None
def get_num_batch(self) -> int:
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
bucket_sample_dict = OrderedDict()
# group by bucket
# each data sample is put into a bucket with a similar image/video size
for i in range(len(self.dataset)):
t, h, w = self.dataset.get_data_info(i)
bucket_id = self.bucket.get_bucket_id(t, h, w, self.dataset.frame_interval, g)
if bucket_id is None:
continue
if bucket_id not in bucket_sample_dict:
bucket_sample_dict[bucket_id] = []
bucket_sample_dict[bucket_id].append(i)
# calculate the number of batches
self._print_bucket_info(bucket_sample_dict)
return self.approximate_num_batch
def __iter__(self) -> Iterator[List[int]]:
g = torch.Generator()
@ -155,8 +177,9 @@ class VariableVideoBatchSampler(DistributedSampler):
def load_state_dict(self, state_dict: dict) -> None:
self.__dict__.update(state_dict)
def _print_bucket_info(self, bucket_sample_dict: dict) -> None:
def _print_bucket_info(self, bucket_sample_dict: dict, verbose=True) -> None:
total_samples = 0
num_batch = 0
num_dict = {}
num_aspect_dict = defaultdict(int)
num_hwt_dict = defaultdict(int)
@ -166,13 +189,17 @@ class VariableVideoBatchSampler(DistributedSampler):
num_dict[k] = size
num_aspect_dict[k[-1]] += size
num_hwt_dict[k[:-1]] += size
print(f"Total training samples: {total_samples}, num buckets: {len(num_dict)}")
print("Bucket samples:")
pprint(num_dict)
print("Bucket samples by HxWxT:")
pprint(num_hwt_dict)
print("Bucket samples by aspect ratio:")
pprint(num_aspect_dict)
num_batch += size // self.bucket.get_batch_size(k[:-1])
if dist.get_rank() == 0 and verbose:
print(f"Total training samples: {total_samples}, num buckets: {len(num_dict)}")
print("Bucket samples:")
pprint(num_dict)
print("Bucket samples by HxWxT:")
pprint(num_hwt_dict)
print("Bucket samples by aspect ratio:")
pprint(num_aspect_dict)
print(f"Number of batches: {num_batch}")
self.approximate_num_batch = num_batch
def set_epoch(self, epoch: int) -> None:
super().set_epoch(epoch)

View file

@ -53,9 +53,9 @@ class VideoAutoencoderKL(nn.Module):
def get_latent_size(self, input_size):
latent_size = []
for i in range(3):
assert (
input_size[i] is None or input_size[i] % self.patch_size[i] == 0
), "Input size must be divisible by patch size"
# assert (
# input_size[i] is None or input_size[i] % self.patch_size[i] == 0
# ), "Input size must be divisible by patch size"
latent_size.append(input_size[i] // self.patch_size[i] if input_size[i] is not None else None)
return latent_size
@ -89,9 +89,9 @@ class VideoAutoencoderKLTemporalDecoder(nn.Module):
def get_latent_size(self, input_size):
latent_size = []
for i in range(3):
assert (
input_size[i] is None or input_size[i] % self.patch_size[i] == 0
), "Input size must be divisible by patch size"
# assert (
# input_size[i] is None or input_size[i] % self.patch_size[i] == 0
# ), "Input size must be divisible by patch size"
latent_size.append(input_size[i] // self.patch_size[i] if input_size[i] is not None else None)
return latent_size

View file

@ -37,8 +37,6 @@ def main():
# 1. args & cfg
# ======================================================
cfg = parse_configs(training=True)
print("Training configuration:")
pprint(cfg._cfg_dict)
exp_name, exp_dir = create_experiment_workspace(cfg)
save_training_config(cfg._cfg_dict, exp_dir)
@ -58,6 +56,8 @@ def main():
if not coordinator.is_master():
logger = create_logger(None)
else:
print("Training configuration:")
pprint(cfg._cfg_dict)
logger = create_logger(exp_dir)
logger.info(f"Experiment directory created at {exp_dir}")
@ -173,13 +173,16 @@ def main():
dataloader=dataloader,
)
torch.set_default_dtype(torch.float)
num_steps_per_epoch = len(dataloader)
logger.info("Boost model for distributed training")
if cfg.dataset.type == "VariableVideoTextDataset":
num_steps_per_epoch = dataloader.batch_sampler.get_num_batch() // dist.get_world_size()
else:
num_steps_per_epoch = len(dataloader)
# =======================================================
# 6. training loop
# =======================================================
start_epoch = start_step = log_step = sampler_start_idx = 0
start_epoch = start_step = log_step = sampler_start_idx = acc_step = 0
running_loss = 0.0
sampler_to_io = dataloader.batch_sampler if cfg.dataset.type == "VariableVideoTextDataset" else None
# 6.1. resume training
@ -254,6 +257,7 @@ def main():
running_loss += loss.item()
global_step = epoch * num_steps_per_epoch + step
log_step += 1
acc_step += 1
# Log to tensorboard
if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0:
@ -269,6 +273,7 @@ def main():
"epoch": epoch,
"loss": loss.item(),
"avg_loss": avg_loss,
"acc_step": acc_step,
},
step=global_step,
)