Open-Sora/opensora/utils/inference_utils.py
Zheng Zangwei (Alex Zheng) f1c6b8b88e open-sora v1.3 code upload (#786)
Co-authored-by: gxyes <gxynoz@gmail.com>
2025-02-20 16:50:24 +08:00

762 lines
28 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import base64
import json
import os
import re
import shutil
import tempfile
from io import BytesIO
import cv2
import numpy as np
import pandas as pd
import torch
from torchvision import transforms
from opensora.datasets import IMG_FPS
from opensora.datasets.utils import read_from_path
def prepare_multi_resolution_info(info_type, batch_size, image_size, num_frames, fps, device, dtype):
if info_type is None:
return dict()
elif info_type == "PixArtMS":
hw = torch.tensor([image_size], device=device, dtype=dtype).repeat(batch_size, 1)
ar = torch.tensor([[image_size[0] / image_size[1]]], device=device, dtype=dtype).repeat(batch_size, 1)
return dict(ar=ar, hw=hw)
elif info_type in ["STDiT2", "OpenSora"]:
fps = fps if num_frames > 1 else IMG_FPS
fps = torch.tensor([fps], device=device, dtype=dtype).repeat(batch_size)
height = torch.tensor([image_size[0]], device=device, dtype=dtype).repeat(batch_size)
width = torch.tensor([image_size[1]], device=device, dtype=dtype).repeat(batch_size)
num_frames = torch.tensor([num_frames], device=device, dtype=dtype).repeat(batch_size)
ar = torch.tensor([image_size[0] / image_size[1]], device=device, dtype=dtype).repeat(batch_size)
return dict(height=height, width=width, num_frames=num_frames, ar=ar, fps=fps)
else:
raise NotImplementedError
def load_prompts(prompt_path, start_idx=None, end_idx=None, csv_ref_column_name=None):
if prompt_path.endswith(".txt"):
with open(prompt_path, "r") as f:
prompts = [line.strip() for line in f.readlines()]
elif prompt_path.endswith(".csv"):
df = pd.read_csv(prompt_path)
prompts = df["text"].tolist()
if csv_ref_column_name is not None:
assert (
csv_ref_column_name in df
), f"column {csv_ref_column_name} for reference paths not found in {prompt_path}"
reference_paths = df[csv_ref_column_name]
return (prompts, reference_paths)
else:
raise ValueError
prompts = prompts[start_idx:end_idx]
return prompts
def get_save_path_name(
save_dir,
sample_name=None, # prefix
sample_idx=None, # sample index
prompt=None, # used prompt
prompt_as_path=False, # use prompt as path
num_sample=1, # number of samples to generate for one prompt
k=None, # kth sample
):
if sample_name is None:
sample_name = "" if prompt_as_path else "sample"
sample_name_suffix = prompt if prompt_as_path else f"_{sample_idx:04d}"
save_path = os.path.join(save_dir, f"{sample_name}{sample_name_suffix}")
if num_sample != 1:
save_path = f"{save_path}-{k}"
return save_path
def transform_aes(aes):
# < 4 filter out
if aes < 4:
return "terrible"
elif aes < 4.5:
return "very poor"
elif aes < 5:
return "poor"
elif aes < 5.5:
return "fair"
elif aes < 6:
return "good"
elif aes < 6.5:
return "very good"
else:
return "excellent"
def transform_motion(motion):
# < 0.3 filter out
if motion < 0.5:
return "very low"
elif motion < 2:
return "low"
elif motion < 5:
return "fair"
elif motion < 10:
return "high"
elif motion < 20:
return "very high"
else:
return "extremely high"
def append_score_to_prompts(prompts, aes=None, flow=None, camera_motion=None):
new_prompts = []
for prompt in prompts:
new_prompt = prompt
if aes is not None and "aesthetic score is" not in prompt:
try:
aes = float(aes)
aes = transform_aes(aes)
except ValueError:
pass # already in text format
new_prompt = f"{new_prompt} The aesthetic score is {aes}."
if flow is not None and "motion strength is" not in prompt:
try:
flow = float(flow)
flow = transform_motion(flow)
except ValueError:
pass # already in text format
new_prompt = f"{new_prompt} The motion strength is {flow}."
if camera_motion is not None and "camera motion:" not in prompt:
new_prompt = f"{new_prompt} camera motion: {camera_motion}."
new_prompts.append(new_prompt)
print("processed prompt:\n", new_prompts)
return new_prompts
def extract_json_from_prompts(prompts, reference, mask_strategy):
ret_prompts = []
for i, prompt in enumerate(prompts):
parts = re.split(r"(?=[{])", prompt)
assert len(parts) <= 2, f"Invalid prompt: {prompt}"
ret_prompts.append(parts[0])
if len(parts) > 1:
additional_info = json.loads(parts[1])
for key in additional_info:
assert key in ["reference_path", "mask_strategy"], f"Invalid key: {key}"
if key == "reference_path":
reference[i] = additional_info[key]
elif key == "mask_strategy":
mask_strategy[i] = additional_info[key]
return ret_prompts, reference, mask_strategy
def extract_json_from_prompts_new(prompts_with_json):
prompts = []
reference = []
mask_strategy = []
for i, prompt in enumerate(prompts_with_json):
parts = re.split(r"(?=[{])", prompt)
assert len(parts) <= 2, f"Invalid prompt: {prompt}"
prompts.append(parts[0])
if len(parts) > 1:
additional_info = json.loads(parts[1])
for key in additional_info:
assert key in ["reference_path", "mask_strategy"], f"Invalid key: {key}"
if key == "reference_path":
reference[i] = additional_info[key]
elif key == "mask_strategy":
mask_strategy[i] = additional_info[key]
return prompts, reference, mask_strategy
def collect_references_batch(reference_paths, vae, image_size):
refs_x = [] # refs_x: [batch, ref_num, C, T, H, W]
for reference_path in reference_paths:
if reference_path == "":
refs_x.append([])
continue
ref_path = reference_path.split(";")
ref = []
for r_path in ref_path:
r = read_from_path(r_path, image_size, transform_name="resize_crop")
# need to ensure r has length accepted by vae
actual_t = r.size(1)
if vae.micro_frame_size is None:
target_t = (actual_t - 1) // 4 * 4 + 1
elif not vae.temporal_overlap:
target_t = actual_t // vae.micro_frame_size * vae.micro_frame_size
else:
target_t = (actual_t - 1) // (vae.micro_frame_size - 1) * (vae.micro_frame_size - 1) + 1
r = r[:, :target_t]
r_x = vae.encode(r.unsqueeze(0).to(vae.device, vae.dtype))
r_x = r_x.squeeze(0)
ref.append(r_x)
refs_x.append(ref)
return refs_x
def extract_images_from_ref_paths(reference_paths, image_size):
refs_images = [] # refs_images: [batch, ref_num, C, T, H, W]
for reference_path in reference_paths:
if reference_path == "":
refs_images.append([])
continue
ref_path = reference_path.split(";")
ref = []
for r_path in ref_path:
r = read_from_path(r_path, image_size, transform_name="resize_crop")
ref.append(r.squeeze(1))
refs_images.append(ref)
return refs_images
def extract_prompts_loop(prompts, num_loop):
ret_prompts = []
for prompt in prompts:
if prompt.startswith("|0|"):
prompt_list = prompt.split("|")[1:]
text_list = []
for i in range(0, len(prompt_list), 2):
start_loop = int(prompt_list[i])
text = prompt_list[i + 1]
end_loop = int(prompt_list[i + 2]) if i + 2 < len(prompt_list) else num_loop + 1
text_list.extend([text] * (end_loop - start_loop))
prompt = text_list[num_loop]
ret_prompts.append(prompt)
return ret_prompts
def split_prompt(prompt_text):
if prompt_text.startswith("|0|"):
# this is for prompts which look like
# |0| a beautiful day |1| a sunny day |2| a rainy day
# we want to parse it into a list of prompts with the loop index
prompt_list = prompt_text.split("|")[1:]
text_list = []
loop_idx = []
for i in range(0, len(prompt_list), 2):
start_loop = int(prompt_list[i])
text = prompt_list[i + 1].strip()
text_list.append(text)
loop_idx.append(start_loop)
return text_list, loop_idx
else:
return [prompt_text], None
def merge_prompt(text_list, loop_idx_list=None):
if loop_idx_list is None:
return text_list[0]
else:
prompt = ""
for i, text in enumerate(text_list):
prompt += f"|{loop_idx_list[i]}|{text}"
return prompt
MASK_DEFAULT = ["0", "0", "0", "0", "1", "0"]
def parse_mask_strategy(mask_strategy):
mask_batch = []
if mask_strategy == "" or mask_strategy is None:
return mask_batch
mask_strategy = mask_strategy.split(";")
for mask in mask_strategy:
mask_group = mask.split(",")
num_group = len(mask_group)
assert num_group >= 1 and num_group <= 6, f"Invalid mask strategy: {mask}"
mask_group.extend(MASK_DEFAULT[num_group:])
for i in range(5):
mask_group[i] = int(mask_group[i])
mask_group[5] = float(mask_group[5])
mask_batch.append(mask_group)
return mask_batch
def find_nearest_point(value, point, max_value):
t = value // point
if value % point > point / 2 and t < max_value // point - 1:
t += 1
return t * point
def apply_mask_strategy(z, refs_x, mask_strategys, loop_i, align=None):
masks = []
no_mask = True
for i, mask_strategy in enumerate(mask_strategys):
no_mask = False
mask = torch.ones(z.shape[2], dtype=torch.float, device=z.device)
mask_strategy = parse_mask_strategy(mask_strategy)
for mst in mask_strategy:
loop_id, m_id, m_ref_start, m_target_start, m_length, edit_ratio = mst
if loop_id != loop_i:
continue
ref = refs_x[i][m_id]
if m_ref_start < 0:
# ref: [C, T, H, W]
m_ref_start = ref.shape[1] + m_ref_start
if m_target_start < 0:
# z: [B, C, T, H, W]
m_target_start = z.shape[2] + m_target_start
if align is not None:
m_ref_start = find_nearest_point(m_ref_start, align, ref.shape[1])
m_target_start = find_nearest_point(m_target_start, align, z.shape[2])
m_length = min(m_length, z.shape[2] - m_target_start, ref.shape[1] - m_ref_start)
z[i, :, m_target_start : m_target_start + m_length] = ref[:, m_ref_start : m_ref_start + m_length]
mask[m_target_start : m_target_start + m_length] = edit_ratio
masks.append(mask)
if no_mask:
return None
masks = torch.stack(masks)
return masks
def append_generated(
vae, generated_video, refs_x, mask_strategy, loop_i, condition_frame_length, condition_frame_edit, is_latent=False
):
ref_x = vae.encode(generated_video) if not is_latent else generated_video
for j, refs in enumerate(refs_x):
if refs is None:
refs_x[j] = [ref_x[j]]
else:
refs.append(ref_x[j])
if mask_strategy[j] is None or mask_strategy[j] == "":
mask_strategy[j] = ""
else:
mask_strategy[j] += ";"
mask_strategy[
j
] += f"{loop_i},{len(refs)-1},-{condition_frame_length},0,{condition_frame_length},{condition_frame_edit}"
return refs_x, mask_strategy
def dframe_to_frame(num):
assert num % 5 == 0, f"Invalid num: {num}"
return num // 5 * 17
OPENAI_CLIENT = None
REFINE_PROMPTS = None
# REFINE_PROMPTS_PATH = "assets/texts/t2v_pllava.txt"
REFINE_PROMPTS_PATH = "assets/texts/t2v_demo.txt"
REFINE_PROMPTS_TEMPLATE = """
You need to refine user's input prompt. The user's input prompt is used for video generation task. You need to refine the user's prompt to make it more suitable for the task. Here are some examples of refined prompts:
{}
The refined prompt should pay attention to all objects in the video. The description should be useful for AI to re-generate the video. The description should be no more than six sentences. The video should not have any scene transitions and must be on the same scene. The refined prompt should be in English.
"""
RANDOM_PROMPTS = None
RANDOM_PROMPTS_TEMPLATE = """
You need to generate one input prompt for video generation task. The prompt should be suitable for the task. Here are some examples of refined prompts:
{}
The prompt should pay attention to all objects in the video. The description should be useful for AI to re-generate the video. The description should be no more than six sentences. The video should not have any scene transitions and must be on the same scene. The prompt should be in English.
"""
def get_openai_response(sys_prompt, usr_prompt, model="gpt-4o"):
global OPENAI_CLIENT
if OPENAI_CLIENT is None:
from openai import OpenAI
OPENAI_CLIENT = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
completion = OPENAI_CLIENT.chat.completions.create(
model=model,
messages=[
{
"role": "system",
"content": sys_prompt,
}, # <-- This is the system message that provides context to the model
{
"role": "user",
"content": usr_prompt,
}, # <-- This is the user message for which the model will generate a response
],
)
return completion.choices[0].message.content
def get_random_prompt_by_openai():
global RANDOM_PROMPTS
if RANDOM_PROMPTS is None:
examples = load_prompts(REFINE_PROMPTS_PATH)
RANDOM_PROMPTS = RANDOM_PROMPTS_TEMPLATE.format("\n".join(examples))
response = get_openai_response(RANDOM_PROMPTS, "Generate one example.")
return response
def refine_prompt_by_openai(prompt):
global REFINE_PROMPTS
if REFINE_PROMPTS is None:
examples = load_prompts(REFINE_PROMPTS_PATH)
REFINE_PROMPTS = REFINE_PROMPTS_TEMPLATE.format("\n".join(examples))
response = get_openai_response(REFINE_PROMPTS, prompt)
return response
def has_openai_key():
return "OPENAI_API_KEY" in os.environ
def refine_prompts_by_openai(prompts):
new_prompts = []
for prompt in prompts:
try:
if prompt.strip() == "":
new_prompt = get_random_prompt_by_openai()
print(f"[Info] Empty prompt detected, generate random prompt: {new_prompt}")
else:
new_prompt = refine_prompt_by_openai(prompt)
print(f"[Info] Refine prompt: {prompt} -> {new_prompt}")
new_prompts.append(new_prompt)
except Exception as e:
print(f"[Warning] Failed to refine prompt: {prompt} due to {e}")
new_prompts.append(prompt)
return new_prompts
FIRST_FRAME_PROMPT_TEMPLATE_WITH_INFO = """
Given the first frame of the video, describe this video and its style in a very detailed manner. Some information about the video is:
'{}'.
Describe the video and its style in a very detailed manner. Pay attention to all objects in the video. You must describe what the human character is doing with action in the video, for instance, talk, walk, blink, laugh, sing or anything else that involves movements in the video. Your description must make it easy for this vide to have human movements, instead of being motionless. The description should be useful for AI to re-generate the video. The description should be no more than six sentences.
"""
FIRST_FRAME_PROMPT_TEMPLATE = """
Given the first frame of the video, you need to generate one input prompt for video generation task. The prompt should be suitable for generating a video using the given image as the first frame.
Describe the video and its style in a very detailed manner. Pay attention to all objects in the video. You must describe what the human character is doing with action in the video, for instance, talk, walk, blink, laugh, sing or anything else that involves movements in the video. Your description must make it easy for this vide to have human movements, instead of being motionless. The description should be useful for AI to re-generate the video. The description should be no more than six sentences.
"""
def to_base64(image_tensor):
buffer = BytesIO()
pil_image = transforms.ToPILImage()(image_tensor)
pil_image.save(buffer, format="JPEG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")
LLAVA_PREFIX = [
"The video shows ",
"The video captures ",
"The video features ",
"The video depicts ",
"The video presents ",
"The video features ",
"The video is ",
"In the video, ",
"The image shows ",
"The image captures ",
"The image features ",
"The image depicts ",
"The image presents ",
"The image features ",
"The image is ",
"The image portrays ",
"In the image, ",
]
def remove_caption_prefix(caption):
for prefix in LLAVA_PREFIX:
if caption.startswith(prefix) or caption.startswith(prefix.lower()):
caption = caption[len(prefix) :].strip()
if caption[0].islower():
caption = caption[0].upper() + caption[1:]
return caption
return caption
def get_caption(frame, prompt):
global OPENAI_CLIENT
if OPENAI_CLIENT is None:
from openai import OpenAI
OPENAI_CLIENT = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
response = client.chat.completions.create(
model="gpt-4o-2024-08-06",
messages=[
{"role": "system", "content": prompt},
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{frame}", "detail": "low"}},
],
},
],
max_tokens=300,
top_p=0.1,
)
caption = response.choices[0].message.content
caption = caption.replace("\n", " ")
caption = remove_caption_prefix(caption).replace(" image ", " video ")
return caption
def refine_batched_prompts_with_images(prompts, images):
new_prompts = []
for prompt, image in zip(prompts, images):
try:
if prompt.strip() == "":
new_prompt = get_random_prompt_with_image(image)
print(f"[Info] Empty prompt detected, generate random prompt: {new_prompt}")
else:
new_prompt = refine_prompt_with_image(prompt, image)
print(f"[Info] Refine prompt: {prompt} -> {new_prompt}")
new_prompts.append(new_prompt)
except Exception as e:
print(f"[Warning] Failed to refine prompt: {prompt} due to {e}")
new_prompts.append(prompt)
return new_prompts
def refine_prompt_with_image(prompt, image):
# check api keys
if has_openai_key():
os.environ.get("OPENAI_API_KEY")
else:
print("no openai api key found, prompt not refined")
return prompt
frame = to_base64(image)
caption = get_caption(frame, FIRST_FRAME_PROMPT_TEMPLATE_WITH_INFO.format(prompt))
return caption
def get_random_prompt_with_image(image):
# check api keys
if has_openai_key():
os.environ.get("OPENAI_API_KEY")
else:
print("no openai api key found, prompt not refined")
return prompt
frame = to_base64(image)
caption = get_caption(frame, FIRST_FRAME_PROMPT_TEMPLATE)
return caption
def add_watermark(
input_video_path, watermark_image_path="./assets/images/watermark/watermark.png", output_video_path=None
):
# execute this command in terminal with subprocess
# return if the process is successful
if output_video_path is None:
output_video_path = input_video_path.replace(".mp4", "_watermark.mp4")
cmd = f'ffmpeg -y -i {input_video_path} -i {watermark_image_path} -filter_complex "[1][0]scale2ref=oh*mdar:ih*0.1[logo][video];[video][logo]overlay" {output_video_path}'
exit_code = os.system(cmd)
is_success = exit_code == 0
return is_success
def super_resolution(input_video_path, sr=2):
temp_dir = tempfile.TemporaryDirectory()
cmd = f"python -m tools.repair.inference_realesrgan_video -n RealESRGAN_x4plus -s 2 -i {input_video_path} -o {temp_dir.name}"
os.system(cmd)
output_video_path = os.path.join(temp_dir.name, os.path.basename(input_video_path).split(".")[0] + "_out.mp4")
dst_video_path = os.path.join(
os.path.dirname(input_video_path), os.path.basename(input_video_path).split(".")[0] + f"_sr{sr:.0f}.mp4"
)
shutil.copyfile(output_video_path, dst_video_path)
temp_dir.cleanup()
return dst_video_path
def deflicker_video_local_brightness(input_video_path, output_video_path, smoothing_window=30, block_size=32):
# 打开输入视频
cap = cv2.VideoCapture(input_video_path)
# 获取视频的基本信息
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# 设置输出视频的编解码器和格式
fourcc = cv2.VideoWriter_fourcc(*"h264")
out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
# 初始化存储每帧局部亮度信息的数组
local_brightness_list = []
# 读取视频帧并计算局部亮度
frames = []
while True:
ret, frame = cap.read()
if not ret:
break
frames.append(frame)
# 转换为灰度图像
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
# 将图像划分为小块,并计算每块的平均亮度
local_brightness = []
for y in range(0, height, block_size):
for x in range(0, width, block_size):
block = gray[y : y + block_size, x : x + block_size]
mean_brightness = np.mean(block)
local_brightness.append(mean_brightness)
local_brightness_list.append(local_brightness)
# 将局部亮度进行平滑处理
local_brightness_array = np.array(local_brightness_list)
# 在前后添加填充
assert total_frames % 2 == 1, "The number of frames should be odd."
pad_width = smoothing_window // 2
padded_local_brightness = np.pad(local_brightness_array, ((pad_width, pad_width - 1), (0, 0)), mode="edge")
# 创建一个存储平滑亮度的数组大小与local_brightness_array相同
smoothed_local_brightness = np.zeros_like(local_brightness_array)
# 对每个局部块的亮度进行平滑处理
for i in range(local_brightness_array.shape[1]):
# 使用卷积进行平滑处理
smoothed_brightness = np.convolve(
padded_local_brightness[:, i], np.ones(smoothing_window) / smoothing_window, mode="valid"
)
smoothed_local_brightness[:, i] = smoothed_brightness
# 对每一帧应用局部亮度校正
for i in range(len(frames)):
frame = frames[i].copy()
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
# 逐块调整亮度
idx = 0
for y in range(0, height, block_size):
for x in range(0, width, block_size):
block = gray[y : y + block_size, x : x + block_size]
current_brightness = np.mean(block)
brightness_ratio = smoothed_local_brightness[i, idx] / current_brightness
frame[y : y + block_size, x : x + block_size] = np.clip(
frame[y : y + block_size, x : x + block_size] * brightness_ratio, 0, 255
).astype(np.uint8)
idx += 1
out.write(frame)
# 释放资源
cap.release()
out.release()
print(f"Deflickered video saved as {output_video_path}")
def deflicker(input_video_path):
deflicker_video_local_brightness(input_video_path, input_video_path.replace(".mp4", "_deflicker.mp4"))
return input_video_path
GIGABYTE = 1024**3
def print_memory_usage(prefix: str, device: torch.device):
torch.cuda.synchronize()
max_memory_allocated = torch.cuda.max_memory_allocated(device)
max_memory_reserved = torch.cuda.max_memory_reserved(device)
print(f"{prefix}: max memory allocated: {max_memory_allocated / GIGABYTE:.4f} GB")
print(f"{prefix}: max memory reserved: {max_memory_reserved / GIGABYTE:.4f} GB")
def easy_data(csv):
from opensora.registry import DATASETS, build_module
dataset = build_module(
{
"type": "VariableVideoTextDataset",
"transform_name": "resize_crop",
"data_path": csv,
},
DATASETS,
)
return dataset["0-113-360-640"]
def prep_ref_and_mask(cond_type, condition_frame_length, refs, target_shape, loop, device, dtype):
"""
prepare the mask_index and reference for the 1st loop
Input:
loop: total number of loops to do
"""
latent_t = target_shape[2]
if cond_type is None:
mask_index = []
elif cond_type == "v2v_head":
min_ref_length = min([ref[0].shape[1] for ref in refs])
condition_frame_length = min(
min(condition_frame_length, min_ref_length), latent_t
) # ensure condition frame is no more than generated length
mask_index = [i for i in range(condition_frame_length)]
elif cond_type == "i2v_head" or cond_type == "i2v_loop":
mask_index = [0] # update mask on last frame lfor i2v_loop
elif cond_type == "i2v_tail":
if loop == 1:
mask_index = [-1] # update mask to be positive later
else:
mask_index = [] # cond on last frame in final loop
else:
raise NotImplementedError
# prep ref in correct shape
ref = torch.zeros(target_shape, device=device, dtype=dtype)
if len(mask_index) > 0:
b = target_shape[0]
for b_i in range(b):
ref[b_i, :, mask_index] = refs[b_i][0][:, mask_index].unsqueeze(0)
# get finalized mask_index, except i2v_tail and i2v_loop intermediate loops will update later
if cond_type == "i2v_loop" and loop <= 1:
b = target_shape[0]
for b_i in range(b):
if len(refs[0]) == 1: # if only 1 ref, use last frame
ref[b_i, :, -1] = refs[b_i][0][:, -1].unsqueeze(0)
else:
ref[b_i, :, -1] = refs[b_i][1][:, 0].unsqueeze(0) # CHANGED TO USE IMAGE
mask_index.append(latent_t - 1)
if cond_type == "i2v_tail" and loop <= 1:
mask_index = [latent_t - 1]
return ref, mask_index
def prep_ref_and_update_mask_in_loop(
cond_type, condition_frame_length, samples, refs, target_shape, is_last_loop, device, dtype
):
latent_t = target_shape[2]
# cond frames from last generation
loop_cond_index = [i for i in range(-condition_frame_length, 0)]
# get ref in correct shape
ref = torch.zeros(target_shape, device=device, dtype=dtype)
ref_cut = samples[:, :, loop_cond_index].to(device=device, dtype=dtype)
mask_index = [i for i in range(condition_frame_length)]
ref[:, :, mask_index] = ref_cut
if cond_type == "i2v_loop" or cond_type == "i2v_tail" and is_last_loop:
b = target_shape[0]
for b_i in range(b):
if len(refs[b_i]) == 1: # if only 1 reference passed, use last frame
ref[b_i, :, -1] = refs[b_i][0][:, -1].unsqueeze(0).to(device=device, dtype=dtype)
else: # use the last frame (either video or image) of second reference
ref[b_i, :, -1] = refs[b_i][1][:, 0].unsqueeze(0).to(device=device, dtype=dtype)
mask_index.append(latent_t - 1) # mask_index for final loop
return ref, mask_index