Merge branch 'dev/v1.0.1' of github.com:hpcaitech/Open-Sora-dev into dev/Rflow

This commit is contained in:
tianyi 2024-04-15 19:42:45 +08:00
commit 62d96a2ced
39 changed files with 936 additions and 1139 deletions

View file

@ -22,7 +22,7 @@ scheduler = dict(
num_sampling_steps=20,
cfg_scale=4.0,
)
dtype = "fp16"
dtype = "bf16"
# Others
batch_size = 2

View file

@ -22,7 +22,7 @@ scheduler = dict(
num_sampling_steps=20,
cfg_scale=4.0,
)
dtype = "fp16"
dtype = "bf16"
# Others
batch_size = 2

View file

@ -23,7 +23,7 @@ scheduler = dict(
num_sampling_steps=20,
cfg_scale=4.0,
)
dtype = "fp16"
dtype = "bf16"
# Others
batch_size = 2

View file

@ -1,16 +1,16 @@
num_frames = 16
frame_interval = 3
image_size = (256, 256)
# Define dataset
root = None
data_path = "CSV_PATH"
use_image_transform = False
num_workers = 4
dataset = dict(
type="VideoTextDataset",
data_path=None,
num_frames=16,
frame_interval=3,
image_size=(256, 256),
)
# Define acceleration
num_workers = 4
dtype = "bf16"
grad_checkpoint = False
grad_checkpoint = True
plugin = "zero2"
sp_size = 1

View file

@ -21,7 +21,7 @@ scheduler = dict(
num_sampling_steps=20,
cfg_scale=4.0,
)
dtype = "fp16"
dtype = "bf16"
# Others
batch_size = 2

View file

@ -22,7 +22,7 @@ scheduler = dict(
num_sampling_steps=20,
cfg_scale=4.0,
)
dtype = "fp16"
dtype = "bf16"
# Others
batch_size = 2

View file

@ -1,14 +1,14 @@
num_frames = 16
frame_interval = 3
image_size = (256, 256)
# Define dataset
root = None
data_path = "CSV_PATH"
use_image_transform = False
num_workers = 4
dataset = dict(
type="VideoTextDataset",
data_path=None,
num_frames=16,
frame_interval=3,
image_size=(256, 256),
)
# Define acceleration
num_workers = 4
dtype = "bf16"
grad_checkpoint = True
plugin = "zero2"

View file

@ -0,0 +1,70 @@
# Define dataset
dataset = dict(
type="VariableVideoTextDataset",
data_path=None,
num_frames=None,
frame_interval=3,
image_size=(None, None),
transform_name="resize_crop",
)
bucket_config = { # 6s/it
"240p": {16: (1.0, 17), 32: (1.0, 9), 64: (1.0, 4), 128: (1.0, 2)},
"256": {1: (1.0, 254)},
"512": {1: (0.5, 86)},
"480p": {1: (0.4, 54), 16: (0.4, 4), 32: (0.0, None)},
"720p": {16: (0.1, 2), 32: (0.0, None)}, # No examples now
"1024": {1: (0.3, 20)},
"1080p": {1: (0.4, 8)},
}
mask_ratios = {
"mask_no": 0.9,
"mask_random": 0.06,
"mask_head": 0.01,
"mask_tail": 0.01,
"mask_head_tail": 0.02,
}
# Define acceleration
num_workers = 4
dtype = "bf16"
grad_checkpoint = True
plugin = "zero2"
sp_size = 1
# Define model
model = dict(
type="STDiT2-XL/2",
from_pretrained=None,
input_sq_size=512, # pretrained model is trained on 512x512
enable_flashattn=True,
enable_layernorm_kernel=True,
)
vae = dict(
type="VideoAutoencoderKL",
from_pretrained="stabilityai/sd-vae-ft-ema",
micro_batch_size=4,
)
text_encoder = dict(
type="t5",
from_pretrained="DeepFloyd/t5-v1_1-xxl",
model_max_length=200,
shardformer=True,
)
scheduler = dict(
type="iddpm-speed",
timestep_respacing="",
)
# Others
seed = 42
outputs = "outputs"
wandb = False
epochs = 1000
log_every = 10
ckpt_every = 500
load = None
batch_size = None
lr = 2e-5
grad_clip = 1.0

View file

@ -28,7 +28,7 @@ scheduler = dict(
num_sampling_steps=100,
cfg_scale=7.0,
)
dtype = "fp16"
dtype = "bf16"
# Condition
prompt_path = None

View file

@ -29,7 +29,7 @@ scheduler = dict(
num_sampling_steps=100,
cfg_scale=7.0,
)
dtype = "fp16"
dtype = "bf16"
# Condition
prompt_path = None

View file

@ -29,7 +29,7 @@ scheduler = dict(
num_sampling_steps=100,
cfg_scale=7.0,
)
dtype = "fp16"
dtype = "bf16"
# Condition
prompt_path = None

View file

@ -29,7 +29,7 @@ scheduler = dict(
num_sampling_steps=100,
cfg_scale=7.0,
)
dtype = "fp16"
dtype = "bf16"
# Condition
prompt_path = None

View file

@ -27,7 +27,7 @@ scheduler = dict(
cfg_scale=7.0,
cfg_channel=3, # or None
)
dtype = "fp16"
dtype = "bf16"
# Condition
prompt_path = "./assets/texts/t2v_samples.txt"

View file

@ -26,7 +26,7 @@ scheduler = dict(
num_sampling_steps=100,
cfg_scale=7.0,
)
dtype = "fp16"
dtype = "bf16"
# Others
batch_size = 2

View file

@ -26,7 +26,7 @@ scheduler = dict(
num_sampling_steps=100,
cfg_scale=7.0,
)
dtype = "fp16"
dtype = "bf16"
# Others
batch_size = 1

View file

@ -1,14 +1,14 @@
num_frames = 16
frame_interval = 3
image_size = (256, 256)
# Define dataset
root = None
data_path = "CSV_PATH"
use_image_transform = False
num_workers = 4
dataset = dict(
type="VideoTextDataset",
data_path=None,
num_frames=16,
frame_interval=3,
image_size=(256, 256),
)
# Define acceleration
num_workers = 4
dtype = "bf16"
grad_checkpoint = True
plugin = "zero2"
@ -23,7 +23,13 @@ model = dict(
enable_flashattn=True,
enable_layernorm_kernel=True,
)
mask_ratios = [0.7, 0.15, 0.05, 0.05, 0.05]
mask_ratios = {
"mask_no": 0.7,
"mask_random": 0.15,
"mask_head": 0.05,
"mask_tail": 0.05,
"mask_head_tail": 0.05,
}
vae = dict(
type="VideoAutoencoderKL",
from_pretrained="stabilityai/sd-vae-ft-ema",

View file

@ -23,7 +23,13 @@ model = dict(
enable_flashattn=True,
enable_layernorm_kernel=True,
)
mask_ratios = [0.5, 0.29, 0.07, 0.07, 0.07]
mask_ratios = {
"mask_no": 0.5,
"mask_random": 0.29,
"mask_head": 0.07,
"mask_tail": 0.07,
"mask_head_tail": 0.07,
}
vae = dict(
type="VideoAutoencoderKL",
from_pretrained="stabilityai/sd-vae-ft-ema",

View file

@ -1,14 +1,14 @@
num_frames = 16
frame_interval = 3
image_size = (256, 256)
# Define dataset
root = None
data_path = "CSV_PATH"
use_image_transform = False
num_workers = 4
dataset = dict(
type="VideoTextDataset",
data_path=None,
num_frames=16,
frame_interval=3,
image_size=(256, 256),
)
# Define acceleration
num_workers = 4
dtype = "bf16"
grad_checkpoint = True
plugin = "zero2"

View file

@ -1,55 +0,0 @@
num_frames = 16
frame_interval = 3
image_size = (512, 512)
# Define dataset
root = None
data_path = "CSV_PATH"
use_image_transform = False
num_workers = 4
# Define acceleration
dtype = "bf16"
grad_checkpoint = True
plugin = "zero2"
sp_size = 1
# Define model
model = dict(
type="STDiT-XL/2",
space_scale=1.0,
time_scale=1.0,
from_pretrained=None,
enable_flashattn=True,
enable_layernorm_kernel=True,
)
mask_ratios = [0.7, 0.15, 0.05, 0.05, 0.05]
vae = dict(
type="VideoAutoencoderKL",
from_pretrained="stabilityai/sd-vae-ft-ema",
micro_batch_size=128,
)
text_encoder = dict(
type="t5",
from_pretrained="DeepFloyd/t5-v1_1-xxl",
model_max_length=120,
shardformer=True,
)
scheduler = dict(
type="iddpm",
timestep_respacing="",
)
# Others
seed = 42
outputs = "outputs"
wandb = False
epochs = 1000
log_every = 10
ckpt_every = 1000
load = None
batch_size = 8
lr = 2e-5
grad_clip = 1.0

View file

@ -1,16 +1,16 @@
num_frames = 16
frame_interval = 3
image_size = (512, 512)
# Define dataset
root = None
data_path = "CSV_PATH"
use_image_transform = False
num_workers = 4
dataset = dict(
type="VideoTextDataset",
data_path=None,
num_frames=16,
frame_interval=3,
image_size=(512, 512),
)
# Define acceleration
num_workers = 4
dtype = "bf16"
grad_checkpoint = False
grad_checkpoint = True
plugin = "zero2"
sp_size = 1

View file

@ -1,12 +1,18 @@
num_frames = 360
frame_interval = 1
image_size = (512, 512)
# Define dataset
root = None
data_path = "CSV_PATH"
use_image_transform = False
dataset = dict(
type="VideoTextDataset",
data_path=None,
num_frames=360,
frame_interval=3,
image_size=(512, 512),
)
# Define acceleration
num_workers = 4
dtype = "bf16"
grad_checkpoint = True
plugin = "zero2"
sp_size = 1
# Define acceleration
dtype = "bf16"

View file

@ -1,17 +1,17 @@
num_frames = 64
frame_interval = 2
image_size = (512, 512)
# Define dataset
root = None
data_path = "CSV_PATH"
use_image_transform = False
num_workers = 4
dataset = dict(
type="VideoTextDataset",
data_path=None,
num_frames=16,
frame_interval=3,
image_size=(512, 512),
)
# Define acceleration
num_workers = 4
dtype = "bf16"
grad_checkpoint = True
plugin = "zero2-seq"
plugin = "zero2"
sp_size = 2
# Define model

View file

@ -1,14 +1,14 @@
num_frames = 64
frame_interval = 2
image_size = (512, 512)
# Define dataset
root = None
data_path = "CSV_PATH"
use_image_transform = False
num_workers = 4
dataset = dict(
type="VideoTextDataset",
data_path=None,
num_frames=64,
frame_interval=3,
image_size=(512, 512),
)
# Define acceleration
num_workers = 4
dtype = "bf16"
grad_checkpoint = True
plugin = "zero2"

View file

@ -23,7 +23,7 @@ scheduler = dict(
num_sampling_steps=20,
cfg_scale=7.0,
)
dtype = "fp16"
dtype = "bf16"
# Others
batch_size = 2

View file

@ -25,7 +25,7 @@ scheduler = dict(
num_sampling_steps=20,
cfg_scale=7.0,
)
dtype = "fp16"
dtype = "bf16"
# Others
batch_size = 2

View file

@ -24,7 +24,7 @@ scheduler = dict(
num_sampling_steps=20,
cfg_scale=7.0,
)
dtype = "fp16"
dtype = "bf16"
# Others
batch_size = 2

View file

@ -24,7 +24,7 @@ scheduler = dict(
num_sampling_steps=20,
cfg_scale=7.0,
)
dtype = "fp16"
dtype = "bf16"
# prompt_path = "./assets/texts/t2i_samples.txt"
prompt = [

View file

@ -1,16 +1,16 @@
num_frames = 16
frame_interval = 3
image_size = (256, 256)
# Define dataset
root = None
data_path = "CSV_PATH"
use_image_transform = False
num_workers = 4
dataset = dict(
type="VideoTextDataset",
data_path=None,
num_frames=16,
frame_interval=3,
image_size=(256, 256),
)
# Define acceleration
num_workers = 4
dtype = "bf16"
grad_checkpoint = False
grad_checkpoint = True
plugin = "zero2"
sp_size = 1

View file

@ -1,14 +1,14 @@
num_frames = 1
frame_interval = 1
image_size = (512, 512)
# Define dataset
root = None
data_path = "CSV_PATH"
use_image_transform = True
num_workers = 4
dataset = dict(
type="VideoTextDataset",
data_path=None,
num_frames=1,
frame_interval=3,
image_size=(512, 512),
)
# Define acceleration
num_workers = 4
dtype = "bf16"
grad_checkpoint = True
plugin = "zero2"

View file

@ -1,19 +1,20 @@
num_frames = 64
frame_interval = 2
image_size = (512, 512)
# Define dataset
root = None
data_path = "CSV_PATH"
use_image_transform = False
num_workers = 4
dataset = dict(
type="VideoTextDataset",
data_path=None,
num_frames=64,
frame_interval=3,
image_size=(256, 256),
)
# Define acceleration
num_workers = 4
dtype = "bf16"
grad_checkpoint = True
plugin = "zero2"
sp_size = 1
# Define model
model = dict(
type="PixArt-XL/2",

View file

@ -1,46 +0,0 @@
import argparse
import json
import pandas as pd
from pandarallel import pandarallel
def process_csv(original_csv_path, original_video_format="mp4"):
pandarallel.initialize(progress_bar=True)
# Construct the new CSV file name by adding '_caption' before the '.csv' extension
caption_file = original_csv_path.replace(".csv", "_caption.csv")
df = pd.read_csv(original_csv_path)
# Add a new column for captions initialized with 'None'
df["caption"] = "None"
def process_row(row):
path = row["path"]
json_path = path.replace(original_video_format, "json")
with open(json_path, "r") as f:
json_data = json.load(f)
row["caption"] = json_data["caption"]
return row
# Iterate over each row to replace video format with json in the path, and extract captions
df = df.parallel_apply(process_row, axis=1)
# Save the modified DataFrame to a new CSV file
df.to_csv(caption_file, index=False)
print(f"New CSV file with captions is saved as {caption_file}")
if __name__ == "__main__":
# Setup argument parser
parser = argparse.ArgumentParser(description="Process a CSV file to add video captions.")
parser.add_argument("csv_path", type=str, help="The path to the original CSV file.")
parser.add_argument("video_format", nargs="?", default="mp4", help="The original video format (default: mp4).")
# Parse arguments
args = parser.parse_args()
# Call the function with the provided arguments
process_csv(args.csv_path, args.video_format)

File diff suppressed because it is too large Load diff

View file

@ -21,7 +21,10 @@
# T5: https://github.com/google-research/text-to-text-transfer-transformer
# --------------------------------------------------------
import html
import re
import ftfy
import torch
from transformers import AutoTokenizer, T5EncoderModel
@ -96,12 +99,8 @@ class T5Embedder:
self.hf_token = hf_token
assert from_pretrained in self.available_models
self.tokenizer = AutoTokenizer.from_pretrained(
from_pretrained, cache_dir=cache_dir
)
self.model = T5EncoderModel.from_pretrained(
from_pretrained, cache_dir=cache_dir, **t5_model_kwargs
).eval()
self.tokenizer = AutoTokenizer.from_pretrained(from_pretrained, cache_dir=cache_dir)
self.model = T5EncoderModel.from_pretrained(from_pretrained, cache_dir=cache_dir, **t5_model_kwargs).eval()
self.model_max_length = model_max_length
def get_text_embeddings(self, texts):
@ -185,3 +184,142 @@ class T5Encoder:
def null(self, n):
null_y = self.y_embedder.y_embedding[None].repeat(n, 1, 1)[:, None]
return null_y
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
BAD_PUNCT_REGEX = re.compile(
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
) # noqa
def clean_caption(caption):
import urllib.parse as ul
from bs4 import BeautifulSoup
caption = str(caption)
caption = ul.unquote_plus(caption)
caption = caption.strip().lower()
caption = re.sub("<person>", "person", caption)
# urls:
caption = re.sub(
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
"",
caption,
) # regex for urls
caption = re.sub(
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
"",
caption,
) # regex for urls
# html:
caption = BeautifulSoup(caption, features="html.parser").text
# @<nickname>
caption = re.sub(r"@[\w\d]+\b", "", caption)
# 31C0—31EF CJK Strokes
# 31F0—31FF Katakana Phonetic Extensions
# 3200—32FF Enclosed CJK Letters and Months
# 3300—33FF CJK Compatibility
# 3400—4DBF CJK Unified Ideographs Extension A
# 4DC0—4DFF Yijing Hexagram Symbols
# 4E00—9FFF CJK Unified Ideographs
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
#######################################################
# все виды тире / all types of dash --> "-"
caption = re.sub(
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
"-",
caption,
)
# кавычки к одному стандарту
caption = re.sub(r"[`´«»“”¨]", '"', caption)
caption = re.sub(r"[]", "'", caption)
# &quot;
caption = re.sub(r"&quot;?", "", caption)
# &amp
caption = re.sub(r"&amp", "", caption)
# ip adresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
caption = re.sub(r"\d:\d\d\s+$", "", caption)
# \n
caption = re.sub(r"\\n", " ", caption)
# "#123"
caption = re.sub(r"#\d{1,3}\b", "", caption)
# "#12345.."
caption = re.sub(r"#\d{5,}\b", "", caption)
# "123456.."
caption = re.sub(r"\b\d{6,}\b", "", caption)
# filenames:
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
#
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
caption = re.sub(BAD_PUNCT_REGEX, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
# this-is-my-cute-cat / this_is_my_cute_cat
regex2 = re.compile(r"(?:\-|\_)")
if len(re.findall(regex2, caption)) > 3:
caption = re.sub(regex2, " ", caption)
caption = basic_clean(caption)
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
caption = re.sub(r"\s+", " ", caption)
caption.strip()
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
caption = re.sub(r"^\.\S+$", "", caption)
return caption.strip()
def text_preprocessing(text, use_text_preprocessing: bool = True):
if use_text_preprocessing:
# The exact text cleaning as was in the training stage:
text = clean_caption(text)
text = clean_caption(text)
return text
else:
return text.lower().strip()

View file

@ -169,11 +169,12 @@ def save(
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096)
if lr_scheduler is not None:
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
sampler_start_idx = step * batch_size if batch_size is not None else None
running_states = {
"epoch": epoch,
"step": step,
"global_step": global_step,
"sample_start_index": step * batch_size,
"sample_start_index": sampler_start_idx,
}
if coordinator.is_master():
save_json(running_states, os.path.join(save_dir, "running_states.json"))

View file

@ -9,6 +9,7 @@ from mmengine.runner import set_random_seed
from opensora.acceleration.parallel_states import set_sequence_parallel_group
from opensora.datasets import save_sample
from opensora.datasets.utils import read_from_path
from opensora.models.text_encoder.t5 import text_preprocessing
from opensora.registry import MODELS, SCHEDULERS, build_module
from opensora.utils.config_utils import parse_configs
from opensora.utils.misc import to_torch_dtype
@ -65,11 +66,13 @@ def process_prompts(prompts, num_loop):
for i in range(0, len(prompt_list), 2):
start_loop = int(prompt_list[i])
text = prompt_list[i + 1]
text = text_preprocessing(text)
end_loop = int(prompt_list[i + 2]) if i + 2 < len(prompt_list) else num_loop
text_list.extend([text] * (end_loop - start_loop))
assert len(text_list) == num_loop
ret_prompts.append(text_list)
else:
prompt = text_preprocessing(prompt)
ret_prompts.append([prompt] * num_loop)
return ret_prompts

View file

@ -8,6 +8,7 @@ from mmengine.runner import set_random_seed
from opensora.acceleration.parallel_states import set_sequence_parallel_group
from opensora.datasets import save_sample
from opensora.models.text_encoder.t5 import text_preprocessing
from opensora.registry import MODELS, SCHEDULERS, build_module
from opensora.utils.config_utils import parse_configs
from opensora.utils.misc import to_torch_dtype
@ -97,6 +98,7 @@ def main():
for i in range(0, len(prompts), cfg.batch_size):
# 4.2 sample in hidden space
batch_prompts = prompts[i : i + cfg.batch_size]
batch_prompts = [text_preprocessing(prompt) for prompt in batch_prompts]
z = torch.randn(len(batch_prompts), vae.out_channels, *latent_size, device=device, dtype=dtype)
# 4.3. diffusion sampling

View file

@ -10,6 +10,7 @@ from glob import glob
import cv2
import numpy as np
import pandas as pd
import torchvision
from tqdm import tqdm
from .utils import IMG_EXTENSIONS
@ -31,6 +32,8 @@ def apply(df, func, **kwargs):
return df.progress_apply(func, **kwargs)
TRAIN_COLUMNS = ["path", "text", "num_frames", "fps", "height", "width", "aspect_ratio", "resolution", "text_len"]
# ======================================================
# --info
# ======================================================
@ -71,14 +74,15 @@ def get_info(path):
def get_video_info(path):
import torchvision
vframes, _, _ = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
num_frames, height, width = vframes.shape[0], vframes.shape[1], vframes.shape[2]
aspect_ratio = height / width
fps = np.nan
resolution = height * width
return num_frames, height, width, aspect_ratio, fps, resolution
try:
vframes, _, _ = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW")
num_frames, height, width = vframes.shape[0], vframes.shape[2], vframes.shape[3]
aspect_ratio = height / width
fps = np.nan
resolution = height * width
return num_frames, height, width, aspect_ratio, fps, resolution
except:
return 0, 0, 0, np.nan, np.nan, np.nan
# ======================================================
@ -134,13 +138,13 @@ CMOTION_TEXT = {
}
CMOTION_PROBS = {
# hard-coded probabilities
"static": 0.1,
"dynamic": 0.25,
"static": 1.0,
"dynamic": 1.0,
"unknown": 0.0,
"zoom in": 0.8,
"zoom in": 1.0,
"zoom out": 1.0,
"pan left": 1.0,
"pan right": 0.1,
"pan right": 1.0,
"tilt up": 1.0,
"tilt down": 1.0,
"pan/tilt": 1.0,
@ -356,6 +360,9 @@ def read_file(input_path):
def save_file(data, output_path):
output_dir = os.path.dirname(output_path)
if not os.path.exists(output_dir) and output_dir != "":
os.makedirs(output_dir)
if output_path.endswith(".csv"):
return data.to_csv(output_path, index=False)
elif output_path.endswith(".parquet"):
@ -410,6 +417,12 @@ def main(args):
data = pd.merge(data, data_new[cols_to_use], on="path", how="inner")
print(f"Intersection number of samples: {len(data)}.")
# train columns
if args.train_column:
all_columns = data.columns
columns_to_drop = all_columns.difference(TRAIN_COLUMNS)
data = data.drop(columns=columns_to_drop)
# get output path
output_path = get_output_path(args, input_name)
@ -469,20 +482,18 @@ def main(args):
data["path"] = apply(data["path"], lambda x: os.path.relpath(x, args.relpath))
if args.abspath is not None:
data["path"] = apply(data["path"], lambda x: os.path.join(args.abspath, x))
if args.merge_cmotion:
data["text"] = apply(data, lambda x: merge_cmotion(x["text"], x["cmotion"]), axis=1)
if args.refine_llm_caption:
assert "text" in data.columns
data["text"] = apply(data["text"], remove_caption_prefix)
if args.unescape:
assert "text" in data.columns
data["text"] = apply(data["text"], html.unescape)
if args.clean_caption:
assert "text" in data.columns
data["text"] = apply(
data["text"],
partial(text_preprocessing, use_text_preprocessing=True),
)
if args.merge_cmotion:
data["text"] = apply(data, lambda x: merge_cmotion(x["text"], x["cmotion"]), axis=1)
if args.count_num_token is not None:
assert "text" in data.columns
data["text_len"] = apply(data["text"], lambda x: len(tokenizer(x)["input_ids"]))
@ -505,7 +516,11 @@ def main(args):
assert "num_frames" in data.columns
data = data[data["num_frames"] <= args.fmax]
if args.hwmax is not None:
assert "resolution" in data.columns
if "resolution" not in data.columns:
height = data["height"]
width = data["width"]
data["resolution"] = height * width
breakpoint()
data = data[data["resolution"] <= args.hwmax]
if args.aesmin is not None:
assert "aes" in data.columns
@ -549,6 +564,7 @@ def parse_args():
parser.add_argument(
"--intersection", type=str, default=None, help="keep the paths in csv from the dataset and merge columns"
)
parser.add_argument("--train-column", action="store_true", help="only keep the train column")
# IO-related
parser.add_argument("--info", action="store_true", help="get the basic information of each video and image")
@ -578,7 +594,6 @@ def parse_args():
parser.add_argument(
"--clean-caption", action="store_true", help="modify the caption according to T5 pipeline to suit training"
)
parser.add_argument("--unescape", action="store_true", help="unescape the caption")
parser.add_argument("--merge-cmotion", action="store_true", help="merge the camera motion to the caption")
parser.add_argument(
"--count-num-token", type=str, choices=["t5"], default=None, help="Count the number of tokens in the caption"
@ -641,8 +656,6 @@ def get_output_path(args, input_name):
# caption processing
if args.refine_llm_caption:
name += "_llm"
if args.unescape:
name += "_unescape"
if args.clean_caption:
name += "_clean"
if args.merge_cmotion:

View file

@ -18,6 +18,7 @@ def extract_frames(
points=None,
backend="opencv",
return_length=False,
num_frames=None,
):
"""
Args:
@ -34,7 +35,10 @@ def extract_frames(
import av
container = av.open(video_path)
total_frames = container.streams.video[0].frames
if num_frames is not None:
total_frames = num_frames
else:
total_frames = container.streams.video[0].frames
if points is not None:
frame_inds = [int(p * total_frames) for p in points]
@ -56,8 +60,10 @@ def extract_frames(
import decord
container = decord.VideoReader(video_path, num_threads=1)
total_frames = len(container)
# avg_fps = container.get_avg_fps()
if num_frames is not None:
total_frames = num_frames
else:
total_frames = len(container)
if points is not None:
frame_inds = [int(p * total_frames) for p in points]
@ -73,7 +79,10 @@ def extract_frames(
elif backend == "opencv":
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if num_frames is not None:
total_frames = num_frames
else:
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if points is not None:
frame_inds = [int(p * total_frames) for p in points]
@ -90,7 +99,6 @@ def extract_frames(
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = Image.fromarray(frame)
except Exception as e:
breakpoint()
print(f"Error reading frame {video_path}: {e}")
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))

View file

@ -45,10 +45,14 @@ class VideoTextDataset(torch.utils.data.Dataset):
if not is_video(path):
images = [pil_loader(path)]
else:
images = extract_frames(sample["path"], points=self.points, backend="opencv")
num_frames = None
if "num_frames" in sample:
num_frames = sample["num_frames"]
images = extract_frames(sample["path"], points=self.points, backend="opencv", num_frames=num_frames)
images = [self.transform(img) for img in images]
images = torch.stack(images)
return dict(index=index, images=images)
ret = dict(index=index, images=images)
return ret
def __len__(self):
return len(self.data)