[feat] image generation

This commit is contained in:
zhengzangw 2024-05-04 10:10:54 +00:00
parent b26beabeb7
commit 4543bc217c
11 changed files with 68 additions and 20 deletions

View file

@ -145,7 +145,7 @@ def read_from_path(path, image_size, transform_name="center"):
return read_image_from_path(path, image_size=image_size, transform_name=transform_name)
def save_sample(x, fps=8, save_path=None, normalize=True, value_range=(-1, 1), force_video=False):
def save_sample(x, fps=8, save_path=None, normalize=True, value_range=(-1, 1), force_video=False, verbose=True):
"""
Args:
x (Tensor): shape [C, T, H, W]
@ -165,7 +165,8 @@ def save_sample(x, fps=8, save_path=None, normalize=True, value_range=(-1, 1), f
x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 3, 0).to("cpu", torch.uint8)
write_video(save_path, x, fps=fps, video_codec="h264")
print(f"Saved to {save_path}")
if verbose:
print(f"Saved to {save_path}")
return save_path

View file

@ -22,6 +22,7 @@ class DPM_SOLVER:
device,
additional_args=None,
mask=None,
progress=True,
):
assert mask is None, "mask is not supported in dpm-solver"
n = len(prompts)
@ -38,7 +39,14 @@ class DPM_SOLVER:
cfg_scale=self.cfg_scale,
model_kwargs=model_args,
)
samples = dpms.sample(z, steps=self.num_sampling_steps, order=2, skip_type="time_uniform", method="multistep")
samples = dpms.sample(
z,
steps=self.num_sampling_steps,
order=2,
skip_type="time_uniform",
method="multistep",
progress=progress,
)
return samples

View file

@ -1255,6 +1255,7 @@ class DPM_Solver:
atol=0.0078,
rtol=0.05,
return_intermediate=False,
progress=True,
):
"""
Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
@ -1414,7 +1415,8 @@ class DPM_Solver:
t_prev_list.append(t)
model_prev_list.append(self.model_fn(x, t))
# Compute the remaining values by `order`-th order multistep DPM-Solver.
for step in tqdm(range(order, steps + 1)):
progress_fn = tqdm if progress else lambda x: x
for step in progress_fn(range(order, steps + 1)):
t = timesteps[step]
# We only use lower order for steps < 10
if lower_order_final and steps < 10:

View file

@ -61,6 +61,7 @@ class IDDPM(SpacedDiffusion):
device,
additional_args=None,
mask=None,
progress=True,
):
n = len(prompts)
z = torch.cat([z, z], 0)
@ -76,7 +77,7 @@ class IDDPM(SpacedDiffusion):
z,
clip_denoised=False,
model_kwargs=model_args,
progress=True,
progress=progress,
device=device,
mask=mask,
)

View file

@ -43,6 +43,7 @@ class RFLOW:
additional_args=None,
mask=None,
guidance_scale=None,
progress=True,
):
assert mask is None, "mask is not supported in rectified flow inference yet"
# if no specific guidance scale is provided, use the default scale when initializing the scheduler
@ -68,7 +69,8 @@ class RFLOW:
for t in timesteps
]
for i, t in tqdm(enumerate(timesteps)):
progress_wrap = tqdm if progress else (lambda x: x)
for i, t in progress_wrap(enumerate(timesteps)):
z_in = torch.cat([z, z], 0)
pred = model(z_in, torch.tensor([t] * z_in.shape[0], device=device), **model_args).chunk(2, dim=1)[0]
pred_cond, pred_uncond = pred.chunk(2, dim=0)

View file

@ -37,6 +37,7 @@ def parse_args(training=False):
parser.add_argument("--end-index", default=None, type=int, help="end index for sample name")
parser.add_argument("--num-sample", default=None, type=int, help="number of samples to generate for one prompt")
parser.add_argument("--prompt-as-path", action="store_true", help="use prompt as path to save samples")
parser.add_argument("--verbose", default=None, type=int, help="verbose level")
# prompt
parser.add_argument("--prompt-path", default=None, type=str, help="path to prompt txt file")

View file

@ -217,7 +217,7 @@ def main():
# ======================================================
# 4. inference
# ======================================================
sample_idx = 0
sample_idx = cfg.get("start_index", 0)
if cfg.sample_name is not None:
sample_name = cfg.sample_name
elif cfg.prompt_as_path:

View file

@ -5,6 +5,7 @@ import torch
import torch.distributed as dist
from colossalai.cluster import DistCoordinator
from mmengine.runner import set_random_seed
from tqdm import tqdm
from opensora.acceleration.parallel_states import set_sequence_parallel_group
from opensora.datasets import IMG_FPS, save_sample
@ -19,6 +20,7 @@ def main():
# 1. cfg and init distributed env
# ======================================================
cfg = parse_configs(training=False)
verbose = cfg.get("verbose", 2)
print(cfg)
# init distributed
@ -99,7 +101,7 @@ def main():
# ======================================================
# 4. inference
# ======================================================
sample_idx = 0
sample_idx = cfg.get("start_index", 0)
if cfg.sample_name is not None:
sample_name = cfg.sample_name
elif cfg.prompt_as_path:
@ -110,7 +112,8 @@ def main():
os.makedirs(save_dir, exist_ok=True)
# 4.1. batch generation
for i in range(0, len(prompts), cfg.batch_size):
progress_wrap = tqdm if verbose == 1 else (lambda x: x)
for i in progress_wrap(range(0, len(prompts), cfg.batch_size)):
# 4.2 sample in hidden space
batch_prompts_raw = prompts[i : i + cfg.batch_size]
batch_prompts = [text_preprocessing(prompt) for prompt in batch_prompts_raw]
@ -152,13 +155,15 @@ def main():
prompts=batch_prompts,
device=device,
additional_args=model_args,
progress=verbose >= 2,
)
samples = vae.decode(samples.to(dtype))
# 4.4. save samples
if not use_dist or coordinator.is_master():
for idx, sample in enumerate(samples):
print(f"Prompt: {batch_prompts_raw[idx]}")
if verbose >= 2:
print(f"Prompt: {batch_prompts_raw[idx]}")
if cfg.prompt_as_path:
sample_name_suffix = batch_prompts_raw[idx]
else:
@ -166,7 +171,12 @@ def main():
save_path = os.path.join(save_dir, f"{sample_name}{sample_name_suffix}")
if cfg.num_sample != 1:
save_path = f"{save_path}-{k}"
save_sample(sample, fps=cfg.fps // cfg.frame_interval, save_path=save_path)
save_sample(
sample,
fps=cfg.fps // cfg.frame_interval,
save_path=save_path,
verbose=verbose >= 2,
)
sample_idx += 1

21
scripts/misc/generate.sh Normal file
View file

@ -0,0 +1,21 @@
#!/bin/bash
set -x
set -e
TEXT_PATH=/home/data/sora_data/pixart-sigma-generated/text.txt
OUTPUT_PATH=/home/data/sora_data/pixart-sigma-generated/raw
CMD="python scripts/inference.py configs/pixart/inference/1x2048MS.py"
LOG_BASE=logs/sample/generate
NUM_PER_GPU=10000
N_LAUNCH=6
NUM_START=$(($N_LAUNCH * $NUM_PER_GPU * 8))
CUDA_VISIBLE_DEVICES=0 $CMD --prompt-path $TEXT_PATH --save-dir $OUTPUT_PATH --start-index $(($NUM_START + $NUM_PER_GPU * 0)) --end-index $(($NUM_START + $NUM_PER_GPU * 1)) --image-size 2048 2048 --verbose 1 --batch-size 2 >${LOG_BASE}_${N_LAUNCH}_1.log 2>&1 &
CUDA_VISIBLE_DEVICES=1 $CMD --prompt-path $TEXT_PATH --save-dir $OUTPUT_PATH --start-index $(($NUM_START + $NUM_PER_GPU * 1)) --end-index $(($NUM_START + $NUM_PER_GPU * 2)) --image-size 1408 2816 --verbose 1 --batch-size 2 >${LOG_BASE}_${N_LAUNCH}_2.log 2>&1 &
CUDA_VISIBLE_DEVICES=2 $CMD --prompt-path $TEXT_PATH --save-dir $OUTPUT_PATH --start-index $(($NUM_START + $NUM_PER_GPU * 2)) --end-index $(($NUM_START + $NUM_PER_GPU * 3)) --image-size 2816 1408 --verbose 1 --batch-size 2 >${LOG_BASE}_${N_LAUNCH}_3.log 2>&1 &
CUDA_VISIBLE_DEVICES=3 $CMD --prompt-path $TEXT_PATH --save-dir $OUTPUT_PATH --start-index $(($NUM_START + $NUM_PER_GPU * 3)) --end-index $(($NUM_START + $NUM_PER_GPU * 4)) --image-size 1664 2304 --verbose 1 --batch-size 2 >${LOG_BASE}_${N_LAUNCH}_4.log 2>&1 &
CUDA_VISIBLE_DEVICES=4 $CMD --prompt-path $TEXT_PATH --save-dir $OUTPUT_PATH --start-index $(($NUM_START + $NUM_PER_GPU * 4)) --end-index $(($NUM_START + $NUM_PER_GPU * 5)) --image-size 2304 1664 --verbose 1 --batch-size 2 >${LOG_BASE}_${N_LAUNCH}_5.log 2>&1 &
CUDA_VISIBLE_DEVICES=5 $CMD --prompt-path $TEXT_PATH --save-dir $OUTPUT_PATH --start-index $(($NUM_START + $NUM_PER_GPU * 5)) --end-index $(($NUM_START + $NUM_PER_GPU * 6)) --image-size 1536 2560 --verbose 1 --batch-size 2 >${LOG_BASE}_${N_LAUNCH}_6.log 2>&1 &
CUDA_VISIBLE_DEVICES=6 $CMD --prompt-path $TEXT_PATH --save-dir $OUTPUT_PATH --start-index $(($NUM_START + $NUM_PER_GPU * 6)) --end-index $(($NUM_START + $NUM_PER_GPU * 7)) --image-size 2560 1536 --verbose 1 --batch-size 2 >${LOG_BASE}_${N_LAUNCH}_7.log 2>&1 &
CUDA_VISIBLE_DEVICES=7 $CMD --prompt-path $TEXT_PATH --save-dir $OUTPUT_PATH --start-index $(($NUM_START + $NUM_PER_GPU * 7)) --end-index $(($NUM_START + $NUM_PER_GPU * 8)) --image-size 2048 2048 --verbose 1 --batch-size 2 >${LOG_BASE}_${N_LAUNCH}_8.log 2>&1 &

View file

@ -7,7 +7,7 @@ Numpy modules for Net2Net
Written by Kyunghyun Paeng
"""
import numpy as np
def net2net(teach_param, stu_param):
# teach param with shape (a, b)
@ -24,26 +24,29 @@ def net2net(teach_param, stu_param):
assert len(teach_param_shape) == len(stu_param_shape), "teach_param and stu_param must have same dimension"
if len(teach_param_shape) == 1:
stu_param[:teach_param_shape[0]] = teach_param
stu_param[: teach_param_shape[0]] = teach_param
elif len(teach_param_shape) == 2:
stu_param[:teach_param_shape[0], :teach_param_shape[1]] = teach_param
stu_param[: teach_param_shape[0], : teach_param_shape[1]] = teach_param
else:
breakpoint()
if stu_param.shape != stu_param_shape:
stu_param = stu_param.reshape(stu_param_shape)
return stu_param
if __name__ == '__main__':
""" Net2Net Class Test """
from opensora.models.pixart import PixArt_Sigma_XL_2, PixArt_1B_2
if __name__ == "__main__":
"""Net2Net Class Test"""
import torch
from opensora.models.pixart import PixArt_1B_2
model = PixArt_1B_2(no_temporal_pos_emb=True, space_scale=4, enable_flashattn=True, enable_layernorm_kernel=True)
print("load model done")
ckpt = torch.load('/home/zhouyukun/projs/opensora/pretrained_models/PixArt-Sigma-XL-2-2K-MS.pth')
ckpt = torch.load("/home/zhouyukun/projs/opensora/pretrained_models/PixArt-Sigma-XL-2-2K-MS.pth")
print("load ckpt done")
ckpt = ckpt["state_dict"]

View file

@ -1 +0,0 @@
colossalai run --nproc_per_node 8 scripts/train.py configs/pixart/train/1x2048x2048.py --data-path /home/zhaowangbo/data/csv/image-v1_1_ext_noempty_rcp_clean_info.csv