From 8cb6d2f0bdde2db41767ba0b97b1eaaf0791744d Mon Sep 17 00:00:00 2001 From: Zangwei Zheng Date: Wed, 17 Apr 2024 17:05:08 +0800 Subject: [PATCH] [feat] update eval --- opensora/utils/config_utils.py | 7 +++- scripts/inference.py | 20 ++++++---- scripts/misc/sample.sh | 16 ++++++++ scripts/{ => misc}/search_bs.py | 65 ++++++++------------------------- 4 files changed, 50 insertions(+), 58 deletions(-) create mode 100644 scripts/misc/sample.sh rename scripts/{ => misc}/search_bs.py (88%) diff --git a/opensora/utils/config_utils.py b/opensora/utils/config_utils.py index d6f1c01..2efb0cc 100644 --- a/opensora/utils/config_utils.py +++ b/opensora/utils/config_utils.py @@ -30,9 +30,12 @@ def parse_args(training=False): # Inference # ====================================================== if not training: + # output + parser.add_argument("--save-dir", default=None, type=str, help="path to save generated samples") + parser.add_argument("--sample-name", default=None, type=str, help="sample name, default is sample_idx") + # prompt parser.add_argument("--prompt-path", default=None, type=str, help="path to prompt txt file") - parser.add_argument("--save-dir", default=None, type=str, help="path to save generated samples") # image/video parser.add_argument("--num-frames", default=None, type=int, help="number of frames") @@ -71,6 +74,8 @@ def merge_args(cfg, args, training=False): if "prompt" not in cfg or cfg["prompt"] is None: assert cfg["prompt_path"] is not None, "prompt or prompt_path must be provided" cfg["prompt"] = load_prompts(cfg["prompt_path"]) + if "sample_name" not in cfg: + cfg["sample_name"] = None else: # Training only if args.data_path is not None: diff --git a/scripts/inference.py b/scripts/inference.py index 76d1a59..bf2f8e1 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -22,13 +22,18 @@ def main(): print(cfg) # init distributed - colossalai.launch_from_torch({}) - coordinator = DistCoordinator() + if os.environ.get("WORLD_SIZE", None): + use_dist = True + colossalai.launch_from_torch({}) + coordinator = DistCoordinator() - if coordinator.world_size > 1: - set_sequence_parallel_group(dist.group.WORLD) - enable_sequence_parallelism = True + if coordinator.world_size > 1: + set_sequence_parallel_group(dist.group.WORLD) + enable_sequence_parallelism = True + else: + enable_sequence_parallelism = False else: + use_dist = False enable_sequence_parallelism = False # ====================================================== @@ -91,6 +96,7 @@ def main(): # 4. inference # ====================================================== sample_idx = 0 + sample_name = cfg.sample_name if cfg.sample_name is not None else "sample" save_dir = cfg.save_dir os.makedirs(save_dir, exist_ok=True) @@ -112,10 +118,10 @@ def main(): ) samples = vae.decode(samples.to(dtype)) - if coordinator.is_master(): + if not use_dist or coordinator.is_master(): for idx, sample in enumerate(samples): print(f"Prompt: {batch_prompts[idx]}") - save_path = os.path.join(save_dir, f"sample_{sample_idx}") + save_path = os.path.join(save_dir, f"{sample_name}_{sample_idx}") save_sample(sample, fps=cfg.fps, save_path=save_path) sample_idx += 1 diff --git a/scripts/misc/sample.sh b/scripts/misc/sample.sh new file mode 100644 index 0000000..ceba65a --- /dev/null +++ b/scripts/misc/sample.sh @@ -0,0 +1,16 @@ +set -x; + +CUDA_VISIBLE_DEVICES=7 +CMD="python scripts/inference.py configs/opensora-v1-1/inference/sample.py" +CKPT="~/lishenggui/epoch0-global_step8500" +OUTPUT="./outputs/samples_s1_8500" + +# 1. image +# 1.1 1024x1024 +eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2i_samples.txt --save-dir $OUTPUT --num-frames 1 --image-size 1024 1024 --sample-name pixart_1024x1024_1 + +# 1.2 512x512 + +# 1.3 240x426 + +# 1.4 720p multi-resolution diff --git a/scripts/search_bs.py b/scripts/misc/search_bs.py similarity index 88% rename from scripts/search_bs.py rename to scripts/misc/search_bs.py index af224e9..d6789df 100644 --- a/scripts/search_bs.py +++ b/scripts/misc/search_bs.py @@ -25,12 +25,7 @@ from opensora.datasets import prepare_variable_dataloader from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module from opensora.utils.ckpt_utils import model_sharding from opensora.utils.config_utils import merge_args, parse_configs -from opensora.utils.misc import ( - format_numel_str, - get_model_numel, - requires_grad, - to_torch_dtype, -) +from opensora.utils.misc import format_numel_str, get_model_numel, requires_grad, to_torch_dtype from opensora.utils.train_utils import MaskGenerator, update_ema @@ -66,9 +61,7 @@ class BColors: def parse_configs(): parser = argparse.ArgumentParser() parser.add_argument("config", help="model config file path") - parser.add_argument( - "-o", "--output", help="output config file path", default="output_config.py" - ) + parser.add_argument("-o", "--output", help="output config file path", default="output_config.py") parser.add_argument("--seed", default=42, type=int, help="generation seed") parser.add_argument( @@ -76,24 +69,14 @@ def parse_configs(): type=str, help="path to model ckpt; will overwrite cfg.ckpt_path if specified", ) - parser.add_argument( - "--data-path", default=None, type=str, help="path to data csv", required=True - ) + parser.add_argument("--data-path", default=None, type=str, help="path to data csv", required=True) parser.add_argument("--warmup-steps", default=1, type=int, help="warmup steps") parser.add_argument("--active-steps", default=1, type=int, help="active steps") - parser.add_argument( - "--base-resolution", default="240p", type=str, help="base resolution" - ) + parser.add_argument("--base-resolution", default="240p", type=str, help="base resolution") parser.add_argument("--base-frames", default=128, type=int, help="base frames") - parser.add_argument( - "--batch-size-start", default=2, type=int, help="batch size start" - ) - parser.add_argument( - "--batch-size-end", default=256, type=int, help="batch size end" - ) - parser.add_argument( - "--batch-size-step", default=2, type=int, help="batch size step" - ) + parser.add_argument("--batch-size-start", default=2, type=int, help="batch size start") + parser.add_argument("--batch-size-end", default=256, type=int, help="batch size end") + parser.add_argument("--batch-size-step", default=2, type=int, help="batch size step") args = parser.parse_args() cfg = Config.fromfile(args.config) cfg = merge_args(cfg, args, training=True) @@ -116,9 +99,7 @@ def main(): # ====================================================== cfg, args = parse_configs() print(cfg) - assert ( - cfg.dataset.type == "VariableVideoTextDataset" - ), "Only VariableVideoTextDataset is supported" + assert cfg.dataset.type == "VariableVideoTextDataset", "Only VariableVideoTextDataset is supported" # ====================================================== # 2. runtime variables & colossalai launch @@ -223,10 +204,7 @@ def main(): model_sharding(ema) buckets = [ - (res, f) - for res, d in cfg.bucket_config.items() - for f, (p, bs) in d.items() - if bs is not None and p > 0.0 + (res, f) for res, d in cfg.bucket_config.items() for f, (p, bs) in d.items() if bs is not None and p > 0.0 ] output_bucket_cfg = deepcopy(cfg.bucket_config) # find the base batch size @@ -248,15 +226,11 @@ def main(): optimizer, ema, ) - update_bucket_config_bs( - output_bucket_cfg, args.base_resolution, args.base_frames, base_batch_size - ) + update_bucket_config_bs(output_bucket_cfg, args.base_resolution, args.base_frames, base_batch_size) coordinator.print_on_master( f"{BColors.OKBLUE}Base resolution: {args.base_resolution}, Base frames: {args.base_frames}, Batch size: {base_batch_size}, Base step time: {base_step_time}{BColors.ENDC}" ) - result_table = [ - f"{args.base_resolution}, {args.base_frames}, {base_batch_size}, {base_step_time:.2f}" - ] + result_table = [f"{args.base_resolution}, {args.base_frames}, {base_batch_size}, {base_step_time:.2f}"] for resolution, frames in buckets: try: batch_size, step_time = benchmark( @@ -280,9 +254,7 @@ def main(): f"{BColors.OKBLUE}Resolution: {resolution}, Frames: {frames}, Batch size: {batch_size}, Step time: {step_time}{BColors.ENDC}" ) update_bucket_config_bs(output_bucket_cfg, resolution, frames, batch_size) - result_table.append( - f"{resolution}, {frames}, {batch_size}, {step_time:.2f}" - ) + result_table.append(f"{resolution}, {frames}, {batch_size}, {step_time:.2f}") except RuntimeError: pass result_table = "\n".join(result_table) @@ -367,10 +339,7 @@ def benchmark( raise RuntimeError("No valid batch size found") if target_step_time is None: # find the fastest batch size - throughputs = [ - batch_size / step_time - for step_time, batch_size in zip(step_times, batch_sizes) - ] + throughputs = [batch_size / step_time for step_time, batch_size in zip(step_times, batch_sizes)] max_throughput = max(throughputs) target_batch_size = batch_sizes[throughputs.index(max_throughput)] step_time = step_times[throughputs.index(max_throughput)] @@ -419,13 +388,9 @@ def train( **dataloader_args, ) dataloader_iter = iter(dataloader) - num_steps_per_epoch = ( - dataloader.batch_sampler.get_num_batch() // dist.get_world_size() - ) + num_steps_per_epoch = dataloader.batch_sampler.get_num_batch() // dist.get_world_size() - assert ( - num_steps_per_epoch >= total_steps - ), f"num_steps_per_epoch={num_steps_per_epoch} < total_steps={total_steps}" + assert num_steps_per_epoch >= total_steps, f"num_steps_per_epoch={num_steps_per_epoch} < total_steps={total_steps}" duration = 0 # this is essential for the first iteration after OOM optimizer._grad_store.reset_all_gradients()