mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-10 21:01:26 +02:00
merge video edit
This commit is contained in:
parent
341b12f9bf
commit
11f1822cd2
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -174,6 +174,7 @@ samples
|
|||
logs/
|
||||
pretrained_models/
|
||||
evaluation_results/
|
||||
cache/
|
||||
*.swp
|
||||
|
||||
# Secret files
|
||||
|
|
|
|||
|
|
@ -3,3 +3,4 @@ In an ornate, historical hall, a massive tidal wave peaks and begins to crash. T
|
|||
Pirate ship in a cosmic maelstrom nebula.
|
||||
Drone view of waves crashing against the rugged cliffs along Big Sur’s garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliff’s edge. The steep drop from the road down to the beach is a dramatic feat, with the cliff's edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway.
|
||||
A sad small cactus with in the Sahara desert becomes happy.
|
||||
A car driving on a road in the middle of a desert.
|
||||
|
|
|
|||
|
|
@ -6,19 +6,28 @@ multi_resolution = "STDiT2"
|
|||
|
||||
# Condition
|
||||
prompt_path = None
|
||||
prompt = None
|
||||
prompt = [
|
||||
"A car driving on a road in the middle of a desert.",
|
||||
# "A man smiling",
|
||||
# "Drone view of waves crashing against the rugged cliffs along Big Sur's garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliff’s edge. The steep drop from the road down to the beach is a dramatic feat, with the cliff's edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway.",
|
||||
# "In an ornate, historical hall, a massive tidal wave peaks and begins to crash. Two surfers, seizing the moment, skillfully navigate the face of the wave.",
|
||||
]
|
||||
|
||||
loop = 10
|
||||
loop = 1
|
||||
condition_frame_length = 4
|
||||
reference_path = [
|
||||
"assets/images/condition/cliff.png",
|
||||
"assets/images/condition/wave.png",
|
||||
"https://cdn.openai.com/tmp/s/interp/d0.mp4",
|
||||
# "https://www.comp.nus.edu.sg/~youy/index_files/yangyou3.png",
|
||||
# "assets/images/condition/cliff.png",
|
||||
# "assets/images/condition/wave.png",
|
||||
]
|
||||
# valid when reference_path is not None
|
||||
# (loop id, ref id, ref start, length, target start)
|
||||
mask_strategy = [
|
||||
"0,0,0,1,0",
|
||||
"0,0,0,1,0",
|
||||
"0,0,0,8,0,0.5",
|
||||
# "0,0,0,1,0",
|
||||
# "0,0,0,1,0",
|
||||
# "0,0,0,1,0",
|
||||
]
|
||||
|
||||
# Define model
|
||||
|
|
|
|||
|
|
@ -24,7 +24,6 @@ text_encoder = dict(
|
|||
)
|
||||
scheduler = dict(
|
||||
type="iddpm",
|
||||
# type="dpm-solver",
|
||||
num_sampling_steps=100,
|
||||
cfg_scale=7.0,
|
||||
)
|
||||
|
|
@ -33,31 +32,16 @@ dtype = "fp16"
|
|||
# Condition
|
||||
prompt_path = None
|
||||
prompt = [
|
||||
# "Drone view of waves crashing against the rugged cliffs along Big Sur’s garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliff’s edge. The steep drop from the road down to the beach is a dramatic feat, with the cliff's edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway.",
|
||||
"In an ornate, historical hall, a massive tidal wave peaks and begins to crash. Two surfers, seizing the moment, skillfully navigate the face of the wave.",
|
||||
# "put the video in space with a rainbow road",
|
||||
# "make it have dinosaurs",
|
||||
# "make it in claymation animation style",
|
||||
# "make it go underwater"
|
||||
"A car driving on a road in the middle of a desert.",
|
||||
]
|
||||
|
||||
loop = 1
|
||||
condition_frame_length = 4
|
||||
reference_path = [
|
||||
# "assets/images/condition/cliff.png",
|
||||
"assets/images/condition/wave.png",
|
||||
# "/home/zhouyukun/Open-Sora-dev/assets/videos/base.mp4",
|
||||
# "/home/zhouyukun/Open-Sora-dev/assets/videos/base.mp4",
|
||||
# "/home/zhouyukun/Open-Sora-dev/assets/videos/base.mp4",
|
||||
# "/home/zhouyukun/Open-Sora-dev/assets/videos/base.mp4",
|
||||
"https://cdn.openai.com/tmp/s/interp/d0.mp4",
|
||||
]
|
||||
mask_strategy = [
|
||||
# "0,0,0,1,0,0",
|
||||
"0,0,0,1,0,0",
|
||||
# "0,0,0,12,0,0", # 噪声率
|
||||
# "0,0,0,12,0,0", # 噪声率
|
||||
# "0,0,0,12,0,0", # 噪声率
|
||||
# "0,0,0,12,0,0", # 噪声率
|
||||
] # valid when reference_path is not None
|
||||
# (loop id, ref id, ref start, length, target start)
|
||||
|
||||
|
|
|
|||
|
|
@ -120,11 +120,11 @@ function run_video_edit() { # 23min
|
|||
|
||||
# 3.2
|
||||
eval $CMD_REF --ckpt-path $CKPT --save-dir $OUTPUT --sample-name ref_L1_128x240x426 \
|
||||
--prompt-path assets/texts/t2v_ref.txt --start-index 3 --end-index 5 \
|
||||
--prompt-path assets/texts/t2v_ref.txt --start-index 3 --end-index 6 \
|
||||
--num-frames 128 --image-size 240 426 \
|
||||
--loop 1 \
|
||||
--reference-path assets/images/condition/cliff.png "assets/images/condition/cactus-sad.png\;assets/images/condition/cactus-happy.png" \
|
||||
--mask-strategy "0,0,0,1,0\;0,0,0,1,-1" "0,0,0,1,0\;0,1,0,1,-1"
|
||||
--reference-path assets/images/condition/cliff.png "assets/images/condition/cactus-sad.png\;assets/images/condition/cactus-happy.png" https://cdn.openai.com/tmp/s/interp/d0.mp4 \
|
||||
--mask-strategy "0,0,0,1,0\;0,0,0,1,-1" "0,0,0,1,0\;0,1,0,1,-1" "0,0,0,64,0,0.5"
|
||||
}
|
||||
|
||||
# vbench has 950 samples
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
import os
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import requests
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
|
|
@ -14,6 +16,20 @@ from . import video_transforms
|
|||
|
||||
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
|
||||
|
||||
regex = re.compile(
|
||||
r"^(?:http|ftp)s?://" # http:// or https://
|
||||
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|" # domain...
|
||||
r"localhost|" # localhost...
|
||||
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" # ...or ip
|
||||
r"(?::\d+)?" # optional port
|
||||
r"(?:/?|[/?]\S+)$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def is_url(url):
|
||||
return re.match(regex, url) is not None
|
||||
|
||||
|
||||
def read_file(input_path):
|
||||
if input_path.endswith(".csv"):
|
||||
|
|
@ -24,6 +40,19 @@ def read_file(input_path):
|
|||
raise NotImplementedError(f"Unsupported file format: {input_path}")
|
||||
|
||||
|
||||
def download_url(input_path):
|
||||
output_dir = "cache"
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
base_name = os.path.basename(input_path)
|
||||
output_path = os.path.join(output_dir, base_name)
|
||||
img_data = requests.get(input_path).content
|
||||
with open(output_path, "wb") as handler:
|
||||
handler.write(img_data)
|
||||
print(f"URL {input_path} downloaded to {output_path}")
|
||||
return output_path
|
||||
|
||||
|
||||
def temporal_random_crop(vframes, num_frames, frame_interval):
|
||||
temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval)
|
||||
total_frames = len(vframes)
|
||||
|
|
@ -106,6 +135,8 @@ def read_video_from_path(path, transform=None, transform_name="center", image_si
|
|||
|
||||
|
||||
def read_from_path(path, image_size, transform_name="center"):
|
||||
if is_url(path):
|
||||
path = download_url(path)
|
||||
ext = os.path.splitext(path)[-1].lower()
|
||||
if ext.lower() in VID_EXTENSIONS:
|
||||
return read_video_from_path(path, image_size=image_size, transform_name=transform_name)
|
||||
|
|
|
|||
|
|
@ -408,15 +408,21 @@ class GaussianDiffusion:
|
|||
if mask is not None:
|
||||
if mask.shape[0] != x.shape[0]:
|
||||
mask = mask.repeat(2, 1) # HACK
|
||||
# copy unchanged x values to x0
|
||||
mask_t = (mask * len(self.betas) - 1).to(torch.int)
|
||||
|
||||
# x0: copy unchanged x values
|
||||
# x_noise: add noise to x values
|
||||
x0 = x.clone()
|
||||
mask_t = (mask * len(self.betas)).to(torch.int)
|
||||
mask_t_equall = (mask_t == t.unsqueeze(1))[:, None, :, None, None]
|
||||
mask_t_upper = (mask_t > t.unsqueeze(1))[:, None, :, None, None]
|
||||
x_noise = x0 * _extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) + torch.randn_like(x) * _extract_into_tensor(
|
||||
self.sqrt_one_minus_alphas_cumprod, t, x.shape)
|
||||
x_noise = x0 * _extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) + torch.randn_like(
|
||||
x
|
||||
) * _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape)
|
||||
|
||||
# active noise addition
|
||||
x = torch.where(mask_t_equall, x_noise, x)
|
||||
mask_t_equall = (mask_t == t.unsqueeze(1))[:, None, :, None, None]
|
||||
x = torch.where(mask_t_equall, x_noise, x0)
|
||||
|
||||
# create x_mask
|
||||
mask_t_upper = (mask_t > t.unsqueeze(1))[:, None, :, None, None]
|
||||
batch_size = x.shape[0]
|
||||
model_kwargs["x_mask"] = mask_t_upper.reshape(batch_size, -1).to(torch.bool)
|
||||
|
||||
|
|
|
|||
|
|
@ -34,9 +34,11 @@ def apply_mask_strategy(z, refs_x, mask_strategys, loop_i):
|
|||
masks = []
|
||||
for i, mask_strategy in enumerate(mask_strategys):
|
||||
mask_strategy = mask_strategy.split(";")
|
||||
mask = torch.ones(z.shape[2], dtype=torch.bool, device=z.device)
|
||||
mask = torch.ones(z.shape[2], dtype=torch.float, device=z.device)
|
||||
for mst in mask_strategy:
|
||||
loop_id, m_id, m_ref_start, m_length, m_target_start = mst.split(",")
|
||||
mask_batch = mst.split(",")
|
||||
loop_id, m_id, m_ref_start, m_length, m_target_start = mask_batch[:5]
|
||||
edit_ratio = mask_batch[5] if len(mask_batch) == 6 else 0.0
|
||||
loop_id = int(loop_id)
|
||||
if loop_id != loop_i:
|
||||
continue
|
||||
|
|
@ -44,6 +46,7 @@ def apply_mask_strategy(z, refs_x, mask_strategys, loop_i):
|
|||
m_ref_start = int(m_ref_start)
|
||||
m_length = int(m_length)
|
||||
m_target_start = int(m_target_start)
|
||||
edit_ratio = float(edit_ratio)
|
||||
ref = refs_x[i][m_id] # [C, T, H, W]
|
||||
if m_ref_start < 0:
|
||||
m_ref_start = ref.shape[1] + m_ref_start
|
||||
|
|
@ -51,7 +54,7 @@ def apply_mask_strategy(z, refs_x, mask_strategys, loop_i):
|
|||
# z: [B, C, T, H, W]
|
||||
m_target_start = z.shape[2] + m_target_start
|
||||
z[i, :, m_target_start : m_target_start + m_length] = ref[:, m_ref_start : m_ref_start + m_length]
|
||||
mask[m_target_start : m_target_start + m_length] = 0
|
||||
mask[m_target_start : m_target_start + m_length] = edit_ratio
|
||||
masks.append(mask)
|
||||
masks = torch.stack(masks)
|
||||
return masks
|
||||
|
|
@ -69,7 +72,7 @@ def process_prompts(prompts, num_loop):
|
|||
text = text_preprocessing(text)
|
||||
end_loop = int(prompt_list[i + 2]) if i + 2 < len(prompt_list) else num_loop
|
||||
text_list.extend([text] * (end_loop - start_loop))
|
||||
assert len(text_list) == num_loop
|
||||
assert len(text_list) == num_loop, f"Prompt loop mismatch: {len(text_list)} != {num_loop}"
|
||||
ret_prompts.append(text_list)
|
||||
else:
|
||||
prompt = text_preprocessing(prompt)
|
||||
|
|
@ -161,8 +164,12 @@ def main():
|
|||
|
||||
# 3.5 reference
|
||||
if cfg.reference_path is not None:
|
||||
assert len(cfg.reference_path) == len(prompts)
|
||||
assert len(cfg.reference_path) == len(cfg.mask_strategy)
|
||||
assert len(cfg.reference_path) == len(
|
||||
prompts
|
||||
), f"Reference path mismatch: {len(cfg.reference_path)} != {len(prompts)}"
|
||||
assert len(cfg.reference_path) == len(
|
||||
cfg.mask_strategy
|
||||
), f"Mask strategy mismatch: {len(cfg.mask_strategy)} != {len(prompts)}"
|
||||
|
||||
# ======================================================
|
||||
# 4. inference
|
||||
|
|
@ -204,7 +211,6 @@ def main():
|
|||
j
|
||||
] += f";{loop_i},{len(refs)-1},-{cfg.condition_frame_length},{cfg.condition_frame_length},0"
|
||||
masks = apply_mask_strategy(z, refs_x, mask_strategy, loop_i)
|
||||
model_args["x_mask"] = masks
|
||||
|
||||
# 4.6. diffusion sampling
|
||||
old_sample_idx = sample_idx
|
||||
|
|
|
|||
|
|
@ -1,220 +0,0 @@
|
|||
import os
|
||||
|
||||
import colossalai
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from mmengine.runner import set_random_seed
|
||||
|
||||
from opensora.acceleration.parallel_states import set_sequence_parallel_group
|
||||
from opensora.datasets import save_sample
|
||||
from opensora.datasets.utils import read_from_path
|
||||
from opensora.registry import MODELS, SCHEDULERS, build_module
|
||||
from opensora.utils.config_utils import parse_configs
|
||||
from opensora.utils.misc import to_torch_dtype
|
||||
|
||||
|
||||
def collect_references_batch(reference_paths, vae, image_size):
|
||||
refs_x = []
|
||||
for reference_path in reference_paths:
|
||||
ref_path = reference_path.split(";")
|
||||
ref = []
|
||||
for r_path in ref_path:
|
||||
r = read_from_path(r_path, image_size)
|
||||
r_x = vae.encode(r.unsqueeze(0).to(vae.device, vae.dtype))
|
||||
r_x = r_x.squeeze(0)
|
||||
ref.append(r_x)
|
||||
refs_x.append(ref)
|
||||
# refs_x: [batch, ref_num, C, T, H, W]
|
||||
return refs_x
|
||||
|
||||
|
||||
def apply_mask_strategy(z, refs_x, mask_strategys, loop_i):
|
||||
masks = []
|
||||
for i, mask_strategy in enumerate(mask_strategys):
|
||||
mask_strategy = mask_strategy.split(";")
|
||||
mask = torch.ones(z.shape[2], dtype=torch.float, device=z.device)
|
||||
for mst in mask_strategy:
|
||||
mask_batch = mst.split(",")
|
||||
loop_id, m_id, m_ref_start, m_length, m_target_start = mask_batch[:5]
|
||||
edit_ratio = mask_batch[5] if len(mask_batch) == 6 else 0.0
|
||||
loop_id = int(loop_id)
|
||||
if loop_id != loop_i:
|
||||
continue
|
||||
m_id = int(m_id)
|
||||
m_ref_start = int(m_ref_start)
|
||||
m_length = int(m_length)
|
||||
m_target_start = int(m_target_start)
|
||||
edit_ratio = float(edit_ratio)
|
||||
ref = refs_x[i][m_id] # [C, T, H, W]
|
||||
if m_ref_start < 0:
|
||||
m_ref_start = ref.shape[1] + m_ref_start
|
||||
if m_target_start < 0:
|
||||
# z: [B, C, T, H, W]
|
||||
m_target_start = z.shape[2] + m_target_start
|
||||
z[i, :, m_target_start : m_target_start + m_length] = ref[:, m_ref_start : m_ref_start + m_length]
|
||||
mask[m_target_start : m_target_start + m_length] = edit_ratio
|
||||
masks.append(mask)
|
||||
masks = torch.stack(masks)
|
||||
return masks, z
|
||||
|
||||
|
||||
def process_prompts(prompts, num_loop):
|
||||
ret_prompts = []
|
||||
for prompt in prompts:
|
||||
if prompt.startswith("|0|"):
|
||||
prompt_list = prompt.split("|")[1:]
|
||||
text_list = []
|
||||
for i in range(0, len(prompt_list), 2):
|
||||
start_loop = int(prompt_list[i])
|
||||
text = prompt_list[i + 1]
|
||||
end_loop = int(prompt_list[i + 2]) if i + 2 < len(prompt_list) else num_loop
|
||||
text_list.extend([text] * (end_loop - start_loop))
|
||||
assert len(text_list) == num_loop
|
||||
ret_prompts.append(text_list)
|
||||
else:
|
||||
ret_prompts.append([prompt] * num_loop)
|
||||
return ret_prompts
|
||||
|
||||
|
||||
def main():
|
||||
# ======================================================
|
||||
# 1. cfg and init distributed env
|
||||
# ======================================================
|
||||
cfg = parse_configs(training=False)
|
||||
print(cfg)
|
||||
|
||||
# init distributed
|
||||
colossalai.launch_from_torch({})
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
if coordinator.world_size > 1:
|
||||
set_sequence_parallel_group(dist.group.WORLD)
|
||||
enable_sequence_parallelism = True
|
||||
else:
|
||||
enable_sequence_parallelism = False
|
||||
|
||||
# ======================================================
|
||||
# 2. runtime variables
|
||||
# ======================================================
|
||||
torch.set_grad_enabled(False)
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
dtype = to_torch_dtype(cfg.dtype)
|
||||
set_random_seed(seed=cfg.seed)
|
||||
prompts = cfg.prompt
|
||||
|
||||
# ======================================================
|
||||
# 3. build model & load weights
|
||||
# ======================================================
|
||||
# 3.1. build model
|
||||
input_size = (cfg.num_frames, *cfg.image_size)
|
||||
vae = build_module(cfg.vae, MODELS)
|
||||
latent_size = vae.get_latent_size(input_size)
|
||||
text_encoder = build_module(cfg.text_encoder, MODELS, device=device) # T5 must be fp32
|
||||
model = build_module(
|
||||
cfg.model,
|
||||
MODELS,
|
||||
input_size=latent_size,
|
||||
in_channels=vae.out_channels,
|
||||
caption_channels=text_encoder.output_dim,
|
||||
model_max_length=text_encoder.model_max_length,
|
||||
dtype=dtype,
|
||||
enable_sequence_parallelism=enable_sequence_parallelism,
|
||||
)
|
||||
text_encoder.y_embedder = model.y_embedder # hack for classifier-free guidance
|
||||
|
||||
# 3.2. move to device & eval
|
||||
vae = vae.to(device, dtype).eval()
|
||||
model = model.to(device, dtype).eval()
|
||||
|
||||
# 3.3. build scheduler
|
||||
scheduler = build_module(cfg.scheduler, SCHEDULERS)
|
||||
|
||||
# 3.4. support for multi-resolution
|
||||
model_args = dict()
|
||||
if cfg.multi_resolution == "PixArtMS":
|
||||
image_size = cfg.image_size
|
||||
hw = torch.tensor([image_size], device=device, dtype=dtype).repeat(cfg.batch_size, 1)
|
||||
ar = torch.tensor([[image_size[0] / image_size[1]]], device=device, dtype=dtype).repeat(cfg.batch_size, 1)
|
||||
model_args["data_info"] = dict(ar=ar, hw=hw)
|
||||
elif cfg.multi_resolution == "STDiT2":
|
||||
image_size = cfg.image_size
|
||||
height = torch.tensor([image_size[0]], device=device, dtype=dtype).repeat(cfg.batch_size)
|
||||
width = torch.tensor([image_size[1]], device=device, dtype=dtype).repeat(cfg.batch_size)
|
||||
num_frames = torch.tensor([cfg.num_frames], device=device, dtype=dtype).repeat(cfg.batch_size)
|
||||
ar = torch.tensor([image_size[0] / image_size[1]], device=device, dtype=dtype).repeat(cfg.batch_size)
|
||||
model_args["height"] = height
|
||||
model_args["width"] = width
|
||||
model_args["num_frames"] = num_frames
|
||||
model_args["ar"] = ar
|
||||
|
||||
# 3.5 reference
|
||||
if cfg.reference_path is not None:
|
||||
assert len(cfg.reference_path) == len(prompts)
|
||||
assert len(cfg.reference_path) == len(cfg.mask_strategy)
|
||||
|
||||
# ======================================================
|
||||
# 4. inference
|
||||
# ======================================================
|
||||
sample_idx = 0
|
||||
save_dir = cfg.save_dir
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# 4.1. batch generation
|
||||
for i in range(0, len(prompts), cfg.batch_size):
|
||||
batch_prompts_loops = process_prompts(prompts[i : i + cfg.batch_size], cfg.loop)
|
||||
video_clips = []
|
||||
|
||||
# 4.2. load reference videos & images
|
||||
if cfg.reference_path is not None:
|
||||
refs_x = collect_references_batch(cfg.reference_path[i : i + cfg.batch_size], vae, cfg.image_size)
|
||||
mask_strategy = cfg.mask_strategy[i : i + cfg.batch_size]
|
||||
|
||||
# 4.3. long video generation
|
||||
for loop_i in range(cfg.loop):
|
||||
# 4.4 sample in hidden space
|
||||
batch_prompts = [prompt[loop_i] for prompt in batch_prompts_loops]
|
||||
z = torch.randn(len(batch_prompts), vae.out_channels, *latent_size, device=device, dtype=dtype)
|
||||
|
||||
# 4.5. apply mask strategy
|
||||
masks = None
|
||||
if cfg.reference_path is not None:
|
||||
if loop_i > 0:
|
||||
ref_x = vae.encode(video_clips[-1])
|
||||
for j, refs in enumerate(refs_x):
|
||||
refs.append(ref_x[j])
|
||||
mask_strategy[
|
||||
j
|
||||
] += f";{loop_i},{len(refs)-1},-{cfg.condition_frame_length},{cfg.condition_frame_length},0,0"
|
||||
masks, z = apply_mask_strategy(z, refs_x, mask_strategy, loop_i)
|
||||
# 4.6. diffusion sampling
|
||||
samples = scheduler.sample(
|
||||
model,
|
||||
text_encoder,
|
||||
z=z,
|
||||
prompts=batch_prompts,
|
||||
device=device,
|
||||
additional_args=model_args,
|
||||
mask=masks, # scheduler must support mask
|
||||
)
|
||||
samples = vae.decode(samples.to(dtype))
|
||||
video_clips.append(samples)
|
||||
|
||||
# 4.7. save video
|
||||
if loop_i == cfg.loop - 1:
|
||||
if coordinator.is_master():
|
||||
for idx in range(len(video_clips[0])):
|
||||
video_clips_i = [video_clips[0][idx]] + [
|
||||
video_clips[i][idx][:, cfg.condition_frame_length :] for i in range(1, cfg.loop)
|
||||
]
|
||||
video = torch.cat(video_clips_i, dim=1)
|
||||
print(f"Prompt: {prompts[i + idx]}")
|
||||
save_path = os.path.join(save_dir, f"sample_{prompts[i + idx]}")
|
||||
save_sample(video, fps=cfg.fps, save_path=save_path)
|
||||
sample_idx += 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in a new issue