[wip] eval loss and search bs

This commit is contained in:
zhengzangw 2024-05-15 12:21:07 +00:00
parent 811ac1055b
commit bb3c23577c
7 changed files with 232 additions and 275 deletions

View file

@ -14,7 +14,13 @@ bucket_config = {
# # ---
# "240p": {1: (100, 20), 51: (24, 5), 102: (12, 4), 204: (4, 2), 408: (2, 1)},
# ---
"512": {1: (60, 100), 51: (12, 4), 102: (6, 2), 204: (2, 1), 408: (1, 0)},
"512": {
# 1: (141, 0),
51: (9, 4),
102: (6, 2),
204: (2, 1),
# 408: (1, 0),
},
# ---
# "480p": {1: (40, 10), 51: (6, 2), 102: (3, 2), 204: (1, 1)},
# # ---

View file

@ -0,0 +1,49 @@
num_workers = 8
dtype = "bf16"
seed = 42
num_eval_timesteps = 10
# Dataset settings
dataset = dict(
type="VariableVideoTextDataset",
transform_name="resize_crop",
)
# just occupy the space.... actually in evaluation we will create dataset for different resolutions
bucket_config = { # 20s/it
"144p": {1: (1.0, 100), 51: (1.0, 30), 102: ((1.0, 0.33), 20), 204: ((1.0, 0.1), 8), 408: ((1.0, 0.1), 4)},
# ---
"240p": {1: (0.3, 100), 51: (0.4, 24), 102: ((0.4, 0.33), 12), 204: ((0.4, 0.1), 4), 408: ((0.4, 0.1), 2)},
# ---
"360p": {1: (0.2, 60), 51: (0.15, 12), 102: ((0.15, 0.33), 6), 204: ((0.15, 0.1), 2), 408: ((0.15, 0.1), 1)},
# ---
"480p": {1: (0.1, 40), 51: (0.3, 6), 102: (0.3, 3), 204: (0.3, 1), 408: (0.0, None)},
# ---
"720p": {1: (0.05, 20), 51: (0.3, 2), 102: (0.3, 1), 204: (0.0, None)},
# ---
"1080p": {1: (0.1, 10)},
# ---
"2048": {1: (0.1, 5)},
}
# Model settings
model = dict(
type="STDiT3-XL/2",
from_pretrained=None,
qk_norm=True,
enable_flash_attn=True,
enable_layernorm_kernel=True,
)
vae = dict(
type="OpenSoraVAE_V1_2",
from_pretrained="pretrained_models/vae-pipeline",
micro_frame_size=17,
micro_batch_size=4,
)
text_encoder = dict(
type="t5",
from_pretrained="DeepFloyd/t5-v1_1-xxl",
model_max_length=300,
local_files_only=True,
)
scheduler = dict(type="rflow")

View file

@ -1,90 +0,0 @@
# Dataset settings
dataset = dict(
type="VideoTextDataset",
data_path=None,
num_frames=16,
frame_interval=3,
image_size=(256, 256),
transform_name="resize_crop",
) # just occupy the space.... actually in evaluation we will create dataset for different resolutions
eval_config = { # 2s/it
"144p": {1: (0.5, 40), 34: (1.0, 10), 51: (1.0, 10), 102: (1.0, 5), 204: (1.0, 2)},
# # ---
"256": {1: (0.6, 40), 34: (0.5, 10), 51: (0.5, 5), 68: (0.5, 5), 136: (0.0, 4)},
"240p": {1: (0.6, 40), 34: (0.5, 10), 51: (0.5, 5), 68: (0.5, 5), 136: (0.0, 4)},
# # ---
"360p": {1: (0.5, 20), 34: (0.2, 8), 102: (0.0, 4)},
"512": {1: (0.5, 10), 34: (0.2, 8), 102: (0.0, 4)},
# ---
"480p": {1: (0.2, 10), 17: (0.3, 5), 68: (0.0, 2)},,
# ---
"720p": {1: (0.1, 5)},
"1024": {1: (0.1, 4)},
# ---
"1080p": {1: (0.1, 2)},
}
grad_checkpoint = False # determine batch size
batch_size = None
# Acceleration settings
num_workers = 8
num_bucket_build_workers = 16
dtype = "bf16"
plugin = "zero2"
# Model settings
model = dict(
type="STDiT3-XL/2",
from_pretrained=None,
qk_norm=True,
enable_flash_attn=True,
enable_layernorm_kernel=True,
)
vae = dict(
type="VideoAutoencoderPipeline",
from_pretrained="/home/zhengzangwei/projs/Open-Sora-dev/pretrained_models/vae-v3",
micro_frame_size=17,
vae_2d=dict(
type="VideoAutoencoderKL",
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
subfolder="vae",
micro_batch_size=4,
local_files_only=True,
),
vae_temporal=dict(
type="VAE_Temporal_SD",
from_pretrained=None,
),
)
text_encoder = dict(
type="t5",
from_pretrained="DeepFloyd/t5-v1_1-xxl",
model_max_length=300,
shardformer=False,
local_files_only=True,
)
scheduler = dict(
type="rflow",
use_discrete_timesteps=False,
use_timestep_transform=True,
sample_method="logit-normal",
)
# Log settings
seed = 42
outputs = "outputs/eval_loss"
epochs = 1000
log_every = 10
ckpt_every = 500
# optimization settings
load = None
grad_clip = 1.0
lr = 1e-4
ema_decay = 0.99
adam_eps = 1e-15
# eval
num_eval_samples = 40 # num eval samples per (res, num_frames, ar, t)
num_eval_timesteps = 20

View file

@ -1,176 +0,0 @@
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from copy import deepcopy
from datetime import timedelta
from pprint import pformat
import torch
import numpy as np
from tqdm import tqdm
from mmengine.runner import set_random_seed
from torch.utils.data import DataLoader as Dataloader
import random
import json
from opensora.acceleration.checkpoint import set_grad_checkpoint
from opensora.acceleration.parallel_states import get_data_parallel_group
from opensora.datasets import prepare_dataloader, prepare_variable_dataloader
from opensora.datasets.aspect import *
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
from opensora.utils.ckpt_utils import load, model_gathering, model_sharding, record_model_param_shape, save
from opensora.utils.config_utils import (
create_tensorboard_writer,
define_experiment_workspace,
parse_configs,
save_training_config,
)
from opensora.utils.misc import (
all_reduce_mean,
create_logger,
format_numel_str,
get_model_numel,
requires_grad,
to_torch_dtype,
)
from opensora.utils.train_utils import MaskGenerator, create_colossalai_plugin, update_ema
DEFAULT_DATASET_NAME = "VideoTextDataset"
@torch.no_grad()
def main():
# ======================================================
# 1. configs & runtime variables
# ======================================================
# == parse configs ==
cfg = parse_configs(training=True)
device = torch.device("cuda")
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
cfg_dtype = cfg.get("dtype", "bf16")
assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}"
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
set_random_seed(seed=cfg.seed)
# == init exp_dir ==
exp_name, exp_dir = define_experiment_workspace(cfg)
os.makedirs(exp_dir, exist_ok=False)
# == init logger, build public models ==
# logger = create_logger(exp_dir)
text_encoder = build_module(cfg.text_encoder, MODELS, device=device)
vae = build_module(cfg.vae, MODELS).to(device, dtype)
vae.eval()
model = build_module(
cfg.model,
MODELS,
input_size=(None, None, None),
in_channels=vae.out_channels,
caption_channels=text_encoder.output_dim,
model_max_length=text_encoder.model_max_length,
).to(device, dtype)
model.eval()
scheduler = build_module(cfg.scheduler, SCHEDULERS)
if cfg.get("mask_ratios", None) is not None:
mask_generator = MaskGenerator(cfg.mask_ratios)
torch.set_default_dtype(dtype)
torch.set_default_dtype(torch.float)
# start evaluation, prepare a dataset everytime in the loop
eval_config = cfg.get("eval_config")
assert eval_config is not None, "eval_config is required for evaluation"
evaluation_losses = {}
for res, v in eval_config.items():
loss_res = {}
for num_frames, (_,batch_size) in v.items():
loss_frame = {}
# for each resolution, there may be different aspect ratios(image_size), can be found in datasets/aspect.py
with tqdm(ASPECT_RATIOS[res][1].items(), desc=f"Resolution {res} num_frames {num_frames}") as pbar:
for ar, img_size in pbar:
# == build dataset ==
dataset = build_module({
"type": DEFAULT_DATASET_NAME,
"num_frames": num_frames,
"data_path": cfg.dataset['data_path'],
"frame_interval": 1,
"image_size": img_size,
"transform_name":"resize_crop",
}, DATASETS)
# == build dataloader ==
seed=cfg.get("seed", 1024)
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
dataloader_args = dict(
dataset=dataset,
batch_size=batch_size,
num_workers=cfg.get("num_workers", 4),
# seed=cfg.get("seed", 1024),
shuffle=True,
drop_last=True,
pin_memory=True,
worker_init_fn=seed_worker,
)
dataloader = Dataloader(
**dataloader_args,
)
# dataloader.sampler.set_start_index(0)
dataloader_iter = iter(dataloader)
num_steps_per_t = cfg.num_eval_samples // batch_size
loss_ar = {}
for t in range(0, scheduler.num_timesteps, scheduler.num_timesteps//cfg.num_eval_timesteps):
# save key = (res, num_frames, ar, t), value = loss finally
loss_t = None
for estep in range(num_steps_per_t):
batch = next(dataloader_iter)
x = batch.pop("video").to(device, dtype)
y = batch.pop("text")
x = vae.encode(x)
model_args = text_encoder.encode(y)
model_args["x_mask"] = None
mask = None
if cfg.get("mask_ratios", None) is not None:
mask = mask_generator.get_masks(x)
model_args["x_mask"] = mask
# == video meta info ==
for k, v in batch.items():
model_args[k] = v.to(device, dtype)
# add height, width and num_frame since they are not in batch meta info
model_args["height"] = torch.tensor([img_size[0]], device=device, dtype=dtype)
model_args["width"] = torch.tensor([img_size[1]], device=device, dtype=dtype)
model_args['num_frames'] = torch.tensor([num_frames], device=device, dtype=dtype)
# == diffusion loss computation ==
timestep = torch.tensor([t]*x.shape[0], device=device, dtype=dtype)
loss_dict = scheduler.training_losses(model, x, model_args, mask=mask, t=timestep)
losses = loss_dict["loss"] # (batch_size)
loss_t = losses if loss_t is None else torch.cat([loss_t, losses], dim=0)
# save the avg loss for this tuple(res, num_frames, ar, t)
loss_ar[t] = loss_t.mean().item(), loss_t.std().item()
pbar.set_postfix({"ar": ar, "t": t, "loss": loss_ar[t][0]})
loss_frame[ar] = loss_ar
loss_res[num_frames] = loss_frame
evaluation_losses[res] = loss_res
with open(os.path.join(exp_dir, "evaluation_losses.json"), "w") as f:
json.dump(evaluation_losses, f)
# save the evaluation_losses
if __name__ == "__main__":
main()

View file

@ -229,7 +229,7 @@ class OpenSoraVAE_V1_2(VideoAutoencoderPipeline):
micro_batch_size=4,
micro_frame_size=17,
from_pretrained=None,
local_files_only=True,
local_files_only=False,
freeze_vae_2d=False,
cal_loss=False,
):

164
scripts/misc/eval_loss.py Normal file
View file

@ -0,0 +1,164 @@
from pprint import pformat
import colossalai
import torch
import torch.distributed as dist
from colossalai.cluster import DistCoordinator
from mmengine.runner import set_random_seed
from tqdm import tqdm
from opensora.acceleration.parallel_states import get_data_parallel_group
from opensora.datasets import prepare_variable_dataloader
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
from opensora.utils.config_utils import parse_configs
from opensora.utils.misc import create_logger, to_torch_dtype
from opensora.utils.train_utils import MaskGenerator
def main():
torch.set_grad_enabled(False)
# ======================================================
# configs & runtime variables
# ======================================================
# == parse configs ==
cfg = parse_configs(training=False)
# == device and dtype ==
device = "cuda" if torch.cuda.is_available() else "cpu"
cfg_dtype = cfg.get("dtype", "fp32")
assert cfg_dtype in ["fp16", "bf16", "fp32"], f"Unknown mixed precision {cfg_dtype}"
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# == device and dtype ==
device = "cuda" if torch.cuda.is_available() else "cpu"
cfg_dtype = cfg.get("dtype", "fp32")
assert cfg_dtype in ["fp16", "bf16", "fp32"], f"Unknown mixed precision {cfg_dtype}"
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# == init distributed env ==
colossalai.launch_from_torch({})
DistCoordinator()
set_random_seed(seed=cfg.get("seed", 1024))
# == init logger ==
logger = create_logger()
logger.info("Training configuration:\n %s", pformat(cfg.to_dict()))
# ======================================================
# build model & load weights
# ======================================================
logger.info("Building models...")
# == build text-encoder and vae ==
text_encoder = build_module(cfg.text_encoder, MODELS, device=device)
vae = build_module(cfg.vae, MODELS).to(device, dtype).eval()
# == build diffusion model ==
input_size = (None, None, None)
latent_size = vae.get_latent_size(input_size)
model = (
build_module(
cfg.model,
MODELS,
input_size=latent_size,
in_channels=vae.out_channels,
caption_channels=text_encoder.output_dim,
model_max_length=text_encoder.model_max_length,
)
.to(device, dtype)
.eval()
)
text_encoder.y_embedder = model.y_embedder # HACK: for classifier-free guidance
# == build scheduler ==
scheduler = build_module(cfg.scheduler, SCHEDULERS)
if cfg.get("mask_ratios", None) is not None:
mask_generator = MaskGenerator(cfg.mask_ratios)
# ======================================================
# inference
# ======================================================
# start evaluation, prepare a dataset everytime in the loop
bucket_config = cfg.bucket_config
assert bucket_config is not None, "bucket_config is required for evaluation"
def build_dataset(resolution, num_frames, batch_size):
bucket_config = {resolution: {num_frames: (1.0, batch_size)}}
dataset = build_module(cfg.dataset, DATASETS)
dataloader_args = dict(
dataset=dataset,
batch_size=None,
num_workers=cfg.num_workers,
shuffle=False,
drop_last=False,
pin_memory=True,
process_group=get_data_parallel_group(),
)
dataloader = prepare_variable_dataloader(
bucket_config=bucket_config,
**dataloader_args,
)
num_batch = dataloader.batch_sampler.get_num_batch()
num_steps_per_epoch = num_batch // dist.get_world_size()
return dataloader, num_steps_per_epoch, num_batch
evaluation_losses = {}
for res, t_bucket in bucket_config.items():
for num_frames, (_, batch_size) in t_bucket.items():
if batch_size is None:
continue
logger.info("Evaluating resolution: %s, num_frames: %s", res, num_frames)
dataloader, num_steps_per_epoch, num_batch = build_dataset(res, num_frames, batch_size)
if num_batch == 0:
logger.warning("No data for resolution: %s, num_frames: %s", res, num_frames)
continue
evaluation_t_losses = []
for t in torch.linspace(0, scheduler.num_timesteps, cfg.get("num_eval_timesteps", 10)):
loss_t = 0.0
num_samples = 0
dataloader_iter = iter(dataloader)
for _ in tqdm(range(num_steps_per_epoch), desc=f"res: {res}, num_frames: {num_frames}, t: {t:.2f}"):
batch = next(dataloader_iter)
x = batch.pop("video").to(device, dtype)
y = batch.pop("text")
x = vae.encode(x)
model_args = text_encoder.encode(y)
# == mask ==
mask = None
if cfg.get("mask_ratios", None) is not None:
mask = mask_generator.get_masks(x)
model_args["x_mask"] = mask
# == video meta info ==
for k, v in batch.items():
model_args[k] = v.to(device, dtype)
# == diffusion loss computation ==
timestep = torch.tensor([t] * x.shape[0], device=device, dtype=dtype)
loss_dict = scheduler.training_losses(model, x, model_args, mask=mask, t=timestep)
losses = loss_dict["loss"] # (batch_size)
num_samples += x.shape[0]
loss_t += losses.sum().item()
loss_t /= num_samples
evaluation_t_losses.append(loss_t)
logger.info("resolution: %s, num_frames: %s, timestep: %.2f, loss: %.4f", res, num_frames, t, loss_t)
evaluation_losses[(res, num_frames)] = sum(evaluation_t_losses) / len(evaluation_t_losses)
logger.info(
"Evaluation losses for resolution: %s, num_frames: %s, loss: %s\n %s",
res,
num_frames,
evaluation_losses[(res, num_frames)],
evaluation_t_losses,
)
logger.info("Evaluation losses: %s", evaluation_losses)
if __name__ == "__main__":
main()

View file

@ -150,8 +150,8 @@ def main():
dataset=dataset,
batch_size=None,
num_workers=cfg.num_workers,
shuffle=True,
drop_last=True,
shuffle=False,
drop_last=False,
pin_memory=True,
process_group=get_data_parallel_group(),
)
@ -159,9 +159,11 @@ def main():
bucket_config=bucket_config,
**dataloader_args,
)
dataloader_iter = iter(dataloader)
num_batch = dataloader.batch_sampler.get_num_batch()
num_steps_per_epoch = num_batch // dist.get_world_size()
dataloader_iter = iter(dataloader)
return dataloader_iter, num_steps_per_epoch, num_batch
def train(resolution, num_frames, batch_size, warmup_steps=5, active_steps=5):
@ -256,17 +258,19 @@ def main():
mid = (lower_bound + upper_bound) // 2
try:
step_time = train(resolution, num_frames, mid)
target_batch_size, target_step_time = mid, step_time
if step_time < 0: # no data
return 0, 0
if ref_step_time is not None:
dis = abs(target_step_time - ref_step_time)
if dis < min_dis:
min_dis = dis
if step_time < ref_step_time:
lower_bound = mid + 1
dis = abs(step_time - ref_step_time)
if dis < min_dis:
target_batch_size, target_step_time = mid, step_time
min_dis = dis
else:
upper_bound = mid - 1
else:
target_batch_size, target_step_time = mid, step_time
lower_bound = mid + 1
except Exception:
traceback.print_exc()