update inference z

This commit is contained in:
Zangwei Zheng 2024-03-23 20:28:34 +08:00
parent 7d27f5553e
commit 98e62a7c57
8 changed files with 254 additions and 37 deletions

View file

@ -18,11 +18,12 @@ vae = dict(
)
text_encoder = dict(
type="t5",
from_pretrained="./pretrained_models/t5_ckpts",
from_pretrained="DeepFloyd/t5-v1_1-xxl",
model_max_length=120,
)
scheduler = dict(
type="iddpm",
# type="dpm-solver",
num_sampling_steps=100,
cfg_scale=7.0,
)

View file

@ -1,2 +1,2 @@
from .datasets import DatasetFromCSV, get_transforms_image, get_transforms_video
from .utils import prepare_dataloader, save_sample
from .datasets import DatasetFromCSV
from .utils import get_transforms_image, get_transforms_video, prepare_dataloader, save_sample

View file

@ -4,35 +4,10 @@ import os
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
from . import video_transforms
from .utils import center_crop_arr, VID_EXTENSIONS
def get_transforms_video(resolution=256):
transform_video = transforms.Compose(
[
video_transforms.ToTensorVideo(), # TCHW
# video_transforms.RandomHorizontalFlipVideo(),
video_transforms.UCFCenterCropVideo(resolution),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
return transform_video
def get_transforms_image(image_size=256):
transform = transforms.Compose(
[
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size)),
# transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
return transform
from .utils import VID_EXTENSIONS
class DatasetFromCSV(torch.utils.data.Dataset):

View file

@ -3,17 +3,46 @@ from typing import Iterator, Optional
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import _get_default_group
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
from torchvision.io import write_video
from torchvision.utils import save_image
from . import video_transforms
VID_EXTENSIONS = ("mp4", "avi", "mov", "mkv")
def get_transforms_video(resolution=256):
transform_video = transforms.Compose(
[
video_transforms.ToTensorVideo(), # TCHW
# video_transforms.RandomHorizontalFlipVideo(),
video_transforms.UCFCenterCropVideo(resolution),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
return transform_video
def get_transforms_image(image_size=256):
transform = transforms.Compose(
[
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size)),
# transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
return transform
def read_image_from_path(path, transform=None, num_frames=1, image_size=256):
image = pil_loader(path)
if transform is None:

View file

@ -17,13 +17,12 @@ class DMP_SOLVER:
self,
model,
text_encoder,
z_size,
z,
prompts,
device,
additional_args=None,
):
n = len(prompts)
z = torch.randn(n, *z_size, device=device)
model_args = text_encoder.encode(prompts)
y = model_args.pop("y")
null_y = text_encoder.null(n)

View file

@ -54,13 +54,12 @@ class IDDPM(SpacedDiffusion):
self,
model,
text_encoder,
z_size,
z,
prompts,
device,
additional_args=None,
):
n = len(prompts)
z = torch.randn(n, *z_size, device=device)
z = torch.cat([z, z], 0)
model_args = text_encoder.encode(prompts)
y_null = text_encoder.null(n)

View file

@ -1,16 +1,16 @@
import os
import torch
import colossalai
import torch
import torch.distributed as dist
from colossalai.cluster import DistCoordinator
from mmengine.runner import set_random_seed
from opensora.acceleration.parallel_states import set_sequence_parallel_group
from opensora.datasets import save_sample
from opensora.registry import MODELS, SCHEDULERS, build_module
from opensora.utils.config_utils import parse_configs
from opensora.utils.misc import to_torch_dtype
from opensora.acceleration.parallel_states import set_sequence_parallel_group
from colossalai.cluster import DistCoordinator
def main():
@ -82,12 +82,18 @@ def main():
sample_idx = 0
save_dir = cfg.save_dir
os.makedirs(save_dir, exist_ok=True)
# 4.1. batch generation
for i in range(0, len(prompts), cfg.batch_size):
# 4.2 sample in hidden space
batch_prompts = prompts[i : i + cfg.batch_size]
z = torch.randn(len(batch_prompts), vae.out_channels, *latent_size, device=device, dtype=dtype)
# 4.3. diffusion sampling
samples = scheduler.sample(
model,
text_encoder,
z_size=(vae.out_channels, *latent_size),
z=z,
prompts=batch_prompts,
device=device,
additional_args=model_args,

View file

@ -0,0 +1,208 @@
import os
import colossalai
import torch
import torch.distributed as dist
from colossalai.cluster import DistCoordinator
from mmengine.runner import set_random_seed
from opensora.acceleration.parallel_states import set_sequence_parallel_group
from opensora.datasets import save_sample
from opensora.datasets.utils import read_from_path
from opensora.registry import MODELS, SCHEDULERS, build_module
from opensora.utils.config_utils import parse_configs
from opensora.utils.misc import to_torch_dtype
def collect_references_batch(reference_paths, vae, image_size):
refs_x = []
for reference_path in reference_paths:
ref_path = reference_path.split(";")
ref = []
for r_path in ref_path:
r = read_from_path(r_path, image_size)
r_x = vae.encode(r.unsqueeze(0).to(vae.device, vae.dtype))
r_x = r_x.squeeze(0)
ref.append(r_x)
refs_x.append(ref)
# refs_x: [batch, ref_num, C, T, H, W]
return refs_x
def apply_mask_strategy(z, refs_x, mask_strategys, loop_i):
masks = []
for i, mask_strategy in enumerate(mask_strategys):
mask_strategy = mask_strategy.split(";")
mask = torch.ones(z.shape[2], dtype=torch.bool, device=z.device)
for mst in mask_strategy:
loop_id, m_id, m_ref_start, m_length, m_target_start = mst.split(",")
loop_id = int(loop_id)
if loop_id != loop_i:
continue
m_id = int(m_id)
m_ref_start = int(m_ref_start)
m_length = int(m_length)
m_target_start = int(m_target_start)
ref = refs_x[i][m_id] # [C, T, H, W]
if m_ref_start < 0:
m_ref_start = ref.shape[1] + m_ref_start
if m_target_start < 0:
# z: [B, C, T, H, W]
m_target_start = z.shape[2] + m_target_start
z[i, :, m_target_start : m_target_start + m_length] = ref[:, m_ref_start : m_ref_start + m_length]
mask[m_target_start : m_target_start + m_length] = 0
masks.append(mask)
masks = torch.stack(masks)
return masks
def process_prompts(prompts, num_loop):
ret_prompts = []
for prompt in prompts:
if prompt.startswith("|0|"):
prompt_list = prompt.split("|")[1:]
text_list = []
for i in range(0, len(prompt_list), 2):
start_loop = int(prompt_list[i])
text = prompt_list[i + 1]
end_loop = int(prompt_list[i + 2]) if i + 2 < len(prompt_list) else num_loop
text_list.extend([text] * (end_loop - start_loop))
assert len(text_list) == num_loop
ret_prompts.append(text_list)
else:
ret_prompts.append([prompt] * num_loop)
return ret_prompts
def main():
# ======================================================
# 1. cfg and init distributed env
# ======================================================
cfg = parse_configs(training=False)
print(cfg)
# init distributed
colossalai.launch_from_torch({})
coordinator = DistCoordinator()
if coordinator.world_size > 1:
set_sequence_parallel_group(dist.group.WORLD)
enable_sequence_parallelism = True
else:
enable_sequence_parallelism = False
# ======================================================
# 2. runtime variables
# ======================================================
torch.set_grad_enabled(False)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = to_torch_dtype(cfg.dtype)
set_random_seed(seed=cfg.seed)
prompts = cfg.prompt
# ======================================================
# 3. build model & load weights
# ======================================================
# 3.1. build model
input_size = (cfg.num_frames, *cfg.image_size)
vae = build_module(cfg.vae, MODELS)
latent_size = vae.get_latent_size(input_size)
text_encoder = build_module(cfg.text_encoder, MODELS, device=device) # T5 must be fp32
model = build_module(
cfg.model,
MODELS,
input_size=latent_size,
in_channels=vae.out_channels,
caption_channels=text_encoder.output_dim,
model_max_length=text_encoder.model_max_length,
dtype=dtype,
enable_sequence_parallelism=enable_sequence_parallelism,
)
text_encoder.y_embedder = model.y_embedder # hack for classifier-free guidance
# 3.2. move to device & eval
vae = vae.to(device, dtype).eval()
model = model.to(device, dtype).eval()
# 3.3. build scheduler
scheduler = build_module(cfg.scheduler, SCHEDULERS)
# 3.4. support for multi-resolution
model_args = dict()
if cfg.multi_resolution:
image_size = cfg.image_size
hw = torch.tensor([image_size], device=device, dtype=dtype).repeat(cfg.batch_size, 1)
ar = torch.tensor([[image_size[0] / image_size[1]]], device=device, dtype=dtype).repeat(cfg.batch_size, 1)
model_args["data_info"] = dict(ar=ar, hw=hw)
# 3.5 reference
if cfg.reference_path is not None:
assert len(cfg.reference_path) == len(prompts)
assert len(cfg.reference_path) == len(cfg.mask_strategy)
# ======================================================
# 4. inference
# ======================================================
sample_idx = 0
save_dir = cfg.save_dir
os.makedirs(save_dir, exist_ok=True)
# 4.1. batch generation
for i in range(0, len(prompts), cfg.batch_size):
batch_prompts_loops = process_prompts(prompts[i : i + cfg.batch_size], cfg.loop)
video_clips = []
# 4.2. load reference videos & images
if cfg.reference_path is not None:
refs_x = collect_references_batch(cfg.reference_path[i : i + cfg.batch_size], vae, cfg.image_size[0])
mask_strategy = cfg.mask_strategy[i : i + cfg.batch_size]
# 4.3. long video generation
for loop_i in range(cfg.loop):
# 4.4 sample in hidden space
batch_prompts = [prompt[loop_i] for prompt in batch_prompts_loops]
z = torch.randn(len(batch_prompts), vae.out_channels, *latent_size, device=device, dtype=dtype)
# 4.5. apply mask strategy
masks = None
if cfg.reference_path is not None:
if loop_i > 0:
ref_x = vae.encode(video_clips[-1])
for j, refs in enumerate(refs_x):
refs.append(ref_x[j])
mask_strategy[
j
] += f";{loop_i},{len(refs)-1},-{cfg.condition_frame_length},{cfg.condition_frame_length},0"
masks = apply_mask_strategy(z, refs_x, mask_strategy, loop_i)
# 4.6. diffusion sampling
samples = scheduler.sample(
model,
text_encoder,
z=z,
prompts=batch_prompts,
device=device,
additional_args=model_args,
mask=masks, # scheduler must support mask
)
samples = vae.decode(samples.to(dtype))
video_clips.append(samples)
# 4.7. save video
if loop_i == cfg.loop - 1:
if coordinator.is_master():
for idx in range(len(video_clips[0])):
video_clips_i = [video_clips[0][idx]] + [
video_clips[i][idx][:, cfg.condition_frame_length :] for i in range(1, cfg.loop)
]
video = torch.cat(video_clips_i, dim=1)
print(f"Prompt: {prompts[i + idx]}")
save_path = os.path.join(save_dir, f"sample_{sample_idx}")
save_sample(video, fps=cfg.fps, save_path=save_path)
sample_idx += 1
if __name__ == "__main__":
main()