mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 21:42:26 +02:00
123 lines
4.5 KiB
Python
123 lines
4.5 KiB
Python
import os
|
|
|
|
import colossalai
|
|
import torch
|
|
import torch.distributed as dist
|
|
from colossalai.cluster import DistCoordinator
|
|
from mmengine.runner import set_random_seed
|
|
|
|
from opensora.acceleration.parallel_states import set_sequence_parallel_group
|
|
from opensora.datasets import save_sample
|
|
from opensora.registry import MODELS, SCHEDULERS, build_module
|
|
from opensora.utils.config_utils import parse_configs
|
|
from opensora.utils.misc import to_torch_dtype
|
|
|
|
|
|
def main():
|
|
# ======================================================
|
|
# 1. cfg and init distributed env
|
|
# ======================================================
|
|
cfg = parse_configs(training=False)
|
|
print(cfg)
|
|
|
|
# init distributed
|
|
colossalai.launch_from_torch({})
|
|
coordinator = DistCoordinator()
|
|
|
|
if coordinator.world_size > 1:
|
|
set_sequence_parallel_group(dist.group.WORLD)
|
|
enable_sequence_parallelism = True
|
|
else:
|
|
enable_sequence_parallelism = False
|
|
|
|
# ======================================================
|
|
# 2. runtime variables
|
|
# ======================================================
|
|
torch.set_grad_enabled(False)
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
dtype = to_torch_dtype(cfg.dtype)
|
|
set_random_seed(seed=cfg.seed)
|
|
prompts = cfg.prompt
|
|
|
|
# ======================================================
|
|
# 3. build model & load weights
|
|
# ======================================================
|
|
# 3.1. build model
|
|
input_size = (cfg.num_frames, *cfg.image_size)
|
|
vae = build_module(cfg.vae, MODELS)
|
|
latent_size = vae.get_latent_size(input_size)
|
|
text_encoder = build_module(cfg.text_encoder, MODELS, device=device) # T5 must be fp32
|
|
model = build_module(
|
|
cfg.model,
|
|
MODELS,
|
|
input_size=latent_size,
|
|
in_channels=vae.out_channels,
|
|
caption_channels=text_encoder.output_dim,
|
|
model_max_length=text_encoder.model_max_length,
|
|
dtype=dtype,
|
|
enable_sequence_parallelism=enable_sequence_parallelism,
|
|
)
|
|
text_encoder.y_embedder = model.y_embedder # hack for classifier-free guidance
|
|
|
|
# 3.2. move to device & eval
|
|
vae = vae.to(device, dtype).eval()
|
|
model = model.to(device, dtype).eval()
|
|
|
|
# 3.3. build scheduler
|
|
scheduler = build_module(cfg.scheduler, SCHEDULERS)
|
|
|
|
# 3.4. support for multi-resolution
|
|
model_args = dict()
|
|
if cfg.multi_resolution == "PixArtMS":
|
|
image_size = cfg.image_size
|
|
hw = torch.tensor([image_size], device=device, dtype=dtype).repeat(cfg.batch_size, 1)
|
|
ar = torch.tensor([[image_size[0] / image_size[1]]], device=device, dtype=dtype).repeat(cfg.batch_size, 1)
|
|
model_args["data_info"] = dict(ar=ar, hw=hw)
|
|
elif cfg.multi_resolution == "STDiT2":
|
|
image_size = cfg.image_size
|
|
height = torch.tensor([image_size[0]], device=device, dtype=dtype).repeat(cfg.batch_size)
|
|
width = torch.tensor([image_size[1]], device=device, dtype=dtype).repeat(cfg.batch_size)
|
|
num_frames = torch.tensor([cfg.num_frames], device=device, dtype=dtype).repeat(cfg.batch_size)
|
|
ar = torch.tensor([image_size[0] / image_size[1]], device=device, dtype=dtype).repeat(cfg.batch_size)
|
|
model_args["height"] = height
|
|
model_args["width"] = width
|
|
model_args["num_frames"] = num_frames
|
|
model_args["ar"] = ar
|
|
|
|
# ======================================================
|
|
# 4. inference
|
|
# ======================================================
|
|
sample_idx = 0
|
|
save_dir = cfg.save_dir
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
|
|
# 4.1. batch generation
|
|
for i in range(0, len(prompts), cfg.batch_size):
|
|
# 4.2 sample in hidden space
|
|
batch_prompts = prompts[i : i + cfg.batch_size]
|
|
z = torch.randn(len(batch_prompts), vae.out_channels, *latent_size, device=device, dtype=dtype)
|
|
|
|
# 4.3. diffusion sampling
|
|
samples = scheduler.sample(
|
|
model,
|
|
text_encoder,
|
|
z=z,
|
|
prompts=batch_prompts,
|
|
device=device,
|
|
additional_args=model_args,
|
|
)
|
|
samples = vae.decode(samples.to(dtype))
|
|
|
|
if coordinator.is_master():
|
|
for idx, sample in enumerate(samples):
|
|
print(f"Prompt: {batch_prompts[idx]}")
|
|
save_path = os.path.join(save_dir, f"sample_{sample_idx}")
|
|
save_sample(sample, fps=cfg.fps, save_path=save_path)
|
|
sample_idx += 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|