From f489cb09158d0bf48560acd8e69375cf278ca6dd Mon Sep 17 00:00:00 2001 From: zhengzangw Date: Wed, 15 May 2024 08:07:49 +0000 Subject: [PATCH] [feat] accelerate search batch size --- configs/opensora-v1-2/misc/bs.py | 88 +++++ opensora/utils/misc.py | 12 + scripts/misc/search_bs.py | 628 ++++++++++++------------------- 3 files changed, 345 insertions(+), 383 deletions(-) create mode 100644 configs/opensora-v1-2/misc/bs.py diff --git a/configs/opensora-v1-2/misc/bs.py b/configs/opensora-v1-2/misc/bs.py new file mode 100644 index 0000000..c0ee452 --- /dev/null +++ b/configs/opensora-v1-2/misc/bs.py @@ -0,0 +1,88 @@ +# Dataset settings +dataset = dict( + type="VariableVideoTextDataset", + transform_name="resize_crop", +) + +# == Config 1 == +# base: (512, 408), 12s/it +grad_checkpoint = True +base = ("512", "408") +base_step_time = 12 +bucket_config = { + # "144p": {1: (100, 50), 51: (30, 20), 102: (20, 10), 204: (8, 4), 408: (4, 4)}, + # # --- + # "240p": {1: (100, 20), 51: (24, 5), 102: (12, 4), 204: (4, 2), 408: (2, 1)}, + # --- + "512": {1: (60, 100), 51: (12, 4), 102: (6, 2), 204: (2, 1), 408: (1, 0)}, + # --- + # "480p": {1: (40, 10), 51: (6, 2), 102: (3, 2), 204: (1, 1)}, + # # --- + # "1024": {1: (20, 10), 51: (2, 1), 102: (1, 1)}, + # # --- + # "1080p": {1: (10, 5)}, + # # --- + # "2048": {1: (5, 2)}, +} + +# Acceleration settings +num_workers = 8 +num_bucket_build_workers = 16 +dtype = "bf16" +plugin = "zero2" + +# Model settings +model = dict( + type="STDiT3-XL/2", + from_pretrained=None, + qk_norm=True, + enable_flash_attn=True, + enable_layernorm_kernel=True, +) +vae = dict( + type="OpenSoraVAE_V1_2", + from_pretrained="pretrained_models/vae-pipeline", + micro_frame_size=17, + micro_batch_size=4, +) +text_encoder = dict( + type="t5", + from_pretrained="DeepFloyd/t5-v1_1-xxl", + model_max_length=300, + shardformer=True, + local_files_only=True, +) +scheduler = dict( + type="rflow", + use_timestep_transform=True, + sample_method="logit-normal", +) + +# Mask settings +mask_ratios = { + "random": 0.2, + "intepolate": 0.01, + "quarter_random": 0.01, + "quarter_head": 0.01, + "quarter_tail": 0.01, + "quarter_head_tail": 0.01, + "image_random": 0.05, + "image_head": 0.1, + "image_tail": 0.05, + "image_head_tail": 0.05, +} + +# Log settings +seed = 42 +outputs = "outputs" +wandb = False +epochs = 1000 +log_every = 10 +ckpt_every = 500 + +# optimization settings +load = None +grad_clip = 1.0 +lr = 2e-4 +ema_decay = 0.99 +adam_eps = 1e-15 diff --git a/opensora/utils/misc.py b/opensora/utils/misc.py index 621bd73..08b63de 100644 --- a/opensora/utils/misc.py +++ b/opensora/utils/misc.py @@ -135,6 +135,18 @@ def format_time(seconds): return f +class BColors: + HEADER = "\033[95m" + OKBLUE = "\033[94m" + OKCYAN = "\033[96m" + OKGREEN = "\033[92m" + WARNING = "\033[93m" + FAIL = "\033[91m" + ENDC = "\033[0m" + BOLD = "\033[1m" + UNDERLINE = "\033[4m" + + # ====================================================== # PyTorch # ====================================================== diff --git a/scripts/misc/search_bs.py b/scripts/misc/search_bs.py index 815f3b2..26a83fe 100644 --- a/scripts/misc/search_bs.py +++ b/scripts/misc/search_bs.py @@ -1,197 +1,124 @@ -import argparse import time import traceback from copy import deepcopy +from datetime import timedelta -import colossalai import torch import torch.distributed as dist from colossalai.booster import Booster -from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device -from mmengine.config import Config +from colossalai.utils import get_current_device, set_seed from tqdm import tqdm from opensora.acceleration.checkpoint import set_grad_checkpoint -from opensora.acceleration.parallel_states import ( - get_data_parallel_group, - set_data_parallel_group, - set_sequence_parallel_group, -) -from opensora.acceleration.plugin import ZeroSeqParallelPlugin +from opensora.acceleration.parallel_states import get_data_parallel_group from opensora.datasets import prepare_variable_dataloader +from opensora.datasets.aspect import get_num_frames from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module from opensora.utils.ckpt_utils import model_sharding -from opensora.utils.config_utils import merge_args, parse_configs -from opensora.utils.misc import format_numel_str, get_model_numel, requires_grad, to_torch_dtype -from opensora.utils.train_utils import MaskGenerator, update_ema +from opensora.utils.config_utils import parse_configs +from opensora.utils.misc import BColors, create_logger, format_numel_str, get_model_numel, requires_grad, to_torch_dtype +from opensora.utils.train_utils import MaskGenerator, create_colossalai_plugin, update_ema - -class BColors: - HEADER = "\033[95m" - OKBLUE = "\033[94m" - OKCYAN = "\033[96m" - OKGREEN = "\033[92m" - WARNING = "\033[93m" - FAIL = "\033[91m" - ENDC = "\033[0m" - BOLD = "\033[1m" - UNDERLINE = "\033[4m" - - -# BUCKETS = [ -# ("240p", 16), -# ("240p", 32), -# ("240p", 64), -# ("240p", 128), -# ("256", 1), -# ("512", 1), -# ("480p", 1), -# ("480p", 16), -# ("480p", 32), -# ("720p", 16), -# ("720p", 32), -# ("1024", 1), -# ("1080p", 1), -# ] - - -def parse_configs(): - parser = argparse.ArgumentParser() - parser.add_argument("config", help="model config file path") - parser.add_argument("-o", "--output", help="output config file path", default="output_config.py") - - parser.add_argument("--seed", default=42, type=int, help="generation seed") - parser.add_argument( - "--ckpt-path", - type=str, - help="path to model ckpt; will overwrite cfg.ckpt_path if specified", - ) - parser.add_argument("--data-path", default=None, type=str, help="path to data csv", required=True) - parser.add_argument("--warmup-steps", default=1, type=int, help="warmup steps") - parser.add_argument("--active-steps", default=1, type=int, help="active steps") - parser.add_argument("--base-resolution", default="240p", type=str, help="base resolution") - parser.add_argument("--base-frames", default=128, type=int, help="base frames") - parser.add_argument("--batch-size-start", default=2, type=int, help="batch size start") - parser.add_argument("--batch-size-end", default=256, type=int, help="batch size end") - parser.add_argument("--batch-size-step", default=2, type=int, help="batch size step") - args = parser.parse_args() - cfg = Config.fromfile(args.config) - cfg = merge_args(cfg, args, training=True) - return cfg, args - - -def rewrite_config(cfg, resolution, num_frames, batch_size): - cfg.bucket_config = {resolution: {num_frames: (1.0, batch_size)}} - return cfg - - -def update_bucket_config_bs(bucket_config, resolution, num_frames, batch_size): - p, _ = bucket_config[resolution][num_frames] - bucket_config[resolution][num_frames] = (p, batch_size) +SEARCH_BS_PREFIX = f"{BColors.OKGREEN}[Search BS]{BColors.ENDC}" def main(): # ====================================================== - # 1. args & cfg + # configs & runtime variables # ====================================================== - cfg, args = parse_configs() - print(cfg) + # == parse configs == + cfg = parse_configs() assert cfg.dataset.type == "VariableVideoTextDataset", "Only VariableVideoTextDataset is supported" - # ====================================================== - # 2. runtime variables & colossalai launch - # ====================================================== + # == device and dtype == assert torch.cuda.is_available(), "Training currently requires at least one GPU." - assert cfg.dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg.dtype}" + cfg_dtype = cfg.get("dtype", "bf16") + assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}" + dtype = to_torch_dtype(cfg.get("dtype", "bf16")) - # 2.1. colossalai init distributed training - colossalai.launch_from_torch({}) - coordinator = DistCoordinator() + # == colossalai init distributed training == + # NOTE: A very large timeout is set to avoid some processes exit early + dist.init_process_group(backend="nccl", timeout=timedelta(hours=24)) + torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) + set_seed(cfg.get("seed", 1024)) + DistCoordinator() device = get_current_device() - dtype = to_torch_dtype(cfg.dtype) - # 2.3. initialize ColossalAI booster - if cfg.plugin == "zero2": - plugin = LowLevelZeroPlugin( - stage=2, - precision=cfg.dtype, - initial_scale=2**16, - max_norm=cfg.grad_clip, - ) - set_data_parallel_group(dist.group.WORLD) - elif cfg.plugin == "zero2-seq": - plugin = ZeroSeqParallelPlugin( - sp_size=cfg.get("sp_size", 1), - stage=2, - precision=cfg.dtype, - initial_scale=2**16, - max_norm=cfg.grad_clip, - ) - set_sequence_parallel_group(plugin.sp_group) - set_data_parallel_group(plugin.dp_group) - else: - raise ValueError(f"Unknown plugin {cfg.plugin}") + # == init logger, tensorboard & wandb == + logger = create_logger() + + # == init ColossalAI booster == + plugin = create_colossalai_plugin( + plugin=cfg.get("plugin", "zero2"), + dtype=cfg_dtype, + grad_clip=cfg.get("grad_clip", 0), + sp_size=cfg.get("sp_size", 1), + ) booster = Booster(plugin=plugin) # ====================================================== - # 4. build model + # build model # ====================================================== - # 4.1. build model + logger.info("Building models...") + # == build text-encoder and vae == text_encoder = build_module(cfg.text_encoder, MODELS, device=device) - vae = build_module(cfg.vae, MODELS) - input_size = (cfg.dataset.num_frames, *cfg.dataset.image_size) + vae = build_module(cfg.vae, MODELS).to(device, dtype).eval() + + # == build diffusion model == + input_size = (None, None, None) latent_size = vae.get_latent_size(input_size) - model = build_module( - cfg.model, - MODELS, - input_size=latent_size, - in_channels=vae.out_channels, - caption_channels=text_encoder.output_dim, - model_max_length=text_encoder.model_max_length, - dtype=dtype, + model = ( + build_module( + cfg.model, + MODELS, + input_size=latent_size, + in_channels=vae.out_channels, + caption_channels=text_encoder.output_dim, + model_max_length=text_encoder.model_max_length, + ) + .to(device, dtype) + .train() ) model_numel, model_numel_trainable = get_model_numel(model) - coordinator.print_on_master( - f"Trainable model params: {format_numel_str(model_numel_trainable)}, Total model params: {format_numel_str(model_numel)}" + logger.info( + "[Diffusion] Trainable model params: %s, Total model params: %s", + format_numel_str(model_numel_trainable), + format_numel_str(model_numel), ) - # 4.2. create ema + # == build ema for diffusion model == ema = deepcopy(model).to(torch.float32).to(device) requires_grad(ema, False) + ema.eval() + update_ema(ema, model, decay=0, sharded=False) - # 4.3. move to device - vae = vae.to(device, dtype) - model = model.to(device, dtype) - - # 4.4. build scheduler + # == setup loss function, build scheduler == scheduler = build_module(cfg.scheduler, SCHEDULERS) - # 4.5. setup optimizer + # == setup optimizer == optimizer = HybridAdam( filter(lambda p: p.requires_grad, model.parameters()), - lr=cfg.lr, - weight_decay=0, adamw_mode=True, + lr=cfg.get("lr", 1e-4), + weight_decay=cfg.get("weight_decay", 0), + eps=cfg.get("adam_eps", 1e-8), ) lr_scheduler = None - # 4.6. prepare for training - if cfg.grad_checkpoint: + # == additional preparation == + if cfg.get("grad_checkpoint", False): set_grad_checkpoint(model) - model.train() - update_ema(ema, model, decay=0, sharded=False) - ema.eval() - if cfg.mask_ratios is not None: + if cfg.get("mask_ratios", None) is not None: mask_generator = MaskGenerator(cfg.mask_ratios) - else: - mask_generator = None # ======================================================= - # 5. boost model for distributed training with colossalai + # distributed training preparation with colossalai # ======================================================= + logger.info("Preparing for distributed training...") + # == boosting == + # NOTE: we set dtype first to make initialization of model consistent with the dtype; then reset it to the fp32 as we make diffusion scheduler in fp32 torch.set_default_dtype(dtype) model, optimizer, _, _, lr_scheduler = booster.boost( model=model, @@ -199,257 +126,192 @@ def main(): lr_scheduler=lr_scheduler, ) torch.set_default_dtype(torch.float) - coordinator.print_on_master("Boost model for distributed training") + logger.info("Boosting model for distributed training") model_sharding(ema) - buckets = [ - (res, f) for res, d in cfg.bucket_config.items() for f, (p, bs) in d.items() if bs is not None and p > 0.0 - ] + def reset_optimizer(): + # this is essential for the first iteration after OOM + optimizer._grad_store.reset_all_gradients() + optimizer._bucket_store.reset_num_elements_in_bucket() + optimizer._bucket_store.grad_to_param_mapping = dict() + optimizer._bucket_store._grad_in_bucket = dict() + optimizer._bucket_store._param_list = [] + optimizer._bucket_store._padding_size = [] + for rank in range(optimizer._bucket_store._world_size): + optimizer._bucket_store._grad_in_bucket[rank] = [] + optimizer._bucket_store.offset_list = [0] + optimizer.zero_grad() + + def build_dataset(resolution, num_frames, batch_size): + bucket_config = {resolution: {num_frames: (1.0, batch_size)}} + dataset = build_module(cfg.dataset, DATASETS) + dataloader_args = dict( + dataset=dataset, + batch_size=None, + num_workers=cfg.num_workers, + shuffle=True, + drop_last=True, + pin_memory=True, + process_group=get_data_parallel_group(), + ) + dataloader = prepare_variable_dataloader( + bucket_config=bucket_config, + **dataloader_args, + ) + dataloader_iter = iter(dataloader) + num_batch = dataloader.batch_sampler.get_num_batch() + num_steps_per_epoch = num_batch // dist.get_world_size() + return dataloader_iter, num_steps_per_epoch, num_batch + + def train(resolution, num_frames, batch_size, warmup_steps=5, active_steps=5): + logger.info( + "%s Training resolution=%s, num_frames=%s, batch_size=%s", + SEARCH_BS_PREFIX, + resolution, + num_frames, + batch_size, + ) + total_steps = warmup_steps + active_steps + dataloader_iter, num_steps_per_epoch, num_batch = build_dataset(resolution, num_frames, batch_size) + if num_batch == 0: # no data + logger.info("%s No data found for resolution=%s, num_frames=%s", SEARCH_BS_PREFIX, resolution, num_frames) + return -1 + assert ( + num_steps_per_epoch >= total_steps + ), f"num_steps_per_epoch={num_steps_per_epoch} < total_steps={total_steps}" + duration = 0 + + reset_optimizer() + for step, batch in tqdm( + enumerate(dataloader_iter), + desc=f"({resolution},{num_frames}) bs={batch_size}", + total=total_steps, + ): + if step >= total_steps: + break + if step >= warmup_steps: + start = time.time() + + x = batch.pop("video").to(device, dtype) # [B, C, T, H, W] + y = batch.pop("text") + + # == visual and text encoding == + with torch.no_grad(): + # Prepare visual inputs + x = vae.encode(x) # [B, C, T, H/P, W/P] + # Prepare text inputs + model_args = text_encoder.encode(y) + + # == mask == + mask = None + if cfg.get("mask_ratios", None) is not None: + mask = mask_generator.get_masks(x) + model_args["x_mask"] = mask + + # == video meta info == + for k, v in batch.items(): + model_args[k] = v.to(device, dtype) + + # == diffusion loss computation == + loss_dict = scheduler.training_losses(model, x, model_args, mask=mask) + + # == backward & update == + loss = loss_dict["loss"].mean() + booster.backward(loss=loss, optimizer=optimizer) + optimizer.step() + optimizer.zero_grad() + + # == update EMA == + update_ema(ema, model.module, optimizer=optimizer, decay=cfg.get("ema_decay", 0.9999)) + + # == time accumulation == + if step >= warmup_steps: + end = time.time() + duration += end - start + + avg_step_time = duration / active_steps + logger.info("%s Average step time: %.2f", SEARCH_BS_PREFIX, avg_step_time) + return avg_step_time + + # ======================================================= + # search for bucket + # ======================================================= + # == benchmark == + def benchmark(resolution, num_frames, lower_bound, upper_bound, ref_step_time=None): + logger.info( + "%s Benchmarking resolution=%s, num_frames=%s, lower_bound=%s, upper_bound=%s", + SEARCH_BS_PREFIX, + resolution, + num_frames, + lower_bound, + upper_bound, + ) + + # binary search the largest valid batch size + mid = target_batch_size = target_step_time = 0 + if ref_step_time is not None: + min_dis = float("inf") + while lower_bound <= upper_bound: + mid = (lower_bound + upper_bound) // 2 + try: + step_time = train(resolution, num_frames, mid) + target_batch_size, target_step_time = mid, step_time + if step_time < 0: # no data + return 0, 0 + if ref_step_time is not None: + dis = abs(target_step_time - ref_step_time) + if dis < min_dis: + min_dis = dis + lower_bound = mid + 1 + else: + upper_bound = mid - 1 + else: + lower_bound = mid + 1 + except Exception: + traceback.print_exc() + upper_bound = mid - 1 + + logger.info( + "%s Benchmarking result: batch_size=%s, step_time=%s", SEARCH_BS_PREFIX, target_batch_size, target_step_time + ) + return target_batch_size, target_step_time + + # == build bucket == output_bucket_cfg = deepcopy(cfg.bucket_config) - # find the base batch size - assert (args.base_resolution, args.base_frames) in buckets - del buckets[buckets.index((args.base_resolution, args.base_frames))] - base_batch_size, base_step_time = benchmark( - args, - cfg, - args.base_resolution, - args.base_frames, - device, - dtype, - booster, - vae, - text_encoder, - model, - mask_generator, - scheduler, - optimizer, - ema, - ) - update_bucket_config_bs(output_bucket_cfg, args.base_resolution, args.base_frames, base_batch_size) - coordinator.print_on_master( - f"{BColors.OKBLUE}Base resolution: {args.base_resolution}, Base frames: {args.base_frames}, Batch size: {base_batch_size}, Base step time: {base_step_time}{BColors.ENDC}" - ) - result_table = [f"{args.base_resolution}, {args.base_frames}, {base_batch_size}, {base_step_time:.2f}"] - for resolution, frames in buckets: + buckets = { + (resolution, num_frames): (max(guess_bs - variance, 1), guess_bs + variance) + for resolution, t_bucket in cfg.bucket_config.items() + for num_frames, (guess_bs, variance) in t_bucket.items() + } + + # == get base_step_time == + base_step_time = cfg.get("base_step_time", None) + result_table = [] + if base_step_time is None: + base_resolution, base_num_frames = cfg.base + base_num_frames = get_num_frames(base_num_frames) + assert ( + base_resolution, + base_num_frames, + ) in buckets, f"Base bucket {base_resolution} {base_num_frames} not found" + base_bound = buckets.pop((base_resolution, base_num_frames)) + + base_batch_size, base_step_time = benchmark(base_resolution, base_num_frames, *base_bound) + output_bucket_cfg[base_resolution][base_num_frames] = base_batch_size + result_table.append(f"{base_resolution}, {base_num_frames}, {base_batch_size}, {base_step_time:.2f}") + + # == search for other buckets == + for (resolution, frames), bounds in buckets.items(): try: - batch_size, step_time = benchmark( - args, - cfg, - resolution, - frames, - device, - dtype, - booster, - vae, - text_encoder, - model, - mask_generator, - scheduler, - optimizer, - ema, - target_step_time=base_step_time, - ) - coordinator.print_on_master( - f"{BColors.OKBLUE}Resolution: {resolution}, Frames: {frames}, Batch size: {batch_size}, Step time: {step_time}{BColors.ENDC}" - ) - update_bucket_config_bs(output_bucket_cfg, resolution, frames, batch_size) + batch_size, step_time = benchmark(resolution, frames, *bounds, ref_step_time=base_step_time) + output_bucket_cfg[resolution][frames] = batch_size result_table.append(f"{resolution}, {frames}, {batch_size}, {step_time:.2f}") except RuntimeError: pass result_table = "\n".join(result_table) - coordinator.print_on_master( - f"{BColors.OKBLUE}Resolution, Frames, Batch size, Step time\n{result_table}{BColors.ENDC}" - ) - coordinator.print_on_master(f"{BColors.OKBLUE}{output_bucket_cfg}{BColors.ENDC}") - if coordinator.is_master(): - cfg.bucket_config = output_bucket_cfg - cfg.dump(args.output) - - -def benchmark( - args, - cfg, - resolution, - num_frames, - device, - dtype, - booster, - vae, - text_encoder, - model, - mask_generator, - scheduler, - optimizer, - ema, - target_step_time=None, -): - batch_sizes = [] - step_times = [] - - def run_step(bs) -> float: - step_time = train( - args, - cfg, - resolution, - num_frames, - bs, - device, - dtype, - booster, - vae, - text_encoder, - model, - mask_generator, - scheduler, - optimizer, - ema, - ) - step_times.append(step_time) - batch_sizes.append(bs) - return step_time - - orig_bs = cfg.bucket_config[resolution][num_frames][1] - lower_bound = args.batch_size_start - upper_bound = args.batch_size_end - step_size = args.batch_size_step - if isinstance(orig_bs, tuple): - if len(orig_bs) == 1: - upper_bound = orig_bs[0] - elif len(orig_bs) == 2: - lower_bound, upper_bound = orig_bs - elif len(orig_bs) == 3: - lower_bound, upper_bound, step_size = orig_bs - batch_start_size = lower_bound - - while lower_bound < upper_bound: - mid = (lower_bound + upper_bound) // 2 - try: - step_time = run_step(mid) - lower_bound = mid + 1 - except Exception: - traceback.print_exc() - upper_bound = mid - - for batch_size in range(batch_start_size, upper_bound, step_size): - if batch_size in batch_sizes: - continue - step_time = run_step(batch_size) - if len(step_times) == 0: - raise RuntimeError("No valid batch size found") - if target_step_time is None: - # find the fastest batch size - throughputs = [batch_size / step_time for step_time, batch_size in zip(step_times, batch_sizes)] - max_throughput = max(throughputs) - target_batch_size = batch_sizes[throughputs.index(max_throughput)] - step_time = step_times[throughputs.index(max_throughput)] - else: - # find the batch size that meets the target step time - diff = [abs(t - target_step_time) for t in step_times] - closest_step_time = min(diff) - target_batch_size = batch_sizes[diff.index(closest_step_time)] - step_time = step_times[diff.index(closest_step_time)] - return target_batch_size, step_time - - -def train( - args, - cfg, - resolution, - num_frames, - batch_size, - device, - dtype, - booster, - vae, - text_encoder, - model, - mask_generator, - scheduler, - optimizer, - ema, -): - total_steps = args.warmup_steps + args.active_steps - cfg = rewrite_config(deepcopy(cfg), resolution, num_frames, batch_size) - - dataset = build_module(cfg.dataset, DATASETS) - dataset.dummy = True - dataloader_args = dict( - dataset=dataset, - batch_size=cfg.batch_size, - num_workers=cfg.num_workers, - shuffle=True, - drop_last=True, - pin_memory=True, - process_group=get_data_parallel_group(), - ) - dataloader = prepare_variable_dataloader( - bucket_config=cfg.bucket_config, - **dataloader_args, - ) - dataloader_iter = iter(dataloader) - num_steps_per_epoch = dataloader.batch_sampler.get_num_batch() // dist.get_world_size() - - assert num_steps_per_epoch >= total_steps, f"num_steps_per_epoch={num_steps_per_epoch} < total_steps={total_steps}" - duration = 0 - # this is essential for the first iteration after OOM - optimizer._grad_store.reset_all_gradients() - optimizer._bucket_store.reset_num_elements_in_bucket() - optimizer._bucket_store.grad_to_param_mapping = dict() - optimizer._bucket_store._grad_in_bucket = dict() - optimizer._bucket_store._param_list = [] - optimizer._bucket_store._padding_size = [] - for rank in range(optimizer._bucket_store._world_size): - optimizer._bucket_store._grad_in_bucket[rank] = [] - optimizer._bucket_store.offset_list = [0] - optimizer.zero_grad() - for step, batch in tqdm( - enumerate(dataloader_iter), - desc=f"{resolution}:{num_frames} bs={batch_size}", - total=total_steps, - ): - if step >= total_steps: - break - if step >= args.warmup_steps: - start = time.time() - - x = batch.pop("video").to(device, dtype) # [B, C, T, H, W] - y = batch.pop("text") - # Visual and text encoding - with torch.no_grad(): - # Prepare visual inputs - x = vae.encode(x) # [B, C, T, H/P, W/P] - # Prepare text inputs - model_args = text_encoder.encode(y) - - # Mask - if cfg.mask_ratios is not None: - mask = mask_generator.get_masks(x) - model_args["x_mask"] = mask - else: - mask = None - - # Video info - for k, v in batch.items(): - model_args[k] = v.to(device, dtype) - - # Diffusion - loss_dict = scheduler.training_losses(model, x, model_args, mask=mask) - - # Backward & update - loss = loss_dict["loss"].mean() - booster.backward(loss=loss, optimizer=optimizer) - optimizer.step() - optimizer.zero_grad() - - # Update EMA - update_ema(ema, model.module, optimizer=optimizer) - if step >= args.warmup_steps: - end = time.time() - duration += end - start - - avg_step_time = duration / args.active_steps - return avg_step_time + logger.info("%s Search result:\nResolution, Frames, Batch size, Step time\n%s", SEARCH_BS_PREFIX, result_table) + logger.info("%s Bucket searched: %s", SEARCH_BS_PREFIX, output_bucket_cfg) if __name__ == "__main__":