mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-12 05:46:22 +02:00
[feat] update eval
This commit is contained in:
parent
79dabf8bdf
commit
8cb6d2f0bd
|
|
@ -30,9 +30,12 @@ def parse_args(training=False):
|
|||
# Inference
|
||||
# ======================================================
|
||||
if not training:
|
||||
# output
|
||||
parser.add_argument("--save-dir", default=None, type=str, help="path to save generated samples")
|
||||
parser.add_argument("--sample-name", default=None, type=str, help="sample name, default is sample_idx")
|
||||
|
||||
# prompt
|
||||
parser.add_argument("--prompt-path", default=None, type=str, help="path to prompt txt file")
|
||||
parser.add_argument("--save-dir", default=None, type=str, help="path to save generated samples")
|
||||
|
||||
# image/video
|
||||
parser.add_argument("--num-frames", default=None, type=int, help="number of frames")
|
||||
|
|
@ -71,6 +74,8 @@ def merge_args(cfg, args, training=False):
|
|||
if "prompt" not in cfg or cfg["prompt"] is None:
|
||||
assert cfg["prompt_path"] is not None, "prompt or prompt_path must be provided"
|
||||
cfg["prompt"] = load_prompts(cfg["prompt_path"])
|
||||
if "sample_name" not in cfg:
|
||||
cfg["sample_name"] = None
|
||||
else:
|
||||
# Training only
|
||||
if args.data_path is not None:
|
||||
|
|
|
|||
|
|
@ -22,13 +22,18 @@ def main():
|
|||
print(cfg)
|
||||
|
||||
# init distributed
|
||||
colossalai.launch_from_torch({})
|
||||
coordinator = DistCoordinator()
|
||||
if os.environ.get("WORLD_SIZE", None):
|
||||
use_dist = True
|
||||
colossalai.launch_from_torch({})
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
if coordinator.world_size > 1:
|
||||
set_sequence_parallel_group(dist.group.WORLD)
|
||||
enable_sequence_parallelism = True
|
||||
if coordinator.world_size > 1:
|
||||
set_sequence_parallel_group(dist.group.WORLD)
|
||||
enable_sequence_parallelism = True
|
||||
else:
|
||||
enable_sequence_parallelism = False
|
||||
else:
|
||||
use_dist = False
|
||||
enable_sequence_parallelism = False
|
||||
|
||||
# ======================================================
|
||||
|
|
@ -91,6 +96,7 @@ def main():
|
|||
# 4. inference
|
||||
# ======================================================
|
||||
sample_idx = 0
|
||||
sample_name = cfg.sample_name if cfg.sample_name is not None else "sample"
|
||||
save_dir = cfg.save_dir
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
|
|
@ -112,10 +118,10 @@ def main():
|
|||
)
|
||||
samples = vae.decode(samples.to(dtype))
|
||||
|
||||
if coordinator.is_master():
|
||||
if not use_dist or coordinator.is_master():
|
||||
for idx, sample in enumerate(samples):
|
||||
print(f"Prompt: {batch_prompts[idx]}")
|
||||
save_path = os.path.join(save_dir, f"sample_{sample_idx}")
|
||||
save_path = os.path.join(save_dir, f"{sample_name}_{sample_idx}")
|
||||
save_sample(sample, fps=cfg.fps, save_path=save_path)
|
||||
sample_idx += 1
|
||||
|
||||
|
|
|
|||
16
scripts/misc/sample.sh
Normal file
16
scripts/misc/sample.sh
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
set -x;
|
||||
|
||||
CUDA_VISIBLE_DEVICES=7
|
||||
CMD="python scripts/inference.py configs/opensora-v1-1/inference/sample.py"
|
||||
CKPT="~/lishenggui/epoch0-global_step8500"
|
||||
OUTPUT="./outputs/samples_s1_8500"
|
||||
|
||||
# 1. image
|
||||
# 1.1 1024x1024
|
||||
eval $CMD --ckpt-path $CKPT --prompt-path assets/texts/t2i_samples.txt --save-dir $OUTPUT --num-frames 1 --image-size 1024 1024 --sample-name pixart_1024x1024_1
|
||||
|
||||
# 1.2 512x512
|
||||
|
||||
# 1.3 240x426
|
||||
|
||||
# 1.4 720p multi-resolution
|
||||
|
|
@ -25,12 +25,7 @@ from opensora.datasets import prepare_variable_dataloader
|
|||
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
|
||||
from opensora.utils.ckpt_utils import model_sharding
|
||||
from opensora.utils.config_utils import merge_args, parse_configs
|
||||
from opensora.utils.misc import (
|
||||
format_numel_str,
|
||||
get_model_numel,
|
||||
requires_grad,
|
||||
to_torch_dtype,
|
||||
)
|
||||
from opensora.utils.misc import format_numel_str, get_model_numel, requires_grad, to_torch_dtype
|
||||
from opensora.utils.train_utils import MaskGenerator, update_ema
|
||||
|
||||
|
||||
|
|
@ -66,9 +61,7 @@ class BColors:
|
|||
def parse_configs():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("config", help="model config file path")
|
||||
parser.add_argument(
|
||||
"-o", "--output", help="output config file path", default="output_config.py"
|
||||
)
|
||||
parser.add_argument("-o", "--output", help="output config file path", default="output_config.py")
|
||||
|
||||
parser.add_argument("--seed", default=42, type=int, help="generation seed")
|
||||
parser.add_argument(
|
||||
|
|
@ -76,24 +69,14 @@ def parse_configs():
|
|||
type=str,
|
||||
help="path to model ckpt; will overwrite cfg.ckpt_path if specified",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-path", default=None, type=str, help="path to data csv", required=True
|
||||
)
|
||||
parser.add_argument("--data-path", default=None, type=str, help="path to data csv", required=True)
|
||||
parser.add_argument("--warmup-steps", default=1, type=int, help="warmup steps")
|
||||
parser.add_argument("--active-steps", default=1, type=int, help="active steps")
|
||||
parser.add_argument(
|
||||
"--base-resolution", default="240p", type=str, help="base resolution"
|
||||
)
|
||||
parser.add_argument("--base-resolution", default="240p", type=str, help="base resolution")
|
||||
parser.add_argument("--base-frames", default=128, type=int, help="base frames")
|
||||
parser.add_argument(
|
||||
"--batch-size-start", default=2, type=int, help="batch size start"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size-end", default=256, type=int, help="batch size end"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size-step", default=2, type=int, help="batch size step"
|
||||
)
|
||||
parser.add_argument("--batch-size-start", default=2, type=int, help="batch size start")
|
||||
parser.add_argument("--batch-size-end", default=256, type=int, help="batch size end")
|
||||
parser.add_argument("--batch-size-step", default=2, type=int, help="batch size step")
|
||||
args = parser.parse_args()
|
||||
cfg = Config.fromfile(args.config)
|
||||
cfg = merge_args(cfg, args, training=True)
|
||||
|
|
@ -116,9 +99,7 @@ def main():
|
|||
# ======================================================
|
||||
cfg, args = parse_configs()
|
||||
print(cfg)
|
||||
assert (
|
||||
cfg.dataset.type == "VariableVideoTextDataset"
|
||||
), "Only VariableVideoTextDataset is supported"
|
||||
assert cfg.dataset.type == "VariableVideoTextDataset", "Only VariableVideoTextDataset is supported"
|
||||
|
||||
# ======================================================
|
||||
# 2. runtime variables & colossalai launch
|
||||
|
|
@ -223,10 +204,7 @@ def main():
|
|||
model_sharding(ema)
|
||||
|
||||
buckets = [
|
||||
(res, f)
|
||||
for res, d in cfg.bucket_config.items()
|
||||
for f, (p, bs) in d.items()
|
||||
if bs is not None and p > 0.0
|
||||
(res, f) for res, d in cfg.bucket_config.items() for f, (p, bs) in d.items() if bs is not None and p > 0.0
|
||||
]
|
||||
output_bucket_cfg = deepcopy(cfg.bucket_config)
|
||||
# find the base batch size
|
||||
|
|
@ -248,15 +226,11 @@ def main():
|
|||
optimizer,
|
||||
ema,
|
||||
)
|
||||
update_bucket_config_bs(
|
||||
output_bucket_cfg, args.base_resolution, args.base_frames, base_batch_size
|
||||
)
|
||||
update_bucket_config_bs(output_bucket_cfg, args.base_resolution, args.base_frames, base_batch_size)
|
||||
coordinator.print_on_master(
|
||||
f"{BColors.OKBLUE}Base resolution: {args.base_resolution}, Base frames: {args.base_frames}, Batch size: {base_batch_size}, Base step time: {base_step_time}{BColors.ENDC}"
|
||||
)
|
||||
result_table = [
|
||||
f"{args.base_resolution}, {args.base_frames}, {base_batch_size}, {base_step_time:.2f}"
|
||||
]
|
||||
result_table = [f"{args.base_resolution}, {args.base_frames}, {base_batch_size}, {base_step_time:.2f}"]
|
||||
for resolution, frames in buckets:
|
||||
try:
|
||||
batch_size, step_time = benchmark(
|
||||
|
|
@ -280,9 +254,7 @@ def main():
|
|||
f"{BColors.OKBLUE}Resolution: {resolution}, Frames: {frames}, Batch size: {batch_size}, Step time: {step_time}{BColors.ENDC}"
|
||||
)
|
||||
update_bucket_config_bs(output_bucket_cfg, resolution, frames, batch_size)
|
||||
result_table.append(
|
||||
f"{resolution}, {frames}, {batch_size}, {step_time:.2f}"
|
||||
)
|
||||
result_table.append(f"{resolution}, {frames}, {batch_size}, {step_time:.2f}")
|
||||
except RuntimeError:
|
||||
pass
|
||||
result_table = "\n".join(result_table)
|
||||
|
|
@ -367,10 +339,7 @@ def benchmark(
|
|||
raise RuntimeError("No valid batch size found")
|
||||
if target_step_time is None:
|
||||
# find the fastest batch size
|
||||
throughputs = [
|
||||
batch_size / step_time
|
||||
for step_time, batch_size in zip(step_times, batch_sizes)
|
||||
]
|
||||
throughputs = [batch_size / step_time for step_time, batch_size in zip(step_times, batch_sizes)]
|
||||
max_throughput = max(throughputs)
|
||||
target_batch_size = batch_sizes[throughputs.index(max_throughput)]
|
||||
step_time = step_times[throughputs.index(max_throughput)]
|
||||
|
|
@ -419,13 +388,9 @@ def train(
|
|||
**dataloader_args,
|
||||
)
|
||||
dataloader_iter = iter(dataloader)
|
||||
num_steps_per_epoch = (
|
||||
dataloader.batch_sampler.get_num_batch() // dist.get_world_size()
|
||||
)
|
||||
num_steps_per_epoch = dataloader.batch_sampler.get_num_batch() // dist.get_world_size()
|
||||
|
||||
assert (
|
||||
num_steps_per_epoch >= total_steps
|
||||
), f"num_steps_per_epoch={num_steps_per_epoch} < total_steps={total_steps}"
|
||||
assert num_steps_per_epoch >= total_steps, f"num_steps_per_epoch={num_steps_per_epoch} < total_steps={total_steps}"
|
||||
duration = 0
|
||||
# this is essential for the first iteration after OOM
|
||||
optimizer._grad_store.reset_all_gradients()
|
||||
Loading…
Reference in a new issue