mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-05-10 10:31:23 +02:00
762 lines
28 KiB
Python
762 lines
28 KiB
Python
import base64
|
||
import json
|
||
import os
|
||
import re
|
||
import shutil
|
||
import tempfile
|
||
from io import BytesIO
|
||
|
||
import cv2
|
||
import numpy as np
|
||
import pandas as pd
|
||
import torch
|
||
from torchvision import transforms
|
||
|
||
from opensora.datasets import IMG_FPS
|
||
from opensora.datasets.utils import read_from_path
|
||
|
||
|
||
def prepare_multi_resolution_info(info_type, batch_size, image_size, num_frames, fps, device, dtype):
|
||
if info_type is None:
|
||
return dict()
|
||
elif info_type == "PixArtMS":
|
||
hw = torch.tensor([image_size], device=device, dtype=dtype).repeat(batch_size, 1)
|
||
ar = torch.tensor([[image_size[0] / image_size[1]]], device=device, dtype=dtype).repeat(batch_size, 1)
|
||
return dict(ar=ar, hw=hw)
|
||
elif info_type in ["STDiT2", "OpenSora"]:
|
||
fps = fps if num_frames > 1 else IMG_FPS
|
||
fps = torch.tensor([fps], device=device, dtype=dtype).repeat(batch_size)
|
||
height = torch.tensor([image_size[0]], device=device, dtype=dtype).repeat(batch_size)
|
||
width = torch.tensor([image_size[1]], device=device, dtype=dtype).repeat(batch_size)
|
||
num_frames = torch.tensor([num_frames], device=device, dtype=dtype).repeat(batch_size)
|
||
ar = torch.tensor([image_size[0] / image_size[1]], device=device, dtype=dtype).repeat(batch_size)
|
||
return dict(height=height, width=width, num_frames=num_frames, ar=ar, fps=fps)
|
||
else:
|
||
raise NotImplementedError
|
||
|
||
|
||
def load_prompts(prompt_path, start_idx=None, end_idx=None, csv_ref_column_name=None):
|
||
if prompt_path.endswith(".txt"):
|
||
with open(prompt_path, "r") as f:
|
||
prompts = [line.strip() for line in f.readlines()]
|
||
elif prompt_path.endswith(".csv"):
|
||
df = pd.read_csv(prompt_path)
|
||
prompts = df["text"].tolist()
|
||
if csv_ref_column_name is not None:
|
||
assert (
|
||
csv_ref_column_name in df
|
||
), f"column {csv_ref_column_name} for reference paths not found in {prompt_path}"
|
||
reference_paths = df[csv_ref_column_name]
|
||
return (prompts, reference_paths)
|
||
else:
|
||
raise ValueError
|
||
prompts = prompts[start_idx:end_idx]
|
||
return prompts
|
||
|
||
|
||
def get_save_path_name(
|
||
save_dir,
|
||
sample_name=None, # prefix
|
||
sample_idx=None, # sample index
|
||
prompt=None, # used prompt
|
||
prompt_as_path=False, # use prompt as path
|
||
num_sample=1, # number of samples to generate for one prompt
|
||
k=None, # kth sample
|
||
):
|
||
if sample_name is None:
|
||
sample_name = "" if prompt_as_path else "sample"
|
||
sample_name_suffix = prompt if prompt_as_path else f"_{sample_idx:04d}"
|
||
save_path = os.path.join(save_dir, f"{sample_name}{sample_name_suffix}")
|
||
if num_sample != 1:
|
||
save_path = f"{save_path}-{k}"
|
||
return save_path
|
||
|
||
|
||
def transform_aes(aes):
|
||
# < 4 filter out
|
||
if aes < 4:
|
||
return "terrible"
|
||
elif aes < 4.5:
|
||
return "very poor"
|
||
elif aes < 5:
|
||
return "poor"
|
||
elif aes < 5.5:
|
||
return "fair"
|
||
elif aes < 6:
|
||
return "good"
|
||
elif aes < 6.5:
|
||
return "very good"
|
||
else:
|
||
return "excellent"
|
||
|
||
|
||
def transform_motion(motion):
|
||
# < 0.3 filter out
|
||
if motion < 0.5:
|
||
return "very low"
|
||
elif motion < 2:
|
||
return "low"
|
||
elif motion < 5:
|
||
return "fair"
|
||
elif motion < 10:
|
||
return "high"
|
||
elif motion < 20:
|
||
return "very high"
|
||
else:
|
||
return "extremely high"
|
||
|
||
|
||
def append_score_to_prompts(prompts, aes=None, flow=None, camera_motion=None):
|
||
new_prompts = []
|
||
for prompt in prompts:
|
||
new_prompt = prompt
|
||
if aes is not None and "aesthetic score is" not in prompt:
|
||
try:
|
||
aes = float(aes)
|
||
aes = transform_aes(aes)
|
||
except ValueError:
|
||
pass # already in text format
|
||
new_prompt = f"{new_prompt} The aesthetic score is {aes}."
|
||
|
||
if flow is not None and "motion strength is" not in prompt:
|
||
try:
|
||
flow = float(flow)
|
||
flow = transform_motion(flow)
|
||
except ValueError:
|
||
pass # already in text format
|
||
new_prompt = f"{new_prompt} The motion strength is {flow}."
|
||
if camera_motion is not None and "camera motion:" not in prompt:
|
||
new_prompt = f"{new_prompt} camera motion: {camera_motion}."
|
||
new_prompts.append(new_prompt)
|
||
print("processed prompt:\n", new_prompts)
|
||
return new_prompts
|
||
|
||
|
||
def extract_json_from_prompts(prompts, reference, mask_strategy):
|
||
ret_prompts = []
|
||
for i, prompt in enumerate(prompts):
|
||
parts = re.split(r"(?=[{])", prompt)
|
||
assert len(parts) <= 2, f"Invalid prompt: {prompt}"
|
||
ret_prompts.append(parts[0])
|
||
if len(parts) > 1:
|
||
additional_info = json.loads(parts[1])
|
||
for key in additional_info:
|
||
assert key in ["reference_path", "mask_strategy"], f"Invalid key: {key}"
|
||
if key == "reference_path":
|
||
reference[i] = additional_info[key]
|
||
elif key == "mask_strategy":
|
||
mask_strategy[i] = additional_info[key]
|
||
return ret_prompts, reference, mask_strategy
|
||
|
||
|
||
def extract_json_from_prompts_new(prompts_with_json):
|
||
prompts = []
|
||
reference = []
|
||
mask_strategy = []
|
||
for i, prompt in enumerate(prompts_with_json):
|
||
parts = re.split(r"(?=[{])", prompt)
|
||
assert len(parts) <= 2, f"Invalid prompt: {prompt}"
|
||
prompts.append(parts[0])
|
||
if len(parts) > 1:
|
||
additional_info = json.loads(parts[1])
|
||
for key in additional_info:
|
||
assert key in ["reference_path", "mask_strategy"], f"Invalid key: {key}"
|
||
if key == "reference_path":
|
||
reference[i] = additional_info[key]
|
||
elif key == "mask_strategy":
|
||
mask_strategy[i] = additional_info[key]
|
||
return prompts, reference, mask_strategy
|
||
|
||
|
||
def collect_references_batch(reference_paths, vae, image_size):
|
||
refs_x = [] # refs_x: [batch, ref_num, C, T, H, W]
|
||
for reference_path in reference_paths:
|
||
if reference_path == "":
|
||
refs_x.append([])
|
||
continue
|
||
ref_path = reference_path.split(";")
|
||
ref = []
|
||
for r_path in ref_path:
|
||
r = read_from_path(r_path, image_size, transform_name="resize_crop")
|
||
|
||
# need to ensure r has length accepted by vae
|
||
actual_t = r.size(1)
|
||
if vae.micro_frame_size is None:
|
||
target_t = (actual_t - 1) // 4 * 4 + 1
|
||
elif not vae.temporal_overlap:
|
||
target_t = actual_t // vae.micro_frame_size * vae.micro_frame_size
|
||
else:
|
||
target_t = (actual_t - 1) // (vae.micro_frame_size - 1) * (vae.micro_frame_size - 1) + 1
|
||
r = r[:, :target_t]
|
||
|
||
r_x = vae.encode(r.unsqueeze(0).to(vae.device, vae.dtype))
|
||
r_x = r_x.squeeze(0)
|
||
ref.append(r_x)
|
||
refs_x.append(ref)
|
||
return refs_x
|
||
|
||
|
||
def extract_images_from_ref_paths(reference_paths, image_size):
|
||
refs_images = [] # refs_images: [batch, ref_num, C, T, H, W]
|
||
for reference_path in reference_paths:
|
||
if reference_path == "":
|
||
refs_images.append([])
|
||
continue
|
||
ref_path = reference_path.split(";")
|
||
ref = []
|
||
for r_path in ref_path:
|
||
r = read_from_path(r_path, image_size, transform_name="resize_crop")
|
||
ref.append(r.squeeze(1))
|
||
refs_images.append(ref)
|
||
return refs_images
|
||
|
||
|
||
def extract_prompts_loop(prompts, num_loop):
|
||
ret_prompts = []
|
||
for prompt in prompts:
|
||
if prompt.startswith("|0|"):
|
||
prompt_list = prompt.split("|")[1:]
|
||
text_list = []
|
||
for i in range(0, len(prompt_list), 2):
|
||
start_loop = int(prompt_list[i])
|
||
text = prompt_list[i + 1]
|
||
end_loop = int(prompt_list[i + 2]) if i + 2 < len(prompt_list) else num_loop + 1
|
||
text_list.extend([text] * (end_loop - start_loop))
|
||
prompt = text_list[num_loop]
|
||
ret_prompts.append(prompt)
|
||
return ret_prompts
|
||
|
||
|
||
def split_prompt(prompt_text):
|
||
if prompt_text.startswith("|0|"):
|
||
# this is for prompts which look like
|
||
# |0| a beautiful day |1| a sunny day |2| a rainy day
|
||
# we want to parse it into a list of prompts with the loop index
|
||
prompt_list = prompt_text.split("|")[1:]
|
||
text_list = []
|
||
loop_idx = []
|
||
for i in range(0, len(prompt_list), 2):
|
||
start_loop = int(prompt_list[i])
|
||
text = prompt_list[i + 1].strip()
|
||
text_list.append(text)
|
||
loop_idx.append(start_loop)
|
||
return text_list, loop_idx
|
||
else:
|
||
return [prompt_text], None
|
||
|
||
|
||
def merge_prompt(text_list, loop_idx_list=None):
|
||
if loop_idx_list is None:
|
||
return text_list[0]
|
||
else:
|
||
prompt = ""
|
||
for i, text in enumerate(text_list):
|
||
prompt += f"|{loop_idx_list[i]}|{text}"
|
||
return prompt
|
||
|
||
|
||
MASK_DEFAULT = ["0", "0", "0", "0", "1", "0"]
|
||
|
||
|
||
def parse_mask_strategy(mask_strategy):
|
||
mask_batch = []
|
||
if mask_strategy == "" or mask_strategy is None:
|
||
return mask_batch
|
||
|
||
mask_strategy = mask_strategy.split(";")
|
||
for mask in mask_strategy:
|
||
mask_group = mask.split(",")
|
||
num_group = len(mask_group)
|
||
assert num_group >= 1 and num_group <= 6, f"Invalid mask strategy: {mask}"
|
||
mask_group.extend(MASK_DEFAULT[num_group:])
|
||
for i in range(5):
|
||
mask_group[i] = int(mask_group[i])
|
||
mask_group[5] = float(mask_group[5])
|
||
mask_batch.append(mask_group)
|
||
return mask_batch
|
||
|
||
|
||
def find_nearest_point(value, point, max_value):
|
||
t = value // point
|
||
if value % point > point / 2 and t < max_value // point - 1:
|
||
t += 1
|
||
return t * point
|
||
|
||
|
||
def apply_mask_strategy(z, refs_x, mask_strategys, loop_i, align=None):
|
||
masks = []
|
||
no_mask = True
|
||
for i, mask_strategy in enumerate(mask_strategys):
|
||
no_mask = False
|
||
mask = torch.ones(z.shape[2], dtype=torch.float, device=z.device)
|
||
mask_strategy = parse_mask_strategy(mask_strategy)
|
||
for mst in mask_strategy:
|
||
loop_id, m_id, m_ref_start, m_target_start, m_length, edit_ratio = mst
|
||
if loop_id != loop_i:
|
||
continue
|
||
ref = refs_x[i][m_id]
|
||
|
||
if m_ref_start < 0:
|
||
# ref: [C, T, H, W]
|
||
m_ref_start = ref.shape[1] + m_ref_start
|
||
if m_target_start < 0:
|
||
# z: [B, C, T, H, W]
|
||
m_target_start = z.shape[2] + m_target_start
|
||
if align is not None:
|
||
m_ref_start = find_nearest_point(m_ref_start, align, ref.shape[1])
|
||
m_target_start = find_nearest_point(m_target_start, align, z.shape[2])
|
||
m_length = min(m_length, z.shape[2] - m_target_start, ref.shape[1] - m_ref_start)
|
||
z[i, :, m_target_start : m_target_start + m_length] = ref[:, m_ref_start : m_ref_start + m_length]
|
||
mask[m_target_start : m_target_start + m_length] = edit_ratio
|
||
masks.append(mask)
|
||
if no_mask:
|
||
return None
|
||
masks = torch.stack(masks)
|
||
return masks
|
||
|
||
|
||
def append_generated(
|
||
vae, generated_video, refs_x, mask_strategy, loop_i, condition_frame_length, condition_frame_edit, is_latent=False
|
||
):
|
||
ref_x = vae.encode(generated_video) if not is_latent else generated_video
|
||
|
||
for j, refs in enumerate(refs_x):
|
||
if refs is None:
|
||
refs_x[j] = [ref_x[j]]
|
||
else:
|
||
refs.append(ref_x[j])
|
||
if mask_strategy[j] is None or mask_strategy[j] == "":
|
||
mask_strategy[j] = ""
|
||
else:
|
||
mask_strategy[j] += ";"
|
||
mask_strategy[
|
||
j
|
||
] += f"{loop_i},{len(refs)-1},-{condition_frame_length},0,{condition_frame_length},{condition_frame_edit}"
|
||
return refs_x, mask_strategy
|
||
|
||
|
||
def dframe_to_frame(num):
|
||
assert num % 5 == 0, f"Invalid num: {num}"
|
||
return num // 5 * 17
|
||
|
||
|
||
OPENAI_CLIENT = None
|
||
REFINE_PROMPTS = None
|
||
# REFINE_PROMPTS_PATH = "assets/texts/t2v_pllava.txt"
|
||
REFINE_PROMPTS_PATH = "assets/texts/t2v_demo.txt"
|
||
REFINE_PROMPTS_TEMPLATE = """
|
||
You need to refine user's input prompt. The user's input prompt is used for video generation task. You need to refine the user's prompt to make it more suitable for the task. Here are some examples of refined prompts:
|
||
{}
|
||
|
||
The refined prompt should pay attention to all objects in the video. The description should be useful for AI to re-generate the video. The description should be no more than six sentences. The video should not have any scene transitions and must be on the same scene. The refined prompt should be in English.
|
||
"""
|
||
RANDOM_PROMPTS = None
|
||
RANDOM_PROMPTS_TEMPLATE = """
|
||
You need to generate one input prompt for video generation task. The prompt should be suitable for the task. Here are some examples of refined prompts:
|
||
{}
|
||
|
||
The prompt should pay attention to all objects in the video. The description should be useful for AI to re-generate the video. The description should be no more than six sentences. The video should not have any scene transitions and must be on the same scene. The prompt should be in English.
|
||
"""
|
||
|
||
|
||
def get_openai_response(sys_prompt, usr_prompt, model="gpt-4o"):
|
||
global OPENAI_CLIENT
|
||
if OPENAI_CLIENT is None:
|
||
from openai import OpenAI
|
||
|
||
OPENAI_CLIENT = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
||
|
||
completion = OPENAI_CLIENT.chat.completions.create(
|
||
model=model,
|
||
messages=[
|
||
{
|
||
"role": "system",
|
||
"content": sys_prompt,
|
||
}, # <-- This is the system message that provides context to the model
|
||
{
|
||
"role": "user",
|
||
"content": usr_prompt,
|
||
}, # <-- This is the user message for which the model will generate a response
|
||
],
|
||
)
|
||
|
||
return completion.choices[0].message.content
|
||
|
||
|
||
def get_random_prompt_by_openai():
|
||
global RANDOM_PROMPTS
|
||
if RANDOM_PROMPTS is None:
|
||
examples = load_prompts(REFINE_PROMPTS_PATH)
|
||
RANDOM_PROMPTS = RANDOM_PROMPTS_TEMPLATE.format("\n".join(examples))
|
||
|
||
response = get_openai_response(RANDOM_PROMPTS, "Generate one example.")
|
||
return response
|
||
|
||
|
||
def refine_prompt_by_openai(prompt):
|
||
global REFINE_PROMPTS
|
||
if REFINE_PROMPTS is None:
|
||
examples = load_prompts(REFINE_PROMPTS_PATH)
|
||
REFINE_PROMPTS = REFINE_PROMPTS_TEMPLATE.format("\n".join(examples))
|
||
|
||
response = get_openai_response(REFINE_PROMPTS, prompt)
|
||
return response
|
||
|
||
|
||
def has_openai_key():
|
||
return "OPENAI_API_KEY" in os.environ
|
||
|
||
|
||
def refine_prompts_by_openai(prompts):
|
||
new_prompts = []
|
||
for prompt in prompts:
|
||
try:
|
||
if prompt.strip() == "":
|
||
new_prompt = get_random_prompt_by_openai()
|
||
print(f"[Info] Empty prompt detected, generate random prompt: {new_prompt}")
|
||
else:
|
||
new_prompt = refine_prompt_by_openai(prompt)
|
||
print(f"[Info] Refine prompt: {prompt} -> {new_prompt}")
|
||
new_prompts.append(new_prompt)
|
||
except Exception as e:
|
||
print(f"[Warning] Failed to refine prompt: {prompt} due to {e}")
|
||
new_prompts.append(prompt)
|
||
return new_prompts
|
||
|
||
|
||
FIRST_FRAME_PROMPT_TEMPLATE_WITH_INFO = """
|
||
Given the first frame of the video, describe this video and its style in a very detailed manner. Some information about the video is:
|
||
'{}'.
|
||
|
||
Describe the video and its style in a very detailed manner. Pay attention to all objects in the video. You must describe what the human character is doing with action in the video, for instance, talk, walk, blink, laugh, sing or anything else that involves movements in the video. Your description must make it easy for this vide to have human movements, instead of being motionless. The description should be useful for AI to re-generate the video. The description should be no more than six sentences.
|
||
|
||
"""
|
||
|
||
FIRST_FRAME_PROMPT_TEMPLATE = """
|
||
Given the first frame of the video, you need to generate one input prompt for video generation task. The prompt should be suitable for generating a video using the given image as the first frame.
|
||
|
||
Describe the video and its style in a very detailed manner. Pay attention to all objects in the video. You must describe what the human character is doing with action in the video, for instance, talk, walk, blink, laugh, sing or anything else that involves movements in the video. Your description must make it easy for this vide to have human movements, instead of being motionless. The description should be useful for AI to re-generate the video. The description should be no more than six sentences.
|
||
"""
|
||
|
||
|
||
def to_base64(image_tensor):
|
||
buffer = BytesIO()
|
||
pil_image = transforms.ToPILImage()(image_tensor)
|
||
pil_image.save(buffer, format="JPEG")
|
||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||
|
||
|
||
LLAVA_PREFIX = [
|
||
"The video shows ",
|
||
"The video captures ",
|
||
"The video features ",
|
||
"The video depicts ",
|
||
"The video presents ",
|
||
"The video features ",
|
||
"The video is ",
|
||
"In the video, ",
|
||
"The image shows ",
|
||
"The image captures ",
|
||
"The image features ",
|
||
"The image depicts ",
|
||
"The image presents ",
|
||
"The image features ",
|
||
"The image is ",
|
||
"The image portrays ",
|
||
"In the image, ",
|
||
]
|
||
|
||
|
||
def remove_caption_prefix(caption):
|
||
for prefix in LLAVA_PREFIX:
|
||
if caption.startswith(prefix) or caption.startswith(prefix.lower()):
|
||
caption = caption[len(prefix) :].strip()
|
||
if caption[0].islower():
|
||
caption = caption[0].upper() + caption[1:]
|
||
return caption
|
||
return caption
|
||
|
||
|
||
def get_caption(frame, prompt):
|
||
global OPENAI_CLIENT
|
||
if OPENAI_CLIENT is None:
|
||
from openai import OpenAI
|
||
|
||
OPENAI_CLIENT = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
||
|
||
response = client.chat.completions.create(
|
||
model="gpt-4o-2024-08-06",
|
||
messages=[
|
||
{"role": "system", "content": prompt},
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{frame}", "detail": "low"}},
|
||
],
|
||
},
|
||
],
|
||
max_tokens=300,
|
||
top_p=0.1,
|
||
)
|
||
caption = response.choices[0].message.content
|
||
caption = caption.replace("\n", " ")
|
||
caption = remove_caption_prefix(caption).replace(" image ", " video ")
|
||
return caption
|
||
|
||
|
||
def refine_batched_prompts_with_images(prompts, images):
|
||
new_prompts = []
|
||
for prompt, image in zip(prompts, images):
|
||
try:
|
||
if prompt.strip() == "":
|
||
new_prompt = get_random_prompt_with_image(image)
|
||
print(f"[Info] Empty prompt detected, generate random prompt: {new_prompt}")
|
||
else:
|
||
new_prompt = refine_prompt_with_image(prompt, image)
|
||
print(f"[Info] Refine prompt: {prompt} -> {new_prompt}")
|
||
new_prompts.append(new_prompt)
|
||
except Exception as e:
|
||
print(f"[Warning] Failed to refine prompt: {prompt} due to {e}")
|
||
new_prompts.append(prompt)
|
||
return new_prompts
|
||
|
||
|
||
def refine_prompt_with_image(prompt, image):
|
||
# check api keys
|
||
if has_openai_key():
|
||
os.environ.get("OPENAI_API_KEY")
|
||
else:
|
||
print("no openai api key found, prompt not refined")
|
||
return prompt
|
||
frame = to_base64(image)
|
||
caption = get_caption(frame, FIRST_FRAME_PROMPT_TEMPLATE_WITH_INFO.format(prompt))
|
||
return caption
|
||
|
||
|
||
def get_random_prompt_with_image(image):
|
||
# check api keys
|
||
if has_openai_key():
|
||
os.environ.get("OPENAI_API_KEY")
|
||
else:
|
||
print("no openai api key found, prompt not refined")
|
||
return prompt
|
||
|
||
frame = to_base64(image)
|
||
caption = get_caption(frame, FIRST_FRAME_PROMPT_TEMPLATE)
|
||
return caption
|
||
|
||
|
||
def add_watermark(
|
||
input_video_path, watermark_image_path="./assets/images/watermark/watermark.png", output_video_path=None
|
||
):
|
||
# execute this command in terminal with subprocess
|
||
# return if the process is successful
|
||
if output_video_path is None:
|
||
output_video_path = input_video_path.replace(".mp4", "_watermark.mp4")
|
||
cmd = f'ffmpeg -y -i {input_video_path} -i {watermark_image_path} -filter_complex "[1][0]scale2ref=oh*mdar:ih*0.1[logo][video];[video][logo]overlay" {output_video_path}'
|
||
exit_code = os.system(cmd)
|
||
is_success = exit_code == 0
|
||
return is_success
|
||
|
||
|
||
def super_resolution(input_video_path, sr=2):
|
||
temp_dir = tempfile.TemporaryDirectory()
|
||
cmd = f"python -m tools.repair.inference_realesrgan_video -n RealESRGAN_x4plus -s 2 -i {input_video_path} -o {temp_dir.name}"
|
||
os.system(cmd)
|
||
output_video_path = os.path.join(temp_dir.name, os.path.basename(input_video_path).split(".")[0] + "_out.mp4")
|
||
dst_video_path = os.path.join(
|
||
os.path.dirname(input_video_path), os.path.basename(input_video_path).split(".")[0] + f"_sr{sr:.0f}.mp4"
|
||
)
|
||
shutil.copyfile(output_video_path, dst_video_path)
|
||
temp_dir.cleanup()
|
||
return dst_video_path
|
||
|
||
|
||
def deflicker_video_local_brightness(input_video_path, output_video_path, smoothing_window=30, block_size=32):
|
||
# 打开输入视频
|
||
cap = cv2.VideoCapture(input_video_path)
|
||
|
||
# 获取视频的基本信息
|
||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||
|
||
# 设置输出视频的编解码器和格式
|
||
fourcc = cv2.VideoWriter_fourcc(*"h264")
|
||
out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
|
||
|
||
# 初始化存储每帧局部亮度信息的数组
|
||
local_brightness_list = []
|
||
|
||
# 读取视频帧并计算局部亮度
|
||
frames = []
|
||
while True:
|
||
ret, frame = cap.read()
|
||
if not ret:
|
||
break
|
||
frames.append(frame)
|
||
|
||
# 转换为灰度图像
|
||
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
||
|
||
# 将图像划分为小块,并计算每块的平均亮度
|
||
local_brightness = []
|
||
for y in range(0, height, block_size):
|
||
for x in range(0, width, block_size):
|
||
block = gray[y : y + block_size, x : x + block_size]
|
||
mean_brightness = np.mean(block)
|
||
local_brightness.append(mean_brightness)
|
||
|
||
local_brightness_list.append(local_brightness)
|
||
|
||
# 将局部亮度进行平滑处理
|
||
local_brightness_array = np.array(local_brightness_list)
|
||
|
||
# 在前后添加填充
|
||
assert total_frames % 2 == 1, "The number of frames should be odd."
|
||
pad_width = smoothing_window // 2
|
||
padded_local_brightness = np.pad(local_brightness_array, ((pad_width, pad_width - 1), (0, 0)), mode="edge")
|
||
|
||
# 创建一个存储平滑亮度的数组,大小与local_brightness_array相同
|
||
smoothed_local_brightness = np.zeros_like(local_brightness_array)
|
||
|
||
# 对每个局部块的亮度进行平滑处理
|
||
for i in range(local_brightness_array.shape[1]):
|
||
# 使用卷积进行平滑处理
|
||
smoothed_brightness = np.convolve(
|
||
padded_local_brightness[:, i], np.ones(smoothing_window) / smoothing_window, mode="valid"
|
||
)
|
||
smoothed_local_brightness[:, i] = smoothed_brightness
|
||
|
||
# 对每一帧应用局部亮度校正
|
||
for i in range(len(frames)):
|
||
frame = frames[i].copy()
|
||
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
||
|
||
# 逐块调整亮度
|
||
idx = 0
|
||
for y in range(0, height, block_size):
|
||
for x in range(0, width, block_size):
|
||
block = gray[y : y + block_size, x : x + block_size]
|
||
current_brightness = np.mean(block)
|
||
brightness_ratio = smoothed_local_brightness[i, idx] / current_brightness
|
||
frame[y : y + block_size, x : x + block_size] = np.clip(
|
||
frame[y : y + block_size, x : x + block_size] * brightness_ratio, 0, 255
|
||
).astype(np.uint8)
|
||
idx += 1
|
||
|
||
out.write(frame)
|
||
|
||
# 释放资源
|
||
cap.release()
|
||
out.release()
|
||
print(f"Deflickered video saved as {output_video_path}")
|
||
|
||
|
||
def deflicker(input_video_path):
|
||
deflicker_video_local_brightness(input_video_path, input_video_path.replace(".mp4", "_deflicker.mp4"))
|
||
return input_video_path
|
||
|
||
|
||
GIGABYTE = 1024**3
|
||
|
||
|
||
def print_memory_usage(prefix: str, device: torch.device):
|
||
torch.cuda.synchronize()
|
||
max_memory_allocated = torch.cuda.max_memory_allocated(device)
|
||
max_memory_reserved = torch.cuda.max_memory_reserved(device)
|
||
print(f"{prefix}: max memory allocated: {max_memory_allocated / GIGABYTE:.4f} GB")
|
||
print(f"{prefix}: max memory reserved: {max_memory_reserved / GIGABYTE:.4f} GB")
|
||
|
||
|
||
def easy_data(csv):
|
||
from opensora.registry import DATASETS, build_module
|
||
|
||
dataset = build_module(
|
||
{
|
||
"type": "VariableVideoTextDataset",
|
||
"transform_name": "resize_crop",
|
||
"data_path": csv,
|
||
},
|
||
DATASETS,
|
||
)
|
||
|
||
return dataset["0-113-360-640"]
|
||
|
||
|
||
def prep_ref_and_mask(cond_type, condition_frame_length, refs, target_shape, loop, device, dtype):
|
||
"""
|
||
prepare the mask_index and reference for the 1st loop
|
||
Input:
|
||
loop: total number of loops to do
|
||
"""
|
||
latent_t = target_shape[2]
|
||
|
||
if cond_type is None:
|
||
mask_index = []
|
||
|
||
elif cond_type == "v2v_head":
|
||
min_ref_length = min([ref[0].shape[1] for ref in refs])
|
||
condition_frame_length = min(
|
||
min(condition_frame_length, min_ref_length), latent_t
|
||
) # ensure condition frame is no more than generated length
|
||
mask_index = [i for i in range(condition_frame_length)]
|
||
elif cond_type == "i2v_head" or cond_type == "i2v_loop":
|
||
mask_index = [0] # update mask on last frame lfor i2v_loop
|
||
elif cond_type == "i2v_tail":
|
||
if loop == 1:
|
||
mask_index = [-1] # update mask to be positive later
|
||
else:
|
||
mask_index = [] # cond on last frame in final loop
|
||
else:
|
||
raise NotImplementedError
|
||
|
||
# prep ref in correct shape
|
||
ref = torch.zeros(target_shape, device=device, dtype=dtype)
|
||
|
||
if len(mask_index) > 0:
|
||
b = target_shape[0]
|
||
for b_i in range(b):
|
||
ref[b_i, :, mask_index] = refs[b_i][0][:, mask_index].unsqueeze(0)
|
||
|
||
# get finalized mask_index, except i2v_tail and i2v_loop intermediate loops will update later
|
||
if cond_type == "i2v_loop" and loop <= 1:
|
||
b = target_shape[0]
|
||
for b_i in range(b):
|
||
if len(refs[0]) == 1: # if only 1 ref, use last frame
|
||
ref[b_i, :, -1] = refs[b_i][0][:, -1].unsqueeze(0)
|
||
else:
|
||
ref[b_i, :, -1] = refs[b_i][1][:, 0].unsqueeze(0) # CHANGED TO USE IMAGE
|
||
mask_index.append(latent_t - 1)
|
||
if cond_type == "i2v_tail" and loop <= 1:
|
||
mask_index = [latent_t - 1]
|
||
|
||
return ref, mask_index
|
||
|
||
|
||
def prep_ref_and_update_mask_in_loop(
|
||
cond_type, condition_frame_length, samples, refs, target_shape, is_last_loop, device, dtype
|
||
):
|
||
latent_t = target_shape[2]
|
||
# cond frames from last generation
|
||
loop_cond_index = [i for i in range(-condition_frame_length, 0)]
|
||
|
||
# get ref in correct shape
|
||
ref = torch.zeros(target_shape, device=device, dtype=dtype)
|
||
ref_cut = samples[:, :, loop_cond_index].to(device=device, dtype=dtype)
|
||
mask_index = [i for i in range(condition_frame_length)]
|
||
ref[:, :, mask_index] = ref_cut
|
||
|
||
if cond_type == "i2v_loop" or cond_type == "i2v_tail" and is_last_loop:
|
||
b = target_shape[0]
|
||
for b_i in range(b):
|
||
if len(refs[b_i]) == 1: # if only 1 reference passed, use last frame
|
||
ref[b_i, :, -1] = refs[b_i][0][:, -1].unsqueeze(0).to(device=device, dtype=dtype)
|
||
else: # use the last frame (either video or image) of second reference
|
||
ref[b_i, :, -1] = refs[b_i][1][:, 0].unsqueeze(0).to(device=device, dtype=dtype)
|
||
|
||
mask_index.append(latent_t - 1) # mask_index for final loop
|
||
|
||
return ref, mask_index
|