mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
[wip] eval loss and search bs
This commit is contained in:
parent
811ac1055b
commit
bb3c23577c
|
|
@ -14,7 +14,13 @@ bucket_config = {
|
|||
# # ---
|
||||
# "240p": {1: (100, 20), 51: (24, 5), 102: (12, 4), 204: (4, 2), 408: (2, 1)},
|
||||
# ---
|
||||
"512": {1: (60, 100), 51: (12, 4), 102: (6, 2), 204: (2, 1), 408: (1, 0)},
|
||||
"512": {
|
||||
# 1: (141, 0),
|
||||
51: (9, 4),
|
||||
102: (6, 2),
|
||||
204: (2, 1),
|
||||
# 408: (1, 0),
|
||||
},
|
||||
# ---
|
||||
# "480p": {1: (40, 10), 51: (6, 2), 102: (3, 2), 204: (1, 1)},
|
||||
# # ---
|
||||
|
|
|
|||
49
configs/opensora-v1-2/misc/eval_loss.py
Normal file
49
configs/opensora-v1-2/misc/eval_loss.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
num_workers = 8
|
||||
dtype = "bf16"
|
||||
seed = 42
|
||||
num_eval_timesteps = 10
|
||||
|
||||
# Dataset settings
|
||||
dataset = dict(
|
||||
type="VariableVideoTextDataset",
|
||||
transform_name="resize_crop",
|
||||
)
|
||||
|
||||
# just occupy the space.... actually in evaluation we will create dataset for different resolutions
|
||||
bucket_config = { # 20s/it
|
||||
"144p": {1: (1.0, 100), 51: (1.0, 30), 102: ((1.0, 0.33), 20), 204: ((1.0, 0.1), 8), 408: ((1.0, 0.1), 4)},
|
||||
# ---
|
||||
"240p": {1: (0.3, 100), 51: (0.4, 24), 102: ((0.4, 0.33), 12), 204: ((0.4, 0.1), 4), 408: ((0.4, 0.1), 2)},
|
||||
# ---
|
||||
"360p": {1: (0.2, 60), 51: (0.15, 12), 102: ((0.15, 0.33), 6), 204: ((0.15, 0.1), 2), 408: ((0.15, 0.1), 1)},
|
||||
# ---
|
||||
"480p": {1: (0.1, 40), 51: (0.3, 6), 102: (0.3, 3), 204: (0.3, 1), 408: (0.0, None)},
|
||||
# ---
|
||||
"720p": {1: (0.05, 20), 51: (0.3, 2), 102: (0.3, 1), 204: (0.0, None)},
|
||||
# ---
|
||||
"1080p": {1: (0.1, 10)},
|
||||
# ---
|
||||
"2048": {1: (0.1, 5)},
|
||||
}
|
||||
|
||||
# Model settings
|
||||
model = dict(
|
||||
type="STDiT3-XL/2",
|
||||
from_pretrained=None,
|
||||
qk_norm=True,
|
||||
enable_flash_attn=True,
|
||||
enable_layernorm_kernel=True,
|
||||
)
|
||||
vae = dict(
|
||||
type="OpenSoraVAE_V1_2",
|
||||
from_pretrained="pretrained_models/vae-pipeline",
|
||||
micro_frame_size=17,
|
||||
micro_batch_size=4,
|
||||
)
|
||||
text_encoder = dict(
|
||||
type="t5",
|
||||
from_pretrained="DeepFloyd/t5-v1_1-xxl",
|
||||
model_max_length=300,
|
||||
local_files_only=True,
|
||||
)
|
||||
scheduler = dict(type="rflow")
|
||||
|
|
@ -1,90 +0,0 @@
|
|||
# Dataset settings
|
||||
dataset = dict(
|
||||
type="VideoTextDataset",
|
||||
data_path=None,
|
||||
num_frames=16,
|
||||
frame_interval=3,
|
||||
image_size=(256, 256),
|
||||
transform_name="resize_crop",
|
||||
) # just occupy the space.... actually in evaluation we will create dataset for different resolutions
|
||||
eval_config = { # 2s/it
|
||||
"144p": {1: (0.5, 40), 34: (1.0, 10), 51: (1.0, 10), 102: (1.0, 5), 204: (1.0, 2)},
|
||||
# # ---
|
||||
"256": {1: (0.6, 40), 34: (0.5, 10), 51: (0.5, 5), 68: (0.5, 5), 136: (0.0, 4)},
|
||||
"240p": {1: (0.6, 40), 34: (0.5, 10), 51: (0.5, 5), 68: (0.5, 5), 136: (0.0, 4)},
|
||||
# # ---
|
||||
"360p": {1: (0.5, 20), 34: (0.2, 8), 102: (0.0, 4)},
|
||||
"512": {1: (0.5, 10), 34: (0.2, 8), 102: (0.0, 4)},
|
||||
# ---
|
||||
"480p": {1: (0.2, 10), 17: (0.3, 5), 68: (0.0, 2)},,
|
||||
# ---
|
||||
"720p": {1: (0.1, 5)},
|
||||
"1024": {1: (0.1, 4)},
|
||||
# ---
|
||||
"1080p": {1: (0.1, 2)},
|
||||
}
|
||||
grad_checkpoint = False # determine batch size
|
||||
batch_size = None
|
||||
|
||||
# Acceleration settings
|
||||
num_workers = 8
|
||||
num_bucket_build_workers = 16
|
||||
dtype = "bf16"
|
||||
plugin = "zero2"
|
||||
|
||||
# Model settings
|
||||
model = dict(
|
||||
type="STDiT3-XL/2",
|
||||
from_pretrained=None,
|
||||
qk_norm=True,
|
||||
enable_flash_attn=True,
|
||||
enable_layernorm_kernel=True,
|
||||
)
|
||||
vae = dict(
|
||||
type="VideoAutoencoderPipeline",
|
||||
from_pretrained="/home/zhengzangwei/projs/Open-Sora-dev/pretrained_models/vae-v3",
|
||||
micro_frame_size=17,
|
||||
vae_2d=dict(
|
||||
type="VideoAutoencoderKL",
|
||||
from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
|
||||
subfolder="vae",
|
||||
micro_batch_size=4,
|
||||
local_files_only=True,
|
||||
),
|
||||
vae_temporal=dict(
|
||||
type="VAE_Temporal_SD",
|
||||
from_pretrained=None,
|
||||
),
|
||||
)
|
||||
text_encoder = dict(
|
||||
type="t5",
|
||||
from_pretrained="DeepFloyd/t5-v1_1-xxl",
|
||||
model_max_length=300,
|
||||
shardformer=False,
|
||||
local_files_only=True,
|
||||
)
|
||||
scheduler = dict(
|
||||
type="rflow",
|
||||
use_discrete_timesteps=False,
|
||||
use_timestep_transform=True,
|
||||
sample_method="logit-normal",
|
||||
)
|
||||
|
||||
|
||||
# Log settings
|
||||
seed = 42
|
||||
outputs = "outputs/eval_loss"
|
||||
epochs = 1000
|
||||
log_every = 10
|
||||
ckpt_every = 500
|
||||
|
||||
# optimization settings
|
||||
load = None
|
||||
grad_clip = 1.0
|
||||
lr = 1e-4
|
||||
ema_decay = 0.99
|
||||
adam_eps = 1e-15
|
||||
|
||||
# eval
|
||||
num_eval_samples = 40 # num eval samples per (res, num_frames, ar, t)
|
||||
num_eval_timesteps = 20
|
||||
|
|
@ -1,176 +0,0 @@
|
|||
import os
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
from copy import deepcopy
|
||||
from datetime import timedelta
|
||||
from pprint import pformat
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from mmengine.runner import set_random_seed
|
||||
from torch.utils.data import DataLoader as Dataloader
|
||||
import random
|
||||
import json
|
||||
|
||||
from opensora.acceleration.checkpoint import set_grad_checkpoint
|
||||
from opensora.acceleration.parallel_states import get_data_parallel_group
|
||||
from opensora.datasets import prepare_dataloader, prepare_variable_dataloader
|
||||
from opensora.datasets.aspect import *
|
||||
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
|
||||
from opensora.utils.ckpt_utils import load, model_gathering, model_sharding, record_model_param_shape, save
|
||||
from opensora.utils.config_utils import (
|
||||
create_tensorboard_writer,
|
||||
define_experiment_workspace,
|
||||
parse_configs,
|
||||
save_training_config,
|
||||
)
|
||||
from opensora.utils.misc import (
|
||||
all_reduce_mean,
|
||||
create_logger,
|
||||
format_numel_str,
|
||||
get_model_numel,
|
||||
requires_grad,
|
||||
to_torch_dtype,
|
||||
)
|
||||
from opensora.utils.train_utils import MaskGenerator, create_colossalai_plugin, update_ema
|
||||
|
||||
DEFAULT_DATASET_NAME = "VideoTextDataset"
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
# ======================================================
|
||||
# 1. configs & runtime variables
|
||||
# ======================================================
|
||||
# == parse configs ==
|
||||
|
||||
cfg = parse_configs(training=True)
|
||||
device = torch.device("cuda")
|
||||
|
||||
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
|
||||
cfg_dtype = cfg.get("dtype", "bf16")
|
||||
assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}"
|
||||
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
|
||||
set_random_seed(seed=cfg.seed)
|
||||
|
||||
# == init exp_dir ==
|
||||
exp_name, exp_dir = define_experiment_workspace(cfg)
|
||||
os.makedirs(exp_dir, exist_ok=False)
|
||||
|
||||
|
||||
# == init logger, build public models ==
|
||||
# logger = create_logger(exp_dir)
|
||||
|
||||
text_encoder = build_module(cfg.text_encoder, MODELS, device=device)
|
||||
vae = build_module(cfg.vae, MODELS).to(device, dtype)
|
||||
vae.eval()
|
||||
model = build_module(
|
||||
cfg.model,
|
||||
MODELS,
|
||||
input_size=(None, None, None),
|
||||
in_channels=vae.out_channels,
|
||||
caption_channels=text_encoder.output_dim,
|
||||
model_max_length=text_encoder.model_max_length,
|
||||
).to(device, dtype)
|
||||
model.eval()
|
||||
|
||||
scheduler = build_module(cfg.scheduler, SCHEDULERS)
|
||||
|
||||
if cfg.get("mask_ratios", None) is not None:
|
||||
mask_generator = MaskGenerator(cfg.mask_ratios)
|
||||
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
# start evaluation, prepare a dataset everytime in the loop
|
||||
eval_config = cfg.get("eval_config")
|
||||
assert eval_config is not None, "eval_config is required for evaluation"
|
||||
|
||||
evaluation_losses = {}
|
||||
for res, v in eval_config.items():
|
||||
loss_res = {}
|
||||
for num_frames, (_,batch_size) in v.items():
|
||||
loss_frame = {}
|
||||
# for each resolution, there may be different aspect ratios(image_size), can be found in datasets/aspect.py
|
||||
with tqdm(ASPECT_RATIOS[res][1].items(), desc=f"Resolution {res} num_frames {num_frames}") as pbar:
|
||||
for ar, img_size in pbar:
|
||||
# == build dataset ==
|
||||
dataset = build_module({
|
||||
"type": DEFAULT_DATASET_NAME,
|
||||
"num_frames": num_frames,
|
||||
"data_path": cfg.dataset['data_path'],
|
||||
"frame_interval": 1,
|
||||
"image_size": img_size,
|
||||
"transform_name":"resize_crop",
|
||||
}, DATASETS)
|
||||
# == build dataloader ==
|
||||
seed=cfg.get("seed", 1024)
|
||||
def seed_worker(worker_id):
|
||||
worker_seed = seed
|
||||
np.random.seed(worker_seed)
|
||||
torch.manual_seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
dataloader_args = dict(
|
||||
dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=cfg.get("num_workers", 4),
|
||||
# seed=cfg.get("seed", 1024),
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
pin_memory=True,
|
||||
worker_init_fn=seed_worker,
|
||||
)
|
||||
dataloader = Dataloader(
|
||||
**dataloader_args,
|
||||
)
|
||||
# dataloader.sampler.set_start_index(0)
|
||||
dataloader_iter = iter(dataloader)
|
||||
num_steps_per_t = cfg.num_eval_samples // batch_size
|
||||
loss_ar = {}
|
||||
|
||||
for t in range(0, scheduler.num_timesteps, scheduler.num_timesteps//cfg.num_eval_timesteps):
|
||||
# save key = (res, num_frames, ar, t), value = loss finally
|
||||
loss_t = None
|
||||
for estep in range(num_steps_per_t):
|
||||
batch = next(dataloader_iter)
|
||||
x = batch.pop("video").to(device, dtype)
|
||||
y = batch.pop("text")
|
||||
x = vae.encode(x)
|
||||
model_args = text_encoder.encode(y)
|
||||
model_args["x_mask"] = None
|
||||
|
||||
mask = None
|
||||
if cfg.get("mask_ratios", None) is not None:
|
||||
mask = mask_generator.get_masks(x)
|
||||
model_args["x_mask"] = mask
|
||||
|
||||
# == video meta info ==
|
||||
for k, v in batch.items():
|
||||
model_args[k] = v.to(device, dtype)
|
||||
|
||||
|
||||
# add height, width and num_frame since they are not in batch meta info
|
||||
model_args["height"] = torch.tensor([img_size[0]], device=device, dtype=dtype)
|
||||
model_args["width"] = torch.tensor([img_size[1]], device=device, dtype=dtype)
|
||||
model_args['num_frames'] = torch.tensor([num_frames], device=device, dtype=dtype)
|
||||
|
||||
# == diffusion loss computation ==
|
||||
timestep = torch.tensor([t]*x.shape[0], device=device, dtype=dtype)
|
||||
loss_dict = scheduler.training_losses(model, x, model_args, mask=mask, t=timestep)
|
||||
losses = loss_dict["loss"] # (batch_size)
|
||||
loss_t = losses if loss_t is None else torch.cat([loss_t, losses], dim=0)
|
||||
# save the avg loss for this tuple(res, num_frames, ar, t)
|
||||
loss_ar[t] = loss_t.mean().item(), loss_t.std().item()
|
||||
pbar.set_postfix({"ar": ar, "t": t, "loss": loss_ar[t][0]})
|
||||
loss_frame[ar] = loss_ar
|
||||
|
||||
loss_res[num_frames] = loss_frame
|
||||
evaluation_losses[res] = loss_res
|
||||
with open(os.path.join(exp_dir, "evaluation_losses.json"), "w") as f:
|
||||
json.dump(evaluation_losses, f)
|
||||
|
||||
# save the evaluation_losses
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -229,7 +229,7 @@ class OpenSoraVAE_V1_2(VideoAutoencoderPipeline):
|
|||
micro_batch_size=4,
|
||||
micro_frame_size=17,
|
||||
from_pretrained=None,
|
||||
local_files_only=True,
|
||||
local_files_only=False,
|
||||
freeze_vae_2d=False,
|
||||
cal_loss=False,
|
||||
):
|
||||
|
|
|
|||
164
scripts/misc/eval_loss.py
Normal file
164
scripts/misc/eval_loss.py
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
from pprint import pformat
|
||||
|
||||
import colossalai
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from mmengine.runner import set_random_seed
|
||||
from tqdm import tqdm
|
||||
|
||||
from opensora.acceleration.parallel_states import get_data_parallel_group
|
||||
from opensora.datasets import prepare_variable_dataloader
|
||||
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
|
||||
from opensora.utils.config_utils import parse_configs
|
||||
from opensora.utils.misc import create_logger, to_torch_dtype
|
||||
from opensora.utils.train_utils import MaskGenerator
|
||||
|
||||
|
||||
def main():
|
||||
torch.set_grad_enabled(False)
|
||||
# ======================================================
|
||||
# configs & runtime variables
|
||||
# ======================================================
|
||||
# == parse configs ==
|
||||
cfg = parse_configs(training=False)
|
||||
|
||||
# == device and dtype ==
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
cfg_dtype = cfg.get("dtype", "fp32")
|
||||
assert cfg_dtype in ["fp16", "bf16", "fp32"], f"Unknown mixed precision {cfg_dtype}"
|
||||
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
# == device and dtype ==
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
cfg_dtype = cfg.get("dtype", "fp32")
|
||||
assert cfg_dtype in ["fp16", "bf16", "fp32"], f"Unknown mixed precision {cfg_dtype}"
|
||||
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
# == init distributed env ==
|
||||
colossalai.launch_from_torch({})
|
||||
DistCoordinator()
|
||||
set_random_seed(seed=cfg.get("seed", 1024))
|
||||
|
||||
# == init logger ==
|
||||
logger = create_logger()
|
||||
logger.info("Training configuration:\n %s", pformat(cfg.to_dict()))
|
||||
|
||||
# ======================================================
|
||||
# build model & load weights
|
||||
# ======================================================
|
||||
logger.info("Building models...")
|
||||
# == build text-encoder and vae ==
|
||||
text_encoder = build_module(cfg.text_encoder, MODELS, device=device)
|
||||
vae = build_module(cfg.vae, MODELS).to(device, dtype).eval()
|
||||
|
||||
# == build diffusion model ==
|
||||
input_size = (None, None, None)
|
||||
latent_size = vae.get_latent_size(input_size)
|
||||
model = (
|
||||
build_module(
|
||||
cfg.model,
|
||||
MODELS,
|
||||
input_size=latent_size,
|
||||
in_channels=vae.out_channels,
|
||||
caption_channels=text_encoder.output_dim,
|
||||
model_max_length=text_encoder.model_max_length,
|
||||
)
|
||||
.to(device, dtype)
|
||||
.eval()
|
||||
)
|
||||
text_encoder.y_embedder = model.y_embedder # HACK: for classifier-free guidance
|
||||
|
||||
# == build scheduler ==
|
||||
scheduler = build_module(cfg.scheduler, SCHEDULERS)
|
||||
|
||||
if cfg.get("mask_ratios", None) is not None:
|
||||
mask_generator = MaskGenerator(cfg.mask_ratios)
|
||||
|
||||
# ======================================================
|
||||
# inference
|
||||
# ======================================================
|
||||
# start evaluation, prepare a dataset everytime in the loop
|
||||
bucket_config = cfg.bucket_config
|
||||
assert bucket_config is not None, "bucket_config is required for evaluation"
|
||||
|
||||
def build_dataset(resolution, num_frames, batch_size):
|
||||
bucket_config = {resolution: {num_frames: (1.0, batch_size)}}
|
||||
dataset = build_module(cfg.dataset, DATASETS)
|
||||
dataloader_args = dict(
|
||||
dataset=dataset,
|
||||
batch_size=None,
|
||||
num_workers=cfg.num_workers,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
pin_memory=True,
|
||||
process_group=get_data_parallel_group(),
|
||||
)
|
||||
dataloader = prepare_variable_dataloader(
|
||||
bucket_config=bucket_config,
|
||||
**dataloader_args,
|
||||
)
|
||||
num_batch = dataloader.batch_sampler.get_num_batch()
|
||||
num_steps_per_epoch = num_batch // dist.get_world_size()
|
||||
return dataloader, num_steps_per_epoch, num_batch
|
||||
|
||||
evaluation_losses = {}
|
||||
for res, t_bucket in bucket_config.items():
|
||||
for num_frames, (_, batch_size) in t_bucket.items():
|
||||
if batch_size is None:
|
||||
continue
|
||||
logger.info("Evaluating resolution: %s, num_frames: %s", res, num_frames)
|
||||
dataloader, num_steps_per_epoch, num_batch = build_dataset(res, num_frames, batch_size)
|
||||
if num_batch == 0:
|
||||
logger.warning("No data for resolution: %s, num_frames: %s", res, num_frames)
|
||||
continue
|
||||
|
||||
evaluation_t_losses = []
|
||||
for t in torch.linspace(0, scheduler.num_timesteps, cfg.get("num_eval_timesteps", 10)):
|
||||
loss_t = 0.0
|
||||
num_samples = 0
|
||||
dataloader_iter = iter(dataloader)
|
||||
for _ in tqdm(range(num_steps_per_epoch), desc=f"res: {res}, num_frames: {num_frames}, t: {t:.2f}"):
|
||||
batch = next(dataloader_iter)
|
||||
x = batch.pop("video").to(device, dtype)
|
||||
y = batch.pop("text")
|
||||
x = vae.encode(x)
|
||||
model_args = text_encoder.encode(y)
|
||||
|
||||
# == mask ==
|
||||
mask = None
|
||||
if cfg.get("mask_ratios", None) is not None:
|
||||
mask = mask_generator.get_masks(x)
|
||||
model_args["x_mask"] = mask
|
||||
|
||||
# == video meta info ==
|
||||
for k, v in batch.items():
|
||||
model_args[k] = v.to(device, dtype)
|
||||
|
||||
# == diffusion loss computation ==
|
||||
timestep = torch.tensor([t] * x.shape[0], device=device, dtype=dtype)
|
||||
loss_dict = scheduler.training_losses(model, x, model_args, mask=mask, t=timestep)
|
||||
losses = loss_dict["loss"] # (batch_size)
|
||||
num_samples += x.shape[0]
|
||||
loss_t += losses.sum().item()
|
||||
loss_t /= num_samples
|
||||
evaluation_t_losses.append(loss_t)
|
||||
logger.info("resolution: %s, num_frames: %s, timestep: %.2f, loss: %.4f", res, num_frames, t, loss_t)
|
||||
|
||||
evaluation_losses[(res, num_frames)] = sum(evaluation_t_losses) / len(evaluation_t_losses)
|
||||
logger.info(
|
||||
"Evaluation losses for resolution: %s, num_frames: %s, loss: %s\n %s",
|
||||
res,
|
||||
num_frames,
|
||||
evaluation_losses[(res, num_frames)],
|
||||
evaluation_t_losses,
|
||||
)
|
||||
logger.info("Evaluation losses: %s", evaluation_losses)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -150,8 +150,8 @@ def main():
|
|||
dataset=dataset,
|
||||
batch_size=None,
|
||||
num_workers=cfg.num_workers,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
pin_memory=True,
|
||||
process_group=get_data_parallel_group(),
|
||||
)
|
||||
|
|
@ -159,9 +159,11 @@ def main():
|
|||
bucket_config=bucket_config,
|
||||
**dataloader_args,
|
||||
)
|
||||
dataloader_iter = iter(dataloader)
|
||||
num_batch = dataloader.batch_sampler.get_num_batch()
|
||||
num_steps_per_epoch = num_batch // dist.get_world_size()
|
||||
|
||||
dataloader_iter = iter(dataloader)
|
||||
|
||||
return dataloader_iter, num_steps_per_epoch, num_batch
|
||||
|
||||
def train(resolution, num_frames, batch_size, warmup_steps=5, active_steps=5):
|
||||
|
|
@ -256,17 +258,19 @@ def main():
|
|||
mid = (lower_bound + upper_bound) // 2
|
||||
try:
|
||||
step_time = train(resolution, num_frames, mid)
|
||||
target_batch_size, target_step_time = mid, step_time
|
||||
if step_time < 0: # no data
|
||||
return 0, 0
|
||||
if ref_step_time is not None:
|
||||
dis = abs(target_step_time - ref_step_time)
|
||||
if dis < min_dis:
|
||||
min_dis = dis
|
||||
if step_time < ref_step_time:
|
||||
lower_bound = mid + 1
|
||||
dis = abs(step_time - ref_step_time)
|
||||
if dis < min_dis:
|
||||
target_batch_size, target_step_time = mid, step_time
|
||||
min_dis = dis
|
||||
else:
|
||||
upper_bound = mid - 1
|
||||
else:
|
||||
target_batch_size, target_step_time = mid, step_time
|
||||
lower_bound = mid + 1
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
|
|
|||
Loading…
Reference in a new issue